Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "mlir/Dialect/StandardOps/IR/Ops.h"
10
11
#include "mlir/Dialect/CommonFolders.h"
12
#include "mlir/IR/AffineExpr.h"
13
#include "mlir/IR/AffineMap.h"
14
#include "mlir/IR/Builders.h"
15
#include "mlir/IR/Function.h"
16
#include "mlir/IR/Matchers.h"
17
#include "mlir/IR/Module.h"
18
#include "mlir/IR/OpImplementation.h"
19
#include "mlir/IR/PatternMatch.h"
20
#include "mlir/IR/StandardTypes.h"
21
#include "mlir/IR/TypeUtilities.h"
22
#include "mlir/IR/Value.h"
23
#include "mlir/Support/MathExtras.h"
24
#include "mlir/Transforms/InliningUtils.h"
25
#include "llvm/ADT/StringSwitch.h"
26
#include "llvm/Support/FormatVariadic.h"
27
#include "llvm/Support/raw_ostream.h"
28
29
// Pull in all enum type definitions and utility function declarations.
30
#include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
31
32
using namespace mlir;
33
34
//===----------------------------------------------------------------------===//
35
// StandardOpsDialect Interfaces
36
//===----------------------------------------------------------------------===//
37
namespace {
38
/// This class defines the interface for handling inlining with standard
39
/// operations.
40
struct StdInlinerInterface : public DialectInlinerInterface {
41
  using DialectInlinerInterface::DialectInlinerInterface;
42
43
  //===--------------------------------------------------------------------===//
44
  // Analysis Hooks
45
  //===--------------------------------------------------------------------===//
46
47
  /// All operations within standard ops can be inlined.
48
  bool isLegalToInline(Operation *, Region *,
49
0
                       BlockAndValueMapping &) const final {
50
0
    return true;
51
0
  }
52
53
  //===--------------------------------------------------------------------===//
54
  // Transformation Hooks
55
  //===--------------------------------------------------------------------===//
56
57
  /// Handle the given inlined terminator by replacing it with a new operation
58
  /// as necessary.
59
0
  void handleTerminator(Operation *op, Block *newDest) const final {
60
0
    // Only "std.return" needs to be handled here.
61
0
    auto returnOp = dyn_cast<ReturnOp>(op);
62
0
    if (!returnOp)
63
0
      return;
64
0
65
0
    // Replace the return with a branch to the dest.
66
0
    OpBuilder builder(op);
67
0
    builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
68
0
    op->erase();
69
0
  }
70
71
  /// Handle the given inlined terminator by replacing it with a new operation
72
  /// as necessary.
73
  void handleTerminator(Operation *op,
74
0
                        ArrayRef<Value> valuesToRepl) const final {
75
0
    // Only "std.return" needs to be handled here.
76
0
    auto returnOp = cast<ReturnOp>(op);
77
0
78
0
    // Replace the values directly with the return operands.
79
0
    assert(returnOp.getNumOperands() == valuesToRepl.size());
80
0
    for (const auto &it : llvm::enumerate(returnOp.getOperands()))
81
0
      valuesToRepl[it.index()].replaceAllUsesWith(it.value());
82
0
  }
83
};
84
} // end anonymous namespace
85
86
//===----------------------------------------------------------------------===//
87
// StandardOpsDialect
88
//===----------------------------------------------------------------------===//
89
90
/// A custom unary operation printer that omits the "std." prefix from the
91
/// operation names.
92
0
static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
93
0
  assert(op->getNumOperands() == 1 && "unary op should have one operand");
94
0
  assert(op->getNumResults() == 1 && "unary op should have one result");
95
0
96
0
  int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
97
0
  p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
98
0
    << op->getOperand(0);
99
0
  p.printOptionalAttrDict(op->getAttrs());
100
0
  p << " : " << op->getOperand(0).getType();
101
0
}
102
103
/// A custom binary operation printer that omits the "std." prefix from the
104
/// operation names.
105
0
static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
106
0
  assert(op->getNumOperands() == 2 && "binary op should have two operands");
107
0
  assert(op->getNumResults() == 1 && "binary op should have one result");
108
0
109
0
  // If not all the operand and result types are the same, just use the
110
0
  // generic assembly form to avoid omitting information in printing.
111
0
  auto resultType = op->getResult(0).getType();
112
0
  if (op->getOperand(0).getType() != resultType ||
113
0
      op->getOperand(1).getType() != resultType) {
114
0
    p.printGenericOp(op);
115
0
    return;
116
0
  }
117
0
118
0
  int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
119
0
  p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
120
0
    << op->getOperand(0) << ", " << op->getOperand(1);
121
0
  p.printOptionalAttrDict(op->getAttrs());
122
0
123
0
  // Now we can output only one type for all operands and the result.
124
0
  p << " : " << op->getResult(0).getType();
125
0
}
126
127
/// A custom cast operation printer that omits the "std." prefix from the
128
/// operation names.
129
0
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
130
0
  int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
131
0
  p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
132
0
    << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to "
133
0
    << op->getResult(0).getType();
134
0
}
135
136
/// A custom cast operation verifier.
137
template <typename T>
138
0
static LogicalResult verifyCastOp(T op) {
139
0
  auto opType = op.getOperand().getType();
140
0
  auto resType = op.getType();
141
0
  if (!T::areCastCompatible(opType, resType))
142
0
    return op.emitError("operand type ") << opType << " and result type "
143
0
                                         << resType << " are cast incompatible";
144
0
145
0
  return success();
146
0
}
Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir7FPExtOpEENS0_13LogicalResultET_
Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir8FPToSIOpEENS0_13LogicalResultET_
Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir9FPTruncOpEENS0_13LogicalResultET_
Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir11IndexCastOpEENS0_13LogicalResultET_
Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir12MemRefCastOpEENS0_13LogicalResultET_
Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir8SIToFPOpEENS0_13LogicalResultET_
Unexecuted instantiation: Ops.cpp:_ZL12verifyCastOpIN4mlir12TensorCastOpEENS0_13LogicalResultET_
147
148
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
149
0
    : Dialect(getDialectNamespace(), context) {
150
0
  addOperations<DmaStartOp, DmaWaitOp,
151
0
#define GET_OP_LIST
152
0
#include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
153
0
                >();
154
0
  addInterfaces<StdInlinerInterface>();
155
0
}
156
157
/// Materialize a single constant operation from a given attribute value with
158
/// the desired resultant type.
159
Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
160
                                                   Attribute value, Type type,
161
0
                                                   Location loc) {
162
0
  return builder.create<ConstantOp>(loc, type, value);
163
0
}
164
165
void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
166
                                 Operation::operand_iterator end,
167
0
                                 unsigned numDims, OpAsmPrinter &p) {
168
0
  Operation::operand_range operands(begin, end);
169
0
  p << '(' << operands.take_front(numDims) << ')';
170
0
  if (operands.size() != numDims)
171
0
    p << '[' << operands.drop_front(numDims) << ']';
172
0
}
173
174
// Parses dimension and symbol list, and sets 'numDims' to the number of
175
// dimension operands parsed.
176
// Returns 'false' on success and 'true' on error.
177
ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
178
                                        SmallVectorImpl<Value> &operands,
179
0
                                        unsigned &numDims) {
180
0
  SmallVector<OpAsmParser::OperandType, 8> opInfos;
181
0
  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
182
0
    return failure();
183
0
  // Store number of dimensions for validation by caller.
184
0
  numDims = opInfos.size();
185
0
186
0
  // Parse the optional symbol operands.
187
0
  auto indexTy = parser.getBuilder().getIndexType();
188
0
  if (parser.parseOperandList(opInfos,
189
0
                              OpAsmParser::Delimiter::OptionalSquare) ||
190
0
      parser.resolveOperands(opInfos, indexTy, operands))
191
0
    return failure();
192
0
  return success();
193
0
}
194
195
/// Matches a ConstantIndexOp.
196
/// TODO: This should probably just be a general matcher that uses m_Constant
197
/// and checks the operation for an index type.
198
0
static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
199
0
  return detail::op_matcher<ConstantIndexOp>();
200
0
}
201
202
//===----------------------------------------------------------------------===//
203
// Common canonicalization pattern support logic
204
//===----------------------------------------------------------------------===//
205
206
/// This is a common class used for patterns of the form
207
/// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
208
/// into the root operation directly.
209
0
static LogicalResult foldMemRefCast(Operation *op) {
210
0
  bool folded = false;
211
0
  for (OpOperand &operand : op->getOpOperands()) {
212
0
    auto cast = operand.get().getDefiningOp<MemRefCastOp>();
213
0
    if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
214
0
      operand.set(cast.getOperand());
215
0
      folded = true;
216
0
    }
217
0
  }
218
0
  return success(folded);
219
0
}
220
221
//===----------------------------------------------------------------------===//
222
// AddFOp
223
//===----------------------------------------------------------------------===//
224
225
0
OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
226
0
  return constFoldBinaryOp<FloatAttr>(
227
0
      operands, [](APFloat a, APFloat b) { return a + b; });
228
0
}
229
230
//===----------------------------------------------------------------------===//
231
// AddIOp
232
//===----------------------------------------------------------------------===//
233
234
0
OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
235
0
  /// addi(x, 0) -> x
236
0
  if (matchPattern(rhs(), m_Zero()))
237
0
    return lhs();
238
0
239
0
  return constFoldBinaryOp<IntegerAttr>(operands,
240
0
                                        [](APInt a, APInt b) { return a + b; });
241
0
}
242
243
//===----------------------------------------------------------------------===//
244
// AllocOp / AllocaOp
245
//===----------------------------------------------------------------------===//
246
247
template <typename AllocLikeOp>
248
0
static void printAllocLikeOp(OpAsmPrinter &p, AllocLikeOp op, StringRef name) {
249
0
  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
250
0
                "applies to only alloc or alloca");
251
0
  p << name;
252
0
253
0
  // Print dynamic dimension operands.
254
0
  MemRefType type = op.getType();
255
0
  printDimAndSymbolList(op.operand_begin(), op.operand_end(),
256
0
                        type.getNumDynamicDims(), p);
257
0
  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
258
0
  p << " : " << type;
259
0
}
Unexecuted instantiation: Ops.cpp:_ZL16printAllocLikeOpIN4mlir7AllocOpEEvRNS0_12OpAsmPrinterET_N4llvm9StringRefE
Unexecuted instantiation: Ops.cpp:_ZL16printAllocLikeOpIN4mlir8AllocaOpEEvRNS0_12OpAsmPrinterET_N4llvm9StringRefE
260
261
0
static void print(OpAsmPrinter &p, AllocOp op) {
262
0
  printAllocLikeOp(p, op, "alloc");
263
0
}
264
265
0
static void print(OpAsmPrinter &p, AllocaOp op) {
266
0
  printAllocLikeOp(p, op, "alloca");
267
0
}
268
269
static ParseResult parseAllocLikeOp(OpAsmParser &parser,
270
0
                                    OperationState &result) {
271
0
  MemRefType type;
272
0
273
0
  // Parse the dimension operands and optional symbol operands, followed by a
274
0
  // memref type.
275
0
  unsigned numDimOperands;
276
0
  if (parseDimAndSymbolList(parser, result.operands, numDimOperands) ||
277
0
      parser.parseOptionalAttrDict(result.attributes) ||
278
0
      parser.parseColonType(type))
279
0
    return failure();
280
0
281
0
  // Check numDynamicDims against number of question marks in memref type.
282
0
  // Note: this check remains here (instead of in verify()), because the
283
0
  // partition between dim operands and symbol operands is lost after parsing.
284
0
  // Verification still checks that the total number of operands matches
285
0
  // the number of symbols in the affine map, plus the number of dynamic
286
0
  // dimensions in the memref.
287
0
  if (numDimOperands != type.getNumDynamicDims())
288
0
    return parser.emitError(parser.getNameLoc())
289
0
           << "dimension operand count does not equal memref dynamic dimension "
290
0
              "count";
291
0
  result.types.push_back(type);
292
0
  return success();
293
0
}
294
295
template <typename AllocLikeOp>
296
0
static LogicalResult verify(AllocLikeOp op) {
297
0
  static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
298
0
                "applies to only alloc or alloca");
299
0
  auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
300
0
  if (!memRefType)
301
0
    return op.emitOpError("result must be a memref");
302
0
303
0
  unsigned numSymbols = 0;
304
0
  if (!memRefType.getAffineMaps().empty()) {
305
0
    // Store number of symbols used in affine map (used in subsequent check).
306
0
    AffineMap affineMap = memRefType.getAffineMaps()[0];
307
0
    numSymbols = affineMap.getNumSymbols();
308
0
  }
309
0
310
0
  // Check that the total number of operands matches the number of symbols in
311
0
  // the affine map, plus the number of dynamic dimensions specified in the
312
0
  // memref type.
313
0
  unsigned numDynamicDims = memRefType.getNumDynamicDims();
314
0
  if (op.getNumOperands() != numDynamicDims + numSymbols)
315
0
    return op.emitOpError(
316
0
        "operand count does not equal dimension plus symbol operand count");
317
0
318
0
  // Verify that all operands are of type Index.
319
0
  for (auto operandType : op.getOperandTypes())
320
0
    if (!operandType.isIndex())
321
0
      return op.emitOpError("requires operands to be of type Index");
322
0
323
0
  if (std::is_same<AllocLikeOp, AllocOp>::value)
324
0
    return success();
325
0
326
0
  // An alloca op needs to have an ancestor with an allocation scope trait.
327
0
  if (!op.template getParentWithTrait<OpTrait::AutomaticAllocationScope>())
328
0
    return op.emitOpError(
329
0
        "requires an ancestor op with AutomaticAllocationScope trait");
330
0
331
0
  return success();
332
0
}
Unexecuted instantiation: Ops.cpp:_ZL6verifyIN4mlir7AllocOpEENS0_13LogicalResultET_
Unexecuted instantiation: Ops.cpp:_ZL6verifyIN4mlir8AllocaOpEENS0_13LogicalResultET_
333
334
namespace {
335
/// Fold constant dimensions into an alloc like operation.
336
template <typename AllocLikeOp>
337
struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
338
  using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
339
340
  LogicalResult matchAndRewrite(AllocLikeOp alloc,
341
0
                                PatternRewriter &rewriter) const override {
342
0
    // Check to see if any dimensions operands are constants.  If so, we can
343
0
    // substitute and drop them.
344
0
    if (llvm::none_of(alloc.getOperands(), [](Value operand) {
345
0
          return matchPattern(operand, m_ConstantIndex());
346
0
        }))
Unexecuted instantiation: Ops.cpp:_ZZNK12_GLOBAL__N_118SimplifyAllocConstIN4mlir7AllocOpEE15matchAndRewriteES2_RNS1_15PatternRewriterEENKUlNS1_5ValueEE_clES6_
Unexecuted instantiation: Ops.cpp:_ZZNK12_GLOBAL__N_118SimplifyAllocConstIN4mlir8AllocaOpEE15matchAndRewriteES2_RNS1_15PatternRewriterEENKUlNS1_5ValueEE_clES6_
347
0
      return failure();
348
0
349
0
    auto memrefType = alloc.getType();
350
0
351
0
    // Ok, we have one or more constant operands.  Collect the non-constant ones
352
0
    // and keep track of the resultant memref type to build.
353
0
    SmallVector<int64_t, 4> newShapeConstants;
354
0
    newShapeConstants.reserve(memrefType.getRank());
355
0
    SmallVector<Value, 4> newOperands;
356
0
357
0
    unsigned dynamicDimPos = 0;
358
0
    for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
359
0
      int64_t dimSize = memrefType.getDimSize(dim);
360
0
      // If this is already static dimension, keep it.
361
0
      if (dimSize != -1) {
362
0
        newShapeConstants.push_back(dimSize);
363
0
        continue;
364
0
      }
365
0
      auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
366
0
      if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
367
0
        // Dynamic shape dimension will be folded.
368
0
        newShapeConstants.push_back(constantIndexOp.getValue());
369
0
      } else {
370
0
        // Dynamic shape dimension not folded; copy operand from old memref.
371
0
        newShapeConstants.push_back(-1);
372
0
        newOperands.push_back(alloc.getOperand(dynamicDimPos));
373
0
      }
374
0
      dynamicDimPos++;
375
0
    }
376
0
377
0
    // Create new memref type (which will have fewer dynamic dimensions).
378
0
    MemRefType newMemRefType =
379
0
        MemRefType::Builder(memrefType).setShape(newShapeConstants);
380
0
    assert(static_cast<int64_t>(newOperands.size()) ==
381
0
           newMemRefType.getNumDynamicDims());
