Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/include/mlir/IR/Matchers.h
Line
Count
Source (jump to first uncovered line)
1
//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file provides a simple and efficient mechanism for performing general
10
// tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
11
// include/llvm/IR/PatternMatch.h.
12
//
13
//===----------------------------------------------------------------------===//
14
15
#ifndef MLIR_MATCHERS_H
16
#define MLIR_MATCHERS_H
17
18
#include "mlir/IR/OpDefinition.h"
19
#include "mlir/IR/StandardTypes.h"
20
21
namespace mlir {
22
23
namespace detail {
24
25
/// The matcher that matches a certain kind of Attribute and binds the value
26
/// inside the Attribute.
27
template <
28
    typename AttrClass,
29
    // Require AttrClass to be a derived class from Attribute and get its
30
    // value type
31
    typename ValueType =
32
        typename std::enable_if<std::is_base_of<Attribute, AttrClass>::value,
33
                                AttrClass>::type::ValueType,
34
    // Require the ValueType is not void
35
    typename = typename std::enable_if<!std::is_void<ValueType>::value>::type>
36
struct attr_value_binder {
37
  ValueType *bind_value;
38
39
  /// Creates a matcher instance that binds the value to bv if match succeeds.
40
0
  attr_value_binder(ValueType *bv) : bind_value(bv) {}
41
42
0
  bool match(const Attribute &attr) {
43
0
    if (auto intAttr = attr.dyn_cast<AttrClass>()) {
44
0
      *bind_value = intAttr.getValue();
45
0
      return true;
46
0
    }
47
0
    return false;
48
0
  }
49
};
50
51
/// The matcher that matches operations that have the `ConstantLike` trait.
52
struct constant_op_matcher {
53
0
  bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
54
};
55
56
/// The matcher that matches operations that have the `ConstantLike` trait, and
57
/// binds the folded attribute value.
58
template <typename AttrT> struct constant_op_binder {
59
  AttrT *bind_value;
60
61
  /// Creates a matcher instance that binds the constant attribute value to
62
  /// bind_value if match succeeds.
63
0
  constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_9AttributeEEC2EPS2_
Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_11IntegerAttrEEC2EPS2_
Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_13SymbolRefAttrEEC2EPS2_
64
  /// Creates a matcher instance that doesn't bind if match succeeds.
65
  constant_op_binder() : bind_value(nullptr) {}
66
67
0
  bool match(Operation *op) {
68
0
    if (!op->hasTrait<OpTrait::ConstantLike>())
69
0
      return false;
70
0
71
0
    // Fold the constant to an attribute.
72
0
    SmallVector<OpFoldResult, 1> foldedOp;
73
0
    LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp);
74
0
    (void)result;
75
0
    assert(succeeded(result) && "expected ConstantLike op to be foldable");
76
0
77
0
    if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
78
0
      if (bind_value)
79
0
        *bind_value = attr;
80
0
      return true;
81
0
    }
82
0
    return false;
83
0
  }
Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_9AttributeEE5matchEPNS_9OperationE
Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_11IntegerAttrEE5matchEPNS_9OperationE
Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_13SymbolRefAttrEE5matchEPNS_9OperationE
84
};
85
86
/// The matcher that matches a constant scalar / vector splat / tensor splat
87
/// integer operation and binds the constant integer value.
88
struct constant_int_op_binder {
89
  IntegerAttr::ValueType *bind_value;
90
91
  /// Creates a matcher instance that binds the value to bv if match succeeds.
92
0
  constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
93
94
0
  bool match(Operation *op) {
95
0
    Attribute attr;
96
0
    if (!constant_op_binder<Attribute>(&attr).match(op))
97
0
      return false;
98
0
    auto type = op->getResult(0).getType();
99
0
100
0
    if (type.isa<IntegerType>() || type.isa<IndexType>())
101
0
      return attr_value_binder<IntegerAttr>(bind_value).match(attr);
102
0
    if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
103
0
      if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
104
0
        return attr_value_binder<IntegerAttr>(bind_value)
105
0
            .match(splatAttr.getSplatValue());
106
0
      }
107
0
    }
