/home/arjun/llvm-project/mlir/lib/IR/FunctionImplementation.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- FunctionImplementation.cpp - Utilities for function-like ops -------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #include "mlir/IR/FunctionImplementation.h" |
10 | | #include "mlir/IR/Builders.h" |
11 | | #include "mlir/IR/FunctionSupport.h" |
12 | | #include "mlir/IR/SymbolTable.h" |
13 | | |
14 | | using namespace mlir; |
15 | | |
16 | | static ParseResult |
17 | | parseArgumentList(OpAsmParser &parser, bool allowVariadic, |
18 | | SmallVectorImpl<Type> &argTypes, |
19 | | SmallVectorImpl<OpAsmParser::OperandType> &argNames, |
20 | 0 | SmallVectorImpl<NamedAttrList> &argAttrs, bool &isVariadic) { |
21 | 0 | if (parser.parseLParen()) |
22 | 0 | return failure(); |
23 | 0 | |
24 | 0 | // The argument list either has to consistently have ssa-id's followed by |
25 | 0 | // types, or just be a type list. It isn't ok to sometimes have SSA ID's and |
26 | 0 | // sometimes not. |
27 | 0 | auto parseArgument = [&]() -> ParseResult { |
28 | 0 | llvm::SMLoc loc = parser.getCurrentLocation(); |
29 | 0 |
|
30 | 0 | // Parse argument name if present. |
31 | 0 | OpAsmParser::OperandType argument; |
32 | 0 | Type argumentType; |
33 | 0 | if (succeeded(parser.parseOptionalRegionArgument(argument)) && |
34 | 0 | !argument.name.empty()) { |
35 | 0 | // Reject this if the preceding argument was missing a name. |
36 | 0 | if (argNames.empty() && !argTypes.empty()) |
37 | 0 | return parser.emitError(loc, "expected type instead of SSA identifier"); |
38 | 0 | argNames.push_back(argument); |
39 | 0 |
|
40 | 0 | if (parser.parseColonType(argumentType)) |
41 | 0 | return failure(); |
42 | 0 | } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { |
43 | 0 | isVariadic = true; |
44 | 0 | return success(); |
45 | 0 | } else if (!argNames.empty()) { |
46 | 0 | // Reject this if the preceding argument had a name. |
47 | 0 | return parser.emitError(loc, "expected SSA identifier"); |
48 | 0 | } else if (parser.parseType(argumentType)) { |
49 | 0 | return failure(); |
50 | 0 | } |
51 | 0 | |
52 | 0 | // Add the argument type. |
53 | 0 | argTypes.push_back(argumentType); |
54 | 0 |
|
55 | 0 | // Parse any argument attributes. |
56 | 0 | NamedAttrList attrs; |
57 | 0 | if (parser.parseOptionalAttrDict(attrs)) |
58 | 0 | return failure(); |
59 | 0 | argAttrs.push_back(attrs); |
60 | 0 | return success(); |
61 | 0 | }; |
62 | 0 |
|
63 | 0 | // Parse the function arguments. |
64 | 0 | isVariadic = false; |
65 | 0 | if (failed(parser.parseOptionalRParen())) { |
66 | 0 | do { |
67 | 0 | unsigned numTypedArguments = argTypes.size(); |
68 | 0 | if (parseArgument()) |
69 | 0 | return failure(); |
70 | 0 | |
71 | 0 | llvm::SMLoc loc = parser.getCurrentLocation(); |
72 | 0 | if (argTypes.size() == numTypedArguments && |
73 | 0 | succeeded(parser.parseOptionalComma())) |
74 | 0 | return parser.emitError( |
75 | 0 | loc, "variadic arguments must be in the end of the argument list"); |
76 | 0 | } while (succeeded(parser.parseOptionalComma())); |
77 | 0 | parser.parseRParen(); |
78 | 0 | } |
79 | 0 |
|
80 | 0 | return success(); |
81 | 0 | } |
82 | | |
83 | | /// Parse a function result list. |
84 | | /// |
85 | | /// function-result-list ::= function-result-list-parens |
86 | | /// | non-function-type |
87 | | /// function-result-list-parens ::= `(` `)` |
88 | | /// | `(` function-result-list-no-parens `)` |
89 | | /// function-result-list-no-parens ::= function-result (`,` function-result)* |
90 | | /// function-result ::= type attribute-dict? |
91 | | /// |
92 | | static ParseResult |
93 | | parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes, |
94 | 0 | SmallVectorImpl<NamedAttrList> &resultAttrs) { |
95 | 0 | if (failed(parser.parseOptionalLParen())) { |
96 | 0 | // We already know that there is no `(`, so parse a type. |
97 | 0 | // Because there is no `(`, it cannot be a function type. |
98 | 0 | Type ty; |
99 | 0 | if (parser.parseType(ty)) |
100 | 0 | return failure(); |
101 | 0 | resultTypes.push_back(ty); |
102 | 0 | resultAttrs.emplace_back(); |
103 | 0 | return success(); |
104 | 0 | } |
105 | 0 | |
106 | 0 | // Special case for an empty set of parens. |
107 | 0 | if (succeeded(parser.parseOptionalRParen())) |
108 | 0 | return success(); |
109 | 0 | |
110 | 0 | // Parse individual function results. |
111 | 0 | do { |
112 | 0 | resultTypes.emplace_back(); |
113 | 0 | resultAttrs.emplace_back(); |
114 | 0 | if (parser.parseType(resultTypes.back()) || |
115 | 0 | parser.parseOptionalAttrDict(resultAttrs.back())) { |
116 | 0 | return failure(); |
117 | 0 | } |
118 | 0 | } while (succeeded(parser.parseOptionalComma())); |
119 | 0 | return parser.parseRParen(); |
120 | 0 | } |
121 | | |
122 | | /// Parses a function signature using `parser`. The `allowVariadic` argument |
123 | | /// indicates whether functions with variadic arguments are supported. The |
124 | | /// trailing arguments are populated by this function with names, types and |
125 | | /// attributes of the arguments and those of the results. |
126 | | ParseResult mlir::impl::parseFunctionSignature( |
127 | | OpAsmParser &parser, bool allowVariadic, |
128 | | SmallVectorImpl<OpAsmParser::OperandType> &argNames, |
129 | | SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs, |
130 | | bool &isVariadic, SmallVectorImpl<Type> &resultTypes, |
131 | 0 | SmallVectorImpl<NamedAttrList> &resultAttrs) { |
132 | 0 | if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs, |
133 | 0 | isVariadic)) |
134 | 0 | return failure(); |
135 | 0 | if (succeeded(parser.parseOptionalArrow())) |
136 | 0 | return parseFunctionResultList(parser, resultTypes, resultAttrs); |
137 | 0 | return success(); |
138 | 0 | } |
139 | | |
140 | | void mlir::impl::addArgAndResultAttrs(Builder &builder, OperationState &result, |
141 | | ArrayRef<NamedAttrList> argAttrs, |
142 | 0 | ArrayRef<NamedAttrList> resultAttrs) { |
143 | 0 | // Add the attributes to the function arguments. |
144 | 0 | SmallString<8> attrNameBuf; |
145 | 0 | for (unsigned i = 0, e = argAttrs.size(); i != e; ++i) |
146 | 0 | if (!argAttrs[i].empty()) |
147 | 0 | result.addAttribute(getArgAttrName(i, attrNameBuf), |
148 | 0 | builder.getDictionaryAttr(argAttrs[i])); |
149 | 0 |
|
150 | 0 | // Add the attributes to the function results. |
151 | 0 | for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i) |
152 | 0 | if (!resultAttrs[i].empty()) |
153 | 0 | result.addAttribute(getResultAttrName(i, attrNameBuf), |
154 | 0 | builder.getDictionaryAttr(resultAttrs[i])); |
155 | 0 | } |
156 | | |
157 | | /// Parser implementation for function-like operations. Uses `funcTypeBuilder` |
158 | | /// to construct the custom function type given lists of input and output types. |
159 | | ParseResult |
160 | | mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, |
161 | | bool allowVariadic, |
162 | 0 | mlir::impl::FuncTypeBuilder funcTypeBuilder) { |
163 | 0 | SmallVector<OpAsmParser::OperandType, 4> entryArgs; |
164 | 0 | SmallVector<NamedAttrList, 4> argAttrs; |
165 | 0 | SmallVector<NamedAttrList, 4> resultAttrs; |
166 | 0 | SmallVector<Type, 4> argTypes; |
167 | 0 | SmallVector<Type, 4> resultTypes; |
168 | 0 | auto &builder = parser.getBuilder(); |
169 | 0 |
|
170 | 0 | // Parse the name as a symbol. |
171 | 0 | StringAttr nameAttr; |
172 | 0 | if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), |
173 | 0 | result.attributes)) |
174 | 0 | return failure(); |
175 | 0 | |
176 | 0 | // Parse the function signature. |
177 | 0 | auto signatureLocation = parser.getCurrentLocation(); |
178 | 0 | bool isVariadic = false; |
179 | 0 | if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, |
180 | 0 | argAttrs, isVariadic, resultTypes, resultAttrs)) |
181 | 0 | return failure(); |
182 | 0 | |
183 | 0 | std::string errorMessage; |
184 | 0 | if (auto type = funcTypeBuilder(builder, argTypes, resultTypes, |
185 | 0 | impl::VariadicFlag(isVariadic), errorMessage)) |
186 | 0 | result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); |
187 | 0 | else |
188 | 0 | return parser.emitError(signatureLocation) |
189 | 0 | << "failed to construct function type" |
190 | 0 | << (errorMessage.empty() ? "" : ": ") << errorMessage; |
191 | 0 |
|
192 | 0 | // If function attributes are present, parse them. |
193 | 0 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
194 | 0 | return failure(); |
195 | 0 | |
196 | 0 | // Add the attributes to the function arguments. |
197 | 0 | assert(argAttrs.size() == argTypes.size()); |
198 | 0 | assert(resultAttrs.size() == resultTypes.size()); |
199 | 0 | addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); |
200 | 0 |
|
201 | 0 | // Parse the optional function body. |
202 | 0 | auto *body = result.addRegion(); |
203 | 0 | return parser.parseOptionalRegion( |
204 | 0 | *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); |
205 | 0 | } |
206 | | |
207 | | // Print a function result list. |
208 | | static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types, |
209 | 0 | ArrayRef<ArrayRef<NamedAttribute>> attrs) { |
210 | 0 | assert(!types.empty() && "Should not be called for empty result list."); |
211 | 0 | auto &os = p.getStream(); |
212 | 0 | bool needsParens = |
213 | 0 | types.size() > 1 || types[0].isa<FunctionType>() || !attrs[0].empty(); |
214 | 0 | if (needsParens) |
215 | 0 | os << '('; |
216 | 0 | llvm::interleaveComma( |
217 | 0 | llvm::zip(types, attrs), os, |
218 | 0 | [&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) { |
219 | 0 | p.printType(std::get<0>(t)); |
220 | 0 | p.printOptionalAttrDict(std::get<1>(t)); |
221 | 0 | }); |
222 | 0 | if (needsParens) |
223 | 0 | os << ')'; |
224 | 0 | } |
225 | | |
226 | | /// Print the signature of the function-like operation `op`. Assumes `op` has |
227 | | /// the FunctionLike trait and passed the verification. |
228 | | void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op, |
229 | | ArrayRef<Type> argTypes, |
230 | | bool isVariadic, |
231 | 0 | ArrayRef<Type> resultTypes) { |
232 | 0 | Region &body = op->getRegion(0); |
233 | 0 | bool isExternal = body.empty(); |
234 | 0 |
|
235 | 0 | p << '('; |
236 | 0 | for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { |
237 | 0 | if (i > 0) |
238 | 0 | p << ", "; |
239 | 0 |
|
240 | 0 | if (!isExternal) { |
241 | 0 | p.printOperand(body.front().getArgument(i)); |
242 | 0 | p << ": "; |
243 | 0 | } |
244 | 0 |
|
245 | 0 | p.printType(argTypes[i]); |
246 | 0 | p.printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i)); |
247 | 0 | } |
248 | 0 |
|
249 | 0 | if (isVariadic) { |
250 | 0 | if (!argTypes.empty()) |
251 | 0 | p << ", "; |
252 | 0 | p << "..."; |
253 | 0 | } |
254 | 0 |
|
255 | 0 | p << ')'; |
256 | 0 |
|
257 | 0 | if (!resultTypes.empty()) { |
258 | 0 | p.getStream() << " -> "; |
259 | 0 | SmallVector<ArrayRef<NamedAttribute>, 4> resultAttrs; |
260 | 0 | for (int i = 0, e = resultTypes.size(); i < e; ++i) |
261 | 0 | resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i)); |
262 | 0 | printFunctionResultList(p, resultTypes, resultAttrs); |
263 | 0 | } |
264 | 0 | } |
265 | | |
266 | | /// Prints the list of function prefixed with the "attributes" keyword. The |
267 | | /// attributes with names listed in "elided" as well as those used by the |
268 | | /// function-like operation internally are not printed. Nothing is printed |
269 | | /// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and |
270 | | /// passed the verification. |
271 | | void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op, |
272 | | unsigned numInputs, |
273 | | unsigned numResults, |
274 | 0 | ArrayRef<StringRef> elided) { |
275 | 0 | // Print out function attributes, if present. |
276 | 0 | SmallVector<StringRef, 2> ignoredAttrs = { |
277 | 0 | ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()}; |
278 | 0 | ignoredAttrs.append(elided.begin(), elided.end()); |
279 | 0 |
|
280 | 0 | SmallString<8> attrNameBuf; |
281 | 0 |
|
282 | 0 | // Ignore any argument attributes. |
283 | 0 | std::vector<SmallString<8>> argAttrStorage; |
284 | 0 | for (unsigned i = 0; i != numInputs; ++i) |
285 | 0 | if (op->getAttr(getArgAttrName(i, attrNameBuf))) |
286 | 0 | argAttrStorage.emplace_back(attrNameBuf); |
287 | 0 | ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end()); |
288 | 0 |
|
289 | 0 | // Ignore any result attributes. |
290 | 0 | std::vector<SmallString<8>> resultAttrStorage; |
291 | 0 | for (unsigned i = 0; i != numResults; ++i) |
292 | 0 | if (op->getAttr(getResultAttrName(i, attrNameBuf))) |
293 | 0 | resultAttrStorage.emplace_back(attrNameBuf); |
294 | 0 | ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end()); |
295 | 0 |
|
296 | 0 | p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); |
297 | 0 | } |
298 | | |
299 | | /// Printer implementation for function-like operations. Accepts lists of |
300 | | /// argument and result types to use while printing. |
301 | | void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op, |
302 | | ArrayRef<Type> argTypes, bool isVariadic, |
303 | 0 | ArrayRef<Type> resultTypes) { |
304 | 0 | // Print the operation and the function name. |
305 | 0 | auto funcName = |
306 | 0 | op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName()) |
307 | 0 | .getValue(); |
308 | 0 | p << op->getName() << ' '; |
309 | 0 | p.printSymbolName(funcName); |
310 | 0 |
|
311 | 0 | printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); |
312 | 0 | printFunctionAttributes(p, op, argTypes.size(), resultTypes.size()); |
313 | 0 |
|
314 | 0 | // Print the body if this is not an external function. |
315 | 0 | Region &body = op->getRegion(0); |
316 | 0 | if (!body.empty()) |
317 | 0 | p.printRegion(body, /*printEntryBlockArgs=*/false, |
318 | 0 | /*printBlockTerminators=*/true); |
319 | 0 | } |