Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/lib/IR/StandardTypes.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- StandardTypes.cpp - MLIR Standard Type Classes ---------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "mlir/IR/StandardTypes.h"
10
#include "TypeDetail.h"
11
#include "mlir/IR/AffineExpr.h"
12
#include "mlir/IR/AffineMap.h"
13
#include "mlir/IR/Diagnostics.h"
14
#include "llvm/ADT/APFloat.h"
15
#include "llvm/ADT/Twine.h"
16
17
using namespace mlir;
18
using namespace mlir::detail;
19
20
//===----------------------------------------------------------------------===//
21
// Type
22
//===----------------------------------------------------------------------===//
23
24
0
bool Type::isBF16() { return getKind() == StandardTypes::BF16; }
25
0
bool Type::isF16() { return getKind() == StandardTypes::F16; }
26
0
bool Type::isF32() { return getKind() == StandardTypes::F32; }
27
0
bool Type::isF64() { return getKind() == StandardTypes::F64; }
28
29
0
bool Type::isIndex() { return isa<IndexType>(); }
30
31
/// Return true if this is an integer type with the specified width.
32
0
bool Type::isInteger(unsigned width) {
33
0
  if (auto intTy = dyn_cast<IntegerType>())
34
0
    return intTy.getWidth() == width;
35
0
  return false;
36
0
}
37
38
0
bool Type::isSignlessInteger() {
39
0
  if (auto intTy = dyn_cast<IntegerType>())
40
0
    return intTy.isSignless();
41
0
  return false;
42
0
}
43
44
0
bool Type::isSignlessInteger(unsigned width) {
45
0
  if (auto intTy = dyn_cast<IntegerType>())
46
0
    return intTy.isSignless() && intTy.getWidth() == width;
47
0
  return false;
48
0
}
49
50
0
bool Type::isSignedInteger() {
51
0
  if (auto intTy = dyn_cast<IntegerType>())
52
0
    return intTy.isSigned();
53
0
  return false;
54
0
}
55
56
0
bool Type::isSignedInteger(unsigned width) {
57
0
  if (auto intTy = dyn_cast<IntegerType>())
58
0
    return intTy.isSigned() && intTy.getWidth() == width;
59
0
  return false;
60
0
}
61
62
0
bool Type::isUnsignedInteger() {
63
0
  if (auto intTy = dyn_cast<IntegerType>())
64
0
    return intTy.isUnsigned();
65
0
  return false;
66
0
}
67
68
0
bool Type::isUnsignedInteger(unsigned width) {
69
0
  if (auto intTy = dyn_cast<IntegerType>())
70
0
    return intTy.isUnsigned() && intTy.getWidth() == width;
71
0
  return false;
72
0
}
73
74
0
bool Type::isSignlessIntOrIndex() {
75
0
  return isa<IndexType>() || isSignlessInteger();
76
0
}
77
78
0
bool Type::isSignlessIntOrIndexOrFloat() {
79
0
  return isa<IndexType>() || isSignlessInteger() || isa<FloatType>();
80
0
}
81
82
0
bool Type::isSignlessIntOrFloat() {
83
0
  return isSignlessInteger() || isa<FloatType>();
84
0
}
85
86
0
bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); }
87
88
0
bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
89
90
0
bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
91
92
//===----------------------------------------------------------------------===//
93
/// ComplexType
94
//===----------------------------------------------------------------------===//
95
96
0
ComplexType ComplexType::get(Type elementType) {
97
0
  return Base::get(elementType.getContext(), StandardTypes::Complex,
98
0
                   elementType);
99
0
}
100
101
0
ComplexType ComplexType::getChecked(Type elementType, Location location) {
102
0
  return Base::getChecked(location, StandardTypes::Complex, elementType);
103
0
}
104
105
/// Verify the construction of an integer type.
106
LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
107
0
                                                        Type elementType) {
108
0
  if (!elementType.isIntOrFloat())
109
0
    return emitError(loc, "invalid element type for complex");
110
0
  return success();
111
0
}
112
113
0
Type ComplexType::getElementType() { return getImpl()->elementType; }
114
115
//===----------------------------------------------------------------------===//
116
// Integer Type
117
//===----------------------------------------------------------------------===//
118
119
// static constexpr must have a definition (until in C++17 and inline variable).
120
constexpr unsigned IntegerType::kMaxWidth;
121
122
/// Verify the construction of an integer type.
123
LogicalResult
124
IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
125
0
                                          SignednessSemantics signedness) {
126
0
  if (width > IntegerType::kMaxWidth) {
127
0
    return emitError(loc) << "integer bitwidth is limited to "
128
0
                          << IntegerType::kMaxWidth << " bits";
129
0
  }
130
0
  return success();
131
0
}
132
133
0
unsigned IntegerType::getWidth() const { return getImpl()->getWidth(); }
134
135
0
IntegerType::SignednessSemantics IntegerType::getSignedness() const {
136
0
  return getImpl()->getSignedness();
137
0
}
138
139
//===----------------------------------------------------------------------===//
140
// Float Type
141
//===----------------------------------------------------------------------===//
142
143
0
unsigned FloatType::getWidth() {
144
0
  switch (getKind()) {
145
0
  case StandardTypes::BF16:
146
0
  case StandardTypes::F16:
147
0
    return 16;
148
0
  case StandardTypes::F32:
149
0
    return 32;
150
0
  case StandardTypes::F64:
151
0
    return 64;
152
0
  default:
153
0
    llvm_unreachable("unexpected type");
154
0
  }
155
0
}
156
157
/// Returns the floating semantics for the given type.
158
0
const llvm::fltSemantics &FloatType::getFloatSemantics() {
159
0
  if (isBF16())
160
0
    // Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
161
0
    // not defined in LLVM.
162
0
    // TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
163
0
    // else one could add it.
164
0
    //  static const fltSemantics semBF16 = {127, -126, 8, 16};
165
0
    return APFloat::IEEEdouble();
166
0
  if (isF16())
167
0
    return APFloat::IEEEhalf();
168
0
  if (isF32())
169
0
    return APFloat::IEEEsingle();
170
0
  if (isF64())
171
0
    return APFloat::IEEEdouble();
172
0
  llvm_unreachable("non-floating point type used");
173
0
}
174
175
0
unsigned Type::getIntOrFloatBitWidth() {
176
0
  assert(isIntOrFloat() && "only integers and floats have a bitwidth");
177
0
  if (auto intType = dyn_cast<IntegerType>())
178
0
    return intType.getWidth();
179
0
  return cast<FloatType>().getWidth();
180
0
}
181
182
//===----------------------------------------------------------------------===//
183
// ShapedType
184
//===----------------------------------------------------------------------===//
185
constexpr int64_t ShapedType::kDynamicSize;
186
constexpr int64_t ShapedType::kDynamicStrideOrOffset;
187
188
0
Type ShapedType::getElementType() const {
189
0
  return static_cast<ImplType *>(impl)->elementType;
190
0
}
191
192
0
unsigned ShapedType::getElementTypeBitWidth() const {
193
0
  return getElementType().getIntOrFloatBitWidth();
194
0
}
195
196
int64_t ShapedType::getNumElements() const {
197
  assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
198
  auto shape = getShape();
199
  int64_t num = 1;
200
  for (auto dim : shape)
201
    num *= dim;
202
  return num;
203
}
204
205
0
int64_t ShapedType::getRank() const { return getShape().size(); }
206
207
0
bool ShapedType::hasRank() const { return !isa<UnrankedTensorType>(); }
208
209
0
int64_t ShapedType::getDimSize(unsigned idx) const {
210
0
  assert(idx < getRank() && "invalid index for shaped type");
211
0
  return getShape()[idx];
212
0
}
213
214
0
bool ShapedType::isDynamicDim(unsigned idx) const {
215
0
  assert(idx < getRank() && "invalid index for shaped type");
216
0
  return isDynamic(getShape()[idx]);
217
0
}
218
219
0
unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
220
0
  assert(index < getRank() && "invalid index");
