Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/lib/IR/AttributeDetail.h
Line
Count
Source (jump to first uncovered line)
1
//===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This holds implementation details of Attribute.
10
//
11
//===----------------------------------------------------------------------===//
12
13
#ifndef ATTRIBUTEDETAIL_H_
14
#define ATTRIBUTEDETAIL_H_
15
16
#include "mlir/IR/AffineMap.h"
17
#include "mlir/IR/Attributes.h"
18
#include "mlir/IR/Identifier.h"
19
#include "mlir/IR/IntegerSet.h"
20
#include "mlir/IR/MLIRContext.h"
21
#include "mlir/IR/StandardTypes.h"
22
#include "mlir/Support/StorageUniquer.h"
23
#include "llvm/ADT/APFloat.h"
24
#include "llvm/ADT/PointerIntPair.h"
25
#include "llvm/Support/TrailingObjects.h"
26
27
namespace mlir {
28
namespace detail {
29
// An attribute representing a reference to an affine map.
30
struct AffineMapAttributeStorage : public AttributeStorage {
31
  using KeyTy = AffineMap;
32
33
  AffineMapAttributeStorage(AffineMap value)
34
0
      : AttributeStorage(IndexType::get(value.getContext())), value(value) {}
35
36
  /// Key equality function.
37
0
  bool operator==(const KeyTy &key) const { return key == value; }
38
39
  /// Construct a new storage instance.
40
  static AffineMapAttributeStorage *
41
0
  construct(AttributeStorageAllocator &allocator, KeyTy key) {
42
0
    return new (allocator.allocate<AffineMapAttributeStorage>())
43
0
        AffineMapAttributeStorage(key);
44
0
  }
45
46
  AffineMap value;
47
};
48
49
/// An attribute representing an array of other attributes.
50
struct ArrayAttributeStorage : public AttributeStorage {
51
  using KeyTy = ArrayRef<Attribute>;
52
53
0
  ArrayAttributeStorage(ArrayRef<Attribute> value) : value(value) {}
54
55
  /// Key equality function.
56
0
  bool operator==(const KeyTy &key) const { return key == value; }
57
58
  /// Construct a new storage instance.
59
  static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator,
60
0
                                          const KeyTy &key) {
61
0
    return new (allocator.allocate<ArrayAttributeStorage>())
62
0
        ArrayAttributeStorage(allocator.copyInto(key));
63
0
  }
64
65
  ArrayRef<Attribute> value;
66
};
67
68
/// An attribute representing a boolean value.
69
struct BoolAttributeStorage : public AttributeStorage {
70
  using KeyTy = std::pair<MLIRContext *, bool>;
71
72
  BoolAttributeStorage(Type type, bool value)
73
0
      : AttributeStorage(type), value(value) {}
74
75
  /// We only check equality for and hash with the boolean key parameter.
76
0
  bool operator==(const KeyTy &key) const { return key.second == value; }
77
0
  static unsigned hashKey(const KeyTy &key) {
78
0
    return llvm::hash_value(key.second);
79
0
  }
80
81
  static BoolAttributeStorage *construct(AttributeStorageAllocator &allocator,
82
0
                                         const KeyTy &key) {
83
0
    return new (allocator.allocate<BoolAttributeStorage>())
84
0
        BoolAttributeStorage(IntegerType::get(1, key.first), key.second);
85
0
  }
86
87
  bool value;
88
};
89
90
/// An attribute representing a dictionary of sorted named attributes.
91
struct DictionaryAttributeStorage final
92
    : public AttributeStorage,
93
      private llvm::TrailingObjects<DictionaryAttributeStorage,
94
                                    NamedAttribute> {
95
  using KeyTy = ArrayRef<NamedAttribute>;
96
97
  /// Given a list of NamedAttribute's, canonicalize the list (sorting
98
  /// by name) and return the unique'd result.
99
  static DictionaryAttributeStorage *get(ArrayRef<NamedAttribute> attrs);
100
101
  /// Key equality function.
102
0
  bool operator==(const KeyTy &key) const { return key == getElements(); }
103
104
  /// Construct a new storage instance.
105
  static DictionaryAttributeStorage *
106
0
  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
107
0
    auto size = DictionaryAttributeStorage::totalSizeToAlloc<NamedAttribute>(
108
0
        key.size());
109
0
    auto rawMem = allocator.allocate(size, alignof(NamedAttribute));
110
0
111
0
    // Initialize the storage and trailing attribute list.
112
0
    auto result = ::new (rawMem) DictionaryAttributeStorage(key.size());
113
0
    std::uninitialized_copy(key.begin(), key.end(),
114
0
                            result->getTrailingObjects<NamedAttribute>());
115
0
    return result;
116
0
  }
117
118
  /// Return the elements of this dictionary attribute.
119
0
  ArrayRef<NamedAttribute> getElements() const {
120
0
    return {getTrailingObjects<NamedAttribute>(), numElements};
121
0
  }
122
123
private:
124
  friend class llvm::TrailingObjects<DictionaryAttributeStorage,
125
                                     NamedAttribute>;
126
127
  // This is used by the llvm::TrailingObjects base class.
128
0
  size_t numTrailingObjects(OverloadToken<NamedAttribute>) const {
129
0
    return numElements;
130
0
  }
131
0
  DictionaryAttributeStorage(unsigned numElements) : numElements(numElements) {}
132
133
  /// This is the number of attributes.
134
  const unsigned numElements;
135
};
136
137
/// An attribute representing a floating point value.
138
struct FloatAttributeStorage final
139
    : public AttributeStorage,
