Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
Line
Count
Source (jump to first uncovered line)
1
//===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file defines convenience types for working with standard operations
10
// in the MLIR operation set.
11
//
12
//===----------------------------------------------------------------------===//
13
14
#ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
15
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
16
17
#include "mlir/IR/Builders.h"
18
#include "mlir/IR/Dialect.h"
19
#include "mlir/IR/OpImplementation.h"
20
#include "mlir/IR/StandardTypes.h"
21
#include "mlir/Interfaces/CallInterfaces.h"
22
#include "mlir/Interfaces/ControlFlowInterfaces.h"
23
#include "mlir/Interfaces/SideEffectInterfaces.h"
24
#include "mlir/Interfaces/ViewLikeInterface.h"
25
26
// Pull in all enum type definitions and utility function declarations.
27
#include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc"
28
29
namespace mlir {
30
class AffineMap;
31
class Builder;
32
class FuncOp;
33
class OpBuilder;
34
35
#define GET_OP_CLASSES
36
#include "mlir/Dialect/StandardOps/IR/Ops.h.inc"
37
38
#include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc"
39
40
/// This is a refinement of the "constant" op for the case where it is
41
/// returning a float value of FloatType.
42
///
43
///   %1 = "std.constant"(){value: 42.0} : bf16
44
///
45
class ConstantFloatOp : public ConstantOp {
46
public:
47
  using ConstantOp::ConstantOp;
48
49
  /// Builds a constant float op producing a float of the specified type.
50
  static void build(OpBuilder &builder, OperationState &result,
51
                    const APFloat &value, FloatType type);
52
53
0
  APFloat getValue() { return getAttrOfType<FloatAttr>("value").getValue(); }
54
55
  static bool classof(Operation *op);
56
};
57
58
/// This is a refinement of the "constant" op for the case where it is
59
/// returning an integer value of IntegerType.
60
///
61
///   %1 = "std.constant"(){value: 42} : i32
62
///
63
class ConstantIntOp : public ConstantOp {
64
public:
65
  using ConstantOp::ConstantOp;
66
  /// Build a constant int op producing an integer of the specified width.
67
  static void build(OpBuilder &builder, OperationState &result, int64_t value,
68
                    unsigned width);
69
70
  /// Build a constant int op producing an integer with the specified type,
71
  /// which must be an integer type.
72
  static void build(OpBuilder &builder, OperationState &result, int64_t value,
73
                    Type type);
74
75
0
  int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
76
77
  static bool classof(Operation *op);
78
};
79
80
/// This is a refinement of the "constant" op for the case where it is
81
/// returning an integer value of Index type.
82
///
83
///   %1 = "std.constant"(){value: 99} : () -> index
84
///
85
class ConstantIndexOp : public ConstantOp {
86
public:
87
  using ConstantOp::ConstantOp;
88
89
  /// Build a constant int op producing an index.
90
  static void build(OpBuilder &builder, OperationState &result, int64_t value);
91
92
0
  int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
93
94
  static bool classof(Operation *op);
95
};
96
97
// DmaStartOp starts a non-blocking DMA operation that transfers data from a
98
// source memref to a destination memref. The source and destination memref need
99
// not be of the same dimensionality, but need to have the same elemental type.
100
// The operands include the source and destination memref's each followed by its
101
// indices, size of the data transfer in terms of the number of elements (of the
102
// elemental type of the memref), a tag memref with its indices, and optionally
103
// at the end, a stride and a number_of_elements_per_stride arguments. The tag
104
// location is used by a DmaWaitOp to check for completion. The indices of the
105
// source memref, destination memref, and the tag memref have the same
106
// restrictions as any load/store. The optional stride arguments should be of
107
// 'index' type, and specify a stride for the slower memory space (memory space
108
// with a lower memory space id), transferring chunks of
109
// number_of_elements_per_stride every stride until %num_elements are
110
// transferred. Either both or no stride arguments should be specified.
111
//
112
// For example, a DmaStartOp operation that transfers 256 elements of a memref
113
// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
114
// 1 at indices [%k, %l], would be specified as follows:
115
//
116
//   %num_elements = constant 256
117
//   %idx = constant 0 : index
118
//   %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
119
//   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
120
//     memref<40 x 128 x f32>, (d0) -> (d0), 0>,
121
//     memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
122
//     memref<1 x i32>, (d0) -> (d0), 2>
123
//
124
//   If %stride and %num_elt_per_stride are specified, the DMA is expected to
125
//   transfer %num_elt_per_stride elements every %stride elements apart from
126
//   memory space 0 until %num_elements are transferred.
127
//
128
//   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
129
//             %num_elt_per_stride :
130
//
131
// TODO(mlir-team): add additional operands to allow source and destination
132
// striding, and multiple stride levels.
133
// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
134
class DmaStartOp
135
    : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
136
public:
137
  using Op::Op;
138
139
  static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
140
                    ValueRange srcIndices, Value destMemRef,
141
                    ValueRange destIndices, Value numElements, Value tagMemRef,
142
                    ValueRange tagIndices, Value stride = nullptr,
143
                    Value elementsPerStride = nullptr);
