Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/include/mlir/IR/Builders.h
Line
Count
Source (jump to first uncovered line)
1
//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#ifndef MLIR_IR_BUILDERS_H
10
#define MLIR_IR_BUILDERS_H
11
12
#include "mlir/IR/OpDefinition.h"
13
14
namespace mlir {
15
16
class AffineExpr;
17
class BlockAndValueMapping;
18
class ModuleOp;
19
class UnknownLoc;
20
class FileLineColLoc;
21
class Type;
22
class PrimitiveType;
23
class IntegerType;
24
class FunctionType;
25
class MemRefType;
26
class VectorType;
27
class RankedTensorType;
28
class UnrankedTensorType;
29
class TupleType;
30
class NoneType;
31
class BoolAttr;
32
class IntegerAttr;
33
class FloatAttr;
34
class StringAttr;
35
class TypeAttr;
36
class ArrayAttr;
37
class SymbolRefAttr;
38
class ElementsAttr;
39
class DenseElementsAttr;
40
class DenseIntElementsAttr;
41
class AffineMapAttr;
42
class AffineMap;
43
class UnitAttr;
44
45
/// This class is a general helper class for creating context-global objects
46
/// like types, attributes, and affine expressions.
47
class Builder {
48
public:
49
0
  explicit Builder(MLIRContext *context) : context(context) {}
50
  explicit Builder(ModuleOp module);
51
52
0
  MLIRContext *getContext() const { return context; }
53
54
  Identifier getIdentifier(StringRef str);
55
56
  // Locations.
57
  Location getUnknownLoc();
58
  Location getFileLineColLoc(Identifier filename, unsigned line,
59
                             unsigned column);
60
  Location getFusedLoc(ArrayRef<Location> locs,
61
                       Attribute metadata = Attribute());
62
63
  // Types.
64
  FloatType getBF16Type();
65
  FloatType getF16Type();
66
  FloatType getF32Type();
67
  FloatType getF64Type();
68
69
  IndexType getIndexType();
70
71
  IntegerType getI1Type();
72
  IntegerType getI32Type();
73
  IntegerType getI64Type();
74
  IntegerType getIntegerType(unsigned width);
75
  IntegerType getIntegerType(unsigned width, bool isSigned);
76
  FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
77
  TupleType getTupleType(ArrayRef<Type> elementTypes);
78
  NoneType getNoneType();
79
80
  /// Get or construct an instance of the type 'ty' with provided arguments.
81
0
  template <typename Ty, typename... Args> Ty getType(Args... args) {
82
0
    return Ty::get(context, args...);
83
0
  }
84
85
  // Attributes.
86
  NamedAttribute getNamedAttr(StringRef name, Attribute val);
87
88
  UnitAttr getUnitAttr();
89
  BoolAttr getBoolAttr(bool value);
90
  DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value);
91
  IntegerAttr getIntegerAttr(Type type, int64_t value);
92
  IntegerAttr getIntegerAttr(Type type, const APInt &value);
93
  FloatAttr getFloatAttr(Type type, double value);
94
  FloatAttr getFloatAttr(Type type, const APFloat &value);
95
  StringAttr getStringAttr(StringRef bytes);
96
  ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
97
  FlatSymbolRefAttr getSymbolRefAttr(Operation *value);
98
  FlatSymbolRefAttr getSymbolRefAttr(StringRef value);
99
  SymbolRefAttr getSymbolRefAttr(StringRef value,
100
                                 ArrayRef<FlatSymbolRefAttr> nestedReferences);
101
102
  // Returns a 0-valued attribute of the given `type`. This function only
103
  // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
104
  // ranked tensor of them. Returns null attribute otherwise.
105
  Attribute getZeroAttr(Type type);
106
107
  // Convenience methods for fixed types.
108
  FloatAttr getF16FloatAttr(float value);
109
  FloatAttr getF32FloatAttr(float value);
110
  FloatAttr getF64FloatAttr(double value);
111
112
  IntegerAttr getI8IntegerAttr(int8_t value);
113
  IntegerAttr getI16IntegerAttr(int16_t value);
114
  IntegerAttr getI32IntegerAttr(int32_t value);
115
  IntegerAttr getI64IntegerAttr(int64_t value);
116
  IntegerAttr getIndexAttr(int64_t value);
117
118
  /// Signed and unsigned integer attribute getters.
119
  IntegerAttr getSI32IntegerAttr(int32_t value);
120
  IntegerAttr getUI32IntegerAttr(uint32_t value);
121
122
  /// Vector-typed DenseIntElementsAttr getters. `values` must not be empty.
123
  DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
124
  DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values);
