Coverage Report

Created: 2020-06-26 05:44

/home/arjun/llvm-project/mlir/lib/IR/AsmPrinter.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements the MLIR AsmPrinter class, which is used to implement
10
// the various print() methods on the core IR objects.
11
//
12
//===----------------------------------------------------------------------===//
13
14
#include "mlir/IR/AffineExpr.h"
15
#include "mlir/IR/AffineMap.h"
16
#include "mlir/IR/AsmState.h"
17
#include "mlir/IR/Attributes.h"
18
#include "mlir/IR/Dialect.h"
19
#include "mlir/IR/DialectImplementation.h"
20
#include "mlir/IR/Function.h"
21
#include "mlir/IR/IntegerSet.h"
22
#include "mlir/IR/MLIRContext.h"
23
#include "mlir/IR/Module.h"
24
#include "mlir/IR/OpImplementation.h"
25
#include "mlir/IR/Operation.h"
26
#include "mlir/IR/StandardTypes.h"
27
#include "llvm/ADT/APFloat.h"
28
#include "llvm/ADT/DenseMap.h"
29
#include "llvm/ADT/MapVector.h"
30
#include "llvm/ADT/STLExtras.h"
31
#include "llvm/ADT/ScopedHashTable.h"
32
#include "llvm/ADT/SetVector.h"
33
#include "llvm/ADT/SmallString.h"
34
#include "llvm/ADT/StringExtras.h"
35
#include "llvm/ADT/StringSet.h"
36
#include "llvm/Support/CommandLine.h"
37
#include "llvm/Support/Regex.h"
38
#include "llvm/Support/SaveAndRestore.h"
39
using namespace mlir;
40
using namespace mlir::detail;
41
42
0
void Identifier::print(raw_ostream &os) const { os << str(); }
43
44
0
void Identifier::dump() const { print(llvm::errs()); }
45
46
0
void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
47
48
0
void OperationName::dump() const { print(llvm::errs()); }
49
50
0
DialectAsmPrinter::~DialectAsmPrinter() {}
51
52
0
OpAsmPrinter::~OpAsmPrinter() {}
53
54
//===--------------------------------------------------------------------===//
55
// Operation OpAsm interface.
56
//===--------------------------------------------------------------------===//
57
58
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
59
#include "mlir/IR/OpAsmInterface.cpp.inc"
60
61
//===----------------------------------------------------------------------===//
62
// OpPrintingFlags
63
//===----------------------------------------------------------------------===//
64
65
namespace {
66
/// This struct contains command line options that can be used to initialize
67
/// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
68
/// for global command line options.
69
struct AsmPrinterOptions {
70
  llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
71
      "mlir-print-elementsattrs-with-hex-if-larger",
72
      llvm::cl::desc(
73
          "Print DenseElementsAttrs with a hex string that have "
74
          "more elements than the given upper limit (use -1 to disable)")};
75
76
  llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
77
      "mlir-elide-elementsattrs-if-larger",
78
      llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
79
                     "more elements than the given upper limit")};
80
81
  llvm::cl::opt<bool> printDebugInfoOpt{
82
      "mlir-print-debuginfo", llvm::cl::init(false),
83
      llvm::cl::desc("Print debug info in MLIR output")};
84
85
  llvm::cl::opt<bool> printPrettyDebugInfoOpt{
86
      "mlir-pretty-debuginfo", llvm::cl::init(false),
87
      llvm::cl::desc("Print pretty debug info in MLIR output")};
88
89
  // Use the generic op output form in the operation printer even if the custom
90
  // form is defined.
91
  llvm::cl::opt<bool> printGenericOpFormOpt{
92
      "mlir-print-op-generic", llvm::cl::init(false),
93
      llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
94
95
  llvm::cl::opt<bool> printLocalScopeOpt{
96
      "mlir-print-local-scope", llvm::cl::init(false),
97
      llvm::cl::desc("Print assuming in local scope by default"),
98
      llvm::cl::Hidden};
99
};
100
} // end anonymous namespace
101
102
static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
103
104
/// Register a set of useful command-line options that can be used to configure
105
/// various flags within the AsmPrinter.
106
0
void mlir::registerAsmPrinterCLOptions() {
107
0
  // Make sure that the options struct has been initialized.
108
0
  *clOptions;
109
0
}
110
111
/// Initialize the printing flags with default supplied by the cl::opts above.
112
OpPrintingFlags::OpPrintingFlags()
113
    : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
114
0
      printGenericOpFormFlag(false), printLocalScope(false) {
115
0
  // Initialize based upon command line options, if they are available.
116
0
  if (!clOptions.isConstructed())
117
0
    return;
118
0
  if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
119
0
    elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
120
0
  printDebugInfoFlag = clOptions->printDebugInfoOpt;
121
0
  printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
122
0
  printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
123
0
  printLocalScope = clOptions->printLocalScopeOpt;
124
0
}
125
126
/// Enable the elision of large elements attributes, by printing a '...'
127
/// instead of the element data, when the number of elements is greater than
128
/// `largeElementLimit`. Note: The IR generated with this option is not
129
/// parsable.
130
OpPrintingFlags &
131
0
OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
132
0
  elementsAttrElementLimit = largeElementLimit;
133
0
  return *this;
134
0
}
135
136
/// Enable printing of debug information. If 'prettyForm' is set to true,
137
/// debug information is printed in a more readable 'pretty' form.
138
0
OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
139
0
  printDebugInfoFlag = true;
140
0
  printDebugInfoPrettyFormFlag = prettyForm;
141
0
  return *this;
142
0
}
143
144
/// Always print operations in the generic form.
145
0
OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
146
0
  printGenericOpFormFlag = true;
147
0
  return *this;
148
0
}
149
150
/// Use local scope when printing the operation. This allows for using the
151
/// printer in a more localized and thread-safe setting, but may not necessarily
152
/// be identical of what the IR will look like when dumping the full module.
153
0
OpPrintingFlags &OpPrintingFlags::useLocalScope() {
154
0
  printLocalScope = true;
155
0
  return *this;
156
0
}
157
158
/// Return if the given ElementsAttr should be elided.
159
0
bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
160
0
  return elementsAttrElementLimit.hasValue() &&
161
0
         *elementsAttrElementLimit < int64_t(attr.getNumElements());
162
0
}
163
164
/// Return the size limit for printing large ElementsAttr.
165
0
Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
166
0
  return elementsAttrElementLimit;
167
0
}
168
169
/// Return if debug information should be printed.
170
0
bool OpPrintingFlags::shouldPrintDebugInfo() const {
171
0
  return printDebugInfoFlag;
172
0
}
173
174
/// Return if debug information should be printed in the pretty form.
175
0
bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
176
0
  return printDebugInfoPrettyFormFlag;
177
0
}
178
179
/// Return if operations should be printed in the generic form.
180
0
bool OpPrintingFlags::shouldPrintGenericOpForm() const {
181
0
  return printGenericOpFormFlag;
182
0
}
183
184
/// Return if the printer should use local scope when dumping the IR.
185
0
bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
186
187
/// Returns true if an ElementsAttr with the given number of elements should be
188
/// printed with hex.
189
0
static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
190
0
  // Check to see if a command line option was provided for the limit.
191
0
  if (clOptions.isConstructed()) {
192
0
    if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
193
0
      // -1 is used to disable hex printing.
194
0
      if (clOptions->printElementsAttrWithHexIfLarger == -1)
195
0
        return false;
196
0
      return numElements > clOptions->printElementsAttrWithHexIfLarger;
197
0
    }
198
0
  }
199
0
200
0
  // Otherwise, default to printing with hex if the number of elements is >100.
201
0
  return numElements > 100;
202
0
}
203
204
//===----------------------------------------------------------------------===//
205
// NewLineCounter
206
//===----------------------------------------------------------------------===//
207
208
namespace {
209
/// This class is a simple formatter that emits a new line when inputted into a
210
/// stream, that enables counting the number of newlines emitted. This class
211
/// should be used whenever emitting newlines in the printer.
212
struct NewLineCounter {
213
  unsigned curLine = 1;
214
};
215
} // end anonymous namespace
216
217
0
static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
218
0
  ++newLine.curLine;
219
0
  return os << '\n';
