Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/lib/IR/SymbolTable.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "mlir/IR/SymbolTable.h"
10
#include "llvm/ADT/SetVector.h"
11
#include "llvm/ADT/SmallPtrSet.h"
12
#include "llvm/ADT/SmallString.h"
13
#include "llvm/ADT/StringSwitch.h"
14
15
using namespace mlir;
16
17
/// Return true if the given operation is unknown and may potentially define a
18
/// symbol table.
19
0
static bool isPotentiallyUnknownSymbolTable(Operation *op) {
20
0
  return !op->getDialect() && op->getNumRegions() == 1;
21
0
}
22
23
/// Returns the string name of the given symbol, or None if this is not a
24
/// symbol.
25
0
static Optional<StringRef> getNameIfSymbol(Operation *symbol) {
26
0
  auto nameAttr =
27
0
      symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
28
0
  return nameAttr ? nameAttr.getValue() : Optional<StringRef>();
29
0
}
30
31
/// Computes the nested symbol reference attribute for the symbol 'symbolName'
32
/// that are usable within the symbol table operations from 'symbol' as far up
33
/// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
34
/// Returns success if all references up to 'within' could be computed.
35
static LogicalResult
36
collectValidReferencesFor(Operation *symbol, StringRef symbolName,
37
                          Operation *within,
38
0
                          SmallVectorImpl<SymbolRefAttr> &results) {
39
0
  assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
40
0
  MLIRContext *ctx = symbol->getContext();
41
0
42
0
  auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx);
43
0
  results.push_back(leafRef);
44
0
45
0
  // Early exit for when 'within' is the parent of 'symbol'.
46
0
  Operation *symbolTableOp = symbol->getParentOp();
47
0
  if (within == symbolTableOp)
48
0
    return success();
49
0
50
0
  // Collect references until 'symbolTableOp' reaches 'within'.
51
0
  SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
52
0
  do {
53
0
    // Each parent of 'symbol' should define a symbol table.
54
0
    if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
55
0
      return failure();
56
0
    // Each parent of 'symbol' should also be a symbol.
57
0
    Optional<StringRef> symbolTableName = getNameIfSymbol(symbolTableOp);
58
0
    if (!symbolTableName)
59
0
      return failure();
60
0
    results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx));
61
0
62
0
    symbolTableOp = symbolTableOp->getParentOp();
63
0
    if (symbolTableOp == within)
64
0
      break;
65
0
    nestedRefs.insert(nestedRefs.begin(),
66
0
                      FlatSymbolRefAttr::get(*symbolTableName, ctx));
67
0
  } while (true);
68
0
  return success();
69
0
}
70
71
//===----------------------------------------------------------------------===//
72
// SymbolTable
73
//===----------------------------------------------------------------------===//
74
75
/// Build a symbol table with the symbols within the given operation.
76
SymbolTable::SymbolTable(Operation *symbolTableOp)
77
    : symbolTableOp(symbolTableOp) {
78
  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
79
         "expected operation to have SymbolTable trait");
80
  assert(symbolTableOp->getNumRegions() == 1 &&
81
         "expected operation to have a single region");
82
  assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) &&
83
         "expected operation to have a single block");
84
85
  for (auto &op : symbolTableOp->getRegion(0).front()) {
86
    Optional<StringRef> name = getNameIfSymbol(&op);
87
    if (!name)
88
      continue;
89
90
    auto inserted = symbolTable.insert({*name, &op});
91
    (void)inserted;
92
    assert(inserted.second &&
93
           "expected region to contain uniquely named symbol operations");
94
  }
95
}
96
97
/// Look up a symbol with the specified name, returning null if no such name
98
/// exists. Names never include the @ on them.
99
0
Operation *SymbolTable::lookup(StringRef name) const {
100
0
  return symbolTable.lookup(name);
101
0
}
102
103
/// Erase the given symbol from the table.
104
void SymbolTable::erase(Operation *symbol) {
105
  Optional<StringRef> name = getNameIfSymbol(symbol);
106
  assert(name && "expected valid 'name' attribute");
107
  assert(symbol->getParentOp() == symbolTableOp &&
108
         "expected this operation to be inside of the operation with this "
109
         "SymbolTable");
110
111
  auto it = symbolTable.find(*name);
112
  if (it != symbolTable.end() && it->second == symbol) {
113
    symbolTable.erase(it);
114
    symbol->erase();
115
  }
116
}
117
118
/// Insert a new symbol into the table and associated operation, and rename it
119
/// as necessary to avoid collisions.
120
0
void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
121
0
  auto &body = symbolTableOp->getRegion(0).front();
122
0
  if (insertPt == Block::iterator() || insertPt == body.end())
123
0
    insertPt = Block::iterator(body.getTerminator());
124
0
125
0
  assert(insertPt->getParentOp() == symbolTableOp &&
126
0
         "expected insertPt to be in the associated module operation");
127
0
128
0
  body.getOperations().insert(insertPt, symbol);
129
0
130
0
  // Add this symbol to the symbol table, uniquing the name if a conflict is
131
0
  // detected.
132
0
  StringRef name = getSymbolName(symbol);
133
0
  if (symbolTable.insert({name, symbol}).second)
134
0
    return;
135
0
  // If a conflict was detected, then the symbol will not have been added to
136
0
  // the symbol table. Try suffixes until we get to a unique name that works.
137
0
  SmallString<128> nameBuffer(name);
138
0
  unsigned originalLength = nameBuffer.size();
139
0
140
0
  // Iteratively try suffixes until we find one that isn't used.
141
0
  do {
142
0
    nameBuffer.resize(originalLength);
143
0
    nameBuffer += '_';
144
0
    nameBuffer += std::to_string(uniquingCounter++);
145
0
  } while (!symbolTable.insert({nameBuffer, symbol}).second);
146
0
  setSymbolName(symbol, nameBuffer);
147
0
}
148
149
/// Returns the name of the given symbol operation.
150
0
StringRef SymbolTable::getSymbolName(Operation *symbol) {
151
0
  Optional<StringRef> name = getNameIfSymbol(symbol);
152
0
  assert(name && "expected valid symbol name");
153
0
  return *name;
154
0
}
155
/// Sets the name of the given symbol operation.
156
0
void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
157
0
  symbol->setAttr(getSymbolAttrName(),
158
0
                  StringAttr::get(name, symbol->getContext()));
159
0
}
160
161
/// Returns the visibility of the given symbol operation.
162
0
SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) {
163
0
  // If the attribute doesn't exist, assume public.
164
0
  StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
165
0
  if (!vis)
166
0
    return Visibility::Public;
167
0
168
0
  // Otherwise, switch on the string value.
169
0
  return llvm::StringSwitch<Visibility>(vis.getValue())
170
0
      .Case("private", Visibility::Private)
171
0
      .Case("nested", Visibility::Nested)
172
0
      .Case("public", Visibility::Public);
173
0
}
174
/// Sets the visibility of the given symbol operation.
175
0
void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
176
0
  MLIRContext *ctx = symbol->getContext();
