/home/arjun/llvm-project/mlir/lib/IR/Builders.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #include "mlir/IR/Builders.h" |
10 | | #include "mlir/IR/AffineExpr.h" |
11 | | #include "mlir/IR/AffineMap.h" |
12 | | #include "mlir/IR/Dialect.h" |
13 | | #include "mlir/IR/IntegerSet.h" |
14 | | #include "mlir/IR/Matchers.h" |
15 | | #include "mlir/IR/Module.h" |
16 | | #include "mlir/IR/StandardTypes.h" |
17 | | #include "llvm/Support/raw_ostream.h" |
18 | | using namespace mlir; |
19 | | |
20 | 0 | Builder::Builder(ModuleOp module) : context(module.getContext()) {} |
21 | | |
22 | 0 | Identifier Builder::getIdentifier(StringRef str) { |
23 | 0 | return Identifier::get(str, context); |
24 | 0 | } |
25 | | |
26 | | //===----------------------------------------------------------------------===// |
27 | | // Locations. |
28 | | //===----------------------------------------------------------------------===// |
29 | | |
30 | 0 | Location Builder::getUnknownLoc() { return UnknownLoc::get(context); } |
31 | | |
32 | | Location Builder::getFileLineColLoc(Identifier filename, unsigned line, |
33 | 0 | unsigned column) { |
34 | 0 | return FileLineColLoc::get(filename, line, column, context); |
35 | 0 | } |
36 | | |
37 | 0 | Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) { |
38 | 0 | return FusedLoc::get(locs, metadata, context); |
39 | 0 | } |
40 | | |
41 | | //===----------------------------------------------------------------------===// |
42 | | // Types. |
43 | | //===----------------------------------------------------------------------===// |
44 | | |
45 | 0 | FloatType Builder::getBF16Type() { return FloatType::getBF16(context); } |
46 | | |
47 | 0 | FloatType Builder::getF16Type() { return FloatType::getF16(context); } |
48 | | |
49 | 0 | FloatType Builder::getF32Type() { return FloatType::getF32(context); } |
50 | | |
51 | 0 | FloatType Builder::getF64Type() { return FloatType::getF64(context); } |
52 | | |
53 | 0 | IndexType Builder::getIndexType() { return IndexType::get(context); } |
54 | | |
55 | 0 | IntegerType Builder::getI1Type() { return IntegerType::get(1, context); } |
56 | | |
57 | 0 | IntegerType Builder::getI32Type() { return IntegerType::get(32, context); } |
58 | | |
59 | 0 | IntegerType Builder::getI64Type() { return IntegerType::get(64, context); } |
60 | | |
61 | 0 | IntegerType Builder::getIntegerType(unsigned width) { |
62 | 0 | return IntegerType::get(width, context); |
63 | 0 | } |
64 | | |
65 | 0 | IntegerType Builder::getIntegerType(unsigned width, bool isSigned) { |
66 | 0 | return IntegerType::get( |
67 | 0 | width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context); |
68 | 0 | } |
69 | | |
70 | | FunctionType Builder::getFunctionType(ArrayRef<Type> inputs, |
71 | 0 | ArrayRef<Type> results) { |
72 | 0 | return FunctionType::get(inputs, results, context); |
73 | 0 | } |
74 | | |
75 | 0 | TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) { |
76 | 0 | return TupleType::get(elementTypes, context); |
77 | 0 | } |
78 | | |
79 | 0 | NoneType Builder::getNoneType() { return NoneType::get(context); } |
80 | | |
81 | | //===----------------------------------------------------------------------===// |
82 | | // Attributes. |
83 | | //===----------------------------------------------------------------------===// |
84 | | |
85 | 0 | NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) { |
86 | 0 | return NamedAttribute(getIdentifier(name), val); |
87 | 0 | } |
88 | | |
89 | 0 | UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); } |
90 | | |
91 | 0 | BoolAttr Builder::getBoolAttr(bool value) { |
92 | 0 | return BoolAttr::get(value, context); |
93 | 0 | } |
94 | | |
95 | 0 | DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) { |
96 | 0 | return DictionaryAttr::get(value, context); |
97 | 0 | } |
98 | | |
99 | 0 | IntegerAttr Builder::getIndexAttr(int64_t value) { |
100 | 0 | return IntegerAttr::get(getIndexType(), APInt(64, value)); |
101 | 0 | } |
102 | | |
103 | 0 | IntegerAttr Builder::getI64IntegerAttr(int64_t value) { |
104 | 0 | return IntegerAttr::get(getIntegerType(64), APInt(64, value)); |
105 | 0 | } |
106 | | |
107 | 0 | DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) { |
108 | 0 | return DenseIntElementsAttr::get( |
109 | 0 | VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)), |
110 | 0 | values); |
111 | 0 | } |
112 | | |
113 | 0 | DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) { |
114 | 0 | return DenseIntElementsAttr::get( |
115 | 0 | VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)), |
116 | 0 | values); |
117 | 0 | } |
118 | | |
119 | 0 | DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) { |
120 | 0 | return DenseIntElementsAttr::get( |
121 | 0 | RankedTensorType::get(static_cast<int64_t>(values.size()), |
122 | 0 | getIntegerType(32)), |
123 | 0 | values); |
124 | 0 | } |
125 | | |
126 | 0 | DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) { |
127 | 0 | return DenseIntElementsAttr::get( |
128 | 0 | RankedTensorType::get(static_cast<int64_t>(values.size()), |
129 | 0 | getIntegerType(64)), |
130 | 0 | values); |
131 | 0 | } |
132 | | |
133 | 0 | DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) { |
134 | 0 | return DenseIntElementsAttr::get( |
135 | 0 | RankedTensorType::get(static_cast<int64_t>(values.size()), |
136 | 0 | getIndexType()), |
137 | 0 | values); |
138 | 0 | } |
139 | | |
140 | 0 | IntegerAttr Builder::getI32IntegerAttr(int32_t value) { |
141 | 0 | return IntegerAttr::get(getIntegerType(32), APInt(32, value)); |
142 | 0 | } |
143 | | |
144 | 0 | IntegerAttr Builder::getSI32IntegerAttr(int32_t value) { |
145 | 0 | return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true), |
146 | 0 | APInt(32, value, /*isSigned=*/true)); |
147 | 0 | } |
148 | | |
149 | 0 | IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) { |
150 | 0 | return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false), |
151 | 0 | APInt(32, (uint64_t)value, /*isSigned=*/false)); |
152 | 0 | } |
153 | | |
154 | 0 | IntegerAttr Builder::getI16IntegerAttr(int16_t value) { |
155 | 0 | return IntegerAttr::get(getIntegerType(16), APInt(16, value)); |
156 | 0 | } |
157 | | |
158 | 0 | IntegerAttr Builder::getI8IntegerAttr(int8_t value) { |
159 | 0 | return IntegerAttr::get(getIntegerType(8), APInt(8, value)); |
160 | 0 | } |
161 | | |
162 | 0 | IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) { |
163 | 0 | if (type.isIndex()) |
164 | 0 | return IntegerAttr::get(type, APInt(64, value)); |
165 | 0 | return IntegerAttr::get( |
166 | 0 | type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger())); |
167 | 0 | } |
168 | | |
169 | 0 | IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) { |
170 | 0 | return IntegerAttr::get(type, value); |
171 | 0 | } |
172 | | |
173 | 0 | FloatAttr Builder::getF64FloatAttr(double value) { |
174 | 0 | return FloatAttr::get(getF64Type(), APFloat(value)); |
175 | 0 | } |
176 | | |
177 | 0 | FloatAttr Builder::getF32FloatAttr(float value) { |
178 | 0 | return FloatAttr::get(getF32Type(), APFloat(value)); |
179 | 0 | } |
180 | | |
181 | 0 | FloatAttr Builder::getF16FloatAttr(float value) { |
182 | 0 | return FloatAttr::get(getF16Type(), value); |
183 | 0 | } |
184 | | |
185 | 0 | FloatAttr Builder::getFloatAttr(Type type, double value) { |
186 | 0 | return FloatAttr::get(type, value); |
187 | 0 | } |
188 | | |
189 | 0 | FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) { |
190 | 0 | return FloatAttr::get(type, value); |
191 | 0 | } |
192 | | |
193 | 0 | StringAttr Builder::getStringAttr(StringRef bytes) { |
194 | 0 | return StringAttr::get(bytes, context); |
195 | 0 | } |
196 | | |
197 | 0 | ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) { |
198 | 0 | return ArrayAttr::get(value, context); |
199 | 0 | } |
200 | | |
201 | 0 | FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { |
202 | 0 | auto symName = |
203 | 0 | value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); |
204 | 0 | assert(symName && "value does not have a valid symbol name"); |
205 | 0 | return getSymbolRefAttr(symName.getValue()); |
206 | 0 | } |
207 | 0 | FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) { |
208 | 0 | return SymbolRefAttr::get(value, getContext()); |
209 | 0 | } |
210 | | SymbolRefAttr |
211 | | Builder::getSymbolRefAttr(StringRef value, |
212 | 0 | ArrayRef<FlatSymbolRefAttr> nestedReferences) { |
213 | 0 | return SymbolRefAttr::get(value, nestedReferences, getContext()); |
214 | 0 | } |
215 | | |
216 | 0 | ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) { |
217 | 0 | auto attrs = llvm::to_vector<8>(llvm::map_range( |
218 | 0 | values, [this](bool v) -> Attribute { return getBoolAttr(v); })); |
219 | 0 | return getArrayAttr(attrs); |
220 | 0 | } |
221 | | |
222 | 0 | ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) { |
223 | 0 | auto attrs = llvm::to_vector<8>(llvm::map_range( |
224 | 0 | values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); })); |
225 | 0 | return getArrayAttr(attrs); |
226 | 0 | } |
227 | 0 | ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) { |
228 | 0 | auto attrs = llvm::to_vector<8>(llvm::map_range( |
229 | 0 | values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); })); |
230 | 0 | return getArrayAttr(attrs); |
231 | 0 | } |
232 | | |
233 | 0 | ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) { |
234 | 0 | auto attrs = llvm::to_vector<8>( |
235 | 0 | llvm::map_range(values, [this](int64_t v) -> Attribute { |
236 | 0 | return getIntegerAttr(IndexType::get(getContext()), v); |
237 | 0 | })); |
238 | 0 | return getArrayAttr(attrs); |
239 | 0 | } |
240 | | |
241 | 0 | ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) { |
242 | 0 | auto attrs = llvm::to_vector<8>(llvm::map_range( |
243 | 0 | values, [this](float v) -> Attribute { return getF32FloatAttr(v); })); |
244 | 0 | return getArrayAttr(attrs); |
245 | 0 | } |
246 | | |
247 | 0 | ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) { |
248 | 0 | auto attrs = llvm::to_vector<8>(llvm::map_range( |
249 | 0 | values, [this](double v) -> Attribute { return getF64FloatAttr(v); })); |
250 | 0 | return getArrayAttr(attrs); |
251 | 0 | } |
252 | | |
253 | 0 | ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) { |
254 | 0 | auto attrs = llvm::to_vector<8>(llvm::map_range( |
255 | 0 | values, [this](StringRef v) -> Attribute { return getStringAttr(v); })); |
256 | 0 | return getArrayAttr(attrs); |
257 | 0 | } |
258 | | |
259 | 0 | ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) { |
260 | 0 | auto attrs = llvm::to_vector<8>(llvm::map_range( |
261 | 0 | values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); })); |
262 | 0 | return getArrayAttr(attrs); |
263 | 0 | } |
264 | | |
265 | 0 | Attribute Builder::getZeroAttr(Type type) { |
266 | 0 | switch (type.getKind()) { |
267 | 0 | case StandardTypes::BF16: |
268 | 0 | case StandardTypes::F16: |
269 | 0 | case StandardTypes::F32: |
270 | 0 | case StandardTypes::F64: |
271 | 0 | return getFloatAttr(type, 0.0); |
272 | 0 | case StandardTypes::Integer: { |
273 | 0 | auto width = type.cast<IntegerType>().getWidth(); |
274 | 0 | if (width == 1) |
275 | 0 | return getBoolAttr(false); |
276 | 0 | return getIntegerAttr(type, APInt(width, 0)); |
277 | 0 | } |
278 | 0 | case StandardTypes::Vector: |
279 | 0 | case StandardTypes::RankedTensor: { |
280 | 0 | auto vtType = type.cast<ShapedType>(); |
281 | 0 | auto element = getZeroAttr(vtType.getElementType()); |
282 | 0 | if (!element) |
283 | 0 | return {}; |
284 | 0 | return DenseElementsAttr::get(vtType, element); |
285 | 0 | } |
286 | 0 | default: |
287 | 0 | break; |
288 | 0 | } |
289 | 0 | return {}; |
290 | 0 | } |
291 | | |
292 | | //===----------------------------------------------------------------------===// |
293 | | // Affine Expressions, Affine Maps, and Integer Sets. |
294 | | //===----------------------------------------------------------------------===// |
295 | | |
296 | 0 | AffineExpr Builder::getAffineDimExpr(unsigned position) { |
297 | 0 | return mlir::getAffineDimExpr(position, context); |
298 | 0 | } |
299 | | |
300 | 0 | AffineExpr Builder::getAffineSymbolExpr(unsigned position) { |
301 | 0 | return mlir::getAffineSymbolExpr(position, context); |
302 | 0 | } |
303 | | |
304 | 0 | AffineExpr Builder::getAffineConstantExpr(int64_t constant) { |
305 | 0 | return mlir::getAffineConstantExpr(constant, context); |
306 | 0 | } |
307 | | |
308 | 0 | AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); } |
309 | | |
310 | 0 | AffineMap Builder::getConstantAffineMap(int64_t val) { |
311 | 0 | return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, |
312 | 0 | getAffineConstantExpr(val)); |
313 | 0 | } |
314 | | |
315 | 0 | AffineMap Builder::getDimIdentityMap() { |
316 | 0 | return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0)); |
317 | 0 | } |
318 | | |
319 | 0 | AffineMap Builder::getMultiDimIdentityMap(unsigned rank) { |
320 | 0 | SmallVector<AffineExpr, 4> dimExprs; |
321 | 0 | dimExprs.reserve(rank); |
322 | 0 | for (unsigned i = 0; i < rank; ++i) |
323 | 0 | dimExprs.push_back(getAffineDimExpr(i)); |
324 | 0 | return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs, |
325 | 0 | context); |
326 | 0 | } |
327 | | |
328 | 0 | AffineMap Builder::getSymbolIdentityMap() { |
329 | 0 | return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, |
330 | 0 | getAffineSymbolExpr(0)); |
331 | 0 | } |
332 | | |
333 | 0 | AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) { |
334 | 0 | // expr = d0 + shift. |
335 | 0 | auto expr = getAffineDimExpr(0) + shift; |
336 | 0 | return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); |
337 | 0 | } |
338 | | |
339 | 0 | AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { |
340 | 0 | SmallVector<AffineExpr, 4> shiftedResults; |
341 | 0 | shiftedResults.reserve(map.getNumResults()); |
342 | 0 | for (auto resultExpr : map.getResults()) |
343 | 0 | shiftedResults.push_back(resultExpr + shift); |
344 | 0 | return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults, |
345 | 0 | context); |
346 | 0 | } |
347 | | |
348 | | //===----------------------------------------------------------------------===// |
349 | | // OpBuilder |
350 | | //===----------------------------------------------------------------------===// |
351 | | |
352 | 0 | OpBuilder::Listener::~Listener() {} |
353 | | |
354 | | /// Insert the given operation at the current insertion point and return it. |
355 | 0 | Operation *OpBuilder::insert(Operation *op) { |
356 | 0 | if (block) |
357 | 0 | block->getOperations().insert(insertPoint, op); |
358 | 0 |
|
359 | 0 | if (listener) |
360 | 0 | listener->notifyOperationInserted(op); |
361 | 0 | return op; |
362 | 0 | } |
363 | | |
364 | | /// Add new block with 'argTypes' arguments and set the insertion point to the |
365 | | /// end of it. The block is inserted at the provided insertion point of |
366 | | /// 'parent'. |
367 | | Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt, |
368 | 0 | TypeRange argTypes) { |
369 | 0 | assert(parent && "expected valid parent region"); |
370 | 0 | if (insertPt == Region::iterator()) |
371 | 0 | insertPt = parent->end(); |
372 | 0 |
|
373 | 0 | Block *b = new Block(); |
374 | 0 | b->addArguments(argTypes); |
375 | 0 | parent->getBlocks().insert(insertPt, b); |
376 | 0 | setInsertionPointToEnd(b); |
377 | 0 |
|
378 | 0 | if (listener) |
379 | 0 | listener->notifyBlockCreated(b); |
380 | 0 | return b; |
381 | 0 | } |
382 | | |
383 | | /// Add new block with 'argTypes' arguments and set the insertion point to the |
384 | | /// end of it. The block is placed before 'insertBefore'. |
385 | 0 | Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) { |
386 | 0 | assert(insertBefore && "expected valid insertion block"); |
387 | 0 | return createBlock(insertBefore->getParent(), Region::iterator(insertBefore), |
388 | 0 | argTypes); |
389 | 0 | } |
390 | | |
391 | | /// Create an operation given the fields represented as an OperationState. |
392 | 0 | Operation *OpBuilder::createOperation(const OperationState &state) { |
393 | 0 | return insert(Operation::create(state)); |
394 | 0 | } |
395 | | |
396 | | /// Attempts to fold the given operation and places new results within |
397 | | /// 'results'. Returns success if the operation was folded, failure otherwise. |
398 | | /// Note: This function does not erase the operation on a successful fold. |
399 | | LogicalResult OpBuilder::tryFold(Operation *op, |
400 | 0 | SmallVectorImpl<Value> &results) { |
401 | 0 | results.reserve(op->getNumResults()); |
402 | 0 | auto cleanupFailure = [&] { |
403 | 0 | results.assign(op->result_begin(), op->result_end()); |
404 | 0 | return failure(); |
405 | 0 | }; |
406 | 0 |
|
407 | 0 | // If this operation is already a constant, there is nothing to do. |
408 | 0 | if (matchPattern(op, m_Constant())) |
409 | 0 | return cleanupFailure(); |
410 | 0 | |
411 | 0 | // Check to see if any operands to the operation is constant and whether |
412 | 0 | // the operation knows how to constant fold itself. |
413 | 0 | SmallVector<Attribute, 4> constOperands(op->getNumOperands()); |
414 | 0 | for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) |
415 | 0 | matchPattern(op->getOperand(i), m_Constant(&constOperands[i])); |
416 | 0 |
|
417 | 0 | // Try to fold the operation. |
418 | 0 | SmallVector<OpFoldResult, 4> foldResults; |
419 | 0 | if (failed(op->fold(constOperands, foldResults)) || foldResults.empty()) |
420 | 0 | return cleanupFailure(); |
421 | 0 | |
422 | 0 | // A temporary builder used for creating constants during folding. |
423 | 0 | OpBuilder cstBuilder(context); |
424 | 0 | SmallVector<Operation *, 1> generatedConstants; |
425 | 0 |
|
426 | 0 | // Populate the results with the folded results. |
427 | 0 | Dialect *dialect = op->getDialect(); |
428 | 0 | for (auto &it : llvm::enumerate(foldResults)) { |
429 | 0 | // Normal values get pushed back directly. |
430 | 0 | if (auto value = it.value().dyn_cast<Value>()) { |
431 | 0 | results.push_back(value); |
432 | 0 | continue; |
433 | 0 | } |
434 | 0 | |
435 | 0 | // Otherwise, try to materialize a constant operation. |
436 | 0 | if (!dialect) |
437 | 0 | return cleanupFailure(); |
438 | 0 | |
439 | 0 | // Ask the dialect to materialize a constant operation for this value. |
440 | 0 | Attribute attr = it.value().get<Attribute>(); |
441 | 0 | auto *constOp = dialect->materializeConstant( |
442 | 0 | cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc()); |
443 | 0 | if (!constOp) { |
444 | 0 | // Erase any generated constants. |
445 | 0 | for (Operation *cst : generatedConstants) |
446 | 0 | cst->erase(); |
447 | 0 | return cleanupFailure(); |
448 | 0 | } |
449 | 0 | assert(matchPattern(constOp, m_Constant())); |
450 | 0 |
|
451 | 0 | generatedConstants.push_back(constOp); |
452 | 0 | results.push_back(constOp->getResult(0)); |
453 | 0 | } |
454 | 0 |
|
455 | 0 | // If we were successful, insert any generated constants. |
456 | 0 | for (Operation *cst : generatedConstants) |
457 | 0 | insert(cst); |
458 | 0 |
|
459 | 0 | return success(); |
460 | 0 | } |