220
0
}
221
222
//===----------------------------------------------------------------------===//
223
// AliasState
224
//===----------------------------------------------------------------------===//
225
226
namespace {
227
/// This class manages the state for type and attribute aliases.
228
class AliasState {
229
public:
230
  // Initialize the internal aliases.
231
  void
232
  initialize(Operation *op,
233
             DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
234
235
  /// Return a name used for an attribute alias, or empty if there is no alias.
236
  Twine getAttributeAlias(Attribute attr) const;
237
238
  /// Print all of the referenced attribute aliases.
239
  void printAttributeAliases(raw_ostream &os, NewLineCounter &newLine) const;
240
241
  /// Return a string to use as an alias for the given type, or empty if there
242
  /// is no alias recorded.
243
  StringRef getTypeAlias(Type ty) const;
244
245
  /// Print all of the referenced type aliases.
246
  void printTypeAliases(raw_ostream &os, NewLineCounter &newLine) const;
247
248
private:
249
  /// A special index constant used for non-kind attribute aliases.
250
  enum { NonAttrKindAlias = -1 };
251
252
  /// Record a reference to the given attribute.
253
  void recordAttributeReference(Attribute attr);
254
255
  /// Record a reference to the given type.
256
  void recordTypeReference(Type ty);
257
258
  // Visit functions.
259
  void visitOperation(Operation *op);
260
  void visitType(Type type);
261
  void visitAttribute(Attribute attr);
262
263
  /// Set of attributes known to be used within the module.
264
  llvm::SetVector<Attribute> usedAttributes;
265
266
  /// Mapping between attribute and a pair comprised of a base alias name and a
267
  /// count suffix. If the suffix is set to -1, it is not displayed.
268
  llvm::MapVector<Attribute, std::pair<StringRef, int>> attrToAlias;
269
270
  /// Mapping between attribute kind and a pair comprised of a base alias name
271
  /// and a unique list of attributes belonging to this kind sorted by location
272
  /// seen in the module.
273
  llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
274
      attrKindToAlias;
275
276
  /// Set of types known to be used within the module.
277
  llvm::SetVector<Type> usedTypes;
278
279
  /// A mapping between a type and a given alias.
280
  DenseMap<Type, StringRef> typeToAlias;
281
};
282
} // end anonymous namespace
283
284
// Utility to generate a function to register a symbol alias.
285
static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
286
  assert(!name.empty() && "expected alias name to be non-empty");
287
  // TODO(riverriddle) Assert that the provided alias name can be lexed as
288
  // an identifier.
289
290
  // Check that the alias doesn't contain a '.' character and the name is not
291
  // already in use.
292
  return !name.contains('.') && usedAliases.insert(name).second;
293
}
294
295
void AliasState::initialize(
296
    Operation *op,
297
0
    DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
298
0
  // Track the identifiers in use for each symbol so that the same identifier
299
0
  // isn't used twice.
300
0
  llvm::StringSet<> usedAliases;
301
0
302
0
  // Collect the set of aliases from each dialect.
303
0
  SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
304
0
  SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
305
0
  SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
306
0
307
0
  // AffineMap/Integer set have specific kind aliases.
308
0
  attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
309
0
  attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
310
0
311
0
  for (auto &interface : interfaces) {
312
0
    interface.getAttributeKindAliases(attributeKindAliases);
313
0
    interface.getAttributeAliases(attributeAliases);
314
0
    interface.getTypeAliases(typeAliases);
315
0
  }
316
0
317
0
  // Setup the attribute kind aliases.
318
0
  StringRef alias;
319
0
  unsigned attrKind;
320
0
  for (auto &attrAliasPair : attributeKindAliases) {
321
0
    std::tie(attrKind, alias) = attrAliasPair;
322
0
    assert(!alias.empty() && "expected non-empty alias string");
323
0
    if (!usedAliases.count(alias) && !alias.contains('.'))
324
0
      attrKindToAlias.insert({attrKind, {alias, {}}});
325
0
  }
326
0
327
0
  // Clear the set of used identifiers so that the attribute kind aliases are
328
0
  // just a prefix and not the full alias, i.e. there may be some overlap.
329
0
  usedAliases.clear();
330
0
331
0
  // Register the attribute aliases.
332
0
  // Create a regex for the attribute kind alias names, these have a prefix with
333
0
  // a counter appended to the end. We prevent normal aliases from having these
334
0
  // names to avoid collisions.
335
0
  llvm::Regex reservedAttrNames("[0-9]+$");
336
0
337
0
  // Attribute value aliases.
338
0
  Attribute attr;
339
0
  for (auto &attrAliasPair : attributeAliases) {
340
0
    std::tie(attr, alias) = attrAliasPair;
341
0
    if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
342
0
      attrToAlias.insert({attr, {alias, NonAttrKindAlias}});
343
0
  }
344
0
345
0
  // Clear the set of used identifiers as types can have the same identifiers as
346
0
  // affine structures.
347
0
  usedAliases.clear();
348
0
349
0
  // Type aliases.
350
0
  for (auto &typeAliasPair : typeAliases)
351
0
    if (canRegisterAlias(typeAliasPair.second, usedAliases))
352
0
      typeToAlias.insert(typeAliasPair);
353
0
354
0
  // Traverse the given IR to generate the set of used attributes/types.
355
0
  op->walk([&](Operation *op) { visitOperation(op); });
356
0
}
357
358
/// Return a name used for an attribute alias, or empty if there is no alias.
359
0
Twine AliasState::getAttributeAlias(Attribute attr) const {
360
0
  auto alias = attrToAlias.find(attr);
361
0
  if (alias == attrToAlias.end())
362
0
    return Twine();
363
0
364
0
  // Return the alias for this attribute, along with the index if this was
365
0
  // generated by a kind alias.
366
0
  int kindIndex = alias->second.second;
367
0
  return alias->second.first +
368
0
         (kindIndex == NonAttrKindAlias ? Twine() : Twine(kindIndex));
369
0
}
370
371
/// Print all of the referenced attribute aliases.
372
void AliasState::printAttributeAliases(raw_ostream &os,
373
0
                                       NewLineCounter &newLine) const {
374
0
  auto printAlias = [&](StringRef alias, Attribute attr, int index) {
375
0
    os << '#' << alias;
376
0
    if (index != NonAttrKindAlias)
377
0
      os << index;
378
0
    os << " = " << attr << newLine;
379
0
  };
380
0
381
0
  // Print all of the attribute kind aliases.
382
0
  for (auto &kindAlias : attrKindToAlias) {
383
0
    auto &aliasAttrsPair = kindAlias.second;
384
0
    for (unsigned i = 0, e = aliasAttrsPair.second.size(); i != e; ++i)
385
0
      printAlias(aliasAttrsPair.first, aliasAttrsPair.second[i], i);
386
0
    os << newLine;
387
0
  }
388
0
389
0
  // In a second pass print all of the remaining attribute aliases that aren't
390
0
  // kind aliases.
391
0
  for (Attribute attr : usedAttributes) {
392
0
    auto alias = attrToAlias.find(attr);
393
0
    if (alias != attrToAlias.end() && alias->second.second == NonAttrKindAlias)
394
0
      printAlias(alias->second.first, attr, alias->second.second);
395
0
  }
396
0
}
397
398
/// Return a string to use as an alias for the given type, or empty if there
399
/// is no alias recorded.
400
0
StringRef AliasState::getTypeAlias(Type ty) const {
401
0
  return typeToAlias.lookup(ty);
402
0
}
403
404
/// Print all of the referenced type aliases.
405
void AliasState::printTypeAliases(raw_ostream &os,
406
0
                                  NewLineCounter &newLine) const {
407
0
  for (Type type : usedTypes) {
408
0
    auto alias = typeToAlias.find(type);
409
0
    if (alias != typeToAlias.end())
410
0
      os << '!' << alias->second << " = type " << type << newLine;
411
0
  }
412
0
}
413
414
/// Record a reference to the given attribute.
415
0
void AliasState::recordAttributeReference(Attribute attr) {
416
0
  // Don't recheck attributes that have already been seen or those that
417
0
  // already have an alias.
418
0
  if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
419
0
    return;
420
0
421
0
  // If this attribute kind has an alias, then record one for this attribute.
422
0
  auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
423
0
  if (alias == attrKindToAlias.end())
424
0
    return;
425
0
  std::pair<StringRef, int> attrAlias(alias->second.first,
426
0
                                      alias->second.second.size());
427
0
  attrToAlias.insert({attr, attrAlias});
428
0
  alias->second.second.push_back(attr);
429
0
}
430
431
/// Record a reference to the given type.
432
0
void AliasState::recordTypeReference(Type ty) { usedTypes.insert(ty); }
433
434
// TODO Support visiting other types/operations when implemented.
435
0
void AliasState::visitType(Type type) {
436
0
  recordTypeReference(type);
437
0
438
0
  if (auto funcType = type.dyn_cast<FunctionType>()) {
439
0
    // Visit input and result types for functions.
440
0
    for (auto input : funcType.getInputs())
441
0
      visitType(input);
442
0
    for (auto result : funcType.getResults())
443
0
      visitType(result);
444
0
  } else if (auto shapedType = type.dyn_cast<ShapedType>()) {
445
0
    visitType(shapedType.getElementType());
446
0
447
0
    // Visit affine maps in memref type.
448
0
    if (auto memref = type.dyn_cast<MemRefType>())
449
0
      for (auto map : memref.getAffineMaps())
450
0
        recordAttributeReference(AffineMapAttr::get(map));
451
0
  }
452
0
}
453
454
0
void AliasState::visitAttribute(Attribute attr) {
455
0
  recordAttributeReference(attr);
456
0
457
0
  if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
458
0
    for (auto elt : arrayAttr.getValue())
459
0
      visitAttribute(elt);
460
0
  } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
461
0
    visitType(typeAttr.getValue());
462
0
  }
463
0
}
464
465
0
void AliasState::visitOperation(Operation *op) {
466
0
  // Visit all the types used in the operation.
467
0
  for (auto type : op->getOperandTypes())
468
0
    visitType(type);
469
0
  for (auto type : op->getResultTypes())
470
0
    visitType(type);
471
0
  for (auto &region : op->getRegions())
472
0
    for (auto &block : region)
473
0
      for (auto arg : block.getArguments())
474
0
        visitType(arg.getType());
475
0
476
0
  // Visit each of the attributes.
477
0
  for (auto elt : op->getAttrs())
478
0
    visitAttribute(elt.second);
479
0
}
480
481
//===----------------------------------------------------------------------===//
482
// SSANameState
483
//===----------------------------------------------------------------------===//
484
485
namespace {
486
/// This class manages the state of SSA value names.
487
class SSANameState {
488
public:
489
  /// A sentinel value used for values with names set.
490
  enum : unsigned { NameSentinel = ~0U };
491
492
  SSANameState(Operation *op,
493
               DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
494
495
  /// Print the SSA identifier for the given value to 'stream'. If
496
  /// 'printResultNo' is true, it also presents the result number ('#' number)
497
  /// of this value.
498
  void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
499
500
  /// Return the result indices for each of the result groups registered by this
501
  /// operation, or empty if none exist.
502
  ArrayRef<int> getOpResultGroups(Operation *op);
503
504
  /// Get the ID for the given block.
505
  unsigned getBlockID(Block *block);
506
507
  /// Renumber the arguments for the specified region to the same names as the
508
  /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
509
  /// details.
510
  void shadowRegionArgs(Region &region, ValueRange namesToUse);
511
512
private:
513
  /// Number the SSA values within the given IR unit.
514
  void numberValuesInRegion(
515
      Region &region,
516
      DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
517
  void numberValuesInBlock(
518
      Block &block,
519
      DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
520
  void numberValuesInOp(
521
      Operation &op,
522
      DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
523
524
  /// Given a result of an operation 'result', find the result group head
525
  /// 'lookupValue' and the result of 'result' within that group in
526
  /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
527
  /// has more than 1 result.
528
  void getResultIDAndNumber(OpResult result, Value &lookupValue,
529
                            Optional<int> &lookupResultNo) const;
530
531
  /// Set a special value name for the given value.
532
  void setValueName(Value value, StringRef name);
533
534
  /// Uniques the given value name within the printer. If the given name
535
  /// conflicts, it is automatically renamed.
536
  StringRef uniqueValueName(StringRef name);
537
538
  /// This is the value ID for each SSA value. If this returns NameSentinel,
539
  /// then the valueID has an entry in valueNames.
540
  DenseMap<Value, unsigned> valueIDs;
541
  DenseMap<Value, StringRef> valueNames;
542
543
  /// This is a map of operations that contain multiple named result groups,
544
  /// i.e. there may be multiple names for the results of the operation. The
545
  /// value of this map are the result numbers that start a result group.
546
  DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
547
548
  /// This is the block ID for each block in the current.
549
  DenseMap<Block *, unsigned> blockIDs;
550
551
  /// This keeps track of all of the non-numeric names that are in flight,
552
  /// allowing us to check for duplicates.
553
  /// Note: the value of the map is unused.
554
  llvm::ScopedHashTable<StringRef, char> usedNames;
555
  llvm::BumpPtrAllocator usedNameAllocator;
556
557
  /// This is the next value ID to assign in numbering.
558
  unsigned nextValueID = 0;
559
  /// This is the next ID to assign to a region entry block argument.
560
  unsigned nextArgumentID = 0;
561
  /// This is the next ID to assign when a name conflict is detected.
562
  unsigned nextConflictID = 0;
563
};
564
} // end anonymous namespace
565
566
SSANameState::SSANameState(
567
    Operation *op,
568
0
    DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
569
0
  llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
570
0
  numberValuesInOp(*op, interfaces);
571
0
572
0
  for (auto &region : op->getRegions())
573
0
    numberValuesInRegion(region, interfaces);
574
0
}
575
576
void SSANameState::printValueID(Value value, bool printResultNo,
577
0
                                raw_ostream &stream) const {
578
0
  if (!value) {
579
0
    stream << "<<NULL>>";
580
0
    return;
581
0
  }
582
0
583
0
  Optional<int> resultNo;
584
0
  auto lookupValue = value;
585
0
586
0
  // If this is an operation result, collect the head lookup value of the result
587
0
  // group and the result number of 'result' within that group.
588
0
  if (OpResult result = value.dyn_cast<OpResult>())
589
0
    getResultIDAndNumber(result, lookupValue, resultNo);
590
0
591
0
  auto it = valueIDs.find(lookupValue);
592
0
  if (it == valueIDs.end()) {
593
0
    stream << "<<UNKNOWN SSA VALUE>>";
594
0
    return;
595
0
  }
596
0
597
0
  stream << '%';
598
0
  if (it->second != NameSentinel) {
599
0
    stream << it->second;
600
0
  } else {
601
0
    auto nameIt = valueNames.find(lookupValue);
602
0
    assert(nameIt != valueNames.end() && "Didn't have a name entry?");
603
0
    stream << nameIt->second;
604
0
  }
605
0
606
0
  if (resultNo.hasValue() && printResultNo)
607
0
    stream << '#' << resultNo;
608
0
}
609
610
0
ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
611
0
  auto it = opResultGroups.find(op);
612
0
  return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
613
0
}
614
615
0
unsigned SSANameState::getBlockID(Block *block) {
616
0
  auto it = blockIDs.find(block);
617
0
  return it != blockIDs.end() ? it->second : NameSentinel;
618
0
}
619
620
0
void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
621
0
  assert(!region.empty() && "cannot shadow arguments of an empty region");
622
0
  assert(region.front().getNumArguments() == namesToUse.size() &&
623
0
         "incorrect number of names passed in");
624
0
  assert(region.getParentOp()->isKnownIsolatedFromAbove() &&
625
0
         "only KnownIsolatedFromAbove ops can shadow names");
626
0
627
0
  SmallVector<char, 16> nameStr;
628
0
  for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
629
0
    auto nameToUse = namesToUse[i];
630
0
    if (nameToUse == nullptr)
631
0
      continue;
632
0
    auto nameToReplace = region.front().getArgument(i);
633
0
634
0
    nameStr.clear();
635
0
    llvm::raw_svector_ostream nameStream(nameStr);
636
0
    printValueID(nameToUse, /*printResultNo=*/true, nameStream);
637
0
638
0
    // Entry block arguments should already have a pretty "arg" name.
639
0
    assert(valueIDs[nameToReplace] == NameSentinel);
640
0
641
0
    // Use the name without the leading %.
642
0
    auto name = StringRef(nameStream.str()).drop_front();
643
0
644
0
    // Overwrite the name.
645
0
    valueNames[nameToReplace] = name.copy(usedNameAllocator);
646
0
  }
647
0
}
648
649
void SSANameState::numberValuesInRegion(
650
    Region &region,
651
0
    DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
652
0
  // Save the current value ids to allow for numbering values in sibling regions
653
0
  // the same.
654
0
  llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
655
0
  llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
656
0
  llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
657
0
658
0
  // Push a new used names scope.
659
0
  llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
660
0
661
0
  // Number the values within this region in a breadth-first order.
662
0
  unsigned nextBlockID = 0;
663
0
  for (auto &block : region) {
664
0
    // Each block gets a unique ID, and all of the operations within it get
665
0
    // numbered as well.
666
0
    blockIDs[&block] = nextBlockID++;
667
0
    numberValuesInBlock(block, interfaces);
668
0
  }
669
0
670
0
  // After that we traverse the nested regions.
671
0
  // TODO: Rework this loop to not use recursion.
672
0
  for (auto &block : region) {
673
0
    for (auto &op : block)
674
0
      for (auto &nestedRegion : op.getRegions())
675
0
        numberValuesInRegion(nestedRegion, interfaces);
676
0
  }
677
0
}
678
679
void SSANameState::numberValuesInBlock(
680
    Block &block,
681
0
    DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
682
0
  auto setArgNameFn = [&](Value arg, StringRef name) {
683
0
    assert(!valueIDs.count(arg) && "arg numbered multiple times");
684
0
    assert(arg.cast<BlockArgument>().getOwner() == &block &&
685
0
           "arg not defined in 'block'");
686
0
    setValueName(arg, name);
687
0
  };
688
0
689
0
  bool isEntryBlock = block.isEntryBlock();
690
0
  if (isEntryBlock) {
691
0
    if (auto *op = block.getParentOp()) {
692
0
      if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
693
0
        asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
694
0
    }
695
0
  }
696
0
697
0
  // Number the block arguments. We give entry block arguments a special name
698
0
  // 'arg'.
699
0
  SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
700
0
  llvm::raw_svector_ostream specialName(specialNameBuffer);
701
0
  for (auto arg : block.getArguments()) {
702
0
    if (valueIDs.count(arg))
703
0
      continue;
704
0
    if (isEntryBlock) {
705
0
      specialNameBuffer.resize(strlen("arg"));
706
0
      specialName << nextArgumentID++;
707
0
    }
708
0
    setValueName(arg, specialName.str());
709
0
  }
710
0
711
0
  // Number the operations in this block.
712
0
  for (auto &op : block)
713
0
    numberValuesInOp(op, interfaces);
714
0
}
715
716
void SSANameState::numberValuesInOp(
717
    Operation &op,
718
0
    DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
719
0
  unsigned numResults = op.getNumResults();
720
0
  if (numResults == 0)
721
0
    return;
722
0
  Value resultBegin = op.getResult(0);
723
0
724
0
  // Function used to set the special result names for the operation.
725
0
  SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
726
0
  auto setResultNameFn = [&](Value result, StringRef name) {
727
0
    assert(!valueIDs.count(result) && "result numbered multiple times");
728
0
    assert(result.getDefiningOp() == &op && "result not defined by 'op'");
729
0
    setValueName(result, name);
730
0
731
0
    // Record the result number for groups not anchored at 0.
732
0
    if (int resultNo = result.cast<OpResult>().getResultNumber())
733
0
      resultGroups.push_back(resultNo);
734
0
  };
735
0
  if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
736
0
    asmInterface.getAsmResultNames(setResultNameFn);
737
0
  else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect()))
