Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/lib/Analysis/AffineStructures.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// Structures for affine/polyhedral analysis of affine dialect ops.
10
//
11
//===----------------------------------------------------------------------===//
12
13
#include "mlir/Analysis/AffineStructures.h"
14
#include "mlir/Analysis/Presburger/Simplex.h"
15
#include "mlir/Dialect/Affine/IR/AffineOps.h"
16
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
17
#include "mlir/Dialect/StandardOps/IR/Ops.h"
18
#include "mlir/IR/AffineExprVisitor.h"
19
#include "mlir/IR/IntegerSet.h"
20
#include "mlir/Support/LLVM.h"
21
#include "mlir/Support/MathExtras.h"
22
#include "llvm/ADT/SmallPtrSet.h"
23
#include "llvm/Support/Debug.h"
24
#include "llvm/Support/raw_ostream.h"
25
26
#define DEBUG_TYPE "affine-structures"
27
28
using namespace mlir;
29
using llvm::SmallDenseMap;
30
using llvm::SmallDenseSet;
31
32
namespace {
33
34
// See comments for SimpleAffineExprFlattener.
35
// An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
36
// constraint information associated with mod's, floordiv's, and ceildiv's
37
// in FlatAffineConstraints 'localVarCst'.
38
struct AffineExprFlattener : public SimpleAffineExprFlattener {
39
public:
40
  // Constraints connecting newly introduced local variables (for mod's and
41
  // div's) to existing (dimensional and symbolic) ones. These are always
42
  // inequalities.
43
  FlatAffineConstraints localVarCst;
44
45
  AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx)
46
0
      : SimpleAffineExprFlattener(nDims, nSymbols) {
47
0
    localVarCst.reset(nDims, nSymbols, /*numLocals=*/0);
48
0
  }
49
50
private:
51
  // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
52
  // The local identifier added is always a floordiv of a pure add/mul affine
53
  // function of other identifiers, coefficients of which are specified in
54
  // `dividend' and with respect to the positive constant `divisor'. localExpr
55
  // is the simplified tree expression (AffineExpr) corresponding to the
56
  // quantifier.
57
  void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
58
0
                          AffineExpr localExpr) override {
59
0
    SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
60
0
    // Update localVarCst.
61
0
    localVarCst.addLocalFloorDiv(dividend, divisor);
62
0
  }
63
};
64
65
} // end anonymous namespace
66
67
// Flattens the expressions in map. Returns failure if 'expr' was unable to be
68
// flattened (i.e., semi-affine expressions not handled yet).
69
static LogicalResult
70
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
71
                        unsigned numSymbols,
72
                        std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
73
0
                        FlatAffineConstraints *localVarCst) {
74
0
  if (exprs.empty()) {
75
0
    localVarCst->reset(numDims, numSymbols);
76
0
    return success();
77
0
  }
78
0
79
0
  AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
80
0
  // Use the same flattener to simplify each expression successively. This way
81
0
  // local identifiers / expressions are shared.
82
0
  for (auto expr : exprs) {
83
0
    if (!expr.isPureAffine())
84
0
      return failure();
85
0
86
0
    flattener.walkPostOrder(expr);
87
0
  }
88
0
89
0
  assert(flattener.operandExprStack.size() == exprs.size());
90
0
  flattenedExprs->clear();
91
0
  flattenedExprs->assign(flattener.operandExprStack.begin(),
92
0
                         flattener.operandExprStack.end());
93
0
94
0
  if (localVarCst)
95
0
    localVarCst->clearAndCopyFrom(flattener.localVarCst);
96
0
97
0
  return success();
98
0
}
99
100
// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
101
// be flattened (semi-affine expressions not handled yet).
102
LogicalResult
103
mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
104
                             unsigned numSymbols,
105
                             SmallVectorImpl<int64_t> *flattenedExpr,
106
0
                             FlatAffineConstraints *localVarCst) {
107
0
  std::vector<SmallVector<int64_t, 8>> flattenedExprs;
108
0
  LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
109
0
                                                &flattenedExprs, localVarCst);
110
0
  *flattenedExpr = flattenedExprs[0];
111
0
  return ret;
112
0
}
113
114
/// Flattens the expressions in map. Returns failure if 'expr' was unable to be
115
/// flattened (i.e., semi-affine expressions not handled yet).
116
LogicalResult mlir::getFlattenedAffineExprs(
117
    AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
118
0
    FlatAffineConstraints *localVarCst) {
119
0
  if (map.getNumResults() == 0) {
120
0
    localVarCst->reset(map.getNumDims(), map.getNumSymbols());
121
0
    return success();
122
0
  }
123
0
  return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
124
0
                                   map.getNumSymbols(), flattenedExprs,
125
0
                                   localVarCst);
126
0
}
127
128
LogicalResult mlir::getFlattenedAffineExprs(
129
    IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
130
0
    FlatAffineConstraints *localVarCst) {
131
0
  if (set.getNumConstraints() == 0) {
132
0
    localVarCst->reset(set.getNumDims(), set.getNumSymbols());
133
0
    return success();
134
0
  }
135
0
  return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
136
0
                                   set.getNumSymbols(), flattenedExprs,
137
0
                                   localVarCst);
138
0
}
139
140
//===----------------------------------------------------------------------===//
141
// FlatAffineConstraints.
142
//===----------------------------------------------------------------------===//
143
144
// Copy constructor.
145
FlatAffineConstraints::FlatAffineConstraints(
146
0
    const FlatAffineConstraints &other) {
147
0
  numReservedCols = other.numReservedCols;
148
0
  numDims = other.getNumDimIds();
149
0
  numSymbols = other.getNumSymbolIds();
150
0
  numIds = other.getNumIds();
151
0
152
0
  auto otherIds = other.getIds();
153
0
  ids.reserve(numReservedCols);
154
0
  ids.append(otherIds.begin(), otherIds.end());
155
0
156
0
  unsigned numReservedEqualities = other.getNumReservedEqualities();
157
0
  unsigned numReservedInequalities = other.getNumReservedInequalities();
158
0
159
0
  equalities.reserve(numReservedEqualities * numReservedCols);
160
0
  inequalities.reserve(numReservedInequalities * numReservedCols);
161
0
162
0
  for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
163
0
    addInequality(other.getInequality(r));
164
0
  }
165
0
  for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
166
0
    addEquality(other.getEquality(r));
167
0
  }
168
0
}
169
170
// Clones this object.
171
0
std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
172
0
  return std::make_unique<FlatAffineConstraints>(*this);
173
0
}
174
175
// Construct from an IntegerSet.
176
FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
177
    : numReservedCols(set.getNumInputs() + 1),
178
      numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
179
0
      numSymbols(set.getNumSymbols()) {
180
0
  equalities.reserve(set.getNumEqualities() * numReservedCols);
181
0
  inequalities.reserve(set.getNumInequalities() * numReservedCols);
182
0
  ids.resize(numIds, None);
183
0
184
0
  // Flatten expressions and add them to the constraint system.
185
0
  std::vector<SmallVector<int64_t, 8>> flatExprs;
186
0
  FlatAffineConstraints localVarCst;
187
0
  if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
188
0
    assert(false && "flattening unimplemented for semi-affine integer sets");
189
0
    return;
190
0
  }
191
0
  assert(flatExprs.size() == set.getNumConstraints());
192
0
  for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) {
193
0
    addLocalId(getNumLocalIds());
194
0
  }
195
0
196
0
  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
197
0
    const auto &flatExpr = flatExprs[i];
198
0
    assert(flatExpr.size() == getNumCols());
199
0
    if (set.getEqFlags()[i]) {
200
0
      addEquality(flatExpr);
201
0
    } else {
202
0
      addInequality(flatExpr);
203
0
    }
204
0
  }
205
0
  // Add the other constraints involving local id's from flattening.
206
0
  append(localVarCst);
207
0
}
208
209
void FlatAffineConstraints::reset(unsigned numReservedInequalities,
210
                                  unsigned numReservedEqualities,
211
                                  unsigned newNumReservedCols,
212
                                  unsigned newNumDims, unsigned newNumSymbols,
213
                                  unsigned newNumLocals,
214
0
                                  ArrayRef<Value> idArgs) {
215
0
  assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
216
0
         "minimum 1 column");
217
0
  numReservedCols = newNumReservedCols;
218
0
  numDims = newNumDims;
219
0
  numSymbols = newNumSymbols;
220
0
  numIds = numDims + numSymbols + newNumLocals;
221
0
  assert(idArgs.empty() || idArgs.size() == numIds);
222
0
223
0
  clearConstraints();
224
0
  if (numReservedEqualities >= 1)
225
0
    equalities.reserve(newNumReservedCols * numReservedEqualities);
226
0
  if (numReservedInequalities >= 1)
227
0
    inequalities.reserve(newNumReservedCols * numReservedInequalities);
228
0
  if (idArgs.empty()) {
229
0
    ids.resize(numIds, None);
230
0
  } else {
231
0
    ids.assign(idArgs.begin(), idArgs.end());
232
0
  }
233
0
}
234
235
void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
236
                                  unsigned newNumLocals,
237
0
                                  ArrayRef<Value> idArgs) {
238
0
  reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
239
0
        newNumSymbols, newNumLocals, idArgs);
240
0
}
241
242
0
void FlatAffineConstraints::append(const FlatAffineConstraints &other) {
243
0
  assert(other.getNumCols() == getNumCols());
244
0
  assert(other.getNumDimIds() == getNumDimIds());
245
0
  assert(other.getNumSymbolIds() == getNumSymbolIds());
246
0
247
0
  inequalities.reserve(inequalities.size() +
248
0
                       other.getNumInequalities() * numReservedCols);
249
0
  equalities.reserve(equalities.size() +
250
0
                     other.getNumEqualities() * numReservedCols);
251
0
252
0
  for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
253
0
    addInequality(other.getInequality(r));
254
0
  }
255
0
  for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
256
0
    addEquality(other.getEquality(r));
257
0
  }
258
0
}
259
260
0
void FlatAffineConstraints::addLocalId(unsigned pos) {
261
0
  addId(IdKind::Local, pos);
262
0
}
263
264
0
void FlatAffineConstraints::addDimId(unsigned pos, Value id) {
265
0
  addId(IdKind::Dimension, pos, id);
266
0
}
267
268
0
void FlatAffineConstraints::addSymbolId(unsigned pos, Value id) {
269
0
  addId(IdKind::Symbol, pos, id);
270
0
}
271
272
/// Adds a dimensional identifier. The added column is initialized to
273
/// zero.
274
0
void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value id) {
275
0
  if (kind == IdKind::Dimension)
276
0
    assert(pos <= getNumDimIds());
277
0
  else if (kind == IdKind::Symbol)
278
0
    assert(pos <= getNumSymbolIds());
279
0
  else
280
0
    assert(pos <= getNumLocalIds());
281
0
282
0
  unsigned oldNumReservedCols = numReservedCols;
283
0
284
0
  // Check if a resize is necessary.
285
0
  if (getNumCols() + 1 > numReservedCols) {
286
0
    equalities.resize(getNumEqualities() * (getNumCols() + 1));
287
0
    inequalities.resize(getNumInequalities() * (getNumCols() + 1));
288
0
    numReservedCols++;
289
0
  }
290
0
291
0
  int absolutePos;
292
0
293
0
  if (kind == IdKind::Dimension) {
294
0
    absolutePos = pos;
295
0
    numDims++;
296
0
  } else if (kind == IdKind::Symbol) {
297
0
    absolutePos = pos + getNumDimIds();
298
0
    numSymbols++;
299
0
  } else {
300
0
    absolutePos = pos + getNumDimIds() + getNumSymbolIds();
301
0
  }
302
0
  numIds++;
303
0
304
0
  // Note that getNumCols() now will already return the new size, which will be
305
0
  // at least one.
306
0
  int numInequalities = static_cast<int>(getNumInequalities());
307
0
  int numEqualities = static_cast<int>(getNumEqualities());
308
0
  int numCols = static_cast<int>(getNumCols());
309
0
  for (int r = numInequalities - 1; r >= 0; r--) {
310
0
    for (int c = numCols - 2; c >= 0; c--) {
311
0
      if (c < absolutePos)
312
0
        atIneq(r, c) = inequalities[r * oldNumReservedCols + c];
313
0
      else
314
0
        atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c];
315
0
    }
316
0
    atIneq(r, absolutePos) = 0;
317
0
  }
318
0
319
0
  for (int r = numEqualities - 1; r >= 0; r--) {
320
0
    for (int c = numCols - 2; c >= 0; c--) {
321
0
      // All values in column absolutePositions < absolutePos have the same
322
0
      // coordinates in the 2-d view of the coefficient buffer.
323
0
      if (c < absolutePos)
324
0
        atEq(r, c) = equalities[r * oldNumReservedCols + c];
325
0
      else
326
0
        // Those at absolutePosition >= absolutePos, get a shifted
327
0
        // absolutePosition.
328
0
        atEq(r, c + 1) = equalities[r * oldNumReservedCols + c];
329
0
    }
330
0
    // Initialize added dimension to zero.
331
0
    atEq(r, absolutePos) = 0;
332
0
  }
333
0
334
0
  // If an 'id' is provided, insert it; otherwise use None.
335
0
  if (id)
336
0
    ids.insert(ids.begin() + absolutePos, id);
337
0
  else
338
0
    ids.insert(ids.begin() + absolutePos, None);
339
0
  assert(ids.size() == getNumIds());
340
0
}
341
342
/// Checks if two constraint systems are in the same space, i.e., if they are
343
/// associated with the same set of identifiers, appearing in the same order.
344
static bool areIdsAligned(const FlatAffineConstraints &A,
345
0
                          const FlatAffineConstraints &B) {
346
0
  return A.getNumDimIds() == B.getNumDimIds() &&
347
0
         A.getNumSymbolIds() == B.getNumSymbolIds() &&
348
0
         A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
349
0
}
350
351
/// Calls areIdsAligned to check if two constraint systems have the same set
352
/// of identifiers in the same order.
353
bool FlatAffineConstraints::areIdsAlignedWithOther(
354
0
    const FlatAffineConstraints &other) {
355
0
  return areIdsAligned(*this, other);
356
0
}
357
358
/// Checks if the SSA values associated with `cst''s identifiers are unique.
359
static bool LLVM_ATTRIBUTE_UNUSED
360
0
areIdsUnique(const FlatAffineConstraints &cst) {
361
0
  SmallPtrSet<Value, 8> uniqueIds;
362
0
  for (auto id : cst.getIds()) {
363
0
    if (id.hasValue() && !uniqueIds.insert(id.getValue()).second)
364
0
      return false;
365
0
  }
366
0
  return true;
367
0
}
368
369
// Swap the posA^th identifier with the posB^th identifier.
370
0
static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) {
371
0
  assert(posA < A->getNumIds() && "invalid position A");
372
0
  assert(posB < A->getNumIds() && "invalid position B");
373
0
374
0
  if (posA == posB)
375
0
    return;
376
0
377
0
  for (unsigned r = 0, e = A->getNumInequalities(); r < e; r++) {
378
0
    std::swap(A->atIneq(r, posA), A->atIneq(r, posB));
379
0
  }
380
0
  for (unsigned r = 0, e = A->getNumEqualities(); r < e; r++) {
381
0
    std::swap(A->atEq(r, posA), A->atEq(r, posB));
382
0
  }
383
0
  std::swap(A->getId(posA), A->getId(posB));
384
0
}
385
386
/// Merge and align the identifiers of A and B starting at 'offset', so that
387
/// both constraint systems get the union of the contained identifiers that is
388
/// dimension-wise and symbol-wise unique; both constraint systems are updated
389
/// so that they have the union of all identifiers, with A's original
390
/// identifiers appearing first followed by any of B's identifiers that didn't
391
/// appear in A. Local identifiers of each system are by design separate/local
392
/// and are placed one after other (A's followed by B's).
393
//  Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
394
//      Output: both A, B have (%i, %j, %k) [%M, %N, %P]
395
//
396
static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
397
0
                             FlatAffineConstraints *B) {
398
0
  assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds());
