Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/include/mlir/IR/StandardTypes.h
Line
Count
Source (jump to first uncovered line)
1
//===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#ifndef MLIR_IR_STANDARDTYPES_H
10
#define MLIR_IR_STANDARDTYPES_H
11
12
#include "mlir/IR/Types.h"
13
14
namespace llvm {
15
struct fltSemantics;
16
} // namespace llvm
17
18
namespace mlir {
19
class AffineExpr;
20
class AffineMap;
21
class FloatType;
22
class IndexType;
23
class IntegerType;
24
class Location;
25
class MLIRContext;
26
27
namespace detail {
28
29
struct IntegerTypeStorage;
30
struct ShapedTypeStorage;
31
struct VectorTypeStorage;
32
struct RankedTensorTypeStorage;
33
struct UnrankedTensorTypeStorage;
34
struct MemRefTypeStorage;
35
struct UnrankedMemRefTypeStorage;
36
struct ComplexTypeStorage;
37
struct TupleTypeStorage;
38
39
} // namespace detail
40
41
namespace StandardTypes {
42
enum Kind {
43
  // Floating point.
44
  BF16 = Type::Kind::FIRST_STANDARD_TYPE,
45
  F16,
46
  F32,
47
  F64,
48
  FIRST_FLOATING_POINT_TYPE = BF16,
49
  LAST_FLOATING_POINT_TYPE = F64,
50
51
  // Target pointer sized integer, used (e.g.) in affine mappings.
52
  Index,
53
54
  // Derived types.
55
  Integer,
56
  Vector,
57
  RankedTensor,
58
  UnrankedTensor,
59
  MemRef,
60
  UnrankedMemRef,
61
  Complex,
62
  Tuple,
63
  None,
64
};
65
66
} // namespace StandardTypes
67
68
//===----------------------------------------------------------------------===//
69
// ComplexType
70
//===----------------------------------------------------------------------===//
71
72
/// The 'complex' type represents a complex number with a parameterized element
73
/// type, which is composed of a real and imaginary value of that element type.
74
///
75
/// The element must be a floating point or integer scalar type.
76
///
77
class ComplexType
78
    : public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> {
79
public:
80
  using Base::Base;
81
82
  /// Get or create a ComplexType with the provided element type.
83
  static ComplexType get(Type elementType);
84
85
  /// Get or create a ComplexType with the provided element type.  This emits
86
  /// and error at the specified location and returns null if the element type
87
  /// isn't supported.
88
  static ComplexType getChecked(Type elementType, Location location);
89
90
  /// Verify the construction of an integer type.
91
  static LogicalResult verifyConstructionInvariants(Location loc,
92
                                                    Type elementType);
93
94
  Type getElementType();
95
96
0
  static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; }
97
};
98
99
//===----------------------------------------------------------------------===//
100
// IndexType
101
//===----------------------------------------------------------------------===//
102
103
/// Index is a special integer-like type with unknown platform-dependent bit
104
/// width.
105
class IndexType : public Type::TypeBase<IndexType, Type> {
106
public:
107
  using Base::Base;
108
109
  /// Get an instance of the IndexType.
110
  static IndexType get(MLIRContext *context);
111
112
  /// Support method to enable LLVM-style type casting.
113
0
  static bool kindof(unsigned kind) { return kind == StandardTypes::Index; }
114
115
  /// Storage bit width used for IndexType by internal compiler data structures.
116
  static constexpr unsigned kInternalStorageBitWidth = 64;
117
};
118
119
//===----------------------------------------------------------------------===//
120
// IntegerType
121
//===----------------------------------------------------------------------===//
122
123
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
124
class IntegerType
125
    : public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> {
126
public:
127
  using Base::Base;
128
129
  /// Signedness semantics.
130
  enum SignednessSemantics {
131
    Signless, /// No signedness semantics
132
    Signed,   /// Signed integer
133
    Unsigned, /// Unsigned integer
134
  };
135
136
  /// Get or create a new IntegerType of the given width within the context.
137
  /// The created IntegerType is signless (i.e., no signedness semantics).
138
  /// Assume the width is within the allowed range and assert on failures. Use
139
  /// getChecked to handle failures gracefully.
140
  static IntegerType get(unsigned width, MLIRContext *context);
141
142
  /// Get or create a new IntegerType of the given width within the context.
143
  /// The created IntegerType has signedness semantics as indicated via
144
  /// `signedness`. Assume the width is within the allowed range and assert on
145
  /// failures. Use getChecked to handle failures gracefully.
146
  static IntegerType get(unsigned width, SignednessSemantics signedness,
147
                         MLIRContext *context);
148
149
  /// Get or create a new IntegerType of the given width within the context,
150
  /// defined at the given, potentially unknown, location.  The created
151
  /// IntegerType is signless (i.e., no signedness semantics). If the width is
152
  /// outside the allowed range, emit errors and return a null type.
153
  static IntegerType getChecked(unsigned width, Location location);
154
155
  /// Get or create a new IntegerType of the given width within the context,
156
  /// defined at the given, potentially unknown, location. The created
157
  /// IntegerType has signedness semantics as indicated via `signedness`. If the
158
  /// width is outside the allowed range, emit errors and return a null type.
159
  static IntegerType getChecked(unsigned width, SignednessSemantics signedness,
160
                                Location location);
161
162
  /// Verify the construction of an integer type.
163
  static LogicalResult
164
  verifyConstructionInvariants(Location loc, unsigned width,
165
                               SignednessSemantics signedness);
166
167
  /// Return the bitwidth of this integer type.
168
  unsigned getWidth() const;
169
170
  /// Return the signedness semantics of this integer type.
171
  SignednessSemantics getSignedness() const;
172
173
  /// Return true if this is a signless integer type.
174
0
  bool isSignless() const { return getSignedness() == Signless; }
175
  /// Return true if this is a signed integer type.
176
0
  bool isSigned() const { return getSignedness() == Signed; }
177
  /// Return true if this is an unsigned integer type.
178
0
  bool isUnsigned() const { return getSignedness() == Unsigned; }
179
180
  /// Methods for support type inquiry through isa, cast, and dyn_cast.
181
0
  static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; }