738
0
    asmInterface->getAsmResultNames(&op, setResultNameFn);
739
0
740
0
  // If the first result wasn't numbered, give it a default number.
741
0
  if (valueIDs.try_emplace(resultBegin, nextValueID).second)
742
0
    ++nextValueID;
743
0
744
0
  // If this operation has multiple result groups, mark it.
745
0
  if (resultGroups.size() != 1) {
746
0
    llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
747
0
    opResultGroups.try_emplace(&op, std::move(resultGroups));
748
0
  }
749
0
}
750
751
void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
752
0
                                        Optional<int> &lookupResultNo) const {
753
0
  Operation *owner = result.getOwner();
754
0
  if (owner->getNumResults() == 1)
755
0
    return;
756
0
  int resultNo = result.getResultNumber();
757
0
758
0
  // If this operation has multiple result groups, we will need to find the
759
0
  // one corresponding to this result.
760
0
  auto resultGroupIt = opResultGroups.find(owner);
761
0
  if (resultGroupIt == opResultGroups.end()) {
762
0
    // If not, just use the first result.
763
0
    lookupResultNo = resultNo;
764
0
    lookupValue = owner->getResult(0);
765
0
    return;
766
0
  }
767
0
768
0
  // Find the correct index using a binary search, as the groups are ordered.
769
0
  ArrayRef<int> resultGroups = resultGroupIt->second;
770
0
  auto it = llvm::upper_bound(resultGroups, resultNo);
771
0
  int groupResultNo = 0, groupSize = 0;
772
0
773
0
  // If there are no smaller elements, the last result group is the lookup.
774
0
  if (it == resultGroups.end()) {
775
0
    groupResultNo = resultGroups.back();
776
0
    groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
777
0
  } else {
778
0
    // Otherwise, the previous element is the lookup.
779
0
    groupResultNo = *std::prev(it);
780
0
    groupSize = *it - groupResultNo;
781
0
  }
782
0
783
0
  // We only record the result number for a group of size greater than 1.
784
0
  if (groupSize != 1)
785
0
    lookupResultNo = resultNo - groupResultNo;
786
0
  lookupValue = owner->getResult(groupResultNo);
787
0
}
788
789
0
void SSANameState::setValueName(Value value, StringRef name) {
790
0
  // If the name is empty, the value uses the default numbering.
791
0
  if (name.empty()) {
792
0
    valueIDs[value] = nextValueID++;
793
0
    return;
794
0
  }
795
0
796
0
  valueIDs[value] = NameSentinel;
797
0
  valueNames[value] = uniqueValueName(name);
798
0
}
799
800
/// Returns true if 'c' is an allowable punctuation character: [$._-]
801
/// Returns false otherwise.
802
0
static bool isPunct(char c) {
803
0
  return c == '$' || c == '.' || c == '_' || c == '-';
804
0
}
805
806
0
StringRef SSANameState::uniqueValueName(StringRef name) {
807
0
  assert(!name.empty() && "Shouldn't have an empty name here");
808
0
809
0
  // Check to see if this name is valid.  If it starts with a digit, then it
810
0
  // could conflict with the autogenerated numeric ID's (we unique them in a
811
0
  // different map), so add an underscore prefix to avoid problems.
812
0
  if (isdigit(name[0])) {
813
0
    SmallString<16> tmpName("_");
814
0
    tmpName += name;
815
0
    return uniqueValueName(tmpName);
816
0
  }
817
0
818
0
  // Check to see if the name consists of all-valid identifiers.  If not, we
819
0
  // need to escape them.
820
0
  for (char ch : name) {
821
0
    if (isalpha(ch) || isPunct(ch) || isdigit(ch))
822
0
      continue;
823
0
824
0
    SmallString<16> tmpName;
825
0
    for (char ch : name) {
826
0
      if (isalpha(ch) || isPunct(ch) || isdigit(ch))
827
0
        tmpName += ch;
828
0
      else if (ch == ' ')
829
0
        tmpName += '_';
830
0
      else {
831
0
        tmpName += llvm::utohexstr((unsigned char)ch);
832
0
      }
833
0
    }
834
0
    return uniqueValueName(tmpName);
835
0
  }
836
0
837
0
  // Check to see if this name is already unique.
838
0
  if (!usedNames.count(name)) {
839
0
    name = name.copy(usedNameAllocator);
840
0
  } else {
841
0
    // Otherwise, we had a conflict - probe until we find a unique name. This
842
0
    // is guaranteed to terminate (and usually in a single iteration) because it
843
0
    // generates new names by incrementing nextConflictID.
844
0
    SmallString<64> probeName(name);
845
0
    probeName.push_back('_');
846
0
    while (true) {
847
0
      probeName.resize(name.size() + 1);
848
0
      probeName += llvm::utostr(nextConflictID++);
849
0
      if (!usedNames.count(probeName)) {
850
0
        name = StringRef(probeName).copy(usedNameAllocator);
851
0
        break;
852
0
      }
853
0
    }
854
0
  }
855
0
856
0
  usedNames.insert(name, char());
857
0
  return name;
858
0
}
859
860
//===----------------------------------------------------------------------===//
861
// AsmState
862
//===----------------------------------------------------------------------===//
863
864
namespace mlir {
865
namespace detail {
866
class AsmStateImpl {
867
public:
868
  explicit AsmStateImpl(Operation *op, AsmState::LocationMap *locationMap)
869
      : interfaces(op->getContext()), nameState(op, interfaces),
870
0
        locationMap(locationMap) {}
871
872
  /// Initialize the alias state to enable the printing of aliases.
873
0
  void initializeAliases(Operation *op) {
874
0
    aliasState.initialize(op, interfaces);
875
0
  }
876
877
  /// Get an instance of the OpAsmDialectInterface for the given dialect, or
878
  /// null if one wasn't registered.
879
0
  const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
880
0
    return interfaces.getInterfaceFor(dialect);
881
0
  }
882
883
  /// Get the state used for aliases.
884
0
  AliasState &getAliasState() { return aliasState; }
885
886
  /// Get the state used for SSA names.
887
0
  SSANameState &getSSANameState() { return nameState; }
888
889
  /// Register the location, line and column, within the buffer that the given
890
  /// operation was printed at.
891
0
  void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
892
0
    if (locationMap)
893
0
      (*locationMap)[op] = std::make_pair(line, col);
894
0
  }
895
896
private:
897
  /// Collection of OpAsm interfaces implemented in the context.
898
  DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
899
900
  /// The state used for attribute and type aliases.
901
  AliasState aliasState;
902
903
  /// The state used for SSA value names.
904
  SSANameState nameState;
905
906
  /// An optional location map to be populated.
907
  AsmState::LocationMap *locationMap;
908
};
909
} // end namespace detail
910
} // end namespace mlir
911
912
AsmState::AsmState(Operation *op, LocationMap *locationMap)
913
0
    : impl(std::make_unique<AsmStateImpl>(op, locationMap)) {}
914
0
AsmState::~AsmState() {}
915
916
//===----------------------------------------------------------------------===//
917
// ModulePrinter
918
//===----------------------------------------------------------------------===//
919
920
namespace {
921
class ModulePrinter {
922
public:
923
  ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
924
                AsmStateImpl *state = nullptr)
925
0
      : os(os), printerFlags(flags), state(state) {}
926
  explicit ModulePrinter(ModulePrinter &printer)
927
      : os(printer.os), printerFlags(printer.printerFlags),
928
0
        state(printer.state) {}
929
930
  /// Returns the output stream of the printer.
931
0
  raw_ostream &getStream() { return os; }
932
933
  template <typename Container, typename UnaryFunctor>
934
0
  inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
935
0
    llvm::interleaveComma(c, os, each_fn);
936
0
  }
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefISt4pairIN4mlir10IdentifierENS5_9AttributeEEEEZNS0_14printAttributeES7_NS0_15AttrTypeElisionEE3$_9EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir9AttributeEEEZNS0_14printAttributeES5_NS0_15AttrTypeElisionEE4$_10EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir4TypeEEEZNS0_9printTypeES5_E4$_14EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir4TypeEEEZNS0_9printTypeES5_E4$_15EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir4TypeEEEZNS0_9printTypeES5_E4$_16EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir10AffineExprEEEZNS0_14printAffineMapENS4_9AffineMapEE4$_19EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4mlir10ValueRangeEZNS_16OperationPrinter24printSuccessorAndUseListEPNS2_5BlockES3_E4$_26EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4mlir10ValueRangeEZNS_16OperationPrinter24printSuccessorAndUseListEPNS2_5BlockES3_E4$_27EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm11SmallVectorISt4pairIN4mlir10IdentifierENS5_9AttributeEELj8EEEZNS0_21printOptionalAttrDictENS2_8ArrayRefIS8_EENSA_INS2_9StringRefEEEbE4$_18EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4mlir12OperandRangeEZNS_16OperationPrinter14printGenericOpEPNS2_9OperationEE4$_21EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4mlir14SuccessorRangeEZNS_16OperationPrinter14printGenericOpEPNS2_9OperationEE4$_22EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm15MutableArrayRefIN4mlir6RegionEEEZNS_16OperationPrinter14printGenericOpEPNS4_9OperationEE4$_23EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir10AffineExprEEEZNS_16OperationPrinter22printAffineMapOfSSAIdsENS4_13AffineMapAttrENS4_10ValueRangeEE4$_29EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm14iterator_rangeINS2_6detail23value_sequence_iteratorIiEEEEZNS_16OperationPrinter14printOperationEPN4mlir9OperationEE4$_20EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm15MutableArrayRefIN4mlir13BlockArgumentEEEZNS_16OperationPrinter5printEPNS4_5BlockEbbE4$_24EEvRKT_T0_
Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm11SmallVectorISt4pairIjPN4mlir5BlockEELj4EEEZNS_16OperationPrinter5printES7_bbE4$_25EEvRKT_T0_
937
938
  /// This enum describes the different kinds of elision for the type of an
