/home/arjun/llvm-project/mlir/include/mlir/Dialect/CommonFolders.h
Line | Count | Source (jump to first uncovered line) |
1 | | //===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This header file declares various common operation folders. These folders |
10 | | // are intended to be used by dialects to support common folding behavior |
11 | | // without requiring each dialect to provide its own implementation. |
12 | | // |
13 | | //===----------------------------------------------------------------------===// |
14 | | |
15 | | #ifndef MLIR_DIALECT_COMMONFOLDERS_H |
16 | | #define MLIR_DIALECT_COMMONFOLDERS_H |
17 | | |
18 | | #include "mlir/IR/Attributes.h" |
19 | | #include "mlir/IR/StandardTypes.h" |
20 | | #include "llvm/ADT/ArrayRef.h" |
21 | | #include "llvm/ADT/STLExtras.h" |
22 | | |
23 | | namespace mlir { |
24 | | /// Performs constant folding `calculate` with element-wise behavior on the two |
25 | | /// attributes in `operands` and returns the result if possible. |
26 | | template <class AttrElementT, |
27 | | class ElementValueT = typename AttrElementT::ValueType, |
28 | | class CalculationT = |
29 | | function_ref<ElementValueT(ElementValueT, ElementValueT)>> |
30 | | Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, |
31 | 0 | const CalculationT &calculate) { |
32 | 0 | assert(operands.size() == 2 && "binary op takes two operands"); |
33 | 0 | if (!operands[0] || !operands[1]) |
34 | 0 | return {}; |
35 | 0 | if (operands[0].getType() != operands[1].getType()) |
36 | 0 | return {}; |
37 | 0 | |
38 | 0 | if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) { |
39 | 0 | auto lhs = operands[0].cast<AttrElementT>(); |
40 | 0 | auto rhs = operands[1].cast<AttrElementT>(); |
41 | 0 |
|
42 | 0 | return AttrElementT::get(lhs.getType(), |
43 | 0 | calculate(lhs.getValue(), rhs.getValue())); |
44 | 0 | } else if (operands[0].isa<SplatElementsAttr>() && |
45 | 0 | operands[1].isa<SplatElementsAttr>()) { |
46 | 0 | // Both operands are splats so we can avoid expanding the values out and |
47 | 0 | // just fold based on the splat value. |
48 | 0 | auto lhs = operands[0].cast<SplatElementsAttr>(); |
49 | 0 | auto rhs = operands[1].cast<SplatElementsAttr>(); |
50 | 0 |
|
51 | 0 | auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(), |
52 | 0 | rhs.getSplatValue<ElementValueT>()); |
53 | 0 | return DenseElementsAttr::get(lhs.getType(), elementResult); |
54 | 0 | } else if (operands[0].isa<ElementsAttr>() && |
55 | 0 | operands[1].isa<ElementsAttr>()) { |
56 | 0 | // Operands are ElementsAttr-derived; perform an element-wise fold by |
57 | 0 | // expanding the values. |
58 | 0 | auto lhs = operands[0].cast<ElementsAttr>(); |
59 | 0 | auto rhs = operands[1].cast<ElementsAttr>(); |
60 | 0 |
|
61 | 0 | auto lhsIt = lhs.getValues<ElementValueT>().begin(); |
62 | 0 | auto rhsIt = rhs.getValues<ElementValueT>().begin(); |
63 | 0 | SmallVector<ElementValueT, 4> elementResults; |
64 | 0 | elementResults.reserve(lhs.getNumElements()); |
65 | 0 | for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) |
66 | 0 | elementResults.push_back(calculate(*lhsIt, *rhsIt)); |
67 | 0 | return DenseElementsAttr::get(lhs.getType(), elementResults); |
68 | 0 | } |
69 | 0 | return {}; |
70 | 0 | } Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_9FloatAttrEN4llvm7APFloatEZNS_6AddFOp4foldENS2_8ArrayRefINS_9AttributeEEEE3$_0EES6_S7_RKT1_ Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_6AddIOp4foldENS2_8ArrayRefINS_9AttributeEEEE3$_1EES6_S7_RKT1_ Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_5AndOp4foldENS2_8ArrayRefINS_9AttributeEEEE3$_2EES6_S7_RKT1_ Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_9FloatAttrEN4llvm7APFloatEZNS_6MulFOp4foldENS2_8ArrayRefINS_9AttributeEEEE3$_8EES6_S7_RKT1_ Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_6MulIOp4foldENS2_8ArrayRefINS_9AttributeEEEE3$_9EES6_S7_RKT1_ Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_4OrOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_10EES6_S7_RKT1_ Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_12SignedDivIOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_11EES6_S7_RKT1_ Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_9FloatAttrEN4llvm7APFloatEZNS_6SubFOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_12EES6_S7_RKT1_ Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_6SubIOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_13EES6_S7_RKT1_ Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_14UnsignedDivIOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_19EES6_S7_RKT1_ Unexecuted instantiation: Ops.cpp:_ZN4mlir17constFoldBinaryOpINS_11IntegerAttrEN4llvm5APIntEZNS_5XOrOp4foldENS2_8ArrayRefINS_9AttributeEEEE4$_20EES6_S7_RKT1_ |
71 | | } // namespace mlir |
72 | | |
73 | | #endif // MLIR_DIALECT_COMMONFOLDERS_H |