/home/arjun/llvm-project/mlir/lib/IR/FunctionImplementation.cpp
| Line | Count | Source (jump to first uncovered line) | 
| 1 |  | //===- FunctionImplementation.cpp - Utilities for function-like ops -------===// | 
| 2 |  | // | 
| 3 |  | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
| 4 |  | // See https://llvm.org/LICENSE.txt for license information. | 
| 5 |  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
| 6 |  | // | 
| 7 |  | //===----------------------------------------------------------------------===// | 
| 8 |  |  | 
| 9 |  | #include "mlir/IR/FunctionImplementation.h" | 
| 10 |  | #include "mlir/IR/Builders.h" | 
| 11 |  | #include "mlir/IR/FunctionSupport.h" | 
| 12 |  | #include "mlir/IR/SymbolTable.h" | 
| 13 |  |  | 
| 14 |  | using namespace mlir; | 
| 15 |  |  | 
| 16 |  | static ParseResult | 
| 17 |  | parseArgumentList(OpAsmParser &parser, bool allowVariadic, | 
| 18 |  |                   SmallVectorImpl<Type> &argTypes, | 
| 19 |  |                   SmallVectorImpl<OpAsmParser::OperandType> &argNames, | 
| 20 | 0 |                   SmallVectorImpl<NamedAttrList> &argAttrs, bool &isVariadic) { | 
| 21 | 0 |   if (parser.parseLParen()) | 
| 22 | 0 |     return failure(); | 
| 23 | 0 |  | 
| 24 | 0 |   // The argument list either has to consistently have ssa-id's followed by | 
| 25 | 0 |   // types, or just be a type list.  It isn't ok to sometimes have SSA ID's and | 
| 26 | 0 |   // sometimes not. | 
| 27 | 0 |   auto parseArgument = [&]() -> ParseResult { | 
| 28 | 0 |     llvm::SMLoc loc = parser.getCurrentLocation(); | 
| 29 | 0 | 
 | 
| 30 | 0 |     // Parse argument name if present. | 
| 31 | 0 |     OpAsmParser::OperandType argument; | 
| 32 | 0 |     Type argumentType; | 
| 33 | 0 |     if (succeeded(parser.parseOptionalRegionArgument(argument)) && | 
| 34 | 0 |         !argument.name.empty()) { | 
| 35 | 0 |       // Reject this if the preceding argument was missing a name. | 
| 36 | 0 |       if (argNames.empty() && !argTypes.empty()) | 
| 37 | 0 |         return parser.emitError(loc, "expected type instead of SSA identifier"); | 
| 38 | 0 |       argNames.push_back(argument); | 
| 39 | 0 | 
 | 
| 40 | 0 |       if (parser.parseColonType(argumentType)) | 
| 41 | 0 |         return failure(); | 
| 42 | 0 |     } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { | 
| 43 | 0 |       isVariadic = true; | 
| 44 | 0 |       return success(); | 
| 45 | 0 |     } else if (!argNames.empty()) { | 
| 46 | 0 |       // Reject this if the preceding argument had a name. | 
| 47 | 0 |       return parser.emitError(loc, "expected SSA identifier"); | 
| 48 | 0 |     } else if (parser.parseType(argumentType)) { | 
| 49 | 0 |       return failure(); | 
| 50 | 0 |     } | 
| 51 | 0 |  | 
| 52 | 0 |     // Add the argument type. | 
| 53 | 0 |     argTypes.push_back(argumentType); | 
| 54 | 0 | 
 | 
| 55 | 0 |     // Parse any argument attributes. | 
| 56 | 0 |     NamedAttrList attrs; | 
| 57 | 0 |     if (parser.parseOptionalAttrDict(attrs)) | 
| 58 | 0 |       return failure(); | 
| 59 | 0 |     argAttrs.push_back(attrs); | 
| 60 | 0 |     return success(); | 
| 61 | 0 |   }; | 
| 62 | 0 | 
 | 
| 63 | 0 |   // Parse the function arguments. | 
| 64 | 0 |   isVariadic = false; | 
| 65 | 0 |   if (failed(parser.parseOptionalRParen())) { | 
| 66 | 0 |     do { | 
| 67 | 0 |       unsigned numTypedArguments = argTypes.size(); | 
| 68 | 0 |       if (parseArgument()) | 
| 69 | 0 |         return failure(); | 
| 70 | 0 |  | 
| 71 | 0 |       llvm::SMLoc loc = parser.getCurrentLocation(); | 
| 72 | 0 |       if (argTypes.size() == numTypedArguments && | 
| 73 | 0 |           succeeded(parser.parseOptionalComma())) | 
| 74 | 0 |         return parser.emitError( | 
| 75 | 0 |             loc, "variadic arguments must be in the end of the argument list"); | 
| 76 | 0 |     } while (succeeded(parser.parseOptionalComma())); | 
| 77 | 0 |     parser.parseRParen(); | 
| 78 | 0 |   } | 
| 79 | 0 | 
 | 
| 80 | 0 |   return success(); | 
| 81 | 0 | } | 
| 82 |  |  | 
| 83 |  | /// Parse a function result list. | 
| 84 |  | /// | 
| 85 |  | ///   function-result-list ::= function-result-list-parens | 
| 86 |  | ///                          | non-function-type | 
| 87 |  | ///   function-result-list-parens ::= `(` `)` | 
| 88 |  | ///                                 | `(` function-result-list-no-parens `)` | 
| 89 |  | ///   function-result-list-no-parens ::= function-result (`,` function-result)* | 
| 90 |  | ///   function-result ::= type attribute-dict? | 
| 91 |  | /// | 
| 92 |  | static ParseResult | 
| 93 |  | parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes, | 
| 94 | 0 |                         SmallVectorImpl<NamedAttrList> &resultAttrs) { | 
| 95 | 0 |   if (failed(parser.parseOptionalLParen())) { | 
| 96 | 0 |     // We already know that there is no `(`, so parse a type. | 
| 97 | 0 |     // Because there is no `(`, it cannot be a function type. | 
| 98 | 0 |     Type ty; | 
| 99 | 0 |     if (parser.parseType(ty)) | 
| 100 | 0 |       return failure(); | 
| 101 | 0 |     resultTypes.push_back(ty); | 
| 102 | 0 |     resultAttrs.emplace_back(); | 
| 103 | 0 |     return success(); | 
| 104 | 0 |   } | 
| 105 | 0 |  | 
| 106 | 0 |   // Special case for an empty set of parens. | 
| 107 | 0 |   if (succeeded(parser.parseOptionalRParen())) | 
| 108 | 0 |     return success(); | 
| 109 | 0 |  | 
| 110 | 0 |   // Parse individual function results. | 
| 111 | 0 |   do { | 
| 112 | 0 |     resultTypes.emplace_back(); | 
| 113 | 0 |     resultAttrs.emplace_back(); | 
| 114 | 0 |     if (parser.parseType(resultTypes.back()) || | 
| 115 | 0 |         parser.parseOptionalAttrDict(resultAttrs.back())) { | 
| 116 | 0 |       return failure(); | 
| 117 | 0 |     } | 
| 118 | 0 |   } while (succeeded(parser.parseOptionalComma())); | 
| 119 | 0 |   return parser.parseRParen(); | 
| 120 | 0 | } | 
| 121 |  |  | 
| 122 |  | /// Parses a function signature using `parser`. The `allowVariadic` argument | 
| 123 |  | /// indicates whether functions with variadic arguments are supported. The | 
| 124 |  | /// trailing arguments are populated by this function with names, types and | 
| 125 |  | /// attributes of the arguments and those of the results. | 
| 126 |  | ParseResult mlir::impl::parseFunctionSignature( | 
| 127 |  |     OpAsmParser &parser, bool allowVariadic, | 
| 128 |  |     SmallVectorImpl<OpAsmParser::OperandType> &argNames, | 
| 129 |  |     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs, | 
| 130 |  |     bool &isVariadic, SmallVectorImpl<Type> &resultTypes, | 
| 131 | 0 |     SmallVectorImpl<NamedAttrList> &resultAttrs) { | 
| 132 | 0 |   if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs, | 
| 133 | 0 |                         isVariadic)) | 
| 134 | 0 |     return failure(); | 
| 135 | 0 |   if (succeeded(parser.parseOptionalArrow())) | 
| 136 | 0 |     return parseFunctionResultList(parser, resultTypes, resultAttrs); | 
| 137 | 0 |   return success(); | 
| 138 | 0 | } | 
| 139 |  |  | 
| 140 |  | void mlir::impl::addArgAndResultAttrs(Builder &builder, OperationState &result, | 
| 141 |  |                                       ArrayRef<NamedAttrList> argAttrs, | 
| 142 | 0 |                                       ArrayRef<NamedAttrList> resultAttrs) { | 
| 143 | 0 |   // Add the attributes to the function arguments. | 
| 144 | 0 |   SmallString<8> attrNameBuf; | 
| 145 | 0 |   for (unsigned i = 0, e = argAttrs.size(); i != e; ++i) | 
| 146 | 0 |     if (!argAttrs[i].empty()) | 
| 147 | 0 |       result.addAttribute(getArgAttrName(i, attrNameBuf), | 
| 148 | 0 |                           builder.getDictionaryAttr(argAttrs[i])); | 
| 149 | 0 | 
 | 
| 150 | 0 |   // Add the attributes to the function results. | 
| 151 | 0 |   for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i) | 
| 152 | 0 |     if (!resultAttrs[i].empty()) | 
| 153 | 0 |       result.addAttribute(getResultAttrName(i, attrNameBuf), | 
| 154 | 0 |                           builder.getDictionaryAttr(resultAttrs[i])); | 
| 155 | 0 | } | 
| 156 |  |  | 
| 157 |  | /// Parser implementation for function-like operations.  Uses `funcTypeBuilder` | 
| 158 |  | /// to construct the custom function type given lists of input and output types. | 
| 159 |  | ParseResult | 
| 160 |  | mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, | 
| 161 |  |                                 bool allowVariadic, | 
| 162 | 0 |                                 mlir::impl::FuncTypeBuilder funcTypeBuilder) { | 
| 163 | 0 |   SmallVector<OpAsmParser::OperandType, 4> entryArgs; | 
| 164 | 0 |   SmallVector<NamedAttrList, 4> argAttrs; | 
| 165 | 0 |   SmallVector<NamedAttrList, 4> resultAttrs; | 
| 166 | 0 |   SmallVector<Type, 4> argTypes; | 
| 167 | 0 |   SmallVector<Type, 4> resultTypes; | 
| 168 | 0 |   auto &builder = parser.getBuilder(); | 
| 169 | 0 | 
 | 
| 170 | 0 |   // Parse the name as a symbol. | 
| 171 | 0 |   StringAttr nameAttr; | 
| 172 | 0 |   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), | 
| 173 | 0 |                              result.attributes)) | 
| 174 | 0 |     return failure(); | 
| 175 | 0 |  | 
| 176 | 0 |   // Parse the function signature. | 
| 177 | 0 |   auto signatureLocation = parser.getCurrentLocation(); | 
| 178 | 0 |   bool isVariadic = false; | 
| 179 | 0 |   if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, | 
| 180 | 0 |                              argAttrs, isVariadic, resultTypes, resultAttrs)) | 
| 181 | 0 |     return failure(); | 
| 182 | 0 |  | 
| 183 | 0 |   std::string errorMessage; | 
| 184 | 0 |   if (auto type = funcTypeBuilder(builder, argTypes, resultTypes, | 
| 185 | 0 |                                   impl::VariadicFlag(isVariadic), errorMessage)) | 
| 186 | 0 |     result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); | 
| 187 | 0 |   else | 
| 188 | 0 |     return parser.emitError(signatureLocation) | 
| 189 | 0 |            << "failed to construct function type" | 
| 190 | 0 |            << (errorMessage.empty() ? "" : ": ") << errorMessage; | 
| 191 | 0 | 
 | 