140
      public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
141
  using KeyTy = std::pair<Type, APFloat>;
142
143
  FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type,
144
                        size_t numObjects)
145
0
      : AttributeStorage(type), semantics(semantics), numObjects(numObjects) {}
146
147
  /// Key equality and hash functions.
148
0
  bool operator==(const KeyTy &key) const {
149
0
    return key.first == getType() && key.second.bitwiseIsEqual(getValue());
150
0
  }
151
0
  static unsigned hashKey(const KeyTy &key) {
152
0
    return llvm::hash_combine(key.first, llvm::hash_value(key.second));
153
0
  }
154
155
  /// Construct a key with a type and double.
156
0
  static KeyTy getKey(Type type, double value) {
157
0
    // Treat BF16 as double because it is not supported in LLVM's APFloat.
158
0
    // TODO(b/121118307): add BF16 support to APFloat?
159
0
    if (type.isBF16() || type.isF64())
160
0
      return KeyTy(type, APFloat(value));
161
0
162
0
    // This handles, e.g., F16 because there is no APFloat constructor for it.
163
0
    bool unused;
164
0
    APFloat val(value);
165
0
    val.convert(type.cast<FloatType>().getFloatSemantics(),
166
0
                APFloat::rmNearestTiesToEven, &unused);
167
0
    return KeyTy(type, val);
168
0
  }
169
170
  /// Construct a new storage instance.
171
  static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator,
172
0
                                          const KeyTy &key) {
173
0
    const auto &apint = key.second.bitcastToAPInt();
174
0
175
0
    // Here one word's bitwidth equals to that of uint64_t.
176
0
    auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
177
0
178
0
    auto byteSize =
179
0
        FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
180
0
    auto rawMem = allocator.allocate(byteSize, alignof(FloatAttributeStorage));
181
0
    auto result = ::new (rawMem) FloatAttributeStorage(
182
0
        key.second.getSemantics(), key.first, elements.size());
183
0
    std::uninitialized_copy(elements.begin(), elements.end(),
184
0
                            result->getTrailingObjects<uint64_t>());
185
0
    return result;
186
0
  }
187
188
  /// Returns an APFloat representing the stored value.
189
0
  APFloat getValue() const {
190
0
    auto val = APInt(APFloat::getSizeInBits(semantics),
191
0
                     {getTrailingObjects<uint64_t>(), numObjects});
192
0
    return APFloat(semantics, val);
193
0
  }
194
195
  const llvm::fltSemantics &semantics;
196
  size_t numObjects;
197
};
198
199
/// An attribute representing an integral value.
200
struct IntegerAttributeStorage final
201
    : public AttributeStorage,
202
      public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> {
203
  using KeyTy = std::pair<Type, APInt>;
204
205
  IntegerAttributeStorage(Type type, size_t numObjects)
206
0
      : AttributeStorage(type), numObjects(numObjects) {
207
0
    assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
208
0
  }
209
210
  /// Key equality and hash functions.
211
0
  bool operator==(const KeyTy &key) const {
212
0
    return key == KeyTy(getType(), getValue());
213
0
  }
214
0
  static unsigned hashKey(const KeyTy &key) {
215
0
    return llvm::hash_combine(key.first, llvm::hash_value(key.second));
216
0
  }
217
218
  /// Construct a new storage instance.
219
  static IntegerAttributeStorage *
220
0
  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
221
0
    Type type;
222
0
    APInt value;
223
0
    std::tie(type, value) = key;
224
0
225
0
    auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords());
