/home/arjun/llvm-project/mlir/include/mlir/IR/Matchers.h
Line | Count | Source (jump to first uncovered line) |
1 | | //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This file provides a simple and efficient mechanism for performing general |
10 | | // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's |
11 | | // include/llvm/IR/PatternMatch.h. |
12 | | // |
13 | | //===----------------------------------------------------------------------===// |
14 | | |
15 | | #ifndef MLIR_MATCHERS_H |
16 | | #define MLIR_MATCHERS_H |
17 | | |
18 | | #include "mlir/IR/OpDefinition.h" |
19 | | #include "mlir/IR/StandardTypes.h" |
20 | | |
21 | | namespace mlir { |
22 | | |
23 | | namespace detail { |
24 | | |
25 | | /// The matcher that matches a certain kind of Attribute and binds the value |
26 | | /// inside the Attribute. |
27 | | template < |
28 | | typename AttrClass, |
29 | | // Require AttrClass to be a derived class from Attribute and get its |
30 | | // value type |
31 | | typename ValueType = |
32 | | typename std::enable_if<std::is_base_of<Attribute, AttrClass>::value, |
33 | | AttrClass>::type::ValueType, |
34 | | // Require the ValueType is not void |
35 | | typename = typename std::enable_if<!std::is_void<ValueType>::value>::type> |
36 | | struct attr_value_binder { |
37 | | ValueType *bind_value; |
38 | | |
39 | | /// Creates a matcher instance that binds the value to bv if match succeeds. |
40 | 0 | attr_value_binder(ValueType *bv) : bind_value(bv) {} |
41 | | |
42 | 0 | bool match(const Attribute &attr) { |
43 | 0 | if (auto intAttr = attr.dyn_cast<AttrClass>()) { |
44 | 0 | *bind_value = intAttr.getValue(); |
45 | 0 | return true; |
46 | 0 | } |
47 | 0 | return false; |
48 | 0 | } |
49 | | }; |
50 | | |
51 | | /// The matcher that matches operations that have the `ConstantLike` trait. |
52 | | struct constant_op_matcher { |
53 | 0 | bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); } |
54 | | }; |
55 | | |
56 | | /// The matcher that matches operations that have the `ConstantLike` trait, and |
57 | | /// binds the folded attribute value. |
58 | | template <typename AttrT> struct constant_op_binder { |
59 | | AttrT *bind_value; |
60 | | |
61 | | /// Creates a matcher instance that binds the constant attribute value to |
62 | | /// bind_value if match succeeds. |
63 | 0 | constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {} Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_9AttributeEEC2EPS2_ Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_11IntegerAttrEEC2EPS2_ Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_13SymbolRefAttrEEC2EPS2_ |
64 | | /// Creates a matcher instance that doesn't bind if match succeeds. |
65 | | constant_op_binder() : bind_value(nullptr) {} |
66 | | |
67 | 0 | bool match(Operation *op) { |
68 | 0 | if (!op->hasTrait<OpTrait::ConstantLike>()) |
69 | 0 | return false; |
70 | 0 | |
71 | 0 | // Fold the constant to an attribute. |
72 | 0 | SmallVector<OpFoldResult, 1> foldedOp; |
73 | 0 | LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp); |
74 | 0 | (void)result; |
75 | 0 | assert(succeeded(result) && "expected ConstantLike op to be foldable"); |
76 | 0 |
|
77 | 0 | if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) { |
78 | 0 | if (bind_value) |
79 | 0 | *bind_value = attr; |
80 | 0 | return true; |
81 | 0 | } |
82 | 0 | return false; |
83 | 0 | } Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_9AttributeEE5matchEPNS_9OperationE Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_11IntegerAttrEE5matchEPNS_9OperationE Unexecuted instantiation: _ZN4mlir6detail18constant_op_binderINS_13SymbolRefAttrEE5matchEPNS_9OperationE |
84 | | }; |
85 | | |
86 | | /// The matcher that matches a constant scalar / vector splat / tensor splat |
87 | | /// integer operation and binds the constant integer value. |
88 | | struct constant_int_op_binder { |
89 | | IntegerAttr::ValueType *bind_value; |
90 | | |
91 | | /// Creates a matcher instance that binds the value to bv if match succeeds. |
92 | 0 | constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {} |
93 | | |
94 | 0 | bool match(Operation *op) { |
95 | 0 | Attribute attr; |
96 | 0 | if (!constant_op_binder<Attribute>(&attr).match(op)) |
97 | 0 | return false; |
98 | 0 | auto type = op->getResult(0).getType(); |
99 | 0 |
|
100 | 0 | if (type.isa<IntegerType>() || type.isa<IndexType>()) |
101 | 0 | return attr_value_binder<IntegerAttr>(bind_value).match(attr); |
102 | 0 | if (type.isa<VectorType>() || type.isa<RankedTensorType>()) { |
103 | 0 | if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) { |
104 | 0 | return attr_value_binder<IntegerAttr>(bind_value) |
105 | 0 | .match(splatAttr.getSplatValue()); |
106 | 0 | } |
107 | 0 | } |
108 | 0 | return false; |
109 | 0 | } |
110 | | }; |
111 | | |
112 | | /// The matcher that matches a given target constant scalar / vector splat / |
113 | | /// tensor splat integer value. |
114 | | template <int64_t TargetValue> struct constant_int_value_matcher { |
115 | 0 | bool match(Operation *op) { |
116 | 0 | APInt value; |
117 | 0 | return constant_int_op_binder(&value).match(op) && TargetValue == value; |
118 | 0 | } Unexecuted instantiation: _ZN4mlir6detail26constant_int_value_matcherILl0EE5matchEPNS_9OperationE Unexecuted instantiation: _ZN4mlir6detail26constant_int_value_matcherILl1EE5matchEPNS_9OperationE |
119 | | }; |
120 | | |
121 | | /// The matcher that matches anything except the given target constant scalar / |
122 | | /// vector splat / tensor splat integer value. |
123 | | template <int64_t TargetNotValue> struct constant_int_not_value_matcher { |
124 | 0 | bool match(Operation *op) { |
125 | 0 | APInt value; |
126 | 0 | return constant_int_op_binder(&value).match(op) && TargetNotValue != value; |
127 | 0 | } |
128 | | }; |
129 | | |
130 | | /// The matcher that matches a certain kind of op. |
131 | | template <typename OpClass> struct op_matcher { |
132 | 0 | bool match(Operation *op) { return isa<OpClass>(op); } |
133 | | }; |
134 | | |
135 | | /// Trait to check whether T provides a 'match' method with type |
136 | | /// `OperationOrValue`. |
137 | | template <typename T, typename OperationOrValue> |
138 | | using has_operation_or_value_matcher_t = |
139 | | decltype(std::declval<T>().match(std::declval<OperationOrValue>())); |
140 | | |
141 | | /// Statically switch to a Value matcher. |
142 | | template <typename MatcherClass> |
143 | | typename std::enable_if_t< |
144 | | llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass, |
145 | | Value>::value, |
146 | | bool> |
147 | | matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { |
148 | | return matcher.match(op->getOperand(idx)); |
149 | | } |
150 | | |
151 | | /// Statically switch to an Operation matcher. |
152 | | template <typename MatcherClass> |
153 | | typename std::enable_if_t< |
154 | | llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass, |
155 | | Operation *>::value, |
156 | | bool> |
157 | | matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { |
158 | | if (auto defOp = op->getOperand(idx).getDefiningOp()) |
159 | | return matcher.match(defOp); |
160 | | return false; |
161 | | } |
162 | | |
163 | | /// Terminal matcher, always returns true. |
164 | | struct AnyValueMatcher { |
165 | 0 | bool match(Value op) const { return true; } |
166 | | }; |
167 | | |
168 | | /// Binds to a specific value and matches it. |
169 | | struct PatternMatcherValue { |
170 | 0 | PatternMatcherValue(Value val) : value(val) {} |
171 | 0 | bool match(Value val) const { return val == value; } |
172 | | Value value; |
173 | | }; |
174 | | |
175 | | template <typename TupleT, class CallbackT, std::size_t... Is> |
176 | | constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, |
177 | | std::index_sequence<Is...>) { |
178 | | (void)std::initializer_list<int>{ |
179 | | 0, |
180 | | (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)), |
181 | | 0)...}; |
182 | | } |
183 | | |
184 | | template <typename... Tys, typename CallbackT> |
185 | | constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) { |
186 | | detail::enumerateImpl(tuple, std::forward<CallbackT>(callback), |
187 | | std::make_index_sequence<sizeof...(Tys)>{}); |
188 | | } |
189 | | |
190 | | /// RecursivePatternMatcher that composes. |
191 | | template <typename OpType, typename... OperandMatchers> |
192 | | struct RecursivePatternMatcher { |
193 | | RecursivePatternMatcher(OperandMatchers... matchers) |
194 | | : operandMatchers(matchers...) {} |
195 | | bool match(Operation *op) { |
196 | | if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers)) |
197 | | return false; |
198 | | bool res = true; |
199 | | enumerate(operandMatchers, [&](size_t index, auto &matcher) { |
200 | | res &= matchOperandOrValueAtIndex(op, index, matcher); |
201 | | }); |
202 | | return res; |
203 | | } |
204 | | std::tuple<OperandMatchers...> operandMatchers; |
205 | | }; |
206 | | |
207 | | } // end namespace detail |
208 | | |
209 | | /// Matches a constant foldable operation. |
210 | 0 | inline detail::constant_op_matcher m_Constant() { |
211 | 0 | return detail::constant_op_matcher(); |
212 | 0 | } |
213 | | |
214 | | /// Matches a value from a constant foldable operation and writes the value to |
215 | | /// bind_value. |
216 | | template <typename AttrT> |
217 | 0 | inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) { |
218 | 0 | return detail::constant_op_binder<AttrT>(bind_value); |
219 | 0 | } Unexecuted instantiation: _ZN4mlir10m_ConstantINS_9AttributeEEENS_6detail18constant_op_binderIT_EEPS4_ Unexecuted instantiation: _ZN4mlir10m_ConstantINS_11IntegerAttrEEENS_6detail18constant_op_binderIT_EEPS4_ Unexecuted instantiation: _ZN4mlir10m_ConstantINS_13SymbolRefAttrEEENS_6detail18constant_op_binderIT_EEPS4_ |
220 | | |
221 | | /// Matches a constant scalar / vector splat / tensor splat integer one. |
222 | 0 | inline detail::constant_int_value_matcher<1> m_One() { |
223 | 0 | return detail::constant_int_value_matcher<1>(); |
224 | 0 | } |
225 | | |
226 | | /// Matches the given OpClass. |
227 | | template <typename OpClass> inline detail::op_matcher<OpClass> m_Op() { |
228 | | return detail::op_matcher<OpClass>(); |
229 | | } |
230 | | |
231 | | /// Matches a constant scalar / vector splat / tensor splat integer zero. |
232 | 0 | inline detail::constant_int_value_matcher<0> m_Zero() { |
233 | 0 | return detail::constant_int_value_matcher<0>(); |
234 | 0 | } |
235 | | |
236 | | /// Matches a constant scalar / vector splat / tensor splat integer that is any |
237 | | /// non-zero value. |
238 | 0 | inline detail::constant_int_not_value_matcher<0> m_NonZero() { |
239 | 0 | return detail::constant_int_not_value_matcher<0>(); |
240 | 0 | } |
241 | | |
242 | | /// Entry point for matching a pattern over a Value. |
243 | | template <typename Pattern> |
244 | 0 | inline bool matchPattern(Value value, const Pattern &pattern) { |
245 | 0 | // TODO: handle other cases |
246 | 0 | if (auto *op = value.getDefiningOp()) |
247 | 0 | return const_cast<Pattern &>(pattern).match(op); |
248 | 0 | return false; |
249 | 0 | } Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail18constant_op_binderINS_9AttributeEEEEEbNS_5ValueERKT_ Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail18constant_op_binderINS_11IntegerAttrEEEEEbNS_5ValueERKT_ Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail26constant_int_value_matcherILl0EEEEEbNS_5ValueERKT_ Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail10op_matcherINS_15ConstantIndexOpEEEEEbNS_5ValueERKT_ Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail18constant_op_binderINS_13SymbolRefAttrEEEEEbNS_5ValueERKT_ Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail30constant_int_not_value_matcherILl0EEEEEbNS_5ValueERKT_ Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail22constant_int_op_binderEEEbNS_5ValueERKT_ Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail26constant_int_value_matcherILl1EEEEEbNS_5ValueERKT_ |
250 | | |
251 | | /// Entry point for matching a pattern over an Operation. |
252 | | template <typename Pattern> |
253 | 0 | inline bool matchPattern(Operation *op, const Pattern &pattern) { |
254 | 0 | return const_cast<Pattern &>(pattern).match(op); |
255 | 0 | } Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail18constant_op_binderINS_9AttributeEEEEEbPNS_9OperationERKT_ Unexecuted instantiation: _ZN4mlir12matchPatternINS_6detail19constant_op_matcherEEEbPNS_9OperationERKT_ |
256 | | |
257 | | /// Matches a constant holding a scalar/vector/tensor integer (splat) and |
258 | | /// writes the integer value to bind_value. |
259 | | inline detail::constant_int_op_binder |
260 | 0 | m_ConstantInt(IntegerAttr::ValueType *bind_value) { |
261 | 0 | return detail::constant_int_op_binder(bind_value); |
262 | 0 | } |
263 | | |
264 | | template <typename OpType, typename... Matchers> |
265 | | auto m_Op(Matchers... matchers) { |
266 | | return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...); |
267 | | } |
268 | | |
269 | | namespace matchers { |
270 | 0 | inline auto m_Any() { return detail::AnyValueMatcher(); } |
271 | 0 | inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); } |
272 | | } // namespace matchers |
273 | | |
274 | | } // end namespace mlir |
275 | | |
276 | | #endif // MLIR_MATCHERS_H |