Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/lib/IR/Builders.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "mlir/IR/Builders.h"
10
#include "mlir/IR/AffineExpr.h"
11
#include "mlir/IR/AffineMap.h"
12
#include "mlir/IR/Dialect.h"
13
#include "mlir/IR/IntegerSet.h"
14
#include "mlir/IR/Matchers.h"
15
#include "mlir/IR/Module.h"
16
#include "mlir/IR/StandardTypes.h"
17
#include "llvm/Support/raw_ostream.h"
18
using namespace mlir;
19
20
0
Builder::Builder(ModuleOp module) : context(module.getContext()) {}
21
22
0
Identifier Builder::getIdentifier(StringRef str) {
23
0
  return Identifier::get(str, context);
24
0
}
25
26
//===----------------------------------------------------------------------===//
27
// Locations.
28
//===----------------------------------------------------------------------===//
29
30
0
Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
31
32
Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
33
0
                                    unsigned column) {
34
0
  return FileLineColLoc::get(filename, line, column, context);
35
0
}
36
37
0
Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
38
0
  return FusedLoc::get(locs, metadata, context);
39
0
}
40
41
//===----------------------------------------------------------------------===//
42
// Types.
43
//===----------------------------------------------------------------------===//
44
45
0
FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
46
47
0
FloatType Builder::getF16Type() { return FloatType::getF16(context); }
48
49
0
FloatType Builder::getF32Type() { return FloatType::getF32(context); }
50
51
0
FloatType Builder::getF64Type() { return FloatType::getF64(context); }
52
53
0
IndexType Builder::getIndexType() { return IndexType::get(context); }
54
55
0
IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
56
57
0
IntegerType Builder::getI32Type() { return IntegerType::get(32, context); }
58
59
0
IntegerType Builder::getI64Type() { return IntegerType::get(64, context); }
60
61
0
IntegerType Builder::getIntegerType(unsigned width) {
62
0
  return IntegerType::get(width, context);
63
0
}
64
65
0
IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
66
0
  return IntegerType::get(
67
0
      width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
68
0
}
69
70
FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
71
0
                                      ArrayRef<Type> results) {
72
0
  return FunctionType::get(inputs, results, context);
73
0
}
74
75
0
TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) {
76
0
  return TupleType::get(elementTypes, context);
77
0
}
78
79
0
NoneType Builder::getNoneType() { return NoneType::get(context); }
80
81
//===----------------------------------------------------------------------===//
82
// Attributes.
83
//===----------------------------------------------------------------------===//
84
85
0
NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
86
0
  return NamedAttribute(getIdentifier(name), val);
87
0
}
88
89
0
UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
90
91
0
BoolAttr Builder::getBoolAttr(bool value) {
92
0
  return BoolAttr::get(value, context);
93
0
}
94
95
0
DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
96
0
  return DictionaryAttr::get(value, context);
97
0
}
98
99
0
IntegerAttr Builder::getIndexAttr(int64_t value) {
100
0
  return IntegerAttr::get(getIndexType(), APInt(64, value));
101
0
}
102
103
0
IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
104
0
  return IntegerAttr::get(getIntegerType(64), APInt(64, value));
105
0
}
106
107
0
DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
108
0
  return DenseIntElementsAttr::get(
109
0
      VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
110
0
      values);
111
0
}
112
113
0
DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
114
0
  return DenseIntElementsAttr::get(
115
0
      VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
116
0
      values);
117
0
}
118
119
0
DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
120
0
  return DenseIntElementsAttr::get(
121
0
      RankedTensorType::get(static_cast<int64_t>(values.size()),
122
0
                            getIntegerType(32)),
123
0
      values);
124
0
}
125
126
0
DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
127
0
  return DenseIntElementsAttr::get(
128
0
      RankedTensorType::get(static_cast<int64_t>(values.size()),
129
0
                            getIntegerType(64)),
130
0
      values);