382
0
383
0
    // Create and insert the alloc op for the new memref.
384
0
    auto newAlloc = rewriter.create<AllocLikeOp>(alloc.getLoc(), newMemRefType,
385
0
                                                 newOperands, IntegerAttr());
386
0
    // Insert a cast so we have the same type as the old alloc.
387
0
    auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
388
0
                                                    alloc.getType());
389
0
390
0
    rewriter.replaceOp(alloc, {resultCast});
391
0
    return success();
392
0
  }
Unexecuted instantiation: Ops.cpp:_ZNK12_GLOBAL__N_118SimplifyAllocConstIN4mlir7AllocOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE
Unexecuted instantiation: Ops.cpp:_ZNK12_GLOBAL__N_118SimplifyAllocConstIN4mlir8AllocaOpEE15matchAndRewriteES2_RNS1_15PatternRewriterE
393
};
394
395
/// Fold alloc operations with no uses. Alloc has side effects on the heap,
396
/// but can still be deleted if it has zero uses.
397
struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
398
  using OpRewritePattern<AllocOp>::OpRewritePattern;
399
400
  LogicalResult matchAndRewrite(AllocOp alloc,
401
0
                                PatternRewriter &rewriter) const override {
402
0
    if (alloc.use_empty()) {
403
0
      rewriter.eraseOp(alloc);
404
0
      return success();
405
0
    }
406
0
    return failure();
407
0
  }
408
};
409
} // end anonymous namespace.
410
411
void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
412
0
                                          MLIRContext *context) {
413
0
  results.insert<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
414
0
}
415
416
void AllocaOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
417
0
                                           MLIRContext *context) {
418
0
  results.insert<SimplifyAllocConst<AllocaOp>>(context);
419
0
}
420
421
//===----------------------------------------------------------------------===//
422
// AndOp
423
//===----------------------------------------------------------------------===//
424
425
0
OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
426
0
  /// and(x, 0) -> 0
427
0
  if (matchPattern(rhs(), m_Zero()))
428
0
    return rhs();
429
0
  /// and(x,x) -> x
430
0
  if (lhs() == rhs())
431
0
    return rhs();
432
0
433
0
  return constFoldBinaryOp<IntegerAttr>(operands,
434
0
                                        [](APInt a, APInt b) { return a & b; });
435
0
}
436
437
//===----------------------------------------------------------------------===//
438
// AssumeAlignmentOp
439
//===----------------------------------------------------------------------===//
440
441
0
static LogicalResult verify(AssumeAlignmentOp op) {
442
0
  unsigned alignment = op.alignment().getZExtValue();
443
0
  if (!llvm::isPowerOf2_32(alignment))
444
0
    return op.emitOpError("alignment must be power of 2");
445
0
  return success();
446
0
}
447
448
//===----------------------------------------------------------------------===//
449
// AtomicRMWOp
450
//===----------------------------------------------------------------------===//
451
452
0
static LogicalResult verify(AtomicRMWOp op) {
453
0
  if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
454
0
    return op.emitOpError(
455
0
        "expects the number of subscripts to be equal to memref rank");
456
0
  switch (op.kind()) {
457
0
  case AtomicRMWKind::addf:
458
0
  case AtomicRMWKind::maxf:
459
0
  case AtomicRMWKind::minf:
460
0
  case AtomicRMWKind::mulf:
461
0
    if (!op.value().getType().isa<FloatType>())
462
0
      return op.emitOpError()
463
0
             << "with kind '" << stringifyAtomicRMWKind(op.kind())
464
0
             << "' expects a floating-point type";
465
0
    break;
466
0
  case AtomicRMWKind::addi:
467
0
  case AtomicRMWKind::maxs:
468
0
  case AtomicRMWKind::maxu:
469
0
  case AtomicRMWKind::mins:
470
0
  case AtomicRMWKind::minu:
471
0
  case AtomicRMWKind::muli:
472
0
    if (!op.value().getType().isa<IntegerType>())
473
0
      return op.emitOpError()
474
0
             << "with kind '" << stringifyAtomicRMWKind(op.kind())
475
0
             << "' expects an integer type";
476
0
    break;
477
0
  default:
478
0
    break;
479
0
  }
480
0
  return success();
481
0
}
482
483
//===----------------------------------------------------------------------===//
484
// GenericAtomicRMWOp
485
//===----------------------------------------------------------------------===//
486
487
void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
488
0
                               Value memref, ValueRange ivs) {
489
0
  result.addOperands(memref);
490
0
  result.addOperands(ivs);
491
0
492
0
  if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
493
0
    Type elementType = memrefType.getElementType();
494
0
    result.addTypes(elementType);
495
0
496
0
    Region *bodyRegion = result.addRegion();
497
0
    bodyRegion->push_back(new Block());
498
0
    bodyRegion->front().addArgument(elementType);
499
0
  }
500
0
}
501
502
0
static LogicalResult verify(GenericAtomicRMWOp op) {
503
0
  auto &block = op.body().front();
504
0
  if (block.getNumArguments() != 1)
505
0
    return op.emitOpError("expected single number of entry block arguments");
506
0
507
0
  if (op.getResult().getType() != block.getArgument(0).getType())
508
0
    return op.emitOpError(
509
0
        "expected block argument of the same type result type");
510
0
511
0
  bool hasSideEffects =
512
0
      op.body()
513
0
          .walk([&](Operation *nestedOp) {
514
0
            if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
515
0
              return WalkResult::advance();
516
0
            nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
517
0
                                "only operations with no side effects");
518
0
            return WalkResult::interrupt();
519
0
          })
520
0
          .wasInterrupted();
521
0
  return hasSideEffects ? failure() : success();
522
0
}
523
524
static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
525
0
                                           OperationState &result) {
526
0
  OpAsmParser::OperandType memref;
527
0
  Type memrefType;
528
0
  SmallVector<OpAsmParser::OperandType, 4> ivs;
529
0
530
0
  Type indexType = parser.getBuilder().getIndexType();
531
0
  if (parser.parseOperand(memref) ||
532
0
      parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
533
0
      parser.parseColonType(memrefType) ||
534
0
      parser.resolveOperand(memref, memrefType, result.operands) ||
535
0
      parser.resolveOperands(ivs, indexType, result.operands))
536
0
    return failure();
537
0
538
0
  Region *body = result.addRegion();
539
0
  if (parser.parseRegion(*body, llvm::None, llvm::None))
540
0
    return failure();
541
0
  result.types.push_back(memrefType.cast<MemRefType>().getElementType());
542
0
  return success();
543
0
}
544
545
0
static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
546
0
  p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices()
547
0
    << "] : " << op.memref().getType();
548
0
  p.printRegion(op.body());
549
0
  p.printOptionalAttrDict(op.getAttrs());
550
0
}
551
552
//===----------------------------------------------------------------------===//
553
// AtomicYieldOp
554
//===----------------------------------------------------------------------===//
555
556
0
static LogicalResult verify(AtomicYieldOp op) {
557
0
  Type parentType = op.getParentOp()->getResultTypes().front();
558
0
  Type resultType = op.result().getType();
559
0
  if (parentType != resultType)
560
0
    return op.emitOpError() << "types mismatch between yield op: " << resultType
561
0
                            << " and its parent: " << parentType;
562
0
  return success();
563
0
}
564
565
//===----------------------------------------------------------------------===//
566
// BranchOp
567
//===----------------------------------------------------------------------===//
568
569
/// Given a successor, try to collapse it to a new destination if it only
570
/// contains a passthrough unconditional branch. If the successor is
571
/// collapsable, `successor` and `successorOperands` are updated to reference
572
/// the new destination and values. `argStorage` is an optional storage to use
573
/// if operands to the collapsed successor need to be remapped.
574
static LogicalResult collapseBranch(Block *&successor,
575
                                    ValueRange &successorOperands,
576
0
                                    SmallVectorImpl<Value> &argStorage) {
577
0
  // Check that the successor only contains a unconditional branch.
578
0
  if (std::next(successor->begin()) != successor->end())
579
0
    return failure();
580
0
  // Check that the terminator is an unconditional branch.
581
0
  BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
582
0
  if (!successorBranch)
583
0
    return failure();
584
0
  // Check that the arguments are only used within the terminator.
585
0
  for (BlockArgument arg : successor->getArguments()) {
586
0
    for (Operation *user : arg.getUsers())
587
0
      if (user != successorBranch)
588
0
        return failure();
589
0
  }
590
0
  // Don't try to collapse branches to infinite loops.
591
0
  Block *successorDest = successorBranch.getDest();
592
0
  if (successorDest == successor)
593
0
    return failure();
594
0
595
0
  // Update the operands to the successor. If the branch parent has no
596
0
  // arguments, we can use the branch operands directly.
597
0
  OperandRange operands = successorBranch.getOperands();
598
0
  if (successor->args_empty()) {
599
0
    successor = successorDest;
600
0
    successorOperands = operands;
601
0
    return success();
602
0
  }
603
0
604
0
  // Otherwise, we need to remap any argument operands.
605
0
  for (Value operand : operands) {
606
0
    BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
607
0
    if (argOperand && argOperand.getOwner() == successor)
608
0
      argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
609
0
    else
610
0
      argStorage.push_back(operand);
611
0
  }
612
0
  successor = successorDest;
613
0
  successorOperands = argStorage;
614
0
  return success();
615
0
}
616
617
namespace {
618
/// Simplify a branch to a block that has a single predecessor. This effectively
619
/// merges the two blocks.
620
struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
621
  using OpRewritePattern<BranchOp>::OpRewritePattern;
622
623
  LogicalResult matchAndRewrite(BranchOp op,
624
0
                                PatternRewriter &rewriter) const override {
625
0
    // Check that the successor block has a single predecessor.
626
0
    Block *succ = op.getDest();
627
0
    Block *opParent = op.getOperation()->getBlock();
628
0
    if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
629
0
      return failure();
630
0
631
0
    // Merge the successor into the current block and erase the branch.
632
0
    rewriter.mergeBlocks(succ, opParent, op.getOperands());
633
0
    rewriter.eraseOp(op);
634
0
    return success();
635
0
  }
636
};
637
638
///   br ^bb1
639
/// ^bb1
640
///   br ^bbN(...)
641
///
642
///  -> br ^bbN(...)
643
///
644
struct SimplifyPassThroughBr : public OpRewritePattern<BranchOp> {
645
  using OpRewritePattern<BranchOp>::OpRewritePattern;
646
647
  LogicalResult matchAndRewrite(BranchOp op,
648
0
                                PatternRewriter &rewriter) const override {
649
0
    Block *dest = op.getDest();
650
0
    ValueRange destOperands = op.getOperands();
651
0
    SmallVector<Value, 4> destOperandStorage;
652
0
653
0
    // Try to collapse the successor if it points somewhere other than this
654
0
    // block.
655
0
    if (dest == op.getOperation()->getBlock() ||
656
0
        failed(collapseBranch(dest, destOperands, destOperandStorage)))
657
0
      return failure();
658
0
659
0
    // Create a new branch with the collapsed successor.
660
0
    rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
661
0
    return success();
662
0
  }
663
};
664
} // end anonymous namespace.
665
666
0
Block *BranchOp::getDest() { return getSuccessor(); }
667
668
0
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
669
670
0
void BranchOp::eraseOperand(unsigned index) {
671
0
  getOperation()->eraseOperand(index);
672
0
}
673
674
void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
675
0
                                           MLIRContext *context) {
676
0
  results.insert<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>(
677
0
      context);
678
0
}
679
680
Optional<MutableOperandRange>
681
0
BranchOp::getMutableSuccessorOperands(unsigned index) {
682
0
  assert(index == 0 && "invalid successor index");
683
0
  return destOperandsMutable();
684
0
}
685
686
0
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
687
688
//===----------------------------------------------------------------------===//
689
// CallOp
690
//===----------------------------------------------------------------------===//
691
692
0
static LogicalResult verify(CallOp op) {
693
0
  // Check that the callee attribute was specified.
694
0
  auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
695
0
  if (!fnAttr)
696
0
    return op.emitOpError("requires a 'callee' symbol reference attribute");
697
0
  auto fn =
698
0
      op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
699
0
  if (!fn)
700
0
    return op.emitOpError() << "'" << fnAttr.getValue()
701
0
                            << "' does not reference a valid function";
702
0
703
0
  // Verify that the operand and result types match the callee.
704
0
  auto fnType = fn.getType();
705
0
  if (fnType.getNumInputs() != op.getNumOperands())
706
0
    return op.emitOpError("incorrect number of operands for callee");
707
0
708
0
  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
709
0
    if (op.getOperand(i).getType() != fnType.getInput(i))
710
0
      return op.emitOpError("operand type mismatch");
711
0
712
0
  if (fnType.getNumResults() != op.getNumResults())
713
0
    return op.emitOpError("incorrect number of results for callee");
714
0
715
0
  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
716
0
    if (op.getResult(i).getType() != fnType.getResult(i))
717
0
      return op.emitOpError("result type mismatch");
718
0
719
0
  return success();
720
0
}
721
722
0
FunctionType CallOp::getCalleeType() {
723
0
  SmallVector<Type, 8> argTypes(getOperandTypes());
724
0
  return FunctionType::get(argTypes, getResultTypes(), getContext());
725
0
}
726
727
//===----------------------------------------------------------------------===//
728
// CallIndirectOp
729
//===----------------------------------------------------------------------===//
730
namespace {
731
/// Fold indirect calls that have a constant function as the callee operand.
732
struct SimplifyIndirectCallWithKnownCallee
733
    : public OpRewritePattern<CallIndirectOp> {
734
  using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
735
736
  LogicalResult matchAndRewrite(CallIndirectOp indirectCall,
737
0
                                PatternRewriter &rewriter) const override {
738
0
    // Check that the callee is a constant callee.
739
0
    SymbolRefAttr calledFn;
740
0
    if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
741
0
      return failure();
742
0
743
0
    // Replace with a direct call.
744
0
    rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
745
0
                                        indirectCall.getResultTypes(),
746
0
                                        indirectCall.getArgOperands());
747
0
    return success();
748
0
  }
749
};
750
} // end anonymous namespace.
751
752
void CallIndirectOp::getCanonicalizationPatterns(
753
0
    OwningRewritePatternList &results, MLIRContext *context) {
754
0
  results.insert<SimplifyIndirectCallWithKnownCallee>(context);
755
0
}
756
757
//===----------------------------------------------------------------------===//
758
// General helpers for comparison ops
759
//===----------------------------------------------------------------------===//
760
761
// Return the type of the same shape (scalar, vector or tensor) containing i1.
762
0
static Type getI1SameShape(Type type) {
763
0
  auto i1Type = IntegerType::get(1, type.getContext());
764
0
  if (auto tensorType = type.dyn_cast<RankedTensorType>())
765
0
    return RankedTensorType::get(tensorType.getShape(), i1Type);
766
0
  if (type.isa<UnrankedTensorType>())
767
0
    return UnrankedTensorType::get(i1Type);
768
0
  if (auto vectorType = type.dyn_cast<VectorType>())
769
0
    return VectorType::get(vectorType.getShape(), i1Type);
770
0
  return i1Type;
771
0
}
772
773
//===----------------------------------------------------------------------===//
774
// CmpIOp
775
//===----------------------------------------------------------------------===//
776
777
static void buildCmpIOp(OpBuilder &build, OperationState &result,
778
0
                        CmpIPredicate predicate, Value lhs, Value rhs) {
779
0
  result.addOperands({lhs, rhs});
780
0
  result.types.push_back(getI1SameShape(lhs.getType()));
781
0
  result.addAttribute(CmpIOp::getPredicateAttrName(),
782
0
                      build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
783
0
}
784
785
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
786
// comparison predicates.
787
static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
788
0
                              const APInt &rhs) {
789
0
  switch (predicate) {
790
0
  case CmpIPredicate::eq:
791
0
    return lhs.eq(rhs);
792
0
  case CmpIPredicate::ne:
793
0
    return lhs.ne(rhs);
794
0
  case CmpIPredicate::slt:
795
0
    return lhs.slt(rhs);
796
0
  case CmpIPredicate::sle:
797
0
    return lhs.sle(rhs);
798
0
  case CmpIPredicate::sgt:
799
0
    return lhs.sgt(rhs);
800
0
  case CmpIPredicate::sge:
801
0
    return lhs.sge(rhs);
802
0
  case CmpIPredicate::ult:
803
0
    return lhs.ult(rhs);
804
0
  case CmpIPredicate::ule:
805
0
    return lhs.ule(rhs);
806
0
  case CmpIPredicate::ugt:
807
0
    return lhs.ugt(rhs);
808
0
  case CmpIPredicate::uge:
809
0
    return lhs.uge(rhs);
810
0
  }
811
0
  llvm_unreachable("unknown comparison predicate");
812
0
}
813
814
// Constant folding hook for comparisons.
815
0
OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
816
0
  assert(operands.size() == 2 && "cmpi takes two arguments");