226
0
    auto size =
227
0
        IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
228
0
    auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage));
229
0
    auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size());
230
0
    std::uninitialized_copy(elements.begin(), elements.end(),
231
0
                            result->getTrailingObjects<uint64_t>());
232
0
    return result;
233
0
  }
234
235
  /// Returns an APInt representing the stored value.
236
0
  APInt getValue() const {
237
0
    if (getType().isIndex())
238
0
      return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
239
0
    return APInt(getType().getIntOrFloatBitWidth(),
240
0
                 {getTrailingObjects<uint64_t>(), numObjects});
241
0
  }
242
243
  size_t numObjects;
244
};
245
246
// An attribute representing a reference to an integer set.
247
struct IntegerSetAttributeStorage : public AttributeStorage {
248
  using KeyTy = IntegerSet;
249
250
0
  IntegerSetAttributeStorage(IntegerSet value) : value(value) {}
251
252
  /// Key equality function.
253
0
  bool operator==(const KeyTy &key) const { return key == value; }
254
255
  /// Construct a new storage instance.
256
  static IntegerSetAttributeStorage *
257
0
  construct(AttributeStorageAllocator &allocator, KeyTy key) {
258
0
    return new (allocator.allocate<IntegerSetAttributeStorage>())
259
0
        IntegerSetAttributeStorage(key);
260
0
  }
261
262
  IntegerSet value;
263
};
264
265
/// Opaque Attribute Storage and Uniquing.
266
struct OpaqueAttributeStorage : public AttributeStorage {
267
  OpaqueAttributeStorage(Identifier dialectNamespace, StringRef attrData,
268
                         Type type)
269
      : AttributeStorage(type), dialectNamespace(dialectNamespace),
270
0
        attrData(attrData) {}
271
272
  /// The hash key used for uniquing.
273
  using KeyTy = std::tuple<Identifier, StringRef, Type>;
274
0
  bool operator==(const KeyTy &key) const {
275
0
    return key == KeyTy(dialectNamespace, attrData, getType());
276
0
  }
277
278
  static OpaqueAttributeStorage *construct(AttributeStorageAllocator &allocator,
279
0
                                           const KeyTy &key) {
280
0
    return new (allocator.allocate<OpaqueAttributeStorage>())
281
0
        OpaqueAttributeStorage(std::get<0>(key),
282
0
                               allocator.copyInto(std::get<1>(key)),
283
0
                               std::get<2>(key));
284
0
  }
285
286
  // The dialect namespace.
287
  Identifier dialectNamespace;
288
289
  // The parser attribute data for this opaque attribute.
290
  StringRef attrData;
291
};
292
293
/// An attribute representing a string value.
294
struct StringAttributeStorage : public AttributeStorage {
295
  using KeyTy = std::pair<StringRef, Type>;
296
297
  StringAttributeStorage(StringRef value, Type type)
298
0
      : AttributeStorage(type), value(value) {}
299
300
  /// Key equality function.
301
0
  bool operator==(const KeyTy &key) const {
302
0
    return key == KeyTy(value, getType());
303
0
  }
304
305
  /// Construct a new storage instance.
306
  static StringAttributeStorage *construct(AttributeStorageAllocator &allocator,
307
0
                                           const KeyTy &key) {
308
0
    return new (allocator.allocate<StringAttributeStorage>())
309
0
        StringAttributeStorage(allocator.copyInto(key.first), key.second);
310
0
  }
311
312
  StringRef value;
313
};
314
315
/// An attribute representing a symbol reference.
316
struct SymbolRefAttributeStorage final
317
    : public AttributeStorage,
318
      public llvm::TrailingObjects<SymbolRefAttributeStorage,