221
0
  assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
222
0
  return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
223
0
}
224
225
/// Get the number of bits require to store a value of the given shaped type.
226
/// Compute the value recursively since tensors are allowed to have vectors as
227
/// elements.
228
0
int64_t ShapedType::getSizeInBits() const {
229
0
  assert(hasStaticShape() &&
230
0
         "cannot get the bit size of an aggregate with a dynamic shape");
231
0
232
0
  auto elementType = getElementType();
233
0
  if (elementType.isIntOrFloat())
234
0
    return elementType.getIntOrFloatBitWidth() * getNumElements();
235
0
236
0
  // Tensors can have vectors and other tensors as elements, other shaped types
237
0
  // cannot.
238
0
  assert(isa<TensorType>() && "unsupported element type");
239
0
  assert((elementType.isa<VectorType>() || elementType.isa<TensorType>()) &&
240
0
         "unsupported tensor element type");
241
0
  return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
242
0
}
243
244
0
ArrayRef<int64_t> ShapedType::getShape() const {
245
0
  switch (getKind()) {
246
0
  case StandardTypes::Vector:
247
0
    return cast<VectorType>().getShape();
248
0
  case StandardTypes::RankedTensor:
249
0
    return cast<RankedTensorType>().getShape();
250
0
  case StandardTypes::MemRef:
251
0
    return cast<MemRefType>().getShape();
252
0
  default:
253
0
    llvm_unreachable("not a ShapedType or not ranked");
254
0
  }
255
0
}
256
257
0
int64_t ShapedType::getNumDynamicDims() const {
258
0
  return llvm::count_if(getShape(), isDynamic);
259
0
}
260
261
0
bool ShapedType::hasStaticShape() const {
262
0
  return hasRank() && llvm::none_of(getShape(), isDynamic);
263
0
}
264
265
0
bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
266
0
  return hasStaticShape() && getShape() == shape;