399
0
  // A merge/align isn't meaningful if a cst's ids aren't distinct.
400
0
  assert(areIdsUnique(*A) && "A's id values aren't unique");
401
0
  assert(areIdsUnique(*B) && "B's id values aren't unique");
402
0
403
0
  assert(std::all_of(A->getIds().begin() + offset,
404
0
                     A->getIds().begin() + A->getNumDimAndSymbolIds(),
405
0
                     [](Optional<Value> id) { return id.hasValue(); }));
406
0
407
0
  assert(std::all_of(B->getIds().begin() + offset,
408
0
                     B->getIds().begin() + B->getNumDimAndSymbolIds(),
409
0
                     [](Optional<Value> id) { return id.hasValue(); }));
410
0
411
0
  // Place local id's of A after local id's of B.
412
0
  for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) {
413
0
    B->addLocalId(0);
414
0
  }
415
0
  for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e;
416
0
       t++) {
417
0
    A->addLocalId(A->getNumLocalIds());
418
0
  }
419
0
420
0
  SmallVector<Value, 4> aDimValues, aSymValues;
421
0
  A->getIdValues(offset, A->getNumDimIds(), &aDimValues);
422
0
  A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues);
423
0
  {
424
0
    // Merge dims from A into B.
425
0
    unsigned d = offset;
426
0
    for (auto aDimValue : aDimValues) {
427
0
      unsigned loc;
428
0
      if (B->findId(aDimValue, &loc)) {
429
0
        assert(loc >= offset && "A's dim appears in B's aligned range");
430
0
        assert(loc < B->getNumDimIds() &&
431
0
               "A's dim appears in B's non-dim position");
432
0
        swapId(B, d, loc);
433
0
      } else {
434
0
        B->addDimId(d);
435
0
        B->setIdValue(d, aDimValue);
436
0
      }
437
0
      d++;
438
0
    }
439
0
440
0
    // Dimensions that are in B, but not in A, are added at the end.
441
0
    for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) {
442
0
      A->addDimId(A->getNumDimIds());
443
0
      A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t));
444
0
    }
445
0
  }
446
0
  {
447
0
    // Merge symbols: merge A's symbols into B first.
448
0
    unsigned s = B->getNumDimIds();
449
0
    for (auto aSymValue : aSymValues) {
450
0
      unsigned loc;
451
0
      if (B->findId(aSymValue, &loc)) {
452
0
        assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
453
0
               "A's symbol appears in B's non-symbol position");
454
0
        swapId(B, s, loc);
455
0
      } else {
456
0
        B->addSymbolId(s - B->getNumDimIds());
457
0
        B->setIdValue(s, aSymValue);
458
0
      }
459
0
      s++;
460
0
    }
461
0
    // Symbols that are in B, but not in A, are added at the end.
462
0
    for (unsigned t = A->getNumDimAndSymbolIds(),
463
0
                  e = B->getNumDimAndSymbolIds();
464
0
         t < e; t++) {
465
0
      A->addSymbolId(A->getNumSymbolIds());
466
0
      A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t));
467
0
    }
468
0
  }
469
0
  assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
470
0
}
471
472
// Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
473
void FlatAffineConstraints::mergeAndAlignIdsWithOther(
474
0
    unsigned offset, FlatAffineConstraints *other) {
475
0
  mergeAndAlignIds(offset, this, other);
476
0
}
477
478
// This routine may add additional local variables if the flattened expression
479
// corresponding to the map has such variables due to mod's, ceildiv's, and
480
// floordiv's in it.
481
0
LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) {
482
0
  std::vector<SmallVector<int64_t, 8>> flatExprs;
483
0
  FlatAffineConstraints localCst;
484
0
  if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs,
485
0
                                     &localCst))) {
486
0
    LLVM_DEBUG(llvm::dbgs()
487
0
               << "composition unimplemented for semi-affine maps\n");
488
0
    return failure();
489
0
  }
490
0
  assert(flatExprs.size() == vMap->getNumResults());
491
0
492
0
  // Add localCst information.
493
0
  if (localCst.getNumLocalIds() > 0) {
494
0
    localCst.setIdValues(0, /*end=*/localCst.getNumDimAndSymbolIds(),
495
0
                         /*values=*/vMap->getOperands());
496
0
    // Align localCst and this.
497
0
    mergeAndAlignIds(/*offset=*/0, &localCst, this);
498
0
    // Finally, append localCst to this constraint set.
499
0
    append(localCst);
500
0
  }
501
0
502
0
  // Add dimensions corresponding to the map's results.
503
0
  for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) {
504
0
    // TODO: Consider using a batched version to add a range of IDs.
505
0
    addDimId(0);
506
0
  }
507
0
508
0
  // We add one equality for each result connecting the result dim of the map to
509
0
  // the other identifiers.
510
0
  // For eg: if the expression is 16*i0 + i1, and this is the r^th
511
0
  // iteration/result of the value map, we are adding the equality:
512
0
  //  d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
513
0
  //  add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
514
0
  for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
515
0
    const auto &flatExpr = flatExprs[r];
516
0
    assert(flatExpr.size() >= vMap->getNumOperands() + 1);
517
0
518
0
    // eqToAdd is the equality corresponding to the flattened affine expression.
519
0
    SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
520
0
    // Set the coefficient for this result to one.
521
0
    eqToAdd[r] = 1;
522
0
523
0
    // Dims and symbols.
524
0
    for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
525
0
      unsigned loc;
526
0
      bool ret = findId(vMap->getOperand(i), &loc);
527
0
      assert(ret && "value map's id can't be found");
528
0
      (void)ret;
529
0
      // Negate 'eq[r]' since the newly added dimension will be set to this one.
530
0
      eqToAdd[loc] = -flatExpr[i];
531
0
    }
532
0
    // Local vars common to eq and localCst are at the beginning.
533
0
    unsigned j = getNumDimIds() + getNumSymbolIds();
534
0
    unsigned end = flatExpr.size() - 1;
535
0
    for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) {
536
0
      eqToAdd[j] = -flatExpr[i];
537
0
    }
538
0
539
0
    // Constant term.
540
0
    eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
541
0
542
0
    // Add the equality connecting the result of the map to this constraint set.
543
0
    addEquality(eqToAdd);
544
0
  }
545
0
546
0
  return success();
547
0
}
548
549
// Similar to composeMap except that no Value's need be associated with the
550
// constraint system nor are they looked at -- since the dimensions and
551
// symbols of 'other' are expected to correspond 1:1 to 'this' system. It
552
// is thus not convenient to share code with composeMap.
553
0
LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
554
0
  assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
555
0
  assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
556
0
557
0
  std::vector<SmallVector<int64_t, 8>> flatExprs;
558
0
  FlatAffineConstraints localCst;
559
0
  if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) {
560
0
    LLVM_DEBUG(llvm::dbgs()
561
0
               << "composition unimplemented for semi-affine maps\n");
562
0
    return failure();
563
0
  }
564
0
  assert(flatExprs.size() == other.getNumResults());
565
0
566
0
  // Add localCst information.
567
0
  if (localCst.getNumLocalIds() > 0) {
568
0
    // Place local id's of A after local id's of B.
569
0
    for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; l++) {
570
0
      addLocalId(0);
571
0
    }
572
0
    // Finally, append localCst to this constraint set.
573
0
    append(localCst);
574
0
  }
575
0
576
0
  // Add dimensions corresponding to the map's results.
577
0
  for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
578
0
    addDimId(0);
579
0
  }
580
0
581
0
  // We add one equality for each result connecting the result dim of the map to
582
0
  // the other identifiers.
583
0
  // For eg: if the expression is 16*i0 + i1, and this is the r^th
584
0
  // iteration/result of the value map, we are adding the equality:
585
0
  //  d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
586
0
  //  add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
587
0
  for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
588
0
    const auto &flatExpr = flatExprs[r];
589
0
    assert(flatExpr.size() >= other.getNumInputs() + 1);
590
0
591
0
    // eqToAdd is the equality corresponding to the flattened affine expression.
592
0
    SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
593
0
    // Set the coefficient for this result to one.
594
0
    eqToAdd[r] = 1;
595
0
596
0
    // Dims and symbols.
597
0
    for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
598
0
      // Negate 'eq[r]' since the newly added dimension will be set to this one.
599
0
      eqToAdd[e + i] = -flatExpr[i];
600
0
    }
601
0
    // Local vars common to eq and localCst are at the beginning.
602
0
    unsigned j = getNumDimIds() + getNumSymbolIds();
603
0
    unsigned end = flatExpr.size() - 1;
604
0
    for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
605
0
      eqToAdd[j] = -flatExpr[i];
606
0
    }
607
0
608
0
    // Constant term.
609
0
    eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
610
0
611
0
    // Add the equality connecting the result of the map to this constraint set.
612
0
    addEquality(eqToAdd);
613
0
  }
614
0
615
0
  return success();
616
0
}
617
618
// Turn a dimension into a symbol.
619
0
static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value id) {
620
0
  unsigned pos;
621
0
  if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) {
622
0
    swapId(cst, pos, cst->getNumDimIds() - 1);
623
0
    cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1);
624
0
  }
625
0
}
626
627
// Turn a symbol into a dimension.
628
0
static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value id) {
629
0
  unsigned pos;
630
0
  if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
631
0
      pos < cst->getNumDimAndSymbolIds()) {
632
0
    swapId(cst, pos, cst->getNumDimIds());
633
0
    cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
634
0
  }
635
0
}
636
637
// Changes all symbol identifiers which are loop IVs to dim identifiers.
638
0
void FlatAffineConstraints::convertLoopIVSymbolsToDims() {
639
0
  // Gather all symbols which are loop IVs.
640
0
  SmallVector<Value, 4> loopIVs;
641
0
  for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
642
0
    if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue()))
643
0
      loopIVs.push_back(ids[i].getValue());
644
0
  }
645
0
  // Turn each symbol in 'loopIVs' into a dim identifier.
646
0
  for (auto iv : loopIVs) {
647
0
    turnSymbolIntoDim(this, iv);
648
0
  }
649
0
}
650
651
0
void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
652
0
  if (containsId(id))
653
0
    return;
654
0
655
0
  // Caller is expected to fully compose map/operands if necessary.
656
0
  assert((isTopLevelValue(id) || isForInductionVar(id)) &&
657
0
         "non-terminal symbol / loop IV expected");
658
0
  // Outer loop IVs could be used in forOp's bounds.
659
0
  if (auto loop = getForInductionVarOwner(id)) {
660
0
    addDimId(getNumDimIds(), id);
661
0
    if (failed(this->addAffineForOpDomain(loop)))
662
0
      LLVM_DEBUG(
663
0
          loop.emitWarning("failed to add domain info to constraint system"));
664
0
    return;
665
0
  }
666
0
  // Add top level symbol.
667
0
  addSymbolId(getNumSymbolIds(), id);
668
0
  // Check if the symbol is a constant.
669
0
  if (auto constOp = id.getDefiningOp<ConstantIndexOp>())
670
0
    setIdToConstant(id, constOp.getValue());
671
0
}
672
673
0
LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
674
0
  unsigned pos;
675
0
  // Pre-condition for this method.
676
0
  if (!findId(forOp.getInductionVar(), &pos)) {
677
0
    assert(false && "Value not found");
678
0
    return failure();
679
0
  }
680
0
681
0
  int64_t step = forOp.getStep();
682
0
  if (step != 1) {
683
0
    if (!forOp.hasConstantLowerBound())
684
0
      forOp.emitWarning("domain conservatively approximated");
685
0
    else {
686
0
      // Add constraints for the stride.
687
0
      // (iv - lb) % step = 0 can be written as:
688
0
      // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
689
0
      // Add local variable 'q' and add the above equality.
690
0
      // The first constraint is q = (iv - lb) floordiv step
691
0
      SmallVector<int64_t, 8> dividend(getNumCols(), 0);
692
0
      int64_t lb = forOp.getConstantLowerBound();
693
0
      dividend[pos] = 1;
694
0
      dividend.back() -= lb;
695
0
      addLocalFloorDiv(dividend, step);
696
0
      // Second constraint: (iv - lb) - step * q = 0.
697
0
      SmallVector<int64_t, 8> eq(getNumCols(), 0);
698
0
      eq[pos] = 1;
699
0
      eq.back() -= lb;
700
0
      // For the local var just added above.
701
0
      eq[getNumCols() - 2] = -step;
702
0
      addEquality(eq);
703
0
    }
704
0
  }
705
0
706
0
  if (forOp.hasConstantLowerBound()) {
707
0
    addConstantLowerBound(pos, forOp.getConstantLowerBound());
708
0
  } else {
709
0
    // Non-constant lower bound case.
710
0
    if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(),
711
0
                                    forOp.getLowerBoundOperands(),
712
0
                                    /*eq=*/false, /*lower=*/true)))
713
0
      return failure();
714
0
  }
715
0
716
0
  if (forOp.hasConstantUpperBound()) {
717
0
    addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1);
718
0
    return success();
719
0
  }
720
0
  // Non-constant upper bound case.
721
0
  return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(),
722
0
                              forOp.getUpperBoundOperands(),
723
0
                              /*eq=*/false, /*lower=*/false);
724
0
}
725
726
// Searches for a constraint with a non-zero coefficient at 'colIdx' in
727
// equality (isEq=true) or inequality (isEq=false) constraints.
728
// Returns true and sets row found in search in 'rowIdx'.
729
// Returns false otherwise.
730
static bool findConstraintWithNonZeroAt(const FlatAffineConstraints &cst,
731
                                        unsigned colIdx, bool isEq,
732
0
                                        unsigned *rowIdx) {
733
0
  assert(colIdx < cst.getNumCols() && "position out of bounds");
734
0
  auto at = [&](unsigned rowIdx) -> int64_t {
735
0
    return isEq ? cst.atEq(rowIdx, colIdx) : cst.atIneq(rowIdx, colIdx);
736
0
  };
737
0
  unsigned e = isEq ? cst.getNumEqualities() : cst.getNumInequalities();
738
0
  for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
739
0
    if (at(*rowIdx) != 0) {
740
0
      return true;
741
0
    }
742
0
  }
743
0
  return false;
744
0
}
745
746
// Normalizes the coefficient values across all columns in 'rowIDx' by their
747
// GCD in equality or inequality constraints as specified by 'isEq'.
748
template <bool isEq>
749
static void normalizeConstraintByGCD(FlatAffineConstraints *constraints,
750
0
                                     unsigned rowIdx) {
751
0
  auto at = [&](unsigned colIdx) -> int64_t {
752
0
    return isEq ? constraints->atEq(rowIdx, colIdx)
753
0
                : constraints->atIneq(rowIdx, colIdx);
754
0
  };
Unexecuted instantiation: AffineStructures.cpp:_ZZL24normalizeConstraintByGCDILb1EEvPN4mlir21FlatAffineConstraintsEjENKUljE_clEj
Unexecuted instantiation: AffineStructures.cpp:_ZZL24normalizeConstraintByGCDILb0EEvPN4mlir21FlatAffineConstraintsEjENKUljE_clEj
755
0
  uint64_t gcd = std::abs(at(0));
756
0
  for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) {
757
0
    gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j)));
758
0
  }
759
0
  if (gcd > 0 && gcd != 1) {
760
0
    for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) {
761
0
      int64_t v = at(j) / static_cast<int64_t>(gcd);
762
0
      isEq ? constraints->atEq(rowIdx, j) = v
763
0
           : constraints->atIneq(rowIdx, j) = v;
764
0
    }
765
0
  }
766
0
}
Unexecuted instantiation: AffineStructures.cpp:_ZL24normalizeConstraintByGCDILb1EEvPN4mlir21FlatAffineConstraintsEj
Unexecuted instantiation: AffineStructures.cpp:_ZL24normalizeConstraintByGCDILb0EEvPN4mlir21FlatAffineConstraintsEj
767
768
0
void FlatAffineConstraints::normalizeConstraintsByGCD() {
769
0
  for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
770
0
    normalizeConstraintByGCD</*isEq=*/true>(this, i);
771
0
  }