131
0
}
132
133
0
DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
134
0
  return DenseIntElementsAttr::get(
135
0
      RankedTensorType::get(static_cast<int64_t>(values.size()),
136
0
                            getIndexType()),
137
0
      values);
138
0
}
139
140
0
IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
141
0
  return IntegerAttr::get(getIntegerType(32), APInt(32, value));
142
0
}
143
144
0
IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
145
0
  return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
146
0
                          APInt(32, value, /*isSigned=*/true));
147
0
}
148
149
0
IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
150
0
  return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
151
0
                          APInt(32, (uint64_t)value, /*isSigned=*/false));
152
0
}
153
154
0
IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
155
0
  return IntegerAttr::get(getIntegerType(16), APInt(16, value));
156
0
}
157
158
0
IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
159
0
  return IntegerAttr::get(getIntegerType(8), APInt(8, value));
160
0
}
161
162
0
IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
163
0
  if (type.isIndex())
164
0
    return IntegerAttr::get(type, APInt(64, value));
165
0
  return IntegerAttr::get(
166
0
      type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
167
0
}
168
169
0
IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
170
0
  return IntegerAttr::get(type, value);
171
0
}
172
173
0
FloatAttr Builder::getF64FloatAttr(double value) {
174
0
  return FloatAttr::get(getF64Type(), APFloat(value));
175
0
}
176
177
0
FloatAttr Builder::getF32FloatAttr(float value) {
178
0
  return FloatAttr::get(getF32Type(), APFloat(value));
179
0
}
180
181
0
FloatAttr Builder::getF16FloatAttr(float value) {
182
0
  return FloatAttr::get(getF16Type(), value);
183
0
}
184
185
0
FloatAttr Builder::getFloatAttr(Type type, double value) {
186
0
  return FloatAttr::get(type, value);
187
0
}
188
189
0
FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
190
0
  return FloatAttr::get(type, value);
191
0
}
192
193
0
StringAttr Builder::getStringAttr(StringRef bytes) {
194
0
  return StringAttr::get(bytes, context);
195
0
}
196
197
0
ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
198
0
  return ArrayAttr::get(value, context);
199
0
}
200
201
0
FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
202
0
  auto symName =
203
0
      value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
204
0
  assert(symName && "value does not have a valid symbol name");
205
0
  return getSymbolRefAttr(symName.getValue());
206
0
}
207
0
FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
208
0
  return SymbolRefAttr::get(value, getContext());