108
0
    return false;
109
0
  }
110
};
111
112
/// The matcher that matches a given target constant scalar / vector splat /
113
/// tensor splat integer value.
114
template <int64_t TargetValue> struct constant_int_value_matcher {
115
0
  bool match(Operation *op) {
116
0
    APInt value;
117
0
    return constant_int_op_binder(&value).match(op) && TargetValue == value;
118
0
  }
Unexecuted instantiation: _ZN4mlir6detail26constant_int_value_matcherILl0EE5matchEPNS_9OperationE
Unexecuted instantiation: _ZN4mlir6detail26constant_int_value_matcherILl1EE5matchEPNS_9OperationE
119
};
120
121
/// The matcher that matches anything except the given target constant scalar /
122
/// vector splat / tensor splat integer value.
123
template <int64_t TargetNotValue> struct constant_int_not_value_matcher {
124
0
  bool match(Operation *op) {
125
0
    APInt value;
126
0
    return constant_int_op_binder(&value).match(op) && TargetNotValue != value;
127
0
  }
128
};
129
130
/// The matcher that matches a certain kind of op.
131
template <typename OpClass> struct op_matcher {
132
0
  bool match(Operation *op) { return isa<OpClass>(op); }
133
};
134
135
/// Trait to check whether T provides a 'match' method with type
136
/// `OperationOrValue`.
137
template <typename T, typename OperationOrValue>
138
using has_operation_or_value_matcher_t =
139
    decltype(std::declval<T>().match(std::declval<OperationOrValue>()));
140
141
/// Statically switch to a Value matcher.
142
template <typename MatcherClass>
143
typename std::enable_if_t<
144
    llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
145
                      Value>::value,
146
    bool>
147
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
148
  return matcher.match(op->getOperand(idx));
149
}
150
151
/// Statically switch to an Operation matcher.
152
template <typename MatcherClass>
153
typename std::enable_if_t<
154
    llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
155
                      Operation *>::value,
156
    bool>
157
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
158
  if (auto defOp = op->getOperand(idx).getDefiningOp())
159
    return matcher.match(defOp);
160
  return false;
161
}
162
163
/// Terminal matcher, always returns true.
164
struct AnyValueMatcher {
165
0
  bool match(Value op) const { return true; }
166
};
167
168
/// Binds to a specific value and matches it.
169
struct PatternMatcherValue {
170
0
  PatternMatcherValue(Value val) : value(val) {}
171
0
  bool match(Value val) const { return val == value; }
172
  Value value;
173
};
174
175
template <typename TupleT, class CallbackT, std::size_t... Is>
176
constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
177
                             std::index_sequence<Is...>) {
178
  (void)std::initializer_list<int>{
179
      0,
180
      (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
181
       0)...};
182
}
183
184
template <typename... Tys, typename CallbackT>
185
constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
186
  detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
187
                        std::make_index_sequence<sizeof...(Tys)>{});
188
}
189
190
/// RecursivePatternMatcher that composes.
191
template <typename OpType, typename... OperandMatchers>
192
struct RecursivePatternMatcher {
193
  RecursivePatternMatcher(OperandMatchers... matchers)
194
      : operandMatchers(matchers...) {}
195
  bool match(Operation *op) {
196
    if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
197
      return false;
198
    bool res = true;
199
    enumerate(operandMatchers, [&](size_t index, auto &matcher) {
200
      res &= matchOperandOrValueAtIndex(op, index, matcher);
201
    });
202
    return res;
203
  }
204
  std::tuple<OperandMatchers...> operandMatchers;
205
};
206
207
} // end namespace detail
208
209
/// Matches a constant foldable operation.
210
0
inline detail::constant_op_matcher m_Constant() {
211
0
  return detail::constant_op_matcher();
212
0
}
213
214
/// Matches a value from a constant foldable operation and writes the value to
215
/// bind_value.
216
template <typename AttrT>
217
0
inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
218
0
  return detail::constant_op_binder<AttrT>(bind_value);