177
0
178
0
  // If the visibility is public, just drop the attribute as this is the
179
0
  // default.
180
0
  if (vis == Visibility::Public) {
181
0
    symbol->removeAttr(Identifier::get(getVisibilityAttrName(), ctx));
182
0
    return;
183
0
  }
184
0
185
0
  // Otherwise, update the attribute.
186
0
  assert((vis == Visibility::Private || vis == Visibility::Nested) &&
187
0
         "unknown symbol visibility kind");
188
0
189
0
  StringRef visName = vis == Visibility::Private ? "private" : "nested";
190
0
  symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx));
191
0
}
192
193
/// Returns the nearest symbol table from a given operation `from`. Returns
194
/// nullptr if no valid parent symbol table could be found.
195
0
Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
196
0
  assert(from && "expected valid operation");
197
0
  if (isPotentiallyUnknownSymbolTable(from))
198
0
    return nullptr;
199
0
200
0
  while (!from->hasTrait<OpTrait::SymbolTable>()) {
201
0
    from = from->getParentOp();
202
0
203
0
    // Check that this is a valid op and isn't an unknown symbol table.
204
0
    if (!from || isPotentiallyUnknownSymbolTable(from))
205
0
      return nullptr;
206
0
  }
207
0
  return from;
208
0
}
209
210
/// Walks all symbol table operations nested within, and including, `op`. For
211
/// each symbol table operation, the provided callback is invoked with the op
212
/// and a boolean signifying if the symbols within that symbol table can be
213
/// treated as if all uses are visible. `allSymUsesVisible` identifies whether
214
/// all of the symbol uses of symbols within `op` are visible.
215
void SymbolTable::walkSymbolTables(
216
    Operation *op, bool allSymUsesVisible,
217
0
    function_ref<void(Operation *, bool)> callback) {
218
0
  bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
219
0
  if (isSymbolTable) {
220
0
    SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
221
0
    allSymUsesVisible |= !symbol || symbol.isPrivate();
222
0
  } else {
223
0
    // Otherwise if 'op' is not a symbol table, any nested symbols are
224
0
    // guaranteed to be hidden.
225
0
    allSymUsesVisible = true;
226
0
  }
227
0
228
0
  for (Region &region : op->getRegions())
229
0
    for (Block &block : region)
230
0
      for (Operation &nestedOp : block)
231
0
        walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
232
0
233
0
  // If 'op' had the symbol table trait, visit it after any nested symbol
234
0
  // tables.
235
0
  if (isSymbolTable)
236
0
    callback(op, allSymUsesVisible);
237
0
}
238
239
/// Returns the operation registered with the given symbol name with the
240
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
241
/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
242
/// was found.
243
Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
244
0
                                       StringRef symbol) {
245
0
  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
246
0
247
0
  // Look for a symbol with the given name.
248
0
  for (auto &op : symbolTableOp->getRegion(0).front().without_terminator())
249
0
    if (getNameIfSymbol(&op) == symbol)
250
0
      return &op;
251
0
  return nullptr;
252
0
}
253
Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
254
0
                                       SymbolRefAttr symbol) {
255
0
  SmallVector<Operation *, 4> resolvedSymbols;
256
0
  if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
257
0
    return nullptr;
258
0
  return resolvedSymbols.back();
259
0
}
260
261
LogicalResult
262
SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
263
0
                            SmallVectorImpl<Operation *> &symbols) {
264
0
  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
265
0
266
0
  // Lookup the root reference for this symbol.
267
0
  symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference());
268
0
  if (!symbolTableOp)
269
0
    return failure();
270
0
  symbols.push_back(symbolTableOp);
271
0
272
0
  // If there are no nested references, just return the root symbol directly.
273
0
  ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
274
0
  if (nestedRefs.empty())
275
0
    return success();
276
0
277
0
  // Verify that the root is also a symbol table.
278
0
  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
279
0
    return failure();
280
0
281
0
  // Otherwise, lookup each of the nested non-leaf references and ensure that
282
0
  // each corresponds to a valid symbol table.
283
0
  for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
284
0
    symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue());
285
0
    if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
286
0
      return failure();
287
0
    symbols.push_back(symbolTableOp);
288
0
  }
289
0
  symbols.push_back(lookupSymbolIn(symbolTableOp, symbol.getLeafReference()));
290
0
  return success(symbols.back());
291
0
}
292
293
/// Returns the operation registered with the given symbol name within the
294
/// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
295
/// nullptr if no valid symbol was found.
296
Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
297
0
                                                StringRef symbol) {
298
0
  Operation *symbolTableOp = getNearestSymbolTable(from);
299
0
  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
300
0
}
301
Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
302
0
                                                SymbolRefAttr symbol) {
303
0
  Operation *symbolTableOp = getNearestSymbolTable(from);
304
0
  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
305
0
}
306
307
//===----------------------------------------------------------------------===//
308
// SymbolTable Trait Types
309
//===----------------------------------------------------------------------===//
310
311
0
LogicalResult detail::verifySymbolTable(Operation *op) {
312
0
  if (op->getNumRegions() != 1)
313
0
    return op->emitOpError()
314
0
           << "Operations with a 'SymbolTable' must have exactly one region";
315
0
  if (!llvm::hasSingleElement(op->getRegion(0)))
316
0
    return op->emitOpError()
317
0
           << "Operations with a 'SymbolTable' must have exactly one block";
318
0
319
0
  // Check that all symbols are uniquely named within child regions.
320
0
  DenseMap<Attribute, Location> nameToOrigLoc;
321
0
  for (auto &block : op->getRegion(0)) {
322
0
    for (auto &op : block) {
323
0
      // Check for a symbol name attribute.
324
0
      auto nameAttr =
325
0
          op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName());
326
0
      if (!nameAttr)
327
0
        continue;
328
0
329
0
      // Try to insert this symbol into the table.
330
0
      auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
331
0
      if (!it.second)
332
0
        return op.emitError()
333
0
            .append("redefinition of symbol named '", nameAttr.getValue(), "'")
334
0
            .attachNote(it.first->second)
335
0
            .append("see existing symbol definition here");
336
0
    }
337
0
  }
338
0
  return success();
339
0
}
340
341
0
LogicalResult detail::verifySymbol(Operation *op) {
342
0
  // Verify the name attribute.
343
0
  if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
344
0
    return op->emitOpError() << "requires string attribute '"
345
0
                             << mlir::SymbolTable::getSymbolAttrName() << "'";
346
0
347
0
  // Verify the visibility attribute.
348
0
  if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
349
0
    StringAttr visStrAttr = vis.dyn_cast<StringAttr>();
350
0
    if (!visStrAttr)
351
0
      return op->emitOpError() << "requires visibility attribute '"
352
0
                               << mlir::SymbolTable::getVisibilityAttrName()
353
0
                               << "' to be a string attribute, but got " << vis;
354
0
355
0
    if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
356
0
                            visStrAttr.getValue()))
