Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11
#include "mlir/Dialect/StandardOps/IR/Ops.h"
12
#include "mlir/IR/Function.h"
13
#include "mlir/IR/IntegerSet.h"
14
#include "mlir/IR/Matchers.h"
15
#include "mlir/IR/OpImplementation.h"
16
#include "mlir/IR/PatternMatch.h"
17
#include "mlir/Transforms/InliningUtils.h"
18
#include "llvm/ADT/SetVector.h"
19
#include "llvm/ADT/SmallBitVector.h"
20
#include "llvm/Support/Debug.h"
21
22
using namespace mlir;
23
using llvm::dbgs;
24
25
#define DEBUG_TYPE "affine-analysis"
26
27
//===----------------------------------------------------------------------===//
28
// AffineDialect Interfaces
29
//===----------------------------------------------------------------------===//
30
31
namespace {
32
/// This class defines the interface for handling inlining with affine
33
/// operations.
34
struct AffineInlinerInterface : public DialectInlinerInterface {
35
  using DialectInlinerInterface::DialectInlinerInterface;
36
37
  //===--------------------------------------------------------------------===//
38
  // Analysis Hooks
39
  //===--------------------------------------------------------------------===//
40
41
  /// Returns true if the given region 'src' can be inlined into the region
42
  /// 'dest' that is attached to an operation registered to the current dialect.
43
  bool isLegalToInline(Region *dest, Region *src,
44
0
                       BlockAndValueMapping &valueMapping) const final {
45
0
    // Conservatively don't allow inlining into affine structures.
46
0
    return false;
47
0
  }
48
49
  /// Returns true if the given operation 'op', that is registered to this
50
  /// dialect, can be inlined into the given region, false otherwise.
51
  bool isLegalToInline(Operation *op, Region *region,
52
0
                       BlockAndValueMapping &valueMapping) const final {
53
0
    // Always allow inlining affine operations into the top-level region of a
54
0
    // function. There are some edge cases when inlining *into* affine
55
0
    // structures, but that is handled in the other 'isLegalToInline' hook
56
0
    // above.
57
0
    // TODO: We should be able to inline into other regions than functions.
58
0
    return isa<FuncOp>(region->getParentOp());
59
0
  }
60
61
  /// Affine regions should be analyzed recursively.
62
0
  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
63
};
64
} // end anonymous namespace
65
66
//===----------------------------------------------------------------------===//
67
// AffineDialect
68
//===----------------------------------------------------------------------===//
69
70
AffineDialect::AffineDialect(MLIRContext *context)
71
0
    : Dialect(getDialectNamespace(), context) {
72
0
  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
73
0
#define GET_OP_LIST
74
0
#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
75
0
                >();
76
0
  addInterfaces<AffineInlinerInterface>();
77
0
}
78
79
/// Materialize a single constant operation from a given attribute value with
80
/// the desired resultant type.
81
Operation *AffineDialect::materializeConstant(OpBuilder &builder,
82
                                              Attribute value, Type type,
83
0
                                              Location loc) {
84
0
  return builder.create<ConstantOp>(loc, type, value);
85
0
}
86
87
/// A utility function to check if a value is defined at the top level of an
88
/// op with trait `AffineScope`. If the value is defined in an unlinked region,
89
/// conservatively assume it is not top-level. A value of index type defined at
90
/// the top level is always a valid symbol.
91
0
bool mlir::isTopLevelValue(Value value) {
92
0
  if (auto arg = value.dyn_cast<BlockArgument>()) {
93
0
    // The block owning the argument may be unlinked, e.g. when the surrounding
94
0
    // region has not yet been attached to an Op, at which point the parent Op
95
0
    // is null.
96
0
    Operation *parentOp = arg.getOwner()->getParentOp();
97
0
    return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
98
0
  }
99
0
  // The defining Op may live in an unlinked block so its parent Op may be null.
100
0
  Operation *parentOp = value.getDefiningOp()->getParentOp();
101
0
  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
102
0
}
103
104
/// A utility function to check if a value is defined at the top level of
105
/// `region` or is an argument of `region`. A value of index type defined at the
106
/// top level of a `AffineScope` region is always a valid symbol for all
107
/// uses in that region.
108
0
static bool isTopLevelValue(Value value, Region *region) {
109
0
  if (auto arg = value.dyn_cast<BlockArgument>())
110
0
    return arg.getParentRegion() == region;
111
0
  return value.getDefiningOp()->getParentRegion() == region;
112
0
}
113
114
/// Returns the closest region enclosing `op` that is held by an operation with
115
/// trait `AffineScope`.
116
//  TODO: getAffineScope should be publicly exposed for affine passes/utilities.
117
0
static Region *getAffineScope(Operation *op) {
118
0
  auto *curOp = op;
119
0
  while (auto *parentOp = curOp->getParentOp()) {
120
0
    if (parentOp->hasTrait<OpTrait::AffineScope>())
121
0
      return curOp->getParentRegion();
122
0
    curOp = parentOp;
123
0
  }
124
0
  llvm_unreachable("op doesn't have an enclosing polyhedral scope");
125
0
}
126
127
// A Value can be used as a dimension id iff it meets one of the following
128
// conditions:
129
// *) It is valid as a symbol.
130
// *) It is an induction variable.
131
// *) It is the result of affine apply operation with dimension id arguments.
132
0
bool mlir::isValidDim(Value value) {
133
0
  // The value must be an index type.
134
0
  if (!value.getType().isIndex())
135
0
    return false;
136
0
137
0
  if (auto *defOp = value.getDefiningOp())
138
0
    return isValidDim(value, getAffineScope(defOp));
139
0
140
0
  // This value has to be a block argument for an op that has the
141
0
  // `AffineScope` trait or for an affine.for or affine.parallel.
142
0
  auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
143
0
  return parentOp &&
144
0
         (parentOp->hasTrait<OpTrait::AffineScope>() ||
145
0
          isa<AffineForOp>(parentOp) || isa<AffineParallelOp>(parentOp));
146
0
}
147
148
// Value can be used as a dimension id iff it meets one of the following
149
// conditions:
150
// *) It is valid as a symbol.
151
// *) It is an induction variable.
152
// *) It is the result of an affine apply operation with dimension id operands.
153
0
bool mlir::isValidDim(Value value, Region *region) {
154
0
  // The value must be an index type.
155
0
  if (!value.getType().isIndex())
156
0
    return false;
157
0
158
0
  // All valid symbols are okay.
159
0
  if (isValidSymbol(value, region))
160
0
    return true;
161
0
162
0
  auto *op = value.getDefiningOp();
163
0
  if (!op) {
164
0
    // This value has to be a block argument for an affine.for or an
165
0
    // affine.parallel.
166
0
    auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
167
0
    return isa<AffineForOp>(parentOp) || isa<AffineParallelOp>(parentOp);
168
0
  }
169
0
170
0
  // Affine apply operation is ok if all of its operands are ok.
171
0
  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
172
0
    return applyOp.isValidDim(region);
173
0
  // The dim op is okay if its operand memref/tensor is defined at the top
174
0
  // level.
175
0
  if (auto dimOp = dyn_cast<DimOp>(op))
176
0
    return isTopLevelValue(dimOp.getOperand());
177
0
  return false;
178
0
}
179
180
/// Returns true if the 'index' dimension of the `memref` defined by
181
/// `memrefDefOp` is a statically  shaped one or defined using a valid symbol
182
/// for `region`.
183
template <typename AnyMemRefDefOp>
184
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
185
0
                                    Region *region) {
186
0
  auto memRefType = memrefDefOp.getType();
187
0
  // Statically shaped.
188
0
  if (!memRefType.isDynamicDim(index))
189
0
    return true;
190
0
  // Get the position of the dimension among dynamic dimensions;
191
0
  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
192
0
  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
193
0
                       region);
194
0
}
Unexecuted instantiation: AffineOps.cpp:_ZL23isMemRefSizeValidSymbolIN4mlir6ViewOpEEbT_jPNS0_6RegionE
Unexecuted instantiation: AffineOps.cpp:_ZL23isMemRefSizeValidSymbolIN4mlir9SubViewOpEEbT_jPNS0_6RegionE
Unexecuted instantiation: AffineOps.cpp:_ZL23isMemRefSizeValidSymbolIN4mlir7AllocOpEEbT_jPNS0_6RegionE
195
196
/// Returns true if the result of the dim op is a valid symbol for `region`.
197
0
static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
198
0
  // The dim op is okay if its operand memref/tensor is defined at the top
199
0
  // level.
200
0
  if (isTopLevelValue(dimOp.getOperand()))
201
0
    return true;
202
0
203
0
  // The dim op is also okay if its operand memref/tensor is a view/subview
204
0
  // whose corresponding size is a valid symbol.
205
0
  unsigned index = dimOp.getIndex();
206
0
  if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand().getDefiningOp()))
207
0
    return isMemRefSizeValidSymbol<ViewOp>(viewOp, index, region);
208
0
  if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand().getDefiningOp()))
209
0
    return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index, region);
210
0
  if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand().getDefiningOp()))
211
0
    return isMemRefSizeValidSymbol<AllocOp>(allocOp, index, region);
212
0
  return false;
213
0
}
214
215
// A value can be used as a symbol (at all its use sites) iff it meets one of
216
// the following conditions:
217
// *) It is a constant.
218
// *) Its defining op or block arg appearance is immediately enclosed by an op
219
//    with `AffineScope` trait.
220
// *) It is the result of an affine.apply operation with symbol operands.
221
// *) It is a result of the dim op on a memref whose corresponding size is a
222
//    valid symbol.
223
0
bool mlir::isValidSymbol(Value value) {
224
0
  // The value must be an index type.
225
0
  if (!value.getType().isIndex())
226
0
    return false;
227
0
228
0
  // Check that the value is a top level value.
229
0
  if (isTopLevelValue(value))
230
0
    return true;
231
0
232
0
  if (auto *defOp = value.getDefiningOp())
233
0
    return isValidSymbol(value, getAffineScope(defOp));
234
0
235
0
  return false;
236
0
}
237
238
// A value can be used as a symbol for `region` iff it meets onf of the the
239
// following conditions:
240
// *) It is a constant.
241
// *) It is defined at the top level of 'region' or is its argument.
242
// *) It dominates `region`'s parent op.
243
// *) It is the result of an affine apply operation with symbol arguments.
244
// *) It is a result of the dim op on a memref whose corresponding size is
245
//    a valid symbol.
246
0
bool mlir::isValidSymbol(Value value, Region *region) {
247
0
  // The value must be an index type.
248
0
  if (!value.getType().isIndex())
249
0
    return false;
250
0
251
0
  // A top-level value is a valid symbol.
252
0
  if (::isTopLevelValue(value, region))
253
0
    return true;
254
0
255
0
  auto *defOp = value.getDefiningOp();
256
0
  if (!defOp) {
257
0
    // A block argument that is not a top-level value is a valid symbol if it
258
0
    // dominates region's parent op.
259
0
    if (!region->getParentOp()->isKnownIsolatedFromAbove())
260
0
      if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
261
0
        return isValidSymbol(value, parentOpRegion);
262
0
    return false;
263
0
  }
264
0
265
0
  // Constant operation is ok.
266
0
  Attribute operandCst;
267
0
  if (matchPattern(defOp, m_Constant(&operandCst)))
268
0
    return true;
269
0
270
0
  // Affine apply operation is ok if all of its operands are ok.
271
0
  if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
272
0
    return applyOp.isValidSymbol(region);
273
0
274
0
  // Dim op results could be valid symbols at any level.
275
0
  if (auto dimOp = dyn_cast<DimOp>(defOp))
276
0
    return isDimOpValidSymbol(dimOp, region);
277
0
278
0
  // Check for values dominating `region`'s parent op.
279
0
  if (!region->getParentOp()->isKnownIsolatedFromAbove())
280
0
    if (auto *parentRegion = region->getParentOp()->getParentRegion())
281
0
      return isValidSymbol(value, parentRegion);
282
0
283
0
  return false;
284
0
}
285
286
// Returns true if 'value' is a valid index to an affine operation (e.g.
287
// affine.load, affine.store, affine.dma_start, affine.dma_wait) where
288
// `region` provides the polyhedral symbol scope. Returns false otherwise.
289
0
static bool isValidAffineIndexOperand(Value value, Region *region) {
290
0
  return isValidDim(value, region) || isValidSymbol(value, region);
291
0
}
292
293
/// Utility function to verify that a set of operands are valid dimension and
294
/// symbol identifiers. The operands should be laid out such that the dimension
295
/// operands are before the symbol operands. This function returns failure if
296
/// there was an invalid operand. An operation is provided to emit any necessary
297
/// errors.
298
template <typename OpTy>
299
static LogicalResult
300
verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
301
0
                              unsigned numDims) {
302
0
  unsigned opIt = 0;
303
0
  for (auto operand : operands) {
304
0
    if (opIt++ < numDims) {
305
0
      if (!isValidDim(operand, getAffineScope(op)))
306
0
        return op.emitOpError("operand cannot be used as a dimension id");
307
0
    } else if (!isValidSymbol(operand, getAffineScope(op))) {
308
0
      return op.emitOpError("operand cannot be used as a symbol");
309
0
    }
310
0
  }
311
0
  return success();
312
0
}
Unexecuted instantiation: AffineOps.cpp:_ZL29verifyDimAndSymbolIdentifiersIN4mlir11AffineForOpEENS0_13LogicalResultERT_NS0_12OperandRangeEj
Unexecuted instantiation: AffineOps.cpp:_ZL29verifyDimAndSymbolIdentifiersIN4mlir10AffineIfOpEENS0_13LogicalResultERT_NS0_12OperandRangeEj
Unexecuted instantiation: AffineOps.cpp:_ZL29verifyDimAndSymbolIdentifiersIN4mlir16AffineParallelOpEENS0_13LogicalResultERT_NS0_12OperandRangeEj
313
314
//===----------------------------------------------------------------------===//
315
// AffineApplyOp
316
//===----------------------------------------------------------------------===//
317
318
0
AffineValueMap AffineApplyOp::getAffineValueMap() {
319
0
  return AffineValueMap(getAffineMap(), getOperands(), getResult());
320
0
}
321
322
static ParseResult parseAffineApplyOp(OpAsmParser &parser,
323
0
                                      OperationState &result) {
324
0
  auto &builder = parser.getBuilder();
325
0
  auto indexTy = builder.getIndexType();
326
0
327
0
  AffineMapAttr mapAttr;
328
0
  unsigned numDims;
329
0
  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
330
0
      parseDimAndSymbolList(parser, result.operands, numDims) ||
331
0
      parser.parseOptionalAttrDict(result.attributes))