939
  /// attribute when printing it.
940
  enum class AttrTypeElision {
941
    /// The type must not be elided,
942
    Never,
943
    /// The type may be elided when it matches the default used in the parser
944
    /// (for example i64 is the default for integer attributes).
945
    May,
946
    /// The type must be elided.
947
    Must
948
  };
949
950
  /// Print the given attribute.
951
  void printAttribute(Attribute attr,
952
                      AttrTypeElision typeElision = AttrTypeElision::Never);
953
954
  void printType(Type type);
955
  void printLocation(LocationAttr loc);
956
957
  void printAffineMap(AffineMap map);
958
  void
959
  printAffineExpr(AffineExpr expr,
960
                  function_ref<void(unsigned, bool)> printValueName = nullptr);
961
  void printAffineConstraint(AffineExpr expr, bool isEq);
962
  void printIntegerSet(IntegerSet set);
963
964
protected:
965
  void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
966
                             ArrayRef<StringRef> elidedAttrs = {},
967
                             bool withKeyword = false);
968
  void printNamedAttribute(NamedAttribute attr);
969
  void printTrailingLocation(Location loc);
970
  void printLocationInternal(LocationAttr loc, bool pretty = false);
971
972
  /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
973
  /// used instead of individual elements when the elements attr is large.
974
  void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
975
976
  /// Print a dense string elements attribute.
977
  void printDenseStringElementsAttr(DenseStringElementsAttr attr);
978
979
  /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
980
  /// used instead of individual elements when the elements attr is large.
981
  void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
982
                                     bool allowHex);
983
984
  void printDialectAttribute(Attribute attr);
985
  void printDialectType(Type type);
986
987
  /// This enum is used to represent the binding strength of the enclosing
988
  /// context that an AffineExprStorage is being printed in, so we can
989
  /// intelligently produce parens.
990
  enum class BindingStrength {
991
    Weak,   // + and -
992
    Strong, // All other binary operators.
993
  };
994
  void printAffineExprInternal(
995
      AffineExpr expr, BindingStrength enclosingTightness,
996
      function_ref<void(unsigned, bool)> printValueName = nullptr);
997
998
  /// The output stream for the printer.
999
  raw_ostream &os;
1000
1001
  /// A set of flags to control the printer's behavior.
1002
  OpPrintingFlags printerFlags;
1003
1004
  /// An optional printer state for the module.
1005
  AsmStateImpl *state;
1006
1007
  /// A tracker for the number of new lines emitted during printing.
1008
  NewLineCounter newLine;
1009
};
1010
} // end anonymous namespace
1011
1012
0
void ModulePrinter::printTrailingLocation(Location loc) {
1013
0
  // Check to see if we are printing debug information.
1014
0
  if (!printerFlags.shouldPrintDebugInfo())
1015
0
    return;
1016
0
1017
0
  os << " ";
1018
0
  printLocation(loc);
1019
0
}
1020
1021
0
void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
1022
0
  switch (loc.getKind()) {
1023
0
  case StandardAttributes::OpaqueLocation:
1024
0
    printLocationInternal(loc.cast<OpaqueLoc>().getFallbackLocation(), pretty);
1025
0
    break;
1026
0
  case StandardAttributes::UnknownLocation:
1027
0
    if (pretty)
1028
0
      os << "[unknown]";
1029
0
    else
1030
0
      os << "unknown";
1031
0
    break;
1032
0
  case StandardAttributes::FileLineColLocation: {
1033
0
    auto fileLoc = loc.cast<FileLineColLoc>();
1034
0
    auto mayQuote = pretty ? "" : "\"";
1035
0
    os << mayQuote << fileLoc.getFilename() << mayQuote << ':'
1036
0
       << fileLoc.getLine() << ':' << fileLoc.getColumn();
1037
0
    break;
1038
0
  }
1039
0
  case StandardAttributes::NameLocation: {
1040
0
    auto nameLoc = loc.cast<NameLoc>();
1041
0
    os << '\"' << nameLoc.getName() << '\"';
1042
0
1043
0
    // Print the child if it isn't unknown.
1044
0
    auto childLoc = nameLoc.getChildLoc();
1045
0
    if (!childLoc.isa<UnknownLoc>()) {
1046
0
      os << '(';
1047
0
      printLocationInternal(childLoc, pretty);
1048
0
      os << ')';
1049
0
    }
1050
0
    break;
1051
0
  }
1052
0
  case StandardAttributes::CallSiteLocation: {
1053
0
    auto callLocation = loc.cast<CallSiteLoc>();
1054
0
    auto caller = callLocation.getCaller();
1055
0
    auto callee = callLocation.getCallee();
1056
0
    if (!pretty)
1057
0
      os << "callsite(";
1058
0
    printLocationInternal(callee, pretty);
1059
0
    if (pretty) {
1060
0
      if (callee.isa<NameLoc>()) {
1061
0
        if (caller.isa<FileLineColLoc>()) {
1062
0
          os << " at ";
1063
0
        } else {
1064
0
          os << newLine << " at ";
1065
0
        }
1066
0
      } else {
1067
0
        os << newLine << " at ";
1068
0
      }
1069
0
    } else {
1070
0
      os << " at ";
1071
0
    }
1072
0
    printLocationInternal(caller, pretty);
1073
0
    if (!pretty)
1074
0
      os << ")";
1075
0
    break;
1076
0
  }
1077
0
  case StandardAttributes::FusedLocation: {
1078
0
    auto fusedLoc = loc.cast<FusedLoc>();
1079
0
    if (!pretty)
1080
0
      os << "fused";
1081
0
    if (auto metadata = fusedLoc.getMetadata())
1082
0
      os << '<' << metadata << '>';
1083
0
    os << '[';
1084
0
    interleave(
1085
0
        fusedLoc.getLocations(),
1086
0
        [&](Location loc) { printLocationInternal(loc, pretty); },
1087
0
        [&]() { os << ", "; });
1088
0
    os << ']';
1089
0
    break;
1090
0
  }
1091
0
  }
1092
0
}
1093
1094
/// Print a floating point value in a way that the parser will be able to
1095
/// round-trip losslessly.
1096
0
static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1097
0
  // We would like to output the FP constant value in exponential notation,
1098
0
  // but we cannot do this if doing so will lose precision.  Check here to
1099
0
  // make sure that we only output it in exponential format if we can parse
1100
0
  // the value back and get the same value.
1101
0
  bool isInf = apValue.isInfinity();
1102
0
  bool isNaN = apValue.isNaN();
1103
0
  if (!isInf && !isNaN) {
1104
0
    SmallString<128> strValue;
1105
0
    apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1106
0
                     /*TruncateZero=*/false);
1107
0
1108
0
    // Check to make sure that the stringized number is not some string like
1109
0
    // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
1110
0
    // that the string matches the "[-+]?[0-9]" regex.
1111
0
    assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1112
0
            ((strValue[0] == '-' || strValue[0] == '+') &&
1113
0
             (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1114
0
           "[-+]?[0-9] regex does not match!");
1115
0
1116
0
    // Parse back the stringized version and check that the value is equal
1117
0
    // (i.e., there is no precision loss).
1118
0
    if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1119
0
      os << strValue;
1120
0
      return;
1121
0
    }
1122
0
1123
0
    // If it is not, use the default format of APFloat instead of the
1124
0
    // exponential notation.
1125
0
    strValue.clear();
1126
0
    apValue.toString(strValue);
1127
0
1128
0
    // Make sure that we can parse the default form as a float.
1129
0
    if (StringRef(strValue).contains('.')) {
1130
0
      os << strValue;
1131
0
      return;
1132
0
    }
1133
0
  }
1134
0
1135
0
  // Print special values in hexadecimal format. The sign bit should be included
1136
0
  // in the literal.
1137
0
  SmallVector<char, 16> str;
1138
0
  APInt apInt = apValue.bitcastToAPInt();
1139
0
  apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1140
0
                 /*formatAsCLiteral=*/true);
1141
0
  os << str;
1142
0
}
1143
1144
0
void ModulePrinter::printLocation(LocationAttr loc) {
1145
0
  if (printerFlags.shouldPrintDebugInfoPrettyForm()) {
1146
0
    printLocationInternal(loc, /*pretty=*/true);
1147
0
  } else {
1148
0
    os << "loc(";
1149
0
    printLocationInternal(loc);
1150
0
    os << ')';
1151
0
  }
1152
0
}
1153
1154
/// Returns if the given dialect symbol data is simple enough to print in the
1155
/// pretty form, i.e. without the enclosing "".
1156
0
static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1157
0
  // The name must start with an identifier.
1158
0
  if (symName.empty() || !isalpha(symName.front()))
1159
0
    return false;
1160
0
1161
0
  // Ignore all the characters that are valid in an identifier in the symbol
1162
0
  // name.
1163
0
  symName = symName.drop_while(
1164
0
      [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1165
0
  if (symName.empty())
1166
0
    return true;
1167
0
1168
0
  // If we got to an unexpected character, then it must be a <>.  Check those
1169
0
  // recursively.
1170
0
  if (symName.front() != '<' || symName.back() != '>')
1171
0
    return false;
1172
0
1173
0
  SmallVector<char, 8> nestedPunctuation;
1174
0
  do {
1175
0
    // If we ran out of characters, then we had a punctuation mismatch.
1176
0
    if (symName.empty())
1177
0
      return false;
1178
0
1179
0
    auto c = symName.front();
1180
0
    symName = symName.drop_front();
1181
0
1182
0
    switch (c) {
1183
0
    // We never allow null characters. This is an EOF indicator for the lexer
1184
0
    // which we could handle, but isn't important for any known dialect.
1185
0
    case '\0':
1186
0
      return false;
1187
0
    case '<':
1188
0
    case '[':
1189
0
    case '(':
1190
0
    case '{':
1191
0
      nestedPunctuation.push_back(c);
1192
0
      continue;
1193
0
    case '-':
1194
0
      // Treat `->` as a special token.
1195
0
      if (!symName.empty() && symName.front() == '>') {
1196
0
        symName = symName.drop_front();
1197
0
        continue;
1198
0
      }
1199
0
      break;
1200
0
    // Reject types with mismatched brackets.
1201
0
    case '>':
1202
0
      if (nestedPunctuation.pop_back_val() != '<')
1203
0
        return false;
1204
0
      break;
1205
0
    case ']':
1206
0
      if (nestedPunctuation.pop_back_val() != '[')
1207
0
        return false;
1208
0
      break;
1209
0
    case ')':
1210
0
      if (nestedPunctuation.pop_back_val() != '(')
1211
0
        return false;
1212
0
      break;
1213
0
    case '}':
1214
0
      if (nestedPunctuation.pop_back_val() != '{')
1215
0
        return false;
1216
0
      break;
1217
0
    default:
1218
0
      continue;
1219
0
    }
1220
0
1221
0
    // We're done when the punctuation is fully matched.
1222
0
  } while (!nestedPunctuation.empty());
1223
0
1224
0
  // If there were extra characters, then we failed.
1225
0
  return symName.empty();
1226
0
}
1227
1228
/// Print the given dialect symbol to the stream.
1229
static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1230
0
                               StringRef dialectName, StringRef symString) {
1231
0
  os << symPrefix << dialectName;
1232
0
1233
0
  // If this symbol name is simple enough, print it directly in pretty form,
1234
0
  // otherwise, we print it as an escaped string.
1235
0
  if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1236
0
    os << '.' << symString;
1237
0
    return;
1238
0
  }
1239
0
1240
0
  // TODO: escape the symbol name, it could contain " characters.
1241
0
  os << "<\"" << symString << "\">";
1242
0
}
1243
1244
/// Returns if the given string can be represented as a bare identifier.
1245
0
static bool isBareIdentifier(StringRef name) {
1246
0
  assert(!name.empty() && "invalid name");
1247
0
1248
0
  // By making this unsigned, the value passed in to isalnum will always be
1249
0
  // in the range 0-255. This is important when building with MSVC because
1250
0
  // its implementation will assert. This situation can arise when dealing
1251
0
  // with UTF-8 multibyte characters.
1252
0
  unsigned char firstChar = static_cast<unsigned char>(name[0]);
1253
0
  if (!isalpha(firstChar) && firstChar != '_')
1254
0
    return false;
1255
0
  return llvm::all_of(name.drop_front(), [](unsigned char c) {
1256
0
    return isalnum(c) || c == '_' || c == '$' || c == '.';
1257
0
  });
1258
0
}
1259
1260
/// Print the given string as a symbol reference. A symbol reference is
1261
/// represented as a string prefixed with '@'. The reference is surrounded with
1262
/// ""'s and escaped if it has any special or non-printable characters in it.
1263
0
static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1264
0
  assert(!symbolRef.empty() && "expected valid symbol reference");