772
0
  for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
773
0
    normalizeConstraintByGCD</*isEq=*/false>(this, i);
774
0
  }
775
0
}
776
777
6
bool FlatAffineConstraints::hasConsistentState() const {
778
6
  if (inequalities.size() != getNumInequalities() * numReservedCols)
779
0
    return false;
780
6
  if (equalities.size() != getNumEqualities() * numReservedCols)
781
0
    return false;
782
6
  if (ids.size() != getNumIds())
783
0
    return false;
784
6
785
6
  // Catches errors where numDims, numSymbols, numIds aren't consistent.
786
6
  if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds)
787
0
    return false;
788
6
789
6
  return true;
790
6
}
791
792
/// Checks all rows of equality/inequality constraints for trivial
793
/// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced
794
/// after elimination. Returns 'true' if an invalid constraint is found;
795
/// 'false' otherwise.
796
0
bool FlatAffineConstraints::hasInvalidConstraint() const {
797
0
  assert(hasConsistentState());
798
0
  auto check = [&](bool isEq) -> bool {
799
0
    unsigned numCols = getNumCols();
800
0
    unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
801
0
    for (unsigned i = 0, e = numRows; i < e; ++i) {
802
0
      unsigned j;
803
0
      for (j = 0; j < numCols - 1; ++j) {
804
0
        int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
805
0
        // Skip rows with non-zero variable coefficients.
806
0
        if (v != 0)
807
0
          break;
808
0
      }
809
0
      if (j < numCols - 1) {
810
0
        continue;
811
0
      }
812
0
      // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
813
0
      // Example invalid constraints include: '1 == 0' or '-1 >= 0'
814
0
      int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
815
0
      if ((isEq && v != 0) || (!isEq && v < 0)) {
816
0
        return true;
817
0
      }
818
0
    }
819
0
    return false;
820
0
  };
821
0
  if (check(/*isEq=*/true))
822
0
    return true;
823
0
  return check(/*isEq=*/false);
824
0
}
825
826
// Eliminate identifier from constraint at 'rowIdx' based on coefficient at
827
// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
828
// updated as they have already been eliminated.
829
static void eliminateFromConstraint(FlatAffineConstraints *constraints,
830
                                    unsigned rowIdx, unsigned pivotRow,
831
                                    unsigned pivotCol, unsigned elimColStart,
832
0
                                    bool isEq) {
833
0
  // Skip if equality 'rowIdx' if same as 'pivotRow'.
834
0
  if (isEq && rowIdx == pivotRow)
835
0
    return;
836
0
  auto at = [&](unsigned i, unsigned j) -> int64_t {
837
0
    return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
838
0
  };
839
0
  int64_t leadCoeff = at(rowIdx, pivotCol);
840
0
  // Skip if leading coefficient at 'rowIdx' is already zero.
841
0
  if (leadCoeff == 0)
842
0
    return;
843
0
  int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
844
0
  int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
845
0
  int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
846
0
  int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
847
0
  int64_t rowMultiplier = lcm / std::abs(leadCoeff);
848
0
849
0
  unsigned numCols = constraints->getNumCols();
850
0
  for (unsigned j = 0; j < numCols; ++j) {
851
0
    // Skip updating column 'j' if it was just eliminated.
852
0
    if (j >= elimColStart && j < pivotCol)
853
0
      continue;
854
0
    int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
855
0
                rowMultiplier * at(rowIdx, j);
856
0
    isEq ? constraints->atEq(rowIdx, j) = v
857
0
         : constraints->atIneq(rowIdx, j) = v;
858
0
  }
859
0
}
860
861
// Remove coefficients in column range [colStart, colLimit) in place.
862
// This removes in data in the specified column range, and copies any
863
// remaining valid data into place.
864
static void shiftColumnsToLeft(FlatAffineConstraints *constraints,
865
                               unsigned colStart, unsigned colLimit,
866
0
                               bool isEq) {
867
0
  assert(colLimit <= constraints->getNumIds());
868
0
  if (colLimit <= colStart)
869
0
    return;
870
0
871
0
  unsigned numCols = constraints->getNumCols();
872
0
  unsigned numRows = isEq ? constraints->getNumEqualities()
873
0
                          : constraints->getNumInequalities();
874
0
  unsigned numToEliminate = colLimit - colStart;
875
0
  for (unsigned r = 0, e = numRows; r < e; ++r) {
876
0
    for (unsigned c = colLimit; c < numCols; ++c) {
877
0
      if (isEq) {
878
0
        constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c);
879
0
      } else {
880
0
        constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c);
881
0
      }
882
0
    }
883
0
  }
884
0
}
885
886
// Removes identifiers in column range [idStart, idLimit), and copies any
887
// remaining valid data into place, and updates member variables.
888
0
void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) {
889
0
  assert(idLimit < getNumCols() && "invalid id limit");
890
0
891
0
  if (idStart >= idLimit)
892
0
    return;
893
0
894
0
  // We are going to be removing one or more identifiers from the range.
895
0
  assert(idStart < numIds && "invalid idStart position");
896
0
897
0
  // TODO(andydavis) Make 'removeIdRange' a lambda called from here.
898
0
  // Remove eliminated identifiers from equalities.
899
0
  shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true);
900
0
901
0
  // Remove eliminated identifiers from inequalities.
902
0
  shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false);
903
0
904
0
  // Update members numDims, numSymbols and numIds.
905
0
  unsigned numDimsEliminated = 0;
906
0
  unsigned numLocalsEliminated = 0;
907
0
  unsigned numColsEliminated = idLimit - idStart;
908
0
  if (idStart < numDims) {
909
0
    numDimsEliminated = std::min(numDims, idLimit) - idStart;
910
0
  }
911
0
  // Check how many local id's were removed. Note that our identifier order is
912
0
  // [dims, symbols, locals]. Local id start at position numDims + numSymbols.
913
0
  if (idLimit > numDims + numSymbols) {
914
0
    numLocalsEliminated = std::min(
915
0
        idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds());
916
0
  }
917
0
  unsigned numSymbolsEliminated =
918
0
      numColsEliminated - numDimsEliminated - numLocalsEliminated;
919
0
920
0
  numDims -= numDimsEliminated;
921
0
  numSymbols -= numSymbolsEliminated;
922
0
  numIds = numIds - numColsEliminated;
923
0
924
0
  ids.erase(ids.begin() + idStart, ids.begin() + idLimit);
925
0
926
0
  // No resize necessary. numReservedCols remains the same.
927
0
}
928
929
/// Returns the position of the identifier that has the minimum <number of lower
930
/// bounds> times <number of upper bounds> from the specified range of
931
/// identifiers [start, end). It is often best to eliminate in the increasing
932
/// order of these counts when doing Fourier-Motzkin elimination since FM adds
933
/// that many new constraints.
934
static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst,
935
0
                                     unsigned start, unsigned end) {
936
0
  assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
937
0
938
0
  auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
939
0
    unsigned numLb = 0;
940
0
    unsigned numUb = 0;
941
0
    for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
942
0
      if (cst.atIneq(r, pos) > 0) {
943
0
        ++numLb;
944
0
      } else if (cst.atIneq(r, pos) < 0) {
945
0
        ++numUb;
946
0
      }
947
0
    }
948
0
    return numLb * numUb;
949
0
  };
950
0
951
0
  unsigned minLoc = start;
952
0
  unsigned min = getProductOfNumLowerUpperBounds(start);
953
0
  for (unsigned c = start + 1; c < end; c++) {
954
0
    unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
955
0
    if (numLbUbProduct < min) {
956
0
      min = numLbUbProduct;
957
0
      minLoc = c;
958
0
    }
959
0
  }
960
0
  return minLoc;
961
0
}
962
963
// Checks for emptiness of the set by eliminating identifiers successively and
964
// using the GCD test (on all equality constraints) and checking for trivially
965
// invalid constraints. Returns 'true' if the constraint system is found to be
966
// empty; false otherwise.
967
0
bool FlatAffineConstraints::isEmpty() const {
968
0
  if (isEmptyByGCDTest() || hasInvalidConstraint())
969
0
    return true;
970
0
971
0
  // First, eliminate as many identifiers as possible using Gaussian
972
0
  // elimination.
973
0
  FlatAffineConstraints tmpCst(*this);
974
0
  unsigned currentPos = 0;
975
0
  while (currentPos < tmpCst.getNumIds()) {
976
0
    tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
977
0
    ++currentPos;
978
0
    // We check emptiness through trivial checks after eliminating each ID to
979
0
    // detect emptiness early. Since the checks isEmptyByGCDTest() and
980
0
    // hasInvalidConstraint() are linear time and single sweep on the constraint
981
0
    // buffer, this appears reasonable - but can optimize in the future.
982
0
    if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
983
0
      return true;
984
0
  }
985
0
986
0
  // Eliminate the remaining using FM.
987
0
  for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
988
0
    tmpCst.FourierMotzkinEliminate(
989
0
        getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
990
0
    // Check for a constraint explosion. This rarely happens in practice, but
991
0
    // this check exists as a safeguard against improperly constructed
992
0
    // constraint systems or artificially created arbitrarily complex systems
993
0
    // that aren't the intended use case for FlatAffineConstraints. This is
994
0
    // needed since FM has a worst case exponential complexity in theory.
995
0
    if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
996
0
      LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
997
0
      return false;
998
0
    }
999
0
1000
0
    // FM wouldn't have modified the equalities in any way. So no need to again
1001
0
    // run GCD test. Check for trivial invalid constraints.
1002
0
    if (tmpCst.hasInvalidConstraint())
1003
0
      return true;
1004
0
  }
1005
0
  return false;
1006
0
}
1007
1008
// Runs the GCD test on all equality constraints. Returns 'true' if this test
1009
// fails on any equality. Returns 'false' otherwise.
1010
// This test can be used to disprove the existence of a solution. If it returns
1011
// true, no integer solution to the equality constraints can exist.
1012
//
1013
// GCD test definition:
1014
//
1015
// The equality constraint:
1016
//
1017
//  c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
1018
//
1019
// has an integer solution iff:
1020
//
1021
//  GCD of c_1, c_2, ..., c_n divides c_0.
1022
//
1023
6
bool FlatAffineConstraints::isEmptyByGCDTest() const {
1024
6
  assert(hasConsistentState());
1025
6
  unsigned numCols = getNumCols();
1026
13
  for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1027
7
    uint64_t gcd = std::abs(atEq(i, 0));
1028
23
    for (unsigned j = 1; j < numCols - 1; ++j) {
1029
16
      gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
1030
16
    }
1031
7
    int64_t v = std::abs(atEq(i, numCols - 1));
1032
7
    if (gcd > 0 && (v % gcd != 0)) {
1033
0
      return true;
1034
0
    }
1035
7
  }
1036
6
  return false;
1037
6
}
1038
1039
// First, try the GCD test heuristic.
1040
//
1041
// If that doesn't find the set empty, check if the set is unbounded. If it is,
1042
// we cannot use the GBR algorithm and we conservatively return false.
1043
//
1044
// If the set is bounded, we use the complete emptiness check for this case
1045
// provided by Simplex::findIntegerSample(), which gives a definitive answer.
1046
6
bool FlatAffineConstraints::isIntegerEmpty() const {
1047
6
  if (isEmptyByGCDTest())
1048
0
    return true;
1049
6
1050
6
  Simplex simplex(*this);
1051
6
  if (simplex.isUnbounded())
1052
1
    return false;
1053
5
  return !simplex.findIntegerSample().hasValue();
1054
5
}
1055
1056
Optional<SmallVector<int64_t, 8>>
1057
25
FlatAffineConstraints::findIntegerSample() const {
1058
25
  return Simplex(*this).findIntegerSample();
1059
25
}
1060
1061
/// Tightens inequalities given that we are dealing with integer spaces. This is
1062
/// analogous to the GCD test but applied to inequalities. The constant term can
1063
/// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
1064
///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
1065
/// fast method - linear in the number of coefficients.
1066
// Example on how this affects practical cases: consider the scenario:
1067
// 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
1068
// j >= 100 instead of the tighter (exact) j >= 128.
1069
0
void FlatAffineConstraints::GCDTightenInequalities() {
1070
0
  unsigned numCols = getNumCols();
1071
0
  for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1072
0
    uint64_t gcd = std::abs(atIneq(i, 0));
1073
0
    for (unsigned j = 1; j < numCols - 1; ++j) {
1074
0
      gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j)));
1075
0
    }
1076
0
    if (gcd > 0 && gcd != 1) {
1077
0
      int64_t gcdI = static_cast<int64_t>(gcd);
1078
0
      // Tighten the constant term and normalize the constraint by the GCD.
1079
0
      atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI);
1080
0
      for (unsigned j = 0, e = numCols - 1; j < e; ++j)
1081
0
        atIneq(i, j) /= gcdI;
1082
0
    }
1083
0
  }
1084
0
}
1085
1086
// Eliminates all identifier variables in column range [posStart, posLimit).
1087
// Returns the number of variables eliminated.
1088
unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
1089
0
                                                     unsigned posLimit) {
1090
0
  // Return if identifier positions to eliminate are out of range.
1091
0
  assert(posLimit <= numIds);
1092
0
  assert(hasConsistentState());
1093
0
1094
0
  if (posStart >= posLimit)
1095
0
    return 0;
1096
0
1097
0
  GCDTightenInequalities();
1098
0
1099
0
  unsigned pivotCol = 0;
1100
0
  for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
1101
0
    // Find a row which has a non-zero coefficient in column 'j'.
1102
0
    unsigned pivotRow;
1103
0
    if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true,
1104
0
                                     &pivotRow)) {
1105
0
      // No pivot row in equalities with non-zero at 'pivotCol'.
1106
0
      if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false,
1107
0
                                       &pivotRow)) {
1108
0
        // If inequalities are also non-zero in 'pivotCol', it can be
1109
0
        // eliminated.
1110
0
        continue;
1111
0
      }
1112
0
      break;
1113
0
    }
1114
0
1115
0
    // Eliminate identifier at 'pivotCol' from each equality row.
1116
0
    for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1117
0
      eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1118
0
                              /*isEq=*/true);
1119
0
      normalizeConstraintByGCD</*isEq=*/true>(this, i);
1120
0
    }
1121
0
1122
0
    // Eliminate identifier at 'pivotCol' from each inequality row.
1123
0
    for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1124
0
      eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1125
0
                              /*isEq=*/false);
1126
0
      normalizeConstraintByGCD</*isEq=*/false>(this, i);
1127
0
    }
1128
0
    removeEquality(pivotRow);
1129
0
    GCDTightenInequalities();
1130
0
  }
1131
0
  // Update position limit based on number eliminated.
1132
0
  posLimit = pivotCol;
1133
0
  // Remove eliminated columns from all constraints.
1134
0
  removeIdRange(posStart, posLimit);
1135
0
  return posLimit - posStart;
1136
0
}
1137
1138
// Detect the identifier at 'pos' (say id_r) as modulo of another identifier
1139
// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
1140
// could be detected as the floordiv of n. For eg:
1141
// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3    <=>
1142
//                          id_r = id_n mod 4, id_q = id_n floordiv 4.
1143
// lbConst and ubConst are the constant lower and upper bounds for 'pos' -
1144
// pre-detected at the caller.
1145
static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
1146
                        int64_t lbConst, int64_t ubConst,