182
183
  /// Integer representation maximal bitwidth.
184
  static constexpr unsigned kMaxWidth = 4096;
185
};
186
187
//===----------------------------------------------------------------------===//
188
// FloatType
189
//===----------------------------------------------------------------------===//
190
191
class FloatType : public Type::TypeBase<FloatType, Type> {
192
public:
193
  using Base::Base;
194
195
  static FloatType get(StandardTypes::Kind kind, MLIRContext *context);
196
197
  // Convenience factories.
198
0
  static FloatType getBF16(MLIRContext *ctx) {
199
0
    return get(StandardTypes::BF16, ctx);
200
0
  }
201
0
  static FloatType getF16(MLIRContext *ctx) {
202
0
    return get(StandardTypes::F16, ctx);
203
0
  }
204
0
  static FloatType getF32(MLIRContext *ctx) {
205
0
    return get(StandardTypes::F32, ctx);
206
0
  }
207
0
  static FloatType getF64(MLIRContext *ctx) {
208
0
    return get(StandardTypes::F64, ctx);
209
0
  }
210
211
  /// Methods for support type inquiry through isa, cast, and dyn_cast.
212
0
  static bool kindof(unsigned kind) {
213
0
    return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE &&
214
0
           kind <= StandardTypes::LAST_FLOATING_POINT_TYPE;
215
0
  }
216
217
  /// Return the bitwidth of this float type.
218
  unsigned getWidth();
219
220
  /// Return the floating semantics of this float type.
221
  const llvm::fltSemantics &getFloatSemantics();
222
};
223
224
//===----------------------------------------------------------------------===//
225
// NoneType
226
//===----------------------------------------------------------------------===//
227
228
/// NoneType is a unit type, i.e. a type with exactly one possible value, where
229
/// its value does not have a defined dynamic representation.
230
class NoneType : public Type::TypeBase<NoneType, Type> {
231
public:
232
  using Base::Base;
233
234
  /// Get an instance of the NoneType.
235
  static NoneType get(MLIRContext *context);
236
237
0
  static bool kindof(unsigned kind) { return kind == StandardTypes::None; }
238
};
239
240
//===----------------------------------------------------------------------===//
241
// ShapedType
242
//===----------------------------------------------------------------------===//
243
244
/// This is a common base class between Vector, UnrankedTensor, RankedTensor,
245
/// and MemRef types because they share behavior and semantics around shape,
246
/// rank, and fixed element type. Any type with these semantics should inherit
247
/// from ShapedType.
248
class ShapedType : public Type {
249
public:
250
  using ImplType = detail::ShapedTypeStorage;
251
  using Type::Type;
252
253
  // TODO(ntv): merge these two special values in a single one used everywhere.
254
  // Unfortunately, uses of `-1` have crept deep into the codebase now and are
255
  // hard to track.
256
  static constexpr int64_t kDynamicSize = -1;
257
  static constexpr int64_t kDynamicStrideOrOffset =
258
      std::numeric_limits<int64_t>::min();
259
260
  /// Return the element type.
261
  Type getElementType() const;
262
263
  /// If an element type is an integer or a float, return its width. Otherwise,
264
  /// abort.
265
  unsigned getElementTypeBitWidth() const;
266
267
  /// If it has static shape, return the number of elements. Otherwise, abort.
268
  int64_t getNumElements() const;
269
270
  /// If this is a ranked type, return the rank. Otherwise, abort.
271
  int64_t getRank() const;
272
273
  /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
274
  /// have a rank, while unranked tensors do not.
275
  bool hasRank() const;
276
277
  /// If this is a ranked type, return the shape. Otherwise, abort.
278
  ArrayRef<int64_t> getShape() const;
279
280
  /// If this is unranked type or any dimension has unknown size (<0), it
281
  /// doesn't have static shape. If all dimensions have known size (>= 0), it
282
  /// has static shape.
283
  bool hasStaticShape() const;
284
285
  /// If this has a static shape and the shape is equal to `shape` return true.
286
  bool hasStaticShape(ArrayRef<int64_t> shape) const;
287
288
  /// If this is a ranked type, return the number of dimensions with dynamic
289
  /// size. Otherwise, abort.
290
  int64_t getNumDynamicDims() const;
291
292
  /// If this is ranked type, return the size of the specified dimension.
293
  /// Otherwise, abort.
294
  int64_t getDimSize(unsigned idx) const;
295
296
  /// Returns true if this dimension has a dynamic size (for ranked types);
297
  /// aborts for unranked types.
298
  bool isDynamicDim(unsigned idx) const;
299
300
  /// Returns the position of the dynamic dimension relative to just the dynamic
301
  /// dimensions, given its `index` within the shape.
302
  unsigned getDynamicDimIndex(unsigned index) const;
303
304
  /// Get the total amount of bits occupied by a value of this type.  This does
305
  /// not take into account any memory layout or widening constraints, e.g. a
306
  /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
307
  /// it will likely be stored as in a 4xi64 vector register.  Fail an assertion
308
  /// if the size cannot be computed statically, i.e. if the type has a dynamic
309
  /// shape or if its elemental type does not have a known bit width.
310
  int64_t getSizeInBits() const;
311
312
  /// Methods for support type inquiry through isa, cast, and dyn_cast.
313
0
  static bool classof(Type type) {
314
0
    return type.getKind() == StandardTypes::Vector ||
315
0
           type.getKind() == StandardTypes::RankedTensor ||
316
0
           type.getKind() == StandardTypes::UnrankedTensor ||
317
0
           type.getKind() == StandardTypes::UnrankedMemRef ||
318
0
           type.getKind() == StandardTypes::MemRef;
319
0
  }
320
321
  /// Whether the given dimension size indicates a dynamic dimension.
322
0
  static constexpr bool isDynamic(int64_t dSize) {
323
0
    return dSize == kDynamicSize;
324
0
  }
325
0
  static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
326
0
    return dStrideOrOffset == kDynamicStrideOrOffset;
327
0
  }