| 192 | 0 |   // If function attributes are present, parse them. | 
| 193 | 0 |   if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) | 
| 194 | 0 |     return failure(); | 
| 195 | 0 |  | 
| 196 | 0 |   // Add the attributes to the function arguments. | 
| 197 | 0 |   assert(argAttrs.size() == argTypes.size()); | 
| 198 | 0 |   assert(resultAttrs.size() == resultTypes.size()); | 
| 199 | 0 |   addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); | 
| 200 | 0 | 
 | 
| 201 | 0 |   // Parse the optional function body. | 
| 202 | 0 |   auto *body = result.addRegion(); | 
| 203 | 0 |   return parser.parseOptionalRegion( | 
| 204 | 0 |       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); | 
| 205 | 0 | } | 
| 206 |  |  | 
| 207 |  | // Print a function result list. | 
| 208 |  | static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types, | 
| 209 | 0 |                                     ArrayRef<ArrayRef<NamedAttribute>> attrs) { | 
| 210 | 0 |   assert(!types.empty() && "Should not be called for empty result list."); | 
| 211 | 0 |   auto &os = p.getStream(); | 
| 212 | 0 |   bool needsParens = | 
| 213 | 0 |       types.size() > 1 || types[0].isa<FunctionType>() || !attrs[0].empty(); | 
| 214 | 0 |   if (needsParens) | 
| 215 | 0 |     os << '('; | 
| 216 | 0 |   llvm::interleaveComma( | 
| 217 | 0 |       llvm::zip(types, attrs), os, | 
| 218 | 0 |       [&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) { | 
| 219 | 0 |         p.printType(std::get<0>(t)); | 
| 220 | 0 |         p.printOptionalAttrDict(std::get<1>(t)); | 
| 221 | 0 |       }); | 
| 222 | 0 |   if (needsParens) | 
| 223 | 0 |     os << ')'; | 
| 224 | 0 | } | 
| 225 |  |  | 
| 226 |  | /// Print the signature of the function-like operation `op`.  Assumes `op` has | 
| 227 |  | /// the FunctionLike trait and passed the verification. | 
| 228 |  | void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op, | 
| 229 |  |                                         ArrayRef<Type> argTypes, | 
| 230 |  |                                         bool isVariadic, | 
| 231 | 0 |                                         ArrayRef<Type> resultTypes) { | 
| 232 | 0 |   Region &body = op->getRegion(0); | 
| 233 | 0 |   bool isExternal = body.empty(); | 
| 234 | 0 | 
 | 
| 235 | 0 |   p << '('; | 
| 236 | 0 |   for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { | 
| 237 | 0 |     if (i > 0) | 
| 238 | 0 |       p << ", "; | 
| 239 | 0 | 
 | 
| 240 | 0 |     if (!isExternal) { | 
| 241 | 0 |       p.printOperand(body.front().getArgument(i)); | 
| 242 | 0 |       p << ": "; | 
| 243 | 0 |     } | 
| 244 | 0 | 
 | 
| 245 | 0 |     p.printType(argTypes[i]); | 
| 246 | 0 |     p.printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i)); | 
| 247 | 0 |   } | 
| 248 | 0 | 
 | 
