/home/arjun/llvm-project/mlir/lib/IR/PatternMatch.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- PatternMatch.cpp - Base classes for pattern match ------------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #include "mlir/IR/PatternMatch.h" |
10 | | #include "mlir/IR/BlockAndValueMapping.h" |
11 | | #include "mlir/IR/Operation.h" |
12 | | #include "mlir/IR/Value.h" |
13 | | |
14 | | using namespace mlir; |
15 | | |
16 | 0 | PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { |
17 | 0 | assert(representation == benefit && benefit != ImpossibleToMatchSentinel && |
18 | 0 | "This pattern match benefit is too large to represent"); |
19 | 0 | } |
20 | | |
21 | 0 | unsigned short PatternBenefit::getBenefit() const { |
22 | 0 | assert(representation != ImpossibleToMatchSentinel && |
23 | 0 | "Pattern doesn't match"); |
24 | 0 | return representation; |
25 | 0 | } |
26 | | |
27 | | //===----------------------------------------------------------------------===// |
28 | | // Pattern implementation |
29 | | //===----------------------------------------------------------------------===// |
30 | | |
31 | | Pattern::Pattern(StringRef rootName, PatternBenefit benefit, |
32 | | MLIRContext *context) |
33 | 0 | : rootKind(OperationName(rootName, context)), benefit(benefit) {} |
34 | | |
35 | | // Out-of-line vtable anchor. |
36 | 0 | void Pattern::anchor() {} |
37 | | |
38 | | //===----------------------------------------------------------------------===// |
39 | | // RewritePattern and PatternRewriter implementation |
40 | | //===----------------------------------------------------------------------===// |
41 | | |
42 | 0 | void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { |
43 | 0 | llvm_unreachable("need to implement either matchAndRewrite or one of the " |
44 | 0 | "rewrite functions!"); |
45 | 0 | } |
46 | | |
47 | 0 | LogicalResult RewritePattern::match(Operation *op) const { |
48 | 0 | llvm_unreachable("need to implement either match or matchAndRewrite!"); |
49 | 0 | } |
50 | | |
51 | | /// Patterns must specify the root operation name they match against, and can |
52 | | /// also specify the benefit of the pattern matching. They can also specify the |
53 | | /// names of operations that may be generated during a successful rewrite. |
54 | | RewritePattern::RewritePattern(StringRef rootName, |
55 | | ArrayRef<StringRef> generatedNames, |
56 | | PatternBenefit benefit, MLIRContext *context) |
57 | 0 | : Pattern(rootName, benefit, context) { |
58 | 0 | generatedOps.reserve(generatedNames.size()); |
59 | 0 | std::transform(generatedNames.begin(), generatedNames.end(), |
60 | 0 | std::back_inserter(generatedOps), [context](StringRef name) { |
61 | 0 | return OperationName(name, context); |
62 | 0 | }); |
63 | 0 | } |
64 | | |
65 | 0 | PatternRewriter::~PatternRewriter() { |
66 | 0 | // Out of line to provide a vtable anchor for the class. |
67 | 0 | } |
68 | | |
69 | | /// This method performs the final replacement for a pattern, where the |
70 | | /// results of the operation are updated to use the specified list of SSA |
71 | | /// values. |
72 | 0 | void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) { |
73 | 0 | // Notify the rewriter subclass that we're about to replace this root. |
74 | 0 | notifyRootReplaced(op); |
75 | 0 |
|
76 | 0 | assert(op->getNumResults() == newValues.size() && |
77 | 0 | "incorrect # of replacement values"); |
78 | 0 | op->replaceAllUsesWith(newValues); |
79 | 0 |
|
80 | 0 | notifyOperationRemoved(op); |
81 | 0 | op->erase(); |
82 | 0 | } |
83 | | |
84 | | /// This method erases an operation that is known to have no uses. The uses of |
85 | | /// the given operation *must* be known to be dead. |
86 | 0 | void PatternRewriter::eraseOp(Operation *op) { |
87 | 0 | assert(op->use_empty() && "expected 'op' to have no uses"); |
88 | 0 | notifyOperationRemoved(op); |
89 | 0 | op->erase(); |
90 | 0 | } |
91 | | |
92 | 0 | void PatternRewriter::eraseBlock(Block *block) { |
93 | 0 | for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) { |
94 | 0 | assert(op.use_empty() && "expected 'op' to have no uses"); |
95 | 0 | eraseOp(&op); |
96 | 0 | } |
97 | 0 | block->erase(); |
98 | 0 | } |
99 | | |
100 | | /// Merge the operations of block 'source' into the end of block 'dest'. |
101 | | /// 'source's predecessors must be empty or only contain 'dest`. |
102 | | /// 'argValues' is used to replace the block arguments of 'source' after |
103 | | /// merging. |
104 | | void PatternRewriter::mergeBlocks(Block *source, Block *dest, |
105 | | ValueRange argValues) { |
106 | | assert(llvm::all_of(source->getPredecessors(), |
107 | | [dest](Block *succ) { return succ == dest; }) && |
108 | | "expected 'source' to have no predecessors or only 'dest'"); |
109 | | assert(argValues.size() == source->getNumArguments() && |
110 | | "incorrect # of argument replacement values"); |
111 | | |
112 | | // Replace all of the successor arguments with the provided values. |
113 | | for (auto it : llvm::zip(source->getArguments(), argValues)) |
114 | | std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); |
115 | | |
116 | | // Splice the operations of the 'source' block into the 'dest' block and erase |
117 | | // it. |
118 | | dest->getOperations().splice(dest->end(), source->getOperations()); |
119 | | source->dropAllUses(); |
120 | | source->erase(); |
121 | | } |
122 | | |
123 | | /// Split the operations starting at "before" (inclusive) out of the given |
124 | | /// block into a new block, and return it. |
125 | 0 | Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) { |
126 | 0 | return block->splitBlock(before); |
127 | 0 | } |
128 | | |
129 | | /// 'op' and 'newOp' are known to have the same number of results, replace the |
130 | | /// uses of op with uses of newOp |
131 | | void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op, |
132 | 0 | Operation *newOp) { |
133 | 0 | assert(op->getNumResults() == newOp->getNumResults() && |
134 | 0 | "replacement op doesn't match results of original op"); |
135 | 0 | if (op->getNumResults() == 1) |
136 | 0 | return replaceOp(op, newOp->getResult(0)); |
137 | 0 | return replaceOp(op, newOp->getResults()); |
138 | 0 | } |
139 | | |
140 | | /// Move the blocks that belong to "region" before the given position in |
141 | | /// another region. The two regions must be different. The caller is in |
142 | | /// charge to update create the operation transferring the control flow to the |
143 | | /// region and pass it the correct block arguments. |
144 | | void PatternRewriter::inlineRegionBefore(Region ®ion, Region &parent, |
145 | 0 | Region::iterator before) { |
146 | 0 | parent.getBlocks().splice(before, region.getBlocks()); |
147 | 0 | } |
148 | 0 | void PatternRewriter::inlineRegionBefore(Region ®ion, Block *before) { |
149 | 0 | inlineRegionBefore(region, *before->getParent(), before->getIterator()); |
150 | 0 | } |
151 | | |
152 | | /// Clone the blocks that belong to "region" before the given position in |
153 | | /// another region "parent". The two regions must be different. The caller is |
154 | | /// responsible for creating or updating the operation transferring flow of |
155 | | /// control to the region and passing it the correct block arguments. |
156 | | void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent, |
157 | | Region::iterator before, |
158 | 0 | BlockAndValueMapping &mapping) { |
159 | 0 | region.cloneInto(&parent, before, mapping); |
160 | 0 | } |
161 | | void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent, |
162 | 0 | Region::iterator before) { |
163 | 0 | BlockAndValueMapping mapping; |
164 | 0 | cloneRegionBefore(region, parent, before, mapping); |
165 | 0 | } |
166 | 0 | void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) { |
167 | 0 | cloneRegionBefore(region, *before->getParent(), before->getIterator()); |
168 | 0 | } |
169 | | |
170 | | //===----------------------------------------------------------------------===// |
171 | | // PatternMatcher implementation |
172 | | //===----------------------------------------------------------------------===// |
173 | | |
174 | | RewritePatternMatcher::RewritePatternMatcher( |
175 | 0 | const OwningRewritePatternList &patterns) { |
176 | 0 | for (auto &pattern : patterns) |
177 | 0 | this->patterns.push_back(pattern.get()); |
178 | 0 |
|
179 | 0 | // Sort the patterns by benefit to simplify the matching logic. |
180 | 0 | std::stable_sort(this->patterns.begin(), this->patterns.end(), |
181 | 0 | [](RewritePattern *l, RewritePattern *r) { |
182 | 0 | return r->getBenefit() < l->getBenefit(); |
183 | 0 | }); |
184 | 0 | } |
185 | | |
186 | | /// Try to match the given operation to a pattern and rewrite it. |
187 | | bool RewritePatternMatcher::matchAndRewrite(Operation *op, |
188 | 0 | PatternRewriter &rewriter) { |
189 | 0 | for (auto *pattern : patterns) { |
190 | 0 | // Ignore patterns that are for the wrong root or are impossible to match. |
191 | 0 | if (pattern->getRootKind() != op->getName() || |
192 | 0 | pattern->getBenefit().isImpossibleToMatch()) |
193 | 0 | continue; |
194 | 0 | |
195 | 0 | // Try to match and rewrite this pattern. The patterns are sorted by |
196 | 0 | // benefit, so if we match we can immediately rewrite and return. |
197 | 0 | if (succeeded(pattern->matchAndRewrite(op, rewriter))) |
198 | 0 | return true; |
199 | 0 | } |
200 | 0 | return false; |
201 | 0 | } |