Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/lib/IR/AffineExpr.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "mlir/IR/AffineExpr.h"
10
#include "AffineExprDetail.h"
11
#include "mlir/IR/AffineExprVisitor.h"
12
#include "mlir/IR/AffineMap.h"
13
#include "mlir/IR/IntegerSet.h"
14
#include "mlir/Support/MathExtras.h"
15
#include "llvm/ADT/STLExtras.h"
16
17
using namespace mlir;
18
using namespace mlir::detail;
19
20
0
MLIRContext *AffineExpr::getContext() const { return expr->context; }
21
22
0
AffineExprKind AffineExpr::getKind() const {
23
0
  return static_cast<AffineExprKind>(expr->getKind());
24
0
}
25
26
/// Walk all of the AffineExprs in this subgraph in postorder.
27
0
void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
28
0
  struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
29
0
    std::function<void(AffineExpr)> callback;
30
0
31
0
    AffineExprWalker(std::function<void(AffineExpr)> callback)
32
0
        : callback(callback) {}
33
0
34
0
    void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
35
0
    void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
36
0
    void visitDimExpr(AffineDimExpr expr) { callback(expr); }
37
0
    void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
38
0
  };
39
0
40
0
  AffineExprWalker(callback).walkPostOrder(*this);
41
0
}
42
43
// Dispatch affine expression construction based on kind.
44
AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
45
0
                                       AffineExpr rhs) {
46
0
  if (kind == AffineExprKind::Add)
47
0
    return lhs + rhs;
48
0
  if (kind == AffineExprKind::Mul)
49
0
    return lhs * rhs;
50
0
  if (kind == AffineExprKind::FloorDiv)
51
0
    return lhs.floorDiv(rhs);
52
0
  if (kind == AffineExprKind::CeilDiv)
53
0
    return lhs.ceilDiv(rhs);
54
0
  if (kind == AffineExprKind::Mod)
55
0
    return lhs % rhs;
56
0
57
0
  llvm_unreachable("unknown binary operation on affine expressions");
58
0
}
59
60
/// This method substitutes any uses of dimensions and symbols (e.g.
61
/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
62
AffineExpr
63
AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
64
0
                                  ArrayRef<AffineExpr> symReplacements) const {
65
0
  switch (getKind()) {
66
0
  case AffineExprKind::Constant:
67
0
    return *this;
68
0
  case AffineExprKind::DimId: {
69
0
    unsigned dimId = cast<AffineDimExpr>().getPosition();
70
0
    if (dimId >= dimReplacements.size())
71
0
      return *this;
72
0
    return dimReplacements[dimId];
73
0
  }
74
0
  case AffineExprKind::SymbolId: {
75
0
    unsigned symId = cast<AffineSymbolExpr>().getPosition();
76
0
    if (symId >= symReplacements.size())
77
0
      return *this;
78
0
    return symReplacements[symId];
79
0
  }
80
0
  case AffineExprKind::Add:
81
0
  case AffineExprKind::Mul:
82
0
  case AffineExprKind::FloorDiv:
83
0
  case AffineExprKind::CeilDiv:
84
0
  case AffineExprKind::Mod:
85
0
    auto binOp = cast<AffineBinaryOpExpr>();
86
0
    auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
87
0
    auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
88
0
    auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
89
0
    if (newLHS == lhs && newRHS == rhs)
90
0
      return *this;
91
0
    return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
92
0
  }
93
0
  llvm_unreachable("Unknown AffineExpr");
94
0
}
95
96
/// Returns true if this expression is made out of only symbols and
97
/// constants (no dimensional identifiers).
98
0
bool AffineExpr::isSymbolicOrConstant() const {
99
0
  switch (getKind()) {
100
0
  case AffineExprKind::Constant:
101
0
    return true;
102
0
  case AffineExprKind::DimId:
103
0
    return false;
104
0
  case AffineExprKind::SymbolId:
105
0
    return true;
106
0
107
0
  case AffineExprKind::Add:
108
0
  case AffineExprKind::Mul:
109
0
  case AffineExprKind::FloorDiv:
110
0
  case AffineExprKind::CeilDiv:
111
0
  case AffineExprKind::Mod: {
112
0
    auto expr = this->cast<AffineBinaryOpExpr>();
113
0
    return expr.getLHS().isSymbolicOrConstant() &&
114
0
           expr.getRHS().isSymbolicOrConstant();
115
0
  }
116
0
  }
117
0
  llvm_unreachable("Unknown AffineExpr");
118
0
}
119
120
/// Returns true if this is a pure affine expression, i.e., multiplication,
121
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
122
0
bool AffineExpr::isPureAffine() const {
123
0
  switch (getKind()) {
124
0
  case AffineExprKind::SymbolId:
125
0
  case AffineExprKind::DimId:
126
0
  case AffineExprKind::Constant:
127
0
    return true;
128
0
  case AffineExprKind::Add: {
129
0
    auto op = cast<AffineBinaryOpExpr>();
130
0
    return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
131
0
  }
132
0
133
0
  case AffineExprKind::Mul: {
134
0
    // TODO: Canonicalize the constants in binary operators to the RHS when
135
0
    // possible, allowing this to merge into the next case.
136
0
    auto op = cast<AffineBinaryOpExpr>();
137
0
    return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
138
0
           (op.getLHS().template isa<AffineConstantExpr>() ||
139
0
            op.getRHS().template isa<AffineConstantExpr>());
140
0
  }
141
0
  case AffineExprKind::FloorDiv:
142
0
  case AffineExprKind::CeilDiv:
143
0
  case AffineExprKind::Mod: {
144
0
    auto op = cast<AffineBinaryOpExpr>();
145
0
    return op.getLHS().isPureAffine() &&
146
0
           op.getRHS().template isa<AffineConstantExpr>();
147
0
  }
148
0
  }
149
0
  llvm_unreachable("Unknown AffineExpr");
150
0
}
151
152
// Returns the greatest known integral divisor of this affine expression.
153
0
int64_t AffineExpr::getLargestKnownDivisor() const {
154
0
  AffineBinaryOpExpr binExpr(nullptr);
155
0
  switch (getKind()) {
156
0
  case AffineExprKind::SymbolId:
157
0
    LLVM_FALLTHROUGH;
158
0
  case AffineExprKind::DimId:
159
0
    return 1;
160
0
  case AffineExprKind::Constant:
161
0
    return std::abs(this->cast<AffineConstantExpr>().getValue());
162
0
  case AffineExprKind::Mul: {
163
0
    binExpr = this->cast<AffineBinaryOpExpr>();
164
0
    return binExpr.getLHS().getLargestKnownDivisor() *
165
0
           binExpr.getRHS().getLargestKnownDivisor();
166
0
  }
167
0
  case AffineExprKind::Add:
168
0
    LLVM_FALLTHROUGH;
169
0
  case AffineExprKind::FloorDiv:
170
0
  case AffineExprKind::CeilDiv:
171
0
  case AffineExprKind::Mod: {
172
0
    binExpr = cast<AffineBinaryOpExpr>();
173
0
    return llvm::GreatestCommonDivisor64(
174
0
        binExpr.getLHS().getLargestKnownDivisor(),
175
0
        binExpr.getRHS().getLargestKnownDivisor());
176
0
  }
177
0
  }
178
0
  llvm_unreachable("Unknown AffineExpr");
179
0
}
180
181
0
bool AffineExpr::isMultipleOf(int64_t factor) const {
182
0
  AffineBinaryOpExpr binExpr(nullptr);
183
0
  uint64_t l, u;
184
0
  switch (getKind()) {
185
0
  case AffineExprKind::SymbolId:
186
0
    LLVM_FALLTHROUGH;
187
0
  case AffineExprKind::DimId:
188
0
    return factor * factor == 1;
189
0
  case AffineExprKind::Constant:
190
0
    return cast<AffineConstantExpr>().getValue() % factor == 0;
191
0
  case AffineExprKind::Mul: {
192
0
    binExpr = cast<AffineBinaryOpExpr>();
193
0
    // It's probably not worth optimizing this further (to not traverse the
194
0
    // whole sub-tree under - it that would require a version of isMultipleOf
195
0
    // that on a 'false' return also returns the largest known divisor).
196
0
    return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
197
0
           (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
198
0
           (l * u) % factor == 0;
199
0
  }
200
0
  case AffineExprKind::Add:
201
0
  case AffineExprKind::FloorDiv:
202
0
  case AffineExprKind::CeilDiv:
203
0
  case AffineExprKind::Mod: {
204
0
    binExpr = cast<AffineBinaryOpExpr>();
205
0
    return llvm::GreatestCommonDivisor64(
206
0
               binExpr.getLHS().getLargestKnownDivisor(),
207
0
               binExpr.getRHS().getLargestKnownDivisor()) %
208
0
               factor ==
209
0
           0;
210
0
  }
211
0
  }
212
0
  llvm_unreachable("Unknown AffineExpr");
213
0
}
214
215
0
bool AffineExpr::isFunctionOfDim(unsigned position) const {
216
0
  if (getKind() == AffineExprKind::DimId) {
217
0
    return *this == mlir::getAffineDimExpr(position, getContext());
218
0
  }
219
0
  if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
220
0
    return expr.getLHS().isFunctionOfDim(position) ||
221
0
           expr.getRHS().isFunctionOfDim(position);
222
0
  }
223
0
  return false;
224
0
}
225
226
AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
227
0
    : AffineExpr(ptr) {}