357
0
      return op->emitOpError()
358
0
             << "visibility expected to be one of [\"public\", \"private\", "
359
0
                "\"nested\"], but got "
360
0
             << visStrAttr;
361
0
  }
362
0
  return success();
363
0
}
364
365
//===----------------------------------------------------------------------===//
366
// Symbol Use Lists
367
//===----------------------------------------------------------------------===//
368
369
/// Walk all of the symbol references within the given operation, invoking the
370
/// provided callback for each found use. The callbacks takes as arguments: the
371
/// use of the symbol, and the nested access chain to the attribute within the
372
/// operation dictionary. An access chain is a set of indices into nested
373
/// container attributes. For example, a symbol use in an attribute dictionary
374
/// that looks like the following:
375
///
376
///    {use = [{other_attr, @symbol}]}
377
///
378
/// May have the following access chain:
379
///
380
///     [0, 0, 1]
381
///
382
static WalkResult walkSymbolRefs(
383
    Operation *op,
384
0
    function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
385
0
  // Check to see if the operation has any attributes.
386
0
  if (op->getMutableAttrDict().empty())
387
0
    return WalkResult::advance();
388
0
  DictionaryAttr attrDict = op->getAttrDictionary();
389
0
390
0
  // A worklist of a container attribute and the current index into the held
391
0
  // attribute list.
392
0
  SmallVector<Attribute, 1> attrWorklist(1, attrDict);
393
0
  SmallVector<int, 1> curAccessChain(1, /*Value=*/-1);
394
0
395
0
  // Process the symbol references within the given nested attribute range.
396
0
  auto processAttrs = [&](int &index, auto attrRange) -> WalkResult {
397
0
    for (Attribute attr : llvm::drop_begin(attrRange, index)) {
398
0
      /// Check for a nested container attribute, these will also need to be
399
0
      /// walked.
400
0
      if (attr.isa<ArrayAttr>() || attr.isa<DictionaryAttr>()) {
401
0
        attrWorklist.push_back(attr);
402
0
        curAccessChain.push_back(-1);
403
0
        return WalkResult::advance();
404
0
      }
405
0
406
0
      // Invoke the provided callback if we find a symbol use and check for a
407
0
      // requested interrupt.
408
0
      if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>())
409
0
        if (callback({op, symbolRef}, curAccessChain).wasInterrupted())
410
0
          return WalkResult::interrupt();
411
0
412
0
      // Make sure to keep the index counter in sync.
413
0
      ++index;
414
0
    }
415
0
416
0
    // Pop this container attribute from the worklist.
417
0
    attrWorklist.pop_back();
418
0
    curAccessChain.pop_back();
419
0
    return WalkResult::advance();
420
0
  };
Unexecuted instantiation: SymbolTable.cpp:_ZZL14walkSymbolRefsPN4mlir9OperationEN4llvm12function_refIFNS_10WalkResultENS_11SymbolTable9SymbolUseENS2_8ArrayRefIiEEEEEENK3$_1clINS2_14iterator_rangeINS2_15mapped_iteratorIPKSt4pairINS_10IdentifierENS_9AttributeEEZNS2_17make_second_rangeINS7_ISI_EEEEDaOT_EUlRSJ_E_RKSH_EEEEEES4_RiSN_
Unexecuted instantiation: SymbolTable.cpp:_ZZL14walkSymbolRefsPN4mlir9OperationEN4llvm12function_refIFNS_10WalkResultENS_11SymbolTable9SymbolUseENS2_8ArrayRefIiEEEEEENK3$_1clINS7_INS_9AttributeEEEEES4_RiT_
421
0
422
0
  WalkResult result = WalkResult::advance();
423
0
  do {
424
0
    Attribute attr = attrWorklist.back();
425
0
    int &index = curAccessChain.back();
426
0
    ++index;
427
0
428
0
    // Process the given attribute, which is guaranteed to be a container.
429
0
    if (auto dict = attr.dyn_cast<DictionaryAttr>())
430
0
      result = processAttrs(index, make_second_range(dict.getValue()));
431
0
    else
432
0
      result = processAttrs(index, attr.cast<ArrayAttr>().getValue());
433
0
  } while (!attrWorklist.empty() && !result.wasInterrupted());
434
0
  return result;
