/home/arjun/llvm-project/mlir/lib/IR/Attributes.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | |
9 | | #include "mlir/IR/Attributes.h" |
10 | | #include "AttributeDetail.h" |
11 | | #include "mlir/IR/AffineMap.h" |
12 | | #include "mlir/IR/Diagnostics.h" |
13 | | #include "mlir/IR/Dialect.h" |
14 | | #include "mlir/IR/Function.h" |
15 | | #include "mlir/IR/IntegerSet.h" |
16 | | #include "mlir/IR/Types.h" |
17 | | #include "llvm/ADT/Sequence.h" |
18 | | #include "llvm/ADT/Twine.h" |
19 | | #include "llvm/Support/Endian.h" |
20 | | |
21 | | using namespace mlir; |
22 | | using namespace mlir::detail; |
23 | | |
24 | | //===----------------------------------------------------------------------===// |
25 | | // AttributeStorage |
26 | | //===----------------------------------------------------------------------===// |
27 | | |
28 | | AttributeStorage::AttributeStorage(Type type) |
29 | 0 | : type(type.getAsOpaquePointer()) {} |
30 | 0 | AttributeStorage::AttributeStorage() : type(nullptr) {} |
31 | | |
32 | 0 | Type AttributeStorage::getType() const { |
33 | 0 | return Type::getFromOpaquePointer(type); |
34 | 0 | } |
35 | 0 | void AttributeStorage::setType(Type newType) { |
36 | 0 | type = newType.getAsOpaquePointer(); |
37 | 0 | } |
38 | | |
39 | | //===----------------------------------------------------------------------===// |
40 | | // Attribute |
41 | | //===----------------------------------------------------------------------===// |
42 | | |
43 | | /// Return the type of this attribute. |
44 | 0 | Type Attribute::getType() const { return impl->getType(); } |
45 | | |
46 | | /// Return the context this attribute belongs to. |
47 | 0 | MLIRContext *Attribute::getContext() const { return getType().getContext(); } |
48 | | |
49 | | /// Get the dialect this attribute is registered to. |
50 | 0 | Dialect &Attribute::getDialect() const { return impl->getDialect(); } |
51 | | |
52 | | //===----------------------------------------------------------------------===// |
53 | | // AffineMapAttr |
54 | | //===----------------------------------------------------------------------===// |
55 | | |
56 | 0 | AffineMapAttr AffineMapAttr::get(AffineMap value) { |
57 | 0 | return Base::get(value.getContext(), StandardAttributes::AffineMap, value); |
58 | 0 | } |
59 | | |
60 | 0 | AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } |
61 | | |
62 | | //===----------------------------------------------------------------------===// |
63 | | // ArrayAttr |
64 | | //===----------------------------------------------------------------------===// |
65 | | |
66 | 0 | ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { |
67 | 0 | return Base::get(context, StandardAttributes::Array, value); |
68 | 0 | } |
69 | | |
70 | 0 | ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } |
71 | | |
72 | 0 | Attribute ArrayAttr::operator[](unsigned idx) const { |
73 | 0 | assert(idx < size() && "index out of bounds"); |
74 | 0 | return getValue()[idx]; |
75 | 0 | } |
76 | | |
77 | | //===----------------------------------------------------------------------===// |
78 | | // BoolAttr |
79 | | //===----------------------------------------------------------------------===// |
80 | | |
81 | 0 | bool BoolAttr::getValue() const { return getImpl()->value; } |
82 | | |
83 | | //===----------------------------------------------------------------------===// |
84 | | // DictionaryAttr |
85 | | //===----------------------------------------------------------------------===// |
86 | | |
87 | | /// Helper function that does either an in place sort or sorts from source array |
88 | | /// into destination. If inPlace then storage is both the source and the |
89 | | /// destination, else value is the source and storage destination. Returns |
90 | | /// whether source was sorted. |
91 | | template <bool inPlace> |
92 | | static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, |
93 | 0 | SmallVectorImpl<NamedAttribute> &storage) { |
94 | 0 | // Specialize for the common case. |
95 | 0 | switch (value.size()) { |
96 | 0 | case 0: |
97 | 0 | // Zero already sorted. |
98 | 0 | break; |
99 | 0 | case 1: |
100 | 0 | // One already sorted but may need to be copied. |
101 | 0 | if (!inPlace) |
102 | 0 | storage.assign({value[0]}); |
103 | 0 | break; |
104 | 0 | case 2: { |
105 | 0 | assert(value[0].first != value[1].first && |
106 | 0 | "DictionaryAttr element names must be unique"); |
107 | 0 | bool isSorted = value[0] < value[1]; |
108 | 0 | if (inPlace) { |
109 | 0 | if (!isSorted) |
110 | 0 | std::swap(storage[0], storage[1]); |
111 | 0 | } else if (isSorted) { |
112 | 0 | storage.assign({value[0], value[1]}); |
113 | 0 | } else { |
114 | 0 | storage.assign({value[1], value[0]}); |
115 | 0 | } |
116 | 0 | return !isSorted; |
117 | 0 | } |
118 | 0 | default: |
119 | 0 | if (!inPlace) |
120 | 0 | storage.assign(value.begin(), value.end()); |
121 | 0 | // Check to see they are sorted already. |
122 | 0 | bool isSorted = llvm::is_sorted(value); |
123 | 0 | if (!isSorted) { |
124 | 0 | // If not, do a general sort. |
125 | 0 | llvm::array_pod_sort(storage.begin(), storage.end()); |
126 | 0 | value = storage; |
127 | 0 | } |
128 | 0 |
|
129 | 0 | // Ensure that the attribute elements are unique. |
130 | 0 | assert(std::adjacent_find(value.begin(), value.end(), |
131 | 0 | [](NamedAttribute l, NamedAttribute r) { |
132 | 0 | return l.first == r.first; |
133 | 0 | }) == value.end() && |
134 | 0 | "DictionaryAttr element names must be unique"); |
135 | 0 | return !isSorted; |
136 | 0 | } |
137 | 0 | return false; |
138 | 0 | } Unexecuted instantiation: Attributes.cpp:_ZL18dictionaryAttrSortILb0EEbN4llvm8ArrayRefISt4pairIN4mlir10IdentifierENS3_9AttributeEEEERNS0_15SmallVectorImplIS6_EE Unexecuted instantiation: Attributes.cpp:_ZL18dictionaryAttrSortILb1EEbN4llvm8ArrayRefISt4pairIN4mlir10IdentifierENS3_9AttributeEEEERNS0_15SmallVectorImplIS6_EE |
139 | | |
140 | | bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, |
141 | 0 | SmallVectorImpl<NamedAttribute> &storage) { |
142 | 0 | return dictionaryAttrSort</*inPlace=*/false>(value, storage); |
143 | 0 | } |
144 | | |
145 | 0 | bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { |
146 | 0 | return dictionaryAttrSort</*inPlace=*/true>(array, array); |
147 | 0 | } |
148 | | |
149 | | DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, |
150 | 0 | MLIRContext *context) { |
151 | 0 | if (value.empty()) |
152 | 0 | return DictionaryAttr::getEmpty(context); |
153 | 0 | assert(llvm::all_of(value, |
154 | 0 | [](const NamedAttribute &attr) { return attr.second; }) && |
155 | 0 | "value cannot have null entries"); |
156 | 0 |
|
157 | 0 | // We need to sort the element list to canonicalize it. |
158 | 0 | SmallVector<NamedAttribute, 8> storage; |
159 | 0 | if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) |
160 | 0 | value = storage; |
161 | 0 |
|
162 | 0 | return Base::get(context, StandardAttributes::Dictionary, value); |
163 | 0 | } |
164 | | /// Construct a dictionary with an array of values that is known to already be |
165 | | /// sorted by name and uniqued. |
166 | | DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value, |
167 | 0 | MLIRContext *context) { |
168 | 0 | if (value.empty()) |
169 | 0 | return DictionaryAttr::getEmpty(context); |
170 | 0 | // Ensure that the attribute elements are unique and sorted. |
171 | 0 | assert(llvm::is_sorted(value, |
172 | 0 | [](NamedAttribute l, NamedAttribute r) { |
173 | 0 | return l.first.strref() < r.first.strref(); |
174 | 0 | }) && |
175 | 0 | "expected attribute values to be sorted"); |
176 | 0 | assert(std::adjacent_find(value.begin(), value.end(), |
177 | 0 | [](NamedAttribute l, NamedAttribute r) { |
178 | 0 | return l.first == r.first; |
179 | 0 | }) == value.end() && |
180 | 0 | "DictionaryAttr element names must be unique"); |
181 | 0 | return Base::get(context, StandardAttributes::Dictionary, value); |
182 | 0 | } |
183 | | |
184 | 0 | ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { |
185 | 0 | return getImpl()->getElements(); |
186 | 0 | } |
187 | | |
188 | | /// Return the specified attribute if present, null otherwise. |
189 | 0 | Attribute DictionaryAttr::get(StringRef name) const { |
190 | 0 | Optional<NamedAttribute> attr = getNamed(name); |
191 | 0 | return attr ? attr->second : nullptr; |
192 | 0 | } |
193 | 0 | Attribute DictionaryAttr::get(Identifier name) const { |
194 | 0 | Optional<NamedAttribute> attr = getNamed(name); |
195 | 0 | return attr ? attr->second : nullptr; |
196 | 0 | } |
197 | | |
198 | | /// Return the specified named attribute if present, None otherwise. |
199 | 0 | Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { |
200 | 0 | ArrayRef<NamedAttribute> values = getValue(); |
201 | 0 | const auto *it = llvm::lower_bound(values, name); |
202 | 0 | return it != values.end() && it->first == name ? *it |
203 | 0 | : Optional<NamedAttribute>(); |
204 | 0 | } |
205 | 0 | Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const { |
206 | 0 | for (auto elt : getValue()) |
207 | 0 | if (elt.first == name) |
208 | 0 | return elt; |
209 | 0 | return llvm::None; |
210 | 0 | } |
211 | | |
212 | 0 | DictionaryAttr::iterator DictionaryAttr::begin() const { |
213 | 0 | return getValue().begin(); |
214 | 0 | } |
215 | 0 | DictionaryAttr::iterator DictionaryAttr::end() const { |
216 | 0 | return getValue().end(); |
217 | 0 | } |
218 | 0 | size_t DictionaryAttr::size() const { return getValue().size(); } |
219 | | |
220 | | //===----------------------------------------------------------------------===// |
221 | | // FloatAttr |
222 | | //===----------------------------------------------------------------------===// |
223 | | |
224 | 0 | FloatAttr FloatAttr::get(Type type, double value) { |
225 | 0 | return Base::get(type.getContext(), StandardAttributes::Float, type, value); |
226 | 0 | } |
227 | | |
228 | 0 | FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { |
229 | 0 | return Base::getChecked(loc, StandardAttributes::Float, type, value); |
230 | 0 | } |
231 | | |
232 | 0 | FloatAttr FloatAttr::get(Type type, const APFloat &value) { |
233 | 0 | return Base::get(type.getContext(), StandardAttributes::Float, type, value); |
234 | 0 | } |
235 | | |
236 | 0 | FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { |
237 | 0 | return Base::getChecked(loc, StandardAttributes::Float, type, value); |
238 | 0 | } |
239 | | |
240 | 0 | APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } |
241 | | |
242 | 0 | double FloatAttr::getValueAsDouble() const { |
243 | 0 | return getValueAsDouble(getValue()); |
244 | 0 | } |
245 | 0 | double FloatAttr::getValueAsDouble(APFloat value) { |
246 | 0 | if (&value.getSemantics() != &APFloat::IEEEdouble()) { |
247 | 0 | bool losesInfo = false; |
248 | 0 | value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, |
249 | 0 | &losesInfo); |
250 | 0 | } |
251 | 0 | return value.convertToDouble(); |
252 | 0 | } |
253 | | |
254 | | /// Verify construction invariants. |
255 | 0 | static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { |
256 | 0 | if (!type.isa<FloatType>()) |
257 | 0 | return emitError(loc, "expected floating point type"); |
258 | 0 | return success(); |
259 | 0 | } |
260 | | |
261 | | LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, |
262 | 0 | double value) { |
263 | 0 | return verifyFloatTypeInvariants(loc, type); |
264 | 0 | } |
265 | | |
266 | | LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, |
267 | 0 | const APFloat &value) { |
268 | 0 | // Verify that the type is correct. |
269 | 0 | if (failed(verifyFloatTypeInvariants(loc, type))) |
270 | 0 | return failure(); |
271 | 0 | |
272 | 0 | // Verify that the type semantics match that of the value. |
273 | 0 | if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { |
274 | 0 | return emitError( |
275 | 0 | loc, "FloatAttr type doesn't match the type implied by its value"); |
276 | 0 | } |
277 | 0 | return success(); |
278 | 0 | } |
279 | | |
280 | | //===----------------------------------------------------------------------===// |
281 | | // SymbolRefAttr |
282 | | //===----------------------------------------------------------------------===// |
283 | | |
284 | 0 | FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { |
285 | 0 | return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) |
286 | 0 | .cast<FlatSymbolRefAttr>(); |
287 | 0 | } |
288 | | |
289 | | SymbolRefAttr SymbolRefAttr::get(StringRef value, |
290 | | ArrayRef<FlatSymbolRefAttr> nestedReferences, |
291 | 0 | MLIRContext *ctx) { |
292 | 0 | return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); |
293 | 0 | } |
294 | | |
295 | 0 | StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } |
296 | | |
297 | 0 | StringRef SymbolRefAttr::getLeafReference() const { |
298 | 0 | ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); |
299 | 0 | return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); |
300 | 0 | } |
301 | | |
302 | 0 | ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { |
303 | 0 | return getImpl()->getNestedRefs(); |
304 | 0 | } |
305 | | |
306 | | //===----------------------------------------------------------------------===// |
307 | | // IntegerAttr |
308 | | //===----------------------------------------------------------------------===// |
309 | | |
310 | 0 | IntegerAttr IntegerAttr::get(Type type, const APInt &value) { |
311 | 0 | return Base::get(type.getContext(), StandardAttributes::Integer, type, value); |
312 | 0 | } |
313 | | |
314 | 0 | IntegerAttr IntegerAttr::get(Type type, int64_t value) { |
315 | 0 | // This uses 64 bit APInts by default for index type. |
316 | 0 | if (type.isIndex()) |
317 | 0 | return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); |
318 | 0 | |
319 | 0 | auto intType = type.cast<IntegerType>(); |
320 | 0 | return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); |
321 | 0 | } |
322 | | |
323 | 0 | APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } |
324 | | |
325 | 0 | int64_t IntegerAttr::getInt() const { |
326 | 0 | assert((getImpl()->getType().isIndex() || |
327 | 0 | getImpl()->getType().isSignlessInteger()) && |
328 | 0 | "must be signless integer"); |
329 | 0 | return getValue().getSExtValue(); |
330 | 0 | } |
331 | | |
332 | 0 | int64_t IntegerAttr::getSInt() const { |
333 | 0 | assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); |
334 | 0 | return getValue().getSExtValue(); |
335 | 0 | } |
336 | | |
337 | 0 | uint64_t IntegerAttr::getUInt() const { |
338 | 0 | assert(getImpl()->getType().isUnsignedInteger() && |
339 | 0 | "must be unsigned integer"); |
340 | 0 | return getValue().getZExtValue(); |
341 | 0 | } |
342 | | |
343 | 0 | static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { |
344 | 0 | if (type.isa<IntegerType>() || type.isa<IndexType>()) |
345 | 0 | return success(); |
346 | 0 | return emitError(loc, "expected integer or index type"); |
347 | 0 | } |
348 | | |
349 | | LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, |
350 | 0 | int64_t value) { |
351 | 0 | return verifyIntegerTypeInvariants(loc, type); |
352 | 0 | } |
353 | | |
354 | | LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, |
355 | 0 | const APInt &value) { |
356 | 0 | if (failed(verifyIntegerTypeInvariants(loc, type))) |
357 | 0 | return failure(); |
358 | 0 | if (auto integerType = type.dyn_cast<IntegerType>()) |
359 | 0 | if (integerType.getWidth() != value.getBitWidth()) |
360 | 0 | return emitError(loc, "integer type bit width (") |
361 | 0 | << integerType.getWidth() << ") doesn't match value bit width (" |
362 | 0 | << value.getBitWidth() << ")"; |
363 | 0 | return success(); |
364 | 0 | } |
365 | | |
366 | | //===----------------------------------------------------------------------===// |
367 | | // IntegerSetAttr |
368 | | //===----------------------------------------------------------------------===// |
369 | | |
370 | 0 | IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { |
371 | 0 | return Base::get(value.getConstraint(0).getContext(), |
372 | 0 | StandardAttributes::IntegerSet, value); |
373 | 0 | } |
374 | | |
375 | 0 | IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } |
376 | | |
377 | | //===----------------------------------------------------------------------===// |
378 | | // OpaqueAttr |
379 | | //===----------------------------------------------------------------------===// |
380 | | |
381 | | OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, |
382 | 0 | MLIRContext *context) { |
383 | 0 | return Base::get(context, StandardAttributes::Opaque, dialect, attrData, |
384 | 0 | type); |
385 | 0 | } |
386 | | |
387 | | OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, |
388 | 0 | Type type, Location location) { |
389 | 0 | return Base::getChecked(location, StandardAttributes::Opaque, dialect, |
390 | 0 | attrData, type); |
391 | 0 | } |
392 | | |
393 | | /// Returns the dialect namespace of the opaque attribute. |
394 | 0 | Identifier OpaqueAttr::getDialectNamespace() const { |
395 | 0 | return getImpl()->dialectNamespace; |
396 | 0 | } |
397 | | |
398 | | /// Returns the raw attribute data of the opaque attribute. |
399 | 0 | StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } |
400 | | |
401 | | /// Verify the construction of an opaque attribute. |
402 | | LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, |
403 | | Identifier dialect, |
404 | | StringRef attrData, |
405 | 0 | Type type) { |
406 | 0 | if (!Dialect::isValidNamespace(dialect.strref())) |
407 | 0 | return emitError(loc, "invalid dialect namespace '") << dialect << "'"; |
408 | 0 | return success(); |
409 | 0 | } |
410 | | |
411 | | //===----------------------------------------------------------------------===// |
412 | | // StringAttr |
413 | | //===----------------------------------------------------------------------===// |
414 | | |
415 | 0 | StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { |
416 | 0 | return get(bytes, NoneType::get(context)); |
417 | 0 | } |
418 | | |
419 | | /// Get an instance of a StringAttr with the given string and Type. |
420 | 0 | StringAttr StringAttr::get(StringRef bytes, Type type) { |
421 | 0 | return Base::get(type.getContext(), StandardAttributes::String, bytes, type); |
422 | 0 | } |
423 | | |
424 | 0 | StringRef StringAttr::getValue() const { return getImpl()->value; } |
425 | | |
426 | | //===----------------------------------------------------------------------===// |
427 | | // TypeAttr |
428 | | //===----------------------------------------------------------------------===// |
429 | | |
430 | 0 | TypeAttr TypeAttr::get(Type value) { |
431 | 0 | return Base::get(value.getContext(), StandardAttributes::Type, value); |
432 | 0 | } |
433 | | |
434 | 0 | Type TypeAttr::getValue() const { return getImpl()->value; } |
435 | | |
436 | | //===----------------------------------------------------------------------===// |
437 | | // ElementsAttr |
438 | | //===----------------------------------------------------------------------===// |
439 | | |
440 | 0 | ShapedType ElementsAttr::getType() const { |
441 | 0 | return Attribute::getType().cast<ShapedType>(); |
442 | 0 | } |
443 | | |
444 | | /// Returns the number of elements held by this attribute. |
445 | 0 | int64_t ElementsAttr::getNumElements() const { |
446 | 0 | return getType().getNumElements(); |
447 | 0 | } |
448 | | |
449 | | /// Return the value at the given index. If index does not refer to a valid |
450 | | /// element, then a null attribute is returned. |
451 | 0 | Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { |
452 | 0 | switch (getKind()) { |
453 | 0 | case StandardAttributes::DenseIntOrFPElements: |
454 | 0 | return cast<DenseElementsAttr>().getValue(index); |
455 | 0 | case StandardAttributes::OpaqueElements: |
456 | 0 | return cast<OpaqueElementsAttr>().getValue(index); |
457 | 0 | case StandardAttributes::SparseElements: |
458 | 0 | return cast<SparseElementsAttr>().getValue(index); |
459 | 0 | default: |
460 | 0 | llvm_unreachable("unknown ElementsAttr kind"); |
461 | 0 | } |
462 | 0 | } |
463 | | |
464 | | /// Return if the given 'index' refers to a valid element in this attribute. |
465 | 0 | bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { |
466 | 0 | auto type = getType(); |
467 | 0 |
|
468 | 0 | // Verify that the rank of the indices matches the held type. |
469 | 0 | auto rank = type.getRank(); |
470 | 0 | if (rank != static_cast<int64_t>(index.size())) |
471 | 0 | return false; |
472 | 0 | |
473 | 0 | // Verify that all of the indices are within the shape dimensions. |
474 | 0 | auto shape = type.getShape(); |
475 | 0 | return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { |
476 | 0 | return static_cast<int64_t>(index[i]) < shape[i]; |
477 | 0 | }); |
478 | 0 | } |
479 | | |
480 | | ElementsAttr |
481 | | ElementsAttr::mapValues(Type newElementType, |
482 | 0 | function_ref<APInt(const APInt &)> mapping) const { |
483 | 0 | switch (getKind()) { |
484 | 0 | case StandardAttributes::DenseIntOrFPElements: |
485 | 0 | return cast<DenseElementsAttr>().mapValues(newElementType, mapping); |
486 | 0 | default: |
487 | 0 | llvm_unreachable("unsupported ElementsAttr subtype"); |
488 | 0 | } |
489 | 0 | } |
490 | | |
491 | | ElementsAttr |
492 | | ElementsAttr::mapValues(Type newElementType, |
493 | 0 | function_ref<APInt(const APFloat &)> mapping) const { |
494 | 0 | switch (getKind()) { |
495 | 0 | case StandardAttributes::DenseIntOrFPElements: |
496 | 0 | return cast<DenseElementsAttr>().mapValues(newElementType, mapping); |
497 | 0 | default: |
498 | 0 | llvm_unreachable("unsupported ElementsAttr subtype"); |
499 | 0 | } |
500 | 0 | } |
501 | | |
502 | | /// Returns the 1 dimensional flattened row-major index from the given |
503 | | /// multi-dimensional index. |
504 | 0 | uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { |
505 | 0 | assert(isValidIndex(index) && "expected valid multi-dimensional index"); |
506 | 0 | auto type = getType(); |
507 | 0 |
|
508 | 0 | // Reduce the provided multidimensional index into a flattended 1D row-major |
509 | 0 | // index. |
510 | 0 | auto rank = type.getRank(); |
511 | 0 | auto shape = type.getShape(); |
512 | 0 | uint64_t valueIndex = 0; |
513 | 0 | uint64_t dimMultiplier = 1; |
514 | 0 | for (int i = rank - 1; i >= 0; --i) { |
515 | 0 | valueIndex += index[i] * dimMultiplier; |
516 | 0 | dimMultiplier *= shape[i]; |
517 | 0 | } |
518 | 0 | return valueIndex; |
519 | 0 | } |
520 | | |
521 | | //===----------------------------------------------------------------------===// |
522 | | // DenseElementAttr Utilities |
523 | | //===----------------------------------------------------------------------===// |
524 | | |
525 | | /// Get the bitwidth of a dense element type within the buffer. |
526 | | /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. |
527 | 0 | static size_t getDenseElementStorageWidth(size_t origWidth) { |
528 | 0 | return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); |
529 | 0 | } |
530 | 0 | static size_t getDenseElementStorageWidth(Type elementType) { |
531 | 0 | return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); |
532 | 0 | } |
533 | | |
534 | | /// Set a bit to a specific value. |
535 | 0 | static void setBit(char *rawData, size_t bitPos, bool value) { |
536 | 0 | if (value) |
537 | 0 | rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); |
538 | 0 | else |
539 | 0 | rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); |
540 | 0 | } |
541 | | |
542 | | /// Return the value of the specified bit. |
543 | 0 | static bool getBit(const char *rawData, size_t bitPos) { |
544 | 0 | return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; |
545 | 0 | } |
546 | | |
547 | | /// Get start position of actual data in `value`. Actual data is |
548 | | /// stored in last `bitWidth`/CHAR_BIT bytes in big endian. |
549 | 0 | static char *getAPIntDataPos(APInt &value, size_t bitWidth) { |
550 | 0 | char *dataPos = |
551 | 0 | const_cast<char *>(reinterpret_cast<const char *>(value.getRawData())); |
552 | 0 | if (llvm::support::endian::system_endianness() == |
553 | 0 | llvm::support::endianness::big) |
554 | 0 | dataPos = dataPos + 8 - llvm::divideCeil(bitWidth, CHAR_BIT); |
555 | 0 | return dataPos; |
556 | 0 | } |
557 | | |
558 | | /// Read APInt `value` from appropriate position. |
559 | 0 | static void readAPInt(APInt &value, size_t bitWidth, char *outData) { |
560 | 0 | char *dataPos = getAPIntDataPos(value, bitWidth); |
561 | 0 | std::copy_n(dataPos, llvm::divideCeil(bitWidth, CHAR_BIT), outData); |
562 | 0 | } |
563 | | |
564 | | /// Write `inData` to appropriate position of APInt `value`. |
565 | 0 | static void writeAPInt(const char *inData, size_t bitWidth, APInt &value) { |
566 | 0 | char *dataPos = getAPIntDataPos(value, bitWidth); |
567 | 0 | std::copy_n(inData, llvm::divideCeil(bitWidth, CHAR_BIT), dataPos); |
568 | 0 | } |
569 | | |
570 | | /// Writes value to the bit position `bitPos` in array `rawData`. |
571 | 0 | static void writeBits(char *rawData, size_t bitPos, APInt value) { |
572 | 0 | size_t bitWidth = value.getBitWidth(); |
573 | 0 |
|
574 | 0 | // If the bitwidth is 1 we just toggle the specific bit. |
575 | 0 | if (bitWidth == 1) |
576 | 0 | return setBit(rawData, bitPos, value.isOneValue()); |
577 | 0 | |
578 | 0 | // Otherwise, the bit position is guaranteed to be byte aligned. |
579 | 0 | assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); |
580 | 0 | readAPInt(value, bitWidth, rawData + (bitPos / CHAR_BIT)); |
581 | 0 | } |
582 | | |
583 | | /// Reads the next `bitWidth` bits from the bit position `bitPos` in array |
584 | | /// `rawData`. |
585 | 0 | static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { |
586 | 0 | // Handle a boolean bit position. |
587 | 0 | if (bitWidth == 1) |
588 | 0 | return APInt(1, getBit(rawData, bitPos) ? 1 : 0); |
589 | 0 |
|
590 | 0 | // Otherwise, the bit position must be 8-bit aligned. |
591 | 0 | assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); |
592 | 0 | APInt result(bitWidth, 0); |
593 | 0 | writeAPInt(rawData + (bitPos / CHAR_BIT), bitWidth, result); |
594 | 0 | return result; |
595 | 0 | } |
596 | | |
597 | | /// Returns if 'values' corresponds to a splat, i.e. one element, or has the |
598 | | /// same element count as 'type'. |
599 | | template <typename Values> |
600 | 0 | static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { |
601 | 0 | return (values.size() == 1) || |
602 | 0 | (type.getNumElements() == static_cast<int64_t>(values.size())); |
603 | 0 | } Unexecuted instantiation: Attributes.cpp:_ZL22hasSameElementsOrSplatIN4llvm8ArrayRefIN4mlir9AttributeEEEEbNS2_10ShapedTypeERKT_ Unexecuted instantiation: Attributes.cpp:_ZL22hasSameElementsOrSplatIN4llvm8ArrayRefIbEEEbN4mlir10ShapedTypeERKT_ Unexecuted instantiation: Attributes.cpp:_ZL22hasSameElementsOrSplatIN4llvm8ArrayRefINS0_5APIntEEEEbN4mlir10ShapedTypeERKT_ Unexecuted instantiation: Attributes.cpp:_ZL22hasSameElementsOrSplatIN4llvm8ArrayRefISt7complexINS0_5APIntEEEEEbN4mlir10ShapedTypeERKT_ Unexecuted instantiation: Attributes.cpp:_ZL22hasSameElementsOrSplatIN4llvm8ArrayRefINS0_7APFloatEEEEbN4mlir10ShapedTypeERKT_ Unexecuted instantiation: Attributes.cpp:_ZL22hasSameElementsOrSplatIN4llvm8ArrayRefISt7complexINS0_7APFloatEEEEEbN4mlir10ShapedTypeERKT_ |
604 | | |
605 | | //===----------------------------------------------------------------------===// |
606 | | // DenseElementAttr Iterators |
607 | | //===----------------------------------------------------------------------===// |
608 | | |
609 | | //===----------------------------------------------------------------------===// |
610 | | // AttributeElementIterator |
611 | | |
612 | | DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( |
613 | | DenseElementsAttr attr, size_t index) |
614 | | : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, |
615 | | Attribute, Attribute, Attribute>( |
616 | 0 | attr.getAsOpaquePointer(), index) {} |
617 | | |
618 | 0 | Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { |
619 | 0 | auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); |
620 | 0 | Type eltTy = owner.getType().getElementType(); |
621 | 0 | if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) { |
622 | 0 | if (intEltTy.getWidth() == 1) |
623 | 0 | return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(), |
624 | 0 | owner.getContext()); |
625 | 0 | return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); |
626 | 0 | } |
627 | 0 | if (eltTy.isa<IndexType>()) |
628 | 0 | return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); |
629 | 0 | if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { |
630 | 0 | IntElementIterator intIt(owner, index); |
631 | 0 | FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); |
632 | 0 | return FloatAttr::get(eltTy, *floatIt); |
633 | 0 | } |
634 | 0 | if (owner.isa<DenseStringElementsAttr>()) { |
635 | 0 | ArrayRef<StringRef> vals = owner.getRawStringData(); |
636 | 0 | return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); |
637 | 0 | } |
638 | 0 | llvm_unreachable("unexpected element type"); |
639 | 0 | } |
640 | | |
641 | | //===----------------------------------------------------------------------===// |
642 | | // BoolElementIterator |
643 | | |
644 | | DenseElementsAttr::BoolElementIterator::BoolElementIterator( |
645 | | DenseElementsAttr attr, size_t dataIndex) |
646 | | : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( |
647 | 0 | attr.getRawData().data(), attr.isSplat(), dataIndex) {} |
648 | | |
649 | 0 | bool DenseElementsAttr::BoolElementIterator::operator*() const { |
650 | 0 | return getBit(getData(), getDataIndex()); |
651 | 0 | } |
652 | | |
653 | | //===----------------------------------------------------------------------===// |
654 | | // IntElementIterator |
655 | | |
656 | | DenseElementsAttr::IntElementIterator::IntElementIterator( |
657 | | DenseElementsAttr attr, size_t dataIndex) |
658 | | : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( |
659 | | attr.getRawData().data(), attr.isSplat(), dataIndex), |
660 | 0 | bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} |
661 | | |
662 | 0 | APInt DenseElementsAttr::IntElementIterator::operator*() const { |
663 | 0 | return readBits(getData(), |
664 | 0 | getDataIndex() * getDenseElementStorageWidth(bitWidth), |
665 | 0 | bitWidth); |
666 | 0 | } |
667 | | |
668 | | //===----------------------------------------------------------------------===// |
669 | | // ComplexIntElementIterator |
670 | | |
671 | | DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( |
672 | | DenseElementsAttr attr, size_t dataIndex) |
673 | | : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, |
674 | | std::complex<APInt>, std::complex<APInt>, |
675 | | std::complex<APInt>>( |
676 | 0 | attr.getRawData().data(), attr.isSplat(), dataIndex) { |
677 | 0 | auto complexType = attr.getType().getElementType().cast<ComplexType>(); |
678 | 0 | bitWidth = getDenseElementBitWidth(complexType.getElementType()); |
679 | 0 | } |
680 | | |
681 | | std::complex<APInt> |
682 | 0 | DenseElementsAttr::ComplexIntElementIterator::operator*() const { |
683 | 0 | size_t storageWidth = getDenseElementStorageWidth(bitWidth); |
684 | 0 | size_t offset = getDataIndex() * storageWidth * 2; |
685 | 0 | return {readBits(getData(), offset, bitWidth), |
686 | 0 | readBits(getData(), offset + storageWidth, bitWidth)}; |
687 | 0 | } |
688 | | |
689 | | //===----------------------------------------------------------------------===// |
690 | | // FloatElementIterator |
691 | | |
692 | | DenseElementsAttr::FloatElementIterator::FloatElementIterator( |
693 | | const llvm::fltSemantics &smt, IntElementIterator it) |
694 | | : llvm::mapped_iterator<IntElementIterator, |
695 | | std::function<APFloat(const APInt &)>>( |
696 | 0 | it, [&](const APInt &val) { return APFloat(smt, val); }) {} |
697 | | |
698 | | //===----------------------------------------------------------------------===// |
699 | | // ComplexFloatElementIterator |
700 | | |
701 | | DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( |
702 | | const llvm::fltSemantics &smt, ComplexIntElementIterator it) |
703 | | : llvm::mapped_iterator< |
704 | | ComplexIntElementIterator, |
705 | | std::function<std::complex<APFloat>(const std::complex<APInt> &)>>( |
706 | 0 | it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> { |
707 | 0 | return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; |
708 | 0 | }) {} |
709 | | |
710 | | //===----------------------------------------------------------------------===// |
711 | | // DenseElementsAttr |
712 | | //===----------------------------------------------------------------------===// |
713 | | |
714 | | DenseElementsAttr DenseElementsAttr::get(ShapedType type, |
715 | 0 | ArrayRef<Attribute> values) { |
716 | 0 | assert(hasSameElementsOrSplat(type, values)); |
717 | 0 |
|
718 | 0 | // If the element type is not based on int/float/index, assume it is a string |
719 | 0 | // type. |
720 | 0 | auto eltType = type.getElementType(); |
721 | 0 | if (!type.getElementType().isIntOrIndexOrFloat()) { |
722 | 0 | SmallVector<StringRef, 8> stringValues; |
723 | 0 | stringValues.reserve(values.size()); |
724 | 0 | for (Attribute attr : values) { |
725 | 0 | assert(attr.isa<StringAttr>() && |
726 | 0 | "expected string value for non integer/index/float element"); |
727 | 0 | stringValues.push_back(attr.cast<StringAttr>().getValue()); |
728 | 0 | } |
729 | 0 | return get(type, stringValues); |
730 | 0 | } |
731 | 0 |
|
732 | 0 | // Otherwise, get the raw storage width to use for the allocation. |
733 | 0 | size_t bitWidth = getDenseElementBitWidth(eltType); |
734 | 0 | size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); |
735 | 0 |
|
736 | 0 | // Compress the attribute values into a character buffer. |
737 | 0 | SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * |
738 | 0 | values.size()); |
739 | 0 | APInt intVal; |
740 | 0 | for (unsigned i = 0, e = values.size(); i < e; ++i) { |
741 | 0 | assert(eltType == values[i].getType() && |
742 | 0 | "expected attribute value to have element type"); |
743 | 0 |
|
744 | 0 | switch (eltType.getKind()) { |
745 | 0 | case StandardTypes::BF16: |
746 | 0 | case StandardTypes::F16: |
747 | 0 | case StandardTypes::F32: |
748 | 0 | case StandardTypes::F64: |
749 | 0 | intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); |
750 | 0 | break; |
751 | 0 | case StandardTypes::Integer: |
752 | 0 | case StandardTypes::Index: |
753 | 0 | intVal = values[i].isa<BoolAttr>() |
754 | 0 | ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0) |
755 | 0 | : values[i].cast<IntegerAttr>().getValue(); |
756 | 0 | break; |
757 | 0 | default: |
758 | 0 | llvm_unreachable("unexpected element type"); |
759 | 0 | } |
760 | 0 | assert(intVal.getBitWidth() == bitWidth && |
761 | 0 | "expected value to have same bitwidth as element type"); |
762 | 0 | writeBits(data.data(), i * storageBitWidth, intVal); |
763 | 0 | } |
764 | 0 | return DenseIntOrFPElementsAttr::getRaw(type, data, |
765 | 0 | /*isSplat=*/(values.size() == 1)); |
766 | 0 | } |
767 | | |
768 | | DenseElementsAttr DenseElementsAttr::get(ShapedType type, |
769 | 0 | ArrayRef<bool> values) { |
770 | 0 | assert(hasSameElementsOrSplat(type, values)); |
771 | 0 | assert(type.getElementType().isInteger(1)); |
772 | 0 |
|
773 | 0 | std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); |
774 | 0 | for (int i = 0, e = values.size(); i != e; ++i) |
775 | 0 | setBit(buff.data(), i, values[i]); |
776 | 0 | return DenseIntOrFPElementsAttr::getRaw(type, buff, |
777 | 0 | /*isSplat=*/(values.size() == 1)); |
778 | 0 | } |
779 | | |
780 | | DenseElementsAttr DenseElementsAttr::get(ShapedType type, |
781 | 0 | ArrayRef<StringRef> values) { |
782 | 0 | assert(!type.getElementType().isIntOrFloat()); |
783 | 0 | return DenseStringElementsAttr::get(type, values); |
784 | 0 | } |
785 | | |
786 | | /// Constructs a dense integer elements attribute from an array of APInt |
787 | | /// values. Each APInt value is expected to have the same bitwidth as the |
788 | | /// element type of 'type'. |
789 | | DenseElementsAttr DenseElementsAttr::get(ShapedType type, |
790 | 0 | ArrayRef<APInt> values) { |
791 | 0 | assert(type.getElementType().isIntOrIndex()); |
792 | 0 | assert(hasSameElementsOrSplat(type, values)); |
793 | 0 | size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); |
794 | 0 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, |
795 | 0 | /*isSplat=*/(values.size() == 1)); |
796 | 0 | } |
797 | | DenseElementsAttr DenseElementsAttr::get(ShapedType type, |
798 | 0 | ArrayRef<std::complex<APInt>> values) { |
799 | 0 | ComplexType complex = type.getElementType().cast<ComplexType>(); |
800 | 0 | assert(complex.getElementType().isa<IntegerType>()); |
801 | 0 | assert(hasSameElementsOrSplat(type, values)); |
802 | 0 | size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; |
803 | 0 | ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), |
804 | 0 | values.size() * 2); |
805 | 0 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, |
806 | 0 | /*isSplat=*/(values.size() == 1)); |
807 | 0 | } |
808 | | |
809 | | // Constructs a dense float elements attribute from an array of APFloat |
810 | | // values. Each APFloat value is expected to have the same bitwidth as the |
811 | | // element type of 'type'. |
812 | | DenseElementsAttr DenseElementsAttr::get(ShapedType type, |
813 | 0 | ArrayRef<APFloat> values) { |
814 | 0 | assert(type.getElementType().isa<FloatType>()); |
815 | 0 | assert(hasSameElementsOrSplat(type, values)); |
816 | 0 | size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); |
817 | 0 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, |
818 | 0 | /*isSplat=*/(values.size() == 1)); |
819 | 0 | } |
820 | | DenseElementsAttr |
821 | | DenseElementsAttr::get(ShapedType type, |
822 | 0 | ArrayRef<std::complex<APFloat>> values) { |
823 | 0 | ComplexType complex = type.getElementType().cast<ComplexType>(); |
824 | 0 | assert(complex.getElementType().isa<FloatType>()); |
825 | 0 | assert(hasSameElementsOrSplat(type, values)); |
826 | 0 | ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), |
827 | 0 | values.size() * 2); |
828 | 0 | size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; |
829 | 0 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, |
830 | 0 | /*isSplat=*/(values.size() == 1)); |
831 | 0 | } |
832 | | |
833 | | /// Construct a dense elements attribute from a raw buffer representing the |
834 | | /// data for this attribute. Users should generally not use this methods as |
835 | | /// the expected buffer format may not be a form the user expects. |
836 | | DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, |
837 | | ArrayRef<char> rawBuffer, |
838 | 0 | bool isSplatBuffer) { |
839 | 0 | return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); |
840 | 0 | } |
841 | | |
842 | | /// Returns true if the given buffer is a valid raw buffer for the given type. |
843 | | bool DenseElementsAttr::isValidRawBuffer(ShapedType type, |
844 | | ArrayRef<char> rawBuffer, |
845 | 0 | bool &detectedSplat) { |
846 | 0 | size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); |
847 | 0 | size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; |
848 | 0 |
|
849 | 0 | // Storage width of 1 is special as it is packed by the bit. |
850 | 0 | if (storageWidth == 1) { |
851 | 0 | // Check for a splat, or a buffer equal to the number of elements. |
852 | 0 | if ((detectedSplat = rawBuffer.size() == 1)) |
853 | 0 | return true; |
854 | 0 | return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); |
855 | 0 | } |
856 | 0 | // All other types are 8-bit aligned. |
857 | 0 | if ((detectedSplat = rawBufferWidth == storageWidth)) |
858 | 0 | return true; |
859 | 0 | return rawBufferWidth == (storageWidth * type.getNumElements()); |
860 | 0 | } |
861 | | |
862 | | /// Check the information for a C++ data type, check if this type is valid for |
863 | | /// the current attribute. This method is used to verify specific type |
864 | | /// invariants that the templatized 'getValues' method cannot. |
865 | | static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, |
866 | 0 | bool isSigned) { |
867 | 0 | // Make sure that the data element size is the same as the type element width. |
868 | 0 | if (getDenseElementBitWidth(type) != |
869 | 0 | static_cast<size_t>(dataEltSize * CHAR_BIT)) |
870 | 0 | return false; |
871 | 0 | |
872 | 0 | // Check that the element type is either float or integer or index. |
873 | 0 | if (!isInt) |
874 | 0 | return type.isa<FloatType>(); |
875 | 0 | if (type.isIndex()) |
876 | 0 | return true; |
877 | 0 | |
878 | 0 | auto intType = type.dyn_cast<IntegerType>(); |
879 | 0 | if (!intType) |
880 | 0 | return false; |
881 | 0 | |
882 | 0 | // Make sure signedness semantics is consistent. |
883 | 0 | if (intType.isSignless()) |
884 | 0 | return true; |
885 | 0 | return intType.isSigned() ? isSigned : !isSigned; |
886 | 0 | } |
887 | | |
888 | | /// Defaults down the subclass implementation. |
889 | | DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, |
890 | | ArrayRef<char> data, |
891 | | int64_t dataEltSize, |
892 | 0 | bool isInt, bool isSigned) { |
893 | 0 | return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, |
894 | 0 | isSigned); |
895 | 0 | } |
896 | | DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, |
897 | | ArrayRef<char> data, |
898 | | int64_t dataEltSize, |
899 | | bool isInt, |
900 | 0 | bool isSigned) { |
901 | 0 | return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, |
902 | 0 | isInt, isSigned); |
903 | 0 | } |
904 | | |
905 | | /// A method used to verify specific type invariants that the templatized 'get' |
906 | | /// method cannot. |
907 | | bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, |
908 | 0 | bool isSigned) const { |
909 | 0 | return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, |
910 | 0 | isSigned); |
911 | 0 | } |
912 | | |
913 | | /// Check the information for a C++ data type, check if this type is valid for |
914 | | /// the current attribute. |
915 | | bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, |
916 | 0 | bool isSigned) const { |
917 | 0 | return ::isValidIntOrFloat( |
918 | 0 | getType().getElementType().cast<ComplexType>().getElementType(), |
919 | 0 | dataEltSize / 2, isInt, isSigned); |
920 | 0 | } |
921 | | |
922 | | /// Returns if this attribute corresponds to a splat, i.e. if all element |
923 | | /// values are the same. |
924 | 0 | bool DenseElementsAttr::isSplat() const { |
925 | 0 | return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; |
926 | 0 | } |
927 | | |
928 | | /// Return the held element values as a range of Attributes. |
929 | | auto DenseElementsAttr::getAttributeValues() const |
930 | 0 | -> llvm::iterator_range<AttributeElementIterator> { |
931 | 0 | return {attr_value_begin(), attr_value_end()}; |
932 | 0 | } |
933 | 0 | auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { |
934 | 0 | return AttributeElementIterator(*this, 0); |
935 | 0 | } |
936 | 0 | auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { |
937 | 0 | return AttributeElementIterator(*this, getNumElements()); |
938 | 0 | } |
939 | | |
940 | | /// Return the held element values as a range of bool. The element type of |
941 | | /// this attribute must be of integer type of bitwidth 1. |
942 | | auto DenseElementsAttr::getBoolValues() const |
943 | 0 | -> llvm::iterator_range<BoolElementIterator> { |
944 | 0 | auto eltType = getType().getElementType().dyn_cast<IntegerType>(); |
945 | 0 | assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); |
946 | 0 | (void)eltType; |
947 | 0 | return {BoolElementIterator(*this, 0), |
948 | 0 | BoolElementIterator(*this, getNumElements())}; |
949 | 0 | } |
950 | | |
951 | | /// Return the held element values as a range of APInts. The element type of |
952 | | /// this attribute must be of integer type. |
953 | | auto DenseElementsAttr::getIntValues() const |
954 | 0 | -> llvm::iterator_range<IntElementIterator> { |
955 | 0 | assert(getType().getElementType().isIntOrIndex() && "expected integral type"); |
956 | 0 | return {raw_int_begin(), raw_int_end()}; |
957 | 0 | } |
958 | 0 | auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { |
959 | 0 | assert(getType().getElementType().isIntOrIndex() && "expected integral type"); |
960 | 0 | return raw_int_begin(); |
961 | 0 | } |
962 | 0 | auto DenseElementsAttr::int_value_end() const -> IntElementIterator { |
963 | 0 | assert(getType().getElementType().isIntOrIndex() && "expected integral type"); |
964 | 0 | return raw_int_end(); |
965 | 0 | } |
966 | | auto DenseElementsAttr::getComplexIntValues() const |
967 | 0 | -> llvm::iterator_range<ComplexIntElementIterator> { |
968 | 0 | Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); |
969 | 0 | (void)eltTy; |
970 | 0 | assert(eltTy.isa<IntegerType>() && "expected complex integral type"); |
971 | 0 | return {ComplexIntElementIterator(*this, 0), |
972 | 0 | ComplexIntElementIterator(*this, getNumElements())}; |
973 | 0 | } |
974 | | |
975 | | /// Return the held element values as a range of APFloat. The element type of |
976 | | /// this attribute must be of float type. |
977 | | auto DenseElementsAttr::getFloatValues() const |
978 | 0 | -> llvm::iterator_range<FloatElementIterator> { |
979 | 0 | auto elementType = getType().getElementType().cast<FloatType>(); |
980 | 0 | const auto &elementSemantics = elementType.getFloatSemantics(); |
981 | 0 | return {FloatElementIterator(elementSemantics, raw_int_begin()), |
982 | 0 | FloatElementIterator(elementSemantics, raw_int_end())}; |
983 | 0 | } |
984 | 0 | auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { |
985 | 0 | return getFloatValues().begin(); |
986 | 0 | } |
987 | 0 | auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { |
988 | 0 | return getFloatValues().end(); |
989 | 0 | } |
990 | | auto DenseElementsAttr::getComplexFloatValues() const |
991 | 0 | -> llvm::iterator_range<ComplexFloatElementIterator> { |
992 | 0 | Type eltTy = getType().getElementType().cast<ComplexType>().getElementType(); |
993 | 0 | assert(eltTy.isa<FloatType>() && "expected complex float type"); |
994 | 0 | const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); |
995 | 0 | return {{semantics, {*this, 0}}, |
996 | 0 | {semantics, {*this, static_cast<size_t>(getNumElements())}}}; |
997 | 0 | } |
998 | | |
999 | | /// Return the raw storage data held by this attribute. |
1000 | 0 | ArrayRef<char> DenseElementsAttr::getRawData() const { |
1001 | 0 | return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data; |
1002 | 0 | } |
1003 | | |
1004 | 0 | ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { |
1005 | 0 | return static_cast<DenseStringElementsAttributeStorage *>(impl)->data; |
1006 | 0 | } |
1007 | | |
1008 | | /// Return a new DenseElementsAttr that has the same data as the current |
1009 | | /// attribute, but has been reshaped to 'newType'. The new type must have the |
1010 | | /// same total number of elements as well as element type. |
1011 | 0 | DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { |
1012 | 0 | ShapedType curType = getType(); |
1013 | 0 | if (curType == newType) |
1014 | 0 | return *this; |
1015 | 0 | |
1016 | 0 | (void)curType; |
1017 | 0 | assert(newType.getElementType() == curType.getElementType() && |
1018 | 0 | "expected the same element type"); |
1019 | 0 | assert(newType.getNumElements() == curType.getNumElements() && |
1020 | 0 | "expected the same number of elements"); |
1021 | 0 | return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); |
1022 | 0 | } |
1023 | | |
1024 | | DenseElementsAttr |
1025 | | DenseElementsAttr::mapValues(Type newElementType, |
1026 | 0 | function_ref<APInt(const APInt &)> mapping) const { |
1027 | 0 | return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); |
1028 | 0 | } |
1029 | | |
1030 | | DenseElementsAttr DenseElementsAttr::mapValues( |
1031 | 0 | Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { |
1032 | 0 | return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); |
1033 | 0 | } |
1034 | | |
1035 | | //===----------------------------------------------------------------------===// |
1036 | | // DenseStringElementsAttr |
1037 | | //===----------------------------------------------------------------------===// |
1038 | | |
1039 | | DenseStringElementsAttr |
1040 | 0 | DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) { |
1041 | 0 | return Base::get(type.getContext(), StandardAttributes::DenseStringElements, |
1042 | 0 | type, values, (values.size() == 1)); |
1043 | 0 | } |
1044 | | |
1045 | | //===----------------------------------------------------------------------===// |
1046 | | // DenseIntOrFPElementsAttr |
1047 | | //===----------------------------------------------------------------------===// |
1048 | | |
1049 | | /// Utility method to write a range of APInt values to a buffer. |
1050 | | template <typename APRangeT> |
1051 | | static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, |
1052 | 0 | APRangeT &&values) { |
1053 | 0 | data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); |
1054 | 0 | size_t offset = 0; |
1055 | 0 | for (auto it = values.begin(), e = values.end(); it != e; |
1056 | 0 | ++it, offset += storageWidth) { |
1057 | 0 | assert((*it).getBitWidth() <= storageWidth); |
1058 | 0 | writeBits(data.data(), offset, *it); |
1059 | 0 | } |
1060 | 0 | } Unexecuted instantiation: Attributes.cpp:_ZL19writeAPIntsToBufferIN4llvm14iterator_rangeINS0_15mapped_iteratorIPKNS0_7APFloatEZN4mlir24DenseIntOrFPElementsAttr6getRawENS6_10ShapedTypeEmNS0_8ArrayRefIS3_EEbE3$_6NS0_5APIntEEEEEEvmRSt6vectorIcSaIcEEOT_ Unexecuted instantiation: Attributes.cpp:_ZL19writeAPIntsToBufferIRN4llvm8ArrayRefINS0_5APIntEEEEvmRSt6vectorIcSaIcEEOT_ |
1061 | | |
1062 | | /// Constructs a dense elements attribute from an array of raw APFloat values. |
1063 | | /// Each APFloat value is expected to have the same bitwidth as the element |
1064 | | /// type of 'type'. 'type' must be a vector or tensor with static shape. |
1065 | | DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, |
1066 | | size_t storageWidth, |
1067 | | ArrayRef<APFloat> values, |
1068 | 0 | bool isSplat) { |
1069 | 0 | std::vector<char> data; |
1070 | 0 | auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; |
1071 | 0 | writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); |
1072 | 0 | return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); |
1073 | 0 | } |
1074 | | |
1075 | | /// Constructs a dense elements attribute from an array of raw APInt values. |
1076 | | /// Each APInt value is expected to have the same bitwidth as the element type |
1077 | | /// of 'type'. |
1078 | | DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, |
1079 | | size_t storageWidth, |
1080 | | ArrayRef<APInt> values, |
1081 | 0 | bool isSplat) { |
1082 | 0 | std::vector<char> data; |
1083 | 0 | writeAPIntsToBuffer(storageWidth, data, values); |
1084 | 0 | return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); |
1085 | 0 | } |
1086 | | |
1087 | | DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, |
1088 | | ArrayRef<char> data, |
1089 | 0 | bool isSplat) { |
1090 | 0 | assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && |
1091 | 0 | "type must be ranked tensor or vector"); |
1092 | 0 | assert(type.hasStaticShape() && "type must have static shape"); |
1093 | 0 | return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements, |
1094 | 0 | type, data, isSplat); |
1095 | 0 | } |
1096 | | |
1097 | | /// Overload of the raw 'get' method that asserts that the given type is of |
1098 | | /// complex type. This method is used to verify type invariants that the |
1099 | | /// templatized 'get' method cannot. |
1100 | | DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, |
1101 | | ArrayRef<char> data, |
1102 | | int64_t dataEltSize, |
1103 | | bool isInt, |
1104 | 0 | bool isSigned) { |
1105 | 0 | assert(::isValidIntOrFloat( |
1106 | 0 | type.getElementType().cast<ComplexType>().getElementType(), |
1107 | 0 | dataEltSize / 2, isInt, isSigned)); |
1108 | 0 |
|
1109 | 0 | int64_t numElements = data.size() / dataEltSize; |
1110 | 0 | assert(numElements == 1 || numElements == type.getNumElements()); |
1111 | 0 | return getRaw(type, data, /*isSplat=*/numElements == 1); |
1112 | 0 | } |
1113 | | |
1114 | | /// Overload of the 'getRaw' method that asserts that the given type is of |
1115 | | /// integer type. This method is used to verify type invariants that the |
1116 | | /// templatized 'get' method cannot. |
1117 | | DenseElementsAttr |
1118 | | DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, |
1119 | | int64_t dataEltSize, bool isInt, |
1120 | 0 | bool isSigned) { |
1121 | 0 | assert( |
1122 | 0 | ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); |
1123 | 0 |
|
1124 | 0 | int64_t numElements = data.size() / dataEltSize; |
1125 | 0 | assert(numElements == 1 || numElements == type.getNumElements()); |
1126 | 0 | return getRaw(type, data, /*isSplat=*/numElements == 1); |
1127 | 0 | } |
1128 | | |
1129 | | //===----------------------------------------------------------------------===// |
1130 | | // DenseFPElementsAttr |
1131 | | //===----------------------------------------------------------------------===// |
1132 | | |
1133 | | template <typename Fn, typename Attr> |
1134 | | static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, |
1135 | | Type newElementType, |
1136 | 0 | llvm::SmallVectorImpl<char> &data) { |
1137 | 0 | size_t bitWidth = getDenseElementBitWidth(newElementType); |
1138 | 0 | size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); |
1139 | 0 |
|
1140 | 0 | ShapedType newArrayType; |
1141 | 0 | if (inType.isa<RankedTensorType>()) |
1142 | 0 | newArrayType = RankedTensorType::get(inType.getShape(), newElementType); |
1143 | 0 | else if (inType.isa<UnrankedTensorType>()) |
1144 | 0 | newArrayType = RankedTensorType::get(inType.getShape(), newElementType); |
1145 | 0 | else if (inType.isa<VectorType>()) |
1146 | 0 | newArrayType = VectorType::get(inType.getShape(), newElementType); |
1147 | 0 | else |
1148 | 0 | assert(newArrayType && "Unhandled tensor type"); |
1149 | 0 |
|
1150 | 0 | size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); |
1151 | 0 | data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); |
1152 | 0 |
|
1153 | 0 | // Functor used to process a single element value of the attribute. |
1154 | 0 | auto processElt = [&](decltype(*attr.begin()) value, size_t index) { |
1155 | 0 | auto newInt = mapping(value); |
1156 | 0 | assert(newInt.getBitWidth() == bitWidth); |
1157 | 0 | writeBits(data.data(), index * storageBitWidth, newInt); |
1158 | 0 | }; Unexecuted instantiation: Attributes.cpp:_ZZL13mappingHelperIN4llvm12function_refIFNS0_5APIntERKNS0_7APFloatEEEEKN4mlir19DenseFPElementsAttrEENS8_10ShapedTypeET_RT0_SB_NS8_4TypeERNS0_15SmallVectorImplIcEEENKUlS3_mE_clES3_m Unexecuted instantiation: Attributes.cpp:_ZZL13mappingHelperIN4llvm12function_refIFNS0_5APIntERKS2_EEEKN4mlir20DenseIntElementsAttrEENS7_10ShapedTypeET_RT0_SA_NS7_4TypeERNS0_15SmallVectorImplIcEEENKUlS2_mE_clES2_m |
1159 | 0 |
|
1160 | 0 | // Check for the splat case. |
1161 | 0 | if (attr.isSplat()) { |
1162 | 0 | processElt(*attr.begin(), /*index=*/0); |
1163 | 0 | return newArrayType; |
1164 | 0 | } |
1165 | 0 | |
1166 | 0 | // Otherwise, process all of the element values. |
1167 | 0 | uint64_t elementIdx = 0; |
1168 | 0 | for (auto value : attr) |
1169 | 0 | processElt(value, elementIdx++); |
1170 | 0 | return newArrayType; |
1171 | 0 | } Unexecuted instantiation: Attributes.cpp:_ZL13mappingHelperIN4llvm12function_refIFNS0_5APIntERKNS0_7APFloatEEEEKN4mlir19DenseFPElementsAttrEENS8_10ShapedTypeET_RT0_SB_NS8_4TypeERNS0_15SmallVectorImplIcEE Unexecuted instantiation: Attributes.cpp:_ZL13mappingHelperIN4llvm12function_refIFNS0_5APIntERKS2_EEEKN4mlir20DenseIntElementsAttrEENS7_10ShapedTypeET_RT0_SA_NS7_4TypeERNS0_15SmallVectorImplIcEE |
1172 | | |
1173 | | DenseElementsAttr DenseFPElementsAttr::mapValues( |
1174 | 0 | Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { |
1175 | 0 | llvm::SmallVector<char, 8> elementData; |
1176 | 0 | auto newArrayType = |
1177 | 0 | mappingHelper(mapping, *this, getType(), newElementType, elementData); |
1178 | 0 |
|
1179 | 0 | return getRaw(newArrayType, elementData, isSplat()); |
1180 | 0 | } |
1181 | | |
1182 | | /// Method for supporting type inquiry through isa, cast and dyn_cast. |
1183 | 0 | bool DenseFPElementsAttr::classof(Attribute attr) { |
1184 | 0 | return attr.isa<DenseElementsAttr>() && |
1185 | 0 | attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); |
1186 | 0 | } |
1187 | | |
1188 | | //===----------------------------------------------------------------------===// |
1189 | | // DenseIntElementsAttr |
1190 | | //===----------------------------------------------------------------------===// |
1191 | | |
1192 | | DenseElementsAttr DenseIntElementsAttr::mapValues( |
1193 | 0 | Type newElementType, function_ref<APInt(const APInt &)> mapping) const { |
1194 | 0 | llvm::SmallVector<char, 8> elementData; |
1195 | 0 | auto newArrayType = |
1196 | 0 | mappingHelper(mapping, *this, getType(), newElementType, elementData); |
1197 | 0 |
|
1198 | 0 | return getRaw(newArrayType, elementData, isSplat()); |
1199 | 0 | } |
1200 | | |
1201 | | /// Method for supporting type inquiry through isa, cast and dyn_cast. |
1202 | 0 | bool DenseIntElementsAttr::classof(Attribute attr) { |
1203 | 0 | return attr.isa<DenseElementsAttr>() && |
1204 | 0 | attr.getType().cast<ShapedType>().getElementType().isIntOrIndex(); |
1205 | 0 | } |
1206 | | |
1207 | | //===----------------------------------------------------------------------===// |
1208 | | // OpaqueElementsAttr |
1209 | | //===----------------------------------------------------------------------===// |
1210 | | |
1211 | | OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, |
1212 | 0 | StringRef bytes) { |
1213 | 0 | assert(TensorType::isValidElementType(type.getElementType()) && |
1214 | 0 | "Input element type should be a valid tensor element type"); |
1215 | 0 | return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, |
1216 | 0 | dialect, bytes); |
1217 | 0 | } |
1218 | | |
1219 | 0 | StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } |
1220 | | |
1221 | | /// Return the value at the given index. If index does not refer to a valid |
1222 | | /// element, then a null attribute is returned. |
1223 | 0 | Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { |
1224 | 0 | assert(isValidIndex(index) && "expected valid multi-dimensional index"); |
1225 | 0 | if (Dialect *dialect = getDialect()) |
1226 | 0 | return dialect->extractElementHook(*this, index); |
1227 | 0 | return Attribute(); |
1228 | 0 | } |
1229 | | |
1230 | 0 | Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } |
1231 | | |
1232 | 0 | bool OpaqueElementsAttr::decode(ElementsAttr &result) { |
1233 | 0 | if (auto *d = getDialect()) |
1234 | 0 | return d->decodeHook(*this, result); |
1235 | 0 | return true; |
1236 | 0 | } |
1237 | | |
1238 | | //===----------------------------------------------------------------------===// |
1239 | | // SparseElementsAttr |
1240 | | //===----------------------------------------------------------------------===// |
1241 | | |
1242 | | SparseElementsAttr SparseElementsAttr::get(ShapedType type, |
1243 | | DenseElementsAttr indices, |
1244 | 0 | DenseElementsAttr values) { |
1245 | 0 | assert(indices.getType().getElementType().isInteger(64) && |
1246 | 0 | "expected sparse indices to be 64-bit integer values"); |
1247 | 0 | assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && |
1248 | 0 | "type must be ranked tensor or vector"); |
1249 | 0 | assert(type.hasStaticShape() && "type must have static shape"); |
1250 | 0 | return Base::get(type.getContext(), StandardAttributes::SparseElements, type, |
1251 | 0 | indices.cast<DenseIntElementsAttr>(), values); |
1252 | 0 | } |
1253 | | |
1254 | 0 | DenseIntElementsAttr SparseElementsAttr::getIndices() const { |
1255 | 0 | return getImpl()->indices; |
1256 | 0 | } |
1257 | | |
1258 | 0 | DenseElementsAttr SparseElementsAttr::getValues() const { |
1259 | 0 | return getImpl()->values; |
1260 | 0 | } |
1261 | | |
1262 | | /// Return the value of the element at the given index. |
1263 | 0 | Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { |
1264 | 0 | assert(isValidIndex(index) && "expected valid multi-dimensional index"); |
1265 | 0 | auto type = getType(); |
1266 | 0 |
|
1267 | 0 | // The sparse indices are 64-bit integers, so we can reinterpret the raw data |
1268 | 0 | // as a 1-D index array. |
1269 | 0 | auto sparseIndices = getIndices(); |
1270 | 0 | auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); |
1271 | 0 |
|
1272 | 0 | // Check to see if the indices are a splat. |
1273 | 0 | if (sparseIndices.isSplat()) { |
1274 | 0 | // If the index is also not a splat of the index value, we know that the |
1275 | 0 | // value is zero. |
1276 | 0 | auto splatIndex = *sparseIndexValues.begin(); |
1277 | 0 | if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) |
1278 | 0 | return getZeroAttr(); |
1279 | 0 | |
1280 | 0 | // If the indices are a splat, we also expect the values to be a splat. |
1281 | 0 | assert(getValues().isSplat() && "expected splat values"); |
1282 | 0 | return getValues().getSplatValue(); |
1283 | 0 | } |
1284 | 0 | |
1285 | 0 | // Build a mapping between known indices and the offset of the stored element. |
1286 | 0 | llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; |
1287 | 0 | auto numSparseIndices = sparseIndices.getType().getDimSize(0); |
1288 | 0 | size_t rank = type.getRank(); |
1289 | 0 | for (size_t i = 0, e = numSparseIndices; i != e; ++i) |
1290 | 0 | mappedIndices.try_emplace( |
1291 | 0 | {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); |
1292 | 0 |
|
1293 | 0 | // Look for the provided index key within the mapped indices. If the provided |
1294 | 0 | // index is not found, then return a zero attribute. |
1295 | 0 | auto it = mappedIndices.find(index); |
1296 | 0 | if (it == mappedIndices.end()) |
1297 | 0 | return getZeroAttr(); |
1298 | 0 | |
1299 | 0 | // Otherwise, return the held sparse value element. |
1300 | 0 | return getValues().getValue(it->second); |
1301 | 0 | } |
1302 | | |
1303 | | /// Get a zero APFloat for the given sparse attribute. |
1304 | 0 | APFloat SparseElementsAttr::getZeroAPFloat() const { |
1305 | 0 | auto eltType = getType().getElementType().cast<FloatType>(); |
1306 | 0 | return APFloat(eltType.getFloatSemantics()); |
1307 | 0 | } |
1308 | | |
1309 | | /// Get a zero APInt for the given sparse attribute. |
1310 | 0 | APInt SparseElementsAttr::getZeroAPInt() const { |
1311 | 0 | auto eltType = getType().getElementType().cast<IntegerType>(); |
1312 | 0 | return APInt::getNullValue(eltType.getWidth()); |
1313 | 0 | } |
1314 | | |
1315 | | /// Get a zero attribute for the given attribute type. |
1316 | 0 | Attribute SparseElementsAttr::getZeroAttr() const { |
1317 | 0 | auto eltType = getType().getElementType(); |
1318 | 0 |
|
1319 | 0 | // Handle floating point elements. |
1320 | 0 | if (eltType.isa<FloatType>()) |
1321 | 0 | return FloatAttr::get(eltType, 0); |
1322 | 0 | |
1323 | 0 | // Otherwise, this is an integer. |
1324 | 0 | auto intEltTy = eltType.cast<IntegerType>(); |
1325 | 0 | if (intEltTy.getWidth() == 1) |
1326 | 0 | return BoolAttr::get(false, eltType.getContext()); |
1327 | 0 | return IntegerAttr::get(eltType, 0); |
1328 | 0 | } |
1329 | | |
1330 | | /// Flatten, and return, all of the sparse indices in this attribute in |
1331 | | /// row-major order. |
1332 | 0 | std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { |
1333 | 0 | std::vector<ptrdiff_t> flatSparseIndices; |
1334 | 0 |
|
1335 | 0 | // The sparse indices are 64-bit integers, so we can reinterpret the raw data |
1336 | 0 | // as a 1-D index array. |
1337 | 0 | auto sparseIndices = getIndices(); |
1338 | 0 | auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); |
1339 | 0 | if (sparseIndices.isSplat()) { |
1340 | 0 | SmallVector<uint64_t, 8> indices(getType().getRank(), |
1341 | 0 | *sparseIndexValues.begin()); |
1342 | 0 | flatSparseIndices.push_back(getFlattenedIndex(indices)); |
1343 | 0 | return flatSparseIndices; |
1344 | 0 | } |
1345 | 0 | |
1346 | 0 | // Otherwise, reinterpret each index as an ArrayRef when flattening. |
1347 | 0 | auto numSparseIndices = sparseIndices.getType().getDimSize(0); |
1348 | 0 | size_t rank = getType().getRank(); |
1349 | 0 | for (size_t i = 0, e = numSparseIndices; i != e; ++i) |
1350 | 0 | flatSparseIndices.push_back(getFlattenedIndex( |
1351 | 0 | {&*std::next(sparseIndexValues.begin(), i * rank), rank})); |
1352 | 0 | return flatSparseIndices; |
1353 | 0 | } |
1354 | | |
1355 | | //===----------------------------------------------------------------------===// |
1356 | | // MutableDictionaryAttr |
1357 | | //===----------------------------------------------------------------------===// |
1358 | | |
1359 | | MutableDictionaryAttr::MutableDictionaryAttr( |
1360 | 0 | ArrayRef<NamedAttribute> attributes) { |
1361 | 0 | setAttrs(attributes); |
1362 | 0 | } |
1363 | | |
1364 | | /// Return the underlying dictionary attribute. |
1365 | | DictionaryAttr |
1366 | 0 | MutableDictionaryAttr::getDictionary(MLIRContext *context) const { |
1367 | 0 | // Construct empty DictionaryAttr if needed. |
1368 | 0 | if (!attrs) |
1369 | 0 | return DictionaryAttr::get({}, context); |
1370 | 0 | return attrs; |
1371 | 0 | } |
1372 | | |
1373 | 0 | ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const { |
1374 | 0 | return attrs ? attrs.getValue() : llvm::None; |
1375 | 0 | } |
1376 | | |
1377 | | /// Replace the held attributes with ones provided in 'newAttrs'. |
1378 | 0 | void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) { |
1379 | 0 | // Don't create an attribute list if there are no attributes. |
1380 | 0 | if (attributes.empty()) |
1381 | 0 | attrs = nullptr; |
1382 | 0 | else |
1383 | 0 | attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); |
1384 | 0 | } |
1385 | | |
1386 | | /// Return the specified attribute if present, null otherwise. |
1387 | 0 | Attribute MutableDictionaryAttr::get(StringRef name) const { |
1388 | 0 | return attrs ? attrs.get(name) : nullptr; |
1389 | 0 | } |
1390 | | |
1391 | | /// Return the specified attribute if present, null otherwise. |
1392 | 0 | Attribute MutableDictionaryAttr::get(Identifier name) const { |
1393 | 0 | return attrs ? attrs.get(name) : nullptr; |
1394 | 0 | } |
1395 | | |
1396 | | /// Return the specified named attribute if present, None otherwise. |
1397 | 0 | Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const { |
1398 | 0 | return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); |
1399 | 0 | } |
1400 | | Optional<NamedAttribute> |
1401 | 0 | MutableDictionaryAttr::getNamed(Identifier name) const { |
1402 | 0 | return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>(); |
1403 | 0 | } |
1404 | | |
1405 | | /// If the an attribute exists with the specified name, change it to the new |
1406 | | /// value. Otherwise, add a new attribute with the specified name/value. |
1407 | 0 | void MutableDictionaryAttr::set(Identifier name, Attribute value) { |
1408 | 0 | assert(value && "attributes may never be null"); |
1409 | 0 |
|
1410 | 0 | // Look for an existing value for the given name, and set it in-place. |
1411 | 0 | ArrayRef<NamedAttribute> values = getAttrs(); |
1412 | 0 | const auto *it = llvm::find_if( |
1413 | 0 | values, [name](NamedAttribute attr) { return attr.first == name; }); |
1414 | 0 | if (it != values.end()) { |
1415 | 0 | // Bail out early if the value is the same as what we already have. |
1416 | 0 | if (it->second == value) |
1417 | 0 | return; |
1418 | 0 | |
1419 | 0 | SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end()); |
1420 | 0 | newAttrs[it - values.begin()].second = value; |
1421 | 0 | attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); |
1422 | 0 | return; |
1423 | 0 | } |
1424 | 0 | |
1425 | 0 | // Otherwise, insert the new attribute into its sorted position. |
1426 | 0 | it = llvm::lower_bound(values, name); |
1427 | 0 | SmallVector<NamedAttribute, 8> newAttrs; |
1428 | 0 | newAttrs.reserve(values.size() + 1); |
1429 | 0 | newAttrs.append(values.begin(), it); |
1430 | 0 | newAttrs.push_back({name, value}); |
1431 | 0 | newAttrs.append(it, values.end()); |
1432 | 0 | attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); |
1433 | 0 | } |
1434 | | |
1435 | | /// Remove the attribute with the specified name if it exists. The return |
1436 | | /// value indicates whether the attribute was present or not. |
1437 | 0 | auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult { |
1438 | 0 | auto origAttrs = getAttrs(); |
1439 | 0 | for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { |
1440 | 0 | if (origAttrs[i].first == name) { |
1441 | 0 | // Handle the simple case of removing the only attribute in the list. |
1442 | 0 | if (e == 1) { |
1443 | 0 | attrs = nullptr; |
1444 | 0 | return RemoveResult::Removed; |
1445 | 0 | } |
1446 | 0 | |
1447 | 0 | SmallVector<NamedAttribute, 8> newAttrs; |
1448 | 0 | newAttrs.reserve(origAttrs.size() - 1); |
1449 | 0 | newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); |
1450 | 0 | newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); |
1451 | 0 | attrs = DictionaryAttr::getWithSorted(newAttrs, |
1452 | 0 | newAttrs[0].second.getContext()); |
1453 | 0 | return RemoveResult::Removed; |
1454 | 0 | } |
1455 | 0 | } |
1456 | 0 | return RemoveResult::NotFound; |
1457 | 0 | } |
1458 | | |
1459 | 0 | bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { |
1460 | 0 | return strcmp(lhs.first.data(), rhs.first.data()) < 0; |
1461 | 0 | } |
1462 | 0 | bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) { |
1463 | 0 | // This is correct even when attr.first.data()[name.size()] is not a zero |
1464 | 0 | // string terminator, because we only care about a less than comparison. |
1465 | 0 | // This can't use memcmp, because it doesn't guarantee that it will stop |
1466 | 0 | // reading both buffers if one is shorter than the other, even if there is |
1467 | 0 | // a difference. |
1468 | 0 | return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0; |
1469 | 0 | } |