1265
0
1266
0
  // If the symbol can be represented as a bare identifier, write it directly.
1267
0
  if (isBareIdentifier(symbolRef)) {
1268
0
    os << '@' << symbolRef;
1269
0
    return;
1270
0
  }
1271
0
1272
0
  // Otherwise, output the reference wrapped in quotes with proper escaping.
1273
0
  os << "@\"";
1274
0
  printEscapedString(symbolRef, os);
1275
0
  os << '"';
1276
0
}
1277
1278
// Print out a valid ElementsAttr that is succinct and can represent any
1279
// potential shape/type, for use when eliding a large ElementsAttr.
1280
//
1281
// We choose to use an opaque ElementsAttr literal with conspicuous content to
1282
// hopefully alert readers to the fact that this has been elided.
1283
//
1284
// Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1285
// accept the string "elided". The first string must be a registered dialect
1286
// name and the latter must be a hex constant.
1287
0
static void printElidedElementsAttr(raw_ostream &os) {
1288
0
  os << R"(opaque<"", "0xDEADBEEF">)";
1289
0
}
1290
1291
void ModulePrinter::printAttribute(Attribute attr,
1292
0
                                   AttrTypeElision typeElision) {
1293
0
  if (!attr) {
1294
0
    os << "<<NULL ATTRIBUTE>>";
1295
0
    return;
1296
0
  }
1297
0
1298
0
  // Check for an alias for this attribute.
1299
0
  if (state) {
1300
0
    Twine alias = state->getAliasState().getAttributeAlias(attr);
1301
0
    if (!alias.isTriviallyEmpty()) {
1302
0
      os << '#' << alias;
1303
0
      return;
1304
0
    }
1305
0
  }
1306
0
1307
0
  auto attrType = attr.getType();
1308
0
  switch (attr.getKind()) {
1309
0
  default:
1310
0
    return printDialectAttribute(attr);
1311
0
1312
0
  case StandardAttributes::Opaque: {
1313
0
    auto opaqueAttr = attr.cast<OpaqueAttr>();
1314
0
    printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1315
0
                       opaqueAttr.getAttrData());
1316
0
    break;
1317
0
  }
1318
0
  case StandardAttributes::Unit:
1319
0
    os << "unit";
1320
0
    break;
1321
0
  case StandardAttributes::Bool:
1322
0
    os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
1323
0
1324
0
    // BoolAttr always elides the type.
1325
0
    return;
1326
0
  case StandardAttributes::Dictionary:
1327
0
    os << '{';
1328
0
    interleaveComma(attr.cast<DictionaryAttr>().getValue(),
1329
0
                    [&](NamedAttribute attr) { printNamedAttribute(attr); });
1330
0
    os << '}';
1331
0
    break;
1332
0
  case StandardAttributes::Integer: {
1333
0
    auto intAttr = attr.cast<IntegerAttr>();
1334
0
    // Only print attributes as unsigned if they are explicitly unsigned or are
1335
0
    // signless 1-bit values.  Indexes, signed values, and multi-bit signless
1336
0
    // values print as signed.
1337
0
    bool isUnsigned =
1338
0
        attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1339
0
    intAttr.getValue().print(os, !isUnsigned);
1340
0
1341
0
    // IntegerAttr elides the type if I64.
1342
0
    if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1343
0
      return;
1344
0
    break;
1345
0
  }
1346
0
  case StandardAttributes::Float: {
1347
0
    auto floatAttr = attr.cast<FloatAttr>();
1348
0
    printFloatValue(floatAttr.getValue(), os);
1349
0
1350
0
    // FloatAttr elides the type if F64.
1351
0
    if (typeElision == AttrTypeElision::May && attrType.isF64())
1352
0
      return;
1353
0
    break;
1354
0
  }
1355
0
  case StandardAttributes::String:
1356
0
    os << '"';
1357
0
    printEscapedString(attr.cast<StringAttr>().getValue(), os);
1358
0
    os << '"';
1359
0
    break;
1360
0
  case StandardAttributes::Array:
1361
0
    os << '[';
1362
0
    interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
1363
0
      printAttribute(attr, AttrTypeElision::May);
1364
0
    });
1365
0
    os << ']';
1366
0
    break;
1367
0
  case StandardAttributes::AffineMap:
1368
0
    os << "affine_map<";
1369
0
    attr.cast<AffineMapAttr>().getValue().print(os);
1370
0
    os << '>';
1371
0
1372
0
    // AffineMap always elides the type.
1373
0
    return;
1374
0
  case StandardAttributes::IntegerSet:
1375
0
    os << "affine_set<";
1376
0
    attr.cast<IntegerSetAttr>().getValue().print(os);
1377
0
    os << '>';
1378
0
1379
0
    // IntegerSet always elides the type.
1380
0
    return;
1381
0
  case StandardAttributes::Type:
1382
0
    printType(attr.cast<TypeAttr>().getValue());
1383
0
    break;
1384
0
  case StandardAttributes::SymbolRef: {
1385
0
    auto refAttr = attr.dyn_cast<SymbolRefAttr>();
1386
0
    printSymbolReference(refAttr.getRootReference(), os);
1387
0
    for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1388
0
      os << "::";
1389
0
      printSymbolReference(nestedRef.getValue(), os);
1390
0
    }
1391
0
    break;
1392
0
  }
1393
0
  case StandardAttributes::OpaqueElements: {
1394
0
    auto eltsAttr = attr.cast<OpaqueElementsAttr>();
1395
0
    if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
1396
0
      printElidedElementsAttr(os);
1397
0
      break;
1398
0
    }
1399
0
    os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
1400
0
    os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
1401
0
    break;
1402
0
  }
1403
0
  case StandardAttributes::DenseIntOrFPElements: {
1404
0
    auto eltsAttr = attr.cast<DenseIntOrFPElementsAttr>();
1405
0
    if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
1406
0
      printElidedElementsAttr(os);
1407
0
      break;
1408
0
    }
1409
0
    os << "dense<";
1410
0
    printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true);
1411
0
    os << '>';
1412
0
    break;
1413
0
  }
1414
0
  case StandardAttributes::DenseStringElements: {
1415
0
    auto eltsAttr = attr.cast<DenseStringElementsAttr>();
1416
0
    if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
1417
0
      printElidedElementsAttr(os);
1418
0
      break;
1419
0
    }
1420
0
    os << "dense<";
1421
0
    printDenseStringElementsAttr(eltsAttr);
1422
0
    os << '>';
1423
0
    break;
1424
0
  }
1425
0
  case StandardAttributes::SparseElements: {
1426
0
    auto elementsAttr = attr.cast<SparseElementsAttr>();
1427
0
    if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) ||
1428
0
        printerFlags.shouldElideElementsAttr(elementsAttr.getValues())) {
1429
0
      printElidedElementsAttr(os);
1430
0
      break;
1431
0
    }
1432
0
    os << "sparse<";
1433
0
    printDenseIntOrFPElementsAttr(elementsAttr.getIndices(),
1434
0
                                  /*allowHex=*/false);
1435
0
    os << ", ";
1436
0
    printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
1437
0
    os << '>';
1438
0
    break;
1439
0
  }
1440
0
1441
0
  // Location attributes.
1442
0
  case StandardAttributes::CallSiteLocation:
1443
0
  case StandardAttributes::FileLineColLocation:
1444
0
  case StandardAttributes::FusedLocation:
1445
0
  case StandardAttributes::NameLocation:
1446
0
  case StandardAttributes::OpaqueLocation:
1447
0
  case StandardAttributes::UnknownLocation:
1448
0
    printLocation(attr.cast<LocationAttr>());
1449
0
    break;
1450
0
  }
1451
0
1452
0
  // Don't print the type if we must elide it, or if it is a None type.
1453
0
  if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1454
0
    os << " : ";
1455
0
    printType(attrType);
1456
0
  }
1457
0
}
1458
1459
/// Print the integer element of a DenseElementsAttr.
1460
static void printDenseIntElement(const APInt &value, raw_ostream &os,
1461
0
                                 bool isSigned) {
1462
0
  if (value.getBitWidth() == 1)
1463
0
    os << (value.getBoolValue() ? "true" : "false");
1464
0
  else
1465
0
    value.print(os, isSigned);
1466
0
}
1467
1468
static void
1469
printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1470
0
                           function_ref<void(unsigned)> printEltFn) {
1471
0
  // Special case for 0-d and splat tensors.
1472
0
  if (isSplat)
1473
0
    return printEltFn(0);
1474
0
1475
0
  // Special case for degenerate tensors.
1476
0
  auto numElements = type.getNumElements();
1477
0
  int64_t rank = type.getRank();
1478
0
  if (numElements == 0) {
1479
0
    for (int i = 0; i < rank; ++i)
1480
0
      os << '[';
1481
0
    for (int i = 0; i < rank; ++i)
1482
0
      os << ']';
1483
0
    return;
1484
0
  }
1485
0
1486
0
  // We use a mixed-radix counter to iterate through the shape. When we bump a
1487
0
  // non-least-significant digit, we emit a close bracket. When we next emit an
1488
0
  // element we re-open all closed brackets.
1489
0
1490
0
  // The mixed-radix counter, with radices in 'shape'.
1491
0
  SmallVector<unsigned, 4> counter(rank, 0);
1492
0
  // The number of brackets that have been opened and not closed.
1493
0
  unsigned openBrackets = 0;
1494
0
1495
0
  auto shape = type.getShape();
1496
0
  auto bumpCounter = [&] {
1497
0
    // Bump the least significant digit.
1498
0
    ++counter[rank - 1];
1499
0
    // Iterate backwards bubbling back the increment.
1500
0
    for (unsigned i = rank - 1; i > 0; --i)
1501
0
      if (counter[i] >= shape[i]) {
1502
0
        // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1503
0
        counter[i] = 0;
1504
0
        ++counter[i - 1];
1505
0
        --openBrackets;
1506
0
        os << ']';
1507
0
      }
1508
0
  };
1509
0
1510
0
  for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1511
0
    if (idx != 0)
1512
0
      os << ", ";
1513
0
    while (openBrackets++ < rank)
1514
0
      os << '[';
1515
0
    openBrackets = rank;
1516
0
    printEltFn(idx);
1517
0
    bumpCounter();
1518
0
  }
1519
0
  while (openBrackets-- > 0)
1520
0
    os << ']';
1521
0
}
1522
1523
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
1524
0
                                           bool allowHex) {
1525
0
  if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1526
0
    return printDenseStringElementsAttr(stringAttr);
1527
0
1528
0
  printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1529
0
                                allowHex);
1530
0
}
1531
1532
void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1533
0
                                                  bool allowHex) {
1534
0
  auto type = attr.getType();
1535
0
  auto elementType = type.getElementType();
1536
0
1537
0
  // Check to see if we should format this attribute as a hex string.
1538
0
  auto numElements = type.getNumElements();
1539
0
  if (!attr.isSplat() && allowHex &&
1540
0
      shouldPrintElementsAttrWithHex(numElements)) {
1541
0
    ArrayRef<char> rawData = attr.getRawData();
1542
0
    os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size()))
1543
0
       << "\"";
1544
0
    return;
1545
0
  }
1546
0
1547
0
  if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1548
0
    auto printComplexValue = [&](auto complexValues, auto printFn,
1549
0
                                 raw_ostream &os, auto &&... params) {
1550
0
      printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1551
0
        auto complexValue = *(complexValues.begin() + index);
1552
0
        os << "(";
1553
0
        printFn(complexValue.real(), os, params...);
1554
0
        os << ",";
1555
0
        printFn(complexValue.imag(), os, params...);
1556
0
        os << ")";
1557
0
      });
Unexecuted instantiation: AsmPrinter.cpp:_ZZZN12_GLOBAL__N_113ModulePrinter29printDenseIntOrFPElementsAttrEN4mlir24DenseIntOrFPElementsAttrEbENK3$_1clIN4llvm14iterator_rangeINS1_17DenseElementsAttr25ComplexIntElementIteratorEEEPFvRKNS5_5APIntERNS5_11raw_ostreamEbEJbEEEDaT_T0_SE_DpOT1_ENKUljE_clEj
Unexecuted instantiation: AsmPrinter.cpp:_ZZZN12_GLOBAL__N_113ModulePrinter29printDenseIntOrFPElementsAttrEN4mlir24DenseIntOrFPElementsAttrEbENK3$_1clIN4llvm14iterator_rangeINS1_17DenseElementsAttr27ComplexFloatElementIteratorEEEPFvRKNS5_7APFloatERNS5_11raw_ostreamEEJEEEDaT_T0_SE_DpOT1_ENKUljE_clEj
1558
0
    };