209
0
}
210
SymbolRefAttr
211
Builder::getSymbolRefAttr(StringRef value,
212
0
                          ArrayRef<FlatSymbolRefAttr> nestedReferences) {
213
0
  return SymbolRefAttr::get(value, nestedReferences, getContext());
214
0
}
215
216
0
ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
217
0
  auto attrs = llvm::to_vector<8>(llvm::map_range(
218
0
      values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
219
0
  return getArrayAttr(attrs);
220
0
}
221
222
0
ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
223
0
  auto attrs = llvm::to_vector<8>(llvm::map_range(
224
0
      values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
225
0
  return getArrayAttr(attrs);
226
0
}
227
0
ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
228
0
  auto attrs = llvm::to_vector<8>(llvm::map_range(
229
0
      values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
230
0
  return getArrayAttr(attrs);
231
0
}
232
233
0
ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
234
0
  auto attrs = llvm::to_vector<8>(
235
0
      llvm::map_range(values, [this](int64_t v) -> Attribute {
236
0
        return getIntegerAttr(IndexType::get(getContext()), v);
237
0
      }));
238
0
  return getArrayAttr(attrs);
239
0
}
240
241
0
ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
242
0
  auto attrs = llvm::to_vector<8>(llvm::map_range(
243
0
      values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
244
0
  return getArrayAttr(attrs);
245
0
}
246
247
0
ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
248
0
  auto attrs = llvm::to_vector<8>(llvm::map_range(
249
0
      values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
250
0
  return getArrayAttr(attrs);
251
0
}
252
253
0
ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
254
0
  auto attrs = llvm::to_vector<8>(llvm::map_range(
255
0
      values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
256
0
  return getArrayAttr(attrs);
257
0
}
258
259
0
ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
260
0
  auto attrs = llvm::to_vector<8>(llvm::map_range(
261
0
      values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
262
0
  return getArrayAttr(attrs);
263
0
}
264
265
0
Attribute Builder::getZeroAttr(Type type) {
266
0
  switch (type.getKind()) {
267
0
  case StandardTypes::BF16:
268
0
  case StandardTypes::F16:
269
0
  case StandardTypes::F32:
270
0
  case StandardTypes::F64:
271
0
    return getFloatAttr(type, 0.0);
272
0
  case StandardTypes::Integer: {
273
0
    auto width = type.cast<IntegerType>().getWidth();
274
0
    if (width == 1)
275
0
      return getBoolAttr(false);
276
0
    return getIntegerAttr(type, APInt(width, 0));
277
0
  }
278
0
  case StandardTypes::Vector:
279
0
  case StandardTypes::RankedTensor: {
280
0
    auto vtType = type.cast<ShapedType>();
281
0
    auto element = getZeroAttr(vtType.getElementType());
282
0
    if (!element)
283
0
      return {};
284
0
    return DenseElementsAttr::get(vtType, element);
285
0
  }
286
0
  default:
287
0
    break;
288
0
  }
289
0
  return {};
290
0
}
291
292
//===----------------------------------------------------------------------===//
293
// Affine Expressions, Affine Maps, and Integer Sets.
294
//===----------------------------------------------------------------------===//
295
296
0
AffineExpr Builder::getAffineDimExpr(unsigned position) {
297
0
  return mlir::getAffineDimExpr(position, context);
298
0
}
299
300
0
AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
301
0
  return mlir::getAffineSymbolExpr(position, context);
302
0
}
303
304
0
AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
305
0
  return mlir::getAffineConstantExpr(constant, context);
306
0
}
307
308
0
AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
309
310
0
AffineMap Builder::getConstantAffineMap(int64_t val) {
311
0
  return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
312
0
                        getAffineConstantExpr(val));
313
0
}
314
315
0
AffineMap Builder::getDimIdentityMap() {
316
0
  return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
317
0
}
318
319
0
AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
320
0
  SmallVector<AffineExpr, 4> dimExprs;
321
0
  dimExprs.reserve(rank);
322
0
  for (unsigned i = 0; i < rank; ++i)
323
0
    dimExprs.push_back(getAffineDimExpr(i));
324
0
  return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
325
0
                        context);
326
0
}
327
328
0
AffineMap Builder::getSymbolIdentityMap() {
329
0
  return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
330
0
                        getAffineSymbolExpr(0));
331
0
}
332
333
0
AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
334
0
  // expr = d0 + shift.
335
0
  auto expr = getAffineDimExpr(0) + shift;
336
0
  return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
337
0
}
338
339
0
AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
340
0
  SmallVector<AffineExpr, 4> shiftedResults;
341
0
  shiftedResults.reserve(map.getNumResults());
342
0
  for (auto resultExpr : map.getResults())
343
0
    shiftedResults.push_back(resultExpr + shift);
344
0
  return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
345
0
                        context);