435
0
}
436
437
/// Walk all of the uses, for any symbol, that are nested within the given
438
/// regions, invoking the provided callback for each. This does not traverse
439
/// into any nested symbol tables.
440
static Optional<WalkResult> walkSymbolUses(
441
    MutableArrayRef<Region> regions,
442
0
    function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
443
0
  SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
444
0
  while (!worklist.empty()) {
445
0
    for (Operation &op : worklist.pop_back_val()->getOps()) {
446
0
      if (walkSymbolRefs(&op, callback).wasInterrupted())
447
0
        return WalkResult::interrupt();
448
0
449
0
      // Check that this isn't a potentially unknown symbol table.
450
0
      if (isPotentiallyUnknownSymbolTable(&op))
451
0
        return llvm::None;
452
0
453
0
      // If this op defines a new symbol table scope, we can't traverse. Any
454
0
      // symbol references nested within 'op' are different semantically.
455
0
      if (!op.hasTrait<OpTrait::SymbolTable>()) {
456
0
        for (Region &region : op.getRegions())
457
0
          worklist.push_back(&region);
458
0
      }
459
0
    }
460
0
  }
461
0
  return WalkResult::advance();
462
0
}
463
/// Walk all of the uses, for any symbol, that are nested within the given
464
/// operation 'from', invoking the provided callback for each. This does not
465
/// traverse into any nested symbol tables.
466
static Optional<WalkResult> walkSymbolUses(
467
    Operation *from,
468
0
    function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
469
0
  // If this operation has regions, and it, as well as its dialect, isn't
470
0
  // registered then conservatively fail. The operation may define a
471
0
  // symbol table, so we can't opaquely know if we should traverse to find
472
0
  // nested uses.
473
0
  if (isPotentiallyUnknownSymbolTable(from))
474
0
    return llvm::None;
475
0
476
0
  // Walk the uses on this operation.
477
0
  if (walkSymbolRefs(from, callback).wasInterrupted())
478
0
    return WalkResult::interrupt();
479
0
480
0
  // Only recurse if this operation is not a symbol table. A symbol table
481
0
  // defines a new scope, so we can't walk the attributes from within the symbol
482
0
  // table op.
483
0
  if (!from->hasTrait<OpTrait::SymbolTable>())
484
0
    return walkSymbolUses(from->getRegions(), callback);
485
0
  return WalkResult::advance();
486
0
}
487
488
namespace {
489
/// This class represents a single symbol scope. A symbol scope represents the
490
/// set of operations nested within a symbol table that may reference symbols
491
/// within that table. A symbol scope does not contain the symbol table
492
/// operation itself, just its contained operations. A scope ends at leaf
493
/// operations or another symbol table operation.
494
struct SymbolScope {
495
  /// Walk the symbol uses within this scope, invoking the given callback.
496
  /// This variant is used when the callback type matches that expected by
497
  /// 'walkSymbolUses'.
498
  template <typename CallbackT,
499
            typename std::enable_if_t<!std::is_same<
500
                typename llvm::function_traits<CallbackT>::result_t,
501
                void>::value> * = nullptr>
502
0
  Optional<WalkResult> walk(CallbackT cback) {
503
0
    if (Region *region = limit.dyn_cast<Region *>())
504
0
      return walkSymbolUses(*region, cback);
505
0
    return walkSymbolUses(limit.get<Operation *>(), cback);
506
0
  }
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZNS0_4walkIZL17getSymbolUsesImplIN4llvm9StringRefEN4mlir9OperationEENS4_8OptionalINS6_11SymbolTable8UseRangeEEET_PT0_EUlNS9_9SymbolUseEE_LPv0EEENS8_INS6_10WalkResultEEESC_EUlSF_NS4_8ArrayRefIiEEE_LSH_0EEESJ_SC_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZNS0_4walkIZL17getSymbolUsesImplIPN4mlir9OperationES5_EN4llvm8OptionalINS4_11SymbolTable8UseRangeEEET_PT0_EUlNS9_9SymbolUseEE_LPv0EEENS8_INS4_10WalkResultEEESC_EUlSF_NS7_8ArrayRefIiEEE_LSH_0EEESJ_SC_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZNS0_4walkIZL17getSymbolUsesImplIN4llvm9StringRefEN4mlir6RegionEENS4_8OptionalINS6_11SymbolTable8UseRangeEEET_PT0_EUlNS9_9SymbolUseEE_LPv0EEENS8_INS6_10WalkResultEEESC_EUlSF_NS4_8ArrayRefIiEEE_LSH_0EEESJ_SC_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZNS0_4walkIZL17getSymbolUsesImplIPN4mlir9OperationENS4_6RegionEEN4llvm8OptionalINS4_11SymbolTable8UseRangeEEET_PT0_EUlNSA_9SymbolUseEE_LPv0EEENS9_INS4_10WalkResultEEESD_EUlSG_NS8_8ArrayRefIiEEE_LSI_0EEESK_SD_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL23symbolKnownUseEmptyImplIN4llvm9StringRefEN4mlir9OperationEEbT_PT0_EUlNS5_11SymbolTable9SymbolUseENS3_8ArrayRefIiEEE_LPv0EEENS3_8OptionalINS5_10WalkResultEEES7_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL23symbolKnownUseEmptyImplIPN4mlir9OperationES4_EbT_PT0_EUlNS3_11SymbolTable9SymbolUseEN4llvm8ArrayRefIiEEE_LPv0EEENSB_8OptionalINS3_10WalkResultEEES6_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL23symbolKnownUseEmptyImplIN4llvm9StringRefEN4mlir6RegionEEbT_PT0_EUlNS5_11SymbolTable9SymbolUseENS3_8ArrayRefIiEEE_LPv0EEENS3_8OptionalINS5_10WalkResultEEES7_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL23symbolKnownUseEmptyImplIPN4mlir9OperationENS3_6RegionEEbT_PT0_EUlNS3_11SymbolTable9SymbolUseEN4llvm8ArrayRefIiEEE_LPv0EEENSC_8OptionalINS3_10WalkResultEEES7_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL24replaceAllSymbolUsesImplIN4llvm9StringRefEN4mlir9OperationEENS5_13LogicalResultET_S4_PT0_EUlNS5_11SymbolTable9SymbolUseENS3_8ArrayRefIiEEE_LPv0EEENS3_8OptionalINS5_10WalkResultEEES8_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL24replaceAllSymbolUsesImplIPN4mlir9OperationES4_ENS3_13LogicalResultET_N4llvm9StringRefEPT0_EUlNS3_11SymbolTable9SymbolUseENS8_8ArrayRefIiEEE_LPv0EEENS8_8OptionalINS3_10WalkResultEEES7_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL24replaceAllSymbolUsesImplIN4llvm9StringRefEN4mlir6RegionEENS5_13LogicalResultET_S4_PT0_EUlNS5_11SymbolTable9SymbolUseENS3_8ArrayRefIiEEE_LPv0EEENS3_8OptionalINS5_10WalkResultEEES8_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL24replaceAllSymbolUsesImplIPN4mlir9OperationENS3_6RegionEENS3_13LogicalResultET_N4llvm9StringRefEPT0_EUlNS3_11SymbolTable9SymbolUseENS9_8ArrayRefIiEEE_LPv0EEENS9_8OptionalINS3_10WalkResultEEES8_
507
  /// This variant is used when the callback type matches a stripped down type:
508
  /// void(SymbolTable::SymbolUse use)
509
  template <typename CallbackT,
510
            typename std::enable_if_t<std::is_same<
511
                typename llvm::function_traits<CallbackT>::result_t,
512
                void>::value> * = nullptr>
513
0
  Optional<WalkResult> walk(CallbackT cback) {
514
0
    return walk([=](SymbolTable::SymbolUse use, ArrayRef<int>) {
515
0
      return cback(use), WalkResult::advance();
516
0
    });
Unexecuted instantiation: SymbolTable.cpp:_ZZN12_GLOBAL__N_111SymbolScope4walkIZL17getSymbolUsesImplIN4llvm9StringRefEN4mlir9OperationEENS3_8OptionalINS5_11SymbolTable8UseRangeEEET_PT0_EUlNS8_9SymbolUseEE_LPv0EEENS7_INS5_10WalkResultEEESB_ENKUlSE_NS3_8ArrayRefIiEEE_clESE_SK_
Unexecuted instantiation: SymbolTable.cpp:_ZZN12_GLOBAL__N_111SymbolScope4walkIZL17getSymbolUsesImplIPN4mlir9OperationES4_EN4llvm8OptionalINS3_11SymbolTable8UseRangeEEET_PT0_EUlNS8_9SymbolUseEE_LPv0EEENS7_INS3_10WalkResultEEESB_ENKUlSE_NS6_8ArrayRefIiEEE_clESE_SK_
Unexecuted instantiation: SymbolTable.cpp:_ZZN12_GLOBAL__N_111SymbolScope4walkIZL17getSymbolUsesImplIN4llvm9StringRefEN4mlir6RegionEENS3_8OptionalINS5_11SymbolTable8UseRangeEEET_PT0_EUlNS8_9SymbolUseEE_LPv0EEENS7_INS5_10WalkResultEEESB_ENKUlSE_NS3_8ArrayRefIiEEE_clESE_SK_
Unexecuted instantiation: SymbolTable.cpp:_ZZN12_GLOBAL__N_111SymbolScope4walkIZL17getSymbolUsesImplIPN4mlir9OperationENS3_6RegionEEN4llvm8OptionalINS3_11SymbolTable8UseRangeEEET_PT0_EUlNS9_9SymbolUseEE_LPv0EEENS8_INS3_10WalkResultEEESC_ENKUlSF_NS7_8ArrayRefIiEEE_clESF_SL_
517
0
  }
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL17getSymbolUsesImplIN4llvm9StringRefEN4mlir9OperationEENS3_8OptionalINS5_11SymbolTable8UseRangeEEET_PT0_EUlNS8_9SymbolUseEE_LPv0EEENS7_INS5_10WalkResultEEESB_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL17getSymbolUsesImplIPN4mlir9OperationES4_EN4llvm8OptionalINS3_11SymbolTable8UseRangeEEET_PT0_EUlNS8_9SymbolUseEE_LPv0EEENS7_INS3_10WalkResultEEESB_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL17getSymbolUsesImplIN4llvm9StringRefEN4mlir6RegionEENS3_8OptionalINS5_11SymbolTable8UseRangeEEET_PT0_EUlNS8_9SymbolUseEE_LPv0EEENS7_INS5_10WalkResultEEESB_
Unexecuted instantiation: SymbolTable.cpp:_ZN12_GLOBAL__N_111SymbolScope4walkIZL17getSymbolUsesImplIPN4mlir9OperationENS3_6RegionEEN4llvm8OptionalINS3_11SymbolTable8UseRangeEEET_PT0_EUlNS9_9SymbolUseEE_LPv0EEENS8_INS3_10WalkResultEEESC_
518
519
  /// The representation of the symbol within this scope.
520
  SymbolRefAttr symbol;
521
522
  /// The IR unit representing this scope.
523
  llvm::PointerUnion<Operation *, Region *> limit;
524
};
525
} // end anonymous namespace
526
527
/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
528
static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
529
0
                                                       Operation *limit) {
530
0
  StringRef symName = SymbolTable::getSymbolName(symbol);
531
0
  assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
532
0
533
0
  // Compute the ancestors of 'limit'.
534
0
  llvm::SetVector<Operation *, SmallVector<Operation *, 4>,
535
0
                  SmallPtrSet<Operation *, 4>>
536
0
      limitAncestors;
537
0
  Operation *limitAncestor = limit;
538
0
  do {
539
0
    // Check to see if 'symbol' is an ancestor of 'limit'.
540
0
    if (limitAncestor == symbol) {
541
0
      // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
542
0
      // doesn't support parent references.
543
0
      if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
544
0
          symbol->getParentOp())
545
0
        return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}};
