/home/arjun/llvm-project/mlir/include/mlir/IR/SymbolTable.h
Line | Count | Source (jump to first uncovered line) |
1 | | //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #ifndef MLIR_IR_SYMBOLTABLE_H |
10 | | #define MLIR_IR_SYMBOLTABLE_H |
11 | | |
12 | | #include "mlir/IR/Attributes.h" |
13 | | #include "mlir/IR/OpDefinition.h" |
14 | | #include "llvm/ADT/StringMap.h" |
15 | | |
16 | | namespace mlir { |
17 | | class Identifier; |
18 | | class Operation; |
19 | | |
20 | | /// This class allows for representing and managing the symbol table used by |
21 | | /// operations with the 'SymbolTable' trait. Inserting into and erasing from |
22 | | /// this SymbolTable will also insert and erase from the Operation given to it |
23 | | /// at construction. |
24 | | class SymbolTable { |
25 | | public: |
26 | | /// Build a symbol table with the symbols within the given operation. |
27 | | SymbolTable(Operation *symbolTableOp); |
28 | | |
29 | | /// Look up a symbol with the specified name, returning null if no such |
30 | | /// name exists. Names never include the @ on them. |
31 | | Operation *lookup(StringRef name) const; |
32 | | template <typename T> T lookup(StringRef name) const { |
33 | | return dyn_cast_or_null<T>(lookup(name)); |
34 | | } |
35 | | |
36 | | /// Erase the given symbol from the table. |
37 | | void erase(Operation *symbol); |
38 | | |
39 | | /// Insert a new symbol into the table, and rename it as necessary to avoid |
40 | | /// collisions. Also insert at the specified location in the body of the |
41 | | /// associated operation. |
42 | | void insert(Operation *symbol, Block::iterator insertPt = {}); |
43 | | |
44 | | /// Return the name of the attribute used for symbol names. |
45 | 0 | static StringRef getSymbolAttrName() { return "sym_name"; } |
46 | | |
47 | | /// Returns the associated operation. |
48 | 0 | Operation *getOp() const { return symbolTableOp; } |
49 | | |
50 | | /// Return the name of the attribute used for symbol visibility. |
51 | 0 | static StringRef getVisibilityAttrName() { return "sym_visibility"; } |
52 | | |
53 | | //===--------------------------------------------------------------------===// |
54 | | // Symbol Utilities |
55 | | //===--------------------------------------------------------------------===// |
56 | | |
57 | | /// An enumeration detailing the different visibility types that a symbol may |
58 | | /// have. |
59 | | enum class Visibility { |
60 | | /// The symbol is public and may be referenced anywhere internal or external |
61 | | /// to the visible references in the IR. |
62 | | Public, |
63 | | |
64 | | /// The symbol is private and may only be referenced by SymbolRefAttrs local |
65 | | /// to the operations within the current symbol table. |
66 | | Private, |
67 | | |
68 | | /// The symbol is visible to the current IR, which may include operations in |
69 | | /// symbol tables above the one that owns the current symbol. `Nested` |
70 | | /// visibility allows for referencing a symbol outside of its current symbol |
71 | | /// table, while retaining the ability to observe all uses. |
72 | | Nested, |
73 | | }; |
74 | | |
75 | | /// Returns the name of the given symbol operation. |
76 | | static StringRef getSymbolName(Operation *symbol); |
77 | | /// Sets the name of the given symbol operation. |
78 | | static void setSymbolName(Operation *symbol, StringRef name); |
79 | | |
80 | | /// Returns the visibility of the given symbol operation. |
81 | | static Visibility getSymbolVisibility(Operation *symbol); |
82 | | /// Sets the visibility of the given symbol operation. |
83 | | static void setSymbolVisibility(Operation *symbol, Visibility vis); |
84 | | |
85 | | /// Returns the nearest symbol table from a given operation `from`. Returns |
86 | | /// nullptr if no valid parent symbol table could be found. |
87 | | static Operation *getNearestSymbolTable(Operation *from); |
88 | | |
89 | | /// Walks all symbol table operations nested within, and including, `op`. For |
90 | | /// each symbol table operation, the provided callback is invoked with the op |
91 | | /// and a boolean signifying if the symbols within that symbol table can be |
92 | | /// treated as if all uses within the IR are visible to the caller. |
93 | | /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols |
94 | | /// within `op` are visible. |
95 | | static void walkSymbolTables(Operation *op, bool allSymUsesVisible, |
96 | | function_ref<void(Operation *, bool)> callback); |
97 | | |
98 | | /// Returns the operation registered with the given symbol name with the |
99 | | /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation |
100 | | /// with the 'OpTrait::SymbolTable' trait. |
101 | | static Operation *lookupSymbolIn(Operation *op, StringRef symbol); |
102 | | static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); |
103 | | /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced |
104 | | /// by a given SymbolRefAttr. Returns failure if any of the nested references |
105 | | /// could not be resolved. |
106 | | static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol, |
107 | | SmallVectorImpl<Operation *> &symbols); |
108 | | |
109 | | /// Returns the operation registered with the given symbol name within the |
110 | | /// closest parent operation of, or including, 'from' with the |
111 | | /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was |
112 | | /// found. |
113 | | static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol); |
114 | | static Operation *lookupNearestSymbolFrom(Operation *from, |
115 | | SymbolRefAttr symbol); |
116 | | template <typename T> |
117 | | static T lookupNearestSymbolFrom(Operation *from, StringRef symbol) { |
118 | | return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); |
119 | | } |
120 | | template <typename T> |
121 | | static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { |
122 | | return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); |
123 | | } |
124 | | |
125 | | /// This class represents a specific symbol use. |
126 | | class SymbolUse { |
127 | | public: |
128 | | SymbolUse(Operation *op, SymbolRefAttr symbolRef) |
129 | 0 | : owner(op), symbolRef(symbolRef) {} |
130 | | |
131 | | /// Return the operation user of this symbol reference. |
132 | 0 | Operation *getUser() const { return owner; } |
133 | | |
134 | | /// Return the symbol reference that this use represents. |
135 | 0 | SymbolRefAttr getSymbolRef() const { return symbolRef; } |
136 | | |
137 | | private: |
138 | | /// The operation that this access is held by. |
139 | | Operation *owner; |
140 | | |
141 | | /// The symbol reference that this use represents. |
142 | | SymbolRefAttr symbolRef; |
143 | | }; |
144 | | |
145 | | /// This class implements a range of SymbolRef uses. |
146 | | class UseRange { |
147 | | public: |
148 | 0 | UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {} |
149 | | |
150 | | using iterator = std::vector<SymbolUse>::const_iterator; |
151 | 0 | iterator begin() const { return uses.begin(); } |
152 | 0 | iterator end() const { return uses.end(); } |
153 | | |
154 | | private: |
155 | | std::vector<SymbolUse> uses; |
156 | | }; |
157 | | |
158 | | /// Get an iterator range for all of the uses, for any symbol, that are nested |
159 | | /// within the given operation 'from'. This does not traverse into any nested |
160 | | /// symbol tables. This function returns None if there are any unknown |
161 | | /// operations that may potentially be symbol tables. |
162 | | static Optional<UseRange> getSymbolUses(Operation *from); |
163 | | static Optional<UseRange> getSymbolUses(Region *from); |
164 | | |
165 | | /// Get all of the uses of the given symbol that are nested within the given |
166 | | /// operation 'from'. This does not traverse into any nested symbol tables. |
167 | | /// This function returns None if there are any unknown operations that may |
168 | | /// potentially be symbol tables. |
169 | | static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from); |
170 | | static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from); |
171 | | static Optional<UseRange> getSymbolUses(StringRef symbol, Region *from); |
172 | | static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from); |
173 | | |
174 | | /// Return if the given symbol is known to have no uses that are nested |
175 | | /// within the given operation 'from'. This does not traverse into any nested |
176 | | /// symbol tables. This function will also return false if there are any |
177 | | /// unknown operations that may potentially be symbol tables. This doesn't |
178 | | /// necessarily mean that there are no uses, we just can't conservatively |
179 | | /// prove it. |
180 | | static bool symbolKnownUseEmpty(StringRef symbol, Operation *from); |
181 | | static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); |
182 | | static bool symbolKnownUseEmpty(StringRef symbol, Region *from); |
183 | | static bool symbolKnownUseEmpty(Operation *symbol, Region *from); |
184 | | |
185 | | /// Attempt to replace all uses of the given symbol 'oldSymbol' with the |
186 | | /// provided symbol 'newSymbol' that are nested within the given operation |
187 | | /// 'from'. This does not traverse into any nested symbol tables. If there are |
188 | | /// any unknown operations that may potentially be symbol tables, no uses are |
189 | | /// replaced and failure is returned. |
190 | | LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol, |
191 | | StringRef newSymbol, |
192 | | Operation *from); |
193 | | LLVM_NODISCARD static LogicalResult |
194 | | replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName, |
195 | | Operation *from); |
196 | | LLVM_NODISCARD static LogicalResult |
197 | | replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, Region *from); |
198 | | LLVM_NODISCARD static LogicalResult |
199 | | replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName, |
200 | | Region *from); |
201 | | |
202 | | private: |
203 | | Operation *symbolTableOp; |
204 | | |
205 | | /// This is a mapping from a name to the symbol with that name. |
206 | | llvm::StringMap<Operation *> symbolTable; |
207 | | |
208 | | /// This is used when name conflicts are detected. |
209 | | unsigned uniquingCounter = 0; |
210 | | }; |
211 | | |
212 | | //===----------------------------------------------------------------------===// |
213 | | // SymbolTable Trait Types |
214 | | //===----------------------------------------------------------------------===// |
215 | | |
216 | | namespace detail { |
217 | | LogicalResult verifySymbolTable(Operation *op); |
218 | | LogicalResult verifySymbol(Operation *op); |
219 | | } // namespace detail |
220 | | |
221 | | namespace OpTrait { |
222 | | /// A trait used to provide symbol table functionalities to a region operation. |
223 | | /// This operation must hold exactly 1 region. Once attached, all operations |
224 | | /// that are directly within the region, i.e not including those within child |
225 | | /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will |
226 | | /// be verified to ensure that the names are uniqued. These operations must also |
227 | | /// adhere to the constraints defined by the `Symbol` trait, even if they do not |
228 | | /// inherit from it. |
229 | | template <typename ConcreteType> |
230 | | class SymbolTable : public TraitBase<ConcreteType, SymbolTable> { |
231 | | public: |
232 | 0 | static LogicalResult verifyTrait(Operation *op) { |
233 | 0 | return ::mlir::detail::verifySymbolTable(op); |
234 | 0 | } |
235 | | |
236 | | /// Look up a symbol with the specified name, returning null if no such |
237 | | /// name exists. Symbol names never include the @ on them. Note: This |
238 | | /// performs a linear scan of held symbols. |
239 | 0 | Operation *lookupSymbol(StringRef name) { |
240 | 0 | return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); |
241 | 0 | } |
242 | 0 | template <typename T> T lookupSymbol(StringRef name) { |
243 | 0 | return dyn_cast_or_null<T>(lookupSymbol(name)); |
244 | 0 | } |
245 | | Operation *lookupSymbol(SymbolRefAttr symbol) { |
246 | | return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol); |
247 | | } |
248 | | template <typename T> |
249 | | T lookupSymbol(SymbolRefAttr symbol) { |
250 | | return dyn_cast_or_null<T>(lookupSymbol(symbol)); |
251 | | } |
252 | | }; |
253 | | |
254 | | } // end namespace OpTrait |
255 | | |
256 | | /// Include the generated symbol interfaces. |
257 | | #include "mlir/IR/SymbolInterfaces.h.inc" |
258 | | |
259 | | } // end namespace mlir |
260 | | |
261 | | #endif // MLIR_IR_SYMBOLTABLE_H |