/home/arjun/llvm-project/mlir/lib/IR/PatternMatch.cpp
| Line | Count | Source (jump to first uncovered line) | 
| 1 |  | //===- PatternMatch.cpp - Base classes for pattern match ------------------===// | 
| 2 |  | // | 
| 3 |  | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
| 4 |  | // See https://llvm.org/LICENSE.txt for license information. | 
| 5 |  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
| 6 |  | // | 
| 7 |  | //===----------------------------------------------------------------------===// | 
| 8 |  |  | 
| 9 |  | #include "mlir/IR/PatternMatch.h" | 
| 10 |  | #include "mlir/IR/BlockAndValueMapping.h" | 
| 11 |  | #include "mlir/IR/Operation.h" | 
| 12 |  | #include "mlir/IR/Value.h" | 
| 13 |  |  | 
| 14 |  | using namespace mlir; | 
| 15 |  |  | 
| 16 | 0 | PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { | 
| 17 | 0 |   assert(representation == benefit && benefit != ImpossibleToMatchSentinel && | 
| 18 | 0 |          "This pattern match benefit is too large to represent"); | 
| 19 | 0 | } | 
| 20 |  |  | 
| 21 | 0 | unsigned short PatternBenefit::getBenefit() const { | 
| 22 | 0 |   assert(representation != ImpossibleToMatchSentinel && | 
| 23 | 0 |          "Pattern doesn't match"); | 
| 24 | 0 |   return representation; | 
| 25 | 0 | } | 
| 26 |  |  | 
| 27 |  | //===----------------------------------------------------------------------===// | 
| 28 |  | // Pattern implementation | 
| 29 |  | //===----------------------------------------------------------------------===// | 
| 30 |  |  | 
| 31 |  | Pattern::Pattern(StringRef rootName, PatternBenefit benefit, | 
| 32 |  |                  MLIRContext *context) | 
| 33 | 0 |     : rootKind(OperationName(rootName, context)), benefit(benefit) {} | 
| 34 |  |  | 
| 35 |  | // Out-of-line vtable anchor. | 
| 36 | 0 | void Pattern::anchor() {} | 
| 37 |  |  | 
| 38 |  | //===----------------------------------------------------------------------===// | 
| 39 |  | // RewritePattern and PatternRewriter implementation | 
| 40 |  | //===----------------------------------------------------------------------===// | 
| 41 |  |  | 
| 42 | 0 | void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { | 
| 43 | 0 |   llvm_unreachable("need to implement either matchAndRewrite or one of the " | 
| 44 | 0 |                    "rewrite functions!"); | 
| 45 | 0 | } | 
| 46 |  |  | 
| 47 | 0 | LogicalResult RewritePattern::match(Operation *op) const { | 
| 48 | 0 |   llvm_unreachable("need to implement either match or matchAndRewrite!"); | 
| 49 | 0 | } | 
| 50 |  |  | 
| 51 |  | /// Patterns must specify the root operation name they match against, and can | 
| 52 |  | /// also specify the benefit of the pattern matching. They can also specify the | 
| 53 |  | /// names of operations that may be generated during a successful rewrite. | 
| 54 |  | RewritePattern::RewritePattern(StringRef rootName, | 
| 55 |  |                                ArrayRef<StringRef> generatedNames, | 
| 56 |  |                                PatternBenefit benefit, MLIRContext *context) | 
| 57 | 0 |     : Pattern(rootName, benefit, context) { | 
| 58 | 0 |   generatedOps.reserve(generatedNames.size()); | 
| 59 | 0 |   std::transform(generatedNames.begin(), generatedNames.end(), | 
| 60 | 0 |                  std::back_inserter(generatedOps), [context](StringRef name) { | 
| 61 | 0 |                    return OperationName(name, context); | 
| 62 | 0 |                  }); | 
| 63 | 0 | } | 
| 64 |  |  | 
| 65 | 0 | PatternRewriter::~PatternRewriter() { | 
| 66 | 0 |   // Out of line to provide a vtable anchor for the class. | 
| 67 | 0 | } | 
| 68 |  |  | 
| 69 |  | /// This method performs the final replacement for a pattern, where the | 
| 70 |  | /// results of the operation are updated to use the specified list of SSA | 
| 71 |  | /// values. | 
| 72 | 0 | void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) { | 
| 73 | 0 |   // Notify the rewriter subclass that we're about to replace this root. | 
| 74 | 0 |   notifyRootReplaced(op); | 
| 75 | 0 | 
 | 
| 76 | 0 |   assert(op->getNumResults() == newValues.size() && | 
| 77 | 0 |          "incorrect # of replacement values"); | 
| 78 | 0 |   op->replaceAllUsesWith(newValues); | 
| 79 | 0 | 
 | 
| 80 | 0 |   notifyOperationRemoved(op); | 
| 81 | 0 |   op->erase(); | 
| 82 | 0 | } | 
| 83 |  |  | 
| 84 |  | /// This method erases an operation that is known to have no uses. The uses of | 
| 85 |  | /// the given operation *must* be known to be dead. | 
| 86 | 0 | void PatternRewriter::eraseOp(Operation *op) { | 
| 87 | 0 |   assert(op->use_empty() && "expected 'op' to have no uses"); | 
| 88 | 0 |   notifyOperationRemoved(op); | 
| 89 | 0 |   op->erase(); | 
| 90 | 0 | } | 
| 91 |  |  | 
| 92 | 0 | void PatternRewriter::eraseBlock(Block *block) { | 
| 93 | 0 |   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) { | 
| 94 | 0 |     assert(op.use_empty() && "expected 'op' to have no uses"); | 
| 95 | 0 |     eraseOp(&op); | 
| 96 | 0 |   } | 
| 97 | 0 |   block->erase(); | 
| 98 | 0 | } | 
| 99 |  |  | 
| 100 |  | /// Merge the operations of block 'source' into the end of block 'dest'. | 
| 101 |  | /// 'source's predecessors must be empty or only contain 'dest`. | 
| 102 |  | /// 'argValues' is used to replace the block arguments of 'source' after | 
| 103 |  | /// merging. | 
| 104 |  | void PatternRewriter::mergeBlocks(Block *source, Block *dest, | 
| 105 |  |                                   ValueRange argValues) { | 
| 106 |  |   assert(llvm::all_of(source->getPredecessors(), | 
| 107 |  |                       [dest](Block *succ) { return succ == dest; }) && | 
| 108 |  |          "expected 'source' to have no predecessors or only 'dest'"); | 
| 109 |  |   assert(argValues.size() == source->getNumArguments() && | 
| 110 |  |          "incorrect # of argument replacement values"); | 
| 111 |  |  | 
| 112 |  |   // Replace all of the successor arguments with the provided values. | 
| 113 |  |   for (auto it : llvm::zip(source->getArguments(), argValues)) | 
| 114 |  |     std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); | 
| 115 |  |  | 
| 116 |  |   // Splice the operations of the 'source' block into the 'dest' block and erase | 
| 117 |  |   // it. | 
| 118 |  |   dest->getOperations().splice(dest->end(), source->getOperations()); | 
| 119 |  |   source->dropAllUses(); | 
| 120 |  |   source->erase(); | 
| 121 |  | } | 
| 122 |  |  | 
| 123 |  | /// Split the operations starting at "before" (inclusive) out of the given | 
| 124 |  | /// block into a new block, and return it. | 
| 125 | 0 | Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) { | 
| 126 | 0 |   return block->splitBlock(before); | 
| 127 | 0 | } | 
| 128 |  |  | 
| 129 |  | /// 'op' and 'newOp' are known to have the same number of results, replace the | 
| 130 |  | /// uses of op with uses of newOp | 
| 131 |  | void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op, | 
| 132 | 0 |                                                       Operation *newOp) { | 
| 133 | 0 |   assert(op->getNumResults() == newOp->getNumResults() && | 
| 134 | 0 |          "replacement op doesn't match results of original op"); | 
| 135 | 0 |   if (op->getNumResults() == 1) | 
| 136 | 0 |     return replaceOp(op, newOp->getResult(0)); | 
| 137 | 0 |   return replaceOp(op, newOp->getResults()); | 
| 138 | 0 | } | 
| 139 |  |  | 
| 140 |  | /// Move the blocks that belong to "region" before the given position in | 
| 141 |  | /// another region.  The two regions must be different.  The caller is in | 
| 142 |  | /// charge to update create the operation transferring the control flow to the | 
| 143 |  | /// region and pass it the correct block arguments. | 
| 144 |  | void PatternRewriter::inlineRegionBefore(Region ®ion, Region &parent, | 
| 145 | 0 |                                          Region::iterator before) { | 
| 146 | 0 |   parent.getBlocks().splice(before, region.getBlocks()); | 
| 147 | 0 | } | 
| 148 | 0 | void PatternRewriter::inlineRegionBefore(Region ®ion, Block *before) { | 
| 149 | 0 |   inlineRegionBefore(region, *before->getParent(), before->getIterator()); | 
| 150 | 0 | } | 
| 151 |  |  | 
| 152 |  | /// Clone the blocks that belong to "region" before the given position in | 
| 153 |  | /// another region "parent". The two regions must be different. The caller is | 
| 154 |  | /// responsible for creating or updating the operation transferring flow of | 
| 155 |  | /// control to the region and passing it the correct block arguments. | 
| 156 |  | void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent, | 
| 157 |  |                                         Region::iterator before, | 
| 158 | 0 |                                         BlockAndValueMapping &mapping) { | 
| 159 | 0 |   region.cloneInto(&parent, before, mapping); | 
| 160 | 0 | } | 
| 161 |  | void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent, | 
| 162 | 0 |                                         Region::iterator before) { | 
| 163 | 0 |   BlockAndValueMapping mapping; | 
| 164 | 0 |   cloneRegionBefore(region, parent, before, mapping); | 
| 165 | 0 | } | 
| 166 | 0 | void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) { | 
| 167 | 0 |   cloneRegionBefore(region, *before->getParent(), before->getIterator()); | 
| 168 | 0 | } | 
| 169 |  |  | 
| 170 |  | //===----------------------------------------------------------------------===// | 
| 171 |  | // PatternMatcher implementation | 
| 172 |  | //===----------------------------------------------------------------------===// | 
| 173 |  |  | 
| 174 |  | RewritePatternMatcher::RewritePatternMatcher( | 
| 175 | 0 |     const OwningRewritePatternList &patterns) { | 
| 176 | 0 |   for (auto &pattern : patterns) | 
| 177 | 0 |     this->patterns.push_back(pattern.get()); | 
| 178 | 0 | 
 | 
| 179 | 0 |   // Sort the patterns by benefit to simplify the matching logic. | 
| 180 | 0 |   std::stable_sort(this->patterns.begin(), this->patterns.end(), | 
| 181 | 0 |                    [](RewritePattern *l, RewritePattern *r) { | 
| 182 | 0 |                      return r->getBenefit() < l->getBenefit(); | 
| 183 | 0 |                    }); | 
| 184 | 0 | } | 
| 185 |  |  | 
| 186 |  | /// Try to match the given operation to a pattern and rewrite it. | 
| 187 |  | bool RewritePatternMatcher::matchAndRewrite(Operation *op, | 
| 188 | 0 |                                             PatternRewriter &rewriter) { | 
| 189 | 0 |   for (auto *pattern : patterns) { | 
| 190 | 0 |     // Ignore patterns that are for the wrong root or are impossible to match. | 
| 191 | 0 |     if (pattern->getRootKind() != op->getName() || | 
| 192 | 0 |         pattern->getBenefit().isImpossibleToMatch()) | 
| 193 | 0 |       continue; | 
| 194 | 0 |  | 
| 195 | 0 |     // Try to match and rewrite this pattern. The patterns are sorted by | 
| 196 | 0 |     // benefit, so if we match we can immediately rewrite and return. | 
| 197 | 0 |     if (succeeded(pattern->matchAndRewrite(op, rewriter))) | 
| 198 | 0 |       return true; | 
| 199 | 0 |   } | 
| 200 | 0 |   return false; | 
| 201 | 0 | } |