Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/include/mlir/IR/AffineExprVisitor.h
Line
Count
Source (jump to first uncovered line)
1
//===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file defines the AffineExpr visitor class.
10
//
11
//===----------------------------------------------------------------------===//
12
13
#ifndef MLIR_IR_AFFINE_EXPR_VISITOR_H
14
#define MLIR_IR_AFFINE_EXPR_VISITOR_H
15
16
#include "mlir/IR/AffineExpr.h"
17
18
namespace mlir {
19
20
/// Base class for AffineExpr visitors/walkers.
21
///
22
/// AffineExpr visitors are used when you want to perform different actions
23
/// for different kinds of AffineExprs without having to use lots of casts
24
/// and a big switch instruction.
25
///
26
/// To define your own visitor, inherit from this class, specifying your
27
/// new type for the 'SubClass' template parameter, and "override" visitXXX
28
/// functions in your class. This class is defined in terms of statically
29
/// resolved overloading, not virtual functions.
30
///
31
/// For example, here is a visitor that counts the number of for AffineDimExprs
32
/// in an AffineExpr.
33
///
34
///  /// Declare the class.  Note that we derive from AffineExprVisitor
35
///  /// instantiated with our new subclasses_ type.
36
///
37
///  struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
38
///    unsigned numDimExprs;
39
///    DimExprCounter() : numDimExprs(0) {}
40
///    void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; }
41
///  };
42
///
43
///  And this class would be used like this:
44
///    DimExprCounter dec;
45
///    dec.visit(affineExpr);
46
///    numDimExprs = dec.numDimExprs;
47
///
48
/// AffineExprVisitor provides visit methods for the following binary affine
49
/// op expressions:
50
/// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr,
51
/// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr,
52
/// AffineBinaryCeilDivOpExpr. Note that default implementations of these
53
/// methods will call the general AffineBinaryOpExpr method.
54
///
55
/// In addition, visit methods are provided for the following affine
56
//  expressions: AffineConstantExpr, AffineDimExpr, and
57
//  AffineSymbolExpr.
58
///
59
/// Note that if you don't implement visitXXX for some affine expression type,
60
/// the visitXXX method for Instruction superclass will be invoked.
61
///
62
/// Note that this class is specifically designed as a template to avoid
63
/// virtual function call overhead. Defining and using a AffineExprVisitor is
64
/// just as efficient as having your own switch instruction over the instruction
65
/// opcode.
66
67
template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
68
  //===--------------------------------------------------------------------===//
69
  // Interface code - This is the public interface of the AffineExprVisitor
70
  // that you use to visit affine expressions...
71
public:
72
  // Function to walk an AffineExpr (in post order).
73
  RetTy walkPostOrder(AffineExpr expr) {
74
    static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
75
                  "Must instantiate with a derived type of AffineExprVisitor");
76
    switch (expr.getKind()) {
77
    case AffineExprKind::Add: {
78
      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
79
      walkOperandsPostOrder(binOpExpr);
80
      return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
81
    }
82
    case AffineExprKind::Mul: {
83
      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
84
      walkOperandsPostOrder(binOpExpr);
85
      return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
86
    }
87
    case AffineExprKind::Mod: {
88
      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
89
      walkOperandsPostOrder(binOpExpr);
90
      return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
91
    }
92
    case AffineExprKind::FloorDiv: {
93
      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
94
      walkOperandsPostOrder(binOpExpr);
95
      return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
96
    }
97
    case AffineExprKind::CeilDiv: {
98
      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
99
      walkOperandsPostOrder(binOpExpr);
100
      return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
101
    }
102
    case AffineExprKind::Constant:
103
      return static_cast<SubClass *>(this)->visitConstantExpr(
104
          expr.cast<AffineConstantExpr>());
105
    case AffineExprKind::DimId:
106
      return static_cast<SubClass *>(this)->visitDimExpr(
107
          expr.cast<AffineDimExpr>());
108
    case AffineExprKind::SymbolId:
109
      return static_cast<SubClass *>(this)->visitSymbolExpr(
110
          expr.cast<AffineSymbolExpr>());
111
    }
112
  }
113
114
  // Function to visit an AffineExpr.
115
  RetTy visit(AffineExpr expr) {
116
    static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
117
                  "Must instantiate with a derived type of AffineExprVisitor");
