/home/arjun/llvm-project/mlir/lib/IR/TypeUtilities.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- TypeUtilities.cpp - Helper function for type queries ---------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This file defines generic type utilities. |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | |
13 | | #include "mlir/IR/TypeUtilities.h" |
14 | | #include "mlir/IR/Attributes.h" |
15 | | #include "mlir/IR/StandardTypes.h" |
16 | | #include "mlir/IR/Types.h" |
17 | | #include "mlir/IR/Value.h" |
18 | | |
19 | | using namespace mlir; |
20 | | |
21 | 0 | Type mlir::getElementTypeOrSelf(Type type) { |
22 | 0 | if (auto st = type.dyn_cast<ShapedType>()) |
23 | 0 | return st.getElementType(); |
24 | 0 | return type; |
25 | 0 | } |
26 | | |
27 | 0 | Type mlir::getElementTypeOrSelf(Value val) { |
28 | 0 | return getElementTypeOrSelf(val.getType()); |
29 | 0 | } |
30 | | |
31 | 0 | Type mlir::getElementTypeOrSelf(Attribute attr) { |
32 | 0 | return getElementTypeOrSelf(attr.getType()); |
33 | 0 | } |
34 | | |
35 | 0 | SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) { |
36 | 0 | SmallVector<Type, 10> fTypes; |
37 | 0 | t.getFlattenedTypes(fTypes); |
38 | 0 | return fTypes; |
39 | 0 | } |
40 | | |
41 | | /// Return true if the specified type is an opaque type with the specified |
42 | | /// dialect and typeData. |
43 | | bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect, |
44 | 0 | StringRef typeData) { |
45 | 0 | if (auto opaque = type.dyn_cast<mlir::OpaqueType>()) |
46 | 0 | return opaque.getDialectNamespace() == dialect && |
47 | 0 | opaque.getTypeData() == typeData; |
48 | 0 | return false; |
49 | 0 | } |
50 | | |
51 | | /// Returns success if the given two shapes are compatible. That is, they have |
52 | | /// the same size and each pair of the elements are equal or one of them is |
53 | | /// dynamic. |
54 | | LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1, |
55 | 0 | ArrayRef<int64_t> shape2) { |
56 | 0 | if (shape1.size() != shape2.size()) |
57 | 0 | return failure(); |
58 | 0 | for (auto dims : llvm::zip(shape1, shape2)) { |
59 | 0 | int64_t dim1 = std::get<0>(dims); |
60 | 0 | int64_t dim2 = std::get<1>(dims); |
61 | 0 | if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && |
62 | 0 | dim1 != dim2) |
63 | 0 | return failure(); |
64 | 0 | } |
65 | 0 | return success(); |
66 | 0 | } |
67 | | |
68 | | /// Returns success if the given two types have compatible shape. That is, |
69 | | /// they are both scalars (not shaped), or they are both shaped types and at |
70 | | /// least one is unranked or they have compatible dimensions. Dimensions are |
71 | | /// compatible if at least one is dynamic or both are equal. The element type |
72 | | /// does not matter. |
73 | 0 | LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { |
74 | 0 | auto sType1 = type1.dyn_cast<ShapedType>(); |
75 | 0 | auto sType2 = type2.dyn_cast<ShapedType>(); |
76 | 0 |
|
77 | 0 | // Either both or neither type should be shaped. |
78 | 0 | if (!sType1) |
79 | 0 | return success(!sType2); |
80 | 0 | if (!sType2) |
81 | 0 | return failure(); |
82 | 0 | |
83 | 0 | if (!sType1.hasRank() || !sType2.hasRank()) |
84 | 0 | return success(); |
85 | 0 | |
86 | 0 | return verifyCompatibleShape(sType1.getShape(), sType2.getShape()); |
87 | 0 | } |
88 | | |
89 | | OperandElementTypeIterator::OperandElementTypeIterator( |
90 | | Operation::operand_iterator it) |
91 | | : llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>( |
92 | 0 | it, &unwrap) {} |
93 | | |
94 | 0 | Type OperandElementTypeIterator::unwrap(Value value) { |
95 | 0 | return value.getType().cast<ShapedType>().getElementType(); |
96 | 0 | } |
97 | | |
98 | | ResultElementTypeIterator::ResultElementTypeIterator( |
99 | | Operation::result_iterator it) |
100 | | : llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value)>( |
101 | 0 | it, &unwrap) {} |
102 | | |
103 | 0 | Type ResultElementTypeIterator::unwrap(Value value) { |
104 | 0 | return value.getType().cast<ShapedType>().getElementType(); |
105 | 0 | } |