267
0
}
268
269
//===----------------------------------------------------------------------===//
270
// VectorType
271
//===----------------------------------------------------------------------===//
272
273
0
VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
274
0
  return Base::get(elementType.getContext(), StandardTypes::Vector, shape,
275
0
                   elementType);
276
0
}
277
278
VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
279
0
                                  Location location) {
280
0
  return Base::getChecked(location, StandardTypes::Vector, shape, elementType);
281
0
}
282
283
LogicalResult VectorType::verifyConstructionInvariants(Location loc,
284
                                                       ArrayRef<int64_t> shape,
285
0
                                                       Type elementType) {
286
0
  if (shape.empty())
287
0
    return emitError(loc, "vector types must have at least one dimension");
288
0
289
0
  if (!isValidElementType(elementType))
290
0
    return emitError(loc, "vector elements must be int or float type");
291
0
292
0
  if (any_of(shape, [](int64_t i) { return i <= 0; }))
293
0
    return emitError(loc, "vector types must have positive constant sizes");
294
0
295
0
  return success();
296
0
}
297
298
0
ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
299
300
//===----------------------------------------------------------------------===//
301
// TensorType
302
//===----------------------------------------------------------------------===//
303
304
// Check if "elementType" can be an element type of a tensor. Emit errors if
305
// location is not nullptr.  Returns failure if check failed.
306
static inline LogicalResult checkTensorElementType(Location location,
307
0
                                                   Type elementType) {
308
0
  if (!TensorType::isValidElementType(elementType))
309
0
    return emitError(location, "invalid tensor element type");
310
0
  return success();
311
0
}
312
313
//===----------------------------------------------------------------------===//
314
// RankedTensorType
315
//===----------------------------------------------------------------------===//
316
317
RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
318
0
                                       Type elementType) {
319
0
  return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape,
320
0
                   elementType);
321
0
}
322
323
RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
324
                                              Type elementType,
325
0
                                              Location location) {
326
0
  return Base::getChecked(location, StandardTypes::RankedTensor, shape,
327
0
                          elementType);
328
0
}
329
330
LogicalResult RankedTensorType::verifyConstructionInvariants(
331
0
    Location loc, ArrayRef<int64_t> shape, Type elementType) {
332
0
  for (int64_t s : shape) {
333
0
    if (s < -1)
334
0
      return emitError(loc, "invalid tensor dimension size");
335
0
  }
336
0
  return checkTensorElementType(loc, elementType);
337
0
}
338
339
0
ArrayRef<int64_t> RankedTensorType::getShape() const {
340
0
  return getImpl()->getShape();
341
0
}
342
343
//===----------------------------------------------------------------------===//
344
// UnrankedTensorType
345
//===----------------------------------------------------------------------===//
346
347
0
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
348
0
  return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor,
349
0
                   elementType);
350
0
}
351
352
UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
353
0
                                                  Location location) {
354
0
  return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType);