228
0
AffineExpr AffineBinaryOpExpr::getLHS() const {
229
0
  return static_cast<ImplType *>(expr)->lhs;
230
0
}
231
0
AffineExpr AffineBinaryOpExpr::getRHS() const {
232
0
  return static_cast<ImplType *>(expr)->rhs;
233
0
}
234
235
0
AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
236
0
unsigned AffineDimExpr::getPosition() const {
237
0
  return static_cast<ImplType *>(expr)->position;
238
0
}
239
240
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
241
0
                                       MLIRContext *context) {
242
0
  auto assignCtx = [context](AffineDimExprStorage *storage) {
243
0
    storage->context = context;
244
0
  };
245
0
246
0
  StorageUniquer &uniquer = context->getAffineUniquer();
247
0
  return uniquer.get<AffineDimExprStorage>(
248
0
      assignCtx, static_cast<unsigned>(kind), position);
249
0
}
250
251
0
AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
252
0
  return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
253
0
}
254
255
AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
256
0
    : AffineExpr(ptr) {}
257
0
unsigned AffineSymbolExpr::getPosition() const {
258
0
  return static_cast<ImplType *>(expr)->position;
259
0
}
260
261
0
AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
262
0
  return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
263
0
  ;
264
0
}
265
266
AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
267
0
    : AffineExpr(ptr) {}
