/home/arjun/llvm-project/mlir/lib/IR/StandardTypes.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- StandardTypes.cpp - MLIR Standard Type Classes ---------------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #include "mlir/IR/StandardTypes.h" |
10 | | #include "TypeDetail.h" |
11 | | #include "mlir/IR/AffineExpr.h" |
12 | | #include "mlir/IR/AffineMap.h" |
13 | | #include "mlir/IR/Diagnostics.h" |
14 | | #include "llvm/ADT/APFloat.h" |
15 | | #include "llvm/ADT/Twine.h" |
16 | | |
17 | | using namespace mlir; |
18 | | using namespace mlir::detail; |
19 | | |
20 | | //===----------------------------------------------------------------------===// |
21 | | // Type |
22 | | //===----------------------------------------------------------------------===// |
23 | | |
24 | 0 | bool Type::isBF16() { return getKind() == StandardTypes::BF16; } |
25 | 0 | bool Type::isF16() { return getKind() == StandardTypes::F16; } |
26 | 0 | bool Type::isF32() { return getKind() == StandardTypes::F32; } |
27 | 0 | bool Type::isF64() { return getKind() == StandardTypes::F64; } |
28 | | |
29 | 0 | bool Type::isIndex() { return isa<IndexType>(); } |
30 | | |
31 | | /// Return true if this is an integer type with the specified width. |
32 | 0 | bool Type::isInteger(unsigned width) { |
33 | 0 | if (auto intTy = dyn_cast<IntegerType>()) |
34 | 0 | return intTy.getWidth() == width; |
35 | 0 | return false; |
36 | 0 | } |
37 | | |
38 | 0 | bool Type::isSignlessInteger() { |
39 | 0 | if (auto intTy = dyn_cast<IntegerType>()) |
40 | 0 | return intTy.isSignless(); |
41 | 0 | return false; |
42 | 0 | } |
43 | | |
44 | 0 | bool Type::isSignlessInteger(unsigned width) { |
45 | 0 | if (auto intTy = dyn_cast<IntegerType>()) |
46 | 0 | return intTy.isSignless() && intTy.getWidth() == width; |
47 | 0 | return false; |
48 | 0 | } |
49 | | |
50 | 0 | bool Type::isSignedInteger() { |
51 | 0 | if (auto intTy = dyn_cast<IntegerType>()) |
52 | 0 | return intTy.isSigned(); |
53 | 0 | return false; |
54 | 0 | } |
55 | | |
56 | 0 | bool Type::isSignedInteger(unsigned width) { |
57 | 0 | if (auto intTy = dyn_cast<IntegerType>()) |
58 | 0 | return intTy.isSigned() && intTy.getWidth() == width; |
59 | 0 | return false; |
60 | 0 | } |
61 | | |
62 | 0 | bool Type::isUnsignedInteger() { |
63 | 0 | if (auto intTy = dyn_cast<IntegerType>()) |
64 | 0 | return intTy.isUnsigned(); |
65 | 0 | return false; |
66 | 0 | } |
67 | | |
68 | 0 | bool Type::isUnsignedInteger(unsigned width) { |
69 | 0 | if (auto intTy = dyn_cast<IntegerType>()) |
70 | 0 | return intTy.isUnsigned() && intTy.getWidth() == width; |
71 | 0 | return false; |
72 | 0 | } |
73 | | |
74 | 0 | bool Type::isSignlessIntOrIndex() { |
75 | 0 | return isa<IndexType>() || isSignlessInteger(); |
76 | 0 | } |
77 | | |
78 | 0 | bool Type::isSignlessIntOrIndexOrFloat() { |
79 | 0 | return isa<IndexType>() || isSignlessInteger() || isa<FloatType>(); |
80 | 0 | } |
81 | | |
82 | 0 | bool Type::isSignlessIntOrFloat() { |
83 | 0 | return isSignlessInteger() || isa<FloatType>(); |
84 | 0 | } |
85 | | |
86 | 0 | bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); } |
87 | | |
88 | 0 | bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); } |
89 | | |
90 | 0 | bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); } |
91 | | |
92 | | //===----------------------------------------------------------------------===// |
93 | | /// ComplexType |
94 | | //===----------------------------------------------------------------------===// |
95 | | |
96 | 0 | ComplexType ComplexType::get(Type elementType) { |
97 | 0 | return Base::get(elementType.getContext(), StandardTypes::Complex, |
98 | 0 | elementType); |
99 | 0 | } |
100 | | |
101 | 0 | ComplexType ComplexType::getChecked(Type elementType, Location location) { |
102 | 0 | return Base::getChecked(location, StandardTypes::Complex, elementType); |
103 | 0 | } |
104 | | |
105 | | /// Verify the construction of an integer type. |
106 | | LogicalResult ComplexType::verifyConstructionInvariants(Location loc, |
107 | 0 | Type elementType) { |
108 | 0 | if (!elementType.isIntOrFloat()) |
109 | 0 | return emitError(loc, "invalid element type for complex"); |
110 | 0 | return success(); |
111 | 0 | } |
112 | | |
113 | 0 | Type ComplexType::getElementType() { return getImpl()->elementType; } |
114 | | |
115 | | //===----------------------------------------------------------------------===// |
116 | | // Integer Type |
117 | | //===----------------------------------------------------------------------===// |
118 | | |
119 | | // static constexpr must have a definition (until in C++17 and inline variable). |
120 | | constexpr unsigned IntegerType::kMaxWidth; |
121 | | |
122 | | /// Verify the construction of an integer type. |
123 | | LogicalResult |
124 | | IntegerType::verifyConstructionInvariants(Location loc, unsigned width, |
125 | 0 | SignednessSemantics signedness) { |
126 | 0 | if (width > IntegerType::kMaxWidth) { |
127 | 0 | return emitError(loc) << "integer bitwidth is limited to " |
128 | 0 | << IntegerType::kMaxWidth << " bits"; |
129 | 0 | } |
130 | 0 | return success(); |
131 | 0 | } |
132 | | |
133 | 0 | unsigned IntegerType::getWidth() const { return getImpl()->getWidth(); } |
134 | | |
135 | 0 | IntegerType::SignednessSemantics IntegerType::getSignedness() const { |
136 | 0 | return getImpl()->getSignedness(); |
137 | 0 | } |
138 | | |
139 | | //===----------------------------------------------------------------------===// |
140 | | // Float Type |
141 | | //===----------------------------------------------------------------------===// |
142 | | |
143 | 0 | unsigned FloatType::getWidth() { |
144 | 0 | switch (getKind()) { |
145 | 0 | case StandardTypes::BF16: |
146 | 0 | case StandardTypes::F16: |
147 | 0 | return 16; |
148 | 0 | case StandardTypes::F32: |
149 | 0 | return 32; |
150 | 0 | case StandardTypes::F64: |
151 | 0 | return 64; |
152 | 0 | default: |
153 | 0 | llvm_unreachable("unexpected type"); |
154 | 0 | } |
155 | 0 | } |
156 | | |
157 | | /// Returns the floating semantics for the given type. |
158 | 0 | const llvm::fltSemantics &FloatType::getFloatSemantics() { |
159 | 0 | if (isBF16()) |
160 | 0 | // Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is |
161 | 0 | // not defined in LLVM. |
162 | 0 | // TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc |
163 | 0 | // else one could add it. |
164 | 0 | // static const fltSemantics semBF16 = {127, -126, 8, 16}; |
165 | 0 | return APFloat::IEEEdouble(); |
166 | 0 | if (isF16()) |
167 | 0 | return APFloat::IEEEhalf(); |
168 | 0 | if (isF32()) |
169 | 0 | return APFloat::IEEEsingle(); |
170 | 0 | if (isF64()) |
171 | 0 | return APFloat::IEEEdouble(); |
172 | 0 | llvm_unreachable("non-floating point type used"); |
173 | 0 | } |
174 | | |
175 | 0 | unsigned Type::getIntOrFloatBitWidth() { |
176 | 0 | assert(isIntOrFloat() && "only integers and floats have a bitwidth"); |
177 | 0 | if (auto intType = dyn_cast<IntegerType>()) |
178 | 0 | return intType.getWidth(); |
179 | 0 | return cast<FloatType>().getWidth(); |
180 | 0 | } |
181 | | |
182 | | //===----------------------------------------------------------------------===// |
183 | | // ShapedType |
184 | | //===----------------------------------------------------------------------===// |
185 | | constexpr int64_t ShapedType::kDynamicSize; |
186 | | constexpr int64_t ShapedType::kDynamicStrideOrOffset; |
187 | | |
188 | 0 | Type ShapedType::getElementType() const { |
189 | 0 | return static_cast<ImplType *>(impl)->elementType; |
190 | 0 | } |
191 | | |
192 | 0 | unsigned ShapedType::getElementTypeBitWidth() const { |
193 | 0 | return getElementType().getIntOrFloatBitWidth(); |
194 | 0 | } |
195 | | |
196 | | int64_t ShapedType::getNumElements() const { |
197 | | assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); |
198 | | auto shape = getShape(); |
199 | | int64_t num = 1; |
200 | | for (auto dim : shape) |
201 | | num *= dim; |
202 | | return num; |
203 | | } |
204 | | |
205 | 0 | int64_t ShapedType::getRank() const { return getShape().size(); } |
206 | | |
207 | 0 | bool ShapedType::hasRank() const { return !isa<UnrankedTensorType>(); } |
208 | | |
209 | 0 | int64_t ShapedType::getDimSize(unsigned idx) const { |
210 | 0 | assert(idx < getRank() && "invalid index for shaped type"); |
211 | 0 | return getShape()[idx]; |
212 | 0 | } |
213 | | |
214 | 0 | bool ShapedType::isDynamicDim(unsigned idx) const { |
215 | 0 | assert(idx < getRank() && "invalid index for shaped type"); |
216 | 0 | return isDynamic(getShape()[idx]); |
217 | 0 | } |
218 | | |
219 | 0 | unsigned ShapedType::getDynamicDimIndex(unsigned index) const { |
220 | 0 | assert(index < getRank() && "invalid index"); |
221 | 0 | assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); |
222 | 0 | return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); |
223 | 0 | } |
224 | | |
225 | | /// Get the number of bits require to store a value of the given shaped type. |
226 | | /// Compute the value recursively since tensors are allowed to have vectors as |
227 | | /// elements. |
228 | 0 | int64_t ShapedType::getSizeInBits() const { |
229 | 0 | assert(hasStaticShape() && |
230 | 0 | "cannot get the bit size of an aggregate with a dynamic shape"); |
231 | 0 |
|
232 | 0 | auto elementType = getElementType(); |
233 | 0 | if (elementType.isIntOrFloat()) |
234 | 0 | return elementType.getIntOrFloatBitWidth() * getNumElements(); |
235 | 0 | |
236 | 0 | // Tensors can have vectors and other tensors as elements, other shaped types |
237 | 0 | // cannot. |
238 | 0 | assert(isa<TensorType>() && "unsupported element type"); |
239 | 0 | assert((elementType.isa<VectorType>() || elementType.isa<TensorType>()) && |
240 | 0 | "unsupported tensor element type"); |
241 | 0 | return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); |
242 | 0 | } |
243 | | |
244 | 0 | ArrayRef<int64_t> ShapedType::getShape() const { |
245 | 0 | switch (getKind()) { |
246 | 0 | case StandardTypes::Vector: |
247 | 0 | return cast<VectorType>().getShape(); |
248 | 0 | case StandardTypes::RankedTensor: |
249 | 0 | return cast<RankedTensorType>().getShape(); |
250 | 0 | case StandardTypes::MemRef: |
251 | 0 | return cast<MemRefType>().getShape(); |
252 | 0 | default: |
253 | 0 | llvm_unreachable("not a ShapedType or not ranked"); |
254 | 0 | } |
255 | 0 | } |
256 | | |
257 | 0 | int64_t ShapedType::getNumDynamicDims() const { |
258 | 0 | return llvm::count_if(getShape(), isDynamic); |
259 | 0 | } |
260 | | |
261 | 0 | bool ShapedType::hasStaticShape() const { |
262 | 0 | return hasRank() && llvm::none_of(getShape(), isDynamic); |
263 | 0 | } |
264 | | |
265 | 0 | bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { |
266 | 0 | return hasStaticShape() && getShape() == shape; |
267 | 0 | } |
268 | | |
269 | | //===----------------------------------------------------------------------===// |
270 | | // VectorType |
271 | | //===----------------------------------------------------------------------===// |
272 | | |
273 | 0 | VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) { |
274 | 0 | return Base::get(elementType.getContext(), StandardTypes::Vector, shape, |
275 | 0 | elementType); |
276 | 0 | } |
277 | | |
278 | | VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType, |
279 | 0 | Location location) { |
280 | 0 | return Base::getChecked(location, StandardTypes::Vector, shape, elementType); |
281 | 0 | } |
282 | | |
283 | | LogicalResult VectorType::verifyConstructionInvariants(Location loc, |
284 | | ArrayRef<int64_t> shape, |
285 | 0 | Type elementType) { |
286 | 0 | if (shape.empty()) |
287 | 0 | return emitError(loc, "vector types must have at least one dimension"); |
288 | 0 | |
289 | 0 | if (!isValidElementType(elementType)) |
290 | 0 | return emitError(loc, "vector elements must be int or float type"); |
291 | 0 | |
292 | 0 | if (any_of(shape, [](int64_t i) { return i <= 0; })) |
293 | 0 | return emitError(loc, "vector types must have positive constant sizes"); |
294 | 0 | |
295 | 0 | return success(); |
296 | 0 | } |
297 | | |
298 | 0 | ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); } |
299 | | |
300 | | //===----------------------------------------------------------------------===// |
301 | | // TensorType |
302 | | //===----------------------------------------------------------------------===// |
303 | | |
304 | | // Check if "elementType" can be an element type of a tensor. Emit errors if |
305 | | // location is not nullptr. Returns failure if check failed. |
306 | | static inline LogicalResult checkTensorElementType(Location location, |
307 | 0 | Type elementType) { |
308 | 0 | if (!TensorType::isValidElementType(elementType)) |
309 | 0 | return emitError(location, "invalid tensor element type"); |
310 | 0 | return success(); |
311 | 0 | } |
312 | | |
313 | | //===----------------------------------------------------------------------===// |
314 | | // RankedTensorType |
315 | | //===----------------------------------------------------------------------===// |
316 | | |
317 | | RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape, |
318 | 0 | Type elementType) { |
319 | 0 | return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape, |
320 | 0 | elementType); |
321 | 0 | } |
322 | | |
323 | | RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape, |
324 | | Type elementType, |
325 | 0 | Location location) { |
326 | 0 | return Base::getChecked(location, StandardTypes::RankedTensor, shape, |
327 | 0 | elementType); |
328 | 0 | } |
329 | | |
330 | | LogicalResult RankedTensorType::verifyConstructionInvariants( |
331 | 0 | Location loc, ArrayRef<int64_t> shape, Type elementType) { |
332 | 0 | for (int64_t s : shape) { |
333 | 0 | if (s < -1) |
334 | 0 | return emitError(loc, "invalid tensor dimension size"); |
335 | 0 | } |
336 | 0 | return checkTensorElementType(loc, elementType); |
337 | 0 | } |
338 | | |
339 | 0 | ArrayRef<int64_t> RankedTensorType::getShape() const { |
340 | 0 | return getImpl()->getShape(); |
341 | 0 | } |
342 | | |
343 | | //===----------------------------------------------------------------------===// |
344 | | // UnrankedTensorType |
345 | | //===----------------------------------------------------------------------===// |
346 | | |
347 | 0 | UnrankedTensorType UnrankedTensorType::get(Type elementType) { |
348 | 0 | return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor, |
349 | 0 | elementType); |
350 | 0 | } |
351 | | |
352 | | UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, |
353 | 0 | Location location) { |
354 | 0 | return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType); |
355 | 0 | } |
356 | | |
357 | | LogicalResult |
358 | | UnrankedTensorType::verifyConstructionInvariants(Location loc, |
359 | 0 | Type elementType) { |
360 | 0 | return checkTensorElementType(loc, elementType); |
361 | 0 | } |
362 | | |
363 | | //===----------------------------------------------------------------------===// |
364 | | // MemRefType |
365 | | //===----------------------------------------------------------------------===// |
366 | | |
367 | | /// Get or create a new MemRefType based on shape, element type, affine |
368 | | /// map composition, and memory space. Assumes the arguments define a |
369 | | /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType |
370 | | /// construction failures. |
371 | | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, |
372 | | ArrayRef<AffineMap> affineMapComposition, |
373 | 0 | unsigned memorySpace) { |
374 | 0 | auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, |
375 | 0 | /*location=*/llvm::None); |
376 | 0 | assert(result && "Failed to construct instance of MemRefType."); |
377 | 0 | return result; |
378 | 0 | } |
379 | | |
380 | | /// Get or create a new MemRefType based on shape, element type, affine |
381 | | /// map composition, and memory space declared at the given location. |
382 | | /// If the location is unknown, the last argument should be an instance of |
383 | | /// UnknownLoc. If the MemRefType defined by the arguments would be |
384 | | /// ill-formed, emits errors (to the handler registered with the context or to |
385 | | /// the error stream) and returns nullptr. |
386 | | MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType, |
387 | | ArrayRef<AffineMap> affineMapComposition, |
388 | 0 | unsigned memorySpace, Location location) { |
389 | 0 | return getImpl(shape, elementType, affineMapComposition, memorySpace, |
390 | 0 | location); |
391 | 0 | } |
392 | | |
393 | | /// Get or create a new MemRefType defined by the arguments. If the resulting |
394 | | /// type would be ill-formed, return nullptr. If the location is provided, |
395 | | /// emit detailed error messages. To emit errors when the location is unknown, |
396 | | /// pass in an instance of UnknownLoc. |
397 | | MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType, |
398 | | ArrayRef<AffineMap> affineMapComposition, |
399 | | unsigned memorySpace, |
400 | 0 | Optional<Location> location) { |
401 | 0 | auto *context = elementType.getContext(); |
402 | 0 |
|
403 | 0 | // Check that memref is formed from allowed types. |
404 | 0 | if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() && |
405 | 0 | !elementType.isa<ComplexType>()) |
406 | 0 | return emitOptionalError(location, "invalid memref element type"), |
407 | 0 | MemRefType(); |
408 | 0 | |
409 | 0 | for (int64_t s : shape) { |
410 | 0 | // Negative sizes are not allowed except for `-1` that means dynamic size. |
411 | 0 | if (s < -1) |
412 | 0 | return emitOptionalError(location, "invalid memref size"), MemRefType(); |
413 | 0 | } |
414 | 0 |
|
415 | 0 | // Check that the structure of the composition is valid, i.e. that each |
416 | 0 | // subsequent affine map has as many inputs as the previous map has results. |
417 | 0 | // Take the dimensionality of the MemRef for the first map. |
418 | 0 | auto dim = shape.size(); |
419 | 0 | unsigned i = 0; |
420 | 0 | for (const auto &affineMap : affineMapComposition) { |
421 | 0 | if (affineMap.getNumDims() != dim) { |
422 | 0 | if (location) |
423 | 0 | emitError(*location) |
424 | 0 | << "memref affine map dimension mismatch between " |
425 | 0 | << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) |
426 | 0 | << " and affine map" << i + 1 << ": " << dim |
427 | 0 | << " != " << affineMap.getNumDims(); |
428 | 0 | return nullptr; |
429 | 0 | } |
430 | 0 |
|
431 | 0 | dim = affineMap.getNumResults(); |
432 | 0 | ++i; |
433 | 0 | } |
434 | 0 |
|
435 | 0 | // Drop identity maps from the composition. |
436 | 0 | // This may lead to the composition becoming empty, which is interpreted as an |
437 | 0 | // implicit identity. |
438 | 0 | SmallVector<AffineMap, 2> cleanedAffineMapComposition; |
439 | 0 | for (const auto &map : affineMapComposition) { |
440 | 0 | if (map.isIdentity()) |
441 | 0 | continue; |
442 | 0 | cleanedAffineMapComposition.push_back(map); |
443 | 0 | } |
444 | 0 |
|
445 | 0 | return Base::get(context, StandardTypes::MemRef, shape, elementType, |
446 | 0 | cleanedAffineMapComposition, memorySpace); |
447 | 0 | } |
448 | | |
449 | 0 | ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); } |
450 | | |
451 | 0 | ArrayRef<AffineMap> MemRefType::getAffineMaps() const { |
452 | 0 | return getImpl()->getAffineMaps(); |
453 | 0 | } |
454 | | |
455 | 0 | unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; } |
456 | | |
457 | | //===----------------------------------------------------------------------===// |
458 | | // UnrankedMemRefType |
459 | | //===----------------------------------------------------------------------===// |
460 | | |
461 | | UnrankedMemRefType UnrankedMemRefType::get(Type elementType, |
462 | 0 | unsigned memorySpace) { |
463 | 0 | return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef, |
464 | 0 | elementType, memorySpace); |
465 | 0 | } |
466 | | |
467 | | UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType, |
468 | | unsigned memorySpace, |
469 | 0 | Location location) { |
470 | 0 | return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType, |
471 | 0 | memorySpace); |
472 | 0 | } |
473 | | |
474 | 0 | unsigned UnrankedMemRefType::getMemorySpace() const { |
475 | 0 | return getImpl()->memorySpace; |
476 | 0 | } |
477 | | |
478 | | LogicalResult |
479 | | UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, |
480 | 0 | unsigned memorySpace) { |
481 | 0 | // Check that memref is formed from allowed types. |
482 | 0 | if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() && |
483 | 0 | !elementType.isa<ComplexType>()) |
484 | 0 | return emitError(loc, "invalid memref element type"); |
485 | 0 | return success(); |
486 | 0 | } |
487 | | |
488 | | // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( |
489 | | // i.e. single term). Accumulate the AffineExpr into the existing one. |
490 | | static void extractStridesFromTerm(AffineExpr e, |
491 | | AffineExpr multiplicativeFactor, |
492 | | MutableArrayRef<AffineExpr> strides, |
493 | 0 | AffineExpr &offset) { |
494 | 0 | if (auto dim = e.dyn_cast<AffineDimExpr>()) |
495 | 0 | strides[dim.getPosition()] = |
496 | 0 | strides[dim.getPosition()] + multiplicativeFactor; |
497 | 0 | else |
498 | 0 | offset = offset + e * multiplicativeFactor; |
499 | 0 | } |
500 | | |
501 | | /// Takes a single AffineExpr `e` and populates the `strides` array with the |
502 | | /// strides expressions for each dim position. |
503 | | /// The convention is that the strides for dimensions d0, .. dn appear in |
504 | | /// order to make indexing intuitive into the result. |
505 | | static LogicalResult extractStrides(AffineExpr e, |
506 | | AffineExpr multiplicativeFactor, |
507 | | MutableArrayRef<AffineExpr> strides, |
508 | 0 | AffineExpr &offset) { |
509 | 0 | auto bin = e.dyn_cast<AffineBinaryOpExpr>(); |
510 | 0 | if (!bin) { |
511 | 0 | extractStridesFromTerm(e, multiplicativeFactor, strides, offset); |
512 | 0 | return success(); |
513 | 0 | } |
514 | 0 | |
515 | 0 | if (bin.getKind() == AffineExprKind::CeilDiv || |
516 | 0 | bin.getKind() == AffineExprKind::FloorDiv || |
517 | 0 | bin.getKind() == AffineExprKind::Mod) |
518 | 0 | return failure(); |
519 | 0 | |
520 | 0 | if (bin.getKind() == AffineExprKind::Mul) { |
521 | 0 | auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); |
522 | 0 | if (dim) { |
523 | 0 | strides[dim.getPosition()] = |
524 | 0 | strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; |
525 | 0 | return success(); |
526 | 0 | } |
527 | 0 | // LHS and RHS may both contain complex expressions of dims. Try one path |
528 | 0 | // and if it fails try the other. This is guaranteed to succeed because |
529 | 0 | // only one path may have a `dim`, otherwise this is not an AffineExpr in |
530 | 0 | // the first place. |
531 | 0 | if (bin.getLHS().isSymbolicOrConstant()) |
532 | 0 | return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), |
533 | 0 | strides, offset); |
534 | 0 | return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), |
535 | 0 | strides, offset); |
536 | 0 | } |
537 | 0 | |
538 | 0 | if (bin.getKind() == AffineExprKind::Add) { |
539 | 0 | auto res1 = |
540 | 0 | extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); |
541 | 0 | auto res2 = |
542 | 0 | extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); |
543 | 0 | return success(succeeded(res1) && succeeded(res2)); |
544 | 0 | } |
545 | 0 |
|
546 | 0 | llvm_unreachable("unexpected binary operation"); |
547 | 0 | } |
548 | | |
549 | | LogicalResult mlir::getStridesAndOffset(MemRefType t, |
550 | | SmallVectorImpl<AffineExpr> &strides, |
551 | 0 | AffineExpr &offset) { |
552 | 0 | auto affineMaps = t.getAffineMaps(); |
553 | 0 | // For now strides are only computed on a single affine map with a single |
554 | 0 | // result (i.e. the closed subset of linearization maps that are compatible |
555 | 0 | // with striding semantics). |
556 | 0 | // TODO(ntv): support more forms on a per-need basis. |
557 | 0 | if (affineMaps.size() > 1) |
558 | 0 | return failure(); |
559 | 0 | if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) |
560 | 0 | return failure(); |
561 | 0 | |
562 | 0 | auto zero = getAffineConstantExpr(0, t.getContext()); |
563 | 0 | auto one = getAffineConstantExpr(1, t.getContext()); |
564 | 0 | offset = zero; |
565 | 0 | strides.assign(t.getRank(), zero); |
566 | 0 |
|
567 | 0 | AffineMap m; |
568 | 0 | if (!affineMaps.empty()) { |
569 | 0 | m = affineMaps.front(); |
570 | 0 | assert(!m.isIdentity() && "unexpected identity map"); |
571 | 0 | } |
572 | 0 |
|
573 | 0 | // Canonical case for empty map. |
574 | 0 | if (!m) { |
575 | 0 | // 0-D corner case, offset is already 0. |
576 | 0 | if (t.getRank() == 0) |
577 | 0 | return success(); |
578 | 0 | auto stridedExpr = |
579 | 0 | makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); |
580 | 0 | if (succeeded(extractStrides(stridedExpr, one, strides, offset))) |
581 | 0 | return success(); |
582 | 0 | assert(false && "unexpected failure: extract strides in canonical layout"); |
583 | 0 | } |
584 | 0 |
|
585 | 0 | // Non-canonical case requires more work. |
586 | 0 | auto stridedExpr = |
587 | 0 | simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); |
588 | 0 | if (failed(extractStrides(stridedExpr, one, strides, offset))) { |
589 | 0 | offset = AffineExpr(); |
590 | 0 | strides.clear(); |
591 | 0 | return failure(); |
592 | 0 | } |
593 | 0 | |
594 | 0 | // Simplify results to allow folding to constants and simple checks. |
595 | 0 | unsigned numDims = m.getNumDims(); |
596 | 0 | unsigned numSymbols = m.getNumSymbols(); |
597 | 0 | offset = simplifyAffineExpr(offset, numDims, numSymbols); |
598 | 0 | for (auto &stride : strides) |
599 | 0 | stride = simplifyAffineExpr(stride, numDims, numSymbols); |
600 | 0 |
|
601 | 0 | /// In practice, a strided memref must be internally non-aliasing. Test |
602 | 0 | /// against 0 as a proxy. |
603 | 0 | /// TODO(ntv) static cases can have more advanced checks. |
604 | 0 | /// TODO(ntv) dynamic cases would require a way to compare symbolic |
605 | 0 | /// expressions and would probably need an affine set context propagated |
606 | 0 | /// everywhere. |
607 | 0 | if (llvm::any_of(strides, [](AffineExpr e) { |
608 | 0 | return e == getAffineConstantExpr(0, e.getContext()); |
609 | 0 | })) { |
610 | 0 | offset = AffineExpr(); |
611 | 0 | strides.clear(); |
612 | 0 | return failure(); |
613 | 0 | } |
614 | 0 | |
615 | 0 | return success(); |
616 | 0 | } |
617 | | |
618 | | LogicalResult mlir::getStridesAndOffset(MemRefType t, |
619 | | SmallVectorImpl<int64_t> &strides, |
620 | 0 | int64_t &offset) { |
621 | 0 | AffineExpr offsetExpr; |
622 | 0 | SmallVector<AffineExpr, 4> strideExprs; |
623 | 0 | if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) |
624 | 0 | return failure(); |
625 | 0 | if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) |
626 | 0 | offset = cst.getValue(); |
627 | 0 | else |
628 | 0 | offset = ShapedType::kDynamicStrideOrOffset; |
629 | 0 | for (auto e : strideExprs) { |
630 | 0 | if (auto c = e.dyn_cast<AffineConstantExpr>()) |
631 | 0 | strides.push_back(c.getValue()); |
632 | 0 | else |
633 | 0 | strides.push_back(ShapedType::kDynamicStrideOrOffset); |
634 | 0 | } |
635 | 0 | return success(); |
636 | 0 | } |
637 | | |
638 | | //===----------------------------------------------------------------------===// |
639 | | /// TupleType |
640 | | //===----------------------------------------------------------------------===// |
641 | | |
642 | | /// Get or create a new TupleType with the provided element types. Assumes the |
643 | | /// arguments define a well-formed type. |
644 | 0 | TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) { |
645 | 0 | return Base::get(context, StandardTypes::Tuple, elementTypes); |
646 | 0 | } |
647 | | |
648 | | /// Return the elements types for this tuple. |
649 | 0 | ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } |
650 | | |
651 | | /// Accumulate the types contained in this tuple and tuples nested within it. |
652 | | /// Note that this only flattens nested tuples, not any other container type, |
653 | | /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to |
654 | | /// (i32, tensor<i32>, f32, i64) |
655 | 0 | void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { |
656 | 0 | for (Type type : getTypes()) { |
657 | 0 | if (auto nestedTuple = type.dyn_cast<TupleType>()) |
658 | 0 | nestedTuple.getFlattenedTypes(types); |
659 | 0 | else |
660 | 0 | types.push_back(type); |
661 | 0 | } |
662 | 0 | } |
663 | | |
664 | | /// Return the number of element types. |
665 | 0 | size_t TupleType::size() const { return getImpl()->size(); } |
666 | | |
667 | | AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, |
668 | | int64_t offset, |
669 | 0 | MLIRContext *context) { |
670 | 0 | AffineExpr expr; |
671 | 0 | unsigned nSymbols = 0; |
672 | 0 |
|
673 | 0 | // AffineExpr for offset. |
674 | 0 | // Static case. |
675 | 0 | if (offset != MemRefType::getDynamicStrideOrOffset()) { |
676 | 0 | auto cst = getAffineConstantExpr(offset, context); |
677 | 0 | expr = cst; |
678 | 0 | } else { |
679 | 0 | // Dynamic case, new symbol for the offset. |
680 | 0 | auto sym = getAffineSymbolExpr(nSymbols++, context); |
681 | 0 | expr = sym; |
682 | 0 | } |
683 | 0 |
|
684 | 0 | // AffineExpr for strides. |
685 | 0 | for (auto en : llvm::enumerate(strides)) { |
686 | 0 | auto dim = en.index(); |
687 | 0 | auto stride = en.value(); |
688 | 0 | assert(stride != 0 && "Invalid stride specification"); |
689 | 0 | auto d = getAffineDimExpr(dim, context); |
690 | 0 | AffineExpr mult; |
691 | 0 | // Static case. |
692 | 0 | if (stride != MemRefType::getDynamicStrideOrOffset()) |
693 | 0 | mult = getAffineConstantExpr(stride, context); |
694 | 0 | else |
695 | 0 | // Dynamic case, new symbol for each new stride. |
696 | 0 | mult = getAffineSymbolExpr(nSymbols++, context); |
697 | 0 | expr = expr + d * mult; |
698 | 0 | } |
699 | 0 |
|
700 | 0 | return AffineMap::get(strides.size(), nSymbols, expr); |
701 | 0 | } |
702 | | |
703 | | /// Return a version of `t` with identity layout if it can be determined |
704 | | /// statically that the layout is the canonical contiguous strided layout. |
705 | | /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of |
706 | | /// `t` with simplified layout. |
707 | | /// If `t` has multiple layout maps or a multi-result layout, just return `t`. |
708 | 0 | MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { |
709 | 0 | auto affineMaps = t.getAffineMaps(); |
710 | 0 | // Already in canonical form. |
711 | 0 | if (affineMaps.empty()) |
712 | 0 | return t; |
713 | 0 | |
714 | 0 | // Can't reduce to canonical identity form, return in canonical form. |
715 | 0 | if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) |
716 | 0 | return t; |
717 | 0 | |
718 | 0 | // If the canonical strided layout for the sizes of `t` is equal to the |
719 | 0 | // simplified layout of `t` we can just return an empty layout. Otherwise, |
720 | 0 | // just simplify the existing layout. |
721 | 0 | AffineExpr expr = |
722 | 0 | makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); |
723 | 0 | auto m = affineMaps[0]; |
724 | 0 | auto simplifiedLayoutExpr = |
725 | 0 | simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); |
726 | 0 | if (expr != simplifiedLayoutExpr) |
727 | 0 | return MemRefType::Builder(t).setAffineMaps({AffineMap::get( |
728 | 0 | m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); |
729 | 0 | return MemRefType::Builder(t).setAffineMaps({}); |
730 | 0 | } |
731 | | |
732 | | AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, |
733 | | ArrayRef<AffineExpr> exprs, |
734 | 0 | MLIRContext *context) { |
735 | 0 | AffineExpr expr; |
736 | 0 | bool dynamicPoisonBit = false; |
737 | 0 | unsigned numDims = 0; |
738 | 0 | unsigned nSymbols = 0; |
739 | 0 | // Compute the number of symbols and dimensions of the passed exprs. |
740 | 0 | for (AffineExpr expr : exprs) { |
741 | 0 | expr.walk([&numDims, &nSymbols](AffineExpr d) { |
742 | 0 | if (AffineDimExpr dim = d.dyn_cast<AffineDimExpr>()) |
743 | 0 | numDims = std::max(numDims, dim.getPosition() + 1); |
744 | 0 | else if (AffineSymbolExpr symbol = d.dyn_cast<AffineSymbolExpr>()) |
745 | 0 | nSymbols = std::max(nSymbols, symbol.getPosition() + 1); |
746 | 0 | }); |
747 | 0 | } |
748 | 0 | int64_t runningSize = 1; |
749 | 0 | for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { |
750 | 0 | int64_t size = std::get<1>(en); |
751 | 0 | // Degenerate case, no size =-> no stride |
752 | 0 | if (size == 0) |
753 | 0 | continue; |
754 | 0 | AffineExpr dimExpr = std::get<0>(en); |
755 | 0 | AffineExpr stride = dynamicPoisonBit |
756 | 0 | ? getAffineSymbolExpr(nSymbols++, context) |
757 | 0 | : getAffineConstantExpr(runningSize, context); |
758 | 0 | expr = expr ? expr + dimExpr * stride : dimExpr * stride; |
759 | 0 | if (size > 0) |
760 | 0 | runningSize *= size; |
761 | 0 | else |
762 | 0 | dynamicPoisonBit = true; |
763 | 0 | } |
764 | 0 | return simplifyAffineExpr(expr, numDims, nSymbols); |
765 | 0 | } |
766 | | |
767 | | /// Return a version of `t` with a layout that has all dynamic offset and |
768 | | /// strides. This is used to erase the static layout. |
769 | 0 | MemRefType mlir::eraseStridedLayout(MemRefType t) { |
770 | 0 | auto val = ShapedType::kDynamicStrideOrOffset; |
771 | 0 | return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( |
772 | 0 | SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); |
773 | 0 | } |
774 | | |
775 | | AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, |
776 | 0 | MLIRContext *context) { |
777 | 0 | SmallVector<AffineExpr, 4> exprs; |
778 | 0 | exprs.reserve(sizes.size()); |
779 | 0 | for (auto dim : llvm::seq<unsigned>(0, sizes.size())) |
780 | 0 | exprs.push_back(getAffineDimExpr(dim, context)); |
781 | 0 | return makeCanonicalStridedLayoutExpr(sizes, exprs, context); |
782 | 0 | } |
783 | | |
784 | | /// Return true if the layout for `t` is compatible with strided semantics. |
785 | 0 | bool mlir::isStrided(MemRefType t) { |
786 | 0 | int64_t offset; |
787 | 0 | SmallVector<int64_t, 4> stridesAndOffset; |
788 | 0 | auto res = getStridesAndOffset(t, stridesAndOffset, offset); |
789 | 0 | return succeeded(res); |
790 | 0 | } |