319
                                   FlatSymbolRefAttr> {
320
  using KeyTy = std::pair<StringRef, ArrayRef<FlatSymbolRefAttr>>;
321
322
  SymbolRefAttributeStorage(StringRef value, size_t numNestedRefs)
323
0
      : value(value), numNestedRefs(numNestedRefs) {}
324
325
  /// Key equality function.
326
0
  bool operator==(const KeyTy &key) const {
327
0
    return key == KeyTy(value, getNestedRefs());
328
0
  }
329
330
  /// Construct a new storage instance.
331
  static SymbolRefAttributeStorage *
332
0
  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
333
0
    auto size = SymbolRefAttributeStorage::totalSizeToAlloc<FlatSymbolRefAttr>(
334
0
        key.second.size());
335
0
    auto rawMem = allocator.allocate(size, alignof(SymbolRefAttributeStorage));
336
0
    auto result = ::new (rawMem) SymbolRefAttributeStorage(
337
0
        allocator.copyInto(key.first), key.second.size());
338
0
    std::uninitialized_copy(key.second.begin(), key.second.end(),
339
0
                            result->getTrailingObjects<FlatSymbolRefAttr>());
340
0
    return result;
341
0
  }
342
343
  /// Returns the set of nested references.
344
0
  ArrayRef<FlatSymbolRefAttr> getNestedRefs() const {
345
0
    return {getTrailingObjects<FlatSymbolRefAttr>(), numNestedRefs};
346
0
  }
347
348
  StringRef value;
349
  size_t numNestedRefs;
350
};
351
352
/// An attribute representing a reference to a type.
353
struct TypeAttributeStorage : public AttributeStorage {
354
  using KeyTy = Type;
355
356
0
  TypeAttributeStorage(Type value) : value(value) {}
357
358
  /// Key equality function.
359
0
  bool operator==(const KeyTy &key) const { return key == value; }
360
361
  /// Construct a new storage instance.
362
  static TypeAttributeStorage *construct(AttributeStorageAllocator &allocator,
363
0
                                         KeyTy key) {
364
0
    return new (allocator.allocate<TypeAttributeStorage>())
365
0
        TypeAttributeStorage(key);
366
0
  }
367
368
  Type value;
369
};
370
371
//===----------------------------------------------------------------------===//
372
// Elements Attributes
373
//===----------------------------------------------------------------------===//
374
375
/// Return the bit width which DenseElementsAttr should use for this type.
376
0
inline size_t getDenseElementBitWidth(Type eltType) {
377
0
  // Align the width for complex to 8 to make storage and interpretation easier.
378
0
  if (ComplexType comp = eltType.dyn_cast<ComplexType>())
379
0
    return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
380
0
  // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
381
0
  // with double semantics.
382
0
  if (eltType.isBF16())
383
0
    return 64;
384
0
  if (eltType.isIndex())
385
0
    return IndexType::kInternalStorageBitWidth;
386
0
  return eltType.getIntOrFloatBitWidth();
387
0
}
388
389
/// An attribute representing a reference to a dense vector or tensor object.
390
struct DenseElementsAttributeStorage : public AttributeStorage {
391
public:
392
  DenseElementsAttributeStorage(ShapedType ty, bool isSplat)
393
0
      : AttributeStorage(ty), isSplat(isSplat) {}
394
395
  bool isSplat;
396
};
397
398
/// An attribute representing a reference to a dense vector or tensor object.
399
struct DenseIntOrFPElementsAttributeStorage
400
    : public DenseElementsAttributeStorage {
401
  DenseIntOrFPElementsAttributeStorage(ShapedType ty, ArrayRef<char> data,
402
                                       bool isSplat = false)
403
0
      : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
404
405
  struct KeyTy {
406
    KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
407
          bool isSplat = false)
408
0
        : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
409
410
    /// The type of the dense elements.
411
    ShapedType type;
412
413
    /// The raw buffer for the data storage.
414
    ArrayRef<char> data;
415
416
    /// The computed hash code for the storage data.
417
    llvm::hash_code hashCode;
418
419
    /// A boolean that indicates if this data is a splat or not.
420
    bool isSplat;
421
  };
422
423
  /// Compare this storage instance with the provided key.
