/home/arjun/llvm-project/mlir/include/mlir/IR/Dialect.h
Line | Count | Source (jump to first uncovered line) |
1 | | //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This file defines the 'dialect' abstraction. |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | |
13 | | #ifndef MLIR_IR_DIALECT_H |
14 | | #define MLIR_IR_DIALECT_H |
15 | | |
16 | | #include "mlir/IR/OperationSupport.h" |
17 | | |
18 | | namespace mlir { |
19 | | class DialectAsmParser; |
20 | | class DialectAsmPrinter; |
21 | | class DialectInterface; |
22 | | class OpBuilder; |
23 | | class Type; |
24 | | |
25 | | using DialectConstantDecodeHook = |
26 | | std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>; |
27 | | using DialectConstantFoldHook = std::function<LogicalResult( |
28 | | Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>; |
29 | | using DialectExtractElementHook = |
30 | | std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>; |
31 | | using DialectAllocatorFunction = std::function<void(MLIRContext *)>; |
32 | | |
33 | | /// Dialects are groups of MLIR operations and behavior associated with the |
34 | | /// entire group. For example, hooks into other systems for constant folding, |
35 | | /// default named types for asm printing, etc. |
36 | | /// |
37 | | /// Instances of the dialect object are global across all MLIRContext's that may |
38 | | /// be active in the process. |
39 | | /// |
40 | | class Dialect { |
41 | | public: |
42 | | virtual ~Dialect(); |
43 | | |
44 | | /// Utility function that returns if the given string is a valid dialect |
45 | | /// namespace. |
46 | | static bool isValidNamespace(StringRef str); |
47 | | |
48 | 0 | MLIRContext *getContext() const { return context; } |
49 | | |
50 | 0 | StringRef getNamespace() const { return name; } |
51 | | |
52 | | /// Returns true if this dialect allows for unregistered operations, i.e. |
53 | | /// operations prefixed with the dialect namespace but not registered with |
54 | | /// addOperation. |
55 | 0 | bool allowsUnknownOperations() const { return unknownOpsAllowed; } |
56 | | |
57 | | /// Return true if this dialect allows for unregistered types, i.e., types |
58 | | /// prefixed with the dialect namespace but not registered with addType. |
59 | | /// These are represented with OpaqueType. |
60 | 0 | bool allowsUnknownTypes() const { return unknownTypesAllowed; } |
61 | | |
62 | | //===--------------------------------------------------------------------===// |
63 | | // Constant Hooks |
64 | | //===--------------------------------------------------------------------===// |
65 | | |
66 | | /// Registered fallback constant fold hook for the dialect. Like the constant |
67 | | /// fold hook of each operation, it attempts to constant fold the operation |
68 | | /// with the specified constant operand values - the elements in "operands" |
69 | | /// will correspond directly to the operands of the operation, but may be null |
70 | | /// if non-constant. If constant folding is successful, this fills in the |
71 | | /// `results` vector. If not, this returns failure and `results` is |
72 | | /// unspecified. |
73 | | DialectConstantFoldHook constantFoldHook = |
74 | | [](Operation *op, ArrayRef<Attribute> operands, |
75 | 0 | SmallVectorImpl<Attribute> &results) { return failure(); }; |
76 | | |
77 | | /// Registered hook to decode opaque constants associated with this |
78 | | /// dialect. The hook function attempts to decode an opaque constant tensor |
79 | | /// into a tensor with non-opaque content. If decoding is successful, this |
80 | | /// method returns false and sets 'output' attribute. If not, it returns true |
81 | | /// and leaves 'output' unspecified. The default hook fails to decode. |
82 | | DialectConstantDecodeHook decodeHook = |
83 | 0 | [](const OpaqueElementsAttr input, ElementsAttr &output) { return true; }; |
84 | | |
85 | | /// Registered hook to extract an element from an opaque constant associated |
86 | | /// with this dialect. If element has been successfully extracted, this |
87 | | /// method returns that element. If not, it returns an empty attribute. |
88 | | /// The default hook fails to extract an element. |
89 | | DialectExtractElementHook extractElementHook = |
90 | 0 | [](const OpaqueElementsAttr input, ArrayRef<uint64_t> index) { |
91 | 0 | return Attribute(); |
92 | 0 | }; |
93 | | |
94 | | /// Registered hook to materialize a single constant operation from a given |
95 | | /// attribute value with the desired resultant type. This method should use |
96 | | /// the provided builder to create the operation without changing the |
97 | | /// insertion position. The generated operation is expected to be constant |
98 | | /// like, i.e. single result, zero operands, non side-effecting, etc. On |
99 | | /// success, this hook should return the value generated to represent the |
100 | | /// constant value. Otherwise, it should return null on failure. |
101 | | virtual Operation *materializeConstant(OpBuilder &builder, Attribute value, |
102 | 0 | Type type, Location loc) { |
103 | 0 | return nullptr; |
104 | 0 | } |
105 | | |
106 | | //===--------------------------------------------------------------------===// |
107 | | // Parsing Hooks |
108 | | //===--------------------------------------------------------------------===// |
109 | | |
110 | | /// Parse an attribute registered to this dialect. If 'type' is nonnull, it |
111 | | /// refers to the expected type of the attribute. |
112 | | virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const; |
113 | | |
114 | | /// Print an attribute registered to this dialect. Note: The type of the |
115 | | /// attribute need not be printed by this method as it is always printed by |
116 | | /// the caller. |
117 | 0 | virtual void printAttribute(Attribute, DialectAsmPrinter &) const { |
118 | 0 | llvm_unreachable("dialect has no registered attribute printing hook"); |
119 | 0 | } |
120 | | |
121 | | /// Parse a type registered to this dialect. |
122 | | virtual Type parseType(DialectAsmParser &parser) const; |
123 | | |
124 | | /// Print a type registered to this dialect. |
125 | 0 | virtual void printType(Type, DialectAsmPrinter &) const { |
126 | 0 | llvm_unreachable("dialect has no registered type printing hook"); |
127 | 0 | } |
128 | | |
129 | | //===--------------------------------------------------------------------===// |
130 | | // Verification Hooks |
131 | | //===--------------------------------------------------------------------===// |
132 | | |
133 | | /// Verify an attribute from this dialect on the argument at 'argIndex' for |
134 | | /// the region at 'regionIndex' on the given operation. Returns failure if |
135 | | /// the verification failed, success otherwise. This hook may optionally be |
136 | | /// invoked from any operation containing a region. |
137 | | virtual LogicalResult verifyRegionArgAttribute(Operation *, |
138 | | unsigned regionIndex, |
139 | | unsigned argIndex, |
140 | | NamedAttribute); |
141 | | |
142 | | /// Verify an attribute from this dialect on the result at 'resultIndex' for |
143 | | /// the region at 'regionIndex' on the given operation. Returns failure if |
144 | | /// the verification failed, success otherwise. This hook may optionally be |
145 | | /// invoked from any operation containing a region. |
146 | | virtual LogicalResult verifyRegionResultAttribute(Operation *, |
147 | | unsigned regionIndex, |
148 | | unsigned resultIndex, |
149 | | NamedAttribute); |
150 | | |
151 | | /// Verify an attribute from this dialect on the given operation. Returns |
152 | | /// failure if the verification failed, success otherwise. |
153 | 0 | virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) { |
154 | 0 | return success(); |
155 | 0 | } |
156 | | |
157 | | //===--------------------------------------------------------------------===// |
158 | | // Interfaces |
159 | | //===--------------------------------------------------------------------===// |
160 | | |
161 | | /// Lookup an interface for the given ID if one is registered, otherwise |
162 | | /// nullptr. |
163 | 0 | const DialectInterface *getRegisteredInterface(TypeID interfaceID) { |
164 | 0 | auto it = registeredInterfaces.find(interfaceID); |
165 | 0 | return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr; |
166 | 0 | } |
167 | | template <typename InterfaceT> const InterfaceT *getRegisteredInterface() { |
168 | | return static_cast<const InterfaceT *>( |
169 | | getRegisteredInterface(InterfaceT::getInterfaceID())); |
170 | | } |
171 | | |
172 | | protected: |
173 | | /// The constructor takes a unique namespace for this dialect as well as the |
174 | | /// context to bind to. |
175 | | /// Note: The namespace must not contain '.' characters. |
176 | | /// Note: All operations belonging to this dialect must have names starting |
177 | | /// with the namespace followed by '.'. |
178 | | /// Example: |
179 | | /// - "tf" for the TensorFlow ops like "tf.add". |
180 | | Dialect(StringRef name, MLIRContext *context); |
181 | | |
182 | | /// This method is used by derived classes to add their operations to the set. |
183 | | /// |
184 | 0 | template <typename... Args> void addOperations() { |
185 | 0 | (void)std::initializer_list<int>{ |
186 | 0 | 0, (addOperation(AbstractOperation::get<Args>(*this)), 0)...}; |
187 | 0 | } Unexecuted instantiation: _ZN4mlir7Dialect13addOperationsIJNS_16AffineDmaStartOpENS_15AffineDmaWaitOpENS_13AffineApplyOpENS_11AffineForOpENS_10AffineIfOpENS_12AffineLoadOpENS_11AffineMaxOpENS_11AffineMinOpENS_16AffineParallelOpENS_16AffinePrefetchOpENS_13AffineStoreOpENS_18AffineTerminatorOpENS_18AffineVectorLoadOpENS_19AffineVectorStoreOpEEEEvv Unexecuted instantiation: _ZN4mlir7Dialect13addOperationsIJNS_10DmaStartOpENS_9DmaWaitOpENS_6AbsFOpENS_7AddCFOpENS_6AddFOpENS_6AddIOpENS_7AllocOpENS_8AllocaOpENS_5AndOpENS_17AssumeAlignmentOpENS_11AtomicRMWOpENS_13AtomicYieldOpENS_8BranchOpENS_14CallIndirectOpENS_6CallOpENS_7CeilFOpENS_6CmpFOpENS_6CmpIOpENS_12CondBranchOpENS_10ConstantOpENS_10CopySignOpENS_5CosOpENS_15CreateComplexOpENS_9DeallocOpENS_5DimOpENS_6DivFOpENS_6Exp2OpENS_5ExpOpENS_16ExtractElementOpENS_7FPExtOpENS_8FPToSIOpENS_9FPTruncOpENS_18GenericAtomicRMWOpENS_4ImOpENS_11IndexCastOpENS_6LoadOpENS_7Log10OpENS_6Log2OpENS_5LogOpENS_12MemRefCastOpENS_6MulFOpENS_6MulIOpENS_6NegFOpENS_4OrOpENS_10PrefetchOpENS_6RankOpENS_4ReOpENS_6RemFOpENS_8ReturnOpENS_7RsqrtOpENS_8SIToFPOpENS_8SelectOpENS_11ShiftLeftOpENS_13SignExtendIOpENS_12SignedDivIOpENS_12SignedRemIOpENS_18SignedShiftRightOpENS_5SinOpENS_7SplatOpENS_6SqrtOpENS_7StoreOpENS_7SubCFOpENS_6SubFOpENS_6SubIOpENS_9SubViewOpENS_6TanhOpENS_12TensorCastOpENS_20TensorFromElementsOpENS_12TensorLoadOpENS_13TensorStoreOpENS_11TruncateIOpENS_14UnsignedDivIOpENS_14UnsignedRemIOpENS_20UnsignedShiftRightOpENS_6ViewOpENS_5XOrOpENS_13ZeroExtendIOpEEEEvv Unexecuted instantiation: _ZN4mlir7Dialect13addOperationsIJNS_6FuncOpENS_8ModuleOpENS_18ModuleTerminatorOpEEEEvv |
188 | | |
189 | | void addOperation(AbstractOperation opInfo); |
190 | | |
191 | | /// This method is used by derived classes to add their types to the set. |
192 | 0 | template <typename... Args> void addTypes() { |
193 | 0 | (void)std::initializer_list<int>{0, (addSymbol(Args::getTypeID()), 0)...}; |
194 | 0 | } |
195 | | |
196 | | /// This method is used by derived classes to add their attributes to the set. |
197 | 0 | template <typename... Args> void addAttributes() { |
198 | 0 | (void)std::initializer_list<int>{0, (addSymbol(Args::getTypeID()), 0)...}; |
199 | 0 | } Unexecuted instantiation: _ZN4mlir7Dialect13addAttributesIJNS_13AffineMapAttrENS_9ArrayAttrENS_8BoolAttrENS_24DenseIntOrFPElementsAttrENS_23DenseStringElementsAttrENS_14DictionaryAttrENS_9FloatAttrENS_13SymbolRefAttrENS_11IntegerAttrENS_14IntegerSetAttrENS_10OpaqueAttrENS_18OpaqueElementsAttrENS_18SparseElementsAttrENS_10StringAttrENS_8TypeAttrENS_8UnitAttrEEEEvv Unexecuted instantiation: _ZN4mlir7Dialect13addAttributesIJNS_11CallSiteLocENS_14FileLineColLocENS_8FusedLocENS_7NameLocENS_9OpaqueLocENS_10UnknownLocEEEEvv |
200 | | |
201 | | /// Enable support for unregistered operations. |
202 | 0 | void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; } |
203 | | |
204 | | /// Enable support for unregistered types. |
205 | 0 | void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; } |
206 | | |
207 | | /// Register a dialect interface with this dialect instance. |
208 | | void addInterface(std::unique_ptr<DialectInterface> interface); |
209 | | |
210 | | /// Register a set of dialect interfaces with this dialect instance. |
211 | 0 | template <typename... Args> void addInterfaces() { |
212 | 0 | (void)std::initializer_list<int>{ |
213 | 0 | 0, (addInterface(std::make_unique<Args>(this)), 0)...}; |
214 | 0 | } Unexecuted instantiation: AffineOps.cpp:_ZN4mlir7Dialect13addInterfacesIJN12_GLOBAL__N_122AffineInlinerInterfaceEEEEvv Unexecuted instantiation: Ops.cpp:_ZN4mlir7Dialect13addInterfacesIJN12_GLOBAL__N_119StdInlinerInterfaceEEEEvv |
215 | | |
216 | | private: |
217 | | // Register a symbol(e.g. type) with its given unique class identifier. |
218 | | void addSymbol(TypeID typeID); |
219 | | |
220 | | Dialect(const Dialect &) = delete; |
221 | | void operator=(Dialect &) = delete; |
222 | | |
223 | | /// Register this dialect object with the specified context. The context |
224 | | /// takes ownership of the heap allocated dialect. |
225 | | void registerDialect(MLIRContext *context); |
226 | | |
227 | | /// The namespace of this dialect. |
228 | | StringRef name; |
229 | | |
230 | | /// This is the context that owns this Dialect object. |
231 | | MLIRContext *context; |
232 | | |
233 | | /// Flag that specifies whether this dialect supports unregistered operations, |
234 | | /// i.e. operations prefixed with the dialect namespace but not registered |
235 | | /// with addOperation. |
236 | | bool unknownOpsAllowed = false; |
237 | | |
238 | | /// Flag that specifies whether this dialect allows unregistered types, i.e. |
239 | | /// types prefixed with the dialect namespace but not registered with addType. |
240 | | /// These types are represented with OpaqueType. |
241 | | bool unknownTypesAllowed = false; |
242 | | |
243 | | /// A collection of registered dialect interfaces. |
244 | | DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces; |
245 | | |
246 | | /// Registers a specific dialect creation function with the global registry. |
247 | | /// Used through the registerDialect template. |
248 | | /// Registrations are deduplicated by dialect TypeID and only the first |
249 | | /// registration will be used. |
250 | | static void |
251 | | registerDialectAllocator(TypeID typeID, |
252 | | const DialectAllocatorFunction &function); |
253 | | template <typename ConcreteDialect> |
254 | | friend void registerDialect(); |
255 | | }; |
256 | | /// Registers all dialects and hooks from the global registries with the |
257 | | /// specified MLIRContext. |
258 | | void registerAllDialects(MLIRContext *context); |
259 | | |
260 | | /// Utility to register a dialect. Client can register their dialect with the |
261 | | /// global registry by calling registerDialect<MyDialect>(); |
262 | | template <typename ConcreteDialect> void registerDialect() { |
263 | | Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(), |
264 | | [](MLIRContext *ctx) { |
265 | | // Just allocate the dialect, the context |
266 | | // takes ownership of it. |
267 | | new ConcreteDialect(ctx); |
268 | | }); |
269 | | } |
270 | | |
271 | | /// DialectRegistration provides a global initializer that registers a Dialect |
272 | | /// allocation routine. |
273 | | /// |
274 | | /// Usage: |
275 | | /// |
276 | | /// // At namespace scope. |
277 | | /// static DialectRegistration<MyDialect> Unused; |
278 | | template <typename ConcreteDialect> struct DialectRegistration { |
279 | | DialectRegistration() { registerDialect<ConcreteDialect>(); } |
280 | | }; |
281 | | |
282 | | } // namespace mlir |
283 | | |
284 | | namespace llvm { |
285 | | /// Provide isa functionality for Dialects. |
286 | | template <typename T> |
287 | | struct isa_impl<T, ::mlir::Dialect> { |
288 | | static inline bool doit(const ::mlir::Dialect &dialect) { |
289 | | return T::getDialectNamespace() == dialect.getNamespace(); |
290 | | } |
291 | | }; |
292 | | } // namespace llvm |
293 | | |
294 | | #endif |