118
    switch (expr.getKind()) {
119
    case AffineExprKind::Add: {
120
      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
121
      return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
122
    }
123
    case AffineExprKind::Mul: {
124
      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
125
      return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
126
    }
127
    case AffineExprKind::Mod: {
128
      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
129
      return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
130
    }
131
    case AffineExprKind::FloorDiv: {
132
      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
133
      return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
134
    }
135
    case AffineExprKind::CeilDiv: {
136
      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
137
      return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
138
    }
139
    case AffineExprKind::Constant:
140
      return static_cast<SubClass *>(this)->visitConstantExpr(
141
          expr.cast<AffineConstantExpr>());
142
    case AffineExprKind::DimId:
143
      return static_cast<SubClass *>(this)->visitDimExpr(
144
          expr.cast<AffineDimExpr>());
145
    case AffineExprKind::SymbolId:
146
      return static_cast<SubClass *>(this)->visitSymbolExpr(
147
          expr.cast<AffineSymbolExpr>());
148
    }
149
    llvm_unreachable("Unknown AffineExpr");
150
  }
151
152
  //===--------------------------------------------------------------------===//
153
  // Visitation functions... these functions provide default fallbacks in case
154
  // the user does not specify what to do for a particular instruction type.
155
  // The default behavior is to generalize the instruction type to its subtype
156
  // and try visiting the subtype.  All of this should be inlined perfectly,
157
  // because there are no virtual functions to get in the way.
158
  //
159
160
  // Default visit methods. Note that the default op-specific binary op visit
161
  // methods call the general visitAffineBinaryOpExpr visit method.
162
  void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {}
163
0
  void visitAddExpr(AffineBinaryOpExpr expr) {
164
0
    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
165
0
  }
166
0
  void visitMulExpr(AffineBinaryOpExpr expr) {
167
0
    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
168
0
  }
169
0
  void visitModExpr(AffineBinaryOpExpr expr) {
170
0
    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
171
0
  }
172
0
  void visitFloorDivExpr(AffineBinaryOpExpr expr) {
173
0
    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
174
0
  }
175
0
  void visitCeilDivExpr(AffineBinaryOpExpr expr) {
176
0
    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
177
0
  }
178
  void visitConstantExpr(AffineConstantExpr expr) {}
179
  void visitDimExpr(AffineDimExpr expr) {}
180
  void visitSymbolExpr(AffineSymbolExpr expr) {}
181
182
private:
183
  // Walk the operands - each operand is itself walked in post order.
184
0
  void walkOperandsPostOrder(AffineBinaryOpExpr expr) {
185
0
    walkPostOrder(expr.getLHS());
186
0
    walkPostOrder(expr.getRHS());
187
0
  }