268
0
int64_t AffineConstantExpr::getValue() const {
269
0
  return static_cast<ImplType *>(expr)->constant;
270
0
}
271
272
0
bool AffineExpr::operator==(int64_t v) const {
273
0
  return *this == getAffineConstantExpr(v, getContext());
274
0
}
275
276
0
AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
277
0
  auto assignCtx = [context](AffineConstantExprStorage *storage) {
278
0
    storage->context = context;
279
0
  };
280
0
281
0
  StorageUniquer &uniquer = context->getAffineUniquer();
282
0
  return uniquer.get<AffineConstantExprStorage>(
283
0
      assignCtx, static_cast<unsigned>(AffineExprKind::Constant), constant);
284
0
}
285
286
/// Simplify add expression. Return nullptr if it can't be simplified.
287
0
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
288
0
  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
289
0
  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
290
0
  // Fold if both LHS, RHS are a constant.
291
0
  if (lhsConst && rhsConst)
292
0
    return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
293
0
                                 lhs.getContext());
294
0
295
0
  // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
296
0
  // If only one of them is a symbolic expressions, make it the RHS.
297
0
  if (lhs.isa<AffineConstantExpr>() ||
298
0
      (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
299
0
    return rhs + lhs;
300
0
  }
301
0
302
0
  // At this point, if there was a constant, it would be on the right.
303
0
304
0
  // Addition with a zero is a noop, return the other input.
305
0
  if (rhsConst) {
306
0
    if (rhsConst.getValue() == 0)
307
0
      return lhs;
308
0
  }
309
0
  // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
310
0
  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
311
0
  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
312
0
    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
313
0
      return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
314
0
  }
315
0
316
0
  // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
317
0
  // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
318
0
  // respective multiplicands.
319
0
  Optional<int64_t> rLhsConst, rRhsConst;
320
0
  AffineExpr firstExpr, secondExpr;
321
0
  AffineConstantExpr rLhsConstExpr;
322
0
  auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
323
0
  if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
324
0
      (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
325
0
    rLhsConst = rLhsConstExpr.getValue();
326
0
    firstExpr = lBinOpExpr.getLHS();
327
0
  } else {
328
0
    rLhsConst = 1;
329
0
    firstExpr = lhs;
330
0
  }