355
0
}
356
357
LogicalResult
358
UnrankedTensorType::verifyConstructionInvariants(Location loc,
359
0
                                                 Type elementType) {
360
0
  return checkTensorElementType(loc, elementType);
361
0
}
362
363
//===----------------------------------------------------------------------===//
364
// MemRefType
365
//===----------------------------------------------------------------------===//
366
367
/// Get or create a new MemRefType based on shape, element type, affine
368
/// map composition, and memory space.  Assumes the arguments define a
369
/// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
370
/// construction failures.
371
MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
372
                           ArrayRef<AffineMap> affineMapComposition,
373
0
                           unsigned memorySpace) {
374
0
  auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
375
0
                        /*location=*/llvm::None);
376
0
  assert(result && "Failed to construct instance of MemRefType.");
377
0
  return result;
378
0
}
379
380
/// Get or create a new MemRefType based on shape, element type, affine
381
/// map composition, and memory space declared at the given location.
382
/// If the location is unknown, the last argument should be an instance of
383
/// UnknownLoc.  If the MemRefType defined by the arguments would be
384
/// ill-formed, emits errors (to the handler registered with the context or to
385
/// the error stream) and returns nullptr.
386
MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
387
                                  ArrayRef<AffineMap> affineMapComposition,
388
0
                                  unsigned memorySpace, Location location) {
389
0
  return getImpl(shape, elementType, affineMapComposition, memorySpace,
390
0
                 location);
391
0
}
392
393
/// Get or create a new MemRefType defined by the arguments.  If the resulting
394
/// type would be ill-formed, return nullptr.  If the location is provided,
395
/// emit detailed error messages.  To emit errors when the location is unknown,
396
/// pass in an instance of UnknownLoc.
397
MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
398
                               ArrayRef<AffineMap> affineMapComposition,
399
                               unsigned memorySpace,
400
0
                               Optional<Location> location) {
401
0
  auto *context = elementType.getContext();
402
0
403
0
  // Check that memref is formed from allowed types.
404
0
  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
405
0
      !elementType.isa<ComplexType>())
406
0
    return emitOptionalError(location, "invalid memref element type"),
407
0
           MemRefType();
408
0
409
0
  for (int64_t s : shape) {
410
0
    // Negative sizes are not allowed except for `-1` that means dynamic size.
411
0
    if (s < -1)
412
0
      return emitOptionalError(location, "invalid memref size"), MemRefType();
413
0
  }
414
0
415
0
  // Check that the structure of the composition is valid, i.e. that each
416
0
  // subsequent affine map has as many inputs as the previous map has results.
417
0
  // Take the dimensionality of the MemRef for the first map.
418
0
  auto dim = shape.size();
419
0
  unsigned i = 0;
420
0
  for (const auto &affineMap : affineMapComposition) {
421
0
    if (affineMap.getNumDims() != dim) {
422
0
      if (location)
423
0
        emitError(*location)
424
0
            << "memref affine map dimension mismatch between "
425
0
            << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
426
0
            << " and affine map" << i + 1 << ": " << dim
427
0
            << " != " << affineMap.getNumDims();
428
0
      return nullptr;
429
0
    }
430
0
431
0
    dim = affineMap.getNumResults();
432
0
    ++i;
433
0
  }
434
0
435
0
  // Drop identity maps from the composition.
436
0
  // This may lead to the composition becoming empty, which is interpreted as an
437
0
  // implicit identity.
438
0
  SmallVector<AffineMap, 2> cleanedAffineMapComposition;
439
0
  for (const auto &map : affineMapComposition) {
440
0
    if (map.isIdentity())
441
0
      continue;
442
0
    cleanedAffineMapComposition.push_back(map);
443
0
  }
444
0
445
0
  return Base::get(context, StandardTypes::MemRef, shape, elementType,
446
0
                   cleanedAffineMapComposition, memorySpace);
447
0
}
448
449
0
ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
450
451
0
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
452
0
  return getImpl()->getAffineMaps();
453
0
}
454
455
0
unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
456
457
//===----------------------------------------------------------------------===//
458
// UnrankedMemRefType
459
//===----------------------------------------------------------------------===//
460
461
UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
462
0
                                           unsigned memorySpace) {
463
0
  return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef,
464
0
                   elementType, memorySpace);
465
0
}
466
467
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
468
                                                  unsigned memorySpace,
469
0
                                                  Location location) {
470
0
  return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType,
471
0
                          memorySpace);
