/home/arjun/llvm-project/mlir/include/mlir/IR/StandardTypes.h
Line | Count | Source (jump to first uncovered line) |
1 | | //===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #ifndef MLIR_IR_STANDARDTYPES_H |
10 | | #define MLIR_IR_STANDARDTYPES_H |
11 | | |
12 | | #include "mlir/IR/Types.h" |
13 | | |
14 | | namespace llvm { |
15 | | struct fltSemantics; |
16 | | } // namespace llvm |
17 | | |
18 | | namespace mlir { |
19 | | class AffineExpr; |
20 | | class AffineMap; |
21 | | class FloatType; |
22 | | class IndexType; |
23 | | class IntegerType; |
24 | | class Location; |
25 | | class MLIRContext; |
26 | | |
27 | | namespace detail { |
28 | | |
29 | | struct IntegerTypeStorage; |
30 | | struct ShapedTypeStorage; |
31 | | struct VectorTypeStorage; |
32 | | struct RankedTensorTypeStorage; |
33 | | struct UnrankedTensorTypeStorage; |
34 | | struct MemRefTypeStorage; |
35 | | struct UnrankedMemRefTypeStorage; |
36 | | struct ComplexTypeStorage; |
37 | | struct TupleTypeStorage; |
38 | | |
39 | | } // namespace detail |
40 | | |
41 | | namespace StandardTypes { |
42 | | enum Kind { |
43 | | // Floating point. |
44 | | BF16 = Type::Kind::FIRST_STANDARD_TYPE, |
45 | | F16, |
46 | | F32, |
47 | | F64, |
48 | | FIRST_FLOATING_POINT_TYPE = BF16, |
49 | | LAST_FLOATING_POINT_TYPE = F64, |
50 | | |
51 | | // Target pointer sized integer, used (e.g.) in affine mappings. |
52 | | Index, |
53 | | |
54 | | // Derived types. |
55 | | Integer, |
56 | | Vector, |
57 | | RankedTensor, |
58 | | UnrankedTensor, |
59 | | MemRef, |
60 | | UnrankedMemRef, |
61 | | Complex, |
62 | | Tuple, |
63 | | None, |
64 | | }; |
65 | | |
66 | | } // namespace StandardTypes |
67 | | |
68 | | //===----------------------------------------------------------------------===// |
69 | | // ComplexType |
70 | | //===----------------------------------------------------------------------===// |
71 | | |
72 | | /// The 'complex' type represents a complex number with a parameterized element |
73 | | /// type, which is composed of a real and imaginary value of that element type. |
74 | | /// |
75 | | /// The element must be a floating point or integer scalar type. |
76 | | /// |
77 | | class ComplexType |
78 | | : public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> { |
79 | | public: |
80 | | using Base::Base; |
81 | | |
82 | | /// Get or create a ComplexType with the provided element type. |
83 | | static ComplexType get(Type elementType); |
84 | | |
85 | | /// Get or create a ComplexType with the provided element type. This emits |
86 | | /// and error at the specified location and returns null if the element type |
87 | | /// isn't supported. |
88 | | static ComplexType getChecked(Type elementType, Location location); |
89 | | |
90 | | /// Verify the construction of an integer type. |
91 | | static LogicalResult verifyConstructionInvariants(Location loc, |
92 | | Type elementType); |
93 | | |
94 | | Type getElementType(); |
95 | | |
96 | 0 | static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; } |
97 | | }; |
98 | | |
99 | | //===----------------------------------------------------------------------===// |
100 | | // IndexType |
101 | | //===----------------------------------------------------------------------===// |
102 | | |
103 | | /// Index is a special integer-like type with unknown platform-dependent bit |
104 | | /// width. |
105 | | class IndexType : public Type::TypeBase<IndexType, Type> { |
106 | | public: |
107 | | using Base::Base; |
108 | | |
109 | | /// Get an instance of the IndexType. |
110 | | static IndexType get(MLIRContext *context); |
111 | | |
112 | | /// Support method to enable LLVM-style type casting. |
113 | 0 | static bool kindof(unsigned kind) { return kind == StandardTypes::Index; } |
114 | | |
115 | | /// Storage bit width used for IndexType by internal compiler data structures. |
116 | | static constexpr unsigned kInternalStorageBitWidth = 64; |
117 | | }; |
118 | | |
119 | | //===----------------------------------------------------------------------===// |
120 | | // IntegerType |
121 | | //===----------------------------------------------------------------------===// |
122 | | |
123 | | /// Integer types can have arbitrary bitwidth up to a large fixed limit. |
124 | | class IntegerType |
125 | | : public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> { |
126 | | public: |
127 | | using Base::Base; |
128 | | |
129 | | /// Signedness semantics. |
130 | | enum SignednessSemantics { |
131 | | Signless, /// No signedness semantics |
132 | | Signed, /// Signed integer |
133 | | Unsigned, /// Unsigned integer |
134 | | }; |
135 | | |
136 | | /// Get or create a new IntegerType of the given width within the context. |
137 | | /// The created IntegerType is signless (i.e., no signedness semantics). |
138 | | /// Assume the width is within the allowed range and assert on failures. Use |
139 | | /// getChecked to handle failures gracefully. |
140 | | static IntegerType get(unsigned width, MLIRContext *context); |
141 | | |
142 | | /// Get or create a new IntegerType of the given width within the context. |
143 | | /// The created IntegerType has signedness semantics as indicated via |
144 | | /// `signedness`. Assume the width is within the allowed range and assert on |
145 | | /// failures. Use getChecked to handle failures gracefully. |
146 | | static IntegerType get(unsigned width, SignednessSemantics signedness, |
147 | | MLIRContext *context); |
148 | | |
149 | | /// Get or create a new IntegerType of the given width within the context, |
150 | | /// defined at the given, potentially unknown, location. The created |
151 | | /// IntegerType is signless (i.e., no signedness semantics). If the width is |
152 | | /// outside the allowed range, emit errors and return a null type. |
153 | | static IntegerType getChecked(unsigned width, Location location); |
154 | | |
155 | | /// Get or create a new IntegerType of the given width within the context, |
156 | | /// defined at the given, potentially unknown, location. The created |
157 | | /// IntegerType has signedness semantics as indicated via `signedness`. If the |
158 | | /// width is outside the allowed range, emit errors and return a null type. |
159 | | static IntegerType getChecked(unsigned width, SignednessSemantics signedness, |
160 | | Location location); |
161 | | |
162 | | /// Verify the construction of an integer type. |
163 | | static LogicalResult |
164 | | verifyConstructionInvariants(Location loc, unsigned width, |
165 | | SignednessSemantics signedness); |
166 | | |
167 | | /// Return the bitwidth of this integer type. |
168 | | unsigned getWidth() const; |
169 | | |
170 | | /// Return the signedness semantics of this integer type. |
171 | | SignednessSemantics getSignedness() const; |
172 | | |
173 | | /// Return true if this is a signless integer type. |
174 | 0 | bool isSignless() const { return getSignedness() == Signless; } |
175 | | /// Return true if this is a signed integer type. |
176 | 0 | bool isSigned() const { return getSignedness() == Signed; } |
177 | | /// Return true if this is an unsigned integer type. |
178 | 0 | bool isUnsigned() const { return getSignedness() == Unsigned; } |
179 | | |
180 | | /// Methods for support type inquiry through isa, cast, and dyn_cast. |
181 | 0 | static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; } |
182 | | |
183 | | /// Integer representation maximal bitwidth. |
184 | | static constexpr unsigned kMaxWidth = 4096; |
185 | | }; |
186 | | |
187 | | //===----------------------------------------------------------------------===// |
188 | | // FloatType |
189 | | //===----------------------------------------------------------------------===// |
190 | | |
191 | | class FloatType : public Type::TypeBase<FloatType, Type> { |
192 | | public: |
193 | | using Base::Base; |
194 | | |
195 | | static FloatType get(StandardTypes::Kind kind, MLIRContext *context); |
196 | | |
197 | | // Convenience factories. |
198 | 0 | static FloatType getBF16(MLIRContext *ctx) { |
199 | 0 | return get(StandardTypes::BF16, ctx); |
200 | 0 | } |
201 | 0 | static FloatType getF16(MLIRContext *ctx) { |
202 | 0 | return get(StandardTypes::F16, ctx); |
203 | 0 | } |
204 | 0 | static FloatType getF32(MLIRContext *ctx) { |
205 | 0 | return get(StandardTypes::F32, ctx); |
206 | 0 | } |
207 | 0 | static FloatType getF64(MLIRContext *ctx) { |
208 | 0 | return get(StandardTypes::F64, ctx); |
209 | 0 | } |
210 | | |
211 | | /// Methods for support type inquiry through isa, cast, and dyn_cast. |
212 | 0 | static bool kindof(unsigned kind) { |
213 | 0 | return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE && |
214 | 0 | kind <= StandardTypes::LAST_FLOATING_POINT_TYPE; |
215 | 0 | } |
216 | | |
217 | | /// Return the bitwidth of this float type. |
218 | | unsigned getWidth(); |
219 | | |
220 | | /// Return the floating semantics of this float type. |
221 | | const llvm::fltSemantics &getFloatSemantics(); |
222 | | }; |
223 | | |
224 | | //===----------------------------------------------------------------------===// |
225 | | // NoneType |
226 | | //===----------------------------------------------------------------------===// |
227 | | |
228 | | /// NoneType is a unit type, i.e. a type with exactly one possible value, where |
229 | | /// its value does not have a defined dynamic representation. |
230 | | class NoneType : public Type::TypeBase<NoneType, Type> { |
231 | | public: |
232 | | using Base::Base; |
233 | | |
234 | | /// Get an instance of the NoneType. |
235 | | static NoneType get(MLIRContext *context); |
236 | | |
237 | 0 | static bool kindof(unsigned kind) { return kind == StandardTypes::None; } |
238 | | }; |
239 | | |
240 | | //===----------------------------------------------------------------------===// |
241 | | // ShapedType |
242 | | //===----------------------------------------------------------------------===// |
243 | | |
244 | | /// This is a common base class between Vector, UnrankedTensor, RankedTensor, |
245 | | /// and MemRef types because they share behavior and semantics around shape, |
246 | | /// rank, and fixed element type. Any type with these semantics should inherit |
247 | | /// from ShapedType. |
248 | | class ShapedType : public Type { |
249 | | public: |
250 | | using ImplType = detail::ShapedTypeStorage; |
251 | | using Type::Type; |
252 | | |
253 | | // TODO(ntv): merge these two special values in a single one used everywhere. |
254 | | // Unfortunately, uses of `-1` have crept deep into the codebase now and are |
255 | | // hard to track. |
256 | | static constexpr int64_t kDynamicSize = -1; |
257 | | static constexpr int64_t kDynamicStrideOrOffset = |
258 | | std::numeric_limits<int64_t>::min(); |
259 | | |
260 | | /// Return the element type. |
261 | | Type getElementType() const; |
262 | | |
263 | | /// If an element type is an integer or a float, return its width. Otherwise, |
264 | | /// abort. |
265 | | unsigned getElementTypeBitWidth() const; |
266 | | |
267 | | /// If it has static shape, return the number of elements. Otherwise, abort. |
268 | | int64_t getNumElements() const; |
269 | | |
270 | | /// If this is a ranked type, return the rank. Otherwise, abort. |
271 | | int64_t getRank() const; |
272 | | |
273 | | /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors |
274 | | /// have a rank, while unranked tensors do not. |
275 | | bool hasRank() const; |
276 | | |
277 | | /// If this is a ranked type, return the shape. Otherwise, abort. |
278 | | ArrayRef<int64_t> getShape() const; |
279 | | |
280 | | /// If this is unranked type or any dimension has unknown size (<0), it |
281 | | /// doesn't have static shape. If all dimensions have known size (>= 0), it |
282 | | /// has static shape. |
283 | | bool hasStaticShape() const; |
284 | | |
285 | | /// If this has a static shape and the shape is equal to `shape` return true. |
286 | | bool hasStaticShape(ArrayRef<int64_t> shape) const; |
287 | | |
288 | | /// If this is a ranked type, return the number of dimensions with dynamic |
289 | | /// size. Otherwise, abort. |
290 | | int64_t getNumDynamicDims() const; |
291 | | |
292 | | /// If this is ranked type, return the size of the specified dimension. |
293 | | /// Otherwise, abort. |
294 | | int64_t getDimSize(unsigned idx) const; |
295 | | |
296 | | /// Returns true if this dimension has a dynamic size (for ranked types); |
297 | | /// aborts for unranked types. |
298 | | bool isDynamicDim(unsigned idx) const; |
299 | | |
300 | | /// Returns the position of the dynamic dimension relative to just the dynamic |
301 | | /// dimensions, given its `index` within the shape. |
302 | | unsigned getDynamicDimIndex(unsigned index) const; |
303 | | |
304 | | /// Get the total amount of bits occupied by a value of this type. This does |
305 | | /// not take into account any memory layout or widening constraints, e.g. a |
306 | | /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice |
307 | | /// it will likely be stored as in a 4xi64 vector register. Fail an assertion |
308 | | /// if the size cannot be computed statically, i.e. if the type has a dynamic |
309 | | /// shape or if its elemental type does not have a known bit width. |
310 | | int64_t getSizeInBits() const; |
311 | | |
312 | | /// Methods for support type inquiry through isa, cast, and dyn_cast. |
313 | 0 | static bool classof(Type type) { |
314 | 0 | return type.getKind() == StandardTypes::Vector || |
315 | 0 | type.getKind() == StandardTypes::RankedTensor || |
316 | 0 | type.getKind() == StandardTypes::UnrankedTensor || |
317 | 0 | type.getKind() == StandardTypes::UnrankedMemRef || |
318 | 0 | type.getKind() == StandardTypes::MemRef; |
319 | 0 | } |
320 | | |
321 | | /// Whether the given dimension size indicates a dynamic dimension. |
322 | 0 | static constexpr bool isDynamic(int64_t dSize) { |
323 | 0 | return dSize == kDynamicSize; |
324 | 0 | } |
325 | 0 | static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) { |
326 | 0 | return dStrideOrOffset == kDynamicStrideOrOffset; |
327 | 0 | } |
328 | | }; |
329 | | |
330 | | //===----------------------------------------------------------------------===// |
331 | | // VectorType |
332 | | //===----------------------------------------------------------------------===// |
333 | | |
334 | | /// Vector types represent multi-dimensional SIMD vectors, and have a fixed |
335 | | /// known constant shape with one or more dimension. |
336 | | class VectorType |
337 | | : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> { |
338 | | public: |
339 | | using Base::Base; |
340 | | |
341 | | /// Get or create a new VectorType of the provided shape and element type. |
342 | | /// Assumes the arguments define a well-formed VectorType. |
343 | | static VectorType get(ArrayRef<int64_t> shape, Type elementType); |
344 | | |
345 | | /// Get or create a new VectorType of the provided shape and element type |
346 | | /// declared at the given, potentially unknown, location. If the VectorType |
347 | | /// defined by the arguments would be ill-formed, emit errors and return |
348 | | /// nullptr-wrapping type. |
349 | | static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType, |
350 | | Location location); |
351 | | |
352 | | /// Verify the construction of a vector type. |
353 | | static LogicalResult verifyConstructionInvariants(Location loc, |
354 | | ArrayRef<int64_t> shape, |
355 | | Type elementType); |
356 | | |
357 | | /// Returns true of the given type can be used as an element of a vector type. |
358 | | /// In particular, vectors can consist of integer or float primitives. |
359 | 0 | static bool isValidElementType(Type t) { |
360 | 0 | return t.isa<IntegerType>() || t.isa<FloatType>(); |
361 | 0 | } |
362 | | |
363 | | ArrayRef<int64_t> getShape() const; |
364 | | |
365 | | /// Methods for support type inquiry through isa, cast, and dyn_cast. |
366 | 0 | static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; } |
367 | | }; |
368 | | |
369 | | //===----------------------------------------------------------------------===// |
370 | | // TensorType |
371 | | //===----------------------------------------------------------------------===// |
372 | | |
373 | | /// Tensor types represent multi-dimensional arrays, and have two variants: |
374 | | /// RankedTensorType and UnrankedTensorType. |
375 | | class TensorType : public ShapedType { |
376 | | public: |
377 | | using ShapedType::ShapedType; |
378 | | |
379 | | /// Return true if the specified element type is ok in a tensor. |
380 | 0 | static bool isValidElementType(Type type) { |
381 | 0 | // Note: Non standard/builtin types are allowed to exist within tensor |
382 | 0 | // types. Dialects are expected to verify that tensor types have a valid |
383 | 0 | // element type within that dialect. |
384 | 0 | return type.isa<ComplexType>() || type.isa<FloatType>() || |
385 | 0 | type.isa<IntegerType>() || type.isa<OpaqueType>() || |
386 | 0 | type.isa<VectorType>() || type.isa<IndexType>() || |
387 | 0 | (type.getKind() > Type::Kind::LAST_STANDARD_TYPE); |
388 | 0 | } |
389 | | |
390 | | /// Methods for support type inquiry through isa, cast, and dyn_cast. |
391 | 0 | static bool classof(Type type) { |
392 | 0 | return type.getKind() == StandardTypes::RankedTensor || |
393 | 0 | type.getKind() == StandardTypes::UnrankedTensor; |
394 | 0 | } |
395 | | }; |
396 | | |
397 | | //===----------------------------------------------------------------------===// |
398 | | // RankedTensorType |
399 | | |
400 | | /// Ranked tensor types represent multi-dimensional arrays that have a shape |
401 | | /// with a fixed number of dimensions. Each shape element can be a non-negative |
402 | | /// integer or unknown (represented by -1). |
403 | | class RankedTensorType |
404 | | : public Type::TypeBase<RankedTensorType, TensorType, |
405 | | detail::RankedTensorTypeStorage> { |
406 | | public: |
407 | | using Base::Base; |
408 | | |
409 | | /// Get or create a new RankedTensorType of the provided shape and element |
410 | | /// type. Assumes the arguments define a well-formed type. |
411 | | static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType); |
412 | | |
413 | | /// Get or create a new RankedTensorType of the provided shape and element |
414 | | /// type declared at the given, potentially unknown, location. If the |
415 | | /// RankedTensorType defined by the arguments would be ill-formed, emit errors |
416 | | /// and return a nullptr-wrapping type. |
417 | | static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType, |
418 | | Location location); |
419 | | |
420 | | /// Verify the construction of a ranked tensor type. |
421 | | static LogicalResult verifyConstructionInvariants(Location loc, |
422 | | ArrayRef<int64_t> shape, |
423 | | Type elementType); |
424 | | |
425 | | ArrayRef<int64_t> getShape() const; |
426 | | |
427 | 0 | static bool kindof(unsigned kind) { |
428 | 0 | return kind == StandardTypes::RankedTensor; |
429 | 0 | } |
430 | | }; |
431 | | |
432 | | //===----------------------------------------------------------------------===// |
433 | | // UnrankedTensorType |
434 | | |
435 | | /// Unranked tensor types represent multi-dimensional arrays that have an |
436 | | /// unknown shape. |
437 | | class UnrankedTensorType |
438 | | : public Type::TypeBase<UnrankedTensorType, TensorType, |
439 | | detail::UnrankedTensorTypeStorage> { |
440 | | public: |
441 | | using Base::Base; |
442 | | |
443 | | /// Get or create a new UnrankedTensorType of the provided shape and element |
444 | | /// type. Assumes the arguments define a well-formed type. |
445 | | static UnrankedTensorType get(Type elementType); |
446 | | |
447 | | /// Get or create a new UnrankedTensorType of the provided shape and element |
448 | | /// type declared at the given, potentially unknown, location. If the |
449 | | /// UnrankedTensorType defined by the arguments would be ill-formed, emit |
450 | | /// errors and return a nullptr-wrapping type. |
451 | | static UnrankedTensorType getChecked(Type elementType, Location location); |
452 | | |
453 | | /// Verify the construction of a unranked tensor type. |
454 | | static LogicalResult verifyConstructionInvariants(Location loc, |
455 | | Type elementType); |
456 | | |
457 | 0 | ArrayRef<int64_t> getShape() const { return llvm::None; } |
458 | | |
459 | 0 | static bool kindof(unsigned kind) { |
460 | 0 | return kind == StandardTypes::UnrankedTensor; |
461 | 0 | } |
462 | | }; |
463 | | |
464 | | //===----------------------------------------------------------------------===// |
465 | | // BaseMemRefType |
466 | | //===----------------------------------------------------------------------===// |
467 | | |
468 | | /// Base MemRef for Ranked and Unranked variants |
469 | | class BaseMemRefType : public ShapedType { |
470 | | public: |
471 | | using ShapedType::ShapedType; |
472 | | |
473 | | /// Methods for support type inquiry through isa, cast, and dyn_cast. |
474 | 0 | static bool classof(Type type) { |
475 | 0 | return type.getKind() == StandardTypes::MemRef || |
476 | 0 | type.getKind() == StandardTypes::UnrankedMemRef; |
477 | 0 | } |
478 | | }; |
479 | | |
480 | | //===----------------------------------------------------------------------===// |
481 | | // MemRefType |
482 | | |
483 | | /// MemRef types represent a region of memory that have a shape with a fixed |
484 | | /// number of dimensions. Each shape element can be a non-negative integer or |
485 | | /// unknown (represented by -1). MemRef types also have an affine map |
486 | | /// composition, represented as an array AffineMap pointers. |
487 | | class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType, |
488 | | detail::MemRefTypeStorage> { |
489 | | public: |
490 | | /// This is a builder type that keeps local references to arguments. Arguments |
491 | | /// that are passed into the builder must out-live the builder. |
492 | | class Builder { |
493 | | public: |
494 | | // Build from another MemRefType. |
495 | | explicit Builder(MemRefType other) |
496 | | : shape(other.getShape()), elementType(other.getElementType()), |
497 | | affineMaps(other.getAffineMaps()), |
498 | 0 | memorySpace(other.getMemorySpace()) {} |
499 | | |
500 | | // Build from scratch. |
501 | | Builder(ArrayRef<int64_t> shape, Type elementType) |
502 | 0 | : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) { |
503 | 0 | } |
504 | | |
505 | 0 | Builder &setShape(ArrayRef<int64_t> newShape) { |
506 | 0 | shape = newShape; |
507 | 0 | return *this; |
508 | 0 | } |
509 | | |
510 | 0 | Builder &setElementType(Type newElementType) { |
511 | 0 | elementType = newElementType; |
512 | 0 | return *this; |
513 | 0 | } |
514 | | |
515 | 0 | Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) { |
516 | 0 | affineMaps = newAffineMaps; |
517 | 0 | return *this; |
518 | 0 | } |
519 | | |
520 | 0 | Builder &setMemorySpace(unsigned newMemorySpace) { |
521 | 0 | memorySpace = newMemorySpace; |
522 | 0 | return *this; |
523 | 0 | } |
524 | | |
525 | 0 | operator MemRefType() { |
526 | 0 | return MemRefType::get(shape, elementType, affineMaps, memorySpace); |
527 | 0 | } |
528 | | |
529 | | private: |
530 | | ArrayRef<int64_t> shape; |
531 | | Type elementType; |
532 | | ArrayRef<AffineMap> affineMaps; |
533 | | unsigned memorySpace; |
534 | | }; |
535 | | |
536 | | using Base::Base; |
537 | | |
538 | | /// Get or create a new MemRefType based on shape, element type, affine |
539 | | /// map composition, and memory space. Assumes the arguments define a |
540 | | /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType |
541 | | /// construction failures. |
542 | | static MemRefType get(ArrayRef<int64_t> shape, Type elementType, |
543 | | ArrayRef<AffineMap> affineMapComposition = {}, |
544 | | unsigned memorySpace = 0); |
545 | | |
546 | | /// Get or create a new MemRefType based on shape, element type, affine |
547 | | /// map composition, and memory space declared at the given location. |
548 | | /// If the location is unknown, the last argument should be an instance of |
549 | | /// UnknownLoc. If the MemRefType defined by the arguments would be |
550 | | /// ill-formed, emits errors (to the handler registered with the context or to |
551 | | /// the error stream) and returns nullptr. |
552 | | static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType, |
553 | | ArrayRef<AffineMap> affineMapComposition, |
554 | | unsigned memorySpace, Location location); |
555 | | |
556 | | ArrayRef<int64_t> getShape() const; |
557 | | |
558 | | /// Returns an array of affine map pointers representing the memref affine |
559 | | /// map composition. |
560 | | ArrayRef<AffineMap> getAffineMaps() const; |
561 | | |
562 | | /// Returns the memory space in which data referred to by this memref resides. |
563 | | unsigned getMemorySpace() const; |
564 | | |
565 | | // TODO(ntv): merge these two special values in a single one used everywhere. |
566 | | // Unfortunately, uses of `-1` have crept deep into the codebase now and are |
567 | | // hard to track. |
568 | | static constexpr int64_t kDynamicSize = -1; |
569 | 0 | static int64_t getDynamicStrideOrOffset() { |
570 | 0 | return ShapedType::kDynamicStrideOrOffset; |
571 | 0 | } |
572 | | |
573 | 0 | static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; } |
574 | | |
575 | | private: |
576 | | /// Get or create a new MemRefType defined by the arguments. If the resulting |
577 | | /// type would be ill-formed, return nullptr. If the location is provided, |
578 | | /// emit detailed error messages. |
579 | | static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType, |
580 | | ArrayRef<AffineMap> affineMapComposition, |
581 | | unsigned memorySpace, Optional<Location> location); |
582 | | using Base::getImpl; |
583 | | }; |
584 | | |
585 | | //===----------------------------------------------------------------------===// |
586 | | // UnrankedMemRefType |
587 | | |
588 | | /// Unranked MemRef type represent multi-dimensional MemRefs that |
589 | | /// have an unknown rank. |
590 | | class UnrankedMemRefType |
591 | | : public Type::TypeBase<UnrankedMemRefType, BaseMemRefType, |
592 | | detail::UnrankedMemRefTypeStorage> { |
593 | | public: |
594 | | using Base::Base; |
595 | | |
596 | | /// Get or create a new UnrankedMemRefType of the provided element |
597 | | /// type and memory space |
598 | | static UnrankedMemRefType get(Type elementType, unsigned memorySpace); |
599 | | |
600 | | /// Get or create a new UnrankedMemRefType of the provided element |
601 | | /// type and memory space declared at the given, potentially unknown, |
602 | | /// location. If the UnrankedMemRefType defined by the arguments would be |
603 | | /// ill-formed, emit errors and return a nullptr-wrapping type. |
604 | | static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace, |
605 | | Location location); |
606 | | |
607 | | /// Verify the construction of a unranked memref type. |
608 | | static LogicalResult verifyConstructionInvariants(Location loc, |
609 | | Type elementType, |
610 | | unsigned memorySpace); |
611 | | |
612 | 0 | ArrayRef<int64_t> getShape() const { return llvm::None; } |
613 | | |
614 | | /// Returns the memory space in which data referred to by this memref resides. |
615 | | unsigned getMemorySpace() const; |
616 | 0 | static bool kindof(unsigned kind) { |
617 | 0 | return kind == StandardTypes::UnrankedMemRef; |
618 | 0 | } |
619 | | }; |
620 | | |
621 | | //===----------------------------------------------------------------------===// |
622 | | // TupleType |
623 | | //===----------------------------------------------------------------------===// |
624 | | |
625 | | /// Tuple types represent a collection of other types. Note: This type merely |
626 | | /// provides a common mechanism for representing tuples in MLIR. It is up to |
627 | | /// dialect authors to provides operations for manipulating them, e.g. |
628 | | /// extract_tuple_element. When possible, users should prefer multi-result |
629 | | /// operations in the place of tuples. |
630 | | class TupleType |
631 | | : public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> { |
632 | | public: |
633 | | using Base::Base; |
634 | | |
635 | | /// Get or create a new TupleType with the provided element types. Assumes the |
636 | | /// arguments define a well-formed type. |
637 | | static TupleType get(ArrayRef<Type> elementTypes, MLIRContext *context); |
638 | | |
639 | | /// Get or create an empty tuple type. |
640 | 0 | static TupleType get(MLIRContext *context) { return get({}, context); } |
641 | | |
642 | | /// Return the elements types for this tuple. |
643 | | ArrayRef<Type> getTypes() const; |
644 | | |
645 | | /// Accumulate the types contained in this tuple and tuples nested within it. |
646 | | /// Note that this only flattens nested tuples, not any other container type, |
647 | | /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to |
648 | | /// (i32, tensor<i32>, f32, i64) |
649 | | void getFlattenedTypes(SmallVectorImpl<Type> &types); |
650 | | |
651 | | /// Return the number of held types. |
652 | | size_t size() const; |
653 | | |
654 | | /// Iterate over the held elements. |
655 | | using iterator = ArrayRef<Type>::iterator; |
656 | 0 | iterator begin() const { return getTypes().begin(); } |
657 | 0 | iterator end() const { return getTypes().end(); } |
658 | | |
659 | | /// Return the element type at index 'index'. |
660 | 0 | Type getType(size_t index) const { |
661 | 0 | assert(index < size() && "invalid index for tuple type"); |
662 | 0 | return getTypes()[index]; |
663 | 0 | } |
664 | | |
665 | 0 | static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; } |
666 | | }; |
667 | | |
668 | | //===----------------------------------------------------------------------===// |
669 | | // Type Utilities |
670 | | //===----------------------------------------------------------------------===// |
671 | | |
672 | | /// Returns the strides of the MemRef if the layout map is in strided form. |
673 | | /// MemRefs with layout maps in strided form include: |
674 | | /// 1. empty or identity layout map, in which case the stride information is |
675 | | /// the canonical form computed from sizes; |
676 | | /// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`, |
677 | | /// where K and ki's are constants or symbols. |
678 | | /// |
679 | | /// A stride specification is a list of integer values that are either static |
680 | | /// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the |
681 | | /// distance in the number of elements between successive entries along a |
682 | | /// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>` |
683 | | /// specifies a view into a non-contiguous memory region of `42` by `16` `f32` |
684 | | /// elements in which the distance between two consecutive elements along the |
685 | | /// outer dimension is `1` and the distance between two consecutive elements |
686 | | /// along the inner dimension is `64`. |
687 | | /// |
688 | | /// If a simple strided form cannot be extracted from the composition of the |
689 | | /// layout map, returns llvm::None. |
690 | | /// |
691 | | /// The convention is that the strides for dimensions d0, .. dn appear in |
692 | | /// order to make indexing intuitive into the result. |
693 | | LogicalResult getStridesAndOffset(MemRefType t, |
694 | | SmallVectorImpl<int64_t> &strides, |
695 | | int64_t &offset); |
696 | | LogicalResult getStridesAndOffset(MemRefType t, |
697 | | SmallVectorImpl<AffineExpr> &strides, |
698 | | AffineExpr &offset); |
699 | | |
700 | | /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() |
701 | | /// represents a dynamic value), return the single result AffineMap which |
702 | | /// represents the linearized strided layout map. Dimensions correspond to the |
703 | | /// offset followed by the strides in order. Symbols are inserted for each |
704 | | /// dynamic dimension in order. A stride cannot take value `0`. |
705 | | /// |
706 | | /// Examples: |
707 | | /// ========= |
708 | | /// |
709 | | /// 1. For offset: 0 strides: ?, ?, 1 return |
710 | | /// (i, j, k)[M, N]->(M * i + N * j + k) |
711 | | /// |
712 | | /// 2. For offset: 3 strides: 32, ?, 16 return |
713 | | /// (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k) |
714 | | /// |
715 | | /// 3. For offset: ? strides: ?, ?, ? return |
716 | | /// (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k) |
717 | | AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset, |
718 | | MLIRContext *context); |
719 | | |
720 | | /// Return a version of `t` with identity layout if it can be determined |
721 | | /// statically that the layout is the canonical contiguous strided layout. |
722 | | /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of |
723 | | /// `t` with simplified layout. |
724 | | MemRefType canonicalizeStridedLayout(MemRefType t); |
725 | | |
726 | | /// Return a version of `t` with a layout that has all dynamic offset and |
727 | | /// strides. This is used to erase the static layout. |
728 | | MemRefType eraseStridedLayout(MemRefType t); |
729 | | |
730 | | /// Given MemRef `sizes` that are either static or dynamic, returns the |
731 | | /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and |
732 | | /// once a dynamic dimension is encountered, all canonical strides become |
733 | | /// dynamic and need to be encoded with a different symbol. |
734 | | /// For canonical strides expressions, the offset is always 0 and and fastest |
735 | | /// varying stride is always `1`. |
736 | | /// |
737 | | /// Examples: |
738 | | /// - memref<3x4x5xf32> has canonical stride expression |
739 | | /// `20*exprs[0] + 5*exprs[1] + exprs[2]`. |
740 | | /// - memref<3x?x5xf32> has canonical stride expression |
741 | | /// `s0*exprs[0] + 5*exprs[1] + exprs[2]`. |
742 | | /// - memref<3x4x?xf32> has canonical stride expression |
743 | | /// `s1*exprs[0] + s0*exprs[1] + exprs[2]`. |
744 | | AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, |
745 | | ArrayRef<AffineExpr> exprs, |
746 | | MLIRContext *context); |
747 | | |
748 | | /// Return the result of makeCanonicalStrudedLayoutExpr for the common case |
749 | | /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)} |
750 | | AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, |
751 | | MLIRContext *context); |
752 | | |
753 | | /// Return true if the layout for `t` is compatible with strided semantics. |
754 | | bool isStrided(MemRefType t); |
755 | | |
756 | | } // end namespace mlir |
757 | | |
758 | | #endif // MLIR_IR_STANDARDTYPES_H |