1147
0
                        SmallVectorImpl<AffineExpr> *memo) {
1148
0
  assert(pos < cst.getNumIds() && "invalid position");
1149
0
1150
0
  // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
1151
0
  // id_n - divisor * id_q. If these are true, then id_n becomes the dividend
1152
0
  // and id_q the quotient when dividing id_n by the divisor.
1153
0
1154
0
  if (lbConst != 0 || ubConst < 1)
1155
0
    return false;
1156
0
1157
0
  int64_t divisor = ubConst + 1;
1158
0
1159
0
  // Now check for: id_r =  id_n - divisor * id_q. As an example, we
1160
0
  // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
1161
0
  unsigned seenQuotient = 0, seenDividend = 0;
1162
0
  int quotientPos = -1, dividendPos = -1;
1163
0
  for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1164
0
    // id_n should have coeff 1 or -1.
1165
0
    if (std::abs(cst.atEq(r, pos)) != 1)
1166
0
      continue;
1167
0
    // constant term should be 0.
1168
0
    if (cst.atEq(r, cst.getNumCols() - 1) != 0)
1169
0
      continue;
1170
0
    unsigned c, f;
1171
0
    int quotientSign = 1, dividendSign = 1;
1172
0
    for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
1173
0
      if (c == pos)
1174
0
        continue;
1175
0
      // The coefficient of the quotient should be +/-divisor.
1176
0
      // TODO(bondhugula): could be extended to detect an affine function for
1177
0
      // the quotient (i.e., the coeff could be a non-zero multiple of divisor).
1178
0
      int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
1179
0
      if (v == divisor || v == -divisor) {
1180
0
        seenQuotient++;
1181
0
        quotientPos = c;
1182
0
        quotientSign = v > 0 ? 1 : -1;
1183
0
      }
1184
0
      // The coefficient of the dividend should be +/-1.
1185
0
      // TODO(bondhugula): could be extended to detect an affine function of
1186
0
      // the other identifiers as the dividend.
1187
0
      else if (v == -1 || v == 1) {
1188
0
        seenDividend++;
1189
0
        dividendPos = c;
1190
0
        dividendSign = v < 0 ? 1 : -1;
1191
0
      } else if (cst.atEq(r, c) != 0) {
1192
0
        // Cannot be inferred as a mod since the constraint has a coefficient
1193
0
        // for an identifier that's neither a unit nor the divisor (see TODOs
1194
0
        // above).
1195
0
        break;
1196
0
      }
1197
0
    }
1198
0
    if (c < f)
1199
0
      // Cannot be inferred as a mod since the constraint has a coefficient for
1200
0
      // an identifier that's neither a unit nor the divisor (see TODOs above).
1201
0
      continue;
1202
0
1203
0
    // We are looking for exactly one identifier as the dividend.
1204
0
    if (seenDividend == 1 && seenQuotient >= 1) {
1205
0
      if (!(*memo)[dividendPos])
1206
0
        return false;
1207
0
      // Successfully detected a mod.
1208
0
      (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
1209
0
      auto ub = cst.getConstantUpperBound(dividendPos);
1210
0
      if (ub.hasValue() && ub.getValue() < divisor)
1211
0
        // The mod can be optimized away.
1212
0
        (*memo)[pos] = (*memo)[dividendPos] * dividendSign;
1213
0
      else
1214
0
        (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
1215
0
1216
0
      if (seenQuotient == 1 && !(*memo)[quotientPos])
1217
0
        // Successfully detected a floordiv as well.
1218
0
        (*memo)[quotientPos] =
1219
0
            (*memo)[dividendPos].floorDiv(divisor) * quotientSign;
1220
0
      return true;
1221
0
    }
1222
0
  }
1223
0
  return false;
1224
0
}
1225
1226
/// Gather all lower and upper bounds of the identifier at `pos`, and
1227
/// optionally any equalities on it. In addition, the bounds are to be
1228
/// independent of identifiers in position range [`offset`, `offset` + `num`).
1229
void FlatAffineConstraints::getLowerAndUpperBoundIndices(
1230
    unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
1231
    SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
1232
0
    unsigned offset, unsigned num) const {
1233
0
  assert(pos < getNumIds() && "invalid position");
1234
0
  assert(offset + num < getNumCols() && "invalid range");
1235
0
1236
0
  // Checks for a constraint that has a non-zero coeff for the identifiers in
1237
0
  // the position range [offset, offset + num) while ignoring `pos`.
1238
0
  auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
1239
0
    unsigned c, f;
1240
0
    auto cst = isEq ? getEquality(r) : getInequality(r);
1241
0
    for (c = offset, f = offset + num; c < f; ++c) {
1242
0
      if (c == pos)
1243
0
        continue;
1244
0
      if (cst[c] != 0)
1245
0
        break;
1246
0
    }
1247
0
    return c < f;
1248
0
  };
1249
0
1250
0
  // Gather all lower bounds and upper bounds of the variable. Since the
1251
0
  // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1252
0
  // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1253
0
  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1254
0
    // The bounds are to be independent of [offset, offset + num) columns.
1255
0
    if (containsConstraintDependentOnRange(r, /*isEq=*/false))
1256
0
      continue;
1257
0
    if (atIneq(r, pos) >= 1) {
1258
0
      // Lower bound.
1259
0
      lbIndices->push_back(r);
1260
0
    } else if (atIneq(r, pos) <= -1) {
1261
0
      // Upper bound.
1262
0
      ubIndices->push_back(r);
1263
0
    }
1264
0
  }
1265
0
1266
0
  // An equality is both a lower and upper bound. Record any equalities
1267
0
  // involving the pos^th identifier.
1268
0
  if (!eqIndices)
1269
0
    return;
1270
0
1271
0
  for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1272
0
    if (atEq(r, pos) == 0)
1273
0
      continue;
1274
0
    if (containsConstraintDependentOnRange(r, /*isEq=*/true))
1275
0
      continue;
1276
0
    eqIndices->push_back(r);
1277
0
  }
1278
0
}
1279
1280
/// Check if the pos^th identifier can be expressed as a floordiv of an affine
1281
/// function of other identifiers (where the divisor is a positive constant)
1282
/// given the initial set of expressions in `exprs`. If it can be, the
1283
/// corresponding position in `exprs` is set as the detected affine expr. For
1284
/// eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4. An equality can
1285
/// also yield a floordiv: eg.  4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
1286
/// <= i <= 32q + 31 => q = i floordiv 32.
1287
static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
1288
                             MLIRContext *context,
1289
0
                             SmallVectorImpl<AffineExpr> &exprs) {
1290
0
  assert(pos < cst.getNumIds() && "invalid position");
1291
0
1292
0
  SmallVector<unsigned, 4> lbIndices, ubIndices;
1293
0
  cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices);
1294
0
1295
0
  // Check if any lower bound, upper bound pair is of the form:
1296
0
  // divisor * id >=  expr - (divisor - 1)    <-- Lower bound for 'id'
1297
0
  // divisor * id <=  expr                    <-- Upper bound for 'id'
1298
0
  // Then, 'id' is equivalent to 'expr floordiv divisor'.  (where divisor > 1).
1299
0
  //
1300
0
  // For example, if -32*k + 16*i + j >= 0
1301
0
  //                  32*k - 16*i - j + 31 >= 0   <=>
1302
0
  //             k = ( 16*i + j ) floordiv 32
1303
0
  unsigned seenDividends = 0;
1304
0
  for (auto ubPos : ubIndices) {
1305
0
    for (auto lbPos : lbIndices) {
1306
0
      // Check if the lower bound's constant term is divisor - 1. The
1307
0
      // 'divisor' here is cst.atIneq(lbPos, pos) and we already know that it's
1308
0
      // positive (since cst.Ineq(lbPos, ...) is a lower bound expr for 'pos'.
1309
0
      int64_t divisor = cst.atIneq(lbPos, pos);
1310
0
      int64_t lbConstTerm = cst.atIneq(lbPos, cst.getNumCols() - 1);
1311
0
      if (lbConstTerm != divisor - 1)
1312
0
        continue;
1313
0
      // Check if upper bound's constant term is 0.
1314
0
      if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0)
1315
0
        continue;
1316
0
      // For the remaining part, check if the lower bound expr's coeff's are
1317
0
      // negations of corresponding upper bound ones'.
1318
0
      unsigned c, f;
1319
0
      for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1320
0
        if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c))
1321
0
          break;
1322
0
        if (c != pos && cst.atIneq(lbPos, c) != 0)
1323
0
          seenDividends++;
1324
0
      }
1325
0
      // Lb coeff's aren't negative of ub coeff's (for the non constant term
1326
0
      // part).
1327
0
      if (c < f)
1328
0
        continue;
1329
0
      if (seenDividends >= 1) {
1330
0
        // Construct the dividend expression.
1331
0
        auto dividendExpr = getAffineConstantExpr(0, context);
1332
0
        unsigned c, f;
1333
0
        for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1334
0
          if (c == pos)
1335
0
            continue;
1336
0
          int64_t ubVal = cst.atIneq(ubPos, c);
1337
0
          if (ubVal == 0)
1338
0
            continue;
1339
0
          if (!exprs[c])
1340
0
            break;
1341
0
          dividendExpr = dividendExpr + ubVal * exprs[c];
1342
0
        }
1343
0
        // Expression can't be constructed as it depends on a yet unknown
1344
0
        // identifier.
1345
0
        // TODO(mlir-team): Visit/compute the identifiers in an order so that
1346
0
        // this doesn't happen. More complex but much more efficient.
1347
0
        if (c < f)
1348
0
          continue;
1349
0
        // Successfully detected the floordiv.
1350
0
        exprs[pos] = dividendExpr.floorDiv(divisor);
1351
0
        return true;
1352
0
      }
1353
0
    }
1354
0
  }
1355
0
  return false;
1356
0
}
1357
1358
// Fills an inequality row with the value 'val'.
1359
static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
1360
0
                                  int64_t val) {
1361
0
  for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1362
0
    cst->atIneq(r, c) = val;
1363
0
  }
1364
0
}
1365
1366
// Negates an inequality.
1367
0
static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
1368
0
  for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1369
0
    cst->atIneq(r, c) = -cst->atIneq(r, c);
1370
0
  }
1371
0
}
1372
1373
// A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
1374
// to check if a constraint is redundant.
1375
0
void FlatAffineConstraints::removeRedundantInequalities() {
1376
0
  SmallVector<bool, 32> redun(getNumInequalities(), false);
1377
0
  // To check if an inequality is redundant, we replace the inequality by its
1378
0
  // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
1379
0
  // system is empty. If it is, the inequality is redundant.
1380
0
  FlatAffineConstraints tmpCst(*this);
1381
0
  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1382
0
    // Change the inequality to its complement.
1383
0
    negateInequality(&tmpCst, r);
1384
0
    tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
1385
0
    if (tmpCst.isEmpty()) {
1386
0
      redun[r] = true;
1387
0
      // Zero fill the redundant inequality.
1388
0
      fillInequality(this, r, /*val=*/0);
1389
0
      fillInequality(&tmpCst, r, /*val=*/0);
1390
0
    } else {
1391
0
      // Reverse the change (to avoid recreating tmpCst each time).
1392
0
      tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
1393
0
      negateInequality(&tmpCst, r);
1394
0
    }
1395
0
  }
1396
0
1397
0
  // Scan to get rid of all rows marked redundant, in-place.
1398
0
  auto copyRow = [&](unsigned src, unsigned dest) {
1399
0
    if (src == dest)
1400
0
      return;
1401
0
    for (unsigned c = 0, e = getNumCols(); c < e; c++) {
1402
0
      atIneq(dest, c) = atIneq(src, c);
1403
0
    }
1404
0
  };
1405
0
  unsigned pos = 0;
1406
0
  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1407
0
    if (!redun[r])
1408
0
      copyRow(r, pos++);
1409
0
  }
1410
0
  inequalities.resize(numReservedCols * pos);
1411
0
}
1412
1413
std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
1414
    unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
1415
    ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
1416
  assert(pos + offset < getNumDimIds() && "invalid dim start pos");
1417
  assert(symStartPos >= (pos + offset) && "invalid sym start pos");
1418
  assert(getNumLocalIds() == localExprs.size() &&
1419
         "incorrect local exprs count");
1420
1421
  SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
1422
  getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
1423
                               offset, num);
1424
1425
  /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
1426
0
  auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
1427
0
    b.clear();
1428
0
    for (unsigned i = 0, e = a.size(); i < e; ++i) {
1429
0
      if (i < offset || i >= offset + num)
1430
0
        b.push_back(a[i]);
1431
0
    }
1432
0
  };
1433
1434
  SmallVector<int64_t, 8> lb, ub;
1435
  SmallVector<AffineExpr, 4> lbExprs;
1436
  unsigned dimCount = symStartPos - num;
1437
  unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
1438
  lbExprs.reserve(lbIndices.size() + eqIndices.size());
1439
  // Lower bound expressions.
1440
  for (auto idx : lbIndices) {
1441
    auto ineq = getInequality(idx);
1442
    // Extract the lower bound (in terms of other coeff's + const), i.e., if
1443
    // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
1444
    // - 1.
1445
    addCoeffs(ineq, lb);
1446
    std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
1447
    auto expr =
1448
        getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
1449
    // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
1450
    int64_t divisor = std::abs(ineq[pos + offset]);
1451
    expr = (expr + divisor - 1).floorDiv(divisor);
1452
    lbExprs.push_back(expr);
1453
  }
1454
1455
  SmallVector<AffineExpr, 4> ubExprs;
1456
  ubExprs.reserve(ubIndices.size() + eqIndices.size());
1457
  // Upper bound expressions.
1458
  for (auto idx : ubIndices) {
1459
    auto ineq = getInequality(idx);
1460
    // Extract the upper bound (in terms of other coeff's + const).
1461
    addCoeffs(ineq, ub);
1462
    auto expr =
1463
        getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
1464
    expr = expr.floorDiv(std::abs(ineq[pos + offset]));
1465
    // Upper bound is exclusive.
1466
    ubExprs.push_back(expr + 1);
1467
  }
1468
1469
  // Equalities. It's both a lower and a upper bound.
1470
  SmallVector<int64_t, 4> b;
1471
  for (auto idx : eqIndices) {
1472
    auto eq = getEquality(idx);
1473
    addCoeffs(eq, b);
1474
    if (eq[pos + offset] > 0)
1475
      std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
1476
1477
    // Extract the upper bound (in terms of other coeff's + const).
1478
    auto expr =
1479
        getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1480
    expr = expr.floorDiv(std::abs(eq[pos + offset]));
1481
    // Upper bound is exclusive.
1482
    ubExprs.push_back(expr + 1);
1483
    // Lower bound.
1484
    expr =
1485
        getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1486
    expr = expr.ceilDiv(std::abs(eq[pos + offset]));
1487
    lbExprs.push_back(expr);
1488
  }
1489
1490
  auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
1491
  auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
1492
1493
  return {lbMap, ubMap};
1494
}
1495
1496
/// Computes the lower and upper bounds of the first 'num' dimensional
1497
/// identifiers (starting at 'offset') as affine maps of the remaining
1498
/// identifiers (dimensional and symbolic identifiers). Local identifiers are
1499
/// themselves explicitly computed as affine functions of other identifiers in
1500
/// this process if needed.
1501
void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
1502
                                           MLIRContext *context,
1503
                                           SmallVectorImpl<AffineMap> *lbMaps,
1504
0
                                           SmallVectorImpl<AffineMap> *ubMaps) {
1505
0
  assert(num < getNumDimIds() && "invalid range");
1506
0
1507
0
  // Basic simplification.
1508
0
  normalizeConstraintsByGCD();
1509
0
1510
0
  LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
1511
0
                          << " identifiers\n");
1512
0
  LLVM_DEBUG(dump());
1513
0
1514
0
  // Record computed/detected identifiers.
1515
0
  SmallVector<AffineExpr, 8> memo(getNumIds());
1516
0
  // Initialize dimensional and symbolic identifiers.
1517
0
  for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
1518
0
    if (i < offset)
1519
0
      memo[i] = getAffineDimExpr(i, context);
1520
0
    else if (i >= offset + num)
1521
0
      memo[i] = getAffineDimExpr(i - num, context);
1522
0
  }
1523
0
  for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
1524
0
    memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
1525
0
1526
0
  bool changed;