472
0
}
473
474
0
unsigned UnrankedMemRefType::getMemorySpace() const {
475
0
  return getImpl()->memorySpace;
476
0
}
477
478
LogicalResult
479
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
480
0
                                                 unsigned memorySpace) {
481
0
  // Check that memref is formed from allowed types.
482
0
  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
483
0
      !elementType.isa<ComplexType>())
484
0
    return emitError(loc, "invalid memref element type");
485
0
  return success();
486
0
}
487
488
// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
489
// i.e. single term). Accumulate the AffineExpr into the existing one.
490
static void extractStridesFromTerm(AffineExpr e,
491
                                   AffineExpr multiplicativeFactor,
492
                                   MutableArrayRef<AffineExpr> strides,
493
0
                                   AffineExpr &offset) {
494
0
  if (auto dim = e.dyn_cast<AffineDimExpr>())
495
0
    strides[dim.getPosition()] =
496
0
        strides[dim.getPosition()] + multiplicativeFactor;
497
0
  else
498
0
    offset = offset + e * multiplicativeFactor;
499
0
}
500
501
/// Takes a single AffineExpr `e` and populates the `strides` array with the
502
/// strides expressions for each dim position.
503
/// The convention is that the strides for dimensions d0, .. dn appear in
504
/// order to make indexing intuitive into the result.
505
static LogicalResult extractStrides(AffineExpr e,
506
                                    AffineExpr multiplicativeFactor,
507
                                    MutableArrayRef<AffineExpr> strides,
508
0
                                    AffineExpr &offset) {
509
0
  auto bin = e.dyn_cast<AffineBinaryOpExpr>();
510
0
  if (!bin) {
511
0
    extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
512
0
    return success();
513
0
  }
514
0
515
0
  if (bin.getKind() == AffineExprKind::CeilDiv ||
516
0
      bin.getKind() == AffineExprKind::FloorDiv ||
517
0
      bin.getKind() == AffineExprKind::Mod)
518
0
    return failure();
519
0
520
0
  if (bin.getKind() == AffineExprKind::Mul) {
521
0
    auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
522
0
    if (dim) {
523
0
      strides[dim.getPosition()] =
524
0
          strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
525
0
      return success();
526
0
    }
527
0
    // LHS and RHS may both contain complex expressions of dims. Try one path
528
0
    // and if it fails try the other. This is guaranteed to succeed because
529
0
    // only one path may have a `dim`, otherwise this is not an AffineExpr in
530
0
    // the first place.
531
0
    if (bin.getLHS().isSymbolicOrConstant())
532
0
      return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
533
0
                            strides, offset);
534
0
    return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
535
0
                          strides, offset);
536
0
  }
537
0
538
0
  if (bin.getKind() == AffineExprKind::Add) {
539
0
    auto res1 =
540
0
        extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
541
0
    auto res2 =
542
0
        extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
543
0
    return success(succeeded(res1) && succeeded(res2));
544
0
  }
545
0
546
0
  llvm_unreachable("unexpected binary operation");
547
0
}
548
549
LogicalResult mlir::getStridesAndOffset(MemRefType t,
550
                                        SmallVectorImpl<AffineExpr> &strides,
551
0
                                        AffineExpr &offset) {
552
0
  auto affineMaps = t.getAffineMaps();
553
0
  // For now strides are only computed on a single affine map with a single
554
0
  // result (i.e. the closed subset of linearization maps that are compatible
555
0
  // with striding semantics).
556
0
  // TODO(ntv): support more forms on a per-need basis.
557
0
  if (affineMaps.size() > 1)
558
0
    return failure();
559
0
  if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
560
0
    return failure();
561
0
562
0
  auto zero = getAffineConstantExpr(0, t.getContext());
563
0
  auto one = getAffineConstantExpr(1, t.getContext());
564
0
  offset = zero;
565
0
  strides.assign(t.getRank(), zero);
566
0
567
0
  AffineMap m;
568
0
  if (!affineMaps.empty()) {
569
0
    m = affineMaps.front();
570
0
    assert(!m.isIdentity() && "unexpected identity map");
571
0
  }
572
0
573
0
  // Canonical case for empty map.
574
0
  if (!m) {
575
0
    // 0-D corner case, offset is already 0.
576
0
    if (t.getRank() == 0)
577
0
      return success();
578
0
    auto stridedExpr =
579
0
        makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
580
0
    if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
581
0
      return success();
582
0
    assert(false && "unexpected failure: extract strides in canonical layout");
583
0
  }
584
0
585
0
  // Non-canonical case requires more work.
586
0
  auto stridedExpr =
587
0
      simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
588
0
  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
589
0
    offset = AffineExpr();
590
0
    strides.clear();
591
0
    return failure();
592
0
  }
