Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/include/mlir/Dialect/CommonFolders.h
Line
Count
Source (jump to first uncovered line)
1
//===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This header file declares various common operation folders. These folders
10
// are intended to be used by dialects to support common folding behavior
11
// without requiring each dialect to provide its own implementation.
12
//
13
//===----------------------------------------------------------------------===//
14
15
#ifndef MLIR_DIALECT_COMMONFOLDERS_H
16
#define MLIR_DIALECT_COMMONFOLDERS_H
17
18
#include "mlir/IR/Attributes.h"
19
#include "mlir/IR/StandardTypes.h"
20
#include "llvm/ADT/ArrayRef.h"
21
#include "llvm/ADT/STLExtras.h"
22
23
namespace mlir {
24
/// Performs constant folding `calculate` with element-wise behavior on the two
25
/// attributes in `operands` and returns the result if possible.
26
template <class AttrElementT,
27
          class ElementValueT = typename AttrElementT::ValueType,
28
          class CalculationT =
29
              function_ref<ElementValueT(ElementValueT, ElementValueT)>>
30
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
31
0
                            const CalculationT &calculate) {
32
0
  assert(operands.size() == 2 && "binary op takes two operands");
33
0
  if (!operands[0] || !operands[1])
34
0
    return {};
35
0
  if (operands[0].getType() != operands[1].getType())
36
0
    return {};
37
0
38
0
  if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
39
0
    auto lhs = operands[0].cast<AttrElementT>();
40
0
    auto rhs = operands[1].cast<AttrElementT>();
41
0
42
0
    return AttrElementT::get(lhs.getType(),
43
0
                             calculate(lhs.getValue(), rhs.getValue()));
44
0
  } else if (operands[0].isa<SplatElementsAttr>() &&
45
0
             operands[1].isa<SplatElementsAttr>()) {
46
0
    // Both operands are splats so we can avoid expanding the values out and
47
0
    // just fold based on the splat value.
48
0
    auto lhs = operands[0].cast<SplatElementsAttr>();
49
0
    auto rhs = operands[1].cast<SplatElementsAttr>();
50
0
51
0
    auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
52
0
                                   rhs.getSplatValue<ElementValueT>());
53
0
    return DenseElementsAttr::get(lhs.getType(), elementResult);
54
0
  } else if (operands[0].isa<ElementsAttr>() &&
55
0
             operands[1].isa<ElementsAttr>()) {
56
0
    // Operands are ElementsAttr-derived; perform an element-wise fold by
57
0
    // expanding the values.
58
0
    auto lhs = operands[0].cast<ElementsAttr>();
59
0
    auto rhs = operands[1].cast<ElementsAttr>();
60
0
61
0
    auto lhsIt = lhs.getValues<ElementValueT>().begin();
62
0
    auto rhsIt = rhs.getValues<ElementValueT>().begin();
63
0
    SmallVector<ElementValueT, 4> elementResults;
64
0
    elementResults.reserve(lhs.getNumElements());
65
0
    for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
66
0
      elementResults.push_back(calculate(*lhsIt, *rhsIt));
67
0
    return DenseElementsAttr::get(lhs.getType(), elementResults);
68
0
  }
69
0
  return {};
70
0
}
Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_9FloatAttrEN4llvm7APFloatEZNS_6AddFOp4foldENS2_8ArrayRefINS_9AttributeEEEE3$_0EES6_S7_RKT1_
Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_6AddIOp4foldENS2_8ArrayRefINS_9AttributeEEEE3$_1EES6_S7_RKT1_
Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_5AndOp4foldENS2_8ArrayRefINS_9AttributeEEEE3$_2EES6_S7_RKT1_
Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_9FloatAttrEN4llvm7APFloatEZNS_6MulFOp4foldENS2_8ArrayRefINS_9AttributeEEEE3$_8EES6_S7_RKT1_
Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_6MulIOp4foldENS2_8ArrayRefINS_9AttributeEEEE3$_9EES6_S7_RKT1_
Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_4OrOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_10EES6_S7_RKT1_
Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_12SignedDivIOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_11EES6_S7_RKT1_
Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_9FloatAttrEN4llvm7APFloatEZNS_6SubFOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_12EES6_S7_RKT1_
Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_6SubIOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_13EES6_S7_RKT1_
Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_14UnsignedDivIOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_19EES6_S7_RKT1_
Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_5XOrOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_20EES6_S7_RKT1_
71
} // namespace mlir
72
73
#endif // MLIR_DIALECT_COMMONFOLDERS_H