546
0
      return {};
547
0
    }
548
0
549
0
    limitAncestors.insert(limitAncestor);
550
0
  } while ((limitAncestor = limitAncestor->getParentOp()));
551
0
552
0
  // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
553
0
  Operation *commonAncestor = symbol->getParentOp();
554
0
  do {
555
0
    if (limitAncestors.count(commonAncestor))
556
0
      break;
557
0
  } while ((commonAncestor = commonAncestor->getParentOp()));
558
0
  assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
559
0
560
0
  // Compute the set of valid nested references for 'symbol' as far up to the
561
0
  // common ancestor as possible.
562
0
  SmallVector<SymbolRefAttr, 2> references;
563
0
  bool collectedAllReferences = succeeded(
564
0
      collectValidReferencesFor(symbol, symName, commonAncestor, references));
565
0
566
0
  // Handle the case where the common ancestor is 'limit'.
567
0
  if (commonAncestor == limit) {
568
0
    SmallVector<SymbolScope, 2> scopes;
569
0
570
0
    // Walk each of the ancestors of 'symbol', calling the compute function for
571
0
    // each one.
572
0
    Operation *limitIt = symbol->getParentOp();
573
0
    for (size_t i = 0, e = references.size(); i != e;
574
0
         ++i, limitIt = limitIt->getParentOp()) {
575
0
      assert(limitIt->hasTrait<OpTrait::SymbolTable>());
576
0
      scopes.push_back({references[i], &limitIt->getRegion(0)});
577
0
    }
578
0
    return scopes;
579
0
  }
580
0
581
0
  // Otherwise, we just need the symbol reference for 'symbol' that will be
582
0
  // used within 'limit'. This is the last reference in the list we computed
583
0
  // above if we were able to collect all references.
584
0
  if (!collectedAllReferences)
585
0
    return {};
586
0
  return {{references.back(), limit}};
587
0
}
588
static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
589
0
                                                       Region *limit) {
590
0
  auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
591
0
592
0
  // If we collected some scopes to walk, make sure to constrain the one for
593
0
  // limit to the specific region requested.
594
0
  if (!scopes.empty())
595
0
    scopes.back().limit = limit;
596
0
  return scopes;
597
0
}
598
template <typename IRUnit>
599
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringRef symbol,
600
0
                                                       IRUnit *limit) {
601
0
  return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}};
602
0
}
Unexecuted instantiation: SymbolTable.cpp:_ZL19collectSymbolScopesIN4mlir9OperationEEN4llvm11SmallVectorIN12_GLOBAL__N_111SymbolScopeELj1EEENS2_9StringRefEPT_
Unexecuted instantiation: SymbolTable.cpp:_ZL19collectSymbolScopesIN4mlir6RegionEEN4llvm11SmallVectorIN12_GLOBAL__N_111SymbolScopeELj1EEENS2_9StringRefEPT_
603
604
/// Returns true if the given reference 'SubRef' is a sub reference of the
605
/// reference 'ref', i.e. 'ref' is a further qualified reference.
606
0
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
607
0
  if (ref == subRef)
608
0
    return true;
609
0
610
0
  // If the references are not pointer equal, check to see if `subRef` is a
611
0
  // prefix of `ref`.
612
0
  if (ref.isa<FlatSymbolRefAttr>() ||
613
0
      ref.getRootReference() != subRef.getRootReference())
614
0
    return false;
615
0
616
0
  auto refLeafs = ref.getNestedReferences();
617
0
  auto subRefLeafs = subRef.getNestedReferences();
618
0
  return subRefLeafs.size() < refLeafs.size() &&
619
0
         subRefLeafs == refLeafs.take_front(subRefLeafs.size());
620
0
}
621
622
//===----------------------------------------------------------------------===//
623
// SymbolTable::getSymbolUses
624
625
/// The implementation of SymbolTable::getSymbolUses below.
626
template <typename FromT>
627
0
static Optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
628
0
  std::vector<SymbolTable::SymbolUse> uses;
629
0
  auto walkFn = [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
630
0
    uses.push_back(symbolUse);
631
0
    return WalkResult::advance();
632
0
  };
