/home/arjun/llvm-project/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- Ops.cpp - Standard MLIR Operations ---------------------------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #include "mlir/Dialect/StandardOps/IR/Ops.h" |
10 | | |
11 | | #include "mlir/Dialect/CommonFolders.h" |
12 | | #include "mlir/IR/AffineExpr.h" |
13 | | #include "mlir/IR/AffineMap.h" |
14 | | #include "mlir/IR/Builders.h" |
15 | | #include "mlir/IR/Function.h" |
16 | | #include "mlir/IR/Matchers.h" |
17 | | #include "mlir/IR/Module.h" |
18 | | #include "mlir/IR/OpImplementation.h" |
19 | | #include "mlir/IR/PatternMatch.h" |
20 | | #include "mlir/IR/StandardTypes.h" |
21 | | #include "mlir/IR/TypeUtilities.h" |
22 | | #include "mlir/IR/Value.h" |
23 | | #include "mlir/Support/MathExtras.h" |
24 | | #include "mlir/Transforms/InliningUtils.h" |
25 | | #include "llvm/ADT/StringSwitch.h" |
26 | | #include "llvm/Support/FormatVariadic.h" |
27 | | #include "llvm/Support/raw_ostream.h" |
28 | | |
29 | | // Pull in all enum type definitions and utility function declarations. |
30 | | #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc" |
31 | | |
32 | | using namespace mlir; |
33 | | |
34 | | //===----------------------------------------------------------------------===// |
35 | | // StandardOpsDialect Interfaces |
36 | | //===----------------------------------------------------------------------===// |
37 | | namespace { |
38 | | /// This class defines the interface for handling inlining with standard |
39 | | /// operations. |
40 | | struct StdInlinerInterface : public DialectInlinerInterface { |
41 | | using DialectInlinerInterface::DialectInlinerInterface; |
42 | | |
43 | | //===--------------------------------------------------------------------===// |
44 | | // Analysis Hooks |
45 | | //===--------------------------------------------------------------------===// |
46 | | |
47 | | /// All operations within standard ops can be inlined. |
48 | | bool isLegalToInline(Operation *, Region *, |
49 | 0 | BlockAndValueMapping &) const final { |
50 | 0 | return true; |
51 | 0 | } |
52 | | |
53 | | //===--------------------------------------------------------------------===// |
54 | | // Transformation Hooks |
55 | | //===--------------------------------------------------------------------===// |
56 | | |
57 | | /// Handle the given inlined terminator by replacing it with a new operation |
58 | | /// as necessary. |
59 | 0 | void handleTerminator(Operation *op, Block *newDest) const final { |
60 | 0 | // Only "std.return" needs to be handled here. |
61 | 0 | auto returnOp = dyn_cast<ReturnOp>(op); |
62 | 0 | if (!returnOp) |
63 | 0 | return; |
64 | 0 | |
65 | 0 | // Replace the return with a branch to the dest. |
66 | 0 | OpBuilder builder(op); |
67 | 0 | builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands()); |
68 | 0 | op->erase(); |
69 | 0 | } |
70 | | |
71 | | /// Handle the given inlined terminator by replacing it with a new operation |
72 | | /// as necessary. |
73 | | void handleTerminator(Operation *op, |
74 | 0 | ArrayRef<Value> valuesToRepl) const final { |
75 | 0 | // Only "std.return" needs to be handled here. |
76 | 0 | auto returnOp = cast<ReturnOp>(op); |
77 | 0 |
|
78 | 0 | // Replace the values directly with the return operands. |
79 | 0 | assert(returnOp.getNumOperands() == valuesToRepl.size()); |
80 | 0 | for (const auto &it : llvm::enumerate(returnOp.getOperands())) |
81 | 0 | valuesToRepl[it.index()].replaceAllUsesWith(it.value()); |
82 | 0 | } |
83 | | }; |
84 | | } // end anonymous namespace |
85 | | |
86 | | //===----------------------------------------------------------------------===// |
87 | | // StandardOpsDialect |
88 | | //===----------------------------------------------------------------------===// |
89 | | |
90 | | /// A custom unary operation printer that omits the "std." prefix from the |
91 | | /// operation names. |
92 | 0 | static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) { |
93 | 0 | assert(op->getNumOperands() == 1 && "unary op should have one operand"); |
94 | 0 | assert(op->getNumResults() == 1 && "unary op should have one result"); |
95 | 0 |
|
96 | 0 | int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; |
97 | 0 | p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' |
98 | 0 | << op->getOperand(0); |
99 | 0 | p.printOptionalAttrDict(op->getAttrs()); |
100 | 0 | p << " : " << op->getOperand(0).getType(); |
101 | 0 | } |
102 | | |
103 | | /// A custom binary operation printer that omits the "std." prefix from the |
104 | | /// operation names. |
105 | 0 | static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) { |
106 | 0 | assert(op->getNumOperands() == 2 && "binary op should have two operands"); |
107 | 0 | assert(op->getNumResults() == 1 && "binary op should have one result"); |
108 | 0 |
|
109 | 0 | // If not all the operand and result types are the same, just use the |
110 | 0 | // generic assembly form to avoid omitting information in printing. |
111 | 0 | auto resultType = op->getResult(0).getType(); |
112 | 0 | if (op->getOperand(0).getType() != resultType || |
113 | 0 | op->getOperand(1).getType() != resultType) { |
114 | 0 | p.printGenericOp(op); |
115 | 0 | return; |
116 | 0 | } |
117 | 0 | |
118 | 0 | int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; |
119 | 0 | p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' |
120 | 0 | << op->getOperand(0) << ", " << op->getOperand(1); |
121 | 0 | p.printOptionalAttrDict(op->getAttrs()); |
122 | 0 |
|
123 | 0 | // Now we can output only one type for all operands and the result. |
124 | 0 | p << " : " << op->getResult(0).getType(); |
125 | 0 | } |
126 | | |
127 | | /// A custom cast operation printer that omits the "std." prefix from the |
128 | | /// operation names. |
129 | 0 | static void printStandardCastOp(Operation *op, OpAsmPrinter &p) { |
130 | 0 | int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; |
131 | 0 | p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' |
132 | 0 | << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to " |
133 | 0 | << op->getResult(0).getType(); |
134 | 0 | } |
135 | | |
136 | | /// A custom cast operation verifier. |
137 | | template <typename T> |
138 | 0 | static LogicalResult verifyCastOp(T op) { |
139 | 0 | auto opType = op.getOperand().getType(); |
140 | 0 | auto resType = op.getType(); |
141 | 0 | if (!T::areCastCompatible(opType, resType)) |
142 | 0 | return op.emitError("operand type ") << opType << " and result type " |
143 | 0 | << resType << " are cast incompatible"; |
144 | 0 | |
145 | 0 | return success(); |
146 | 0 | } Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir7FPExtOpEENS0_13LogicalResultET_ Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir8FPToSIOpEENS0_13LogicalResultET_ Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir9FPTruncOpEENS0_13LogicalResultET_ Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir11IndexCastOpEENS0_13LogicalResultET_ Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir12MemRefCastOpEENS0_13LogicalResultET_ Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir8SIToFPOpEENS0_13LogicalResultET_ Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir12TensorCastOpEENS0_13LogicalResultET_ |
147 | | |
148 | | StandardOpsDialect::StandardOpsDialect(MLIRContext *context) |
149 | 0 | : Dialect(getDialectNamespace(), context) { |
150 | 0 | addOperations<DmaStartOp, DmaWaitOp, |
151 | 0 | #define GET_OP_LIST |
152 | 0 | #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc" |
153 | 0 | >(); |
154 | 0 | addInterfaces<StdInlinerInterface>(); |
155 | 0 | } |
156 | | |
157 | | /// Materialize a single constant operation from a given attribute value with |
158 | | /// the desired resultant type. |
159 | | Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder, |
160 | | Attribute value, Type type, |
161 | 0 | Location loc) { |
162 | 0 | return builder.create<ConstantOp>(loc, type, value); |
163 | 0 | } |
164 | | |
165 | | void mlir::printDimAndSymbolList(Operation::operand_iterator begin, |
166 | | Operation::operand_iterator end, |
167 | 0 | unsigned numDims, OpAsmPrinter &p) { |
168 | 0 | Operation::operand_range operands(begin, end); |
169 | 0 | p << '(' << operands.take_front(numDims) << ')'; |
170 | 0 | if (operands.size() != numDims) |
171 | 0 | p << '[' << operands.drop_front(numDims) << ']'; |
172 | 0 | } |
173 | | |
174 | | // Parses dimension and symbol list, and sets 'numDims' to the number of |
175 | | // dimension operands parsed. |
176 | | // Returns 'false' on success and 'true' on error. |
177 | | ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser, |
178 | | SmallVectorImpl<Value> &operands, |
179 | 0 | unsigned &numDims) { |
180 | 0 | SmallVector<OpAsmParser::OperandType, 8> opInfos; |
181 | 0 | if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) |
182 | 0 | return failure(); |
183 | 0 | // Store number of dimensions for validation by caller. |
184 | 0 | numDims = opInfos.size(); |
185 | 0 |
|
186 | 0 | // Parse the optional symbol operands. |
187 | 0 | auto indexTy = parser.getBuilder().getIndexType(); |
188 | 0 | if (parser.parseOperandList(opInfos, |
189 | 0 | OpAsmParser::Delimiter::OptionalSquare) || |
190 | 0 | parser.resolveOperands(opInfos, indexTy, operands)) |
191 | 0 | return failure(); |
192 | 0 | return success(); |
193 | 0 | } |
194 | | |
195 | | /// Matches a ConstantIndexOp. |
196 | | /// TODO: This should probably just be a general matcher that uses m_Constant |
197 | | /// and checks the operation for an index type. |
198 | 0 | static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() { |
199 | 0 | return detail::op_matcher<ConstantIndexOp>(); |
200 | 0 | } |
201 | | |
202 | | //===----------------------------------------------------------------------===// |
203 | | // Common canonicalization pattern support logic |
204 | | //===----------------------------------------------------------------------===// |
205 | | |
206 | | /// This is a common class used for patterns of the form |
207 | | /// "someop(memrefcast) -> someop". It folds the source of any memref_cast |
208 | | /// into the root operation directly. |
209 | 0 | static LogicalResult foldMemRefCast(Operation *op) { |
210 | 0 | bool folded = false; |
211 | 0 | for (OpOperand &operand : op->getOpOperands()) { |
212 | 0 | auto cast = operand.get().getDefiningOp<MemRefCastOp>(); |
213 | 0 | if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) { |
214 | 0 | operand.set(cast.getOperand()); |
215 | 0 | folded = true; |
216 | 0 | } |
217 | 0 | } |
218 | 0 | return success(folded); |
219 | 0 | } |
220 | | |
221 | | //===----------------------------------------------------------------------===// |
222 | | // AddFOp |
223 | | //===----------------------------------------------------------------------===// |
224 | | |
225 | 0 | OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) { |
226 | 0 | return constFoldBinaryOp<FloatAttr>( |
227 | 0 | operands, [](APFloat a, APFloat b) { return a + b; }); |
228 | 0 | } |
229 | | |
230 | | //===----------------------------------------------------------------------===// |
231 | | // AddIOp |
232 | | //===----------------------------------------------------------------------===// |
233 | | |
234 | 0 | OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) { |
235 | 0 | /// addi(x, 0) -> x |
236 | 0 | if (matchPattern(rhs(), m_Zero())) |
237 | 0 | return lhs(); |
238 | 0 | |
239 | 0 | return constFoldBinaryOp<IntegerAttr>(operands, |
240 | 0 | [](APInt a, APInt b) { return a + b; }); |
241 | 0 | } |
242 | | |
243 | | //===----------------------------------------------------------------------===// |
244 | | // AllocOp / AllocaOp |
245 | | //===----------------------------------------------------------------------===// |
246 | | |
247 | | template <typename AllocLikeOp> |
248 | 0 | static void printAllocLikeOp(OpAsmPrinter &p, AllocLikeOp op, StringRef name) { |
249 | 0 | static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, |
250 | 0 | "applies to only alloc or alloca"); |
251 | 0 | p << name; |
252 | 0 |
|
253 | 0 | // Print dynamic dimension operands. |
254 | 0 | MemRefType type = op.getType(); |
255 | 0 | printDimAndSymbolList(op.operand_begin(), op.operand_end(), |
256 | 0 | type.getNumDynamicDims(), p); |
257 | 0 | p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); |
258 | 0 | p << " : " << type; |
259 | 0 | } Unexecuted instantiation: Ops.cpp:_ZL16printAllocLikeOpIN4mlir7AllocOpEEvRNS0_12OpAsmPrinterET_N4llvm9StringRefE Unexecuted instantiation: Ops.cpp:_ZL16printAllocLikeOpIN4mlir8AllocaOpEEvRNS0_12OpAsmPrinterET_N4llvm9StringRefE |
260 | | |
261 | 0 | static void print(OpAsmPrinter &p, AllocOp op) { |
262 | 0 | printAllocLikeOp(p, op, "alloc"); |
263 | 0 | } |
264 | | |
265 | 0 | static void print(OpAsmPrinter &p, AllocaOp op) { |
266 | 0 | printAllocLikeOp(p, op, "alloca"); |
267 | 0 | } |
268 | | |
269 | | static ParseResult parseAllocLikeOp(OpAsmParser &parser, |
270 | 0 | OperationState &result) { |
271 | 0 | MemRefType type; |
272 | 0 |
|
273 | 0 | // Parse the dimension operands and optional symbol operands, followed by a |
274 | 0 | // memref type. |
275 | 0 | unsigned numDimOperands; |
276 | 0 | if (parseDimAndSymbolList(parser, result.operands, numDimOperands) || |
277 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
278 | 0 | parser.parseColonType(type)) |
279 | 0 | return failure(); |
280 | 0 | |
281 | 0 | // Check numDynamicDims against number of question marks in memref type. |
282 | 0 | // Note: this check remains here (instead of in verify()), because the |
283 | 0 | // partition between dim operands and symbol operands is lost after parsing. |
284 | 0 | // Verification still checks that the total number of operands matches |
285 | 0 | // the number of symbols in the affine map, plus the number of dynamic |
286 | 0 | // dimensions in the memref. |
287 | 0 | if (numDimOperands != type.getNumDynamicDims()) |
288 | 0 | return parser.emitError(parser.getNameLoc()) |
289 | 0 | << "dimension operand count does not equal memref dynamic dimension " |
290 | 0 | "count"; |
291 | 0 | result.types.push_back(type); |
292 | 0 | return success(); |
293 | 0 | } |
294 | | |
295 | | template <typename AllocLikeOp> |
296 | 0 | static LogicalResult verify(AllocLikeOp op) { |
297 | 0 | static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, |
298 | 0 | "applies to only alloc or alloca"); |
299 | 0 | auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); |
300 | 0 | if (!memRefType) |
301 | 0 | return op.emitOpError("result must be a memref"); |
302 | 0 | |
303 | 0 | unsigned numSymbols = 0; |
304 | 0 | if (!memRefType.getAffineMaps().empty()) { |
305 | 0 | // Store number of symbols used in affine map (used in subsequent check). |
306 | 0 | AffineMap affineMap = memRefType.getAffineMaps()[0]; |
307 | 0 | numSymbols = affineMap.getNumSymbols(); |
308 | 0 | } |
309 | 0 |
|
310 | 0 | // Check that the total number of operands matches the number of symbols in |
311 | 0 | // the affine map, plus the number of dynamic dimensions specified in the |
312 | 0 | // memref type. |
313 | 0 | unsigned numDynamicDims = memRefType.getNumDynamicDims(); |
314 | 0 | if (op.getNumOperands() != numDynamicDims + numSymbols) |
315 | 0 | return op.emitOpError( |
316 | 0 | "operand count does not equal dimension plus symbol operand count"); |
317 | 0 | |
318 | 0 | // Verify that all operands are of type Index. |
319 | 0 | for (auto operandType : op.getOperandTypes()) |
320 | 0 | if (!operandType.isIndex()) |
321 | 0 | return op.emitOpError("requires operands to be of type Index"); |
322 | 0 |
|
323 | 0 | if (std::is_same<AllocLikeOp, AllocOp>::value) |
324 | 0 | return success(); |
325 | 0 | |
326 | 0 | // An alloca op needs to have an ancestor with an allocation scope trait. |
327 | 0 | if (!op.template getParentWithTrait<OpTrait::AutomaticAllocationScope>()) |
328 | 0 | return op.emitOpError( |
329 | 0 | "requires an ancestor op with AutomaticAllocationScope trait"); |
330 | 0 | |
331 | 0 | return success(); |
332 | 0 | } Unexecuted instantiation: Ops.cpp:_ZL6verifyIN4mlir7AllocOpEENS0_13LogicalResultET_ Unexecuted instantiation: Ops.cpp:_ZL6verifyIN4mlir8AllocaOpEENS0_13LogicalResultET_ |
333 | | |
334 | | namespace { |
335 | | /// Fold constant dimensions into an alloc like operation. |
336 | | template <typename AllocLikeOp> |
337 | | struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { |
338 | | using OpRewritePattern<AllocLikeOp>::OpRewritePattern; |
339 | | |
340 | | LogicalResult matchAndRewrite(AllocLikeOp alloc, |
341 | 0 | PatternRewriter &rewriter) const override { |
342 | 0 | // Check to see if any dimensions operands are constants. If so, we can |
343 | 0 | // substitute and drop them. |
344 | 0 | if (llvm::none_of(alloc.getOperands(), [](Value operand) { |
345 | 0 | return matchPattern(operand, m_ConstantIndex()); |
346 | 0 | })) Unexecuted instantiation: Ops.cpp:_ZZNK12_GLOBAL__N_118SimplifyAllocConstIN4mlir7AllocOpEE15matchAndRewriteES2_RNS1_15PatternRewriterEENKUlNS1_5ValueEE_clES6_ Unexecuted instantiation: Ops.cpp:_ZZNK12_GLOBAL__N_118SimplifyAllocConstIN4mlir8AllocaOpEE15matchAndRewriteES2_RNS1_15PatternRewriterEENKUlNS1_5ValueEE_clES6_ |
347 | 0 | return failure(); |
348 | 0 | |
349 | 0 | auto memrefType = alloc.getType(); |
350 | 0 |
|
351 | 0 | // Ok, we have one or more constant operands. Collect the non-constant ones |
352 | 0 | // and keep track of the resultant memref type to build. |
353 | 0 | SmallVector<int64_t, 4> newShapeConstants; |
354 | 0 | newShapeConstants.reserve(memrefType.getRank()); |
355 | 0 | SmallVector<Value, 4> newOperands; |
356 | 0 |
|
357 | 0 | unsigned dynamicDimPos = 0; |
358 | 0 | for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { |
359 | 0 | int64_t dimSize = memrefType.getDimSize(dim); |
360 | 0 | // If this is already static dimension, keep it. |
361 | 0 | if (dimSize != -1) { |
362 | 0 | newShapeConstants.push_back(dimSize); |
363 | 0 | continue; |
364 | 0 | } |
365 | 0 | auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); |
366 | 0 | if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { |
367 | 0 | // Dynamic shape dimension will be folded. |
368 | 0 | newShapeConstants.push_back(constantIndexOp.getValue()); |
369 | 0 | } else { |
370 | 0 | // Dynamic shape dimension not folded; copy operand from old memref. |
371 | 0 | newShapeConstants.push_back(-1); |
372 | 0 | newOperands.push_back(alloc.getOperand(dynamicDimPos)); |
373 | 0 | } |
374 | 0 | dynamicDimPos++; |
375 | 0 | } |
376 | 0 |
|
377 | 0 | // Create new memref type (which will have fewer dynamic dimensions). |
378 | 0 | MemRefType newMemRefType = |
379 | 0 | MemRefType::Builder(memrefType).setShape(newShapeConstants); |
380 | 0 | assert(static_cast<int64_t>(newOperands.size()) == |
381 | 0 | newMemRefType.getNumDynamicDims()); |
382 | 0 |
|
383 | 0 | // Create and insert the alloc op for the new memref. |
384 | 0 | auto newAlloc = rewriter.create<AllocLikeOp>(alloc.getLoc(), newMemRefType, |
385 | 0 | newOperands, IntegerAttr()); |
386 | 0 | // Insert a cast so we have the same type as the old alloc. |
387 | 0 | auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc, |
388 | 0 | alloc.getType()); |
389 | 0 |
|
390 | 0 | rewriter.replaceOp(alloc, {resultCast}); |
391 | 0 | return success(); |
392 | 0 | } Unexecuted instantiation: Ops.cpp:_ZNK12_GLOBAL__N_118SimplifyAllocConstIN4mlir7AllocOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE Unexecuted instantiation: Ops.cpp:_ZNK12_GLOBAL__N_118SimplifyAllocConstIN4mlir8AllocaOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE |
393 | | }; |
394 | | |
395 | | /// Fold alloc operations with no uses. Alloc has side effects on the heap, |
396 | | /// but can still be deleted if it has zero uses. |
397 | | struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> { |
398 | | using OpRewritePattern<AllocOp>::OpRewritePattern; |
399 | | |
400 | | LogicalResult matchAndRewrite(AllocOp alloc, |
401 | 0 | PatternRewriter &rewriter) const override { |
402 | 0 | if (alloc.use_empty()) { |
403 | 0 | rewriter.eraseOp(alloc); |
404 | 0 | return success(); |
405 | 0 | } |
406 | 0 | return failure(); |
407 | 0 | } |
408 | | }; |
409 | | } // end anonymous namespace. |
410 | | |
411 | | void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
412 | 0 | MLIRContext *context) { |
413 | 0 | results.insert<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context); |
414 | 0 | } |
415 | | |
416 | | void AllocaOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
417 | 0 | MLIRContext *context) { |
418 | 0 | results.insert<SimplifyAllocConst<AllocaOp>>(context); |
419 | 0 | } |
420 | | |
421 | | //===----------------------------------------------------------------------===// |
422 | | // AndOp |
423 | | //===----------------------------------------------------------------------===// |
424 | | |
425 | 0 | OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) { |
426 | 0 | /// and(x, 0) -> 0 |
427 | 0 | if (matchPattern(rhs(), m_Zero())) |
428 | 0 | return rhs(); |
429 | 0 | /// and(x,x) -> x |
430 | 0 | if (lhs() == rhs()) |
431 | 0 | return rhs(); |
432 | 0 | |
433 | 0 | return constFoldBinaryOp<IntegerAttr>(operands, |
434 | 0 | [](APInt a, APInt b) { return a & b; }); |
435 | 0 | } |
436 | | |
437 | | //===----------------------------------------------------------------------===// |
438 | | // AssumeAlignmentOp |
439 | | //===----------------------------------------------------------------------===// |
440 | | |
441 | 0 | static LogicalResult verify(AssumeAlignmentOp op) { |
442 | 0 | unsigned alignment = op.alignment().getZExtValue(); |
443 | 0 | if (!llvm::isPowerOf2_32(alignment)) |
444 | 0 | return op.emitOpError("alignment must be power of 2"); |
445 | 0 | return success(); |
446 | 0 | } |
447 | | |
448 | | //===----------------------------------------------------------------------===// |
449 | | // AtomicRMWOp |
450 | | //===----------------------------------------------------------------------===// |
451 | | |
452 | 0 | static LogicalResult verify(AtomicRMWOp op) { |
453 | 0 | if (op.getMemRefType().getRank() != op.getNumOperands() - 2) |
454 | 0 | return op.emitOpError( |
455 | 0 | "expects the number of subscripts to be equal to memref rank"); |
456 | 0 | switch (op.kind()) { |
457 | 0 | case AtomicRMWKind::addf: |
458 | 0 | case AtomicRMWKind::maxf: |
459 | 0 | case AtomicRMWKind::minf: |
460 | 0 | case AtomicRMWKind::mulf: |
461 | 0 | if (!op.value().getType().isa<FloatType>()) |
462 | 0 | return op.emitOpError() |
463 | 0 | << "with kind '" << stringifyAtomicRMWKind(op.kind()) |
464 | 0 | << "' expects a floating-point type"; |
465 | 0 | break; |
466 | 0 | case AtomicRMWKind::addi: |
467 | 0 | case AtomicRMWKind::maxs: |
468 | 0 | case AtomicRMWKind::maxu: |
469 | 0 | case AtomicRMWKind::mins: |
470 | 0 | case AtomicRMWKind::minu: |
471 | 0 | case AtomicRMWKind::muli: |
472 | 0 | if (!op.value().getType().isa<IntegerType>()) |
473 | 0 | return op.emitOpError() |
474 | 0 | << "with kind '" << stringifyAtomicRMWKind(op.kind()) |
475 | 0 | << "' expects an integer type"; |
476 | 0 | break; |
477 | 0 | default: |
478 | 0 | break; |
479 | 0 | } |
480 | 0 | return success(); |
481 | 0 | } |
482 | | |
483 | | //===----------------------------------------------------------------------===// |
484 | | // GenericAtomicRMWOp |
485 | | //===----------------------------------------------------------------------===// |
486 | | |
487 | | void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, |
488 | 0 | Value memref, ValueRange ivs) { |
489 | 0 | result.addOperands(memref); |
490 | 0 | result.addOperands(ivs); |
491 | 0 |
|
492 | 0 | if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) { |
493 | 0 | Type elementType = memrefType.getElementType(); |
494 | 0 | result.addTypes(elementType); |
495 | 0 |
|
496 | 0 | Region *bodyRegion = result.addRegion(); |
497 | 0 | bodyRegion->push_back(new Block()); |
498 | 0 | bodyRegion->front().addArgument(elementType); |
499 | 0 | } |
500 | 0 | } |
501 | | |
502 | 0 | static LogicalResult verify(GenericAtomicRMWOp op) { |
503 | 0 | auto &block = op.body().front(); |
504 | 0 | if (block.getNumArguments() != 1) |
505 | 0 | return op.emitOpError("expected single number of entry block arguments"); |
506 | 0 | |
507 | 0 | if (op.getResult().getType() != block.getArgument(0).getType()) |
508 | 0 | return op.emitOpError( |
509 | 0 | "expected block argument of the same type result type"); |
510 | 0 | |
511 | 0 | bool hasSideEffects = |
512 | 0 | op.body() |
513 | 0 | .walk([&](Operation *nestedOp) { |
514 | 0 | if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) |
515 | 0 | return WalkResult::advance(); |
516 | 0 | nestedOp->emitError("body of 'generic_atomic_rmw' should contain " |
517 | 0 | "only operations with no side effects"); |
518 | 0 | return WalkResult::interrupt(); |
519 | 0 | }) |
520 | 0 | .wasInterrupted(); |
521 | 0 | return hasSideEffects ? failure() : success(); |
522 | 0 | } |
523 | | |
524 | | static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser, |
525 | 0 | OperationState &result) { |
526 | 0 | OpAsmParser::OperandType memref; |
527 | 0 | Type memrefType; |
528 | 0 | SmallVector<OpAsmParser::OperandType, 4> ivs; |
529 | 0 |
|
530 | 0 | Type indexType = parser.getBuilder().getIndexType(); |
531 | 0 | if (parser.parseOperand(memref) || |
532 | 0 | parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || |
533 | 0 | parser.parseColonType(memrefType) || |
534 | 0 | parser.resolveOperand(memref, memrefType, result.operands) || |
535 | 0 | parser.resolveOperands(ivs, indexType, result.operands)) |
536 | 0 | return failure(); |
537 | 0 | |
538 | 0 | Region *body = result.addRegion(); |
539 | 0 | if (parser.parseRegion(*body, llvm::None, llvm::None)) |
540 | 0 | return failure(); |
541 | 0 | result.types.push_back(memrefType.cast<MemRefType>().getElementType()); |
542 | 0 | return success(); |
543 | 0 | } |
544 | | |
545 | 0 | static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { |
546 | 0 | p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices() |
547 | 0 | << "] : " << op.memref().getType(); |
548 | 0 | p.printRegion(op.body()); |
549 | 0 | p.printOptionalAttrDict(op.getAttrs()); |
550 | 0 | } |
551 | | |
552 | | //===----------------------------------------------------------------------===// |
553 | | // AtomicYieldOp |
554 | | //===----------------------------------------------------------------------===// |
555 | | |
556 | 0 | static LogicalResult verify(AtomicYieldOp op) { |
557 | 0 | Type parentType = op.getParentOp()->getResultTypes().front(); |
558 | 0 | Type resultType = op.result().getType(); |
559 | 0 | if (parentType != resultType) |
560 | 0 | return op.emitOpError() << "types mismatch between yield op: " << resultType |
561 | 0 | << " and its parent: " << parentType; |
562 | 0 | return success(); |
563 | 0 | } |
564 | | |
565 | | //===----------------------------------------------------------------------===// |
566 | | // BranchOp |
567 | | //===----------------------------------------------------------------------===// |
568 | | |
569 | | /// Given a successor, try to collapse it to a new destination if it only |
570 | | /// contains a passthrough unconditional branch. If the successor is |
571 | | /// collapsable, `successor` and `successorOperands` are updated to reference |
572 | | /// the new destination and values. `argStorage` is an optional storage to use |
573 | | /// if operands to the collapsed successor need to be remapped. |
574 | | static LogicalResult collapseBranch(Block *&successor, |
575 | | ValueRange &successorOperands, |
576 | 0 | SmallVectorImpl<Value> &argStorage) { |
577 | 0 | // Check that the successor only contains a unconditional branch. |
578 | 0 | if (std::next(successor->begin()) != successor->end()) |
579 | 0 | return failure(); |
580 | 0 | // Check that the terminator is an unconditional branch. |
581 | 0 | BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); |
582 | 0 | if (!successorBranch) |
583 | 0 | return failure(); |
584 | 0 | // Check that the arguments are only used within the terminator. |
585 | 0 | for (BlockArgument arg : successor->getArguments()) { |
586 | 0 | for (Operation *user : arg.getUsers()) |
587 | 0 | if (user != successorBranch) |
588 | 0 | return failure(); |
589 | 0 | } |
590 | 0 | // Don't try to collapse branches to infinite loops. |
591 | 0 | Block *successorDest = successorBranch.getDest(); |
592 | 0 | if (successorDest == successor) |
593 | 0 | return failure(); |
594 | 0 | |
595 | 0 | // Update the operands to the successor. If the branch parent has no |
596 | 0 | // arguments, we can use the branch operands directly. |
597 | 0 | OperandRange operands = successorBranch.getOperands(); |
598 | 0 | if (successor->args_empty()) { |
599 | 0 | successor = successorDest; |
600 | 0 | successorOperands = operands; |
601 | 0 | return success(); |
602 | 0 | } |
603 | 0 | |
604 | 0 | // Otherwise, we need to remap any argument operands. |
605 | 0 | for (Value operand : operands) { |
606 | 0 | BlockArgument argOperand = operand.dyn_cast<BlockArgument>(); |
607 | 0 | if (argOperand && argOperand.getOwner() == successor) |
608 | 0 | argStorage.push_back(successorOperands[argOperand.getArgNumber()]); |
609 | 0 | else |
610 | 0 | argStorage.push_back(operand); |
611 | 0 | } |
612 | 0 | successor = successorDest; |
613 | 0 | successorOperands = argStorage; |
614 | 0 | return success(); |
615 | 0 | } |
616 | | |
617 | | namespace { |
618 | | /// Simplify a branch to a block that has a single predecessor. This effectively |
619 | | /// merges the two blocks. |
620 | | struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> { |
621 | | using OpRewritePattern<BranchOp>::OpRewritePattern; |
622 | | |
623 | | LogicalResult matchAndRewrite(BranchOp op, |
624 | 0 | PatternRewriter &rewriter) const override { |
625 | 0 | // Check that the successor block has a single predecessor. |
626 | 0 | Block *succ = op.getDest(); |
627 | 0 | Block *opParent = op.getOperation()->getBlock(); |
628 | 0 | if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) |
629 | 0 | return failure(); |
630 | 0 | |
631 | 0 | // Merge the successor into the current block and erase the branch. |
632 | 0 | rewriter.mergeBlocks(succ, opParent, op.getOperands()); |
633 | 0 | rewriter.eraseOp(op); |
634 | 0 | return success(); |
635 | 0 | } |
636 | | }; |
637 | | |
638 | | /// br ^bb1 |
639 | | /// ^bb1 |
640 | | /// br ^bbN(...) |
641 | | /// |
642 | | /// -> br ^bbN(...) |
643 | | /// |
644 | | struct SimplifyPassThroughBr : public OpRewritePattern<BranchOp> { |
645 | | using OpRewritePattern<BranchOp>::OpRewritePattern; |
646 | | |
647 | | LogicalResult matchAndRewrite(BranchOp op, |
648 | 0 | PatternRewriter &rewriter) const override { |
649 | 0 | Block *dest = op.getDest(); |
650 | 0 | ValueRange destOperands = op.getOperands(); |
651 | 0 | SmallVector<Value, 4> destOperandStorage; |
652 | 0 |
|
653 | 0 | // Try to collapse the successor if it points somewhere other than this |
654 | 0 | // block. |
655 | 0 | if (dest == op.getOperation()->getBlock() || |
656 | 0 | failed(collapseBranch(dest, destOperands, destOperandStorage))) |
657 | 0 | return failure(); |
658 | 0 | |
659 | 0 | // Create a new branch with the collapsed successor. |
660 | 0 | rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); |
661 | 0 | return success(); |
662 | 0 | } |
663 | | }; |
664 | | } // end anonymous namespace. |
665 | | |
666 | 0 | Block *BranchOp::getDest() { return getSuccessor(); } |
667 | | |
668 | 0 | void BranchOp::setDest(Block *block) { return setSuccessor(block); } |
669 | | |
670 | 0 | void BranchOp::eraseOperand(unsigned index) { |
671 | 0 | getOperation()->eraseOperand(index); |
672 | 0 | } |
673 | | |
674 | | void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
675 | 0 | MLIRContext *context) { |
676 | 0 | results.insert<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>( |
677 | 0 | context); |
678 | 0 | } |
679 | | |
680 | | Optional<MutableOperandRange> |
681 | 0 | BranchOp::getMutableSuccessorOperands(unsigned index) { |
682 | 0 | assert(index == 0 && "invalid successor index"); |
683 | 0 | return destOperandsMutable(); |
684 | 0 | } |
685 | | |
686 | 0 | Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); } |
687 | | |
688 | | //===----------------------------------------------------------------------===// |
689 | | // CallOp |
690 | | //===----------------------------------------------------------------------===// |
691 | | |
692 | 0 | static LogicalResult verify(CallOp op) { |
693 | 0 | // Check that the callee attribute was specified. |
694 | 0 | auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee"); |
695 | 0 | if (!fnAttr) |
696 | 0 | return op.emitOpError("requires a 'callee' symbol reference attribute"); |
697 | 0 | auto fn = |
698 | 0 | op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue()); |
699 | 0 | if (!fn) |
700 | 0 | return op.emitOpError() << "'" << fnAttr.getValue() |
701 | 0 | << "' does not reference a valid function"; |
702 | 0 | |
703 | 0 | // Verify that the operand and result types match the callee. |
704 | 0 | auto fnType = fn.getType(); |
705 | 0 | if (fnType.getNumInputs() != op.getNumOperands()) |
706 | 0 | return op.emitOpError("incorrect number of operands for callee"); |
707 | 0 | |
708 | 0 | for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) |
709 | 0 | if (op.getOperand(i).getType() != fnType.getInput(i)) |
710 | 0 | return op.emitOpError("operand type mismatch"); |
711 | 0 |
|
712 | 0 | if (fnType.getNumResults() != op.getNumResults()) |
713 | 0 | return op.emitOpError("incorrect number of results for callee"); |
714 | 0 | |
715 | 0 | for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) |
716 | 0 | if (op.getResult(i).getType() != fnType.getResult(i)) |
717 | 0 | return op.emitOpError("result type mismatch"); |
718 | 0 |
|
719 | 0 | return success(); |
720 | 0 | } |
721 | | |
722 | 0 | FunctionType CallOp::getCalleeType() { |
723 | 0 | SmallVector<Type, 8> argTypes(getOperandTypes()); |
724 | 0 | return FunctionType::get(argTypes, getResultTypes(), getContext()); |
725 | 0 | } |
726 | | |
727 | | //===----------------------------------------------------------------------===// |
728 | | // CallIndirectOp |
729 | | //===----------------------------------------------------------------------===// |
730 | | namespace { |
731 | | /// Fold indirect calls that have a constant function as the callee operand. |
732 | | struct SimplifyIndirectCallWithKnownCallee |
733 | | : public OpRewritePattern<CallIndirectOp> { |
734 | | using OpRewritePattern<CallIndirectOp>::OpRewritePattern; |
735 | | |
736 | | LogicalResult matchAndRewrite(CallIndirectOp indirectCall, |
737 | 0 | PatternRewriter &rewriter) const override { |
738 | 0 | // Check that the callee is a constant callee. |
739 | 0 | SymbolRefAttr calledFn; |
740 | 0 | if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) |
741 | 0 | return failure(); |
742 | 0 | |
743 | 0 | // Replace with a direct call. |
744 | 0 | rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, |
745 | 0 | indirectCall.getResultTypes(), |
746 | 0 | indirectCall.getArgOperands()); |
747 | 0 | return success(); |
748 | 0 | } |
749 | | }; |
750 | | } // end anonymous namespace. |
751 | | |
752 | | void CallIndirectOp::getCanonicalizationPatterns( |
753 | 0 | OwningRewritePatternList &results, MLIRContext *context) { |
754 | 0 | results.insert<SimplifyIndirectCallWithKnownCallee>(context); |
755 | 0 | } |
756 | | |
757 | | //===----------------------------------------------------------------------===// |
758 | | // General helpers for comparison ops |
759 | | //===----------------------------------------------------------------------===// |
760 | | |
761 | | // Return the type of the same shape (scalar, vector or tensor) containing i1. |
762 | 0 | static Type getI1SameShape(Type type) { |
763 | 0 | auto i1Type = IntegerType::get(1, type.getContext()); |
764 | 0 | if (auto tensorType = type.dyn_cast<RankedTensorType>()) |
765 | 0 | return RankedTensorType::get(tensorType.getShape(), i1Type); |
766 | 0 | if (type.isa<UnrankedTensorType>()) |
767 | 0 | return UnrankedTensorType::get(i1Type); |
768 | 0 | if (auto vectorType = type.dyn_cast<VectorType>()) |
769 | 0 | return VectorType::get(vectorType.getShape(), i1Type); |
770 | 0 | return i1Type; |
771 | 0 | } |
772 | | |
773 | | //===----------------------------------------------------------------------===// |
774 | | // CmpIOp |
775 | | //===----------------------------------------------------------------------===// |
776 | | |
777 | | static void buildCmpIOp(OpBuilder &build, OperationState &result, |
778 | 0 | CmpIPredicate predicate, Value lhs, Value rhs) { |
779 | 0 | result.addOperands({lhs, rhs}); |
780 | 0 | result.types.push_back(getI1SameShape(lhs.getType())); |
781 | 0 | result.addAttribute(CmpIOp::getPredicateAttrName(), |
782 | 0 | build.getI64IntegerAttr(static_cast<int64_t>(predicate))); |
783 | 0 | } |
784 | | |
785 | | // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer |
786 | | // comparison predicates. |
787 | | static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, |
788 | 0 | const APInt &rhs) { |
789 | 0 | switch (predicate) { |
790 | 0 | case CmpIPredicate::eq: |
791 | 0 | return lhs.eq(rhs); |
792 | 0 | case CmpIPredicate::ne: |
793 | 0 | return lhs.ne(rhs); |
794 | 0 | case CmpIPredicate::slt: |
795 | 0 | return lhs.slt(rhs); |
796 | 0 | case CmpIPredicate::sle: |
797 | 0 | return lhs.sle(rhs); |
798 | 0 | case CmpIPredicate::sgt: |
799 | 0 | return lhs.sgt(rhs); |
800 | 0 | case CmpIPredicate::sge: |
801 | 0 | return lhs.sge(rhs); |
802 | 0 | case CmpIPredicate::ult: |
803 | 0 | return lhs.ult(rhs); |
804 | 0 | case CmpIPredicate::ule: |
805 | 0 | return lhs.ule(rhs); |
806 | 0 | case CmpIPredicate::ugt: |
807 | 0 | return lhs.ugt(rhs); |
808 | 0 | case CmpIPredicate::uge: |
809 | 0 | return lhs.uge(rhs); |
810 | 0 | } |
811 | 0 | llvm_unreachable("unknown comparison predicate"); |
812 | 0 | } |
813 | | |
814 | | // Constant folding hook for comparisons. |
815 | 0 | OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) { |
816 | 0 | assert(operands.size() == 2 && "cmpi takes two arguments"); |
817 | 0 |
|
818 | 0 | auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); |
819 | 0 | auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); |
820 | 0 | if (!lhs || !rhs) |
821 | 0 | return {}; |
822 | 0 | |
823 | 0 | auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); |
824 | 0 | return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); |
825 | 0 | } |
826 | | |
827 | | //===----------------------------------------------------------------------===// |
828 | | // CmpFOp |
829 | | //===----------------------------------------------------------------------===// |
830 | | |
831 | | static void buildCmpFOp(OpBuilder &build, OperationState &result, |
832 | 0 | CmpFPredicate predicate, Value lhs, Value rhs) { |
833 | 0 | result.addOperands({lhs, rhs}); |
834 | 0 | result.types.push_back(getI1SameShape(lhs.getType())); |
835 | 0 | result.addAttribute(CmpFOp::getPredicateAttrName(), |
836 | 0 | build.getI64IntegerAttr(static_cast<int64_t>(predicate))); |
837 | 0 | } |
838 | | |
839 | | /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point |
840 | | /// comparison predicates. |
841 | | static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, |
842 | 0 | const APFloat &rhs) { |
843 | 0 | auto cmpResult = lhs.compare(rhs); |
844 | 0 | switch (predicate) { |
845 | 0 | case CmpFPredicate::AlwaysFalse: |
846 | 0 | return false; |
847 | 0 | case CmpFPredicate::OEQ: |
848 | 0 | return cmpResult == APFloat::cmpEqual; |
849 | 0 | case CmpFPredicate::OGT: |
850 | 0 | return cmpResult == APFloat::cmpGreaterThan; |
851 | 0 | case CmpFPredicate::OGE: |
852 | 0 | return cmpResult == APFloat::cmpGreaterThan || |
853 | 0 | cmpResult == APFloat::cmpEqual; |
854 | 0 | case CmpFPredicate::OLT: |
855 | 0 | return cmpResult == APFloat::cmpLessThan; |
856 | 0 | case CmpFPredicate::OLE: |
857 | 0 | return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; |
858 | 0 | case CmpFPredicate::ONE: |
859 | 0 | return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; |
860 | 0 | case CmpFPredicate::ORD: |
861 | 0 | return cmpResult != APFloat::cmpUnordered; |
862 | 0 | case CmpFPredicate::UEQ: |
863 | 0 | return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; |
864 | 0 | case CmpFPredicate::UGT: |
865 | 0 | return cmpResult == APFloat::cmpUnordered || |
866 | 0 | cmpResult == APFloat::cmpGreaterThan; |
867 | 0 | case CmpFPredicate::UGE: |
868 | 0 | return cmpResult == APFloat::cmpUnordered || |
869 | 0 | cmpResult == APFloat::cmpGreaterThan || |
870 | 0 | cmpResult == APFloat::cmpEqual; |
871 | 0 | case CmpFPredicate::ULT: |
872 | 0 | return cmpResult == APFloat::cmpUnordered || |
873 | 0 | cmpResult == APFloat::cmpLessThan; |
874 | 0 | case CmpFPredicate::ULE: |
875 | 0 | return cmpResult == APFloat::cmpUnordered || |
876 | 0 | cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; |
877 | 0 | case CmpFPredicate::UNE: |
878 | 0 | return cmpResult != APFloat::cmpEqual; |
879 | 0 | case CmpFPredicate::UNO: |
880 | 0 | return cmpResult == APFloat::cmpUnordered; |
881 | 0 | case CmpFPredicate::AlwaysTrue: |
882 | 0 | return true; |
883 | 0 | } |
884 | 0 | llvm_unreachable("unknown comparison predicate"); |
885 | 0 | } |
886 | | |
887 | | // Constant folding hook for comparisons. |
888 | 0 | OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) { |
889 | 0 | assert(operands.size() == 2 && "cmpf takes two arguments"); |
890 | 0 |
|
891 | 0 | auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); |
892 | 0 | auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); |
893 | 0 |
|
894 | 0 | // TODO(gcmn) We could actually do some intelligent things if we know only one |
895 | 0 | // of the operands, but it's inf or nan. |
896 | 0 | if (!lhs || !rhs) |
897 | 0 | return {}; |
898 | 0 | |
899 | 0 | auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); |
900 | 0 | return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); |
901 | 0 | } |
902 | | |
903 | | //===----------------------------------------------------------------------===// |
904 | | // CondBranchOp |
905 | | //===----------------------------------------------------------------------===// |
906 | | |
907 | | namespace { |
908 | | /// cond_br true, ^bb1, ^bb2 |
909 | | /// -> br ^bb1 |
910 | | /// cond_br false, ^bb1, ^bb2 |
911 | | /// -> br ^bb2 |
912 | | /// |
913 | | struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { |
914 | | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
915 | | |
916 | | LogicalResult matchAndRewrite(CondBranchOp condbr, |
917 | 0 | PatternRewriter &rewriter) const override { |
918 | 0 | if (matchPattern(condbr.getCondition(), m_NonZero())) { |
919 | 0 | // True branch taken. |
920 | 0 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), |
921 | 0 | condbr.getTrueOperands()); |
922 | 0 | return success(); |
923 | 0 | } else if (matchPattern(condbr.getCondition(), m_Zero())) { |
924 | 0 | // False branch taken. |
925 | 0 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), |
926 | 0 | condbr.getFalseOperands()); |
927 | 0 | return success(); |
928 | 0 | } |
929 | 0 | return failure(); |
930 | 0 | } |
931 | | }; |
932 | | |
933 | | /// cond_br %cond, ^bb1, ^bb2 |
934 | | /// ^bb1 |
935 | | /// br ^bbN(...) |
936 | | /// ^bb2 |
937 | | /// br ^bbK(...) |
938 | | /// |
939 | | /// -> cond_br %cond, ^bbN(...), ^bbK(...) |
940 | | /// |
941 | | struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> { |
942 | | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
943 | | |
944 | | LogicalResult matchAndRewrite(CondBranchOp condbr, |
945 | 0 | PatternRewriter &rewriter) const override { |
946 | 0 | Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest(); |
947 | 0 | ValueRange trueDestOperands = condbr.getTrueOperands(); |
948 | 0 | ValueRange falseDestOperands = condbr.getFalseOperands(); |
949 | 0 | SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage; |
950 | 0 |
|
951 | 0 | // Try to collapse one of the current successors. |
952 | 0 | LogicalResult collapsedTrue = |
953 | 0 | collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); |
954 | 0 | LogicalResult collapsedFalse = |
955 | 0 | collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); |
956 | 0 | if (failed(collapsedTrue) && failed(collapsedFalse)) |
957 | 0 | return failure(); |
958 | 0 | |
959 | 0 | // Create a new branch with the collapsed successors. |
960 | 0 | rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(), |
961 | 0 | trueDest, trueDestOperands, |
962 | 0 | falseDest, falseDestOperands); |
963 | 0 | return success(); |
964 | 0 | } |
965 | | }; |
966 | | |
967 | | /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) |
968 | | /// -> br ^bb1(A, ..., N) |
969 | | /// |
970 | | /// cond_br %cond, ^bb1(A), ^bb1(B) |
971 | | /// -> %select = select %cond, A, B |
972 | | /// br ^bb1(%select) |
973 | | /// |
974 | | struct SimplifyCondBranchIdenticalSuccessors |
975 | | : public OpRewritePattern<CondBranchOp> { |
976 | | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
977 | | |
978 | | LogicalResult matchAndRewrite(CondBranchOp condbr, |
979 | 0 | PatternRewriter &rewriter) const override { |
980 | 0 | // Check that the true and false destinations are the same and have the same |
981 | 0 | // operands. |
982 | 0 | Block *trueDest = condbr.trueDest(); |
983 | 0 | if (trueDest != condbr.falseDest()) |
984 | 0 | return failure(); |
985 | 0 | |
986 | 0 | // If all of the operands match, no selects need to be generated. |
987 | 0 | OperandRange trueOperands = condbr.getTrueOperands(); |
988 | 0 | OperandRange falseOperands = condbr.getFalseOperands(); |
989 | 0 | if (trueOperands == falseOperands) { |
990 | 0 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands); |
991 | 0 | return success(); |
992 | 0 | } |
993 | 0 | |
994 | 0 | // Otherwise, if the current block is the only predecessor insert selects |
995 | 0 | // for any mismatched branch operands. |
996 | 0 | if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock()) |
997 | 0 | return failure(); |
998 | 0 | |
999 | 0 | // Generate a select for any operands that differ between the two. |
1000 | 0 | SmallVector<Value, 8> mergedOperands; |
1001 | 0 | mergedOperands.reserve(trueOperands.size()); |
1002 | 0 | Value condition = condbr.getCondition(); |
1003 | 0 | for (auto it : llvm::zip(trueOperands, falseOperands)) { |
1004 | 0 | if (std::get<0>(it) == std::get<1>(it)) |
1005 | 0 | mergedOperands.push_back(std::get<0>(it)); |
1006 | 0 | else |
1007 | 0 | mergedOperands.push_back(rewriter.create<SelectOp>( |
1008 | 0 | condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); |
1009 | 0 | } |
1010 | 0 |
|
1011 | 0 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands); |
1012 | 0 | return success(); |
1013 | 0 | } |
1014 | | }; |
1015 | | } // end anonymous namespace |
1016 | | |
1017 | | void CondBranchOp::getCanonicalizationPatterns( |
1018 | 0 | OwningRewritePatternList &results, MLIRContext *context) { |
1019 | 0 | results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, |
1020 | 0 | SimplifyCondBranchIdenticalSuccessors>(context); |
1021 | 0 | } |
1022 | | |
1023 | | Optional<MutableOperandRange> |
1024 | 0 | CondBranchOp::getMutableSuccessorOperands(unsigned index) { |
1025 | 0 | assert(index < getNumSuccessors() && "invalid successor index"); |
1026 | 0 | return index == trueIndex ? trueDestOperandsMutable() |
1027 | 0 | : falseDestOperandsMutable(); |
1028 | 0 | } |
1029 | | |
1030 | 0 | Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { |
1031 | 0 | if (BoolAttr condAttr = operands.front().dyn_cast_or_null<BoolAttr>()) |
1032 | 0 | return condAttr.getValue() ? trueDest() : falseDest(); |
1033 | 0 | if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) |
1034 | 0 | return condAttr.getValue().isOneValue() ? trueDest() : falseDest(); |
1035 | 0 | return nullptr; |
1036 | 0 | } |
1037 | | |
1038 | | //===----------------------------------------------------------------------===// |
1039 | | // Constant*Op |
1040 | | //===----------------------------------------------------------------------===// |
1041 | | |
1042 | 0 | static void print(OpAsmPrinter &p, ConstantOp &op) { |
1043 | 0 | p << "constant "; |
1044 | 0 | p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); |
1045 | 0 |
|
1046 | 0 | if (op.getAttrs().size() > 1) |
1047 | 0 | p << ' '; |
1048 | 0 | p << op.getValue(); |
1049 | 0 |
|
1050 | 0 | // If the value is a symbol reference, print a trailing type. |
1051 | 0 | if (op.getValue().isa<SymbolRefAttr>()) |
1052 | 0 | p << " : " << op.getType(); |
1053 | 0 | } |
1054 | | |
1055 | | static ParseResult parseConstantOp(OpAsmParser &parser, |
1056 | 0 | OperationState &result) { |
1057 | 0 | Attribute valueAttr; |
1058 | 0 | if (parser.parseOptionalAttrDict(result.attributes) || |
1059 | 0 | parser.parseAttribute(valueAttr, "value", result.attributes)) |
1060 | 0 | return failure(); |
1061 | 0 | |
1062 | 0 | // If the attribute is a symbol reference, then we expect a trailing type. |
1063 | 0 | Type type; |
1064 | 0 | if (!valueAttr.isa<SymbolRefAttr>()) |
1065 | 0 | type = valueAttr.getType(); |
1066 | 0 | else if (parser.parseColonType(type)) |
1067 | 0 | return failure(); |
1068 | 0 | |
1069 | 0 | // Add the attribute type to the list. |
1070 | 0 | return parser.addTypeToList(type, result.types); |
1071 | 0 | } |
1072 | | |
1073 | | /// The constant op requires an attribute, and furthermore requires that it |
1074 | | /// matches the return type. |
1075 | 0 | static LogicalResult verify(ConstantOp &op) { |
1076 | 0 | auto value = op.getValue(); |
1077 | 0 | if (!value) |
1078 | 0 | return op.emitOpError("requires a 'value' attribute"); |
1079 | 0 | |
1080 | 0 | auto type = op.getType(); |
1081 | 0 | if (!value.getType().isa<NoneType>() && type != value.getType()) |
1082 | 0 | return op.emitOpError() << "requires attribute's type (" << value.getType() |
1083 | 0 | << ") to match op's return type (" << type << ")"; |
1084 | 0 | |
1085 | 0 | if (type.isa<IndexType>() || value.isa<BoolAttr>()) |
1086 | 0 | return success(); |
1087 | 0 | |
1088 | 0 | if (auto intAttr = value.dyn_cast<IntegerAttr>()) { |
1089 | 0 | // If the type has a known bitwidth we verify that the value can be |
1090 | 0 | // represented with the given bitwidth. |
1091 | 0 | auto bitwidth = type.cast<IntegerType>().getWidth(); |
1092 | 0 | auto intVal = intAttr.getValue(); |
1093 | 0 | if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) |
1094 | 0 | return op.emitOpError("requires 'value' to be an integer within the " |
1095 | 0 | "range of the integer result type"); |
1096 | 0 | return success(); |
1097 | 0 | } |
1098 | 0 | |
1099 | 0 | if (type.isa<FloatType>()) { |
1100 | 0 | if (!value.isa<FloatAttr>()) |
1101 | 0 | return op.emitOpError("requires 'value' to be a floating point constant"); |
1102 | 0 | return success(); |
1103 | 0 | } |
1104 | 0 | |
1105 | 0 | if (type.isa<ShapedType>()) { |
1106 | 0 | if (!value.isa<ElementsAttr>()) |
1107 | 0 | return op.emitOpError("requires 'value' to be a shaped constant"); |
1108 | 0 | return success(); |
1109 | 0 | } |
1110 | 0 | |
1111 | 0 | if (type.isa<FunctionType>()) { |
1112 | 0 | auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>(); |
1113 | 0 | if (!fnAttr) |
1114 | 0 | return op.emitOpError("requires 'value' to be a function reference"); |
1115 | 0 | |
1116 | 0 | // Try to find the referenced function. |
1117 | 0 | auto fn = |
1118 | 0 | op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue()); |
1119 | 0 | if (!fn) |
1120 | 0 | return op.emitOpError("reference to undefined function 'bar'"); |
1121 | 0 | |
1122 | 0 | // Check that the referenced function has the correct type. |
1123 | 0 | if (fn.getType() != type) |
1124 | 0 | return op.emitOpError("reference to function with mismatched type"); |
1125 | 0 | |
1126 | 0 | return success(); |
1127 | 0 | } |
1128 | 0 | |
1129 | 0 | if (type.isa<NoneType>() && value.isa<UnitAttr>()) |
1130 | 0 | return success(); |
1131 | 0 | |
1132 | 0 | return op.emitOpError("unsupported 'value' attribute: ") << value; |
1133 | 0 | } |
1134 | | |
1135 | 0 | OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { |
1136 | 0 | assert(operands.empty() && "constant has no operands"); |
1137 | 0 | return getValue(); |
1138 | 0 | } |
1139 | | |
1140 | | void ConstantOp::getAsmResultNames( |
1141 | 0 | function_ref<void(Value, StringRef)> setNameFn) { |
1142 | 0 | Type type = getType(); |
1143 | 0 | if (auto intCst = getValue().dyn_cast<IntegerAttr>()) { |
1144 | 0 | IntegerType intTy = type.dyn_cast<IntegerType>(); |
1145 | 0 |
|
1146 | 0 | // Sugar i1 constants with 'true' and 'false'. |
1147 | 0 | if (intTy && intTy.getWidth() == 1) |
1148 | 0 | return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); |
1149 | 0 |
|
1150 | 0 | // Otherwise, build a complex name with the value and type. |
1151 | 0 | SmallString<32> specialNameBuffer; |
1152 | 0 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
1153 | 0 | specialName << 'c' << intCst.getInt(); |
1154 | 0 | if (intTy) |
1155 | 0 | specialName << '_' << type; |
1156 | 0 | setNameFn(getResult(), specialName.str()); |
1157 | 0 |
|
1158 | 0 | } else if (type.isa<FunctionType>()) { |
1159 | 0 | setNameFn(getResult(), "f"); |
1160 | 0 | } else { |
1161 | 0 | setNameFn(getResult(), "cst"); |
1162 | 0 | } |
1163 | 0 | } |
1164 | | |
1165 | | /// Returns true if a constant operation can be built with the given value and |
1166 | | /// result type. |
1167 | 0 | bool ConstantOp::isBuildableWith(Attribute value, Type type) { |
1168 | 0 | // SymbolRefAttr can only be used with a function type. |
1169 | 0 | if (value.isa<SymbolRefAttr>()) |
1170 | 0 | return type.isa<FunctionType>(); |
1171 | 0 | // Otherwise, the attribute must have the same type as 'type'. |
1172 | 0 | if (value.getType() != type) |
1173 | 0 | return false; |
1174 | 0 | // Finally, check that the attribute kind is handled. |
1175 | 0 | return value.isa<BoolAttr>() || value.isa<IntegerAttr>() || |
1176 | 0 | value.isa<FloatAttr>() || value.isa<ElementsAttr>() || |
1177 | 0 | value.isa<UnitAttr>(); |
1178 | 0 | } |
1179 | | |
1180 | | void ConstantFloatOp::build(OpBuilder &builder, OperationState &result, |
1181 | 0 | const APFloat &value, FloatType type) { |
1182 | 0 | ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value)); |
1183 | 0 | } |
1184 | | |
1185 | 0 | bool ConstantFloatOp::classof(Operation *op) { |
1186 | 0 | return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>(); |
1187 | 0 | } |
1188 | | |
1189 | | /// ConstantIntOp only matches values whose result type is an IntegerType. |
1190 | 0 | bool ConstantIntOp::classof(Operation *op) { |
1191 | 0 | return ConstantOp::classof(op) && |
1192 | 0 | op->getResult(0).getType().isSignlessInteger(); |
1193 | 0 | } |
1194 | | |
1195 | | void ConstantIntOp::build(OpBuilder &builder, OperationState &result, |
1196 | 0 | int64_t value, unsigned width) { |
1197 | 0 | Type type = builder.getIntegerType(width); |
1198 | 0 | ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); |
1199 | 0 | } |
1200 | | |
1201 | | /// Build a constant int op producing an integer with the specified type, |
1202 | | /// which must be an integer type. |
1203 | | void ConstantIntOp::build(OpBuilder &builder, OperationState &result, |
1204 | 0 | int64_t value, Type type) { |
1205 | 0 | assert(type.isSignlessInteger() && |
1206 | 0 | "ConstantIntOp can only have signless integer type"); |
1207 | 0 | ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); |
1208 | 0 | } |
1209 | | |
1210 | | /// ConstantIndexOp only matches values whose result type is Index. |
1211 | 0 | bool ConstantIndexOp::classof(Operation *op) { |
1212 | 0 | return ConstantOp::classof(op) && op->getResult(0).getType().isIndex(); |
1213 | 0 | } |
1214 | | |
1215 | | void ConstantIndexOp::build(OpBuilder &builder, OperationState &result, |
1216 | 0 | int64_t value) { |
1217 | 0 | Type type = builder.getIndexType(); |
1218 | 0 | ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); |
1219 | 0 | } |
1220 | | |
1221 | | //===----------------------------------------------------------------------===// |
1222 | | // DeallocOp |
1223 | | //===----------------------------------------------------------------------===// |
1224 | | namespace { |
1225 | | /// Fold Dealloc operations that are deallocating an AllocOp that is only used |
1226 | | /// by other Dealloc operations. |
1227 | | struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> { |
1228 | | using OpRewritePattern<DeallocOp>::OpRewritePattern; |
1229 | | |
1230 | | LogicalResult matchAndRewrite(DeallocOp dealloc, |
1231 | 0 | PatternRewriter &rewriter) const override { |
1232 | 0 | // Check that the memref operand's defining operation is an AllocOp. |
1233 | 0 | Value memref = dealloc.memref(); |
1234 | 0 | if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp())) |
1235 | 0 | return failure(); |
1236 | 0 | |
1237 | 0 | // Check that all of the uses of the AllocOp are other DeallocOps. |
1238 | 0 | for (auto *user : memref.getUsers()) |
1239 | 0 | if (!isa<DeallocOp>(user)) |
1240 | 0 | return failure(); |
1241 | 0 |
|
1242 | 0 | // Erase the dealloc operation. |
1243 | 0 | rewriter.eraseOp(dealloc); |
1244 | 0 | return success(); |
1245 | 0 | } |
1246 | | }; |
1247 | | } // end anonymous namespace. |
1248 | | |
1249 | 0 | static LogicalResult verify(DeallocOp op) { |
1250 | 0 | if (!op.memref().getType().isa<MemRefType>()) |
1251 | 0 | return op.emitOpError("operand must be a memref"); |
1252 | 0 | return success(); |
1253 | 0 | } |
1254 | | |
1255 | | void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
1256 | 0 | MLIRContext *context) { |
1257 | 0 | results.insert<SimplifyDeadDealloc>(context); |
1258 | 0 | } |
1259 | | |
1260 | | LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, |
1261 | 0 | SmallVectorImpl<OpFoldResult> &results) { |
1262 | 0 | /// dealloc(memrefcast) -> dealloc |
1263 | 0 | return foldMemRefCast(*this); |
1264 | 0 | } |
1265 | | |
1266 | | //===----------------------------------------------------------------------===// |
1267 | | // DimOp |
1268 | | //===----------------------------------------------------------------------===// |
1269 | | |
1270 | 0 | static void print(OpAsmPrinter &p, DimOp op) { |
1271 | 0 | p << "dim " << op.getOperand() << ", " << op.getIndex(); |
1272 | 0 | p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); |
1273 | 0 | p << " : " << op.getOperand().getType(); |
1274 | 0 | } |
1275 | | |
1276 | 0 | static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) { |
1277 | 0 | OpAsmParser::OperandType operandInfo; |
1278 | 0 | IntegerAttr indexAttr; |
1279 | 0 | Type type; |
1280 | 0 | Type indexType = parser.getBuilder().getIndexType(); |
1281 | 0 |
|
1282 | 0 | return failure( |
1283 | 0 | parser.parseOperand(operandInfo) || parser.parseComma() || |
1284 | 0 | parser.parseAttribute(indexAttr, indexType, "index", result.attributes) || |
1285 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
1286 | 0 | parser.parseColonType(type) || |
1287 | 0 | parser.resolveOperand(operandInfo, type, result.operands) || |
1288 | 0 | parser.addTypeToList(indexType, result.types)); |
1289 | 0 | } |
1290 | | |
1291 | 0 | static LogicalResult verify(DimOp op) { |
1292 | 0 | // Check that we have an integer index operand. |
1293 | 0 | auto indexAttr = op.getAttrOfType<IntegerAttr>("index"); |
1294 | 0 | if (!indexAttr) |
1295 | 0 | return op.emitOpError("requires an integer attribute named 'index'"); |
1296 | 0 | int64_t index = indexAttr.getInt(); |
1297 | 0 |
|
1298 | 0 | auto type = op.getOperand().getType(); |
1299 | 0 | if (auto tensorType = type.dyn_cast<RankedTensorType>()) { |
1300 | 0 | if (index >= tensorType.getRank()) |
1301 | 0 | return op.emitOpError("index is out of range"); |
1302 | 0 | } else if (auto memrefType = type.dyn_cast<MemRefType>()) { |
1303 | 0 | if (index >= memrefType.getRank()) |
1304 | 0 | return op.emitOpError("index is out of range"); |
1305 | 0 | |
1306 | 0 | } else if (type.isa<UnrankedTensorType>()) { |
1307 | 0 | // ok, assumed to be in-range. |
1308 | 0 | } else { |
1309 | 0 | return op.emitOpError("requires an operand with tensor or memref type"); |
1310 | 0 | } |
1311 | 0 | |
1312 | 0 | return success(); |
1313 | 0 | } |
1314 | | |
1315 | 0 | OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { |
1316 | 0 | // Constant fold dim when the size along the index referred to is a constant. |
1317 | 0 | auto opType = memrefOrTensor().getType(); |
1318 | 0 | if (auto shapedType = opType.dyn_cast<ShapedType>()) |
1319 | 0 | if (!shapedType.isDynamicDim(getIndex())) |
1320 | 0 | return IntegerAttr::get(IndexType::get(getContext()), |
1321 | 0 | shapedType.getShape()[getIndex()]); |
1322 | 0 | |
1323 | 0 | // Fold dim to the size argument for an AllocOp/ViewOp/SubViewOp. |
1324 | 0 | auto memrefType = opType.dyn_cast<MemRefType>(); |
1325 | 0 | if (!memrefType) |
1326 | 0 | return {}; |
1327 | 0 | |
1328 | 0 | // The size at getIndex() is now known to be a dynamic size of a memref. |
1329 | 0 | auto memref = memrefOrTensor().getDefiningOp(); |
1330 | 0 | if (auto alloc = dyn_cast_or_null<AllocOp>(memref)) |
1331 | 0 | return *(alloc.getDynamicSizes().begin() + |
1332 | 0 | memrefType.getDynamicDimIndex(getIndex())); |
1333 | 0 | |
1334 | 0 | if (auto view = dyn_cast_or_null<ViewOp>(memref)) |
1335 | 0 | return *(view.getDynamicSizes().begin() + |
1336 | 0 | memrefType.getDynamicDimIndex(getIndex())); |
1337 | 0 | |
1338 | 0 | if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) { |
1339 | 0 | assert(subview.isDynamicSize(getIndex()) && |
1340 | 0 | "Expected dynamic subview size"); |
1341 | 0 | return subview.getDynamicSize(getIndex()); |
1342 | 0 | } |
1343 | 0 | |
1344 | 0 | /// dim(memrefcast) -> dim |
1345 | 0 | if (succeeded(foldMemRefCast(*this))) |
1346 | 0 | return getResult(); |
1347 | 0 | |
1348 | 0 | return {}; |
1349 | 0 | } |
1350 | | |
1351 | | // --------------------------------------------------------------------------- |
1352 | | // DmaStartOp |
1353 | | // --------------------------------------------------------------------------- |
1354 | | |
1355 | | void DmaStartOp::build(OpBuilder &builder, OperationState &result, |
1356 | | Value srcMemRef, ValueRange srcIndices, Value destMemRef, |
1357 | | ValueRange destIndices, Value numElements, |
1358 | | Value tagMemRef, ValueRange tagIndices, Value stride, |
1359 | 0 | Value elementsPerStride) { |
1360 | 0 | result.addOperands(srcMemRef); |
1361 | 0 | result.addOperands(srcIndices); |
1362 | 0 | result.addOperands(destMemRef); |
1363 | 0 | result.addOperands(destIndices); |
1364 | 0 | result.addOperands({numElements, tagMemRef}); |
1365 | 0 | result.addOperands(tagIndices); |
1366 | 0 | if (stride) |
1367 | 0 | result.addOperands({stride, elementsPerStride}); |
1368 | 0 | } |
1369 | | |
1370 | 0 | void DmaStartOp::print(OpAsmPrinter &p) { |
1371 | 0 | p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], " |
1372 | 0 | << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() |
1373 | 0 | << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; |
1374 | 0 | if (isStrided()) |
1375 | 0 | p << ", " << getStride() << ", " << getNumElementsPerStride(); |
1376 | 0 |
|
1377 | 0 | p.printOptionalAttrDict(getAttrs()); |
1378 | 0 | p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() |
1379 | 0 | << ", " << getTagMemRef().getType(); |
1380 | 0 | } |
1381 | | |
1382 | | // Parse DmaStartOp. |
1383 | | // Ex: |
1384 | | // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, |
1385 | | // %tag[%index], %stride, %num_elt_per_stride : |
1386 | | // : memref<3076 x f32, 0>, |
1387 | | // memref<1024 x f32, 2>, |
1388 | | // memref<1 x i32> |
1389 | | // |
1390 | 0 | ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { |
1391 | 0 | OpAsmParser::OperandType srcMemRefInfo; |
1392 | 0 | SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; |
1393 | 0 | OpAsmParser::OperandType dstMemRefInfo; |
1394 | 0 | SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; |
1395 | 0 | OpAsmParser::OperandType numElementsInfo; |
1396 | 0 | OpAsmParser::OperandType tagMemrefInfo; |
1397 | 0 | SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; |
1398 | 0 | SmallVector<OpAsmParser::OperandType, 2> strideInfo; |
1399 | 0 |
|
1400 | 0 | SmallVector<Type, 3> types; |
1401 | 0 | auto indexType = parser.getBuilder().getIndexType(); |
1402 | 0 |
|
1403 | 0 | // Parse and resolve the following list of operands: |
1404 | 0 | // *) source memref followed by its indices (in square brackets). |
1405 | 0 | // *) destination memref followed by its indices (in square brackets). |
1406 | 0 | // *) dma size in KiB. |
1407 | 0 | if (parser.parseOperand(srcMemRefInfo) || |
1408 | 0 | parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || |
1409 | 0 | parser.parseComma() || parser.parseOperand(dstMemRefInfo) || |
1410 | 0 | parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || |
1411 | 0 | parser.parseComma() || parser.parseOperand(numElementsInfo) || |
1412 | 0 | parser.parseComma() || parser.parseOperand(tagMemrefInfo) || |
1413 | 0 | parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) |
1414 | 0 | return failure(); |
1415 | 0 | |
1416 | 0 | // Parse optional stride and elements per stride. |
1417 | 0 | if (parser.parseTrailingOperandList(strideInfo)) |
1418 | 0 | return failure(); |
1419 | 0 | |
1420 | 0 | bool isStrided = strideInfo.size() == 2; |
1421 | 0 | if (!strideInfo.empty() && !isStrided) { |
1422 | 0 | return parser.emitError(parser.getNameLoc(), |
1423 | 0 | "expected two stride related operands"); |
1424 | 0 | } |
1425 | 0 | |
1426 | 0 | if (parser.parseColonTypeList(types)) |
1427 | 0 | return failure(); |
1428 | 0 | if (types.size() != 3) |
1429 | 0 | return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); |
1430 | 0 | |
1431 | 0 | if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || |
1432 | 0 | parser.resolveOperands(srcIndexInfos, indexType, result.operands) || |
1433 | 0 | parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || |
1434 | 0 | parser.resolveOperands(dstIndexInfos, indexType, result.operands) || |
1435 | 0 | // size should be an index. |
1436 | 0 | parser.resolveOperand(numElementsInfo, indexType, result.operands) || |
1437 | 0 | parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || |
1438 | 0 | // tag indices should be index. |
1439 | 0 | parser.resolveOperands(tagIndexInfos, indexType, result.operands)) |
1440 | 0 | return failure(); |
1441 | 0 | |
1442 | 0 | if (isStrided) { |
1443 | 0 | if (parser.resolveOperands(strideInfo, indexType, result.operands)) |
1444 | 0 | return failure(); |
1445 | 0 | } |
1446 | 0 | |
1447 | 0 | return success(); |
1448 | 0 | } |
1449 | | |
1450 | 0 | LogicalResult DmaStartOp::verify() { |
1451 | 0 | unsigned numOperands = getNumOperands(); |
1452 | 0 |
|
1453 | 0 | // Mandatory non-variadic operands are: src memref, dst memref, tag memref and |
1454 | 0 | // the number of elements. |
1455 | 0 | if (numOperands < 4) |
1456 | 0 | return emitOpError("expected at least 4 operands"); |
1457 | 0 | |
1458 | 0 | // Check types of operands. The order of these calls is important: the later |
1459 | 0 | // calls rely on some type properties to compute the operand position. |
1460 | 0 | // 1. Source memref. |
1461 | 0 | if (!getSrcMemRef().getType().isa<MemRefType>()) |
1462 | 0 | return emitOpError("expected source to be of memref type"); |
1463 | 0 | if (numOperands < getSrcMemRefRank() + 4) |
1464 | 0 | return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 |
1465 | 0 | << " operands"; |
1466 | 0 | if (!getSrcIndices().empty() && |
1467 | 0 | !llvm::all_of(getSrcIndices().getTypes(), |
1468 | 0 | [](Type t) { return t.isIndex(); })) |
1469 | 0 | return emitOpError("expected source indices to be of index type"); |
1470 | 0 | |
1471 | 0 | // 2. Destination memref. |
1472 | 0 | if (!getDstMemRef().getType().isa<MemRefType>()) |
1473 | 0 | return emitOpError("expected destination to be of memref type"); |
1474 | 0 | unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; |
1475 | 0 | if (numOperands < numExpectedOperands) |
1476 | 0 | return emitOpError() << "expected at least " << numExpectedOperands |
1477 | 0 | << " operands"; |
1478 | 0 | if (!getDstIndices().empty() && |
1479 | 0 | !llvm::all_of(getDstIndices().getTypes(), |
1480 | 0 | [](Type t) { return t.isIndex(); })) |
1481 | 0 | return emitOpError("expected destination indices to be of index type"); |
1482 | 0 | |
1483 | 0 | // 3. Number of elements. |
1484 | 0 | if (!getNumElements().getType().isIndex()) |
1485 | 0 | return emitOpError("expected num elements to be of index type"); |
1486 | 0 | |
1487 | 0 | // 4. Tag memref. |
1488 | 0 | if (!getTagMemRef().getType().isa<MemRefType>()) |
1489 | 0 | return emitOpError("expected tag to be of memref type"); |
1490 | 0 | numExpectedOperands += getTagMemRefRank(); |
1491 | 0 | if (numOperands < numExpectedOperands) |
1492 | 0 | return emitOpError() << "expected at least " << numExpectedOperands |
1493 | 0 | << " operands"; |
1494 | 0 | if (!getTagIndices().empty() && |
1495 | 0 | !llvm::all_of(getTagIndices().getTypes(), |
1496 | 0 | [](Type t) { return t.isIndex(); })) |
1497 | 0 | return emitOpError("expected tag indices to be of index type"); |
1498 | 0 | |
1499 | 0 | // DMAs from different memory spaces supported. |
1500 | 0 | if (getSrcMemorySpace() == getDstMemorySpace()) |
1501 | 0 | return emitOpError("DMA should be between different memory spaces"); |
1502 | 0 | |
1503 | 0 | // Optional stride-related operands must be either both present or both |
1504 | 0 | // absent. |
1505 | 0 | if (numOperands != numExpectedOperands && |
1506 | 0 | numOperands != numExpectedOperands + 2) |
1507 | 0 | return emitOpError("incorrect number of operands"); |
1508 | 0 | |
1509 | 0 | // 5. Strides. |
1510 | 0 | if (isStrided()) { |
1511 | 0 | if (!getStride().getType().isIndex() || |
1512 | 0 | !getNumElementsPerStride().getType().isIndex()) |
1513 | 0 | return emitOpError( |
1514 | 0 | "expected stride and num elements per stride to be of type index"); |
1515 | 0 | } |
1516 | 0 | |
1517 | 0 | return success(); |
1518 | 0 | } |
1519 | | |
1520 | | LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, |
1521 | 0 | SmallVectorImpl<OpFoldResult> &results) { |
1522 | 0 | /// dma_start(memrefcast) -> dma_start |
1523 | 0 | return foldMemRefCast(*this); |
1524 | 0 | } |
1525 | | |
1526 | | // --------------------------------------------------------------------------- |
1527 | | // DmaWaitOp |
1528 | | // --------------------------------------------------------------------------- |
1529 | | |
1530 | | void DmaWaitOp::build(OpBuilder &builder, OperationState &result, |
1531 | | Value tagMemRef, ValueRange tagIndices, |
1532 | 0 | Value numElements) { |
1533 | 0 | result.addOperands(tagMemRef); |
1534 | 0 | result.addOperands(tagIndices); |
1535 | 0 | result.addOperands(numElements); |
1536 | 0 | } |
1537 | | |
1538 | 0 | void DmaWaitOp::print(OpAsmPrinter &p) { |
1539 | 0 | p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], " |
1540 | 0 | << getNumElements(); |
1541 | 0 | p.printOptionalAttrDict(getAttrs()); |
1542 | 0 | p << " : " << getTagMemRef().getType(); |
1543 | 0 | } |
1544 | | |
1545 | | // Parse DmaWaitOp. |
1546 | | // Eg: |
1547 | | // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> |
1548 | | // |
1549 | 0 | ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { |
1550 | 0 | OpAsmParser::OperandType tagMemrefInfo; |
1551 | 0 | SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; |
1552 | 0 | Type type; |
1553 | 0 | auto indexType = parser.getBuilder().getIndexType(); |
1554 | 0 | OpAsmParser::OperandType numElementsInfo; |
1555 | 0 |
|
1556 | 0 | // Parse tag memref, its indices, and dma size. |
1557 | 0 | if (parser.parseOperand(tagMemrefInfo) || |
1558 | 0 | parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || |
1559 | 0 | parser.parseComma() || parser.parseOperand(numElementsInfo) || |
1560 | 0 | parser.parseColonType(type) || |
1561 | 0 | parser.resolveOperand(tagMemrefInfo, type, result.operands) || |
1562 | 0 | parser.resolveOperands(tagIndexInfos, indexType, result.operands) || |
1563 | 0 | parser.resolveOperand(numElementsInfo, indexType, result.operands)) |
1564 | 0 | return failure(); |
1565 | 0 | |
1566 | 0 | return success(); |
1567 | 0 | } |
1568 | | |
1569 | | LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, |
1570 | 0 | SmallVectorImpl<OpFoldResult> &results) { |
1571 | 0 | /// dma_wait(memrefcast) -> dma_wait |
1572 | 0 | return foldMemRefCast(*this); |
1573 | 0 | } |
1574 | | |
1575 | 0 | LogicalResult DmaWaitOp::verify() { |
1576 | 0 | // Mandatory non-variadic operands are tag and the number of elements. |
1577 | 0 | if (getNumOperands() < 2) |
1578 | 0 | return emitOpError() << "expected at least 2 operands"; |
1579 | 0 | |
1580 | 0 | // Check types of operands. The order of these calls is important: the later |
1581 | 0 | // calls rely on some type properties to compute the operand position. |
1582 | 0 | if (!getTagMemRef().getType().isa<MemRefType>()) |
1583 | 0 | return emitOpError() << "expected tag to be of memref type"; |
1584 | 0 | |
1585 | 0 | if (getNumOperands() != 2 + getTagMemRefRank()) |
1586 | 0 | return emitOpError() << "expected " << 2 + getTagMemRefRank() |
1587 | 0 | << " operands"; |
1588 | 0 | |
1589 | 0 | if (!getTagIndices().empty() && |
1590 | 0 | !llvm::all_of(getTagIndices().getTypes(), |
1591 | 0 | [](Type t) { return t.isIndex(); })) |
1592 | 0 | return emitOpError() << "expected tag indices to be of index type"; |
1593 | 0 | |
1594 | 0 | if (!getNumElements().getType().isIndex()) |
1595 | 0 | return emitOpError() |
1596 | 0 | << "expected the number of elements to be of index type"; |
1597 | 0 | |
1598 | 0 | return success(); |
1599 | 0 | } |
1600 | | |
1601 | | //===----------------------------------------------------------------------===// |
1602 | | // ExtractElementOp |
1603 | | //===----------------------------------------------------------------------===// |
1604 | | |
1605 | 0 | static LogicalResult verify(ExtractElementOp op) { |
1606 | 0 | // Verify the # indices match if we have a ranked type. |
1607 | 0 | auto aggregateType = op.getAggregate().getType().cast<ShapedType>(); |
1608 | 0 | if (aggregateType.hasRank() && |
1609 | 0 | aggregateType.getRank() != op.getNumOperands() - 1) |
1610 | 0 | return op.emitOpError("incorrect number of indices for extract_element"); |
1611 | 0 | |
1612 | 0 | return success(); |
1613 | 0 | } |
1614 | | |
1615 | 0 | OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) { |
1616 | 0 | assert(!operands.empty() && "extract_element takes at least one operand"); |
1617 | 0 |
|
1618 | 0 | // The aggregate operand must be a known constant. |
1619 | 0 | Attribute aggregate = operands.front(); |
1620 | 0 | if (!aggregate) |
1621 | 0 | return {}; |
1622 | 0 | |
1623 | 0 | // If this is a splat elements attribute, simply return the value. All of the |
1624 | 0 | // elements of a splat attribute are the same. |
1625 | 0 | if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>()) |
1626 | 0 | return splatAggregate.getSplatValue(); |
1627 | 0 | |
1628 | 0 | // Otherwise, collect the constant indices into the aggregate. |
1629 | 0 | SmallVector<uint64_t, 8> indices; |
1630 | 0 | for (Attribute indice : llvm::drop_begin(operands, 1)) { |
1631 | 0 | if (!indice || !indice.isa<IntegerAttr>()) |
1632 | 0 | return {}; |
1633 | 0 | indices.push_back(indice.cast<IntegerAttr>().getInt()); |
1634 | 0 | } |
1635 | 0 |
|
1636 | 0 | // If this is an elements attribute, query the value at the given indices. |
1637 | 0 | auto elementsAttr = aggregate.dyn_cast<ElementsAttr>(); |
1638 | 0 | if (elementsAttr && elementsAttr.isValidIndex(indices)) |
1639 | 0 | return elementsAttr.getValue(indices); |
1640 | 0 | return {}; |
1641 | 0 | } |
1642 | | |
1643 | | //===----------------------------------------------------------------------===// |
1644 | | // TensorFromElementsOp |
1645 | | //===----------------------------------------------------------------------===// |
1646 | | |
1647 | | static ParseResult parseTensorFromElementsOp(OpAsmParser &parser, |
1648 | 0 | OperationState &result) { |
1649 | 0 | SmallVector<OpAsmParser::OperandType, 4> elementsOperands; |
1650 | 0 | Type resultType; |
1651 | 0 | if (parser.parseLParen() || parser.parseOperandList(elementsOperands) || |
1652 | 0 | parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) || |
1653 | 0 | parser.parseColon() || parser.parseType(resultType)) |
1654 | 0 | return failure(); |
1655 | 0 | |
1656 | 0 | if (parser.resolveOperands(elementsOperands, |
1657 | 0 | resultType.cast<ShapedType>().getElementType(), |
1658 | 0 | result.operands)) |
1659 | 0 | return failure(); |
1660 | 0 | |
1661 | 0 | result.addTypes(resultType); |
1662 | 0 | return success(); |
1663 | 0 | } |
1664 | | |
1665 | 0 | static void print(OpAsmPrinter &p, TensorFromElementsOp op) { |
1666 | 0 | p << "tensor_from_elements(" << op.elements() << ')'; |
1667 | 0 | p.printOptionalAttrDict(op.getAttrs()); |
1668 | 0 | p << " : " << op.result().getType(); |
1669 | 0 | } |
1670 | | |
1671 | 0 | static LogicalResult verify(TensorFromElementsOp op) { |
1672 | 0 | auto resultTensorType = op.result().getType().dyn_cast<RankedTensorType>(); |
1673 | 0 | if (!resultTensorType) |
1674 | 0 | return op.emitOpError("expected result type to be a ranked tensor"); |
1675 | 0 | |
1676 | 0 | int64_t elementsCount = static_cast<int64_t>(op.elements().size()); |
1677 | 0 | if (resultTensorType.getRank() != 1 || |
1678 | 0 | resultTensorType.getShape().front() != elementsCount) |
1679 | 0 | return op.emitOpError() |
1680 | 0 | << "expected result type to be a 1D tensor with " << elementsCount |
1681 | 0 | << (elementsCount == 1 ? " element" : " elements"); |
1682 | 0 | return success(); |
1683 | 0 | } |
1684 | | |
1685 | | namespace { |
1686 | | |
1687 | | // Canonicalizes the pattern of the form |
1688 | | // |
1689 | | // %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32> |
1690 | | // %extracted_element = extract_element %tensor[%c0] : tensor<1xi32> |
1691 | | // |
1692 | | // to just %element. |
1693 | | struct ExtractElementFromTensorFromElements |
1694 | | : public OpRewritePattern<ExtractElementOp> { |
1695 | | using OpRewritePattern<ExtractElementOp>::OpRewritePattern; |
1696 | | |
1697 | | LogicalResult matchAndRewrite(ExtractElementOp extract, |
1698 | 0 | PatternRewriter &rewriter) const final { |
1699 | 0 | if (extract.indices().size() != 1) |
1700 | 0 | return failure(); |
1701 | 0 | |
1702 | 0 | auto tensor_from_elements = |
1703 | 0 | dyn_cast<TensorFromElementsOp>(extract.aggregate().getDefiningOp()); |
1704 | 0 | if (tensor_from_elements == nullptr) |
1705 | 0 | return failure(); |
1706 | 0 | |
1707 | 0 | APInt index; |
1708 | 0 | if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) |
1709 | 0 | return failure(); |
1710 | 0 | rewriter.replaceOp(extract, |
1711 | 0 | tensor_from_elements.getOperand(index.getZExtValue())); |
1712 | 0 | return success(); |
1713 | 0 | } |
1714 | | }; |
1715 | | |
1716 | | } // namespace |
1717 | | |
1718 | | void TensorFromElementsOp::getCanonicalizationPatterns( |
1719 | 0 | OwningRewritePatternList &results, MLIRContext *context) { |
1720 | 0 | results.insert<ExtractElementFromTensorFromElements>(context); |
1721 | 0 | } |
1722 | | |
1723 | | //===----------------------------------------------------------------------===// |
1724 | | // FPExtOp |
1725 | | //===----------------------------------------------------------------------===// |
1726 | | |
1727 | 0 | bool FPExtOp::areCastCompatible(Type a, Type b) { |
1728 | 0 | if (auto fa = a.dyn_cast<FloatType>()) |
1729 | 0 | if (auto fb = b.dyn_cast<FloatType>()) |
1730 | 0 | return fa.getWidth() < fb.getWidth(); |
1731 | 0 | if (auto va = a.dyn_cast<VectorType>()) |
1732 | 0 | if (auto vb = b.dyn_cast<VectorType>()) |
1733 | 0 | return va.getShape().equals(vb.getShape()) && |
1734 | 0 | areCastCompatible(va.getElementType(), vb.getElementType()); |
1735 | 0 | return false; |
1736 | 0 | } |
1737 | | |
1738 | | //===----------------------------------------------------------------------===// |
1739 | | // FPToSIOp |
1740 | | //===----------------------------------------------------------------------===// |
1741 | | |
1742 | 0 | bool FPToSIOp::areCastCompatible(Type a, Type b) { |
1743 | 0 | return a.isa<FloatType>() && b.isSignlessInteger(); |
1744 | 0 | } |
1745 | | |
1746 | | //===----------------------------------------------------------------------===// |
1747 | | // FPTruncOp |
1748 | | //===----------------------------------------------------------------------===// |
1749 | | |
1750 | 0 | bool FPTruncOp::areCastCompatible(Type a, Type b) { |
1751 | 0 | if (auto fa = a.dyn_cast<FloatType>()) |
1752 | 0 | if (auto fb = b.dyn_cast<FloatType>()) |
1753 | 0 | return fa.getWidth() > fb.getWidth(); |
1754 | 0 | if (auto va = a.dyn_cast<VectorType>()) |
1755 | 0 | if (auto vb = b.dyn_cast<VectorType>()) |
1756 | 0 | return va.getShape().equals(vb.getShape()) && |
1757 | 0 | areCastCompatible(va.getElementType(), vb.getElementType()); |
1758 | 0 | return false; |
1759 | 0 | } |
1760 | | |
1761 | | //===----------------------------------------------------------------------===// |
1762 | | // IndexCastOp |
1763 | | //===----------------------------------------------------------------------===// |
1764 | | |
1765 | | // Index cast is applicable from index to integer and backwards. |
1766 | 0 | bool IndexCastOp::areCastCompatible(Type a, Type b) { |
1767 | 0 | return (a.isIndex() && b.isSignlessInteger()) || |
1768 | 0 | (a.isSignlessInteger() && b.isIndex()); |
1769 | 0 | } |
1770 | | |
1771 | 0 | OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) { |
1772 | 0 | // Fold IndexCast(IndexCast(x)) -> x |
1773 | 0 | auto cast = getOperand().getDefiningOp<IndexCastOp>(); |
1774 | 0 | if (cast && cast.getOperand().getType() == getType()) |
1775 | 0 | return cast.getOperand(); |
1776 | 0 | |
1777 | 0 | // Fold IndexCast(constant) -> constant |
1778 | 0 | // A little hack because we go through int. Otherwise, the size |
1779 | 0 | // of the constant might need to change. |
1780 | 0 | if (auto value = cstOperands[0].dyn_cast_or_null<IntegerAttr>()) |
1781 | 0 | return IntegerAttr::get(getType(), value.getInt()); |
1782 | 0 | |
1783 | 0 | return {}; |
1784 | 0 | } |
1785 | | |
1786 | | //===----------------------------------------------------------------------===// |
1787 | | // LoadOp |
1788 | | //===----------------------------------------------------------------------===// |
1789 | | |
1790 | 0 | static LogicalResult verify(LoadOp op) { |
1791 | 0 | if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) |
1792 | 0 | return op.emitOpError("incorrect number of indices for load"); |
1793 | 0 | return success(); |
1794 | 0 | } |
1795 | | |
1796 | 0 | OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { |
1797 | 0 | /// load(memrefcast) -> load |
1798 | 0 | if (succeeded(foldMemRefCast(*this))) |
1799 | 0 | return getResult(); |
1800 | 0 | return OpFoldResult(); |
1801 | 0 | } |
1802 | | |
1803 | | //===----------------------------------------------------------------------===// |
1804 | | // MemRefCastOp |
1805 | | //===----------------------------------------------------------------------===// |
1806 | | |
1807 | 0 | bool MemRefCastOp::areCastCompatible(Type a, Type b) { |
1808 | 0 | auto aT = a.dyn_cast<MemRefType>(); |
1809 | 0 | auto bT = b.dyn_cast<MemRefType>(); |
1810 | 0 |
|
1811 | 0 | auto uaT = a.dyn_cast<UnrankedMemRefType>(); |
1812 | 0 | auto ubT = b.dyn_cast<UnrankedMemRefType>(); |
1813 | 0 |
|
1814 | 0 | if (aT && bT) { |
1815 | 0 | if (aT.getElementType() != bT.getElementType()) |
1816 | 0 | return false; |
1817 | 0 | if (aT.getAffineMaps() != bT.getAffineMaps()) { |
1818 | 0 | int64_t aOffset, bOffset; |
1819 | 0 | SmallVector<int64_t, 4> aStrides, bStrides; |
1820 | 0 | if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || |
1821 | 0 | failed(getStridesAndOffset(bT, bStrides, bOffset)) || |
1822 | 0 | aStrides.size() != bStrides.size()) |
1823 | 0 | return false; |
1824 | 0 | |
1825 | 0 | // Strides along a dimension/offset are compatible if the value in the |
1826 | 0 | // source memref is static and the value in the target memref is the |
1827 | 0 | // same. They are also compatible if either one is dynamic (see |
1828 | 0 | // description of MemRefCastOp for details). |
1829 | 0 | auto checkCompatible = [](int64_t a, int64_t b) { |
1830 | 0 | return (a == MemRefType::getDynamicStrideOrOffset() || |
1831 | 0 | b == MemRefType::getDynamicStrideOrOffset() || a == b); |
1832 | 0 | }; |
1833 | 0 | if (!checkCompatible(aOffset, bOffset)) |
1834 | 0 | return false; |
1835 | 0 | for (auto aStride : enumerate(aStrides)) |
1836 | 0 | if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) |
1837 | 0 | return false; |
1838 | 0 | } |
1839 | 0 | if (aT.getMemorySpace() != bT.getMemorySpace()) |
1840 | 0 | return false; |
1841 | 0 | |
1842 | 0 | // They must have the same rank, and any specified dimensions must match. |
1843 | 0 | if (aT.getRank() != bT.getRank()) |
1844 | 0 | return false; |
1845 | 0 | |
1846 | 0 | for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { |
1847 | 0 | int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); |
1848 | 0 | if (aDim != -1 && bDim != -1 && aDim != bDim) |
1849 | 0 | return false; |
1850 | 0 | } |
1851 | 0 | return true; |
1852 | 0 | } else { |
1853 | 0 | if (!aT && !uaT) |
1854 | 0 | return false; |
1855 | 0 | if (!bT && !ubT) |
1856 | 0 | return false; |
1857 | 0 | // Unranked to unranked casting is unsupported |
1858 | 0 | if (uaT && ubT) |
1859 | 0 | return false; |
1860 | 0 | |
1861 | 0 | auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); |
1862 | 0 | auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); |
1863 | 0 | if (aEltType != bEltType) |
1864 | 0 | return false; |
1865 | 0 | |
1866 | 0 | auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); |
1867 | 0 | auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); |
1868 | 0 | if (aMemSpace != bMemSpace) |
1869 | 0 | return false; |
1870 | 0 | |
1871 | 0 | return true; |
1872 | 0 | } |
1873 | 0 | |
1874 | 0 | return false; |
1875 | 0 | } |
1876 | | |
1877 | 0 | OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) { |
1878 | 0 | return impl::foldCastOp(*this); |
1879 | 0 | } |
1880 | | |
1881 | | //===----------------------------------------------------------------------===// |
1882 | | // MulFOp |
1883 | | //===----------------------------------------------------------------------===// |
1884 | | |
1885 | 0 | OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) { |
1886 | 0 | return constFoldBinaryOp<FloatAttr>( |
1887 | 0 | operands, [](APFloat a, APFloat b) { return a * b; }); |
1888 | 0 | } |
1889 | | |
1890 | | //===----------------------------------------------------------------------===// |
1891 | | // MulIOp |
1892 | | //===----------------------------------------------------------------------===// |
1893 | | |
1894 | 0 | OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) { |
1895 | 0 | /// muli(x, 0) -> 0 |
1896 | 0 | if (matchPattern(rhs(), m_Zero())) |
1897 | 0 | return rhs(); |
1898 | 0 | /// muli(x, 1) -> x |
1899 | 0 | if (matchPattern(rhs(), m_One())) |
1900 | 0 | return getOperand(0); |
1901 | 0 | |
1902 | 0 | // TODO: Handle the overflow case. |
1903 | 0 | return constFoldBinaryOp<IntegerAttr>(operands, |
1904 | 0 | [](APInt a, APInt b) { return a * b; }); |
1905 | 0 | } |
1906 | | |
1907 | | //===----------------------------------------------------------------------===// |
1908 | | // OrOp |
1909 | | //===----------------------------------------------------------------------===// |
1910 | | |
1911 | 0 | OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) { |
1912 | 0 | /// or(x, 0) -> x |
1913 | 0 | if (matchPattern(rhs(), m_Zero())) |
1914 | 0 | return lhs(); |
1915 | 0 | /// or(x,x) -> x |
1916 | 0 | if (lhs() == rhs()) |
1917 | 0 | return rhs(); |
1918 | 0 | |
1919 | 0 | return constFoldBinaryOp<IntegerAttr>(operands, |
1920 | 0 | [](APInt a, APInt b) { return a | b; }); |
1921 | 0 | } |
1922 | | |
1923 | | //===----------------------------------------------------------------------===// |
1924 | | // PrefetchOp |
1925 | | //===----------------------------------------------------------------------===// |
1926 | | |
1927 | 0 | static void print(OpAsmPrinter &p, PrefetchOp op) { |
1928 | 0 | p << PrefetchOp::getOperationName() << " " << op.memref() << '['; |
1929 | 0 | p.printOperands(op.indices()); |
1930 | 0 | p << ']' << ", " << (op.isWrite() ? "write" : "read"); |
1931 | 0 | p << ", locality<" << op.localityHint(); |
1932 | 0 | p << ">, " << (op.isDataCache() ? "data" : "instr"); |
1933 | 0 | p.printOptionalAttrDict( |
1934 | 0 | op.getAttrs(), |
1935 | 0 | /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); |
1936 | 0 | p << " : " << op.getMemRefType(); |
1937 | 0 | } |
1938 | | |
1939 | | static ParseResult parsePrefetchOp(OpAsmParser &parser, |
1940 | 0 | OperationState &result) { |
1941 | 0 | OpAsmParser::OperandType memrefInfo; |
1942 | 0 | SmallVector<OpAsmParser::OperandType, 4> indexInfo; |
1943 | 0 | IntegerAttr localityHint; |
1944 | 0 | MemRefType type; |
1945 | 0 | StringRef readOrWrite, cacheType; |
1946 | 0 |
|
1947 | 0 | auto indexTy = parser.getBuilder().getIndexType(); |
1948 | 0 | auto i32Type = parser.getBuilder().getIntegerType(32); |
1949 | 0 | if (parser.parseOperand(memrefInfo) || |
1950 | 0 | parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || |
1951 | 0 | parser.parseComma() || parser.parseKeyword(&readOrWrite) || |
1952 | 0 | parser.parseComma() || parser.parseKeyword("locality") || |
1953 | 0 | parser.parseLess() || |
1954 | 0 | parser.parseAttribute(localityHint, i32Type, "localityHint", |
1955 | 0 | result.attributes) || |
1956 | 0 | parser.parseGreater() || parser.parseComma() || |
1957 | 0 | parser.parseKeyword(&cacheType) || parser.parseColonType(type) || |
1958 | 0 | parser.resolveOperand(memrefInfo, type, result.operands) || |
1959 | 0 | parser.resolveOperands(indexInfo, indexTy, result.operands)) |
1960 | 0 | return failure(); |
1961 | 0 | |
1962 | 0 | if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) |
1963 | 0 | return parser.emitError(parser.getNameLoc(), |
1964 | 0 | "rw specifier has to be 'read' or 'write'"); |
1965 | 0 | result.addAttribute( |
1966 | 0 | PrefetchOp::getIsWriteAttrName(), |
1967 | 0 | parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); |
1968 | 0 |
|
1969 | 0 | if (!cacheType.equals("data") && !cacheType.equals("instr")) |
1970 | 0 | return parser.emitError(parser.getNameLoc(), |
1971 | 0 | "cache type has to be 'data' or 'instr'"); |
1972 | 0 | |
1973 | 0 | result.addAttribute( |
1974 | 0 | PrefetchOp::getIsDataCacheAttrName(), |
1975 | 0 | parser.getBuilder().getBoolAttr(cacheType.equals("data"))); |
1976 | 0 |
|
1977 | 0 | return success(); |
1978 | 0 | } |
1979 | | |
1980 | 0 | static LogicalResult verify(PrefetchOp op) { |
1981 | 0 | if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) |
1982 | 0 | return op.emitOpError("too few indices"); |
1983 | 0 | |
1984 | 0 | return success(); |
1985 | 0 | } |
1986 | | |
1987 | | LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, |
1988 | 0 | SmallVectorImpl<OpFoldResult> &results) { |
1989 | 0 | // prefetch(memrefcast) -> prefetch |
1990 | 0 | return foldMemRefCast(*this); |
1991 | 0 | } |
1992 | | |
1993 | | //===----------------------------------------------------------------------===// |
1994 | | // RankOp |
1995 | | //===----------------------------------------------------------------------===// |
1996 | | |
1997 | 0 | OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { |
1998 | 0 | // Constant fold rank when the rank of the tensor is known. |
1999 | 0 | auto type = getOperand().getType(); |
2000 | 0 | if (auto tensorType = type.dyn_cast<RankedTensorType>()) |
2001 | 0 | return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank()); |
2002 | 0 | return IntegerAttr(); |
2003 | 0 | } |
2004 | | |
2005 | | //===----------------------------------------------------------------------===// |
2006 | | // ReturnOp |
2007 | | //===----------------------------------------------------------------------===// |
2008 | | |
2009 | 0 | static LogicalResult verify(ReturnOp op) { |
2010 | 0 | auto function = cast<FuncOp>(op.getParentOp()); |
2011 | 0 |
|
2012 | 0 | // The operand number and types must match the function signature. |
2013 | 0 | const auto &results = function.getType().getResults(); |
2014 | 0 | if (op.getNumOperands() != results.size()) |
2015 | 0 | return op.emitOpError("has ") |
2016 | 0 | << op.getNumOperands() |
2017 | 0 | << " operands, but enclosing function returns " << results.size(); |
2018 | 0 | |
2019 | 0 | for (unsigned i = 0, e = results.size(); i != e; ++i) |
2020 | 0 | if (op.getOperand(i).getType() != results[i]) |
2021 | 0 | return op.emitError() |
2022 | 0 | << "type of return operand " << i << " (" |
2023 | 0 | << op.getOperand(i).getType() |
2024 | 0 | << ") doesn't match function result type (" << results[i] << ")"; |
2025 | 0 |
|
2026 | 0 | return success(); |
2027 | 0 | } |
2028 | | |
2029 | | //===----------------------------------------------------------------------===// |
2030 | | // SelectOp |
2031 | | //===----------------------------------------------------------------------===// |
2032 | | |
2033 | 0 | OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) { |
2034 | 0 | auto condition = getCondition(); |
2035 | 0 |
|
2036 | 0 | // select true, %0, %1 => %0 |
2037 | 0 | if (matchPattern(condition, m_One())) |
2038 | 0 | return getTrueValue(); |
2039 | 0 | |
2040 | 0 | // select false, %0, %1 => %1 |
2041 | 0 | if (matchPattern(condition, m_Zero())) |
2042 | 0 | return getFalseValue(); |
2043 | 0 | return nullptr; |
2044 | 0 | } |
2045 | | |
2046 | 0 | static void print(OpAsmPrinter &p, SelectOp op) { |
2047 | 0 | p << "select " << op.getOperands(); |
2048 | 0 | p.printOptionalAttrDict(op.getAttrs()); |
2049 | 0 | p << " : "; |
2050 | 0 | if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>()) |
2051 | 0 | p << condType << ", "; |
2052 | 0 | p << op.getType(); |
2053 | 0 | } |
2054 | | |
2055 | 0 | static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { |
2056 | 0 | Type conditionType, resultType; |
2057 | 0 | SmallVector<OpAsmParser::OperandType, 3> operands; |
2058 | 0 | if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || |
2059 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
2060 | 0 | parser.parseColonType(resultType)) |
2061 | 0 | return failure(); |
2062 | 0 | |
2063 | 0 | // Check for the explicit condition type if this is a masked tensor or vector. |
2064 | 0 | if (succeeded(parser.parseOptionalComma())) { |
2065 | 0 | conditionType = resultType; |
2066 | 0 | if (parser.parseType(resultType)) |
2067 | 0 | return failure(); |
2068 | 0 | } else { |
2069 | 0 | conditionType = parser.getBuilder().getI1Type(); |
2070 | 0 | } |
2071 | 0 |
|
2072 | 0 | result.addTypes(resultType); |
2073 | 0 | return parser.resolveOperands(operands, |
2074 | 0 | {conditionType, resultType, resultType}, |
2075 | 0 | parser.getNameLoc(), result.operands); |
2076 | 0 | } |
2077 | | |
2078 | 0 | static LogicalResult verify(SelectOp op) { |
2079 | 0 | Type conditionType = op.getCondition().getType(); |
2080 | 0 | if (conditionType.isSignlessInteger(1)) |
2081 | 0 | return success(); |
2082 | 0 | |
2083 | 0 | // If the result type is a vector or tensor, the type can be a mask with the |
2084 | 0 | // same elements. |
2085 | 0 | Type resultType = op.getType(); |
2086 | 0 | if (!resultType.isa<TensorType>() && !resultType.isa<VectorType>()) |
2087 | 0 | return op.emitOpError() |
2088 | 0 | << "expected condition to be a signless i1, but got " |
2089 | 0 | << conditionType; |
2090 | 0 | Type shapedConditionType = getI1SameShape(resultType); |
2091 | 0 | if (conditionType != shapedConditionType) |
2092 | 0 | return op.emitOpError() |
2093 | 0 | << "expected condition type to have the same shape " |
2094 | 0 | "as the result type, expected " |
2095 | 0 | << shapedConditionType << ", but got " << conditionType; |
2096 | 0 | return success(); |
2097 | 0 | } |
2098 | | |
2099 | | //===----------------------------------------------------------------------===// |
2100 | | // SignExtendIOp |
2101 | | //===----------------------------------------------------------------------===// |
2102 | | |
2103 | 0 | static LogicalResult verify(SignExtendIOp op) { |
2104 | 0 | // Get the scalar type (which is either directly the type of the operand |
2105 | 0 | // or the vector's/tensor's element type. |
2106 | 0 | auto srcType = getElementTypeOrSelf(op.getOperand().getType()); |
2107 | 0 | auto dstType = getElementTypeOrSelf(op.getType()); |
2108 | 0 |
|
2109 | 0 | // For now, index is forbidden for the source and the destination type. |
2110 | 0 | if (srcType.isa<IndexType>()) |
2111 | 0 | return op.emitError() << srcType << " is not a valid operand type"; |
2112 | 0 | if (dstType.isa<IndexType>()) |
2113 | 0 | return op.emitError() << dstType << " is not a valid result type"; |
2114 | 0 | |
2115 | 0 | if (srcType.cast<IntegerType>().getWidth() >= |
2116 | 0 | dstType.cast<IntegerType>().getWidth()) |
2117 | 0 | return op.emitError("result type ") |
2118 | 0 | << dstType << " must be wider than operand type " << srcType; |
2119 | 0 | |
2120 | 0 | return success(); |
2121 | 0 | } |
2122 | | |
2123 | | //===----------------------------------------------------------------------===// |
2124 | | // SignedDivIOp |
2125 | | //===----------------------------------------------------------------------===// |
2126 | | |
2127 | 0 | OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) { |
2128 | 0 | assert(operands.size() == 2 && "binary operation takes two operands"); |
2129 | 0 |
|
2130 | 0 | // Don't fold if it would overflow or if it requires a division by zero. |
2131 | 0 | bool overflowOrDiv0 = false; |
2132 | 0 | auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { |
2133 | 0 | if (overflowOrDiv0 || !b) { |
2134 | 0 | overflowOrDiv0 = true; |
2135 | 0 | return a; |
2136 | 0 | } |
2137 | 0 | return a.sdiv_ov(b, overflowOrDiv0); |
2138 | 0 | }); |
2139 | 0 |
|
2140 | 0 | // Fold out division by one. Assumes all tensors of all ones are splats. |
2141 | 0 | if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { |
2142 | 0 | if (rhs.getValue() == 1) |
2143 | 0 | return lhs(); |
2144 | 0 | } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { |
2145 | 0 | if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) |
2146 | 0 | return lhs(); |
2147 | 0 | } |
2148 | 0 | |
2149 | 0 | return overflowOrDiv0 ? Attribute() : result; |
2150 | 0 | } |
2151 | | |
2152 | | //===----------------------------------------------------------------------===// |
2153 | | // SignedRemIOp |
2154 | | //===----------------------------------------------------------------------===// |
2155 | | |
2156 | 0 | OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) { |
2157 | 0 | assert(operands.size() == 2 && "remi_signed takes two operands"); |
2158 | 0 |
|
2159 | 0 | auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); |
2160 | 0 | if (!rhs) |
2161 | 0 | return {}; |
2162 | 0 | auto rhsValue = rhs.getValue(); |
2163 | 0 |
|
2164 | 0 | // x % 1 = 0 |
2165 | 0 | if (rhsValue.isOneValue()) |
2166 | 0 | return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); |
2167 | 0 | |
2168 | 0 | // Don't fold if it requires division by zero. |
2169 | 0 | if (rhsValue.isNullValue()) |
2170 | 0 | return {}; |
2171 | 0 | |
2172 | 0 | auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); |
2173 | 0 | if (!lhs) |
2174 | 0 | return {}; |
2175 | 0 | return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); |
2176 | 0 | } |
2177 | | |
2178 | | //===----------------------------------------------------------------------===// |
2179 | | // SIToFPOp |
2180 | | //===----------------------------------------------------------------------===// |
2181 | | |
2182 | | // sitofp is applicable from integer types to float types. |
2183 | 0 | bool SIToFPOp::areCastCompatible(Type a, Type b) { |
2184 | 0 | return a.isSignlessInteger() && b.isa<FloatType>(); |
2185 | 0 | } |
2186 | | |
2187 | | //===----------------------------------------------------------------------===// |
2188 | | // SplatOp |
2189 | | //===----------------------------------------------------------------------===// |
2190 | | |
2191 | 0 | static LogicalResult verify(SplatOp op) { |
2192 | 0 | // TODO: we could replace this by a trait. |
2193 | 0 | if (op.getOperand().getType() != |
2194 | 0 | op.getType().cast<ShapedType>().getElementType()) |
2195 | 0 | return op.emitError("operand should be of elemental type of result type"); |
2196 | 0 | |
2197 | 0 | return success(); |
2198 | 0 | } |
2199 | | |
2200 | | // Constant folding hook for SplatOp. |
2201 | 0 | OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) { |
2202 | 0 | assert(operands.size() == 1 && "splat takes one operand"); |
2203 | 0 |
|
2204 | 0 | auto constOperand = operands.front(); |
2205 | 0 | if (!constOperand || |
2206 | 0 | (!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>())) |
2207 | 0 | return {}; |
2208 | 0 | |
2209 | 0 | auto shapedType = getType().cast<ShapedType>(); |
2210 | 0 | assert(shapedType.getElementType() == constOperand.getType() && |
2211 | 0 | "incorrect input attribute type for folding"); |
2212 | 0 |
|
2213 | 0 | // SplatElementsAttr::get treats single value for second arg as being a splat. |
2214 | 0 | return SplatElementsAttr::get(shapedType, {constOperand}); |
2215 | 0 | } |
2216 | | |
2217 | | //===----------------------------------------------------------------------===// |
2218 | | // StoreOp |
2219 | | //===----------------------------------------------------------------------===// |
2220 | | |
2221 | 0 | static LogicalResult verify(StoreOp op) { |
2222 | 0 | if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) |
2223 | 0 | return op.emitOpError("store index operand count not equal to memref rank"); |
2224 | 0 | |
2225 | 0 | return success(); |
2226 | 0 | } |
2227 | | |
2228 | | LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, |
2229 | 0 | SmallVectorImpl<OpFoldResult> &results) { |
2230 | 0 | /// store(memrefcast) -> store |
2231 | 0 | return foldMemRefCast(*this); |
2232 | 0 | } |
2233 | | |
2234 | | //===----------------------------------------------------------------------===// |
2235 | | // SubFOp |
2236 | | //===----------------------------------------------------------------------===// |
2237 | | |
2238 | 0 | OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) { |
2239 | 0 | return constFoldBinaryOp<FloatAttr>( |
2240 | 0 | operands, [](APFloat a, APFloat b) { return a - b; }); |
2241 | 0 | } |
2242 | | |
2243 | | //===----------------------------------------------------------------------===// |
2244 | | // SubIOp |
2245 | | //===----------------------------------------------------------------------===// |
2246 | | |
2247 | 0 | OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) { |
2248 | 0 | // subi(x,x) -> 0 |
2249 | 0 | if (getOperand(0) == getOperand(1)) |
2250 | 0 | return Builder(getContext()).getZeroAttr(getType()); |
2251 | 0 | |
2252 | 0 | return constFoldBinaryOp<IntegerAttr>(operands, |
2253 | 0 | [](APInt a, APInt b) { return a - b; }); |
2254 | 0 | } |
2255 | | |
2256 | | //===----------------------------------------------------------------------===// |
2257 | | // SubViewOp |
2258 | | //===----------------------------------------------------------------------===// |
2259 | | |
2260 | | /// Print a list with either (1) the static integer value in `arrayAttr` if |
2261 | | /// `isDynamic` evaluates to false or (2) the next value otherwise. |
2262 | | /// This allows idiomatic printing of mixed value and integer attributes in a |
2263 | | /// list. E.g. `[%arg0, 7, 42, %arg42]`. |
2264 | | static void printSubViewListOfOperandsOrIntegers( |
2265 | | OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr, |
2266 | 0 | llvm::function_ref<bool(int64_t)> isDynamic) { |
2267 | 0 | p << "["; |
2268 | 0 | unsigned idx = 0; |
2269 | 0 | llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { |
2270 | 0 | int64_t val = a.cast<IntegerAttr>().getInt(); |
2271 | 0 | if (isDynamic(val)) |
2272 | 0 | p << values[idx++]; |
2273 | 0 | else |
2274 | 0 | p << val; |
2275 | 0 | }); |
2276 | 0 | p << "] "; |
2277 | 0 | } |
2278 | | |
2279 | | /// Parse a mixed list with either (1) static integer values or (2) SSA values. |
2280 | | /// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` |
2281 | | /// encode the position of SSA values. Add the parsed SSA values to `ssa` |
2282 | | /// in-order. |
2283 | | // |
2284 | | /// E.g. after parsing "[%arg0, 7, 42, %arg42]": |
2285 | | /// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" |
2286 | | /// 2. `ssa` is filled with "[%arg0, %arg1]". |
2287 | | static ParseResult |
2288 | | parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, |
2289 | | StringRef attrName, int64_t dynVal, |
2290 | 0 | SmallVectorImpl<OpAsmParser::OperandType> &ssa) { |
2291 | 0 | if (failed(parser.parseLSquare())) |
2292 | 0 | return failure(); |
2293 | 0 | // 0-D. |
2294 | 0 | if (succeeded(parser.parseOptionalRSquare())) |
2295 | 0 | return success(); |
2296 | 0 | |
2297 | 0 | SmallVector<int64_t, 4> attrVals; |
2298 | 0 | while (true) { |
2299 | 0 | OpAsmParser::OperandType operand; |
2300 | 0 | auto res = parser.parseOptionalOperand(operand); |
2301 | 0 | if (res.hasValue() && succeeded(res.getValue())) { |
2302 | 0 | ssa.push_back(operand); |
2303 | 0 | attrVals.push_back(dynVal); |
2304 | 0 | } else { |
2305 | 0 | Attribute attr; |
2306 | 0 | NamedAttrList placeholder; |
2307 | 0 | if (failed(parser.parseAttribute(attr, "_", placeholder)) || |
2308 | 0 | !attr.isa<IntegerAttr>()) |
2309 | 0 | return parser.emitError(parser.getNameLoc()) |
2310 | 0 | << "expected SSA value or integer"; |
2311 | 0 | attrVals.push_back(attr.cast<IntegerAttr>().getInt()); |
2312 | 0 | } |
2313 | 0 |
|
2314 | 0 | if (succeeded(parser.parseOptionalComma())) |
2315 | 0 | continue; |
2316 | 0 | if (failed(parser.parseRSquare())) |
2317 | 0 | return failure(); |
2318 | 0 | else |
2319 | 0 | break; |
2320 | 0 | } |
2321 | 0 |
|
2322 | 0 | auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); |
2323 | 0 | result.addAttribute(attrName, arrayAttr); |
2324 | 0 | return success(); |
2325 | 0 | } |
2326 | | |
2327 | | namespace { |
2328 | | /// Helpers to write more idiomatic operations. |
2329 | | namespace saturated_arith { |
2330 | | struct Wrapper { |
2331 | 0 | explicit Wrapper(int64_t v) : v(v) {} |
2332 | 0 | operator int64_t() { return v; } |
2333 | | int64_t v; |
2334 | | }; |
2335 | 0 | Wrapper operator+(Wrapper a, int64_t b) { |
2336 | 0 | if (ShapedType::isDynamicStrideOrOffset(a) || |
2337 | 0 | ShapedType::isDynamicStrideOrOffset(b)) |
2338 | 0 | return Wrapper(ShapedType::kDynamicStrideOrOffset); |
2339 | 0 | return Wrapper(a.v + b); |
2340 | 0 | } |
2341 | 0 | Wrapper operator*(Wrapper a, int64_t b) { |
2342 | 0 | if (ShapedType::isDynamicStrideOrOffset(a) || |
2343 | 0 | ShapedType::isDynamicStrideOrOffset(b)) |
2344 | 0 | return Wrapper(ShapedType::kDynamicStrideOrOffset); |
2345 | 0 | return Wrapper(a.v * b); |
2346 | 0 | } |
2347 | | } // end namespace saturated_arith |
2348 | | } // end namespace |
2349 | | |
2350 | | /// A subview result type can be fully inferred from the source type and the |
2351 | | /// static representation of offsets, sizes and strides. Special sentinels |
2352 | | /// encode the dynamic case. |
2353 | | Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType, |
2354 | | ArrayRef<int64_t> staticOffsets, |
2355 | | ArrayRef<int64_t> staticSizes, |
2356 | | ArrayRef<int64_t> staticStrides) { |
2357 | | unsigned rank = sourceMemRefType.getRank(); |
2358 | | (void)rank; |
2359 | | assert(staticOffsets.size() == rank && |
2360 | | "unexpected staticOffsets size mismatch"); |
2361 | | assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch"); |
2362 | | assert(staticStrides.size() == rank && |
2363 | | "unexpected staticStrides size mismatch"); |
2364 | | |
2365 | | // Extract source offset and strides. |
2366 | | int64_t sourceOffset; |
2367 | | SmallVector<int64_t, 4> sourceStrides; |
2368 | | auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); |
2369 | | assert(succeeded(res) && "SubViewOp expected strided memref type"); |
2370 | | (void)res; |
2371 | | |
2372 | | // Compute target offset whose value is: |
2373 | | // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. |
2374 | | int64_t targetOffset = sourceOffset; |
2375 | | for (auto it : llvm::zip(staticOffsets, sourceStrides)) { |
2376 | | auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); |
2377 | | using namespace saturated_arith; |
2378 | | targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; |
2379 | | } |
2380 | | |
2381 | | // Compute target stride whose value is: |
2382 | | // `sourceStrides_i * staticStrides_i`. |
2383 | | SmallVector<int64_t, 4> targetStrides; |
2384 | | targetStrides.reserve(staticOffsets.size()); |
2385 | | for (auto it : llvm::zip(sourceStrides, staticStrides)) { |
2386 | | auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); |
2387 | | using namespace saturated_arith; |
2388 | | targetStrides.push_back(Wrapper(sourceStride) * staticStride); |
2389 | | } |
2390 | | |
2391 | | // The type is now known. |
2392 | | return MemRefType::get( |
2393 | | staticSizes, sourceMemRefType.getElementType(), |
2394 | | makeStridedLinearLayoutMap(targetStrides, targetOffset, |
2395 | | sourceMemRefType.getContext()), |
2396 | | sourceMemRefType.getMemorySpace()); |
2397 | | } |
2398 | | |
2399 | | /// Print SubViewOp in the form: |
2400 | | /// ``` |
2401 | | /// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]` |
2402 | | /// `:` strided-memref-type `to` strided-memref-type |
2403 | | /// ``` |
2404 | 0 | static void print(OpAsmPrinter &p, SubViewOp op) { |
2405 | 0 | int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; |
2406 | 0 | p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; |
2407 | 0 | p << op.getOperand(0); |
2408 | 0 | printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), |
2409 | 0 | ShapedType::isDynamicStrideOrOffset); |
2410 | 0 | printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), |
2411 | 0 | ShapedType::isDynamic); |
2412 | 0 | printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), |
2413 | 0 | ShapedType::isDynamicStrideOrOffset); |
2414 | 0 | p.printOptionalAttrDict(op.getAttrs(), |
2415 | 0 | /*elidedAttrs=*/{SubViewOp::getSpecialAttrNames()}); |
2416 | 0 | p << " : " << op.getOperand(0).getType() << " to " << op.getType(); |
2417 | 0 | } |
2418 | | |
2419 | | /// Parse SubViewOp of the form: |
2420 | | /// ``` |
2421 | | /// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]` |
2422 | | /// `:` strided-memref-type `to` strided-memref-type |
2423 | | /// ``` |
2424 | 0 | static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { |
2425 | 0 | OpAsmParser::OperandType srcInfo; |
2426 | 0 | SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo; |
2427 | 0 | auto indexType = parser.getBuilder().getIndexType(); |
2428 | 0 | Type srcType, dstType; |
2429 | 0 | if (parser.parseOperand(srcInfo)) |
2430 | 0 | return failure(); |
2431 | 0 | if (parseListOfOperandsOrIntegers( |
2432 | 0 | parser, result, SubViewOp::getStaticOffsetsAttrName(), |
2433 | 0 | ShapedType::kDynamicStrideOrOffset, offsetsInfo) || |
2434 | 0 | parseListOfOperandsOrIntegers(parser, result, |
2435 | 0 | SubViewOp::getStaticSizesAttrName(), |
2436 | 0 | ShapedType::kDynamicSize, sizesInfo) || |
2437 | 0 | parseListOfOperandsOrIntegers( |
2438 | 0 | parser, result, SubViewOp::getStaticStridesAttrName(), |
2439 | 0 | ShapedType::kDynamicStrideOrOffset, stridesInfo)) |
2440 | 0 | return failure(); |
2441 | 0 | |
2442 | 0 | auto b = parser.getBuilder(); |
2443 | 0 | SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()), |
2444 | 0 | static_cast<int>(sizesInfo.size()), |
2445 | 0 | static_cast<int>(stridesInfo.size())}; |
2446 | 0 | result.addAttribute(SubViewOp::getOperandSegmentSizeAttr(), |
2447 | 0 | b.getI32VectorAttr(segmentSizes)); |
2448 | 0 |
|
2449 | 0 | return failure( |
2450 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
2451 | 0 | parser.parseColonType(srcType) || |
2452 | 0 | parser.resolveOperand(srcInfo, srcType, result.operands) || |
2453 | 0 | parser.resolveOperands(offsetsInfo, indexType, result.operands) || |
2454 | 0 | parser.resolveOperands(sizesInfo, indexType, result.operands) || |
2455 | 0 | parser.resolveOperands(stridesInfo, indexType, result.operands) || |
2456 | 0 | parser.parseKeywordType("to", dstType) || |
2457 | 0 | parser.addTypeToList(dstType, result.types)); |
2458 | 0 | } |
2459 | | |
2460 | | void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source, |
2461 | | ArrayRef<int64_t> staticOffsets, |
2462 | | ArrayRef<int64_t> staticSizes, |
2463 | | ArrayRef<int64_t> staticStrides, ValueRange offsets, |
2464 | | ValueRange sizes, ValueRange strides, |
2465 | 0 | ArrayRef<NamedAttribute> attrs) { |
2466 | 0 | auto sourceMemRefType = source.getType().cast<MemRefType>(); |
2467 | 0 | auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets, |
2468 | 0 | staticSizes, staticStrides); |
2469 | 0 | build(b, result, resultType, source, offsets, sizes, strides, |
2470 | 0 | b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), |
2471 | 0 | b.getI64ArrayAttr(staticStrides)); |
2472 | 0 | result.addAttributes(attrs); |
2473 | 0 | } |
2474 | | |
2475 | | /// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes` |
2476 | | /// and `staticStrides` are automatically filled with source-memref-rank |
2477 | | /// sentinel values that encode dynamic entries. |
2478 | | void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source, |
2479 | | ValueRange offsets, ValueRange sizes, |
2480 | | ValueRange strides, |
2481 | 0 | ArrayRef<NamedAttribute> attrs) { |
2482 | 0 | auto sourceMemRefType = source.getType().cast<MemRefType>(); |
2483 | 0 | unsigned rank = sourceMemRefType.getRank(); |
2484 | 0 | SmallVector<int64_t, 4> staticOffsetsVector; |
2485 | 0 | staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset); |
2486 | 0 | SmallVector<int64_t, 4> staticSizesVector; |
2487 | 0 | staticSizesVector.assign(rank, ShapedType::kDynamicSize); |
2488 | 0 | SmallVector<int64_t, 4> staticStridesVector; |
2489 | 0 | staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset); |
2490 | 0 | build(b, result, source, staticOffsetsVector, staticSizesVector, |
2491 | 0 | staticStridesVector, offsets, sizes, strides, attrs); |
2492 | 0 | } |
2493 | | |
2494 | | /// Verify that a particular offset/size/stride static attribute is well-formed. |
2495 | | static LogicalResult |
2496 | | verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName, |
2497 | | ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, |
2498 | 0 | ValueRange values) { |
2499 | 0 | /// Check static and dynamic offsets/sizes/strides breakdown. |
2500 | 0 | if (attr.size() != op.getRank()) |
2501 | 0 | return op.emitError("expected ") |
2502 | 0 | << op.getRank() << " " << name << " values"; |
2503 | 0 | unsigned expectedNumDynamicEntries = |
2504 | 0 | llvm::count_if(attr.getValue(), [&](Attribute attr) { |
2505 | 0 | return isDynamic(attr.cast<IntegerAttr>().getInt()); |
2506 | 0 | }); |
2507 | 0 | if (values.size() != expectedNumDynamicEntries) |
2508 | 0 | return op.emitError("expected ") |
2509 | 0 | << expectedNumDynamicEntries << " dynamic " << name << " values"; |
2510 | 0 | return success(); |
2511 | 0 | } |
2512 | | |
2513 | | /// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. |
2514 | 0 | static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { |
2515 | 0 | return llvm::to_vector<4>( |
2516 | 0 | llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { |
2517 | 0 | return a.cast<IntegerAttr>().getInt(); |
2518 | 0 | })); |
2519 | 0 | } |
2520 | | |
2521 | | /// Verifier for SubViewOp. |
2522 | 0 | static LogicalResult verify(SubViewOp op) { |
2523 | 0 | auto baseType = op.getBaseMemRefType().cast<MemRefType>(); |
2524 | 0 | auto subViewType = op.getType(); |
2525 | 0 |
|
2526 | 0 | // The base memref and the view memref should be in the same memory space. |
2527 | 0 | if (baseType.getMemorySpace() != subViewType.getMemorySpace()) |
2528 | 0 | return op.emitError("different memory spaces specified for base memref " |
2529 | 0 | "type ") |
2530 | 0 | << baseType << " and subview memref type " << subViewType; |
2531 | 0 | |
2532 | 0 | // Verify that the base memref type has a strided layout map. |
2533 | 0 | if (!isStrided(baseType)) |
2534 | 0 | return op.emitError("base type ") << baseType << " is not strided"; |
2535 | 0 | |
2536 | 0 | // Verify static attributes offsets/sizes/strides. |
2537 | 0 | if (failed(verifySubViewOpPart( |
2538 | 0 | op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(), |
2539 | 0 | ShapedType::isDynamicStrideOrOffset, op.offsets()))) |
2540 | 0 | return failure(); |
2541 | 0 | |
2542 | 0 | if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(), |
2543 | 0 | op.static_sizes(), ShapedType::isDynamic, |
2544 | 0 | op.sizes()))) |
2545 | 0 | return failure(); |
2546 | 0 | if (failed(verifySubViewOpPart( |
2547 | 0 | op, "stride", op.getStaticStridesAttrName(), op.static_strides(), |
2548 | 0 | ShapedType::isDynamicStrideOrOffset, op.strides()))) |
2549 | 0 | return failure(); |
2550 | 0 | |
2551 | 0 | // Verify result type against inferred type. |
2552 | 0 | auto expectedType = SubViewOp::inferSubViewResultType( |
2553 | 0 | op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()), |
2554 | 0 | extractFromI64ArrayAttr(op.static_sizes()), |
2555 | 0 | extractFromI64ArrayAttr(op.static_strides())); |
2556 | 0 | if (op.getType() != expectedType) |
2557 | 0 | return op.emitError("expected result type to be ") << expectedType; |
2558 | 0 | |
2559 | 0 | return success(); |
2560 | 0 | } |
2561 | | |
2562 | 0 | raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) { |
2563 | 0 | return os << "range " << range.offset << ":" << range.size << ":" |
2564 | 0 | << range.stride; |
2565 | 0 | } |
2566 | | |
2567 | | static unsigned getNumDynamicEntriesUpToIdx( |
2568 | 0 | ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) { |
2569 | 0 | return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx, |
2570 | 0 | [&](Attribute attr) { |
2571 | 0 | return isDynamic(attr.cast<IntegerAttr>().getInt()); |
2572 | 0 | }); |
2573 | 0 | } |
2574 | | |
2575 | 0 | bool SubViewOp::isDynamicOffset(unsigned idx) { |
2576 | 0 | return ShapedType::isDynamicStrideOrOffset( |
2577 | 0 | extractFromI64ArrayAttr(static_offsets())[idx]); |
2578 | 0 | } |
2579 | 0 | bool SubViewOp::isDynamicSize(unsigned idx) { |
2580 | 0 | return ShapedType::isDynamic(extractFromI64ArrayAttr(static_sizes())[idx]); |
2581 | 0 | } |
2582 | 0 | bool SubViewOp::isDynamicStride(unsigned idx) { |
2583 | 0 | return ShapedType::isDynamicStrideOrOffset( |
2584 | 0 | extractFromI64ArrayAttr(static_strides())[idx]); |
2585 | 0 | } |
2586 | | |
2587 | 0 | unsigned SubViewOp::getIndexOfDynamicOffset(unsigned idx) { |
2588 | 0 | assert(isDynamicOffset(idx) && "expected static offset"); |
2589 | 0 | auto numDynamic = |
2590 | 0 | getNumDynamicEntriesUpToIdx(static_offsets().cast<ArrayAttr>(), |
2591 | 0 | ShapedType::isDynamicStrideOrOffset, idx); |
2592 | 0 | return 1 + numDynamic; |
2593 | 0 | } |
2594 | 0 | unsigned SubViewOp::getIndexOfDynamicSize(unsigned idx) { |
2595 | 0 | assert(isDynamicSize(idx) && "expected static size"); |
2596 | 0 | auto numDynamic = getNumDynamicEntriesUpToIdx( |
2597 | 0 | static_sizes().cast<ArrayAttr>(), ShapedType::isDynamic, idx); |
2598 | 0 | return 1 + offsets().size() + numDynamic; |
2599 | 0 | } |
2600 | 0 | unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) { |
2601 | 0 | assert(isDynamicStride(idx) && "expected static stride"); |
2602 | 0 | auto numDynamic = |
2603 | 0 | getNumDynamicEntriesUpToIdx(static_strides().cast<ArrayAttr>(), |
2604 | 0 | ShapedType::isDynamicStrideOrOffset, idx); |
2605 | 0 | return 1 + offsets().size() + sizes().size() + numDynamic; |
2606 | 0 | } |
2607 | | |
2608 | | /// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range |
2609 | | /// entry contains either the dynamic value or a ConstantIndexOp constructed |
2610 | | /// with `b` at location `loc`. |
2611 | | SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b, |
2612 | 0 | Location loc) { |
2613 | 0 | SmallVector<Range, 8> res; |
2614 | 0 | unsigned rank = getType().getRank(); |
2615 | 0 | res.reserve(rank); |
2616 | 0 | for (unsigned idx = 0; idx < rank; ++idx) { |
2617 | 0 | auto offset = isDynamicOffset(idx) |
2618 | 0 | ? getDynamicOffset(idx) |
2619 | 0 | : b.create<ConstantIndexOp>(loc, getStaticOffset(idx)); |
2620 | 0 | auto size = isDynamicSize(idx) |
2621 | 0 | ? getDynamicSize(idx) |
2622 | 0 | : b.create<ConstantIndexOp>(loc, getStaticSize(idx)); |
2623 | 0 | auto stride = isDynamicStride(idx) |
2624 | 0 | ? getDynamicStride(idx) |
2625 | 0 | : b.create<ConstantIndexOp>(loc, getStaticStride(idx)); |
2626 | 0 | res.emplace_back(Range{offset, size, stride}); |
2627 | 0 | } |
2628 | 0 | return res; |
2629 | 0 | } |
2630 | | |
2631 | | SmallVector<Value, 4> SubViewOp::getOrCreateOffsets(OpBuilder &b, |
2632 | 0 | Location loc) { |
2633 | 0 | unsigned dynamicIdx = 1; |
2634 | 0 | return llvm::to_vector<4>(llvm::map_range( |
2635 | 0 | static_offsets().cast<ArrayAttr>(), [&](Attribute a) -> Value { |
2636 | 0 | int64_t staticOffset = a.cast<IntegerAttr>().getInt(); |
2637 | 0 | if (ShapedType::isDynamicStrideOrOffset(staticOffset)) |
2638 | 0 | return getOperand(dynamicIdx++); |
2639 | 0 | else |
2640 | 0 | return b.create<ConstantIndexOp>(loc, staticOffset); |
2641 | 0 | })); |
2642 | 0 | } |
2643 | | |
2644 | 0 | SmallVector<Value, 4> SubViewOp::getOrCreateSizes(OpBuilder &b, Location loc) { |
2645 | 0 | unsigned dynamicIdx = 1 + offsets().size(); |
2646 | 0 | return llvm::to_vector<4>(llvm::map_range( |
2647 | 0 | static_sizes().cast<ArrayAttr>(), [&](Attribute a) -> Value { |
2648 | 0 | int64_t staticSize = a.cast<IntegerAttr>().getInt(); |
2649 | 0 | if (ShapedType::isDynamic(staticSize)) |
2650 | 0 | return getOperand(dynamicIdx++); |
2651 | 0 | else |
2652 | 0 | return b.create<ConstantIndexOp>(loc, staticSize); |
2653 | 0 | })); |
2654 | 0 | } |
2655 | | |
2656 | | SmallVector<Value, 4> SubViewOp::getOrCreateStrides(OpBuilder &b, |
2657 | 0 | Location loc) { |
2658 | 0 | unsigned dynamicIdx = 1 + offsets().size() + sizes().size(); |
2659 | 0 | return llvm::to_vector<4>(llvm::map_range( |
2660 | 0 | static_strides().cast<ArrayAttr>(), [&](Attribute a) -> Value { |
2661 | 0 | int64_t staticStride = a.cast<IntegerAttr>().getInt(); |
2662 | 0 | if (ShapedType::isDynamicStrideOrOffset(staticStride)) |
2663 | 0 | return getOperand(dynamicIdx++); |
2664 | 0 | else |
2665 | 0 | return b.create<ConstantIndexOp>(loc, staticStride); |
2666 | 0 | })); |
2667 | 0 | } |
2668 | | |
2669 | | LogicalResult |
2670 | 0 | SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) { |
2671 | 0 | if (!strides().empty()) |
2672 | 0 | return failure(); |
2673 | 0 | staticStrides = extractFromI64ArrayAttr(static_strides()); |
2674 | 0 | return success(); |
2675 | 0 | } |
2676 | | |
2677 | 0 | Value SubViewOp::getViewSource() { return source(); } |
2678 | | |
2679 | | namespace { |
2680 | | |
2681 | | /// Take a list of `values` with potential new constant to extract and a list |
2682 | | /// of `constantValues` with`values.size()` sentinel that evaluate to true by |
2683 | | /// applying `isDynamic`. |
2684 | | /// Detects the `values` produced by a ConstantIndexOp and places the new |
2685 | | /// constant in place of the corresponding sentinel value. |
2686 | | void canonicalizeSubViewPart(SmallVectorImpl<Value> &values, |
2687 | | SmallVectorImpl<int64_t> &constantValues, |
2688 | 0 | llvm::function_ref<bool(int64_t)> isDynamic) { |
2689 | 0 | bool hasNewStaticValue = llvm::any_of( |
2690 | 0 | values, [](Value val) { return matchPattern(val, m_ConstantIndex()); }); |
2691 | 0 | if (hasNewStaticValue) { |
2692 | 0 | for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size(); |
2693 | 0 | cstIdx != e; ++cstIdx) { |
2694 | 0 | // Was already static, skip. |
2695 | 0 | if (!isDynamic(constantValues[cstIdx])) |
2696 | 0 | continue; |
2697 | 0 | // Newly static, move from Value to constant. |
2698 | 0 | if (matchPattern(values[valIdx], m_ConstantIndex())) { |
2699 | 0 | constantValues[cstIdx] = |
2700 | 0 | cast<ConstantIndexOp>(values[valIdx].getDefiningOp()).getValue(); |
2701 | 0 | // Erase for impl. simplicity. Reverse iterator if we really must. |
2702 | 0 | values.erase(std::next(values.begin(), valIdx)); |
2703 | 0 | continue; |
2704 | 0 | } |
2705 | 0 | // Remains dynamic move to next value. |
2706 | 0 | ++valIdx; |
2707 | 0 | } |
2708 | 0 | } |
2709 | 0 | } |
2710 | | |
2711 | | /// Pattern to rewrite a subview op with constant arguments. |
2712 | | class SubViewOpConstantArgumentFolder final |
2713 | | : public OpRewritePattern<SubViewOp> { |
2714 | | public: |
2715 | | using OpRewritePattern<SubViewOp>::OpRewritePattern; |
2716 | | |
2717 | | LogicalResult matchAndRewrite(SubViewOp subViewOp, |
2718 | 0 | PatternRewriter &rewriter) const override { |
2719 | 0 | // No constant operand, just return; |
2720 | 0 | if (llvm::none_of(subViewOp.getOperands(), [](Value operand) { |
2721 | 0 | return matchPattern(operand, m_ConstantIndex()); |
2722 | 0 | })) |
2723 | 0 | return failure(); |
2724 | 0 | |
2725 | 0 | // At least one of offsets/sizes/strides is a new constant. |
2726 | 0 | // Form the new list of operands and constant attributes from the existing. |
2727 | 0 | SmallVector<Value, 8> newOffsets(subViewOp.offsets()); |
2728 | 0 | SmallVector<int64_t, 8> newStaticOffsets = |
2729 | 0 | extractFromI64ArrayAttr(subViewOp.static_offsets()); |
2730 | 0 | assert(newStaticOffsets.size() == subViewOp.getRank()); |
2731 | 0 | canonicalizeSubViewPart(newOffsets, newStaticOffsets, |
2732 | 0 | ShapedType::isDynamicStrideOrOffset); |
2733 | 0 |
|
2734 | 0 | SmallVector<Value, 8> newSizes(subViewOp.sizes()); |
2735 | 0 | SmallVector<int64_t, 8> newStaticSizes = |
2736 | 0 | extractFromI64ArrayAttr(subViewOp.static_sizes()); |
2737 | 0 | assert(newStaticOffsets.size() == subViewOp.getRank()); |
2738 | 0 | canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic); |
2739 | 0 |
|
2740 | 0 | SmallVector<Value, 8> newStrides(subViewOp.strides()); |
2741 | 0 | SmallVector<int64_t, 8> newStaticStrides = |
2742 | 0 | extractFromI64ArrayAttr(subViewOp.static_strides()); |
2743 | 0 | assert(newStaticOffsets.size() == subViewOp.getRank()); |
2744 | 0 | canonicalizeSubViewPart(newStrides, newStaticStrides, |
2745 | 0 | ShapedType::isDynamicStrideOrOffset); |
2746 | 0 |
|
2747 | 0 | // Create the new op in canonical form. |
2748 | 0 | auto newSubViewOp = rewriter.create<SubViewOp>( |
2749 | 0 | subViewOp.getLoc(), subViewOp.source(), newStaticOffsets, |
2750 | 0 | newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides); |
2751 | 0 |
|
2752 | 0 | // Insert a memref_cast for compatibility of the uses of the op. |
2753 | 0 | rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp, |
2754 | 0 | subViewOp.getType()); |
2755 | 0 |
|
2756 | 0 | return success(); |
2757 | 0 | } |
2758 | | }; |
2759 | | |
2760 | | } // end anonymous namespace |
2761 | | |
2762 | | /// Determines whether MemRefCastOp casts to a more dynamic version of the |
2763 | | /// source memref. This is useful to to fold a memref_cast into a consuming op |
2764 | | /// and implement canonicalization patterns for ops in different dialects that |
2765 | | /// may consume the results of memref_cast operations. Such foldable memref_cast |
2766 | | /// operations are typically inserted as `view` and `subview` ops are |
2767 | | /// canonicalized, to preserve the type compatibility of their uses. |
2768 | | /// |
2769 | | /// Returns true when all conditions are met: |
2770 | | /// 1. source and result are ranked memrefs with strided semantics and same |
2771 | | /// element type and rank. |
2772 | | /// 2. each of the source's size, offset or stride has more static information |
2773 | | /// than the corresponding result's size, offset or stride. |
2774 | | /// |
2775 | | /// Example 1: |
2776 | | /// ```mlir |
2777 | | /// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32> |
2778 | | /// %2 = consumer %1 ... : memref<?x?xf32> ... |
2779 | | /// ``` |
2780 | | /// |
2781 | | /// may fold into: |
2782 | | /// |
2783 | | /// ```mlir |
2784 | | /// %2 = consumer %0 ... : memref<8x16xf32> ... |
2785 | | /// ``` |
2786 | | /// |
2787 | | /// Example 2: |
2788 | | /// ``` |
2789 | | /// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> |
2790 | | /// to memref<?x?xf32> |
2791 | | /// consumer %1 : memref<?x?xf32> ... |
2792 | | /// ``` |
2793 | | /// |
2794 | | /// may fold into: |
2795 | | /// |
2796 | | /// ``` |
2797 | | /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> |
2798 | | /// ``` |
2799 | 0 | bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) { |
2800 | 0 | MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); |
2801 | 0 | MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); |
2802 | 0 |
|
2803 | 0 | // Requires ranked MemRefType. |
2804 | 0 | if (!sourceType || !resultType) |
2805 | 0 | return false; |
2806 | 0 | |
2807 | 0 | // Requires same elemental type. |
2808 | 0 | if (sourceType.getElementType() != resultType.getElementType()) |
2809 | 0 | return false; |
2810 | 0 | |
2811 | 0 | // Requires same rank. |
2812 | 0 | if (sourceType.getRank() != resultType.getRank()) |
2813 | 0 | return false; |
2814 | 0 | |
2815 | 0 | // Only fold casts between strided memref forms. |
2816 | 0 | int64_t sourceOffset, resultOffset; |
2817 | 0 | SmallVector<int64_t, 4> sourceStrides, resultStrides; |
2818 | 0 | if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || |
2819 | 0 | failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) |
2820 | 0 | return false; |
2821 | 0 | |
2822 | 0 | // If cast is towards more static sizes along any dimension, don't fold. |
2823 | 0 | for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { |
2824 | 0 | auto ss = std::get<0>(it), st = std::get<1>(it); |
2825 | 0 | if (ss != st) |
2826 | 0 | if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) |
2827 | 0 | return false; |
2828 | 0 | } |
2829 | 0 |
|
2830 | 0 | // If cast is towards more static offset along any dimension, don't fold. |
2831 | 0 | if (sourceOffset != resultOffset) |
2832 | 0 | if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && |
2833 | 0 | !MemRefType::isDynamicStrideOrOffset(resultOffset)) |
2834 | 0 | return false; |
2835 | 0 | |
2836 | 0 | // If cast is towards more static strides along any dimension, don't fold. |
2837 | 0 | for (auto it : llvm::zip(sourceStrides, resultStrides)) { |
2838 | 0 | auto ss = std::get<0>(it), st = std::get<1>(it); |
2839 | 0 | if (ss != st) |
2840 | 0 | if (MemRefType::isDynamicStrideOrOffset(ss) && |
2841 | 0 | !MemRefType::isDynamicStrideOrOffset(st)) |
2842 | 0 | return false; |
2843 | 0 | } |
2844 | 0 |
|
2845 | 0 | return true; |
2846 | 0 | } |
2847 | | |
2848 | | namespace { |
2849 | | /// Pattern to rewrite a subview op with MemRefCast arguments. |
2850 | | /// This essentially pushes memref_cast past its consuming subview when |
2851 | | /// `canFoldIntoConsumerOp` is true. |
2852 | | /// |
2853 | | /// Example: |
2854 | | /// ``` |
2855 | | /// %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32> |
2856 | | /// %1 = subview %0[0, 0][3, 4][1, 1] : |
2857 | | /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> |
2858 | | /// ``` |
2859 | | /// is rewritten into: |
2860 | | /// ``` |
2861 | | /// %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> |
2862 | | /// %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to |
2863 | | /// memref<3x4xf32, offset:?, strides:[?, 1]> |
2864 | | /// ``` |
2865 | | class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { |
2866 | | public: |
2867 | | using OpRewritePattern<SubViewOp>::OpRewritePattern; |
2868 | | |
2869 | | LogicalResult matchAndRewrite(SubViewOp subViewOp, |
2870 | 0 | PatternRewriter &rewriter) const override { |
2871 | 0 | // Any constant operand, just return to let SubViewOpConstantFolder kick in. |
2872 | 0 | if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { |
2873 | 0 | return matchPattern(operand, m_ConstantIndex()); |
2874 | 0 | })) |
2875 | 0 | return failure(); |
2876 | 0 | |
2877 | 0 | auto castOp = subViewOp.source().getDefiningOp<MemRefCastOp>(); |
2878 | 0 | if (!castOp) |
2879 | 0 | return failure(); |
2880 | 0 | |
2881 | 0 | if (!canFoldIntoConsumerOp(castOp)) |
2882 | 0 | return failure(); |
2883 | 0 | |
2884 | 0 | /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on |
2885 | 0 | /// the cast source operand type and the SubViewOp static information. This |
2886 | 0 | /// is the resulting type if the MemRefCastOp were folded. |
2887 | 0 | Type resultType = SubViewOp::inferSubViewResultType( |
2888 | 0 | castOp.source().getType().cast<MemRefType>(), |
2889 | 0 | extractFromI64ArrayAttr(subViewOp.static_offsets()), |
2890 | 0 | extractFromI64ArrayAttr(subViewOp.static_sizes()), |
2891 | 0 | extractFromI64ArrayAttr(subViewOp.static_strides())); |
2892 | 0 | Value newSubView = rewriter.create<SubViewOp>( |
2893 | 0 | subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), |
2894 | 0 | subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), |
2895 | 0 | subViewOp.static_sizes(), subViewOp.static_strides()); |
2896 | 0 | rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, subViewOp.getType(), |
2897 | 0 | newSubView); |
2898 | 0 | return success(); |
2899 | 0 | } |
2900 | | }; |
2901 | | } // namespace |
2902 | | |
2903 | | void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
2904 | 0 | MLIRContext *context) { |
2905 | 0 | results.insert<SubViewOpConstantArgumentFolder, SubViewOpMemRefCastFolder>( |
2906 | 0 | context); |
2907 | 0 | } |
2908 | | |
2909 | | //===----------------------------------------------------------------------===// |
2910 | | // TensorCastOp |
2911 | | //===----------------------------------------------------------------------===// |
2912 | | |
2913 | 0 | bool TensorCastOp::areCastCompatible(Type a, Type b) { |
2914 | 0 | auto aT = a.dyn_cast<TensorType>(); |
2915 | 0 | auto bT = b.dyn_cast<TensorType>(); |
2916 | 0 | if (!aT || !bT) |
2917 | 0 | return false; |
2918 | 0 | |
2919 | 0 | if (aT.getElementType() != bT.getElementType()) |
2920 | 0 | return false; |
2921 | 0 | |
2922 | 0 | return succeeded(verifyCompatibleShape(aT, bT)); |
2923 | 0 | } |
2924 | | |
2925 | 0 | OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) { |
2926 | 0 | return impl::foldCastOp(*this); |
2927 | 0 | } |
2928 | | |
2929 | | //===----------------------------------------------------------------------===// |
2930 | | // Helpers for Tensor[Load|Store]Op |
2931 | | //===----------------------------------------------------------------------===// |
2932 | | |
2933 | 0 | static Type getTensorTypeFromMemRefType(Type type) { |
2934 | 0 | if (auto memref = type.dyn_cast<MemRefType>()) |
2935 | 0 | return RankedTensorType::get(memref.getShape(), memref.getElementType()); |
2936 | 0 | return NoneType::get(type.getContext()); |
2937 | 0 | } |
2938 | | |
2939 | | //===----------------------------------------------------------------------===// |
2940 | | // TruncateIOp |
2941 | | //===----------------------------------------------------------------------===// |
2942 | | |
2943 | 0 | static LogicalResult verify(TruncateIOp op) { |
2944 | 0 | auto srcType = getElementTypeOrSelf(op.getOperand().getType()); |
2945 | 0 | auto dstType = getElementTypeOrSelf(op.getType()); |
2946 | 0 |
|
2947 | 0 | if (srcType.isa<IndexType>()) |
2948 | 0 | return op.emitError() << srcType << " is not a valid operand type"; |
2949 | 0 | if (dstType.isa<IndexType>()) |
2950 | 0 | return op.emitError() << dstType << " is not a valid result type"; |
2951 | 0 | |
2952 | 0 | if (srcType.cast<IntegerType>().getWidth() <= |
2953 | 0 | dstType.cast<IntegerType>().getWidth()) |
2954 | 0 | return op.emitError("operand type ") |
2955 | 0 | << srcType << " must be wider than result type " << dstType; |
2956 | 0 | |
2957 | 0 | return success(); |
2958 | 0 | } |
2959 | | |
2960 | | //===----------------------------------------------------------------------===// |
2961 | | // UnsignedDivIOp |
2962 | | //===----------------------------------------------------------------------===// |
2963 | | |
2964 | 0 | OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) { |
2965 | 0 | assert(operands.size() == 2 && "binary operation takes two operands"); |
2966 | 0 |
|
2967 | 0 | // Don't fold if it would require a division by zero. |
2968 | 0 | bool div0 = false; |
2969 | 0 | auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) { |
2970 | 0 | if (div0 || !b) { |
2971 | 0 | div0 = true; |
2972 | 0 | return a; |
2973 | 0 | } |
2974 | 0 | return a.udiv(b); |
2975 | 0 | }); |
2976 | 0 |
|
2977 | 0 | // Fold out division by one. Assumes all tensors of all ones are splats. |
2978 | 0 | if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) { |
2979 | 0 | if (rhs.getValue() == 1) |
2980 | 0 | return lhs(); |
2981 | 0 | } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) { |
2982 | 0 | if (rhs.getSplatValue<IntegerAttr>().getValue() == 1) |
2983 | 0 | return lhs(); |
2984 | 0 | } |
2985 | 0 | |
2986 | 0 | return div0 ? Attribute() : result; |
2987 | 0 | } |
2988 | | |
2989 | | //===----------------------------------------------------------------------===// |
2990 | | // UnsignedRemIOp |
2991 | | //===----------------------------------------------------------------------===// |
2992 | | |
2993 | 0 | OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) { |
2994 | 0 | assert(operands.size() == 2 && "remi_unsigned takes two operands"); |
2995 | 0 |
|
2996 | 0 | auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); |
2997 | 0 | if (!rhs) |
2998 | 0 | return {}; |
2999 | 0 | auto rhsValue = rhs.getValue(); |
3000 | 0 |
|
3001 | 0 | // x % 1 = 0 |
3002 | 0 | if (rhsValue.isOneValue()) |
3003 | 0 | return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); |
3004 | 0 | |
3005 | 0 | // Don't fold if it requires division by zero. |
3006 | 0 | if (rhsValue.isNullValue()) |
3007 | 0 | return {}; |
3008 | 0 | |
3009 | 0 | auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); |
3010 | 0 | if (!lhs) |
3011 | 0 | return {}; |
3012 | 0 | return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); |
3013 | 0 | } |
3014 | | |
3015 | | //===----------------------------------------------------------------------===// |
3016 | | // ViewOp |
3017 | | //===----------------------------------------------------------------------===// |
3018 | | |
3019 | 0 | static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { |
3020 | 0 | OpAsmParser::OperandType srcInfo; |
3021 | 0 | SmallVector<OpAsmParser::OperandType, 1> offsetInfo; |
3022 | 0 | SmallVector<OpAsmParser::OperandType, 4> sizesInfo; |
3023 | 0 | auto indexType = parser.getBuilder().getIndexType(); |
3024 | 0 | Type srcType, dstType; |
3025 | 0 | llvm::SMLoc offsetLoc; |
3026 | 0 | if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || |
3027 | 0 | parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) |
3028 | 0 | return failure(); |
3029 | 0 | |
3030 | 0 | if (offsetInfo.size() != 1) |
3031 | 0 | return parser.emitError(offsetLoc) << "expects 1 offset operand"; |
3032 | 0 | |
3033 | 0 | return failure( |
3034 | 0 | parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || |
3035 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
3036 | 0 | parser.parseColonType(srcType) || |
3037 | 0 | parser.resolveOperand(srcInfo, srcType, result.operands) || |
3038 | 0 | parser.resolveOperands(offsetInfo, indexType, result.operands) || |
3039 | 0 | parser.resolveOperands(sizesInfo, indexType, result.operands) || |
3040 | 0 | parser.parseKeywordType("to", dstType) || |
3041 | 0 | parser.addTypeToList(dstType, result.types)); |
3042 | 0 | } |
3043 | | |
3044 | 0 | static void print(OpAsmPrinter &p, ViewOp op) { |
3045 | 0 | p << op.getOperationName() << ' ' << op.getOperand(0) << '['; |
3046 | 0 | p.printOperand(op.byte_shift()); |
3047 | 0 | p << "][" << op.sizes() << ']'; |
3048 | 0 | p.printOptionalAttrDict(op.getAttrs()); |
3049 | 0 | p << " : " << op.getOperand(0).getType() << " to " << op.getType(); |
3050 | 0 | } |
3051 | | |
3052 | 0 | static LogicalResult verify(ViewOp op) { |
3053 | 0 | auto baseType = op.getOperand(0).getType().cast<MemRefType>(); |
3054 | 0 | auto viewType = op.getType(); |
3055 | 0 |
|
3056 | 0 | // The base memref should have identity layout map (or none). |
3057 | 0 | if (baseType.getAffineMaps().size() > 1 || |
3058 | 0 | (baseType.getAffineMaps().size() == 1 && |
3059 | 0 | !baseType.getAffineMaps()[0].isIdentity())) |
3060 | 0 | return op.emitError("unsupported map for base memref type ") << baseType; |
3061 | 0 | |
3062 | 0 | // The result memref should have identity layout map (or none). |
3063 | 0 | if (viewType.getAffineMaps().size() > 1 || |
3064 | 0 | (viewType.getAffineMaps().size() == 1 && |
3065 | 0 | !viewType.getAffineMaps()[0].isIdentity())) |
3066 | 0 | return op.emitError("unsupported map for result memref type ") << viewType; |
3067 | 0 | |
3068 | 0 | // The base memref and the view memref should be in the same memory space. |
3069 | 0 | if (baseType.getMemorySpace() != viewType.getMemorySpace()) |
3070 | 0 | return op.emitError("different memory spaces specified for base memref " |
3071 | 0 | "type ") |
3072 | 0 | << baseType << " and view memref type " << viewType; |
3073 | 0 | |
3074 | 0 | // Verify that we have the correct number of sizes for the result type. |
3075 | 0 | unsigned numDynamicDims = viewType.getNumDynamicDims(); |
3076 | 0 | if (op.sizes().size() != numDynamicDims) |
3077 | 0 | return op.emitError("incorrect number of size operands for type ") |
3078 | 0 | << viewType; |
3079 | 0 | |
3080 | 0 | return success(); |
3081 | 0 | } |
3082 | | |
3083 | 0 | Value ViewOp::getViewSource() { return source(); } |
3084 | | |
3085 | | namespace { |
3086 | | |
3087 | | struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { |
3088 | | using OpRewritePattern<ViewOp>::OpRewritePattern; |
3089 | | |
3090 | | LogicalResult matchAndRewrite(ViewOp viewOp, |
3091 | 0 | PatternRewriter &rewriter) const override { |
3092 | 0 | // Return if none of the operands are constants. |
3093 | 0 | if (llvm::none_of(viewOp.getOperands(), [](Value operand) { |
3094 | 0 | return matchPattern(operand, m_ConstantIndex()); |
3095 | 0 | })) |
3096 | 0 | return failure(); |
3097 | 0 | |
3098 | 0 | // Get result memref type. |
3099 | 0 | auto memrefType = viewOp.getType(); |
3100 | 0 |
|
3101 | 0 | // Get offset from old memref view type 'memRefType'. |
3102 | 0 | int64_t oldOffset; |
3103 | 0 | SmallVector<int64_t, 4> oldStrides; |
3104 | 0 | if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) |
3105 | 0 | return failure(); |
3106 | 0 | assert(oldOffset == 0 && "Expected 0 offset"); |
3107 | 0 |
|
3108 | 0 | SmallVector<Value, 4> newOperands; |
3109 | 0 |
|
3110 | 0 | // Offset cannot be folded into result type. |
3111 | 0 |
|
3112 | 0 | // Fold any dynamic dim operands which are produced by a constant. |
3113 | 0 | SmallVector<int64_t, 4> newShapeConstants; |
3114 | 0 | newShapeConstants.reserve(memrefType.getRank()); |
3115 | 0 |
|
3116 | 0 | unsigned dynamicDimPos = 0; |
3117 | 0 | unsigned rank = memrefType.getRank(); |
3118 | 0 | for (unsigned dim = 0, e = rank; dim < e; ++dim) { |
3119 | 0 | int64_t dimSize = memrefType.getDimSize(dim); |
3120 | 0 | // If this is already static dimension, keep it. |
3121 | 0 | if (!ShapedType::isDynamic(dimSize)) { |
3122 | 0 | newShapeConstants.push_back(dimSize); |
3123 | 0 | continue; |
3124 | 0 | } |
3125 | 0 | auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); |
3126 | 0 | if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { |
3127 | 0 | // Dynamic shape dimension will be folded. |
3128 | 0 | newShapeConstants.push_back(constantIndexOp.getValue()); |
3129 | 0 | } else { |
3130 | 0 | // Dynamic shape dimension not folded; copy operand from old memref. |
3131 | 0 | newShapeConstants.push_back(dimSize); |
3132 | 0 | newOperands.push_back(viewOp.sizes()[dynamicDimPos]); |
3133 | 0 | } |
3134 | 0 | dynamicDimPos++; |
3135 | 0 | } |
3136 | 0 |
|
3137 | 0 | // Create new memref type with constant folded dims. |
3138 | 0 | MemRefType newMemRefType = |
3139 | 0 | MemRefType::Builder(memrefType).setShape(newShapeConstants); |
3140 | 0 | // Nothing new, don't fold. |
3141 | 0 | if (newMemRefType == memrefType) |
3142 | 0 | return failure(); |
3143 | 0 | |
3144 | 0 | // Create new ViewOp. |
3145 | 0 | auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, |
3146 | 0 | viewOp.getOperand(0), |
3147 | 0 | viewOp.byte_shift(), newOperands); |
3148 | 0 | // Insert a cast so we have the same type as the old memref type. |
3149 | 0 | rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp, |
3150 | 0 | viewOp.getType()); |
3151 | 0 | return success(); |
3152 | 0 | } |
3153 | | }; |
3154 | | |
3155 | | struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { |
3156 | | using OpRewritePattern<ViewOp>::OpRewritePattern; |
3157 | | |
3158 | | LogicalResult matchAndRewrite(ViewOp viewOp, |
3159 | 0 | PatternRewriter &rewriter) const override { |
3160 | 0 | Value memrefOperand = viewOp.getOperand(0); |
3161 | 0 | MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp<MemRefCastOp>(); |
3162 | 0 | if (!memrefCastOp) |
3163 | 0 | return failure(); |
3164 | 0 | Value allocOperand = memrefCastOp.getOperand(); |
3165 | 0 | AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); |
3166 | 0 | if (!allocOp) |
3167 | 0 | return failure(); |
3168 | 0 | rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, |
3169 | 0 | viewOp.byte_shift(), viewOp.sizes()); |
3170 | 0 | return success(); |
3171 | 0 | } |
3172 | | }; |
3173 | | |
3174 | | } // end anonymous namespace |
3175 | | |
3176 | | void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
3177 | 0 | MLIRContext *context) { |
3178 | 0 | results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); |
3179 | 0 | } |
3180 | | |
3181 | | //===----------------------------------------------------------------------===// |
3182 | | // XOrOp |
3183 | | //===----------------------------------------------------------------------===// |
3184 | | |
3185 | 0 | OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) { |
3186 | 0 | /// xor(x, 0) -> x |
3187 | 0 | if (matchPattern(rhs(), m_Zero())) |
3188 | 0 | return lhs(); |
3189 | 0 | /// xor(x,x) -> 0 |
3190 | 0 | if (lhs() == rhs()) |
3191 | 0 | return Builder(getContext()).getZeroAttr(getType()); |
3192 | 0 | |
3193 | 0 | return constFoldBinaryOp<IntegerAttr>(operands, |
3194 | 0 | [](APInt a, APInt b) { return a ^ b; }); |
3195 | 0 | } |
3196 | | |
3197 | | //===----------------------------------------------------------------------===// |
3198 | | // ZeroExtendIOp |
3199 | | //===----------------------------------------------------------------------===// |
3200 | | |
3201 | 0 | static LogicalResult verify(ZeroExtendIOp op) { |
3202 | 0 | auto srcType = getElementTypeOrSelf(op.getOperand().getType()); |
3203 | 0 | auto dstType = getElementTypeOrSelf(op.getType()); |
3204 | 0 |
|
3205 | 0 | if (srcType.isa<IndexType>()) |
3206 | 0 | return op.emitError() << srcType << " is not a valid operand type"; |
3207 | 0 | if (dstType.isa<IndexType>()) |
3208 | 0 | return op.emitError() << dstType << " is not a valid result type"; |
3209 | 0 | |
3210 | 0 | if (srcType.cast<IntegerType>().getWidth() >= |
3211 | 0 | dstType.cast<IntegerType>().getWidth()) |
3212 | 0 | return op.emitError("result type ") |
3213 | 0 | << dstType << " must be wider than operand type " << srcType; |
3214 | 0 | |
3215 | 0 | return success(); |
3216 | 0 | } |
3217 | | |
3218 | | //===----------------------------------------------------------------------===// |
3219 | | // TableGen'd op method definitions |
3220 | | //===----------------------------------------------------------------------===// |
3221 | | |
3222 | | #define GET_OP_CLASSES |
3223 | | #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc" |