593
0
594
0
  // Simplify results to allow folding to constants and simple checks.
595
0
  unsigned numDims = m.getNumDims();
596
0
  unsigned numSymbols = m.getNumSymbols();
597
0
  offset = simplifyAffineExpr(offset, numDims, numSymbols);
598
0
  for (auto &stride : strides)
599
0
    stride = simplifyAffineExpr(stride, numDims, numSymbols);
600
0
601
0
  /// In practice, a strided memref must be internally non-aliasing. Test
602
0
  /// against 0 as a proxy.
603
0
  /// TODO(ntv) static cases can have more advanced checks.
604
0
  /// TODO(ntv) dynamic cases would require a way to compare symbolic
605
0
  /// expressions and would probably need an affine set context propagated
606
0
  /// everywhere.
607
0
  if (llvm::any_of(strides, [](AffineExpr e) {
608
0
        return e == getAffineConstantExpr(0, e.getContext());
609
0
      })) {
610
0
    offset = AffineExpr();
611
0
    strides.clear();
612
0
    return failure();
613
0
  }
614
0
615
0
  return success();
616
0
}
617
618
LogicalResult mlir::getStridesAndOffset(MemRefType t,
619
                                        SmallVectorImpl<int64_t> &strides,
620
0
                                        int64_t &offset) {
621
0
  AffineExpr offsetExpr;
622
0
  SmallVector<AffineExpr, 4> strideExprs;
623
0
  if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
624
0
    return failure();
625
0
  if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
626
0
    offset = cst.getValue();
627
0
  else
628
0
    offset = ShapedType::kDynamicStrideOrOffset;
629
0
  for (auto e : strideExprs) {
630
0
    if (auto c = e.dyn_cast<AffineConstantExpr>())
631
0
      strides.push_back(c.getValue());
632
0
    else
633
0
      strides.push_back(ShapedType::kDynamicStrideOrOffset);
634
0
  }
635
0
  return success();
636
0
}
637
638
//===----------------------------------------------------------------------===//
639
/// TupleType
640
//===----------------------------------------------------------------------===//
641
642
/// Get or create a new TupleType with the provided element types. Assumes the
643
/// arguments define a well-formed type.
644
0
TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) {
645
0
  return Base::get(context, StandardTypes::Tuple, elementTypes);
646
0
}
647
648
/// Return the elements types for this tuple.
649
0
ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
650
651
/// Accumulate the types contained in this tuple and tuples nested within it.
652
/// Note that this only flattens nested tuples, not any other container type,
653
/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
654
/// (i32, tensor<i32>, f32, i64)
655
0
void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
656
0
  for (Type type : getTypes()) {
657
0
    if (auto nestedTuple = type.dyn_cast<TupleType>())
658
0
      nestedTuple.getFlattenedTypes(types);
659
0
    else
660
0
      types.push_back(type);
661
0
  }
662
0
}
663
664
/// Return the number of element types.
665
0
size_t TupleType::size() const { return getImpl()->size(); }
666
667
AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
668
                                           int64_t offset,
669
0
                                           MLIRContext *context) {
670
0
  AffineExpr expr;
671
0
  unsigned nSymbols = 0;
672
0
673
0
  // AffineExpr for offset.
674
0
  // Static case.
675
0
  if (offset != MemRefType::getDynamicStrideOrOffset()) {
676
0
    auto cst = getAffineConstantExpr(offset, context);
677
0
    expr = cst;
678
0
  } else {
679
0
    // Dynamic case, new symbol for the offset.
680
0
    auto sym = getAffineSymbolExpr(nSymbols++, context);
681
0
    expr = sym;
682
0
  }
683
0
684
0
  // AffineExpr for strides.
685
0
  for (auto en : llvm::enumerate(strides)) {
686
0
    auto dim = en.index();
687
0
    auto stride = en.value();
688
0
    assert(stride != 0 && "Invalid stride specification");
689
0
    auto d = getAffineDimExpr(dim, context);
690
0
    AffineExpr mult;
691
0
    // Static case.
692
0
    if (stride != MemRefType::getDynamicStrideOrOffset())
693
0
      mult = getAffineConstantExpr(stride, context);
694
0
    else
695
0
      // Dynamic case, new symbol for each new stride.
696
0
      mult = getAffineSymbolExpr(nSymbols++, context);
697
0
    expr = expr + d * mult;
698
0
  }
699
0
700
0
  return AffineMap::get(strides.size(), nSymbols, expr);
701
0
}
702
703
/// Return a version of `t` with identity layout if it can be determined
704
/// statically that the layout is the canonical contiguous strided layout.
705
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
706
/// `t` with simplified layout.
707
/// If `t` has multiple layout maps or a multi-result layout, just return `t`.
708
0
MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
709
0
  auto affineMaps = t.getAffineMaps();