331
0
332
0
  auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
333
0
  AffineConstantExpr rRhsConstExpr;
334
0
  if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
335
0
      (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
336
0
    rRhsConst = rRhsConstExpr.getValue();
337
0
    secondExpr = rBinOpExpr.getLHS();
338
0
  } else {
339
0
    rRhsConst = 1;
340
0
    secondExpr = rhs;
341
0
  }
342
0
343
0
  if (rLhsConst && rRhsConst && firstExpr == secondExpr)
344
0
    return getAffineBinaryOpExpr(
345
0
        AffineExprKind::Mul, firstExpr,
346
0
        getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(),
347
0
                              lhs.getContext()));
348
0
349
0
  // When doing successive additions, bring constant to the right: turn (d0 + 2)
350
0
  // + d1 into (d0 + d1) + 2.
351
0
  if (lBin && lBin.getKind() == AffineExprKind::Add) {
352
0
    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
353
0
      return lBin.getLHS() + rhs + lrhs;
354
0
    }
355
0
  }
356
0
357
0
  // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This
358
0
  // leads to a much more efficient form when 'c' is a power of two, and in
359
0
  // general a more compact and readable form.
360
0
361
0
  // Process '(expr floordiv c) * (-c)'.
362
0
  if (!rBinOpExpr)
363
0
    return nullptr;
364
0
365
0
  auto lrhs = rBinOpExpr.getLHS();
366
0
  auto rrhs = rBinOpExpr.getRHS();
367
0
368
0
  // Process lrhs, which is 'expr floordiv c'.
369
0
  AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
370
0
  if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
371
0
    return nullptr;
372
0
373
0
  auto llrhs = lrBinOpExpr.getLHS();
374
0
  auto rlrhs = lrBinOpExpr.getRHS();
375
0
376
0
  if (lhs == llrhs && rlrhs == -rrhs) {
377
0
    return lhs % rlrhs;
378
0
  }
379
0
  return nullptr;
380
0
}
381
382
0
AffineExpr AffineExpr::operator+(int64_t v) const {
383
0
  return *this + getAffineConstantExpr(v, getContext());
384
0
}
385
0
AffineExpr AffineExpr::operator+(AffineExpr other) const {
386
0
  if (auto simplified = simplifyAdd(*this, other))
387
0
    return simplified;
388
0
389
0
  StorageUniquer &uniquer = getContext()->getAffineUniquer();
390
0
  return uniquer.get<AffineBinaryOpExprStorage>(
391
0
      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
392
0
}
393
394
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
395
0
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
396
0
  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
397
0
  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
398
0
399
0
  if (lhsConst && rhsConst)
400
0
    return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
401
0
                                 lhs.getContext());
402
0
403
0
  assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
404
0
405
0
  // Canonicalize the mul expression so that the constant/symbolic term is the
406
0
  // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
407
0
  // constant. (Note that a constant is trivially symbolic).
408
0
  if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
409
0
    // At least one of them has to be symbolic.
410
0
    return rhs * lhs;
411
0
  }
412
0
413
0
  // At this point, if there was a constant, it would be on the right.
414
0
415
0
  // Multiplication with a one is a noop, return the other input.
416
0
  if (rhsConst) {
417
0
    if (rhsConst.getValue() == 1)
418
0
      return lhs;
419
0
    // Multiplication with zero.
420
0
    if (rhsConst.getValue() == 0)
421
0
      return rhsConst;
422
0
  }
423
0
424
0
  // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
425
0
  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
426
0
  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
427
0
    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
428
0
      return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
429
0
  }
430
0
431
0
  // When doing successive multiplication, bring constant to the right: turn (d0
432
0
  // * 2) * d1 into (d0 * d1) * 2.
433
0
  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
434
0
    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
435
0
      return (lBin.getLHS() * rhs) * lrhs;
436
0
    }
437
0
  }
438
0
439
0
  return nullptr;