346
0
}
347
348
//===----------------------------------------------------------------------===//
349
// OpBuilder
350
//===----------------------------------------------------------------------===//
351
352
0
OpBuilder::Listener::~Listener() {}
353
354
/// Insert the given operation at the current insertion point and return it.
355
0
Operation *OpBuilder::insert(Operation *op) {
356
0
  if (block)
357
0
    block->getOperations().insert(insertPoint, op);
358
0
359
0
  if (listener)
360
0
    listener->notifyOperationInserted(op);
361
0
  return op;
362
0
}
363
364
/// Add new block with 'argTypes' arguments and set the insertion point to the
365
/// end of it. The block is inserted at the provided insertion point of
366
/// 'parent'.
367
Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
368
0
                              TypeRange argTypes) {
369
0
  assert(parent && "expected valid parent region");
370
0
  if (insertPt == Region::iterator())
371
0
    insertPt = parent->end();
372
0
373
0
  Block *b = new Block();
374
0
  b->addArguments(argTypes);
375
0
  parent->getBlocks().insert(insertPt, b);
376
0
  setInsertionPointToEnd(b);
377
0
378
0
  if (listener)
379
0
    listener->notifyBlockCreated(b);
380
0
  return b;
381
0
}
382
383
/// Add new block with 'argTypes' arguments and set the insertion point to the
384
/// end of it.  The block is placed before 'insertBefore'.
385
0
Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
386
0
  assert(insertBefore && "expected valid insertion block");
387
0
  return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
388
0
                     argTypes);
389
0
}
390
391
/// Create an operation given the fields represented as an OperationState.
392
0
Operation *OpBuilder::createOperation(const OperationState &state) {
393
0
  return insert(Operation::create(state));
394
0
}
395
396
/// Attempts to fold the given operation and places new results within
397
/// 'results'. Returns success if the operation was folded, failure otherwise.
398
/// Note: This function does not erase the operation on a successful fold.
399
LogicalResult OpBuilder::tryFold(Operation *op,
400
0
                                 SmallVectorImpl<Value> &results) {
401
0
  results.reserve(op->getNumResults());
402
0
  auto cleanupFailure = [&] {
403
0
    results.assign(op->result_begin(), op->result_end());
404
0
    return failure();
405
0
  };
406
0
407
0
  // If this operation is already a constant, there is nothing to do.
408
0
  if (matchPattern(op, m_Constant()))
409
0
    return cleanupFailure();
410
0
411
0
  // Check to see if any operands to the operation is constant and whether
412
0
  // the operation knows how to constant fold itself.
413
0
  SmallVector<Attribute, 4> constOperands(op->getNumOperands());
414
0
  for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
415
0
    matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
416
0
417
0
  // Try to fold the operation.
418
0
  SmallVector<OpFoldResult, 4> foldResults;
419
0
  if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
420
0
    return cleanupFailure();
421
0
422
0
  // A temporary builder used for creating constants during folding.
423
0
  OpBuilder cstBuilder(context);
424
0
  SmallVector<Operation *, 1> generatedConstants;
425
0
426
0
  // Populate the results with the folded results.
427
0
  Dialect *dialect = op->getDialect();
428
0
  for (auto &it : llvm::enumerate(foldResults)) {
429
0
    // Normal values get pushed back directly.
430
0
    if (auto value = it.value().dyn_cast<Value>()) {
431
0
      results.push_back(value);
432
0
      continue;
433
0
    }
434
0
435
0
    // Otherwise, try to materialize a constant operation.
436
0
    if (!dialect)
437
0
      return cleanupFailure();
438
0
439
0
    // Ask the dialect to materialize a constant operation for this value.
440
0
    Attribute attr = it.value().get<Attribute>();
441
0
    auto *constOp = dialect->materializeConstant(
442
0
        cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
443
0
    if (!constOp) {
444
0
      // Erase any generated constants.
445
0
      for (Operation *cst : generatedConstants)
446
0
        cst->erase();
447
0
      return cleanupFailure();
448
0
    }
449
0
    assert(matchPattern(constOp, m_Constant()));
450
0
451
0
    generatedConstants.push_back(constOp);
452
0
    results.push_back(constOp->getResult(0));
453
0
  }
454
0
455
0
  // If we were successful, insert any generated constants.
456
0
  for (Operation *cst : generatedConstants)
457
0
    insert(cst);
458
0
459
0
  return success();
460
0
}