Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/lib/IR/PatternMatch.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "mlir/IR/PatternMatch.h"
10
#include "mlir/IR/BlockAndValueMapping.h"
11
#include "mlir/IR/Operation.h"
12
#include "mlir/IR/Value.h"
13
14
using namespace mlir;
15
16
0
PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
17
0
  assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
18
0
         "This pattern match benefit is too large to represent");
19
0
}
20
21
0
unsigned short PatternBenefit::getBenefit() const {
22
0
  assert(representation != ImpossibleToMatchSentinel &&
23
0
         "Pattern doesn't match");
24
0
  return representation;
25
0
}
26
27
//===----------------------------------------------------------------------===//
28
// Pattern implementation
29
//===----------------------------------------------------------------------===//
30
31
Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
32
                 MLIRContext *context)
33
0
    : rootKind(OperationName(rootName, context)), benefit(benefit) {}
34
35
// Out-of-line vtable anchor.
36
0
void Pattern::anchor() {}
37
38
//===----------------------------------------------------------------------===//
39
// RewritePattern and PatternRewriter implementation
40
//===----------------------------------------------------------------------===//
41
42
0
void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
43
0
  llvm_unreachable("need to implement either matchAndRewrite or one of the "
44
0
                   "rewrite functions!");
45
0
}
46
47
0
LogicalResult RewritePattern::match(Operation *op) const {
48
0
  llvm_unreachable("need to implement either match or matchAndRewrite!");
49
0
}
50
51
/// Patterns must specify the root operation name they match against, and can
52
/// also specify the benefit of the pattern matching. They can also specify the
53
/// names of operations that may be generated during a successful rewrite.
54
RewritePattern::RewritePattern(StringRef rootName,
55
                               ArrayRef<StringRef> generatedNames,
56
                               PatternBenefit benefit, MLIRContext *context)
57
0
    : Pattern(rootName, benefit, context) {
58
0
  generatedOps.reserve(generatedNames.size());
59
0
  std::transform(generatedNames.begin(), generatedNames.end(),
60
0
                 std::back_inserter(generatedOps), [context](StringRef name) {
61
0
                   return OperationName(name, context);
62
0
                 });
63
0
}
64
65
0
PatternRewriter::~PatternRewriter() {
66
0
  // Out of line to provide a vtable anchor for the class.
67
0
}
68
69
/// This method performs the final replacement for a pattern, where the
70
/// results of the operation are updated to use the specified list of SSA
71
/// values.
72
0
void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
73
0
  // Notify the rewriter subclass that we're about to replace this root.
74
0
  notifyRootReplaced(op);
75
0
76
0
  assert(op->getNumResults() == newValues.size() &&
77
0
         "incorrect # of replacement values");
78
0
  op->replaceAllUsesWith(newValues);
79
0
80
0
  notifyOperationRemoved(op);
81
0
  op->erase();
82
0
}
83
84
/// This method erases an operation that is known to have no uses. The uses of
85
/// the given operation *must* be known to be dead.
86
0
void PatternRewriter::eraseOp(Operation *op) {
87
0
  assert(op->use_empty() && "expected 'op' to have no uses");
88
0
  notifyOperationRemoved(op);
89
0
  op->erase();
90
0
}
91
92
0
void PatternRewriter::eraseBlock(Block *block) {
93
0
  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
94
0
    assert(op.use_empty() && "expected 'op' to have no uses");
95
0
    eraseOp(&op);
96
0
  }
97
0
  block->erase();