332
0
    return failure();
333
0
  auto map = mapAttr.getValue();
334
0
335
0
  if (map.getNumDims() != numDims ||
336
0
      numDims + map.getNumSymbols() != result.operands.size()) {
337
0
    return parser.emitError(parser.getNameLoc(),
338
0
                            "dimension or symbol index mismatch");
339
0
  }
340
0
341
0
  result.types.append(map.getNumResults(), indexTy);
342
0
  return success();
343
0
}
344
345
0
static void print(OpAsmPrinter &p, AffineApplyOp op) {
346
0
  p << AffineApplyOp::getOperationName() << " " << op.mapAttr();
347
0
  printDimAndSymbolList(op.operand_begin(), op.operand_end(),
348
0
                        op.getAffineMap().getNumDims(), p);
349
0
  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
350
0
}
351
352
0
static LogicalResult verify(AffineApplyOp op) {
353
0
  // Check input and output dimensions match.
354
0
  auto map = op.map();
355
0
356
0
  // Verify that operand count matches affine map dimension and symbol count.
357
0
  if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols())
358
0
    return op.emitOpError(
359
0
        "operand count and affine map dimension and symbol count must match");
360
0
361
0
  // Verify that the map only produces one result.
362
0
  if (map.getNumResults() != 1)
363
0
    return op.emitOpError("mapping must produce one value");
364
0
365
0
  return success();
366
0
}
367
368
// The result of the affine apply operation can be used as a dimension id if all
369
// its operands are valid dimension ids.
370
0
bool AffineApplyOp::isValidDim() {
371
0
  return llvm::all_of(getOperands(),
372
0
                      [](Value op) { return mlir::isValidDim(op); });
373
0
}
374
375
// The result of the affine apply operation can be used as a dimension id if all
376
// its operands are valid dimension ids with the parent operation of `region`
377
// defining the polyhedral scope for symbols.
378
0
bool AffineApplyOp::isValidDim(Region *region) {
379
0
  return llvm::all_of(getOperands(),
380
0
                      [&](Value op) { return ::isValidDim(op, region); });
381
0
}
382
383
// The result of the affine apply operation can be used as a symbol if all its
384
// operands are symbols.
385
0
bool AffineApplyOp::isValidSymbol() {
386
0
  return llvm::all_of(getOperands(),
387
0
                      [](Value op) { return mlir::isValidSymbol(op); });
388
0
}
389
390
// The result of the affine apply operation can be used as a symbol in `region`
391
// if all its operands are symbols in `region`.
392
0
bool AffineApplyOp::isValidSymbol(Region *region) {
393
0
  return llvm::all_of(getOperands(), [&](Value operand) {
394
0
    return mlir::isValidSymbol(operand, region);
395
0
  });
396
0
}
397
398
0
OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
399
0
  auto map = getAffineMap();
400
0
401
0
  // Fold dims and symbols to existing values.
402
0
  auto expr = map.getResult(0);
403
0
  if (auto dim = expr.dyn_cast<AffineDimExpr>())
404
0
    return getOperand(dim.getPosition());
405
0
  if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
406
0
    return getOperand(map.getNumDims() + sym.getPosition());
407
0
408
0
  // Otherwise, default to folding the map.
409
0
  SmallVector<Attribute, 1> result;
410
0
  if (failed(map.constantFold(operands, result)))
411
0
    return {};
412
0
  return result[0];
413
0
}
414
415
0
AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) {
416
0
  DenseMap<Value, unsigned>::iterator iterPos;
417
0
  bool inserted = false;
418
0
  std::tie(iterPos, inserted) =
419
0
      dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size()));
420
0
  if (inserted) {
421
0
    reorderedDims.push_back(v);
422
0
  }
423
0
  return getAffineDimExpr(iterPos->second, v.getContext())
424
0
      .cast<AffineDimExpr>();
425
0
}
426
427
0
AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) {
428
0
  SmallVector<AffineExpr, 8> dimRemapping;
429
0
  for (auto v : other.reorderedDims) {
430
0
    auto kvp = other.dimValueToPosition.find(v);
431
0
    if (dimRemapping.size() <= kvp->second)
432
0
      dimRemapping.resize(kvp->second + 1);
433
0
    dimRemapping[kvp->second] = renumberOneDim(kvp->first);
434
0
  }
435
0
  unsigned numSymbols = concatenatedSymbols.size();
436
0
  unsigned numOtherSymbols = other.concatenatedSymbols.size();
437
0
  SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols);
438
0
  for (unsigned idx = 0; idx < numOtherSymbols; ++idx) {
439
0
    symRemapping[idx] =
440
0
        getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext());
441
0
  }
442
0
  concatenatedSymbols.insert(concatenatedSymbols.end(),
443
0
                             other.concatenatedSymbols.begin(),
444
0
                             other.concatenatedSymbols.end());
445
0
  auto map = other.affineMap;
446
0
  return map.replaceDimsAndSymbols(dimRemapping, symRemapping,
447
0
                                   reorderedDims.size(),
448
0
                                   concatenatedSymbols.size());
449
0
}
450
451
// Gather the positions of the operands that are produced by an AffineApplyOp.
452
static llvm::SetVector<unsigned>
453
0
indicesFromAffineApplyOp(ArrayRef<Value> operands) {
454
0
  llvm::SetVector<unsigned> res;
455
0
  for (auto en : llvm::enumerate(operands))
456
0
    if (isa_and_nonnull<AffineApplyOp>(en.value().getDefiningOp()))
457
0
      res.insert(en.index());
458
0
  return res;
459
0
}
460
461
// Support the special case of a symbol coming from an AffineApplyOp that needs
462
// to be composed into the current AffineApplyOp.
463
// This case is handled by rewriting all such symbols into dims for the purpose
464
// of allowing mathematical AffineMap composition.
465
// Returns an AffineMap where symbols that come from an AffineApplyOp have been
466
// rewritten as dims and are ordered after the original dims.
467
// TODO(andydavis,ntv): This promotion makes AffineMap lose track of which
468
// symbols are represented as dims. This loss is static but can still be
469
// recovered dynamically (with `isValidSymbol`). Still this is annoying for the
470
// semi-affine map case. A dynamic canonicalization of all dims that are valid
471
// symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even
472
// results in better simplifications and foldings. But we should evaluate
473
// whether this behavior is what we really want after using more.
474
static AffineMap promoteComposedSymbolsAsDims(AffineMap map,
475
0
                                              ArrayRef<Value> symbols) {
476
0
  if (symbols.empty()) {
477
0
    return map;
478
0
  }
479
0
480
0
  // Sanity check on symbols.
481
0
  for (auto sym : symbols) {
482
0
    assert(isValidSymbol(sym) && "Expected only valid symbols");
483
0
    (void)sym;
484
0
  }
485
0
486
0
  // Extract the symbol positions that come from an AffineApplyOp and
487
0
  // needs to be rewritten as dims.
488
0
  auto symPositions = indicesFromAffineApplyOp(symbols);
489
0
  if (symPositions.empty()) {
490
0
    return map;
491
0
  }
492
0
493
0
  // Create the new map by replacing each symbol at pos by the next new dim.
494
0
  unsigned numDims = map.getNumDims();
495
0
  unsigned numSymbols = map.getNumSymbols();
496
0
  unsigned numNewDims = 0;
497
0
  unsigned numNewSymbols = 0;
498
0
  SmallVector<AffineExpr, 8> symReplacements(numSymbols);
499
0
  for (unsigned i = 0; i < numSymbols; ++i) {
500
0
    symReplacements[i] =
501
0
        symPositions.count(i) > 0
502
0
            ? getAffineDimExpr(numDims + numNewDims++, map.getContext())
503
0
            : getAffineSymbolExpr(numNewSymbols++, map.getContext());
504
0
  }
505
0
  assert(numSymbols >= numNewDims);
506
0
  AffineMap newMap = map.replaceDimsAndSymbols(
507
0
      {}, symReplacements, numDims + numNewDims, numNewSymbols);
508
0
509
0
  return newMap;
510
0
}
511
512
/// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to
513
/// keep a correspondence between the mathematical `map` and the `operands` of
514
/// a given AffineApplyOp. This correspondence is maintained by iterating over
515
/// the operands and forming an `auxiliaryMap` that can be composed
516
/// mathematically with `map`. To keep this correspondence in cases where
517
/// symbols are produced by affine.apply operations, we perform a local rewrite
518
/// of symbols as dims.
519
///
520
/// Rationale for locally rewriting symbols as dims:
521
/// ================================================
522
/// The mathematical composition of AffineMap must always concatenate symbols
523
/// because it does not have enough information to do otherwise. For example,
524
/// composing `(d0)[s0] -> (d0 + s0)` with itself must produce
525
/// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
526
///
527
/// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
528
/// applied to the same mlir::Value for both s0 and s1.
529
/// As a consequence mathematical composition of AffineMap always concatenates
530
/// symbols.
531
///
532
/// When AffineMaps are used in AffineApplyOp however, they may specify
533
/// composition via symbols, which is ambiguous mathematically. This corner case
534
/// is handled by locally rewriting such symbols that come from AffineApplyOp
535
/// into dims and composing through dims.
536
/// TODO(andydavis, ntv): Composition via symbols comes at a significant code
537
/// complexity. Alternatively we should investigate whether we want to
538
/// explicitly disallow symbols coming from affine.apply and instead force the
539
/// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2
540
/// extra API calls for such uses, which haven't popped up until now) and the
541
/// benefit potentially big: simpler and more maintainable code for a
542
/// non-trivial, recursive, procedure.
543
AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
544
                                             ArrayRef<Value> operands)
545
0
    : AffineApplyNormalizer() {
546
0
  static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0");
547
0
  assert(map.getNumInputs() == operands.size() &&
548
0
         "number of operands does not match the number of map inputs");
549
0
550
0
  LLVM_DEBUG(map.print(dbgs() << "\nInput map: "));
551
0
552
0
  // Promote symbols that come from an AffineApplyOp to dims by rewriting the
553
0
  // map to always refer to:
554
0
  //   (dims, symbols coming from AffineApplyOp, other symbols).
555
0
  // The order of operands can remain unchanged.
556
0
  // This is a simplification that relies on 2 ordering properties:
557
0
  //   1. rewritten symbols always appear after the original dims in the map;
558
0
  //   2. operands are traversed in order and either dispatched to:
559
0
  //      a. auxiliaryExprs (dims and symbols rewritten as dims);
560
0
  //      b. concatenatedSymbols (all other symbols)
561
0
  // This allows operand order to remain unchanged.
562
0
  unsigned numDimsBeforeRewrite = map.getNumDims();
563
0
  map = promoteComposedSymbolsAsDims(map,
564
0
                                     operands.take_back(map.getNumSymbols()));
565
0
566
0
  LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: "));
567
0
568
0
  SmallVector<AffineExpr, 8> auxiliaryExprs;
569
0
  bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth);
570
0
  // We fully spell out the 2 cases below. In this particular instance a little
571
0
  // code duplication greatly improves readability.
572
0
  // Note that the first branch would disappear if we only supported full
573
0
  // composition (i.e. infinite kMaxAffineApplyDepth).
574
0
  if (!furtherCompose) {
575
0
    // 1. Only dispatch dims or symbols.
576
0
    for (auto en : llvm::enumerate(operands)) {
577
0
      auto t = en.value();
578
0
      assert(t.getType().isIndex());
579
0
      bool isDim = (en.index() < map.getNumDims());
580
0
      if (isDim) {
581
0
        // a. The mathematical composition of AffineMap composes dims.
582
0
        auxiliaryExprs.push_back(renumberOneDim(t));
583
0
      } else {
584
0
        // b. The mathematical composition of AffineMap concatenates symbols.
585
0
        //    We do the same for symbol operands.
586
0
        concatenatedSymbols.push_back(t);
587
0
      }
588
0
    }
589
0
  } else {
590
0
    assert(numDimsBeforeRewrite <= operands.size());
591
0
    // 2. Compose AffineApplyOps and dispatch dims or symbols.
592
0
    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
593
0
      auto t = operands[i];
594
0
      auto affineApply = t.getDefiningOp<AffineApplyOp>();
595
0
      if (affineApply) {
596
0
        // a. Compose affine.apply operations.
597
0
        LLVM_DEBUG(affineApply.getOperation()->print(
598
0
            dbgs() << "\nCompose AffineApplyOp recursively: "));
599
0
        AffineMap affineApplyMap = affineApply.getAffineMap();
600
0
        SmallVector<Value, 8> affineApplyOperands(
601
0
            affineApply.getOperands().begin(), affineApply.getOperands().end());
602
0
        AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands);
603
0
604
0
        LLVM_DEBUG(normalizer.affineMap.print(
605
0
            dbgs() << "\nRenumber into current normalizer: "));
606
0
607
0
        auto renumberedMap = renumber(normalizer);
608
0
609
0
        LLVM_DEBUG(
610
0
            renumberedMap.print(dbgs() << "\nRecursive composition yields: "));
611
0
612
0
        auxiliaryExprs.push_back(renumberedMap.getResult(0));
613
0
      } else {
614
0
        if (i < numDimsBeforeRewrite) {
615
0
          // b. The mathematical composition of AffineMap composes dims.
616
0
          auxiliaryExprs.push_back(renumberOneDim(t));
617
0
        } else {
618
0
          // c. The mathematical composition of AffineMap concatenates symbols.
619
0
          //    Note that the map composition will put symbols already present
620
0
          //    in the map before any symbols coming from the auxiliary map, so
621
0
          //    we insert them before any symbols that are due to renumbering,
622
0
          //    and after the proper symbols we have seen already.
623
0
          concatenatedSymbols.insert(
624
0
              std::next(concatenatedSymbols.begin(), numProperSymbols++), t);
625
0
        }
626
0
      }