817
0
818
0
  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
819
0
  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
820
0
  if (!lhs || !rhs)
821
0
    return {};
822
0
823
0
  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
824
0
  return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
825
0
}
826
827
//===----------------------------------------------------------------------===//
828
// CmpFOp
829
//===----------------------------------------------------------------------===//
830
831
static void buildCmpFOp(OpBuilder &build, OperationState &result,
832
0
                        CmpFPredicate predicate, Value lhs, Value rhs) {
833
0
  result.addOperands({lhs, rhs});
834
0
  result.types.push_back(getI1SameShape(lhs.getType()));
835
0
  result.addAttribute(CmpFOp::getPredicateAttrName(),
836
0
                      build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
837
0
}
838
839
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
840
/// comparison predicates.
841
static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
842
0
                              const APFloat &rhs) {
843
0
  auto cmpResult = lhs.compare(rhs);
844
0
  switch (predicate) {
845
0
  case CmpFPredicate::AlwaysFalse:
846
0
    return false;
847
0
  case CmpFPredicate::OEQ:
848
0
    return cmpResult == APFloat::cmpEqual;
849
0
  case CmpFPredicate::OGT:
850
0
    return cmpResult == APFloat::cmpGreaterThan;
851
0
  case CmpFPredicate::OGE:
852
0
    return cmpResult == APFloat::cmpGreaterThan ||
853
0
           cmpResult == APFloat::cmpEqual;
854
0
  case CmpFPredicate::OLT:
855
0
    return cmpResult == APFloat::cmpLessThan;
856
0
  case CmpFPredicate::OLE:
857
0
    return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
858
0
  case CmpFPredicate::ONE:
859
0
    return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
860
0
  case CmpFPredicate::ORD:
861
0
    return cmpResult != APFloat::cmpUnordered;
862
0
  case CmpFPredicate::UEQ:
863
0
    return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
864
0
  case CmpFPredicate::UGT:
865
0
    return cmpResult == APFloat::cmpUnordered ||
866
0
           cmpResult == APFloat::cmpGreaterThan;
867
0
  case CmpFPredicate::UGE:
868
0
    return cmpResult == APFloat::cmpUnordered ||
869
0
           cmpResult == APFloat::cmpGreaterThan ||
870
0
           cmpResult == APFloat::cmpEqual;
871
0
  case CmpFPredicate::ULT:
872
0
    return cmpResult == APFloat::cmpUnordered ||
873
0
           cmpResult == APFloat::cmpLessThan;
874
0
  case CmpFPredicate::ULE:
875
0
    return cmpResult == APFloat::cmpUnordered ||
876
0
           cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
877
0
  case CmpFPredicate::UNE:
878
0
    return cmpResult != APFloat::cmpEqual;
879
0
  case CmpFPredicate::UNO:
880
0
    return cmpResult == APFloat::cmpUnordered;
881
0
  case CmpFPredicate::AlwaysTrue:
882
0
    return true;
883
0
  }
884
0
  llvm_unreachable("unknown comparison predicate");
885
0
}
886
887
// Constant folding hook for comparisons.
888
0
OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
889
0
  assert(operands.size() == 2 && "cmpf takes two arguments");
890
0
891
0
  auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
892
0
  auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
893
0
894
0
  // TODO(gcmn) We could actually do some intelligent things if we know only one
895
0
  // of the operands, but it's inf or nan.
896
0
  if (!lhs || !rhs)
897
0
    return {};
898
0
899
0
  auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
900
0
  return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
901
0
}
902
903
//===----------------------------------------------------------------------===//
904
// CondBranchOp
905
//===----------------------------------------------------------------------===//
906
907
namespace {
908
/// cond_br true, ^bb1, ^bb2
909
///  -> br ^bb1
910
/// cond_br false, ^bb1, ^bb2
911
///  -> br ^bb2
912
///
913
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
914
  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
915
916
  LogicalResult matchAndRewrite(CondBranchOp condbr,
917
0
                                PatternRewriter &rewriter) const override {
918
0
    if (matchPattern(condbr.getCondition(), m_NonZero())) {
919
0
      // True branch taken.
920
0
      rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
921
0
                                            condbr.getTrueOperands());
922
0
      return success();
923
0
    } else if (matchPattern(condbr.getCondition(), m_Zero())) {
924
0
      // False branch taken.
925
0
      rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
926
0
                                            condbr.getFalseOperands());
927
0
      return success();
928
0
    }
929
0
    return failure();
930
0
  }
931
};
932
933
///   cond_br %cond, ^bb1, ^bb2
934
/// ^bb1
935
///   br ^bbN(...)
936
/// ^bb2
937
///   br ^bbK(...)
938
///
939
///  -> cond_br %cond, ^bbN(...), ^bbK(...)
940
///
941
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
942
  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
943
944
  LogicalResult matchAndRewrite(CondBranchOp condbr,
945
0
                                PatternRewriter &rewriter) const override {
946
0
    Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest();
947
0
    ValueRange trueDestOperands = condbr.getTrueOperands();
948
0
    ValueRange falseDestOperands = condbr.getFalseOperands();
949
0
    SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
950
0
951
0
    // Try to collapse one of the current successors.
952
0
    LogicalResult collapsedTrue =
953
0
        collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
954
0
    LogicalResult collapsedFalse =
955
0
        collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
956
0
    if (failed(collapsedTrue) && failed(collapsedFalse))
957
0
      return failure();
958
0
959
0
    // Create a new branch with the collapsed successors.
960
0
    rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
961
0
                                              trueDest, trueDestOperands,
962
0
                                              falseDest, falseDestOperands);
963
0
    return success();
964
0
  }
965
};
966
967
/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
968
///  -> br ^bb1(A, ..., N)
969
///
970
/// cond_br %cond, ^bb1(A), ^bb1(B)
971
///  -> %select = select %cond, A, B
972
///     br ^bb1(%select)
973
///
974
struct SimplifyCondBranchIdenticalSuccessors
975
    : public OpRewritePattern<CondBranchOp> {
976
  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
977
978
  LogicalResult matchAndRewrite(CondBranchOp condbr,
979
0
                                PatternRewriter &rewriter) const override {
980
0
    // Check that the true and false destinations are the same and have the same
981
0
    // operands.
982
0
    Block *trueDest = condbr.trueDest();
983
0
    if (trueDest != condbr.falseDest())
984
0
      return failure();
985
0
986
0
    // If all of the operands match, no selects need to be generated.
987
0
    OperandRange trueOperands = condbr.getTrueOperands();
988
0
    OperandRange falseOperands = condbr.getFalseOperands();
989
0
    if (trueOperands == falseOperands) {
990
0
      rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
991
0
      return success();
992
0
    }
993
0
994
0
    // Otherwise, if the current block is the only predecessor insert selects
995
0
    // for any mismatched branch operands.
996
0
    if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock())
997
0
      return failure();
998
0
999
0
    // Generate a select for any operands that differ between the two.
1000
0
    SmallVector<Value, 8> mergedOperands;
1001
0
    mergedOperands.reserve(trueOperands.size());
1002
0
    Value condition = condbr.getCondition();
1003
0
    for (auto it : llvm::zip(trueOperands, falseOperands)) {
1004
0
      if (std::get<0>(it) == std::get<1>(it))
1005
0
        mergedOperands.push_back(std::get<0>(it));
1006
0
      else
1007
0
        mergedOperands.push_back(rewriter.create<SelectOp>(
1008
0
            condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
1009
0
    }
1010
0
1011
0
    rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
1012
0
    return success();
1013
0
  }
1014
};
1015
} // end anonymous namespace
1016
1017
void CondBranchOp::getCanonicalizationPatterns(
1018
0
    OwningRewritePatternList &results, MLIRContext *context) {
1019
0
  results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
1020
0
                 SimplifyCondBranchIdenticalSuccessors>(context);
1021
0
}
1022
1023
Optional<MutableOperandRange>
1024
0
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
1025
0
  assert(index < getNumSuccessors() && "invalid successor index");
1026
0
  return index == trueIndex ? trueDestOperandsMutable()
1027
0
                            : falseDestOperandsMutable();
1028
0
}
1029
1030
0
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1031
0
  if (BoolAttr condAttr = operands.front().dyn_cast_or_null<BoolAttr>())
1032
0
    return condAttr.getValue() ? trueDest() : falseDest();
1033
0
  if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
1034
0
    return condAttr.getValue().isOneValue() ? trueDest() : falseDest();
1035
0
  return nullptr;
1036
0
}
1037
1038
//===----------------------------------------------------------------------===//
1039
// Constant*Op
1040
//===----------------------------------------------------------------------===//
1041
1042
0
static void print(OpAsmPrinter &p, ConstantOp &op) {
1043
0
  p << "constant ";
1044
0
  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
1045
0
1046
0
  if (op.getAttrs().size() > 1)
1047
0
    p << ' ';
1048
0
  p << op.getValue();
1049
0
1050
0
  // If the value is a symbol reference, print a trailing type.
1051
0
  if (op.getValue().isa<SymbolRefAttr>())
1052
0
    p << " : " << op.getType();
1053
0
}
1054
1055
static ParseResult parseConstantOp(OpAsmParser &parser,
1056
0
                                   OperationState &result) {
1057
0
  Attribute valueAttr;
1058
0
  if (parser.parseOptionalAttrDict(result.attributes) ||
1059
0
      parser.parseAttribute(valueAttr, "value", result.attributes))
1060
0
    return failure();
1061
0
1062
0
  // If the attribute is a symbol reference, then we expect a trailing type.
1063
0
  Type type;
1064
0
  if (!valueAttr.isa<SymbolRefAttr>())
1065
0
    type = valueAttr.getType();
1066
0
  else if (parser.parseColonType(type))
1067
0
    return failure();
1068
0
1069
0
  // Add the attribute type to the list.
1070
0
  return parser.addTypeToList(type, result.types);
1071
0
}
1072
1073
/// The constant op requires an attribute, and furthermore requires that it
1074
/// matches the return type.
1075
0
static LogicalResult verify(ConstantOp &op) {
1076
0
  auto value = op.getValue();
1077
0
  if (!value)
1078
0
    return op.emitOpError("requires a 'value' attribute");
1079
0
1080
0
  auto type = op.getType();
1081
0
  if (!value.getType().isa<NoneType>() && type != value.getType())
1082
0
    return op.emitOpError() << "requires attribute's type (" << value.getType()
1083
0
                            << ") to match op's return type (" << type << ")";
1084
0
1085
0
  if (type.isa<IndexType>() || value.isa<BoolAttr>())
1086
0
    return success();
1087
0
1088
0
  if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
1089
0
    // If the type has a known bitwidth we verify that the value can be
1090
0
    // represented with the given bitwidth.
1091
0
    auto bitwidth = type.cast<IntegerType>().getWidth();
1092
0
    auto intVal = intAttr.getValue();
1093
0
    if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
1094
0
      return op.emitOpError("requires 'value' to be an integer within the "
1095
0
                            "range of the integer result type");
1096
0
    return success();
1097
0
  }
1098
0
1099
0
  if (type.isa<FloatType>()) {
1100
0
    if (!value.isa<FloatAttr>())
1101
0
      return op.emitOpError("requires 'value' to be a floating point constant");
1102
0
    return success();
1103
0
  }
1104
0
1105
0
  if (type.isa<ShapedType>()) {
1106
0
    if (!value.isa<ElementsAttr>())
1107
0
      return op.emitOpError("requires 'value' to be a shaped constant");
1108
0
    return success();
1109
0
  }
1110
0
1111
0
  if (type.isa<FunctionType>()) {
1112
0
    auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
1113
0
    if (!fnAttr)
1114
0
      return op.emitOpError("requires 'value' to be a function reference");
1115
0
1116
0
    // Try to find the referenced function.
1117
0
    auto fn =
1118
0
        op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
1119
0
    if (!fn)
1120
0
      return op.emitOpError("reference to undefined function 'bar'");
1121
0
1122
0
    // Check that the referenced function has the correct type.
1123
0
    if (fn.getType() != type)
1124
0
      return op.emitOpError("reference to function with mismatched type");
1125
0
1126
0
    return success();
1127
0
  }
1128
0
1129
0
  if (type.isa<NoneType>() && value.isa<UnitAttr>())
1130
0
    return success();
1131
0
1132
0
  return op.emitOpError("unsupported 'value' attribute: ") << value;
1133
0
}
1134
1135
0
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
1136
0
  assert(operands.empty() && "constant has no operands");
1137
0
  return getValue();
1138
0
}
1139
1140
void ConstantOp::getAsmResultNames(
1141
0
    function_ref<void(Value, StringRef)> setNameFn) {
1142
0
  Type type = getType();
1143
0
  if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
1144
0
    IntegerType intTy = type.dyn_cast<IntegerType>();
1145
0
1146
0
    // Sugar i1 constants with 'true' and 'false'.
1147
0
    if (intTy && intTy.getWidth() == 1)
1148
0
      return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1149
0
1150
0
    // Otherwise, build a complex name with the value and type.
1151
0
    SmallString<32> specialNameBuffer;
1152
0
    llvm::raw_svector_ostream specialName(specialNameBuffer);
1153
0
    specialName << 'c' << intCst.getInt();
1154
0
    if (intTy)
1155
0
      specialName << '_' << type;
1156
0
    setNameFn(getResult(), specialName.str());
1157
0
1158
0
  } else if (type.isa<FunctionType>()) {
1159
0
    setNameFn(getResult(), "f");
1160
0
  } else {
1161
0
    setNameFn(getResult(), "cst");
1162
0
  }
1163
0
}
1164
1165
/// Returns true if a constant operation can be built with the given value and
1166
/// result type.
1167
0
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
1168
0
  // SymbolRefAttr can only be used with a function type.
1169
0
  if (value.isa<SymbolRefAttr>())
1170
0
    return type.isa<FunctionType>();
1171
0
  // Otherwise, the attribute must have the same type as 'type'.
1172
0
  if (value.getType() != type)
1173
0
    return false;
1174
0
  // Finally, check that the attribute kind is handled.
1175
0
  return value.isa<BoolAttr>() || value.isa<IntegerAttr>() ||
1176
0
         value.isa<FloatAttr>() || value.isa<ElementsAttr>() ||
1177
0
         value.isa<UnitAttr>();
1178
0
}
1179
1180
void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
1181
0
                            const APFloat &value, FloatType type) {
1182
0
  ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value));
