/home/arjun/llvm-project/mlir/include/mlir/IR/DialectHooks.h
Line | Count | Source (jump to first uncovered line) |
1 | | //===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This file defines abstraction and registration mechanism for dialect hooks. |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | |
13 | | #ifndef MLIR_IR_DIALECT_HOOKS_H |
14 | | #define MLIR_IR_DIALECT_HOOKS_H |
15 | | |
16 | | #include "mlir/IR/Dialect.h" |
17 | | #include "llvm/Support/raw_ostream.h" |
18 | | |
19 | | namespace mlir { |
20 | | using DialectHooksSetter = std::function<void(MLIRContext *)>; |
21 | | |
22 | | /// Dialect hooks allow external components to register their functions to |
23 | | /// be called for specific tasks specialized per dialect, such as decoding |
24 | | /// of opaque constants. To register concrete dialect hooks, one should |
25 | | /// define a DialectHooks subclass and use it as a template |
26 | | /// argument to DialectHooksRegistration. For example, |
27 | | /// class MyHooks : public DialectHooks {...}; |
28 | | /// static DialectHooksRegistration<MyHooks, MyDialect> hooksReg; |
29 | | /// The subclass should override DialectHook methods for supported hooks. |
30 | | class DialectHooks { |
31 | | public: |
32 | | // Returns hook to constant fold an operation. |
33 | 0 | DialectConstantFoldHook getConstantFoldHook() { return nullptr; } |
34 | | // Returns hook to decode opaque constant tensor. |
35 | 0 | DialectConstantDecodeHook getDecodeHook() { return nullptr; } |
36 | | // Returns hook to extract an element of an opaque constant tensor. |
37 | 0 | DialectExtractElementHook getExtractElementHook() { return nullptr; } |
38 | | |
39 | | private: |
40 | | /// Registers a function that will set hooks in the registered dialects. |
41 | | /// Registrations are deduplicated by dialect TypeID and only the first |
42 | | /// registration will be used. |
43 | | static void registerDialectHooksSetter(TypeID typeID, |
44 | | const DialectHooksSetter &function); |
45 | | template <typename ConcreteHooks> |
46 | | friend void registerDialectHooks(StringRef dialectName); |
47 | | }; |
48 | | |
49 | | void registerDialectHooksSetter(TypeID typeID, |
50 | | const DialectHooksSetter &function); |
51 | | |
52 | | /// Utility to register dialect hooks. Client can register their dialect hooks |
53 | | /// with the global registry by calling |
54 | | /// registerDialectHooks<MyHooks>("dialect_namespace"); |
55 | | template <typename ConcreteHooks> |
56 | | void registerDialectHooks(StringRef dialectName) { |
57 | | DialectHooks::registerDialectHooksSetter( |
58 | | TypeID::get<ConcreteHooks>(), [dialectName](MLIRContext *ctx) { |
59 | | Dialect *dialect = ctx->getRegisteredDialect(dialectName); |
60 | | if (!dialect) { |
61 | | llvm::errs() << "error: cannot register hooks for unknown dialect '" |
62 | | << dialectName << "'\n"; |
63 | | abort(); |
64 | | } |
65 | | // Set hooks. |
66 | | ConcreteHooks hooks; |
67 | | if (auto h = hooks.getConstantFoldHook()) |
68 | | dialect->constantFoldHook = h; |
69 | | if (auto h = hooks.getDecodeHook()) |
70 | | dialect->decodeHook = h; |
71 | | if (auto h = hooks.getExtractElementHook()) |
72 | | dialect->extractElementHook = h; |
73 | | }); |
74 | | } |
75 | | |
76 | | /// DialectHooksRegistration provides a global initializer that registers |
77 | | /// a dialect hooks setter routine. |
78 | | /// Usage: |
79 | | /// |
80 | | /// // At namespace scope. |
81 | | /// static DialectHooksRegistration<MyHooks> Unused("dialect_namespace"); |
82 | | template <typename ConcreteHooks> struct DialectHooksRegistration { |
83 | | DialectHooksRegistration(StringRef dialectName) { |
84 | | registerDialectHooks<ConcreteHooks>(dialectName); |
85 | | } |
86 | | }; |
87 | | |
88 | | } // namespace mlir |
89 | | |
90 | | #endif |