Unexecuted instantiation: _ZN4mlir17AffineExprVisitorINS_25SimpleAffineExprFlattenerEvE21walkOperandsPostOrderENS_18AffineBinaryOpExprE
Unexecuted instantiation: AffineExpr.cpp:_ZN4mlir17AffineExprVisitorIZNKS_10AffineExpr4walkESt8functionIFvS1_EEE16AffineExprWalkervE21walkOperandsPostOrderENS_18AffineBinaryOpExprE
188
};
189
190
// This class is used to flatten a pure affine expression (AffineExpr,
191
// which is in a tree form) into a sum of products (w.r.t constants) when
192
// possible, and in that process simplifying the expression. For a modulo,
193
// floordiv, or a ceildiv expression, an additional identifier, called a local
194
// identifier, is introduced to rewrite the expression as a sum of product
195
// affine expression. Each local identifier is always and by construction a
196
// floordiv of a pure add/mul affine function of dimensional, symbolic, and
197
// other local identifiers, in a non-mutually recursive way. Hence, every local
198
// identifier can ultimately always be recovered as an affine function of
199
// dimensional and symbolic identifiers (involving floordiv's); note however
200
// that by AffineExpr construction, some floordiv combinations are converted to
201
// mod's. The result of the flattening is a flattened expression and a set of
202
// constraints involving just the local variables.
203
//
204
// d2 + (d0 + d1) floordiv 4  is flattened to d2 + q where 'q' is the local
205
// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
206
//
207
// The simplification performed includes the accumulation of contributions for
208
// each dimensional and symbolic identifier together, the simplification of
209
// floordiv/ceildiv/mod expressions and other simplifications that in turn
210
// happen as a result. A simplification that this flattening naturally performs
211
// is of simplifying the numerator and denominator of floordiv/ceildiv, and
212
// folding a modulo expression to a zero, if possible. Three examples are below:
213
//
214
// (d0 + 3 * d1) + d0) - 2 * d1) - d0    simplified to     d0 + d1
215
// (d0 - d0 mod 4 + 4) mod 4             simplified to     0
216
// (3*d0 + 2*d1 + d0) floordiv 2 + d1    simplified to     2*d0 + 2*d1
217
//
218
// The way the flattening works for the second example is as follows: d0 % 4 is
219
// replaced by d0 - 4*q with q being introduced: the expression then simplifies
220
// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
221
// zero. Note that an affine expression may not always be expressible purely as
222
// a sum of products involving just the original dimensional and symbolic
223
// identifiers due to the presence of modulo/floordiv/ceildiv expressions that
224
// may not be eliminated after simplification; in such cases, the final
225
// expression can be reconstructed by replacing the local identifiers with their
226
// corresponding explicit form stored in 'localExprs' (note that each of the
227
// explicit forms itself would have been simplified).
228
//
229
// The expression walk method here performs a linear time post order walk that
230
// performs the above simplifications through visit methods, with partial
231
// results being stored in 'operandExprStack'. When a parent expr is visited,
232
// the flattened expressions corresponding to its two operands would already be
233
// on the stack - the parent expression looks at the two flattened expressions
234
// and combines the two. It pops off the operand expressions and pushes the
235
// combined result (although this is done in-place on its LHS operand expr).
236
// When the walk is completed, the flattened form of the top-level expression
237
// would be left on the stack.
238
//
239
// A flattener can be repeatedly used for multiple affine expressions that bind
240
// to the same operands, for example, for all result expressions of an
241
// AffineMap or AffineValueMap. In such cases, using it for multiple expressions
242
// is more efficient than creating a new flattener for each expression since
243
// common identical div and mod expressions appearing across different
244
// expressions are mapped to the same local identifier (same column position in
245
// 'localVarCst').
246
class SimpleAffineExprFlattener
247
    : public AffineExprVisitor<SimpleAffineExprFlattener> {
248
public:
249
  // Flattend expression layout: [dims, symbols, locals, constant]
250
  // Stack that holds the LHS and RHS operands while visiting a binary op expr.
251
  // In future, consider adding a prepass to determine how big the SmallVector's
252
  // will be, and linearize this to std::vector<int64_t> to prevent
253
  // SmallVector moves on re-allocation.
254
  std::vector<SmallVector<int64_t, 8>> operandExprStack;
255
256
  unsigned numDims;
257
  unsigned numSymbols;
258
259
  // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
260
  unsigned numLocals;
261
262
  // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
263
  // which new identifiers were introduced; if the latter do not get canceled
264
  // out, these expressions can be readily used to reconstruct the AffineExpr
265
  // (tree) form. Note that these expressions themselves would have been
266
  // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
267
  // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
268
  // ceildiv 2 would be the local expression stored for q.
269
  SmallVector<AffineExpr, 4> localExprs;
270
271
  SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols);
272
273
0
  virtual ~SimpleAffineExprFlattener() = default;
274
275
  // Visitor method overrides.
276
  void visitMulExpr(AffineBinaryOpExpr expr);
277
  void visitAddExpr(AffineBinaryOpExpr expr);
278
  void visitDimExpr(AffineDimExpr expr);
279
  void visitSymbolExpr(AffineSymbolExpr expr);
280
  void visitConstantExpr(AffineConstantExpr expr);
281
  void visitCeilDivExpr(AffineBinaryOpExpr expr);
282
  void visitFloorDivExpr(AffineBinaryOpExpr expr);
283
284
  //
285
  // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
286
  //
287
  // A mod expression "expr mod c" is thus flattened by introducing a new local
288
  // variable q (= expr floordiv c), such that expr mod c is replaced with
289
  // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
290
  void visitModExpr(AffineBinaryOpExpr expr);
291
292
protected:
293
  // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
294
  // The local identifier added is always a floordiv of a pure add/mul affine
295
  // function of other identifiers, coefficients of which are specified in
296
  // dividend and with respect to a positive constant divisor. localExpr is the
297
  // simplified tree expression (AffineExpr) corresponding to the quantifier.
298
  virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
299
                                  AffineExpr localExpr);
300
301
private:
302
  // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
303
  // A floordiv is thus flattened by introducing a new local variable q, and
304
  // replacing that expression with 'q' while adding the constraints
305
  // c * q <= expr <= c * q + c - 1 to localVarCst (done by
306
  // FlatAffineConstraints::addLocalFloorDiv).
307
  //
308
  // A ceildiv is similarly flattened:
309
  // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
310
  void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
311
312
  int findLocalId(AffineExpr localExpr);
313
314
0
  inline unsigned getNumCols() const {
315
0
    return numDims + numSymbols + numLocals + 1;
316
0
  }
317
0
  inline unsigned getConstantIndex() const { return getNumCols() - 1; }
318
0
  inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
319
0
  inline unsigned getSymbolStartIndex() const { return numDims; }
320
0
  inline unsigned getDimStartIndex() const { return 0; }
321
};
322
323
} // end namespace mlir
324
325
#endif // MLIR_IR_AFFINE_EXPR_VISITOR_H