| 249 | 0 |   if (isVariadic) { | 
| 250 | 0 |     if (!argTypes.empty()) | 
| 251 | 0 |       p << ", "; | 
| 252 | 0 |     p << "..."; | 
| 253 | 0 |   } | 
| 254 | 0 | 
 | 
| 255 | 0 |   p << ')'; | 
| 256 | 0 | 
 | 
| 257 | 0 |   if (!resultTypes.empty()) { | 
| 258 | 0 |     p.getStream() << " -> "; | 
| 259 | 0 |     SmallVector<ArrayRef<NamedAttribute>, 4> resultAttrs; | 
| 260 | 0 |     for (int i = 0, e = resultTypes.size(); i < e; ++i) | 
| 261 | 0 |       resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i)); | 
| 262 | 0 |     printFunctionResultList(p, resultTypes, resultAttrs); | 
| 263 | 0 |   } | 
| 264 | 0 | } | 
| 265 |  |  | 
| 266 |  | /// Prints the list of function prefixed with the "attributes" keyword. The | 
| 267 |  | /// attributes with names listed in "elided" as well as those used by the | 
| 268 |  | /// function-like operation internally are not printed. Nothing is printed | 
| 269 |  | /// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and | 
| 270 |  | /// passed the verification. | 
| 271 |  | void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op, | 
| 272 |  |                                          unsigned numInputs, | 
| 273 |  |                                          unsigned numResults, | 
| 274 | 0 |                                          ArrayRef<StringRef> elided) { | 
| 275 | 0 |   // Print out function attributes, if present. | 
| 276 | 0 |   SmallVector<StringRef, 2> ignoredAttrs = { | 
| 277 | 0 |       ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()}; | 
| 278 | 0 |   ignoredAttrs.append(elided.begin(), elided.end()); | 
| 279 | 0 | 
 | 
| 280 | 0 |   SmallString<8> attrNameBuf; | 
| 281 | 0 | 
 | 
| 282 | 0 |   // Ignore any argument attributes. | 
| 283 | 0 |   std::vector<SmallString<8>> argAttrStorage; | 
| 284 | 0 |   for (unsigned i = 0; i != numInputs; ++i) | 
| 285 | 0 |     if (op->getAttr(getArgAttrName(i, attrNameBuf))) | 
| 286 | 0 |       argAttrStorage.emplace_back(attrNameBuf); | 
| 287 | 0 |   ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end()); | 
| 288 | 0 | 
 | 