144
145
  // Returns the source MemRefType for this DMA operation.
146
0
  Value getSrcMemRef() { return getOperand(0); }
147
  // Returns the rank (number of indices) of the source MemRefType.
148
0
  unsigned getSrcMemRefRank() {
149
0
    return getSrcMemRef().getType().cast<MemRefType>().getRank();
150
0
  }
151
  // Returns the source memref indices for this DMA operation.
152
0
  operand_range getSrcIndices() {
153
0
    return {getOperation()->operand_begin() + 1,
154
0
            getOperation()->operand_begin() + 1 + getSrcMemRefRank()};
155
0
  }
156
157
  // Returns the destination MemRefType for this DMA operations.
158
0
  Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
159
  // Returns the rank (number of indices) of the destination MemRefType.
160
0
  unsigned getDstMemRefRank() {
161
0
    return getDstMemRef().getType().cast<MemRefType>().getRank();
162
0
  }
163
0
  unsigned getSrcMemorySpace() {
164
0
    return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace();
165
0
  }
166
0
  unsigned getDstMemorySpace() {
167
0
    return getDstMemRef().getType().cast<MemRefType>().getMemorySpace();
168
0
  }
169
170
  // Returns the destination memref indices for this DMA operation.
171
0
  operand_range getDstIndices() {
172
0
    return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1,
173
0
            getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 +
174
0
                getDstMemRefRank()};
175
0
  }
176
177
  // Returns the number of elements being transferred by this DMA operation.
178
0
  Value getNumElements() {
179
0
    return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
180
0
  }
181
182
  // Returns the Tag MemRef for this DMA operation.
183
0
  Value getTagMemRef() {
184
0
    return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
185
0
  }
186
  // Returns the rank (number of indices) of the tag MemRefType.
187
0
  unsigned getTagMemRefRank() {
188
0
    return getTagMemRef().getType().cast<MemRefType>().getRank();
189
0
  }
190
191
  // Returns the tag memref index for this DMA operation.
192
0
  operand_range getTagIndices() {
193
0
    unsigned tagIndexStartPos =
194
0
        1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
195
0
    return {getOperation()->operand_begin() + tagIndexStartPos,
196
0
            getOperation()->operand_begin() + tagIndexStartPos +
197
0
                getTagMemRefRank()};
198
0
  }
199
200
  /// Returns true if this is a DMA from a faster memory space to a slower one.
201
0
  bool isDestMemorySpaceFaster() {
202
0
    return (getSrcMemorySpace() < getDstMemorySpace());
203
0
  }
204
205
  /// Returns true if this is a DMA from a slower memory space to a faster one.
206
0
  bool isSrcMemorySpaceFaster() {
207
0
    // Assumes that a lower number is for a slower memory space.
208
0
    return (getDstMemorySpace() < getSrcMemorySpace());
209
0
  }
210
211
  /// Given a DMA start operation, returns the operand position of either the
212
  /// source or destination memref depending on the one that is at the higher
213
  /// level of the memory hierarchy. Asserts failure if neither is true.
214
0
  unsigned getFasterMemPos() {
215
0
    assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
216
0
    return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
217
0
  }
218
219
0
  static StringRef getOperationName() { return "std.dma_start"; }
220
  static ParseResult parse(OpAsmParser &parser, OperationState &result);
221
  void print(OpAsmPrinter &p);
222
  LogicalResult verify();