125
126
  /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty.
127
  /// These are generally preferable for representing general lists of integers
128
  /// as attributes.
129
  DenseIntElementsAttr getI32TensorAttr(ArrayRef<int32_t> values);
130
  DenseIntElementsAttr getI64TensorAttr(ArrayRef<int64_t> values);
131
  DenseIntElementsAttr getIndexTensorAttr(ArrayRef<int64_t> values);
132
133
  ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
134
  ArrayAttr getBoolArrayAttr(ArrayRef<bool> values);
135
  ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
136
  ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
137
  ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values);
138
  ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
139
  ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
140
  ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
141
142
  // Affine expressions and affine maps.
143
  AffineExpr getAffineDimExpr(unsigned position);
144
  AffineExpr getAffineSymbolExpr(unsigned position);
145
  AffineExpr getAffineConstantExpr(int64_t constant);
146
147
  // Special cases of affine maps and integer sets
148
  /// Returns a zero result affine map with no dimensions or symbols: () -> ().
149
  AffineMap getEmptyAffineMap();
150
  /// Returns a single constant result affine map with 0 dimensions and 0
151
  /// symbols.  One constant result: () -> (val).
152
  AffineMap getConstantAffineMap(int64_t val);
153
  // One dimension id identity map: (i) -> (i).
154
  AffineMap getDimIdentityMap();
155
  // Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2).
156
  AffineMap getMultiDimIdentityMap(unsigned rank);
157
  // One symbol identity map: ()[s] -> (s).
158
  AffineMap getSymbolIdentityMap();
159
160
  /// Returns a map that shifts its (single) input dimension by 'shift'.
161
  /// (d0) -> (d0 + shift)
162
  AffineMap getSingleDimShiftAffineMap(int64_t shift);
163
164
  /// Returns an affine map that is a translation (shift) of all result
165
  /// expressions in 'map' by 'shift'.
166
  /// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2
167
  ///   returns:    (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2)
168
  AffineMap getShiftedAffineMap(AffineMap map, int64_t shift);
169
170
protected:
171
  MLIRContext *context;
172
};
173
174
/// This class helps build Operations. Operations that are created are
175
/// automatically inserted at an insertion point. The builder is copyable.
176
class OpBuilder : public Builder {
177
public:
178
  struct Listener;
179
180
  /// Create a builder with the given context.
181
  explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr)
182
0
      : Builder(ctx), listener(listener) {}
183
184
  /// Create a builder and set the insertion point to the start of the region.
185
  explicit OpBuilder(Region *region, Listener *listener = nullptr)
186
0
      : OpBuilder(region->getContext(), listener) {
187
0
    if (!region->empty())
188
0
      setInsertionPoint(&region->front(), region->front().begin());
189
0
  }
190
  explicit OpBuilder(Region &region, Listener *listener = nullptr)
191
0
      : OpBuilder(&region, listener) {}
192
193
  /// Create a builder and set insertion point to the given operation, which
194
  /// will cause subsequent insertions to go right before it.
195
  explicit OpBuilder(Operation *op, Listener *listener = nullptr)
196
0
      : OpBuilder(op->getContext(), listener) {
197
0
    setInsertionPoint(op);
198
0
  }
199
200
  OpBuilder(Block *block, Block::iterator insertPoint,
201
            Listener *listener = nullptr)