440
0
}
441
442
0
AffineExpr AffineExpr::operator*(int64_t v) const {
443
0
  return *this * getAffineConstantExpr(v, getContext());
444
0
}
445
0
AffineExpr AffineExpr::operator*(AffineExpr other) const {
446
0
  if (auto simplified = simplifyMul(*this, other))
447
0
    return simplified;
448
0
449
0
  StorageUniquer &uniquer = getContext()->getAffineUniquer();
450
0
  return uniquer.get<AffineBinaryOpExprStorage>(
451
0
      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
452
0
}
453
454
// Unary minus, delegate to operator*.
455
0
AffineExpr AffineExpr::operator-() const {
456
0
  return *this * getAffineConstantExpr(-1, getContext());
457
0
}
458
459
// Delegate to operator+.
460
0
AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
461
0
AffineExpr AffineExpr::operator-(AffineExpr other) const {
462
0
  return *this + (-other);
463
0
}
464
465
0
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
466
0
  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
467
0
  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
468
0
469
0
  // mlir floordiv by zero or negative numbers is undefined and preserved as is.
470
0
  if (!rhsConst || rhsConst.getValue() < 1)
471
0
    return nullptr;
472
0
473
0
  if (lhsConst)
474
0
    return getAffineConstantExpr(
475
0
        floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
476
0
477
0
  // Fold floordiv of a multiply with a constant that is a multiple of the
478
0
  // divisor. Eg: (i * 128) floordiv 64 = i * 2.
479
0
  if (rhsConst == 1)
480
0
    return lhs;
481
0
482
0
  // Simplify (expr * const) floordiv divConst when expr is known to be a
483
0
  // multiple of divConst.
484
0
  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
485
0
  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
486
0
    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
487
0
      // rhsConst is known to be a positive constant.
488
0
      if (lrhs.getValue() % rhsConst.getValue() == 0)
489
0
        return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
490
0
    }
491
0
  }
492
0
493
0
  // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
494
0
  // known to be a multiple of divConst.
495
0
  if (lBin && lBin.getKind() == AffineExprKind::Add) {
496
0
    int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
497
0
    int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
498
0
    // rhsConst is known to be a positive constant.
499
0
    if (llhsDiv % rhsConst.getValue() == 0 ||
500
0
        lrhsDiv % rhsConst.getValue() == 0)
501
0
      return lBin.getLHS().floorDiv(rhsConst.getValue()) +
502
0
             lBin.getRHS().floorDiv(rhsConst.getValue());
503
0
  }
504
0
505
0
  return nullptr;
506
0
}
507
508
0
AffineExpr AffineExpr::floorDiv(uint64_t v) const {
509
0
  return floorDiv(getAffineConstantExpr(v, getContext()));
510
0
}
511
0
AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
512
0
  if (auto simplified = simplifyFloorDiv(*this, other))
513
0
    return simplified;
514
0
515
0
  StorageUniquer &uniquer = getContext()->getAffineUniquer();
516
0
  return uniquer.get<AffineBinaryOpExprStorage>(
517
0
      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
518
0
      other);
519
0
}
520
521
0
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
522
0
  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
523
0
  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
524
0
525
0
  if (!rhsConst || rhsConst.getValue() < 1)
526
0
    return nullptr;
527
0
528
0
  if (lhsConst)
529
0
    return getAffineConstantExpr(
530
0
        ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
531
0
532
0
  // Fold ceildiv of a multiply with a constant that is a multiple of the
533
0
  // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
534
0
  if (rhsConst.getValue() == 1)
535
0
    return lhs;
536
0
537
0
  // Simplify (expr * const) ceildiv divConst when const is known to be a
538
0
  // multiple of divConst.
539
0
  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
540
0
  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
541
0
    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
542
0
      // rhsConst is known to be a positive constant.
543
0
      if (lrhs.getValue() % rhsConst.getValue() == 0)
544
0
        return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
545
0
    }
546
0
  }
547
0
548
0
  return nullptr;
549
0
}
550
551
0
AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
552
0
  return ceilDiv(getAffineConstantExpr(v, getContext()));
553
0
}
554
0
AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
555
0
  if (auto simplified = simplifyCeilDiv(*this, other))
556
0
    return simplified;
557
0
558
0
  StorageUniquer &uniquer = getContext()->getAffineUniquer();
559
0
  return uniquer.get<AffineBinaryOpExprStorage>(
560
0
      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
561
0
      other);
562
0
}
563
564
0
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
565
0
  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
566
0
  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
567
0
568
0
  // mod w.r.t zero or negative numbers is undefined and preserved as is.
569
0
  if (!rhsConst || rhsConst.getValue() < 1)
570
0
    return nullptr;
571
0
572
0
  if (lhsConst)
573
0
    return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
574
0
                                 lhs.getContext());
575
0
576
0
  // Fold modulo of an expression that is known to be a multiple of a constant