Unexecuted instantiation: SymbolTable.cpp:_ZZL17getSymbolUsesImplIPN4mlir9OperationEEN4llvm8OptionalINS0_11SymbolTable8UseRangeEEET_ENKUlNS5_9SymbolUseENS3_8ArrayRefIiEEE_clES9_SB_
Unexecuted instantiation: SymbolTable.cpp:_ZZL17getSymbolUsesImplIN4llvm15MutableArrayRefIN4mlir6RegionEEEENS0_8OptionalINS2_11SymbolTable8UseRangeEEET_ENKUlNS6_9SymbolUseENS0_8ArrayRefIiEEE_clESA_SC_
633
0
  auto result = walkSymbolUses(from, walkFn);
634
0
  return result ? Optional<SymbolTable::UseRange>(std::move(uses)) : llvm::None;
635
0
}
Unexecuted instantiation: SymbolTable.cpp:_ZL17getSymbolUsesImplIPN4mlir9OperationEEN4llvm8OptionalINS0_11SymbolTable8UseRangeEEET_
Unexecuted instantiation: SymbolTable.cpp:_ZL17getSymbolUsesImplIN4llvm15MutableArrayRefIN4mlir6RegionEEEENS0_8OptionalINS2_11SymbolTable8UseRangeEEET_
636
637
/// Get an iterator range for all of the uses, for any symbol, that are nested
638
/// within the given operation 'from'. This does not traverse into any nested
639
/// symbol tables, and will also only return uses on 'from' if it does not
640
/// also define a symbol table. This is because we treat the region as the
641
/// boundary of the symbol table, and not the op itself. This function returns
642
/// None if there are any unknown operations that may potentially be symbol
643
/// tables.
644
0
auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> {
645
0
  return getSymbolUsesImpl(from);
646
0
}
647
0
auto SymbolTable::getSymbolUses(Region *from) -> Optional<UseRange> {
648
0
  return getSymbolUsesImpl(MutableArrayRef<Region>(*from));
649
0
}
650
651
//===----------------------------------------------------------------------===//
652
// SymbolTable::getSymbolUses
653
654
/// The implementation of SymbolTable::getSymbolUses below.
655
template <typename SymbolT, typename IRUnitT>
656
static Optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
657
0
                                                         IRUnitT *limit) {
658
0
  std::vector<SymbolTable::SymbolUse> uses;
659
0
  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
660
0
    if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
661
0
          if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
662
0
            uses.push_back(symbolUse);
663
0
        }))
Unexecuted instantiation: SymbolTable.cpp:_ZZL17getSymbolUsesImplIN4llvm9StringRefEN4mlir9OperationEENS0_8OptionalINS2_11SymbolTable8UseRangeEEET_PT0_ENKUlNS5_9SymbolUseEE_clESB_
Unexecuted instantiation: SymbolTable.cpp:_ZZL17getSymbolUsesImplIPN4mlir9OperationES1_EN4llvm8OptionalINS0_11SymbolTable8UseRangeEEET_PT0_ENKUlNS5_9SymbolUseEE_clESB_
Unexecuted instantiation: SymbolTable.cpp:_ZZL17getSymbolUsesImplIN4llvm9StringRefEN4mlir6RegionEENS0_8OptionalINS2_11SymbolTable8UseRangeEEET_PT0_ENKUlNS5_9SymbolUseEE_clESB_
Unexecuted instantiation: SymbolTable.cpp:_ZZL17getSymbolUsesImplIPN4mlir9OperationENS0_6RegionEEN4llvm8OptionalINS0_11SymbolTable8UseRangeEEET_PT0_ENKUlNS6_9SymbolUseEE_clESC_
664
0
      return llvm::None;
665
0
  }
666
0
  return SymbolTable::UseRange(std::move(uses));
667
0
}
Unexecuted instantiation: SymbolTable.cpp:_ZL17getSymbolUsesImplIN4llvm9StringRefEN4mlir9OperationEENS0_8OptionalINS2_11SymbolTable8UseRangeEEET_PT0_
Unexecuted instantiation: SymbolTable.cpp:_ZL17getSymbolUsesImplIPN4mlir9OperationES1_EN4llvm8OptionalINS0_11SymbolTable8UseRangeEEET_PT0_
Unexecuted instantiation: SymbolTable.cpp:_ZL17getSymbolUsesImplIN4llvm9StringRefEN4mlir6RegionEENS0_8OptionalINS2_11SymbolTable8UseRangeEEET_PT0_
Unexecuted instantiation: SymbolTable.cpp:_ZL17getSymbolUsesImplIPN4mlir9OperationENS0_6RegionEEN4llvm8OptionalINS0_11SymbolTable8UseRangeEEET_PT0_
668
669
/// Get all of the uses of the given symbol that are nested within the given
670
/// operation 'from', invoking the provided callback for each. This does not
671
/// traverse into any nested symbol tables. This function returns None if there
672
/// are any unknown operations that may potentially be symbol tables.
673
auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from)
674
0
    -> Optional<UseRange> {
675
0
  return getSymbolUsesImpl(symbol, from);
676
0
}
677
auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
678
0
    -> Optional<UseRange> {
679
0
  return getSymbolUsesImpl(symbol, from);
680
0
}
681
auto SymbolTable::getSymbolUses(StringRef symbol, Region *from)
682
0
    -> Optional<UseRange> {
683
0
  return getSymbolUsesImpl(symbol, from);
684
0
}
685
auto SymbolTable::getSymbolUses(Operation *symbol, Region *from)
686
0
    -> Optional<UseRange> {
687
0
  return getSymbolUsesImpl(symbol, from);
688
0
}
689
690
//===----------------------------------------------------------------------===//
691
// SymbolTable::symbolKnownUseEmpty
692
693
/// The implementation of SymbolTable::symbolKnownUseEmpty below.
694
template <typename SymbolT, typename IRUnitT>
695
0
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
696
0
  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
697
0
    // Walk all of the symbol uses looking for a reference to 'symbol'.
698
0
    if (scope.walk([&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
699
0
          return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
700
0
                     ? WalkResult::interrupt()
701
0
                     : WalkResult::advance();
702
0
        }) != WalkResult::advance())
Unexecuted instantiation: SymbolTable.cpp:_ZZL23symbolKnownUseEmptyImplIN4llvm9StringRefEN4mlir9OperationEEbT_PT0_ENKUlNS2_11SymbolTable9SymbolUseENS0_8ArrayRefIiEEE_clES8_SA_
Unexecuted instantiation: SymbolTable.cpp:_ZZL23symbolKnownUseEmptyImplIPN4mlir9OperationES1_EbT_PT0_ENKUlNS0_11SymbolTable9SymbolUseEN4llvm8ArrayRefIiEEE_clES7_SA_
Unexecuted instantiation: SymbolTable.cpp:_ZZL23symbolKnownUseEmptyImplIN4llvm9StringRefEN4mlir6RegionEEbT_PT0_ENKUlNS2_11SymbolTable9SymbolUseENS0_8ArrayRefIiEEE_clES8_SA_
Unexecuted instantiation: SymbolTable.cpp:_ZZL23symbolKnownUseEmptyImplIPN4mlir9OperationENS0_6RegionEEbT_PT0_ENKUlNS0_11SymbolTable9SymbolUseEN4llvm8ArrayRefIiEEE_clES8_SB_
703
0
      return false;
