Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/include/mlir/IR/DialectHooks.h
Line
Count
Source (jump to first uncovered line)
1
//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file defines abstraction and registration mechanism for dialect hooks.
10
//
11
//===----------------------------------------------------------------------===//
12
13
#ifndef MLIR_IR_DIALECT_HOOKS_H
14
#define MLIR_IR_DIALECT_HOOKS_H
15
16
#include "mlir/IR/Dialect.h"
17
#include "llvm/Support/raw_ostream.h"
18
19
namespace mlir {
20
using DialectHooksSetter = std::function<void(MLIRContext *)>;
21
22
/// Dialect hooks allow external components to register their functions to
23
/// be called for specific tasks specialized per dialect, such as decoding
24
/// of opaque constants. To register concrete dialect hooks, one should
25
/// define a DialectHooks subclass and use it as a template
26
/// argument to DialectHooksRegistration. For example,
27
///     class MyHooks : public DialectHooks {...};
28
///     static DialectHooksRegistration<MyHooks, MyDialect> hooksReg;
29
/// The subclass should override DialectHook methods for supported hooks.
30
class DialectHooks {
31
public:
32
  // Returns hook to constant fold an operation.
33
0
  DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
34
  // Returns hook to decode opaque constant tensor.
35
0
  DialectConstantDecodeHook getDecodeHook() { return nullptr; }
36
  // Returns hook to extract an element of an opaque constant tensor.
37
0
  DialectExtractElementHook getExtractElementHook() { return nullptr; }
38
39
private:
40
  /// Registers a function that will set hooks in the registered dialects.
41
  /// Registrations are deduplicated by dialect TypeID and only the first
42
  /// registration will be used.
43
  static void registerDialectHooksSetter(TypeID typeID,
44
                                         const DialectHooksSetter &function);
45
  template <typename ConcreteHooks>
46
  friend void registerDialectHooks(StringRef dialectName);
47
};
48
49
void registerDialectHooksSetter(TypeID typeID,
50
                                const DialectHooksSetter &function);
51
52
/// Utility to register dialect hooks. Client can register their dialect hooks
53
/// with the global registry by calling
54
/// registerDialectHooks<MyHooks>("dialect_namespace");
55
template <typename ConcreteHooks>
56
void registerDialectHooks(StringRef dialectName) {
57
  DialectHooks::registerDialectHooksSetter(
58
      TypeID::get<ConcreteHooks>(), [dialectName](MLIRContext *ctx) {
59
        Dialect *dialect = ctx->getRegisteredDialect(dialectName);
60
        if (!dialect) {
61
          llvm::errs() << "error: cannot register hooks for unknown dialect '"
62
                       << dialectName << "'\n";
63
          abort();
64
        }
65
        // Set hooks.
66
        ConcreteHooks hooks;
67
        if (auto h = hooks.getConstantFoldHook())
68
          dialect->constantFoldHook = h;
69
        if (auto h = hooks.getDecodeHook())
70
          dialect->decodeHook = h;
71
        if (auto h = hooks.getExtractElementHook())
72
          dialect->extractElementHook = h;
73
      });
74
}
75
76
/// DialectHooksRegistration provides a global initializer that registers
77
/// a dialect hooks setter routine.
78
/// Usage:
79
///
80
///   // At namespace scope.
81
///   static DialectHooksRegistration<MyHooks> Unused("dialect_namespace");
82
template <typename ConcreteHooks> struct DialectHooksRegistration {
83
  DialectHooksRegistration(StringRef dialectName) {
84
    registerDialectHooks<ConcreteHooks>(dialectName);
85
  }
86
};
87
88
} // namespace mlir
89
90
#endif