/home/arjun/llvm-project/mlir/include/mlir/IR/AffineExprVisitor.h
Line | Count | Source (jump to first uncovered line) |
1 | | //===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This file defines the AffineExpr visitor class. |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | |
13 | | #ifndef MLIR_IR_AFFINE_EXPR_VISITOR_H |
14 | | #define MLIR_IR_AFFINE_EXPR_VISITOR_H |
15 | | |
16 | | #include "mlir/IR/AffineExpr.h" |
17 | | |
18 | | namespace mlir { |
19 | | |
20 | | /// Base class for AffineExpr visitors/walkers. |
21 | | /// |
22 | | /// AffineExpr visitors are used when you want to perform different actions |
23 | | /// for different kinds of AffineExprs without having to use lots of casts |
24 | | /// and a big switch instruction. |
25 | | /// |
26 | | /// To define your own visitor, inherit from this class, specifying your |
27 | | /// new type for the 'SubClass' template parameter, and "override" visitXXX |
28 | | /// functions in your class. This class is defined in terms of statically |
29 | | /// resolved overloading, not virtual functions. |
30 | | /// |
31 | | /// For example, here is a visitor that counts the number of for AffineDimExprs |
32 | | /// in an AffineExpr. |
33 | | /// |
34 | | /// /// Declare the class. Note that we derive from AffineExprVisitor |
35 | | /// /// instantiated with our new subclasses_ type. |
36 | | /// |
37 | | /// struct DimExprCounter : public AffineExprVisitor<DimExprCounter> { |
38 | | /// unsigned numDimExprs; |
39 | | /// DimExprCounter() : numDimExprs(0) {} |
40 | | /// void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; } |
41 | | /// }; |
42 | | /// |
43 | | /// And this class would be used like this: |
44 | | /// DimExprCounter dec; |
45 | | /// dec.visit(affineExpr); |
46 | | /// numDimExprs = dec.numDimExprs; |
47 | | /// |
48 | | /// AffineExprVisitor provides visit methods for the following binary affine |
49 | | /// op expressions: |
50 | | /// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr, |
51 | | /// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr, |
52 | | /// AffineBinaryCeilDivOpExpr. Note that default implementations of these |
53 | | /// methods will call the general AffineBinaryOpExpr method. |
54 | | /// |
55 | | /// In addition, visit methods are provided for the following affine |
56 | | // expressions: AffineConstantExpr, AffineDimExpr, and |
57 | | // AffineSymbolExpr. |
58 | | /// |
59 | | /// Note that if you don't implement visitXXX for some affine expression type, |
60 | | /// the visitXXX method for Instruction superclass will be invoked. |
61 | | /// |
62 | | /// Note that this class is specifically designed as a template to avoid |
63 | | /// virtual function call overhead. Defining and using a AffineExprVisitor is |
64 | | /// just as efficient as having your own switch instruction over the instruction |
65 | | /// opcode. |
66 | | |
67 | | template <typename SubClass, typename RetTy = void> class AffineExprVisitor { |
68 | | //===--------------------------------------------------------------------===// |
69 | | // Interface code - This is the public interface of the AffineExprVisitor |
70 | | // that you use to visit affine expressions... |
71 | | public: |
72 | | // Function to walk an AffineExpr (in post order). |
73 | | RetTy walkPostOrder(AffineExpr expr) { |
74 | | static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value, |
75 | | "Must instantiate with a derived type of AffineExprVisitor"); |
76 | | switch (expr.getKind()) { |
77 | | case AffineExprKind::Add: { |
78 | | auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); |
79 | | walkOperandsPostOrder(binOpExpr); |
80 | | return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr); |
81 | | } |
82 | | case AffineExprKind::Mul: { |
83 | | auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); |
84 | | walkOperandsPostOrder(binOpExpr); |
85 | | return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr); |
86 | | } |
87 | | case AffineExprKind::Mod: { |
88 | | auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); |
89 | | walkOperandsPostOrder(binOpExpr); |
90 | | return static_cast<SubClass *>(this)->visitModExpr(binOpExpr); |
91 | | } |
92 | | case AffineExprKind::FloorDiv: { |
93 | | auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); |
94 | | walkOperandsPostOrder(binOpExpr); |
95 | | return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr); |
96 | | } |
97 | | case AffineExprKind::CeilDiv: { |
98 | | auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); |
99 | | walkOperandsPostOrder(binOpExpr); |
100 | | return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr); |
101 | | } |
102 | | case AffineExprKind::Constant: |
103 | | return static_cast<SubClass *>(this)->visitConstantExpr( |
104 | | expr.cast<AffineConstantExpr>()); |
105 | | case AffineExprKind::DimId: |
106 | | return static_cast<SubClass *>(this)->visitDimExpr( |
107 | | expr.cast<AffineDimExpr>()); |
108 | | case AffineExprKind::SymbolId: |
109 | | return static_cast<SubClass *>(this)->visitSymbolExpr( |
110 | | expr.cast<AffineSymbolExpr>()); |
111 | | } |
112 | | } |
113 | | |
114 | | // Function to visit an AffineExpr. |
115 | | RetTy visit(AffineExpr expr) { |
116 | | static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value, |
117 | | "Must instantiate with a derived type of AffineExprVisitor"); |
118 | | switch (expr.getKind()) { |
119 | | case AffineExprKind::Add: { |
120 | | auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); |
121 | | return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr); |
122 | | } |
123 | | case AffineExprKind::Mul: { |
124 | | auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); |
125 | | return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr); |
126 | | } |
127 | | case AffineExprKind::Mod: { |
128 | | auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); |
129 | | return static_cast<SubClass *>(this)->visitModExpr(binOpExpr); |
130 | | } |
131 | | case AffineExprKind::FloorDiv: { |
132 | | auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); |
133 | | return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr); |
134 | | } |
135 | | case AffineExprKind::CeilDiv: { |
136 | | auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); |
137 | | return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr); |
138 | | } |
139 | | case AffineExprKind::Constant: |
140 | | return static_cast<SubClass *>(this)->visitConstantExpr( |
141 | | expr.cast<AffineConstantExpr>()); |
142 | | case AffineExprKind::DimId: |
143 | | return static_cast<SubClass *>(this)->visitDimExpr( |
144 | | expr.cast<AffineDimExpr>()); |
145 | | case AffineExprKind::SymbolId: |
146 | | return static_cast<SubClass *>(this)->visitSymbolExpr( |
147 | | expr.cast<AffineSymbolExpr>()); |
148 | | } |
149 | | llvm_unreachable("Unknown AffineExpr"); |
150 | | } |
151 | | |
152 | | //===--------------------------------------------------------------------===// |
153 | | // Visitation functions... these functions provide default fallbacks in case |
154 | | // the user does not specify what to do for a particular instruction type. |
155 | | // The default behavior is to generalize the instruction type to its subtype |
156 | | // and try visiting the subtype. All of this should be inlined perfectly, |
157 | | // because there are no virtual functions to get in the way. |
158 | | // |
159 | | |
160 | | // Default visit methods. Note that the default op-specific binary op visit |
161 | | // methods call the general visitAffineBinaryOpExpr visit method. |
162 | | void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {} |
163 | 0 | void visitAddExpr(AffineBinaryOpExpr expr) { |
164 | 0 | static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); |
165 | 0 | } |
166 | 0 | void visitMulExpr(AffineBinaryOpExpr expr) { |
167 | 0 | static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); |
168 | 0 | } |
169 | 0 | void visitModExpr(AffineBinaryOpExpr expr) { |
170 | 0 | static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); |
171 | 0 | } |
172 | 0 | void visitFloorDivExpr(AffineBinaryOpExpr expr) { |
173 | 0 | static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); |
174 | 0 | } |
175 | 0 | void visitCeilDivExpr(AffineBinaryOpExpr expr) { |
176 | 0 | static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); |
177 | 0 | } |
178 | | void visitConstantExpr(AffineConstantExpr expr) {} |
179 | | void visitDimExpr(AffineDimExpr expr) {} |
180 | | void visitSymbolExpr(AffineSymbolExpr expr) {} |
181 | | |
182 | | private: |
183 | | // Walk the operands - each operand is itself walked in post order. |
184 | 0 | void walkOperandsPostOrder(AffineBinaryOpExpr expr) { |
185 | 0 | walkPostOrder(expr.getLHS()); |
186 | 0 | walkPostOrder(expr.getRHS()); |
187 | 0 | } Unexecuted instantiation: _ZN4mlir17AffineExprVisitorINS_25SimpleAffineExprFlattenerEvE21walkOperandsPostOrderENS_18AffineBinaryOpExprE Unexecuted instantiation: AffineExpr.cpp:_ZN4mlir17AffineExprVisitorIZNKS_10AffineExpr4walkESt8functionIFvS1_EEE16AffineExprWalkervE21walkOperandsPostOrderENS_18AffineBinaryOpExprE |
188 | | }; |
189 | | |
190 | | // This class is used to flatten a pure affine expression (AffineExpr, |
191 | | // which is in a tree form) into a sum of products (w.r.t constants) when |
192 | | // possible, and in that process simplifying the expression. For a modulo, |
193 | | // floordiv, or a ceildiv expression, an additional identifier, called a local |
194 | | // identifier, is introduced to rewrite the expression as a sum of product |
195 | | // affine expression. Each local identifier is always and by construction a |
196 | | // floordiv of a pure add/mul affine function of dimensional, symbolic, and |
197 | | // other local identifiers, in a non-mutually recursive way. Hence, every local |
198 | | // identifier can ultimately always be recovered as an affine function of |
199 | | // dimensional and symbolic identifiers (involving floordiv's); note however |
200 | | // that by AffineExpr construction, some floordiv combinations are converted to |
201 | | // mod's. The result of the flattening is a flattened expression and a set of |
202 | | // constraints involving just the local variables. |
203 | | // |
204 | | // d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local |
205 | | // variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3. |
206 | | // |
207 | | // The simplification performed includes the accumulation of contributions for |
208 | | // each dimensional and symbolic identifier together, the simplification of |
209 | | // floordiv/ceildiv/mod expressions and other simplifications that in turn |
210 | | // happen as a result. A simplification that this flattening naturally performs |
211 | | // is of simplifying the numerator and denominator of floordiv/ceildiv, and |
212 | | // folding a modulo expression to a zero, if possible. Three examples are below: |
213 | | // |
214 | | // (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1 |
215 | | // (d0 - d0 mod 4 + 4) mod 4 simplified to 0 |
216 | | // (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1 |
217 | | // |
218 | | // The way the flattening works for the second example is as follows: d0 % 4 is |
219 | | // replaced by d0 - 4*q with q being introduced: the expression then simplifies |
220 | | // to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to |
221 | | // zero. Note that an affine expression may not always be expressible purely as |
222 | | // a sum of products involving just the original dimensional and symbolic |
223 | | // identifiers due to the presence of modulo/floordiv/ceildiv expressions that |
224 | | // may not be eliminated after simplification; in such cases, the final |
225 | | // expression can be reconstructed by replacing the local identifiers with their |
226 | | // corresponding explicit form stored in 'localExprs' (note that each of the |
227 | | // explicit forms itself would have been simplified). |
228 | | // |
229 | | // The expression walk method here performs a linear time post order walk that |
230 | | // performs the above simplifications through visit methods, with partial |
231 | | // results being stored in 'operandExprStack'. When a parent expr is visited, |
232 | | // the flattened expressions corresponding to its two operands would already be |
233 | | // on the stack - the parent expression looks at the two flattened expressions |
234 | | // and combines the two. It pops off the operand expressions and pushes the |
235 | | // combined result (although this is done in-place on its LHS operand expr). |
236 | | // When the walk is completed, the flattened form of the top-level expression |
237 | | // would be left on the stack. |
238 | | // |
239 | | // A flattener can be repeatedly used for multiple affine expressions that bind |
240 | | // to the same operands, for example, for all result expressions of an |
241 | | // AffineMap or AffineValueMap. In such cases, using it for multiple expressions |
242 | | // is more efficient than creating a new flattener for each expression since |
243 | | // common identical div and mod expressions appearing across different |
244 | | // expressions are mapped to the same local identifier (same column position in |
245 | | // 'localVarCst'). |
246 | | class SimpleAffineExprFlattener |
247 | | : public AffineExprVisitor<SimpleAffineExprFlattener> { |
248 | | public: |
249 | | // Flattend expression layout: [dims, symbols, locals, constant] |
250 | | // Stack that holds the LHS and RHS operands while visiting a binary op expr. |
251 | | // In future, consider adding a prepass to determine how big the SmallVector's |
252 | | // will be, and linearize this to std::vector<int64_t> to prevent |
253 | | // SmallVector moves on re-allocation. |
254 | | std::vector<SmallVector<int64_t, 8>> operandExprStack; |
255 | | |
256 | | unsigned numDims; |
257 | | unsigned numSymbols; |
258 | | |
259 | | // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's. |
260 | | unsigned numLocals; |
261 | | |
262 | | // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for |
263 | | // which new identifiers were introduced; if the latter do not get canceled |
264 | | // out, these expressions can be readily used to reconstruct the AffineExpr |
265 | | // (tree) form. Note that these expressions themselves would have been |
266 | | // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 |
267 | | // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) |
268 | | // ceildiv 2 would be the local expression stored for q. |
269 | | SmallVector<AffineExpr, 4> localExprs; |
270 | | |
271 | | SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols); |
272 | | |
273 | 0 | virtual ~SimpleAffineExprFlattener() = default; |
274 | | |
275 | | // Visitor method overrides. |
276 | | void visitMulExpr(AffineBinaryOpExpr expr); |
277 | | void visitAddExpr(AffineBinaryOpExpr expr); |
278 | | void visitDimExpr(AffineDimExpr expr); |
279 | | void visitSymbolExpr(AffineSymbolExpr expr); |
280 | | void visitConstantExpr(AffineConstantExpr expr); |
281 | | void visitCeilDivExpr(AffineBinaryOpExpr expr); |
282 | | void visitFloorDivExpr(AffineBinaryOpExpr expr); |
283 | | |
284 | | // |
285 | | // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 |
286 | | // |
287 | | // A mod expression "expr mod c" is thus flattened by introducing a new local |
288 | | // variable q (= expr floordiv c), such that expr mod c is replaced with |
289 | | // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. |
290 | | void visitModExpr(AffineBinaryOpExpr expr); |
291 | | |
292 | | protected: |
293 | | // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). |
294 | | // The local identifier added is always a floordiv of a pure add/mul affine |
295 | | // function of other identifiers, coefficients of which are specified in |
296 | | // dividend and with respect to a positive constant divisor. localExpr is the |
297 | | // simplified tree expression (AffineExpr) corresponding to the quantifier. |
298 | | virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, |
299 | | AffineExpr localExpr); |
300 | | |
301 | | private: |
302 | | // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 |
303 | | // A floordiv is thus flattened by introducing a new local variable q, and |
304 | | // replacing that expression with 'q' while adding the constraints |
305 | | // c * q <= expr <= c * q + c - 1 to localVarCst (done by |
306 | | // FlatAffineConstraints::addLocalFloorDiv). |
307 | | // |
308 | | // A ceildiv is similarly flattened: |
309 | | // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c |
310 | | void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil); |
311 | | |
312 | | int findLocalId(AffineExpr localExpr); |
313 | | |
314 | 0 | inline unsigned getNumCols() const { |
315 | 0 | return numDims + numSymbols + numLocals + 1; |
316 | 0 | } |
317 | 0 | inline unsigned getConstantIndex() const { return getNumCols() - 1; } |
318 | 0 | inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; } |
319 | 0 | inline unsigned getSymbolStartIndex() const { return numDims; } |
320 | 0 | inline unsigned getDimStartIndex() const { return 0; } |
321 | | }; |
322 | | |
323 | | } // end namespace mlir |
324 | | |
325 | | #endif // MLIR_IR_AFFINE_EXPR_VISITOR_H |