/home/arjun/llvm-project/mlir/lib/IR/AttributeDetail.h
Line | Count | Source (jump to first uncovered line) |
1 | | //===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This holds implementation details of Attribute. |
10 | | // |
11 | | //===----------------------------------------------------------------------===// |
12 | | |
13 | | #ifndef ATTRIBUTEDETAIL_H_ |
14 | | #define ATTRIBUTEDETAIL_H_ |
15 | | |
16 | | #include "mlir/IR/AffineMap.h" |
17 | | #include "mlir/IR/Attributes.h" |
18 | | #include "mlir/IR/Identifier.h" |
19 | | #include "mlir/IR/IntegerSet.h" |
20 | | #include "mlir/IR/MLIRContext.h" |
21 | | #include "mlir/IR/StandardTypes.h" |
22 | | #include "mlir/Support/StorageUniquer.h" |
23 | | #include "llvm/ADT/APFloat.h" |
24 | | #include "llvm/ADT/PointerIntPair.h" |
25 | | #include "llvm/Support/TrailingObjects.h" |
26 | | |
27 | | namespace mlir { |
28 | | namespace detail { |
29 | | // An attribute representing a reference to an affine map. |
30 | | struct AffineMapAttributeStorage : public AttributeStorage { |
31 | | using KeyTy = AffineMap; |
32 | | |
33 | | AffineMapAttributeStorage(AffineMap value) |
34 | 0 | : AttributeStorage(IndexType::get(value.getContext())), value(value) {} |
35 | | |
36 | | /// Key equality function. |
37 | 0 | bool operator==(const KeyTy &key) const { return key == value; } |
38 | | |
39 | | /// Construct a new storage instance. |
40 | | static AffineMapAttributeStorage * |
41 | 0 | construct(AttributeStorageAllocator &allocator, KeyTy key) { |
42 | 0 | return new (allocator.allocate<AffineMapAttributeStorage>()) |
43 | 0 | AffineMapAttributeStorage(key); |
44 | 0 | } |
45 | | |
46 | | AffineMap value; |
47 | | }; |
48 | | |
49 | | /// An attribute representing an array of other attributes. |
50 | | struct ArrayAttributeStorage : public AttributeStorage { |
51 | | using KeyTy = ArrayRef<Attribute>; |
52 | | |
53 | 0 | ArrayAttributeStorage(ArrayRef<Attribute> value) : value(value) {} |
54 | | |
55 | | /// Key equality function. |
56 | 0 | bool operator==(const KeyTy &key) const { return key == value; } |
57 | | |
58 | | /// Construct a new storage instance. |
59 | | static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator, |
60 | 0 | const KeyTy &key) { |
61 | 0 | return new (allocator.allocate<ArrayAttributeStorage>()) |
62 | 0 | ArrayAttributeStorage(allocator.copyInto(key)); |
63 | 0 | } |
64 | | |
65 | | ArrayRef<Attribute> value; |
66 | | }; |
67 | | |
68 | | /// An attribute representing a boolean value. |
69 | | struct BoolAttributeStorage : public AttributeStorage { |
70 | | using KeyTy = std::pair<MLIRContext *, bool>; |
71 | | |
72 | | BoolAttributeStorage(Type type, bool value) |
73 | 0 | : AttributeStorage(type), value(value) {} |
74 | | |
75 | | /// We only check equality for and hash with the boolean key parameter. |
76 | 0 | bool operator==(const KeyTy &key) const { return key.second == value; } |
77 | 0 | static unsigned hashKey(const KeyTy &key) { |
78 | 0 | return llvm::hash_value(key.second); |
79 | 0 | } |
80 | | |
81 | | static BoolAttributeStorage *construct(AttributeStorageAllocator &allocator, |
82 | 0 | const KeyTy &key) { |
83 | 0 | return new (allocator.allocate<BoolAttributeStorage>()) |
84 | 0 | BoolAttributeStorage(IntegerType::get(1, key.first), key.second); |
85 | 0 | } |
86 | | |
87 | | bool value; |
88 | | }; |
89 | | |
90 | | /// An attribute representing a dictionary of sorted named attributes. |
91 | | struct DictionaryAttributeStorage final |
92 | | : public AttributeStorage, |
93 | | private llvm::TrailingObjects<DictionaryAttributeStorage, |
94 | | NamedAttribute> { |
95 | | using KeyTy = ArrayRef<NamedAttribute>; |
96 | | |
97 | | /// Given a list of NamedAttribute's, canonicalize the list (sorting |
98 | | /// by name) and return the unique'd result. |
99 | | static DictionaryAttributeStorage *get(ArrayRef<NamedAttribute> attrs); |
100 | | |
101 | | /// Key equality function. |
102 | 0 | bool operator==(const KeyTy &key) const { return key == getElements(); } |
103 | | |
104 | | /// Construct a new storage instance. |
105 | | static DictionaryAttributeStorage * |
106 | 0 | construct(AttributeStorageAllocator &allocator, const KeyTy &key) { |
107 | 0 | auto size = DictionaryAttributeStorage::totalSizeToAlloc<NamedAttribute>( |
108 | 0 | key.size()); |
109 | 0 | auto rawMem = allocator.allocate(size, alignof(NamedAttribute)); |
110 | 0 |
|
111 | 0 | // Initialize the storage and trailing attribute list. |
112 | 0 | auto result = ::new (rawMem) DictionaryAttributeStorage(key.size()); |
113 | 0 | std::uninitialized_copy(key.begin(), key.end(), |
114 | 0 | result->getTrailingObjects<NamedAttribute>()); |
115 | 0 | return result; |
116 | 0 | } |
117 | | |
118 | | /// Return the elements of this dictionary attribute. |
119 | 0 | ArrayRef<NamedAttribute> getElements() const { |
120 | 0 | return {getTrailingObjects<NamedAttribute>(), numElements}; |
121 | 0 | } |
122 | | |
123 | | private: |
124 | | friend class llvm::TrailingObjects<DictionaryAttributeStorage, |
125 | | NamedAttribute>; |
126 | | |
127 | | // This is used by the llvm::TrailingObjects base class. |
128 | 0 | size_t numTrailingObjects(OverloadToken<NamedAttribute>) const { |
129 | 0 | return numElements; |
130 | 0 | } |
131 | 0 | DictionaryAttributeStorage(unsigned numElements) : numElements(numElements) {} |
132 | | |
133 | | /// This is the number of attributes. |
134 | | const unsigned numElements; |
135 | | }; |
136 | | |
137 | | /// An attribute representing a floating point value. |
138 | | struct FloatAttributeStorage final |
139 | | : public AttributeStorage, |
140 | | public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> { |
141 | | using KeyTy = std::pair<Type, APFloat>; |
142 | | |
143 | | FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type, |
144 | | size_t numObjects) |
145 | 0 | : AttributeStorage(type), semantics(semantics), numObjects(numObjects) {} |
146 | | |
147 | | /// Key equality and hash functions. |
148 | 0 | bool operator==(const KeyTy &key) const { |
149 | 0 | return key.first == getType() && key.second.bitwiseIsEqual(getValue()); |
150 | 0 | } |
151 | 0 | static unsigned hashKey(const KeyTy &key) { |
152 | 0 | return llvm::hash_combine(key.first, llvm::hash_value(key.second)); |
153 | 0 | } |
154 | | |
155 | | /// Construct a key with a type and double. |
156 | 0 | static KeyTy getKey(Type type, double value) { |
157 | 0 | // Treat BF16 as double because it is not supported in LLVM's APFloat. |
158 | 0 | // TODO(b/121118307): add BF16 support to APFloat? |
159 | 0 | if (type.isBF16() || type.isF64()) |
160 | 0 | return KeyTy(type, APFloat(value)); |
161 | 0 | |
162 | 0 | // This handles, e.g., F16 because there is no APFloat constructor for it. |
163 | 0 | bool unused; |
164 | 0 | APFloat val(value); |
165 | 0 | val.convert(type.cast<FloatType>().getFloatSemantics(), |
166 | 0 | APFloat::rmNearestTiesToEven, &unused); |
167 | 0 | return KeyTy(type, val); |
168 | 0 | } |
169 | | |
170 | | /// Construct a new storage instance. |
171 | | static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator, |
172 | 0 | const KeyTy &key) { |
173 | 0 | const auto &apint = key.second.bitcastToAPInt(); |
174 | 0 |
|
175 | 0 | // Here one word's bitwidth equals to that of uint64_t. |
176 | 0 | auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords()); |
177 | 0 |
|
178 | 0 | auto byteSize = |
179 | 0 | FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size()); |
180 | 0 | auto rawMem = allocator.allocate(byteSize, alignof(FloatAttributeStorage)); |
181 | 0 | auto result = ::new (rawMem) FloatAttributeStorage( |
182 | 0 | key.second.getSemantics(), key.first, elements.size()); |
183 | 0 | std::uninitialized_copy(elements.begin(), elements.end(), |
184 | 0 | result->getTrailingObjects<uint64_t>()); |
185 | 0 | return result; |
186 | 0 | } |
187 | | |
188 | | /// Returns an APFloat representing the stored value. |
189 | 0 | APFloat getValue() const { |
190 | 0 | auto val = APInt(APFloat::getSizeInBits(semantics), |
191 | 0 | {getTrailingObjects<uint64_t>(), numObjects}); |
192 | 0 | return APFloat(semantics, val); |
193 | 0 | } |
194 | | |
195 | | const llvm::fltSemantics &semantics; |
196 | | size_t numObjects; |
197 | | }; |
198 | | |
199 | | /// An attribute representing an integral value. |
200 | | struct IntegerAttributeStorage final |
201 | | : public AttributeStorage, |
202 | | public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> { |
203 | | using KeyTy = std::pair<Type, APInt>; |
204 | | |
205 | | IntegerAttributeStorage(Type type, size_t numObjects) |
206 | 0 | : AttributeStorage(type), numObjects(numObjects) { |
207 | 0 | assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type"); |
208 | 0 | } |
209 | | |
210 | | /// Key equality and hash functions. |
211 | 0 | bool operator==(const KeyTy &key) const { |
212 | 0 | return key == KeyTy(getType(), getValue()); |
213 | 0 | } |
214 | 0 | static unsigned hashKey(const KeyTy &key) { |
215 | 0 | return llvm::hash_combine(key.first, llvm::hash_value(key.second)); |
216 | 0 | } |
217 | | |
218 | | /// Construct a new storage instance. |
219 | | static IntegerAttributeStorage * |
220 | 0 | construct(AttributeStorageAllocator &allocator, const KeyTy &key) { |
221 | 0 | Type type; |
222 | 0 | APInt value; |
223 | 0 | std::tie(type, value) = key; |
224 | 0 |
|
225 | 0 | auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords()); |
226 | 0 | auto size = |
227 | 0 | IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size()); |
228 | 0 | auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage)); |
229 | 0 | auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size()); |
230 | 0 | std::uninitialized_copy(elements.begin(), elements.end(), |
231 | 0 | result->getTrailingObjects<uint64_t>()); |
232 | 0 | return result; |
233 | 0 | } |
234 | | |
235 | | /// Returns an APInt representing the stored value. |
236 | 0 | APInt getValue() const { |
237 | 0 | if (getType().isIndex()) |
238 | 0 | return APInt(64, {getTrailingObjects<uint64_t>(), numObjects}); |
239 | 0 | return APInt(getType().getIntOrFloatBitWidth(), |
240 | 0 | {getTrailingObjects<uint64_t>(), numObjects}); |
241 | 0 | } |
242 | | |
243 | | size_t numObjects; |
244 | | }; |
245 | | |
246 | | // An attribute representing a reference to an integer set. |
247 | | struct IntegerSetAttributeStorage : public AttributeStorage { |
248 | | using KeyTy = IntegerSet; |
249 | | |
250 | 0 | IntegerSetAttributeStorage(IntegerSet value) : value(value) {} |
251 | | |
252 | | /// Key equality function. |
253 | 0 | bool operator==(const KeyTy &key) const { return key == value; } |
254 | | |
255 | | /// Construct a new storage instance. |
256 | | static IntegerSetAttributeStorage * |
257 | 0 | construct(AttributeStorageAllocator &allocator, KeyTy key) { |
258 | 0 | return new (allocator.allocate<IntegerSetAttributeStorage>()) |
259 | 0 | IntegerSetAttributeStorage(key); |
260 | 0 | } |
261 | | |
262 | | IntegerSet value; |
263 | | }; |
264 | | |
265 | | /// Opaque Attribute Storage and Uniquing. |
266 | | struct OpaqueAttributeStorage : public AttributeStorage { |
267 | | OpaqueAttributeStorage(Identifier dialectNamespace, StringRef attrData, |
268 | | Type type) |
269 | | : AttributeStorage(type), dialectNamespace(dialectNamespace), |
270 | 0 | attrData(attrData) {} |
271 | | |
272 | | /// The hash key used for uniquing. |
273 | | using KeyTy = std::tuple<Identifier, StringRef, Type>; |
274 | 0 | bool operator==(const KeyTy &key) const { |
275 | 0 | return key == KeyTy(dialectNamespace, attrData, getType()); |
276 | 0 | } |
277 | | |
278 | | static OpaqueAttributeStorage *construct(AttributeStorageAllocator &allocator, |
279 | 0 | const KeyTy &key) { |
280 | 0 | return new (allocator.allocate<OpaqueAttributeStorage>()) |
281 | 0 | OpaqueAttributeStorage(std::get<0>(key), |
282 | 0 | allocator.copyInto(std::get<1>(key)), |
283 | 0 | std::get<2>(key)); |
284 | 0 | } |
285 | | |
286 | | // The dialect namespace. |
287 | | Identifier dialectNamespace; |
288 | | |
289 | | // The parser attribute data for this opaque attribute. |
290 | | StringRef attrData; |
291 | | }; |
292 | | |
293 | | /// An attribute representing a string value. |
294 | | struct StringAttributeStorage : public AttributeStorage { |
295 | | using KeyTy = std::pair<StringRef, Type>; |
296 | | |
297 | | StringAttributeStorage(StringRef value, Type type) |
298 | 0 | : AttributeStorage(type), value(value) {} |
299 | | |
300 | | /// Key equality function. |
301 | 0 | bool operator==(const KeyTy &key) const { |
302 | 0 | return key == KeyTy(value, getType()); |
303 | 0 | } |
304 | | |
305 | | /// Construct a new storage instance. |
306 | | static StringAttributeStorage *construct(AttributeStorageAllocator &allocator, |
307 | 0 | const KeyTy &key) { |
308 | 0 | return new (allocator.allocate<StringAttributeStorage>()) |
309 | 0 | StringAttributeStorage(allocator.copyInto(key.first), key.second); |
310 | 0 | } |
311 | | |
312 | | StringRef value; |
313 | | }; |
314 | | |
315 | | /// An attribute representing a symbol reference. |
316 | | struct SymbolRefAttributeStorage final |
317 | | : public AttributeStorage, |
318 | | public llvm::TrailingObjects<SymbolRefAttributeStorage, |
319 | | FlatSymbolRefAttr> { |
320 | | using KeyTy = std::pair<StringRef, ArrayRef<FlatSymbolRefAttr>>; |
321 | | |
322 | | SymbolRefAttributeStorage(StringRef value, size_t numNestedRefs) |
323 | 0 | : value(value), numNestedRefs(numNestedRefs) {} |
324 | | |
325 | | /// Key equality function. |
326 | 0 | bool operator==(const KeyTy &key) const { |
327 | 0 | return key == KeyTy(value, getNestedRefs()); |
328 | 0 | } |
329 | | |
330 | | /// Construct a new storage instance. |
331 | | static SymbolRefAttributeStorage * |
332 | 0 | construct(AttributeStorageAllocator &allocator, const KeyTy &key) { |
333 | 0 | auto size = SymbolRefAttributeStorage::totalSizeToAlloc<FlatSymbolRefAttr>( |
334 | 0 | key.second.size()); |
335 | 0 | auto rawMem = allocator.allocate(size, alignof(SymbolRefAttributeStorage)); |
336 | 0 | auto result = ::new (rawMem) SymbolRefAttributeStorage( |
337 | 0 | allocator.copyInto(key.first), key.second.size()); |
338 | 0 | std::uninitialized_copy(key.second.begin(), key.second.end(), |
339 | 0 | result->getTrailingObjects<FlatSymbolRefAttr>()); |
340 | 0 | return result; |
341 | 0 | } |
342 | | |
343 | | /// Returns the set of nested references. |
344 | 0 | ArrayRef<FlatSymbolRefAttr> getNestedRefs() const { |
345 | 0 | return {getTrailingObjects<FlatSymbolRefAttr>(), numNestedRefs}; |
346 | 0 | } |
347 | | |
348 | | StringRef value; |
349 | | size_t numNestedRefs; |
350 | | }; |
351 | | |
352 | | /// An attribute representing a reference to a type. |
353 | | struct TypeAttributeStorage : public AttributeStorage { |
354 | | using KeyTy = Type; |
355 | | |
356 | 0 | TypeAttributeStorage(Type value) : value(value) {} |
357 | | |
358 | | /// Key equality function. |
359 | 0 | bool operator==(const KeyTy &key) const { return key == value; } |
360 | | |
361 | | /// Construct a new storage instance. |
362 | | static TypeAttributeStorage *construct(AttributeStorageAllocator &allocator, |
363 | 0 | KeyTy key) { |
364 | 0 | return new (allocator.allocate<TypeAttributeStorage>()) |
365 | 0 | TypeAttributeStorage(key); |
366 | 0 | } |
367 | | |
368 | | Type value; |
369 | | }; |
370 | | |
371 | | //===----------------------------------------------------------------------===// |
372 | | // Elements Attributes |
373 | | //===----------------------------------------------------------------------===// |
374 | | |
375 | | /// Return the bit width which DenseElementsAttr should use for this type. |
376 | 0 | inline size_t getDenseElementBitWidth(Type eltType) { |
377 | 0 | // Align the width for complex to 8 to make storage and interpretation easier. |
378 | 0 | if (ComplexType comp = eltType.dyn_cast<ComplexType>()) |
379 | 0 | return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2; |
380 | 0 | // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored |
381 | 0 | // with double semantics. |
382 | 0 | if (eltType.isBF16()) |
383 | 0 | return 64; |
384 | 0 | if (eltType.isIndex()) |
385 | 0 | return IndexType::kInternalStorageBitWidth; |
386 | 0 | return eltType.getIntOrFloatBitWidth(); |
387 | 0 | } |
388 | | |
389 | | /// An attribute representing a reference to a dense vector or tensor object. |
390 | | struct DenseElementsAttributeStorage : public AttributeStorage { |
391 | | public: |
392 | | DenseElementsAttributeStorage(ShapedType ty, bool isSplat) |
393 | 0 | : AttributeStorage(ty), isSplat(isSplat) {} |
394 | | |
395 | | bool isSplat; |
396 | | }; |
397 | | |
398 | | /// An attribute representing a reference to a dense vector or tensor object. |
399 | | struct DenseIntOrFPElementsAttributeStorage |
400 | | : public DenseElementsAttributeStorage { |
401 | | DenseIntOrFPElementsAttributeStorage(ShapedType ty, ArrayRef<char> data, |
402 | | bool isSplat = false) |
403 | 0 | : DenseElementsAttributeStorage(ty, isSplat), data(data) {} |
404 | | |
405 | | struct KeyTy { |
406 | | KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode, |
407 | | bool isSplat = false) |
408 | 0 | : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {} |
409 | | |
410 | | /// The type of the dense elements. |
411 | | ShapedType type; |
412 | | |
413 | | /// The raw buffer for the data storage. |
414 | | ArrayRef<char> data; |
415 | | |
416 | | /// The computed hash code for the storage data. |
417 | | llvm::hash_code hashCode; |
418 | | |
419 | | /// A boolean that indicates if this data is a splat or not. |
420 | | bool isSplat; |
421 | | }; |
422 | | |
423 | | /// Compare this storage instance with the provided key. |
424 | 0 | bool operator==(const KeyTy &key) const { |
425 | 0 | if (key.type != getType()) |
426 | 0 | return false; |
427 | 0 | |
428 | 0 | // For boolean splats we need to explicitly check that the first bit is the |
429 | 0 | // same. Boolean values are packed at the bit level, and even though a splat |
430 | 0 | // is detected the rest of the bits in the first byte may differ from the |
431 | 0 | // splat value. |
432 | 0 | if (key.type.getElementType().isInteger(1)) { |
433 | 0 | if (key.isSplat != isSplat) |
434 | 0 | return false; |
435 | 0 | if (isSplat) |
436 | 0 | return (key.data.front() & 1) == data.front(); |
437 | 0 | } |
438 | 0 | |
439 | 0 | // Otherwise, we can default to just checking the data. |
440 | 0 | return key.data == data; |
441 | 0 | } |
442 | | |
443 | | /// Construct a key from a shaped type, raw data buffer, and a flag that |
444 | | /// signals if the data is already known to be a splat. Callers to this |
445 | | /// function are expected to tag preknown splat values when possible, e.g. one |
446 | | /// element shapes. |
447 | 0 | static KeyTy getKey(ShapedType ty, ArrayRef<char> data, bool isKnownSplat) { |
448 | 0 | // Handle an empty storage instance. |
449 | 0 | if (data.empty()) |
450 | 0 | return KeyTy(ty, data, 0); |
451 | 0 | |
452 | 0 | // If the data is already known to be a splat, the key hash value is |
453 | 0 | // directly the data buffer. |
454 | 0 | if (isKnownSplat) |
455 | 0 | return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat); |
456 | 0 | |
457 | 0 | // Otherwise, we need to check if the data corresponds to a splat or not. |
458 | 0 | |
459 | 0 | // Handle the simple case of only one element. |
460 | 0 | size_t numElements = ty.getNumElements(); |
461 | 0 | assert(numElements != 1 && "splat of 1 element should already be detected"); |
462 | 0 |
|
463 | 0 | // Handle boolean values directly as they are packed to 1-bit. |
464 | 0 | if (ty.getElementType().isInteger(1) == 1) |
465 | 0 | return getKeyForBoolData(ty, data, numElements); |
466 | 0 | |
467 | 0 | size_t elementWidth = getDenseElementBitWidth(ty.getElementType()); |
468 | 0 | // Non 1-bit dense elements are padded to 8-bits. |
469 | 0 | size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT); |
470 | 0 | assert(((data.size() / storageSize) == numElements) && |
471 | 0 | "data does not hold expected number of elements"); |
472 | 0 |
|
473 | 0 | // Create the initial hash value with just the first element. |
474 | 0 | auto firstElt = data.take_front(storageSize); |
475 | 0 | auto hashVal = llvm::hash_value(firstElt); |
476 | 0 |
|
477 | 0 | // Check to see if this storage represents a splat. If it doesn't then |
478 | 0 | // combine the hash for the data starting with the first non splat element. |
479 | 0 | for (size_t i = storageSize, e = data.size(); i != e; i += storageSize) |
480 | 0 | if (memcmp(data.data(), &data[i], storageSize)) |
481 | 0 | return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i))); |
482 | 0 |
|
483 | 0 | // Otherwise, this is a splat so just return the hash of the first element. |
484 | 0 | return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true); |
485 | 0 | } |
486 | | |
487 | | /// Construct a key with a set of boolean data. |
488 | | static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data, |
489 | 0 | size_t numElements) { |
490 | 0 | ArrayRef<char> splatData = data; |
491 | 0 | bool splatValue = splatData.front() & 1; |
492 | 0 |
|
493 | 0 | // Helper functor to generate a KeyTy for a boolean splat value. |
494 | 0 | auto generateSplatKey = [=] { |
495 | 0 | return KeyTy(ty, data.take_front(1), |
496 | 0 | llvm::hash_value(ArrayRef<char>(splatValue ? 1 : 0)), |
497 | 0 | /*isSplat=*/true); |
498 | 0 | }; |
499 | 0 |
|
500 | 0 | // Handle the case where the potential splat value is 1 and the number of |
501 | 0 | // elements is non 8-bit aligned. |
502 | 0 | size_t numOddElements = numElements % CHAR_BIT; |
503 | 0 | if (splatValue && numOddElements != 0) { |
504 | 0 | // Check that all bits are set in the last value. |
505 | 0 | char lastElt = splatData.back(); |
506 | 0 | if (lastElt != llvm::maskTrailingOnes<unsigned char>(numOddElements)) |
507 | 0 | return KeyTy(ty, data, llvm::hash_value(data)); |
508 | 0 | |
509 | 0 | // If this is the only element, the data is known to be a splat. |
510 | 0 | if (splatData.size() == 1) |
511 | 0 | return generateSplatKey(); |
512 | 0 | splatData = splatData.drop_back(); |
513 | 0 | } |
514 | 0 |
|
515 | 0 | // Check that the data buffer corresponds to a splat of the proper mask. |
516 | 0 | char mask = splatValue ? ~0 : 0; |
517 | 0 | return llvm::all_of(splatData, [mask](char c) { return c == mask; }) |
518 | 0 | ? generateSplatKey() |
519 | 0 | : KeyTy(ty, data, llvm::hash_value(data)); |
520 | 0 | } |
521 | | |
522 | | /// Hash the key for the storage. |
523 | 0 | static llvm::hash_code hashKey(const KeyTy &key) { |
524 | 0 | return llvm::hash_combine(key.type, key.hashCode); |
525 | 0 | } |
526 | | |
527 | | /// Construct a new storage instance. |
528 | | static DenseIntOrFPElementsAttributeStorage * |
529 | 0 | construct(AttributeStorageAllocator &allocator, KeyTy key) { |
530 | 0 | // If the data buffer is non-empty, we copy it into the allocator with a |
531 | 0 | // 64-bit alignment. |
532 | 0 | ArrayRef<char> copy, data = key.data; |
533 | 0 | if (!data.empty()) { |
534 | 0 | char *rawData = reinterpret_cast<char *>( |
535 | 0 | allocator.allocate(data.size(), alignof(uint64_t))); |
536 | 0 | std::memcpy(rawData, data.data(), data.size()); |
537 | 0 |
|
538 | 0 | // If this is a boolean splat, make sure only the first bit is used. |
539 | 0 | if (key.isSplat && key.type.getElementType().isInteger(1)) |
540 | 0 | rawData[0] &= 1; |
541 | 0 | copy = ArrayRef<char>(rawData, data.size()); |
542 | 0 | } |
543 | 0 |
|
544 | 0 | return new (allocator.allocate<DenseIntOrFPElementsAttributeStorage>()) |
545 | 0 | DenseIntOrFPElementsAttributeStorage(key.type, copy, key.isSplat); |
546 | 0 | } |
547 | | |
548 | | ArrayRef<char> data; |
549 | | }; |
550 | | |
551 | | /// An attribute representing a reference to a dense vector or tensor object |
552 | | /// containing strings. |
553 | | struct DenseStringElementsAttributeStorage |
554 | | : public DenseElementsAttributeStorage { |
555 | | DenseStringElementsAttributeStorage(ShapedType ty, ArrayRef<StringRef> data, |
556 | | bool isSplat = false) |
557 | 0 | : DenseElementsAttributeStorage(ty, isSplat), data(data) {} |
558 | | |
559 | | struct KeyTy { |
560 | | KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode, |
561 | | bool isSplat = false) |
562 | 0 | : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {} |
563 | | |
564 | | /// The type of the dense elements. |
565 | | ShapedType type; |
566 | | |
567 | | /// The raw buffer for the data storage. |
568 | | ArrayRef<StringRef> data; |
569 | | |
570 | | /// The computed hash code for the storage data. |
571 | | llvm::hash_code hashCode; |
572 | | |
573 | | /// A boolean that indicates if this data is a splat or not. |
574 | | bool isSplat; |
575 | | }; |
576 | | |
577 | | /// Compare this storage instance with the provided key. |
578 | 0 | bool operator==(const KeyTy &key) const { |
579 | 0 | if (key.type != getType()) |
580 | 0 | return false; |
581 | 0 | |
582 | 0 | // Otherwise, we can default to just checking the data. StringRefs compare |
583 | 0 | // by contents. |
584 | 0 | return key.data == data; |
585 | 0 | } |
586 | | |
587 | | /// Construct a key from a shaped type, StringRef data buffer, and a flag that |
588 | | /// signals if the data is already known to be a splat. Callers to this |
589 | | /// function are expected to tag preknown splat values when possible, e.g. one |
590 | | /// element shapes. |
591 | | static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data, |
592 | 0 | bool isKnownSplat) { |
593 | 0 | // Handle an empty storage instance. |
594 | 0 | if (data.empty()) |
595 | 0 | return KeyTy(ty, data, 0); |
596 | 0 | |
597 | 0 | // If the data is already known to be a splat, the key hash value is |
598 | 0 | // directly the data buffer. |
599 | 0 | if (isKnownSplat) |
600 | 0 | return KeyTy(ty, data, llvm::hash_value(data.front()), isKnownSplat); |
601 | 0 | |
602 | 0 | // Handle the simple case of only one element. |
603 | 0 | assert(ty.getNumElements() != 1 && |
604 | 0 | "splat of 1 element should already be detected"); |
605 | 0 |
|
606 | 0 | // Create the initial hash value with just the first element. |
607 | 0 | const auto &firstElt = data.front(); |
608 | 0 | auto hashVal = llvm::hash_value(firstElt); |
609 | 0 |
|
610 | 0 | // Check to see if this storage represents a splat. If it doesn't then |
611 | 0 | // combine the hash for the data starting with the first non splat element. |
612 | 0 | for (size_t i = 1, e = data.size(); i != e; i++) |
613 | 0 | if (!firstElt.equals(data[i])) |
614 | 0 | return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i))); |
615 | 0 |
|
616 | 0 | // Otherwise, this is a splat so just return the hash of the first element. |
617 | 0 | return KeyTy(ty, data.take_front(), hashVal, /*isSplat=*/true); |
618 | 0 | } |
619 | | |
620 | | /// Hash the key for the storage. |
621 | 0 | static llvm::hash_code hashKey(const KeyTy &key) { |
622 | 0 | return llvm::hash_combine(key.type, key.hashCode); |
623 | 0 | } |
624 | | |
625 | | /// Construct a new storage instance. |
626 | | static DenseStringElementsAttributeStorage * |
627 | 0 | construct(AttributeStorageAllocator &allocator, KeyTy key) { |
628 | 0 | // If the data buffer is non-empty, we copy it into the allocator with a |
629 | 0 | // 64-bit alignment. |
630 | 0 | ArrayRef<StringRef> copy, data = key.data; |
631 | 0 | if (data.empty()) { |
632 | 0 | return new (allocator.allocate<DenseStringElementsAttributeStorage>()) |
633 | 0 | DenseStringElementsAttributeStorage(key.type, copy, key.isSplat); |
634 | 0 | } |
635 | 0 | |
636 | 0 | int numEntries = key.isSplat ? 1 : data.size(); |
637 | 0 |
|
638 | 0 | // Compute the amount data needed to store the ArrayRef and StringRef |
639 | 0 | // contents. |
640 | 0 | size_t dataSize = sizeof(StringRef) * numEntries; |
641 | 0 | for (int i = 0; i < numEntries; i++) |
642 | 0 | dataSize += data[i].size(); |
643 | 0 |
|
644 | 0 | char *rawData = reinterpret_cast<char *>( |
645 | 0 | allocator.allocate(dataSize, alignof(uint64_t))); |
646 | 0 |
|
647 | 0 | // Setup a mutable array ref of our string refs so that we can update their |
648 | 0 | // contents. |
649 | 0 | auto mutableCopy = MutableArrayRef<StringRef>( |
650 | 0 | reinterpret_cast<StringRef *>(rawData), numEntries); |
651 | 0 | auto stringData = rawData + numEntries * sizeof(StringRef); |
652 | 0 |
|
653 | 0 | for (int i = 0; i < numEntries; i++) { |
654 | 0 | memcpy(stringData, data[i].data(), data[i].size()); |
655 | 0 | mutableCopy[i] = StringRef(stringData, data[i].size()); |
656 | 0 | stringData += data[i].size(); |
657 | 0 | } |
658 | 0 |
|
659 | 0 | copy = |
660 | 0 | ArrayRef<StringRef>(reinterpret_cast<StringRef *>(rawData), numEntries); |
661 | 0 |
|
662 | 0 | return new (allocator.allocate<DenseStringElementsAttributeStorage>()) |
663 | 0 | DenseStringElementsAttributeStorage(key.type, copy, key.isSplat); |
664 | 0 | } |
665 | | |
666 | | ArrayRef<StringRef> data; |
667 | | }; |
668 | | |
669 | | /// An attribute representing a reference to a tensor constant with opaque |
670 | | /// content. |
671 | | struct OpaqueElementsAttributeStorage : public AttributeStorage { |
672 | | using KeyTy = std::tuple<Type, Dialect *, StringRef>; |
673 | | |
674 | | OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes) |
675 | 0 | : AttributeStorage(type), dialect(dialect), bytes(bytes) {} |
676 | | |
677 | | /// Key equality and hash functions. |
678 | 0 | bool operator==(const KeyTy &key) const { |
679 | 0 | return key == std::make_tuple(getType(), dialect, bytes); |
680 | 0 | } |
681 | 0 | static unsigned hashKey(const KeyTy &key) { |
682 | 0 | return llvm::hash_combine(std::get<0>(key), std::get<1>(key), |
683 | 0 | std::get<2>(key)); |
684 | 0 | } |
685 | | |
686 | | /// Construct a new storage instance. |
687 | | static OpaqueElementsAttributeStorage * |
688 | 0 | construct(AttributeStorageAllocator &allocator, KeyTy key) { |
689 | 0 | // TODO(b/131468830): Provide a way to avoid copying content of large opaque |
690 | 0 | // tensors This will likely require a new reference attribute kind. |
691 | 0 | return new (allocator.allocate<OpaqueElementsAttributeStorage>()) |
692 | 0 | OpaqueElementsAttributeStorage(std::get<0>(key), std::get<1>(key), |
693 | 0 | allocator.copyInto(std::get<2>(key))); |
694 | 0 | } |
695 | | |
696 | | Dialect *dialect; |
697 | | StringRef bytes; |
698 | | }; |
699 | | |
700 | | /// An attribute representing a reference to a sparse vector or tensor object. |
701 | | struct SparseElementsAttributeStorage : public AttributeStorage { |
702 | | using KeyTy = std::tuple<Type, DenseIntElementsAttr, DenseElementsAttr>; |
703 | | |
704 | | SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices, |
705 | | DenseElementsAttr values) |
706 | 0 | : AttributeStorage(type), indices(indices), values(values) {} |
707 | | |
708 | | /// Key equality and hash functions. |
709 | 0 | bool operator==(const KeyTy &key) const { |
710 | 0 | return key == std::make_tuple(getType(), indices, values); |
711 | 0 | } |
712 | 0 | static unsigned hashKey(const KeyTy &key) { |
713 | 0 | return llvm::hash_combine(std::get<0>(key), std::get<1>(key), |
714 | 0 | std::get<2>(key)); |
715 | 0 | } |
716 | | |
717 | | /// Construct a new storage instance. |
718 | | static SparseElementsAttributeStorage * |
719 | 0 | construct(AttributeStorageAllocator &allocator, KeyTy key) { |
720 | 0 | return new (allocator.allocate<SparseElementsAttributeStorage>()) |
721 | 0 | SparseElementsAttributeStorage(std::get<0>(key), std::get<1>(key), |
722 | 0 | std::get<2>(key)); |
723 | 0 | } |
724 | | |
725 | | DenseIntElementsAttr indices; |
726 | | DenseElementsAttr values; |
727 | | }; |
728 | | } // namespace detail |
729 | | } // namespace mlir |
730 | | |
731 | | #endif // ATTRIBUTEDETAIL_H_ |