1183
0
}
1184
1185
0
bool ConstantFloatOp::classof(Operation *op) {
1186
0
  return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>();
1187
0
}
1188
1189
/// ConstantIntOp only matches values whose result type is an IntegerType.
1190
0
bool ConstantIntOp::classof(Operation *op) {
1191
0
  return ConstantOp::classof(op) &&
1192
0
         op->getResult(0).getType().isSignlessInteger();
1193
0
}
1194
1195
void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1196
0
                          int64_t value, unsigned width) {
1197
0
  Type type = builder.getIntegerType(width);
1198
0
  ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1199
0
}
1200
1201
/// Build a constant int op producing an integer with the specified type,
1202
/// which must be an integer type.
1203
void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1204
0
                          int64_t value, Type type) {
1205
0
  assert(type.isSignlessInteger() &&
1206
0
         "ConstantIntOp can only have signless integer type");
1207
0
  ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1208
0
}
1209
1210
/// ConstantIndexOp only matches values whose result type is Index.
1211
0
bool ConstantIndexOp::classof(Operation *op) {
1212
0
  return ConstantOp::classof(op) && op->getResult(0).getType().isIndex();
1213
0
}
1214
1215
void ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
1216
0
                            int64_t value) {
1217
0
  Type type = builder.getIndexType();
1218
0
  ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1219
0
}
1220
1221
//===----------------------------------------------------------------------===//
1222
// DeallocOp
1223
//===----------------------------------------------------------------------===//
1224
namespace {
1225
/// Fold Dealloc operations that are deallocating an AllocOp that is only used
1226
/// by other Dealloc operations.
1227
struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
1228
  using OpRewritePattern<DeallocOp>::OpRewritePattern;
1229
1230
  LogicalResult matchAndRewrite(DeallocOp dealloc,
1231
0
                                PatternRewriter &rewriter) const override {
1232
0
    // Check that the memref operand's defining operation is an AllocOp.
1233
0
    Value memref = dealloc.memref();
1234
0
    if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
1235
0
      return failure();
1236
0
1237
0
    // Check that all of the uses of the AllocOp are other DeallocOps.
1238
0
    for (auto *user : memref.getUsers())
1239
0
      if (!isa<DeallocOp>(user))
1240
0
        return failure();
1241
0
1242
0
    // Erase the dealloc operation.
1243
0
    rewriter.eraseOp(dealloc);
1244
0
    return success();
1245
0
  }
1246
};
1247
} // end anonymous namespace.
1248
1249
0
static LogicalResult verify(DeallocOp op) {
1250
0
  if (!op.memref().getType().isa<MemRefType>())
1251
0
    return op.emitOpError("operand must be a memref");
1252
0
  return success();
1253
0
}
1254
1255
void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1256
0
                                            MLIRContext *context) {
1257
0
  results.insert<SimplifyDeadDealloc>(context);
1258
0
}
1259
1260
LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
1261
0
                              SmallVectorImpl<OpFoldResult> &results) {
1262
0
  /// dealloc(memrefcast) -> dealloc
1263
0
  return foldMemRefCast(*this);
1264
0
}
1265
1266
//===----------------------------------------------------------------------===//
1267
// DimOp
1268
//===----------------------------------------------------------------------===//
1269
1270
0
static void print(OpAsmPrinter &p, DimOp op) {
1271
0
  p << "dim " << op.getOperand() << ", " << op.getIndex();
1272
0
  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
1273
0
  p << " : " << op.getOperand().getType();
1274
0
}
1275
1276
0
static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) {
1277
0
  OpAsmParser::OperandType operandInfo;
1278
0
  IntegerAttr indexAttr;
1279
0
  Type type;
1280
0
  Type indexType = parser.getBuilder().getIndexType();
1281
0
1282
0
  return failure(
1283
0
      parser.parseOperand(operandInfo) || parser.parseComma() ||
1284
0
      parser.parseAttribute(indexAttr, indexType, "index", result.attributes) ||
1285
0
      parser.parseOptionalAttrDict(result.attributes) ||
1286
0
      parser.parseColonType(type) ||
1287
0
      parser.resolveOperand(operandInfo, type, result.operands) ||
1288
0
      parser.addTypeToList(indexType, result.types));
1289
0
}
1290
1291
0
static LogicalResult verify(DimOp op) {
1292
0
  // Check that we have an integer index operand.
1293
0
  auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
1294
0
  if (!indexAttr)
1295
0
    return op.emitOpError("requires an integer attribute named 'index'");
1296
0
  int64_t index = indexAttr.getInt();
1297
0
1298
0
  auto type = op.getOperand().getType();
1299
0
  if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
1300
0
    if (index >= tensorType.getRank())
1301
0
      return op.emitOpError("index is out of range");
1302
0
  } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
1303
0
    if (index >= memrefType.getRank())
1304
0
      return op.emitOpError("index is out of range");
1305
0
1306
0
  } else if (type.isa<UnrankedTensorType>()) {
1307
0
    // ok, assumed to be in-range.
1308
0
  } else {
1309
0
    return op.emitOpError("requires an operand with tensor or memref type");
1310
0
  }
1311
0
1312
0
  return success();
1313
0
}
1314
1315
0
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
1316
0
  // Constant fold dim when the size along the index referred to is a constant.
1317
0
  auto opType = memrefOrTensor().getType();
1318
0
  if (auto shapedType = opType.dyn_cast<ShapedType>())
1319
0
    if (!shapedType.isDynamicDim(getIndex()))
1320
0
      return IntegerAttr::get(IndexType::get(getContext()),
1321
0
                              shapedType.getShape()[getIndex()]);
1322
0
1323
0
  // Fold dim to the size argument for an AllocOp/ViewOp/SubViewOp.
1324
0
  auto memrefType = opType.dyn_cast<MemRefType>();
1325
0
  if (!memrefType)
1326
0
    return {};
1327
0
1328
0
  // The size at getIndex() is now known to be a dynamic size of a memref.
1329
0
  auto memref = memrefOrTensor().getDefiningOp();
1330
0
  if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
1331
0
    return *(alloc.getDynamicSizes().begin() +
1332
0
             memrefType.getDynamicDimIndex(getIndex()));
1333
0
1334
0
  if (auto view = dyn_cast_or_null<ViewOp>(memref))
1335
0
    return *(view.getDynamicSizes().begin() +
1336
0
             memrefType.getDynamicDimIndex(getIndex()));
1337
0
1338
0
  if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
1339
0
    assert(subview.isDynamicSize(getIndex()) &&
1340
0
           "Expected dynamic subview size");
1341
0
    return subview.getDynamicSize(getIndex());
1342
0
  }
1343
0
1344
0
  /// dim(memrefcast) -> dim
1345
0
  if (succeeded(foldMemRefCast(*this)))
1346
0
    return getResult();
1347
0
1348
0
  return {};
1349
0
}
1350
1351
// ---------------------------------------------------------------------------
1352
// DmaStartOp
1353
// ---------------------------------------------------------------------------
1354
1355
void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1356
                       Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1357
                       ValueRange destIndices, Value numElements,
1358
                       Value tagMemRef, ValueRange tagIndices, Value stride,
1359
0
                       Value elementsPerStride) {
1360
0
  result.addOperands(srcMemRef);
1361
0
  result.addOperands(srcIndices);
1362
0
  result.addOperands(destMemRef);
1363
0
  result.addOperands(destIndices);
1364
0
  result.addOperands({numElements, tagMemRef});
1365
0
  result.addOperands(tagIndices);
1366
0
  if (stride)
1367
0
    result.addOperands({stride, elementsPerStride});
1368
0
}
1369
1370
0
void DmaStartOp::print(OpAsmPrinter &p) {
1371
0
  p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1372
0
    << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1373
0
    << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1374
0
  if (isStrided())
1375
0
    p << ", " << getStride() << ", " << getNumElementsPerStride();
1376
0
1377
0
  p.printOptionalAttrDict(getAttrs());
1378
0
  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1379
0
    << ", " << getTagMemRef().getType();
1380
0
}
1381
1382
// Parse DmaStartOp.
1383
// Ex:
1384
//   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1385
//                       %tag[%index], %stride, %num_elt_per_stride :
1386
//                     : memref<3076 x f32, 0>,
1387
//                       memref<1024 x f32, 2>,
1388
//                       memref<1 x i32>
1389
//
1390
0
ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1391
0
  OpAsmParser::OperandType srcMemRefInfo;
1392
0
  SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
1393
0
  OpAsmParser::OperandType dstMemRefInfo;
1394
0
  SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
1395
0
  OpAsmParser::OperandType numElementsInfo;
1396
0
  OpAsmParser::OperandType tagMemrefInfo;
1397
0
  SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
1398
0
  SmallVector<OpAsmParser::OperandType, 2> strideInfo;
1399
0
1400
0
  SmallVector<Type, 3> types;
1401
0
  auto indexType = parser.getBuilder().getIndexType();
1402
0
1403
0
  // Parse and resolve the following list of operands:
1404
0
  // *) source memref followed by its indices (in square brackets).
1405
0
  // *) destination memref followed by its indices (in square brackets).
1406
0
  // *) dma size in KiB.
1407
0
  if (parser.parseOperand(srcMemRefInfo) ||
1408
0
      parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1409
0
      parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1410
0
      parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1411
0
      parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1412
0
      parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1413
0
      parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1414
0
    return failure();
1415
0
1416
0
  // Parse optional stride and elements per stride.
1417
0
  if (parser.parseTrailingOperandList(strideInfo))
1418
0
    return failure();
1419
0
1420
0
  bool isStrided = strideInfo.size() == 2;
1421
0
  if (!strideInfo.empty() && !isStrided) {
1422
0
    return parser.emitError(parser.getNameLoc(),
1423
0
                            "expected two stride related operands");
1424
0
  }
1425
0
1426
0
  if (parser.parseColonTypeList(types))
1427
0
    return failure();
1428
0
  if (types.size() != 3)
1429
0
    return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1430
0
1431
0
  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1432
0
      parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1433
0
      parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1434
0
      parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1435
0
      // size should be an index.
1436
0
      parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1437
0
      parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1438
0
      // tag indices should be index.
1439
0
      parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1440
0
    return failure();
1441
0
1442
0
  if (isStrided) {
1443
0
    if (parser.resolveOperands(strideInfo, indexType, result.operands))
1444
0
      return failure();
1445
0
  }
1446
0
1447
0
  return success();
1448
0
}
1449
1450
0
LogicalResult DmaStartOp::verify() {
1451
0
  unsigned numOperands = getNumOperands();
1452
0
1453
0
  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1454
0
  // the number of elements.
1455
0
  if (numOperands < 4)
1456
0
    return emitOpError("expected at least 4 operands");
1457
0
1458
0
  // Check types of operands. The order of these calls is important: the later
1459
0
  // calls rely on some type properties to compute the operand position.
1460
0
  // 1. Source memref.
1461
0
  if (!getSrcMemRef().getType().isa<MemRefType>())
1462
0
    return emitOpError("expected source to be of memref type");
1463
0
  if (numOperands < getSrcMemRefRank() + 4)
1464
0
    return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1465
0
                         << " operands";
1466
0
  if (!getSrcIndices().empty() &&
1467
0
      !llvm::all_of(getSrcIndices().getTypes(),
1468
0
                    [](Type t) { return t.isIndex(); }))
1469
0
    return emitOpError("expected source indices to be of index type");
1470
0
1471
0
  // 2. Destination memref.
1472
0
  if (!getDstMemRef().getType().isa<MemRefType>())
1473
0
    return emitOpError("expected destination to be of memref type");
1474
0
  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1475
0
  if (numOperands < numExpectedOperands)
1476
0
    return emitOpError() << "expected at least " << numExpectedOperands
1477
0
                         << " operands";
1478
0
  if (!getDstIndices().empty() &&
1479
0
      !llvm::all_of(getDstIndices().getTypes(),
1480
0
                    [](Type t) { return t.isIndex(); }))
1481
0
    return emitOpError("expected destination indices to be of index type");
1482
0
1483
0
  // 3. Number of elements.
1484
0
  if (!getNumElements().getType().isIndex())
1485
0
    return emitOpError("expected num elements to be of index type");
1486
0
1487
0
  // 4. Tag memref.
1488
0
  if (!getTagMemRef().getType().isa<MemRefType>())
1489
0
    return emitOpError("expected tag to be of memref type");
1490
0
  numExpectedOperands += getTagMemRefRank();
1491
0
  if (numOperands < numExpectedOperands)
1492
0
    return emitOpError() << "expected at least " << numExpectedOperands
1493
0
                         << " operands";
1494
0
  if (!getTagIndices().empty() &&
1495
0
      !llvm::all_of(getTagIndices().getTypes(),
1496
0
                    [](Type t) { return t.isIndex(); }))
1497
0
    return emitOpError("expected tag indices to be of index type");
1498
0
1499
0
  // DMAs from different memory spaces supported.
1500
0
  if (getSrcMemorySpace() == getDstMemorySpace())
1501
0
    return emitOpError("DMA should be between different memory spaces");
1502
0
1503
0
  // Optional stride-related operands must be either both present or both
1504
0
  // absent.
1505
0
  if (numOperands != numExpectedOperands &&
1506
0
      numOperands != numExpectedOperands + 2)
1507
0
    return emitOpError("incorrect number of operands");
1508
0
1509
0
  // 5. Strides.
1510
0
  if (isStrided()) {
1511
0
    if (!getStride().getType().isIndex() ||
1512
0
        !getNumElementsPerStride().getType().isIndex())
1513
0
      return emitOpError(
1514
0
          "expected stride and num elements per stride to be of type index");
1515
0
  }
1516
0
1517
0
  return success();
1518
0
}
1519
1520
LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1521
0
                               SmallVectorImpl<OpFoldResult> &results) {
1522
0
  /// dma_start(memrefcast) -> dma_start
1523
0
  return foldMemRefCast(*this);
1524
0
}
1525
1526
// ---------------------------------------------------------------------------
1527
// DmaWaitOp
1528
// ---------------------------------------------------------------------------
1529
1530
void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
1531
                      Value tagMemRef, ValueRange tagIndices,
1532
0
                      Value numElements) {
1533
0
  result.addOperands(tagMemRef);
1534
0
  result.addOperands(tagIndices);
1535
0
  result.addOperands(numElements);
1536
0
}
1537
1538
0
void DmaWaitOp::print(OpAsmPrinter &p) {
1539
0
  p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
1540
0
    << getNumElements();
1541
0
  p.printOptionalAttrDict(getAttrs());
1542
0
  p << " : " << getTagMemRef().getType();
1543
0
}
1544
1545
// Parse DmaWaitOp.
1546
// Eg:
1547
//   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
1548
//
1549
0
ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
1550
0
  OpAsmParser::OperandType tagMemrefInfo;
1551
0
  SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
1552
0
  Type type;
1553
0
  auto indexType = parser.getBuilder().getIndexType();
1554
0
  OpAsmParser::OperandType numElementsInfo;
1555
0
1556
0
  // Parse tag memref, its indices, and dma size.
1557
0
  if (parser.parseOperand(tagMemrefInfo) ||
1558
0
      parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
1559
0
      parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1560
0
      parser.parseColonType(type) ||
1561
0
      parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
1562
0
      parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
1563
0
      parser.resolveOperand(numElementsInfo, indexType, result.operands))
1564
0
    return failure();
1565
0
1566
0
  return success();
1567
0
}
1568
1569
LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1570
0
                              SmallVectorImpl<OpFoldResult> &results) {
1571
0
  /// dma_wait(memrefcast) -> dma_wait
1572
0
  return foldMemRefCast(*this);
1573
0
}
1574
1575
0
LogicalResult DmaWaitOp::verify() {
1576
0
  // Mandatory non-variadic operands are tag and the number of elements.
1577
0
  if (getNumOperands() < 2)
1578
0
    return emitOpError() << "expected at least 2 operands";
1579
0
1580
0
  // Check types of operands. The order of these calls is important: the later
1581
0
  // calls rely on some type properties to compute the operand position.
1582
0
  if (!getTagMemRef().getType().isa<MemRefType>())
1583
0
    return emitOpError() << "expected tag to be of memref type";
1584
0
1585
0
  if (getNumOperands() != 2 + getTagMemRefRank())
1586
0
    return emitOpError() << "expected " << 2 + getTagMemRefRank()
1587
0
                         << " operands";
1588
0
1589
0
  if (!getTagIndices().empty() &&
1590
0
      !llvm::all_of(getTagIndices().getTypes(),
1591
0
                    [](Type t) { return t.isIndex(); }))
1592
0
    return emitOpError() << "expected tag indices to be of index type";
1593
0
1594
0
  if (!getNumElements().getType().isIndex())
1595
0
    return emitOpError()
1596
0
           << "expected the number of elements to be of index type";
1597
0
1598
0
  return success();
1599
0
}
1600
1601
//===----------------------------------------------------------------------===//
1602
// ExtractElementOp
1603
//===----------------------------------------------------------------------===//
1604
1605
0
static LogicalResult verify(ExtractElementOp op) {
1606
0
  // Verify the # indices match if we have a ranked type.
1607
0
  auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
1608
0
  if (aggregateType.hasRank() &&
1609
0
      aggregateType.getRank() != op.getNumOperands() - 1)
1610
0
    return op.emitOpError("incorrect number of indices for extract_element");
1611
0
1612
0
  return success();
1613
0
}
1614
1615
0
OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
1616
0
  assert(!operands.empty() && "extract_element takes at least one operand");
1617
0
1618
0
  // The aggregate operand must be a known constant.
1619
0
  Attribute aggregate = operands.front();
1620
0
  if (!aggregate)
1621
0
    return {};
1622
0
1623
0
  // If this is a splat elements attribute, simply return the value. All of the
1624
0
  // elements of a splat attribute are the same.
1625
0
  if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
1626
0
    return splatAggregate.getSplatValue();
1627
0
1628
0
  // Otherwise, collect the constant indices into the aggregate.
1629
0
  SmallVector<uint64_t, 8> indices;
1630
0
  for (Attribute indice : llvm::drop_begin(operands, 1)) {
1631
0
    if (!indice || !indice.isa<IntegerAttr>())
1632
0
      return {};
1633
0
    indices.push_back(indice.cast<IntegerAttr>().getInt());
1634
0
  }
1635
0
1636
0
  // If this is an elements attribute, query the value at the given indices.