424
0
  bool operator==(const KeyTy &key) const {
425
0
    if (key.type != getType())
426
0
      return false;
427
0
428
0
    // For boolean splats we need to explicitly check that the first bit is the
429
0
    // same. Boolean values are packed at the bit level, and even though a splat
430
0
    // is detected the rest of the bits in the first byte may differ from the
431
0
    // splat value.
432
0
    if (key.type.getElementType().isInteger(1)) {
433
0
      if (key.isSplat != isSplat)
434
0
        return false;
435
0
      if (isSplat)
436
0
        return (key.data.front() & 1) == data.front();
437
0
    }
438
0
439
0
    // Otherwise, we can default to just checking the data.
440
0
    return key.data == data;
441
0
  }
442
443
  /// Construct a key from a shaped type, raw data buffer, and a flag that
444
  /// signals if the data is already known to be a splat. Callers to this
445
  /// function are expected to tag preknown splat values when possible, e.g. one
446
  /// element shapes.
447
0
  static KeyTy getKey(ShapedType ty, ArrayRef<char> data, bool isKnownSplat) {
448
0
    // Handle an empty storage instance.
449
0
    if (data.empty())
450
0
      return KeyTy(ty, data, 0);
451
0
452
0
    // If the data is already known to be a splat, the key hash value is
453
0
    // directly the data buffer.
454
0
    if (isKnownSplat)
455
0
      return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
456
0
457
0
    // Otherwise, we need to check if the data corresponds to a splat or not.
458
0
459
0
    // Handle the simple case of only one element.
460
0
    size_t numElements = ty.getNumElements();
461
0
    assert(numElements != 1 && "splat of 1 element should already be detected");
462
0
463
0
    // Handle boolean values directly as they are packed to 1-bit.
464
0
    if (ty.getElementType().isInteger(1) == 1)
465
0
      return getKeyForBoolData(ty, data, numElements);
466
0
467
0
    size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
468
0
    // Non 1-bit dense elements are padded to 8-bits.
469
0
    size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
470
0
    assert(((data.size() / storageSize) == numElements) &&
471
0
           "data does not hold expected number of elements");
472
0
473
0
    // Create the initial hash value with just the first element.
474
0
    auto firstElt = data.take_front(storageSize);
475
0
    auto hashVal = llvm::hash_value(firstElt);
476
0
477
0
    // Check to see if this storage represents a splat. If it doesn't then
478
0
    // combine the hash for the data starting with the first non splat element.
479
0
    for (size_t i = storageSize, e = data.size(); i != e; i += storageSize)
480
0
      if (memcmp(data.data(), &data[i], storageSize))
481
0
        return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
482
0
483
0
    // Otherwise, this is a splat so just return the hash of the first element.
484
0
    return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true);
485
0
  }
486
487
  /// Construct a key with a set of boolean data.
488
  static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data,
489
0
                                 size_t numElements) {
490
0
    ArrayRef<char> splatData = data;
491
0
    bool splatValue = splatData.front() & 1;
492
0
493
0
    // Helper functor to generate a KeyTy for a boolean splat value.
494
0
    auto generateSplatKey = [=] {
495
0
      return KeyTy(ty, data.take_front(1),
496
0
                   llvm::hash_value(ArrayRef<char>(splatValue ? 1 : 0)),
497
0
                   /*isSplat=*/true);
498
0
    };
499
0
500
0
    // Handle the case where the potential splat value is 1 and the number of
501
0
    // elements is non 8-bit aligned.
502
0
    size_t numOddElements = numElements % CHAR_BIT;
503
0
    if (splatValue && numOddElements != 0) {
504
0
      // Check that all bits are set in the last value.
505
0
      char lastElt = splatData.back();
506
0
      if (lastElt != llvm::maskTrailingOnes<unsigned char>(numOddElements))
507
0
        return KeyTy(ty, data, llvm::hash_value(data));
508
0
509
0
      // If this is the only element, the data is known to be a splat.
510
0
      if (splatData.size() == 1)
511
0
        return generateSplatKey();
512
0
      splatData = splatData.drop_back();
513
0
    }
514
0
515
0
    // Check that the data buffer corresponds to a splat of the proper mask.
516
0
    char mask = splatValue ? ~0 : 0;
517
0
    return llvm::all_of(splatData, [mask](char c) { return c == mask; })
518
0
               ? generateSplatKey()
519
0
               : KeyTy(ty, data, llvm::hash_value(data));
520
0
  }
521
522
  /// Hash the key for the storage.
523
0
  static llvm::hash_code hashKey(const KeyTy &key) {
524
0
    return llvm::hash_combine(key.type, key.hashCode);
525
0
  }
526
527
  /// Construct a new storage instance.