98
0
}
99
100
/// Merge the operations of block 'source' into the end of block 'dest'.
101
/// 'source's predecessors must be empty or only contain 'dest`.
102
/// 'argValues' is used to replace the block arguments of 'source' after
103
/// merging.
104
void PatternRewriter::mergeBlocks(Block *source, Block *dest,
105
                                  ValueRange argValues) {
106
  assert(llvm::all_of(source->getPredecessors(),
107
                      [dest](Block *succ) { return succ == dest; }) &&
108
         "expected 'source' to have no predecessors or only 'dest'");
109
  assert(argValues.size() == source->getNumArguments() &&
110
         "incorrect # of argument replacement values");
111
112
  // Replace all of the successor arguments with the provided values.
113
  for (auto it : llvm::zip(source->getArguments(), argValues))
114
    std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
115
116
  // Splice the operations of the 'source' block into the 'dest' block and erase
117
  // it.
118
  dest->getOperations().splice(dest->end(), source->getOperations());
119
  source->dropAllUses();
120
  source->erase();
121
}
122
123
/// Split the operations starting at "before" (inclusive) out of the given
124
/// block into a new block, and return it.
125
0
Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
126
0
  return block->splitBlock(before);
127
0
}
128
129
/// 'op' and 'newOp' are known to have the same number of results, replace the
130
/// uses of op with uses of newOp
131
void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
132
0
                                                      Operation *newOp) {
133
0
  assert(op->getNumResults() == newOp->getNumResults() &&
134
0
         "replacement op doesn't match results of original op");
135
0
  if (op->getNumResults() == 1)
136
0
    return replaceOp(op, newOp->getResult(0));
137
0
  return replaceOp(op, newOp->getResults());
138
0
}
139
140
/// Move the blocks that belong to "region" before the given position in
141
/// another region.  The two regions must be different.  The caller is in
142
/// charge to update create the operation transferring the control flow to the
143
/// region and pass it the correct block arguments.
144
void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
145
0
                                         Region::iterator before) {
146
0
  parent.getBlocks().splice(before, region.getBlocks());
147
0
}
148
0
void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
149
0
  inlineRegionBefore(region, *before->getParent(), before->getIterator());
150
0
}
151
152
/// Clone the blocks that belong to "region" before the given position in
153
/// another region "parent". The two regions must be different. The caller is
154
/// responsible for creating or updating the operation transferring flow of
155
/// control to the region and passing it the correct block arguments.
156
void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
157
                                        Region::iterator before,
158
0
                                        BlockAndValueMapping &mapping) {
159
0
  region.cloneInto(&parent, before, mapping);
160
0
}
161
void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
162
0
                                        Region::iterator before) {
163
0
  BlockAndValueMapping mapping;
164
0
  cloneRegionBefore(region, parent, before, mapping);
165
0
}
166
0
void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
167
0
  cloneRegionBefore(region, *before->getParent(), before->getIterator());
168
0
}
169
170
//===----------------------------------------------------------------------===//
171
// PatternMatcher implementation
172
//===----------------------------------------------------------------------===//
173
174
RewritePatternMatcher::RewritePatternMatcher(
175
0
    const OwningRewritePatternList &patterns) {
176
0
  for (auto &pattern : patterns)
177
0
    this->patterns.push_back(pattern.get());
178
0
179
0
  // Sort the patterns by benefit to simplify the matching logic.
180
0
  std::stable_sort(this->patterns.begin(), this->patterns.end(),
181
0
                   [](RewritePattern *l, RewritePattern *r) {
182
0
                     return r->getBenefit() < l->getBenefit();
183
0
                   });
184
0
}
185
186
/// Try to match the given operation to a pattern and rewrite it.
187
bool RewritePatternMatcher::matchAndRewrite(Operation *op,
188
0
                                            PatternRewriter &rewriter) {
189
0
  for (auto *pattern : patterns) {
190
0
    // Ignore patterns that are for the wrong root or are impossible to match.
191
0
    if (pattern->getRootKind() != op->getName() ||
192
0
        pattern->getBenefit().isImpossibleToMatch())
193
0
      continue;
194
0
195
0
    // Try to match and rewrite this pattern. The patterns are sorted by
196
0
    // benefit, so if we match we can immediately rewrite and return.
197
0
    if (succeeded(pattern->matchAndRewrite(op, rewriter)))
198
0
      return true;
199
0
  }
200
0
  return false;
201
0
}