1527
0
  do {
1528
0
    changed = false;
1529
0
    // Identify yet unknown identifiers as constants or mod's / floordiv's of
1530
0
    // other identifiers if possible.
1531
0
    for (unsigned pos = 0; pos < getNumIds(); pos++) {
1532
0
      if (memo[pos])
1533
0
        continue;
1534
0
1535
0
      auto lbConst = getConstantLowerBound(pos);
1536
0
      auto ubConst = getConstantUpperBound(pos);
1537
0
      if (lbConst.hasValue() && ubConst.hasValue()) {
1538
0
        // Detect equality to a constant.
1539
0
        if (lbConst.getValue() == ubConst.getValue()) {
1540
0
          memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
1541
0
          changed = true;
1542
0
          continue;
1543
0
        }
1544
0
1545
0
        // Detect an identifier as modulo of another identifier w.r.t a
1546
0
        // constant.
1547
0
        if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
1548
0
                        &memo)) {
1549
0
          changed = true;
1550
0
          continue;
1551
0
        }
1552
0
      }
1553
0
1554
0
      // Detect an identifier as a floordiv of an affine function of other
1555
0
      // identifiers (divisor is a positive constant).
1556
0
      if (detectAsFloorDiv(*this, pos, context, memo)) {
1557
0
        changed = true;
1558
0
        continue;
1559
0
      }
1560
0
1561
0
      // Detect an identifier as an expression of other identifiers.
1562
0
      unsigned idx;
1563
0
      if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) {
1564
0
        continue;
1565
0
      }
1566
0
1567
0
      // Build AffineExpr solving for identifier 'pos' in terms of all others.
1568
0
      auto expr = getAffineConstantExpr(0, context);
1569
0
      unsigned j, e;
1570
0
      for (j = 0, e = getNumIds(); j < e; ++j) {
1571
0
        if (j == pos)
1572
0
          continue;
1573
0
        int64_t c = atEq(idx, j);
1574
0
        if (c == 0)
1575
0
          continue;
1576
0
        // If any of the involved IDs hasn't been found yet, we can't proceed.
1577
0
        if (!memo[j])
1578
0
          break;
1579
0
        expr = expr + memo[j] * c;
1580
0
      }
1581
0
      if (j < e)
1582
0
        // Can't construct expression as it depends on a yet uncomputed
1583
0
        // identifier.
1584
0
        continue;
1585
0
1586
0
      // Add constant term to AffineExpr.
1587
0
      expr = expr + atEq(idx, getNumIds());
1588
0
      int64_t vPos = atEq(idx, pos);
1589
0
      assert(vPos != 0 && "expected non-zero here");
1590
0
      if (vPos > 0)
1591
0
        expr = (-expr).floorDiv(vPos);
1592
0
      else
1593
0
        // vPos < 0.
1594
0
        expr = expr.floorDiv(-vPos);
1595
0
      // Successfully constructed expression.
1596
0
      memo[pos] = expr;
1597
0
      changed = true;
1598
0
    }
1599
0
    // This loop is guaranteed to reach a fixed point - since once an
1600
0
    // identifier's explicit form is computed (in memo[pos]), it's not updated
1601
0
    // again.
1602
0
  } while (changed);
1603
0
1604
0
  // Set the lower and upper bound maps for all the identifiers that were
1605
0
  // computed as affine expressions of the rest as the "detected expr" and
1606
0
  // "detected expr + 1" respectively; set the undetected ones to null.
1607
0
  Optional<FlatAffineConstraints> tmpClone;
1608
0
  for (unsigned pos = 0; pos < num; pos++) {
1609
0
    unsigned numMapDims = getNumDimIds() - num;
1610
0
    unsigned numMapSymbols = getNumSymbolIds();
1611
0
    AffineExpr expr = memo[pos + offset];
1612
0
    if (expr)
1613
0
      expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
1614
0
1615
0
    AffineMap &lbMap = (*lbMaps)[pos];
1616
0
    AffineMap &ubMap = (*ubMaps)[pos];
1617
0
1618
0
    if (expr) {
1619
0
      lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
1620
0
      ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
1621
0
    } else {
1622
0
      // TODO(bondhugula): Whenever there are local identifiers in the
1623
0
      // dependence constraints, we'll conservatively over-approximate, since we
1624
0
      // don't always explicitly compute them above (in the while loop).
1625
0
      if (getNumLocalIds() == 0) {
1626
0
        // Work on a copy so that we don't update this constraint system.
1627
0
        if (!tmpClone) {
1628
0
          tmpClone.emplace(FlatAffineConstraints(*this));
1629
0
          // Removing redundant inequalities is necessary so that we don't get
1630
0
          // redundant loop bounds.
1631
0
          tmpClone->removeRedundantInequalities();
1632
0
        }
1633
0
        std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
1634
0
            pos, offset, num, getNumDimIds(), /*localExprs=*/{}, context);
1635
0
      }
1636
0
1637
0
      // If the above fails, we'll just use the constant lower bound and the
1638
0
      // constant upper bound (if they exist) as the slice bounds.
1639
0
      // TODO(b/126426796): being conservative for the moment in cases that
1640
0
      // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
1641
0
      // fixed (b/126426796).
1642
0
      if (!lbMap || lbMap.getNumResults() > 1) {
1643
0
        LLVM_DEBUG(llvm::dbgs()
1644
0
                   << "WARNING: Potentially over-approximating slice lb\n");
1645
0
        auto lbConst = getConstantLowerBound(pos + offset);
1646
0
        if (lbConst.hasValue()) {
1647
0
          lbMap = AffineMap::get(
1648
0
              numMapDims, numMapSymbols,
1649
0
              getAffineConstantExpr(lbConst.getValue(), context));
1650
0
        }
1651
0
      }
1652
0
      if (!ubMap || ubMap.getNumResults() > 1) {
1653
0
        LLVM_DEBUG(llvm::dbgs()
1654
0
                   << "WARNING: Potentially over-approximating slice ub\n");
1655
0
        auto ubConst = getConstantUpperBound(pos + offset);
1656
0
        if (ubConst.hasValue()) {
1657
0
          (ubMap) = AffineMap::get(
1658
0
              numMapDims, numMapSymbols,
1659
0
              getAffineConstantExpr(ubConst.getValue() + 1, context));
1660
0
        }
1661
0
      }
1662
0
    }
1663
0
    LLVM_DEBUG(llvm::dbgs()
1664
0
               << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
1665
0
    LLVM_DEBUG(lbMap.dump(););
1666
0
    LLVM_DEBUG(llvm::dbgs()
1667
0
               << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
1668
0
    LLVM_DEBUG(ubMap.dump(););
1669
0
  }
1670
0
}
1671
1672
LogicalResult
1673
FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
1674
                                            ValueRange boundOperands, bool eq,
1675
0
                                            bool lower) {
1676
0
  assert(pos < getNumDimAndSymbolIds() && "invalid position");
1677
0
  // Equality follows the logic of lower bound except that we add an equality
1678
0
  // instead of an inequality.
1679
0
  assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
1680
0
  if (eq)
1681
0
    lower = true;
1682
0
1683
0
  // Fully compose map and operands; canonicalize and simplify so that we
1684
0
  // transitively get to terminal symbols or loop IVs.
1685
0
  auto map = boundMap;
1686
0
  SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
1687
0
  fullyComposeAffineMapAndOperands(&map, &operands);
1688
0
  map = simplifyAffineMap(map);
1689
0
  canonicalizeMapAndOperands(&map, &operands);
1690
0
  for (auto operand : operands)
1691
0
    addInductionVarOrTerminalSymbol(operand);
1692
0
1693
0
  FlatAffineConstraints localVarCst;
1694
0
  std::vector<SmallVector<int64_t, 8>> flatExprs;
1695
0
  if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) {
1696
0
    LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
1697
0
    return failure();
1698
0
  }
1699
0
1700
0
  // Merge and align with localVarCst.
1701
0
  if (localVarCst.getNumLocalIds() > 0) {
1702
0
    // Set values for localVarCst.
1703
0
    localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
1704
0
    for (auto operand : operands) {
1705
0
      unsigned pos;
1706
0
      if (findId(operand, &pos)) {
1707
0
        if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
1708
0
          // If the local var cst has this as a dim, turn it into its symbol.
1709
0
          turnDimIntoSymbol(&localVarCst, operand);
1710
0
        } else if (pos < getNumDimIds()) {
1711
0
          // Or vice versa.
1712
0
          turnSymbolIntoDim(&localVarCst, operand);
1713
0
        }
1714
0
      }
1715
0
    }
1716
0
    mergeAndAlignIds(/*offset=*/0, this, &localVarCst);
1717
0
    append(localVarCst);
1718
0
  }
1719
0
1720
0
  // Record positions of the operands in the constraint system. Need to do
1721
0
  // this here since the constraint system changes after a bound is added.
1722
0
  SmallVector<unsigned, 8> positions;
1723
0
  unsigned numOperands = operands.size();
1724
0
  for (auto operand : operands) {
1725
0
    unsigned pos;
1726
0
    if (!findId(operand, &pos))
1727
0
      assert(0 && "expected to be found");
1728
0
    positions.push_back(pos);
1729
0
  }
1730
0
1731
0
  for (const auto &flatExpr : flatExprs) {
1732
0
    SmallVector<int64_t, 4> ineq(getNumCols(), 0);
1733
0
    ineq[pos] = lower ? 1 : -1;
1734
0
    // Dims and symbols.
1735
0
    for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) {
1736
0
      ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
1737
0
    }
1738
0
    // Copy over the local id coefficients.
1739
0
    unsigned numLocalIds = flatExpr.size() - 1 - numOperands;
1740
0
    for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds;
1741
0
         jj++, j++) {
1742
0
      ineq[j] =
1743
0
          lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj];
1744
0
    }
1745
0
    // Constant term.
1746
0
    ineq[getNumCols() - 1] =
1747
0
        lower ? -flatExpr[flatExpr.size() - 1]
1748
0
              // Upper bound in flattenedExpr is an exclusive one.
1749
0
              : flatExpr[flatExpr.size() - 1] - 1;
1750
0
    eq ? addEquality(ineq) : addInequality(ineq);
1751
0
  }
1752
0
  return success();
1753
0
}
1754
1755
// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
1756
// bounds in 'ubMaps' to each value in `values' that appears in the constraint
1757
// system. Note that both lower/upper bounds share the same operand list
1758
// 'operands'.
1759
// This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
1760
// skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
1761
// Note that both lower/upper bounds use operands from 'operands'.
1762
// Returns failure for unimplemented cases such as semi-affine expressions or
1763
// expressions with mod/floordiv.
1764
LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values,
1765
                                                    ArrayRef<AffineMap> lbMaps,
1766
                                                    ArrayRef<AffineMap> ubMaps,
1767
0
                                                    ArrayRef<Value> operands) {
1768
0
  assert(values.size() == lbMaps.size());
1769
0
  assert(lbMaps.size() == ubMaps.size());
1770
0
1771
0
  for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
1772
0
    unsigned pos;
1773
0
    if (!findId(values[i], &pos))
1774
0
      continue;
1775
0
1776
0
    AffineMap lbMap = lbMaps[i];
1777
0
    AffineMap ubMap = ubMaps[i];
1778
0
    assert(!lbMap || lbMap.getNumInputs() == operands.size());
1779
0
    assert(!ubMap || ubMap.getNumInputs() == operands.size());
1780
0
1781
0
    // Check if this slice is just an equality along this dimension.
1782
0
    if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
1783
0
        ubMap.getNumResults() == 1 &&
1784
0
        lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
1785
0
      if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
1786
0
                                      /*lower=*/true)))
1787
0
        return failure();
1788
0
      continue;
1789
0
    }
1790
0
1791
0
    if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
1792
0
                                             /*lower=*/true)))
1793
0
      return failure();
1794
0
1795
0
    if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
1796
0
                                             /*lower=*/false)))
1797
0
      return failure();
1798
0
  }
1799
0
  return success();
1800
0
}
1801
1802
16
void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
1803
16
  assert(eq.size() == getNumCols());
1804
16
  unsigned offset = equalities.size();
1805
16
  equalities.resize(equalities.size() + numReservedCols);
1806
16
  std::copy(eq.begin(), eq.end(), equalities.begin() + offset);
1807
16
}
1808
1809
102
void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
1810
102
  assert(inEq.size() == getNumCols());
1811
102
  unsigned offset = inequalities.size();
1812
102
  inequalities.resize(inequalities.size() + numReservedCols);
1813
102
  std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset);
1814
102
}
1815
1816
0
void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
1817
0
  assert(pos < getNumCols());
1818
0
  unsigned offset = inequalities.size();
1819
0
  inequalities.resize(inequalities.size() + numReservedCols);
1820
0
  std::fill(inequalities.begin() + offset,
1821
0
            inequalities.begin() + offset + getNumCols(), 0);
1822
0
  inequalities[offset + pos] = 1;
1823
0
  inequalities[offset + getNumCols() - 1] = -lb;
1824
0
}
1825
1826
0
void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) {
1827
0
  assert(pos < getNumCols());
1828
0
  unsigned offset = inequalities.size();
1829
0
  inequalities.resize(inequalities.size() + numReservedCols);
1830
0
  std::fill(inequalities.begin() + offset,
1831
0
            inequalities.begin() + offset + getNumCols(), 0);
1832
0
  inequalities[offset + pos] = -1;
1833
0
  inequalities[offset + getNumCols() - 1] = ub;
1834
0
}
1835
1836
void FlatAffineConstraints::addConstantLowerBound(ArrayRef<int64_t> expr,
1837
0
                                                  int64_t lb) {
1838
0
  assert(expr.size() == getNumCols());
1839
0
  unsigned offset = inequalities.size();
1840
0
  inequalities.resize(inequalities.size() + numReservedCols);
1841
0
  std::fill(inequalities.begin() + offset,
1842
0
            inequalities.begin() + offset + getNumCols(), 0);
1843
0
  std::copy(expr.begin(), expr.end(), inequalities.begin() + offset);
1844
0
  inequalities[offset + getNumCols() - 1] += -lb;
1845
0
}
1846
1847
void FlatAffineConstraints::addConstantUpperBound(ArrayRef<int64_t> expr,
1848
0
                                                  int64_t ub) {
1849
0
  assert(expr.size() == getNumCols());
1850
0
  unsigned offset = inequalities.size();
1851
0
  inequalities.resize(inequalities.size() + numReservedCols);
1852
0
  std::fill(inequalities.begin() + offset,
1853
0
            inequalities.begin() + offset + getNumCols(), 0);
1854
0
  for (unsigned i = 0, e = getNumCols(); i < e; i++) {
1855
0
    inequalities[offset + i] = -expr[i];
1856
0
  }
1857
0
  inequalities[offset + getNumCols() - 1] += ub;
1858
0
}
1859
1860
/// Adds a new local identifier as the floordiv of an affine function of other
1861
/// identifiers, the coefficients of which are provided in 'dividend' and with
1862
/// respect to a positive constant 'divisor'. Two constraints are added to the
1863
/// system to capture equivalence with the floordiv.
1864
///      q = expr floordiv c    <=>   c*q <= expr <= c*q + c - 1.
1865
void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend,
1866
0
                                             int64_t divisor) {
1867
0
  assert(dividend.size() == getNumCols() && "incorrect dividend size");
1868
0
  assert(divisor > 0 && "positive divisor expected");
1869
0
1870
0
  addLocalId(getNumLocalIds());
1871
0
1872
0
  // Add two constraints for this new identifier 'q'.
1873
0
  SmallVector<int64_t, 8> bound(dividend.size() + 1);
1874
0
1875
0
  // dividend - q * divisor >= 0
1876
0
  std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
1877
0
            bound.begin());
1878
0
  bound.back() = dividend.back();
1879
0
  bound[getNumIds() - 1] = -divisor;
1880
0
  addInequality(bound);
1881
0
1882
0
  // -dividend +qdivisor * q + divisor - 1 >= 0
1883
0
  std::transform(bound.begin(), bound.end(), bound.begin(),
1884
0
                 std::negate<int64_t>());
1885
0
  bound[bound.size() - 1] += divisor - 1;
1886
0
  addInequality(bound);