328
};
329
330
//===----------------------------------------------------------------------===//
331
// VectorType
332
//===----------------------------------------------------------------------===//
333
334
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
335
/// known constant shape with one or more dimension.
336
class VectorType
337
    : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
338
public:
339
  using Base::Base;
340
341
  /// Get or create a new VectorType of the provided shape and element type.
342
  /// Assumes the arguments define a well-formed VectorType.
343
  static VectorType get(ArrayRef<int64_t> shape, Type elementType);
344
345
  /// Get or create a new VectorType of the provided shape and element type
346
  /// declared at the given, potentially unknown, location.  If the VectorType
347
  /// defined by the arguments would be ill-formed, emit errors and return
348
  /// nullptr-wrapping type.
349
  static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType,
350
                               Location location);
351
352
  /// Verify the construction of a vector type.
353
  static LogicalResult verifyConstructionInvariants(Location loc,
354
                                                    ArrayRef<int64_t> shape,
355
                                                    Type elementType);
356
357
  /// Returns true of the given type can be used as an element of a vector type.
358
  /// In particular, vectors can consist of integer or float primitives.
359
0
  static bool isValidElementType(Type t) {
360
0
    return t.isa<IntegerType>() || t.isa<FloatType>();
361
0
  }
362
363
  ArrayRef<int64_t> getShape() const;