704
0
  }
705
0
  return true;
706
0
}
Unexecuted instantiation: SymbolTable.cpp:_ZL23symbolKnownUseEmptyImplIN4llvm9StringRefEN4mlir9OperationEEbT_PT0_
Unexecuted instantiation: SymbolTable.cpp:_ZL23symbolKnownUseEmptyImplIPN4mlir9OperationES1_EbT_PT0_
Unexecuted instantiation: SymbolTable.cpp:_ZL23symbolKnownUseEmptyImplIN4llvm9StringRefEN4mlir6RegionEEbT_PT0_
Unexecuted instantiation: SymbolTable.cpp:_ZL23symbolKnownUseEmptyImplIPN4mlir9OperationENS0_6RegionEEbT_PT0_
707
708
/// Return if the given symbol is known to have no uses that are nested within
709
/// the given operation 'from'. This does not traverse into any nested symbol
710
/// tables. This function will also return false if there are any unknown
711
/// operations that may potentially be symbol tables.
712
0
bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) {
713
0
  return symbolKnownUseEmptyImpl(symbol, from);
714
0
}
715
0
bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
716
0
  return symbolKnownUseEmptyImpl(symbol, from);
717
0
}
718
0
bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Region *from) {
719
0
  return symbolKnownUseEmptyImpl(symbol, from);
720
0
}
721
0
bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
722
0
  return symbolKnownUseEmptyImpl(symbol, from);
723
0
}
724
725
//===----------------------------------------------------------------------===//
726
// SymbolTable::replaceAllSymbolUses
727
728
/// Rebuild the given attribute container after replacing all references to a
729
/// symbol with the updated attribute in 'accesses'.
730
static Attribute rebuildAttrAfterRAUW(
731
    Attribute container,
732
    ArrayRef<std::pair<SmallVector<int, 1>, SymbolRefAttr>> accesses,
733
0
    unsigned depth) {
734
0
  // Given a range of Attributes, update the ones referred to by the given
735
0
  // access chains to point to the new symbol attribute.
736
0
  auto updateAttrs = [&](auto &&attrRange) {
737
0
    auto attrBegin = std::begin(attrRange);
738
0
    for (unsigned i = 0, e = accesses.size(); i != e;) {
739
0
      ArrayRef<int> access = accesses[i].first;
740
0
      Attribute &attr = *std::next(attrBegin, access[depth]);
741
0
742
0
      // Check to see if this is a leaf access, i.e. a SymbolRef.
743
0
      if (access.size() == depth + 1) {
744
0
        attr = accesses[i].second;
745
0
        ++i;
746
0
        continue;
747
0
      }
748
0
749
0
      // Otherwise, this is a container. Collect all of the accesses for this
750
0
      // index and recurse. The recursion here is bounded by the size of the
751
0
      // largest access array.
752
0
      auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
753
0
        ArrayRef<int> nextAccess = it.first;
754
0
        return nextAccess.size() > depth + 1 &&
755
0
               nextAccess[depth] == access[depth];
756
0
      });
Unexecuted instantiation: SymbolTable.cpp:_ZZZL20rebuildAttrAfterRAUWN4mlir9AttributeEN4llvm8ArrayRefISt4pairINS1_11SmallVectorIiLj1EEENS_13SymbolRefAttrEEEEjENK3$_0clINS1_14iterator_rangeINS1_15mapped_iteratorIPS3_INS_10IdentifierES0_EZNS1_17make_second_rangeIRNS4_ISE_Lj4EEEEEDaOT_EUlRSE_E_RS0_EEEEEEDaSK_ENKUlTyRSJ_E_clIKS7_EEDaSQ_
Unexecuted instantiation: SymbolTable.cpp:_ZZZL20rebuildAttrAfterRAUWN4mlir9AttributeEN4llvm8ArrayRefISt4pairINS1_11SmallVectorIiLj1EEENS_13SymbolRefAttrEEEEjENK3$_0clIRNS4_IS0_Lj4EEEEEDaOT_ENKUlTyRSD_E_clIKS7_EEDaSF_
757
0
      attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1);
758
0
759
0
      // Skip over all of the accesses that refer to the nested container.
760
0
      i += nestedAccesses.size();
761
0
    }
762
0
  };
Unexecuted instantiation: SymbolTable.cpp:_ZZL20rebuildAttrAfterRAUWN4mlir9AttributeEN4llvm8ArrayRefISt4pairINS1_11SmallVectorIiLj1EEENS_13SymbolRefAttrEEEEjENK3$_0clINS1_14iterator_rangeINS1_15mapped_iteratorIPS3_INS_10IdentifierES0_EZNS1_17make_second_rangeIRNS4_ISE_Lj4EEEEEDaOT_EUlRSE_E_RS0_EEEEEEDaSK_
Unexecuted instantiation: SymbolTable.cpp:_ZZL20rebuildAttrAfterRAUWN4mlir9AttributeEN4llvm8ArrayRefISt4pairINS1_11SmallVectorIiLj1EEENS_13SymbolRefAttrEEEEjENK3$_0clIRNS4_IS0_Lj4EEEEEDaOT_
763
0
764
0
  if (auto dictAttr = container.dyn_cast<DictionaryAttr>()) {
765
0
    auto newAttrs = llvm::to_vector<4>(dictAttr.getValue());
766
0
    updateAttrs(make_second_range(newAttrs));
767
0
    return DictionaryAttr::get(newAttrs, dictAttr.getContext());
768
0
  }
769
0
  auto newAttrs = llvm::to_vector<4>(container.cast<ArrayAttr>().getValue());
770
0
  updateAttrs(newAttrs);
771
0
  return ArrayAttr::get(newAttrs, container.getContext());
772
0
}
773
774
/// Generates a new symbol reference attribute with a new leaf reference.
775
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
776
0
                                        FlatSymbolRefAttr newLeafAttr) {
777
0
  if (oldAttr.isa<FlatSymbolRefAttr>())
778
0
    return newLeafAttr;
779
0
  auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
780
0
  nestedRefs.back() = newLeafAttr;
781
0
  return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs,
782
0
                            oldAttr.getContext());
783
0
}
784
785
/// The implementation of SymbolTable::replaceAllSymbolUses below.
786
template <typename SymbolT, typename IRUnitT>
787
static LogicalResult
788
0
replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
789
0
  // A collection of operations along with their new attribute dictionary.
790
0
  std::vector<std::pair<Operation *, DictionaryAttr>> updatedAttrDicts;
791
0
792
0
  // The current operation being processed.
