/home/arjun/llvm-project/mlir/lib/IR/Function.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- Function.cpp - MLIR Function Classes -------------------------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #include "mlir/IR/Function.h" |
10 | | #include "mlir/IR/BlockAndValueMapping.h" |
11 | | #include "mlir/IR/Builders.h" |
12 | | #include "mlir/IR/FunctionImplementation.h" |
13 | | #include "llvm/ADT/BitVector.h" |
14 | | #include "llvm/ADT/MapVector.h" |
15 | | #include "llvm/ADT/SmallString.h" |
16 | | #include "llvm/ADT/Twine.h" |
17 | | |
18 | | using namespace mlir; |
19 | | |
20 | | //===----------------------------------------------------------------------===// |
21 | | // Function Operation. |
22 | | //===----------------------------------------------------------------------===// |
23 | | |
24 | | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
25 | 0 | ArrayRef<NamedAttribute> attrs) { |
26 | 0 | OperationState state(location, "func"); |
27 | 0 | OpBuilder builder(location->getContext()); |
28 | 0 | FuncOp::build(builder, state, name, type, attrs); |
29 | 0 | return cast<FuncOp>(Operation::create(state)); |
30 | 0 | } |
31 | | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
32 | 0 | iterator_range<dialect_attr_iterator> attrs) { |
33 | 0 | SmallVector<NamedAttribute, 8> attrRef(attrs); |
34 | 0 | return create(location, name, type, llvm::makeArrayRef(attrRef)); |
35 | 0 | } |
36 | | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
37 | | ArrayRef<NamedAttribute> attrs, |
38 | 0 | ArrayRef<MutableDictionaryAttr> argAttrs) { |
39 | 0 | FuncOp func = create(location, name, type, attrs); |
40 | 0 | func.setAllArgAttrs(argAttrs); |
41 | 0 | return func; |
42 | 0 | } |
43 | | |
44 | | void FuncOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
45 | 0 | FunctionType type, ArrayRef<NamedAttribute> attrs) { |
46 | 0 | result.addAttribute(SymbolTable::getSymbolAttrName(), |
47 | 0 | builder.getStringAttr(name)); |
48 | 0 | result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); |
49 | 0 | result.attributes.append(attrs.begin(), attrs.end()); |
50 | 0 | result.addRegion(); |
51 | 0 | } |
52 | | |
53 | | void FuncOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
54 | | FunctionType type, ArrayRef<NamedAttribute> attrs, |
55 | 0 | ArrayRef<MutableDictionaryAttr> argAttrs) { |
56 | 0 | build(builder, result, name, type, attrs); |
57 | 0 | assert(type.getNumInputs() == argAttrs.size()); |
58 | 0 | SmallString<8> argAttrName; |
59 | 0 | for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) |
60 | 0 | if (auto argDict = argAttrs[i].getDictionary(builder.getContext())) |
61 | 0 | result.addAttribute(getArgAttrName(i, argAttrName), argDict); |
62 | 0 | } |
63 | | |
64 | | /// Parsing/Printing methods. |
65 | | |
66 | 0 | ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
67 | 0 | auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, |
68 | 0 | ArrayRef<Type> results, impl::VariadicFlag, |
69 | 0 | std::string &) { |
70 | 0 | return builder.getFunctionType(argTypes, results); |
71 | 0 | }; |
72 | 0 |
|
73 | 0 | return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false, |
74 | 0 | buildFuncType); |
75 | 0 | } |
76 | | |
77 | 0 | void FuncOp::print(OpAsmPrinter &p) { |
78 | 0 | FunctionType fnType = getType(); |
79 | 0 | impl::printFunctionLikeOp(p, *this, fnType.getInputs(), /*isVariadic=*/false, |
80 | 0 | fnType.getResults()); |
81 | 0 | } |
82 | | |
83 | 0 | LogicalResult FuncOp::verify() { |
84 | 0 | // If this function is external there is nothing to do. |
85 | 0 | if (isExternal()) |
86 | 0 | return success(); |
87 | 0 | |
88 | 0 | // Verify that the argument list of the function and the arg list of the entry |
89 | 0 | // block line up. The trait already verified that the number of arguments is |
90 | 0 | // the same between the signature and the block. |
91 | 0 | auto fnInputTypes = getType().getInputs(); |
92 | 0 | Block &entryBlock = front(); |
93 | 0 | for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i) |
94 | 0 | if (fnInputTypes[i] != entryBlock.getArgument(i).getType()) |
95 | 0 | return emitOpError("type of entry block argument #") |
96 | 0 | << i << '(' << entryBlock.getArgument(i).getType() |
97 | 0 | << ") must match the type of the corresponding argument in " |
98 | 0 | << "function signature(" << fnInputTypes[i] << ')'; |
99 | 0 |
|
100 | 0 | return success(); |
101 | 0 | } |
102 | | |
103 | 0 | void FuncOp::eraseArguments(ArrayRef<unsigned> argIndices) { |
104 | 0 | auto oldType = getType(); |
105 | 0 | int originalNumArgs = oldType.getNumInputs(); |
106 | 0 | llvm::BitVector eraseIndices(originalNumArgs); |
107 | 0 | for (auto index : argIndices) |
108 | 0 | eraseIndices.set(index); |
109 | 0 | auto shouldEraseArg = [&](int i) { return eraseIndices.test(i); }; |
110 | 0 |
|
111 | 0 | // There are 3 things that need to be updated: |
112 | 0 | // - Function type. |
113 | 0 | // - Arg attrs. |
114 | 0 | // - Block arguments of entry block. |
115 | 0 |
|
116 | 0 | // Update the function type and arg attrs. |
117 | 0 | SmallVector<Type, 4> newInputTypes; |
118 | 0 | SmallVector<MutableDictionaryAttr, 4> newArgAttrs; |
119 | 0 | for (int i = 0; i < originalNumArgs; i++) { |
120 | 0 | if (shouldEraseArg(i)) |
121 | 0 | continue; |
122 | 0 | newInputTypes.emplace_back(oldType.getInput(i)); |
123 | 0 | newArgAttrs.emplace_back(getArgAttrDict(i)); |
124 | 0 | } |
125 | 0 | setType(FunctionType::get(newInputTypes, oldType.getResults(), getContext())); |
126 | 0 | setAllArgAttrs(newArgAttrs); |
127 | 0 |
|
128 | 0 | // Update the entry block's arguments. |
129 | 0 | // We do this in reverse so that we erase later indices before earlier |
130 | 0 | // indices, to avoid shifting the later indices. |
131 | 0 | Block &entry = front(); |
132 | 0 | for (int i = 0; i < originalNumArgs; i++) |
133 | 0 | if (shouldEraseArg(originalNumArgs - i - 1)) |
134 | 0 | entry.eraseArgument(originalNumArgs - i - 1); |
135 | 0 | } |
136 | | |
137 | | /// Clone the internal blocks from this function into dest and all attributes |
138 | | /// from this function to dest. |
139 | 0 | void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { |
140 | 0 | // Add the attributes of this function to dest. |
141 | 0 | llvm::MapVector<Identifier, Attribute> newAttrs; |
142 | 0 | for (auto &attr : dest.getAttrs()) |
143 | 0 | newAttrs.insert(attr); |
144 | 0 | for (auto &attr : getAttrs()) |
145 | 0 | newAttrs.insert(attr); |
146 | 0 | dest.getOperation()->setAttrs( |
147 | 0 | DictionaryAttr::get(newAttrs.takeVector(), getContext())); |
148 | 0 |
|
149 | 0 | // Clone the body. |
150 | 0 | getBody().cloneInto(&dest.getBody(), mapper); |
151 | 0 | } |
152 | | |
153 | | /// Create a deep copy of this function and all of its blocks, remapping |
154 | | /// any operands that use values outside of the function using the map that is |
155 | | /// provided (leaving them alone if no entry is present). Replaces references |
156 | | /// to cloned sub-values with the corresponding value that is copied, and adds |
157 | | /// those mappings to the mapper. |
158 | 0 | FuncOp FuncOp::clone(BlockAndValueMapping &mapper) { |
159 | 0 | FunctionType newType = getType(); |
160 | 0 |
|
161 | 0 | // If the function has a body, then the user might be deleting arguments to |
162 | 0 | // the function by specifying them in the mapper. If so, we don't add the |
163 | 0 | // argument to the input type vector. |
164 | 0 | bool isExternalFn = isExternal(); |
165 | 0 | if (!isExternalFn) { |
166 | 0 | SmallVector<Type, 4> inputTypes; |
167 | 0 | inputTypes.reserve(newType.getNumInputs()); |
168 | 0 | for (unsigned i = 0, e = getNumArguments(); i != e; ++i) |
169 | 0 | if (!mapper.contains(getArgument(i))) |
170 | 0 | inputTypes.push_back(newType.getInput(i)); |
171 | 0 | newType = FunctionType::get(inputTypes, newType.getResults(), getContext()); |
172 | 0 | } |
173 | 0 |
|
174 | 0 | // Create the new function. |
175 | 0 | FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); |
176 | 0 | newFunc.setType(newType); |
177 | 0 |
|
178 | 0 | /// Set the argument attributes for arguments that aren't being replaced. |
179 | 0 | for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i) |
180 | 0 | if (isExternalFn || !mapper.contains(getArgument(i))) |
181 | 0 | newFunc.setArgAttrs(destI++, getArgAttrs(i)); |
182 | 0 |
|
183 | 0 | /// Clone the current function into the new one and return it. |
184 | 0 | cloneInto(newFunc, mapper); |
185 | 0 | return newFunc; |
186 | 0 | } |
187 | 0 | FuncOp FuncOp::clone() { |
188 | 0 | BlockAndValueMapping mapper; |
189 | 0 | return clone(mapper); |
190 | 0 | } |