/home/arjun/llvm-project/mlir/include/mlir/IR/Builders.h
Line | Count | Source (jump to first uncovered line) |
1 | | //===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #ifndef MLIR_IR_BUILDERS_H |
10 | | #define MLIR_IR_BUILDERS_H |
11 | | |
12 | | #include "mlir/IR/OpDefinition.h" |
13 | | |
14 | | namespace mlir { |
15 | | |
16 | | class AffineExpr; |
17 | | class BlockAndValueMapping; |
18 | | class ModuleOp; |
19 | | class UnknownLoc; |
20 | | class FileLineColLoc; |
21 | | class Type; |
22 | | class PrimitiveType; |
23 | | class IntegerType; |
24 | | class FunctionType; |
25 | | class MemRefType; |
26 | | class VectorType; |
27 | | class RankedTensorType; |
28 | | class UnrankedTensorType; |
29 | | class TupleType; |
30 | | class NoneType; |
31 | | class BoolAttr; |
32 | | class IntegerAttr; |
33 | | class FloatAttr; |
34 | | class StringAttr; |
35 | | class TypeAttr; |
36 | | class ArrayAttr; |
37 | | class SymbolRefAttr; |
38 | | class ElementsAttr; |
39 | | class DenseElementsAttr; |
40 | | class DenseIntElementsAttr; |
41 | | class AffineMapAttr; |
42 | | class AffineMap; |
43 | | class UnitAttr; |
44 | | |
45 | | /// This class is a general helper class for creating context-global objects |
46 | | /// like types, attributes, and affine expressions. |
47 | | class Builder { |
48 | | public: |
49 | 0 | explicit Builder(MLIRContext *context) : context(context) {} |
50 | | explicit Builder(ModuleOp module); |
51 | | |
52 | 0 | MLIRContext *getContext() const { return context; } |
53 | | |
54 | | Identifier getIdentifier(StringRef str); |
55 | | |
56 | | // Locations. |
57 | | Location getUnknownLoc(); |
58 | | Location getFileLineColLoc(Identifier filename, unsigned line, |
59 | | unsigned column); |
60 | | Location getFusedLoc(ArrayRef<Location> locs, |
61 | | Attribute metadata = Attribute()); |
62 | | |
63 | | // Types. |
64 | | FloatType getBF16Type(); |
65 | | FloatType getF16Type(); |
66 | | FloatType getF32Type(); |
67 | | FloatType getF64Type(); |
68 | | |
69 | | IndexType getIndexType(); |
70 | | |
71 | | IntegerType getI1Type(); |
72 | | IntegerType getI32Type(); |
73 | | IntegerType getI64Type(); |
74 | | IntegerType getIntegerType(unsigned width); |
75 | | IntegerType getIntegerType(unsigned width, bool isSigned); |
76 | | FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results); |
77 | | TupleType getTupleType(ArrayRef<Type> elementTypes); |
78 | | NoneType getNoneType(); |
79 | | |
80 | | /// Get or construct an instance of the type 'ty' with provided arguments. |
81 | 0 | template <typename Ty, typename... Args> Ty getType(Args... args) { |
82 | 0 | return Ty::get(context, args...); |
83 | 0 | } |
84 | | |
85 | | // Attributes. |
86 | | NamedAttribute getNamedAttr(StringRef name, Attribute val); |
87 | | |
88 | | UnitAttr getUnitAttr(); |
89 | | BoolAttr getBoolAttr(bool value); |
90 | | DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value); |
91 | | IntegerAttr getIntegerAttr(Type type, int64_t value); |
92 | | IntegerAttr getIntegerAttr(Type type, const APInt &value); |
93 | | FloatAttr getFloatAttr(Type type, double value); |
94 | | FloatAttr getFloatAttr(Type type, const APFloat &value); |
95 | | StringAttr getStringAttr(StringRef bytes); |
96 | | ArrayAttr getArrayAttr(ArrayRef<Attribute> value); |
97 | | FlatSymbolRefAttr getSymbolRefAttr(Operation *value); |
98 | | FlatSymbolRefAttr getSymbolRefAttr(StringRef value); |
99 | | SymbolRefAttr getSymbolRefAttr(StringRef value, |
100 | | ArrayRef<FlatSymbolRefAttr> nestedReferences); |
101 | | |
102 | | // Returns a 0-valued attribute of the given `type`. This function only |
103 | | // supports boolean, integer, and 16-/32-/64-bit float types, and vector or |
104 | | // ranked tensor of them. Returns null attribute otherwise. |
105 | | Attribute getZeroAttr(Type type); |
106 | | |
107 | | // Convenience methods for fixed types. |
108 | | FloatAttr getF16FloatAttr(float value); |
109 | | FloatAttr getF32FloatAttr(float value); |
110 | | FloatAttr getF64FloatAttr(double value); |
111 | | |
112 | | IntegerAttr getI8IntegerAttr(int8_t value); |
113 | | IntegerAttr getI16IntegerAttr(int16_t value); |
114 | | IntegerAttr getI32IntegerAttr(int32_t value); |
115 | | IntegerAttr getI64IntegerAttr(int64_t value); |
116 | | IntegerAttr getIndexAttr(int64_t value); |
117 | | |
118 | | /// Signed and unsigned integer attribute getters. |
119 | | IntegerAttr getSI32IntegerAttr(int32_t value); |
120 | | IntegerAttr getUI32IntegerAttr(uint32_t value); |
121 | | |
122 | | /// Vector-typed DenseIntElementsAttr getters. `values` must not be empty. |
123 | | DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values); |
124 | | DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values); |
125 | | |
126 | | /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty. |
127 | | /// These are generally preferable for representing general lists of integers |
128 | | /// as attributes. |
129 | | DenseIntElementsAttr getI32TensorAttr(ArrayRef<int32_t> values); |
130 | | DenseIntElementsAttr getI64TensorAttr(ArrayRef<int64_t> values); |
131 | | DenseIntElementsAttr getIndexTensorAttr(ArrayRef<int64_t> values); |
132 | | |
133 | | ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values); |
134 | | ArrayAttr getBoolArrayAttr(ArrayRef<bool> values); |
135 | | ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values); |
136 | | ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values); |
137 | | ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values); |
138 | | ArrayAttr getF32ArrayAttr(ArrayRef<float> values); |
139 | | ArrayAttr getF64ArrayAttr(ArrayRef<double> values); |
140 | | ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values); |
141 | | |
142 | | // Affine expressions and affine maps. |
143 | | AffineExpr getAffineDimExpr(unsigned position); |
144 | | AffineExpr getAffineSymbolExpr(unsigned position); |
145 | | AffineExpr getAffineConstantExpr(int64_t constant); |
146 | | |
147 | | // Special cases of affine maps and integer sets |
148 | | /// Returns a zero result affine map with no dimensions or symbols: () -> (). |
149 | | AffineMap getEmptyAffineMap(); |
150 | | /// Returns a single constant result affine map with 0 dimensions and 0 |
151 | | /// symbols. One constant result: () -> (val). |
152 | | AffineMap getConstantAffineMap(int64_t val); |
153 | | // One dimension id identity map: (i) -> (i). |
154 | | AffineMap getDimIdentityMap(); |
155 | | // Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2). |
156 | | AffineMap getMultiDimIdentityMap(unsigned rank); |
157 | | // One symbol identity map: ()[s] -> (s). |
158 | | AffineMap getSymbolIdentityMap(); |
159 | | |
160 | | /// Returns a map that shifts its (single) input dimension by 'shift'. |
161 | | /// (d0) -> (d0 + shift) |
162 | | AffineMap getSingleDimShiftAffineMap(int64_t shift); |
163 | | |
164 | | /// Returns an affine map that is a translation (shift) of all result |
165 | | /// expressions in 'map' by 'shift'. |
166 | | /// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2 |
167 | | /// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2) |
168 | | AffineMap getShiftedAffineMap(AffineMap map, int64_t shift); |
169 | | |
170 | | protected: |
171 | | MLIRContext *context; |
172 | | }; |
173 | | |
174 | | /// This class helps build Operations. Operations that are created are |
175 | | /// automatically inserted at an insertion point. The builder is copyable. |
176 | | class OpBuilder : public Builder { |
177 | | public: |
178 | | struct Listener; |
179 | | |
180 | | /// Create a builder with the given context. |
181 | | explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr) |
182 | 0 | : Builder(ctx), listener(listener) {} |
183 | | |
184 | | /// Create a builder and set the insertion point to the start of the region. |
185 | | explicit OpBuilder(Region *region, Listener *listener = nullptr) |
186 | 0 | : OpBuilder(region->getContext(), listener) { |
187 | 0 | if (!region->empty()) |
188 | 0 | setInsertionPoint(®ion->front(), region->front().begin()); |
189 | 0 | } |
190 | | explicit OpBuilder(Region ®ion, Listener *listener = nullptr) |
191 | 0 | : OpBuilder(®ion, listener) {} |
192 | | |
193 | | /// Create a builder and set insertion point to the given operation, which |
194 | | /// will cause subsequent insertions to go right before it. |
195 | | explicit OpBuilder(Operation *op, Listener *listener = nullptr) |
196 | 0 | : OpBuilder(op->getContext(), listener) { |
197 | 0 | setInsertionPoint(op); |
198 | 0 | } |
199 | | |
200 | | OpBuilder(Block *block, Block::iterator insertPoint, |
201 | | Listener *listener = nullptr) |
202 | 0 | : OpBuilder(block->getParent()->getContext(), listener) { |
203 | 0 | setInsertionPoint(block, insertPoint); |
204 | 0 | } |
205 | | |
206 | | /// Create a builder and set the insertion point to before the first operation |
207 | | /// in the block but still inside the block. |
208 | 0 | static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) { |
209 | 0 | return OpBuilder(block, block->begin(), listener); |
210 | 0 | } |
211 | | |
212 | | /// Create a builder and set the insertion point to after the last operation |
213 | | /// in the block but still inside the block. |
214 | 0 | static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) { |
215 | 0 | return OpBuilder(block, block->end(), listener); |
216 | 0 | } |
217 | | |
218 | | /// Create a builder and set the insertion point to before the block |
219 | | /// terminator. |
220 | | static OpBuilder atBlockTerminator(Block *block, |
221 | 0 | Listener *listener = nullptr) { |
222 | 0 | auto *terminator = block->getTerminator(); |
223 | 0 | assert(terminator != nullptr && "the block has no terminator"); |
224 | 0 | return OpBuilder(block, Block::iterator(terminator), listener); |
225 | 0 | } |
226 | | |
227 | | //===--------------------------------------------------------------------===// |
228 | | // Listeners |
229 | | //===--------------------------------------------------------------------===// |
230 | | |
231 | | /// This class represents a listener that may be used to hook into various |
232 | | /// actions within an OpBuilder. |
233 | | struct Listener { |
234 | | virtual ~Listener(); |
235 | | |
236 | | /// Notification handler for when an operation is inserted into the builder. |
237 | | /// `op` is the operation that was inserted. |
238 | 0 | virtual void notifyOperationInserted(Operation *op) {} |
239 | | |
240 | | /// Notification handler for when a block is created using the builder. |
241 | | /// `block` is the block that was created. |
242 | 0 | virtual void notifyBlockCreated(Block *block) {} |
243 | | }; |
244 | | |
245 | | /// Sets the listener of this builder to the one provided. |
246 | 0 | void setListener(Listener *newListener) { listener = newListener; } |
247 | | |
248 | | /// Returns the current listener of this builder, or nullptr if this builder |
249 | | /// doesn't have a listener. |
250 | 0 | Listener *getListener() const { return listener; } |
251 | | |
252 | | //===--------------------------------------------------------------------===// |
253 | | // Insertion Point Management |
254 | | //===--------------------------------------------------------------------===// |
255 | | |
256 | | /// This class represents a saved insertion point. |
257 | | class InsertPoint { |
258 | | public: |
259 | | /// Creates a new insertion point which doesn't point to anything. |
260 | | InsertPoint() = default; |
261 | | |
262 | | /// Creates a new insertion point at the given location. |
263 | | InsertPoint(Block *insertBlock, Block::iterator insertPt) |
264 | 0 | : block(insertBlock), point(insertPt) {} |
265 | | |
266 | | /// Returns true if this insert point is set. |
267 | 0 | bool isSet() const { return (block != nullptr); } |
268 | | |
269 | 0 | Block *getBlock() const { return block; } |
270 | 0 | Block::iterator getPoint() const { return point; } |
271 | | |
272 | | private: |
273 | | Block *block = nullptr; |
274 | | Block::iterator point; |
275 | | }; |
276 | | |
277 | | /// RAII guard to reset the insertion point of the builder when destroyed. |
278 | | class InsertionGuard { |
279 | | public: |
280 | | InsertionGuard(OpBuilder &builder) |
281 | 0 | : builder(builder), ip(builder.saveInsertionPoint()) {} |
282 | 0 | ~InsertionGuard() { builder.restoreInsertionPoint(ip); } |
283 | | |
284 | | private: |
285 | | OpBuilder &builder; |
286 | | OpBuilder::InsertPoint ip; |
287 | | }; |
288 | | |
289 | | /// Reset the insertion point to no location. Creating an operation without a |
290 | | /// set insertion point is an error, but this can still be useful when the |
291 | | /// current insertion point a builder refers to is being removed. |
292 | 0 | void clearInsertionPoint() { |
293 | 0 | this->block = nullptr; |
294 | 0 | insertPoint = Block::iterator(); |
295 | 0 | } |
296 | | |
297 | | /// Return a saved insertion point. |
298 | 0 | InsertPoint saveInsertionPoint() const { |
299 | 0 | return InsertPoint(getInsertionBlock(), getInsertionPoint()); |
300 | 0 | } |
301 | | |
302 | | /// Restore the insert point to a previously saved point. |
303 | 0 | void restoreInsertionPoint(InsertPoint ip) { |
304 | 0 | if (ip.isSet()) |
305 | 0 | setInsertionPoint(ip.getBlock(), ip.getPoint()); |
306 | 0 | else |
307 | 0 | clearInsertionPoint(); |
308 | 0 | } |
309 | | |
310 | | /// Set the insertion point to the specified location. |
311 | 0 | void setInsertionPoint(Block *block, Block::iterator insertPoint) { |
312 | 0 | // TODO: check that insertPoint is in this rather than some other block. |
313 | 0 | this->block = block; |
314 | 0 | this->insertPoint = insertPoint; |
315 | 0 | } |
316 | | |
317 | | /// Sets the insertion point to the specified operation, which will cause |
318 | | /// subsequent insertions to go right before it. |
319 | 0 | void setInsertionPoint(Operation *op) { |
320 | 0 | setInsertionPoint(op->getBlock(), Block::iterator(op)); |
321 | 0 | } |
322 | | |
323 | | /// Sets the insertion point to the node after the specified operation, which |
324 | | /// will cause subsequent insertions to go right after it. |
325 | 0 | void setInsertionPointAfter(Operation *op) { |
326 | 0 | setInsertionPoint(op->getBlock(), ++Block::iterator(op)); |
327 | 0 | } |
328 | | |
329 | | /// Sets the insertion point to the start of the specified block. |
330 | 0 | void setInsertionPointToStart(Block *block) { |
331 | 0 | setInsertionPoint(block, block->begin()); |
332 | 0 | } |
333 | | |
334 | | /// Sets the insertion point to the end of the specified block. |
335 | 0 | void setInsertionPointToEnd(Block *block) { |
336 | 0 | setInsertionPoint(block, block->end()); |
337 | 0 | } |
338 | | |
339 | | /// Return the block the current insertion point belongs to. Note that the |
340 | | /// the insertion point is not necessarily the end of the block. |
341 | 0 | Block *getInsertionBlock() const { return block; } |
342 | | |
343 | | /// Returns the current insertion point of the builder. |
344 | 0 | Block::iterator getInsertionPoint() const { return insertPoint; } |
345 | | |
346 | | /// Returns the current block of the builder. |
347 | 0 | Block *getBlock() const { return block; } |
348 | | |
349 | | //===--------------------------------------------------------------------===// |
350 | | // Block Creation |
351 | | //===--------------------------------------------------------------------===// |
352 | | |
353 | | /// Add new block with 'argTypes' arguments and set the insertion point to the |
354 | | /// end of it. The block is inserted at the provided insertion point of |
355 | | /// 'parent'. |
356 | | Block *createBlock(Region *parent, Region::iterator insertPt = {}, |
357 | | TypeRange argTypes = llvm::None); |
358 | | |
359 | | /// Add new block with 'argTypes' arguments and set the insertion point to the |
360 | | /// end of it. The block is placed before 'insertBefore'. |
361 | | Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None); |
362 | | |
363 | | //===--------------------------------------------------------------------===// |
364 | | // Operation Creation |
365 | | //===--------------------------------------------------------------------===// |
366 | | |
367 | | /// Insert the given operation at the current insertion point and return it. |
368 | | Operation *insert(Operation *op); |
369 | | |
370 | | /// Creates an operation given the fields represented as an OperationState. |
371 | | Operation *createOperation(const OperationState &state); |
372 | | |
373 | | /// Create an operation of specific op type at the current insertion point. |
374 | | template <typename OpTy, typename... Args> |
375 | 0 | OpTy create(Location location, Args &&... args) { |
376 | 0 | OperationState state(location, OpTy::getOperationName()); |
377 | 0 | if (!state.name.getAbstractOperation()) |
378 | 0 | llvm::report_fatal_error("Building op `" + |
379 | 0 | state.name.getStringRef().str() + |
380 | 0 | "` but it isn't registered in this MLIRContext"); |
381 | 0 | OpTy::build(*this, state, std::forward<Args>(args)...); |
382 | 0 | auto *op = createOperation(state); |
383 | 0 | auto result = dyn_cast<OpTy>(op); |
384 | 0 | assert(result && "builder didn't return the right type"); |
385 | 0 | return result; |
386 | 0 | } Unexecuted instantiation: _ZN4mlir9OpBuilder6createINS_10ConstantOpEJRNS_4TypeERNS_9AttributeEEEET_NS_8LocationEDpOT0_ Unexecuted instantiation: _ZN4mlir9OpBuilder6createINS_13AffineApplyOpEJRNS_9AffineMapERN4llvm11SmallVectorINS_5ValueELj8EEEEEET_NS_8LocationEDpOT0_ Unexecuted instantiation: _ZN4mlir9OpBuilder6createINS_8BranchOpEJRPNS_5BlockENS_12OperandRangeEEEET_NS_8LocationEDpOT0_ Unexecuted instantiation: _ZN4mlir9OpBuilder6createINS_15ConstantIndexOpEJlEEET_NS_8LocationEDpOT0_ Unexecuted instantiation: _ZN4mlir9OpBuilder6createINS_15ConstantIndexOpEJRlEEET_NS_8LocationEDpOT0_ |
387 | | |
388 | | /// Create an operation of specific op type at the current insertion point, |
389 | | /// and immediately try to fold it. This functions populates 'results' with |
390 | | /// the results after folding the operation. |
391 | | template <typename OpTy, typename... Args> |
392 | | void createOrFold(SmallVectorImpl<Value> &results, Location location, |
393 | | Args &&... args) { |
394 | | // Create the operation without using 'createOperation' as we don't want to |
395 | | // insert it yet. |
396 | | OperationState state(location, OpTy::getOperationName()); |
397 | | if (!state.name.getAbstractOperation()) |
398 | | llvm::report_fatal_error("Building op `" + |
399 | | state.name.getStringRef().str() + |
400 | | "` but it isn't registered in this MLIRContext"); |
401 | | OpTy::build(*this, state, std::forward<Args>(args)...); |
402 | | Operation *op = Operation::create(state); |
403 | | |
404 | | // Fold the operation. If successful destroy it, otherwise insert it. |
405 | | if (succeeded(tryFold(op, results))) |
406 | | op->destroy(); |
407 | | else |
408 | | insert(op); |
409 | | } |
410 | | |
411 | | /// Overload to create or fold a single result operation. |
412 | | template <typename OpTy, typename... Args> |
413 | | typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(), |
414 | | Value>::type |
415 | | createOrFold(Location location, Args &&... args) { |
416 | | SmallVector<Value, 1> results; |
417 | | createOrFold<OpTy>(results, location, std::forward<Args>(args)...); |
418 | | return results.front(); |
419 | | } |
420 | | |
421 | | /// Overload to create or fold a zero result operation. |
422 | | template <typename OpTy, typename... Args> |
423 | | typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(), |
424 | | OpTy>::type |
425 | | createOrFold(Location location, Args &&... args) { |
426 | | auto op = create<OpTy>(location, std::forward<Args>(args)...); |
427 | | SmallVector<Value, 0> unused; |
428 | | tryFold(op.getOperation(), unused); |
429 | | |
430 | | // Folding cannot remove a zero-result operation, so for convenience we |
431 | | // continue to return it. |
432 | | return op; |
433 | | } |
434 | | |
435 | | /// Attempts to fold the given operation and places new results within |
436 | | /// 'results'. Returns success if the operation was folded, failure otherwise. |
437 | | /// Note: This function does not erase the operation on a successful fold. |
438 | | LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results); |
439 | | |
440 | | /// Creates a deep copy of the specified operation, remapping any operands |
441 | | /// that use values outside of the operation using the map that is provided |
442 | | /// ( leaving them alone if no entry is present). Replaces references to |
443 | | /// cloned sub-operations to the corresponding operation that is copied, |
444 | | /// and adds those mappings to the map. |
445 | 0 | Operation *clone(Operation &op, BlockAndValueMapping &mapper) { |
446 | 0 | return insert(op.clone(mapper)); |
447 | 0 | } |
448 | 0 | Operation *clone(Operation &op) { return insert(op.clone()); } |
449 | | |
450 | | /// Creates a deep copy of this operation but keep the operation regions |
451 | | /// empty. Operands are remapped using `mapper` (if present), and `mapper` is |
452 | | /// updated to contain the results. |
453 | 0 | Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) { |
454 | 0 | return insert(op.cloneWithoutRegions(mapper)); |
455 | 0 | } |
456 | 0 | Operation *cloneWithoutRegions(Operation &op) { |
457 | 0 | return insert(op.cloneWithoutRegions()); |
458 | 0 | } |
459 | | template <typename OpT> OpT cloneWithoutRegions(OpT op) { |
460 | | return cast<OpT>(cloneWithoutRegions(*op.getOperation())); |
461 | | } |
462 | | |
463 | | private: |
464 | | /// The current block this builder is inserting into. |
465 | | Block *block = nullptr; |
466 | | /// The insertion point within the block that this builder is inserting |
467 | | /// before. |
468 | | Block::iterator insertPoint; |
469 | | /// The optional listener for events of this builder. |
470 | | Listener *listener; |
471 | | }; |
472 | | |
473 | | } // namespace mlir |
474 | | |
475 | | #endif |