1637
0
  auto elementsAttr = aggregate.dyn_cast<ElementsAttr>();
1638
0
  if (elementsAttr && elementsAttr.isValidIndex(indices))
1639
0
    return elementsAttr.getValue(indices);
1640
0
  return {};
1641
0
}
1642
1643
//===----------------------------------------------------------------------===//
1644
// TensorFromElementsOp
1645
//===----------------------------------------------------------------------===//
1646
1647
static ParseResult parseTensorFromElementsOp(OpAsmParser &parser,
1648
0
                                             OperationState &result) {
1649
0
  SmallVector<OpAsmParser::OperandType, 4> elementsOperands;
1650
0
  Type resultType;
1651
0
  if (parser.parseLParen() || parser.parseOperandList(elementsOperands) ||
1652
0
      parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
1653
0
      parser.parseColon() || parser.parseType(resultType))
1654
0
    return failure();
1655
0
1656
0
  if (parser.resolveOperands(elementsOperands,
1657
0
                             resultType.cast<ShapedType>().getElementType(),
1658
0
                             result.operands))
1659
0
    return failure();
1660
0
1661
0
  result.addTypes(resultType);
1662
0
  return success();
1663
0
}
1664
1665
0
static void print(OpAsmPrinter &p, TensorFromElementsOp op) {
1666
0
  p << "tensor_from_elements(" << op.elements() << ')';
1667
0
  p.printOptionalAttrDict(op.getAttrs());
1668
0
  p << " : " << op.result().getType();
1669
0
}
1670
1671
0
static LogicalResult verify(TensorFromElementsOp op) {
1672
0
  auto resultTensorType = op.result().getType().dyn_cast<RankedTensorType>();
1673
0
  if (!resultTensorType)
1674
0
    return op.emitOpError("expected result type to be a ranked tensor");
1675
0
1676
0
  int64_t elementsCount = static_cast<int64_t>(op.elements().size());
1677
0
  if (resultTensorType.getRank() != 1 ||
1678
0
      resultTensorType.getShape().front() != elementsCount)
1679
0
    return op.emitOpError()
1680
0
           << "expected result type to be a 1D tensor with " << elementsCount
1681
0
           << (elementsCount == 1 ? " element" : " elements");
1682
0
  return success();
1683
0
}
1684
1685
namespace {
1686
1687
// Canonicalizes the pattern of the form
1688
//
1689
// %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32>
1690
// %extracted_element = extract_element %tensor[%c0] : tensor<1xi32>
1691
//
1692
// to just %element.
1693
struct ExtractElementFromTensorFromElements
1694
    : public OpRewritePattern<ExtractElementOp> {
1695
  using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
1696
1697
  LogicalResult matchAndRewrite(ExtractElementOp extract,
1698
0
                                PatternRewriter &rewriter) const final {
1699
0
    if (extract.indices().size() != 1)
1700
0
      return failure();
1701
0
1702
0
    auto tensor_from_elements =
1703
0
        dyn_cast<TensorFromElementsOp>(extract.aggregate().getDefiningOp());
1704
0
    if (tensor_from_elements == nullptr)
1705
0
      return failure();
1706
0
1707
0
    APInt index;
1708
0
    if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
1709
0
      return failure();
1710
0
    rewriter.replaceOp(extract,
1711
0
                       tensor_from_elements.getOperand(index.getZExtValue()));
1712
0
    return success();
1713
0
  }
1714
};
1715
1716
} // namespace
1717
1718
void TensorFromElementsOp::getCanonicalizationPatterns(
1719
0
    OwningRewritePatternList &results, MLIRContext *context) {
1720
0
  results.insert<ExtractElementFromTensorFromElements>(context);
1721
0
}
1722
1723
//===----------------------------------------------------------------------===//
1724
// FPExtOp
1725
//===----------------------------------------------------------------------===//
1726
1727
0
bool FPExtOp::areCastCompatible(Type a, Type b) {
1728
0
  if (auto fa = a.dyn_cast<FloatType>())
1729
0
    if (auto fb = b.dyn_cast<FloatType>())
1730
0
      return fa.getWidth() < fb.getWidth();
1731
0
  if (auto va = a.dyn_cast<VectorType>())
1732
0
    if (auto vb = b.dyn_cast<VectorType>())
1733
0
      return va.getShape().equals(vb.getShape()) &&
1734
0
             areCastCompatible(va.getElementType(), vb.getElementType());
1735
0
  return false;
1736
0
}
1737
1738
//===----------------------------------------------------------------------===//
1739
// FPToSIOp
1740
//===----------------------------------------------------------------------===//
1741
1742
0
bool FPToSIOp::areCastCompatible(Type a, Type b) {
1743
0
  return a.isa<FloatType>() && b.isSignlessInteger();
1744
0
}
1745
1746
//===----------------------------------------------------------------------===//
1747
// FPTruncOp
1748
//===----------------------------------------------------------------------===//
1749
1750
0
bool FPTruncOp::areCastCompatible(Type a, Type b) {
1751
0
  if (auto fa = a.dyn_cast<FloatType>())
1752
0
    if (auto fb = b.dyn_cast<FloatType>())
1753
0
      return fa.getWidth() > fb.getWidth();
1754
0
  if (auto va = a.dyn_cast<VectorType>())
1755
0
    if (auto vb = b.dyn_cast<VectorType>())
1756
0
      return va.getShape().equals(vb.getShape()) &&
1757
0
             areCastCompatible(va.getElementType(), vb.getElementType());
1758
0
  return false;
1759
0
}
1760
1761
//===----------------------------------------------------------------------===//
1762
// IndexCastOp
1763
//===----------------------------------------------------------------------===//
1764
1765
// Index cast is applicable from index to integer and backwards.
1766
0
bool IndexCastOp::areCastCompatible(Type a, Type b) {
1767
0
  return (a.isIndex() && b.isSignlessInteger()) ||
1768
0
         (a.isSignlessInteger() && b.isIndex());
1769
0
}
1770
1771
0
OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
1772
0
  // Fold IndexCast(IndexCast(x)) -> x
1773
0
  auto cast = getOperand().getDefiningOp<IndexCastOp>();
1774
0
  if (cast && cast.getOperand().getType() == getType())
1775
0
    return cast.getOperand();
1776
0
1777
0
  // Fold IndexCast(constant) -> constant
1778
0
  // A little hack because we go through int.  Otherwise, the size
1779
0
  // of the constant might need to change.
1780
0
  if (auto value = cstOperands[0].dyn_cast_or_null<IntegerAttr>())
1781
0
    return IntegerAttr::get(getType(), value.getInt());
1782
0
1783
0
  return {};
1784
0
}
1785
1786
//===----------------------------------------------------------------------===//
1787
// LoadOp
1788
//===----------------------------------------------------------------------===//
1789
1790
0
static LogicalResult verify(LoadOp op) {
1791
0
  if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1792
0
    return op.emitOpError("incorrect number of indices for load");
1793
0
  return success();
1794
0
}
1795
1796
0
OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1797
0
  /// load(memrefcast) -> load
1798
0
  if (succeeded(foldMemRefCast(*this)))
1799
0
    return getResult();
1800
0
  return OpFoldResult();
1801
0
}
1802
1803
//===----------------------------------------------------------------------===//
1804
// MemRefCastOp
1805
//===----------------------------------------------------------------------===//
1806
1807
0
bool MemRefCastOp::areCastCompatible(Type a, Type b) {
1808
0
  auto aT = a.dyn_cast<MemRefType>();
1809
0
  auto bT = b.dyn_cast<MemRefType>();
1810
0
1811
0
  auto uaT = a.dyn_cast<UnrankedMemRefType>();
1812
0
  auto ubT = b.dyn_cast<UnrankedMemRefType>();
1813
0
1814
0
  if (aT && bT) {
1815
0
    if (aT.getElementType() != bT.getElementType())
1816
0
      return false;
1817
0
    if (aT.getAffineMaps() != bT.getAffineMaps()) {
1818
0
      int64_t aOffset, bOffset;
1819
0
      SmallVector<int64_t, 4> aStrides, bStrides;
1820
0
      if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
1821
0
          failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
1822
0
          aStrides.size() != bStrides.size())
1823
0
        return false;
1824
0
1825
0
      // Strides along a dimension/offset are compatible if the value in the
1826
0
      // source memref is static and the value in the target memref is the
1827
0
      // same. They are also compatible if either one is dynamic (see
1828
0
      // description of MemRefCastOp for details).
1829
0
      auto checkCompatible = [](int64_t a, int64_t b) {
1830
0
        return (a == MemRefType::getDynamicStrideOrOffset() ||
1831
0
                b == MemRefType::getDynamicStrideOrOffset() || a == b);
1832
0
      };
1833
0
      if (!checkCompatible(aOffset, bOffset))
1834
0
        return false;
1835
0
      for (auto aStride : enumerate(aStrides))
1836
0
        if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
1837
0
          return false;
1838
0
    }
1839
0
    if (aT.getMemorySpace() != bT.getMemorySpace())
1840
0
      return false;
1841
0
1842
0
    // They must have the same rank, and any specified dimensions must match.
1843
0
    if (aT.getRank() != bT.getRank())
1844
0
      return false;
1845
0
1846
0
    for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
1847
0
      int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
1848
0
      if (aDim != -1 && bDim != -1 && aDim != bDim)
1849
0
        return false;
1850
0
    }
1851
0
    return true;
1852
0
  } else {
1853
0
    if (!aT && !uaT)
1854
0
      return false;
1855
0
    if (!bT && !ubT)
1856
0
      return false;
1857
0
    // Unranked to unranked casting is unsupported
1858
0
    if (uaT && ubT)
1859
0
      return false;
1860
0
1861
0
    auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
1862
0
    auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
1863
0
    if (aEltType != bEltType)
1864
0
      return false;
1865
0
1866
0
    auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
1867
0
    auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
1868
0
    if (aMemSpace != bMemSpace)
1869
0
      return false;
1870
0
1871
0
    return true;
1872
0
  }
1873
0
1874
0
  return false;
1875
0
}
1876
1877
0
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
1878
0
  return impl::foldCastOp(*this);
1879
0
}
1880
1881
//===----------------------------------------------------------------------===//
1882
// MulFOp
1883
//===----------------------------------------------------------------------===//
1884
1885
0
OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
1886
0
  return constFoldBinaryOp<FloatAttr>(
1887
0
      operands, [](APFloat a, APFloat b) { return a * b; });
1888
0
}
1889
1890
//===----------------------------------------------------------------------===//
1891
// MulIOp
1892
//===----------------------------------------------------------------------===//
1893
1894
0
OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
1895
0
  /// muli(x, 0) -> 0
1896
0
  if (matchPattern(rhs(), m_Zero()))
1897
0
    return rhs();
1898
0
  /// muli(x, 1) -> x
1899
0
  if (matchPattern(rhs(), m_One()))
1900
0
    return getOperand(0);
1901
0
1902
0
  // TODO: Handle the overflow case.
1903
0
  return constFoldBinaryOp<IntegerAttr>(operands,
1904
0
                                        [](APInt a, APInt b) { return a * b; });
1905
0
}
1906
1907
//===----------------------------------------------------------------------===//
1908
// OrOp
1909
//===----------------------------------------------------------------------===//
1910
1911
0
OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
1912
0
  /// or(x, 0) -> x
1913
0
  if (matchPattern(rhs(), m_Zero()))
1914
0
    return lhs();
1915
0
  /// or(x,x) -> x
1916
0
  if (lhs() == rhs())
1917
0
    return rhs();
1918
0
1919
0
  return constFoldBinaryOp<IntegerAttr>(operands,
1920
0
                                        [](APInt a, APInt b) { return a | b; });
1921
0
}
1922
1923
//===----------------------------------------------------------------------===//
1924
// PrefetchOp
1925
//===----------------------------------------------------------------------===//
1926
1927
0
static void print(OpAsmPrinter &p, PrefetchOp op) {
1928
0
  p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1929
0
  p.printOperands(op.indices());
1930
0
  p << ']' << ", " << (op.isWrite() ? "write" : "read");
1931
0
  p << ", locality<" << op.localityHint();
1932
0
  p << ">, " << (op.isDataCache() ? "data" : "instr");
1933
0
  p.printOptionalAttrDict(
1934
0
      op.getAttrs(),
1935
0
      /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1936
0
  p << " : " << op.getMemRefType();
1937
0
}
1938
1939
static ParseResult parsePrefetchOp(OpAsmParser &parser,
1940
0
                                   OperationState &result) {
1941
0
  OpAsmParser::OperandType memrefInfo;
1942
0
  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1943
0
  IntegerAttr localityHint;
1944
0
  MemRefType type;
1945
0
  StringRef readOrWrite, cacheType;
1946
0
1947
0
  auto indexTy = parser.getBuilder().getIndexType();
1948
0
  auto i32Type = parser.getBuilder().getIntegerType(32);
1949
0
  if (parser.parseOperand(memrefInfo) ||
1950
0
      parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1951
0
      parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1952
0
      parser.parseComma() || parser.parseKeyword("locality") ||
1953
0
      parser.parseLess() ||
1954
0
      parser.parseAttribute(localityHint, i32Type, "localityHint",
1955
0
                            result.attributes) ||
1956
0
      parser.parseGreater() || parser.parseComma() ||
1957
0
      parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1958
0
      parser.resolveOperand(memrefInfo, type, result.operands) ||
1959
0
      parser.resolveOperands(indexInfo, indexTy, result.operands))
1960
0
    return failure();
1961
0
1962
0
  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1963
0
    return parser.emitError(parser.getNameLoc(),
1964
0
                            "rw specifier has to be 'read' or 'write'");
1965
0
  result.addAttribute(
1966
0
      PrefetchOp::getIsWriteAttrName(),
1967
0
      parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1968
0
1969
0
  if (!cacheType.equals("data") && !cacheType.equals("instr"))
1970
0
    return parser.emitError(parser.getNameLoc(),
1971
0
                            "cache type has to be 'data' or 'instr'");
1972
0
1973
0
  result.addAttribute(
1974
0
      PrefetchOp::getIsDataCacheAttrName(),
1975
0
      parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1976
0
1977
0
  return success();
1978
0
}
1979
1980
0
static LogicalResult verify(PrefetchOp op) {
1981
0
  if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1982
0
    return op.emitOpError("too few indices");
1983
0
1984
0
  return success();
1985
0
}
1986
1987
LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1988
0
                               SmallVectorImpl<OpFoldResult> &results) {
1989
0
  // prefetch(memrefcast) -> prefetch
1990
0
  return foldMemRefCast(*this);
1991
0
}
1992
1993
//===----------------------------------------------------------------------===//
1994
// RankOp
1995
//===----------------------------------------------------------------------===//
1996
1997
0
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1998
0
  // Constant fold rank when the rank of the tensor is known.
1999
0
  auto type = getOperand().getType();
2000
0
  if (auto tensorType = type.dyn_cast<RankedTensorType>())
2001
0
    return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank());
2002
0
  return IntegerAttr();
2003
0
}
2004
2005
//===----------------------------------------------------------------------===//
2006
// ReturnOp
2007
//===----------------------------------------------------------------------===//
2008
2009
0
static LogicalResult verify(ReturnOp op) {
2010
0
  auto function = cast<FuncOp>(op.getParentOp());
2011
0
2012
0
  // The operand number and types must match the function signature.
2013
0
  const auto &results = function.getType().getResults();
2014
0
  if (op.getNumOperands() != results.size())
2015
0
    return op.emitOpError("has ")
2016
0
           << op.getNumOperands()
2017
0
           << " operands, but enclosing function returns " << results.size();
2018
0
2019
0
  for (unsigned i = 0, e = results.size(); i != e; ++i)
2020
0
    if (op.getOperand(i).getType() != results[i])
2021
0
      return op.emitError()
2022
0
             << "type of return operand " << i << " ("
2023
0
             << op.getOperand(i).getType()
2024
0
             << ") doesn't match function result type (" << results[i] << ")";
2025
0
2026
0
  return success();
2027
0
}
2028
2029
//===----------------------------------------------------------------------===//
2030
// SelectOp
2031
//===----------------------------------------------------------------------===//
2032
2033
0
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
2034
0
  auto condition = getCondition();
2035
0
2036
0
  // select true, %0, %1 => %0
2037
0
  if (matchPattern(condition, m_One()))
2038
0
    return getTrueValue();
2039
0
2040
0
  // select false, %0, %1 => %1
2041
0
  if (matchPattern(condition, m_Zero()))
2042
0
    return getFalseValue();
