/home/arjun/llvm-project/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
11 | | #include "mlir/Dialect/StandardOps/IR/Ops.h" |
12 | | #include "mlir/IR/Function.h" |
13 | | #include "mlir/IR/IntegerSet.h" |
14 | | #include "mlir/IR/Matchers.h" |
15 | | #include "mlir/IR/OpImplementation.h" |
16 | | #include "mlir/IR/PatternMatch.h" |
17 | | #include "mlir/Transforms/InliningUtils.h" |
18 | | #include "llvm/ADT/SetVector.h" |
19 | | #include "llvm/ADT/SmallBitVector.h" |
20 | | #include "llvm/Support/Debug.h" |
21 | | |
22 | | using namespace mlir; |
23 | | using llvm::dbgs; |
24 | | |
25 | | #define DEBUG_TYPE "affine-analysis" |
26 | | |
27 | | //===----------------------------------------------------------------------===// |
28 | | // AffineDialect Interfaces |
29 | | //===----------------------------------------------------------------------===// |
30 | | |
31 | | namespace { |
32 | | /// This class defines the interface for handling inlining with affine |
33 | | /// operations. |
34 | | struct AffineInlinerInterface : public DialectInlinerInterface { |
35 | | using DialectInlinerInterface::DialectInlinerInterface; |
36 | | |
37 | | //===--------------------------------------------------------------------===// |
38 | | // Analysis Hooks |
39 | | //===--------------------------------------------------------------------===// |
40 | | |
41 | | /// Returns true if the given region 'src' can be inlined into the region |
42 | | /// 'dest' that is attached to an operation registered to the current dialect. |
43 | | bool isLegalToInline(Region *dest, Region *src, |
44 | 0 | BlockAndValueMapping &valueMapping) const final { |
45 | 0 | // Conservatively don't allow inlining into affine structures. |
46 | 0 | return false; |
47 | 0 | } |
48 | | |
49 | | /// Returns true if the given operation 'op', that is registered to this |
50 | | /// dialect, can be inlined into the given region, false otherwise. |
51 | | bool isLegalToInline(Operation *op, Region *region, |
52 | 0 | BlockAndValueMapping &valueMapping) const final { |
53 | 0 | // Always allow inlining affine operations into the top-level region of a |
54 | 0 | // function. There are some edge cases when inlining *into* affine |
55 | 0 | // structures, but that is handled in the other 'isLegalToInline' hook |
56 | 0 | // above. |
57 | 0 | // TODO: We should be able to inline into other regions than functions. |
58 | 0 | return isa<FuncOp>(region->getParentOp()); |
59 | 0 | } |
60 | | |
61 | | /// Affine regions should be analyzed recursively. |
62 | 0 | bool shouldAnalyzeRecursively(Operation *op) const final { return true; } |
63 | | }; |
64 | | } // end anonymous namespace |
65 | | |
66 | | //===----------------------------------------------------------------------===// |
67 | | // AffineDialect |
68 | | //===----------------------------------------------------------------------===// |
69 | | |
70 | | AffineDialect::AffineDialect(MLIRContext *context) |
71 | 0 | : Dialect(getDialectNamespace(), context) { |
72 | 0 | addOperations<AffineDmaStartOp, AffineDmaWaitOp, |
73 | 0 | #define GET_OP_LIST |
74 | 0 | #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc" |
75 | 0 | >(); |
76 | 0 | addInterfaces<AffineInlinerInterface>(); |
77 | 0 | } |
78 | | |
79 | | /// Materialize a single constant operation from a given attribute value with |
80 | | /// the desired resultant type. |
81 | | Operation *AffineDialect::materializeConstant(OpBuilder &builder, |
82 | | Attribute value, Type type, |
83 | 0 | Location loc) { |
84 | 0 | return builder.create<ConstantOp>(loc, type, value); |
85 | 0 | } |
86 | | |
87 | | /// A utility function to check if a value is defined at the top level of an |
88 | | /// op with trait `AffineScope`. If the value is defined in an unlinked region, |
89 | | /// conservatively assume it is not top-level. A value of index type defined at |
90 | | /// the top level is always a valid symbol. |
91 | 0 | bool mlir::isTopLevelValue(Value value) { |
92 | 0 | if (auto arg = value.dyn_cast<BlockArgument>()) { |
93 | 0 | // The block owning the argument may be unlinked, e.g. when the surrounding |
94 | 0 | // region has not yet been attached to an Op, at which point the parent Op |
95 | 0 | // is null. |
96 | 0 | Operation *parentOp = arg.getOwner()->getParentOp(); |
97 | 0 | return parentOp && parentOp->hasTrait<OpTrait::AffineScope>(); |
98 | 0 | } |
99 | 0 | // The defining Op may live in an unlinked block so its parent Op may be null. |
100 | 0 | Operation *parentOp = value.getDefiningOp()->getParentOp(); |
101 | 0 | return parentOp && parentOp->hasTrait<OpTrait::AffineScope>(); |
102 | 0 | } |
103 | | |
104 | | /// A utility function to check if a value is defined at the top level of |
105 | | /// `region` or is an argument of `region`. A value of index type defined at the |
106 | | /// top level of a `AffineScope` region is always a valid symbol for all |
107 | | /// uses in that region. |
108 | 0 | static bool isTopLevelValue(Value value, Region *region) { |
109 | 0 | if (auto arg = value.dyn_cast<BlockArgument>()) |
110 | 0 | return arg.getParentRegion() == region; |
111 | 0 | return value.getDefiningOp()->getParentRegion() == region; |
112 | 0 | } |
113 | | |
114 | | /// Returns the closest region enclosing `op` that is held by an operation with |
115 | | /// trait `AffineScope`. |
116 | | // TODO: getAffineScope should be publicly exposed for affine passes/utilities. |
117 | 0 | static Region *getAffineScope(Operation *op) { |
118 | 0 | auto *curOp = op; |
119 | 0 | while (auto *parentOp = curOp->getParentOp()) { |
120 | 0 | if (parentOp->hasTrait<OpTrait::AffineScope>()) |
121 | 0 | return curOp->getParentRegion(); |
122 | 0 | curOp = parentOp; |
123 | 0 | } |
124 | 0 | llvm_unreachable("op doesn't have an enclosing polyhedral scope"); |
125 | 0 | } |
126 | | |
127 | | // A Value can be used as a dimension id iff it meets one of the following |
128 | | // conditions: |
129 | | // *) It is valid as a symbol. |
130 | | // *) It is an induction variable. |
131 | | // *) It is the result of affine apply operation with dimension id arguments. |
132 | 0 | bool mlir::isValidDim(Value value) { |
133 | 0 | // The value must be an index type. |
134 | 0 | if (!value.getType().isIndex()) |
135 | 0 | return false; |
136 | 0 | |
137 | 0 | if (auto *defOp = value.getDefiningOp()) |
138 | 0 | return isValidDim(value, getAffineScope(defOp)); |
139 | 0 | |
140 | 0 | // This value has to be a block argument for an op that has the |
141 | 0 | // `AffineScope` trait or for an affine.for or affine.parallel. |
142 | 0 | auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp(); |
143 | 0 | return parentOp && |
144 | 0 | (parentOp->hasTrait<OpTrait::AffineScope>() || |
145 | 0 | isa<AffineForOp>(parentOp) || isa<AffineParallelOp>(parentOp)); |
146 | 0 | } |
147 | | |
148 | | // Value can be used as a dimension id iff it meets one of the following |
149 | | // conditions: |
150 | | // *) It is valid as a symbol. |
151 | | // *) It is an induction variable. |
152 | | // *) It is the result of an affine apply operation with dimension id operands. |
153 | 0 | bool mlir::isValidDim(Value value, Region *region) { |
154 | 0 | // The value must be an index type. |
155 | 0 | if (!value.getType().isIndex()) |
156 | 0 | return false; |
157 | 0 | |
158 | 0 | // All valid symbols are okay. |
159 | 0 | if (isValidSymbol(value, region)) |
160 | 0 | return true; |
161 | 0 | |
162 | 0 | auto *op = value.getDefiningOp(); |
163 | 0 | if (!op) { |
164 | 0 | // This value has to be a block argument for an affine.for or an |
165 | 0 | // affine.parallel. |
166 | 0 | auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp(); |
167 | 0 | return isa<AffineForOp>(parentOp) || isa<AffineParallelOp>(parentOp); |
168 | 0 | } |
169 | 0 |
|
170 | 0 | // Affine apply operation is ok if all of its operands are ok. |
171 | 0 | if (auto applyOp = dyn_cast<AffineApplyOp>(op)) |
172 | 0 | return applyOp.isValidDim(region); |
173 | 0 | // The dim op is okay if its operand memref/tensor is defined at the top |
174 | 0 | // level. |
175 | 0 | if (auto dimOp = dyn_cast<DimOp>(op)) |
176 | 0 | return isTopLevelValue(dimOp.getOperand()); |
177 | 0 | return false; |
178 | 0 | } |
179 | | |
180 | | /// Returns true if the 'index' dimension of the `memref` defined by |
181 | | /// `memrefDefOp` is a statically shaped one or defined using a valid symbol |
182 | | /// for `region`. |
183 | | template <typename AnyMemRefDefOp> |
184 | | static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, |
185 | 0 | Region *region) { |
186 | 0 | auto memRefType = memrefDefOp.getType(); |
187 | 0 | // Statically shaped. |
188 | 0 | if (!memRefType.isDynamicDim(index)) |
189 | 0 | return true; |
190 | 0 | // Get the position of the dimension among dynamic dimensions; |
191 | 0 | unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); |
192 | 0 | return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos), |
193 | 0 | region); |
194 | 0 | } Unexecuted instantiation: AffineOps.cpp:_ZL23isMemRefSizeValidSymbolIN4mlir6ViewOpEEbT_jPNS0_6RegionE Unexecuted instantiation: AffineOps.cpp:_ZL23isMemRefSizeValidSymbolIN4mlir9SubViewOpEEbT_jPNS0_6RegionE Unexecuted instantiation: AffineOps.cpp:_ZL23isMemRefSizeValidSymbolIN4mlir7AllocOpEEbT_jPNS0_6RegionE |
195 | | |
196 | | /// Returns true if the result of the dim op is a valid symbol for `region`. |
197 | 0 | static bool isDimOpValidSymbol(DimOp dimOp, Region *region) { |
198 | 0 | // The dim op is okay if its operand memref/tensor is defined at the top |
199 | 0 | // level. |
200 | 0 | if (isTopLevelValue(dimOp.getOperand())) |
201 | 0 | return true; |
202 | 0 | |
203 | 0 | // The dim op is also okay if its operand memref/tensor is a view/subview |
204 | 0 | // whose corresponding size is a valid symbol. |
205 | 0 | unsigned index = dimOp.getIndex(); |
206 | 0 | if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand().getDefiningOp())) |
207 | 0 | return isMemRefSizeValidSymbol<ViewOp>(viewOp, index, region); |
208 | 0 | if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand().getDefiningOp())) |
209 | 0 | return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index, region); |
210 | 0 | if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand().getDefiningOp())) |
211 | 0 | return isMemRefSizeValidSymbol<AllocOp>(allocOp, index, region); |
212 | 0 | return false; |
213 | 0 | } |
214 | | |
215 | | // A value can be used as a symbol (at all its use sites) iff it meets one of |
216 | | // the following conditions: |
217 | | // *) It is a constant. |
218 | | // *) Its defining op or block arg appearance is immediately enclosed by an op |
219 | | // with `AffineScope` trait. |
220 | | // *) It is the result of an affine.apply operation with symbol operands. |
221 | | // *) It is a result of the dim op on a memref whose corresponding size is a |
222 | | // valid symbol. |
223 | 0 | bool mlir::isValidSymbol(Value value) { |
224 | 0 | // The value must be an index type. |
225 | 0 | if (!value.getType().isIndex()) |
226 | 0 | return false; |
227 | 0 | |
228 | 0 | // Check that the value is a top level value. |
229 | 0 | if (isTopLevelValue(value)) |
230 | 0 | return true; |
231 | 0 | |
232 | 0 | if (auto *defOp = value.getDefiningOp()) |
233 | 0 | return isValidSymbol(value, getAffineScope(defOp)); |
234 | 0 | |
235 | 0 | return false; |
236 | 0 | } |
237 | | |
238 | | // A value can be used as a symbol for `region` iff it meets onf of the the |
239 | | // following conditions: |
240 | | // *) It is a constant. |
241 | | // *) It is defined at the top level of 'region' or is its argument. |
242 | | // *) It dominates `region`'s parent op. |
243 | | // *) It is the result of an affine apply operation with symbol arguments. |
244 | | // *) It is a result of the dim op on a memref whose corresponding size is |
245 | | // a valid symbol. |
246 | 0 | bool mlir::isValidSymbol(Value value, Region *region) { |
247 | 0 | // The value must be an index type. |
248 | 0 | if (!value.getType().isIndex()) |
249 | 0 | return false; |
250 | 0 | |
251 | 0 | // A top-level value is a valid symbol. |
252 | 0 | if (::isTopLevelValue(value, region)) |
253 | 0 | return true; |
254 | 0 | |
255 | 0 | auto *defOp = value.getDefiningOp(); |
256 | 0 | if (!defOp) { |
257 | 0 | // A block argument that is not a top-level value is a valid symbol if it |
258 | 0 | // dominates region's parent op. |
259 | 0 | if (!region->getParentOp()->isKnownIsolatedFromAbove()) |
260 | 0 | if (auto *parentOpRegion = region->getParentOp()->getParentRegion()) |
261 | 0 | return isValidSymbol(value, parentOpRegion); |
262 | 0 | return false; |
263 | 0 | } |
264 | 0 | |
265 | 0 | // Constant operation is ok. |
266 | 0 | Attribute operandCst; |
267 | 0 | if (matchPattern(defOp, m_Constant(&operandCst))) |
268 | 0 | return true; |
269 | 0 | |
270 | 0 | // Affine apply operation is ok if all of its operands are ok. |
271 | 0 | if (auto applyOp = dyn_cast<AffineApplyOp>(defOp)) |
272 | 0 | return applyOp.isValidSymbol(region); |
273 | 0 | |
274 | 0 | // Dim op results could be valid symbols at any level. |
275 | 0 | if (auto dimOp = dyn_cast<DimOp>(defOp)) |
276 | 0 | return isDimOpValidSymbol(dimOp, region); |
277 | 0 | |
278 | 0 | // Check for values dominating `region`'s parent op. |
279 | 0 | if (!region->getParentOp()->isKnownIsolatedFromAbove()) |
280 | 0 | if (auto *parentRegion = region->getParentOp()->getParentRegion()) |
281 | 0 | return isValidSymbol(value, parentRegion); |
282 | 0 | |
283 | 0 | return false; |
284 | 0 | } |
285 | | |
286 | | // Returns true if 'value' is a valid index to an affine operation (e.g. |
287 | | // affine.load, affine.store, affine.dma_start, affine.dma_wait) where |
288 | | // `region` provides the polyhedral symbol scope. Returns false otherwise. |
289 | 0 | static bool isValidAffineIndexOperand(Value value, Region *region) { |
290 | 0 | return isValidDim(value, region) || isValidSymbol(value, region); |
291 | 0 | } |
292 | | |
293 | | /// Utility function to verify that a set of operands are valid dimension and |
294 | | /// symbol identifiers. The operands should be laid out such that the dimension |
295 | | /// operands are before the symbol operands. This function returns failure if |
296 | | /// there was an invalid operand. An operation is provided to emit any necessary |
297 | | /// errors. |
298 | | template <typename OpTy> |
299 | | static LogicalResult |
300 | | verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, |
301 | 0 | unsigned numDims) { |
302 | 0 | unsigned opIt = 0; |
303 | 0 | for (auto operand : operands) { |
304 | 0 | if (opIt++ < numDims) { |
305 | 0 | if (!isValidDim(operand, getAffineScope(op))) |
306 | 0 | return op.emitOpError("operand cannot be used as a dimension id"); |
307 | 0 | } else if (!isValidSymbol(operand, getAffineScope(op))) { |
308 | 0 | return op.emitOpError("operand cannot be used as a symbol"); |
309 | 0 | } |
310 | 0 | } |
311 | 0 | return success(); |
312 | 0 | } Unexecuted instantiation: AffineOps.cpp:_ZL29verifyDimAndSymbolIdentifiersIN4mlir11AffineForOpEENS0_13LogicalResultERT_NS0_12OperandRangeEj Unexecuted instantiation: AffineOps.cpp:_ZL29verifyDimAndSymbolIdentifiersIN4mlir10AffineIfOpEENS0_13LogicalResultERT_NS0_12OperandRangeEj Unexecuted instantiation: AffineOps.cpp:_ZL29verifyDimAndSymbolIdentifiersIN4mlir16AffineParallelOpEENS0_13LogicalResultERT_NS0_12OperandRangeEj |
313 | | |
314 | | //===----------------------------------------------------------------------===// |
315 | | // AffineApplyOp |
316 | | //===----------------------------------------------------------------------===// |
317 | | |
318 | 0 | AffineValueMap AffineApplyOp::getAffineValueMap() { |
319 | 0 | return AffineValueMap(getAffineMap(), getOperands(), getResult()); |
320 | 0 | } |
321 | | |
322 | | static ParseResult parseAffineApplyOp(OpAsmParser &parser, |
323 | 0 | OperationState &result) { |
324 | 0 | auto &builder = parser.getBuilder(); |
325 | 0 | auto indexTy = builder.getIndexType(); |
326 | 0 |
|
327 | 0 | AffineMapAttr mapAttr; |
328 | 0 | unsigned numDims; |
329 | 0 | if (parser.parseAttribute(mapAttr, "map", result.attributes) || |
330 | 0 | parseDimAndSymbolList(parser, result.operands, numDims) || |
331 | 0 | parser.parseOptionalAttrDict(result.attributes)) |
332 | 0 | return failure(); |
333 | 0 | auto map = mapAttr.getValue(); |
334 | 0 |
|
335 | 0 | if (map.getNumDims() != numDims || |
336 | 0 | numDims + map.getNumSymbols() != result.operands.size()) { |
337 | 0 | return parser.emitError(parser.getNameLoc(), |
338 | 0 | "dimension or symbol index mismatch"); |
339 | 0 | } |
340 | 0 | |
341 | 0 | result.types.append(map.getNumResults(), indexTy); |
342 | 0 | return success(); |
343 | 0 | } |
344 | | |
345 | 0 | static void print(OpAsmPrinter &p, AffineApplyOp op) { |
346 | 0 | p << AffineApplyOp::getOperationName() << " " << op.mapAttr(); |
347 | 0 | printDimAndSymbolList(op.operand_begin(), op.operand_end(), |
348 | 0 | op.getAffineMap().getNumDims(), p); |
349 | 0 | p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); |
350 | 0 | } |
351 | | |
352 | 0 | static LogicalResult verify(AffineApplyOp op) { |
353 | 0 | // Check input and output dimensions match. |
354 | 0 | auto map = op.map(); |
355 | 0 |
|
356 | 0 | // Verify that operand count matches affine map dimension and symbol count. |
357 | 0 | if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols()) |
358 | 0 | return op.emitOpError( |
359 | 0 | "operand count and affine map dimension and symbol count must match"); |
360 | 0 | |
361 | 0 | // Verify that the map only produces one result. |
362 | 0 | if (map.getNumResults() != 1) |
363 | 0 | return op.emitOpError("mapping must produce one value"); |
364 | 0 | |
365 | 0 | return success(); |
366 | 0 | } |
367 | | |
368 | | // The result of the affine apply operation can be used as a dimension id if all |
369 | | // its operands are valid dimension ids. |
370 | 0 | bool AffineApplyOp::isValidDim() { |
371 | 0 | return llvm::all_of(getOperands(), |
372 | 0 | [](Value op) { return mlir::isValidDim(op); }); |
373 | 0 | } |
374 | | |
375 | | // The result of the affine apply operation can be used as a dimension id if all |
376 | | // its operands are valid dimension ids with the parent operation of `region` |
377 | | // defining the polyhedral scope for symbols. |
378 | 0 | bool AffineApplyOp::isValidDim(Region *region) { |
379 | 0 | return llvm::all_of(getOperands(), |
380 | 0 | [&](Value op) { return ::isValidDim(op, region); }); |
381 | 0 | } |
382 | | |
383 | | // The result of the affine apply operation can be used as a symbol if all its |
384 | | // operands are symbols. |
385 | 0 | bool AffineApplyOp::isValidSymbol() { |
386 | 0 | return llvm::all_of(getOperands(), |
387 | 0 | [](Value op) { return mlir::isValidSymbol(op); }); |
388 | 0 | } |
389 | | |
390 | | // The result of the affine apply operation can be used as a symbol in `region` |
391 | | // if all its operands are symbols in `region`. |
392 | 0 | bool AffineApplyOp::isValidSymbol(Region *region) { |
393 | 0 | return llvm::all_of(getOperands(), [&](Value operand) { |
394 | 0 | return mlir::isValidSymbol(operand, region); |
395 | 0 | }); |
396 | 0 | } |
397 | | |
398 | 0 | OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) { |
399 | 0 | auto map = getAffineMap(); |
400 | 0 |
|
401 | 0 | // Fold dims and symbols to existing values. |
402 | 0 | auto expr = map.getResult(0); |
403 | 0 | if (auto dim = expr.dyn_cast<AffineDimExpr>()) |
404 | 0 | return getOperand(dim.getPosition()); |
405 | 0 | if (auto sym = expr.dyn_cast<AffineSymbolExpr>()) |
406 | 0 | return getOperand(map.getNumDims() + sym.getPosition()); |
407 | 0 | |
408 | 0 | // Otherwise, default to folding the map. |
409 | 0 | SmallVector<Attribute, 1> result; |
410 | 0 | if (failed(map.constantFold(operands, result))) |
411 | 0 | return {}; |
412 | 0 | return result[0]; |
413 | 0 | } |
414 | | |
415 | 0 | AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) { |
416 | 0 | DenseMap<Value, unsigned>::iterator iterPos; |
417 | 0 | bool inserted = false; |
418 | 0 | std::tie(iterPos, inserted) = |
419 | 0 | dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); |
420 | 0 | if (inserted) { |
421 | 0 | reorderedDims.push_back(v); |
422 | 0 | } |
423 | 0 | return getAffineDimExpr(iterPos->second, v.getContext()) |
424 | 0 | .cast<AffineDimExpr>(); |
425 | 0 | } |
426 | | |
427 | 0 | AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { |
428 | 0 | SmallVector<AffineExpr, 8> dimRemapping; |
429 | 0 | for (auto v : other.reorderedDims) { |
430 | 0 | auto kvp = other.dimValueToPosition.find(v); |
431 | 0 | if (dimRemapping.size() <= kvp->second) |
432 | 0 | dimRemapping.resize(kvp->second + 1); |
433 | 0 | dimRemapping[kvp->second] = renumberOneDim(kvp->first); |
434 | 0 | } |
435 | 0 | unsigned numSymbols = concatenatedSymbols.size(); |
436 | 0 | unsigned numOtherSymbols = other.concatenatedSymbols.size(); |
437 | 0 | SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols); |
438 | 0 | for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { |
439 | 0 | symRemapping[idx] = |
440 | 0 | getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext()); |
441 | 0 | } |
442 | 0 | concatenatedSymbols.insert(concatenatedSymbols.end(), |
443 | 0 | other.concatenatedSymbols.begin(), |
444 | 0 | other.concatenatedSymbols.end()); |
445 | 0 | auto map = other.affineMap; |
446 | 0 | return map.replaceDimsAndSymbols(dimRemapping, symRemapping, |
447 | 0 | reorderedDims.size(), |
448 | 0 | concatenatedSymbols.size()); |
449 | 0 | } |
450 | | |
451 | | // Gather the positions of the operands that are produced by an AffineApplyOp. |
452 | | static llvm::SetVector<unsigned> |
453 | 0 | indicesFromAffineApplyOp(ArrayRef<Value> operands) { |
454 | 0 | llvm::SetVector<unsigned> res; |
455 | 0 | for (auto en : llvm::enumerate(operands)) |
456 | 0 | if (isa_and_nonnull<AffineApplyOp>(en.value().getDefiningOp())) |
457 | 0 | res.insert(en.index()); |
458 | 0 | return res; |
459 | 0 | } |
460 | | |
461 | | // Support the special case of a symbol coming from an AffineApplyOp that needs |
462 | | // to be composed into the current AffineApplyOp. |
463 | | // This case is handled by rewriting all such symbols into dims for the purpose |
464 | | // of allowing mathematical AffineMap composition. |
465 | | // Returns an AffineMap where symbols that come from an AffineApplyOp have been |
466 | | // rewritten as dims and are ordered after the original dims. |
467 | | // TODO(andydavis,ntv): This promotion makes AffineMap lose track of which |
468 | | // symbols are represented as dims. This loss is static but can still be |
469 | | // recovered dynamically (with `isValidSymbol`). Still this is annoying for the |
470 | | // semi-affine map case. A dynamic canonicalization of all dims that are valid |
471 | | // symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even |
472 | | // results in better simplifications and foldings. But we should evaluate |
473 | | // whether this behavior is what we really want after using more. |
474 | | static AffineMap promoteComposedSymbolsAsDims(AffineMap map, |
475 | 0 | ArrayRef<Value> symbols) { |
476 | 0 | if (symbols.empty()) { |
477 | 0 | return map; |
478 | 0 | } |
479 | 0 | |
480 | 0 | // Sanity check on symbols. |
481 | 0 | for (auto sym : symbols) { |
482 | 0 | assert(isValidSymbol(sym) && "Expected only valid symbols"); |
483 | 0 | (void)sym; |
484 | 0 | } |
485 | 0 |
|
486 | 0 | // Extract the symbol positions that come from an AffineApplyOp and |
487 | 0 | // needs to be rewritten as dims. |
488 | 0 | auto symPositions = indicesFromAffineApplyOp(symbols); |
489 | 0 | if (symPositions.empty()) { |
490 | 0 | return map; |
491 | 0 | } |
492 | 0 | |
493 | 0 | // Create the new map by replacing each symbol at pos by the next new dim. |
494 | 0 | unsigned numDims = map.getNumDims(); |
495 | 0 | unsigned numSymbols = map.getNumSymbols(); |
496 | 0 | unsigned numNewDims = 0; |
497 | 0 | unsigned numNewSymbols = 0; |
498 | 0 | SmallVector<AffineExpr, 8> symReplacements(numSymbols); |
499 | 0 | for (unsigned i = 0; i < numSymbols; ++i) { |
500 | 0 | symReplacements[i] = |
501 | 0 | symPositions.count(i) > 0 |
502 | 0 | ? getAffineDimExpr(numDims + numNewDims++, map.getContext()) |
503 | 0 | : getAffineSymbolExpr(numNewSymbols++, map.getContext()); |
504 | 0 | } |
505 | 0 | assert(numSymbols >= numNewDims); |
506 | 0 | AffineMap newMap = map.replaceDimsAndSymbols( |
507 | 0 | {}, symReplacements, numDims + numNewDims, numNewSymbols); |
508 | 0 |
|
509 | 0 | return newMap; |
510 | 0 | } |
511 | | |
512 | | /// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to |
513 | | /// keep a correspondence between the mathematical `map` and the `operands` of |
514 | | /// a given AffineApplyOp. This correspondence is maintained by iterating over |
515 | | /// the operands and forming an `auxiliaryMap` that can be composed |
516 | | /// mathematically with `map`. To keep this correspondence in cases where |
517 | | /// symbols are produced by affine.apply operations, we perform a local rewrite |
518 | | /// of symbols as dims. |
519 | | /// |
520 | | /// Rationale for locally rewriting symbols as dims: |
521 | | /// ================================================ |
522 | | /// The mathematical composition of AffineMap must always concatenate symbols |
523 | | /// because it does not have enough information to do otherwise. For example, |
524 | | /// composing `(d0)[s0] -> (d0 + s0)` with itself must produce |
525 | | /// `(d0)[s0, s1] -> (d0 + s0 + s1)`. |
526 | | /// |
527 | | /// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when |
528 | | /// applied to the same mlir::Value for both s0 and s1. |
529 | | /// As a consequence mathematical composition of AffineMap always concatenates |
530 | | /// symbols. |
531 | | /// |
532 | | /// When AffineMaps are used in AffineApplyOp however, they may specify |
533 | | /// composition via symbols, which is ambiguous mathematically. This corner case |
534 | | /// is handled by locally rewriting such symbols that come from AffineApplyOp |
535 | | /// into dims and composing through dims. |
536 | | /// TODO(andydavis, ntv): Composition via symbols comes at a significant code |
537 | | /// complexity. Alternatively we should investigate whether we want to |
538 | | /// explicitly disallow symbols coming from affine.apply and instead force the |
539 | | /// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2 |
540 | | /// extra API calls for such uses, which haven't popped up until now) and the |
541 | | /// benefit potentially big: simpler and more maintainable code for a |
542 | | /// non-trivial, recursive, procedure. |
543 | | AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, |
544 | | ArrayRef<Value> operands) |
545 | 0 | : AffineApplyNormalizer() { |
546 | 0 | static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0"); |
547 | 0 | assert(map.getNumInputs() == operands.size() && |
548 | 0 | "number of operands does not match the number of map inputs"); |
549 | 0 |
|
550 | 0 | LLVM_DEBUG(map.print(dbgs() << "\nInput map: ")); |
551 | 0 |
|
552 | 0 | // Promote symbols that come from an AffineApplyOp to dims by rewriting the |
553 | 0 | // map to always refer to: |
554 | 0 | // (dims, symbols coming from AffineApplyOp, other symbols). |
555 | 0 | // The order of operands can remain unchanged. |
556 | 0 | // This is a simplification that relies on 2 ordering properties: |
557 | 0 | // 1. rewritten symbols always appear after the original dims in the map; |
558 | 0 | // 2. operands are traversed in order and either dispatched to: |
559 | 0 | // a. auxiliaryExprs (dims and symbols rewritten as dims); |
560 | 0 | // b. concatenatedSymbols (all other symbols) |
561 | 0 | // This allows operand order to remain unchanged. |
562 | 0 | unsigned numDimsBeforeRewrite = map.getNumDims(); |
563 | 0 | map = promoteComposedSymbolsAsDims(map, |
564 | 0 | operands.take_back(map.getNumSymbols())); |
565 | 0 |
|
566 | 0 | LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: ")); |
567 | 0 |
|
568 | 0 | SmallVector<AffineExpr, 8> auxiliaryExprs; |
569 | 0 | bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth); |
570 | 0 | // We fully spell out the 2 cases below. In this particular instance a little |
571 | 0 | // code duplication greatly improves readability. |
572 | 0 | // Note that the first branch would disappear if we only supported full |
573 | 0 | // composition (i.e. infinite kMaxAffineApplyDepth). |
574 | 0 | if (!furtherCompose) { |
575 | 0 | // 1. Only dispatch dims or symbols. |
576 | 0 | for (auto en : llvm::enumerate(operands)) { |
577 | 0 | auto t = en.value(); |
578 | 0 | assert(t.getType().isIndex()); |
579 | 0 | bool isDim = (en.index() < map.getNumDims()); |
580 | 0 | if (isDim) { |
581 | 0 | // a. The mathematical composition of AffineMap composes dims. |
582 | 0 | auxiliaryExprs.push_back(renumberOneDim(t)); |
583 | 0 | } else { |
584 | 0 | // b. The mathematical composition of AffineMap concatenates symbols. |
585 | 0 | // We do the same for symbol operands. |
586 | 0 | concatenatedSymbols.push_back(t); |
587 | 0 | } |
588 | 0 | } |
589 | 0 | } else { |
590 | 0 | assert(numDimsBeforeRewrite <= operands.size()); |
591 | 0 | // 2. Compose AffineApplyOps and dispatch dims or symbols. |
592 | 0 | for (unsigned i = 0, e = operands.size(); i < e; ++i) { |
593 | 0 | auto t = operands[i]; |
594 | 0 | auto affineApply = t.getDefiningOp<AffineApplyOp>(); |
595 | 0 | if (affineApply) { |
596 | 0 | // a. Compose affine.apply operations. |
597 | 0 | LLVM_DEBUG(affineApply.getOperation()->print( |
598 | 0 | dbgs() << "\nCompose AffineApplyOp recursively: ")); |
599 | 0 | AffineMap affineApplyMap = affineApply.getAffineMap(); |
600 | 0 | SmallVector<Value, 8> affineApplyOperands( |
601 | 0 | affineApply.getOperands().begin(), affineApply.getOperands().end()); |
602 | 0 | AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands); |
603 | 0 |
|
604 | 0 | LLVM_DEBUG(normalizer.affineMap.print( |
605 | 0 | dbgs() << "\nRenumber into current normalizer: ")); |
606 | 0 |
|
607 | 0 | auto renumberedMap = renumber(normalizer); |
608 | 0 |
|
609 | 0 | LLVM_DEBUG( |
610 | 0 | renumberedMap.print(dbgs() << "\nRecursive composition yields: ")); |
611 | 0 |
|
612 | 0 | auxiliaryExprs.push_back(renumberedMap.getResult(0)); |
613 | 0 | } else { |
614 | 0 | if (i < numDimsBeforeRewrite) { |
615 | 0 | // b. The mathematical composition of AffineMap composes dims. |
616 | 0 | auxiliaryExprs.push_back(renumberOneDim(t)); |
617 | 0 | } else { |
618 | 0 | // c. The mathematical composition of AffineMap concatenates symbols. |
619 | 0 | // Note that the map composition will put symbols already present |
620 | 0 | // in the map before any symbols coming from the auxiliary map, so |
621 | 0 | // we insert them before any symbols that are due to renumbering, |
622 | 0 | // and after the proper symbols we have seen already. |
623 | 0 | concatenatedSymbols.insert( |
624 | 0 | std::next(concatenatedSymbols.begin(), numProperSymbols++), t); |
625 | 0 | } |
626 | 0 | } |
627 | 0 | } |
628 | 0 | } |
629 | 0 |
|
630 | 0 | // Early exit if `map` is already composed. |
631 | 0 | if (auxiliaryExprs.empty()) { |
632 | 0 | affineMap = map; |
633 | 0 | return; |
634 | 0 | } |
635 | 0 | |
636 | 0 | assert(concatenatedSymbols.size() >= map.getNumSymbols() && |
637 | 0 | "Unexpected number of concatenated symbols"); |
638 | 0 | auto numDims = dimValueToPosition.size(); |
639 | 0 | auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols(); |
640 | 0 | auto auxiliaryMap = |
641 | 0 | AffineMap::get(numDims, numSymbols, auxiliaryExprs, map.getContext()); |
642 | 0 |
|
643 | 0 | LLVM_DEBUG(map.print(dbgs() << "\nCompose map: ")); |
644 | 0 | LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: ")); |
645 | 0 | LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: ")); |
646 | 0 |
|
647 | 0 | // TODO(andydavis,ntv): Disabling simplification results in major speed gains. |
648 | 0 | // Another option is to cache the results as it is expected a lot of redundant |
649 | 0 | // work is performed in practice. |
650 | 0 | affineMap = simplifyAffineMap(map.compose(auxiliaryMap)); |
651 | 0 |
|
652 | 0 | LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: ")); |
653 | 0 | LLVM_DEBUG(dbgs() << "\n"); |
654 | 0 | } |
655 | | |
656 | | void AffineApplyNormalizer::normalize(AffineMap *otherMap, |
657 | 0 | SmallVectorImpl<Value> *otherOperands) { |
658 | 0 | AffineApplyNormalizer other(*otherMap, *otherOperands); |
659 | 0 | *otherMap = renumber(other); |
660 | 0 |
|
661 | 0 | otherOperands->reserve(reorderedDims.size() + concatenatedSymbols.size()); |
662 | 0 | otherOperands->assign(reorderedDims.begin(), reorderedDims.end()); |
663 | 0 | otherOperands->append(concatenatedSymbols.begin(), concatenatedSymbols.end()); |
664 | 0 | } |
665 | | |
666 | | /// Implements `map` and `operands` composition and simplification to support |
667 | | /// `makeComposedAffineApply`. This can be called to achieve the same effects |
668 | | /// on `map` and `operands` without creating an AffineApplyOp that needs to be |
669 | | /// immediately deleted. |
670 | | static void composeAffineMapAndOperands(AffineMap *map, |
671 | 0 | SmallVectorImpl<Value> *operands) { |
672 | 0 | AffineApplyNormalizer normalizer(*map, *operands); |
673 | 0 | auto normalizedMap = normalizer.getAffineMap(); |
674 | 0 | auto normalizedOperands = normalizer.getOperands(); |
675 | 0 | canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); |
676 | 0 | *map = normalizedMap; |
677 | 0 | *operands = normalizedOperands; |
678 | 0 | assert(*map); |
679 | 0 | } |
680 | | |
681 | | void mlir::fullyComposeAffineMapAndOperands(AffineMap *map, |
682 | 0 | SmallVectorImpl<Value> *operands) { |
683 | 0 | while (llvm::any_of(*operands, [](Value v) { |
684 | 0 | return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp()); |
685 | 0 | })) { |
686 | 0 | composeAffineMapAndOperands(map, operands); |
687 | 0 | } |
688 | 0 | } |
689 | | |
690 | | AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, |
691 | | AffineMap map, |
692 | 0 | ArrayRef<Value> operands) { |
693 | 0 | AffineMap normalizedMap = map; |
694 | 0 | SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end()); |
695 | 0 | composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); |
696 | 0 | assert(normalizedMap); |
697 | 0 | return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands); |
698 | 0 | } |
699 | | |
700 | | // A symbol may appear as a dim in affine.apply operations. This function |
701 | | // canonicalizes dims that are valid symbols into actual symbols. |
702 | | template <class MapOrSet> |
703 | | static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, |
704 | 0 | SmallVectorImpl<Value> *operands) { |
705 | 0 | if (!mapOrSet || operands->empty()) |
706 | 0 | return; |
707 | 0 | |
708 | 0 | assert(mapOrSet->getNumInputs() == operands->size() && |
709 | 0 | "map/set inputs must match number of operands"); |
710 | 0 |
|
711 | 0 | auto *context = mapOrSet->getContext(); |
712 | 0 | SmallVector<Value, 8> resultOperands; |
713 | 0 | resultOperands.reserve(operands->size()); |
714 | 0 | SmallVector<Value, 8> remappedSymbols; |
715 | 0 | remappedSymbols.reserve(operands->size()); |
716 | 0 | unsigned nextDim = 0; |
717 | 0 | unsigned nextSym = 0; |
718 | 0 | unsigned oldNumSyms = mapOrSet->getNumSymbols(); |
719 | 0 | SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims()); |
720 | 0 | for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) { |
721 | 0 | if (i < mapOrSet->getNumDims()) { |
722 | 0 | if (isValidSymbol((*operands)[i])) { |
723 | 0 | // This is a valid symbol that appears as a dim, canonicalize it. |
724 | 0 | dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context); |
725 | 0 | remappedSymbols.push_back((*operands)[i]); |
726 | 0 | } else { |
727 | 0 | dimRemapping[i] = getAffineDimExpr(nextDim++, context); |
728 | 0 | resultOperands.push_back((*operands)[i]); |
729 | 0 | } |
730 | 0 | } else { |
731 | 0 | resultOperands.push_back((*operands)[i]); |
732 | 0 | } |
733 | 0 | } |
734 | 0 |
|
735 | 0 | resultOperands.append(remappedSymbols.begin(), remappedSymbols.end()); |
736 | 0 | *operands = resultOperands; |
737 | 0 | *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim, |
738 | 0 | oldNumSyms + nextSym); |
739 | 0 |
|
740 | 0 | assert(mapOrSet->getNumInputs() == operands->size() && |
741 | 0 | "map/set inputs must match number of operands"); |
742 | 0 | } Unexecuted instantiation: AffineOps.cpp:_ZL27canonicalizePromotedSymbolsIN4mlir9AffineMapEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEE Unexecuted instantiation: AffineOps.cpp:_ZL27canonicalizePromotedSymbolsIN4mlir10IntegerSetEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEE |
743 | | |
744 | | // Works for either an affine map or an integer set. |
745 | | template <class MapOrSet> |
746 | | static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, |
747 | 0 | SmallVectorImpl<Value> *operands) { |
748 | 0 | static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value, |
749 | 0 | "Argument must be either of AffineMap or IntegerSet type"); |
750 | 0 |
|
751 | 0 | if (!mapOrSet || operands->empty()) |
752 | 0 | return; |
753 | 0 | |
754 | 0 | assert(mapOrSet->getNumInputs() == operands->size() && |
755 | 0 | "map/set inputs must match number of operands"); |
756 | 0 |
|
757 | 0 | canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands); |
758 | 0 |
|
759 | 0 | // Check to see what dims are used. |
760 | 0 | llvm::SmallBitVector usedDims(mapOrSet->getNumDims()); |
761 | 0 | llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols()); |
762 | 0 | mapOrSet->walkExprs([&](AffineExpr expr) { |
763 | 0 | if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) |
764 | 0 | usedDims[dimExpr.getPosition()] = true; |
765 | 0 | else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) |
766 | 0 | usedSyms[symExpr.getPosition()] = true; |
767 | 0 | }); Unexecuted instantiation: AffineOps.cpp:_ZZL31canonicalizeMapOrSetAndOperandsIN4mlir9AffineMapEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEEENKUlNS0_10AffineExprEE_clES9_ Unexecuted instantiation: AffineOps.cpp:_ZZL31canonicalizeMapOrSetAndOperandsIN4mlir10IntegerSetEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEEENKUlNS0_10AffineExprEE_clES9_ |
768 | 0 |
|
769 | 0 | auto *context = mapOrSet->getContext(); |
770 | 0 |
|
771 | 0 | SmallVector<Value, 8> resultOperands; |
772 | 0 | resultOperands.reserve(operands->size()); |
773 | 0 |
|
774 | 0 | llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims; |
775 | 0 | SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims()); |
776 | 0 | unsigned nextDim = 0; |
777 | 0 | for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) { |
778 | 0 | if (usedDims[i]) { |
779 | 0 | // Remap dim positions for duplicate operands. |
780 | 0 | auto it = seenDims.find((*operands)[i]); |
781 | 0 | if (it == seenDims.end()) { |
782 | 0 | dimRemapping[i] = getAffineDimExpr(nextDim++, context); |
783 | 0 | resultOperands.push_back((*operands)[i]); |
784 | 0 | seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i])); |
785 | 0 | } else { |
786 | 0 | dimRemapping[i] = it->second; |
787 | 0 | } |
788 | 0 | } |
789 | 0 | } |
790 | 0 | llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols; |
791 | 0 | SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols()); |
792 | 0 | unsigned nextSym = 0; |
793 | 0 | for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) { |
794 | 0 | if (!usedSyms[i]) |
795 | 0 | continue; |
796 | 0 | // Handle constant operands (only needed for symbolic operands since |
797 | 0 | // constant operands in dimensional positions would have already been |
798 | 0 | // promoted to symbolic positions above). |
799 | 0 | IntegerAttr operandCst; |
800 | 0 | if (matchPattern((*operands)[i + mapOrSet->getNumDims()], |
801 | 0 | m_Constant(&operandCst))) { |
802 | 0 | symRemapping[i] = |
803 | 0 | getAffineConstantExpr(operandCst.getValue().getSExtValue(), context); |
804 | 0 | continue; |
805 | 0 | } |
806 | 0 | // Remap symbol positions for duplicate operands. |
807 | 0 | auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]); |
808 | 0 | if (it == seenSymbols.end()) { |
809 | 0 | symRemapping[i] = getAffineSymbolExpr(nextSym++, context); |
810 | 0 | resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]); |
811 | 0 | seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()], |
812 | 0 | symRemapping[i])); |
813 | 0 | } else { |
814 | 0 | symRemapping[i] = it->second; |
815 | 0 | } |
816 | 0 | } |
817 | 0 | *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping, |
818 | 0 | nextDim, nextSym); |
819 | 0 | *operands = resultOperands; |
820 | 0 | } Unexecuted instantiation: AffineOps.cpp:_ZL31canonicalizeMapOrSetAndOperandsIN4mlir9AffineMapEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEE Unexecuted instantiation: AffineOps.cpp:_ZL31canonicalizeMapOrSetAndOperandsIN4mlir10IntegerSetEEvPT_PN4llvm15SmallVectorImplINS0_5ValueEEE |
821 | | |
822 | | void mlir::canonicalizeMapAndOperands(AffineMap *map, |
823 | 0 | SmallVectorImpl<Value> *operands) { |
824 | 0 | canonicalizeMapOrSetAndOperands<AffineMap>(map, operands); |
825 | 0 | } |
826 | | |
827 | | void mlir::canonicalizeSetAndOperands(IntegerSet *set, |
828 | 0 | SmallVectorImpl<Value> *operands) { |
829 | 0 | canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands); |
830 | 0 | } |
831 | | |
832 | | namespace { |
833 | | /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing |
834 | | /// maps that supply results into them. |
835 | | /// |
836 | | template <typename AffineOpTy> |
837 | | struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> { |
838 | | using OpRewritePattern<AffineOpTy>::OpRewritePattern; |
839 | | |
840 | | /// Replace the affine op with another instance of it with the supplied |
841 | | /// map and mapOperands. |
842 | | void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp, |
843 | | AffineMap map, ArrayRef<Value> mapOperands) const; |
844 | | |
845 | | LogicalResult matchAndRewrite(AffineOpTy affineOp, |
846 | 0 | PatternRewriter &rewriter) const override { |
847 | 0 | static_assert(llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp, |
848 | 0 | AffineStoreOp, AffineApplyOp, AffineMinOp, |
849 | 0 | AffineMaxOp>::value, |
850 | 0 | "affine load/store/apply/prefetch/min/max op expected"); |
851 | 0 | auto map = affineOp.getAffineMap(); |
852 | 0 | AffineMap oldMap = map; |
853 | 0 | auto oldOperands = affineOp.getMapOperands(); |
854 | 0 | SmallVector<Value, 8> resultOperands(oldOperands); |
855 | 0 | composeAffineMapAndOperands(&map, &resultOperands); |
856 | 0 | if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), |
857 | 0 | resultOperands.begin())) |
858 | 0 | return failure(); |
859 | 0 | |
860 | 0 | replaceAffineOp(rewriter, affineOp, map, resultOperands); |
861 | 0 | return success(); |
862 | 0 | } Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir13AffineApplyOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir12AffineLoadOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir13AffineStoreOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir11AffineMinOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir11AffineMaxOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir16AffinePrefetchOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE |
863 | | }; |
864 | | |
865 | | // Specialize the template to account for the different build signatures for |
866 | | // affine load, store, and apply ops. |
867 | | template <> |
868 | | void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp( |
869 | | PatternRewriter &rewriter, AffineLoadOp load, AffineMap map, |
870 | 0 | ArrayRef<Value> mapOperands) const { |
871 | 0 | rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map, |
872 | 0 | mapOperands); |
873 | 0 | } |
874 | | template <> |
875 | | void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp( |
876 | | PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map, |
877 | 0 | ArrayRef<Value> mapOperands) const { |
878 | 0 | rewriter.replaceOpWithNewOp<AffinePrefetchOp>( |
879 | 0 | prefetch, prefetch.memref(), map, mapOperands, |
880 | 0 | prefetch.localityHint().getZExtValue(), prefetch.isWrite(), |
881 | 0 | prefetch.isDataCache()); |
882 | 0 | } |
883 | | template <> |
884 | | void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp( |
885 | | PatternRewriter &rewriter, AffineStoreOp store, AffineMap map, |
886 | 0 | ArrayRef<Value> mapOperands) const { |
887 | 0 | rewriter.replaceOpWithNewOp<AffineStoreOp>( |
888 | 0 | store, store.getValueToStore(), store.getMemRef(), map, mapOperands); |
889 | 0 | } |
890 | | |
891 | | // Generic version for ops that don't have extra operands. |
892 | | template <typename AffineOpTy> |
893 | | void SimplifyAffineOp<AffineOpTy>::replaceAffineOp( |
894 | | PatternRewriter &rewriter, AffineOpTy op, AffineMap map, |
895 | 0 | ArrayRef<Value> mapOperands) const { |
896 | 0 | rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands); |
897 | 0 | } Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir13AffineApplyOpEE15replaceAffineOpERNS1_15PatternRewriterES2_NS1_9AffineMapEN4llvm8ArrayRefINS1_5ValueEEE Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir11AffineMinOpEE15replaceAffineOpERNS1_15PatternRewriterES2_NS1_9AffineMapEN4llvm8ArrayRefINS1_5ValueEEE Unexecuted instantiation: AffineOps.cpp:_ZNK12_GLOBAL__N_116SimplifyAffineOpIN4mlir11AffineMaxOpEE15replaceAffineOpERNS1_15PatternRewriterES2_NS1_9AffineMapEN4llvm8ArrayRefINS1_5ValueEEE |
898 | | } // end anonymous namespace. |
899 | | |
900 | | void AffineApplyOp::getCanonicalizationPatterns( |
901 | 0 | OwningRewritePatternList &results, MLIRContext *context) { |
902 | 0 | results.insert<SimplifyAffineOp<AffineApplyOp>>(context); |
903 | 0 | } |
904 | | |
905 | | //===----------------------------------------------------------------------===// |
906 | | // Common canonicalization pattern support logic |
907 | | //===----------------------------------------------------------------------===// |
908 | | |
909 | | /// This is a common class used for patterns of the form |
910 | | /// "someop(memrefcast) -> someop". It folds the source of any memref_cast |
911 | | /// into the root operation directly. |
912 | 0 | static LogicalResult foldMemRefCast(Operation *op) { |
913 | 0 | bool folded = false; |
914 | 0 | for (OpOperand &operand : op->getOpOperands()) { |
915 | 0 | auto cast = operand.get().getDefiningOp<MemRefCastOp>(); |
916 | 0 | if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) { |
917 | 0 | operand.set(cast.getOperand()); |
918 | 0 | folded = true; |
919 | 0 | } |
920 | 0 | } |
921 | 0 | return success(folded); |
922 | 0 | } |
923 | | |
924 | | //===----------------------------------------------------------------------===// |
925 | | // AffineDmaStartOp |
926 | | //===----------------------------------------------------------------------===// |
927 | | |
928 | | // TODO(b/133776335) Check that map operands are loop IVs or symbols. |
929 | | void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result, |
930 | | Value srcMemRef, AffineMap srcMap, |
931 | | ValueRange srcIndices, Value destMemRef, |
932 | | AffineMap dstMap, ValueRange destIndices, |
933 | | Value tagMemRef, AffineMap tagMap, |
934 | | ValueRange tagIndices, Value numElements, |
935 | 0 | Value stride, Value elementsPerStride) { |
936 | 0 | result.addOperands(srcMemRef); |
937 | 0 | result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap)); |
938 | 0 | result.addOperands(srcIndices); |
939 | 0 | result.addOperands(destMemRef); |
940 | 0 | result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap)); |
941 | 0 | result.addOperands(destIndices); |
942 | 0 | result.addOperands(tagMemRef); |
943 | 0 | result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); |
944 | 0 | result.addOperands(tagIndices); |
945 | 0 | result.addOperands(numElements); |
946 | 0 | if (stride) { |
947 | 0 | result.addOperands({stride, elementsPerStride}); |
948 | 0 | } |
949 | 0 | } |
950 | | |
951 | 0 | void AffineDmaStartOp::print(OpAsmPrinter &p) { |
952 | 0 | p << "affine.dma_start " << getSrcMemRef() << '['; |
953 | 0 | p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); |
954 | 0 | p << "], " << getDstMemRef() << '['; |
955 | 0 | p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices()); |
956 | 0 | p << "], " << getTagMemRef() << '['; |
957 | 0 | p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices()); |
958 | 0 | p << "], " << getNumElements(); |
959 | 0 | if (isStrided()) { |
960 | 0 | p << ", " << getStride(); |
961 | 0 | p << ", " << getNumElementsPerStride(); |
962 | 0 | } |
963 | 0 | p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " |
964 | 0 | << getTagMemRefType(); |
965 | 0 | } |
966 | | |
967 | | // Parse AffineDmaStartOp. |
968 | | // Ex: |
969 | | // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size, |
970 | | // %stride, %num_elt_per_stride |
971 | | // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32> |
972 | | // |
973 | | ParseResult AffineDmaStartOp::parse(OpAsmParser &parser, |
974 | 0 | OperationState &result) { |
975 | 0 | OpAsmParser::OperandType srcMemRefInfo; |
976 | 0 | AffineMapAttr srcMapAttr; |
977 | 0 | SmallVector<OpAsmParser::OperandType, 4> srcMapOperands; |
978 | 0 | OpAsmParser::OperandType dstMemRefInfo; |
979 | 0 | AffineMapAttr dstMapAttr; |
980 | 0 | SmallVector<OpAsmParser::OperandType, 4> dstMapOperands; |
981 | 0 | OpAsmParser::OperandType tagMemRefInfo; |
982 | 0 | AffineMapAttr tagMapAttr; |
983 | 0 | SmallVector<OpAsmParser::OperandType, 4> tagMapOperands; |
984 | 0 | OpAsmParser::OperandType numElementsInfo; |
985 | 0 | SmallVector<OpAsmParser::OperandType, 2> strideInfo; |
986 | 0 |
|
987 | 0 | SmallVector<Type, 3> types; |
988 | 0 | auto indexType = parser.getBuilder().getIndexType(); |
989 | 0 |
|
990 | 0 | // Parse and resolve the following list of operands: |
991 | 0 | // *) dst memref followed by its affine maps operands (in square brackets). |
992 | 0 | // *) src memref followed by its affine map operands (in square brackets). |
993 | 0 | // *) tag memref followed by its affine map operands (in square brackets). |
994 | 0 | // *) number of elements transferred by DMA operation. |
995 | 0 | if (parser.parseOperand(srcMemRefInfo) || |
996 | 0 | parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr, |
997 | 0 | getSrcMapAttrName(), result.attributes) || |
998 | 0 | parser.parseComma() || parser.parseOperand(dstMemRefInfo) || |
999 | 0 | parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr, |
1000 | 0 | getDstMapAttrName(), result.attributes) || |
1001 | 0 | parser.parseComma() || parser.parseOperand(tagMemRefInfo) || |
1002 | 0 | parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, |
1003 | 0 | getTagMapAttrName(), result.attributes) || |
1004 | 0 | parser.parseComma() || parser.parseOperand(numElementsInfo)) |
1005 | 0 | return failure(); |
1006 | 0 | |
1007 | 0 | // Parse optional stride and elements per stride. |
1008 | 0 | if (parser.parseTrailingOperandList(strideInfo)) { |
1009 | 0 | return failure(); |
1010 | 0 | } |
1011 | 0 | if (!strideInfo.empty() && strideInfo.size() != 2) { |
1012 | 0 | return parser.emitError(parser.getNameLoc(), |
1013 | 0 | "expected two stride related operands"); |
1014 | 0 | } |
1015 | 0 | bool isStrided = strideInfo.size() == 2; |
1016 | 0 |
|
1017 | 0 | if (parser.parseColonTypeList(types)) |
1018 | 0 | return failure(); |
1019 | 0 | |
1020 | 0 | if (types.size() != 3) |
1021 | 0 | return parser.emitError(parser.getNameLoc(), "expected three types"); |
1022 | 0 | |
1023 | 0 | if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || |
1024 | 0 | parser.resolveOperands(srcMapOperands, indexType, result.operands) || |
1025 | 0 | parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || |
1026 | 0 | parser.resolveOperands(dstMapOperands, indexType, result.operands) || |
1027 | 0 | parser.resolveOperand(tagMemRefInfo, types[2], result.operands) || |
1028 | 0 | parser.resolveOperands(tagMapOperands, indexType, result.operands) || |
1029 | 0 | parser.resolveOperand(numElementsInfo, indexType, result.operands)) |
1030 | 0 | return failure(); |
1031 | 0 | |
1032 | 0 | if (isStrided) { |
1033 | 0 | if (parser.resolveOperands(strideInfo, indexType, result.operands)) |
1034 | 0 | return failure(); |
1035 | 0 | } |
1036 | 0 | |
1037 | 0 | // Check that src/dst/tag operand counts match their map.numInputs. |
1038 | 0 | if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() || |
1039 | 0 | dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() || |
1040 | 0 | tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) |
1041 | 0 | return parser.emitError(parser.getNameLoc(), |
1042 | 0 | "memref operand count not equal to map.numInputs"); |
1043 | 0 | return success(); |
1044 | 0 | } |
1045 | | |
1046 | 0 | LogicalResult AffineDmaStartOp::verify() { |
1047 | 0 | if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>()) |
1048 | 0 | return emitOpError("expected DMA source to be of memref type"); |
1049 | 0 | if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>()) |
1050 | 0 | return emitOpError("expected DMA destination to be of memref type"); |
1051 | 0 | if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>()) |
1052 | 0 | return emitOpError("expected DMA tag to be of memref type"); |
1053 | 0 | |
1054 | 0 | // DMAs from different memory spaces supported. |
1055 | 0 | if (getSrcMemorySpace() == getDstMemorySpace()) { |
1056 | 0 | return emitOpError("DMA should be between different memory spaces"); |
1057 | 0 | } |
1058 | 0 | unsigned numInputsAllMaps = getSrcMap().getNumInputs() + |
1059 | 0 | getDstMap().getNumInputs() + |
1060 | 0 | getTagMap().getNumInputs(); |
1061 | 0 | if (getNumOperands() != numInputsAllMaps + 3 + 1 && |
1062 | 0 | getNumOperands() != numInputsAllMaps + 3 + 1 + 2) { |
1063 | 0 | return emitOpError("incorrect number of operands"); |
1064 | 0 | } |
1065 | 0 | |
1066 | 0 | Region *scope = getAffineScope(*this); |
1067 | 0 | for (auto idx : getSrcIndices()) { |
1068 | 0 | if (!idx.getType().isIndex()) |
1069 | 0 | return emitOpError("src index to dma_start must have 'index' type"); |
1070 | 0 | if (!isValidAffineIndexOperand(idx, scope)) |
1071 | 0 | return emitOpError("src index must be a dimension or symbol identifier"); |
1072 | 0 | } |
1073 | 0 | for (auto idx : getDstIndices()) { |
1074 | 0 | if (!idx.getType().isIndex()) |
1075 | 0 | return emitOpError("dst index to dma_start must have 'index' type"); |
1076 | 0 | if (!isValidAffineIndexOperand(idx, scope)) |
1077 | 0 | return emitOpError("dst index must be a dimension or symbol identifier"); |
1078 | 0 | } |
1079 | 0 | for (auto idx : getTagIndices()) { |
1080 | 0 | if (!idx.getType().isIndex()) |
1081 | 0 | return emitOpError("tag index to dma_start must have 'index' type"); |
1082 | 0 | if (!isValidAffineIndexOperand(idx, scope)) |
1083 | 0 | return emitOpError("tag index must be a dimension or symbol identifier"); |
1084 | 0 | } |
1085 | 0 | return success(); |
1086 | 0 | } |
1087 | | |
1088 | | LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands, |
1089 | 0 | SmallVectorImpl<OpFoldResult> &results) { |
1090 | 0 | /// dma_start(memrefcast) -> dma_start |
1091 | 0 | return foldMemRefCast(*this); |
1092 | 0 | } |
1093 | | |
1094 | | //===----------------------------------------------------------------------===// |
1095 | | // AffineDmaWaitOp |
1096 | | //===----------------------------------------------------------------------===// |
1097 | | |
1098 | | // TODO(b/133776335) Check that map operands are loop IVs or symbols. |
1099 | | void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result, |
1100 | | Value tagMemRef, AffineMap tagMap, |
1101 | 0 | ValueRange tagIndices, Value numElements) { |
1102 | 0 | result.addOperands(tagMemRef); |
1103 | 0 | result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); |
1104 | 0 | result.addOperands(tagIndices); |
1105 | 0 | result.addOperands(numElements); |
1106 | 0 | } |
1107 | | |
1108 | 0 | void AffineDmaWaitOp::print(OpAsmPrinter &p) { |
1109 | 0 | p << "affine.dma_wait " << getTagMemRef() << '['; |
1110 | 0 | SmallVector<Value, 2> operands(getTagIndices()); |
1111 | 0 | p.printAffineMapOfSSAIds(getTagMapAttr(), operands); |
1112 | 0 | p << "], "; |
1113 | 0 | p.printOperand(getNumElements()); |
1114 | 0 | p << " : " << getTagMemRef().getType(); |
1115 | 0 | } |
1116 | | |
1117 | | // Parse AffineDmaWaitOp. |
1118 | | // Eg: |
1119 | | // affine.dma_wait %tag[%index], %num_elements |
1120 | | // : memref<1 x i32, (d0) -> (d0), 4> |
1121 | | // |
1122 | | ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser, |
1123 | 0 | OperationState &result) { |
1124 | 0 | OpAsmParser::OperandType tagMemRefInfo; |
1125 | 0 | AffineMapAttr tagMapAttr; |
1126 | 0 | SmallVector<OpAsmParser::OperandType, 2> tagMapOperands; |
1127 | 0 | Type type; |
1128 | 0 | auto indexType = parser.getBuilder().getIndexType(); |
1129 | 0 | OpAsmParser::OperandType numElementsInfo; |
1130 | 0 |
|
1131 | 0 | // Parse tag memref, its map operands, and dma size. |
1132 | 0 | if (parser.parseOperand(tagMemRefInfo) || |
1133 | 0 | parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, |
1134 | 0 | getTagMapAttrName(), result.attributes) || |
1135 | 0 | parser.parseComma() || parser.parseOperand(numElementsInfo) || |
1136 | 0 | parser.parseColonType(type) || |
1137 | 0 | parser.resolveOperand(tagMemRefInfo, type, result.operands) || |
1138 | 0 | parser.resolveOperands(tagMapOperands, indexType, result.operands) || |
1139 | 0 | parser.resolveOperand(numElementsInfo, indexType, result.operands)) |
1140 | 0 | return failure(); |
1141 | 0 | |
1142 | 0 | if (!type.isa<MemRefType>()) |
1143 | 0 | return parser.emitError(parser.getNameLoc(), |
1144 | 0 | "expected tag to be of memref type"); |
1145 | 0 | |
1146 | 0 | if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) |
1147 | 0 | return parser.emitError(parser.getNameLoc(), |
1148 | 0 | "tag memref operand count != to map.numInputs"); |
1149 | 0 | return success(); |
1150 | 0 | } |
1151 | | |
1152 | 0 | LogicalResult AffineDmaWaitOp::verify() { |
1153 | 0 | if (!getOperand(0).getType().isa<MemRefType>()) |
1154 | 0 | return emitOpError("expected DMA tag to be of memref type"); |
1155 | 0 | Region *scope = getAffineScope(*this); |
1156 | 0 | for (auto idx : getTagIndices()) { |
1157 | 0 | if (!idx.getType().isIndex()) |
1158 | 0 | return emitOpError("index to dma_wait must have 'index' type"); |
1159 | 0 | if (!isValidAffineIndexOperand(idx, scope)) |
1160 | 0 | return emitOpError("index must be a dimension or symbol identifier"); |
1161 | 0 | } |
1162 | 0 | return success(); |
1163 | 0 | } |
1164 | | |
1165 | | LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands, |
1166 | 0 | SmallVectorImpl<OpFoldResult> &results) { |
1167 | 0 | /// dma_wait(memrefcast) -> dma_wait |
1168 | 0 | return foldMemRefCast(*this); |
1169 | 0 | } |
1170 | | |
1171 | | //===----------------------------------------------------------------------===// |
1172 | | // AffineForOp |
1173 | | //===----------------------------------------------------------------------===// |
1174 | | |
1175 | | void AffineForOp::build(OpBuilder &builder, OperationState &result, |
1176 | | ValueRange lbOperands, AffineMap lbMap, |
1177 | 0 | ValueRange ubOperands, AffineMap ubMap, int64_t step) { |
1178 | 0 | assert(((!lbMap && lbOperands.empty()) || |
1179 | 0 | lbOperands.size() == lbMap.getNumInputs()) && |
1180 | 0 | "lower bound operand count does not match the affine map"); |
1181 | 0 | assert(((!ubMap && ubOperands.empty()) || |
1182 | 0 | ubOperands.size() == ubMap.getNumInputs()) && |
1183 | 0 | "upper bound operand count does not match the affine map"); |
1184 | 0 | assert(step > 0 && "step has to be a positive integer constant"); |
1185 | 0 |
|
1186 | 0 | // Add an attribute for the step. |
1187 | 0 | result.addAttribute(getStepAttrName(), |
1188 | 0 | builder.getIntegerAttr(builder.getIndexType(), step)); |
1189 | 0 |
|
1190 | 0 | // Add the lower bound. |
1191 | 0 | result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap)); |
1192 | 0 | result.addOperands(lbOperands); |
1193 | 0 |
|
1194 | 0 | // Add the upper bound. |
1195 | 0 | result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap)); |
1196 | 0 | result.addOperands(ubOperands); |
1197 | 0 |
|
1198 | 0 | // Create a region and a block for the body. The argument of the region is |
1199 | 0 | // the loop induction variable. |
1200 | 0 | Region *bodyRegion = result.addRegion(); |
1201 | 0 | Block *body = new Block(); |
1202 | 0 | body->addArgument(IndexType::get(builder.getContext())); |
1203 | 0 | bodyRegion->push_back(body); |
1204 | 0 | ensureTerminator(*bodyRegion, builder, result.location); |
1205 | 0 | } |
1206 | | |
1207 | | void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb, |
1208 | 0 | int64_t ub, int64_t step) { |
1209 | 0 | auto lbMap = AffineMap::getConstantMap(lb, builder.getContext()); |
1210 | 0 | auto ubMap = AffineMap::getConstantMap(ub, builder.getContext()); |
1211 | 0 | return build(builder, result, {}, lbMap, {}, ubMap, step); |
1212 | 0 | } |
1213 | | |
1214 | 0 | static LogicalResult verify(AffineForOp op) { |
1215 | 0 | // Check that the body defines as single block argument for the induction |
1216 | 0 | // variable. |
1217 | 0 | auto *body = op.getBody(); |
1218 | 0 | if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex()) |
1219 | 0 | return op.emitOpError( |
1220 | 0 | "expected body to have a single index argument for the " |
1221 | 0 | "induction variable"); |
1222 | 0 | |
1223 | 0 | // Verify that there are enough operands for the bounds. |
1224 | 0 | AffineMap lowerBoundMap = op.getLowerBoundMap(), |
1225 | 0 | upperBoundMap = op.getUpperBoundMap(); |
1226 | 0 | if (op.getNumOperands() != |
1227 | 0 | (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs())) |
1228 | 0 | return op.emitOpError( |
1229 | 0 | "operand count must match with affine map dimension and symbol count"); |
1230 | 0 | |
1231 | 0 | // Verify that the bound operands are valid dimension/symbols. |
1232 | 0 | /// Lower bound. |
1233 | 0 | if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), |
1234 | 0 | op.getLowerBoundMap().getNumDims()))) |
1235 | 0 | return failure(); |
1236 | 0 | /// Upper bound. |
1237 | 0 | if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), |
1238 | 0 | op.getUpperBoundMap().getNumDims()))) |
1239 | 0 | return failure(); |
1240 | 0 | return success(); |
1241 | 0 | } |
1242 | | |
1243 | | /// Parse a for operation loop bounds. |
1244 | | static ParseResult parseBound(bool isLower, OperationState &result, |
1245 | 0 | OpAsmParser &p) { |
1246 | 0 | // 'min' / 'max' prefixes are generally syntactic sugar, but are required if |
1247 | 0 | // the map has multiple results. |
1248 | 0 | bool failedToParsedMinMax = |
1249 | 0 | failed(p.parseOptionalKeyword(isLower ? "max" : "min")); |
1250 | 0 |
|
1251 | 0 | auto &builder = p.getBuilder(); |
1252 | 0 | auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() |
1253 | 0 | : AffineForOp::getUpperBoundAttrName(); |
1254 | 0 |
|
1255 | 0 | // Parse ssa-id as identity map. |
1256 | 0 | SmallVector<OpAsmParser::OperandType, 1> boundOpInfos; |
1257 | 0 | if (p.parseOperandList(boundOpInfos)) |
1258 | 0 | return failure(); |
1259 | 0 | |
1260 | 0 | if (!boundOpInfos.empty()) { |
1261 | 0 | // Check that only one operand was parsed. |
1262 | 0 | if (boundOpInfos.size() > 1) |
1263 | 0 | return p.emitError(p.getNameLoc(), |
1264 | 0 | "expected only one loop bound operand"); |
1265 | 0 | |
1266 | 0 | // TODO: improve error message when SSA value is not of index type. |
1267 | 0 | // Currently it is 'use of value ... expects different type than prior uses' |
1268 | 0 | if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(), |
1269 | 0 | result.operands)) |
1270 | 0 | return failure(); |
1271 | 0 | |
1272 | 0 | // Create an identity map using symbol id. This representation is optimized |
1273 | 0 | // for storage. Analysis passes may expand it into a multi-dimensional map |
1274 | 0 | // if desired. |
1275 | 0 | AffineMap map = builder.getSymbolIdentityMap(); |
1276 | 0 | result.addAttribute(boundAttrName, AffineMapAttr::get(map)); |
1277 | 0 | return success(); |
1278 | 0 | } |
1279 | 0 | |
1280 | 0 | // Get the attribute location. |
1281 | 0 | llvm::SMLoc attrLoc = p.getCurrentLocation(); |
1282 | 0 |
|
1283 | 0 | Attribute boundAttr; |
1284 | 0 | if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName, |
1285 | 0 | result.attributes)) |
1286 | 0 | return failure(); |
1287 | 0 | |
1288 | 0 | // Parse full form - affine map followed by dim and symbol list. |
1289 | 0 | if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) { |
1290 | 0 | unsigned currentNumOperands = result.operands.size(); |
1291 | 0 | unsigned numDims; |
1292 | 0 | if (parseDimAndSymbolList(p, result.operands, numDims)) |
1293 | 0 | return failure(); |
1294 | 0 | |
1295 | 0 | auto map = affineMapAttr.getValue(); |
1296 | 0 | if (map.getNumDims() != numDims) |
1297 | 0 | return p.emitError( |
1298 | 0 | p.getNameLoc(), |
1299 | 0 | "dim operand count and affine map dim count must match"); |
1300 | 0 | |
1301 | 0 | unsigned numDimAndSymbolOperands = |
1302 | 0 | result.operands.size() - currentNumOperands; |
1303 | 0 | if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) |
1304 | 0 | return p.emitError( |
1305 | 0 | p.getNameLoc(), |
1306 | 0 | "symbol operand count and affine map symbol count must match"); |
1307 | 0 | |
1308 | 0 | // If the map has multiple results, make sure that we parsed the min/max |
1309 | 0 | // prefix. |
1310 | 0 | if (map.getNumResults() > 1 && failedToParsedMinMax) { |
1311 | 0 | if (isLower) { |
1312 | 0 | return p.emitError(attrLoc, "lower loop bound affine map with " |
1313 | 0 | "multiple results requires 'max' prefix"); |
1314 | 0 | } |
1315 | 0 | return p.emitError(attrLoc, "upper loop bound affine map with multiple " |
1316 | 0 | "results requires 'min' prefix"); |
1317 | 0 | } |
1318 | 0 | return success(); |
1319 | 0 | } |
1320 | 0 | |
1321 | 0 | // Parse custom assembly form. |
1322 | 0 | if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) { |
1323 | 0 | result.attributes.pop_back(); |
1324 | 0 | result.addAttribute( |
1325 | 0 | boundAttrName, |
1326 | 0 | AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt()))); |
1327 | 0 | return success(); |
1328 | 0 | } |
1329 | 0 | |
1330 | 0 | return p.emitError( |
1331 | 0 | p.getNameLoc(), |
1332 | 0 | "expected valid affine map representation for loop bounds"); |
1333 | 0 | } |
1334 | | |
1335 | | static ParseResult parseAffineForOp(OpAsmParser &parser, |
1336 | 0 | OperationState &result) { |
1337 | 0 | auto &builder = parser.getBuilder(); |
1338 | 0 | OpAsmParser::OperandType inductionVariable; |
1339 | 0 | // Parse the induction variable followed by '='. |
1340 | 0 | if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) |
1341 | 0 | return failure(); |
1342 | 0 | |
1343 | 0 | // Parse loop bounds. |
1344 | 0 | if (parseBound(/*isLower=*/true, result, parser) || |
1345 | 0 | parser.parseKeyword("to", " between bounds") || |
1346 | 0 | parseBound(/*isLower=*/false, result, parser)) |
1347 | 0 | return failure(); |
1348 | 0 | |
1349 | 0 | // Parse the optional loop step, we default to 1 if one is not present. |
1350 | 0 | if (parser.parseOptionalKeyword("step")) { |
1351 | 0 | result.addAttribute( |
1352 | 0 | AffineForOp::getStepAttrName(), |
1353 | 0 | builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); |
1354 | 0 | } else { |
1355 | 0 | llvm::SMLoc stepLoc = parser.getCurrentLocation(); |
1356 | 0 | IntegerAttr stepAttr; |
1357 | 0 | if (parser.parseAttribute(stepAttr, builder.getIndexType(), |
1358 | 0 | AffineForOp::getStepAttrName().data(), |
1359 | 0 | result.attributes)) |
1360 | 0 | return failure(); |
1361 | 0 | |
1362 | 0 | if (stepAttr.getValue().getSExtValue() < 0) |
1363 | 0 | return parser.emitError( |
1364 | 0 | stepLoc, |
1365 | 0 | "expected step to be representable as a positive signed integer"); |
1366 | 0 | } |
1367 | 0 | |
1368 | 0 | // Parse the body region. |
1369 | 0 | Region *body = result.addRegion(); |
1370 | 0 | if (parser.parseRegion(*body, inductionVariable, builder.getIndexType())) |
1371 | 0 | return failure(); |
1372 | 0 | |
1373 | 0 | AffineForOp::ensureTerminator(*body, builder, result.location); |
1374 | 0 |
|
1375 | 0 | // Parse the optional attribute list. |
1376 | 0 | return parser.parseOptionalAttrDict(result.attributes); |
1377 | 0 | } |
1378 | | |
1379 | | static void printBound(AffineMapAttr boundMap, |
1380 | | Operation::operand_range boundOperands, |
1381 | 0 | const char *prefix, OpAsmPrinter &p) { |
1382 | 0 | AffineMap map = boundMap.getValue(); |
1383 | 0 |
|
1384 | 0 | // Check if this bound should be printed using custom assembly form. |
1385 | 0 | // The decision to restrict printing custom assembly form to trivial cases |
1386 | 0 | // comes from the will to roundtrip MLIR binary -> text -> binary in a |
1387 | 0 | // lossless way. |
1388 | 0 | // Therefore, custom assembly form parsing and printing is only supported for |
1389 | 0 | // zero-operand constant maps and single symbol operand identity maps. |
1390 | 0 | if (map.getNumResults() == 1) { |
1391 | 0 | AffineExpr expr = map.getResult(0); |
1392 | 0 |
|
1393 | 0 | // Print constant bound. |
1394 | 0 | if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { |
1395 | 0 | if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { |
1396 | 0 | p << constExpr.getValue(); |
1397 | 0 | return; |
1398 | 0 | } |
1399 | 0 | } |
1400 | 0 | |
1401 | 0 | // Print bound that consists of a single SSA symbol if the map is over a |
1402 | 0 | // single symbol. |
1403 | 0 | if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { |
1404 | 0 | if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) { |
1405 | 0 | p.printOperand(*boundOperands.begin()); |
1406 | 0 | return; |
1407 | 0 | } |
1408 | 0 | } |
1409 | 0 | } else { |
1410 | 0 | // Map has multiple results. Print 'min' or 'max' prefix. |
1411 | 0 | p << prefix << ' '; |
1412 | 0 | } |
1413 | 0 |
|
1414 | 0 | // Print the map and its operands. |
1415 | 0 | p << boundMap; |
1416 | 0 | printDimAndSymbolList(boundOperands.begin(), boundOperands.end(), |
1417 | 0 | map.getNumDims(), p); |
1418 | 0 | } |
1419 | | |
1420 | 0 | static void print(OpAsmPrinter &p, AffineForOp op) { |
1421 | 0 | p << op.getOperationName() << ' '; |
1422 | 0 | p.printOperand(op.getBody()->getArgument(0)); |
1423 | 0 | p << " = "; |
1424 | 0 | printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p); |
1425 | 0 | p << " to "; |
1426 | 0 | printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p); |
1427 | 0 |
|
1428 | 0 | if (op.getStep() != 1) |
1429 | 0 | p << " step " << op.getStep(); |
1430 | 0 | p.printRegion(op.region(), |
1431 | 0 | /*printEntryBlockArgs=*/false, |
1432 | 0 | /*printBlockTerminators=*/false); |
1433 | 0 | p.printOptionalAttrDict(op.getAttrs(), |
1434 | 0 | /*elidedAttrs=*/{op.getLowerBoundAttrName(), |
1435 | 0 | op.getUpperBoundAttrName(), |
1436 | 0 | op.getStepAttrName()}); |
1437 | 0 | } |
1438 | | |
1439 | | /// Fold the constant bounds of a loop. |
1440 | 0 | static LogicalResult foldLoopBounds(AffineForOp forOp) { |
1441 | 0 | auto foldLowerOrUpperBound = [&forOp](bool lower) { |
1442 | 0 | // Check to see if each of the operands is the result of a constant. If |
1443 | 0 | // so, get the value. If not, ignore it. |
1444 | 0 | SmallVector<Attribute, 8> operandConstants; |
1445 | 0 | auto boundOperands = |
1446 | 0 | lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); |
1447 | 0 | for (auto operand : boundOperands) { |
1448 | 0 | Attribute operandCst; |
1449 | 0 | matchPattern(operand, m_Constant(&operandCst)); |
1450 | 0 | operandConstants.push_back(operandCst); |
1451 | 0 | } |
1452 | 0 |
|
1453 | 0 | AffineMap boundMap = |
1454 | 0 | lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); |
1455 | 0 | assert(boundMap.getNumResults() >= 1 && |
1456 | 0 | "bound maps should have at least one result"); |
1457 | 0 | SmallVector<Attribute, 4> foldedResults; |
1458 | 0 | if (failed(boundMap.constantFold(operandConstants, foldedResults))) |
1459 | 0 | return failure(); |
1460 | 0 | |
1461 | 0 | // Compute the max or min as applicable over the results. |
1462 | 0 | assert(!foldedResults.empty() && "bounds should have at least one result"); |
1463 | 0 | auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue(); |
1464 | 0 | for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { |
1465 | 0 | auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue(); |
1466 | 0 | maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) |
1467 | 0 | : llvm::APIntOps::smin(maxOrMin, foldedResult); |
1468 | 0 | } |
1469 | 0 | lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) |
1470 | 0 | : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); |
1471 | 0 | return success(); |
1472 | 0 | }; |
1473 | 0 |
|
1474 | 0 | // Try to fold the lower bound. |
1475 | 0 | bool folded = false; |
1476 | 0 | if (!forOp.hasConstantLowerBound()) |
1477 | 0 | folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true)); |
1478 | 0 |
|
1479 | 0 | // Try to fold the upper bound. |
1480 | 0 | if (!forOp.hasConstantUpperBound()) |
1481 | 0 | folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false)); |
1482 | 0 | return success(folded); |
1483 | 0 | } |
1484 | | |
1485 | | /// Canonicalize the bounds of the given loop. |
1486 | 0 | static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { |
1487 | 0 | SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands()); |
1488 | 0 | SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands()); |
1489 | 0 |
|
1490 | 0 | auto lbMap = forOp.getLowerBoundMap(); |
1491 | 0 | auto ubMap = forOp.getUpperBoundMap(); |
1492 | 0 | auto prevLbMap = lbMap; |
1493 | 0 | auto prevUbMap = ubMap; |
1494 | 0 |
|
1495 | 0 | canonicalizeMapAndOperands(&lbMap, &lbOperands); |
1496 | 0 | lbMap = removeDuplicateExprs(lbMap); |
1497 | 0 |
|
1498 | 0 | canonicalizeMapAndOperands(&ubMap, &ubOperands); |
1499 | 0 | ubMap = removeDuplicateExprs(ubMap); |
1500 | 0 |
|
1501 | 0 | // Any canonicalization change always leads to updated map(s). |
1502 | 0 | if (lbMap == prevLbMap && ubMap == prevUbMap) |
1503 | 0 | return failure(); |
1504 | 0 | |
1505 | 0 | if (lbMap != prevLbMap) |
1506 | 0 | forOp.setLowerBound(lbOperands, lbMap); |
1507 | 0 | if (ubMap != prevUbMap) |
1508 | 0 | forOp.setUpperBound(ubOperands, ubMap); |
1509 | 0 | return success(); |
1510 | 0 | } |
1511 | | |
1512 | | namespace { |
1513 | | /// This is a pattern to fold trivially empty loops. |
1514 | | struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> { |
1515 | | using OpRewritePattern<AffineForOp>::OpRewritePattern; |
1516 | | |
1517 | | LogicalResult matchAndRewrite(AffineForOp forOp, |
1518 | 0 | PatternRewriter &rewriter) const override { |
1519 | 0 | // Check that the body only contains a terminator. |
1520 | 0 | if (!llvm::hasSingleElement(*forOp.getBody())) |
1521 | 0 | return failure(); |
1522 | 0 | rewriter.eraseOp(forOp); |
1523 | 0 | return success(); |
1524 | 0 | } |
1525 | | }; |
1526 | | } // end anonymous namespace |
1527 | | |
1528 | | void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
1529 | 0 | MLIRContext *context) { |
1530 | 0 | results.insert<AffineForEmptyLoopFolder>(context); |
1531 | 0 | } |
1532 | | |
1533 | | LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands, |
1534 | 0 | SmallVectorImpl<OpFoldResult> &results) { |
1535 | 0 | bool folded = succeeded(foldLoopBounds(*this)); |
1536 | 0 | folded |= succeeded(canonicalizeLoopBounds(*this)); |
1537 | 0 | return success(folded); |
1538 | 0 | } |
1539 | | |
1540 | 0 | AffineBound AffineForOp::getLowerBound() { |
1541 | 0 | auto lbMap = getLowerBoundMap(); |
1542 | 0 | return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap); |
1543 | 0 | } |
1544 | | |
1545 | 0 | AffineBound AffineForOp::getUpperBound() { |
1546 | 0 | auto lbMap = getLowerBoundMap(); |
1547 | 0 | auto ubMap = getUpperBoundMap(); |
1548 | 0 | return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(), |
1549 | 0 | ubMap); |
1550 | 0 | } |
1551 | | |
1552 | 0 | void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) { |
1553 | 0 | assert(lbOperands.size() == map.getNumInputs()); |
1554 | 0 | assert(map.getNumResults() >= 1 && "bound map has at least one result"); |
1555 | 0 |
|
1556 | 0 | SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end()); |
1557 | 0 |
|
1558 | 0 | auto ubOperands = getUpperBoundOperands(); |
1559 | 0 | newOperands.append(ubOperands.begin(), ubOperands.end()); |
1560 | 0 | getOperation()->setOperands(newOperands); |
1561 | 0 |
|
1562 | 0 | setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); |
1563 | 0 | } |
1564 | | |
1565 | 0 | void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) { |
1566 | 0 | assert(ubOperands.size() == map.getNumInputs()); |
1567 | 0 | assert(map.getNumResults() >= 1 && "bound map has at least one result"); |
1568 | 0 |
|
1569 | 0 | SmallVector<Value, 4> newOperands(getLowerBoundOperands()); |
1570 | 0 | newOperands.append(ubOperands.begin(), ubOperands.end()); |
1571 | 0 | getOperation()->setOperands(newOperands); |
1572 | 0 |
|
1573 | 0 | setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); |
1574 | 0 | } |
1575 | | |
1576 | 0 | void AffineForOp::setLowerBoundMap(AffineMap map) { |
1577 | 0 | auto lbMap = getLowerBoundMap(); |
1578 | 0 | assert(lbMap.getNumDims() == map.getNumDims() && |
1579 | 0 | lbMap.getNumSymbols() == map.getNumSymbols()); |
1580 | 0 | assert(map.getNumResults() >= 1 && "bound map has at least one result"); |
1581 | 0 | (void)lbMap; |
1582 | 0 | setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); |
1583 | 0 | } |
1584 | | |
1585 | 0 | void AffineForOp::setUpperBoundMap(AffineMap map) { |
1586 | 0 | auto ubMap = getUpperBoundMap(); |
1587 | 0 | assert(ubMap.getNumDims() == map.getNumDims() && |
1588 | 0 | ubMap.getNumSymbols() == map.getNumSymbols()); |
1589 | 0 | assert(map.getNumResults() >= 1 && "bound map has at least one result"); |
1590 | 0 | (void)ubMap; |
1591 | 0 | setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); |
1592 | 0 | } |
1593 | | |
1594 | 0 | bool AffineForOp::hasConstantLowerBound() { |
1595 | 0 | return getLowerBoundMap().isSingleConstant(); |
1596 | 0 | } |
1597 | | |
1598 | 0 | bool AffineForOp::hasConstantUpperBound() { |
1599 | 0 | return getUpperBoundMap().isSingleConstant(); |
1600 | 0 | } |
1601 | | |
1602 | 0 | int64_t AffineForOp::getConstantLowerBound() { |
1603 | 0 | return getLowerBoundMap().getSingleConstantResult(); |
1604 | 0 | } |
1605 | | |
1606 | 0 | int64_t AffineForOp::getConstantUpperBound() { |
1607 | 0 | return getUpperBoundMap().getSingleConstantResult(); |
1608 | 0 | } |
1609 | | |
1610 | 0 | void AffineForOp::setConstantLowerBound(int64_t value) { |
1611 | 0 | setLowerBound({}, AffineMap::getConstantMap(value, getContext())); |
1612 | 0 | } |
1613 | | |
1614 | 0 | void AffineForOp::setConstantUpperBound(int64_t value) { |
1615 | 0 | setUpperBound({}, AffineMap::getConstantMap(value, getContext())); |
1616 | 0 | } |
1617 | | |
1618 | 0 | AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { |
1619 | 0 | return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; |
1620 | 0 | } |
1621 | | |
1622 | 0 | AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { |
1623 | 0 | return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; |
1624 | 0 | } |
1625 | | |
1626 | 0 | bool AffineForOp::matchingBoundOperandList() { |
1627 | 0 | auto lbMap = getLowerBoundMap(); |
1628 | 0 | auto ubMap = getUpperBoundMap(); |
1629 | 0 | if (lbMap.getNumDims() != ubMap.getNumDims() || |
1630 | 0 | lbMap.getNumSymbols() != ubMap.getNumSymbols()) |
1631 | 0 | return false; |
1632 | 0 | |
1633 | 0 | unsigned numOperands = lbMap.getNumInputs(); |
1634 | 0 | for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { |
1635 | 0 | // Compare Value 's. |
1636 | 0 | if (getOperand(i) != getOperand(numOperands + i)) |
1637 | 0 | return false; |
1638 | 0 | } |
1639 | 0 | return true; |
1640 | 0 | } |
1641 | | |
1642 | 0 | Region &AffineForOp::getLoopBody() { return region(); } |
1643 | | |
1644 | 0 | bool AffineForOp::isDefinedOutsideOfLoop(Value value) { |
1645 | 0 | return !region().isAncestor(value.getParentRegion()); |
1646 | 0 | } |
1647 | | |
1648 | 0 | LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) { |
1649 | 0 | for (auto *op : ops) |
1650 | 0 | op->moveBefore(*this); |
1651 | 0 | return success(); |
1652 | 0 | } |
1653 | | |
1654 | | /// Returns if the provided value is the induction variable of a AffineForOp. |
1655 | 0 | bool mlir::isForInductionVar(Value val) { |
1656 | 0 | return getForInductionVarOwner(val) != AffineForOp(); |
1657 | 0 | } |
1658 | | |
1659 | | /// Returns the loop parent of an induction variable. If the provided value is |
1660 | | /// not an induction variable, then return nullptr. |
1661 | 0 | AffineForOp mlir::getForInductionVarOwner(Value val) { |
1662 | 0 | auto ivArg = val.dyn_cast<BlockArgument>(); |
1663 | 0 | if (!ivArg || !ivArg.getOwner()) |
1664 | 0 | return AffineForOp(); |
1665 | 0 | auto *containingInst = ivArg.getOwner()->getParent()->getParentOp(); |
1666 | 0 | return dyn_cast<AffineForOp>(containingInst); |
1667 | 0 | } |
1668 | | |
1669 | | /// Extracts the induction variables from a list of AffineForOps and returns |
1670 | | /// them. |
1671 | | void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts, |
1672 | 0 | SmallVectorImpl<Value> *ivs) { |
1673 | 0 | ivs->reserve(forInsts.size()); |
1674 | 0 | for (auto forInst : forInsts) |
1675 | 0 | ivs->push_back(forInst.getInductionVar()); |
1676 | 0 | } |
1677 | | |
1678 | | //===----------------------------------------------------------------------===// |
1679 | | // AffineIfOp |
1680 | | //===----------------------------------------------------------------------===// |
1681 | | |
1682 | | namespace { |
1683 | | /// Remove else blocks that have nothing other than the terminator. |
1684 | | struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> { |
1685 | | using OpRewritePattern<AffineIfOp>::OpRewritePattern; |
1686 | | |
1687 | | LogicalResult matchAndRewrite(AffineIfOp ifOp, |
1688 | 0 | PatternRewriter &rewriter) const override { |
1689 | 0 | if (ifOp.elseRegion().empty() || |
1690 | 0 | !llvm::hasSingleElement(*ifOp.getElseBlock())) |
1691 | 0 | return failure(); |
1692 | 0 | |
1693 | 0 | rewriter.startRootUpdate(ifOp); |
1694 | 0 | rewriter.eraseBlock(ifOp.getElseBlock()); |
1695 | 0 | rewriter.finalizeRootUpdate(ifOp); |
1696 | 0 | return success(); |
1697 | 0 | } |
1698 | | }; |
1699 | | } // end anonymous namespace. |
1700 | | |
1701 | 0 | static LogicalResult verify(AffineIfOp op) { |
1702 | 0 | // Verify that we have a condition attribute. |
1703 | 0 | auto conditionAttr = |
1704 | 0 | op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName()); |
1705 | 0 | if (!conditionAttr) |
1706 | 0 | return op.emitOpError( |
1707 | 0 | "requires an integer set attribute named 'condition'"); |
1708 | 0 | |
1709 | 0 | // Verify that there are enough operands for the condition. |
1710 | 0 | IntegerSet condition = conditionAttr.getValue(); |
1711 | 0 | if (op.getNumOperands() != condition.getNumInputs()) |
1712 | 0 | return op.emitOpError( |
1713 | 0 | "operand count and condition integer set dimension and " |
1714 | 0 | "symbol count must match"); |
1715 | 0 | |
1716 | 0 | // Verify that the operands are valid dimension/symbols. |
1717 | 0 | if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(), |
1718 | 0 | condition.getNumDims()))) |
1719 | 0 | return failure(); |
1720 | 0 | |
1721 | 0 | // Verify that the entry of each child region does not have arguments. |
1722 | 0 | for (auto ®ion : op.getOperation()->getRegions()) { |
1723 | 0 | for (auto &b : region) |
1724 | 0 | if (b.getNumArguments() != 0) |
1725 | 0 | return op.emitOpError( |
1726 | 0 | "requires that child entry blocks have no arguments"); |
1727 | 0 | } |
1728 | 0 | return success(); |
1729 | 0 | } |
1730 | | |
1731 | | static ParseResult parseAffineIfOp(OpAsmParser &parser, |
1732 | 0 | OperationState &result) { |
1733 | 0 | // Parse the condition attribute set. |
1734 | 0 | IntegerSetAttr conditionAttr; |
1735 | 0 | unsigned numDims; |
1736 | 0 | if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(), |
1737 | 0 | result.attributes) || |
1738 | 0 | parseDimAndSymbolList(parser, result.operands, numDims)) |
1739 | 0 | return failure(); |
1740 | 0 | |
1741 | 0 | // Verify the condition operands. |
1742 | 0 | auto set = conditionAttr.getValue(); |
1743 | 0 | if (set.getNumDims() != numDims) |
1744 | 0 | return parser.emitError( |
1745 | 0 | parser.getNameLoc(), |
1746 | 0 | "dim operand count and integer set dim count must match"); |
1747 | 0 | if (numDims + set.getNumSymbols() != result.operands.size()) |
1748 | 0 | return parser.emitError( |
1749 | 0 | parser.getNameLoc(), |
1750 | 0 | "symbol operand count and integer set symbol count must match"); |
1751 | 0 | |
1752 | 0 | // Create the regions for 'then' and 'else'. The latter must be created even |
1753 | 0 | // if it remains empty for the validity of the operation. |
1754 | 0 | result.regions.reserve(2); |
1755 | 0 | Region *thenRegion = result.addRegion(); |
1756 | 0 | Region *elseRegion = result.addRegion(); |
1757 | 0 |
|
1758 | 0 | // Parse the 'then' region. |
1759 | 0 | if (parser.parseRegion(*thenRegion, {}, {})) |
1760 | 0 | return failure(); |
1761 | 0 | AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(), |
1762 | 0 | result.location); |
1763 | 0 |
|
1764 | 0 | // If we find an 'else' keyword then parse the 'else' region. |
1765 | 0 | if (!parser.parseOptionalKeyword("else")) { |
1766 | 0 | if (parser.parseRegion(*elseRegion, {}, {})) |
1767 | 0 | return failure(); |
1768 | 0 | AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(), |
1769 | 0 | result.location); |
1770 | 0 | } |
1771 | 0 |
|
1772 | 0 | // Parse the optional attribute list. |
1773 | 0 | if (parser.parseOptionalAttrDict(result.attributes)) |
1774 | 0 | return failure(); |
1775 | 0 | |
1776 | 0 | return success(); |
1777 | 0 | } |
1778 | | |
1779 | 0 | static void print(OpAsmPrinter &p, AffineIfOp op) { |
1780 | 0 | auto conditionAttr = |
1781 | 0 | op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName()); |
1782 | 0 | p << "affine.if " << conditionAttr; |
1783 | 0 | printDimAndSymbolList(op.operand_begin(), op.operand_end(), |
1784 | 0 | conditionAttr.getValue().getNumDims(), p); |
1785 | 0 | p.printRegion(op.thenRegion(), |
1786 | 0 | /*printEntryBlockArgs=*/false, |
1787 | 0 | /*printBlockTerminators=*/false); |
1788 | 0 |
|
1789 | 0 | // Print the 'else' regions if it has any blocks. |
1790 | 0 | auto &elseRegion = op.elseRegion(); |
1791 | 0 | if (!elseRegion.empty()) { |
1792 | 0 | p << " else"; |
1793 | 0 | p.printRegion(elseRegion, |
1794 | 0 | /*printEntryBlockArgs=*/false, |
1795 | 0 | /*printBlockTerminators=*/false); |
1796 | 0 | } |
1797 | 0 |
|
1798 | 0 | // Print the attribute list. |
1799 | 0 | p.printOptionalAttrDict(op.getAttrs(), |
1800 | 0 | /*elidedAttrs=*/op.getConditionAttrName()); |
1801 | 0 | } |
1802 | | |
1803 | 0 | IntegerSet AffineIfOp::getIntegerSet() { |
1804 | 0 | return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue(); |
1805 | 0 | } |
1806 | 0 | void AffineIfOp::setIntegerSet(IntegerSet newSet) { |
1807 | 0 | setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet)); |
1808 | 0 | } |
1809 | | |
1810 | 0 | void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) { |
1811 | 0 | setIntegerSet(set); |
1812 | 0 | getOperation()->setOperands(operands); |
1813 | 0 | } |
1814 | | |
1815 | | void AffineIfOp::build(OpBuilder &builder, OperationState &result, |
1816 | 0 | IntegerSet set, ValueRange args, bool withElseRegion) { |
1817 | 0 | result.addOperands(args); |
1818 | 0 | result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set)); |
1819 | 0 | Region *thenRegion = result.addRegion(); |
1820 | 0 | Region *elseRegion = result.addRegion(); |
1821 | 0 | AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); |
1822 | 0 | if (withElseRegion) |
1823 | 0 | AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); |
1824 | 0 | } |
1825 | | |
1826 | | /// Canonicalize an affine if op's conditional (integer set + operands). |
1827 | | LogicalResult AffineIfOp::fold(ArrayRef<Attribute>, |
1828 | 0 | SmallVectorImpl<OpFoldResult> &) { |
1829 | 0 | auto set = getIntegerSet(); |
1830 | 0 | SmallVector<Value, 4> operands(getOperands()); |
1831 | 0 | canonicalizeSetAndOperands(&set, &operands); |
1832 | 0 |
|
1833 | 0 | // Any canonicalization change always leads to either a reduction in the |
1834 | 0 | // number of operands or a change in the number of symbolic operands |
1835 | 0 | // (promotion of dims to symbols). |
1836 | 0 | if (operands.size() < getIntegerSet().getNumInputs() || |
1837 | 0 | set.getNumSymbols() > getIntegerSet().getNumSymbols()) { |
1838 | 0 | setConditional(set, operands); |
1839 | 0 | return success(); |
1840 | 0 | } |
1841 | 0 | |
1842 | 0 | return failure(); |
1843 | 0 | } |
1844 | | |
1845 | | void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
1846 | 0 | MLIRContext *context) { |
1847 | 0 | results.insert<SimplifyDeadElse>(context); |
1848 | 0 | } |
1849 | | |
1850 | | //===----------------------------------------------------------------------===// |
1851 | | // AffineLoadOp |
1852 | | //===----------------------------------------------------------------------===// |
1853 | | |
1854 | | void AffineLoadOp::build(OpBuilder &builder, OperationState &result, |
1855 | 0 | AffineMap map, ValueRange operands) { |
1856 | 0 | assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands"); |
1857 | 0 | result.addOperands(operands); |
1858 | 0 | if (map) |
1859 | 0 | result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); |
1860 | 0 | auto memrefType = operands[0].getType().cast<MemRefType>(); |
1861 | 0 | result.types.push_back(memrefType.getElementType()); |
1862 | 0 | } |
1863 | | |
1864 | | void AffineLoadOp::build(OpBuilder &builder, OperationState &result, |
1865 | 0 | Value memref, AffineMap map, ValueRange mapOperands) { |
1866 | 0 | assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); |
1867 | 0 | result.addOperands(memref); |
1868 | 0 | result.addOperands(mapOperands); |
1869 | 0 | auto memrefType = memref.getType().cast<MemRefType>(); |
1870 | 0 | result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); |
1871 | 0 | result.types.push_back(memrefType.getElementType()); |
1872 | 0 | } |
1873 | | |
1874 | | void AffineLoadOp::build(OpBuilder &builder, OperationState &result, |
1875 | 0 | Value memref, ValueRange indices) { |
1876 | 0 | auto memrefType = memref.getType().cast<MemRefType>(); |
1877 | 0 | auto rank = memrefType.getRank(); |
1878 | 0 | // Create identity map for memrefs with at least one dimension or () -> () |
1879 | 0 | // for zero-dimensional memrefs. |
1880 | 0 | auto map = |
1881 | 0 | rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); |
1882 | 0 | build(builder, result, memref, map, indices); |
1883 | 0 | } |
1884 | | |
1885 | | static ParseResult parseAffineLoadOp(OpAsmParser &parser, |
1886 | 0 | OperationState &result) { |
1887 | 0 | auto &builder = parser.getBuilder(); |
1888 | 0 | auto indexTy = builder.getIndexType(); |
1889 | 0 |
|
1890 | 0 | MemRefType type; |
1891 | 0 | OpAsmParser::OperandType memrefInfo; |
1892 | 0 | AffineMapAttr mapAttr; |
1893 | 0 | SmallVector<OpAsmParser::OperandType, 1> mapOperands; |
1894 | 0 | return failure( |
1895 | 0 | parser.parseOperand(memrefInfo) || |
1896 | 0 | parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
1897 | 0 | AffineLoadOp::getMapAttrName(), |
1898 | 0 | result.attributes) || |
1899 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
1900 | 0 | parser.parseColonType(type) || |
1901 | 0 | parser.resolveOperand(memrefInfo, type, result.operands) || |
1902 | 0 | parser.resolveOperands(mapOperands, indexTy, result.operands) || |
1903 | 0 | parser.addTypeToList(type.getElementType(), result.types)); |
1904 | 0 | } |
1905 | | |
1906 | 0 | static void print(OpAsmPrinter &p, AffineLoadOp op) { |
1907 | 0 | p << "affine.load " << op.getMemRef() << '['; |
1908 | 0 | if (AffineMapAttr mapAttr = |
1909 | 0 | op.getAttrOfType<AffineMapAttr>(op.getMapAttrName())) |
1910 | 0 | p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); |
1911 | 0 | p << ']'; |
1912 | 0 | p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); |
1913 | 0 | p << " : " << op.getMemRefType(); |
1914 | 0 | } |
1915 | | |
1916 | | /// Verify common indexing invariants of affine.load, affine.store, |
1917 | | /// affine.vector_load and affine.vector_store. |
1918 | | static LogicalResult |
1919 | | verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, |
1920 | | Operation::operand_range mapOperands, |
1921 | 0 | MemRefType memrefType, unsigned numIndexOperands) { |
1922 | 0 | if (mapAttr) { |
1923 | 0 | AffineMap map = mapAttr.getValue(); |
1924 | 0 | if (map.getNumResults() != memrefType.getRank()) |
1925 | 0 | return op->emitOpError("affine map num results must equal memref rank"); |
1926 | 0 | if (map.getNumInputs() != numIndexOperands) |
1927 | 0 | return op->emitOpError("expects as many subscripts as affine map inputs"); |
1928 | 0 | } else { |
1929 | 0 | if (memrefType.getRank() != numIndexOperands) |
1930 | 0 | return op->emitOpError( |
1931 | 0 | "expects the number of subscripts to be equal to memref rank"); |
1932 | 0 | } |
1933 | 0 | |
1934 | 0 | Region *scope = getAffineScope(op); |
1935 | 0 | for (auto idx : mapOperands) { |
1936 | 0 | if (!idx.getType().isIndex()) |
1937 | 0 | return op->emitOpError("index to load must have 'index' type"); |
1938 | 0 | if (!isValidAffineIndexOperand(idx, scope)) |
1939 | 0 | return op->emitOpError("index must be a dimension or symbol identifier"); |
1940 | 0 | } |
1941 | 0 |
|
1942 | 0 | return success(); |
1943 | 0 | } |
1944 | | |
1945 | 0 | LogicalResult verify(AffineLoadOp op) { |
1946 | 0 | auto memrefType = op.getMemRefType(); |
1947 | 0 | if (op.getType() != memrefType.getElementType()) |
1948 | 0 | return op.emitOpError("result type must match element type of memref"); |
1949 | 0 | |
1950 | 0 | if (failed(verifyMemoryOpIndexing( |
1951 | 0 | op.getOperation(), |
1952 | 0 | op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()), |
1953 | 0 | op.getMapOperands(), memrefType, |
1954 | 0 | /*numIndexOperands=*/op.getNumOperands() - 1))) |
1955 | 0 | return failure(); |
1956 | 0 | |
1957 | 0 | return success(); |
1958 | 0 | } |
1959 | | |
1960 | | void AffineLoadOp::getCanonicalizationPatterns( |
1961 | 0 | OwningRewritePatternList &results, MLIRContext *context) { |
1962 | 0 | results.insert<SimplifyAffineOp<AffineLoadOp>>(context); |
1963 | 0 | } |
1964 | | |
1965 | 0 | OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) { |
1966 | 0 | /// load(memrefcast) -> load |
1967 | 0 | if (succeeded(foldMemRefCast(*this))) |
1968 | 0 | return getResult(); |
1969 | 0 | return OpFoldResult(); |
1970 | 0 | } |
1971 | | |
1972 | | //===----------------------------------------------------------------------===// |
1973 | | // AffineStoreOp |
1974 | | //===----------------------------------------------------------------------===// |
1975 | | |
1976 | | void AffineStoreOp::build(OpBuilder &builder, OperationState &result, |
1977 | | Value valueToStore, Value memref, AffineMap map, |
1978 | 0 | ValueRange mapOperands) { |
1979 | 0 | assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); |
1980 | 0 | result.addOperands(valueToStore); |
1981 | 0 | result.addOperands(memref); |
1982 | 0 | result.addOperands(mapOperands); |
1983 | 0 | result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); |
1984 | 0 | } |
1985 | | |
1986 | | // Use identity map. |
1987 | | void AffineStoreOp::build(OpBuilder &builder, OperationState &result, |
1988 | | Value valueToStore, Value memref, |
1989 | 0 | ValueRange indices) { |
1990 | 0 | auto memrefType = memref.getType().cast<MemRefType>(); |
1991 | 0 | auto rank = memrefType.getRank(); |
1992 | 0 | // Create identity map for memrefs with at least one dimension or () -> () |
1993 | 0 | // for zero-dimensional memrefs. |
1994 | 0 | auto map = |
1995 | 0 | rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); |
1996 | 0 | build(builder, result, valueToStore, memref, map, indices); |
1997 | 0 | } |
1998 | | |
1999 | | static ParseResult parseAffineStoreOp(OpAsmParser &parser, |
2000 | 0 | OperationState &result) { |
2001 | 0 | auto indexTy = parser.getBuilder().getIndexType(); |
2002 | 0 |
|
2003 | 0 | MemRefType type; |
2004 | 0 | OpAsmParser::OperandType storeValueInfo; |
2005 | 0 | OpAsmParser::OperandType memrefInfo; |
2006 | 0 | AffineMapAttr mapAttr; |
2007 | 0 | SmallVector<OpAsmParser::OperandType, 1> mapOperands; |
2008 | 0 | return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() || |
2009 | 0 | parser.parseOperand(memrefInfo) || |
2010 | 0 | parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
2011 | 0 | AffineStoreOp::getMapAttrName(), |
2012 | 0 | result.attributes) || |
2013 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
2014 | 0 | parser.parseColonType(type) || |
2015 | 0 | parser.resolveOperand(storeValueInfo, type.getElementType(), |
2016 | 0 | result.operands) || |
2017 | 0 | parser.resolveOperand(memrefInfo, type, result.operands) || |
2018 | 0 | parser.resolveOperands(mapOperands, indexTy, result.operands)); |
2019 | 0 | } |
2020 | | |
2021 | 0 | static void print(OpAsmPrinter &p, AffineStoreOp op) { |
2022 | 0 | p << "affine.store " << op.getValueToStore(); |
2023 | 0 | p << ", " << op.getMemRef() << '['; |
2024 | 0 | if (AffineMapAttr mapAttr = |
2025 | 0 | op.getAttrOfType<AffineMapAttr>(op.getMapAttrName())) |
2026 | 0 | p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); |
2027 | 0 | p << ']'; |
2028 | 0 | p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); |
2029 | 0 | p << " : " << op.getMemRefType(); |
2030 | 0 | } |
2031 | | |
2032 | 0 | LogicalResult verify(AffineStoreOp op) { |
2033 | 0 | // First operand must have same type as memref element type. |
2034 | 0 | auto memrefType = op.getMemRefType(); |
2035 | 0 | if (op.getValueToStore().getType() != memrefType.getElementType()) |
2036 | 0 | return op.emitOpError( |
2037 | 0 | "first operand must have same type memref element type"); |
2038 | 0 | |
2039 | 0 | if (failed(verifyMemoryOpIndexing( |
2040 | 0 | op.getOperation(), |
2041 | 0 | op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()), |
2042 | 0 | op.getMapOperands(), memrefType, |
2043 | 0 | /*numIndexOperands=*/op.getNumOperands() - 2))) |
2044 | 0 | return failure(); |
2045 | 0 | |
2046 | 0 | return success(); |
2047 | 0 | } |
2048 | | |
2049 | | void AffineStoreOp::getCanonicalizationPatterns( |
2050 | 0 | OwningRewritePatternList &results, MLIRContext *context) { |
2051 | 0 | results.insert<SimplifyAffineOp<AffineStoreOp>>(context); |
2052 | 0 | } |
2053 | | |
2054 | | LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands, |
2055 | 0 | SmallVectorImpl<OpFoldResult> &results) { |
2056 | 0 | /// store(memrefcast) -> store |
2057 | 0 | return foldMemRefCast(*this); |
2058 | 0 | } |
2059 | | |
2060 | | //===----------------------------------------------------------------------===// |
2061 | | // AffineMinMaxOpBase |
2062 | | //===----------------------------------------------------------------------===// |
2063 | | |
2064 | | template <typename T> |
2065 | 0 | static LogicalResult verifyAffineMinMaxOp(T op) { |
2066 | 0 | // Verify that operand count matches affine map dimension and symbol count. |
2067 | 0 | if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols()) |
2068 | 0 | return op.emitOpError( |
2069 | 0 | "operand count and affine map dimension and symbol count must match"); |
2070 | 0 | return success(); |
2071 | 0 | } Unexecuted instantiation: AffineOps.cpp:_ZL20verifyAffineMinMaxOpIN4mlir11AffineMaxOpEENS0_13LogicalResultET_ Unexecuted instantiation: AffineOps.cpp:_ZL20verifyAffineMinMaxOpIN4mlir11AffineMinOpEENS0_13LogicalResultET_ |
2072 | | |
2073 | | template <typename T> |
2074 | 0 | static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { |
2075 | 0 | p << op.getOperationName() << ' ' << op.getAttr(T::getMapAttrName()); |
2076 | 0 | auto operands = op.getOperands(); |
2077 | 0 | unsigned numDims = op.map().getNumDims(); |
2078 | 0 | p << '(' << operands.take_front(numDims) << ')'; |
2079 | 0 |
|
2080 | 0 | if (operands.size() != numDims) |
2081 | 0 | p << '[' << operands.drop_front(numDims) << ']'; |
2082 | 0 | p.printOptionalAttrDict(op.getAttrs(), |
2083 | 0 | /*elidedAttrs=*/{T::getMapAttrName()}); |
2084 | 0 | } Unexecuted instantiation: AffineOps.cpp:_ZL19printAffineMinMaxOpIN4mlir11AffineMaxOpEEvRNS0_12OpAsmPrinterET_ Unexecuted instantiation: AffineOps.cpp:_ZL19printAffineMinMaxOpIN4mlir11AffineMinOpEEvRNS0_12OpAsmPrinterET_ |
2085 | | |
2086 | | template <typename T> |
2087 | | static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, |
2088 | 0 | OperationState &result) { |
2089 | 0 | auto &builder = parser.getBuilder(); |
2090 | 0 | auto indexType = builder.getIndexType(); |
2091 | 0 | SmallVector<OpAsmParser::OperandType, 8> dim_infos; |
2092 | 0 | SmallVector<OpAsmParser::OperandType, 8> sym_infos; |
2093 | 0 | AffineMapAttr mapAttr; |
2094 | 0 | return failure( |
2095 | 0 | parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) || |
2096 | 0 | parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) || |
2097 | 0 | parser.parseOperandList(sym_infos, |
2098 | 0 | OpAsmParser::Delimiter::OptionalSquare) || |
2099 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
2100 | 0 | parser.resolveOperands(dim_infos, indexType, result.operands) || |
2101 | 0 | parser.resolveOperands(sym_infos, indexType, result.operands) || |
2102 | 0 | parser.addTypeToList(indexType, result.types)); |
2103 | 0 | } Unexecuted instantiation: AffineOps.cpp:_ZL19parseAffineMinMaxOpIN4mlir11AffineMaxOpEENS0_11ParseResultERNS0_11OpAsmParserERNS0_14OperationStateE Unexecuted instantiation: AffineOps.cpp:_ZL19parseAffineMinMaxOpIN4mlir11AffineMinOpEENS0_11ParseResultERNS0_11OpAsmParserERNS0_14OperationStateE |
2104 | | |
2105 | | /// Fold an affine min or max operation with the given operands. The operand |
2106 | | /// list may contain nulls, which are interpreted as the operand not being a |
2107 | | /// constant. |
2108 | | template <typename T> |
2109 | 0 | static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) { |
2110 | 0 | static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value, |
2111 | 0 | "expected affine min or max op"); |
2112 | 0 |
|
2113 | 0 | // Fold the affine map. |
2114 | 0 | // TODO(andydavis, ntv) Fold more cases: |
2115 | 0 | // min(some_affine, some_affine + constant, ...), etc. |
2116 | 0 | SmallVector<int64_t, 2> results; |
2117 | 0 | auto foldedMap = op.map().partialConstantFold(operands, &results); |
2118 | 0 |
|
2119 | 0 | // If some of the map results are not constant, try changing the map in-place. |
2120 | 0 | if (results.empty()) { |
2121 | 0 | // If the map is the same, report that folding did not happen. |
2122 | 0 | if (foldedMap == op.map()) |
2123 | 0 | return {}; |
2124 | 0 | op.setAttr("map", AffineMapAttr::get(foldedMap)); |
2125 | 0 | return op.getResult(); |
2126 | 0 | } |
2127 | 0 | |
2128 | 0 | // Otherwise, completely fold the op into a constant. |
2129 | 0 | auto resultIt = std::is_same<T, AffineMinOp>::value |
2130 | 0 | ? std::min_element(results.begin(), results.end()) |
2131 | 0 | : std::max_element(results.begin(), results.end()); |
2132 | 0 | if (resultIt == results.end()) |
2133 | 0 | return {}; |
2134 | 0 | return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt); |
2135 | 0 | } Unexecuted instantiation: AffineOps.cpp:_ZL12foldMinMaxOpIN4mlir11AffineMinOpEENS0_12OpFoldResultET_N4llvm8ArrayRefINS0_9AttributeEEE Unexecuted instantiation: AffineOps.cpp:_ZL12foldMinMaxOpIN4mlir11AffineMaxOpEENS0_12OpFoldResultET_N4llvm8ArrayRefINS0_9AttributeEEE |
2136 | | |
2137 | | //===----------------------------------------------------------------------===// |
2138 | | // AffineMinOp |
2139 | | //===----------------------------------------------------------------------===// |
2140 | | // |
2141 | | // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) |
2142 | | // |
2143 | | |
2144 | 0 | OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) { |
2145 | 0 | return foldMinMaxOp(*this, operands); |
2146 | 0 | } |
2147 | | |
2148 | | void AffineMinOp::getCanonicalizationPatterns( |
2149 | 0 | OwningRewritePatternList &patterns, MLIRContext *context) { |
2150 | 0 | patterns.insert<SimplifyAffineOp<AffineMinOp>>(context); |
2151 | 0 | } |
2152 | | |
2153 | | //===----------------------------------------------------------------------===// |
2154 | | // AffineMaxOp |
2155 | | //===----------------------------------------------------------------------===// |
2156 | | // |
2157 | | // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0) |
2158 | | // |
2159 | | |
2160 | 0 | OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) { |
2161 | 0 | return foldMinMaxOp(*this, operands); |
2162 | 0 | } |
2163 | | |
2164 | | void AffineMaxOp::getCanonicalizationPatterns( |
2165 | 0 | OwningRewritePatternList &patterns, MLIRContext *context) { |
2166 | 0 | patterns.insert<SimplifyAffineOp<AffineMaxOp>>(context); |
2167 | 0 | } |
2168 | | |
2169 | | //===----------------------------------------------------------------------===// |
2170 | | // AffinePrefetchOp |
2171 | | //===----------------------------------------------------------------------===// |
2172 | | |
2173 | | // |
2174 | | // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32> |
2175 | | // |
2176 | | static ParseResult parseAffinePrefetchOp(OpAsmParser &parser, |
2177 | 0 | OperationState &result) { |
2178 | 0 | auto &builder = parser.getBuilder(); |
2179 | 0 | auto indexTy = builder.getIndexType(); |
2180 | 0 |
|
2181 | 0 | MemRefType type; |
2182 | 0 | OpAsmParser::OperandType memrefInfo; |
2183 | 0 | IntegerAttr hintInfo; |
2184 | 0 | auto i32Type = parser.getBuilder().getIntegerType(32); |
2185 | 0 | StringRef readOrWrite, cacheType; |
2186 | 0 |
|
2187 | 0 | AffineMapAttr mapAttr; |
2188 | 0 | SmallVector<OpAsmParser::OperandType, 1> mapOperands; |
2189 | 0 | if (parser.parseOperand(memrefInfo) || |
2190 | 0 | parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
2191 | 0 | AffinePrefetchOp::getMapAttrName(), |
2192 | 0 | result.attributes) || |
2193 | 0 | parser.parseComma() || parser.parseKeyword(&readOrWrite) || |
2194 | 0 | parser.parseComma() || parser.parseKeyword("locality") || |
2195 | 0 | parser.parseLess() || |
2196 | 0 | parser.parseAttribute(hintInfo, i32Type, |
2197 | 0 | AffinePrefetchOp::getLocalityHintAttrName(), |
2198 | 0 | result.attributes) || |
2199 | 0 | parser.parseGreater() || parser.parseComma() || |
2200 | 0 | parser.parseKeyword(&cacheType) || |
2201 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
2202 | 0 | parser.parseColonType(type) || |
2203 | 0 | parser.resolveOperand(memrefInfo, type, result.operands) || |
2204 | 0 | parser.resolveOperands(mapOperands, indexTy, result.operands)) |
2205 | 0 | return failure(); |
2206 | 0 | |
2207 | 0 | if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) |
2208 | 0 | return parser.emitError(parser.getNameLoc(), |
2209 | 0 | "rw specifier has to be 'read' or 'write'"); |
2210 | 0 | result.addAttribute( |
2211 | 0 | AffinePrefetchOp::getIsWriteAttrName(), |
2212 | 0 | parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); |
2213 | 0 |
|
2214 | 0 | if (!cacheType.equals("data") && !cacheType.equals("instr")) |
2215 | 0 | return parser.emitError(parser.getNameLoc(), |
2216 | 0 | "cache type has to be 'data' or 'instr'"); |
2217 | 0 | |
2218 | 0 | result.addAttribute( |
2219 | 0 | AffinePrefetchOp::getIsDataCacheAttrName(), |
2220 | 0 | parser.getBuilder().getBoolAttr(cacheType.equals("data"))); |
2221 | 0 |
|
2222 | 0 | return success(); |
2223 | 0 | } |
2224 | | |
2225 | 0 | static void print(OpAsmPrinter &p, AffinePrefetchOp op) { |
2226 | 0 | p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '['; |
2227 | 0 | AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()); |
2228 | 0 | if (mapAttr) { |
2229 | 0 | SmallVector<Value, 2> operands(op.getMapOperands()); |
2230 | 0 | p.printAffineMapOfSSAIds(mapAttr, operands); |
2231 | 0 | } |
2232 | 0 | p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", " |
2233 | 0 | << "locality<" << op.localityHint() << ">, " |
2234 | 0 | << (op.isDataCache() ? "data" : "instr"); |
2235 | 0 | p.printOptionalAttrDict( |
2236 | 0 | op.getAttrs(), |
2237 | 0 | /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(), |
2238 | 0 | op.getIsDataCacheAttrName(), op.getIsWriteAttrName()}); |
2239 | 0 | p << " : " << op.getMemRefType(); |
2240 | 0 | } |
2241 | | |
2242 | 0 | static LogicalResult verify(AffinePrefetchOp op) { |
2243 | 0 | auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()); |
2244 | 0 | if (mapAttr) { |
2245 | 0 | AffineMap map = mapAttr.getValue(); |
2246 | 0 | if (map.getNumResults() != op.getMemRefType().getRank()) |
2247 | 0 | return op.emitOpError("affine.prefetch affine map num results must equal" |
2248 | 0 | " memref rank"); |
2249 | 0 | if (map.getNumInputs() + 1 != op.getNumOperands()) |
2250 | 0 | return op.emitOpError("too few operands"); |
2251 | 0 | } else { |
2252 | 0 | if (op.getNumOperands() != 1) |
2253 | 0 | return op.emitOpError("too few operands"); |
2254 | 0 | } |
2255 | 0 | |
2256 | 0 | Region *scope = getAffineScope(op); |
2257 | 0 | for (auto idx : op.getMapOperands()) { |
2258 | 0 | if (!isValidAffineIndexOperand(idx, scope)) |
2259 | 0 | return op.emitOpError("index must be a dimension or symbol identifier"); |
2260 | 0 | } |
2261 | 0 | return success(); |
2262 | 0 | } |
2263 | | |
2264 | | void AffinePrefetchOp::getCanonicalizationPatterns( |
2265 | 0 | OwningRewritePatternList &results, MLIRContext *context) { |
2266 | 0 | // prefetch(memrefcast) -> prefetch |
2267 | 0 | results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context); |
2268 | 0 | } |
2269 | | |
2270 | | LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands, |
2271 | 0 | SmallVectorImpl<OpFoldResult> &results) { |
2272 | 0 | /// prefetch(memrefcast) -> prefetch |
2273 | 0 | return foldMemRefCast(*this); |
2274 | 0 | } |
2275 | | |
2276 | | //===----------------------------------------------------------------------===// |
2277 | | // AffineParallelOp |
2278 | | //===----------------------------------------------------------------------===// |
2279 | | |
2280 | | void AffineParallelOp::build(OpBuilder &builder, OperationState &result, |
2281 | 0 | ArrayRef<int64_t> ranges) { |
2282 | 0 | SmallVector<AffineExpr, 8> lbExprs(ranges.size(), |
2283 | 0 | builder.getAffineConstantExpr(0)); |
2284 | 0 | auto lbMap = AffineMap::get(0, 0, lbExprs, builder.getContext()); |
2285 | 0 | SmallVector<AffineExpr, 8> ubExprs; |
2286 | 0 | for (int64_t range : ranges) |
2287 | 0 | ubExprs.push_back(builder.getAffineConstantExpr(range)); |
2288 | 0 | auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext()); |
2289 | 0 | build(builder, result, lbMap, {}, ubMap, {}); |
2290 | 0 | } |
2291 | | |
2292 | | void AffineParallelOp::build(OpBuilder &builder, OperationState &result, |
2293 | | AffineMap lbMap, ValueRange lbArgs, |
2294 | 0 | AffineMap ubMap, ValueRange ubArgs) { |
2295 | 0 | auto numDims = lbMap.getNumResults(); |
2296 | 0 | // Verify that the dimensionality of both maps are the same. |
2297 | 0 | assert(numDims == ubMap.getNumResults() && |
2298 | 0 | "num dims and num results mismatch"); |
2299 | 0 | // Make default step sizes of 1. |
2300 | 0 | SmallVector<int64_t, 8> steps(numDims, 1); |
2301 | 0 | build(builder, result, lbMap, lbArgs, ubMap, ubArgs, steps); |
2302 | 0 | } |
2303 | | |
2304 | | void AffineParallelOp::build(OpBuilder &builder, OperationState &result, |
2305 | | AffineMap lbMap, ValueRange lbArgs, |
2306 | | AffineMap ubMap, ValueRange ubArgs, |
2307 | | ArrayRef<int64_t> steps) { |
2308 | | auto numDims = lbMap.getNumResults(); |
2309 | | // Verify that the dimensionality of the maps matches the number of steps. |
2310 | | assert(numDims == ubMap.getNumResults() && |
2311 | | "num dims and num results mismatch"); |
2312 | | assert(numDims == steps.size() && "num dims and num steps mismatch"); |
2313 | | result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap)); |
2314 | | result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap)); |
2315 | | result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps)); |
2316 | | result.addOperands(lbArgs); |
2317 | | result.addOperands(ubArgs); |
2318 | | // Create a region and a block for the body. |
2319 | | auto bodyRegion = result.addRegion(); |
2320 | | auto body = new Block(); |
2321 | | // Add all the block arguments. |
2322 | | for (unsigned i = 0; i < numDims; ++i) |
2323 | | body->addArgument(IndexType::get(builder.getContext())); |
2324 | | bodyRegion->push_back(body); |
2325 | | ensureTerminator(*bodyRegion, builder, result.location); |
2326 | | } |
2327 | | |
2328 | 0 | unsigned AffineParallelOp::getNumDims() { return steps().size(); } |
2329 | | |
2330 | 0 | AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() { |
2331 | 0 | return getOperands().take_front(lowerBoundsMap().getNumInputs()); |
2332 | 0 | } |
2333 | | |
2334 | 0 | AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() { |
2335 | 0 | return getOperands().drop_front(lowerBoundsMap().getNumInputs()); |
2336 | 0 | } |
2337 | | |
2338 | 0 | AffineValueMap AffineParallelOp::getLowerBoundsValueMap() { |
2339 | 0 | return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands()); |
2340 | 0 | } |
2341 | | |
2342 | 0 | AffineValueMap AffineParallelOp::getUpperBoundsValueMap() { |
2343 | 0 | return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands()); |
2344 | 0 | } |
2345 | | |
2346 | 0 | AffineValueMap AffineParallelOp::getRangesValueMap() { |
2347 | 0 | AffineValueMap out; |
2348 | 0 | AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(), |
2349 | 0 | &out); |
2350 | 0 | return out; |
2351 | 0 | } |
2352 | | |
2353 | 0 | Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() { |
2354 | 0 | // Try to convert all the ranges to constant expressions. |
2355 | 0 | SmallVector<int64_t, 8> out; |
2356 | 0 | AffineValueMap rangesValueMap = getRangesValueMap(); |
2357 | 0 | out.reserve(rangesValueMap.getNumResults()); |
2358 | 0 | for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) { |
2359 | 0 | auto expr = rangesValueMap.getResult(i); |
2360 | 0 | auto cst = expr.dyn_cast<AffineConstantExpr>(); |
2361 | 0 | if (!cst) |
2362 | 0 | return llvm::None; |
2363 | 0 | out.push_back(cst.getValue()); |
2364 | 0 | } |
2365 | 0 | return out; |
2366 | 0 | } |
2367 | | |
2368 | 0 | Block *AffineParallelOp::getBody() { return ®ion().front(); } |
2369 | | |
2370 | 0 | OpBuilder AffineParallelOp::getBodyBuilder() { |
2371 | 0 | return OpBuilder(getBody(), std::prev(getBody()->end())); |
2372 | 0 | } |
2373 | | |
2374 | 0 | void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) { |
2375 | 0 | assert(newSteps.size() == getNumDims() && "steps & num dims mismatch"); |
2376 | 0 | setAttr(getStepsAttrName(), getBodyBuilder().getI64ArrayAttr(newSteps)); |
2377 | 0 | } |
2378 | | |
2379 | 0 | static LogicalResult verify(AffineParallelOp op) { |
2380 | 0 | auto numDims = op.getNumDims(); |
2381 | 0 | if (op.lowerBoundsMap().getNumResults() != numDims || |
2382 | 0 | op.upperBoundsMap().getNumResults() != numDims || |
2383 | 0 | op.steps().size() != numDims || |
2384 | 0 | op.getBody()->getNumArguments() != numDims) { |
2385 | 0 | return op.emitOpError("region argument count and num results of upper " |
2386 | 0 | "bounds, lower bounds, and steps must all match"); |
2387 | 0 | } |
2388 | 0 | // Verify that the bound operands are valid dimension/symbols. |
2389 | 0 | /// Lower bounds. |
2390 | 0 | if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(), |
2391 | 0 | op.lowerBoundsMap().getNumDims()))) |
2392 | 0 | return failure(); |
2393 | 0 | /// Upper bounds. |
2394 | 0 | if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(), |
2395 | 0 | op.upperBoundsMap().getNumDims()))) |
2396 | 0 | return failure(); |
2397 | 0 | return success(); |
2398 | 0 | } |
2399 | | |
2400 | 0 | static void print(OpAsmPrinter &p, AffineParallelOp op) { |
2401 | 0 | p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("; |
2402 | 0 | p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(), |
2403 | 0 | op.getLowerBoundsOperands()); |
2404 | 0 | p << ") to ("; |
2405 | 0 | p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(), |
2406 | 0 | op.getUpperBoundsOperands()); |
2407 | 0 | p << ')'; |
2408 | 0 | SmallVector<int64_t, 4> steps; |
2409 | 0 | bool elideSteps = true; |
2410 | 0 | for (auto attr : op.steps()) { |
2411 | 0 | auto step = attr.cast<IntegerAttr>().getInt(); |
2412 | 0 | elideSteps &= (step == 1); |
2413 | 0 | steps.push_back(step); |
2414 | 0 | } |
2415 | 0 | if (!elideSteps) { |
2416 | 0 | p << " step ("; |
2417 | 0 | llvm::interleaveComma(steps, p); |
2418 | 0 | p << ')'; |
2419 | 0 | } |
2420 | 0 | p.printRegion(op.region(), /*printEntryBlockArgs=*/false, |
2421 | 0 | /*printBlockTerminators=*/false); |
2422 | 0 | p.printOptionalAttrDict( |
2423 | 0 | op.getAttrs(), |
2424 | 0 | /*elidedAttrs=*/{AffineParallelOp::getLowerBoundsMapAttrName(), |
2425 | 0 | AffineParallelOp::getUpperBoundsMapAttrName(), |
2426 | 0 | AffineParallelOp::getStepsAttrName()}); |
2427 | 0 | } |
2428 | | |
2429 | | // |
2430 | | // operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)` |
2431 | | // `to` `(` map-of-ssa-ids `)` steps? region attr-dict? |
2432 | | // steps ::= `steps` `(` integer-literals `)` |
2433 | | // |
2434 | | static ParseResult parseAffineParallelOp(OpAsmParser &parser, |
2435 | 0 | OperationState &result) { |
2436 | 0 | auto &builder = parser.getBuilder(); |
2437 | 0 | auto indexType = builder.getIndexType(); |
2438 | 0 | AffineMapAttr lowerBoundsAttr, upperBoundsAttr; |
2439 | 0 | SmallVector<OpAsmParser::OperandType, 4> ivs; |
2440 | 0 | SmallVector<OpAsmParser::OperandType, 4> lowerBoundsMapOperands; |
2441 | 0 | SmallVector<OpAsmParser::OperandType, 4> upperBoundsMapOperands; |
2442 | 0 | if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, |
2443 | 0 | OpAsmParser::Delimiter::Paren) || |
2444 | 0 | parser.parseEqual() || |
2445 | 0 | parser.parseAffineMapOfSSAIds( |
2446 | 0 | lowerBoundsMapOperands, lowerBoundsAttr, |
2447 | 0 | AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes, |
2448 | 0 | OpAsmParser::Delimiter::Paren) || |
2449 | 0 | parser.resolveOperands(lowerBoundsMapOperands, indexType, |
2450 | 0 | result.operands) || |
2451 | 0 | parser.parseKeyword("to") || |
2452 | 0 | parser.parseAffineMapOfSSAIds( |
2453 | 0 | upperBoundsMapOperands, upperBoundsAttr, |
2454 | 0 | AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes, |
2455 | 0 | OpAsmParser::Delimiter::Paren) || |
2456 | 0 | parser.resolveOperands(upperBoundsMapOperands, indexType, |
2457 | 0 | result.operands)) |
2458 | 0 | return failure(); |
2459 | 0 | |
2460 | 0 | AffineMapAttr stepsMapAttr; |
2461 | 0 | NamedAttrList stepsAttrs; |
2462 | 0 | SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands; |
2463 | 0 | if (failed(parser.parseOptionalKeyword("step"))) { |
2464 | 0 | SmallVector<int64_t, 4> steps(ivs.size(), 1); |
2465 | 0 | result.addAttribute(AffineParallelOp::getStepsAttrName(), |
2466 | 0 | builder.getI64ArrayAttr(steps)); |
2467 | 0 | } else { |
2468 | 0 | if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr, |
2469 | 0 | AffineParallelOp::getStepsAttrName(), |
2470 | 0 | stepsAttrs, |
2471 | 0 | OpAsmParser::Delimiter::Paren)) |
2472 | 0 | return failure(); |
2473 | 0 | |
2474 | 0 | // Convert steps from an AffineMap into an I64ArrayAttr. |
2475 | 0 | SmallVector<int64_t, 4> steps; |
2476 | 0 | auto stepsMap = stepsMapAttr.getValue(); |
2477 | 0 | for (const auto &result : stepsMap.getResults()) { |
2478 | 0 | auto constExpr = result.dyn_cast<AffineConstantExpr>(); |
2479 | 0 | if (!constExpr) |
2480 | 0 | return parser.emitError(parser.getNameLoc(), |
2481 | 0 | "steps must be constant integers"); |
2482 | 0 | steps.push_back(constExpr.getValue()); |
2483 | 0 | } |
2484 | 0 | result.addAttribute(AffineParallelOp::getStepsAttrName(), |
2485 | 0 | builder.getI64ArrayAttr(steps)); |
2486 | 0 | } |
2487 | 0 |
|
2488 | 0 | // Now parse the body. |
2489 | 0 | Region *body = result.addRegion(); |
2490 | 0 | SmallVector<Type, 4> types(ivs.size(), indexType); |
2491 | 0 | if (parser.parseRegion(*body, ivs, types) || |
2492 | 0 | parser.parseOptionalAttrDict(result.attributes)) |
2493 | 0 | return failure(); |
2494 | 0 | |
2495 | 0 | // Add a terminator if none was parsed. |
2496 | 0 | AffineParallelOp::ensureTerminator(*body, builder, result.location); |
2497 | 0 | return success(); |
2498 | 0 | } |
2499 | | |
2500 | | //===----------------------------------------------------------------------===// |
2501 | | // AffineVectorLoadOp |
2502 | | //===----------------------------------------------------------------------===// |
2503 | | |
2504 | | static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser, |
2505 | 0 | OperationState &result) { |
2506 | 0 | auto &builder = parser.getBuilder(); |
2507 | 0 | auto indexTy = builder.getIndexType(); |
2508 | 0 |
|
2509 | 0 | MemRefType memrefType; |
2510 | 0 | VectorType resultType; |
2511 | 0 | OpAsmParser::OperandType memrefInfo; |
2512 | 0 | AffineMapAttr mapAttr; |
2513 | 0 | SmallVector<OpAsmParser::OperandType, 1> mapOperands; |
2514 | 0 | return failure( |
2515 | 0 | parser.parseOperand(memrefInfo) || |
2516 | 0 | parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
2517 | 0 | AffineVectorLoadOp::getMapAttrName(), |
2518 | 0 | result.attributes) || |
2519 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
2520 | 0 | parser.parseColonType(memrefType) || parser.parseComma() || |
2521 | 0 | parser.parseType(resultType) || |
2522 | 0 | parser.resolveOperand(memrefInfo, memrefType, result.operands) || |
2523 | 0 | parser.resolveOperands(mapOperands, indexTy, result.operands) || |
2524 | 0 | parser.addTypeToList(resultType, result.types)); |
2525 | 0 | } |
2526 | | |
2527 | 0 | static void print(OpAsmPrinter &p, AffineVectorLoadOp op) { |
2528 | 0 | p << "affine.vector_load " << op.getMemRef() << '['; |
2529 | 0 | if (AffineMapAttr mapAttr = |
2530 | 0 | op.getAttrOfType<AffineMapAttr>(op.getMapAttrName())) |
2531 | 0 | p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); |
2532 | 0 | p << ']'; |
2533 | 0 | p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); |
2534 | 0 | p << " : " << op.getMemRefType() << ", " << op.getType(); |
2535 | 0 | } |
2536 | | |
2537 | | /// Verify common invariants of affine.vector_load and affine.vector_store. |
2538 | | static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, |
2539 | 0 | VectorType vectorType) { |
2540 | 0 | // Check that memref and vector element types match. |
2541 | 0 | if (memrefType.getElementType() != vectorType.getElementType()) |
2542 | 0 | return op->emitOpError( |
2543 | 0 | "requires memref and vector types of the same elemental type"); |
2544 | 0 | |
2545 | 0 | return success(); |
2546 | 0 | } |
2547 | | |
2548 | 0 | static LogicalResult verify(AffineVectorLoadOp op) { |
2549 | 0 | MemRefType memrefType = op.getMemRefType(); |
2550 | 0 | if (failed(verifyMemoryOpIndexing( |
2551 | 0 | op.getOperation(), |
2552 | 0 | op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()), |
2553 | 0 | op.getMapOperands(), memrefType, |
2554 | 0 | /*numIndexOperands=*/op.getNumOperands() - 1))) |
2555 | 0 | return failure(); |
2556 | 0 | |
2557 | 0 | if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, |
2558 | 0 | op.getVectorType()))) |
2559 | 0 | return failure(); |
2560 | 0 | |
2561 | 0 | return success(); |
2562 | 0 | } |
2563 | | |
2564 | | //===----------------------------------------------------------------------===// |
2565 | | // AffineVectorStoreOp |
2566 | | //===----------------------------------------------------------------------===// |
2567 | | |
2568 | | static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser, |
2569 | 0 | OperationState &result) { |
2570 | 0 | auto indexTy = parser.getBuilder().getIndexType(); |
2571 | 0 |
|
2572 | 0 | MemRefType memrefType; |
2573 | 0 | VectorType resultType; |
2574 | 0 | OpAsmParser::OperandType storeValueInfo; |
2575 | 0 | OpAsmParser::OperandType memrefInfo; |
2576 | 0 | AffineMapAttr mapAttr; |
2577 | 0 | SmallVector<OpAsmParser::OperandType, 1> mapOperands; |
2578 | 0 | return failure( |
2579 | 0 | parser.parseOperand(storeValueInfo) || parser.parseComma() || |
2580 | 0 | parser.parseOperand(memrefInfo) || |
2581 | 0 | parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
2582 | 0 | AffineVectorStoreOp::getMapAttrName(), |
2583 | 0 | result.attributes) || |
2584 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
2585 | 0 | parser.parseColonType(memrefType) || parser.parseComma() || |
2586 | 0 | parser.parseType(resultType) || |
2587 | 0 | parser.resolveOperand(storeValueInfo, resultType, result.operands) || |
2588 | 0 | parser.resolveOperand(memrefInfo, memrefType, result.operands) || |
2589 | 0 | parser.resolveOperands(mapOperands, indexTy, result.operands)); |
2590 | 0 | } |
2591 | | |
2592 | 0 | static void print(OpAsmPrinter &p, AffineVectorStoreOp op) { |
2593 | 0 | p << "affine.vector_store " << op.getValueToStore(); |
2594 | 0 | p << ", " << op.getMemRef() << '['; |
2595 | 0 | if (AffineMapAttr mapAttr = |
2596 | 0 | op.getAttrOfType<AffineMapAttr>(op.getMapAttrName())) |
2597 | 0 | p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); |
2598 | 0 | p << ']'; |
2599 | 0 | p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); |
2600 | 0 | p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType(); |
2601 | 0 | } |
2602 | | |
2603 | 0 | static LogicalResult verify(AffineVectorStoreOp op) { |
2604 | 0 | MemRefType memrefType = op.getMemRefType(); |
2605 | 0 | if (failed(verifyMemoryOpIndexing( |
2606 | 0 | op.getOperation(), |
2607 | 0 | op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()), |
2608 | 0 | op.getMapOperands(), memrefType, |
2609 | 0 | /*numIndexOperands=*/op.getNumOperands() - 2))) |
2610 | 0 | return failure(); |
2611 | 0 | |
2612 | 0 | if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, |
2613 | 0 | op.getVectorType()))) |
2614 | 0 | return failure(); |
2615 | 0 | |
2616 | 0 | return success(); |
2617 | 0 | } |
2618 | | |
2619 | | //===----------------------------------------------------------------------===// |
2620 | | // TableGen'd op method definitions |
2621 | | //===----------------------------------------------------------------------===// |
2622 | | |
2623 | | #define GET_OP_CLASSES |
2624 | | #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc" |