1887
0
}
1888
1889
0
bool FlatAffineConstraints::findId(Value id, unsigned *pos) const {
1890
0
  unsigned i = 0;
1891
0
  for (const auto &mayBeId : ids) {
1892
0
    if (mayBeId.hasValue() && mayBeId.getValue() == id) {
1893
0
      *pos = i;
1894
0
      return true;
1895
0
    }
1896
0
    i++;
1897
0
  }
1898
0
  return false;
1899
0
}
1900
1901
0
bool FlatAffineConstraints::containsId(Value id) const {
1902
0
  return llvm::any_of(ids, [&](const Optional<Value> &mayBeId) {
1903
0
    return mayBeId.hasValue() && mayBeId.getValue() == id;
1904
0
  });
1905
0
}
1906
1907
0
void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
1908
0
  assert(newSymbolCount <= numDims + numSymbols &&
1909
0
         "invalid separation position");
1910
0
  numDims = numDims + numSymbols - newSymbolCount;
1911
0
  numSymbols = newSymbolCount;
1912
0
}
1913
1914
/// Sets the specified identifier to a constant value.
1915
0
void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
1916
0
  unsigned offset = equalities.size();
1917
0
  equalities.resize(equalities.size() + numReservedCols);
1918
0
  std::fill(equalities.begin() + offset,
1919
0
            equalities.begin() + offset + getNumCols(), 0);
1920
0
  equalities[offset + pos] = 1;
1921
0
  equalities[offset + getNumCols() - 1] = -val;
1922
0
}
1923
1924
/// Sets the specified identifier to a constant value; asserts if the id is not
1925
/// found.
1926
0
void FlatAffineConstraints::setIdToConstant(Value id, int64_t val) {
1927
0
  unsigned pos;
1928
0
  if (!findId(id, &pos))
1929
0
    // This is a pre-condition for this method.
1930
0
    assert(0 && "id not found");
1931
0
  setIdToConstant(pos, val);
1932
0
}
1933
1934
0
void FlatAffineConstraints::removeEquality(unsigned pos) {
1935
0
  unsigned numEqualities = getNumEqualities();
1936
0
  assert(pos < numEqualities);
1937
0
  unsigned outputIndex = pos * numReservedCols;
1938
0
  unsigned inputIndex = (pos + 1) * numReservedCols;
1939
0
  unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols;
1940
0
  std::copy(equalities.begin() + inputIndex,
1941
0
            equalities.begin() + inputIndex + numElemsToCopy,
1942
0
            equalities.begin() + outputIndex);
1943
0
  assert(equalities.size() >= numReservedCols);
1944
0
  equalities.resize(equalities.size() - numReservedCols);
1945
0
}
1946
1947
0
void FlatAffineConstraints::removeInequality(unsigned pos) {
1948
0
  unsigned numInequalities = getNumInequalities();
1949
0
  assert(pos < numInequalities && "invalid position");
1950
0
  unsigned outputIndex = pos * numReservedCols;
1951
0
  unsigned inputIndex = (pos + 1) * numReservedCols;
1952
0
  unsigned numElemsToCopy = (numInequalities - pos - 1) * numReservedCols;
1953
0
  std::copy(inequalities.begin() + inputIndex,
1954
0
            inequalities.begin() + inputIndex + numElemsToCopy,
1955
0
            inequalities.begin() + outputIndex);
1956
0
  assert(inequalities.size() >= numReservedCols);
1957
0
  inequalities.resize(inequalities.size() - numReservedCols);
1958
0
}
1959
1960
/// Finds an equality that equates the specified identifier to a constant.
1961
/// Returns the position of the equality row. If 'symbolic' is set to true,
1962
/// symbols are also treated like a constant, i.e., an affine function of the
1963
/// symbols is also treated like a constant. Returns -1 if such an equality
1964
/// could not be found.
1965
static int findEqualityToConstant(const FlatAffineConstraints &cst,
1966
0
                                  unsigned pos, bool symbolic = false) {
1967
0
  assert(pos < cst.getNumIds() && "invalid position");
1968
0
  for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1969
0
    int64_t v = cst.atEq(r, pos);
1970
0
    if (v * v != 1)
1971
0
      continue;
1972
0
    unsigned c;
1973
0
    unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
1974
0
    // This checks for zeros in all positions other than 'pos' in [0, f)
1975
0
    for (c = 0; c < f; c++) {
1976
0
      if (c == pos)
1977
0
        continue;
1978
0
      if (cst.atEq(r, c) != 0) {
1979
0
        // Dependent on another identifier.
1980
0
        break;
1981
0
      }
1982
0
    }
1983
0
    if (c == f)
1984
0
      // Equality is free of other identifiers.
1985
0
      return r;
1986
0
  }
1987
0
  return -1;
1988
0
}
1989
1990
0
void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) {
1991
0
  assert(pos < getNumIds() && "invalid position");
1992
0
  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1993
0
    atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal;
1994
0
  }
1995
0
  for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1996
0
    atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal;
1997
0
  }
1998
0
  removeId(pos);
1999
0
}
2000
2001
0
LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) {
2002
0
  assert(pos < getNumIds() && "invalid position");
2003
0
  int rowIdx;
2004
0
  if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
2005
0
    return failure();
2006
0
2007
0
  // atEq(rowIdx, pos) is either -1 or 1.
2008
0
  assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
2009
0
  int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
2010
0
  setAndEliminate(pos, constVal);
2011
0
  return success();
2012
0
}
2013
2014
0
void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
2015
0
  for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
2016
0
    if (failed(constantFoldId(t)))
2017
0
      t++;
2018
0
  }
2019
0
}
2020
2021
/// Returns the extent (upper bound - lower bound) of the specified
2022
/// identifier if it is found to be a constant; returns None if it's not a
2023
/// constant. This methods treats symbolic identifiers specially, i.e.,
2024
/// it looks for constant differences between affine expressions involving
2025
/// only the symbolic identifiers. See comments at function definition for
2026
/// example. 'lb', if provided, is set to the lower bound associated with the
2027
/// constant difference. Note that 'lb' is purely symbolic and thus will contain
2028
/// the coefficients of the symbolic identifiers and the constant coefficient.
2029
//  Egs: 0 <= i <= 15, return 16.
2030
//       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
2031
//       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
2032
//       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
2033
//       ceil(s0 - 7 / 8) = floor(s0 / 8)).
2034
Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
2035
    unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
2036
    SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
2037
0
    unsigned *minUbPos) const {
2038
0
  assert(pos < getNumDimIds() && "Invalid identifier position");
2039
0
2040
0
  // Find an equality for 'pos'^th identifier that equates it to some function
2041
0
  // of the symbolic identifiers (+ constant).
2042
0
  int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
2043
0
  if (eqPos != -1) {
2044
0
    auto eq = getEquality(eqPos);
2045
0
    // If the equality involves a local var, punt for now.
2046
0
    // TODO: this can be handled in the future by using the explicit
2047
0
    // representation of the local vars.
2048
0
    if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1,
2049
0
                     [](int64_t coeff) { return coeff == 0; }))
2050
0
      return None;
2051
0
2052
0
    // This identifier can only take a single value.
2053
0
    if (lb) {
2054
0
      // Set lb to that symbolic value.
2055
0
      lb->resize(getNumSymbolIds() + 1);
2056
0
      if (ub)
2057
0
        ub->resize(getNumSymbolIds() + 1);
2058
0
      for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
2059
0
        int64_t v = atEq(eqPos, pos);
2060
0
        // atEq(eqRow, pos) is either -1 or 1.
2061
0
        assert(v * v == 1);
2062
0
        (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v
2063
0
                         : -atEq(eqPos, getNumDimIds() + c) / v;
2064
0
        // Since this is an equality, ub = lb.
2065
0
        if (ub)
2066
0
          (*ub)[c] = (*lb)[c];
2067
0
      }
2068
0
      assert(boundFloorDivisor &&
2069
0
             "both lb and divisor or none should be provided");
2070
0
      *boundFloorDivisor = 1;
2071
0
    }
2072
0
    if (minLbPos)
2073
0
      *minLbPos = eqPos;
2074
0
    if (minUbPos)
2075
0
      *minUbPos = eqPos;
2076
0
    return 1;
2077
0
  }
2078
0
2079
0
  // Check if the identifier appears at all in any of the inequalities.
2080
0
  unsigned r, e;
2081
0
  for (r = 0, e = getNumInequalities(); r < e; r++) {
2082
0
    if (atIneq(r, pos) != 0)
2083
0
      break;
2084
0
  }
2085
0
  if (r == e)
2086
0
    // If it doesn't, there isn't a bound on it.
2087
0
    return None;
2088
0
2089
0
  // Positions of constraints that are lower/upper bounds on the variable.
2090
0
  SmallVector<unsigned, 4> lbIndices, ubIndices;
2091
0
2092
0
  // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
2093
0
  // the bounds can only involve symbolic (and local) identifiers. Since the
2094
0
  // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2095
0
  // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2096
0
  getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
2097
0
                               /*eqIndices=*/nullptr, /*offset=*/0,
2098
0
                               /*num=*/getNumDimIds());
2099
0
2100
0
  Optional<int64_t> minDiff = None;
2101
0
  unsigned minLbPosition = 0, minUbPosition = 0;
2102
0
  for (auto ubPos : ubIndices) {
2103
0
    for (auto lbPos : lbIndices) {
2104
0
      // Look for a lower bound and an upper bound that only differ by a
2105
0
      // constant, i.e., pairs of the form  0 <= c_pos - f(c_i's) <= diffConst.
2106
0
      // For example, if ii is the pos^th variable, we are looking for
2107
0
      // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
2108
0
      // minimum among all such constant differences is kept since that's the
2109
0
      // constant bounding the extent of the pos^th variable.
2110
0
      unsigned j, e;
2111
0
      for (j = 0, e = getNumCols() - 1; j < e; j++)
2112
0
        if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
2113
0
          break;
2114
0
        }
2115
0
      if (j < getNumCols() - 1)
2116
0
        continue;
2117
0
      int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
2118
0
                                 atIneq(lbPos, getNumCols() - 1) + 1,
2119
0
                             atIneq(lbPos, pos));
2120
0
      if (minDiff == None || diff < minDiff) {
2121
0
        minDiff = diff;
2122
0
        minLbPosition = lbPos;
2123
0
        minUbPosition = ubPos;
2124
0
      }
2125
0
    }
2126
0
  }
2127
0
  if (lb && minDiff.hasValue()) {
2128
0
    // Set lb to the symbolic lower bound.
2129
0
    lb->resize(getNumSymbolIds() + 1);
2130
0
    if (ub)
2131
0
      ub->resize(getNumSymbolIds() + 1);
2132
0
    // The lower bound is the ceildiv of the lb constraint over the coefficient
2133
0
    // of the variable at 'pos'. We express the ceildiv equivalently as a floor
2134
0
    // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
2135
0
    // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
2136
0
    *boundFloorDivisor = atIneq(minLbPosition, pos);
2137
0
    assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
2138
0
    for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
2139
0
      (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
2140
0
    }
2141
0
    if (ub) {
2142
0
      for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
2143
0
        (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
2144
0
    }
2145
0
    // The lower bound leads to a ceildiv while the upper bound is a floordiv
2146
0
    // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
2147
0
    // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
2148
0
    // the constant term for the lower bound.
2149
0
    (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
2150
0
  }
2151
0
  if (minLbPos)
2152
0
    *minLbPos = minLbPosition;
2153
0
  if (minUbPos)
2154
0
    *minUbPos = minUbPosition;
2155
0
  return minDiff;
2156
0
}
2157
2158
template <bool isLower>
2159
Optional<int64_t>
2160
0
FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
2161
0
  assert(pos < getNumIds() && "invalid position");
2162
0
  // Project to 'pos'.
2163
0
  projectOut(0, pos);
2164
0
  projectOut(1, getNumIds() - 1);
2165
0
  // Check if there's an equality equating the '0'^th identifier to a constant.
2166
0
  int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
2167
0
  if (eqRowIdx != -1)
2168
0
    // atEq(rowIdx, 0) is either -1 or 1.
2169
0
    return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
2170
0
2171
0
  // Check if the identifier appears at all in any of the inequalities.
2172
0
  unsigned r, e;
2173
0
  for (r = 0, e = getNumInequalities(); r < e; r++) {
2174
0
    if (atIneq(r, 0) != 0)
2175
0
      break;
2176
0
  }
2177
0
  if (r == e)
2178
0
    // If it doesn't, there isn't a bound on it.
2179
0
    return None;
2180
0
2181
0
  Optional<int64_t> minOrMaxConst = None;
2182
0
2183
0
  // Take the max across all const lower bounds (or min across all constant
2184
0
  // upper bounds).
2185
0
  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2186
0
    if (isLower) {
2187
0
      if (atIneq(r, 0) <= 0)
2188
0
        // Not a lower bound.
2189
0
        continue;
2190
0
    } else if (atIneq(r, 0) >= 0) {
2191
0
      // Not an upper bound.
2192
0
      continue;
2193
0
    }
2194
0
    unsigned c, f;
2195
0
    for (c = 0, f = getNumCols() - 1; c < f; c++)
2196
0
      if (c != 0 && atIneq(r, c) != 0)
2197
0
        break;
2198
0
    if (c < getNumCols() - 1)
2199
0
      // Not a constant bound.
2200
0
      continue;
2201
0
2202
0
    int64_t boundConst =
2203
0
        isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
2204
0
                : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
2205
0
    if (isLower) {
2206
0
      if (minOrMaxConst == None || boundConst > minOrMaxConst)
2207
0
        minOrMaxConst = boundConst;
2208
0
    } else {
2209
0
      if (minOrMaxConst == None || boundConst < minOrMaxConst)
2210
0
        minOrMaxConst = boundConst;
2211
0
    }
2212
0
  }
2213
0
  return minOrMaxConst;
2214
0
}
Unexecuted instantiation: _ZN4mlir21FlatAffineConstraints32computeConstantLowerOrUpperBoundILb1EEEN4llvm8OptionalIlEEj
Unexecuted instantiation: _ZN4mlir21FlatAffineConstraints32computeConstantLowerOrUpperBoundILb0EEEN4llvm8OptionalIlEEj
2215
2216
Optional<int64_t>
2217
0
FlatAffineConstraints::getConstantLowerBound(unsigned pos) const {
2218
0
  FlatAffineConstraints tmpCst(*this);
2219
0
  return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
2220
0
}
2221
2222
Optional<int64_t>
2223
0
FlatAffineConstraints::getConstantUpperBound(unsigned pos) const {
2224
0
  FlatAffineConstraints tmpCst(*this);
2225
0
  return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
2226
0
}
2227
2228
// A simple (naive and conservative) check for hyper-rectangularity.
2229
bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
2230
0
                                               unsigned num) const {
2231
0
  assert(pos < getNumCols() - 1);
2232
0
  // Check for two non-zero coefficients in the range [pos, pos + sum).
2233
0
  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2234
0
    unsigned sum = 0;
2235
0
    for (unsigned c = pos; c < pos + num; c++) {
2236
0
      if (atIneq(r, c) != 0)
2237
0
        sum++;
2238
0
    }
2239
0
    if (sum > 1)
2240
0
      return false;
2241
0
  }
2242
0
  for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2243
0
    unsigned sum = 0;
2244
0
    for (unsigned c = pos; c < pos + num; c++) {
2245
0
      if (atEq(r, c) != 0)
2246
0
        sum++;
2247
0
    }
2248
0
    if (sum > 1)
2249
0
      return false;
2250
0
  }
2251
0
  return true;
2252
0
}
2253
2254
0
void FlatAffineConstraints::print(raw_ostream &os) const {
2255
0
  assert(hasConsistentState());
2256
0
  os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds()
2257
0
     << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints()
2258
0
     << " constraints)\n";
2259
0
  os << "(";
2260
0
  for (unsigned i = 0, e = getNumIds(); i < e; i++) {
2261
0
    if (ids[i] == None)
2262
0
      os << "None ";
2263
0
    else
2264
0
      os << "Value ";
2265
0
  }
2266
0
  os << " const)\n";
2267
0
  for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
2268
0
    for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2269
0
      os << atEq(i, j) << " ";
2270
0
    }