364
365
  /// Methods for support type inquiry through isa, cast, and dyn_cast.
366
0
  static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; }
367
};
368
369
//===----------------------------------------------------------------------===//
370
// TensorType
371
//===----------------------------------------------------------------------===//
372
373
/// Tensor types represent multi-dimensional arrays, and have two variants:
374
/// RankedTensorType and UnrankedTensorType.
375
class TensorType : public ShapedType {
376
public:
377
  using ShapedType::ShapedType;
378
379
  /// Return true if the specified element type is ok in a tensor.
380
0
  static bool isValidElementType(Type type) {
381
0
    // Note: Non standard/builtin types are allowed to exist within tensor
382
0
    // types. Dialects are expected to verify that tensor types have a valid
383
0
    // element type within that dialect.
384
0
    return type.isa<ComplexType>() || type.isa<FloatType>() ||
385
0
           type.isa<IntegerType>() || type.isa<OpaqueType>() ||
386
0
           type.isa<VectorType>() || type.isa<IndexType>() ||
387
0
           (type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
388
0
  }
389
390
  /// Methods for support type inquiry through isa, cast, and dyn_cast.
391
0
  static bool classof(Type type) {
392
0
    return type.getKind() == StandardTypes::RankedTensor ||
393
0
           type.getKind() == StandardTypes::UnrankedTensor;
394
0
  }
395
};
396
397
//===----------------------------------------------------------------------===//
398
// RankedTensorType
399
400
/// Ranked tensor types represent multi-dimensional arrays that have a shape
401
/// with a fixed number of dimensions. Each shape element can be a non-negative
402
/// integer or unknown (represented by -1).
403
class RankedTensorType
404
    : public Type::TypeBase<RankedTensorType, TensorType,
405
                            detail::RankedTensorTypeStorage> {
406
public:
407
  using Base::Base;
408
409
  /// Get or create a new RankedTensorType of the provided shape and element
410
  /// type. Assumes the arguments define a well-formed type.
411
  static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType);
412
413
  /// Get or create a new RankedTensorType of the provided shape and element
414
  /// type declared at the given, potentially unknown, location.  If the
415
  /// RankedTensorType defined by the arguments would be ill-formed, emit errors
416
  /// and return a nullptr-wrapping type.
417
  static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType,
418
                                     Location location);
419
420
  /// Verify the construction of a ranked tensor type.
421
  static LogicalResult verifyConstructionInvariants(Location loc,
422
                                                    ArrayRef<int64_t> shape,
423
                                                    Type elementType);
424
425
  ArrayRef<int64_t> getShape() const;
426
427
0
  static bool kindof(unsigned kind) {
428
0
    return kind == StandardTypes::RankedTensor;
429
0
  }
430
};
431
432
//===----------------------------------------------------------------------===//
433
// UnrankedTensorType
434
435
/// Unranked tensor types represent multi-dimensional arrays that have an
436
/// unknown shape.
437
class UnrankedTensorType
438
    : public Type::TypeBase<UnrankedTensorType, TensorType,