710
0
  // Already in canonical form.
711
0
  if (affineMaps.empty())
712
0
    return t;
713
0
714
0
  // Can't reduce to canonical identity form, return in canonical form.
715
0
  if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
716
0
    return t;
717
0
718
0
  // If the canonical strided layout for the sizes of `t` is equal to the
719
0
  // simplified layout of `t` we can just return an empty layout. Otherwise,
720
0
  // just simplify the existing layout.
721
0
  AffineExpr expr =
722
0
      makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
723
0
  auto m = affineMaps[0];
724
0
  auto simplifiedLayoutExpr =
725
0
      simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
726
0
  if (expr != simplifiedLayoutExpr)
727
0
    return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
728
0
        m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
729
0
  return MemRefType::Builder(t).setAffineMaps({});
730
0
}
731
732
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
733
                                                ArrayRef<AffineExpr> exprs,
734
0
                                                MLIRContext *context) {
735
0
  AffineExpr expr;
736
0
  bool dynamicPoisonBit = false;
737
0
  unsigned numDims = 0;
738
0
  unsigned nSymbols = 0;
739
0
  // Compute the number of symbols and dimensions of the passed exprs.
740
0
  for (AffineExpr expr : exprs) {
741
0
    expr.walk([&numDims, &nSymbols](AffineExpr d) {
742
0
      if (AffineDimExpr dim = d.dyn_cast<AffineDimExpr>())
743
0
        numDims = std::max(numDims, dim.getPosition() + 1);
744
0
      else if (AffineSymbolExpr symbol = d.dyn_cast<AffineSymbolExpr>())
745
0
        nSymbols = std::max(nSymbols, symbol.getPosition() + 1);
746
0
    });
747
0
  }
748
0
  int64_t runningSize = 1;
749
0
  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
750
0
    int64_t size = std::get<1>(en);
751
0
    // Degenerate case, no size =-> no stride
752
0
    if (size == 0)
753
0
      continue;
754
0
    AffineExpr dimExpr = std::get<0>(en);
755
0
    AffineExpr stride = dynamicPoisonBit
756
0
                            ? getAffineSymbolExpr(nSymbols++, context)
757
0
                            : getAffineConstantExpr(runningSize, context);
758
0
    expr = expr ? expr + dimExpr * stride : dimExpr * stride;
759
0
    if (size > 0)
760
0
      runningSize *= size;
761
0
    else
762
0
      dynamicPoisonBit = true;
763
0
  }
764
0
  return simplifyAffineExpr(expr, numDims, nSymbols);
765
0
}
766
767
/// Return a version of `t` with a layout that has all dynamic offset and
768
/// strides. This is used to erase the static layout.
769
0
MemRefType mlir::eraseStridedLayout(MemRefType t) {
770
0
  auto val = ShapedType::kDynamicStrideOrOffset;
771
0
  return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
772
0
      SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
773
0
}
774
775
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
776
0
                                                MLIRContext *context) {
777
0
  SmallVector<AffineExpr, 4> exprs;
778
0
  exprs.reserve(sizes.size());
779
0
  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
780
0
    exprs.push_back(getAffineDimExpr(dim, context));
781
0
  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
782
0
}
783
784
/// Return true if the layout for `t` is compatible with strided semantics.
785
0
bool mlir::isStrided(MemRefType t) {
786
0
  int64_t offset;
787
0
  SmallVector<int64_t, 4> stridesAndOffset;
788
0
  auto res = getStridesAndOffset(t, stridesAndOffset, offset);
789
0
  return succeeded(res);
790
0
}