/home/arjun/llvm-project/mlir/lib/IR/Operation.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- Operation.cpp - Operation support code -----------------------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #include "mlir/IR/Operation.h" |
10 | | #include "mlir/IR/BlockAndValueMapping.h" |
11 | | #include "mlir/IR/Dialect.h" |
12 | | #include "mlir/IR/OpImplementation.h" |
13 | | #include "mlir/IR/PatternMatch.h" |
14 | | #include "mlir/IR/StandardTypes.h" |
15 | | #include "mlir/IR/TypeUtilities.h" |
16 | | #include <numeric> |
17 | | |
18 | | using namespace mlir; |
19 | | |
20 | 0 | OpAsmParser::~OpAsmParser() {} |
21 | | |
22 | | //===----------------------------------------------------------------------===// |
23 | | // OperationName |
24 | | //===----------------------------------------------------------------------===// |
25 | | |
26 | | /// Form the OperationName for an op with the specified string. This either is |
27 | | /// a reference to an AbstractOperation if one is known, or a uniqued Identifier |
28 | | /// if not. |
29 | 0 | OperationName::OperationName(StringRef name, MLIRContext *context) { |
30 | 0 | if (auto *op = AbstractOperation::lookup(name, context)) |
31 | 0 | representation = op; |
32 | 0 | else |
33 | 0 | representation = Identifier::get(name, context); |
34 | 0 | } |
35 | | |
36 | | /// Return the name of the dialect this operation is registered to. |
37 | 0 | StringRef OperationName::getDialect() const { |
38 | 0 | return getStringRef().split('.').first; |
39 | 0 | } |
40 | | |
41 | | /// Return the name of this operation. This always succeeds. |
42 | 0 | StringRef OperationName::getStringRef() const { |
43 | 0 | if (auto *op = representation.dyn_cast<const AbstractOperation *>()) |
44 | 0 | return op->name; |
45 | 0 | return representation.get<Identifier>().strref(); |
46 | 0 | } |
47 | | |
48 | 0 | const AbstractOperation *OperationName::getAbstractOperation() const { |
49 | 0 | return representation.dyn_cast<const AbstractOperation *>(); |
50 | 0 | } |
51 | | |
52 | 0 | OperationName OperationName::getFromOpaquePointer(void *pointer) { |
53 | 0 | return OperationName(RepresentationUnion::getFromOpaqueValue(pointer)); |
54 | 0 | } |
55 | | |
56 | | //===----------------------------------------------------------------------===// |
57 | | // Operation |
58 | | //===----------------------------------------------------------------------===// |
59 | | |
60 | | /// Create a new Operation with the specific fields. |
61 | | Operation *Operation::create(Location location, OperationName name, |
62 | | ArrayRef<Type> resultTypes, |
63 | | ArrayRef<Value> operands, |
64 | | ArrayRef<NamedAttribute> attributes, |
65 | | ArrayRef<Block *> successors, |
66 | 0 | unsigned numRegions) { |
67 | 0 | return create(location, name, resultTypes, operands, |
68 | 0 | MutableDictionaryAttr(attributes), successors, numRegions); |
69 | 0 | } |
70 | | |
71 | | /// Create a new Operation from operation state. |
72 | 0 | Operation *Operation::create(const OperationState &state) { |
73 | 0 | return Operation::create(state.location, state.name, state.types, |
74 | 0 | state.operands, state.attributes, state.successors, |
75 | 0 | state.regions); |
76 | 0 | } |
77 | | |
78 | | /// Create a new Operation with the specific fields. |
79 | | Operation *Operation::create(Location location, OperationName name, |
80 | | ArrayRef<Type> resultTypes, |
81 | | ArrayRef<Value> operands, |
82 | | MutableDictionaryAttr attributes, |
83 | | ArrayRef<Block *> successors, |
84 | 0 | RegionRange regions) { |
85 | 0 | unsigned numRegions = regions.size(); |
86 | 0 | Operation *op = create(location, name, resultTypes, operands, attributes, |
87 | 0 | successors, numRegions); |
88 | 0 | for (unsigned i = 0; i < numRegions; ++i) |
89 | 0 | if (regions[i]) |
90 | 0 | op->getRegion(i).takeBody(*regions[i]); |
91 | 0 | return op; |
92 | 0 | } |
93 | | |
94 | | /// Overload of create that takes an existing MutableDictionaryAttr to avoid |
95 | | /// unnecessarily uniquing a list of attributes. |
96 | | Operation *Operation::create(Location location, OperationName name, |
97 | | ArrayRef<Type> resultTypes, |
98 | | ArrayRef<Value> operands, |
99 | | MutableDictionaryAttr attributes, |
100 | | ArrayRef<Block *> successors, |
101 | 0 | unsigned numRegions) { |
102 | 0 | // We only need to allocate additional memory for a subset of results. |
103 | 0 | unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); |
104 | 0 | unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); |
105 | 0 | unsigned numSuccessors = successors.size(); |
106 | 0 | unsigned numOperands = operands.size(); |
107 | 0 |
|
108 | 0 | // If the operation is known to have no operands, don't allocate an operand |
109 | 0 | // storage. |
110 | 0 | bool needsOperandStorage = true; |
111 | 0 | if (operands.empty()) { |
112 | 0 | if (const AbstractOperation *abstractOp = name.getAbstractOperation()) |
113 | 0 | needsOperandStorage = !abstractOp->hasTrait<OpTrait::ZeroOperands>(); |
114 | 0 | } |
115 | 0 |
|
116 | 0 | // Compute the byte size for the operation and the operand storage. |
117 | 0 | auto byteSize = |
118 | 0 | totalSizeToAlloc<detail::InLineOpResult, detail::TrailingOpResult, |
119 | 0 | BlockOperand, Region, detail::OperandStorage>( |
120 | 0 | numInlineResults, numTrailingResults, numSuccessors, numRegions, |
121 | 0 | needsOperandStorage ? 1 : 0); |
122 | 0 | byteSize += |
123 | 0 | llvm::alignTo(detail::OperandStorage::additionalAllocSize(numOperands), |
124 | 0 | alignof(Operation)); |
125 | 0 | void *rawMem = malloc(byteSize); |
126 | 0 |
|
127 | 0 | // Create the new Operation. |
128 | 0 | Operation *op = |
129 | 0 | ::new (rawMem) Operation(location, name, resultTypes, numSuccessors, |
130 | 0 | numRegions, attributes, needsOperandStorage); |
131 | 0 |
|
132 | 0 | assert((numSuccessors == 0 || !op->isKnownNonTerminator()) && |
133 | 0 | "unexpected successors in a non-terminator operation"); |
134 | 0 |
|
135 | 0 | // Initialize the results. |
136 | 0 | for (unsigned i = 0; i < numInlineResults; ++i) |
137 | 0 | new (op->getInlineResult(i)) detail::InLineOpResult(); |
138 | 0 | for (unsigned i = 0; i < numTrailingResults; ++i) |
139 | 0 | new (op->getTrailingResult(i)) detail::TrailingOpResult(i); |
140 | 0 |
|
141 | 0 | // Initialize the regions. |
142 | 0 | for (unsigned i = 0; i != numRegions; ++i) |
143 | 0 | new (&op->getRegion(i)) Region(op); |
144 | 0 |
|
145 | 0 | // Initialize the operands. |
146 | 0 | if (needsOperandStorage) |
147 | 0 | new (&op->getOperandStorage()) detail::OperandStorage(op, operands); |
148 | 0 |
|
149 | 0 | // Initialize the successors. |
150 | 0 | auto blockOperands = op->getBlockOperands(); |
151 | 0 | for (unsigned i = 0; i != numSuccessors; ++i) |
152 | 0 | new (&blockOperands[i]) BlockOperand(op, successors[i]); |
153 | 0 |
|
154 | 0 | return op; |
155 | 0 | } |
156 | | |
157 | | Operation::Operation(Location location, OperationName name, |
158 | | ArrayRef<Type> resultTypes, unsigned numSuccessors, |
159 | | unsigned numRegions, |
160 | | const MutableDictionaryAttr &attributes, |
161 | | bool hasOperandStorage) |
162 | | : location(location), numSuccs(numSuccessors), numRegions(numRegions), |
163 | | hasOperandStorage(hasOperandStorage), hasSingleResult(false), name(name), |
164 | 0 | attrs(attributes) { |
165 | 0 | if (!resultTypes.empty()) { |
166 | 0 | // If there is a single result it is stored in-place, otherwise use a tuple. |
167 | 0 | hasSingleResult = resultTypes.size() == 1; |
168 | 0 | if (hasSingleResult) |
169 | 0 | resultType = resultTypes.front(); |
170 | 0 | else |
171 | 0 | resultType = TupleType::get(resultTypes, location->getContext()); |
172 | 0 | } |
173 | 0 | } |
174 | | |
175 | | // Operations are deleted through the destroy() member because they are |
176 | | // allocated via malloc. |
177 | 0 | Operation::~Operation() { |
178 | 0 | assert(block == nullptr && "operation destroyed but still in a block"); |
179 | 0 |
|
180 | 0 | // Explicitly run the destructors for the operands. |
181 | 0 | if (hasOperandStorage) |
182 | 0 | getOperandStorage().~OperandStorage(); |
183 | 0 |
|
184 | 0 | // Explicitly run the destructors for the successors. |
185 | 0 | for (auto &successor : getBlockOperands()) |
186 | 0 | successor.~BlockOperand(); |
187 | 0 |
|
188 | 0 | // Explicitly destroy the regions. |
189 | 0 | for (auto ®ion : getRegions()) |
190 | 0 | region.~Region(); |
191 | 0 | } |
192 | | |
193 | | /// Destroy this operation or one of its subclasses. |
194 | 0 | void Operation::destroy() { |
195 | 0 | this->~Operation(); |
196 | 0 | free(this); |
197 | 0 | } |
198 | | |
199 | | /// Return the context this operation is associated with. |
200 | 0 | MLIRContext *Operation::getContext() { return location->getContext(); } |
201 | | |
202 | | /// Return the dialect this operation is associated with, or nullptr if the |
203 | | /// associated dialect is not registered. |
204 | 0 | Dialect *Operation::getDialect() { |
205 | 0 | if (auto *abstractOp = getAbstractOperation()) |
206 | 0 | return &abstractOp->dialect; |
207 | 0 | |
208 | 0 | // If this operation hasn't been registered or doesn't have abstract |
209 | 0 | // operation, try looking up the dialect name in the context. |
210 | 0 | return getContext()->getRegisteredDialect(getName().getDialect()); |
211 | 0 | } |
212 | | |
213 | 0 | Region *Operation::getParentRegion() { |
214 | 0 | return block ? block->getParent() : nullptr; |
215 | 0 | } |
216 | | |
217 | 0 | Operation *Operation::getParentOp() { |
218 | 0 | return block ? block->getParentOp() : nullptr; |
219 | 0 | } |
220 | | |
221 | | /// Return true if this operation is a proper ancestor of the `other` |
222 | | /// operation. |
223 | 0 | bool Operation::isProperAncestor(Operation *other) { |
224 | 0 | while ((other = other->getParentOp())) |
225 | 0 | if (this == other) |
226 | 0 | return true; |
227 | 0 | return false; |
228 | 0 | } |
229 | | |
230 | | /// Replace any uses of 'from' with 'to' within this operation. |
231 | 0 | void Operation::replaceUsesOfWith(Value from, Value to) { |
232 | 0 | if (from == to) |
233 | 0 | return; |
234 | 0 | for (auto &operand : getOpOperands()) |
235 | 0 | if (operand.get() == from) |
236 | 0 | operand.set(to); |
237 | 0 | } |
238 | | |
239 | | /// Replace the current operands of this operation with the ones provided in |
240 | | /// 'operands'. |
241 | 0 | void Operation::setOperands(ValueRange operands) { |
242 | 0 | if (LLVM_LIKELY(hasOperandStorage)) |
243 | 0 | return getOperandStorage().setOperands(this, operands); |
244 | 0 | assert(operands.empty() && "setting operands without an operand storage"); |
245 | 0 | } |
246 | | |
247 | | /// Replace the operands beginning at 'start' and ending at 'start' + 'length' |
248 | | /// with the ones provided in 'operands'. 'operands' may be smaller or larger |
249 | | /// than the range pointed to by 'start'+'length'. |
250 | | void Operation::setOperands(unsigned start, unsigned length, |
251 | 0 | ValueRange operands) { |
252 | 0 | assert((start + length) <= getNumOperands() && |
253 | 0 | "invalid operand range specified"); |
254 | 0 | if (LLVM_LIKELY(hasOperandStorage)) |
255 | 0 | return getOperandStorage().setOperands(this, start, length, operands); |
256 | 0 | assert(operands.empty() && "setting operands without an operand storage"); |
257 | 0 | } |
258 | | |
259 | | /// Insert the given operands into the operand list at the given 'index'. |
260 | 0 | void Operation::insertOperands(unsigned index, ValueRange operands) { |
261 | 0 | if (LLVM_LIKELY(hasOperandStorage)) |
262 | 0 | return setOperands(index, /*length=*/0, operands); |
263 | 0 | assert(operands.empty() && "inserting operands without an operand storage"); |
264 | 0 | } |
265 | | |
266 | | //===----------------------------------------------------------------------===// |
267 | | // Diagnostics |
268 | | //===----------------------------------------------------------------------===// |
269 | | |
270 | | /// Emit an error about fatal conditions with this operation, reporting up to |
271 | | /// any diagnostic handlers that may be listening. |
272 | 0 | InFlightDiagnostic Operation::emitError(const Twine &message) { |
273 | 0 | InFlightDiagnostic diag = mlir::emitError(getLoc(), message); |
274 | 0 | if (getContext()->shouldPrintOpOnDiagnostic()) { |
275 | 0 | // Print out the operation explicitly here so that we can print the generic |
276 | 0 | // form. |
277 | 0 | // TODO(riverriddle) It would be nice if we could instead provide the |
278 | 0 | // specific printing flags when adding the operation as an argument to the |
279 | 0 | // diagnostic. |
280 | 0 | std::string printedOp; |
281 | 0 | { |
282 | 0 | llvm::raw_string_ostream os(printedOp); |
283 | 0 | print(os, OpPrintingFlags().printGenericOpForm().useLocalScope()); |
284 | 0 | } |
285 | 0 | diag.attachNote(getLoc()) << "see current operation: " << printedOp; |
286 | 0 | } |
287 | 0 | return diag; |
288 | 0 | } |
289 | | |
290 | | /// Emit a warning about this operation, reporting up to any diagnostic |
291 | | /// handlers that may be listening. |
292 | 0 | InFlightDiagnostic Operation::emitWarning(const Twine &message) { |
293 | 0 | InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); |
294 | 0 | if (getContext()->shouldPrintOpOnDiagnostic()) |
295 | 0 | diag.attachNote(getLoc()) << "see current operation: " << *this; |
296 | 0 | return diag; |
297 | 0 | } |
298 | | |
299 | | /// Emit a remark about this operation, reporting up to any diagnostic |
300 | | /// handlers that may be listening. |
301 | 0 | InFlightDiagnostic Operation::emitRemark(const Twine &message) { |
302 | 0 | InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); |
303 | 0 | if (getContext()->shouldPrintOpOnDiagnostic()) |
304 | 0 | diag.attachNote(getLoc()) << "see current operation: " << *this; |
305 | 0 | return diag; |
306 | 0 | } |
307 | | |
308 | | //===----------------------------------------------------------------------===// |
309 | | // Operation Ordering |
310 | | //===----------------------------------------------------------------------===// |
311 | | |
312 | | constexpr unsigned Operation::kInvalidOrderIdx; |
313 | | constexpr unsigned Operation::kOrderStride; |
314 | | |
315 | | /// Given an operation 'other' that is within the same parent block, return |
316 | | /// whether the current operation is before 'other' in the operation list |
317 | | /// of the parent block. |
318 | | /// Note: This function has an average complexity of O(1), but worst case may |
319 | | /// take O(N) where N is the number of operations within the parent block. |
320 | | bool Operation::isBeforeInBlock(Operation *other) { |
321 | | assert(block && "Operations without parent blocks have no order."); |
322 | | assert(other && other->block == block && |
323 | | "Expected other operation to have the same parent block."); |
324 | | // If the order of the block is already invalid, directly recompute the |
325 | | // parent. |
326 | | if (!block->isOpOrderValid()) { |
327 | | block->recomputeOpOrder(); |
328 | | } else { |
329 | | // Update the order either operation if necessary. |
330 | | updateOrderIfNecessary(); |
331 | | other->updateOrderIfNecessary(); |
332 | | } |
333 | | |
334 | | return orderIndex < other->orderIndex; |
335 | | } |
336 | | |
337 | | /// Update the order index of this operation of this operation if necessary, |
338 | | /// potentially recomputing the order of the parent block. |
339 | 0 | void Operation::updateOrderIfNecessary() { |
340 | 0 | assert(block && "expected valid parent"); |
341 | 0 |
|
342 | 0 | // If the order is valid for this operation there is nothing to do. |
343 | 0 | if (hasValidOrder()) |
344 | 0 | return; |
345 | 0 | Operation *blockFront = &block->front(); |
346 | 0 | Operation *blockBack = &block->back(); |
347 | 0 |
|
348 | 0 | // This method is expected to only be invoked on blocks with more than one |
349 | 0 | // operation. |
350 | 0 | assert(blockFront != blockBack && "expected more than one operation"); |
351 | 0 |
|
352 | 0 | // If the operation is at the end of the block. |
353 | 0 | if (this == blockBack) { |
354 | 0 | Operation *prevNode = getPrevNode(); |
355 | 0 | if (!prevNode->hasValidOrder()) |
356 | 0 | return block->recomputeOpOrder(); |
357 | 0 | |
358 | 0 | // Add the stride to the previous operation. |
359 | 0 | orderIndex = prevNode->orderIndex + kOrderStride; |
360 | 0 | return; |
361 | 0 | } |
362 | 0 | |
363 | 0 | // If this is the first operation try to use the next operation to compute the |
364 | 0 | // ordering. |
365 | 0 | if (this == blockFront) { |
366 | 0 | Operation *nextNode = getNextNode(); |
367 | 0 | if (!nextNode->hasValidOrder()) |
368 | 0 | return block->recomputeOpOrder(); |
369 | 0 | // There is no order to give this operation. |
370 | 0 | if (nextNode->orderIndex == 0) |
371 | 0 | return block->recomputeOpOrder(); |
372 | 0 | |
373 | 0 | // If we can't use the stride, just take the middle value left. This is safe |
374 | 0 | // because we know there is at least one valid index to assign to. |
375 | 0 | if (nextNode->orderIndex <= kOrderStride) |
376 | 0 | orderIndex = (nextNode->orderIndex / 2); |
377 | 0 | else |
378 | 0 | orderIndex = kOrderStride; |
379 | 0 | return; |
380 | 0 | } |
381 | 0 |
|
382 | 0 | // Otherwise, this operation is between two others. Place this operation in |
383 | 0 | // the middle of the previous and next if possible. |
384 | 0 | Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); |
385 | 0 | if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) |
386 | 0 | return block->recomputeOpOrder(); |
387 | 0 | unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; |
388 | 0 |
|
389 | 0 | // Check to see if there is a valid order between the two. |
390 | 0 | if (prevOrder + 1 == nextOrder) |
391 | 0 | return block->recomputeOpOrder(); |
392 | 0 | orderIndex = prevOrder + 1 + ((nextOrder - prevOrder) / 2); |
393 | 0 | } |
394 | | |
395 | | //===----------------------------------------------------------------------===// |
396 | | // ilist_traits for Operation |
397 | | //===----------------------------------------------------------------------===// |
398 | | |
399 | | auto llvm::ilist_detail::SpecificNodeAccess< |
400 | | typename llvm::ilist_detail::compute_node_options< |
401 | 0 | ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * { |
402 | 0 | return NodeAccess::getNodePtr<OptionsT>(N); |
403 | 0 | } |
404 | | |
405 | | auto llvm::ilist_detail::SpecificNodeAccess< |
406 | | typename llvm::ilist_detail::compute_node_options< |
407 | | ::mlir::Operation>::type>::getNodePtr(const_pointer N) |
408 | 0 | -> const node_type * { |
409 | 0 | return NodeAccess::getNodePtr<OptionsT>(N); |
410 | 0 | } |
411 | | |
412 | | auto llvm::ilist_detail::SpecificNodeAccess< |
413 | | typename llvm::ilist_detail::compute_node_options< |
414 | 0 | ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer { |
415 | 0 | return NodeAccess::getValuePtr<OptionsT>(N); |
416 | 0 | } |
417 | | |
418 | | auto llvm::ilist_detail::SpecificNodeAccess< |
419 | | typename llvm::ilist_detail::compute_node_options< |
420 | | ::mlir::Operation>::type>::getValuePtr(const node_type *N) |
421 | 0 | -> const_pointer { |
422 | 0 | return NodeAccess::getValuePtr<OptionsT>(N); |
423 | 0 | } |
424 | | |
425 | 0 | void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { |
426 | 0 | op->destroy(); |
427 | 0 | } |
428 | | |
429 | 0 | Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { |
430 | 0 | size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); |
431 | 0 | iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this)); |
432 | 0 | return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); |
433 | 0 | } |
434 | | |
435 | | /// This is a trait method invoked when an operation is added to a block. We |
436 | | /// keep the block pointer up to date. |
437 | 0 | void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { |
438 | 0 | assert(!op->getBlock() && "already in an operation block!"); |
439 | 0 | op->block = getContainingBlock(); |
440 | 0 |
|
441 | 0 | // Invalidate the order on the operation. |
442 | 0 | op->orderIndex = Operation::kInvalidOrderIdx; |
443 | 0 | } |
444 | | |
445 | | /// This is a trait method invoked when an operation is removed from a block. |
446 | | /// We keep the block pointer up to date. |
447 | 0 | void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { |
448 | 0 | assert(op->block && "not already in an operation block!"); |
449 | 0 | op->block = nullptr; |
450 | 0 | } |
451 | | |
452 | | /// This is a trait method invoked when an operation is moved from one block |
453 | | /// to another. We keep the block pointer up to date. |
454 | | void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( |
455 | 0 | ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { |
456 | 0 | Block *curParent = getContainingBlock(); |
457 | 0 |
|
458 | 0 | // Invalidate the ordering of the parent block. |
459 | 0 | curParent->invalidateOpOrder(); |
460 | 0 |
|
461 | 0 | // If we are transferring operations within the same block, the block |
462 | 0 | // pointer doesn't need to be updated. |
463 | 0 | if (curParent == otherList.getContainingBlock()) |
464 | 0 | return; |
465 | 0 | |
466 | 0 | // Update the 'block' member of each operation. |
467 | 0 | for (; first != last; ++first) |
468 | 0 | first->block = curParent; |
469 | 0 | } |
470 | | |
471 | | /// Remove this operation (and its descendants) from its Block and delete |
472 | | /// all of them. |
473 | 0 | void Operation::erase() { |
474 | 0 | if (auto *parent = getBlock()) |
475 | 0 | parent->getOperations().erase(this); |
476 | 0 | else |
477 | 0 | destroy(); |
478 | 0 | } |
479 | | |
480 | | /// Unlink this operation from its current block and insert it right before |
481 | | /// `existingOp` which may be in the same or another block in the same |
482 | | /// function. |
483 | 0 | void Operation::moveBefore(Operation *existingOp) { |
484 | 0 | moveBefore(existingOp->getBlock(), existingOp->getIterator()); |
485 | 0 | } |
486 | | |
487 | | /// Unlink this operation from its current basic block and insert it right |
488 | | /// before `iterator` in the specified basic block. |
489 | | void Operation::moveBefore(Block *block, |
490 | 0 | llvm::iplist<Operation>::iterator iterator) { |
491 | 0 | block->getOperations().splice(iterator, getBlock()->getOperations(), |
492 | 0 | getIterator()); |
493 | 0 | } |
494 | | |
495 | | /// Unlink this operation from its current block and insert it right after |
496 | | /// `existingOp` which may be in the same or another block in the same function. |
497 | 0 | void Operation::moveAfter(Operation *existingOp) { |
498 | 0 | moveAfter(existingOp->getBlock(), existingOp->getIterator()); |
499 | 0 | } |
500 | | |
501 | | /// Unlink this operation from its current block and insert it right after |
502 | | /// `iterator` in the specified block. |
503 | | void Operation::moveAfter(Block *block, |
504 | 0 | llvm::iplist<Operation>::iterator iterator) { |
505 | 0 | assert(iterator != block->end() && "cannot move after end of block"); |
506 | 0 | moveBefore(&*std::next(iterator)); |
507 | 0 | } |
508 | | |
509 | | /// This drops all operand uses from this operation, which is an essential |
510 | | /// step in breaking cyclic dependences between references when they are to |
511 | | /// be deleted. |
512 | 0 | void Operation::dropAllReferences() { |
513 | 0 | for (auto &op : getOpOperands()) |
514 | 0 | op.drop(); |
515 | 0 |
|
516 | 0 | for (auto ®ion : getRegions()) |
517 | 0 | region.dropAllReferences(); |
518 | 0 |
|
519 | 0 | for (auto &dest : getBlockOperands()) |
520 | 0 | dest.drop(); |
521 | 0 | } |
522 | | |
523 | | /// This drops all uses of any values defined by this operation or its nested |
524 | | /// regions, wherever they are located. |
525 | 0 | void Operation::dropAllDefinedValueUses() { |
526 | 0 | dropAllUses(); |
527 | 0 |
|
528 | 0 | for (auto ®ion : getRegions()) |
529 | 0 | for (auto &block : region) |
530 | 0 | block.dropAllDefinedValueUses(); |
531 | 0 | } |
532 | | |
533 | | /// Return the number of results held by this operation. |
534 | 0 | unsigned Operation::getNumResults() { |
535 | 0 | if (!resultType) |
536 | 0 | return 0; |
537 | 0 | return hasSingleResult ? 1 : resultType.cast<TupleType>().size(); |
538 | 0 | } |
539 | | |
540 | 0 | auto Operation::getResultTypes() -> result_type_range { |
541 | 0 | if (!resultType) |
542 | 0 | return llvm::None; |
543 | 0 | if (hasSingleResult) |
544 | 0 | return resultType; |
545 | 0 | return resultType.cast<TupleType>().getTypes(); |
546 | 0 | } |
547 | | |
548 | 0 | void Operation::setSuccessor(Block *block, unsigned index) { |
549 | 0 | assert(index < getNumSuccessors()); |
550 | 0 | getBlockOperands()[index].set(block); |
551 | 0 | } |
552 | | |
553 | | /// Attempt to fold this operation using the Op's registered foldHook. |
554 | | LogicalResult Operation::fold(ArrayRef<Attribute> operands, |
555 | 0 | SmallVectorImpl<OpFoldResult> &results) { |
556 | 0 | // If we have a registered operation definition matching this one, use it to |
557 | 0 | // try to constant fold the operation. |
558 | 0 | auto *abstractOp = getAbstractOperation(); |
559 | 0 | if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results))) |
560 | 0 | return success(); |
561 | 0 | |
562 | 0 | // Otherwise, fall back on the dialect hook to handle it. |
563 | 0 | Dialect *dialect = getDialect(); |
564 | 0 | if (!dialect) |
565 | 0 | return failure(); |
566 | 0 | |
567 | 0 | SmallVector<Attribute, 8> constants; |
568 | 0 | if (failed(dialect->constantFoldHook(this, operands, constants))) |
569 | 0 | return failure(); |
570 | 0 | results.assign(constants.begin(), constants.end()); |
571 | 0 | return success(); |
572 | 0 | } |
573 | | |
574 | | /// Emit an error with the op name prefixed, like "'dim' op " which is |
575 | | /// convenient for verifiers. |
576 | 0 | InFlightDiagnostic Operation::emitOpError(const Twine &message) { |
577 | 0 | return emitError() << "'" << getName() << "' op " << message; |
578 | 0 | } |
579 | | |
580 | | //===----------------------------------------------------------------------===// |
581 | | // Operation Cloning |
582 | | //===----------------------------------------------------------------------===// |
583 | | |
584 | | /// Create a deep copy of this operation but keep the operation regions empty. |
585 | | /// Operands are remapped using `mapper` (if present), and `mapper` is updated |
586 | | /// to contain the results. |
587 | 0 | Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { |
588 | 0 | SmallVector<Value, 8> operands; |
589 | 0 | SmallVector<Block *, 2> successors; |
590 | 0 |
|
591 | 0 | // Remap the operands. |
592 | 0 | operands.reserve(getNumOperands()); |
593 | 0 | for (auto opValue : getOperands()) |
594 | 0 | operands.push_back(mapper.lookupOrDefault(opValue)); |
595 | 0 |
|
596 | 0 | // Remap the successors. |
597 | 0 | successors.reserve(getNumSuccessors()); |
598 | 0 | for (Block *successor : getSuccessors()) |
599 | 0 | successors.push_back(mapper.lookupOrDefault(successor)); |
600 | 0 |
|
601 | 0 | // Create the new operation. |
602 | 0 | auto *newOp = Operation::create(getLoc(), getName(), getResultTypes(), |
603 | 0 | operands, attrs, successors, getNumRegions()); |
604 | 0 |
|
605 | 0 | // Remember the mapping of any results. |
606 | 0 | for (unsigned i = 0, e = getNumResults(); i != e; ++i) |
607 | 0 | mapper.map(getResult(i), newOp->getResult(i)); |
608 | 0 |
|
609 | 0 | return newOp; |
610 | 0 | } |
611 | | |
612 | 0 | Operation *Operation::cloneWithoutRegions() { |
613 | 0 | BlockAndValueMapping mapper; |
614 | 0 | return cloneWithoutRegions(mapper); |
615 | 0 | } |
616 | | |
617 | | /// Create a deep copy of this operation, remapping any operands that use |
618 | | /// values outside of the operation using the map that is provided (leaving |
619 | | /// them alone if no entry is present). Replaces references to cloned |
620 | | /// sub-operations to the corresponding operation that is copied, and adds |
621 | | /// those mappings to the map. |
622 | 0 | Operation *Operation::clone(BlockAndValueMapping &mapper) { |
623 | 0 | auto *newOp = cloneWithoutRegions(mapper); |
624 | 0 |
|
625 | 0 | // Clone the regions. |
626 | 0 | for (unsigned i = 0; i != numRegions; ++i) |
627 | 0 | getRegion(i).cloneInto(&newOp->getRegion(i), mapper); |
628 | 0 |
|
629 | 0 | return newOp; |
630 | 0 | } |
631 | | |
632 | 0 | Operation *Operation::clone() { |
633 | 0 | BlockAndValueMapping mapper; |
634 | 0 | return clone(mapper); |
635 | 0 | } |
636 | | |
637 | | //===----------------------------------------------------------------------===// |
638 | | // OpState trait class. |
639 | | //===----------------------------------------------------------------------===// |
640 | | |
641 | | // The fallback for the parser is to reject the custom assembly form. |
642 | 0 | ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { |
643 | 0 | return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); |
644 | 0 | } |
645 | | |
646 | | // The fallback for the printer is to print in the generic assembly form. |
647 | 0 | void OpState::print(OpAsmPrinter &p) { p.printGenericOp(getOperation()); } |
648 | | |
649 | | /// Emit an error about fatal conditions with this operation, reporting up to |
650 | | /// any diagnostic handlers that may be listening. |
651 | 0 | InFlightDiagnostic OpState::emitError(const Twine &message) { |
652 | 0 | return getOperation()->emitError(message); |
653 | 0 | } |
654 | | |
655 | | /// Emit an error with the op name prefixed, like "'dim' op " which is |
656 | | /// convenient for verifiers. |
657 | 0 | InFlightDiagnostic OpState::emitOpError(const Twine &message) { |
658 | 0 | return getOperation()->emitOpError(message); |
659 | 0 | } |
660 | | |
661 | | /// Emit a warning about this operation, reporting up to any diagnostic |
662 | | /// handlers that may be listening. |
663 | 0 | InFlightDiagnostic OpState::emitWarning(const Twine &message) { |
664 | 0 | return getOperation()->emitWarning(message); |
665 | 0 | } |
666 | | |
667 | | /// Emit a remark about this operation, reporting up to any diagnostic |
668 | | /// handlers that may be listening. |
669 | 0 | InFlightDiagnostic OpState::emitRemark(const Twine &message) { |
670 | 0 | return getOperation()->emitRemark(message); |
671 | 0 | } |
672 | | |
673 | | //===----------------------------------------------------------------------===// |
674 | | // Op Trait implementations |
675 | | //===----------------------------------------------------------------------===// |
676 | | |
677 | 0 | LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { |
678 | 0 | if (op->getNumOperands() != 0) |
679 | 0 | return op->emitOpError() << "requires zero operands"; |
680 | 0 | return success(); |
681 | 0 | } |
682 | | |
683 | 0 | LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { |
684 | 0 | if (op->getNumOperands() != 1) |
685 | 0 | return op->emitOpError() << "requires a single operand"; |
686 | 0 | return success(); |
687 | 0 | } |
688 | | |
689 | | LogicalResult OpTrait::impl::verifyNOperands(Operation *op, |
690 | 0 | unsigned numOperands) { |
691 | 0 | if (op->getNumOperands() != numOperands) { |
692 | 0 | return op->emitOpError() << "expected " << numOperands |
693 | 0 | << " operands, but found " << op->getNumOperands(); |
694 | 0 | } |
695 | 0 | return success(); |
696 | 0 | } |
697 | | |
698 | | LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, |
699 | 0 | unsigned numOperands) { |
700 | 0 | if (op->getNumOperands() < numOperands) |
701 | 0 | return op->emitOpError() |
702 | 0 | << "expected " << numOperands << " or more operands"; |
703 | 0 | return success(); |
704 | 0 | } |
705 | | |
706 | | /// If this is a vector type, or a tensor type, return the scalar element type |
707 | | /// that it is built around, otherwise return the type unmodified. |
708 | 0 | static Type getTensorOrVectorElementType(Type type) { |
709 | 0 | if (auto vec = type.dyn_cast<VectorType>()) |
710 | 0 | return vec.getElementType(); |
711 | 0 | |
712 | 0 | // Look through tensor<vector<...>> to find the underlying element type. |
713 | 0 | if (auto tensor = type.dyn_cast<TensorType>()) |
714 | 0 | return getTensorOrVectorElementType(tensor.getElementType()); |
715 | 0 | return type; |
716 | 0 | } |
717 | | |
718 | | LogicalResult |
719 | 0 | OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { |
720 | 0 | for (auto opType : op->getOperandTypes()) { |
721 | 0 | auto type = getTensorOrVectorElementType(opType); |
722 | 0 | if (!type.isSignlessIntOrIndex()) |
723 | 0 | return op->emitOpError() << "requires an integer or index type"; |
724 | 0 | } |
725 | 0 | return success(); |
726 | 0 | } |
727 | | |
728 | 0 | LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { |
729 | 0 | for (auto opType : op->getOperandTypes()) { |
730 | 0 | auto type = getTensorOrVectorElementType(opType); |
731 | 0 | if (!type.isa<FloatType>()) |
732 | 0 | return op->emitOpError("requires a float type"); |
733 | 0 | } |
734 | 0 | return success(); |
735 | 0 | } |
736 | | |
737 | 0 | LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { |
738 | 0 | // Zero or one operand always have the "same" type. |
739 | 0 | unsigned nOperands = op->getNumOperands(); |
740 | 0 | if (nOperands < 2) |
741 | 0 | return success(); |
742 | 0 | |
743 | 0 | auto type = op->getOperand(0).getType(); |
744 | 0 | for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) |
745 | 0 | if (opType != type) |
746 | 0 | return op->emitOpError() << "requires all operands to have the same type"; |
747 | 0 | return success(); |
748 | 0 | } |
749 | | |
750 | 0 | LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { |
751 | 0 | if (op->getNumRegions() != 0) |
752 | 0 | return op->emitOpError() << "requires zero regions"; |
753 | 0 | return success(); |
754 | 0 | } |
755 | | |
756 | 0 | LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { |
757 | 0 | if (op->getNumRegions() != 1) |
758 | 0 | return op->emitOpError() << "requires one region"; |
759 | 0 | return success(); |
760 | 0 | } |
761 | | |
762 | | LogicalResult OpTrait::impl::verifyNRegions(Operation *op, |
763 | 0 | unsigned numRegions) { |
764 | 0 | if (op->getNumRegions() != numRegions) |
765 | 0 | return op->emitOpError() << "expected " << numRegions << " regions"; |
766 | 0 | return success(); |
767 | 0 | } |
768 | | |
769 | | LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, |
770 | 0 | unsigned numRegions) { |
771 | 0 | if (op->getNumRegions() < numRegions) |
772 | 0 | return op->emitOpError() << "expected " << numRegions << " or more regions"; |
773 | 0 | return success(); |
774 | 0 | } |
775 | | |
776 | 0 | LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { |
777 | 0 | if (op->getNumResults() != 0) |
778 | 0 | return op->emitOpError() << "requires zero results"; |
779 | 0 | return success(); |
780 | 0 | } |
781 | | |
782 | 0 | LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { |
783 | 0 | if (op->getNumResults() != 1) |
784 | 0 | return op->emitOpError() << "requires one result"; |
785 | 0 | return success(); |
786 | 0 | } |
787 | | |
788 | | LogicalResult OpTrait::impl::verifyNResults(Operation *op, |
789 | 0 | unsigned numOperands) { |
790 | 0 | if (op->getNumResults() != numOperands) |
791 | 0 | return op->emitOpError() << "expected " << numOperands << " results"; |
792 | 0 | return success(); |
793 | 0 | } |
794 | | |
795 | | LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, |
796 | 0 | unsigned numOperands) { |
797 | 0 | if (op->getNumResults() < numOperands) |
798 | 0 | return op->emitOpError() |
799 | 0 | << "expected " << numOperands << " or more results"; |
800 | 0 | return success(); |
801 | 0 | } |
802 | | |
803 | 0 | LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { |
804 | 0 | if (failed(verifyAtLeastNOperands(op, 1))) |
805 | 0 | return failure(); |
806 | 0 | |
807 | 0 | auto type = op->getOperand(0).getType(); |
808 | 0 | for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { |
809 | 0 | if (failed(verifyCompatibleShape(opType, type))) |
810 | 0 | return op->emitOpError() << "requires the same shape for all operands"; |
811 | 0 | } |
812 | 0 | return success(); |
813 | 0 | } |
814 | | |
815 | 0 | LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { |
816 | 0 | if (failed(verifyAtLeastNOperands(op, 1)) || |
817 | 0 | failed(verifyAtLeastNResults(op, 1))) |
818 | 0 | return failure(); |
819 | 0 | |
820 | 0 | auto type = op->getOperand(0).getType(); |
821 | 0 | for (auto resultType : op->getResultTypes()) { |
822 | 0 | if (failed(verifyCompatibleShape(resultType, type))) |
823 | 0 | return op->emitOpError() |
824 | 0 | << "requires the same shape for all operands and results"; |
825 | 0 | } |
826 | 0 | for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { |
827 | 0 | if (failed(verifyCompatibleShape(opType, type))) |
828 | 0 | return op->emitOpError() |
829 | 0 | << "requires the same shape for all operands and results"; |
830 | 0 | } |
831 | 0 | return success(); |
832 | 0 | } |
833 | | |
834 | 0 | LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { |
835 | 0 | if (failed(verifyAtLeastNOperands(op, 1))) |
836 | 0 | return failure(); |
837 | 0 | auto elementType = getElementTypeOrSelf(op->getOperand(0)); |
838 | 0 |
|
839 | 0 | for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { |
840 | 0 | if (getElementTypeOrSelf(operand) != elementType) |
841 | 0 | return op->emitOpError("requires the same element type for all operands"); |
842 | 0 | } |
843 | 0 |
|
844 | 0 | return success(); |
845 | 0 | } |
846 | | |
847 | | LogicalResult |
848 | 0 | OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { |
849 | 0 | if (failed(verifyAtLeastNOperands(op, 1)) || |
850 | 0 | failed(verifyAtLeastNResults(op, 1))) |
851 | 0 | return failure(); |
852 | 0 | |
853 | 0 | auto elementType = getElementTypeOrSelf(op->getResult(0)); |
854 | 0 |
|
855 | 0 | // Verify result element type matches first result's element type. |
856 | 0 | for (auto result : llvm::drop_begin(op->getResults(), 1)) { |
857 | 0 | if (getElementTypeOrSelf(result) != elementType) |
858 | 0 | return op->emitOpError( |
859 | 0 | "requires the same element type for all operands and results"); |
860 | 0 | } |
861 | 0 |
|
862 | 0 | // Verify operand's element type matches first result's element type. |
863 | 0 | for (auto operand : op->getOperands()) { |
864 | 0 | if (getElementTypeOrSelf(operand) != elementType) |
865 | 0 | return op->emitOpError( |
866 | 0 | "requires the same element type for all operands and results"); |
867 | 0 | } |
868 | 0 |
|
869 | 0 | return success(); |
870 | 0 | } |
871 | | |
872 | 0 | LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { |
873 | 0 | if (failed(verifyAtLeastNOperands(op, 1)) || |
874 | 0 | failed(verifyAtLeastNResults(op, 1))) |
875 | 0 | return failure(); |
876 | 0 | |
877 | 0 | auto type = op->getResult(0).getType(); |
878 | 0 | auto elementType = getElementTypeOrSelf(type); |
879 | 0 | for (auto resultType : op->getResultTypes().drop_front(1)) { |
880 | 0 | if (getElementTypeOrSelf(resultType) != elementType || |
881 | 0 | failed(verifyCompatibleShape(resultType, type))) |
882 | 0 | return op->emitOpError() |
883 | 0 | << "requires the same type for all operands and results"; |
884 | 0 | } |
885 | 0 | for (auto opType : op->getOperandTypes()) { |
886 | 0 | if (getElementTypeOrSelf(opType) != elementType || |
887 | 0 | failed(verifyCompatibleShape(opType, type))) |
888 | 0 | return op->emitOpError() |
889 | 0 | << "requires the same type for all operands and results"; |
890 | 0 | } |
891 | 0 | return success(); |
892 | 0 | } |
893 | | |
894 | 0 | LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { |
895 | 0 | Block *block = op->getBlock(); |
896 | 0 | // Verify that the operation is at the end of the respective parent block. |
897 | 0 | if (!block || &block->back() != op) |
898 | 0 | return op->emitOpError("must be the last operation in the parent block"); |
899 | 0 | return success(); |
900 | 0 | } |
901 | | |
902 | 0 | static LogicalResult verifyTerminatorSuccessors(Operation *op) { |
903 | 0 | auto *parent = op->getParentRegion(); |
904 | 0 |
|
905 | 0 | // Verify that the operands lines up with the BB arguments in the successor. |
906 | 0 | for (Block *succ : op->getSuccessors()) |
907 | 0 | if (succ->getParent() != parent) |
908 | 0 | return op->emitError("reference to block defined in another region"); |
909 | 0 | return success(); |
910 | 0 | } |
911 | | |
912 | 0 | LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { |
913 | 0 | if (op->getNumSuccessors() != 0) { |
914 | 0 | return op->emitOpError("requires 0 successors but found ") |
915 | 0 | << op->getNumSuccessors(); |
916 | 0 | } |
917 | 0 | return success(); |
918 | 0 | } |
919 | | |
920 | 0 | LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { |
921 | 0 | if (op->getNumSuccessors() != 1) { |
922 | 0 | return op->emitOpError("requires 1 successor but found ") |
923 | 0 | << op->getNumSuccessors(); |
924 | 0 | } |
925 | 0 | return verifyTerminatorSuccessors(op); |
926 | 0 | } |
927 | | LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, |
928 | 0 | unsigned numSuccessors) { |
929 | 0 | if (op->getNumSuccessors() != numSuccessors) { |
930 | 0 | return op->emitOpError("requires ") |
931 | 0 | << numSuccessors << " successors but found " |
932 | 0 | << op->getNumSuccessors(); |
933 | 0 | } |
934 | 0 | return verifyTerminatorSuccessors(op); |
935 | 0 | } |
936 | | LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, |
937 | 0 | unsigned numSuccessors) { |
938 | 0 | if (op->getNumSuccessors() < numSuccessors) { |
939 | 0 | return op->emitOpError("requires at least ") |
940 | 0 | << numSuccessors << " successors but found " |
941 | 0 | << op->getNumSuccessors(); |
942 | 0 | } |
943 | 0 | return verifyTerminatorSuccessors(op); |
944 | 0 | } |
945 | | |
946 | 0 | LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { |
947 | 0 | for (auto resultType : op->getResultTypes()) { |
948 | 0 | auto elementType = getTensorOrVectorElementType(resultType); |
949 | 0 | bool isBoolType = elementType.isInteger(1); |
950 | 0 | if (!isBoolType) |
951 | 0 | return op->emitOpError() << "requires a bool result type"; |
952 | 0 | } |
953 | 0 |
|
954 | 0 | return success(); |
955 | 0 | } |
956 | | |
957 | 0 | LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { |
958 | 0 | for (auto resultType : op->getResultTypes()) |
959 | 0 | if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) |
960 | 0 | return op->emitOpError() << "requires a floating point type"; |
961 | 0 |
|
962 | 0 | return success(); |
963 | 0 | } |
964 | | |
965 | | LogicalResult |
966 | 0 | OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { |
967 | 0 | for (auto resultType : op->getResultTypes()) |
968 | 0 | if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) |
969 | 0 | return op->emitOpError() << "requires an integer or index type"; |
970 | 0 | return success(); |
971 | 0 | } |
972 | | |
973 | | static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, |
974 | 0 | bool isOperand) { |
975 | 0 | auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); |
976 | 0 | if (!sizeAttr) |
977 | 0 | return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; |
978 | 0 | |
979 | 0 | auto sizeAttrType = sizeAttr.getType().dyn_cast<VectorType>(); |
980 | 0 | if (!sizeAttrType || sizeAttrType.getRank() != 1) |
981 | 0 | return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; |
982 | 0 | |
983 | 0 | if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) { |
984 | 0 | return !element.isNonNegative(); |
985 | 0 | })) |
986 | 0 | return op->emitOpError("'") |
987 | 0 | << attrName << "' attribute cannot have negative elements"; |
988 | 0 | |
989 | 0 | size_t totalCount = std::accumulate( |
990 | 0 | sizeAttr.begin(), sizeAttr.end(), 0, |
991 | 0 | [](unsigned all, APInt one) { return all + one.getZExtValue(); }); |
992 | 0 |
|
993 | 0 | if (isOperand && totalCount != op->getNumOperands()) |
994 | 0 | return op->emitOpError("operand count (") |
995 | 0 | << op->getNumOperands() << ") does not match with the total size (" |
996 | 0 | << totalCount << ") specified in attribute '" << attrName << "'"; |
997 | 0 | else if (!isOperand && totalCount != op->getNumResults()) |
998 | 0 | return op->emitOpError("result count (") |
999 | 0 | << op->getNumResults() << ") does not match with the total size (" |
1000 | 0 | << totalCount << ") specified in attribute '" << attrName << "'"; |
1001 | 0 | return success(); |
1002 | 0 | } |
1003 | | |
1004 | | LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, |
1005 | 0 | StringRef attrName) { |
1006 | 0 | return verifyValueSizeAttr(op, attrName, /*isOperand=*/true); |
1007 | 0 | } |
1008 | | |
1009 | | LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, |
1010 | 0 | StringRef attrName) { |
1011 | 0 | return verifyValueSizeAttr(op, attrName, /*isOperand=*/false); |
1012 | 0 | } |
1013 | | |
1014 | | //===----------------------------------------------------------------------===// |
1015 | | // BinaryOp implementation |
1016 | | //===----------------------------------------------------------------------===// |
1017 | | |
1018 | | // These functions are out-of-line implementations of the methods in BinaryOp, |
1019 | | // which avoids them being template instantiated/duplicated. |
1020 | | |
1021 | | void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, |
1022 | 0 | Value rhs) { |
1023 | 0 | assert(lhs.getType() == rhs.getType()); |
1024 | 0 | result.addOperands({lhs, rhs}); |
1025 | 0 | result.types.push_back(lhs.getType()); |
1026 | 0 | } |
1027 | | |
1028 | | ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, |
1029 | 0 | OperationState &result) { |
1030 | 0 | SmallVector<OpAsmParser::OperandType, 2> ops; |
1031 | 0 | Type type; |
1032 | 0 | return failure(parser.parseOperandList(ops) || |
1033 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
1034 | 0 | parser.parseColonType(type) || |
1035 | 0 | parser.resolveOperands(ops, type, result.operands) || |
1036 | 0 | parser.addTypeToList(type, result.types)); |
1037 | 0 | } |
1038 | | |
1039 | 0 | void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { |
1040 | 0 | assert(op->getNumResults() == 1 && "op should have one result"); |
1041 | 0 |
|
1042 | 0 | // If not all the operand and result types are the same, just use the |
1043 | 0 | // generic assembly form to avoid omitting information in printing. |
1044 | 0 | auto resultType = op->getResult(0).getType(); |
1045 | 0 | if (llvm::any_of(op->getOperandTypes(), |
1046 | 0 | [&](Type type) { return type != resultType; })) { |
1047 | 0 | p.printGenericOp(op); |
1048 | 0 | return; |
1049 | 0 | } |
1050 | 0 | |
1051 | 0 | p << op->getName() << ' '; |
1052 | 0 | p.printOperands(op->getOperands()); |
1053 | 0 | p.printOptionalAttrDict(op->getAttrs()); |
1054 | 0 | // Now we can output only one type for all operands and the result. |
1055 | 0 | p << " : " << resultType; |
1056 | 0 | } |
1057 | | |
1058 | | //===----------------------------------------------------------------------===// |
1059 | | // CastOp implementation |
1060 | | //===----------------------------------------------------------------------===// |
1061 | | |
1062 | | void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, |
1063 | 0 | Type destType) { |
1064 | 0 | result.addOperands(source); |
1065 | 0 | result.addTypes(destType); |
1066 | 0 | } |
1067 | | |
1068 | 0 | ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { |
1069 | 0 | OpAsmParser::OperandType srcInfo; |
1070 | 0 | Type srcType, dstType; |
1071 | 0 | return failure(parser.parseOperand(srcInfo) || |
1072 | 0 | parser.parseOptionalAttrDict(result.attributes) || |
1073 | 0 | parser.parseColonType(srcType) || |
1074 | 0 | parser.resolveOperand(srcInfo, srcType, result.operands) || |
1075 | 0 | parser.parseKeywordType("to", dstType) || |
1076 | 0 | parser.addTypeToList(dstType, result.types)); |
1077 | 0 | } |
1078 | | |
1079 | 0 | void impl::printCastOp(Operation *op, OpAsmPrinter &p) { |
1080 | 0 | p << op->getName() << ' ' << op->getOperand(0); |
1081 | 0 | p.printOptionalAttrDict(op->getAttrs()); |
1082 | 0 | p << " : " << op->getOperand(0).getType() << " to " |
1083 | 0 | << op->getResult(0).getType(); |
1084 | 0 | } |
1085 | | |
1086 | 0 | Value impl::foldCastOp(Operation *op) { |
1087 | 0 | // Identity cast |
1088 | 0 | if (op->getOperand(0).getType() == op->getResult(0).getType()) |
1089 | 0 | return op->getOperand(0); |
1090 | 0 | return nullptr; |
1091 | 0 | } |
1092 | | |
1093 | | //===----------------------------------------------------------------------===// |
1094 | | // Misc. utils |
1095 | | //===----------------------------------------------------------------------===// |
1096 | | |
1097 | | /// Insert an operation, generated by `buildTerminatorOp`, at the end of the |
1098 | | /// region's only block if it does not have a terminator already. If the region |
1099 | | /// is empty, insert a new block first. `buildTerminatorOp` should return the |
1100 | | /// terminator operation to insert. |
1101 | | void impl::ensureRegionTerminator( |
1102 | | Region ®ion, OpBuilder &builder, Location loc, |
1103 | 0 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { |
1104 | 0 | OpBuilder::InsertionGuard guard(builder); |
1105 | 0 | if (region.empty()) |
1106 | 0 | builder.createBlock(®ion); |
1107 | 0 |
|
1108 | 0 | Block &block = region.back(); |
1109 | 0 | if (!block.empty() && block.back().isKnownTerminator()) |
1110 | 0 | return; |
1111 | 0 | |
1112 | 0 | builder.setInsertionPointToEnd(&block); |
1113 | 0 | builder.insert(buildTerminatorOp(builder, loc)); |
1114 | 0 | } |
1115 | | |
1116 | | /// Create a simple OpBuilder and forward to the OpBuilder version of this |
1117 | | /// function. |
1118 | | void impl::ensureRegionTerminator( |
1119 | | Region ®ion, Builder &builder, Location loc, |
1120 | 0 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { |
1121 | 0 | OpBuilder opBuilder(builder.getContext()); |
1122 | 0 | ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); |
1123 | 0 | } |
1124 | | |
1125 | | //===----------------------------------------------------------------------===// |
1126 | | // UseIterator |
1127 | | //===----------------------------------------------------------------------===// |
1128 | | |
1129 | | Operation::UseIterator::UseIterator(Operation *op, bool end) |
1130 | 0 | : op(op), res(end ? op->result_end() : op->result_begin()) { |
1131 | 0 | // Only initialize current use if there are results/can be uses. |
1132 | 0 | if (op->getNumResults()) |
1133 | 0 | skipOverResultsWithNoUsers(); |
1134 | 0 | } |
1135 | | |
1136 | 0 | Operation::UseIterator &Operation::UseIterator::operator++() { |
1137 | 0 | // We increment over uses, if we reach the last use then move to next |
1138 | 0 | // result. |
1139 | 0 | if (use != (*res).use_end()) |
1140 | 0 | ++use; |
1141 | 0 | if (use == (*res).use_end()) { |
1142 | 0 | ++res; |
1143 | 0 | skipOverResultsWithNoUsers(); |
1144 | 0 | } |
1145 | 0 | return *this; |
1146 | 0 | } |
1147 | | |
1148 | 0 | void Operation::UseIterator::skipOverResultsWithNoUsers() { |
1149 | 0 | while (res != op->result_end() && (*res).use_empty()) |
1150 | 0 | ++res; |
1151 | 0 |
|
1152 | 0 | // If we are at the last result, then set use to first use of |
1153 | 0 | // first result (sentinel value used for end). |
1154 | 0 | if (res == op->result_end()) |
1155 | 0 | use = {}; |
1156 | 0 | else |
1157 | 0 | use = (*res).use_begin(); |
1158 | 0 | } |