202
0
      : OpBuilder(block->getParent()->getContext(), listener) {
203
0
    setInsertionPoint(block, insertPoint);
204
0
  }
205
206
  /// Create a builder and set the insertion point to before the first operation
207
  /// in the block but still inside the block.
208
0
  static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) {
209
0
    return OpBuilder(block, block->begin(), listener);
210
0
  }
211
212
  /// Create a builder and set the insertion point to after the last operation
213
  /// in the block but still inside the block.
214
0
  static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) {
215
0
    return OpBuilder(block, block->end(), listener);
216
0
  }
217
218
  /// Create a builder and set the insertion point to before the block
219
  /// terminator.
220
  static OpBuilder atBlockTerminator(Block *block,
221
0
                                     Listener *listener = nullptr) {
222
0
    auto *terminator = block->getTerminator();
223
0
    assert(terminator != nullptr && "the block has no terminator");
224
0
    return OpBuilder(block, Block::iterator(terminator), listener);
225
0
  }
226
227
  //===--------------------------------------------------------------------===//
228
  // Listeners
229
  //===--------------------------------------------------------------------===//
230
231
  /// This class represents a listener that may be used to hook into various
232
  /// actions within an OpBuilder.
233
  struct Listener {
234
    virtual ~Listener();
235
236
    /// Notification handler for when an operation is inserted into the builder.
237
    /// `op` is the operation that was inserted.
238
0
    virtual void notifyOperationInserted(Operation *op) {}
239
240
    /// Notification handler for when a block is created using the builder.
241
    /// `block` is the block that was created.
242
0
    virtual void notifyBlockCreated(Block *block) {}
243
  };
244
245
  /// Sets the listener of this builder to the one provided.
246
0
  void setListener(Listener *newListener) { listener = newListener; }
247
248
  /// Returns the current listener of this builder, or nullptr if this builder
249
  /// doesn't have a listener.
250
0
  Listener *getListener() const { return listener; }
251
252
  //===--------------------------------------------------------------------===//
253
  // Insertion Point Management
254
  //===--------------------------------------------------------------------===//
255
256
  /// This class represents a saved insertion point.
257
  class InsertPoint {
258
  public:
259
    /// Creates a new insertion point which doesn't point to anything.
260
    InsertPoint() = default;
261
262
    /// Creates a new insertion point at the given location.
263
    InsertPoint(Block *insertBlock, Block::iterator insertPt)
264
0
        : block(insertBlock), point(insertPt) {}
265
266
    /// Returns true if this insert point is set.
267
0
    bool isSet() const { return (block != nullptr); }
268
269
0
    Block *getBlock() const { return block; }
270
0
    Block::iterator getPoint() const { return point; }
271
272
  private:
273
    Block *block = nullptr;
274
    Block::iterator point;
275
  };
276
277
  /// RAII guard to reset the insertion point of the builder when destroyed.
278
  class InsertionGuard {
279
  public:
280
    InsertionGuard(OpBuilder &builder)
281
0
        : builder(builder), ip(builder.saveInsertionPoint()) {}
282
0
    ~InsertionGuard() { builder.restoreInsertionPoint(ip); }
283
284
  private:
285
    OpBuilder &builder;
286
    OpBuilder::InsertPoint ip;
287
  };
288
289
  /// Reset the insertion point to no location.  Creating an operation without a
290
  /// set insertion point is an error, but this can still be useful when the
291
  /// current insertion point a builder refers to is being removed.
292
0
  void clearInsertionPoint() {
293
0
    this->block = nullptr;
294
0
    insertPoint = Block::iterator();
295
0
  }
296
297
  /// Return a saved insertion point.
298
0
  InsertPoint saveInsertionPoint() const {
299
0
    return InsertPoint(getInsertionBlock(), getInsertionPoint());
300
0
  }
301
302
  /// Restore the insert point to a previously saved point.
303
0
  void restoreInsertionPoint(InsertPoint ip) {
304
0
    if (ip.isSet())
305
0
      setInsertionPoint(ip.getBlock(), ip.getPoint());
306
0
    else
307
0
      clearInsertionPoint();
308
0
  }
