Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/include/mlir/IR/Dialect.h
Line
Count
Source (jump to first uncovered line)
1
//===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file defines the 'dialect' abstraction.
10
//
11
//===----------------------------------------------------------------------===//
12
13
#ifndef MLIR_IR_DIALECT_H
14
#define MLIR_IR_DIALECT_H
15
16
#include "mlir/IR/OperationSupport.h"
17
18
namespace mlir {
19
class DialectAsmParser;
20
class DialectAsmPrinter;
21
class DialectInterface;
22
class OpBuilder;
23
class Type;
24
25
using DialectConstantDecodeHook =
26
    std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>;
27
using DialectConstantFoldHook = std::function<LogicalResult(
28
    Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
29
using DialectExtractElementHook =
30
    std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
31
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
32
33
/// Dialects are groups of MLIR operations and behavior associated with the
34
/// entire group.  For example, hooks into other systems for constant folding,
35
/// default named types for asm printing, etc.
36
///
37
/// Instances of the dialect object are global across all MLIRContext's that may
38
/// be active in the process.
39
///
40
class Dialect {
41
public:
42
  virtual ~Dialect();
43
44
  /// Utility function that returns if the given string is a valid dialect
45
  /// namespace.
46
  static bool isValidNamespace(StringRef str);
47
48
0
  MLIRContext *getContext() const { return context; }
49
50
0
  StringRef getNamespace() const { return name; }
51
52
  /// Returns true if this dialect allows for unregistered operations, i.e.
53
  /// operations prefixed with the dialect namespace but not registered with
54
  /// addOperation.
55
0
  bool allowsUnknownOperations() const { return unknownOpsAllowed; }
56
57
  /// Return true if this dialect allows for unregistered types, i.e., types
58
  /// prefixed with the dialect namespace but not registered with addType.
59
  /// These are represented with OpaqueType.
60
0
  bool allowsUnknownTypes() const { return unknownTypesAllowed; }
61
62
  //===--------------------------------------------------------------------===//
63
  // Constant Hooks
64
  //===--------------------------------------------------------------------===//
65
66
  /// Registered fallback constant fold hook for the dialect. Like the constant
67
  /// fold hook of each operation, it attempts to constant fold the operation
68
  /// with the specified constant operand values - the elements in "operands"
69
  /// will correspond directly to the operands of the operation, but may be null
70
  /// if non-constant.  If constant folding is successful, this fills in the
71
  /// `results` vector.  If not, this returns failure and `results` is
72
  /// unspecified.
73
  DialectConstantFoldHook constantFoldHook =
74
      [](Operation *op, ArrayRef<Attribute> operands,
75
0
         SmallVectorImpl<Attribute> &results) { return failure(); };
76
77
  /// Registered hook to decode opaque constants associated with this
78
  /// dialect. The hook function attempts to decode an opaque constant tensor
79
  /// into a tensor with non-opaque content. If decoding is successful, this
80
  /// method returns false and sets 'output' attribute. If not, it returns true
81
  /// and leaves 'output' unspecified. The default hook fails to decode.
82
  DialectConstantDecodeHook decodeHook =
83
0
      [](const OpaqueElementsAttr input, ElementsAttr &output) { return true; };
84
85
  /// Registered hook to extract an element from an opaque constant associated
86
  /// with this dialect. If element has been successfully extracted, this
87
  /// method returns that element. If not, it returns an empty attribute.
88
  /// The default hook fails to extract an element.
89
  DialectExtractElementHook extractElementHook =
90
0
      [](const OpaqueElementsAttr input, ArrayRef<uint64_t> index) {
91
0
        return Attribute();
92
0
      };
93
94
  /// Registered hook to materialize a single constant operation from a given
95
  /// attribute value with the desired resultant type. This method should use
96
  /// the provided builder to create the operation without changing the
97
  /// insertion position. The generated operation is expected to be constant
98
  /// like, i.e. single result, zero operands, non side-effecting, etc. On
99
  /// success, this hook should return the value generated to represent the
100
  /// constant value. Otherwise, it should return null on failure.
101
  virtual Operation *materializeConstant(OpBuilder &builder, Attribute value,
102
0
                                         Type type, Location loc) {
103
0
    return nullptr;
104
0
  }
105
106
  //===--------------------------------------------------------------------===//
107
  // Parsing Hooks
108
  //===--------------------------------------------------------------------===//
109
110
  /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
111
  /// refers to the expected type of the attribute.
112
  virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
113
114
  /// Print an attribute registered to this dialect. Note: The type of the
115
  /// attribute need not be printed by this method as it is always printed by
116
  /// the caller.
117
0
  virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
118
0
    llvm_unreachable("dialect has no registered attribute printing hook");
119
0
  }
120
121
  /// Parse a type registered to this dialect.
122
  virtual Type parseType(DialectAsmParser &parser) const;
123
124
  /// Print a type registered to this dialect.
125
0
  virtual void printType(Type, DialectAsmPrinter &) const {
126
0
    llvm_unreachable("dialect has no registered type printing hook");
127
0
  }
128
129
  //===--------------------------------------------------------------------===//
130
  // Verification Hooks
131
  //===--------------------------------------------------------------------===//
132
133
  /// Verify an attribute from this dialect on the argument at 'argIndex' for
134
  /// the region at 'regionIndex' on the given operation. Returns failure if
135
  /// the verification failed, success otherwise. This hook may optionally be
136
  /// invoked from any operation containing a region.
137
  virtual LogicalResult verifyRegionArgAttribute(Operation *,
138
                                                 unsigned regionIndex,
139
                                                 unsigned argIndex,
140
                                                 NamedAttribute);
141
142
  /// Verify an attribute from this dialect on the result at 'resultIndex' for
143
  /// the region at 'regionIndex' on the given operation. Returns failure if
144
  /// the verification failed, success otherwise. This hook may optionally be
145
  /// invoked from any operation containing a region.
146
  virtual LogicalResult verifyRegionResultAttribute(Operation *,
147
                                                    unsigned regionIndex,
148
                                                    unsigned resultIndex,
149
                                                    NamedAttribute);
150
151
  /// Verify an attribute from this dialect on the given operation. Returns
152
  /// failure if the verification failed, success otherwise.
153
0
  virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) {
154
0
    return success();
155
0
  }
156
157
  //===--------------------------------------------------------------------===//
158
  // Interfaces
159
  //===--------------------------------------------------------------------===//
160
161
  /// Lookup an interface for the given ID if one is registered, otherwise
162
  /// nullptr.
163
0
  const DialectInterface *getRegisteredInterface(TypeID interfaceID) {
164
0
    auto it = registeredInterfaces.find(interfaceID);
165
0
    return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
166
0
  }
167
  template <typename InterfaceT> const InterfaceT *getRegisteredInterface() {
168
    return static_cast<const InterfaceT *>(
169
        getRegisteredInterface(InterfaceT::getInterfaceID()));
170
  }
171
172
protected:
173
  /// The constructor takes a unique namespace for this dialect as well as the
174
  /// context to bind to.
175
  /// Note: The namespace must not contain '.' characters.
176
  /// Note: All operations belonging to this dialect must have names starting
177
  ///       with the namespace followed by '.'.
178
  /// Example:
179
  ///       - "tf" for the TensorFlow ops like "tf.add".
180
  Dialect(StringRef name, MLIRContext *context);
181
182
  /// This method is used by derived classes to add their operations to the set.
183
  ///
184
0
  template <typename... Args> void addOperations() {
185
0
    (void)std::initializer_list<int>{
186
0
        0, (addOperation(AbstractOperation::get<Args>(*this)), 0)...};
187
0
  }
Unexecuted instantiation: _ZN4mlir7Dialect13addOperationsIJNS_16AffineDmaStartOpENS_15AffineDmaWaitOpENS_13AffineApplyOpENS_11AffineForOpENS_10AffineIfOpENS_12AffineLoadOpENS_11AffineMaxOpENS_11AffineMinOpENS_16AffineParallelOpENS_16AffinePrefetchOpENS_13AffineStoreOpENS_18AffineTerminatorOpENS_18AffineVectorLoadOpENS_19AffineVectorStoreOpEEEEvv
Unexecuted instantiation: _ZN4mlir7Dialect13addOperationsIJNS_10DmaStartOpENS_9DmaWaitOpENS_6AbsFOpENS_7AddCFOpENS_6AddFOpENS_6AddIOpENS_7AllocOpENS_8AllocaOpENS_5AndOpENS_17AssumeAlignmentOpENS_11AtomicRMWOpENS_13AtomicYieldOpENS_8BranchOpENS_14CallIndirectOpENS_6CallOpENS_7CeilFOpENS_6CmpFOpENS_6CmpIOpENS_12CondBranchOpENS_10ConstantOpENS_10CopySignOpENS_5CosOpENS_15CreateComplexOpENS_9DeallocOpENS_5DimOpENS_6DivFOpENS_6Exp2OpENS_5ExpOpENS_16ExtractElementOpENS_7FPExtOpENS_8FPToSIOpENS_9FPTruncOpENS_18GenericAtomicRMWOpENS_4ImOpENS_11IndexCastOpENS_6LoadOpENS_7Log10OpENS_6Log2OpENS_5LogOpENS_12MemRefCastOpENS_6MulFOpENS_6MulIOpENS_6NegFOpENS_4OrOpENS_10PrefetchOpENS_6RankOpENS_4ReOpENS_6RemFOpENS_8ReturnOpENS_7RsqrtOpENS_8SIToFPOpENS_8SelectOpENS_11ShiftLeftOpENS_13SignExtendIOpENS_12SignedDivIOpENS_12SignedRemIOpENS_18SignedShiftRightOpENS_5SinOpENS_7SplatOpENS_6SqrtOpENS_7StoreOpENS_7SubCFOpENS_6SubFOpENS_6SubIOpENS_9SubViewOpENS_6TanhOpENS_12TensorCastOpENS_20TensorFromElementsOpENS_12TensorLoadOpENS_13TensorStoreOpENS_11TruncateIOpENS_14UnsignedDivIOpENS_14UnsignedRemIOpENS_20UnsignedShiftRightOpENS_6ViewOpENS_5XOrOpENS_13ZeroExtendIOpEEEEvv
Unexecuted instantiation: _ZN4mlir7Dialect13addOperationsIJNS_6FuncOpENS_8ModuleOpENS_18ModuleTerminatorOpEEEEvv
188
189
  void addOperation(AbstractOperation opInfo);
190
191
  /// This method is used by derived classes to add their types to the set.
192
0
  template <typename... Args> void addTypes() {
193
0
    (void)std::initializer_list<int>{0, (addSymbol(Args::getTypeID()), 0)...};
194
0
  }
195
196
  /// This method is used by derived classes to add their attributes to the set.
197
0
  template <typename... Args> void addAttributes() {
198
0
    (void)std::initializer_list<int>{0, (addSymbol(Args::getTypeID()), 0)...};
199
0
  }
Unexecuted instantiation: _ZN4mlir7Dialect13addAttributesIJNS_13AffineMapAttrENS_9ArrayAttrENS_8BoolAttrENS_24DenseIntOrFPElementsAttrENS_23DenseStringElementsAttrENS_14DictionaryAttrENS_9FloatAttrENS_13SymbolRefAttrENS_11IntegerAttrENS_14IntegerSetAttrENS_10OpaqueAttrENS_18OpaqueElementsAttrENS_18SparseElementsAttrENS_10StringAttrENS_8TypeAttrENS_8UnitAttrEEEEvv
Unexecuted instantiation: _ZN4mlir7Dialect13addAttributesIJNS_11CallSiteLocENS_14FileLineColLocENS_8FusedLocENS_7NameLocENS_9OpaqueLocENS_10UnknownLocEEEEvv
200
201
  /// Enable support for unregistered operations.
202
0
  void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
203
204
  /// Enable support for unregistered types.
205
0
  void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
206
207
  /// Register a dialect interface with this dialect instance.
208
  void addInterface(std::unique_ptr<DialectInterface> interface);
209
210
  /// Register a set of dialect interfaces with this dialect instance.
211
0
  template <typename... Args> void addInterfaces() {
212
0
    (void)std::initializer_list<int>{
213
0
        0, (addInterface(std::make_unique<Args>(this)), 0)...};
214
0
  }
Unexecuted instantiation: AffineOps.cpp:_ZN4mlir7Dialect13addInterfacesIJN12_GLOBAL__N_122AffineInlinerInterfaceEEEEvv
Unexecuted instantiation: Ops.cpp:_ZN4mlir7Dialect13addInterfacesIJN12_GLOBAL__N_119StdInlinerInterfaceEEEEvv
215
216
private:
217
  // Register a symbol(e.g. type) with its given unique class identifier.
218
  void addSymbol(TypeID typeID);
219
220
  Dialect(const Dialect &) = delete;
221
  void operator=(Dialect &) = delete;
222
223
  /// Register this dialect object with the specified context.  The context
224
  /// takes ownership of the heap allocated dialect.
225
  void registerDialect(MLIRContext *context);
226
227
  /// The namespace of this dialect.
228
  StringRef name;
229
230
  /// This is the context that owns this Dialect object.
231
  MLIRContext *context;
232
233
  /// Flag that specifies whether this dialect supports unregistered operations,
234
  /// i.e. operations prefixed with the dialect namespace but not registered
235
  /// with addOperation.
236
  bool unknownOpsAllowed = false;
237
238
  /// Flag that specifies whether this dialect allows unregistered types, i.e.
239
  /// types prefixed with the dialect namespace but not registered with addType.
240
  /// These types are represented with OpaqueType.
241
  bool unknownTypesAllowed = false;
242
243
  /// A collection of registered dialect interfaces.
244
  DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
245
246
  /// Registers a specific dialect creation function with the global registry.
247
  /// Used through the registerDialect template.
248
  /// Registrations are deduplicated by dialect TypeID and only the first
249
  /// registration will be used.
250
  static void
251
  registerDialectAllocator(TypeID typeID,
252
                           const DialectAllocatorFunction &function);
253
  template <typename ConcreteDialect>
254
  friend void registerDialect();
255
};
256
/// Registers all dialects and hooks from the global registries with the
257
/// specified MLIRContext.
258
void registerAllDialects(MLIRContext *context);
259
260
/// Utility to register a dialect. Client can register their dialect with the
261
/// global registry by calling registerDialect<MyDialect>();
262
template <typename ConcreteDialect> void registerDialect() {
263
  Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(),
264
                                    [](MLIRContext *ctx) {
265
                                      // Just allocate the dialect, the context
266
                                      // takes ownership of it.
267
                                      new ConcreteDialect(ctx);
268
                                    });
269
}
270
271
/// DialectRegistration provides a global initializer that registers a Dialect
272
/// allocation routine.
273
///
274
/// Usage:
275
///
276
///   // At namespace scope.
277
///   static DialectRegistration<MyDialect> Unused;
278
template <typename ConcreteDialect> struct DialectRegistration {
279
  DialectRegistration() { registerDialect<ConcreteDialect>(); }
280
};
281
282
} // namespace mlir
283
284
namespace llvm {
285
/// Provide isa functionality for Dialects.
286
template <typename T>
287
struct isa_impl<T, ::mlir::Dialect> {
288
  static inline bool doit(const ::mlir::Dialect &dialect) {
289
    return T::getDialectNamespace() == dialect.getNamespace();
290
  }
291
};
292
} // namespace llvm
293
294
#endif