528
  static DenseIntOrFPElementsAttributeStorage *
529
0
  construct(AttributeStorageAllocator &allocator, KeyTy key) {
530
0
    // If the data buffer is non-empty, we copy it into the allocator with a
531
0
    // 64-bit alignment.
532
0
    ArrayRef<char> copy, data = key.data;
533
0
    if (!data.empty()) {
534
0
      char *rawData = reinterpret_cast<char *>(
535
0
          allocator.allocate(data.size(), alignof(uint64_t)));
536
0
      std::memcpy(rawData, data.data(), data.size());
537
0
538
0
      // If this is a boolean splat, make sure only the first bit is used.
539
0
      if (key.isSplat && key.type.getElementType().isInteger(1))
540
0
        rawData[0] &= 1;
541
0
      copy = ArrayRef<char>(rawData, data.size());
542
0
    }
543
0
544
0
    return new (allocator.allocate<DenseIntOrFPElementsAttributeStorage>())
545
0
        DenseIntOrFPElementsAttributeStorage(key.type, copy, key.isSplat);
546
0
  }
547
548
  ArrayRef<char> data;
549
};
550
551
/// An attribute representing a reference to a dense vector or tensor object
552
/// containing strings.
553
struct DenseStringElementsAttributeStorage
554
    : public DenseElementsAttributeStorage {
555
  DenseStringElementsAttributeStorage(ShapedType ty, ArrayRef<StringRef> data,
556
                                      bool isSplat = false)
557
0
      : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
558
559
  struct KeyTy {
560
    KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode,
561
          bool isSplat = false)
562
0
        : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
563
564
    /// The type of the dense elements.
565
    ShapedType type;
566
567
    /// The raw buffer for the data storage.
568
    ArrayRef<StringRef> data;
569
570
    /// The computed hash code for the storage data.
571
    llvm::hash_code hashCode;
572
573
    /// A boolean that indicates if this data is a splat or not.
574
    bool isSplat;
575
  };
576
577
  /// Compare this storage instance with the provided key.
578
0
  bool operator==(const KeyTy &key) const {
579
0
    if (key.type != getType())
580
0
      return false;
581
0
582
0
    // Otherwise, we can default to just checking the data. StringRefs compare
583
0
    // by contents.
584
0
    return key.data == data;
585
0
  }
586
587
  /// Construct a key from a shaped type, StringRef data buffer, and a flag that
588
  /// signals if the data is already known to be a splat. Callers to this
589
  /// function are expected to tag preknown splat values when possible, e.g. one
590
  /// element shapes.
591
  static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data,
592
0
                      bool isKnownSplat) {
593
0
    // Handle an empty storage instance.
594
0
    if (data.empty())
595
0
      return KeyTy(ty, data, 0);
596
0
597
0
    // If the data is already known to be a splat, the key hash value is
598
0
    // directly the data buffer.
599
0
    if (isKnownSplat)
600
0
      return KeyTy(ty, data, llvm::hash_value(data.front()), isKnownSplat);
601
0
602
0
    // Handle the simple case of only one element.
603
0
    assert(ty.getNumElements() != 1 &&
604
0
           "splat of 1 element should already be detected");
605
0
606
0
    // Create the initial hash value with just the first element.
607
0
    const auto &firstElt = data.front();
608
0
    auto hashVal = llvm::hash_value(firstElt);
609
0
610
0
    // Check to see if this storage represents a splat. If it doesn't then
611
0
    // combine the hash for the data starting with the first non splat element.
612
0
    for (size_t i = 1, e = data.size(); i != e; i++)
613
0
      if (!firstElt.equals(data[i]))
614
0
        return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
615
0
616
0
    // Otherwise, this is a splat so just return the hash of the first element.
617
0
    return KeyTy(ty, data.take_front(), hashVal, /*isSplat=*/true);
618
0
  }
619
620
  /// Hash the key for the storage.
621
0
  static llvm::hash_code hashKey(const KeyTy &key) {
622
0
    return llvm::hash_combine(key.type, key.hashCode);
623
0
  }
624
625
  /// Construct a new storage instance.
626
  static DenseStringElementsAttributeStorage *