| 289 | 0 |   // Ignore any result attributes. | 
| 290 | 0 |   std::vector<SmallString<8>> resultAttrStorage; | 
| 291 | 0 |   for (unsigned i = 0; i != numResults; ++i) | 
| 292 | 0 |     if (op->getAttr(getResultAttrName(i, attrNameBuf))) | 
| 293 | 0 |       resultAttrStorage.emplace_back(attrNameBuf); | 
| 294 | 0 |   ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end()); | 
| 295 | 0 | 
 | 
| 296 | 0 |   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); | 
| 297 | 0 | } | 
| 298 |  |  | 
| 299 |  | /// Printer implementation for function-like operations.  Accepts lists of | 
| 300 |  | /// argument and result types to use while printing. | 
| 301 |  | void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op, | 
| 302 |  |                                      ArrayRef<Type> argTypes, bool isVariadic, | 
| 303 | 0 |                                      ArrayRef<Type> resultTypes) { | 
| 304 | 0 |   // Print the operation and the function name. | 
| 305 | 0 |   auto funcName = | 
| 306 | 0 |       op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName()) | 
| 307 | 0 |           .getValue(); | 
| 308 | 0 |   p << op->getName() << ' '; | 
| 309 | 0 |   p.printSymbolName(funcName); | 
| 310 | 0 | 
 | 
| 311 | 0 |   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); | 
| 312 | 0 |   printFunctionAttributes(p, op, argTypes.size(), resultTypes.size()); | 
| 313 | 0 | 
 | 
| 314 | 0 |   // Print the body if this is not an external function. | 
| 315 | 0 |   Region &body = op->getRegion(0); | 
| 316 | 0 |   if (!body.empty()) | 
| 317 | 0 |     p.printRegion(body, /*printEntryBlockArgs=*/false, | 
| 318 | 0 |                   /*printBlockTerminators=*/true); | 
| 319 | 0 | } |