627
0
    }
628
0
  }
629
0
630
0
  // Early exit if `map` is already composed.
631
0
  if (auxiliaryExprs.empty()) {
632
0
    affineMap = map;
633
0
    return;
634
0
  }
635
0
636
0
  assert(concatenatedSymbols.size() >= map.getNumSymbols() &&
637
0
         "Unexpected number of concatenated symbols");
638
0
  auto numDims = dimValueToPosition.size();
639
0
  auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols();
640
0
  auto auxiliaryMap =
641
0
      AffineMap::get(numDims, numSymbols, auxiliaryExprs, map.getContext());
642
0
643
0
  LLVM_DEBUG(map.print(dbgs() << "\nCompose map: "));
644
0
  LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: "));
645
0
  LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: "));
646
0
647
0
  // TODO(andydavis,ntv): Disabling simplification results in major speed gains.
648
0
  // Another option is to cache the results as it is expected a lot of redundant
649
0
  // work is performed in practice.
650
0
  affineMap = simplifyAffineMap(map.compose(auxiliaryMap));
651
0
652
0
  LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: "));
653
0
  LLVM_DEBUG(dbgs() << "\n");
654
0
}
655
656
void AffineApplyNormalizer::normalize(AffineMap *otherMap,
657
0
                                      SmallVectorImpl<Value> *otherOperands) {
658
0
  AffineApplyNormalizer other(*otherMap, *otherOperands);
659
0
  *otherMap = renumber(other);
660
0
661
0
  otherOperands->reserve(reorderedDims.size() + concatenatedSymbols.size());
662
0
  otherOperands->assign(reorderedDims.begin(), reorderedDims.end());
663
0
  otherOperands->append(concatenatedSymbols.begin(), concatenatedSymbols.end());
664
0
}
665
666
/// Implements `map` and `operands` composition and simplification to support
667
/// `makeComposedAffineApply`. This can be called to achieve the same effects
668
/// on `map` and `operands` without creating an AffineApplyOp that needs to be
669
/// immediately deleted.
670
static void composeAffineMapAndOperands(AffineMap *map,
671
0
                                        SmallVectorImpl<Value> *operands) {
672
0
  AffineApplyNormalizer normalizer(*map, *operands);
673
0
  auto normalizedMap = normalizer.getAffineMap();
674
0
  auto normalizedOperands = normalizer.getOperands();
675
0
  canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands);
676
0
  *map = normalizedMap;
677
0
  *operands = normalizedOperands;
678
0
  assert(*map);
679
0
}
680
681
void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
682
0
                                            SmallVectorImpl<Value> *operands) {
683
0
  while (llvm::any_of(*operands, [](Value v) {
684
0
    return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
685
0
  })) {
686
0
    composeAffineMapAndOperands(map, operands);
687
0
  }
688
0
}
689
690
AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
691
                                            AffineMap map,
692
0
                                            ArrayRef<Value> operands) {
693
0
  AffineMap normalizedMap = map;
694
0
  SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
695
0
  composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
696
0
  assert(normalizedMap);
697
0
  return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
698
0
}
699
700
// A symbol may appear as a dim in affine.apply operations. This function
701
// canonicalizes dims that are valid symbols into actual symbols.
702
template <class MapOrSet>
703
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
704
0
                                        SmallVectorImpl<Value> *operands) {
705
0
  if (!mapOrSet || operands->empty())
706
0
    return;
707
0
708
0
  assert(mapOrSet->getNumInputs() == operands->size() &&
709
0
         "map/set inputs must match number of operands");
710
0
711
0
  auto *context = mapOrSet->getContext();
712
0
  SmallVector<Value, 8> resultOperands;
713
0
  resultOperands.reserve(operands->size());
714
0
  SmallVector<Value, 8> remappedSymbols;
715
0
  remappedSymbols.reserve(operands->size());
716
0
  unsigned nextDim = 0;
717
0
  unsigned nextSym = 0;
718
0
  unsigned oldNumSyms = mapOrSet->getNumSymbols();
719
0
  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
720
0
  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
721
0
    if (i < mapOrSet->getNumDims()) {
722
0
      if (isValidSymbol((*operands)[i])) {
723
0
        // This is a valid symbol that appears as a dim, canonicalize it.
724
0
        dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
725
0
        remappedSymbols.push_back((*operands)[i]);
726
0
      } else {
727
0
        dimRemapping[i] = getAffineDimExpr(nextDim++, context);
728
0
        resultOperands.push_back((*operands)[i]);
729
0
      }
730
0
    } else {
731
0
      resultOperands.push_back((*operands)[i]);
732
0
    }
733
0
  }
734
0
735
0
  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
736
0
  *operands = resultOperands;
737
0
  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
738
0
                                              oldNumSyms + nextSym);
739
0
740
0
  assert(mapOrSet->getNumInputs() == operands->size() &&
741
0
         "map/set inputs must match number of operands");
742
0
}
Unexecuted instantiation: AffineOps.cpp:_ZL27canonicalizePromotedSymbolsIN4mlir9AffineMapEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEE
Unexecuted instantiation: AffineOps.cpp:_ZL27canonicalizePromotedSymbolsIN4mlir10IntegerSetEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEE
743
744
// Works for either an affine map or an integer set.
745
template <class MapOrSet>
746
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
747
0
                                            SmallVectorImpl<Value> *operands) {
748
0
  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
749
0
                "Argument must be either of AffineMap or IntegerSet type");
750
0
751
0
  if (!mapOrSet || operands->empty())
752
0
    return;
753
0
754
0
  assert(mapOrSet->getNumInputs() == operands->size() &&
755
0
         "map/set inputs must match number of operands");
756
0
757
0
  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
758
0
759
0
  // Check to see what dims are used.
760
0
  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
761
0
  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
762
0
  mapOrSet->walkExprs([&](AffineExpr expr) {
763
0
    if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
764
0
      usedDims[dimExpr.getPosition()] = true;
765
0
    else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
766
0
      usedSyms[symExpr.getPosition()] = true;
767
0
  });
Unexecuted instantiation: AffineOps.cpp:_ZZL31canonicalizeMapOrSetAndOperandsIN4mlir9AffineMapEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEEENKUlNS0_10AffineExprEE_clES9_
Unexecuted instantiation: AffineOps.cpp:_ZZL31canonicalizeMapOrSetAndOperandsIN4mlir10IntegerSetEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEEENKUlNS0_10AffineExprEE_clES9_
768
0
769
0
  auto *context = mapOrSet->getContext();
770
0
771
0
  SmallVector<Value, 8> resultOperands;
772
0
  resultOperands.reserve(operands->size());
773
0
774
0
  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
775
0
  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
776
0
  unsigned nextDim = 0;
777
0
  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
778
0
    if (usedDims[i]) {
779
0
      // Remap dim positions for duplicate operands.
780
0
      auto it = seenDims.find((*operands)[i]);
781
0
      if (it == seenDims.end()) {
782
0
        dimRemapping[i] = getAffineDimExpr(nextDim++, context);
783
0
        resultOperands.push_back((*operands)[i]);
784
0
        seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
785
0
      } else {
786
0
        dimRemapping[i] = it->second;
787
0
      }
788
0
    }
789
0
  }
790
0
  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
791
0
  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
792
0
  unsigned nextSym = 0;
793
0
  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
794
0
    if (!usedSyms[i])
795
0
      continue;
796
0
    // Handle constant operands (only needed for symbolic operands since
797
0
    // constant operands in dimensional positions would have already been
798
0
    // promoted to symbolic positions above).
799
0
    IntegerAttr operandCst;
800
0
    if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
801
0
                     m_Constant(&operandCst))) {
802
0
      symRemapping[i] =
803
0
          getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
804
0
      continue;
805
0
    }
806
0
    // Remap symbol positions for duplicate operands.
807
0
    auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
808
0
    if (it == seenSymbols.end()) {
809
0
      symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
810
0
      resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
811
0
      seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
812
0
                                        symRemapping[i]));
813
0
    } else {
814
0
      symRemapping[i] = it->second;
815
0
    }
816
0
  }
817
0
  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
818
0
                                              nextDim, nextSym);
819
0
  *operands = resultOperands;
820
0
}
Unexecuted instantiation: AffineOps.cpp:_ZL31canonicalizeMapOrSetAndOperandsIN4mlir9AffineMapEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEE
Unexecuted instantiation: AffineOps.cpp:_ZL31canonicalizeMapOrSetAndOperandsIN4mlir10IntegerSetEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEE
821
822
void mlir::canonicalizeMapAndOperands(AffineMap *map,
823
0
                                      SmallVectorImpl<Value> *operands) {
824
0
  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
825
0
}
826
827
void mlir::canonicalizeSetAndOperands(IntegerSet *set,
828
0
                                      SmallVectorImpl<Value> *operands) {
829
0
  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
830
0
}
831
832
namespace {
833
/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
834
/// maps that supply results into them.
835
///
836
template <typename AffineOpTy>
837
struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
838
  using OpRewritePattern<AffineOpTy>::OpRewritePattern;
839
840
  /// Replace the affine op with another instance of it with the supplied
841
  /// map and mapOperands.
842
  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
843
                       AffineMap map, ArrayRef<Value> mapOperands) const;
844
845
  LogicalResult matchAndRewrite(AffineOpTy affineOp,
846
0
                                PatternRewriter &rewriter) const override {
847
0
    static_assert(llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
848
0
                                  AffineStoreOp, AffineApplyOp, AffineMinOp,
849
0
                                  AffineMaxOp>::value,
850
0
                  "affine load/store/apply/prefetch/min/max op expected");
851
0
    auto map = affineOp.getAffineMap();
852
0
    AffineMap oldMap = map;
853
0
    auto oldOperands = affineOp.getMapOperands();
854
0
    SmallVector<Value, 8> resultOperands(oldOperands);
855
0
    composeAffineMapAndOperands(&map, &resultOperands);
856
0
    if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
857
0
                                    resultOperands.begin()))
858
0
      return failure();
859
0
860
0
    replaceAffineOp(rewriter, affineOp, map, resultOperands);
861
0
    return success();
862
0
  }
Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir13AffineApplyOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE
Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir12AffineLoadOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE
Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir13AffineStoreOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE
Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir11AffineMinOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE
Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir11AffineMaxOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE
Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir16AffinePrefetchOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE
863
};
864
865
// Specialize the template to account for the different build signatures for
866
// affine load, store, and apply ops.
867
template <>
868
void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
869
    PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
870
0
    ArrayRef<Value> mapOperands) const {
871
0
  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
872
0
                                            mapOperands);
873
0
}
874
template <>
875
void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
876
    PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
877
0
    ArrayRef<Value> mapOperands) const {
878
0
  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
879
0
      prefetch, prefetch.memref(), map, mapOperands,
880
0
      prefetch.localityHint().getZExtValue(), prefetch.isWrite(),
881
0
      prefetch.isDataCache());
882
0
}
883
template <>
884
void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
885
    PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
886
0
    ArrayRef<Value> mapOperands) const {
887
0
  rewriter.replaceOpWithNewOp<AffineStoreOp>(
888
0
      store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
889
0
}
890
891
// Generic version for ops that don't have extra operands.
892
template <typename AffineOpTy>
893
void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
894
    PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
895
0
    ArrayRef<Value> mapOperands) const {
896
0
  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
897
0
}
Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir13AffineApplyOpEE15replaceAffineOpERNS1_15PatternRewriterES2_NS1_9AffineMapEN4llvm8ArrayRefINS1_5ValueEEE
Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir11AffineMinOpEE15replaceAffineOpERNS1_15PatternRewriterES2_NS1_9AffineMapEN4llvm8ArrayRefINS1_5ValueEEE
Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir11AffineMaxOpEE15replaceAffineOpERNS1_15PatternRewriterES2_NS1_9AffineMapEN4llvm8ArrayRefINS1_5ValueEEE
898
} // end anonymous namespace.
899
900
void AffineApplyOp::getCanonicalizationPatterns(
901
0
    OwningRewritePatternList &results, MLIRContext *context) {
902
0
  results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
903
0
}
904
905
//===----------------------------------------------------------------------===//
906
// Common canonicalization pattern support logic
907
//===----------------------------------------------------------------------===//
908
909
/// This is a common class used for patterns of the form
910
/// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
911
/// into the root operation directly.
912
0
static LogicalResult foldMemRefCast(Operation *op) {
913
0
  bool folded = false;
914
0
  for (OpOperand &operand : op->getOpOperands()) {
915
0
    auto cast = operand.get().getDefiningOp<MemRefCastOp>();
916
0
    if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
917
0
      operand.set(cast.getOperand());
918
0
      folded = true;
919
0
    }
920
0
  }
921
0
  return success(folded);
922
0
}
923
924
//===----------------------------------------------------------------------===//
925
// AffineDmaStartOp
926
//===----------------------------------------------------------------------===//
927
928
// TODO(b/133776335) Check that map operands are loop IVs or symbols.
929
void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
930
                             Value srcMemRef, AffineMap srcMap,
931
                             ValueRange srcIndices, Value destMemRef,
932
                             AffineMap dstMap, ValueRange destIndices,
933
                             Value tagMemRef, AffineMap tagMap,
934
                             ValueRange tagIndices, Value numElements,
935
0
                             Value stride, Value elementsPerStride) {
936
0
  result.addOperands(srcMemRef);
937
0
  result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
938
0
  result.addOperands(srcIndices);
939
0
  result.addOperands(destMemRef);
940
0
  result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap));
941
0
  result.addOperands(destIndices);
942
0
  result.addOperands(tagMemRef);
943
0
  result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
944
0
  result.addOperands(tagIndices);
945
0
  result.addOperands(numElements);
946
0
  if (stride) {
947
0
    result.addOperands({stride, elementsPerStride});
948
0
  }