577
0
  // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
578
0
  // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
579
0
  if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
580
0
    return getAffineConstantExpr(0, lhs.getContext());
581
0
582
0
  // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
583
0
  // known to be a multiple of divConst.
584
0
  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
585
0
  if (lBin && lBin.getKind() == AffineExprKind::Add) {
586
0
    int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
587
0
    int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
588
0
    // rhsConst is known to be a positive constant.
589
0
    if (llhsDiv % rhsConst.getValue() == 0)
590
0
      return lBin.getRHS() % rhsConst.getValue();
591
0
    if (lrhsDiv % rhsConst.getValue() == 0)
592
0
      return lBin.getLHS() % rhsConst.getValue();
593
0
  }
594
0
595
0
  return nullptr;
596
0
}
597
598
0
AffineExpr AffineExpr::operator%(uint64_t v) const {
599
0
  return *this % getAffineConstantExpr(v, getContext());
600
0
}
601
0
AffineExpr AffineExpr::operator%(AffineExpr other) const {
602
0
  if (auto simplified = simplifyMod(*this, other))
603
0
    return simplified;
604
0
605
0
  StorageUniquer &uniquer = getContext()->getAffineUniquer();
606
0
  return uniquer.get<AffineBinaryOpExprStorage>(
607
0
      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
608
0
}
609
610
0
AffineExpr AffineExpr::compose(AffineMap map) const {
611
0
  SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
612
0
                                             map.getResults().end());
613
0
  return replaceDimsAndSymbols(dimReplacements, {});
614
0
}
615
0
raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
616
0
  expr.print(os);
617
0
  return os;
618
0
}
619
620
/// Constructs an affine expression from a flat ArrayRef. If there are local
621
/// identifiers (neither dimensional nor symbolic) that appear in the sum of
622
/// products expression, `localExprs` is expected to have the AffineExpr
623
/// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
624
/// in the format [dims, symbols, locals, constant term].
625
AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
626
                                           unsigned numDims,
627
                                           unsigned numSymbols,
628
                                           ArrayRef<AffineExpr> localExprs,
629
0
                                           MLIRContext *context) {
630
0
  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
631
0
  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
632
0
         "unexpected number of local expressions");
633
0
634
0
  auto expr = getAffineConstantExpr(0, context);
635
0
  // Dimensions and symbols.
636
0
  for (unsigned j = 0; j < numDims + numSymbols; j++) {
637
0
    if (flatExprs[j] == 0)
638
0
      continue;
639
0
    auto id = j < numDims ? getAffineDimExpr(j, context)
640
0
                          : getAffineSymbolExpr(j - numDims, context);
641
0
    expr = expr + id * flatExprs[j];
642
0
  }
643
0
644
0
  // Local identifiers.
645
0
  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
646
0
       j++) {
647
0
    if (flatExprs[j] == 0)
648
0
      continue;
649
0
    auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
650
0
    expr = expr + term;
651
0
  }
652
0
653
0
  // Constant term.
654
0
  int64_t constTerm = flatExprs[flatExprs.size() - 1];
655
0
  if (constTerm != 0)
656
0
    expr = expr + constTerm;
657
0
  return expr;
658
0
}
659
660
SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
661
                                                     unsigned numSymbols)
662
0
    : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
663
0
  operandExprStack.reserve(8);
664
0
}
665
666
0
void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
667
0
  assert(operandExprStack.size() >= 2);
668
0
  // This is a pure affine expr; the RHS will be a constant.
669
0
  assert(expr.getRHS().isa<AffineConstantExpr>());
670
0
  // Get the RHS constant.
671
0
  auto rhsConst = operandExprStack.back()[getConstantIndex()];
672
0
  operandExprStack.pop_back();
673
0
  // Update the LHS in place instead of pop and push.
674
0
  auto &lhs = operandExprStack.back();
675
0
  for (unsigned i = 0, e = lhs.size(); i < e; i++) {
676
0
    lhs[i] *= rhsConst;
677
0
  }
678
0
}
679
680
0
void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
681
0
  assert(operandExprStack.size() >= 2);
682
0
  const auto &rhs = operandExprStack.back();
683
0
  auto &lhs = operandExprStack[operandExprStack.size() - 2];
684
0
  assert(lhs.size() == rhs.size());
685
0
  // Update the LHS in place.