219
0
}
Unexecuted instantiation: _ZN4mlir10m_ConstantINS_9AttributeEEENS_6detail18constant_op_binderIT_EEPS4_
Unexecuted instantiation: _ZN4mlir10m_ConstantINS_11IntegerAttrEEENS_6detail18constant_op_binderIT_EEPS4_
Unexecuted instantiation: _ZN4mlir10m_ConstantINS_13SymbolRefAttrEEENS_6detail18constant_op_binderIT_EEPS4_
220
221
/// Matches a constant scalar / vector splat / tensor splat integer one.
222
0
inline detail::constant_int_value_matcher<1> m_One() {
223
0
  return detail::constant_int_value_matcher<1>();
224
0
}
225
226
/// Matches the given OpClass.
227
template <typename OpClass> inline detail::op_matcher<OpClass> m_Op() {
228
  return detail::op_matcher<OpClass>();
229
}
230
231
/// Matches a constant scalar / vector splat / tensor splat integer zero.
232
0
inline detail::constant_int_value_matcher<0> m_Zero() {
233
0
  return detail::constant_int_value_matcher<0>();
234
0
}
235
236
/// Matches a constant scalar / vector splat / tensor splat integer that is any
237
/// non-zero value.
238
0
inline detail::constant_int_not_value_matcher<0> m_NonZero() {
239
0
  return detail::constant_int_not_value_matcher<0>();
240
0
}
241
242
/// Entry point for matching a pattern over a Value.
243
template <typename Pattern>
244
0
inline bool matchPattern(Value value, const Pattern &pattern) {
245
0
  // TODO: handle other cases
246
0
  if (auto *op = value.getDefiningOp())
247
0
    return const_cast<Pattern &>(pattern).match(op);
248
0
  return false;
249
0
}
Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail18constant_op_binderINS_9AttributeEEEEEbNS_5ValueERKT_
Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail18constant_op_binderINS_11IntegerAttrEEEEEbNS_5ValueERKT_
Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail26constant_int_value_matcherILl0EEEEEbNS_5ValueERKT_
Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail10op_matcherINS_15ConstantIndexOpEEEEEbNS_5ValueERKT_
Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail18constant_op_binderINS_13SymbolRefAttrEEEEEbNS_5ValueERKT_
Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail30constant_int_not_value_matcherILl0EEEEEbNS_5ValueERKT_
Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail22constant_int_op_binderEEEbNS_5ValueERKT_
Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail26constant_int_value_matcherILl1EEEEEbNS_5ValueERKT_
250
251
/// Entry point for matching a pattern over an Operation.
252
template <typename Pattern>
253
0
inline bool matchPattern(Operation *op, const Pattern &pattern) {
254
0
  return const_cast<Pattern &>(pattern).match(op);
255
0
}
Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail18constant_op_binderINS_9AttributeEEEEEbPNS_9OperationERKT_
Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail19constant_op_matcherEEEbPNS_9OperationERKT_
256
257
/// Matches a constant holding a scalar/vector/tensor integer (splat) and
258
/// writes the integer value to bind_value.
259
inline detail::constant_int_op_binder
260
0
m_ConstantInt(IntegerAttr::ValueType *bind_value) {
261
0
  return detail::constant_int_op_binder(bind_value);
262
0
}
263
264
template <typename OpType, typename... Matchers>
265
auto m_Op(Matchers... matchers) {
266
  return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
267
}
268
269
namespace matchers {
270
0
inline auto m_Any() { return detail::AnyValueMatcher(); }
271
0
inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); }
272
} // namespace matchers
273
274
} // end namespace mlir
275
276
#endif // MLIR_MATCHERS_H