439
                            detail::UnrankedTensorTypeStorage> {
440
public:
441
  using Base::Base;
442
443
  /// Get or create a new UnrankedTensorType of the provided shape and element
444
  /// type. Assumes the arguments define a well-formed type.
445
  static UnrankedTensorType get(Type elementType);
446
447
  /// Get or create a new UnrankedTensorType of the provided shape and element
448
  /// type declared at the given, potentially unknown, location.  If the
449
  /// UnrankedTensorType defined by the arguments would be ill-formed, emit
450
  /// errors and return a nullptr-wrapping type.
451
  static UnrankedTensorType getChecked(Type elementType, Location location);
452
453
  /// Verify the construction of a unranked tensor type.
454
  static LogicalResult verifyConstructionInvariants(Location loc,
455
                                                    Type elementType);
456
457
0
  ArrayRef<int64_t> getShape() const { return llvm::None; }
458
459
0
  static bool kindof(unsigned kind) {
460
0
    return kind == StandardTypes::UnrankedTensor;
461
0
  }
462
};
463
464
//===----------------------------------------------------------------------===//
465
// BaseMemRefType
466
//===----------------------------------------------------------------------===//
467
468
/// Base MemRef for Ranked and Unranked variants
469
class BaseMemRefType : public ShapedType {
470
public:
471
  using ShapedType::ShapedType;
472
473
  /// Methods for support type inquiry through isa, cast, and dyn_cast.
474
0
  static bool classof(Type type) {
475
0
    return type.getKind() == StandardTypes::MemRef ||
476
0
           type.getKind() == StandardTypes::UnrankedMemRef;
477
0
  }
478
};
479
480
//===----------------------------------------------------------------------===//
481
// MemRefType
482
483
/// MemRef types represent a region of memory that have a shape with a fixed
484
/// number of dimensions. Each shape element can be a non-negative integer or
485
/// unknown (represented by -1). MemRef types also have an affine map
486
/// composition, represented as an array AffineMap pointers.
487
class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
488
                                         detail::MemRefTypeStorage> {
489
public:
490
  /// This is a builder type that keeps local references to arguments. Arguments
491
  /// that are passed into the builder must out-live the builder.
492
  class Builder {
493
  public:
494
    // Build from another MemRefType.
495
    explicit Builder(MemRefType other)
496
        : shape(other.getShape()), elementType(other.getElementType()),
497
          affineMaps(other.getAffineMaps()),
498
0
          memorySpace(other.getMemorySpace()) {}
499
500
    // Build from scratch.
501
    Builder(ArrayRef<int64_t> shape, Type elementType)
502
0
        : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {
503
0
    }
504
505
0
    Builder &setShape(ArrayRef<int64_t> newShape) {
506
0
      shape = newShape;
507
0
      return *this;
508
0
    }
509
510
0
    Builder &setElementType(Type newElementType) {
511
0
      elementType = newElementType;
512
0
      return *this;
513
0
    }
514
515
0
    Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
516
0
      affineMaps = newAffineMaps;
517
0
      return *this;
518
0
    }
519
520
0
    Builder &setMemorySpace(unsigned newMemorySpace) {
521
0
      memorySpace = newMemorySpace;
522
0
      return *this;
523
0
    }
524
525
0
    operator MemRefType() {
526
0
      return MemRefType::get(shape, elementType, affineMaps, memorySpace);
527
0
    }
528
529
  private:
530
    ArrayRef<int64_t> shape;
531
    Type elementType;
532
    ArrayRef<AffineMap> affineMaps;
533
    unsigned memorySpace;
534
  };
535
536
  using Base::Base;
537
538
  /// Get or create a new MemRefType based on shape, element type, affine
539
  /// map composition, and memory space.  Assumes the arguments define a
540
  /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
541
  /// construction failures.
542
  static MemRefType get(ArrayRef<int64_t> shape, Type elementType,
543
                        ArrayRef<AffineMap> affineMapComposition = {},
544
                        unsigned memorySpace = 0);
545
546
  /// Get or create a new MemRefType based on shape, element type, affine
547
  /// map composition, and memory space declared at the given location.
548
  /// If the location is unknown, the last argument should be an instance of