686
0
  for (unsigned i = 0, e = rhs.size(); i < e; i++) {
687
0
    lhs[i] += rhs[i];
688
0
  }
689
0
  // Pop off the RHS.
690
0
  operandExprStack.pop_back();
691
0
}
692
693
//
694
// t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
695
//
696
// A mod expression "expr mod c" is thus flattened by introducing a new local
697
// variable q (= expr floordiv c), such that expr mod c is replaced with
698
// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
699
0
void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
700
0
  assert(operandExprStack.size() >= 2);
701
0
  // This is a pure affine expr; the RHS will be a constant.
702
0
  assert(expr.getRHS().isa<AffineConstantExpr>());
703
0
  auto rhsConst = operandExprStack.back()[getConstantIndex()];
704
0
  operandExprStack.pop_back();
705
0
  auto &lhs = operandExprStack.back();
706
0
  // TODO(bondhugula): handle modulo by zero case when this issue is fixed
707
0
  // at the other places in the IR.
708
0
  assert(rhsConst > 0 && "RHS constant has to be positive");
709
0
710
0
  // Check if the LHS expression is a multiple of modulo factor.
711
0
  unsigned i, e;
712
0
  for (i = 0, e = lhs.size(); i < e; i++)
713
0
    if (lhs[i] % rhsConst != 0)
714
0
      break;
715
0
  // If yes, modulo expression here simplifies to zero.
716
0
  if (i == lhs.size()) {
717
0
    std::fill(lhs.begin(), lhs.end(), 0);
718
0
    return;
719
0
  }
720
0
721
0
  // Add a local variable for the quotient, i.e., expr % c is replaced by
722
0
  // (expr - q * c) where q = expr floordiv c. Do this while canceling out
723
0
  // the GCD of expr and c.
724
0
  SmallVector<int64_t, 8> floorDividend(lhs);
725
0
  uint64_t gcd = rhsConst;
726
0
  for (unsigned i = 0, e = lhs.size(); i < e; i++)
727
0
    gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
728
0
  // Simplify the numerator and the denominator.
729
0
  if (gcd != 1) {
730
0
    for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
731
0
      floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
732
0
  }
733
0
  int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
734
0
735
0
  // Construct the AffineExpr form of the floordiv to store in localExprs.
736
0
  MLIRContext *context = expr.getContext();
737
0
  auto dividendExpr = getAffineExprFromFlatForm(
738
0
      floorDividend, numDims, numSymbols, localExprs, context);
739
0
  auto divisorExpr = getAffineConstantExpr(floorDivisor, context);
740
0
  auto floorDivExpr = dividendExpr.floorDiv(divisorExpr);
741
0
  int loc;
742
0
  if ((loc = findLocalId(floorDivExpr)) == -1) {
743
0
    addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
744
0
    // Set result at top of stack to "lhs - rhsConst * q".
745
0
    lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
746
0
  } else {
747
0
    // Reuse the existing local id.
748
0
    lhs[getLocalVarStartIndex() + loc] = -rhsConst;
749
0
  }
750
0
}
751
752
0
void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
753
0
  visitDivExpr(expr, /*isCeil=*/true);
754
0
}
755
0
void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
756
0
  visitDivExpr(expr, /*isCeil=*/false);
757
0
}
758
759
0
void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
760
0
  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
761
0
  auto &eq = operandExprStack.back();
762
0
  assert(expr.getPosition() < numDims && "Inconsistent number of dims");
763
0
  eq[getDimStartIndex() + expr.getPosition()] = 1;
764
0
}
765
766
0
void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
767
0
  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
768
0
  auto &eq = operandExprStack.back();
769
0
  assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
770
0
  eq[getSymbolStartIndex() + expr.getPosition()] = 1;
771
0
}
772
773
0
void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
774
0
  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
775
0
  auto &eq = operandExprStack.back();
776
0
  eq[getConstantIndex()] = expr.getValue();