309
310
  /// Set the insertion point to the specified location.
311
0
  void setInsertionPoint(Block *block, Block::iterator insertPoint) {
312
0
    // TODO: check that insertPoint is in this rather than some other block.
313
0
    this->block = block;
314
0
    this->insertPoint = insertPoint;
315
0
  }
316
317
  /// Sets the insertion point to the specified operation, which will cause
318
  /// subsequent insertions to go right before it.
319
0
  void setInsertionPoint(Operation *op) {
320
0
    setInsertionPoint(op->getBlock(), Block::iterator(op));
321
0
  }
322
323
  /// Sets the insertion point to the node after the specified operation, which
324
  /// will cause subsequent insertions to go right after it.
325
0
  void setInsertionPointAfter(Operation *op) {
326
0
    setInsertionPoint(op->getBlock(), ++Block::iterator(op));
327
0
  }
328
329
  /// Sets the insertion point to the start of the specified block.
330
0
  void setInsertionPointToStart(Block *block) {
331
0
    setInsertionPoint(block, block->begin());
332
0
  }
333
334
  /// Sets the insertion point to the end of the specified block.
335
0
  void setInsertionPointToEnd(Block *block) {
336
0
    setInsertionPoint(block, block->end());
337
0
  }
338
339
  /// Return the block the current insertion point belongs to.  Note that the
340
  /// the insertion point is not necessarily the end of the block.
341
0
  Block *getInsertionBlock() const { return block; }
342
343
  /// Returns the current insertion point of the builder.
344
0
  Block::iterator getInsertionPoint() const { return insertPoint; }
345
346
  /// Returns the current block of the builder.
347
0
  Block *getBlock() const { return block; }
348
349
  //===--------------------------------------------------------------------===//
350
  // Block Creation
351
  //===--------------------------------------------------------------------===//
352
353
  /// Add new block with 'argTypes' arguments and set the insertion point to the
354
  /// end of it. The block is inserted at the provided insertion point of
355
  /// 'parent'.
356
  Block *createBlock(Region *parent, Region::iterator insertPt = {},
357
                     TypeRange argTypes = llvm::None);
358
359
  /// Add new block with 'argTypes' arguments and set the insertion point to the
360
  /// end of it. The block is placed before 'insertBefore'.
361
  Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None);
362
363
  //===--------------------------------------------------------------------===//
364
  // Operation Creation
365
  //===--------------------------------------------------------------------===//
366
367
  /// Insert the given operation at the current insertion point and return it.
368
  Operation *insert(Operation *op);
369
370
  /// Creates an operation given the fields represented as an OperationState.
371
  Operation *createOperation(const OperationState &state);
372
373
  /// Create an operation of specific op type at the current insertion point.
374
  template <typename OpTy, typename... Args>
375
0
  OpTy create(Location location, Args &&... args) {
376
0
    OperationState state(location, OpTy::getOperationName());
377
0
    if (!state.name.getAbstractOperation())
378
0
      llvm::report_fatal_error("Building op `" +
379
0
                               state.name.getStringRef().str() +
380
0
                               "` but it isn't registered in this MLIRContext");
381
0
    OpTy::build(*this, state, std::forward<Args>(args)...);
382
0
    auto *op = createOperation(state);
383
0
    auto result = dyn_cast<OpTy>(op);
384
0
    assert(result && "builder didn't return the right type");
385
0
    return result;
386
0
  }
Unexecuted instantiation: _ZN4mlir9OpBuilder6createINS_10ConstantOpEJRNS_4TypeERNS_9AttributeEEEET_NS_8LocationEDpOT0_
Unexecuted instantiation: _ZN4mlir9OpBuilder6createINS_13AffineApplyOpEJRNS_9AffineMapERN4llvm11SmallVectorINS_5ValueELj8EEEEEET_NS_8LocationEDpOT0_
Unexecuted instantiation: _ZN4mlir9OpBuilder6createINS_8BranchOpEJRPNS_5BlockENS_12OperandRangeEEEET_NS_8LocationEDpOT0_
Unexecuted instantiation: _ZN4mlir9OpBuilder6createINS_15ConstantIndexOpEJlEEET_NS_8LocationEDpOT0_
Unexecuted instantiation: _ZN4mlir9OpBuilder6createINS_15ConstantIndexOpEJRlEEET_NS_8LocationEDpOT0_
387
388
  /// Create an operation of specific op type at the current insertion point,