Unexecuted instantiation: AsmPrinter.cpp:_ZZN12_GLOBAL__N_113ModulePrinter29printDenseIntOrFPElementsAttrEN4mlir24DenseIntOrFPElementsAttrEbENK3$_1clIN4llvm14iterator_rangeINS1_17DenseElementsAttr25ComplexIntElementIteratorEEEPFvRKNS5_5APIntERNS5_11raw_ostreamEbEJbEEEDaT_T0_SE_DpOT1_
Unexecuted instantiation: AsmPrinter.cpp:_ZZN12_GLOBAL__N_113ModulePrinter29printDenseIntOrFPElementsAttrEN4mlir24DenseIntOrFPElementsAttrEbENK3$_1clIN4llvm14iterator_rangeINS1_17DenseElementsAttr27ComplexFloatElementIteratorEEEPFvRKNS5_7APFloatERNS5_11raw_ostreamEEJEEEDaT_T0_SE_DpOT1_
1559
0
1560
0
    Type complexElementType = complexTy.getElementType();
1561
0
    if (complexElementType.isa<IntegerType>())
1562
0
      printComplexValue(attr.getComplexIntValues(), printDenseIntElement, os,
1563
0
                        /*isSigned=*/!complexElementType.isUnsignedInteger());
1564
0
    else
1565
0
      printComplexValue(attr.getComplexFloatValues(), printFloatValue, os);
1566
0
  } else if (elementType.isIntOrIndex()) {
1567
0
    bool isSigned = !elementType.isUnsignedInteger();
1568
0
    auto intValues = attr.getIntValues();
1569
0
    printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1570
0
      printDenseIntElement(*(intValues.begin() + index), os, isSigned);
1571
0
    });
1572
0
  } else {
1573
0
    assert(elementType.isa<FloatType>() && "unexpected element type");
1574
0
    auto floatValues = attr.getFloatValues();
1575
0
    printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1576
0
      printFloatValue(*(floatValues.begin() + index), os);
1577
0
    });
1578
0
  }
1579
0
}
1580
1581
0
void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
1582
0
  ArrayRef<StringRef> data = attr.getRawStringData();
1583
0
  auto printFn = [&](unsigned index) {
1584
0
    os << "\"";
1585
0
    printEscapedString(data[index], os);
1586
0
    os << "\"";
1587
0
  };
1588
0
  printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
1589
0
}
1590
1591
0
void ModulePrinter::printType(Type type) {
1592
0
  if (!type) {
1593
0
    os << "<<NULL TYPE>>";
1594
0
    return;
1595
0
  }
1596
0
1597
0
  // Check for an alias for this type.
1598
0
  if (state) {
1599
0
    StringRef alias = state->getAliasState().getTypeAlias(type);
1600
0
    if (!alias.empty()) {
1601
0
      os << '!' << alias;
1602
0
      return;
1603
0
    }
1604
0
  }
1605
0
1606
0
  switch (type.getKind()) {
1607
0
  default:
1608
0
    return printDialectType(type);
1609
0
1610
0
  case Type::Kind::Opaque: {
1611
0
    auto opaqueTy = type.cast<OpaqueType>();
1612
0
    printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1613
0
                       opaqueTy.getTypeData());
1614
0
    return;
1615
0
  }
1616
0
  case StandardTypes::Index:
1617
0
    os << "index";
1618
0
    return;
1619
0
  case StandardTypes::BF16:
1620
0
    os << "bf16";
1621
0
    return;
1622
0
  case StandardTypes::F16:
1623
0
    os << "f16";
1624
0
    return;
1625
0
  case StandardTypes::F32:
1626
0
    os << "f32";
1627
0
    return;
1628
0
  case StandardTypes::F64:
1629
0
    os << "f64";
1630
0
    return;
1631
0
1632
0
  case StandardTypes::Integer: {
1633
0
    auto integer = type.cast<IntegerType>();
1634
0
    if (integer.isSigned())
1635
0
      os << 's';
1636
0
    else if (integer.isUnsigned())
1637
0
      os << 'u';
1638
0
    os << 'i' << integer.getWidth();
1639
0
    return;
1640
0
  }
1641
0
  case Type::Kind::Function: {
1642
0
    auto func = type.cast<FunctionType>();
1643
0
    os << '(';
1644
0
    interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
1645
0
    os << ") -> ";
1646
0
    auto results = func.getResults();
1647
0
    if (results.size() == 1 && !results[0].isa<FunctionType>())
1648
0
      os << results[0];
1649
0
    else {
1650
0
      os << '(';
1651
0
      interleaveComma(results, [&](Type type) { printType(type); });
1652
0
      os << ')';
1653
0
    }
1654
0
    return;
1655
0
  }
1656
0
  case StandardTypes::Vector: {
1657
0
    auto v = type.cast<VectorType>();
1658
0
    os << "vector<";
1659
0
    for (auto dim : v.getShape())
1660
0
      os << dim << 'x';
1661
0
    os << v.getElementType() << '>';
1662
0
    return;
1663
0
  }
1664
0
  case StandardTypes::RankedTensor: {
1665
0
    auto v = type.cast<RankedTensorType>();
1666
0
    os << "tensor<";
1667
0
    for (auto dim : v.getShape()) {
1668
0
      if (dim < 0)
1669
0
        os << '?';
1670
0
      else
1671
0
        os << dim;
1672
0
      os << 'x';
1673
0
    }
1674
0
    os << v.getElementType() << '>';
1675
0
    return;
1676
0
  }
1677
0
  case StandardTypes::UnrankedTensor: {
1678
0
    auto v = type.cast<UnrankedTensorType>();
1679
0
    os << "tensor<*x";
1680
0
    printType(v.getElementType());
1681
0
    os << '>';
1682
0
    return;
1683
0
  }
1684
0
  case StandardTypes::MemRef: {
1685
0
    auto v = type.cast<MemRefType>();
1686
0
    os << "memref<";
1687
0
    for (auto dim : v.getShape()) {
1688
0
      if (dim < 0)
1689
0
        os << '?';
1690
0
      else
1691
0
        os << dim;
1692
0
      os << 'x';
1693
0
    }
1694
0
    printType(v.getElementType());
1695
0
    for (auto map : v.getAffineMaps()) {
1696
0
      os << ", ";
1697
0
      printAttribute(AffineMapAttr::get(map));
1698
0
    }
1699
0
    // Only print the memory space if it is the non-default one.
1700
0
    if (v.getMemorySpace())
1701
0
      os << ", " << v.getMemorySpace();
1702
0
    os << '>';
1703
0
    return;
1704
0
  }
1705
0
  case StandardTypes::UnrankedMemRef: {
1706
0
    auto v = type.cast<UnrankedMemRefType>();
1707
0
    os << "memref<*x";
1708
0
    printType(v.getElementType());
1709
0
    os << '>';
1710
0
    return;
1711
0
  }
1712
0
  case StandardTypes::Complex:
1713
0
    os << "complex<";
1714
0
    printType(type.cast<ComplexType>().getElementType());
1715
0
    os << '>';
1716
0
    return;
1717
0
  case StandardTypes::Tuple: {
1718
0
    auto tuple = type.cast<TupleType>();
1719
0
    os << "tuple<";
1720
0
    interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); });
1721
0
    os << '>';
1722
0
    return;
1723
0
  }
1724
0
  case StandardTypes::None:
1725
0
    os << "none";
1726
0
    return;
1727
0
  }
1728
0
}
1729
1730
void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1731
                                          ArrayRef<StringRef> elidedAttrs,
1732
0
                                          bool withKeyword) {
1733
0
  // If there are no attributes, then there is nothing to be done.
1734
0
  if (attrs.empty())
1735
0
    return;
1736
0
1737
0
  // Filter out any attributes that shouldn't be included.
1738
0
  SmallVector<NamedAttribute, 8> filteredAttrs(
1739
0
      llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
1740
0
        return !llvm::is_contained(elidedAttrs, attr.first.strref());
1741
0
      }));
1742
0
1743
0
  // If there are no attributes left to print after filtering, then we're done.
1744
0
  if (filteredAttrs.empty())
1745
0
    return;
1746
0
1747
0
  // Print the 'attributes' keyword if necessary.
1748
0
  if (withKeyword)
1749
0
    os << " attributes";
1750
0
1751
0
  // Otherwise, print them all out in braces.
1752
0
  os << " {";
1753
0
  interleaveComma(filteredAttrs,
1754
0
                  [&](NamedAttribute attr) { printNamedAttribute(attr); });
1755
0
  os << '}';
1756
0
}
1757
1758
0
void ModulePrinter::printNamedAttribute(NamedAttribute attr) {
1759
0
  if (isBareIdentifier(attr.first)) {
1760
0
    os << attr.first;
1761
0
  } else {
1762
0
    os << '"';
1763
0
    printEscapedString(attr.first.strref(), os);
1764
0
    os << '"';
1765
0
  }
1766
0
1767
0
  // Pretty printing elides the attribute value for unit attributes.
1768
0
  if (attr.second.isa<UnitAttr>())
1769
0
    return;
1770
0
1771
0
  os << " = ";
1772
0
  printAttribute(attr.second);
1773
0
}
1774
1775
//===----------------------------------------------------------------------===//
1776
// CustomDialectAsmPrinter
1777
//===----------------------------------------------------------------------===//
1778
1779
namespace {
1780
/// This class provides the main specialization of the DialectAsmPrinter that is
1781
/// used to provide support for print attributes and types. This hooks allows
1782
/// for dialects to hook into the main ModulePrinter.
1783
struct CustomDialectAsmPrinter : public DialectAsmPrinter {
1784
public:
1785
0
  CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
1786
0
  ~CustomDialectAsmPrinter() override {}
1787
1788
0
  raw_ostream &getStream() const override { return printer.getStream(); }
1789
1790
  /// Print the given attribute to the stream.
1791
0
  void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
1792
1793
  /// Print the given floating point value in a stablized form.
1794
0
  void printFloat(const APFloat &value) override {
1795
0
    printFloatValue(value, getStream());
1796
0
  }
1797
1798
  /// Print the given type to the stream.
1799
0
  void printType(Type type) override { printer.printType(type); }
1800
1801
  /// The main module printer.
1802
  ModulePrinter &printer;
1803
};
1804
} // end anonymous namespace
1805
1806
0
void ModulePrinter::printDialectAttribute(Attribute attr) {
1807
0
  auto &dialect = attr.getDialect();
1808
0
1809
0
  // Ask the dialect to serialize the attribute to a string.
1810
0
  std::string attrName;
1811
0
  {
1812
0
    llvm::raw_string_ostream attrNameStr(attrName);
1813
0
    ModulePrinter subPrinter(attrNameStr, printerFlags, state);
1814
0
    CustomDialectAsmPrinter printer(subPrinter);
1815
0
    dialect.printAttribute(attr, printer);
1816
0
  }
1817
0
  printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
1818
0
}
1819
1820
0
void ModulePrinter::printDialectType(Type type) {
1821
0
  auto &dialect = type.getDialect();
1822
0
1823
0
  // Ask the dialect to serialize the type to a string.
1824
0
  std::string typeName;
1825
0
  {
1826
0
    llvm::raw_string_ostream typeNameStr(typeName);
1827
0
    ModulePrinter subPrinter(typeNameStr, printerFlags, state);
1828
0
    CustomDialectAsmPrinter printer(subPrinter);
1829
0
    dialect.printType(type, printer);
1830
0
  }
1831
0
  printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
1832
0
}
1833
1834
//===----------------------------------------------------------------------===//
1835
// Affine expressions and maps
1836
//===----------------------------------------------------------------------===//
1837
1838
void ModulePrinter::printAffineExpr(
1839
0
    AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
1840
0
  printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
1841
0
}
1842
1843
void ModulePrinter::printAffineExprInternal(
1844
    AffineExpr expr, BindingStrength enclosingTightness,
1845
0
    function_ref<void(unsigned, bool)> printValueName) {
1846
0
  const char *binopSpelling = nullptr;
1847
0
  switch (expr.getKind()) {
1848
0
  case AffineExprKind::SymbolId: {
1849
0
    unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
1850
0
    if (printValueName)
1851
0
      printValueName(pos, /*isSymbol=*/true);
1852
0
    else
1853
0
      os << 's' << pos;
1854
0
    return;
1855
0
  }
1856
0
  case AffineExprKind::DimId: {
1857
0
    unsigned pos = expr.cast<AffineDimExpr>().getPosition();
1858
0
    if (printValueName)
1859
0
      printValueName(pos, /*isSymbol=*/false);
1860
0
    else
1861
0
      os << 'd' << pos;
1862
0
    return;
1863
0
  }
1864
0
  case AffineExprKind::Constant:
1865
0
    os << expr.cast<AffineConstantExpr>().getValue();
1866
0
    return;
1867
0
  case AffineExprKind::Add:
1868
0
    binopSpelling = " + ";
1869
0
    break;
1870
0
  case AffineExprKind::Mul:
1871
0
    binopSpelling = " * ";
1872
0
    break;
1873
0
  case AffineExprKind::FloorDiv:
1874
0
    binopSpelling = " floordiv ";
1875
0
    break;
1876
0
  case AffineExprKind::CeilDiv:
1877
0
    binopSpelling = " ceildiv ";
1878
0
    break;
1879
0
  case AffineExprKind::Mod:
1880
0
    binopSpelling = " mod ";
1881
0
    break;
1882
0
  }
1883
0
1884
0
  auto binOp = expr.cast<AffineBinaryOpExpr>();
1885
0
  AffineExpr lhsExpr = binOp.getLHS();
1886
0
  AffineExpr rhsExpr = binOp.getRHS();
1887
0
1888
0
  // Handle tightly binding binary operators.
1889
0
  if (binOp.getKind() != AffineExprKind::Add) {
1890
0
    if (enclosingTightness == BindingStrength::Strong)
1891
0
      os << '(';
1892
0
1893
0
    // Pretty print multiplication with -1.
1894
0
    auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
1895
0
    if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
1896
0
        rhsConst.getValue() == -1) {
1897
0
      os << "-";
1898
0
      printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
1899
0
      if (enclosingTightness == BindingStrength::Strong)
1900
0
        os << ')';
1901
0
      return;
1902
0
    }
1903
0
1904
0
    printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
1905
0
1906
0
    os << binopSpelling;
1907
0
    printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
1908
0
1909
0
    if (enclosingTightness == BindingStrength::Strong)
1910
0
      os << ')';
1911
0
    return;
1912
0
  }
