Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/include/mlir/IR/SymbolTable.h
Line
Count
Source (jump to first uncovered line)
1
//===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#ifndef MLIR_IR_SYMBOLTABLE_H
10
#define MLIR_IR_SYMBOLTABLE_H
11
12
#include "mlir/IR/Attributes.h"
13
#include "mlir/IR/OpDefinition.h"
14
#include "llvm/ADT/StringMap.h"
15
16
namespace mlir {
17
class Identifier;
18
class Operation;
19
20
/// This class allows for representing and managing the symbol table used by
21
/// operations with the 'SymbolTable' trait. Inserting into and erasing from
22
/// this SymbolTable will also insert and erase from the Operation given to it
23
/// at construction.
24
class SymbolTable {
25
public:
26
  /// Build a symbol table with the symbols within the given operation.
27
  SymbolTable(Operation *symbolTableOp);
28
29
  /// Look up a symbol with the specified name, returning null if no such
30
  /// name exists. Names never include the @ on them.
31
  Operation *lookup(StringRef name) const;
32
  template <typename T> T lookup(StringRef name) const {
33
    return dyn_cast_or_null<T>(lookup(name));
34
  }
35
36
  /// Erase the given symbol from the table.
37
  void erase(Operation *symbol);
38
39
  /// Insert a new symbol into the table, and rename it as necessary to avoid
40
  /// collisions. Also insert at the specified location in the body of the
41
  /// associated operation.
42
  void insert(Operation *symbol, Block::iterator insertPt = {});
43
44
  /// Return the name of the attribute used for symbol names.
45
0
  static StringRef getSymbolAttrName() { return "sym_name"; }
46
47
  /// Returns the associated operation.
48
0
  Operation *getOp() const { return symbolTableOp; }
49
50
  /// Return the name of the attribute used for symbol visibility.
51
0
  static StringRef getVisibilityAttrName() { return "sym_visibility"; }
52
53
  //===--------------------------------------------------------------------===//
54
  // Symbol Utilities
55
  //===--------------------------------------------------------------------===//
56
57
  /// An enumeration detailing the different visibility types that a symbol may
58
  /// have.
59
  enum class Visibility {
60
    /// The symbol is public and may be referenced anywhere internal or external
61
    /// to the visible references in the IR.
62
    Public,
63
64
    /// The symbol is private and may only be referenced by SymbolRefAttrs local
65
    /// to the operations within the current symbol table.
66
    Private,
67
68
    /// The symbol is visible to the current IR, which may include operations in
69
    /// symbol tables above the one that owns the current symbol. `Nested`
70
    /// visibility allows for referencing a symbol outside of its current symbol
71
    /// table, while retaining the ability to observe all uses.
72
    Nested,
73
  };
74
75
  /// Returns the name of the given symbol operation.
76
  static StringRef getSymbolName(Operation *symbol);
77
  /// Sets the name of the given symbol operation.
78
  static void setSymbolName(Operation *symbol, StringRef name);
79
80
  /// Returns the visibility of the given symbol operation.
81
  static Visibility getSymbolVisibility(Operation *symbol);
82
  /// Sets the visibility of the given symbol operation.
83
  static void setSymbolVisibility(Operation *symbol, Visibility vis);
84
85
  /// Returns the nearest symbol table from a given operation `from`. Returns
86
  /// nullptr if no valid parent symbol table could be found.
87
  static Operation *getNearestSymbolTable(Operation *from);
88
89
  /// Walks all symbol table operations nested within, and including, `op`. For
90
  /// each symbol table operation, the provided callback is invoked with the op
91
  /// and a boolean signifying if the symbols within that symbol table can be
92
  /// treated as if all uses within the IR are visible to the caller.
93
  /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols
94
  /// within `op` are visible.
95
  static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
96
                               function_ref<void(Operation *, bool)> callback);
97
98
  /// Returns the operation registered with the given symbol name with the
99
  /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
100
  /// with the 'OpTrait::SymbolTable' trait.
101
  static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
102
  static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
103
  /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
104
  /// by a given SymbolRefAttr. Returns failure if any of the nested references
105
  /// could not be resolved.
106
  static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol,
107
                                      SmallVectorImpl<Operation *> &symbols);
108
109
  /// Returns the operation registered with the given symbol name within the
110
  /// closest parent operation of, or including, 'from' with the
111
  /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
112
  /// found.
113
  static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
114
  static Operation *lookupNearestSymbolFrom(Operation *from,
115
                                            SymbolRefAttr symbol);
116
  template <typename T>
117
  static T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
118
    return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
119
  }
120
  template <typename T>
121
  static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
122
    return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
123
  }
124
125
  /// This class represents a specific symbol use.