549
  /// UnknownLoc.  If the MemRefType defined by the arguments would be
550
  /// ill-formed, emits errors (to the handler registered with the context or to
551
  /// the error stream) and returns nullptr.
552
  static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType,
553
                               ArrayRef<AffineMap> affineMapComposition,
554
                               unsigned memorySpace, Location location);
555
556
  ArrayRef<int64_t> getShape() const;
557
558
  /// Returns an array of affine map pointers representing the memref affine
559
  /// map composition.
560
  ArrayRef<AffineMap> getAffineMaps() const;
561
562
  /// Returns the memory space in which data referred to by this memref resides.
563
  unsigned getMemorySpace() const;
564
565
  // TODO(ntv): merge these two special values in a single one used everywhere.
566
  // Unfortunately, uses of `-1` have crept deep into the codebase now and are
567
  // hard to track.
568
  static constexpr int64_t kDynamicSize = -1;
569
0
  static int64_t getDynamicStrideOrOffset() {
570
0
    return ShapedType::kDynamicStrideOrOffset;
571
0
  }
572
573
0
  static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
574
575
private:
576
  /// Get or create a new MemRefType defined by the arguments.  If the resulting
577
  /// type would be ill-formed, return nullptr.  If the location is provided,
578
  /// emit detailed error messages.
579
  static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
580
                            ArrayRef<AffineMap> affineMapComposition,
581
                            unsigned memorySpace, Optional<Location> location);
582
  using Base::getImpl;
583
};
584
585
//===----------------------------------------------------------------------===//
586
// UnrankedMemRefType
587
588
/// Unranked MemRef type represent multi-dimensional MemRefs that
589
/// have an unknown rank.
590
class UnrankedMemRefType
591
    : public Type::TypeBase<UnrankedMemRefType, BaseMemRefType,
592
                            detail::UnrankedMemRefTypeStorage> {
593
public:
594
  using Base::Base;
595
596
  /// Get or create a new UnrankedMemRefType of the provided element
597
  /// type and memory space
598
  static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
599
600
  /// Get or create a new UnrankedMemRefType of the provided element
601
  /// type and memory space declared at the given, potentially unknown,
602
  /// location. If the UnrankedMemRefType defined by the arguments would be
603
  /// ill-formed, emit errors and return a nullptr-wrapping type.
604
  static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
605
                                       Location location);
606
607
  /// Verify the construction of a unranked memref type.
608
  static LogicalResult verifyConstructionInvariants(Location loc,
609
                                                    Type elementType,
610
                                                    unsigned memorySpace);
611
612
0
  ArrayRef<int64_t> getShape() const { return llvm::None; }
613
614
  /// Returns the memory space in which data referred to by this memref resides.
615
  unsigned getMemorySpace() const;
616
0
  static bool kindof(unsigned kind) {
617
0
    return kind == StandardTypes::UnrankedMemRef;
618
0
  }
619
};
620
621
//===----------------------------------------------------------------------===//
622
// TupleType
623
//===----------------------------------------------------------------------===//
624
625
/// Tuple types represent a collection of other types. Note: This type merely
626
/// provides a common mechanism for representing tuples in MLIR. It is up to
627
/// dialect authors to provides operations for manipulating them, e.g.
628
/// extract_tuple_element. When possible, users should prefer multi-result
629
/// operations in the place of tuples.
630
class TupleType
631
    : public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> {
632
public:
633
  using Base::Base;
634
635
  /// Get or create a new TupleType with the provided element types. Assumes the
636
  /// arguments define a well-formed type.
637
  static TupleType get(ArrayRef<Type> elementTypes, MLIRContext *context);
638
639
  /// Get or create an empty tuple type.
640
0
  static TupleType get(MLIRContext *context) { return get({}, context); }
641
642
  /// Return the elements types for this tuple.
643
  ArrayRef<Type> getTypes() const;
644
645
  /// Accumulate the types contained in this tuple and tuples nested within it.
646
  /// Note that this only flattens nested tuples, not any other container type,
647
  /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
648
  /// (i32, tensor<i32>, f32, i64)
649
  void getFlattenedTypes(SmallVectorImpl<Type> &types);
650
651
  /// Return the number of held types.
652
  size_t size() const;
653
654
  /// Iterate over the held elements.
655
  using iterator = ArrayRef<Type>::iterator;
656
0
  iterator begin() const { return getTypes().begin(); }
657
0
  iterator end() const { return getTypes().end(); }
658
659
  /// Return the element type at index 'index'.
660
0
  Type getType(size_t index) const {
661
0
    assert(index < size() && "invalid index for tuple type");
662
0
    return getTypes()[index];
663
0
  }
664
665
0
  static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; }