949
0
}
950
951
0
void AffineDmaStartOp::print(OpAsmPrinter &p) {
952
0
  p << "affine.dma_start " << getSrcMemRef() << '[';
953
0
  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
954
0
  p << "], " << getDstMemRef() << '[';
955
0
  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
956
0
  p << "], " << getTagMemRef() << '[';
957
0
  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
958
0
  p << "], " << getNumElements();
959
0
  if (isStrided()) {
960
0
    p << ", " << getStride();
961
0
    p << ", " << getNumElementsPerStride();
962
0
  }
963
0
  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
964
0
    << getTagMemRefType();
965
0
}
966
967
// Parse AffineDmaStartOp.
968
// Ex:
969
//   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
970
//     %stride, %num_elt_per_stride
971
//       : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
972
//
973
ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
974
0
                                    OperationState &result) {
975
0
  OpAsmParser::OperandType srcMemRefInfo;
976
0
  AffineMapAttr srcMapAttr;
977
0
  SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
978
0
  OpAsmParser::OperandType dstMemRefInfo;
979
0
  AffineMapAttr dstMapAttr;
980
0
  SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
981
0
  OpAsmParser::OperandType tagMemRefInfo;
982
0
  AffineMapAttr tagMapAttr;
983
0
  SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
984
0
  OpAsmParser::OperandType numElementsInfo;
985
0
  SmallVector<OpAsmParser::OperandType, 2> strideInfo;
986
0
987
0
  SmallVector<Type, 3> types;
988
0
  auto indexType = parser.getBuilder().getIndexType();
989
0
990
0
  // Parse and resolve the following list of operands:
991
0
  // *) dst memref followed by its affine maps operands (in square brackets).
992
0
  // *) src memref followed by its affine map operands (in square brackets).
993
0
  // *) tag memref followed by its affine map operands (in square brackets).
994
0
  // *) number of elements transferred by DMA operation.
995
0
  if (parser.parseOperand(srcMemRefInfo) ||
996
0
      parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
997
0
                                    getSrcMapAttrName(), result.attributes) ||
998
0
      parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
999
0
      parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1000
0
                                    getDstMapAttrName(), result.attributes) ||
1001
0
      parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1002
0
      parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1003
0
                                    getTagMapAttrName(), result.attributes) ||
1004
0
      parser.parseComma() || parser.parseOperand(numElementsInfo))
1005
0
    return failure();
1006
0
1007
0
  // Parse optional stride and elements per stride.
1008
0
  if (parser.parseTrailingOperandList(strideInfo)) {
1009
0
    return failure();
1010
0
  }
1011
0
  if (!strideInfo.empty() && strideInfo.size() != 2) {
1012
0
    return parser.emitError(parser.getNameLoc(),
1013
0
                            "expected two stride related operands");
1014
0
  }
1015
0
  bool isStrided = strideInfo.size() == 2;
1016
0
1017
0
  if (parser.parseColonTypeList(types))
1018
0
    return failure();
1019
0
1020
0
  if (types.size() != 3)
1021
0
    return parser.emitError(parser.getNameLoc(), "expected three types");
1022
0
1023
0
  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1024
0
      parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1025
0
      parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1026
0
      parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1027
0
      parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1028
0
      parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1029
0
      parser.resolveOperand(numElementsInfo, indexType, result.operands))
1030
0
    return failure();
1031
0
1032
0
  if (isStrided) {
1033
0
    if (parser.resolveOperands(strideInfo, indexType, result.operands))
1034
0
      return failure();
1035
0
  }
1036
0
1037
0
  // Check that src/dst/tag operand counts match their map.numInputs.
1038
0
  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1039
0
      dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1040
0
      tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1041
0
    return parser.emitError(parser.getNameLoc(),
1042
0
                            "memref operand count not equal to map.numInputs");
1043
0
  return success();
1044
0
}
1045
1046
0
LogicalResult AffineDmaStartOp::verify() {
1047
0
  if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
1048
0
    return emitOpError("expected DMA source to be of memref type");
1049
0
  if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
1050
0
    return emitOpError("expected DMA destination to be of memref type");
1051
0
  if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
1052
0
    return emitOpError("expected DMA tag to be of memref type");
1053
0
1054
0
  // DMAs from different memory spaces supported.
1055
0
  if (getSrcMemorySpace() == getDstMemorySpace()) {
1056
0
    return emitOpError("DMA should be between different memory spaces");
1057
0
  }
1058
0
  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1059
0
                              getDstMap().getNumInputs() +
1060
0
                              getTagMap().getNumInputs();
1061
0
  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1062
0
      getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1063
0
    return emitOpError("incorrect number of operands");
1064
0
  }
1065
0
1066
0
  Region *scope = getAffineScope(*this);
1067
0
  for (auto idx : getSrcIndices()) {
1068
0
    if (!idx.getType().isIndex())
1069
0
      return emitOpError("src index to dma_start must have 'index' type");
1070
0
    if (!isValidAffineIndexOperand(idx, scope))
1071
0
      return emitOpError("src index must be a dimension or symbol identifier");
1072
0
  }
1073
0
  for (auto idx : getDstIndices()) {
1074
0
    if (!idx.getType().isIndex())
1075
0
      return emitOpError("dst index to dma_start must have 'index' type");
1076
0
    if (!isValidAffineIndexOperand(idx, scope))
1077
0
      return emitOpError("dst index must be a dimension or symbol identifier");
1078
0
  }
1079
0
  for (auto idx : getTagIndices()) {
1080
0
    if (!idx.getType().isIndex())
1081
0
      return emitOpError("tag index to dma_start must have 'index' type");
1082
0
    if (!isValidAffineIndexOperand(idx, scope))
1083
0
      return emitOpError("tag index must be a dimension or symbol identifier");
1084
0
  }
1085
0
  return success();
1086
0
}
1087
1088
LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1089
0
                                     SmallVectorImpl<OpFoldResult> &results) {
1090
0
  /// dma_start(memrefcast) -> dma_start
1091
0
  return foldMemRefCast(*this);
1092
0
}
1093
1094
//===----------------------------------------------------------------------===//
1095
// AffineDmaWaitOp
1096
//===----------------------------------------------------------------------===//
1097
1098
// TODO(b/133776335) Check that map operands are loop IVs or symbols.
1099
void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1100
                            Value tagMemRef, AffineMap tagMap,
1101
0
                            ValueRange tagIndices, Value numElements) {
1102
0
  result.addOperands(tagMemRef);
1103
0
  result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
1104
0
  result.addOperands(tagIndices);
1105
0
  result.addOperands(numElements);
1106
0
}
1107
1108
0
void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1109
0
  p << "affine.dma_wait " << getTagMemRef() << '[';
1110
0
  SmallVector<Value, 2> operands(getTagIndices());
1111
0
  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1112
0
  p << "], ";
1113
0
  p.printOperand(getNumElements());
1114
0
  p << " : " << getTagMemRef().getType();
1115
0
}
1116
1117
// Parse AffineDmaWaitOp.
1118
// Eg:
1119
//   affine.dma_wait %tag[%index], %num_elements
1120
//     : memref<1 x i32, (d0) -> (d0), 4>
1121
//
1122
ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1123
0
                                   OperationState &result) {
1124
0
  OpAsmParser::OperandType tagMemRefInfo;
1125
0
  AffineMapAttr tagMapAttr;
1126
0
  SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
1127
0
  Type type;
1128
0
  auto indexType = parser.getBuilder().getIndexType();
1129
0
  OpAsmParser::OperandType numElementsInfo;
1130
0
1131
0
  // Parse tag memref, its map operands, and dma size.
1132
0
  if (parser.parseOperand(tagMemRefInfo) ||
1133
0
      parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1134
0
                                    getTagMapAttrName(), result.attributes) ||
1135
0
      parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1136
0
      parser.parseColonType(type) ||
1137
0
      parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1138
0
      parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1139
0
      parser.resolveOperand(numElementsInfo, indexType, result.operands))
1140
0
    return failure();
1141
0
1142
0
  if (!type.isa<MemRefType>())
1143
0
    return parser.emitError(parser.getNameLoc(),
1144
0
                            "expected tag to be of memref type");
1145
0
1146
0
  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1147
0
    return parser.emitError(parser.getNameLoc(),
1148
0
                            "tag memref operand count != to map.numInputs");
1149
0
  return success();
1150
0
}
1151
1152
0
LogicalResult AffineDmaWaitOp::verify() {
1153
0
  if (!getOperand(0).getType().isa<MemRefType>())
1154
0
    return emitOpError("expected DMA tag to be of memref type");
1155
0
  Region *scope = getAffineScope(*this);
1156
0
  for (auto idx : getTagIndices()) {
1157
0
    if (!idx.getType().isIndex())
1158
0
      return emitOpError("index to dma_wait must have 'index' type");
1159
0
    if (!isValidAffineIndexOperand(idx, scope))
1160
0
      return emitOpError("index must be a dimension or symbol identifier");
1161
0
  }
1162
0
  return success();
1163
0
}
1164
1165
LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1166
0
                                    SmallVectorImpl<OpFoldResult> &results) {
1167
0
  /// dma_wait(memrefcast) -> dma_wait
1168
0
  return foldMemRefCast(*this);
1169
0
}
1170
1171
//===----------------------------------------------------------------------===//
1172
// AffineForOp
1173
//===----------------------------------------------------------------------===//
1174
1175
void AffineForOp::build(OpBuilder &builder, OperationState &result,
1176
                        ValueRange lbOperands, AffineMap lbMap,
1177
0
                        ValueRange ubOperands, AffineMap ubMap, int64_t step) {
1178
0
  assert(((!lbMap && lbOperands.empty()) ||
1179
0
          lbOperands.size() == lbMap.getNumInputs()) &&
1180
0
         "lower bound operand count does not match the affine map");
1181
0
  assert(((!ubMap && ubOperands.empty()) ||
1182
0
          ubOperands.size() == ubMap.getNumInputs()) &&
1183
0
         "upper bound operand count does not match the affine map");
1184
0
  assert(step > 0 && "step has to be a positive integer constant");
1185
0
1186
0
  // Add an attribute for the step.
1187
0
  result.addAttribute(getStepAttrName(),
1188
0
                      builder.getIntegerAttr(builder.getIndexType(), step));
1189
0
1190
0
  // Add the lower bound.
1191
0
  result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap));
1192
0
  result.addOperands(lbOperands);
1193
0
1194
0
  // Add the upper bound.
1195
0
  result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
1196
0
  result.addOperands(ubOperands);
1197
0
1198
0
  // Create a region and a block for the body.  The argument of the region is
1199
0
  // the loop induction variable.
1200
0
  Region *bodyRegion = result.addRegion();
1201
0
  Block *body = new Block();
1202
0
  body->addArgument(IndexType::get(builder.getContext()));
1203
0
  bodyRegion->push_back(body);
1204
0
  ensureTerminator(*bodyRegion, builder, result.location);
1205
0
}
1206
1207
void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1208
0
                        int64_t ub, int64_t step) {
1209
0
  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1210
0
  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1211
0
  return build(builder, result, {}, lbMap, {}, ubMap, step);
1212
0
}
1213
1214
0
static LogicalResult verify(AffineForOp op) {
1215
0
  // Check that the body defines as single block argument for the induction
1216
0
  // variable.
1217
0
  auto *body = op.getBody();
1218
0
  if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex())
1219
0
    return op.emitOpError(
1220
0
        "expected body to have a single index argument for the "
1221
0
        "induction variable");
1222
0
1223
0
  // Verify that there are enough operands for the bounds.
1224
0
  AffineMap lowerBoundMap = op.getLowerBoundMap(),
1225
0
            upperBoundMap = op.getUpperBoundMap();
1226
0
  if (op.getNumOperands() !=
1227
0
      (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs()))
1228
0
    return op.emitOpError(
1229
0
        "operand count must match with affine map dimension and symbol count");
1230
0
1231
0
  // Verify that the bound operands are valid dimension/symbols.
1232
0
  /// Lower bound.
1233
0
  if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
1234
0
                                           op.getLowerBoundMap().getNumDims())))
1235
0
    return failure();
1236
0
  /// Upper bound.
1237
0
  if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
1238
0
                                           op.getUpperBoundMap().getNumDims())))
1239
0
    return failure();
1240
0
  return success();
1241
0
}
1242
1243
/// Parse a for operation loop bounds.
1244
static ParseResult parseBound(bool isLower, OperationState &result,
1245
0
                              OpAsmParser &p) {
1246
0
  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1247
0
  // the map has multiple results.
1248
0
  bool failedToParsedMinMax =
1249
0
      failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1250
0
1251
0
  auto &builder = p.getBuilder();
1252
0
  auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
1253
0
                               : AffineForOp::getUpperBoundAttrName();
1254
0
1255
0
  // Parse ssa-id as identity map.
1256
0
  SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
1257
0
  if (p.parseOperandList(boundOpInfos))
1258
0
    return failure();
1259
0
1260
0
  if (!boundOpInfos.empty()) {
1261
0
    // Check that only one operand was parsed.
1262
0
    if (boundOpInfos.size() > 1)
1263
0
      return p.emitError(p.getNameLoc(),
1264
0
                         "expected only one loop bound operand");
1265
0
1266
0
    // TODO: improve error message when SSA value is not of index type.
1267
0
    // Currently it is 'use of value ... expects different type than prior uses'
1268
0
    if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1269
0
                         result.operands))
1270
0
      return failure();
1271
0
1272
0
    // Create an identity map using symbol id. This representation is optimized
1273
0
    // for storage. Analysis passes may expand it into a multi-dimensional map
1274
0
    // if desired.
1275
0
    AffineMap map = builder.getSymbolIdentityMap();
1276
0
    result.addAttribute(boundAttrName, AffineMapAttr::get(map));
1277
0
    return success();
1278
0
  }
1279
0
1280
0
  // Get the attribute location.
1281
0
  llvm::SMLoc attrLoc = p.getCurrentLocation();
1282
0
1283
0
  Attribute boundAttr;
1284
0
  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
1285
0
                       result.attributes))
1286
0
    return failure();
1287
0
1288
0
  // Parse full form - affine map followed by dim and symbol list.
1289
0
  if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1290
0
    unsigned currentNumOperands = result.operands.size();
1291
0
    unsigned numDims;
1292
0
    if (parseDimAndSymbolList(p, result.operands, numDims))
1293
0
      return failure();
1294
0
1295
0
    auto map = affineMapAttr.getValue();
1296
0
    if (map.getNumDims() != numDims)