389
  /// and immediately try to fold it. This functions populates 'results' with
390
  /// the results after folding the operation.
391
  template <typename OpTy, typename... Args>
392
  void createOrFold(SmallVectorImpl<Value> &results, Location location,
393
                    Args &&... args) {
394
    // Create the operation without using 'createOperation' as we don't want to
395
    // insert it yet.
396
    OperationState state(location, OpTy::getOperationName());
397
    if (!state.name.getAbstractOperation())
398
      llvm::report_fatal_error("Building op `" +
399
                               state.name.getStringRef().str() +
400
                               "` but it isn't registered in this MLIRContext");
401
    OpTy::build(*this, state, std::forward<Args>(args)...);
402
    Operation *op = Operation::create(state);
403
404
    // Fold the operation. If successful destroy it, otherwise insert it.
405
    if (succeeded(tryFold(op, results)))
406
      op->destroy();
407
    else
408
      insert(op);
409
  }
410
411
  /// Overload to create or fold a single result operation.
412
  template <typename OpTy, typename... Args>
413
  typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
414
                          Value>::type
415
  createOrFold(Location location, Args &&... args) {
416
    SmallVector<Value, 1> results;
417
    createOrFold<OpTy>(results, location, std::forward<Args>(args)...);
418
    return results.front();
419
  }
420
421
  /// Overload to create or fold a zero result operation.
422
  template <typename OpTy, typename... Args>
423
  typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(),
424
                          OpTy>::type
425
  createOrFold(Location location, Args &&... args) {
426
    auto op = create<OpTy>(location, std::forward<Args>(args)...);
427
    SmallVector<Value, 0> unused;
428
    tryFold(op.getOperation(), unused);
429
430
    // Folding cannot remove a zero-result operation, so for convenience we
431
    // continue to return it.
432
    return op;
433
  }
434
435
  /// Attempts to fold the given operation and places new results within
436
  /// 'results'. Returns success if the operation was folded, failure otherwise.
437
  /// Note: This function does not erase the operation on a successful fold.
438
  LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
439
440
  /// Creates a deep copy of the specified operation, remapping any operands
441
  /// that use values outside of the operation using the map that is provided
442
  /// ( leaving them alone if no entry is present).  Replaces references to
443
  /// cloned sub-operations to the corresponding operation that is copied,
444
  /// and adds those mappings to the map.
445
0
  Operation *clone(Operation &op, BlockAndValueMapping &mapper) {
446
0
    return insert(op.clone(mapper));
447
0
  }
448
0
  Operation *clone(Operation &op) { return insert(op.clone()); }
449
450
  /// Creates a deep copy of this operation but keep the operation regions
451
  /// empty. Operands are remapped using `mapper` (if present), and `mapper` is
452
  /// updated to contain the results.
453
0
  Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) {
454
0
    return insert(op.cloneWithoutRegions(mapper));
455
0
  }
456
0
  Operation *cloneWithoutRegions(Operation &op) {
457
0
    return insert(op.cloneWithoutRegions());
458
0
  }
459
  template <typename OpT> OpT cloneWithoutRegions(OpT op) {
460
    return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
461
  }
462
463
private:
464
  /// The current block this builder is inserting into.
465
  Block *block = nullptr;
466
  /// The insertion point within the block that this builder is inserting
467
  /// before.
468
  Block::iterator insertPoint;
469
  /// The optional listener for events of this builder.
470
  Listener *listener;
471
};
472
473
} // namespace mlir
474
475
#endif