2271
0
    os << "= 0\n";
2272
0
  }
2273
0
  for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2274
0
    for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2275
0
      os << atIneq(i, j) << " ";
2276
0
    }
2277
0
    os << ">= 0\n";
2278
0
  }
2279
0
  os << '\n';
2280
0
}
2281
2282
0
void FlatAffineConstraints::dump() const { print(llvm::errs()); }
2283
2284
/// Removes duplicate constraints, trivially true constraints, and constraints
2285
/// that can be detected as redundant as a result of differing only in their
2286
/// constant term part. A constraint of the form <non-negative constant> >= 0 is
2287
/// considered trivially true.
2288
//  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
2289
//  remove duplicates in place.
2290
0
void FlatAffineConstraints::removeTrivialRedundancy() {
2291
0
  GCDTightenInequalities();
2292
0
  normalizeConstraintsByGCD();
2293
0
2294
0
  // A map used to detect redundancy stemming from constraints that only differ
2295
0
  // in their constant term. The value stored is <row position, const term>
2296
0
  // for a given row.
2297
0
  SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
2298
0
      rowsWithoutConstTerm;
2299
0
  // To unique rows.
2300
0
  SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
2301
0
2302
0
  // Check if constraint is of the form <non-negative-constant> >= 0.
2303
0
  auto isTriviallyValid = [&](unsigned r) -> bool {
2304
0
    for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
2305
0
      if (atIneq(r, c) != 0)
2306
0
        return false;
2307
0
    }
2308
0
    return atIneq(r, getNumCols() - 1) >= 0;
2309
0
  };
2310
0
2311
0
  // Detect and mark redundant constraints.
2312
0
  SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
2313
0
  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2314
0
    int64_t *rowStart = inequalities.data() + numReservedCols * r;
2315
0
    auto row = ArrayRef<int64_t>(rowStart, getNumCols());
2316
0
    if (isTriviallyValid(r) || !rowSet.insert(row).second) {
2317
0
      redunIneq[r] = true;
2318
0
      continue;
2319
0
    }
2320
0
2321
0
    // Among constraints that only differ in the constant term part, mark
2322
0
    // everything other than the one with the smallest constant term redundant.
2323
0
    // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
2324
0
    // former two are redundant).
2325
0
    int64_t constTerm = atIneq(r, getNumCols() - 1);
2326
0
    auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
2327
0
    const auto &ret =
2328
0
        rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
2329
0
    if (!ret.second) {
2330
0
      // Check if the other constraint has a higher constant term.
2331
0
      auto &val = ret.first->second;
2332
0
      if (val.second > constTerm) {
2333
0
        // The stored row is redundant. Mark it so, and update with this one.
2334
0
        redunIneq[val.first] = true;
2335
0
        val = {r, constTerm};
2336
0
      } else {
2337
0
        // The one stored makes this one redundant.
2338
0
        redunIneq[r] = true;
2339
0
      }
2340
0
    }
2341
0
  }
2342
0
2343
0
  auto copyRow = [&](unsigned src, unsigned dest) {
2344
0
    if (src == dest)
2345
0
      return;
2346
0
    for (unsigned c = 0, e = getNumCols(); c < e; c++) {
2347
0
      atIneq(dest, c) = atIneq(src, c);
2348
0
    }
2349
0
  };
2350
0
2351
0
  // Scan to get rid of all rows marked redundant, in-place.
2352
0
  unsigned pos = 0;
2353
0
  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2354
0
    if (!redunIneq[r])
2355
0
      copyRow(r, pos++);
2356
0
  }
2357
0
  inequalities.resize(numReservedCols * pos);
2358
0
2359
0
  // TODO(bondhugula): consider doing this for equalities as well, but probably
2360
0
  // not worth the savings.
2361
0
}
2362
2363
void FlatAffineConstraints::clearAndCopyFrom(
2364
0
    const FlatAffineConstraints &other) {
2365
0
  FlatAffineConstraints copy(other);
2366
0
  std::swap(*this, copy);
2367
0
  assert(copy.getNumIds() == copy.getIds().size());
2368
0
}
2369
2370
0
void FlatAffineConstraints::removeId(unsigned pos) {
2371
0
  removeIdRange(pos, pos + 1);
2372
0
}
2373
2374
static std::pair<unsigned, unsigned>
2375
0
getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
2376
0
  unsigned numDims = cst.getNumDimIds();
2377
0
  unsigned numSymbols = cst.getNumSymbolIds();
2378
0
  unsigned newNumDims, newNumSymbols;
2379
0
  if (pos < numDims) {
2380
0
    newNumDims = numDims - 1;
2381
0
    newNumSymbols = numSymbols;
2382
0
  } else if (pos < numDims + numSymbols) {
2383
0
    assert(numSymbols >= 1);
2384
0
    newNumDims = numDims;
2385
0
    newNumSymbols = numSymbols - 1;
2386
0
  } else {
2387
0
    newNumDims = numDims;
2388
0
    newNumSymbols = numSymbols;
2389
0
  }
2390
0
  return {newNumDims, newNumSymbols};
2391
0
}
2392
2393
#undef DEBUG_TYPE
2394
#define DEBUG_TYPE "fm"
2395
2396
/// Eliminates identifier at the specified position using Fourier-Motzkin
2397
/// variable elimination. This technique is exact for rational spaces but
2398
/// conservative (in "rare" cases) for integer spaces. The operation corresponds
2399
/// to a projection operation yielding the (convex) set of integer points
2400
/// contained in the rational shadow of the set. An emptiness test that relies
2401
/// on this method will guarantee emptiness, i.e., it disproves the existence of
2402
/// a solution if it says it's empty.
2403
/// If a non-null isResultIntegerExact is passed, it is set to true if the
2404
/// result is also integer exact. If it's set to false, the obtained solution
2405
/// *may* not be exact, i.e., it may contain integer points that do not have an
2406
/// integer pre-image in the original set.
2407
///
2408
/// Eg:
2409
/// j >= 0, j <= i + 1
2410
/// i >= 0, i <= N + 1
2411
/// Eliminating i yields,
2412
///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
2413
///
2414
/// If darkShadow = true, this method computes the dark shadow on elimination;
2415
/// the dark shadow is a convex integer subset of the exact integer shadow. A
2416
/// non-empty dark shadow proves the existence of an integer solution. The
2417
/// elimination in such a case could however be an under-approximation, and thus
2418
/// should not be used for scanning sets or used by itself for dependence
2419
/// checking.
2420
///
2421
/// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
2422
///            ^
2423
///            |
2424
///            | * * * * o o
2425
///         i  | * * o o o o
2426
///            | o * * * * *
2427
///            --------------->
2428
///                 j ->
2429
///
2430
/// Eliminating i from this system (projecting on the j dimension):
2431
/// rational shadow / integer light shadow:  1 <= j <= 6
2432
/// dark shadow:                             3 <= j <= 6
2433
/// exact integer shadow:                    j = 1 \union  3 <= j <= 6
2434
/// holes/splinters:                         j = 2
2435
///
2436
/// darkShadow = false, isResultIntegerExact = nullptr are default values.
2437
// TODO(bondhugula): a slight modification to yield dark shadow version of FM
2438
// (tightened), which can prove the existence of a solution if there is one.
2439
void FlatAffineConstraints::FourierMotzkinEliminate(
2440
0
    unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
2441
0
  LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
2442
0
  LLVM_DEBUG(dump());
2443
0
  assert(pos < getNumIds() && "invalid position");
2444
0
  assert(hasConsistentState());
2445
0
2446
0
  // Check if this identifier can be eliminated through a substitution.
2447
0
  for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2448
0
    if (atEq(r, pos) != 0) {
2449
0
      // Use Gaussian elimination here (since we have an equality).
2450
0
      LogicalResult ret = gaussianEliminateId(pos);
2451
0
      (void)ret;
2452
0
      assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
2453
0
      LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
2454
0
      LLVM_DEBUG(dump());
2455
0
      return;
2456
0
    }
2457
0
  }
2458
0
2459
0
  // A fast linear time tightening.
2460
0
  GCDTightenInequalities();
2461
0
2462
0
  // Check if the identifier appears at all in any of the inequalities.
2463
0
  unsigned r, e;
2464
0
  for (r = 0, e = getNumInequalities(); r < e; r++) {
2465
0
    if (atIneq(r, pos) != 0)
2466
0
      break;
2467
0
  }
2468
0
  if (r == getNumInequalities()) {
2469
0
    // If it doesn't appear, just remove the column and return.
2470
0
    // TODO(andydavis,bondhugula): refactor removeColumns to use it from here.
2471
0
    removeId(pos);
2472
0
    LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2473
0
    LLVM_DEBUG(dump());
2474
0
    return;
2475
0
  }
2476
0
2477
0
  // Positions of constraints that are lower bounds on the variable.
2478
0
  SmallVector<unsigned, 4> lbIndices;
2479
0
  // Positions of constraints that are lower bounds on the variable.
2480
0
  SmallVector<unsigned, 4> ubIndices;
2481
0
  // Positions of constraints that do not involve the variable.
2482
0
  std::vector<unsigned> nbIndices;
2483
0
  nbIndices.reserve(getNumInequalities());
2484
0
2485
0
  // Gather all lower bounds and upper bounds of the variable. Since the
2486
0
  // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2487
0
  // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2488
0
  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2489
0
    if (atIneq(r, pos) == 0) {
2490
0
      // Id does not appear in bound.
2491
0
      nbIndices.push_back(r);
2492
0
    } else if (atIneq(r, pos) >= 1) {
2493
0
      // Lower bound.
2494
0
      lbIndices.push_back(r);
2495
0
    } else {
2496
0
      // Upper bound.
2497
0
      ubIndices.push_back(r);
2498
0
    }
2499
0
  }
2500
0
2501
0
  // Set the number of dimensions, symbols in the resulting system.
2502
0
  const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this);
2503
0
  unsigned newNumDims = dimsSymbols.first;
2504
0
  unsigned newNumSymbols = dimsSymbols.second;
2505
0
2506
0
  SmallVector<Optional<Value>, 8> newIds;
2507
0
  newIds.reserve(numIds - 1);
2508
0
  newIds.append(ids.begin(), ids.begin() + pos);
2509
0
  newIds.append(ids.begin() + pos + 1, ids.end());
2510
0
2511
0
  /// Create the new system which has one identifier less.
2512
0
  FlatAffineConstraints newFac(
2513
0
      lbIndices.size() * ubIndices.size() + nbIndices.size(),
2514
0
      getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
2515
0
      /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds);
2516
0
2517
0
  assert(newFac.getIds().size() == newFac.getNumIds());
2518
0
2519
0
  // This will be used to check if the elimination was integer exact.
2520
0
  unsigned lcmProducts = 1;
2521
0
2522
0
  // Let x be the variable we are eliminating.
2523
0
  // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
2524
0
  // that c_l, c_u >= 1) we have:
2525
0
  // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
2526
0
  // We thus generate a constraint:
2527
0
  // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
2528
0
  // Note if c_l = c_u = 1, all integer points captured by the resulting
2529
0
  // constraint correspond to integer points in the original system (i.e., they
2530
0
  // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
2531
0
  // integer exact.
2532
0
  for (auto ubPos : ubIndices) {
2533
0
    for (auto lbPos : lbIndices) {
2534
0
      SmallVector<int64_t, 4> ineq;
2535
0
      ineq.reserve(newFac.getNumCols());
2536
0
      int64_t lbCoeff = atIneq(lbPos, pos);
2537
0
      // Note that in the comments above, ubCoeff is the negation of the
2538
0
      // coefficient in the canonical form as the view taken here is that of the
2539
0
      // term being moved to the other size of '>='.
2540
0
      int64_t ubCoeff = -atIneq(ubPos, pos);
2541
0
      // TODO(bondhugula): refactor this loop to avoid all branches inside.
2542
0
      for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2543
0
        if (l == pos)
2544
0
          continue;
2545
0
        assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
2546
0
        int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
2547
0
        ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
2548
0
                       atIneq(lbPos, l) * (lcm / lbCoeff));
2549
0
        lcmProducts *= lcm;
2550
0
      }
2551
0
      if (darkShadow) {
2552
0
        // The dark shadow is a convex subset of the exact integer shadow. If
2553
0
        // there is a point here, it proves the existence of a solution.
2554
0
        ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
2555
0
      }
2556
0
      // TODO: we need to have a way to add inequalities in-place in
2557
0
      // FlatAffineConstraints instead of creating and copying over.
2558
0
      newFac.addInequality(ineq);
2559
0
    }
2560
0
  }
2561
0
2562
0
  LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
2563
0
                          << "\n");
2564
0
  if (lcmProducts == 1 && isResultIntegerExact)
2565
0
    *isResultIntegerExact = true;
2566
0
2567
0
  // Copy over the constraints not involving this variable.
2568
0
  for (auto nbPos : nbIndices) {
2569
0
    SmallVector<int64_t, 4> ineq;
2570
0
    ineq.reserve(getNumCols() - 1);
2571
0
    for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2572
0
      if (l == pos)
2573
0
        continue;
2574
0
      ineq.push_back(atIneq(nbPos, l));
2575
0
    }
2576
0
    newFac.addInequality(ineq);
2577
0
  }
2578
0
2579
0
  assert(newFac.getNumConstraints() ==
2580
0
         lbIndices.size() * ubIndices.size() + nbIndices.size());
2581
0
2582
0
  // Copy over the equalities.
2583
0
  for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2584
0
    SmallVector<int64_t, 4> eq;
2585
0
    eq.reserve(newFac.getNumCols());
2586
0
    for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2587
0
      if (l == pos)
2588
0
        continue;
2589
0
      eq.push_back(atEq(r, l));
2590
0
    }
2591
0
    newFac.addEquality(eq);
2592
0
  }
2593
0
2594
0
  // GCD tightening and normalization allows detection of more trivially
2595
0
  // redundant constraints.
2596
0
  newFac.GCDTightenInequalities();
2597
0
  newFac.normalizeConstraintsByGCD();
2598
0
  newFac.removeTrivialRedundancy();
2599
0
  clearAndCopyFrom(newFac);
2600
0
  LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2601
0
  LLVM_DEBUG(dump());
2602
0
}
2603
2604
#undef DEBUG_TYPE
2605
#define DEBUG_TYPE "affine-structures"
2606
2607
0
void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
2608
0
  if (num == 0)
2609
0
    return;
2610
0
2611
0
  // 'pos' can be at most getNumCols() - 2 if num > 0.
2612
0
  assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
2613
0
  assert(pos + num < getNumCols() && "invalid range");
2614
0
2615
0
  // Eliminate as many identifiers as possible using Gaussian elimination.
2616
0
  unsigned currentPos = pos;
2617
0
  unsigned numToEliminate = num;
2618
0
  unsigned numGaussianEliminated = 0;
2619
0
2620
0
  while (currentPos < getNumIds()) {
2621
0
    unsigned curNumEliminated =
2622
0
        gaussianEliminateIds(currentPos, currentPos + numToEliminate);
2623
0
    ++currentPos;
2624
0
    numToEliminate -= curNumEliminated + 1;
2625
0
    numGaussianEliminated += curNumEliminated;
2626
0
  }
2627
0
2628
0
  // Eliminate the remaining using Fourier-Motzkin.
2629
0
  for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
2630
0
    unsigned numToEliminate = num - numGaussianEliminated - i;
2631
0
    FourierMotzkinEliminate(
2632
0
        getBestIdToEliminate(*this, pos, pos + numToEliminate));
2633
0
  }
2634
0
2635
0
  // Fast/trivial simplifications.
2636
0
  GCDTightenInequalities();
2637
0
  // Normalize constraints after tightening since the latter impacts this, but
2638
0
  // not the other way round.
2639
0
  normalizeConstraintsByGCD();