223
224
  LogicalResult fold(ArrayRef<Attribute> cstOperands,
225
                     SmallVectorImpl<OpFoldResult> &results);
226
227
0
  bool isStrided() {
228
0
    return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
229
0
                                   1 + 1 + getTagMemRefRank();
230
0
  }
231
232
0
  Value getStride() {
233
0
    if (!isStrided())
234
0
      return nullptr;
235
0
    return getOperand(getNumOperands() - 1 - 1);
236
0
  }
237
238
0
  Value getNumElementsPerStride() {
239
0
    if (!isStrided())
240
0
      return nullptr;
241
0
    return getOperand(getNumOperands() - 1);
242
0
  }
243
};
244
245
// DmaWaitOp blocks until the completion of a DMA operation associated with the
246
// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
247
// with the same restrictions as any load/store index. %num_elements is the
248
// number of elements associated with the DMA operation. For example:
249
//
250
//   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
251
//     memref<2048 x f32>, (d0) -> (d0), 0>,
252
//     memref<256 x f32>, (d0) -> (d0), 1>
253
//     memref<1 x i32>, (d0) -> (d0), 2>
254
//   ...
255
//   ...
256
//   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
257
//
258
class DmaWaitOp
259
    : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
260
public:
261
  using Op::Op;
262
263
  static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
264
                    ValueRange tagIndices, Value numElements);
265
266
0
  static StringRef getOperationName() { return "std.dma_wait"; }
267
268
  // Returns the Tag MemRef associated with the DMA operation being waited on.
269
0
  Value getTagMemRef() { return getOperand(0); }
270
271
  // Returns the tag memref index for this DMA operation.
272
0
  operand_range getTagIndices() {
273
0
    return {getOperation()->operand_begin() + 1,
274
0
            getOperation()->operand_begin() + 1 + getTagMemRefRank()};
275
0
  }
276
277
  // Returns the rank (number of indices) of the tag memref.
278
0
  unsigned getTagMemRefRank() {
279
0
    return getTagMemRef().getType().cast<MemRefType>().getRank();
280
0
  }
281
282
  // Returns the number of elements transferred in the associated DMA operation.
283
0
  Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
284
285
  static ParseResult parse(OpAsmParser &parser, OperationState &result);
286
  void print(OpAsmPrinter &p);
287
  LogicalResult fold(ArrayRef<Attribute> cstOperands,
288
                     SmallVectorImpl<OpFoldResult> &results);
289
  LogicalResult verify();
290
};
291
292
/// Prints dimension and symbol list.
293
void printDimAndSymbolList(Operation::operand_iterator begin,
294
                           Operation::operand_iterator end, unsigned numDims,
295
                           OpAsmPrinter &p);
296
297
/// Parses dimension and symbol list and returns true if parsing failed.
298
ParseResult parseDimAndSymbolList(OpAsmParser &parser,
299
                                  SmallVectorImpl<Value> &operands,
300
                                  unsigned &numDims);
301
302
raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
303
304
/// Determines whether MemRefCastOp casts to a more dynamic version of the
305
/// source memref. This is useful to to fold a memref_cast into a consuming op
306
/// and implement canonicalization patterns for ops in different dialects that
307
/// may consume the results of memref_cast operations. Such foldable memref_cast
308
/// operations are typically inserted as `view` and `subview` ops are
309
/// canonicalized, to preserve the type compatibility of their uses.
310
///
311
/// Returns true when all conditions are met:
312
/// 1. source and result are ranked memrefs with strided semantics and same
313
/// element type and rank.
314
/// 2. each of the source's size, offset or stride has more static information
315
/// than the corresponding result's size, offset or stride.
316
///
317
/// Example 1:
318
/// ```mlir
319
///   %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
320
///   %2 = consumer %1 ... : memref<?x?xf32> ...
321
/// ```
322
///
323
/// may fold into:
324
///
325
/// ```mlir
326
///   %2 = consumer %0 ... : memref<8x16xf32> ...
327
/// ```
328
///
329
/// Example 2:
330
/// ```
331
///   %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
332
///          to memref<?x?xf32>
333
///   consumer %1 : memref<?x?xf32> ...
334
/// ```
335
///
336
/// may fold into:
337
///
338
/// ```
339
///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
340
/// ```
341
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
342
} // end namespace mlir
343
344
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H