627
0
  construct(AttributeStorageAllocator &allocator, KeyTy key) {
628
0
    // If the data buffer is non-empty, we copy it into the allocator with a
629
0
    // 64-bit alignment.
630
0
    ArrayRef<StringRef> copy, data = key.data;
631
0
    if (data.empty()) {
632
0
      return new (allocator.allocate<DenseStringElementsAttributeStorage>())
633
0
          DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
634
0
    }
635
0
636
0
    int numEntries = key.isSplat ? 1 : data.size();
637
0
638
0
    // Compute the amount data needed to store the ArrayRef and StringRef
639
0
    // contents.
640
0
    size_t dataSize = sizeof(StringRef) * numEntries;
641
0
    for (int i = 0; i < numEntries; i++)
642
0
      dataSize += data[i].size();
643
0
644
0
    char *rawData = reinterpret_cast<char *>(
645
0
        allocator.allocate(dataSize, alignof(uint64_t)));
646
0
647
0
    // Setup a mutable array ref of our string refs so that we can update their
648
0
    // contents.
649
0
    auto mutableCopy = MutableArrayRef<StringRef>(
650
0
        reinterpret_cast<StringRef *>(rawData), numEntries);
651
0
    auto stringData = rawData + numEntries * sizeof(StringRef);
652
0
653
0
    for (int i = 0; i < numEntries; i++) {
654
0
      memcpy(stringData, data[i].data(), data[i].size());
655
0
      mutableCopy[i] = StringRef(stringData, data[i].size());
656
0
      stringData += data[i].size();
657
0
    }
658
0
659
0
    copy =
660
0
        ArrayRef<StringRef>(reinterpret_cast<StringRef *>(rawData), numEntries);
661
0
662
0
    return new (allocator.allocate<DenseStringElementsAttributeStorage>())
663
0
        DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
664
0
  }
665
666
  ArrayRef<StringRef> data;
667
};
668
669
/// An attribute representing a reference to a tensor constant with opaque
670
/// content.
671
struct OpaqueElementsAttributeStorage : public AttributeStorage {
672
  using KeyTy = std::tuple<Type, Dialect *, StringRef>;
673
674
  OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes)
675
0
      : AttributeStorage(type), dialect(dialect), bytes(bytes) {}
676
677
  /// Key equality and hash functions.
678
0
  bool operator==(const KeyTy &key) const {
679
0
    return key == std::make_tuple(getType(), dialect, bytes);
680
0
  }
681
0
  static unsigned hashKey(const KeyTy &key) {
682
0
    return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
683
0
                              std::get<2>(key));
684
0
  }
685
686
  /// Construct a new storage instance.
687
  static OpaqueElementsAttributeStorage *
688
0
  construct(AttributeStorageAllocator &allocator, KeyTy key) {
689
0
    // TODO(b/131468830): Provide a way to avoid copying content of large opaque
690
0
    // tensors This will likely require a new reference attribute kind.
691
0
    return new (allocator.allocate<OpaqueElementsAttributeStorage>())
692
0
        OpaqueElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
693
0
                                       allocator.copyInto(std::get<2>(key)));
694
0
  }
695
696
  Dialect *dialect;
697
  StringRef bytes;
698
};
699
700
/// An attribute representing a reference to a sparse vector or tensor object.
701
struct SparseElementsAttributeStorage : public AttributeStorage {
702
  using KeyTy = std::tuple<Type, DenseIntElementsAttr, DenseElementsAttr>;
703
704
  SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices,
705
                                 DenseElementsAttr values)
706
0
      : AttributeStorage(type), indices(indices), values(values) {}
707
708
  /// Key equality and hash functions.
709
0
  bool operator==(const KeyTy &key) const {
710
0
    return key == std::make_tuple(getType(), indices, values);
711
0
  }
712
0
  static unsigned hashKey(const KeyTy &key) {
713
0
    return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
714
0
                              std::get<2>(key));
715
0
  }
716
717
  /// Construct a new storage instance.
718
  static SparseElementsAttributeStorage *
719
0
  construct(AttributeStorageAllocator &allocator, KeyTy key) {
720
0
    return new (allocator.allocate<SparseElementsAttributeStorage>())
721
0
        SparseElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
722
0
                                       std::get<2>(key));
723
0
  }
724
725
  DenseIntElementsAttr indices;
726
  DenseElementsAttr values;
727
};
728
} // namespace detail
729
} // namespace mlir
730
731
#endif // ATTRIBUTEDETAIL_H_