793
0
  Operation *curOp = nullptr;
794
0
795
0
  // The set of access chains into the attribute dictionary of the current
796
0
  // operation, as well as the replacement attribute to use.
797
0
  SmallVector<std::pair<SmallVector<int, 1>, SymbolRefAttr>, 1> accessChains;
798
0
799
0
  // Generate a new attribute dictionary for the current operation by replacing
800
0
  // references to the old symbol.
801
0
  auto generateNewAttrDict = [&] {
802
0
    auto oldDict = curOp->getAttrDictionary();
803
0
    auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0);
804
0
    return newDict.cast<DictionaryAttr>();
805
0
  };
Unexecuted instantiation: SymbolTable.cpp:_ZZL24replaceAllSymbolUsesImplIN4llvm9StringRefEN4mlir9OperationEENS2_13LogicalResultET_S1_PT0_ENKUlvE_clEv
Unexecuted instantiation: SymbolTable.cpp:_ZZL24replaceAllSymbolUsesImplIPN4mlir9OperationES1_ENS0_13LogicalResultET_N4llvm9StringRefEPT0_ENKUlvE_clEv
Unexecuted instantiation: SymbolTable.cpp:_ZZL24replaceAllSymbolUsesImplIN4llvm9StringRefEN4mlir6RegionEENS2_13LogicalResultET_S1_PT0_ENKUlvE_clEv
Unexecuted instantiation: SymbolTable.cpp:_ZZL24replaceAllSymbolUsesImplIPN4mlir9OperationENS0_6RegionEENS0_13LogicalResultET_N4llvm9StringRefEPT0_ENKUlvE_clEv
806
0
807
0
  // Generate a new attribute to replace the given attribute.
808
0
  MLIRContext *ctx = limit->getContext();
809
0
  FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx);
810
0
  for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
811
0
    SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
812
0
    auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
813
0
                      ArrayRef<int> accessChain) {
814
0
      SymbolRefAttr useRef = symbolUse.getSymbolRef();
815
0
      if (!isReferencePrefixOf(scope.symbol, useRef))
816
0
        return WalkResult::advance();
817
0
818
0
      // If we have a valid match, check to see if this is a proper
819
0
      // subreference. If it is, then we will need to generate a different new
820
0
      // attribute specifically for this use.
821
0
      SymbolRefAttr replacementRef = newAttr;
822
0
      if (useRef != scope.symbol) {
823
0
        if (scope.symbol.isa<FlatSymbolRefAttr>()) {
824
0
          replacementRef =
825
0
              SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx);
826
0
        } else {
827
0
          auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
828
0
          nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
829
0
              newLeafAttr;
830
0
          replacementRef =
831
0
              SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx);
832
0
        }
833
0
      }
834
0
835
0
      // If there was a previous operation, generate a new attribute dict
836
0
      // for it. This means that we've finished processing the current
837
0
      // operation, so generate a new dictionary for it.
838
0
      if (curOp && symbolUse.getUser() != curOp) {
839
0
        updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
840
0
        accessChains.clear();
841
0
      }
842
0
843
0
      // Record this access.
844
0
      curOp = symbolUse.getUser();
845
0
      accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef});
846
0
      return WalkResult::advance();
847
0
    };
Unexecuted instantiation: SymbolTable.cpp:_ZZL24replaceAllSymbolUsesImplIN4llvm9StringRefEN4mlir9OperationEENS2_13LogicalResultET_S1_PT0_ENKUlNS2_11SymbolTable9SymbolUseENS0_8ArrayRefIiEEE_clES9_SB_
Unexecuted instantiation: SymbolTable.cpp:_ZZL24replaceAllSymbolUsesImplIPN4mlir9OperationES1_ENS0_13LogicalResultET_N4llvm9StringRefEPT0_ENKUlNS0_11SymbolTable9SymbolUseENS5_8ArrayRefIiEEE_clESA_SC_
Unexecuted instantiation: SymbolTable.cpp:_ZZL24replaceAllSymbolUsesImplIN4llvm9StringRefEN4mlir6RegionEENS2_13LogicalResultET_S1_PT0_ENKUlNS2_11SymbolTable9SymbolUseENS0_8ArrayRefIiEEE_clES9_SB_
Unexecuted instantiation: SymbolTable.cpp:_ZZL24replaceAllSymbolUsesImplIPN4mlir9OperationENS0_6RegionEENS0_13LogicalResultET_N4llvm9StringRefEPT0_ENKUlNS0_11SymbolTable9SymbolUseENS6_8ArrayRefIiEEE_clESB_SD_
848
0
    if (!scope.walk(walkFn))
849
0
      return failure();
850
0
851
0
    // Check to see if we have a dangling op that needs to be processed.
852
0
    if (curOp) {
853
0
      updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
854
0
      curOp = nullptr;
855
0
    }
856
0
  }
857
0
858
0
  // Update the attribute dictionaries as necessary.
859
0
  for (auto &it : updatedAttrDicts)
860
0
    it.first->setAttrs(it.second);
861
0
  return success();
862
0
}
Unexecuted instantiation: SymbolTable.cpp:_ZL24replaceAllSymbolUsesImplIN4llvm9StringRefEN4mlir9OperationEENS2_13LogicalResultET_S1_PT0_
Unexecuted instantiation: SymbolTable.cpp:_ZL24replaceAllSymbolUsesImplIPN4mlir9OperationES1_ENS0_13LogicalResultET_N4llvm9StringRefEPT0_
Unexecuted instantiation: SymbolTable.cpp:_ZL24replaceAllSymbolUsesImplIN4llvm9StringRefEN4mlir6RegionEENS2_13LogicalResultET_S1_PT0_
Unexecuted instantiation: SymbolTable.cpp:_ZL24replaceAllSymbolUsesImplIPN4mlir9OperationENS0_6RegionEENS0_13LogicalResultET_N4llvm9StringRefEPT0_
863
864
/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
865
/// provided symbol 'newSymbol' that are nested within the given operation
866
/// 'from'. This does not traverse into any nested symbol tables. If there are
867
/// any unknown operations that may potentially be symbol tables, no uses are
868
/// replaced and failure is returned.
869
LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
870
                                                StringRef newSymbol,
871
0
                                                Operation *from) {
872
0
  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
873
0
}
874
LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
875
                                                StringRef newSymbol,
876
0
                                                Operation *from) {
877
0
  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
878
0
}
879
LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
880
                                                StringRef newSymbol,
881
0
                                                Region *from) {
882
0
  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
883
0
}
884
LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
885
                                                StringRef newSymbol,
886
0
                                                Region *from) {
887
0
  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
888
0
}
889
890
//===----------------------------------------------------------------------===//
891
// Symbol Interfaces
892
//===----------------------------------------------------------------------===//
893
894
/// Include the generated symbol interfaces.
895
#include "mlir/IR/SymbolInterfaces.cpp.inc"