2640
0
}
2641
2642
0
void FlatAffineConstraints::projectOut(Value id) {
2643
0
  unsigned pos;
2644
0
  bool ret = findId(id, &pos);
2645
0
  assert(ret);
2646
0
  (void)ret;
2647
0
  FourierMotzkinEliminate(pos);
2648
0
}
2649
2650
0
void FlatAffineConstraints::clearConstraints() {
2651
0
  equalities.clear();
2652
0
  inequalities.clear();
2653
0
}
2654
2655
namespace {
2656
2657
enum BoundCmpResult { Greater, Less, Equal, Unknown };
2658
2659
/// Compares two affine bounds whose coefficients are provided in 'first' and
2660
/// 'second'. The last coefficient is the constant term.
2661
0
static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
2662
0
  assert(a.size() == b.size());
2663
0
2664
0
  // For the bounds to be comparable, their corresponding identifier
2665
0
  // coefficients should be equal; the constant terms are then compared to
2666
0
  // determine less/greater/equal.
2667
0
2668
0
  if (!std::equal(a.begin(), a.end() - 1, b.begin()))
2669
0
    return Unknown;
2670
0
2671
0
  if (a.back() == b.back())
2672
0
    return Equal;
2673
0
2674
0
  return a.back() < b.back() ? Less : Greater;
2675
0
}
2676
} // namespace
2677
2678
// Returns constraints that are common to both A & B.
2679
static void getCommonConstraints(const FlatAffineConstraints &A,
2680
                                 const FlatAffineConstraints &B,
2681
0
                                 FlatAffineConstraints &C) {
2682
0
  C.reset(A.getNumDimIds(), A.getNumSymbolIds(), A.getNumLocalIds());
2683
0
  // A naive O(n^2) check should be enough here given the input sizes.
2684
0
  for (unsigned r = 0, e = A.getNumInequalities(); r < e; ++r) {
2685
0
    for (unsigned s = 0, f = B.getNumInequalities(); s < f; ++s) {
2686
0
      if (A.getInequality(r) == B.getInequality(s)) {
2687
0
        C.addInequality(A.getInequality(r));
2688
0
        break;
2689
0
      }
2690
0
    }
2691
0
  }
2692
0
  for (unsigned r = 0, e = A.getNumEqualities(); r < e; ++r) {
2693
0
    for (unsigned s = 0, f = B.getNumEqualities(); s < f; ++s) {
2694
0
      if (A.getEquality(r) == B.getEquality(s)) {
2695
0
        C.addEquality(A.getEquality(r));
2696
0
        break;
2697
0
      }
2698
0
    }
2699
0
  }
2700
0
}
2701
2702
// Computes the bounding box with respect to 'other' by finding the min of the
2703
// lower bounds and the max of the upper bounds along each of the dimensions.
2704
LogicalResult
2705
0
FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
2706
0
  assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
2707
0
  assert(otherCst.getIds()
2708
0
             .slice(0, getNumDimIds())
2709
0
             .equals(getIds().slice(0, getNumDimIds())) &&
2710
0
         "dim values mismatch");
2711
0
  assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
2712
0
  assert(getNumLocalIds() == 0 && "local ids not supported yet here");
2713
0
2714
0
  // Align `other` to this.
2715
0
  Optional<FlatAffineConstraints> otherCopy;
2716
0
  if (!areIdsAligned(*this, otherCst)) {
2717
0
    otherCopy.emplace(FlatAffineConstraints(otherCst));
2718
0
    mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy.getValue());
2719
0
  }
2720
0
2721
0
  const auto &otherAligned = otherCopy ? *otherCopy : otherCst;
2722
0
2723
0
  // Get the constraints common to both systems; these will be added as is to
2724
0
  // the union.
2725
0
  FlatAffineConstraints commonCst;
2726
0
  getCommonConstraints(*this, otherAligned, commonCst);
2727
0
2728
0
  std::vector<SmallVector<int64_t, 8>> boundingLbs;
2729
0
  std::vector<SmallVector<int64_t, 8>> boundingUbs;
2730
0
  boundingLbs.reserve(2 * getNumDimIds());
2731
0
  boundingUbs.reserve(2 * getNumDimIds());
2732
0
2733
0
  // To hold lower and upper bounds for each dimension.
2734
0
  SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
2735
0
  // To compute min of lower bounds and max of upper bounds for each dimension.
2736
0
  SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
2737
0
  SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
2738
0
  // To compute final new lower and upper bounds for the union.
2739
0
  SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
2740
0
2741
0
  int64_t lbFloorDivisor, otherLbFloorDivisor;
2742
0
  for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
2743
0
    auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
2744
0
    if (!extent.hasValue())
2745
0
      // TODO(bondhugula): symbolic extents when necessary.
2746
0
      // TODO(bondhugula): handle union if a dimension is unbounded.
2747
0
      return failure();
2748
0
2749
0
    auto otherExtent = otherAligned.getConstantBoundOnDimSize(
2750
0
        d, &otherLb, &otherLbFloorDivisor, &otherUb);
2751
0
    if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
2752
0
      // TODO(bondhugula): symbolic extents when necessary.
2753
0
      return failure();
2754
0
2755
0
    assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
2756
0
2757
0
    auto res = compareBounds(lb, otherLb);
2758
0
    // Identify min.
2759
0
    if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
2760
0
      minLb = lb;
2761
0
      // Since the divisor is for a floordiv, we need to convert to ceildiv,
2762
0
      // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
2763
0
      // div * i >= expr - div + 1.
2764
0
      minLb.back() -= lbFloorDivisor - 1;
2765
0
    } else if (res == BoundCmpResult::Greater) {
2766
0
      minLb = otherLb;
2767
0
      minLb.back() -= otherLbFloorDivisor - 1;
2768
0
    } else {
2769
0
      // Uncomparable - check for constant lower/upper bounds.
2770
0
      auto constLb = getConstantLowerBound(d);
2771
0
      auto constOtherLb = otherAligned.getConstantLowerBound(d);
2772
0
      if (!constLb.hasValue() || !constOtherLb.hasValue())
2773
0
        return failure();
2774
0
      std::fill(minLb.begin(), minLb.end(), 0);
2775
0
      minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
2776
0
    }
2777
0
2778
0
    // Do the same for ub's but max of upper bounds. Identify max.
2779
0
    auto uRes = compareBounds(ub, otherUb);
2780
0
    if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
2781
0
      maxUb = ub;
2782
0
    } else if (uRes == BoundCmpResult::Less) {
2783
0
      maxUb = otherUb;
2784
0
    } else {
2785
0
      // Uncomparable - check for constant lower/upper bounds.
2786
0
      auto constUb = getConstantUpperBound(d);
2787
0
      auto constOtherUb = otherAligned.getConstantUpperBound(d);
2788
0
      if (!constUb.hasValue() || !constOtherUb.hasValue())
2789
0
        return failure();
2790
0
      std::fill(maxUb.begin(), maxUb.end(), 0);
2791
0
      maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
2792
0
    }
2793
0
2794
0
    std::fill(newLb.begin(), newLb.end(), 0);
2795
0
    std::fill(newUb.begin(), newUb.end(), 0);
2796
0
2797
0
    // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
2798
0
    // and so it's the divisor for newLb and newUb as well.
2799
0
    newLb[d] = lbFloorDivisor;
2800
0
    newUb[d] = -lbFloorDivisor;
2801
0
    // Copy over the symbolic part + constant term.
2802
0
    std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
2803
0
    std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
2804
0
                   newLb.begin() + getNumDimIds(), std::negate<int64_t>());
2805
0
    std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
2806
0
2807
0
    boundingLbs.push_back(newLb);
2808
0
    boundingUbs.push_back(newUb);
2809
0
  }
2810
0
2811
0
  // Clear all constraints and add the lower/upper bounds for the bounding box.
2812
0
  clearConstraints();
2813
0
  for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
2814
0
    addInequality(boundingLbs[d]);
2815
0
    addInequality(boundingUbs[d]);
2816
0
  }
2817
0
2818
0
  // Add the constraints that were common to both systems.
2819
0
  append(commonCst);
2820
0
  removeTrivialRedundancy();
2821
0
2822
0
  // TODO(mlir-team): copy over pure symbolic constraints from this and 'other'
2823
0
  // over to the union (since the above are just the union along dimensions); we
2824
0
  // shouldn't be discarding any other constraints on the symbols.
2825
0
2826
0
  return success();
2827
0
}
2828
2829
/// Compute an explicit representation for local vars. For all systems coming
2830
/// from MLIR integer sets, maps, or expressions where local vars were
2831
/// introduced to model floordivs and mods, this always succeeds.
2832
static LogicalResult computeLocalVars(const FlatAffineConstraints &cst,
2833
                                      SmallVectorImpl<AffineExpr> &memo,
2834
0
                                      MLIRContext *context) {
2835
0
  unsigned numDims = cst.getNumDimIds();
2836
0
  unsigned numSyms = cst.getNumSymbolIds();
2837
0
2838
0
  // Initialize dimensional and symbolic identifiers.
2839
0
  for (unsigned i = 0; i < numDims; i++)
2840
0
    memo[i] = getAffineDimExpr(i, context);
2841
0
  for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
2842
0
    memo[i] = getAffineSymbolExpr(i - numDims, context);
2843
0
2844
0
  bool changed;
2845
0
  do {
2846
0
    // Each time `changed` is true at the end of this iteration, one or more
2847
0
    // local vars would have been detected as floordivs and set in memo; so the
2848
0
    // number of null entries in memo[...] strictly reduces; so this converges.
2849
0
    changed = false;
2850
0
    for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i)
2851
0
      if (!memo[numDims + numSyms + i] &&
2852
0
          detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
2853
0
        changed = true;
2854
0
  } while (changed);
2855
0
2856
0
  ArrayRef<AffineExpr> localExprs =
2857
0
      ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds());
2858
0
  return success(
2859
0
      llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
2860
0
}
2861
2862
void FlatAffineConstraints::getIneqAsAffineValueMap(
2863
    unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
2864
    MLIRContext *context) const {
2865
  unsigned numDims = getNumDimIds();
2866
  unsigned numSyms = getNumSymbolIds();
2867
2868
  assert(pos < numDims && "invalid position");
2869
  assert(ineqPos < getNumInequalities() && "invalid inequality position");
2870
2871
  // Get expressions for local vars.
2872
  SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
2873
  if (failed(computeLocalVars(*this, memo, context)))
2874
    assert(false &&
2875
           "one or more local exprs do not have an explicit representation");
2876
  auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
2877
2878
  // Compute the AffineExpr lower/upper bound for this inequality.
2879
  ArrayRef<int64_t> inequality = getInequality(ineqPos);
2880
  SmallVector<int64_t, 8> bound;
2881
  bound.reserve(getNumCols() - 1);
2882
  // Everything other than the coefficient at `pos`.
2883
  bound.append(inequality.begin(), inequality.begin() + pos);
2884
  bound.append(inequality.begin() + pos + 1, inequality.end());
2885
2886
  if (inequality[pos] > 0)
2887
    // Lower bound.
2888
    std::transform(bound.begin(), bound.end(), bound.begin(),
2889
                   std::negate<int64_t>());
2890
  else
2891
    // Upper bound (which is exclusive).
2892
    bound.back() += 1;
2893
2894
  // Convert to AffineExpr (tree) form.
2895
  auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
2896
                                             localExprs, context);
2897
2898
  // Get the values to bind to this affine expr (all dims and symbols).
2899
  SmallVector<Value, 4> operands;
2900
  getIdValues(0, pos, &operands);
2901
  SmallVector<Value, 4> trailingOperands;
2902
  getIdValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands);
2903
  operands.append(trailingOperands.begin(), trailingOperands.end());
2904
  vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
2905
}
2906
2907
/// Returns true if the pos^th column is all zero for both inequalities and
2908
/// equalities..
2909
0
static bool isColZero(const FlatAffineConstraints &cst, unsigned pos) {
2910
0
  unsigned rowPos;
2911
0
  return !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/false, &rowPos) &&
2912
0
         !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/true, &rowPos);
2913
0
}
2914
2915
0
IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
2916
0
  if (getNumConstraints() == 0)
2917
0
    // Return universal set (always true): 0 == 0.
2918
0
    return IntegerSet::get(getNumDimIds(), getNumSymbolIds(),
2919
0
                           getAffineConstantExpr(/*constant=*/0, context),
2920
0
                           /*eqFlags=*/true);
2921
0
2922
0
  // Construct local references.
2923
0
  SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
2924
0
2925
0
  if (failed(computeLocalVars(*this, memo, context))) {
2926
0
    // Check if the local variables without an explicit representation have
2927
0
    // zero coefficients everywhere.
2928
0
    for (unsigned i = getNumDimAndSymbolIds(), e = getNumIds(); i < e; ++i) {
2929
0
      if (!memo[i] && !isColZero(*this, /*pos=*/i)) {
2930
0
        LLVM_DEBUG(llvm::dbgs() << "one or more local exprs do not have an "
2931
0
                                   "explicit representation");
2932
0
        return IntegerSet();
2933
0
      }
2934
0
    }
2935
0
  }
2936
0
2937
0
  ArrayRef<AffineExpr> localExprs =
2938
0
      ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
2939
0
2940
0
  // Construct the IntegerSet from the equalities/inequalities.
2941
0
  unsigned numDims = getNumDimIds();
2942
0
  unsigned numSyms = getNumSymbolIds();
2943
0
2944
0
  SmallVector<bool, 16> eqFlags(getNumConstraints());
2945
0
  std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
2946
0
  std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
2947
0
2948
0
  SmallVector<AffineExpr, 8> exprs;
2949
0
  exprs.reserve(getNumConstraints());
2950
0
2951
0
  for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
2952
0
    exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
2953
0
                                              localExprs, context));
2954
0
  for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
2955
0
    exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
2956
0
                                              numSyms, localExprs, context));
2957
0
  return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
2958
0
}
2959
2960
/// Find positions of inequalities and equalities that do not have a coefficient
2961
/// for [pos, pos + num) identifiers.
2962
static void getIndependentConstraints(const FlatAffineConstraints &cst,
2963
                                      unsigned pos, unsigned num,
2964
                                      SmallVectorImpl<unsigned> &nbIneqIndices,
2965
0
                                      SmallVectorImpl<unsigned> &nbEqIndices) {
2966
0
  assert(pos < cst.getNumIds() && "invalid start position");
2967
0
  assert(pos + num <= cst.getNumIds() && "invalid limit");
2968
0
2969
0
  for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
2970
0
    // The bounds are to be independent of [offset, offset + num) columns.
2971
0
    unsigned c;
2972
0
    for (c = pos; c < pos + num; ++c) {
2973
0
      if (cst.atIneq(r, c) != 0)
2974
0
        break;
2975
0
    }
2976
0
    if (c == pos + num)
2977
0
      nbIneqIndices.push_back(r);
2978
0
  }
2979
0
2980
0
  for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
2981
0
    // The bounds are to be independent of [offset, offset + num) columns.
2982
0
    unsigned c;
2983
0
    for (c = pos; c < pos + num; ++c) {
2984
0
      if (cst.atEq(r, c) != 0)
2985
0
        break;
2986
0
    }
2987
0
    if (c == pos + num)
2988
0
      nbEqIndices.push_back(r);
2989
0
  }
2990
0
}
2991
2992
void FlatAffineConstraints::removeIndependentConstraints(unsigned pos,
2993
0
                                                         unsigned num) {
2994
0
  assert(pos + num <= getNumIds() && "invalid range");
2995
0
2996
0
  // Remove constraints that are independent of these identifiers.
2997
0
  SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
2998
0
  getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
2999
0
3000
0
  // Iterate in reverse so that indices don't have to be updated.
3001
0
  // TODO: This method can be made more efficient (because removal of each
3002
0
  // inequality leads to much shifting/copying in the underlying buffer).
3003
0
  for (auto nbIndex : llvm::reverse(nbIneqIndices))
3004
0
    removeInequality(nbIndex);
3005
0
  for (auto nbIndex : llvm::reverse(nbEqIndices))
3006
0
    removeEquality(nbIndex);
3007
0
}