1297
0
      return p.emitError(
1298
0
          p.getNameLoc(),
1299
0
          "dim operand count and affine map dim count must match");
1300
0
1301
0
    unsigned numDimAndSymbolOperands =
1302
0
        result.operands.size() - currentNumOperands;
1303
0
    if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1304
0
      return p.emitError(
1305
0
          p.getNameLoc(),
1306
0
          "symbol operand count and affine map symbol count must match");
1307
0
1308
0
    // If the map has multiple results, make sure that we parsed the min/max
1309
0
    // prefix.
1310
0
    if (map.getNumResults() > 1 && failedToParsedMinMax) {
1311
0
      if (isLower) {
1312
0
        return p.emitError(attrLoc, "lower loop bound affine map with "
1313
0
                                    "multiple results requires 'max' prefix");
1314
0
      }
1315
0
      return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1316
0
                                  "results requires 'min' prefix");
1317
0
    }
1318
0
    return success();
1319
0
  }
1320
0
1321
0
  // Parse custom assembly form.
1322
0
  if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1323
0
    result.attributes.pop_back();
1324
0
    result.addAttribute(
1325
0
        boundAttrName,
1326
0
        AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1327
0
    return success();
1328
0
  }
1329
0
1330
0
  return p.emitError(
1331
0
      p.getNameLoc(),
1332
0
      "expected valid affine map representation for loop bounds");
1333
0
}
1334
1335
static ParseResult parseAffineForOp(OpAsmParser &parser,
1336
0
                                    OperationState &result) {
1337
0
  auto &builder = parser.getBuilder();
1338
0
  OpAsmParser::OperandType inductionVariable;
1339
0
  // Parse the induction variable followed by '='.
1340
0
  if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
1341
0
    return failure();
1342
0
1343
0
  // Parse loop bounds.
1344
0
  if (parseBound(/*isLower=*/true, result, parser) ||
1345
0
      parser.parseKeyword("to", " between bounds") ||
1346
0
      parseBound(/*isLower=*/false, result, parser))
1347
0
    return failure();
1348
0
1349
0
  // Parse the optional loop step, we default to 1 if one is not present.
1350
0
  if (parser.parseOptionalKeyword("step")) {
1351
0
    result.addAttribute(
1352
0
        AffineForOp::getStepAttrName(),
1353
0
        builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1354
0
  } else {
1355
0
    llvm::SMLoc stepLoc = parser.getCurrentLocation();
1356
0
    IntegerAttr stepAttr;
1357
0
    if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1358
0
                              AffineForOp::getStepAttrName().data(),
1359
0
                              result.attributes))
1360
0
      return failure();
1361
0
1362
0
    if (stepAttr.getValue().getSExtValue() < 0)
1363
0
      return parser.emitError(
1364
0
          stepLoc,
1365
0
          "expected step to be representable as a positive signed integer");
1366
0
  }
1367
0
1368
0
  // Parse the body region.
1369
0
  Region *body = result.addRegion();
1370
0
  if (parser.parseRegion(*body, inductionVariable, builder.getIndexType()))
1371
0
    return failure();
1372
0
1373
0
  AffineForOp::ensureTerminator(*body, builder, result.location);
1374
0
1375
0
  // Parse the optional attribute list.
1376
0
  return parser.parseOptionalAttrDict(result.attributes);
1377
0
}
1378
1379
static void printBound(AffineMapAttr boundMap,
1380
                       Operation::operand_range boundOperands,
1381
0
                       const char *prefix, OpAsmPrinter &p) {
1382
0
  AffineMap map = boundMap.getValue();
1383
0
1384
0
  // Check if this bound should be printed using custom assembly form.
1385
0
  // The decision to restrict printing custom assembly form to trivial cases
1386
0
  // comes from the will to roundtrip MLIR binary -> text -> binary in a
1387
0
  // lossless way.
1388
0
  // Therefore, custom assembly form parsing and printing is only supported for
1389
0
  // zero-operand constant maps and single symbol operand identity maps.
1390
0
  if (map.getNumResults() == 1) {
1391
0
    AffineExpr expr = map.getResult(0);
1392
0
1393
0
    // Print constant bound.
1394
0
    if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1395
0
      if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1396
0
        p << constExpr.getValue();
1397
0
        return;
1398
0
      }
1399
0
    }
1400
0
1401
0
    // Print bound that consists of a single SSA symbol if the map is over a
1402
0
    // single symbol.
1403
0
    if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1404
0
      if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1405
0
        p.printOperand(*boundOperands.begin());
1406
0
        return;
1407
0
      }
1408
0
    }
1409
0
  } else {
1410
0
    // Map has multiple results. Print 'min' or 'max' prefix.
1411
0
    p << prefix << ' ';
1412
0
  }
1413
0
1414
0
  // Print the map and its operands.
1415
0
  p << boundMap;
1416
0
  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1417
0
                        map.getNumDims(), p);
1418
0
}
1419
1420
0
static void print(OpAsmPrinter &p, AffineForOp op) {
1421
0
  p << op.getOperationName() << ' ';
1422
0
  p.printOperand(op.getBody()->getArgument(0));
1423
0
  p << " = ";
1424
0
  printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
1425
0
  p << " to ";
1426
0
  printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
1427
0
1428
0
  if (op.getStep() != 1)
1429
0
    p << " step " << op.getStep();
1430
0
  p.printRegion(op.region(),
1431
0
                /*printEntryBlockArgs=*/false,
1432
0
                /*printBlockTerminators=*/false);
1433
0
  p.printOptionalAttrDict(op.getAttrs(),
1434
0
                          /*elidedAttrs=*/{op.getLowerBoundAttrName(),
1435
0
                                           op.getUpperBoundAttrName(),
1436
0
                                           op.getStepAttrName()});
1437
0
}
1438
1439
/// Fold the constant bounds of a loop.
1440
0
static LogicalResult foldLoopBounds(AffineForOp forOp) {
1441
0
  auto foldLowerOrUpperBound = [&forOp](bool lower) {
1442
0
    // Check to see if each of the operands is the result of a constant.  If
1443
0
    // so, get the value.  If not, ignore it.
1444
0
    SmallVector<Attribute, 8> operandConstants;
1445
0
    auto boundOperands =
1446
0
        lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1447
0
    for (auto operand : boundOperands) {
1448
0
      Attribute operandCst;
1449
0
      matchPattern(operand, m_Constant(&operandCst));
1450
0
      operandConstants.push_back(operandCst);
1451
0
    }
1452
0
1453
0
    AffineMap boundMap =
1454
0
        lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1455
0
    assert(boundMap.getNumResults() >= 1 &&
1456
0
           "bound maps should have at least one result");
1457
0
    SmallVector<Attribute, 4> foldedResults;
1458
0
    if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1459
0
      return failure();
1460
0
1461
0
    // Compute the max or min as applicable over the results.
1462
0
    assert(!foldedResults.empty() && "bounds should have at least one result");
1463
0
    auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1464
0
    for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1465
0
      auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1466
0
      maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1467
0
                       : llvm::APIntOps::smin(maxOrMin, foldedResult);
1468
0
    }
1469
0
    lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1470
0
          : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1471
0
    return success();
1472
0
  };
1473
0
1474
0
  // Try to fold the lower bound.
1475
0
  bool folded = false;
1476
0
  if (!forOp.hasConstantLowerBound())
1477
0
    folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1478
0
1479
0
  // Try to fold the upper bound.
1480
0
  if (!forOp.hasConstantUpperBound())
1481
0
    folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1482
0
  return success(folded);
1483
0
}
1484
1485
/// Canonicalize the bounds of the given loop.
1486
0
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1487
0
  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1488
0
  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1489
0
1490
0
  auto lbMap = forOp.getLowerBoundMap();
1491
0
  auto ubMap = forOp.getUpperBoundMap();
1492
0
  auto prevLbMap = lbMap;
1493
0
  auto prevUbMap = ubMap;
1494
0
1495
0
  canonicalizeMapAndOperands(&lbMap, &lbOperands);
1496
0
  lbMap = removeDuplicateExprs(lbMap);
1497
0
1498
0
  canonicalizeMapAndOperands(&ubMap, &ubOperands);
1499
0
  ubMap = removeDuplicateExprs(ubMap);
1500
0
1501
0
  // Any canonicalization change always leads to updated map(s).
1502
0
  if (lbMap == prevLbMap && ubMap == prevUbMap)
1503
0
    return failure();
1504
0
1505
0
  if (lbMap != prevLbMap)
1506
0
    forOp.setLowerBound(lbOperands, lbMap);
1507
0
  if (ubMap != prevUbMap)
1508
0
    forOp.setUpperBound(ubOperands, ubMap);
1509
0
  return success();
1510
0
}
1511
1512
namespace {
1513
/// This is a pattern to fold trivially empty loops.
1514
struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1515
  using OpRewritePattern<AffineForOp>::OpRewritePattern;
1516
1517
  LogicalResult matchAndRewrite(AffineForOp forOp,
1518
0
                                PatternRewriter &rewriter) const override {
1519
0
    // Check that the body only contains a terminator.
1520
0
    if (!llvm::hasSingleElement(*forOp.getBody()))
1521
0
      return failure();
1522
0
    rewriter.eraseOp(forOp);
1523
0
    return success();
1524
0
  }
1525
};
1526
} // end anonymous namespace
1527
1528
void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1529
0
                                              MLIRContext *context) {
1530
0
  results.insert<AffineForEmptyLoopFolder>(context);
1531
0
}
1532
1533
LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1534
0
                                SmallVectorImpl<OpFoldResult> &results) {
1535
0
  bool folded = succeeded(foldLoopBounds(*this));
1536
0
  folded |= succeeded(canonicalizeLoopBounds(*this));
1537
0
  return success(folded);
1538
0
}
1539
1540
0
AffineBound AffineForOp::getLowerBound() {
1541
0
  auto lbMap = getLowerBoundMap();
1542
0
  return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1543
0
}
1544
1545
0
AffineBound AffineForOp::getUpperBound() {
1546
0
  auto lbMap = getLowerBoundMap();
1547
0
  auto ubMap = getUpperBoundMap();
1548
0
  return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(),
1549
0
                     ubMap);
1550
0
}
1551
1552
0
void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
1553
0
  assert(lbOperands.size() == map.getNumInputs());
1554
0
  assert(map.getNumResults() >= 1 && "bound map has at least one result");
1555
0
1556
0
  SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
1557
0
1558
0
  auto ubOperands = getUpperBoundOperands();
1559
0
  newOperands.append(ubOperands.begin(), ubOperands.end());
1560
0
  getOperation()->setOperands(newOperands);
1561
0
1562
0
  setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1563
0
}
1564
1565
0
void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
1566
0
  assert(ubOperands.size() == map.getNumInputs());
1567
0
  assert(map.getNumResults() >= 1 && "bound map has at least one result");
1568
0
1569
0
  SmallVector<Value, 4> newOperands(getLowerBoundOperands());
1570
0
  newOperands.append(ubOperands.begin(), ubOperands.end());
1571
0
  getOperation()->setOperands(newOperands);
1572
0
1573
0
  setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1574
0
}
1575
1576
0
void AffineForOp::setLowerBoundMap(AffineMap map) {
1577
0
  auto lbMap = getLowerBoundMap();
1578
0
  assert(lbMap.getNumDims() == map.getNumDims() &&
1579
0
         lbMap.getNumSymbols() == map.getNumSymbols());
1580
0
  assert(map.getNumResults() >= 1 && "bound map has at least one result");
1581
0
  (void)lbMap;
1582
0
  setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1583
0
}
1584
1585
0
void AffineForOp::setUpperBoundMap(AffineMap map) {
1586
0
  auto ubMap = getUpperBoundMap();
1587
0
  assert(ubMap.getNumDims() == map.getNumDims() &&
1588
0
         ubMap.getNumSymbols() == map.getNumSymbols());
1589
0
  assert(map.getNumResults() >= 1 && "bound map has at least one result");
1590
0
  (void)ubMap;
1591
0
  setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1592
0
}
1593
1594
0
bool AffineForOp::hasConstantLowerBound() {
1595
0
  return getLowerBoundMap().isSingleConstant();
1596
0
}
1597
1598
0
bool AffineForOp::hasConstantUpperBound() {
1599
0
  return getUpperBoundMap().isSingleConstant();
1600
0
}
1601
1602
0
int64_t AffineForOp::getConstantLowerBound() {
1603
0
  return getLowerBoundMap().getSingleConstantResult();
1604
0
}
1605
1606
0
int64_t AffineForOp::getConstantUpperBound() {
1607
0
  return getUpperBoundMap().getSingleConstantResult();
1608
0
}
1609
1610
0
void AffineForOp::setConstantLowerBound(int64_t value) {
1611
0
  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
1612
0
}
1613
1614
0
void AffineForOp::setConstantUpperBound(int64_t value) {
1615
0
  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
1616
0
}
1617
1618
0
AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
1619
0
  return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
1620
0
}
1621
1622
0
AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
1623
0
  return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
1624
0
}
1625
1626
0
bool AffineForOp::matchingBoundOperandList() {
1627
0
  auto lbMap = getLowerBoundMap();
1628
0
  auto ubMap = getUpperBoundMap();
1629
0
  if (lbMap.getNumDims() != ubMap.getNumDims() ||
1630
0
      lbMap.getNumSymbols() != ubMap.getNumSymbols())
1631
0
    return false;
1632
0
1633
0
  unsigned numOperands = lbMap.getNumInputs();
1634
0
  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
1635
0
    // Compare Value 's.
1636
0
    if (getOperand(i) != getOperand(numOperands + i))
1637
0
      return false;
1638
0
  }
1639
0
  return true;
1640
0
}
1641
1642
0
Region &AffineForOp::getLoopBody() { return region(); }
1643
1644
0
bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
1645
0
  return !region().isAncestor(value.getParentRegion());
1646
0
}
1647
1648
0
LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1649
0
  for (auto *op : ops)
1650
0
    op->moveBefore(*this);
1651
0
  return success();