1913
0
1914
0
  // Print out special "pretty" forms for add.
1915
0
  if (enclosingTightness == BindingStrength::Strong)
1916
0
    os << '(';
1917
0
1918
0
  // Pretty print addition to a product that has a negative operand as a
1919
0
  // subtraction.
1920
0
  if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
1921
0
    if (rhs.getKind() == AffineExprKind::Mul) {
1922
0
      AffineExpr rrhsExpr = rhs.getRHS();
1923
0
      if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
1924
0
        if (rrhs.getValue() == -1) {
1925
0
          printAffineExprInternal(lhsExpr, BindingStrength::Weak,
1926
0
                                  printValueName);
1927
0
          os << " - ";
1928
0
          if (rhs.getLHS().getKind() == AffineExprKind::Add) {
1929
0
            printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
1930
0
                                    printValueName);
1931
0
          } else {
1932
0
            printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
1933
0
                                    printValueName);
1934
0
          }
1935
0
1936
0
          if (enclosingTightness == BindingStrength::Strong)
1937
0
            os << ')';
1938
0
          return;
1939
0
        }
1940
0
1941
0
        if (rrhs.getValue() < -1) {
1942
0
          printAffineExprInternal(lhsExpr, BindingStrength::Weak,
1943
0
                                  printValueName);
1944
0
          os << " - ";
1945
0
          printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
1946
0
                                  printValueName);
1947
0
          os << " * " << -rrhs.getValue();
1948
0
          if (enclosingTightness == BindingStrength::Strong)
1949
0
            os << ')';
1950
0
          return;
1951
0
        }
1952
0
      }
1953
0
    }
1954
0
  }
1955
0
1956
0
  // Pretty print addition to a negative number as a subtraction.
1957
0
  if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
1958
0
    if (rhsConst.getValue() < 0) {
1959
0
      printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
1960
0
      os << " - " << -rhsConst.getValue();
1961
0
      if (enclosingTightness == BindingStrength::Strong)
1962
0
        os << ')';
1963
0
      return;
1964
0
    }
1965
0
  }
1966
0
1967
0
  printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
1968
0
1969
0
  os << " + ";
1970
0
  printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
1971
0
1972
0
  if (enclosingTightness == BindingStrength::Strong)
1973
0
    os << ')';
1974
0
}
1975
1976
0
void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
1977
0
  printAffineExprInternal(expr, BindingStrength::Weak);
1978
0
  isEq ? os << " == 0" : os << " >= 0";
1979
0
}
1980
1981
0
void ModulePrinter::printAffineMap(AffineMap map) {
1982
0
  // Dimension identifiers.
1983
0
  os << '(';
1984
0
  for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
1985
0
    os << 'd' << i << ", ";
1986
0
  if (map.getNumDims() >= 1)
1987
0
    os << 'd' << map.getNumDims() - 1;
1988
0
  os << ')';
1989
0
1990
0
  // Symbolic identifiers.
1991
0
  if (map.getNumSymbols() != 0) {
1992
0
    os << '[';
1993
0
    for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
1994
0
      os << 's' << i << ", ";
1995
0
    if (map.getNumSymbols() >= 1)
1996
0
      os << 's' << map.getNumSymbols() - 1;
1997
0
    os << ']';
1998
0
  }
1999
0
2000
0
  // Result affine expressions.
2001
0
  os << " -> (";
2002
0
  interleaveComma(map.getResults(),
2003
0
                  [&](AffineExpr expr) { printAffineExpr(expr); });
2004
0
  os << ')';
2005
0
}
2006
2007
0
void ModulePrinter::printIntegerSet(IntegerSet set) {
2008
0
  // Dimension identifiers.
2009
0
  os << '(';
2010
0
  for (unsigned i = 1; i < set.getNumDims(); ++i)
2011
0
    os << 'd' << i - 1 << ", ";
2012
0
  if (set.getNumDims() >= 1)
2013
0
    os << 'd' << set.getNumDims() - 1;
2014
0
  os << ')';
2015
0
2016
0
  // Symbolic identifiers.
2017
0
  if (set.getNumSymbols() != 0) {
2018
0
    os << '[';
2019
0
    for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2020
0
      os << 's' << i << ", ";
2021
0
    if (set.getNumSymbols() >= 1)
2022
0
      os << 's' << set.getNumSymbols() - 1;
2023
0
    os << ']';
2024
0
  }
2025
0
2026
0
  // Print constraints.
2027
0
  os << " : (";
2028
0
  int numConstraints = set.getNumConstraints();
2029
0
  for (int i = 1; i < numConstraints; ++i) {
2030
0
    printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2031
0
    os << ", ";
2032
0
  }
2033
0
  if (numConstraints >= 1)
2034
0
    printAffineConstraint(set.getConstraint(numConstraints - 1),
2035
0
                          set.isEq(numConstraints - 1));
2036
0
  os << ')';
2037
0
}
2038
2039
//===----------------------------------------------------------------------===//
2040
// OperationPrinter
2041
//===----------------------------------------------------------------------===//
2042
2043
namespace {
2044
/// This class contains the logic for printing operations, regions, and blocks.
2045
class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
2046
public:
2047
  explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
2048
                            AsmStateImpl &state)
2049
0
      : ModulePrinter(os, flags, &state) {}
2050
2051
  /// Print the given top-level module.
2052
  void print(ModuleOp op);
2053
  /// Print the given operation with its indent and location.
2054
  void print(Operation *op);
2055
  /// Print the bare location, not including indentation/location/etc.
2056
  void printOperation(Operation *op);
2057
  /// Print the given operation in the generic form.
2058
  void printGenericOp(Operation *op) override;
2059
2060
  /// Print the name of the given block.
2061
  void printBlockName(Block *block);
2062
2063
  /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2064
  /// block are not printed. If 'printBlockTerminator' is false, the terminator
2065
  /// operation of the block is not printed.
2066
  void print(Block *block, bool printBlockArgs = true,
2067
             bool printBlockTerminator = true);
2068
2069
  /// Print the ID of the given value, optionally with its result number.
2070
  void printValueID(Value value, bool printResultNo = true,
2071
                    raw_ostream *streamOverride = nullptr) const;
2072
2073
  //===--------------------------------------------------------------------===//
2074
  // OpAsmPrinter methods
2075
  //===--------------------------------------------------------------------===//
2076
2077
  /// Return the current stream of the printer.
2078
0
  raw_ostream &getStream() const override { return os; }
2079
2080
  /// Print the given type.
2081
0
  void printType(Type type) override { ModulePrinter::printType(type); }
2082
2083
  /// Print the given attribute.
2084
0
  void printAttribute(Attribute attr) override {
2085
0
    ModulePrinter::printAttribute(attr);
2086
0
  }
2087
2088
  /// Print the given attribute without its type. The corresponding parser must
2089
  /// provide a valid type for the attribute.
2090
0
  void printAttributeWithoutType(Attribute attr) override {
2091
0
    ModulePrinter::printAttribute(attr, AttrTypeElision::Must);
2092
0
  }
2093
2094
  /// Print the ID for the given value.
2095
0
  void printOperand(Value value) override { printValueID(value); }
2096
0
  void printOperand(Value value, raw_ostream &os) override {
2097
0
    printValueID(value, /*printResultNo=*/true, &os);
2098
0
  }
2099
2100
  /// Print an optional attribute dictionary with a given set of elided values.
2101
  void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2102
0
                             ArrayRef<StringRef> elidedAttrs = {}) override {
2103
0
    ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
2104
0
  }
2105
  void printOptionalAttrDictWithKeyword(
2106
      ArrayRef<NamedAttribute> attrs,
2107
0
      ArrayRef<StringRef> elidedAttrs = {}) override {
2108
0
    ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs,
2109
0
                                         /*withKeyword=*/true);
2110
0
  }
2111
2112
  /// Print the given successor.
2113
  void printSuccessor(Block *successor) override;
2114
2115
  /// Print an operation successor with the operands used for the block
2116
  /// arguments.
2117
  void printSuccessorAndUseList(Block *successor,
2118
                                ValueRange succOperands) override;
2119
2120
  /// Print the given region.
2121
  void printRegion(Region &region, bool printEntryBlockArgs,
2122
                   bool printBlockTerminators) override;
2123
2124
  /// Renumber the arguments for the specified region to the same names as the
2125
  /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2126
  /// operations. If any entry in namesToUse is null, the corresponding
2127
  /// argument name is left alone.
2128
0
  void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
2129
0
    state->getSSANameState().shadowRegionArgs(region, namesToUse);
2130
0
  }
2131
2132
  /// Print the given affine map with the symbol and dimension operands printed
2133
  /// inline with the map.
2134
  void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2135
                              ValueRange operands) override;
2136
2137
  /// Print the given string as a symbol reference.
2138
0
  void printSymbolName(StringRef symbolRef) override {
2139
0
    ::printSymbolReference(symbolRef, os);
2140
0
  }
2141
2142
private:
2143
  /// The number of spaces used for indenting nested operations.
2144
  const static unsigned indentWidth = 2;
2145
2146
  // This is the current indentation level for nested structures.
2147
  unsigned currentIndent = 0;
2148
};
2149
} // end anonymous namespace
2150
2151
0
void OperationPrinter::print(ModuleOp op) {
2152
0
  // Output the aliases at the top level.
2153
0
  state->getAliasState().printAttributeAliases(os, newLine);
2154
0
  state->getAliasState().printTypeAliases(os, newLine);
2155
0
2156
0
  // Print the module.
2157
0
  print(op.getOperation());
2158
0
}
2159
2160
0
void OperationPrinter::print(Operation *op) {
2161
0
  // Track the location of this operation.
2162
0
  state->registerOperationLocation(op, newLine.curLine, currentIndent);
2163
0
2164
0
  os.indent(currentIndent);
2165
0
  printOperation(op);
2166
0
  printTrailingLocation(op->getLoc());
2167
0
}
2168
2169
0
void OperationPrinter::printOperation(Operation *op) {
2170
0
  if (size_t numResults = op->getNumResults()) {
2171
0
    auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2172
0
      printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2173
0
      if (resultCount > 1)
2174
0
        os << ':' << resultCount;
2175
0
    };
2176
0
2177
0
    // Check to see if this operation has multiple result groups.
2178
0
    ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2179
0
    if (!resultGroups.empty()) {
2180
0
      // Interleave the groups excluding the last one, this one will be handled
2181
0
      // separately.
2182
0
      interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2183
0
        printResultGroup(resultGroups[i],
2184
0
                         resultGroups[i + 1] - resultGroups[i]);
2185
0
      });