777
0
}
778
779
// t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
780
// A floordiv is thus flattened by introducing a new local variable q, and
781
// replacing that expression with 'q' while adding the constraints
782
// c * q <= expr <= c * q + c - 1 to localVarCst (done by
783
// FlatAffineConstraints::addLocalFloorDiv).
784
//
785
// A ceildiv is similarly flattened:
786
// t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
787
void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
788
0
                                             bool isCeil) {
789
0
  assert(operandExprStack.size() >= 2);
790
0
  assert(expr.getRHS().isa<AffineConstantExpr>());
791
0
792
0
  // This is a pure affine expr; the RHS is a positive constant.
793
0
  int64_t rhsConst = operandExprStack.back()[getConstantIndex()];
794
0
  // TODO(bondhugula): handle division by zero at the same time the issue is
795
0
  // fixed at other places.
796
0
  assert(rhsConst > 0 && "RHS constant has to be positive");
797
0
  operandExprStack.pop_back();
798
0
  auto &lhs = operandExprStack.back();
799
0
800
0
  // Simplify the floordiv, ceildiv if possible by canceling out the greatest
801
0
  // common divisors of the numerator and denominator.
802
0
  uint64_t gcd = std::abs(rhsConst);
803
0
  for (unsigned i = 0, e = lhs.size(); i < e; i++)
804
0
    gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
805
0
  // Simplify the numerator and the denominator.
806
0
  if (gcd != 1) {
807
0
    for (unsigned i = 0, e = lhs.size(); i < e; i++)
808
0
      lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
809
0
  }
810
0
  int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
811
0
  // If the divisor becomes 1, the updated LHS is the result. (The
812
0
  // divisor can't be negative since rhsConst is positive).
813
0
  if (divisor == 1)
814
0
    return;
815
0
816
0
  // If the divisor cannot be simplified to one, we will have to retain
817
0
  // the ceil/floor expr (simplified up until here). Add an existential
818
0
  // quantifier to express its result, i.e., expr1 div expr2 is replaced
819
0
  // by a new identifier, q.
820
0
  MLIRContext *context = expr.getContext();
821
0
  auto a =
822
0
      getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context);
823
0
  auto b = getAffineConstantExpr(divisor, context);
824
0
825
0
  int loc;
826
0
  auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
827
0
  if ((loc = findLocalId(divExpr)) == -1) {
828
0
    if (!isCeil) {
829
0
      SmallVector<int64_t, 8> dividend(lhs);
830
0
      addLocalFloorDivId(dividend, divisor, divExpr);
831
0
    } else {
832
0
      // lhs ceildiv c <=>  (lhs + c - 1) floordiv c
833
0
      SmallVector<int64_t, 8> dividend(lhs);
834
0
      dividend.back() += divisor - 1;
835
0
      addLocalFloorDivId(dividend, divisor, divExpr);
836
0
    }
837
0
  }
838
0
  // Set the expression on stack to the local var introduced to capture the
839
0
  // result of the division (floor or ceil).
840
0
  std::fill(lhs.begin(), lhs.end(), 0);
841
0
  if (loc == -1)
842
0
    lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
843
0
  else
844
0
    lhs[getLocalVarStartIndex() + loc] = 1;
845
0
}
846
847
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
848
// The local identifier added is always a floordiv of a pure add/mul affine
849
// function of other identifiers, coefficients of which are specified in
850
// dividend and with respect to a positive constant divisor. localExpr is the
851
// simplified tree expression (AffineExpr) corresponding to the quantifier.
852
void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
853
                                                   int64_t divisor,
854
                                                   AffineExpr localExpr) {
855
  assert(divisor > 0 && "positive constant divisor expected");
856
  for (auto &subExpr : operandExprStack)
857
    subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
858
  localExprs.push_back(localExpr);
859
  numLocals++;
860
  // dividend and divisor are not used here; an override of this method uses it.
861
}
862
863
0
int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
864
0
  SmallVectorImpl<AffineExpr>::iterator it;
865
0
  if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
866
0
    return -1;
867
0
  return it - localExprs.begin();
868
0
}
869
870
/// Simplify the affine expression by flattening it and reconstructing it.
871
AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
872
0
                                    unsigned numSymbols) {
873
0
  // TODO(bondhugula): only pure affine for now. The simplification here can
874
0
  // be extended to semi-affine maps in the future.
875
0
  if (!expr.isPureAffine())
876
0
    return expr;
877
0
878
0
  SimpleAffineExprFlattener flattener(numDims, numSymbols);
879
0
  flattener.walkPostOrder(expr);
880
0
  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
881
0
  auto simplifiedExpr =
882
0
      getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
883
0
                                flattener.localExprs, expr.getContext());
884
0
  flattener.operandExprStack.pop_back();
885
0
  assert(flattener.operandExprStack.empty());
886
0
887
0
  return simplifiedExpr;
888
0
}