1652
0
}
1653
1654
/// Returns if the provided value is the induction variable of a AffineForOp.
1655
0
bool mlir::isForInductionVar(Value val) {
1656
0
  return getForInductionVarOwner(val) != AffineForOp();
1657
0
}
1658
1659
/// Returns the loop parent of an induction variable. If the provided value is
1660
/// not an induction variable, then return nullptr.
1661
0
AffineForOp mlir::getForInductionVarOwner(Value val) {
1662
0
  auto ivArg = val.dyn_cast<BlockArgument>();
1663
0
  if (!ivArg || !ivArg.getOwner())
1664
0
    return AffineForOp();
1665
0
  auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
1666
0
  return dyn_cast<AffineForOp>(containingInst);
1667
0
}
1668
1669
/// Extracts the induction variables from a list of AffineForOps and returns
1670
/// them.
1671
void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
1672
0
                                   SmallVectorImpl<Value> *ivs) {
1673
0
  ivs->reserve(forInsts.size());
1674
0
  for (auto forInst : forInsts)
1675
0
    ivs->push_back(forInst.getInductionVar());
1676
0
}
1677
1678
//===----------------------------------------------------------------------===//
1679
// AffineIfOp
1680
//===----------------------------------------------------------------------===//
1681
1682
namespace {
1683
/// Remove else blocks that have nothing other than the terminator.
1684
struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
1685
  using OpRewritePattern<AffineIfOp>::OpRewritePattern;
1686
1687
  LogicalResult matchAndRewrite(AffineIfOp ifOp,
1688
0
                                PatternRewriter &rewriter) const override {
1689
0
    if (ifOp.elseRegion().empty() ||
1690
0
        !llvm::hasSingleElement(*ifOp.getElseBlock()))
1691
0
      return failure();
1692
0
1693
0
    rewriter.startRootUpdate(ifOp);
1694
0
    rewriter.eraseBlock(ifOp.getElseBlock());
1695
0
    rewriter.finalizeRootUpdate(ifOp);
1696
0
    return success();
1697
0
  }
1698
};
1699
} // end anonymous namespace.
1700
1701
0
static LogicalResult verify(AffineIfOp op) {
1702
0
  // Verify that we have a condition attribute.
1703
0
  auto conditionAttr =
1704
0
      op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1705
0
  if (!conditionAttr)
1706
0
    return op.emitOpError(
1707
0
        "requires an integer set attribute named 'condition'");
1708
0
1709
0
  // Verify that there are enough operands for the condition.
1710
0
  IntegerSet condition = conditionAttr.getValue();
1711
0
  if (op.getNumOperands() != condition.getNumInputs())
1712
0
    return op.emitOpError(
1713
0
        "operand count and condition integer set dimension and "
1714
0
        "symbol count must match");
1715
0
1716
0
  // Verify that the operands are valid dimension/symbols.
1717
0
  if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(),
1718
0
                                           condition.getNumDims())))
1719
0
    return failure();
1720
0
1721
0
  // Verify that the entry of each child region does not have arguments.
1722
0
  for (auto &region : op.getOperation()->getRegions()) {
1723
0
    for (auto &b : region)
1724
0
      if (b.getNumArguments() != 0)
1725
0
        return op.emitOpError(
1726
0
            "requires that child entry blocks have no arguments");
1727
0
  }
1728
0
  return success();
1729
0
}
1730
1731
static ParseResult parseAffineIfOp(OpAsmParser &parser,
1732
0
                                   OperationState &result) {
1733
0
  // Parse the condition attribute set.
1734
0
  IntegerSetAttr conditionAttr;
1735
0
  unsigned numDims;
1736
0
  if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
1737
0
                            result.attributes) ||
1738
0
      parseDimAndSymbolList(parser, result.operands, numDims))
1739
0
    return failure();
1740
0
1741
0
  // Verify the condition operands.
1742
0
  auto set = conditionAttr.getValue();
1743
0
  if (set.getNumDims() != numDims)
1744
0
    return parser.emitError(
1745
0
        parser.getNameLoc(),
1746
0
        "dim operand count and integer set dim count must match");
1747
0
  if (numDims + set.getNumSymbols() != result.operands.size())
1748
0
    return parser.emitError(
1749
0
        parser.getNameLoc(),
1750
0
        "symbol operand count and integer set symbol count must match");
1751
0
1752
0
  // Create the regions for 'then' and 'else'.  The latter must be created even
1753
0
  // if it remains empty for the validity of the operation.
1754
0
  result.regions.reserve(2);
1755
0
  Region *thenRegion = result.addRegion();
1756
0
  Region *elseRegion = result.addRegion();
1757
0
1758
0
  // Parse the 'then' region.
1759
0
  if (parser.parseRegion(*thenRegion, {}, {}))
1760
0
    return failure();
1761
0
  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
1762
0
                               result.location);
1763
0
1764
0
  // If we find an 'else' keyword then parse the 'else' region.
1765
0
  if (!parser.parseOptionalKeyword("else")) {
1766
0
    if (parser.parseRegion(*elseRegion, {}, {}))
1767
0
      return failure();
1768
0
    AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
1769
0
                                 result.location);
1770
0
  }
1771
0
1772
0
  // Parse the optional attribute list.
1773
0
  if (parser.parseOptionalAttrDict(result.attributes))
1774
0
    return failure();
1775
0
1776
0
  return success();
1777
0
}
1778
1779
0
static void print(OpAsmPrinter &p, AffineIfOp op) {
1780
0
  auto conditionAttr =
1781
0
      op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1782
0
  p << "affine.if " << conditionAttr;
1783
0
  printDimAndSymbolList(op.operand_begin(), op.operand_end(),
1784
0
                        conditionAttr.getValue().getNumDims(), p);
1785
0
  p.printRegion(op.thenRegion(),
1786
0
                /*printEntryBlockArgs=*/false,
1787
0
                /*printBlockTerminators=*/false);
1788
0
1789
0
  // Print the 'else' regions if it has any blocks.
1790
0
  auto &elseRegion = op.elseRegion();
1791
0
  if (!elseRegion.empty()) {
1792
0
    p << " else";
1793
0
    p.printRegion(elseRegion,
1794
0
                  /*printEntryBlockArgs=*/false,
1795
0
                  /*printBlockTerminators=*/false);
1796
0
  }
1797
0
1798
0
  // Print the attribute list.
1799
0
  p.printOptionalAttrDict(op.getAttrs(),
1800
0
                          /*elidedAttrs=*/op.getConditionAttrName());
1801
0
}
1802
1803
0
IntegerSet AffineIfOp::getIntegerSet() {
1804
0
  return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue();
1805
0
}
1806
0
void AffineIfOp::setIntegerSet(IntegerSet newSet) {
1807
0
  setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
1808
0
}
1809
1810
0
void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
1811
0
  setIntegerSet(set);
1812
0
  getOperation()->setOperands(operands);
1813
0
}
1814
1815
void AffineIfOp::build(OpBuilder &builder, OperationState &result,
1816
0
                       IntegerSet set, ValueRange args, bool withElseRegion) {
1817
0
  result.addOperands(args);
1818
0
  result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
1819
0
  Region *thenRegion = result.addRegion();
1820
0
  Region *elseRegion = result.addRegion();
1821
0
  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
1822
0
  if (withElseRegion)
1823
0
    AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
1824
0
}
1825
1826
/// Canonicalize an affine if op's conditional (integer set + operands).
1827
LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
1828
0
                               SmallVectorImpl<OpFoldResult> &) {
1829
0
  auto set = getIntegerSet();
1830
0
  SmallVector<Value, 4> operands(getOperands());
1831
0
  canonicalizeSetAndOperands(&set, &operands);
1832
0
1833
0
  // Any canonicalization change always leads to either a reduction in the
1834
0
  // number of operands or a change in the number of symbolic operands
1835
0
  // (promotion of dims to symbols).
1836
0
  if (operands.size() < getIntegerSet().getNumInputs() ||
1837
0
      set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
1838
0
    setConditional(set, operands);
1839
0
    return success();
1840
0
  }
1841
0
1842
0
  return failure();
1843
0
}
1844
1845
void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1846
0
                                             MLIRContext *context) {
1847
0
  results.insert<SimplifyDeadElse>(context);
1848
0
}
1849
1850
//===----------------------------------------------------------------------===//
1851
// AffineLoadOp
1852
//===----------------------------------------------------------------------===//
1853
1854
void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
1855
0
                         AffineMap map, ValueRange operands) {
1856
0
  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
1857
0
  result.addOperands(operands);
1858
0
  if (map)
1859
0
    result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
1860
0
  auto memrefType = operands[0].getType().cast<MemRefType>();
1861
0
  result.types.push_back(memrefType.getElementType());
1862
0
}
1863
1864
void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
1865
0
                         Value memref, AffineMap map, ValueRange mapOperands) {
1866
0
  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
1867
0
  result.addOperands(memref);
1868
0
  result.addOperands(mapOperands);
1869
0
  auto memrefType = memref.getType().cast<MemRefType>();
1870
0
  result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
1871
0
  result.types.push_back(memrefType.getElementType());
1872
0
}
1873
1874
void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
1875
0
                         Value memref, ValueRange indices) {
1876
0
  auto memrefType = memref.getType().cast<MemRefType>();
1877
0
  auto rank = memrefType.getRank();
1878
0
  // Create identity map for memrefs with at least one dimension or () -> ()
1879
0
  // for zero-dimensional memrefs.
1880
0
  auto map =
1881
0
      rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
1882
0
  build(builder, result, memref, map, indices);
1883
0
}
1884
1885
static ParseResult parseAffineLoadOp(OpAsmParser &parser,
1886
0
                                     OperationState &result) {
1887
0
  auto &builder = parser.getBuilder();
1888
0
  auto indexTy = builder.getIndexType();
1889
0
1890
0
  MemRefType type;
1891
0
  OpAsmParser::OperandType memrefInfo;
1892
0
  AffineMapAttr mapAttr;
1893
0
  SmallVector<OpAsmParser::OperandType, 1> mapOperands;
1894
0
  return failure(
1895
0
      parser.parseOperand(memrefInfo) ||
1896
0
      parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
1897
0
                                    AffineLoadOp::getMapAttrName(),
1898
0
                                    result.attributes) ||
1899
0
      parser.parseOptionalAttrDict(result.attributes) ||
1900
0
      parser.parseColonType(type) ||
1901
0
      parser.resolveOperand(memrefInfo, type, result.operands) ||
1902
0
      parser.resolveOperands(mapOperands, indexTy, result.operands) ||
1903
0
      parser.addTypeToList(type.getElementType(), result.types));
1904
0
}
1905
1906
0
static void print(OpAsmPrinter &p, AffineLoadOp op) {
1907
0
  p << "affine.load " << op.getMemRef() << '[';
1908
0
  if (AffineMapAttr mapAttr =
1909
0
          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
1910
0
    p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
1911
0
  p << ']';
1912
0
  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
1913
0
  p << " : " << op.getMemRefType();
1914
0
}
1915
1916
/// Verify common indexing invariants of affine.load, affine.store,
1917
/// affine.vector_load and affine.vector_store.
1918
static LogicalResult
1919
verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
1920
                       Operation::operand_range mapOperands,
1921
0
                       MemRefType memrefType, unsigned numIndexOperands) {
1922
0
  if (mapAttr) {
1923
0
    AffineMap map = mapAttr.getValue();
1924
0
    if (map.getNumResults() != memrefType.getRank())
1925
0
      return op->emitOpError("affine map num results must equal memref rank");
1926
0
    if (map.getNumInputs() != numIndexOperands)
1927
0
      return op->emitOpError("expects as many subscripts as affine map inputs");
1928
0
  } else {
1929
0
    if (memrefType.getRank() != numIndexOperands)
1930
0
      return op->emitOpError(
1931
0
          "expects the number of subscripts to be equal to memref rank");
1932
0
  }
1933
0
1934
0
  Region *scope = getAffineScope(op);
1935
0
  for (auto idx : mapOperands) {
1936
0
    if (!idx.getType().isIndex())
1937
0
      return op->emitOpError("index to load must have 'index' type");
1938
0
    if (!isValidAffineIndexOperand(idx, scope))
1939
0
      return op->emitOpError("index must be a dimension or symbol identifier");
1940
0
  }
1941
0
1942
0
  return success();
1943
0
}
1944
1945
0
LogicalResult verify(AffineLoadOp op) {
1946
0
  auto memrefType = op.getMemRefType();
1947
0
  if (op.getType() != memrefType.getElementType())
1948
0
    return op.emitOpError("result type must match element type of memref");
1949
0
1950
0
  if (failed(verifyMemoryOpIndexing(
1951
0
          op.getOperation(),
1952
0
          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
1953
0
          op.getMapOperands(), memrefType,
1954
0
          /*numIndexOperands=*/op.getNumOperands() - 1)))
1955
0
    return failure();
1956
0
1957
0
  return success();
1958
0
}
1959
1960
void AffineLoadOp::getCanonicalizationPatterns(
1961
0
    OwningRewritePatternList &results, MLIRContext *context) {
1962
0
  results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
1963
0
}
1964
1965
0
OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
1966
0
  /// load(memrefcast) -> load
1967
0
  if (succeeded(foldMemRefCast(*this)))
1968
0
    return getResult();
1969
0
  return OpFoldResult();
1970
0
}
1971
1972
//===----------------------------------------------------------------------===//
1973
// AffineStoreOp
1974
//===----------------------------------------------------------------------===//
1975
1976
void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
1977
                          Value valueToStore, Value memref, AffineMap map,
1978
0
                          ValueRange mapOperands) {
1979
0
  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
1980
0
  result.addOperands(valueToStore);
1981
0
  result.addOperands(memref);
1982
0
  result.addOperands(mapOperands);
1983
0
  result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
1984
0
}
1985
1986
// Use identity map.
1987
void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
1988
                          Value valueToStore, Value memref,