126
  class SymbolUse {
127
  public:
128
    SymbolUse(Operation *op, SymbolRefAttr symbolRef)
129
0
        : owner(op), symbolRef(symbolRef) {}
130
131
    /// Return the operation user of this symbol reference.
132
0
    Operation *getUser() const { return owner; }
133
134
    /// Return the symbol reference that this use represents.
135
0
    SymbolRefAttr getSymbolRef() const { return symbolRef; }
136
137
  private:
138
    /// The operation that this access is held by.
139
    Operation *owner;
140
141
    /// The symbol reference that this use represents.
142
    SymbolRefAttr symbolRef;
143
  };
144
145
  /// This class implements a range of SymbolRef uses.
146
  class UseRange {
147
  public:
148
0
    UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {}
149
150
    using iterator = std::vector<SymbolUse>::const_iterator;
151
0
    iterator begin() const { return uses.begin(); }
152
0
    iterator end() const { return uses.end(); }
153
154
  private:
155
    std::vector<SymbolUse> uses;
156
  };
157
158
  /// Get an iterator range for all of the uses, for any symbol, that are nested
159
  /// within the given operation 'from'. This does not traverse into any nested
160
  /// symbol tables. This function returns None if there are any unknown
161
  /// operations that may potentially be symbol tables.
162
  static Optional<UseRange> getSymbolUses(Operation *from);
163
  static Optional<UseRange> getSymbolUses(Region *from);
164
165
  /// Get all of the uses of the given symbol that are nested within the given
166
  /// operation 'from'. This does not traverse into any nested symbol tables.
167
  /// This function returns None if there are any unknown operations that may
168
  /// potentially be symbol tables.
169
  static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
170
  static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
171
  static Optional<UseRange> getSymbolUses(StringRef symbol, Region *from);
172
  static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from);
173
174
  /// Return if the given symbol is known to have no uses that are nested
175
  /// within the given operation 'from'. This does not traverse into any nested
176
  /// symbol tables. This function will also return false if there are any
177
  /// unknown operations that may potentially be symbol tables. This doesn't
178
  /// necessarily mean that there are no uses, we just can't conservatively
179
  /// prove it.
180
  static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
181
  static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
182
  static bool symbolKnownUseEmpty(StringRef symbol, Region *from);
183
  static bool symbolKnownUseEmpty(Operation *symbol, Region *from);
184
185
  /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
186
  /// provided symbol 'newSymbol' that are nested within the given operation
187
  /// 'from'. This does not traverse into any nested symbol tables. If there are
188
  /// any unknown operations that may potentially be symbol tables, no uses are
189
  /// replaced and failure is returned.
190
  LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
191
                                                           StringRef newSymbol,
192
                                                           Operation *from);
193
  LLVM_NODISCARD static LogicalResult
194
  replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName,
195
                       Operation *from);
196
  LLVM_NODISCARD static LogicalResult
197
  replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, Region *from);
198
  LLVM_NODISCARD static LogicalResult
199
  replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName,
200
                       Region *from);
201
202
private:
203
  Operation *symbolTableOp;
204
205
  /// This is a mapping from a name to the symbol with that name.
206
  llvm::StringMap<Operation *> symbolTable;
207
208
  /// This is used when name conflicts are detected.
209
  unsigned uniquingCounter = 0;
210
};
211
212
//===----------------------------------------------------------------------===//
213
// SymbolTable Trait Types
214
//===----------------------------------------------------------------------===//
215
216
namespace detail {
217
LogicalResult verifySymbolTable(Operation *op);
218
LogicalResult verifySymbol(Operation *op);
219
} // namespace detail
220
221
namespace OpTrait {
222
/// A trait used to provide symbol table functionalities to a region operation.
223
/// This operation must hold exactly 1 region. Once attached, all operations
224
/// that are directly within the region, i.e not including those within child
225
/// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will
226
/// be verified to ensure that the names are uniqued. These operations must also
227
/// adhere to the constraints defined by the `Symbol` trait, even if they do not
228
/// inherit from it.
229
template <typename ConcreteType>
230
class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
231
public:
232
0
  static LogicalResult verifyTrait(Operation *op) {
233
0
    return ::mlir::detail::verifySymbolTable(op);
234
0
  }
235
236
  /// Look up a symbol with the specified name, returning null if no such
237
  /// name exists. Symbol names never include the @ on them. Note: This
238
  /// performs a linear scan of held symbols.
239
0
  Operation *lookupSymbol(StringRef name) {
240
0
    return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
241
0
  }
242
0
  template <typename T> T lookupSymbol(StringRef name) {
243
0
    return dyn_cast_or_null<T>(lookupSymbol(name));
244
0
  }
245
  Operation *lookupSymbol(SymbolRefAttr symbol) {
246
    return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol);
247
  }
248
  template <typename T>
249
  T lookupSymbol(SymbolRefAttr symbol) {
250
    return dyn_cast_or_null<T>(lookupSymbol(symbol));
251
  }
252
};
253
254
} // end namespace OpTrait
255
256
/// Include the generated symbol interfaces.
257
#include "mlir/IR/SymbolInterfaces.h.inc"
258
259
} // end namespace mlir
260
261
#endif // MLIR_IR_SYMBOLTABLE_H