/home/arjun/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
Line | Count | Source (jump to first uncovered line) |
1 | | //===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This file defines convenience types for working with standard operations |
10 | | // in the MLIR operation set. |
11 | | // |
12 | | //===----------------------------------------------------------------------===// |
13 | | |
14 | | #ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H |
15 | | #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H |
16 | | |
17 | | #include "mlir/IR/Builders.h" |
18 | | #include "mlir/IR/Dialect.h" |
19 | | #include "mlir/IR/OpImplementation.h" |
20 | | #include "mlir/IR/StandardTypes.h" |
21 | | #include "mlir/Interfaces/CallInterfaces.h" |
22 | | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
23 | | #include "mlir/Interfaces/SideEffectInterfaces.h" |
24 | | #include "mlir/Interfaces/ViewLikeInterface.h" |
25 | | |
26 | | // Pull in all enum type definitions and utility function declarations. |
27 | | #include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc" |
28 | | |
29 | | namespace mlir { |
30 | | class AffineMap; |
31 | | class Builder; |
32 | | class FuncOp; |
33 | | class OpBuilder; |
34 | | |
35 | | #define GET_OP_CLASSES |
36 | | #include "mlir/Dialect/StandardOps/IR/Ops.h.inc" |
37 | | |
38 | | #include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc" |
39 | | |
40 | | /// This is a refinement of the "constant" op for the case where it is |
41 | | /// returning a float value of FloatType. |
42 | | /// |
43 | | /// %1 = "std.constant"(){value: 42.0} : bf16 |
44 | | /// |
45 | | class ConstantFloatOp : public ConstantOp { |
46 | | public: |
47 | | using ConstantOp::ConstantOp; |
48 | | |
49 | | /// Builds a constant float op producing a float of the specified type. |
50 | | static void build(OpBuilder &builder, OperationState &result, |
51 | | const APFloat &value, FloatType type); |
52 | | |
53 | 0 | APFloat getValue() { return getAttrOfType<FloatAttr>("value").getValue(); } |
54 | | |
55 | | static bool classof(Operation *op); |
56 | | }; |
57 | | |
58 | | /// This is a refinement of the "constant" op for the case where it is |
59 | | /// returning an integer value of IntegerType. |
60 | | /// |
61 | | /// %1 = "std.constant"(){value: 42} : i32 |
62 | | /// |
63 | | class ConstantIntOp : public ConstantOp { |
64 | | public: |
65 | | using ConstantOp::ConstantOp; |
66 | | /// Build a constant int op producing an integer of the specified width. |
67 | | static void build(OpBuilder &builder, OperationState &result, int64_t value, |
68 | | unsigned width); |
69 | | |
70 | | /// Build a constant int op producing an integer with the specified type, |
71 | | /// which must be an integer type. |
72 | | static void build(OpBuilder &builder, OperationState &result, int64_t value, |
73 | | Type type); |
74 | | |
75 | 0 | int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); } |
76 | | |
77 | | static bool classof(Operation *op); |
78 | | }; |
79 | | |
80 | | /// This is a refinement of the "constant" op for the case where it is |
81 | | /// returning an integer value of Index type. |
82 | | /// |
83 | | /// %1 = "std.constant"(){value: 99} : () -> index |
84 | | /// |
85 | | class ConstantIndexOp : public ConstantOp { |
86 | | public: |
87 | | using ConstantOp::ConstantOp; |
88 | | |
89 | | /// Build a constant int op producing an index. |
90 | | static void build(OpBuilder &builder, OperationState &result, int64_t value); |
91 | | |
92 | 0 | int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); } |
93 | | |
94 | | static bool classof(Operation *op); |
95 | | }; |
96 | | |
97 | | // DmaStartOp starts a non-blocking DMA operation that transfers data from a |
98 | | // source memref to a destination memref. The source and destination memref need |
99 | | // not be of the same dimensionality, but need to have the same elemental type. |
100 | | // The operands include the source and destination memref's each followed by its |
101 | | // indices, size of the data transfer in terms of the number of elements (of the |
102 | | // elemental type of the memref), a tag memref with its indices, and optionally |
103 | | // at the end, a stride and a number_of_elements_per_stride arguments. The tag |
104 | | // location is used by a DmaWaitOp to check for completion. The indices of the |
105 | | // source memref, destination memref, and the tag memref have the same |
106 | | // restrictions as any load/store. The optional stride arguments should be of |
107 | | // 'index' type, and specify a stride for the slower memory space (memory space |
108 | | // with a lower memory space id), transferring chunks of |
109 | | // number_of_elements_per_stride every stride until %num_elements are |
110 | | // transferred. Either both or no stride arguments should be specified. |
111 | | // |
112 | | // For example, a DmaStartOp operation that transfers 256 elements of a memref |
113 | | // '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space |
114 | | // 1 at indices [%k, %l], would be specified as follows: |
115 | | // |
116 | | // %num_elements = constant 256 |
117 | | // %idx = constant 0 : index |
118 | | // %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> |
119 | | // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] : |
120 | | // memref<40 x 128 x f32>, (d0) -> (d0), 0>, |
121 | | // memref<2 x 1024 x f32>, (d0) -> (d0), 1>, |
122 | | // memref<1 x i32>, (d0) -> (d0), 2> |
123 | | // |
124 | | // If %stride and %num_elt_per_stride are specified, the DMA is expected to |
125 | | // transfer %num_elt_per_stride elements every %stride elements apart from |
126 | | // memory space 0 until %num_elements are transferred. |
127 | | // |
128 | | // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, |
129 | | // %num_elt_per_stride : |
130 | | // |
131 | | // TODO(mlir-team): add additional operands to allow source and destination |
132 | | // striding, and multiple stride levels. |
133 | | // TODO(andydavis) Consider replacing src/dst memref indices with view memrefs. |
134 | | class DmaStartOp |
135 | | : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> { |
136 | | public: |
137 | | using Op::Op; |
138 | | |
139 | | static void build(OpBuilder &builder, OperationState &result, Value srcMemRef, |
140 | | ValueRange srcIndices, Value destMemRef, |
141 | | ValueRange destIndices, Value numElements, Value tagMemRef, |
142 | | ValueRange tagIndices, Value stride = nullptr, |
143 | | Value elementsPerStride = nullptr); |
144 | | |
145 | | // Returns the source MemRefType for this DMA operation. |
146 | 0 | Value getSrcMemRef() { return getOperand(0); } |
147 | | // Returns the rank (number of indices) of the source MemRefType. |
148 | 0 | unsigned getSrcMemRefRank() { |
149 | 0 | return getSrcMemRef().getType().cast<MemRefType>().getRank(); |
150 | 0 | } |
151 | | // Returns the source memref indices for this DMA operation. |
152 | 0 | operand_range getSrcIndices() { |
153 | 0 | return {getOperation()->operand_begin() + 1, |
154 | 0 | getOperation()->operand_begin() + 1 + getSrcMemRefRank()}; |
155 | 0 | } |
156 | | |
157 | | // Returns the destination MemRefType for this DMA operations. |
158 | 0 | Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } |
159 | | // Returns the rank (number of indices) of the destination MemRefType. |
160 | 0 | unsigned getDstMemRefRank() { |
161 | 0 | return getDstMemRef().getType().cast<MemRefType>().getRank(); |
162 | 0 | } |
163 | 0 | unsigned getSrcMemorySpace() { |
164 | 0 | return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace(); |
165 | 0 | } |
166 | 0 | unsigned getDstMemorySpace() { |
167 | 0 | return getDstMemRef().getType().cast<MemRefType>().getMemorySpace(); |
168 | 0 | } |
169 | | |
170 | | // Returns the destination memref indices for this DMA operation. |
171 | 0 | operand_range getDstIndices() { |
172 | 0 | return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1, |
173 | 0 | getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 + |
174 | 0 | getDstMemRefRank()}; |
175 | 0 | } |
176 | | |
177 | | // Returns the number of elements being transferred by this DMA operation. |
178 | 0 | Value getNumElements() { |
179 | 0 | return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); |
180 | 0 | } |
181 | | |
182 | | // Returns the Tag MemRef for this DMA operation. |
183 | 0 | Value getTagMemRef() { |
184 | 0 | return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); |
185 | 0 | } |
186 | | // Returns the rank (number of indices) of the tag MemRefType. |
187 | 0 | unsigned getTagMemRefRank() { |
188 | 0 | return getTagMemRef().getType().cast<MemRefType>().getRank(); |
189 | 0 | } |
190 | | |
191 | | // Returns the tag memref index for this DMA operation. |
192 | 0 | operand_range getTagIndices() { |
193 | 0 | unsigned tagIndexStartPos = |
194 | 0 | 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; |
195 | 0 | return {getOperation()->operand_begin() + tagIndexStartPos, |
196 | 0 | getOperation()->operand_begin() + tagIndexStartPos + |
197 | 0 | getTagMemRefRank()}; |
198 | 0 | } |
199 | | |
200 | | /// Returns true if this is a DMA from a faster memory space to a slower one. |
201 | 0 | bool isDestMemorySpaceFaster() { |
202 | 0 | return (getSrcMemorySpace() < getDstMemorySpace()); |
203 | 0 | } |
204 | | |
205 | | /// Returns true if this is a DMA from a slower memory space to a faster one. |
206 | 0 | bool isSrcMemorySpaceFaster() { |
207 | 0 | // Assumes that a lower number is for a slower memory space. |
208 | 0 | return (getDstMemorySpace() < getSrcMemorySpace()); |
209 | 0 | } |
210 | | |
211 | | /// Given a DMA start operation, returns the operand position of either the |
212 | | /// source or destination memref depending on the one that is at the higher |
213 | | /// level of the memory hierarchy. Asserts failure if neither is true. |
214 | 0 | unsigned getFasterMemPos() { |
215 | 0 | assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); |
216 | 0 | return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1; |
217 | 0 | } |
218 | | |
219 | 0 | static StringRef getOperationName() { return "std.dma_start"; } |
220 | | static ParseResult parse(OpAsmParser &parser, OperationState &result); |
221 | | void print(OpAsmPrinter &p); |
222 | | LogicalResult verify(); |
223 | | |
224 | | LogicalResult fold(ArrayRef<Attribute> cstOperands, |
225 | | SmallVectorImpl<OpFoldResult> &results); |
226 | | |
227 | 0 | bool isStrided() { |
228 | 0 | return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + |
229 | 0 | 1 + 1 + getTagMemRefRank(); |
230 | 0 | } |
231 | | |
232 | 0 | Value getStride() { |
233 | 0 | if (!isStrided()) |
234 | 0 | return nullptr; |
235 | 0 | return getOperand(getNumOperands() - 1 - 1); |
236 | 0 | } |
237 | | |
238 | 0 | Value getNumElementsPerStride() { |
239 | 0 | if (!isStrided()) |
240 | 0 | return nullptr; |
241 | 0 | return getOperand(getNumOperands() - 1); |
242 | 0 | } |
243 | | }; |
244 | | |
245 | | // DmaWaitOp blocks until the completion of a DMA operation associated with the |
246 | | // tag element '%tag[%index]'. %tag is a memref, and %index has to be an index |
247 | | // with the same restrictions as any load/store index. %num_elements is the |
248 | | // number of elements associated with the DMA operation. For example: |
249 | | // |
250 | | // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] : |
251 | | // memref<2048 x f32>, (d0) -> (d0), 0>, |
252 | | // memref<256 x f32>, (d0) -> (d0), 1> |
253 | | // memref<1 x i32>, (d0) -> (d0), 2> |
254 | | // ... |
255 | | // ... |
256 | | // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2> |
257 | | // |
258 | | class DmaWaitOp |
259 | | : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> { |
260 | | public: |
261 | | using Op::Op; |
262 | | |
263 | | static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, |
264 | | ValueRange tagIndices, Value numElements); |
265 | | |
266 | 0 | static StringRef getOperationName() { return "std.dma_wait"; } |
267 | | |
268 | | // Returns the Tag MemRef associated with the DMA operation being waited on. |
269 | 0 | Value getTagMemRef() { return getOperand(0); } |
270 | | |
271 | | // Returns the tag memref index for this DMA operation. |
272 | 0 | operand_range getTagIndices() { |
273 | 0 | return {getOperation()->operand_begin() + 1, |
274 | 0 | getOperation()->operand_begin() + 1 + getTagMemRefRank()}; |
275 | 0 | } |
276 | | |
277 | | // Returns the rank (number of indices) of the tag memref. |
278 | 0 | unsigned getTagMemRefRank() { |
279 | 0 | return getTagMemRef().getType().cast<MemRefType>().getRank(); |
280 | 0 | } |
281 | | |
282 | | // Returns the number of elements transferred in the associated DMA operation. |
283 | 0 | Value getNumElements() { return getOperand(1 + getTagMemRefRank()); } |
284 | | |
285 | | static ParseResult parse(OpAsmParser &parser, OperationState &result); |
286 | | void print(OpAsmPrinter &p); |
287 | | LogicalResult fold(ArrayRef<Attribute> cstOperands, |
288 | | SmallVectorImpl<OpFoldResult> &results); |
289 | | LogicalResult verify(); |
290 | | }; |
291 | | |
292 | | /// Prints dimension and symbol list. |
293 | | void printDimAndSymbolList(Operation::operand_iterator begin, |
294 | | Operation::operand_iterator end, unsigned numDims, |
295 | | OpAsmPrinter &p); |
296 | | |
297 | | /// Parses dimension and symbol list and returns true if parsing failed. |
298 | | ParseResult parseDimAndSymbolList(OpAsmParser &parser, |
299 | | SmallVectorImpl<Value> &operands, |
300 | | unsigned &numDims); |
301 | | |
302 | | raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range); |
303 | | |
304 | | /// Determines whether MemRefCastOp casts to a more dynamic version of the |
305 | | /// source memref. This is useful to to fold a memref_cast into a consuming op |
306 | | /// and implement canonicalization patterns for ops in different dialects that |
307 | | /// may consume the results of memref_cast operations. Such foldable memref_cast |
308 | | /// operations are typically inserted as `view` and `subview` ops are |
309 | | /// canonicalized, to preserve the type compatibility of their uses. |
310 | | /// |
311 | | /// Returns true when all conditions are met: |
312 | | /// 1. source and result are ranked memrefs with strided semantics and same |
313 | | /// element type and rank. |
314 | | /// 2. each of the source's size, offset or stride has more static information |
315 | | /// than the corresponding result's size, offset or stride. |
316 | | /// |
317 | | /// Example 1: |
318 | | /// ```mlir |
319 | | /// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32> |
320 | | /// %2 = consumer %1 ... : memref<?x?xf32> ... |
321 | | /// ``` |
322 | | /// |
323 | | /// may fold into: |
324 | | /// |
325 | | /// ```mlir |
326 | | /// %2 = consumer %0 ... : memref<8x16xf32> ... |
327 | | /// ``` |
328 | | /// |
329 | | /// Example 2: |
330 | | /// ``` |
331 | | /// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> |
332 | | /// to memref<?x?xf32> |
333 | | /// consumer %1 : memref<?x?xf32> ... |
334 | | /// ``` |
335 | | /// |
336 | | /// may fold into: |
337 | | /// |
338 | | /// ``` |
339 | | /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> |
340 | | /// ``` |
341 | | bool canFoldIntoConsumerOp(MemRefCastOp castOp); |
342 | | } // end namespace mlir |
343 | | |
344 | | #endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H |