1989
0
                          ValueRange indices) {
1990
0
  auto memrefType = memref.getType().cast<MemRefType>();
1991
0
  auto rank = memrefType.getRank();
1992
0
  // Create identity map for memrefs with at least one dimension or () -> ()
1993
0
  // for zero-dimensional memrefs.
1994
0
  auto map =
1995
0
      rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
1996
0
  build(builder, result, valueToStore, memref, map, indices);
1997
0
}
1998
1999
static ParseResult parseAffineStoreOp(OpAsmParser &parser,
2000
0
                                      OperationState &result) {
2001
0
  auto indexTy = parser.getBuilder().getIndexType();
2002
0
2003
0
  MemRefType type;
2004
0
  OpAsmParser::OperandType storeValueInfo;
2005
0
  OpAsmParser::OperandType memrefInfo;
2006
0
  AffineMapAttr mapAttr;
2007
0
  SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2008
0
  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2009
0
                 parser.parseOperand(memrefInfo) ||
2010
0
                 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2011
0
                                               AffineStoreOp::getMapAttrName(),
2012
0
                                               result.attributes) ||
2013
0
                 parser.parseOptionalAttrDict(result.attributes) ||
2014
0
                 parser.parseColonType(type) ||
2015
0
                 parser.resolveOperand(storeValueInfo, type.getElementType(),
2016
0
                                       result.operands) ||
2017
0
                 parser.resolveOperand(memrefInfo, type, result.operands) ||
2018
0
                 parser.resolveOperands(mapOperands, indexTy, result.operands));
2019
0
}
2020
2021
0
static void print(OpAsmPrinter &p, AffineStoreOp op) {
2022
0
  p << "affine.store " << op.getValueToStore();
2023
0
  p << ", " << op.getMemRef() << '[';
2024
0
  if (AffineMapAttr mapAttr =
2025
0
          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2026
0
    p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2027
0
  p << ']';
2028
0
  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2029
0
  p << " : " << op.getMemRefType();
2030
0
}
2031
2032
0
LogicalResult verify(AffineStoreOp op) {
2033
0
  // First operand must have same type as memref element type.
2034
0
  auto memrefType = op.getMemRefType();
2035
0
  if (op.getValueToStore().getType() != memrefType.getElementType())
2036
0
    return op.emitOpError(
2037
0
        "first operand must have same type memref element type");
2038
0
2039
0
  if (failed(verifyMemoryOpIndexing(
2040
0
          op.getOperation(),
2041
0
          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2042
0
          op.getMapOperands(), memrefType,
2043
0
          /*numIndexOperands=*/op.getNumOperands() - 2)))
2044
0
    return failure();
2045
0
2046
0
  return success();
2047
0
}
2048
2049
void AffineStoreOp::getCanonicalizationPatterns(
2050
0
    OwningRewritePatternList &results, MLIRContext *context) {
2051
0
  results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
2052
0
}
2053
2054
LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
2055
0
                                  SmallVectorImpl<OpFoldResult> &results) {
2056
0
  /// store(memrefcast) -> store
2057
0
  return foldMemRefCast(*this);
2058
0
}
2059
2060
//===----------------------------------------------------------------------===//
2061
// AffineMinMaxOpBase
2062
//===----------------------------------------------------------------------===//
2063
2064
template <typename T>
2065
0
static LogicalResult verifyAffineMinMaxOp(T op) {
2066
0
  // Verify that operand count matches affine map dimension and symbol count.
2067
0
  if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
2068
0
    return op.emitOpError(
2069
0
        "operand count and affine map dimension and symbol count must match");
2070
0
  return success();
2071
0
}
Unexecuted instantiation: AffineOps.cpp:_ZL20verifyAffineMinMaxOpIN4mlir11AffineMaxOpEENS0_13LogicalResultET_
Unexecuted instantiation: AffineOps.cpp:_ZL20verifyAffineMinMaxOpIN4mlir11AffineMinOpEENS0_13LogicalResultET_
2072
2073
template <typename T>
2074
0
static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
2075
0
  p << op.getOperationName() << ' ' << op.getAttr(T::getMapAttrName());
2076
0
  auto operands = op.getOperands();
2077
0
  unsigned numDims = op.map().getNumDims();
2078
0
  p << '(' << operands.take_front(numDims) << ')';
2079
0
2080
0
  if (operands.size() != numDims)
2081
0
    p << '[' << operands.drop_front(numDims) << ']';
2082
0
  p.printOptionalAttrDict(op.getAttrs(),
2083
0
                          /*elidedAttrs=*/{T::getMapAttrName()});
2084
0
}
Unexecuted instantiation: AffineOps.cpp:_ZL19printAffineMinMaxOpIN4mlir11AffineMaxOpEEvRNS0_12OpAsmPrinterET_
Unexecuted instantiation: AffineOps.cpp:_ZL19printAffineMinMaxOpIN4mlir11AffineMinOpEEvRNS0_12OpAsmPrinterET_
2085
2086
template <typename T>
2087
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
2088
0
                                       OperationState &result) {
2089
0
  auto &builder = parser.getBuilder();
2090
0
  auto indexType = builder.getIndexType();
2091
0
  SmallVector<OpAsmParser::OperandType, 8> dim_infos;
2092
0
  SmallVector<OpAsmParser::OperandType, 8> sym_infos;
2093
0
  AffineMapAttr mapAttr;
2094
0
  return failure(
2095
0
      parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) ||
2096
0
      parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
2097
0
      parser.parseOperandList(sym_infos,
2098
0
                              OpAsmParser::Delimiter::OptionalSquare) ||
2099
0
      parser.parseOptionalAttrDict(result.attributes) ||
2100
0
      parser.resolveOperands(dim_infos, indexType, result.operands) ||
2101
0
      parser.resolveOperands(sym_infos, indexType, result.operands) ||
2102
0
      parser.addTypeToList(indexType, result.types));
2103
0
}
Unexecuted instantiation: AffineOps.cpp:_ZL19parseAffineMinMaxOpIN4mlir11AffineMaxOpEENS0_11ParseResultERNS0_11OpAsmParserERNS0_14OperationStateE
Unexecuted instantiation: AffineOps.cpp:_ZL19parseAffineMinMaxOpIN4mlir11AffineMinOpEENS0_11ParseResultERNS0_11OpAsmParserERNS0_14OperationStateE
2104
2105
/// Fold an affine min or max operation with the given operands. The operand
2106
/// list may contain nulls, which are interpreted as the operand not being a
2107
/// constant.
2108
template <typename T>
2109
0
static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
2110
0
  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
2111
0
                "expected affine min or max op");
2112
0
2113
0
  // Fold the affine map.
2114
0
  // TODO(andydavis, ntv) Fold more cases:
2115
0
  // min(some_affine, some_affine + constant, ...), etc.
2116
0
  SmallVector<int64_t, 2> results;
2117
0
  auto foldedMap = op.map().partialConstantFold(operands, &results);
2118
0
2119
0
  // If some of the map results are not constant, try changing the map in-place.
2120
0
  if (results.empty()) {
2121
0
    // If the map is the same, report that folding did not happen.
2122
0
    if (foldedMap == op.map())
2123
0
      return {};
2124
0
    op.setAttr("map", AffineMapAttr::get(foldedMap));
2125
0
    return op.getResult();
2126
0
  }
2127
0
2128
0
  // Otherwise, completely fold the op into a constant.
2129
0
  auto resultIt = std::is_same<T, AffineMinOp>::value
2130
0
                      ? std::min_element(results.begin(), results.end())
2131
0
                      : std::max_element(results.begin(), results.end());
2132
0
  if (resultIt == results.end())
2133
0
    return {};
2134
0
  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
2135
0
}
Unexecuted instantiation: AffineOps.cpp:_ZL12foldMinMaxOpIN4mlir11AffineMinOpEENS0_12OpFoldResultET_N4llvm8ArrayRefINS0_9AttributeEEE
Unexecuted instantiation: AffineOps.cpp:_ZL12foldMinMaxOpIN4mlir11AffineMaxOpEENS0_12OpFoldResultET_N4llvm8ArrayRefINS0_9AttributeEEE
2136
2137
//===----------------------------------------------------------------------===//
2138
// AffineMinOp
2139
//===----------------------------------------------------------------------===//
2140
//
2141
//   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
2142
//
2143
2144
0
OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
2145
0
  return foldMinMaxOp(*this, operands);
2146
0
}
2147
2148
void AffineMinOp::getCanonicalizationPatterns(
2149
0
    OwningRewritePatternList &patterns, MLIRContext *context) {
2150
0
  patterns.insert<SimplifyAffineOp<AffineMinOp>>(context);
2151
0
}
2152
2153
//===----------------------------------------------------------------------===//
2154
// AffineMaxOp
2155
//===----------------------------------------------------------------------===//
2156
//
2157
//   %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
2158
//
2159
2160
0
OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
2161
0
  return foldMinMaxOp(*this, operands);
2162
0
}
2163
2164
void AffineMaxOp::getCanonicalizationPatterns(
2165
0
    OwningRewritePatternList &patterns, MLIRContext *context) {
2166
0
  patterns.insert<SimplifyAffineOp<AffineMaxOp>>(context);
2167
0
}
2168
2169
//===----------------------------------------------------------------------===//
2170
// AffinePrefetchOp
2171
//===----------------------------------------------------------------------===//
2172
2173
//
2174
// affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
2175
//
2176
static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
2177
0
                                         OperationState &result) {
2178
0
  auto &builder = parser.getBuilder();
2179
0
  auto indexTy = builder.getIndexType();
2180
0
2181
0
  MemRefType type;
2182
0
  OpAsmParser::OperandType memrefInfo;
2183
0
  IntegerAttr hintInfo;
2184
0
  auto i32Type = parser.getBuilder().getIntegerType(32);
2185
0
  StringRef readOrWrite, cacheType;
2186
0
2187
0
  AffineMapAttr mapAttr;
2188
0
  SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2189
0
  if (parser.parseOperand(memrefInfo) ||
2190
0
      parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2191
0
                                    AffinePrefetchOp::getMapAttrName(),
2192
0
                                    result.attributes) ||
2193
0
      parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2194
0
      parser.parseComma() || parser.parseKeyword("locality") ||
2195
0
      parser.parseLess() ||
2196
0
      parser.parseAttribute(hintInfo, i32Type,
2197
0
                            AffinePrefetchOp::getLocalityHintAttrName(),
2198
0
                            result.attributes) ||
2199
0
      parser.parseGreater() || parser.parseComma() ||
2200
0
      parser.parseKeyword(&cacheType) ||
2201
0
      parser.parseOptionalAttrDict(result.attributes) ||
2202
0
      parser.parseColonType(type) ||
2203
0
      parser.resolveOperand(memrefInfo, type, result.operands) ||
2204
0
      parser.resolveOperands(mapOperands, indexTy, result.operands))
2205
0
    return failure();
2206
0
2207
0
  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2208
0
    return parser.emitError(parser.getNameLoc(),
2209
0
                            "rw specifier has to be 'read' or 'write'");
2210
0
  result.addAttribute(
2211
0
      AffinePrefetchOp::getIsWriteAttrName(),
2212
0
      parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2213
0
2214
0
  if (!cacheType.equals("data") && !cacheType.equals("instr"))
2215
0
    return parser.emitError(parser.getNameLoc(),
2216
0
                            "cache type has to be 'data' or 'instr'");
2217
0
2218
0
  result.addAttribute(
2219
0
      AffinePrefetchOp::getIsDataCacheAttrName(),
2220
0
      parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2221
0
2222
0
  return success();
2223
0
}
2224
2225
0
static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
2226
0
  p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
2227
0
  AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2228
0
  if (mapAttr) {
2229
0
    SmallVector<Value, 2> operands(op.getMapOperands());
2230
0
    p.printAffineMapOfSSAIds(mapAttr, operands);
2231
0
  }
2232
0
  p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
2233
0
    << "locality<" << op.localityHint() << ">, "
2234
0
    << (op.isDataCache() ? "data" : "instr");
2235
0
  p.printOptionalAttrDict(
2236
0
      op.getAttrs(),
2237
0
      /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
2238
0
                       op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
2239
0
  p << " : " << op.getMemRefType();
2240
0
}
2241
2242
0
static LogicalResult verify(AffinePrefetchOp op) {
2243
0
  auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2244
0
  if (mapAttr) {
2245
0
    AffineMap map = mapAttr.getValue();
2246
0
    if (map.getNumResults() != op.getMemRefType().getRank())
2247
0
      return op.emitOpError("affine.prefetch affine map num results must equal"
2248
0
                            " memref rank");
2249
0
    if (map.getNumInputs() + 1 != op.getNumOperands())
2250
0
      return op.emitOpError("too few operands");
2251
0
  } else {
2252
0
    if (op.getNumOperands() != 1)
2253
0
      return op.emitOpError("too few operands");
2254
0
  }
2255
0
2256
0
  Region *scope = getAffineScope(op);
2257
0
  for (auto idx : op.getMapOperands()) {
2258
0
    if (!isValidAffineIndexOperand(idx, scope))
2259
0
      return op.emitOpError("index must be a dimension or symbol identifier");
2260
0
  }
2261
0
  return success();
2262
0
}
2263
2264
void AffinePrefetchOp::getCanonicalizationPatterns(
2265
0
    OwningRewritePatternList &results, MLIRContext *context) {
2266
0
  // prefetch(memrefcast) -> prefetch
2267
0
  results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context);
2268
0
}
2269
2270
LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2271
0
                                     SmallVectorImpl<OpFoldResult> &results) {
2272
0
  /// prefetch(memrefcast) -> prefetch
2273
0
  return foldMemRefCast(*this);
2274
0
}
2275
2276
//===----------------------------------------------------------------------===//
2277
// AffineParallelOp
2278
//===----------------------------------------------------------------------===//
2279
2280
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2281
0
                             ArrayRef<int64_t> ranges) {
2282
0
  SmallVector<AffineExpr, 8> lbExprs(ranges.size(),
2283
0
                                     builder.getAffineConstantExpr(0));
2284
0
  auto lbMap = AffineMap::get(0, 0, lbExprs, builder.getContext());
2285
0
  SmallVector<AffineExpr, 8> ubExprs;
2286
0
  for (int64_t range : ranges)
2287
0
    ubExprs.push_back(builder.getAffineConstantExpr(range));
2288
0
  auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext());
2289
0
  build(builder, result, lbMap, {}, ubMap, {});
2290
0
}
2291
2292
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2293
                             AffineMap lbMap, ValueRange lbArgs,