2043
0
  return nullptr;
2044
0
}
2045
2046
0
static void print(OpAsmPrinter &p, SelectOp op) {
2047
0
  p << "select " << op.getOperands();
2048
0
  p.printOptionalAttrDict(op.getAttrs());
2049
0
  p << " : ";
2050
0
  if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
2051
0
    p << condType << ", ";
2052
0
  p << op.getType();
2053
0
}
2054
2055
0
static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
2056
0
  Type conditionType, resultType;
2057
0
  SmallVector<OpAsmParser::OperandType, 3> operands;
2058
0
  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2059
0
      parser.parseOptionalAttrDict(result.attributes) ||
2060
0
      parser.parseColonType(resultType))
2061
0
    return failure();
2062
0
2063
0
  // Check for the explicit condition type if this is a masked tensor or vector.
2064
0
  if (succeeded(parser.parseOptionalComma())) {
2065
0
    conditionType = resultType;
2066
0
    if (parser.parseType(resultType))
2067
0
      return failure();
2068
0
  } else {
2069
0
    conditionType = parser.getBuilder().getI1Type();
2070
0
  }
2071
0
2072
0
  result.addTypes(resultType);
2073
0
  return parser.resolveOperands(operands,
2074
0
                                {conditionType, resultType, resultType},
2075
0
                                parser.getNameLoc(), result.operands);
2076
0
}
2077
2078
0
static LogicalResult verify(SelectOp op) {
2079
0
  Type conditionType = op.getCondition().getType();
2080
0
  if (conditionType.isSignlessInteger(1))
2081
0
    return success();
2082
0
2083
0
  // If the result type is a vector or tensor, the type can be a mask with the
2084
0
  // same elements.
2085
0
  Type resultType = op.getType();
2086
0
  if (!resultType.isa<TensorType>() && !resultType.isa<VectorType>())
2087
0
    return op.emitOpError()
2088
0
           << "expected condition to be a signless i1, but got "
2089
0
           << conditionType;
2090
0
  Type shapedConditionType = getI1SameShape(resultType);
2091
0
  if (conditionType != shapedConditionType)
2092
0
    return op.emitOpError()
2093
0
           << "expected condition type to have the same shape "
2094
0
              "as the result type, expected "
2095
0
           << shapedConditionType << ", but got " << conditionType;
2096
0
  return success();
2097
0
}
2098
2099
//===----------------------------------------------------------------------===//
2100
// SignExtendIOp
2101
//===----------------------------------------------------------------------===//
2102
2103
0
static LogicalResult verify(SignExtendIOp op) {
2104
0
  // Get the scalar type (which is either directly the type of the operand
2105
0
  // or the vector's/tensor's element type.
2106
0
  auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2107
0
  auto dstType = getElementTypeOrSelf(op.getType());
2108
0
2109
0
  // For now, index is forbidden for the source and the destination type.
2110
0
  if (srcType.isa<IndexType>())
2111
0
    return op.emitError() << srcType << " is not a valid operand type";
2112
0
  if (dstType.isa<IndexType>())
2113
0
    return op.emitError() << dstType << " is not a valid result type";
2114
0
2115
0
  if (srcType.cast<IntegerType>().getWidth() >=
2116
0
      dstType.cast<IntegerType>().getWidth())
2117
0
    return op.emitError("result type ")
2118
0
           << dstType << " must be wider than operand type " << srcType;
2119
0
2120
0
  return success();
2121
0
}
2122
2123
//===----------------------------------------------------------------------===//
2124
// SignedDivIOp
2125
//===----------------------------------------------------------------------===//
2126
2127
0
OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
2128
0
  assert(operands.size() == 2 && "binary operation takes two operands");
2129
0
2130
0
  // Don't fold if it would overflow or if it requires a division by zero.
2131
0
  bool overflowOrDiv0 = false;
2132
0
  auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2133
0
    if (overflowOrDiv0 || !b) {
2134
0
      overflowOrDiv0 = true;
2135
0
      return a;
2136
0
    }
2137
0
    return a.sdiv_ov(b, overflowOrDiv0);
2138
0
  });
2139
0
2140
0
  // Fold out division by one. Assumes all tensors of all ones are splats.
2141
0
  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2142
0
    if (rhs.getValue() == 1)
2143
0
      return lhs();
2144
0
  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2145
0
    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2146
0
      return lhs();
2147
0
  }
2148
0
2149
0
  return overflowOrDiv0 ? Attribute() : result;
2150
0
}
2151
2152
//===----------------------------------------------------------------------===//
2153
// SignedRemIOp
2154
//===----------------------------------------------------------------------===//
2155
2156
0
OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
2157
0
  assert(operands.size() == 2 && "remi_signed takes two operands");
2158
0
2159
0
  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
2160
0
  if (!rhs)
2161
0
    return {};
2162
0
  auto rhsValue = rhs.getValue();
2163
0
2164
0
  // x % 1 = 0
2165
0
  if (rhsValue.isOneValue())
2166
0
    return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
2167
0
2168
0
  // Don't fold if it requires division by zero.
2169
0
  if (rhsValue.isNullValue())
2170
0
    return {};
2171
0
2172
0
  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
2173
0
  if (!lhs)
2174
0
    return {};
2175
0
  return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
2176
0
}
2177
2178
//===----------------------------------------------------------------------===//
2179
// SIToFPOp
2180
//===----------------------------------------------------------------------===//
2181
2182
// sitofp is applicable from integer types to float types.
2183
0
bool SIToFPOp::areCastCompatible(Type a, Type b) {
2184
0
  return a.isSignlessInteger() && b.isa<FloatType>();
2185
0
}
2186
2187
//===----------------------------------------------------------------------===//
2188
// SplatOp
2189
//===----------------------------------------------------------------------===//
2190
2191
0
static LogicalResult verify(SplatOp op) {
2192
0
  // TODO: we could replace this by a trait.
2193
0
  if (op.getOperand().getType() !=
2194
0
      op.getType().cast<ShapedType>().getElementType())
2195
0
    return op.emitError("operand should be of elemental type of result type");
2196
0
2197
0
  return success();
2198
0
}
2199
2200
// Constant folding hook for SplatOp.
2201
0
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
2202
0
  assert(operands.size() == 1 && "splat takes one operand");
2203
0
2204
0
  auto constOperand = operands.front();
2205
0
  if (!constOperand ||
2206
0
      (!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>()))
2207
0
    return {};
2208
0
2209
0
  auto shapedType = getType().cast<ShapedType>();
2210
0
  assert(shapedType.getElementType() == constOperand.getType() &&
2211
0
         "incorrect input attribute type for folding");
2212
0
2213
0
  // SplatElementsAttr::get treats single value for second arg as being a splat.
2214
0
  return SplatElementsAttr::get(shapedType, {constOperand});
2215
0
}
2216
2217
//===----------------------------------------------------------------------===//
2218
// StoreOp
2219
//===----------------------------------------------------------------------===//
2220
2221
0
static LogicalResult verify(StoreOp op) {
2222
0
  if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
2223
0
    return op.emitOpError("store index operand count not equal to memref rank");
2224
0
2225
0
  return success();
2226
0
}
2227
2228
LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
2229
0
                            SmallVectorImpl<OpFoldResult> &results) {
2230
0
  /// store(memrefcast) -> store
2231
0
  return foldMemRefCast(*this);
2232
0
}
2233
2234
//===----------------------------------------------------------------------===//
2235
// SubFOp
2236
//===----------------------------------------------------------------------===//
2237
2238
0
OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
2239
0
  return constFoldBinaryOp<FloatAttr>(
2240
0
      operands, [](APFloat a, APFloat b) { return a - b; });
2241
0
}
2242
2243
//===----------------------------------------------------------------------===//
2244
// SubIOp
2245
//===----------------------------------------------------------------------===//
2246
2247
0
OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
2248
0
  // subi(x,x) -> 0
2249
0
  if (getOperand(0) == getOperand(1))
2250
0
    return Builder(getContext()).getZeroAttr(getType());
2251
0
2252
0
  return constFoldBinaryOp<IntegerAttr>(operands,
2253
0
                                        [](APInt a, APInt b) { return a - b; });
2254
0
}
2255
2256
//===----------------------------------------------------------------------===//
2257
// SubViewOp
2258
//===----------------------------------------------------------------------===//
2259
2260
/// Print a list with either (1) the static integer value in `arrayAttr` if
2261
/// `isDynamic` evaluates to false or (2) the next value otherwise.
2262
/// This allows idiomatic printing of mixed value and integer attributes in a
2263
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
2264
static void printSubViewListOfOperandsOrIntegers(
2265
    OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
2266
0
    llvm::function_ref<bool(int64_t)> isDynamic) {
2267
0
  p << "[";
2268
0
  unsigned idx = 0;
2269
0
  llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
2270
0
    int64_t val = a.cast<IntegerAttr>().getInt();
2271
0
    if (isDynamic(val))
2272
0
      p << values[idx++];
2273
0
    else
2274
0
      p << val;
2275
0
  });
2276
0
  p << "] ";
2277
0
}
2278
2279
/// Parse a mixed list with either (1) static integer values or (2) SSA values.
2280
/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
2281
/// encode the position of SSA values. Add the parsed SSA values to `ssa`
2282
/// in-order.
2283
//
2284
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
2285
///   1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
2286
///   2. `ssa` is filled with "[%arg0, %arg1]".
2287
static ParseResult
2288
parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
2289
                              StringRef attrName, int64_t dynVal,
2290
0
                              SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
2291
0
  if (failed(parser.parseLSquare()))
2292
0
    return failure();
2293
0
  // 0-D.
2294
0
  if (succeeded(parser.parseOptionalRSquare()))
2295
0
    return success();
2296
0
2297
0
  SmallVector<int64_t, 4> attrVals;
2298
0
  while (true) {
2299
0
    OpAsmParser::OperandType operand;
2300
0
    auto res = parser.parseOptionalOperand(operand);
2301
0
    if (res.hasValue() && succeeded(res.getValue())) {
2302
0
      ssa.push_back(operand);
2303
0
      attrVals.push_back(dynVal);
2304
0
    } else {
2305
0
      Attribute attr;
2306
0
      NamedAttrList placeholder;
2307
0
      if (failed(parser.parseAttribute(attr, "_", placeholder)) ||
2308
0
          !attr.isa<IntegerAttr>())
2309
0
        return parser.emitError(parser.getNameLoc())
2310
0
               << "expected SSA value or integer";
2311
0
      attrVals.push_back(attr.cast<IntegerAttr>().getInt());
2312
0
    }
2313
0
2314
0
    if (succeeded(parser.parseOptionalComma()))
2315
0
      continue;
2316
0
    if (failed(parser.parseRSquare()))
2317
0
      return failure();
2318
0
    else
2319
0
      break;
2320
0
  }
2321
0
2322
0
  auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
2323
0
  result.addAttribute(attrName, arrayAttr);
2324
0
  return success();
2325
0
}
2326
2327
namespace {
2328
/// Helpers to write more idiomatic operations.
2329
namespace saturated_arith {
2330
struct Wrapper {
2331
0
  explicit Wrapper(int64_t v) : v(v) {}
2332
0
  operator int64_t() { return v; }
2333
  int64_t v;
2334
};
2335
0
Wrapper operator+(Wrapper a, int64_t b) {
2336
0
  if (ShapedType::isDynamicStrideOrOffset(a) ||
2337
0
      ShapedType::isDynamicStrideOrOffset(b))
2338
0
    return Wrapper(ShapedType::kDynamicStrideOrOffset);
2339
0
  return Wrapper(a.v + b);
2340
0
}
2341
0
Wrapper operator*(Wrapper a, int64_t b) {
2342
0
  if (ShapedType::isDynamicStrideOrOffset(a) ||
2343
0
      ShapedType::isDynamicStrideOrOffset(b))
2344
0
    return Wrapper(ShapedType::kDynamicStrideOrOffset);
2345
0
  return Wrapper(a.v * b);
2346
0
}
2347
} // end namespace saturated_arith
2348
} // end namespace
2349
2350
/// A subview result type can be fully inferred from the source type and the
2351
/// static representation of offsets, sizes and strides. Special sentinels
2352
/// encode the dynamic case.
2353
Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
2354
                                       ArrayRef<int64_t> staticOffsets,
2355
                                       ArrayRef<int64_t> staticSizes,
2356
                                       ArrayRef<int64_t> staticStrides) {
2357
  unsigned rank = sourceMemRefType.getRank();
2358
  (void)rank;
2359
  assert(staticOffsets.size() == rank &&
2360
         "unexpected staticOffsets size mismatch");
2361
  assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
2362
  assert(staticStrides.size() == rank &&
2363
         "unexpected staticStrides size mismatch");
2364
2365
  // Extract source offset and strides.
2366
  int64_t sourceOffset;
2367
  SmallVector<int64_t, 4> sourceStrides;
2368
  auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
2369
  assert(succeeded(res) && "SubViewOp expected strided memref type");
2370
  (void)res;
2371
2372
  // Compute target offset whose value is:
2373
  //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2374
  int64_t targetOffset = sourceOffset;
2375
  for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2376
    auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2377
    using namespace saturated_arith;
2378
    targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
2379
  }
2380
2381
  // Compute target stride whose value is:
2382
  //   `sourceStrides_i * staticStrides_i`.
2383
  SmallVector<int64_t, 4> targetStrides;
2384
  targetStrides.reserve(staticOffsets.size());
2385
  for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2386
    auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2387
    using namespace saturated_arith;
2388
    targetStrides.push_back(Wrapper(sourceStride) * staticStride);
2389
  }
2390
2391
  // The type is now known.
2392
  return MemRefType::get(
2393
      staticSizes, sourceMemRefType.getElementType(),
2394
      makeStridedLinearLayoutMap(targetStrides, targetOffset,
2395
                                 sourceMemRefType.getContext()),
2396
      sourceMemRefType.getMemorySpace());
2397
}
2398
2399
/// Print SubViewOp in the form:
2400
/// ```
2401
///   subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
2402
///     `:` strided-memref-type `to` strided-memref-type
2403
/// ```
2404
0
static void print(OpAsmPrinter &p, SubViewOp op) {
2405
0
  int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
2406
0
  p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
2407
0
  p << op.getOperand(0);
2408
0
  printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
2409
0
                                       ShapedType::isDynamicStrideOrOffset);
2410
0
  printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
2411
0
                                       ShapedType::isDynamic);
2412
0
  printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
2413
0
                                       ShapedType::isDynamicStrideOrOffset);
2414
0
  p.printOptionalAttrDict(op.getAttrs(),
2415
0
                          /*elidedAttrs=*/{SubViewOp::getSpecialAttrNames()});
2416
0
  p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2417
0
}
2418
2419
/// Parse SubViewOp of the form:
2420
/// ```
2421
///   subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
2422
///     `:` strided-memref-type `to` strided-memref-type
2423
/// ```
2424
0
static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
2425
0
  OpAsmParser::OperandType srcInfo;
2426
0
  SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
2427
0
  auto indexType = parser.getBuilder().getIndexType();
2428
0
  Type srcType, dstType;
2429
0
  if (parser.parseOperand(srcInfo))
2430
0
    return failure();
2431
0
  if (parseListOfOperandsOrIntegers(
2432
0
          parser, result, SubViewOp::getStaticOffsetsAttrName(),
2433
0
          ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
2434
0
      parseListOfOperandsOrIntegers(parser, result,
2435
0
                                    SubViewOp::getStaticSizesAttrName(),
2436
0
                                    ShapedType::kDynamicSize, sizesInfo) ||
2437
0
      parseListOfOperandsOrIntegers(
2438
0
          parser, result, SubViewOp::getStaticStridesAttrName(),
2439
0
          ShapedType::kDynamicStrideOrOffset, stridesInfo))
2440
0
    return failure();
2441
0
2442
0
  auto b = parser.getBuilder();
2443
0
  SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()),
2444
0
                                   static_cast<int>(sizesInfo.size()),
2445
0
                                   static_cast<int>(stridesInfo.size())};
2446
0
  result.addAttribute(SubViewOp::getOperandSegmentSizeAttr(),
2447
0
                      b.getI32VectorAttr(segmentSizes));
2448
0
2449
0
  return failure(
2450
0
      parser.parseOptionalAttrDict(result.attributes) ||
2451
0
      parser.parseColonType(srcType) ||
2452
0
      parser.resolveOperand(srcInfo, srcType, result.operands) ||
2453
0
      parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
2454
0
      parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2455
0
      parser.resolveOperands(stridesInfo, indexType, result.operands) ||
2456
0
      parser.parseKeywordType("to", dstType) ||
2457
0
      parser.addTypeToList(dstType, result.types));