2186
0
      os << ", ";
2187
0
      printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2188
0
2189
0
    } else {
2190
0
      printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2191
0
    }
2192
0
2193
0
    os << " = ";
2194
0
  }
2195
0
2196
0
  // If requested, always print the generic form.
2197
0
  if (!printerFlags.shouldPrintGenericOpForm()) {
2198
0
    // Check to see if this is a known operation.  If so, use the registered
2199
0
    // custom printer hook.
2200
0
    if (auto *opInfo = op->getAbstractOperation()) {
2201
0
      opInfo->printAssembly(op, *this);
2202
0
      return;
2203
0
    }
2204
0
  }
2205
0
2206
0
  // Otherwise print with the generic assembly form.
2207
0
  printGenericOp(op);
2208
0
}
2209
2210
0
void OperationPrinter::printGenericOp(Operation *op) {
2211
0
  os << '"';
2212
0
  printEscapedString(op->getName().getStringRef(), os);
2213
0
  os << "\"(";
2214
0
  interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2215
0
  os << ')';
2216
0
2217
0
  // For terminators, print the list of successors and their operands.
2218
0
  if (op->getNumSuccessors() != 0) {
2219
0
    os << '[';
2220
0
    interleaveComma(op->getSuccessors(),
2221
0
                    [&](Block *successor) { printBlockName(successor); });
2222
0
    os << ']';
2223
0
  }
2224
0
2225
0
  // Print regions.
2226
0
  if (op->getNumRegions() != 0) {
2227
0
    os << " (";
2228
0
    interleaveComma(op->getRegions(), [&](Region &region) {
2229
0
      printRegion(region, /*printEntryBlockArgs=*/true,
2230
0
                  /*printBlockTerminators=*/true);
2231
0
    });
2232
0
    os << ')';
2233
0
  }
2234
0
2235
0
  auto attrs = op->getAttrs();
2236
0
  printOptionalAttrDict(attrs);
2237
0
2238
0
  // Print the type signature of the operation.
2239
0
  os << " : ";
2240
0
  printFunctionalType(op);
2241
0
}
2242
2243
0
void OperationPrinter::printBlockName(Block *block) {
2244
0
  auto id = state->getSSANameState().getBlockID(block);
2245
0
  if (id != SSANameState::NameSentinel)
2246
0
    os << "^bb" << id;
2247
0
  else
2248
0
    os << "^INVALIDBLOCK";
2249
0
}
2250
2251
void OperationPrinter::print(Block *block, bool printBlockArgs,
2252
0
                             bool printBlockTerminator) {
2253
0
  // Print the block label and argument list if requested.
2254
0
  if (printBlockArgs) {
2255
0
    os.indent(currentIndent);
2256
0
    printBlockName(block);
2257
0
2258
0
    // Print the argument list if non-empty.
2259
0
    if (!block->args_empty()) {
2260
0
      os << '(';
2261
0
      interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2262
0
        printValueID(arg);
2263
0
        os << ": ";
2264
0
        printType(arg.getType());
2265
0
      });
2266
0
      os << ')';
2267
0
    }
2268
0
    os << ':';
2269
0
2270
0
    // Print out some context information about the predecessors of this block.
2271
0
    if (!block->getParent()) {
2272
0
      os << "  // block is not in a region!";
2273
0
    } else if (block->hasNoPredecessors()) {
2274
0
      os << "  // no predecessors";
2275
0
    } else if (auto *pred = block->getSinglePredecessor()) {
2276
0
      os << "  // pred: ";
2277
0
      printBlockName(pred);
2278
0
    } else {
2279
0
      // We want to print the predecessors in increasing numeric order, not in
2280
0
      // whatever order the use-list is in, so gather and sort them.
2281
0
      SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
2282
0
      for (auto *pred : block->getPredecessors())
2283
0
        predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
2284
0
      llvm::array_pod_sort(predIDs.begin(), predIDs.end());
2285
0
2286
0
      os << "  // " << predIDs.size() << " preds: ";
2287
0
2288
0
      interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
2289
0
        printBlockName(pred.second);
2290
0
      });
2291
0
    }
2292
0
    os << newLine;
2293
0
  }
2294
0
2295
0
  currentIndent += indentWidth;
2296
0
  auto range = llvm::make_range(
2297
0
      block->getOperations().begin(),
2298
0
      std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1));
2299
0
  for (auto &op : range) {
2300
0
    print(&op);
2301
0
    os << newLine;
2302
0
  }
2303
0
  currentIndent -= indentWidth;
2304
0
}
2305
2306
void OperationPrinter::printValueID(Value value, bool printResultNo,
2307
0
                                    raw_ostream *streamOverride) const {
2308
0
  state->getSSANameState().printValueID(value, printResultNo,
2309
0
                                        streamOverride ? *streamOverride : os);
2310
0
}
2311
2312
0
void OperationPrinter::printSuccessor(Block *successor) {
2313
0
  printBlockName(successor);
2314
0
}
2315
2316
void OperationPrinter::printSuccessorAndUseList(Block *successor,
2317
0
                                                ValueRange succOperands) {
2318
0
  printBlockName(successor);
2319
0
  if (succOperands.empty())
2320
0
    return;
2321
0
2322
0
  os << '(';
2323
0
  interleaveComma(succOperands,
2324
0
                  [this](Value operand) { printValueID(operand); });
2325
0
  os << " : ";
2326
0
  interleaveComma(succOperands,
2327
0
                  [this](Value operand) { printType(operand.getType()); });
2328
0
  os << ')';
2329
0
}
2330
2331
void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
2332
0
                                   bool printBlockTerminators) {
2333
0
  os << " {" << newLine;
2334
0
  if (!region.empty()) {
2335
0
    auto *entryBlock = &region.front();
2336
0
    print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0,
2337
0
          printBlockTerminators);
2338
0
    for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2339
0
      print(&b);
2340
0
  }
2341
0
  os.indent(currentIndent) << "}";
2342
0
}
2343
2344
void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2345
0
                                              ValueRange operands) {
2346
0
  AffineMap map = mapAttr.getValue();
2347
0
  unsigned numDims = map.getNumDims();
2348
0
  auto printValueName = [&](unsigned pos, bool isSymbol) {
2349
0
    unsigned index = isSymbol ? numDims + pos : pos;
2350
0
    assert(index < operands.size());
2351
0
    if (isSymbol)
2352
0
      os << "symbol(";
2353
0
    printValueID(operands[index]);
2354
0
    if (isSymbol)
2355
0
      os << ')';
2356
0
  };
2357
0
2358
0
  interleaveComma(map.getResults(), [&](AffineExpr expr) {
2359
0
    printAffineExpr(expr, printValueName);
2360
0
  });
2361
0
}
2362
2363
//===----------------------------------------------------------------------===//
2364
// print and dump methods
2365
//===----------------------------------------------------------------------===//
2366
2367
0
void Attribute::print(raw_ostream &os) const {
2368
0
  ModulePrinter(os).printAttribute(*this);
2369
0
}
2370
2371
0
void Attribute::dump() const {
2372
0
  print(llvm::errs());
2373
0
  llvm::errs() << "\n";
2374
0
}
2375
2376
0
void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); }
2377
2378
0
void Type::dump() { print(llvm::errs()); }
2379
2380
0
void AffineMap::dump() const {
2381
0
  print(llvm::errs());
2382
0
  llvm::errs() << "\n";
2383
0
}
2384
2385
0
void IntegerSet::dump() const {
2386
0
  print(llvm::errs());
2387
0
  llvm::errs() << "\n";
2388
0
}
2389
2390
0
void AffineExpr::print(raw_ostream &os) const {
2391
0
  if (!expr) {
2392
0
    os << "<<NULL AFFINE EXPR>>";
2393
0
    return;
2394
0
  }
2395
0
  ModulePrinter(os).printAffineExpr(*this);
2396
0
}
2397
2398
0
void AffineExpr::dump() const {
2399
0
  print(llvm::errs());
2400
0
  llvm::errs() << "\n";
2401
0
}
2402
2403
0
void AffineMap::print(raw_ostream &os) const {
2404
0
  if (!map) {
2405
0
    os << "<<NULL AFFINE MAP>>";
2406
0
    return;
2407
0
  }
2408
0
  ModulePrinter(os).printAffineMap(*this);
2409
0
}
2410
2411
0
void IntegerSet::print(raw_ostream &os) const {
2412
0
  ModulePrinter(os).printIntegerSet(*this);
2413
0
}
2414
2415
0
void Value::print(raw_ostream &os) {
2416
0
  if (auto *op = getDefiningOp())
2417
0
    return op->print(os);
2418
0
  // TODO: Improve this.
2419
0
  assert(isa<BlockArgument>());
2420
0
  os << "<block argument>\n";
2421
0
}
2422
0
void Value::print(raw_ostream &os, AsmState &state) {
2423
0
  if (auto *op = getDefiningOp())
2424
0
    return op->print(os, state);
2425
0
2426
0
  // TODO: Improve this.
2427
0
  assert(isa<BlockArgument>());
2428
0
  os << "<block argument>\n";
2429
0
}
2430
2431
0
void Value::dump() {
2432
0
  print(llvm::errs());
2433
0
  llvm::errs() << "\n";
2434
0
}
2435
2436
0
void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2437
0
  // TODO(riverriddle) This doesn't necessarily capture all potential cases.
2438
0
  // Currently, region arguments can be shadowed when printing the main
2439
0
  // operation. If the IR hasn't been printed, this will produce the old SSA
2440
0
  // name and not the shadowed name.
2441
0
  state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2442
0
                                                 os);
2443
0
}
2444
2445
0
void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
2446
0
  // Find the operation to number from based upon the provided flags.
2447
0
  Operation *printedOp = this;
2448
0
  bool shouldUseLocalScope = flags.shouldUseLocalScope();
2449
0
  do {
2450
0
    // If we are printing local scope, stop at the first operation that is
2451
0
    // isolated from above.
2452
0
    if (shouldUseLocalScope && printedOp->isKnownIsolatedFromAbove())
2453
0
      break;
2454
0
2455
0
    // Otherwise, traverse up to the next parent.
2456
0
    Operation *parentOp = printedOp->getParentOp();
2457
0
    if (!parentOp)
2458
0
      break;
2459
0
    printedOp = parentOp;
2460
0
  } while (true);
2461
0
2462
0
  AsmState state(printedOp);
2463
0
  print(os, state, flags);
2464
0
}
2465
0
void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
2466
0
  OperationPrinter(os, flags, state.getImpl()).print(this);
2467
0
}
2468
2469
0
void Operation::dump() {
2470
0
  print(llvm::errs(), OpPrintingFlags().useLocalScope());
2471
0
  llvm::errs() << "\n";
2472
0
}
2473
2474
0
void Block::print(raw_ostream &os) {
2475
0
  Operation *parentOp = getParentOp();
2476
0
  if (!parentOp) {
2477
0
    os << "<<UNLINKED BLOCK>>\n";
2478
0
    return;
2479
0
  }
2480
0
  // Get the top-level op.
2481
0
  while (auto *nextOp = parentOp->getParentOp())
2482
0
    parentOp = nextOp;
2483
0
2484
0
  AsmState state(parentOp);
2485
0
  print(os, state);
2486
0
}
2487
0
void Block::print(raw_ostream &os, AsmState &state) {
2488
0
  OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
2489
0
}
2490
2491
0
void Block::dump() { print(llvm::errs()); }
2492
2493
/// Print out the name of the block without printing its body.
2494
0
void Block::printAsOperand(raw_ostream &os, bool printType) {
2495
0
  Operation *parentOp = getParentOp();
2496
0
  if (!parentOp) {
2497
0
    os << "<<UNLINKED BLOCK>>\n";
2498
0
    return;
2499
0
  }
2500
0
  AsmState state(parentOp);
2501
0
  printAsOperand(os, state);
2502
0
}
2503
0
void Block::printAsOperand(raw_ostream &os, AsmState &state) {
2504
0
  OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
2505
0
  printer.printBlockName(this);
2506
0
}
2507
2508
0
void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
2509
0
  AsmState state(*this);
2510
0
2511
0
  // Don't populate aliases when printing at local scope.
2512
0
  if (!flags.shouldUseLocalScope())
2513
0
    state.getImpl().initializeAliases(*this);
2514
0
  print(os, state, flags);
2515
0
}
2516
0
void ModuleOp::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
2517
0
  OperationPrinter(os, flags, state.getImpl()).print(*this);
2518
0
}
2519
2520
0
void ModuleOp::dump() { print(llvm::errs()); }