/home/arjun/llvm-project/mlir/include/mlir/IR/DialectImplementation.h
Line | Count | Source (jump to first uncovered line) |
1 | | //===- DialectImplementation.h ----------------------------------*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This file contains utilities classes for implementing dialect attributes and |
10 | | // types. |
11 | | // |
12 | | //===----------------------------------------------------------------------===// |
13 | | |
14 | | #ifndef MLIR_IR_DIALECTIMPLEMENTATION_H |
15 | | #define MLIR_IR_DIALECTIMPLEMENTATION_H |
16 | | |
17 | | #include "mlir/IR/OpImplementation.h" |
18 | | #include "llvm/ADT/Twine.h" |
19 | | #include "llvm/Support/SMLoc.h" |
20 | | #include "llvm/Support/raw_ostream.h" |
21 | | |
22 | | namespace mlir { |
23 | | |
24 | | class Builder; |
25 | | |
26 | | //===----------------------------------------------------------------------===// |
27 | | // DialectAsmPrinter |
28 | | //===----------------------------------------------------------------------===// |
29 | | |
30 | | /// This is a pure-virtual base class that exposes the asmprinter hooks |
31 | | /// necessary to implement a custom printAttribute/printType() method on a |
32 | | /// dialect. |
33 | | class DialectAsmPrinter { |
34 | | public: |
35 | 0 | DialectAsmPrinter() {} |
36 | | virtual ~DialectAsmPrinter(); |
37 | | virtual raw_ostream &getStream() const = 0; |
38 | | |
39 | | /// Print the given attribute to the stream. |
40 | | virtual void printAttribute(Attribute attr) = 0; |
41 | | |
42 | | /// Print the given floating point value in a stabilized form that can be |
43 | | /// roundtripped through the IR. This is the companion to the 'parseFloat' |
44 | | /// hook on the DialectAsmParser. |
45 | | virtual void printFloat(const APFloat &value) = 0; |
46 | | |
47 | | /// Print the given type to the stream. |
48 | | virtual void printType(Type type) = 0; |
49 | | |
50 | | private: |
51 | | DialectAsmPrinter(const DialectAsmPrinter &) = delete; |
52 | | void operator=(const DialectAsmPrinter &) = delete; |
53 | | }; |
54 | | |
55 | | // Make the implementations convenient to use. |
56 | 0 | inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) { |
57 | 0 | p.printAttribute(attr); |
58 | 0 | return p; |
59 | 0 | } |
60 | | |
61 | | inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, |
62 | 0 | const APFloat &value) { |
63 | 0 | p.printFloat(value); |
64 | 0 | return p; |
65 | 0 | } |
66 | 0 | inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) { |
67 | 0 | return p << APFloat(value); |
68 | 0 | } |
69 | 0 | inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) { |
70 | 0 | return p << APFloat(value); |
71 | 0 | } |
72 | | |
73 | 0 | inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) { |
74 | 0 | p.printType(type); |
75 | 0 | return p; |
76 | 0 | } |
77 | | |
78 | | // Support printing anything that isn't convertible to one of the above types, |
79 | | // even if it isn't exactly one of them. For example, we want to print |
80 | | // FunctionType with the Type version above, not have it match this. |
81 | | template <typename T, typename std::enable_if< |
82 | | !std::is_convertible<T &, Attribute &>::value && |
83 | | !std::is_convertible<T &, Type &>::value && |
84 | | !std::is_convertible<T &, APFloat &>::value && |
85 | | !llvm::is_one_of<T, double, float>::value, |
86 | | T>::type * = nullptr> |
87 | | inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) { |
88 | | p.getStream() << other; |
89 | | return p; |
90 | | } |
91 | | |
92 | | //===----------------------------------------------------------------------===// |
93 | | // DialectAsmParser |
94 | | //===----------------------------------------------------------------------===// |
95 | | |
96 | | /// The DialectAsmParser has methods for interacting with the asm parser: |
97 | | /// parsing things from it, emitting errors etc. It has an intentionally |
98 | | /// high-level API that is designed to reduce/constrain syntax innovation in |
99 | | /// individual attributes or types. |
100 | | class DialectAsmParser { |
101 | | public: |
102 | | virtual ~DialectAsmParser(); |
103 | | |
104 | | /// Emit a diagnostic at the specified location and return failure. |
105 | | virtual InFlightDiagnostic emitError(llvm::SMLoc loc, |
106 | | const Twine &message = {}) = 0; |
107 | | |
108 | | /// Return a builder which provides useful access to MLIRContext, global |
109 | | /// objects like types and attributes. |
110 | | virtual Builder &getBuilder() const = 0; |
111 | | |
112 | | /// Get the location of the next token and store it into the argument. This |
113 | | /// always succeeds. |
114 | | virtual llvm::SMLoc getCurrentLocation() = 0; |
115 | 0 | ParseResult getCurrentLocation(llvm::SMLoc *loc) { |
116 | 0 | *loc = getCurrentLocation(); |
117 | 0 | return success(); |
118 | 0 | } |
119 | | |
120 | | /// Return the location of the original name token. |
121 | | virtual llvm::SMLoc getNameLoc() const = 0; |
122 | | |
123 | | /// Re-encode the given source location as an MLIR location and return it. |
124 | | virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0; |
125 | | |
126 | | /// Returns the full specification of the symbol being parsed. This allows for |
127 | | /// using a separate parser if necessary. |
128 | | virtual StringRef getFullSymbolSpec() const = 0; |
129 | | |
130 | | // These methods emit an error and return failure or success. This allows |
131 | | // these to be chained together into a linear sequence of || expressions in |
132 | | // many cases. |
133 | | |
134 | | /// Parse a floating point value from the stream. |
135 | | virtual ParseResult parseFloat(double &result) = 0; |
136 | | |
137 | | /// Parse an integer value from the stream. |
138 | | template <typename IntT> ParseResult parseInteger(IntT &result) { |
139 | | auto loc = getCurrentLocation(); |
140 | | OptionalParseResult parseResult = parseOptionalInteger(result); |
141 | | if (!parseResult.hasValue()) |
142 | | return emitError(loc, "expected integer value"); |
143 | | return *parseResult; |
144 | | } |
145 | | |
146 | | /// Parse an optional integer value from the stream. |
147 | | virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0; |
148 | | |
149 | | template <typename IntT> |
150 | | OptionalParseResult parseOptionalInteger(IntT &result) { |
151 | | auto loc = getCurrentLocation(); |
152 | | |
153 | | // Parse the unsigned variant. |
154 | | uint64_t uintResult; |
155 | | OptionalParseResult parseResult = parseOptionalInteger(uintResult); |
156 | | if (!parseResult.hasValue() || failed(*parseResult)) |
157 | | return parseResult; |
158 | | |
159 | | // Try to convert to the provided integer type. |
160 | | result = IntT(uintResult); |
161 | | if (uint64_t(result) != uintResult) |
162 | | return emitError(loc, "integer value too large"); |
163 | | return success(); |
164 | | } |
165 | | |
166 | | //===--------------------------------------------------------------------===// |
167 | | // Token Parsing |
168 | | //===--------------------------------------------------------------------===// |
169 | | |
170 | | /// Parse a '->' token. |
171 | | virtual ParseResult parseArrow() = 0; |
172 | | |
173 | | /// Parse a '->' token if present |
174 | | virtual ParseResult parseOptionalArrow() = 0; |
175 | | |
176 | | /// Parse a '{' token. |
177 | | virtual ParseResult parseLBrace() = 0; |
178 | | |
179 | | /// Parse a '{' token if present |
180 | | virtual ParseResult parseOptionalLBrace() = 0; |
181 | | |
182 | | /// Parse a `}` token. |
183 | | virtual ParseResult parseRBrace() = 0; |
184 | | |
185 | | /// Parse a `}` token if present |
186 | | virtual ParseResult parseOptionalRBrace() = 0; |
187 | | |
188 | | /// Parse a `:` token. |
189 | | virtual ParseResult parseColon() = 0; |
190 | | |
191 | | /// Parse a `:` token if present. |
192 | | virtual ParseResult parseOptionalColon() = 0; |
193 | | |
194 | | /// Parse a `,` token. |
195 | | virtual ParseResult parseComma() = 0; |
196 | | |
197 | | /// Parse a `,` token if present. |
198 | | virtual ParseResult parseOptionalComma() = 0; |
199 | | |
200 | | /// Parse a `=` token. |
201 | | virtual ParseResult parseEqual() = 0; |
202 | | |
203 | | /// Parse a given keyword. |
204 | 0 | ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { |
205 | 0 | auto loc = getCurrentLocation(); |
206 | 0 | if (parseOptionalKeyword(keyword)) |
207 | 0 | return emitError(loc, "expected '") << keyword << "'" << msg; |
208 | 0 | return success(); |
209 | 0 | } |
210 | | |
211 | | /// Parse a keyword into 'keyword'. |
212 | 0 | ParseResult parseKeyword(StringRef *keyword) { |
213 | 0 | auto loc = getCurrentLocation(); |
214 | 0 | if (parseOptionalKeyword(keyword)) |
215 | 0 | return emitError(loc, "expected valid keyword"); |
216 | 0 | return success(); |
217 | 0 | } |
218 | | |
219 | | /// Parse the given keyword if present. |
220 | | virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; |
221 | | |
222 | | /// Parse a keyword, if present, into 'keyword'. |
223 | | virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; |
224 | | |
225 | | /// Parse a '<' token. |
226 | | virtual ParseResult parseLess() = 0; |
227 | | |
228 | | /// Parse a `<` token if present. |
229 | | virtual ParseResult parseOptionalLess() = 0; |
230 | | |
231 | | /// Parse a '>' token. |
232 | | virtual ParseResult parseGreater() = 0; |
233 | | |
234 | | /// Parse a `>` token if present. |
235 | | virtual ParseResult parseOptionalGreater() = 0; |
236 | | |
237 | | /// Parse a `(` token. |
238 | | virtual ParseResult parseLParen() = 0; |
239 | | |
240 | | /// Parse a `(` token if present. |
241 | | virtual ParseResult parseOptionalLParen() = 0; |
242 | | |
243 | | /// Parse a `)` token. |
244 | | virtual ParseResult parseRParen() = 0; |
245 | | |
246 | | /// Parse a `)` token if present. |
247 | | virtual ParseResult parseOptionalRParen() = 0; |
248 | | |
249 | | /// Parse a `[` token. |
250 | | virtual ParseResult parseLSquare() = 0; |
251 | | |
252 | | /// Parse a `[` token if present. |
253 | | virtual ParseResult parseOptionalLSquare() = 0; |
254 | | |
255 | | /// Parse a `]` token. |
256 | | virtual ParseResult parseRSquare() = 0; |
257 | | |
258 | | /// Parse a `]` token if present. |
259 | | virtual ParseResult parseOptionalRSquare() = 0; |
260 | | |
261 | | /// Parse a `...` token if present; |
262 | | virtual ParseResult parseOptionalEllipsis() = 0; |
263 | | |
264 | | /// Parse a `?` token. |
265 | | virtual ParseResult parseOptionalQuestion() = 0; |
266 | | |
267 | | /// Parse a `*` token. |
268 | | virtual ParseResult parseOptionalStar() = 0; |
269 | | |
270 | | //===--------------------------------------------------------------------===// |
271 | | // Attribute Parsing |
272 | | //===--------------------------------------------------------------------===// |
273 | | |
274 | | /// Parse an arbitrary attribute and return it in result. |
275 | | virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; |
276 | | |
277 | | /// Parse an attribute of a specific kind and type. |
278 | | template <typename AttrType> |
279 | | ParseResult parseAttribute(AttrType &result, Type type = {}) { |
280 | | llvm::SMLoc loc = getCurrentLocation(); |
281 | | |
282 | | // Parse any kind of attribute. |
283 | | Attribute attr; |
284 | | if (parseAttribute(attr)) |
285 | | return failure(); |
286 | | |
287 | | // Check for the right kind of attribute. |
288 | | result = attr.dyn_cast<AttrType>(); |
289 | | if (!result) |
290 | | return emitError(loc, "invalid kind of attribute specified"); |
291 | | return success(); |
292 | | } |
293 | | |
294 | | /// Parse an affine map instance into 'map'. |
295 | | virtual ParseResult parseAffineMap(AffineMap &map) = 0; |
296 | | |
297 | | /// Parse an integer set instance into 'set'. |
298 | | virtual ParseResult printIntegerSet(IntegerSet &set) = 0; |
299 | | |
300 | | //===--------------------------------------------------------------------===// |
301 | | // Type Parsing |
302 | | //===--------------------------------------------------------------------===// |
303 | | |
304 | | /// Parse a type. |
305 | | virtual ParseResult parseType(Type &result) = 0; |
306 | | |
307 | | /// Parse a type of a specific kind, e.g. a FunctionType. |
308 | | template <typename TypeType> ParseResult parseType(TypeType &result) { |
309 | | llvm::SMLoc loc = getCurrentLocation(); |
310 | | |
311 | | // Parse any kind of type. |
312 | | Type type; |
313 | | if (parseType(type)) |
314 | | return failure(); |
315 | | |
316 | | // Check for the right kind of attribute. |
317 | | result = type.dyn_cast<TypeType>(); |
318 | | if (!result) |
319 | | return emitError(loc, "invalid kind of type specified"); |
320 | | return success(); |
321 | | } |
322 | | |
323 | | /// Parse a 'x' separated dimension list. This populates the dimension list, |
324 | | /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on |
325 | | /// `?` otherwise. |
326 | | /// |
327 | | /// dimension-list ::= (dimension `x`)* |
328 | | /// dimension ::= `?` | integer |
329 | | /// |
330 | | /// When `allowDynamic` is not set, this is used to parse: |
331 | | /// |
332 | | /// static-dimension-list ::= (integer `x`)* |
333 | | virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, |
334 | | bool allowDynamic = true) = 0; |
335 | | }; |
336 | | |
337 | | } // end namespace mlir |
338 | | |
339 | | #endif |