666
};
667
668
//===----------------------------------------------------------------------===//
669
// Type Utilities
670
//===----------------------------------------------------------------------===//
671
672
/// Returns the strides of the MemRef if the layout map is in strided form.
673
/// MemRefs with layout maps in strided form include:
674
///   1. empty or identity layout map, in which case the stride information is
675
///      the canonical form computed from sizes;
676
///   2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
677
///      where K and ki's are constants or symbols.
678
///
679
/// A stride specification is a list of integer values that are either static
680
/// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the
681
/// distance in the number of elements between successive entries along a
682
/// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
683
/// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
684
/// elements in which the distance between two consecutive elements along the
685
/// outer dimension is `1` and the distance between two consecutive elements
686
/// along the inner dimension is `64`.
687
///
688
/// If a simple strided form cannot be extracted from the composition of the
689
/// layout map, returns llvm::None.
690
///
691
/// The convention is that the strides for dimensions d0, .. dn appear in
692
/// order to make indexing intuitive into the result.
693
LogicalResult getStridesAndOffset(MemRefType t,
694
                                  SmallVectorImpl<int64_t> &strides,
695
                                  int64_t &offset);
696
LogicalResult getStridesAndOffset(MemRefType t,
697
                                  SmallVectorImpl<AffineExpr> &strides,
698
                                  AffineExpr &offset);
699
700
/// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset()
701
/// represents a dynamic value), return the single result AffineMap which
702
/// represents the linearized strided layout map. Dimensions correspond to the
703
/// offset followed by the strides in order. Symbols are inserted for each
704
/// dynamic dimension in order. A stride cannot take value `0`.
705
///
706
/// Examples:
707
/// =========
708
///
709
///   1. For offset: 0 strides: ?, ?, 1 return
710
///         (i, j, k)[M, N]->(M * i + N * j + k)
711
///
712
///   2. For offset: 3 strides: 32, ?, 16 return
713
///         (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k)
714
///
715
///   3. For offset: ? strides: ?, ?, ? return
716
///         (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k)
717
AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
718
                                     MLIRContext *context);
719
720
/// Return a version of `t` with identity layout if it can be determined
721
/// statically that the layout is the canonical contiguous strided layout.
722
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
723
/// `t` with simplified layout.
724
MemRefType canonicalizeStridedLayout(MemRefType t);
725
726
/// Return a version of `t` with a layout that has all dynamic offset and
727
/// strides. This is used to erase the static layout.
728
MemRefType eraseStridedLayout(MemRefType t);
729
730
/// Given MemRef `sizes` that are either static or dynamic, returns the
731
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
732
/// once a dynamic dimension is encountered, all canonical strides become
733
/// dynamic and need to be encoded with a different symbol.
734
/// For canonical strides expressions, the offset is always 0 and and fastest
735
/// varying stride is always `1`.
736
///
737
/// Examples:
738
///   - memref<3x4x5xf32> has canonical stride expression
739
///         `20*exprs[0] + 5*exprs[1] + exprs[2]`.
740
///   - memref<3x?x5xf32> has canonical stride expression
741
///         `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
742
///   - memref<3x4x?xf32> has canonical stride expression
743
///         `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
744
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
745
                                          ArrayRef<AffineExpr> exprs,
746
                                          MLIRContext *context);
747
748
/// Return the result of makeCanonicalStrudedLayoutExpr for the common case
749
/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
750
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
751
                                          MLIRContext *context);
752
753
/// Return true if the layout for `t` is compatible with strided semantics.
754
bool isStrided(MemRefType t);
755
756
} // end namespace mlir
757
758
#endif // MLIR_IR_STANDARDTYPES_H