2458
0
}
2459
2460
void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2461
                            ArrayRef<int64_t> staticOffsets,
2462
                            ArrayRef<int64_t> staticSizes,
2463
                            ArrayRef<int64_t> staticStrides, ValueRange offsets,
2464
                            ValueRange sizes, ValueRange strides,
2465
0
                            ArrayRef<NamedAttribute> attrs) {
2466
0
  auto sourceMemRefType = source.getType().cast<MemRefType>();
2467
0
  auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets,
2468
0
                                           staticSizes, staticStrides);
2469
0
  build(b, result, resultType, source, offsets, sizes, strides,
2470
0
        b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
2471
0
        b.getI64ArrayAttr(staticStrides));
2472
0
  result.addAttributes(attrs);
2473
0
}
2474
2475
/// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
2476
/// and `staticStrides` are  automatically filled with source-memref-rank
2477
/// sentinel values that encode dynamic entries.
2478
void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2479
                            ValueRange offsets, ValueRange sizes,
2480
                            ValueRange strides,
2481
0
                            ArrayRef<NamedAttribute> attrs) {
2482
0
  auto sourceMemRefType = source.getType().cast<MemRefType>();
2483
0
  unsigned rank = sourceMemRefType.getRank();
2484
0
  SmallVector<int64_t, 4> staticOffsetsVector;
2485
0
  staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
2486
0
  SmallVector<int64_t, 4> staticSizesVector;
2487
0
  staticSizesVector.assign(rank, ShapedType::kDynamicSize);
2488
0
  SmallVector<int64_t, 4> staticStridesVector;
2489
0
  staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
2490
0
  build(b, result, source, staticOffsetsVector, staticSizesVector,
2491
0
        staticStridesVector, offsets, sizes, strides, attrs);
2492
0
}
2493
2494
/// Verify that a particular offset/size/stride static attribute is well-formed.
2495
static LogicalResult
2496
verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName,
2497
                    ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
2498
0
                    ValueRange values) {
2499
0
  /// Check static and dynamic offsets/sizes/strides breakdown.
2500
0
  if (attr.size() != op.getRank())
2501
0
    return op.emitError("expected ")
2502
0
           << op.getRank() << " " << name << " values";
2503
0
  unsigned expectedNumDynamicEntries =
2504
0
      llvm::count_if(attr.getValue(), [&](Attribute attr) {
2505
0
        return isDynamic(attr.cast<IntegerAttr>().getInt());
2506
0
      });
2507
0
  if (values.size() != expectedNumDynamicEntries)
2508
0
    return op.emitError("expected ")
2509
0
           << expectedNumDynamicEntries << " dynamic " << name << " values";
2510
0
  return success();
2511
0
}
2512
2513
/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
2514
0
static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
2515
0
  return llvm::to_vector<4>(
2516
0
      llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
2517
0
        return a.cast<IntegerAttr>().getInt();
2518
0
      }));
2519
0
}
2520
2521
/// Verifier for SubViewOp.
2522
0
static LogicalResult verify(SubViewOp op) {
2523
0
  auto baseType = op.getBaseMemRefType().cast<MemRefType>();
2524
0
  auto subViewType = op.getType();
2525
0
2526
0
  // The base memref and the view memref should be in the same memory space.
2527
0
  if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2528
0
    return op.emitError("different memory spaces specified for base memref "
2529
0
                        "type ")
2530
0
           << baseType << " and subview memref type " << subViewType;
2531
0
2532
0
  // Verify that the base memref type has a strided layout map.
2533
0
  if (!isStrided(baseType))
2534
0
    return op.emitError("base type ") << baseType << " is not strided";
2535
0
2536
0
  // Verify static attributes offsets/sizes/strides.
2537
0
  if (failed(verifySubViewOpPart(
2538
0
          op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
2539
0
          ShapedType::isDynamicStrideOrOffset, op.offsets())))
2540
0
    return failure();
2541
0
2542
0
  if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(),
2543
0
                                 op.static_sizes(), ShapedType::isDynamic,
2544
0
                                 op.sizes())))
2545
0
    return failure();
2546
0
  if (failed(verifySubViewOpPart(
2547
0
          op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
2548
0
          ShapedType::isDynamicStrideOrOffset, op.strides())))
2549
0
    return failure();
2550
0
2551
0
  // Verify result type against inferred type.
2552
0
  auto expectedType = SubViewOp::inferSubViewResultType(
2553
0
      op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
2554
0
      extractFromI64ArrayAttr(op.static_sizes()),
2555
0
      extractFromI64ArrayAttr(op.static_strides()));
2556
0
  if (op.getType() != expectedType)
2557
0
    return op.emitError("expected result type to be ") << expectedType;
2558
0
2559
0
  return success();
2560
0
}
2561
2562
0
raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
2563
0
  return os << "range " << range.offset << ":" << range.size << ":"
2564
0
            << range.stride;
2565
0
}
2566
2567
static unsigned getNumDynamicEntriesUpToIdx(
2568
0
    ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
2569
0
  return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx,
2570
0
                       [&](Attribute attr) {
2571
0
                         return isDynamic(attr.cast<IntegerAttr>().getInt());
2572
0
                       });
2573
0
}
2574
2575
0
bool SubViewOp::isDynamicOffset(unsigned idx) {
2576
0
  return ShapedType::isDynamicStrideOrOffset(
2577
0
      extractFromI64ArrayAttr(static_offsets())[idx]);
2578
0
}
2579
0
bool SubViewOp::isDynamicSize(unsigned idx) {
2580
0
  return ShapedType::isDynamic(extractFromI64ArrayAttr(static_sizes())[idx]);
2581
0
}
2582
0
bool SubViewOp::isDynamicStride(unsigned idx) {
2583
0
  return ShapedType::isDynamicStrideOrOffset(
2584
0
      extractFromI64ArrayAttr(static_strides())[idx]);
2585
0
}
2586
2587
0
unsigned SubViewOp::getIndexOfDynamicOffset(unsigned idx) {
2588
0
  assert(isDynamicOffset(idx) && "expected static offset");
2589
0
  auto numDynamic =
2590
0
      getNumDynamicEntriesUpToIdx(static_offsets().cast<ArrayAttr>(),
2591
0
                                  ShapedType::isDynamicStrideOrOffset, idx);
2592
0
  return 1 + numDynamic;
2593
0
}
2594
0
unsigned SubViewOp::getIndexOfDynamicSize(unsigned idx) {
2595
0
  assert(isDynamicSize(idx) && "expected static size");
2596
0
  auto numDynamic = getNumDynamicEntriesUpToIdx(
2597
0
      static_sizes().cast<ArrayAttr>(), ShapedType::isDynamic, idx);
2598
0
  return 1 + offsets().size() + numDynamic;
2599
0
}
2600
0
unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) {
2601
0
  assert(isDynamicStride(idx) && "expected static stride");
2602
0
  auto numDynamic =
2603
0
      getNumDynamicEntriesUpToIdx(static_strides().cast<ArrayAttr>(),
2604
0
                                  ShapedType::isDynamicStrideOrOffset, idx);
2605
0
  return 1 + offsets().size() + sizes().size() + numDynamic;
2606
0
}
2607
2608
/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range
2609
/// entry contains either the dynamic value or a ConstantIndexOp constructed
2610
/// with `b` at location `loc`.
2611
SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
2612
0
                                                              Location loc) {
2613
0
  SmallVector<Range, 8> res;
2614
0
  unsigned rank = getType().getRank();
2615
0
  res.reserve(rank);
2616
0
  for (unsigned idx = 0; idx < rank; ++idx) {
2617
0
    auto offset = isDynamicOffset(idx)
2618
0
                      ? getDynamicOffset(idx)
2619
0
                      : b.create<ConstantIndexOp>(loc, getStaticOffset(idx));
2620
0
    auto size = isDynamicSize(idx)
2621
0
                    ? getDynamicSize(idx)
2622
0
                    : b.create<ConstantIndexOp>(loc, getStaticSize(idx));
2623
0
    auto stride = isDynamicStride(idx)
2624
0
                      ? getDynamicStride(idx)
2625
0
                      : b.create<ConstantIndexOp>(loc, getStaticStride(idx));
2626
0
    res.emplace_back(Range{offset, size, stride});
2627
0
  }
2628
0
  return res;
2629
0
}
2630
2631
SmallVector<Value, 4> SubViewOp::getOrCreateOffsets(OpBuilder &b,
2632
0
                                                    Location loc) {
2633
0
  unsigned dynamicIdx = 1;
2634
0
  return llvm::to_vector<4>(llvm::map_range(
2635
0
      static_offsets().cast<ArrayAttr>(), [&](Attribute a) -> Value {
2636
0
        int64_t staticOffset = a.cast<IntegerAttr>().getInt();
2637
0
        if (ShapedType::isDynamicStrideOrOffset(staticOffset))
2638
0
          return getOperand(dynamicIdx++);
2639
0
        else
2640
0
          return b.create<ConstantIndexOp>(loc, staticOffset);
2641
0
      }));
2642
0
}
2643
2644
0
SmallVector<Value, 4> SubViewOp::getOrCreateSizes(OpBuilder &b, Location loc) {
2645
0
  unsigned dynamicIdx = 1 + offsets().size();
2646
0
  return llvm::to_vector<4>(llvm::map_range(
2647
0
      static_sizes().cast<ArrayAttr>(), [&](Attribute a) -> Value {
2648
0
        int64_t staticSize = a.cast<IntegerAttr>().getInt();
2649
0
        if (ShapedType::isDynamic(staticSize))
2650
0
          return getOperand(dynamicIdx++);
2651
0
        else
2652
0
          return b.create<ConstantIndexOp>(loc, staticSize);
2653
0
      }));
2654
0
}
2655
2656
SmallVector<Value, 4> SubViewOp::getOrCreateStrides(OpBuilder &b,
2657
0
                                                    Location loc) {
2658
0
  unsigned dynamicIdx = 1 + offsets().size() + sizes().size();
2659
0
  return llvm::to_vector<4>(llvm::map_range(
2660
0
      static_strides().cast<ArrayAttr>(), [&](Attribute a) -> Value {
2661
0
        int64_t staticStride = a.cast<IntegerAttr>().getInt();
2662
0
        if (ShapedType::isDynamicStrideOrOffset(staticStride))
2663
0
          return getOperand(dynamicIdx++);
2664
0
        else
2665
0
          return b.create<ConstantIndexOp>(loc, staticStride);
2666
0
      }));
2667
0
}
2668
2669
LogicalResult
2670
0
SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
2671
0
  if (!strides().empty())
2672
0
    return failure();
2673
0
  staticStrides = extractFromI64ArrayAttr(static_strides());
2674
0
  return success();
2675
0
}
2676
2677
0
Value SubViewOp::getViewSource() { return source(); }
2678
2679
namespace {
2680
2681
/// Take a list of `values` with potential new constant to extract and a list
2682
/// of `constantValues` with`values.size()` sentinel that evaluate to true by
2683
/// applying `isDynamic`.
2684
/// Detects the `values` produced by a ConstantIndexOp and places the new
2685
/// constant in place of the corresponding sentinel value.
2686
void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
2687
                             SmallVectorImpl<int64_t> &constantValues,
2688
0
                             llvm::function_ref<bool(int64_t)> isDynamic) {
2689
0
  bool hasNewStaticValue = llvm::any_of(
2690
0
      values, [](Value val) { return matchPattern(val, m_ConstantIndex()); });
2691
0
  if (hasNewStaticValue) {
2692
0
    for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size();
2693
0
         cstIdx != e; ++cstIdx) {
2694
0
      // Was already static, skip.
2695
0
      if (!isDynamic(constantValues[cstIdx]))
2696
0
        continue;
2697
0
      // Newly static, move from Value to constant.
2698
0
      if (matchPattern(values[valIdx], m_ConstantIndex())) {
2699
0
        constantValues[cstIdx] =
2700
0
            cast<ConstantIndexOp>(values[valIdx].getDefiningOp()).getValue();
2701
0
        // Erase for impl. simplicity. Reverse iterator if we really must.
2702
0
        values.erase(std::next(values.begin(), valIdx));
2703
0
        continue;
2704
0
      }
2705
0
      // Remains dynamic move to next value.
2706
0
      ++valIdx;
2707
0
    }
2708
0
  }
2709
0
}
2710
2711
/// Pattern to rewrite a subview op with constant arguments.
2712
class SubViewOpConstantArgumentFolder final
2713
    : public OpRewritePattern<SubViewOp> {
2714
public:
2715
  using OpRewritePattern<SubViewOp>::OpRewritePattern;
2716
2717
  LogicalResult matchAndRewrite(SubViewOp subViewOp,
2718
0
                                PatternRewriter &rewriter) const override {
2719
0
    // No constant operand, just return;
2720
0
    if (llvm::none_of(subViewOp.getOperands(), [](Value operand) {
2721
0
          return matchPattern(operand, m_ConstantIndex());
2722
0
        }))
2723
0
      return failure();
2724
0
2725
0
    // At least one of offsets/sizes/strides is a new constant.
2726
0
    // Form the new list of operands and constant attributes from the existing.
2727
0
    SmallVector<Value, 8> newOffsets(subViewOp.offsets());
2728
0
    SmallVector<int64_t, 8> newStaticOffsets =
2729
0
        extractFromI64ArrayAttr(subViewOp.static_offsets());
2730
0
    assert(newStaticOffsets.size() == subViewOp.getRank());
2731
0
    canonicalizeSubViewPart(newOffsets, newStaticOffsets,
2732
0
                            ShapedType::isDynamicStrideOrOffset);
2733
0
2734
0
    SmallVector<Value, 8> newSizes(subViewOp.sizes());
2735
0
    SmallVector<int64_t, 8> newStaticSizes =
2736
0
        extractFromI64ArrayAttr(subViewOp.static_sizes());
2737
0
    assert(newStaticOffsets.size() == subViewOp.getRank());
2738
0
    canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
2739
0
2740
0
    SmallVector<Value, 8> newStrides(subViewOp.strides());
2741
0
    SmallVector<int64_t, 8> newStaticStrides =
2742
0
        extractFromI64ArrayAttr(subViewOp.static_strides());
2743
0
    assert(newStaticOffsets.size() == subViewOp.getRank());
2744
0
    canonicalizeSubViewPart(newStrides, newStaticStrides,
2745
0
                            ShapedType::isDynamicStrideOrOffset);
2746
0
2747
0
    // Create the new op in canonical form.
2748
0
    auto newSubViewOp = rewriter.create<SubViewOp>(
2749
0
        subViewOp.getLoc(), subViewOp.source(), newStaticOffsets,
2750
0
        newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides);
2751
0
2752
0
    // Insert a memref_cast for compatibility of the uses of the op.
2753
0
    rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
2754
0
                                              subViewOp.getType());
2755
0
2756
0
    return success();
2757
0
  }
2758
};
2759
2760
} // end anonymous namespace
2761
2762
/// Determines whether MemRefCastOp casts to a more dynamic version of the
2763
/// source memref. This is useful to to fold a memref_cast into a consuming op
2764
/// and implement canonicalization patterns for ops in different dialects that
2765
/// may consume the results of memref_cast operations. Such foldable memref_cast
2766
/// operations are typically inserted as `view` and `subview` ops are
2767
/// canonicalized, to preserve the type compatibility of their uses.
2768
///
2769
/// Returns true when all conditions are met:
2770
/// 1. source and result are ranked memrefs with strided semantics and same
2771
/// element type and rank.
2772
/// 2. each of the source's size, offset or stride has more static information
2773
/// than the corresponding result's size, offset or stride.
2774
///
2775
/// Example 1:
2776
/// ```mlir
2777
///   %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
2778
///   %2 = consumer %1 ... : memref<?x?xf32> ...
2779
/// ```
2780
///
2781
/// may fold into:
2782
///
2783
/// ```mlir
2784
///   %2 = consumer %0 ... : memref<8x16xf32> ...
2785
/// ```
2786
///
2787
/// Example 2:
2788
/// ```
2789
///   %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
2790
///          to memref<?x?xf32>
2791
///   consumer %1 : memref<?x?xf32> ...
2792
/// ```
2793
///
2794
/// may fold into:
2795
///
2796
/// ```
2797
///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
2798
/// ```
2799
0
bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
2800
0
  MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
2801
0
  MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
2802
0
2803
0
  // Requires ranked MemRefType.
2804
0
  if (!sourceType || !resultType)
2805
0
    return false;