2294
0
                             AffineMap ubMap, ValueRange ubArgs) {
2295
0
  auto numDims = lbMap.getNumResults();
2296
0
  // Verify that the dimensionality of both maps are the same.
2297
0
  assert(numDims == ubMap.getNumResults() &&
2298
0
         "num dims and num results mismatch");
2299
0
  // Make default step sizes of 1.
2300
0
  SmallVector<int64_t, 8> steps(numDims, 1);
2301
0
  build(builder, result, lbMap, lbArgs, ubMap, ubArgs, steps);
2302
0
}
2303
2304
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2305
                             AffineMap lbMap, ValueRange lbArgs,
2306
                             AffineMap ubMap, ValueRange ubArgs,
2307
                             ArrayRef<int64_t> steps) {
2308
  auto numDims = lbMap.getNumResults();
2309
  // Verify that the dimensionality of the maps matches the number of steps.
2310
  assert(numDims == ubMap.getNumResults() &&
2311
         "num dims and num results mismatch");
2312
  assert(numDims == steps.size() && "num dims and num steps mismatch");
2313
  result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
2314
  result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
2315
  result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps));
2316
  result.addOperands(lbArgs);
2317
  result.addOperands(ubArgs);
2318
  // Create a region and a block for the body.
2319
  auto bodyRegion = result.addRegion();
2320
  auto body = new Block();
2321
  // Add all the block arguments.
2322
  for (unsigned i = 0; i < numDims; ++i)
2323
    body->addArgument(IndexType::get(builder.getContext()));
2324
  bodyRegion->push_back(body);
2325
  ensureTerminator(*bodyRegion, builder, result.location);
2326
}
2327
2328
0
unsigned AffineParallelOp::getNumDims() { return steps().size(); }
2329
2330
0
AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
2331
0
  return getOperands().take_front(lowerBoundsMap().getNumInputs());
2332
0
}
2333
2334
0
AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
2335
0
  return getOperands().drop_front(lowerBoundsMap().getNumInputs());
2336
0
}
2337
2338
0
AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
2339
0
  return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
2340
0
}
2341
2342
0
AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
2343
0
  return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
2344
0
}
2345
2346
0
AffineValueMap AffineParallelOp::getRangesValueMap() {
2347
0
  AffineValueMap out;
2348
0
  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
2349
0
                             &out);
2350
0
  return out;
2351
0
}
2352
2353
0
Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
2354
0
  // Try to convert all the ranges to constant expressions.
2355
0
  SmallVector<int64_t, 8> out;
2356
0
  AffineValueMap rangesValueMap = getRangesValueMap();
2357
0
  out.reserve(rangesValueMap.getNumResults());
2358
0
  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
2359
0
    auto expr = rangesValueMap.getResult(i);
2360
0
    auto cst = expr.dyn_cast<AffineConstantExpr>();
2361
0
    if (!cst)
2362
0
      return llvm::None;
2363
0
    out.push_back(cst.getValue());
2364
0
  }
2365
0
  return out;
2366
0
}
2367
2368
0
Block *AffineParallelOp::getBody() { return &region().front(); }
2369
2370
0
OpBuilder AffineParallelOp::getBodyBuilder() {
2371
0
  return OpBuilder(getBody(), std::prev(getBody()->end()));
2372
0
}
2373
2374
0
void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
2375
0
  assert(newSteps.size() == getNumDims() && "steps & num dims mismatch");
2376
0
  setAttr(getStepsAttrName(), getBodyBuilder().getI64ArrayAttr(newSteps));
2377
0
}
2378
2379
0
static LogicalResult verify(AffineParallelOp op) {
2380
0
  auto numDims = op.getNumDims();
2381
0
  if (op.lowerBoundsMap().getNumResults() != numDims ||
2382
0
      op.upperBoundsMap().getNumResults() != numDims ||
2383
0
      op.steps().size() != numDims ||
2384
0
      op.getBody()->getNumArguments() != numDims) {
2385
0
    return op.emitOpError("region argument count and num results of upper "
2386
0
                          "bounds, lower bounds, and steps must all match");
2387
0
  }
2388
0
  // Verify that the bound operands are valid dimension/symbols.
2389
0
  /// Lower bounds.
2390
0
  if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(),
2391
0
                                           op.lowerBoundsMap().getNumDims())))
2392
0
    return failure();
2393
0
  /// Upper bounds.
2394
0
  if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(),
2395
0
                                           op.upperBoundsMap().getNumDims())))
2396
0
    return failure();
2397
0
  return success();
2398
0
}
2399
2400
0
static void print(OpAsmPrinter &p, AffineParallelOp op) {
2401
0
  p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
2402
0
  p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(),
2403
0
                           op.getLowerBoundsOperands());
2404
0
  p << ") to (";
2405
0
  p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(),
2406
0
                           op.getUpperBoundsOperands());
2407
0
  p << ')';
2408
0
  SmallVector<int64_t, 4> steps;
2409
0
  bool elideSteps = true;
2410
0
  for (auto attr : op.steps()) {
2411
0
    auto step = attr.cast<IntegerAttr>().getInt();
2412
0
    elideSteps &= (step == 1);
2413
0
    steps.push_back(step);
2414
0
  }
2415
0
  if (!elideSteps) {
2416
0
    p << " step (";
2417
0
    llvm::interleaveComma(steps, p);
2418
0
    p << ')';
2419
0
  }
2420
0
  p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
2421
0
                /*printBlockTerminators=*/false);
2422
0
  p.printOptionalAttrDict(
2423
0
      op.getAttrs(),
2424
0
      /*elidedAttrs=*/{AffineParallelOp::getLowerBoundsMapAttrName(),
2425
0
                       AffineParallelOp::getUpperBoundsMapAttrName(),
2426
0
                       AffineParallelOp::getStepsAttrName()});
2427
0
}
2428
2429
//
2430
// operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)`
2431
//               `to` `(` map-of-ssa-ids `)` steps? region attr-dict?
2432
// steps     ::= `steps` `(` integer-literals `)`
2433
//
2434
static ParseResult parseAffineParallelOp(OpAsmParser &parser,
2435
0
                                         OperationState &result) {
2436
0
  auto &builder = parser.getBuilder();
2437
0
  auto indexType = builder.getIndexType();
2438
0
  AffineMapAttr lowerBoundsAttr, upperBoundsAttr;
2439
0
  SmallVector<OpAsmParser::OperandType, 4> ivs;
2440
0
  SmallVector<OpAsmParser::OperandType, 4> lowerBoundsMapOperands;
2441
0
  SmallVector<OpAsmParser::OperandType, 4> upperBoundsMapOperands;
2442
0
  if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
2443
0
                                     OpAsmParser::Delimiter::Paren) ||
2444
0
      parser.parseEqual() ||
2445
0
      parser.parseAffineMapOfSSAIds(
2446
0
          lowerBoundsMapOperands, lowerBoundsAttr,
2447
0
          AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes,
2448
0
          OpAsmParser::Delimiter::Paren) ||
2449
0
      parser.resolveOperands(lowerBoundsMapOperands, indexType,
2450
0
                             result.operands) ||
2451
0
      parser.parseKeyword("to") ||
2452
0
      parser.parseAffineMapOfSSAIds(
2453
0
          upperBoundsMapOperands, upperBoundsAttr,
2454
0
          AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes,
2455
0
          OpAsmParser::Delimiter::Paren) ||
2456
0
      parser.resolveOperands(upperBoundsMapOperands, indexType,
2457
0
                             result.operands))
2458
0
    return failure();
2459
0
2460
0
  AffineMapAttr stepsMapAttr;
2461
0
  NamedAttrList stepsAttrs;
2462
0
  SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands;
2463
0
  if (failed(parser.parseOptionalKeyword("step"))) {
2464
0
    SmallVector<int64_t, 4> steps(ivs.size(), 1);
2465
0
    result.addAttribute(AffineParallelOp::getStepsAttrName(),
2466
0
                        builder.getI64ArrayAttr(steps));
2467
0
  } else {
2468
0
    if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
2469
0
                                      AffineParallelOp::getStepsAttrName(),
2470
0
                                      stepsAttrs,
2471
0
                                      OpAsmParser::Delimiter::Paren))
2472
0
      return failure();
2473
0
2474
0
    // Convert steps from an AffineMap into an I64ArrayAttr.
2475
0
    SmallVector<int64_t, 4> steps;
2476
0
    auto stepsMap = stepsMapAttr.getValue();
2477
0
    for (const auto &result : stepsMap.getResults()) {
2478
0
      auto constExpr = result.dyn_cast<AffineConstantExpr>();
2479
0
      if (!constExpr)
2480
0
        return parser.emitError(parser.getNameLoc(),
2481
0
                                "steps must be constant integers");
2482
0
      steps.push_back(constExpr.getValue());
2483
0
    }
2484
0
    result.addAttribute(AffineParallelOp::getStepsAttrName(),
2485
0
                        builder.getI64ArrayAttr(steps));
2486
0
  }
2487
0
2488
0
  // Now parse the body.
2489
0
  Region *body = result.addRegion();
2490
0
  SmallVector<Type, 4> types(ivs.size(), indexType);
2491
0
  if (parser.parseRegion(*body, ivs, types) ||
2492
0
      parser.parseOptionalAttrDict(result.attributes))
2493
0
    return failure();
2494
0
2495
0
  // Add a terminator if none was parsed.
2496
0
  AffineParallelOp::ensureTerminator(*body, builder, result.location);
2497
0
  return success();
2498
0
}
2499
2500
//===----------------------------------------------------------------------===//
2501
// AffineVectorLoadOp
2502
//===----------------------------------------------------------------------===//
2503
2504
static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser,
2505
0
                                           OperationState &result) {
2506
0
  auto &builder = parser.getBuilder();
2507
0
  auto indexTy = builder.getIndexType();
2508
0
2509
0
  MemRefType memrefType;
2510
0
  VectorType resultType;
2511
0
  OpAsmParser::OperandType memrefInfo;
2512
0
  AffineMapAttr mapAttr;
2513
0
  SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2514
0
  return failure(
2515
0
      parser.parseOperand(memrefInfo) ||
2516
0
      parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2517
0
                                    AffineVectorLoadOp::getMapAttrName(),
2518
0
                                    result.attributes) ||
2519
0
      parser.parseOptionalAttrDict(result.attributes) ||
2520
0
      parser.parseColonType(memrefType) || parser.parseComma() ||
2521
0
      parser.parseType(resultType) ||
2522
0
      parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
2523
0
      parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2524
0
      parser.addTypeToList(resultType, result.types));
2525
0
}
2526
2527
0
static void print(OpAsmPrinter &p, AffineVectorLoadOp op) {
2528
0
  p << "affine.vector_load " << op.getMemRef() << '[';
2529
0
  if (AffineMapAttr mapAttr =
2530
0
          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2531
0
    p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2532
0
  p << ']';
2533
0
  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2534
0
  p << " : " << op.getMemRefType() << ", " << op.getType();
2535
0
}
2536
2537
/// Verify common invariants of affine.vector_load and affine.vector_store.
2538
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
2539
0
                                          VectorType vectorType) {
2540
0
  // Check that memref and vector element types match.
2541
0
  if (memrefType.getElementType() != vectorType.getElementType())
2542
0
    return op->emitOpError(
2543
0
        "requires memref and vector types of the same elemental type");
2544
0
2545
0
  return success();
2546
0
}
2547
2548
0
static LogicalResult verify(AffineVectorLoadOp op) {
2549
0
  MemRefType memrefType = op.getMemRefType();
2550
0
  if (failed(verifyMemoryOpIndexing(
2551
0
          op.getOperation(),
2552
0
          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2553
0
          op.getMapOperands(), memrefType,
2554
0
          /*numIndexOperands=*/op.getNumOperands() - 1)))
2555
0
    return failure();
2556
0
2557
0
  if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
2558
0
                                  op.getVectorType())))
2559
0
    return failure();
2560
0
2561
0
  return success();
2562
0
}
2563
2564
//===----------------------------------------------------------------------===//
2565
// AffineVectorStoreOp
2566
//===----------------------------------------------------------------------===//
2567
2568
static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser,
2569
0
                                            OperationState &result) {
2570
0
  auto indexTy = parser.getBuilder().getIndexType();
2571
0
2572
0
  MemRefType memrefType;
2573
0
  VectorType resultType;
2574
0
  OpAsmParser::OperandType storeValueInfo;
2575
0
  OpAsmParser::OperandType memrefInfo;
2576
0
  AffineMapAttr mapAttr;
2577
0
  SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2578
0
  return failure(
2579
0
      parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2580
0
      parser.parseOperand(memrefInfo) ||
2581
0
      parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2582
0
                                    AffineVectorStoreOp::getMapAttrName(),
2583
0
                                    result.attributes) ||
2584
0
      parser.parseOptionalAttrDict(result.attributes) ||
2585
0
      parser.parseColonType(memrefType) || parser.parseComma() ||
2586
0
      parser.parseType(resultType) ||
2587
0
      parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
2588
0
      parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
2589
0
      parser.resolveOperands(mapOperands, indexTy, result.operands));
2590
0
}
2591
2592
0
static void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
2593
0
  p << "affine.vector_store " << op.getValueToStore();
2594
0
  p << ", " << op.getMemRef() << '[';
2595
0
  if (AffineMapAttr mapAttr =
2596
0
          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2597
0
    p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2598
0
  p << ']';
2599
0
  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2600
0
  p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
2601
0
}
2602
2603
0
static LogicalResult verify(AffineVectorStoreOp op) {
2604
0
  MemRefType memrefType = op.getMemRefType();
2605
0
  if (failed(verifyMemoryOpIndexing(
2606
0
          op.getOperation(),
2607
0
          op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2608
0
          op.getMapOperands(), memrefType,
2609
0
          /*numIndexOperands=*/op.getNumOperands() - 2)))
2610
0
    return failure();
2611
0
2612
0
  if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
2613
0
                                  op.getVectorType())))
2614
0
    return failure();
2615
0
2616
0
  return success();
2617
0
}
2618
2619
//===----------------------------------------------------------------------===//
2620
// TableGen'd op method definitions
2621
//===----------------------------------------------------------------------===//
2622
2623
#define GET_OP_CLASSES
2624
#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"