2806
0
2807
0
  // Requires same elemental type.
2808
0
  if (sourceType.getElementType() != resultType.getElementType())
2809
0
    return false;
2810
0
2811
0
  // Requires same rank.
2812
0
  if (sourceType.getRank() != resultType.getRank())
2813
0
    return false;
2814
0
2815
0
  // Only fold casts between strided memref forms.
2816
0
  int64_t sourceOffset, resultOffset;
2817
0
  SmallVector<int64_t, 4> sourceStrides, resultStrides;
2818
0
  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
2819
0
      failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
2820
0
    return false;
2821
0
2822
0
  // If cast is towards more static sizes along any dimension, don't fold.
2823
0
  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
2824
0
    auto ss = std::get<0>(it), st = std::get<1>(it);
2825
0
    if (ss != st)
2826
0
      if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
2827
0
        return false;
2828
0
  }
2829
0
2830
0
  // If cast is towards more static offset along any dimension, don't fold.
2831
0
  if (sourceOffset != resultOffset)
2832
0
    if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
2833
0
        !MemRefType::isDynamicStrideOrOffset(resultOffset))
2834
0
      return false;
2835
0
2836
0
  // If cast is towards more static strides along any dimension, don't fold.
2837
0
  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
2838
0
    auto ss = std::get<0>(it), st = std::get<1>(it);
2839
0
    if (ss != st)
2840
0
      if (MemRefType::isDynamicStrideOrOffset(ss) &&
2841
0
          !MemRefType::isDynamicStrideOrOffset(st))
2842
0
        return false;
2843
0
  }
2844
0
2845
0
  return true;
2846
0
}
2847
2848
namespace {
2849
/// Pattern to rewrite a subview op with MemRefCast arguments.
2850
/// This essentially pushes memref_cast past its consuming subview when
2851
/// `canFoldIntoConsumerOp` is true.
2852
///
2853
/// Example:
2854
/// ```
2855
///   %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
2856
///   %1 = subview %0[0, 0][3, 4][1, 1] :
2857
///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2858
/// ```
2859
/// is rewritten into:
2860
/// ```
2861
///   %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2862
///   %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2863
///     memref<3x4xf32, offset:?, strides:[?, 1]>
2864
/// ```
2865
class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2866
public:
2867
  using OpRewritePattern<SubViewOp>::OpRewritePattern;
2868
2869
  LogicalResult matchAndRewrite(SubViewOp subViewOp,
2870
0
                                PatternRewriter &rewriter) const override {
2871
0
    // Any constant operand, just return to let SubViewOpConstantFolder kick in.
2872
0
    if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2873
0
          return matchPattern(operand, m_ConstantIndex());
2874
0
        }))
2875
0
      return failure();
2876
0
2877
0
    auto castOp = subViewOp.source().getDefiningOp<MemRefCastOp>();
2878
0
    if (!castOp)
2879
0
      return failure();
2880
0
2881
0
    if (!canFoldIntoConsumerOp(castOp))
2882
0
      return failure();
2883
0
2884
0
    /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
2885
0
    /// the cast source operand type and the SubViewOp static information. This
2886
0
    /// is the resulting type if the MemRefCastOp were folded.
2887
0
    Type resultType = SubViewOp::inferSubViewResultType(
2888
0
        castOp.source().getType().cast<MemRefType>(),
2889
0
        extractFromI64ArrayAttr(subViewOp.static_offsets()),
2890
0
        extractFromI64ArrayAttr(subViewOp.static_sizes()),
2891
0
        extractFromI64ArrayAttr(subViewOp.static_strides()));
2892
0
    Value newSubView = rewriter.create<SubViewOp>(
2893
0
        subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2894
0
        subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2895
0
        subViewOp.static_sizes(), subViewOp.static_strides());
2896
0
    rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, subViewOp.getType(),
2897
0
                                              newSubView);
2898
0
    return success();
2899
0
  }
2900
};
2901
} // namespace
2902
2903
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2904
0
                                            MLIRContext *context) {
2905
0
  results.insert<SubViewOpConstantArgumentFolder, SubViewOpMemRefCastFolder>(
2906
0
      context);
2907
0
}
2908
2909
//===----------------------------------------------------------------------===//
2910
// TensorCastOp
2911
//===----------------------------------------------------------------------===//
2912
2913
0
bool TensorCastOp::areCastCompatible(Type a, Type b) {
2914
0
  auto aT = a.dyn_cast<TensorType>();
2915
0
  auto bT = b.dyn_cast<TensorType>();
2916
0
  if (!aT || !bT)
2917
0
    return false;
2918
0
2919
0
  if (aT.getElementType() != bT.getElementType())
2920
0
    return false;
2921
0
2922
0
  return succeeded(verifyCompatibleShape(aT, bT));
2923
0
}
2924
2925
0
OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
2926
0
  return impl::foldCastOp(*this);
2927
0
}
2928
2929
//===----------------------------------------------------------------------===//
2930
// Helpers for Tensor[Load|Store]Op
2931
//===----------------------------------------------------------------------===//
2932
2933
0
static Type getTensorTypeFromMemRefType(Type type) {
2934
0
  if (auto memref = type.dyn_cast<MemRefType>())
2935
0
    return RankedTensorType::get(memref.getShape(), memref.getElementType());
2936
0
  return NoneType::get(type.getContext());
2937
0
}
2938
2939
//===----------------------------------------------------------------------===//
2940
// TruncateIOp
2941
//===----------------------------------------------------------------------===//
2942
2943
0
static LogicalResult verify(TruncateIOp op) {
2944
0
  auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2945
0
  auto dstType = getElementTypeOrSelf(op.getType());
2946
0
2947
0
  if (srcType.isa<IndexType>())
2948
0
    return op.emitError() << srcType << " is not a valid operand type";
2949
0
  if (dstType.isa<IndexType>())
2950
0
    return op.emitError() << dstType << " is not a valid result type";
2951
0
2952
0
  if (srcType.cast<IntegerType>().getWidth() <=
2953
0
      dstType.cast<IntegerType>().getWidth())
2954
0
    return op.emitError("operand type ")
2955
0
           << srcType << " must be wider than result type " << dstType;
2956
0
2957
0
  return success();
2958
0
}
2959
2960
//===----------------------------------------------------------------------===//
2961
// UnsignedDivIOp
2962
//===----------------------------------------------------------------------===//
2963
2964
0
OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
2965
0
  assert(operands.size() == 2 && "binary operation takes two operands");
2966
0
2967
0
  // Don't fold if it would require a division by zero.
2968
0
  bool div0 = false;
2969
0
  auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2970
0
    if (div0 || !b) {
2971
0
      div0 = true;
2972
0
      return a;
2973
0
    }
2974
0
    return a.udiv(b);
2975
0
  });
2976
0
2977
0
  // Fold out division by one. Assumes all tensors of all ones are splats.
2978
0
  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2979
0
    if (rhs.getValue() == 1)
2980
0
      return lhs();
2981
0
  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2982
0
    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2983
0
      return lhs();
2984
0
  }
2985
0
2986
0
  return div0 ? Attribute() : result;
2987
0
}
2988
2989
//===----------------------------------------------------------------------===//
2990
// UnsignedRemIOp
2991
//===----------------------------------------------------------------------===//
2992
2993
0
OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
2994
0
  assert(operands.size() == 2 && "remi_unsigned takes two operands");
2995
0
2996
0
  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
2997
0
  if (!rhs)
2998
0
    return {};
2999
0
  auto rhsValue = rhs.getValue();
3000
0
3001
0
  // x % 1 = 0
3002
0
  if (rhsValue.isOneValue())
3003
0
    return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
3004
0
3005
0
  // Don't fold if it requires division by zero.
3006
0
  if (rhsValue.isNullValue())
3007
0
    return {};
3008
0
3009
0
  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
3010
0
  if (!lhs)
3011
0
    return {};
3012
0
  return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
3013
0
}
3014
3015
//===----------------------------------------------------------------------===//
3016
// ViewOp
3017
//===----------------------------------------------------------------------===//
3018
3019
0
static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
3020
0
  OpAsmParser::OperandType srcInfo;
3021
0
  SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
3022
0
  SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
3023
0
  auto indexType = parser.getBuilder().getIndexType();
3024
0
  Type srcType, dstType;
3025
0
  llvm::SMLoc offsetLoc;
3026
0
  if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
3027
0
      parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
3028
0
    return failure();
3029
0
3030
0
  if (offsetInfo.size() != 1)
3031
0
    return parser.emitError(offsetLoc) << "expects 1 offset operand";
3032
0
3033
0
  return failure(
3034
0
      parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
3035
0
      parser.parseOptionalAttrDict(result.attributes) ||
3036
0
      parser.parseColonType(srcType) ||
3037
0
      parser.resolveOperand(srcInfo, srcType, result.operands) ||
3038
0
      parser.resolveOperands(offsetInfo, indexType, result.operands) ||
3039
0
      parser.resolveOperands(sizesInfo, indexType, result.operands) ||
3040
0
      parser.parseKeywordType("to", dstType) ||
3041
0
      parser.addTypeToList(dstType, result.types));
3042
0
}
3043
3044
0
static void print(OpAsmPrinter &p, ViewOp op) {
3045
0
  p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
3046
0
  p.printOperand(op.byte_shift());
3047
0
  p << "][" << op.sizes() << ']';
3048
0
  p.printOptionalAttrDict(op.getAttrs());
3049
0
  p << " : " << op.getOperand(0).getType() << " to " << op.getType();
3050
0
}
3051
3052
0
static LogicalResult verify(ViewOp op) {
3053
0
  auto baseType = op.getOperand(0).getType().cast<MemRefType>();
3054
0
  auto viewType = op.getType();
3055
0
3056
0
  // The base memref should have identity layout map (or none).
3057
0
  if (baseType.getAffineMaps().size() > 1 ||
3058
0
      (baseType.getAffineMaps().size() == 1 &&
3059
0
       !baseType.getAffineMaps()[0].isIdentity()))
3060
0
    return op.emitError("unsupported map for base memref type ") << baseType;
3061
0
3062
0
  // The result memref should have identity layout map (or none).
3063
0
  if (viewType.getAffineMaps().size() > 1 ||
3064
0
      (viewType.getAffineMaps().size() == 1 &&
3065
0
       !viewType.getAffineMaps()[0].isIdentity()))
3066
0
    return op.emitError("unsupported map for result memref type ") << viewType;
3067
0
3068
0
  // The base memref and the view memref should be in the same memory space.
3069
0
  if (baseType.getMemorySpace() != viewType.getMemorySpace())
3070
0
    return op.emitError("different memory spaces specified for base memref "
3071
0
                        "type ")
3072
0
           << baseType << " and view memref type " << viewType;
3073
0
3074
0
  // Verify that we have the correct number of sizes for the result type.
3075
0
  unsigned numDynamicDims = viewType.getNumDynamicDims();
3076
0
  if (op.sizes().size() != numDynamicDims)
3077
0
    return op.emitError("incorrect number of size operands for type ")
3078
0
           << viewType;
3079
0
3080
0
  return success();
3081
0
}
3082
3083
0
Value ViewOp::getViewSource() { return source(); }
3084
3085
namespace {
3086
3087
struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3088
  using OpRewritePattern<ViewOp>::OpRewritePattern;
3089
3090
  LogicalResult matchAndRewrite(ViewOp viewOp,
3091
0
                                PatternRewriter &rewriter) const override {
3092
0
    // Return if none of the operands are constants.
3093
0
    if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3094
0
          return matchPattern(operand, m_ConstantIndex());
3095
0
        }))
3096
0
      return failure();
3097
0
3098
0
    // Get result memref type.
3099
0
    auto memrefType = viewOp.getType();
3100
0
3101
0
    // Get offset from old memref view type 'memRefType'.
3102
0
    int64_t oldOffset;
3103
0
    SmallVector<int64_t, 4> oldStrides;
3104
0
    if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3105
0
      return failure();
3106
0
    assert(oldOffset == 0 && "Expected 0 offset");
3107
0
3108
0
    SmallVector<Value, 4> newOperands;
3109
0
3110
0
    // Offset cannot be folded into result type.
3111
0
3112
0
    // Fold any dynamic dim operands which are produced by a constant.
3113
0
    SmallVector<int64_t, 4> newShapeConstants;
3114
0
    newShapeConstants.reserve(memrefType.getRank());
3115
0
3116
0
    unsigned dynamicDimPos = 0;
3117
0
    unsigned rank = memrefType.getRank();
3118
0
    for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3119
0
      int64_t dimSize = memrefType.getDimSize(dim);
3120
0
      // If this is already static dimension, keep it.
3121
0
      if (!ShapedType::isDynamic(dimSize)) {
3122
0
        newShapeConstants.push_back(dimSize);
3123
0
        continue;
3124
0
      }
3125
0
      auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
3126
0
      if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
3127
0
        // Dynamic shape dimension will be folded.
3128
0
        newShapeConstants.push_back(constantIndexOp.getValue());
3129
0
      } else {
3130
0
        // Dynamic shape dimension not folded; copy operand from old memref.
3131
0
        newShapeConstants.push_back(dimSize);
3132
0
        newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
3133
0
      }
3134
0
      dynamicDimPos++;
3135
0
    }
3136
0
3137
0
    // Create new memref type with constant folded dims.
3138
0
    MemRefType newMemRefType =
3139
0
        MemRefType::Builder(memrefType).setShape(newShapeConstants);
3140
0
    // Nothing new, don't fold.
3141
0
    if (newMemRefType == memrefType)
3142
0
      return failure();
3143
0
3144
0
    // Create new ViewOp.
3145
0
    auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
3146
0
                                             viewOp.getOperand(0),
3147
0
                                             viewOp.byte_shift(), newOperands);
3148
0
    // Insert a cast so we have the same type as the old memref type.
3149
0
    rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
3150
0
                                              viewOp.getType());
3151
0
    return success();
3152
0
  }
3153
};
3154
3155
struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3156
  using OpRewritePattern<ViewOp>::OpRewritePattern;
3157
3158
  LogicalResult matchAndRewrite(ViewOp viewOp,
3159
0
                                PatternRewriter &rewriter) const override {
3160
0
    Value memrefOperand = viewOp.getOperand(0);
3161
0
    MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp<MemRefCastOp>();
3162
0
    if (!memrefCastOp)
3163
0
      return failure();
3164
0
    Value allocOperand = memrefCastOp.getOperand();
3165
0
    AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3166
0
    if (!allocOp)
3167
0
      return failure();
3168
0
    rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3169
0
                                        viewOp.byte_shift(), viewOp.sizes());
3170
0
    return success();
3171
0
  }
3172
};
3173
3174
} // end anonymous namespace
3175
3176
void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3177
0
                                         MLIRContext *context) {
3178
0
  results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3179
0
}
3180
3181
//===----------------------------------------------------------------------===//
3182
// XOrOp
3183
//===----------------------------------------------------------------------===//
3184
3185
0
OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
3186
0
  /// xor(x, 0) -> x
3187
0
  if (matchPattern(rhs(), m_Zero()))
3188
0
    return lhs();
3189
0
  /// xor(x,x) -> 0
3190
0
  if (lhs() == rhs())
3191
0
    return Builder(getContext()).getZeroAttr(getType());
3192
0
3193
0
  return constFoldBinaryOp<IntegerAttr>(operands,
3194
0
                                        [](APInt a, APInt b) { return a ^ b; });
3195
0
}
3196
3197
//===----------------------------------------------------------------------===//
3198
// ZeroExtendIOp
3199
//===----------------------------------------------------------------------===//
3200
3201
0
static LogicalResult verify(ZeroExtendIOp op) {
3202
0
  auto srcType = getElementTypeOrSelf(op.getOperand().getType());
3203
0
  auto dstType = getElementTypeOrSelf(op.getType());
3204
0
3205
0
  if (srcType.isa<IndexType>())
3206
0
    return op.emitError() << srcType << " is not a valid operand type";
3207
0
  if (dstType.isa<IndexType>())
3208
0
    return op.emitError() << dstType << " is not a valid result type";
3209
0
3210
0
  if (srcType.cast<IntegerType>().getWidth() >=
3211
0
      dstType.cast<IntegerType>().getWidth())
3212
0
    return op.emitError("result type ")
3213
0
           << dstType << " must be wider than operand type " << srcType;
3214
0
3215
0
  return success();
3216
0
}
3217
3218
//===----------------------------------------------------------------------===//
3219
// TableGen'd op method definitions
3220
//===----------------------------------------------------------------------===//
3221
3222
#define GET_OP_CLASSES
3223
#include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"