/home/arjun/llvm-project/mlir/lib/IR/AsmPrinter.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This file implements the MLIR AsmPrinter class, which is used to implement |
10 | | // the various print() methods on the core IR objects. |
11 | | // |
12 | | //===----------------------------------------------------------------------===// |
13 | | |
14 | | #include "mlir/IR/AffineExpr.h" |
15 | | #include "mlir/IR/AffineMap.h" |
16 | | #include "mlir/IR/AsmState.h" |
17 | | #include "mlir/IR/Attributes.h" |
18 | | #include "mlir/IR/Dialect.h" |
19 | | #include "mlir/IR/DialectImplementation.h" |
20 | | #include "mlir/IR/Function.h" |
21 | | #include "mlir/IR/IntegerSet.h" |
22 | | #include "mlir/IR/MLIRContext.h" |
23 | | #include "mlir/IR/Module.h" |
24 | | #include "mlir/IR/OpImplementation.h" |
25 | | #include "mlir/IR/Operation.h" |
26 | | #include "mlir/IR/StandardTypes.h" |
27 | | #include "llvm/ADT/APFloat.h" |
28 | | #include "llvm/ADT/DenseMap.h" |
29 | | #include "llvm/ADT/MapVector.h" |
30 | | #include "llvm/ADT/STLExtras.h" |
31 | | #include "llvm/ADT/ScopedHashTable.h" |
32 | | #include "llvm/ADT/SetVector.h" |
33 | | #include "llvm/ADT/SmallString.h" |
34 | | #include "llvm/ADT/StringExtras.h" |
35 | | #include "llvm/ADT/StringSet.h" |
36 | | #include "llvm/Support/CommandLine.h" |
37 | | #include "llvm/Support/Regex.h" |
38 | | #include "llvm/Support/SaveAndRestore.h" |
39 | | using namespace mlir; |
40 | | using namespace mlir::detail; |
41 | | |
42 | 0 | void Identifier::print(raw_ostream &os) const { os << str(); } |
43 | | |
44 | 0 | void Identifier::dump() const { print(llvm::errs()); } |
45 | | |
46 | 0 | void OperationName::print(raw_ostream &os) const { os << getStringRef(); } |
47 | | |
48 | 0 | void OperationName::dump() const { print(llvm::errs()); } |
49 | | |
50 | 0 | DialectAsmPrinter::~DialectAsmPrinter() {} |
51 | | |
52 | 0 | OpAsmPrinter::~OpAsmPrinter() {} |
53 | | |
54 | | //===--------------------------------------------------------------------===// |
55 | | // Operation OpAsm interface. |
56 | | //===--------------------------------------------------------------------===// |
57 | | |
58 | | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
59 | | #include "mlir/IR/OpAsmInterface.cpp.inc" |
60 | | |
61 | | //===----------------------------------------------------------------------===// |
62 | | // OpPrintingFlags |
63 | | //===----------------------------------------------------------------------===// |
64 | | |
65 | | namespace { |
66 | | /// This struct contains command line options that can be used to initialize |
67 | | /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need |
68 | | /// for global command line options. |
69 | | struct AsmPrinterOptions { |
70 | | llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ |
71 | | "mlir-print-elementsattrs-with-hex-if-larger", |
72 | | llvm::cl::desc( |
73 | | "Print DenseElementsAttrs with a hex string that have " |
74 | | "more elements than the given upper limit (use -1 to disable)")}; |
75 | | |
76 | | llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ |
77 | | "mlir-elide-elementsattrs-if-larger", |
78 | | llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " |
79 | | "more elements than the given upper limit")}; |
80 | | |
81 | | llvm::cl::opt<bool> printDebugInfoOpt{ |
82 | | "mlir-print-debuginfo", llvm::cl::init(false), |
83 | | llvm::cl::desc("Print debug info in MLIR output")}; |
84 | | |
85 | | llvm::cl::opt<bool> printPrettyDebugInfoOpt{ |
86 | | "mlir-pretty-debuginfo", llvm::cl::init(false), |
87 | | llvm::cl::desc("Print pretty debug info in MLIR output")}; |
88 | | |
89 | | // Use the generic op output form in the operation printer even if the custom |
90 | | // form is defined. |
91 | | llvm::cl::opt<bool> printGenericOpFormOpt{ |
92 | | "mlir-print-op-generic", llvm::cl::init(false), |
93 | | llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden}; |
94 | | |
95 | | llvm::cl::opt<bool> printLocalScopeOpt{ |
96 | | "mlir-print-local-scope", llvm::cl::init(false), |
97 | | llvm::cl::desc("Print assuming in local scope by default"), |
98 | | llvm::cl::Hidden}; |
99 | | }; |
100 | | } // end anonymous namespace |
101 | | |
102 | | static llvm::ManagedStatic<AsmPrinterOptions> clOptions; |
103 | | |
104 | | /// Register a set of useful command-line options that can be used to configure |
105 | | /// various flags within the AsmPrinter. |
106 | 0 | void mlir::registerAsmPrinterCLOptions() { |
107 | 0 | // Make sure that the options struct has been initialized. |
108 | 0 | *clOptions; |
109 | 0 | } |
110 | | |
111 | | /// Initialize the printing flags with default supplied by the cl::opts above. |
112 | | OpPrintingFlags::OpPrintingFlags() |
113 | | : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), |
114 | 0 | printGenericOpFormFlag(false), printLocalScope(false) { |
115 | 0 | // Initialize based upon command line options, if they are available. |
116 | 0 | if (!clOptions.isConstructed()) |
117 | 0 | return; |
118 | 0 | if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) |
119 | 0 | elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; |
120 | 0 | printDebugInfoFlag = clOptions->printDebugInfoOpt; |
121 | 0 | printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; |
122 | 0 | printGenericOpFormFlag = clOptions->printGenericOpFormOpt; |
123 | 0 | printLocalScope = clOptions->printLocalScopeOpt; |
124 | 0 | } |
125 | | |
126 | | /// Enable the elision of large elements attributes, by printing a '...' |
127 | | /// instead of the element data, when the number of elements is greater than |
128 | | /// `largeElementLimit`. Note: The IR generated with this option is not |
129 | | /// parsable. |
130 | | OpPrintingFlags & |
131 | 0 | OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { |
132 | 0 | elementsAttrElementLimit = largeElementLimit; |
133 | 0 | return *this; |
134 | 0 | } |
135 | | |
136 | | /// Enable printing of debug information. If 'prettyForm' is set to true, |
137 | | /// debug information is printed in a more readable 'pretty' form. |
138 | 0 | OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) { |
139 | 0 | printDebugInfoFlag = true; |
140 | 0 | printDebugInfoPrettyFormFlag = prettyForm; |
141 | 0 | return *this; |
142 | 0 | } |
143 | | |
144 | | /// Always print operations in the generic form. |
145 | 0 | OpPrintingFlags &OpPrintingFlags::printGenericOpForm() { |
146 | 0 | printGenericOpFormFlag = true; |
147 | 0 | return *this; |
148 | 0 | } |
149 | | |
150 | | /// Use local scope when printing the operation. This allows for using the |
151 | | /// printer in a more localized and thread-safe setting, but may not necessarily |
152 | | /// be identical of what the IR will look like when dumping the full module. |
153 | 0 | OpPrintingFlags &OpPrintingFlags::useLocalScope() { |
154 | 0 | printLocalScope = true; |
155 | 0 | return *this; |
156 | 0 | } |
157 | | |
158 | | /// Return if the given ElementsAttr should be elided. |
159 | 0 | bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { |
160 | 0 | return elementsAttrElementLimit.hasValue() && |
161 | 0 | *elementsAttrElementLimit < int64_t(attr.getNumElements()); |
162 | 0 | } |
163 | | |
164 | | /// Return the size limit for printing large ElementsAttr. |
165 | 0 | Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { |
166 | 0 | return elementsAttrElementLimit; |
167 | 0 | } |
168 | | |
169 | | /// Return if debug information should be printed. |
170 | 0 | bool OpPrintingFlags::shouldPrintDebugInfo() const { |
171 | 0 | return printDebugInfoFlag; |
172 | 0 | } |
173 | | |
174 | | /// Return if debug information should be printed in the pretty form. |
175 | 0 | bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { |
176 | 0 | return printDebugInfoPrettyFormFlag; |
177 | 0 | } |
178 | | |
179 | | /// Return if operations should be printed in the generic form. |
180 | 0 | bool OpPrintingFlags::shouldPrintGenericOpForm() const { |
181 | 0 | return printGenericOpFormFlag; |
182 | 0 | } |
183 | | |
184 | | /// Return if the printer should use local scope when dumping the IR. |
185 | 0 | bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } |
186 | | |
187 | | /// Returns true if an ElementsAttr with the given number of elements should be |
188 | | /// printed with hex. |
189 | 0 | static bool shouldPrintElementsAttrWithHex(int64_t numElements) { |
190 | 0 | // Check to see if a command line option was provided for the limit. |
191 | 0 | if (clOptions.isConstructed()) { |
192 | 0 | if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) { |
193 | 0 | // -1 is used to disable hex printing. |
194 | 0 | if (clOptions->printElementsAttrWithHexIfLarger == -1) |
195 | 0 | return false; |
196 | 0 | return numElements > clOptions->printElementsAttrWithHexIfLarger; |
197 | 0 | } |
198 | 0 | } |
199 | 0 |
|
200 | 0 | // Otherwise, default to printing with hex if the number of elements is >100. |
201 | 0 | return numElements > 100; |
202 | 0 | } |
203 | | |
204 | | //===----------------------------------------------------------------------===// |
205 | | // NewLineCounter |
206 | | //===----------------------------------------------------------------------===// |
207 | | |
208 | | namespace { |
209 | | /// This class is a simple formatter that emits a new line when inputted into a |
210 | | /// stream, that enables counting the number of newlines emitted. This class |
211 | | /// should be used whenever emitting newlines in the printer. |
212 | | struct NewLineCounter { |
213 | | unsigned curLine = 1; |
214 | | }; |
215 | | } // end anonymous namespace |
216 | | |
217 | 0 | static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { |
218 | 0 | ++newLine.curLine; |
219 | 0 | return os << '\n'; |
220 | 0 | } |
221 | | |
222 | | //===----------------------------------------------------------------------===// |
223 | | // AliasState |
224 | | //===----------------------------------------------------------------------===// |
225 | | |
226 | | namespace { |
227 | | /// This class manages the state for type and attribute aliases. |
228 | | class AliasState { |
229 | | public: |
230 | | // Initialize the internal aliases. |
231 | | void |
232 | | initialize(Operation *op, |
233 | | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); |
234 | | |
235 | | /// Return a name used for an attribute alias, or empty if there is no alias. |
236 | | Twine getAttributeAlias(Attribute attr) const; |
237 | | |
238 | | /// Print all of the referenced attribute aliases. |
239 | | void printAttributeAliases(raw_ostream &os, NewLineCounter &newLine) const; |
240 | | |
241 | | /// Return a string to use as an alias for the given type, or empty if there |
242 | | /// is no alias recorded. |
243 | | StringRef getTypeAlias(Type ty) const; |
244 | | |
245 | | /// Print all of the referenced type aliases. |
246 | | void printTypeAliases(raw_ostream &os, NewLineCounter &newLine) const; |
247 | | |
248 | | private: |
249 | | /// A special index constant used for non-kind attribute aliases. |
250 | | enum { NonAttrKindAlias = -1 }; |
251 | | |
252 | | /// Record a reference to the given attribute. |
253 | | void recordAttributeReference(Attribute attr); |
254 | | |
255 | | /// Record a reference to the given type. |
256 | | void recordTypeReference(Type ty); |
257 | | |
258 | | // Visit functions. |
259 | | void visitOperation(Operation *op); |
260 | | void visitType(Type type); |
261 | | void visitAttribute(Attribute attr); |
262 | | |
263 | | /// Set of attributes known to be used within the module. |
264 | | llvm::SetVector<Attribute> usedAttributes; |
265 | | |
266 | | /// Mapping between attribute and a pair comprised of a base alias name and a |
267 | | /// count suffix. If the suffix is set to -1, it is not displayed. |
268 | | llvm::MapVector<Attribute, std::pair<StringRef, int>> attrToAlias; |
269 | | |
270 | | /// Mapping between attribute kind and a pair comprised of a base alias name |
271 | | /// and a unique list of attributes belonging to this kind sorted by location |
272 | | /// seen in the module. |
273 | | llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>> |
274 | | attrKindToAlias; |
275 | | |
276 | | /// Set of types known to be used within the module. |
277 | | llvm::SetVector<Type> usedTypes; |
278 | | |
279 | | /// A mapping between a type and a given alias. |
280 | | DenseMap<Type, StringRef> typeToAlias; |
281 | | }; |
282 | | } // end anonymous namespace |
283 | | |
284 | | // Utility to generate a function to register a symbol alias. |
285 | | static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) { |
286 | | assert(!name.empty() && "expected alias name to be non-empty"); |
287 | | // TODO(riverriddle) Assert that the provided alias name can be lexed as |
288 | | // an identifier. |
289 | | |
290 | | // Check that the alias doesn't contain a '.' character and the name is not |
291 | | // already in use. |
292 | | return !name.contains('.') && usedAliases.insert(name).second; |
293 | | } |
294 | | |
295 | | void AliasState::initialize( |
296 | | Operation *op, |
297 | 0 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { |
298 | 0 | // Track the identifiers in use for each symbol so that the same identifier |
299 | 0 | // isn't used twice. |
300 | 0 | llvm::StringSet<> usedAliases; |
301 | 0 |
|
302 | 0 | // Collect the set of aliases from each dialect. |
303 | 0 | SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases; |
304 | 0 | SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases; |
305 | 0 | SmallVector<std::pair<Type, StringRef>, 16> typeAliases; |
306 | 0 |
|
307 | 0 | // AffineMap/Integer set have specific kind aliases. |
308 | 0 | attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map"); |
309 | 0 | attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set"); |
310 | 0 |
|
311 | 0 | for (auto &interface : interfaces) { |
312 | 0 | interface.getAttributeKindAliases(attributeKindAliases); |
313 | 0 | interface.getAttributeAliases(attributeAliases); |
314 | 0 | interface.getTypeAliases(typeAliases); |
315 | 0 | } |
316 | 0 |
|
317 | 0 | // Setup the attribute kind aliases. |
318 | 0 | StringRef alias; |
319 | 0 | unsigned attrKind; |
320 | 0 | for (auto &attrAliasPair : attributeKindAliases) { |
321 | 0 | std::tie(attrKind, alias) = attrAliasPair; |
322 | 0 | assert(!alias.empty() && "expected non-empty alias string"); |
323 | 0 | if (!usedAliases.count(alias) && !alias.contains('.')) |
324 | 0 | attrKindToAlias.insert({attrKind, {alias, {}}}); |
325 | 0 | } |
326 | 0 |
|
327 | 0 | // Clear the set of used identifiers so that the attribute kind aliases are |
328 | 0 | // just a prefix and not the full alias, i.e. there may be some overlap. |
329 | 0 | usedAliases.clear(); |
330 | 0 |
|
331 | 0 | // Register the attribute aliases. |
332 | 0 | // Create a regex for the attribute kind alias names, these have a prefix with |
333 | 0 | // a counter appended to the end. We prevent normal aliases from having these |
334 | 0 | // names to avoid collisions. |
335 | 0 | llvm::Regex reservedAttrNames("[0-9]+$"); |
336 | 0 |
|
337 | 0 | // Attribute value aliases. |
338 | 0 | Attribute attr; |
339 | 0 | for (auto &attrAliasPair : attributeAliases) { |
340 | 0 | std::tie(attr, alias) = attrAliasPair; |
341 | 0 | if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases)) |
342 | 0 | attrToAlias.insert({attr, {alias, NonAttrKindAlias}}); |
343 | 0 | } |
344 | 0 |
|
345 | 0 | // Clear the set of used identifiers as types can have the same identifiers as |
346 | 0 | // affine structures. |
347 | 0 | usedAliases.clear(); |
348 | 0 |
|
349 | 0 | // Type aliases. |
350 | 0 | for (auto &typeAliasPair : typeAliases) |
351 | 0 | if (canRegisterAlias(typeAliasPair.second, usedAliases)) |
352 | 0 | typeToAlias.insert(typeAliasPair); |
353 | 0 |
|
354 | 0 | // Traverse the given IR to generate the set of used attributes/types. |
355 | 0 | op->walk([&](Operation *op) { visitOperation(op); }); |
356 | 0 | } |
357 | | |
358 | | /// Return a name used for an attribute alias, or empty if there is no alias. |
359 | 0 | Twine AliasState::getAttributeAlias(Attribute attr) const { |
360 | 0 | auto alias = attrToAlias.find(attr); |
361 | 0 | if (alias == attrToAlias.end()) |
362 | 0 | return Twine(); |
363 | 0 | |
364 | 0 | // Return the alias for this attribute, along with the index if this was |
365 | 0 | // generated by a kind alias. |
366 | 0 | int kindIndex = alias->second.second; |
367 | 0 | return alias->second.first + |
368 | 0 | (kindIndex == NonAttrKindAlias ? Twine() : Twine(kindIndex)); |
369 | 0 | } |
370 | | |
371 | | /// Print all of the referenced attribute aliases. |
372 | | void AliasState::printAttributeAliases(raw_ostream &os, |
373 | 0 | NewLineCounter &newLine) const { |
374 | 0 | auto printAlias = [&](StringRef alias, Attribute attr, int index) { |
375 | 0 | os << '#' << alias; |
376 | 0 | if (index != NonAttrKindAlias) |
377 | 0 | os << index; |
378 | 0 | os << " = " << attr << newLine; |
379 | 0 | }; |
380 | 0 |
|
381 | 0 | // Print all of the attribute kind aliases. |
382 | 0 | for (auto &kindAlias : attrKindToAlias) { |
383 | 0 | auto &aliasAttrsPair = kindAlias.second; |
384 | 0 | for (unsigned i = 0, e = aliasAttrsPair.second.size(); i != e; ++i) |
385 | 0 | printAlias(aliasAttrsPair.first, aliasAttrsPair.second[i], i); |
386 | 0 | os << newLine; |
387 | 0 | } |
388 | 0 |
|
389 | 0 | // In a second pass print all of the remaining attribute aliases that aren't |
390 | 0 | // kind aliases. |
391 | 0 | for (Attribute attr : usedAttributes) { |
392 | 0 | auto alias = attrToAlias.find(attr); |
393 | 0 | if (alias != attrToAlias.end() && alias->second.second == NonAttrKindAlias) |
394 | 0 | printAlias(alias->second.first, attr, alias->second.second); |
395 | 0 | } |
396 | 0 | } |
397 | | |
398 | | /// Return a string to use as an alias for the given type, or empty if there |
399 | | /// is no alias recorded. |
400 | 0 | StringRef AliasState::getTypeAlias(Type ty) const { |
401 | 0 | return typeToAlias.lookup(ty); |
402 | 0 | } |
403 | | |
404 | | /// Print all of the referenced type aliases. |
405 | | void AliasState::printTypeAliases(raw_ostream &os, |
406 | 0 | NewLineCounter &newLine) const { |
407 | 0 | for (Type type : usedTypes) { |
408 | 0 | auto alias = typeToAlias.find(type); |
409 | 0 | if (alias != typeToAlias.end()) |
410 | 0 | os << '!' << alias->second << " = type " << type << newLine; |
411 | 0 | } |
412 | 0 | } |
413 | | |
414 | | /// Record a reference to the given attribute. |
415 | 0 | void AliasState::recordAttributeReference(Attribute attr) { |
416 | 0 | // Don't recheck attributes that have already been seen or those that |
417 | 0 | // already have an alias. |
418 | 0 | if (!usedAttributes.insert(attr) || attrToAlias.count(attr)) |
419 | 0 | return; |
420 | 0 | |
421 | 0 | // If this attribute kind has an alias, then record one for this attribute. |
422 | 0 | auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind())); |
423 | 0 | if (alias == attrKindToAlias.end()) |
424 | 0 | return; |
425 | 0 | std::pair<StringRef, int> attrAlias(alias->second.first, |
426 | 0 | alias->second.second.size()); |
427 | 0 | attrToAlias.insert({attr, attrAlias}); |
428 | 0 | alias->second.second.push_back(attr); |
429 | 0 | } |
430 | | |
431 | | /// Record a reference to the given type. |
432 | 0 | void AliasState::recordTypeReference(Type ty) { usedTypes.insert(ty); } |
433 | | |
434 | | // TODO Support visiting other types/operations when implemented. |
435 | 0 | void AliasState::visitType(Type type) { |
436 | 0 | recordTypeReference(type); |
437 | 0 |
|
438 | 0 | if (auto funcType = type.dyn_cast<FunctionType>()) { |
439 | 0 | // Visit input and result types for functions. |
440 | 0 | for (auto input : funcType.getInputs()) |
441 | 0 | visitType(input); |
442 | 0 | for (auto result : funcType.getResults()) |
443 | 0 | visitType(result); |
444 | 0 | } else if (auto shapedType = type.dyn_cast<ShapedType>()) { |
445 | 0 | visitType(shapedType.getElementType()); |
446 | 0 |
|
447 | 0 | // Visit affine maps in memref type. |
448 | 0 | if (auto memref = type.dyn_cast<MemRefType>()) |
449 | 0 | for (auto map : memref.getAffineMaps()) |
450 | 0 | recordAttributeReference(AffineMapAttr::get(map)); |
451 | 0 | } |
452 | 0 | } |
453 | | |
454 | 0 | void AliasState::visitAttribute(Attribute attr) { |
455 | 0 | recordAttributeReference(attr); |
456 | 0 |
|
457 | 0 | if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { |
458 | 0 | for (auto elt : arrayAttr.getValue()) |
459 | 0 | visitAttribute(elt); |
460 | 0 | } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) { |
461 | 0 | visitType(typeAttr.getValue()); |
462 | 0 | } |
463 | 0 | } |
464 | | |
465 | 0 | void AliasState::visitOperation(Operation *op) { |
466 | 0 | // Visit all the types used in the operation. |
467 | 0 | for (auto type : op->getOperandTypes()) |
468 | 0 | visitType(type); |
469 | 0 | for (auto type : op->getResultTypes()) |
470 | 0 | visitType(type); |
471 | 0 | for (auto ®ion : op->getRegions()) |
472 | 0 | for (auto &block : region) |
473 | 0 | for (auto arg : block.getArguments()) |
474 | 0 | visitType(arg.getType()); |
475 | 0 |
|
476 | 0 | // Visit each of the attributes. |
477 | 0 | for (auto elt : op->getAttrs()) |
478 | 0 | visitAttribute(elt.second); |
479 | 0 | } |
480 | | |
481 | | //===----------------------------------------------------------------------===// |
482 | | // SSANameState |
483 | | //===----------------------------------------------------------------------===// |
484 | | |
485 | | namespace { |
486 | | /// This class manages the state of SSA value names. |
487 | | class SSANameState { |
488 | | public: |
489 | | /// A sentinel value used for values with names set. |
490 | | enum : unsigned { NameSentinel = ~0U }; |
491 | | |
492 | | SSANameState(Operation *op, |
493 | | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); |
494 | | |
495 | | /// Print the SSA identifier for the given value to 'stream'. If |
496 | | /// 'printResultNo' is true, it also presents the result number ('#' number) |
497 | | /// of this value. |
498 | | void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; |
499 | | |
500 | | /// Return the result indices for each of the result groups registered by this |
501 | | /// operation, or empty if none exist. |
502 | | ArrayRef<int> getOpResultGroups(Operation *op); |
503 | | |
504 | | /// Get the ID for the given block. |
505 | | unsigned getBlockID(Block *block); |
506 | | |
507 | | /// Renumber the arguments for the specified region to the same names as the |
508 | | /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for |
509 | | /// details. |
510 | | void shadowRegionArgs(Region ®ion, ValueRange namesToUse); |
511 | | |
512 | | private: |
513 | | /// Number the SSA values within the given IR unit. |
514 | | void numberValuesInRegion( |
515 | | Region ®ion, |
516 | | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); |
517 | | void numberValuesInBlock( |
518 | | Block &block, |
519 | | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); |
520 | | void numberValuesInOp( |
521 | | Operation &op, |
522 | | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); |
523 | | |
524 | | /// Given a result of an operation 'result', find the result group head |
525 | | /// 'lookupValue' and the result of 'result' within that group in |
526 | | /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group |
527 | | /// has more than 1 result. |
528 | | void getResultIDAndNumber(OpResult result, Value &lookupValue, |
529 | | Optional<int> &lookupResultNo) const; |
530 | | |
531 | | /// Set a special value name for the given value. |
532 | | void setValueName(Value value, StringRef name); |
533 | | |
534 | | /// Uniques the given value name within the printer. If the given name |
535 | | /// conflicts, it is automatically renamed. |
536 | | StringRef uniqueValueName(StringRef name); |
537 | | |
538 | | /// This is the value ID for each SSA value. If this returns NameSentinel, |
539 | | /// then the valueID has an entry in valueNames. |
540 | | DenseMap<Value, unsigned> valueIDs; |
541 | | DenseMap<Value, StringRef> valueNames; |
542 | | |
543 | | /// This is a map of operations that contain multiple named result groups, |
544 | | /// i.e. there may be multiple names for the results of the operation. The |
545 | | /// value of this map are the result numbers that start a result group. |
546 | | DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; |
547 | | |
548 | | /// This is the block ID for each block in the current. |
549 | | DenseMap<Block *, unsigned> blockIDs; |
550 | | |
551 | | /// This keeps track of all of the non-numeric names that are in flight, |
552 | | /// allowing us to check for duplicates. |
553 | | /// Note: the value of the map is unused. |
554 | | llvm::ScopedHashTable<StringRef, char> usedNames; |
555 | | llvm::BumpPtrAllocator usedNameAllocator; |
556 | | |
557 | | /// This is the next value ID to assign in numbering. |
558 | | unsigned nextValueID = 0; |
559 | | /// This is the next ID to assign to a region entry block argument. |
560 | | unsigned nextArgumentID = 0; |
561 | | /// This is the next ID to assign when a name conflict is detected. |
562 | | unsigned nextConflictID = 0; |
563 | | }; |
564 | | } // end anonymous namespace |
565 | | |
566 | | SSANameState::SSANameState( |
567 | | Operation *op, |
568 | 0 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { |
569 | 0 | llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); |
570 | 0 | numberValuesInOp(*op, interfaces); |
571 | 0 |
|
572 | 0 | for (auto ®ion : op->getRegions()) |
573 | 0 | numberValuesInRegion(region, interfaces); |
574 | 0 | } |
575 | | |
576 | | void SSANameState::printValueID(Value value, bool printResultNo, |
577 | 0 | raw_ostream &stream) const { |
578 | 0 | if (!value) { |
579 | 0 | stream << "<<NULL>>"; |
580 | 0 | return; |
581 | 0 | } |
582 | 0 | |
583 | 0 | Optional<int> resultNo; |
584 | 0 | auto lookupValue = value; |
585 | 0 |
|
586 | 0 | // If this is an operation result, collect the head lookup value of the result |
587 | 0 | // group and the result number of 'result' within that group. |
588 | 0 | if (OpResult result = value.dyn_cast<OpResult>()) |
589 | 0 | getResultIDAndNumber(result, lookupValue, resultNo); |
590 | 0 |
|
591 | 0 | auto it = valueIDs.find(lookupValue); |
592 | 0 | if (it == valueIDs.end()) { |
593 | 0 | stream << "<<UNKNOWN SSA VALUE>>"; |
594 | 0 | return; |
595 | 0 | } |
596 | 0 | |
597 | 0 | stream << '%'; |
598 | 0 | if (it->second != NameSentinel) { |
599 | 0 | stream << it->second; |
600 | 0 | } else { |
601 | 0 | auto nameIt = valueNames.find(lookupValue); |
602 | 0 | assert(nameIt != valueNames.end() && "Didn't have a name entry?"); |
603 | 0 | stream << nameIt->second; |
604 | 0 | } |
605 | 0 |
|
606 | 0 | if (resultNo.hasValue() && printResultNo) |
607 | 0 | stream << '#' << resultNo; |
608 | 0 | } |
609 | | |
610 | 0 | ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { |
611 | 0 | auto it = opResultGroups.find(op); |
612 | 0 | return it == opResultGroups.end() ? ArrayRef<int>() : it->second; |
613 | 0 | } |
614 | | |
615 | 0 | unsigned SSANameState::getBlockID(Block *block) { |
616 | 0 | auto it = blockIDs.find(block); |
617 | 0 | return it != blockIDs.end() ? it->second : NameSentinel; |
618 | 0 | } |
619 | | |
620 | 0 | void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { |
621 | 0 | assert(!region.empty() && "cannot shadow arguments of an empty region"); |
622 | 0 | assert(region.front().getNumArguments() == namesToUse.size() && |
623 | 0 | "incorrect number of names passed in"); |
624 | 0 | assert(region.getParentOp()->isKnownIsolatedFromAbove() && |
625 | 0 | "only KnownIsolatedFromAbove ops can shadow names"); |
626 | 0 |
|
627 | 0 | SmallVector<char, 16> nameStr; |
628 | 0 | for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { |
629 | 0 | auto nameToUse = namesToUse[i]; |
630 | 0 | if (nameToUse == nullptr) |
631 | 0 | continue; |
632 | 0 | auto nameToReplace = region.front().getArgument(i); |
633 | 0 |
|
634 | 0 | nameStr.clear(); |
635 | 0 | llvm::raw_svector_ostream nameStream(nameStr); |
636 | 0 | printValueID(nameToUse, /*printResultNo=*/true, nameStream); |
637 | 0 |
|
638 | 0 | // Entry block arguments should already have a pretty "arg" name. |
639 | 0 | assert(valueIDs[nameToReplace] == NameSentinel); |
640 | 0 |
|
641 | 0 | // Use the name without the leading %. |
642 | 0 | auto name = StringRef(nameStream.str()).drop_front(); |
643 | 0 |
|
644 | 0 | // Overwrite the name. |
645 | 0 | valueNames[nameToReplace] = name.copy(usedNameAllocator); |
646 | 0 | } |
647 | 0 | } |
648 | | |
649 | | void SSANameState::numberValuesInRegion( |
650 | | Region ®ion, |
651 | 0 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { |
652 | 0 | // Save the current value ids to allow for numbering values in sibling regions |
653 | 0 | // the same. |
654 | 0 | llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID); |
655 | 0 | llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID); |
656 | 0 | llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID); |
657 | 0 |
|
658 | 0 | // Push a new used names scope. |
659 | 0 | llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); |
660 | 0 |
|
661 | 0 | // Number the values within this region in a breadth-first order. |
662 | 0 | unsigned nextBlockID = 0; |
663 | 0 | for (auto &block : region) { |
664 | 0 | // Each block gets a unique ID, and all of the operations within it get |
665 | 0 | // numbered as well. |
666 | 0 | blockIDs[&block] = nextBlockID++; |
667 | 0 | numberValuesInBlock(block, interfaces); |
668 | 0 | } |
669 | 0 |
|
670 | 0 | // After that we traverse the nested regions. |
671 | 0 | // TODO: Rework this loop to not use recursion. |
672 | 0 | for (auto &block : region) { |
673 | 0 | for (auto &op : block) |
674 | 0 | for (auto &nestedRegion : op.getRegions()) |
675 | 0 | numberValuesInRegion(nestedRegion, interfaces); |
676 | 0 | } |
677 | 0 | } |
678 | | |
679 | | void SSANameState::numberValuesInBlock( |
680 | | Block &block, |
681 | 0 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { |
682 | 0 | auto setArgNameFn = [&](Value arg, StringRef name) { |
683 | 0 | assert(!valueIDs.count(arg) && "arg numbered multiple times"); |
684 | 0 | assert(arg.cast<BlockArgument>().getOwner() == &block && |
685 | 0 | "arg not defined in 'block'"); |
686 | 0 | setValueName(arg, name); |
687 | 0 | }; |
688 | 0 |
|
689 | 0 | bool isEntryBlock = block.isEntryBlock(); |
690 | 0 | if (isEntryBlock) { |
691 | 0 | if (auto *op = block.getParentOp()) { |
692 | 0 | if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect())) |
693 | 0 | asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); |
694 | 0 | } |
695 | 0 | } |
696 | 0 |
|
697 | 0 | // Number the block arguments. We give entry block arguments a special name |
698 | 0 | // 'arg'. |
699 | 0 | SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); |
700 | 0 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
701 | 0 | for (auto arg : block.getArguments()) { |
702 | 0 | if (valueIDs.count(arg)) |
703 | 0 | continue; |
704 | 0 | if (isEntryBlock) { |
705 | 0 | specialNameBuffer.resize(strlen("arg")); |
706 | 0 | specialName << nextArgumentID++; |
707 | 0 | } |
708 | 0 | setValueName(arg, specialName.str()); |
709 | 0 | } |
710 | 0 |
|
711 | 0 | // Number the operations in this block. |
712 | 0 | for (auto &op : block) |
713 | 0 | numberValuesInOp(op, interfaces); |
714 | 0 | } |
715 | | |
716 | | void SSANameState::numberValuesInOp( |
717 | | Operation &op, |
718 | 0 | DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { |
719 | 0 | unsigned numResults = op.getNumResults(); |
720 | 0 | if (numResults == 0) |
721 | 0 | return; |
722 | 0 | Value resultBegin = op.getResult(0); |
723 | 0 |
|
724 | 0 | // Function used to set the special result names for the operation. |
725 | 0 | SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); |
726 | 0 | auto setResultNameFn = [&](Value result, StringRef name) { |
727 | 0 | assert(!valueIDs.count(result) && "result numbered multiple times"); |
728 | 0 | assert(result.getDefiningOp() == &op && "result not defined by 'op'"); |
729 | 0 | setValueName(result, name); |
730 | 0 |
|
731 | 0 | // Record the result number for groups not anchored at 0. |
732 | 0 | if (int resultNo = result.cast<OpResult>().getResultNumber()) |
733 | 0 | resultGroups.push_back(resultNo); |
734 | 0 | }; |
735 | 0 | if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) |
736 | 0 | asmInterface.getAsmResultNames(setResultNameFn); |
737 | 0 | else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect())) |
738 | 0 | asmInterface->getAsmResultNames(&op, setResultNameFn); |
739 | 0 |
|
740 | 0 | // If the first result wasn't numbered, give it a default number. |
741 | 0 | if (valueIDs.try_emplace(resultBegin, nextValueID).second) |
742 | 0 | ++nextValueID; |
743 | 0 |
|
744 | 0 | // If this operation has multiple result groups, mark it. |
745 | 0 | if (resultGroups.size() != 1) { |
746 | 0 | llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); |
747 | 0 | opResultGroups.try_emplace(&op, std::move(resultGroups)); |
748 | 0 | } |
749 | 0 | } |
750 | | |
751 | | void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue, |
752 | 0 | Optional<int> &lookupResultNo) const { |
753 | 0 | Operation *owner = result.getOwner(); |
754 | 0 | if (owner->getNumResults() == 1) |
755 | 0 | return; |
756 | 0 | int resultNo = result.getResultNumber(); |
757 | 0 |
|
758 | 0 | // If this operation has multiple result groups, we will need to find the |
759 | 0 | // one corresponding to this result. |
760 | 0 | auto resultGroupIt = opResultGroups.find(owner); |
761 | 0 | if (resultGroupIt == opResultGroups.end()) { |
762 | 0 | // If not, just use the first result. |
763 | 0 | lookupResultNo = resultNo; |
764 | 0 | lookupValue = owner->getResult(0); |
765 | 0 | return; |
766 | 0 | } |
767 | 0 | |
768 | 0 | // Find the correct index using a binary search, as the groups are ordered. |
769 | 0 | ArrayRef<int> resultGroups = resultGroupIt->second; |
770 | 0 | auto it = llvm::upper_bound(resultGroups, resultNo); |
771 | 0 | int groupResultNo = 0, groupSize = 0; |
772 | 0 |
|
773 | 0 | // If there are no smaller elements, the last result group is the lookup. |
774 | 0 | if (it == resultGroups.end()) { |
775 | 0 | groupResultNo = resultGroups.back(); |
776 | 0 | groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); |
777 | 0 | } else { |
778 | 0 | // Otherwise, the previous element is the lookup. |
779 | 0 | groupResultNo = *std::prev(it); |
780 | 0 | groupSize = *it - groupResultNo; |
781 | 0 | } |
782 | 0 |
|
783 | 0 | // We only record the result number for a group of size greater than 1. |
784 | 0 | if (groupSize != 1) |
785 | 0 | lookupResultNo = resultNo - groupResultNo; |
786 | 0 | lookupValue = owner->getResult(groupResultNo); |
787 | 0 | } |
788 | | |
789 | 0 | void SSANameState::setValueName(Value value, StringRef name) { |
790 | 0 | // If the name is empty, the value uses the default numbering. |
791 | 0 | if (name.empty()) { |
792 | 0 | valueIDs[value] = nextValueID++; |
793 | 0 | return; |
794 | 0 | } |
795 | 0 | |
796 | 0 | valueIDs[value] = NameSentinel; |
797 | 0 | valueNames[value] = uniqueValueName(name); |
798 | 0 | } |
799 | | |
800 | | /// Returns true if 'c' is an allowable punctuation character: [$._-] |
801 | | /// Returns false otherwise. |
802 | 0 | static bool isPunct(char c) { |
803 | 0 | return c == '$' || c == '.' || c == '_' || c == '-'; |
804 | 0 | } |
805 | | |
806 | 0 | StringRef SSANameState::uniqueValueName(StringRef name) { |
807 | 0 | assert(!name.empty() && "Shouldn't have an empty name here"); |
808 | 0 |
|
809 | 0 | // Check to see if this name is valid. If it starts with a digit, then it |
810 | 0 | // could conflict with the autogenerated numeric ID's (we unique them in a |
811 | 0 | // different map), so add an underscore prefix to avoid problems. |
812 | 0 | if (isdigit(name[0])) { |
813 | 0 | SmallString<16> tmpName("_"); |
814 | 0 | tmpName += name; |
815 | 0 | return uniqueValueName(tmpName); |
816 | 0 | } |
817 | 0 | |
818 | 0 | // Check to see if the name consists of all-valid identifiers. If not, we |
819 | 0 | // need to escape them. |
820 | 0 | for (char ch : name) { |
821 | 0 | if (isalpha(ch) || isPunct(ch) || isdigit(ch)) |
822 | 0 | continue; |
823 | 0 | |
824 | 0 | SmallString<16> tmpName; |
825 | 0 | for (char ch : name) { |
826 | 0 | if (isalpha(ch) || isPunct(ch) || isdigit(ch)) |
827 | 0 | tmpName += ch; |
828 | 0 | else if (ch == ' ') |
829 | 0 | tmpName += '_'; |
830 | 0 | else { |
831 | 0 | tmpName += llvm::utohexstr((unsigned char)ch); |
832 | 0 | } |
833 | 0 | } |
834 | 0 | return uniqueValueName(tmpName); |
835 | 0 | } |
836 | 0 |
|
837 | 0 | // Check to see if this name is already unique. |
838 | 0 | if (!usedNames.count(name)) { |
839 | 0 | name = name.copy(usedNameAllocator); |
840 | 0 | } else { |
841 | 0 | // Otherwise, we had a conflict - probe until we find a unique name. This |
842 | 0 | // is guaranteed to terminate (and usually in a single iteration) because it |
843 | 0 | // generates new names by incrementing nextConflictID. |
844 | 0 | SmallString<64> probeName(name); |
845 | 0 | probeName.push_back('_'); |
846 | 0 | while (true) { |
847 | 0 | probeName.resize(name.size() + 1); |
848 | 0 | probeName += llvm::utostr(nextConflictID++); |
849 | 0 | if (!usedNames.count(probeName)) { |
850 | 0 | name = StringRef(probeName).copy(usedNameAllocator); |
851 | 0 | break; |
852 | 0 | } |
853 | 0 | } |
854 | 0 | } |
855 | 0 |
|
856 | 0 | usedNames.insert(name, char()); |
857 | 0 | return name; |
858 | 0 | } |
859 | | |
860 | | //===----------------------------------------------------------------------===// |
861 | | // AsmState |
862 | | //===----------------------------------------------------------------------===// |
863 | | |
864 | | namespace mlir { |
865 | | namespace detail { |
866 | | class AsmStateImpl { |
867 | | public: |
868 | | explicit AsmStateImpl(Operation *op, AsmState::LocationMap *locationMap) |
869 | | : interfaces(op->getContext()), nameState(op, interfaces), |
870 | 0 | locationMap(locationMap) {} |
871 | | |
872 | | /// Initialize the alias state to enable the printing of aliases. |
873 | 0 | void initializeAliases(Operation *op) { |
874 | 0 | aliasState.initialize(op, interfaces); |
875 | 0 | } |
876 | | |
877 | | /// Get an instance of the OpAsmDialectInterface for the given dialect, or |
878 | | /// null if one wasn't registered. |
879 | 0 | const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { |
880 | 0 | return interfaces.getInterfaceFor(dialect); |
881 | 0 | } |
882 | | |
883 | | /// Get the state used for aliases. |
884 | 0 | AliasState &getAliasState() { return aliasState; } |
885 | | |
886 | | /// Get the state used for SSA names. |
887 | 0 | SSANameState &getSSANameState() { return nameState; } |
888 | | |
889 | | /// Register the location, line and column, within the buffer that the given |
890 | | /// operation was printed at. |
891 | 0 | void registerOperationLocation(Operation *op, unsigned line, unsigned col) { |
892 | 0 | if (locationMap) |
893 | 0 | (*locationMap)[op] = std::make_pair(line, col); |
894 | 0 | } |
895 | | |
896 | | private: |
897 | | /// Collection of OpAsm interfaces implemented in the context. |
898 | | DialectInterfaceCollection<OpAsmDialectInterface> interfaces; |
899 | | |
900 | | /// The state used for attribute and type aliases. |
901 | | AliasState aliasState; |
902 | | |
903 | | /// The state used for SSA value names. |
904 | | SSANameState nameState; |
905 | | |
906 | | /// An optional location map to be populated. |
907 | | AsmState::LocationMap *locationMap; |
908 | | }; |
909 | | } // end namespace detail |
910 | | } // end namespace mlir |
911 | | |
912 | | AsmState::AsmState(Operation *op, LocationMap *locationMap) |
913 | 0 | : impl(std::make_unique<AsmStateImpl>(op, locationMap)) {} |
914 | 0 | AsmState::~AsmState() {} |
915 | | |
916 | | //===----------------------------------------------------------------------===// |
917 | | // ModulePrinter |
918 | | //===----------------------------------------------------------------------===// |
919 | | |
920 | | namespace { |
921 | | class ModulePrinter { |
922 | | public: |
923 | | ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, |
924 | | AsmStateImpl *state = nullptr) |
925 | 0 | : os(os), printerFlags(flags), state(state) {} |
926 | | explicit ModulePrinter(ModulePrinter &printer) |
927 | | : os(printer.os), printerFlags(printer.printerFlags), |
928 | 0 | state(printer.state) {} |
929 | | |
930 | | /// Returns the output stream of the printer. |
931 | 0 | raw_ostream &getStream() { return os; } |
932 | | |
933 | | template <typename Container, typename UnaryFunctor> |
934 | 0 | inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { |
935 | 0 | llvm::interleaveComma(c, os, each_fn); |
936 | 0 | } Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefISt4pairIN4mlir10IdentifierENS5_9AttributeEEEEZNS0_14printAttributeES7_NS0_15AttrTypeElisionEE3$_9EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir9AttributeEEEZNS0_14printAttributeES5_NS0_15AttrTypeElisionEE4$_10EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir4TypeEEEZNS0_9printTypeES5_E4$_14EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir4TypeEEEZNS0_9printTypeES5_E4$_15EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir4TypeEEEZNS0_9printTypeES5_E4$_16EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir10AffineExprEEEZNS0_14printAffineMapENS4_9AffineMapEE4$_19EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4mlir10ValueRangeEZNS_16OperationPrinter24printSuccessorAndUseListEPNS2_5BlockES3_E4$_26EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4mlir10ValueRangeEZNS_16OperationPrinter24printSuccessorAndUseListEPNS2_5BlockES3_E4$_27EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm11SmallVectorISt4pairIN4mlir10IdentifierENS5_9AttributeEELj8EEEZNS0_21printOptionalAttrDictENS2_8ArrayRefIS8_EENSA_INS2_9StringRefEEEbE4$_18EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4mlir12OperandRangeEZNS_16OperationPrinter14printGenericOpEPNS2_9OperationEE4$_21EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4mlir14SuccessorRangeEZNS_16OperationPrinter14printGenericOpEPNS2_9OperationEE4$_22EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm15MutableArrayRefIN4mlir6RegionEEEZNS_16OperationPrinter14printGenericOpEPNS4_9OperationEE4$_23EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm8ArrayRefIN4mlir10AffineExprEEEZNS_16OperationPrinter22printAffineMapOfSSAIdsENS4_13AffineMapAttrENS4_10ValueRangeEE4$_29EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm14iterator_rangeINS2_6detail23value_sequence_iteratorIiEEEEZNS_16OperationPrinter14printOperationEPN4mlir9OperationEE4$_20EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm15MutableArrayRefIN4mlir13BlockArgumentEEEZNS_16OperationPrinter5printEPNS4_5BlockEbbE4$_24EEvRKT_T0_ Unexecuted instantiation: AsmPrinter.cpp:_ZNK12_GLOBAL__N_113ModulePrinter15interleaveCommaIN4llvm11SmallVectorISt4pairIjPN4mlir5BlockEELj4EEEZNS_16OperationPrinter5printES7_bbE4$_25EEvRKT_T0_ |
937 | | |
938 | | /// This enum describes the different kinds of elision for the type of an |
939 | | /// attribute when printing it. |
940 | | enum class AttrTypeElision { |
941 | | /// The type must not be elided, |
942 | | Never, |
943 | | /// The type may be elided when it matches the default used in the parser |
944 | | /// (for example i64 is the default for integer attributes). |
945 | | May, |
946 | | /// The type must be elided. |
947 | | Must |
948 | | }; |
949 | | |
950 | | /// Print the given attribute. |
951 | | void printAttribute(Attribute attr, |
952 | | AttrTypeElision typeElision = AttrTypeElision::Never); |
953 | | |
954 | | void printType(Type type); |
955 | | void printLocation(LocationAttr loc); |
956 | | |
957 | | void printAffineMap(AffineMap map); |
958 | | void |
959 | | printAffineExpr(AffineExpr expr, |
960 | | function_ref<void(unsigned, bool)> printValueName = nullptr); |
961 | | void printAffineConstraint(AffineExpr expr, bool isEq); |
962 | | void printIntegerSet(IntegerSet set); |
963 | | |
964 | | protected: |
965 | | void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
966 | | ArrayRef<StringRef> elidedAttrs = {}, |
967 | | bool withKeyword = false); |
968 | | void printNamedAttribute(NamedAttribute attr); |
969 | | void printTrailingLocation(Location loc); |
970 | | void printLocationInternal(LocationAttr loc, bool pretty = false); |
971 | | |
972 | | /// Print a dense elements attribute. If 'allowHex' is true, a hex string is |
973 | | /// used instead of individual elements when the elements attr is large. |
974 | | void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); |
975 | | |
976 | | /// Print a dense string elements attribute. |
977 | | void printDenseStringElementsAttr(DenseStringElementsAttr attr); |
978 | | |
979 | | /// Print a dense elements attribute. If 'allowHex' is true, a hex string is |
980 | | /// used instead of individual elements when the elements attr is large. |
981 | | void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, |
982 | | bool allowHex); |
983 | | |
984 | | void printDialectAttribute(Attribute attr); |
985 | | void printDialectType(Type type); |
986 | | |
987 | | /// This enum is used to represent the binding strength of the enclosing |
988 | | /// context that an AffineExprStorage is being printed in, so we can |
989 | | /// intelligently produce parens. |
990 | | enum class BindingStrength { |
991 | | Weak, // + and - |
992 | | Strong, // All other binary operators. |
993 | | }; |
994 | | void printAffineExprInternal( |
995 | | AffineExpr expr, BindingStrength enclosingTightness, |
996 | | function_ref<void(unsigned, bool)> printValueName = nullptr); |
997 | | |
998 | | /// The output stream for the printer. |
999 | | raw_ostream &os; |
1000 | | |
1001 | | /// A set of flags to control the printer's behavior. |
1002 | | OpPrintingFlags printerFlags; |
1003 | | |
1004 | | /// An optional printer state for the module. |
1005 | | AsmStateImpl *state; |
1006 | | |
1007 | | /// A tracker for the number of new lines emitted during printing. |
1008 | | NewLineCounter newLine; |
1009 | | }; |
1010 | | } // end anonymous namespace |
1011 | | |
1012 | 0 | void ModulePrinter::printTrailingLocation(Location loc) { |
1013 | 0 | // Check to see if we are printing debug information. |
1014 | 0 | if (!printerFlags.shouldPrintDebugInfo()) |
1015 | 0 | return; |
1016 | 0 | |
1017 | 0 | os << " "; |
1018 | 0 | printLocation(loc); |
1019 | 0 | } |
1020 | | |
1021 | 0 | void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { |
1022 | 0 | switch (loc.getKind()) { |
1023 | 0 | case StandardAttributes::OpaqueLocation: |
1024 | 0 | printLocationInternal(loc.cast<OpaqueLoc>().getFallbackLocation(), pretty); |
1025 | 0 | break; |
1026 | 0 | case StandardAttributes::UnknownLocation: |
1027 | 0 | if (pretty) |
1028 | 0 | os << "[unknown]"; |
1029 | 0 | else |
1030 | 0 | os << "unknown"; |
1031 | 0 | break; |
1032 | 0 | case StandardAttributes::FileLineColLocation: { |
1033 | 0 | auto fileLoc = loc.cast<FileLineColLoc>(); |
1034 | 0 | auto mayQuote = pretty ? "" : "\""; |
1035 | 0 | os << mayQuote << fileLoc.getFilename() << mayQuote << ':' |
1036 | 0 | << fileLoc.getLine() << ':' << fileLoc.getColumn(); |
1037 | 0 | break; |
1038 | 0 | } |
1039 | 0 | case StandardAttributes::NameLocation: { |
1040 | 0 | auto nameLoc = loc.cast<NameLoc>(); |
1041 | 0 | os << '\"' << nameLoc.getName() << '\"'; |
1042 | 0 |
|
1043 | 0 | // Print the child if it isn't unknown. |
1044 | 0 | auto childLoc = nameLoc.getChildLoc(); |
1045 | 0 | if (!childLoc.isa<UnknownLoc>()) { |
1046 | 0 | os << '('; |
1047 | 0 | printLocationInternal(childLoc, pretty); |
1048 | 0 | os << ')'; |
1049 | 0 | } |
1050 | 0 | break; |
1051 | 0 | } |
1052 | 0 | case StandardAttributes::CallSiteLocation: { |
1053 | 0 | auto callLocation = loc.cast<CallSiteLoc>(); |
1054 | 0 | auto caller = callLocation.getCaller(); |
1055 | 0 | auto callee = callLocation.getCallee(); |
1056 | 0 | if (!pretty) |
1057 | 0 | os << "callsite("; |
1058 | 0 | printLocationInternal(callee, pretty); |
1059 | 0 | if (pretty) { |
1060 | 0 | if (callee.isa<NameLoc>()) { |
1061 | 0 | if (caller.isa<FileLineColLoc>()) { |
1062 | 0 | os << " at "; |
1063 | 0 | } else { |
1064 | 0 | os << newLine << " at "; |
1065 | 0 | } |
1066 | 0 | } else { |
1067 | 0 | os << newLine << " at "; |
1068 | 0 | } |
1069 | 0 | } else { |
1070 | 0 | os << " at "; |
1071 | 0 | } |
1072 | 0 | printLocationInternal(caller, pretty); |
1073 | 0 | if (!pretty) |
1074 | 0 | os << ")"; |
1075 | 0 | break; |
1076 | 0 | } |
1077 | 0 | case StandardAttributes::FusedLocation: { |
1078 | 0 | auto fusedLoc = loc.cast<FusedLoc>(); |
1079 | 0 | if (!pretty) |
1080 | 0 | os << "fused"; |
1081 | 0 | if (auto metadata = fusedLoc.getMetadata()) |
1082 | 0 | os << '<' << metadata << '>'; |
1083 | 0 | os << '['; |
1084 | 0 | interleave( |
1085 | 0 | fusedLoc.getLocations(), |
1086 | 0 | [&](Location loc) { printLocationInternal(loc, pretty); }, |
1087 | 0 | [&]() { os << ", "; }); |
1088 | 0 | os << ']'; |
1089 | 0 | break; |
1090 | 0 | } |
1091 | 0 | } |
1092 | 0 | } |
1093 | | |
1094 | | /// Print a floating point value in a way that the parser will be able to |
1095 | | /// round-trip losslessly. |
1096 | 0 | static void printFloatValue(const APFloat &apValue, raw_ostream &os) { |
1097 | 0 | // We would like to output the FP constant value in exponential notation, |
1098 | 0 | // but we cannot do this if doing so will lose precision. Check here to |
1099 | 0 | // make sure that we only output it in exponential format if we can parse |
1100 | 0 | // the value back and get the same value. |
1101 | 0 | bool isInf = apValue.isInfinity(); |
1102 | 0 | bool isNaN = apValue.isNaN(); |
1103 | 0 | if (!isInf && !isNaN) { |
1104 | 0 | SmallString<128> strValue; |
1105 | 0 | apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, |
1106 | 0 | /*TruncateZero=*/false); |
1107 | 0 |
|
1108 | 0 | // Check to make sure that the stringized number is not some string like |
1109 | 0 | // "Inf" or NaN, that atof will accept, but the lexer will not. Check |
1110 | 0 | // that the string matches the "[-+]?[0-9]" regex. |
1111 | 0 | assert(((strValue[0] >= '0' && strValue[0] <= '9') || |
1112 | 0 | ((strValue[0] == '-' || strValue[0] == '+') && |
1113 | 0 | (strValue[1] >= '0' && strValue[1] <= '9'))) && |
1114 | 0 | "[-+]?[0-9] regex does not match!"); |
1115 | 0 |
|
1116 | 0 | // Parse back the stringized version and check that the value is equal |
1117 | 0 | // (i.e., there is no precision loss). |
1118 | 0 | if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { |
1119 | 0 | os << strValue; |
1120 | 0 | return; |
1121 | 0 | } |
1122 | 0 | |
1123 | 0 | // If it is not, use the default format of APFloat instead of the |
1124 | 0 | // exponential notation. |
1125 | 0 | strValue.clear(); |
1126 | 0 | apValue.toString(strValue); |
1127 | 0 |
|
1128 | 0 | // Make sure that we can parse the default form as a float. |
1129 | 0 | if (StringRef(strValue).contains('.')) { |
1130 | 0 | os << strValue; |
1131 | 0 | return; |
1132 | 0 | } |
1133 | 0 | } |
1134 | 0 | |
1135 | 0 | // Print special values in hexadecimal format. The sign bit should be included |
1136 | 0 | // in the literal. |
1137 | 0 | SmallVector<char, 16> str; |
1138 | 0 | APInt apInt = apValue.bitcastToAPInt(); |
1139 | 0 | apInt.toString(str, /*Radix=*/16, /*Signed=*/false, |
1140 | 0 | /*formatAsCLiteral=*/true); |
1141 | 0 | os << str; |
1142 | 0 | } |
1143 | | |
1144 | 0 | void ModulePrinter::printLocation(LocationAttr loc) { |
1145 | 0 | if (printerFlags.shouldPrintDebugInfoPrettyForm()) { |
1146 | 0 | printLocationInternal(loc, /*pretty=*/true); |
1147 | 0 | } else { |
1148 | 0 | os << "loc("; |
1149 | 0 | printLocationInternal(loc); |
1150 | 0 | os << ')'; |
1151 | 0 | } |
1152 | 0 | } |
1153 | | |
1154 | | /// Returns if the given dialect symbol data is simple enough to print in the |
1155 | | /// pretty form, i.e. without the enclosing "". |
1156 | 0 | static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { |
1157 | 0 | // The name must start with an identifier. |
1158 | 0 | if (symName.empty() || !isalpha(symName.front())) |
1159 | 0 | return false; |
1160 | 0 | |
1161 | 0 | // Ignore all the characters that are valid in an identifier in the symbol |
1162 | 0 | // name. |
1163 | 0 | symName = symName.drop_while( |
1164 | 0 | [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); |
1165 | 0 | if (symName.empty()) |
1166 | 0 | return true; |
1167 | 0 | |
1168 | 0 | // If we got to an unexpected character, then it must be a <>. Check those |
1169 | 0 | // recursively. |
1170 | 0 | if (symName.front() != '<' || symName.back() != '>') |
1171 | 0 | return false; |
1172 | 0 | |
1173 | 0 | SmallVector<char, 8> nestedPunctuation; |
1174 | 0 | do { |
1175 | 0 | // If we ran out of characters, then we had a punctuation mismatch. |
1176 | 0 | if (symName.empty()) |
1177 | 0 | return false; |
1178 | 0 | |
1179 | 0 | auto c = symName.front(); |
1180 | 0 | symName = symName.drop_front(); |
1181 | 0 |
|
1182 | 0 | switch (c) { |
1183 | 0 | // We never allow null characters. This is an EOF indicator for the lexer |
1184 | 0 | // which we could handle, but isn't important for any known dialect. |
1185 | 0 | case '\0': |
1186 | 0 | return false; |
1187 | 0 | case '<': |
1188 | 0 | case '[': |
1189 | 0 | case '(': |
1190 | 0 | case '{': |
1191 | 0 | nestedPunctuation.push_back(c); |
1192 | 0 | continue; |
1193 | 0 | case '-': |
1194 | 0 | // Treat `->` as a special token. |
1195 | 0 | if (!symName.empty() && symName.front() == '>') { |
1196 | 0 | symName = symName.drop_front(); |
1197 | 0 | continue; |
1198 | 0 | } |
1199 | 0 | break; |
1200 | 0 | // Reject types with mismatched brackets. |
1201 | 0 | case '>': |
1202 | 0 | if (nestedPunctuation.pop_back_val() != '<') |
1203 | 0 | return false; |
1204 | 0 | break; |
1205 | 0 | case ']': |
1206 | 0 | if (nestedPunctuation.pop_back_val() != '[') |
1207 | 0 | return false; |
1208 | 0 | break; |
1209 | 0 | case ')': |
1210 | 0 | if (nestedPunctuation.pop_back_val() != '(') |
1211 | 0 | return false; |
1212 | 0 | break; |
1213 | 0 | case '}': |
1214 | 0 | if (nestedPunctuation.pop_back_val() != '{') |
1215 | 0 | return false; |
1216 | 0 | break; |
1217 | 0 | default: |
1218 | 0 | continue; |
1219 | 0 | } |
1220 | 0 | |
1221 | 0 | // We're done when the punctuation is fully matched. |
1222 | 0 | } while (!nestedPunctuation.empty()); |
1223 | 0 |
|
1224 | 0 | // If there were extra characters, then we failed. |
1225 | 0 | return symName.empty(); |
1226 | 0 | } |
1227 | | |
1228 | | /// Print the given dialect symbol to the stream. |
1229 | | static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, |
1230 | 0 | StringRef dialectName, StringRef symString) { |
1231 | 0 | os << symPrefix << dialectName; |
1232 | 0 |
|
1233 | 0 | // If this symbol name is simple enough, print it directly in pretty form, |
1234 | 0 | // otherwise, we print it as an escaped string. |
1235 | 0 | if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { |
1236 | 0 | os << '.' << symString; |
1237 | 0 | return; |
1238 | 0 | } |
1239 | 0 | |
1240 | 0 | // TODO: escape the symbol name, it could contain " characters. |
1241 | 0 | os << "<\"" << symString << "\">"; |
1242 | 0 | } |
1243 | | |
1244 | | /// Returns if the given string can be represented as a bare identifier. |
1245 | 0 | static bool isBareIdentifier(StringRef name) { |
1246 | 0 | assert(!name.empty() && "invalid name"); |
1247 | 0 |
|
1248 | 0 | // By making this unsigned, the value passed in to isalnum will always be |
1249 | 0 | // in the range 0-255. This is important when building with MSVC because |
1250 | 0 | // its implementation will assert. This situation can arise when dealing |
1251 | 0 | // with UTF-8 multibyte characters. |
1252 | 0 | unsigned char firstChar = static_cast<unsigned char>(name[0]); |
1253 | 0 | if (!isalpha(firstChar) && firstChar != '_') |
1254 | 0 | return false; |
1255 | 0 | return llvm::all_of(name.drop_front(), [](unsigned char c) { |
1256 | 0 | return isalnum(c) || c == '_' || c == '$' || c == '.'; |
1257 | 0 | }); |
1258 | 0 | } |
1259 | | |
1260 | | /// Print the given string as a symbol reference. A symbol reference is |
1261 | | /// represented as a string prefixed with '@'. The reference is surrounded with |
1262 | | /// ""'s and escaped if it has any special or non-printable characters in it. |
1263 | 0 | static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { |
1264 | 0 | assert(!symbolRef.empty() && "expected valid symbol reference"); |
1265 | 0 |
|
1266 | 0 | // If the symbol can be represented as a bare identifier, write it directly. |
1267 | 0 | if (isBareIdentifier(symbolRef)) { |
1268 | 0 | os << '@' << symbolRef; |
1269 | 0 | return; |
1270 | 0 | } |
1271 | 0 | |
1272 | 0 | // Otherwise, output the reference wrapped in quotes with proper escaping. |
1273 | 0 | os << "@\""; |
1274 | 0 | printEscapedString(symbolRef, os); |
1275 | 0 | os << '"'; |
1276 | 0 | } |
1277 | | |
1278 | | // Print out a valid ElementsAttr that is succinct and can represent any |
1279 | | // potential shape/type, for use when eliding a large ElementsAttr. |
1280 | | // |
1281 | | // We choose to use an opaque ElementsAttr literal with conspicuous content to |
1282 | | // hopefully alert readers to the fact that this has been elided. |
1283 | | // |
1284 | | // Unfortunately, neither of the strings of an opaque ElementsAttr literal will |
1285 | | // accept the string "elided". The first string must be a registered dialect |
1286 | | // name and the latter must be a hex constant. |
1287 | 0 | static void printElidedElementsAttr(raw_ostream &os) { |
1288 | 0 | os << R"(opaque<"", "0xDEADBEEF">)"; |
1289 | 0 | } |
1290 | | |
1291 | | void ModulePrinter::printAttribute(Attribute attr, |
1292 | 0 | AttrTypeElision typeElision) { |
1293 | 0 | if (!attr) { |
1294 | 0 | os << "<<NULL ATTRIBUTE>>"; |
1295 | 0 | return; |
1296 | 0 | } |
1297 | 0 | |
1298 | 0 | // Check for an alias for this attribute. |
1299 | 0 | if (state) { |
1300 | 0 | Twine alias = state->getAliasState().getAttributeAlias(attr); |
1301 | 0 | if (!alias.isTriviallyEmpty()) { |
1302 | 0 | os << '#' << alias; |
1303 | 0 | return; |
1304 | 0 | } |
1305 | 0 | } |
1306 | 0 | |
1307 | 0 | auto attrType = attr.getType(); |
1308 | 0 | switch (attr.getKind()) { |
1309 | 0 | default: |
1310 | 0 | return printDialectAttribute(attr); |
1311 | 0 |
|
1312 | 0 | case StandardAttributes::Opaque: { |
1313 | 0 | auto opaqueAttr = attr.cast<OpaqueAttr>(); |
1314 | 0 | printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), |
1315 | 0 | opaqueAttr.getAttrData()); |
1316 | 0 | break; |
1317 | 0 | } |
1318 | 0 | case StandardAttributes::Unit: |
1319 | 0 | os << "unit"; |
1320 | 0 | break; |
1321 | 0 | case StandardAttributes::Bool: |
1322 | 0 | os << (attr.cast<BoolAttr>().getValue() ? "true" : "false"); |
1323 | 0 |
|
1324 | 0 | // BoolAttr always elides the type. |
1325 | 0 | return; |
1326 | 0 | case StandardAttributes::Dictionary: |
1327 | 0 | os << '{'; |
1328 | 0 | interleaveComma(attr.cast<DictionaryAttr>().getValue(), |
1329 | 0 | [&](NamedAttribute attr) { printNamedAttribute(attr); }); |
1330 | 0 | os << '}'; |
1331 | 0 | break; |
1332 | 0 | case StandardAttributes::Integer: { |
1333 | 0 | auto intAttr = attr.cast<IntegerAttr>(); |
1334 | 0 | // Only print attributes as unsigned if they are explicitly unsigned or are |
1335 | 0 | // signless 1-bit values. Indexes, signed values, and multi-bit signless |
1336 | 0 | // values print as signed. |
1337 | 0 | bool isUnsigned = |
1338 | 0 | attrType.isUnsignedInteger() || attrType.isSignlessInteger(1); |
1339 | 0 | intAttr.getValue().print(os, !isUnsigned); |
1340 | 0 |
|
1341 | 0 | // IntegerAttr elides the type if I64. |
1342 | 0 | if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64)) |
1343 | 0 | return; |
1344 | 0 | break; |
1345 | 0 | } |
1346 | 0 | case StandardAttributes::Float: { |
1347 | 0 | auto floatAttr = attr.cast<FloatAttr>(); |
1348 | 0 | printFloatValue(floatAttr.getValue(), os); |
1349 | 0 |
|
1350 | 0 | // FloatAttr elides the type if F64. |
1351 | 0 | if (typeElision == AttrTypeElision::May && attrType.isF64()) |
1352 | 0 | return; |
1353 | 0 | break; |
1354 | 0 | } |
1355 | 0 | case StandardAttributes::String: |
1356 | 0 | os << '"'; |
1357 | 0 | printEscapedString(attr.cast<StringAttr>().getValue(), os); |
1358 | 0 | os << '"'; |
1359 | 0 | break; |
1360 | 0 | case StandardAttributes::Array: |
1361 | 0 | os << '['; |
1362 | 0 | interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) { |
1363 | 0 | printAttribute(attr, AttrTypeElision::May); |
1364 | 0 | }); |
1365 | 0 | os << ']'; |
1366 | 0 | break; |
1367 | 0 | case StandardAttributes::AffineMap: |
1368 | 0 | os << "affine_map<"; |
1369 | 0 | attr.cast<AffineMapAttr>().getValue().print(os); |
1370 | 0 | os << '>'; |
1371 | 0 |
|
1372 | 0 | // AffineMap always elides the type. |
1373 | 0 | return; |
1374 | 0 | case StandardAttributes::IntegerSet: |
1375 | 0 | os << "affine_set<"; |
1376 | 0 | attr.cast<IntegerSetAttr>().getValue().print(os); |
1377 | 0 | os << '>'; |
1378 | 0 |
|
1379 | 0 | // IntegerSet always elides the type. |
1380 | 0 | return; |
1381 | 0 | case StandardAttributes::Type: |
1382 | 0 | printType(attr.cast<TypeAttr>().getValue()); |
1383 | 0 | break; |
1384 | 0 | case StandardAttributes::SymbolRef: { |
1385 | 0 | auto refAttr = attr.dyn_cast<SymbolRefAttr>(); |
1386 | 0 | printSymbolReference(refAttr.getRootReference(), os); |
1387 | 0 | for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { |
1388 | 0 | os << "::"; |
1389 | 0 | printSymbolReference(nestedRef.getValue(), os); |
1390 | 0 | } |
1391 | 0 | break; |
1392 | 0 | } |
1393 | 0 | case StandardAttributes::OpaqueElements: { |
1394 | 0 | auto eltsAttr = attr.cast<OpaqueElementsAttr>(); |
1395 | 0 | if (printerFlags.shouldElideElementsAttr(eltsAttr)) { |
1396 | 0 | printElidedElementsAttr(os); |
1397 | 0 | break; |
1398 | 0 | } |
1399 | 0 | os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", "; |
1400 | 0 | os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">"; |
1401 | 0 | break; |
1402 | 0 | } |
1403 | 0 | case StandardAttributes::DenseIntOrFPElements: { |
1404 | 0 | auto eltsAttr = attr.cast<DenseIntOrFPElementsAttr>(); |
1405 | 0 | if (printerFlags.shouldElideElementsAttr(eltsAttr)) { |
1406 | 0 | printElidedElementsAttr(os); |
1407 | 0 | break; |
1408 | 0 | } |
1409 | 0 | os << "dense<"; |
1410 | 0 | printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true); |
1411 | 0 | os << '>'; |
1412 | 0 | break; |
1413 | 0 | } |
1414 | 0 | case StandardAttributes::DenseStringElements: { |
1415 | 0 | auto eltsAttr = attr.cast<DenseStringElementsAttr>(); |
1416 | 0 | if (printerFlags.shouldElideElementsAttr(eltsAttr)) { |
1417 | 0 | printElidedElementsAttr(os); |
1418 | 0 | break; |
1419 | 0 | } |
1420 | 0 | os << "dense<"; |
1421 | 0 | printDenseStringElementsAttr(eltsAttr); |
1422 | 0 | os << '>'; |
1423 | 0 | break; |
1424 | 0 | } |
1425 | 0 | case StandardAttributes::SparseElements: { |
1426 | 0 | auto elementsAttr = attr.cast<SparseElementsAttr>(); |
1427 | 0 | if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) || |
1428 | 0 | printerFlags.shouldElideElementsAttr(elementsAttr.getValues())) { |
1429 | 0 | printElidedElementsAttr(os); |
1430 | 0 | break; |
1431 | 0 | } |
1432 | 0 | os << "sparse<"; |
1433 | 0 | printDenseIntOrFPElementsAttr(elementsAttr.getIndices(), |
1434 | 0 | /*allowHex=*/false); |
1435 | 0 | os << ", "; |
1436 | 0 | printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true); |
1437 | 0 | os << '>'; |
1438 | 0 | break; |
1439 | 0 | } |
1440 | 0 |
|
1441 | 0 | // Location attributes. |
1442 | 0 | case StandardAttributes::CallSiteLocation: |
1443 | 0 | case StandardAttributes::FileLineColLocation: |
1444 | 0 | case StandardAttributes::FusedLocation: |
1445 | 0 | case StandardAttributes::NameLocation: |
1446 | 0 | case StandardAttributes::OpaqueLocation: |
1447 | 0 | case StandardAttributes::UnknownLocation: |
1448 | 0 | printLocation(attr.cast<LocationAttr>()); |
1449 | 0 | break; |
1450 | 0 | } |
1451 | 0 | |
1452 | 0 | // Don't print the type if we must elide it, or if it is a None type. |
1453 | 0 | if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) { |
1454 | 0 | os << " : "; |
1455 | 0 | printType(attrType); |
1456 | 0 | } |
1457 | 0 | } |
1458 | | |
1459 | | /// Print the integer element of a DenseElementsAttr. |
1460 | | static void printDenseIntElement(const APInt &value, raw_ostream &os, |
1461 | 0 | bool isSigned) { |
1462 | 0 | if (value.getBitWidth() == 1) |
1463 | 0 | os << (value.getBoolValue() ? "true" : "false"); |
1464 | 0 | else |
1465 | 0 | value.print(os, isSigned); |
1466 | 0 | } |
1467 | | |
1468 | | static void |
1469 | | printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, |
1470 | 0 | function_ref<void(unsigned)> printEltFn) { |
1471 | 0 | // Special case for 0-d and splat tensors. |
1472 | 0 | if (isSplat) |
1473 | 0 | return printEltFn(0); |
1474 | 0 | |
1475 | 0 | // Special case for degenerate tensors. |
1476 | 0 | auto numElements = type.getNumElements(); |
1477 | 0 | int64_t rank = type.getRank(); |
1478 | 0 | if (numElements == 0) { |
1479 | 0 | for (int i = 0; i < rank; ++i) |
1480 | 0 | os << '['; |
1481 | 0 | for (int i = 0; i < rank; ++i) |
1482 | 0 | os << ']'; |
1483 | 0 | return; |
1484 | 0 | } |
1485 | 0 |
|
1486 | 0 | // We use a mixed-radix counter to iterate through the shape. When we bump a |
1487 | 0 | // non-least-significant digit, we emit a close bracket. When we next emit an |
1488 | 0 | // element we re-open all closed brackets. |
1489 | 0 |
|
1490 | 0 | // The mixed-radix counter, with radices in 'shape'. |
1491 | 0 | SmallVector<unsigned, 4> counter(rank, 0); |
1492 | 0 | // The number of brackets that have been opened and not closed. |
1493 | 0 | unsigned openBrackets = 0; |
1494 | 0 |
|
1495 | 0 | auto shape = type.getShape(); |
1496 | 0 | auto bumpCounter = [&] { |
1497 | 0 | // Bump the least significant digit. |
1498 | 0 | ++counter[rank - 1]; |
1499 | 0 | // Iterate backwards bubbling back the increment. |
1500 | 0 | for (unsigned i = rank - 1; i > 0; --i) |
1501 | 0 | if (counter[i] >= shape[i]) { |
1502 | 0 | // Index 'i' is rolled over. Bump (i-1) and close a bracket. |
1503 | 0 | counter[i] = 0; |
1504 | 0 | ++counter[i - 1]; |
1505 | 0 | --openBrackets; |
1506 | 0 | os << ']'; |
1507 | 0 | } |
1508 | 0 | }; |
1509 | 0 |
|
1510 | 0 | for (unsigned idx = 0, e = numElements; idx != e; ++idx) { |
1511 | 0 | if (idx != 0) |
1512 | 0 | os << ", "; |
1513 | 0 | while (openBrackets++ < rank) |
1514 | 0 | os << '['; |
1515 | 0 | openBrackets = rank; |
1516 | 0 | printEltFn(idx); |
1517 | 0 | bumpCounter(); |
1518 | 0 | } |
1519 | 0 | while (openBrackets-- > 0) |
1520 | 0 | os << ']'; |
1521 | 0 | } |
1522 | | |
1523 | | void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, |
1524 | 0 | bool allowHex) { |
1525 | 0 | if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) |
1526 | 0 | return printDenseStringElementsAttr(stringAttr); |
1527 | 0 | |
1528 | 0 | printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(), |
1529 | 0 | allowHex); |
1530 | 0 | } |
1531 | | |
1532 | | void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, |
1533 | 0 | bool allowHex) { |
1534 | 0 | auto type = attr.getType(); |
1535 | 0 | auto elementType = type.getElementType(); |
1536 | 0 |
|
1537 | 0 | // Check to see if we should format this attribute as a hex string. |
1538 | 0 | auto numElements = type.getNumElements(); |
1539 | 0 | if (!attr.isSplat() && allowHex && |
1540 | 0 | shouldPrintElementsAttrWithHex(numElements)) { |
1541 | 0 | ArrayRef<char> rawData = attr.getRawData(); |
1542 | 0 | os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size())) |
1543 | 0 | << "\""; |
1544 | 0 | return; |
1545 | 0 | } |
1546 | 0 | |
1547 | 0 | if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { |
1548 | 0 | auto printComplexValue = [&](auto complexValues, auto printFn, |
1549 | 0 | raw_ostream &os, auto &&... params) { |
1550 | 0 | printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
1551 | 0 | auto complexValue = *(complexValues.begin() + index); |
1552 | 0 | os << "("; |
1553 | 0 | printFn(complexValue.real(), os, params...); |
1554 | 0 | os << ","; |
1555 | 0 | printFn(complexValue.imag(), os, params...); |
1556 | 0 | os << ")"; |
1557 | 0 | }); Unexecuted instantiation: AsmPrinter.cpp:_ZZZN12_GLOBAL__N_113ModulePrinter29printDenseIntOrFPElementsAttrEN4mlir24DenseIntOrFPElementsAttrEbENK3$_1clIN4llvm14iterator_rangeINS1_17DenseElementsAttr25ComplexIntElementIteratorEEEPFvRKNS5_5APIntERNS5_11raw_ostreamEbEJbEEEDaT_T0_SE_DpOT1_ENKUljE_clEj Unexecuted instantiation: AsmPrinter.cpp:_ZZZN12_GLOBAL__N_113ModulePrinter29printDenseIntOrFPElementsAttrEN4mlir24DenseIntOrFPElementsAttrEbENK3$_1clIN4llvm14iterator_rangeINS1_17DenseElementsAttr27ComplexFloatElementIteratorEEEPFvRKNS5_7APFloatERNS5_11raw_ostreamEEJEEEDaT_T0_SE_DpOT1_ENKUljE_clEj |
1558 | 0 | }; Unexecuted instantiation: AsmPrinter.cpp:_ZZN12_GLOBAL__N_113ModulePrinter29printDenseIntOrFPElementsAttrEN4mlir24DenseIntOrFPElementsAttrEbENK3$_1clIN4llvm14iterator_rangeINS1_17DenseElementsAttr25ComplexIntElementIteratorEEEPFvRKNS5_5APIntERNS5_11raw_ostreamEbEJbEEEDaT_T0_SE_DpOT1_ Unexecuted instantiation: AsmPrinter.cpp:_ZZN12_GLOBAL__N_113ModulePrinter29printDenseIntOrFPElementsAttrEN4mlir24DenseIntOrFPElementsAttrEbENK3$_1clIN4llvm14iterator_rangeINS1_17DenseElementsAttr27ComplexFloatElementIteratorEEEPFvRKNS5_7APFloatERNS5_11raw_ostreamEEJEEEDaT_T0_SE_DpOT1_ |
1559 | 0 |
|
1560 | 0 | Type complexElementType = complexTy.getElementType(); |
1561 | 0 | if (complexElementType.isa<IntegerType>()) |
1562 | 0 | printComplexValue(attr.getComplexIntValues(), printDenseIntElement, os, |
1563 | 0 | /*isSigned=*/!complexElementType.isUnsignedInteger()); |
1564 | 0 | else |
1565 | 0 | printComplexValue(attr.getComplexFloatValues(), printFloatValue, os); |
1566 | 0 | } else if (elementType.isIntOrIndex()) { |
1567 | 0 | bool isSigned = !elementType.isUnsignedInteger(); |
1568 | 0 | auto intValues = attr.getIntValues(); |
1569 | 0 | printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
1570 | 0 | printDenseIntElement(*(intValues.begin() + index), os, isSigned); |
1571 | 0 | }); |
1572 | 0 | } else { |
1573 | 0 | assert(elementType.isa<FloatType>() && "unexpected element type"); |
1574 | 0 | auto floatValues = attr.getFloatValues(); |
1575 | 0 | printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
1576 | 0 | printFloatValue(*(floatValues.begin() + index), os); |
1577 | 0 | }); |
1578 | 0 | } |
1579 | 0 | } |
1580 | | |
1581 | 0 | void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) { |
1582 | 0 | ArrayRef<StringRef> data = attr.getRawStringData(); |
1583 | 0 | auto printFn = [&](unsigned index) { |
1584 | 0 | os << "\""; |
1585 | 0 | printEscapedString(data[index], os); |
1586 | 0 | os << "\""; |
1587 | 0 | }; |
1588 | 0 | printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); |
1589 | 0 | } |
1590 | | |
1591 | 0 | void ModulePrinter::printType(Type type) { |
1592 | 0 | if (!type) { |
1593 | 0 | os << "<<NULL TYPE>>"; |
1594 | 0 | return; |
1595 | 0 | } |
1596 | 0 | |
1597 | 0 | // Check for an alias for this type. |
1598 | 0 | if (state) { |
1599 | 0 | StringRef alias = state->getAliasState().getTypeAlias(type); |
1600 | 0 | if (!alias.empty()) { |
1601 | 0 | os << '!' << alias; |
1602 | 0 | return; |
1603 | 0 | } |
1604 | 0 | } |
1605 | 0 | |
1606 | 0 | switch (type.getKind()) { |
1607 | 0 | default: |
1608 | 0 | return printDialectType(type); |
1609 | 0 |
|
1610 | 0 | case Type::Kind::Opaque: { |
1611 | 0 | auto opaqueTy = type.cast<OpaqueType>(); |
1612 | 0 | printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), |
1613 | 0 | opaqueTy.getTypeData()); |
1614 | 0 | return; |
1615 | 0 | } |
1616 | 0 | case StandardTypes::Index: |
1617 | 0 | os << "index"; |
1618 | 0 | return; |
1619 | 0 | case StandardTypes::BF16: |
1620 | 0 | os << "bf16"; |
1621 | 0 | return; |
1622 | 0 | case StandardTypes::F16: |
1623 | 0 | os << "f16"; |
1624 | 0 | return; |
1625 | 0 | case StandardTypes::F32: |
1626 | 0 | os << "f32"; |
1627 | 0 | return; |
1628 | 0 | case StandardTypes::F64: |
1629 | 0 | os << "f64"; |
1630 | 0 | return; |
1631 | 0 |
|
1632 | 0 | case StandardTypes::Integer: { |
1633 | 0 | auto integer = type.cast<IntegerType>(); |
1634 | 0 | if (integer.isSigned()) |
1635 | 0 | os << 's'; |
1636 | 0 | else if (integer.isUnsigned()) |
1637 | 0 | os << 'u'; |
1638 | 0 | os << 'i' << integer.getWidth(); |
1639 | 0 | return; |
1640 | 0 | } |
1641 | 0 | case Type::Kind::Function: { |
1642 | 0 | auto func = type.cast<FunctionType>(); |
1643 | 0 | os << '('; |
1644 | 0 | interleaveComma(func.getInputs(), [&](Type type) { printType(type); }); |
1645 | 0 | os << ") -> "; |
1646 | 0 | auto results = func.getResults(); |
1647 | 0 | if (results.size() == 1 && !results[0].isa<FunctionType>()) |
1648 | 0 | os << results[0]; |
1649 | 0 | else { |
1650 | 0 | os << '('; |
1651 | 0 | interleaveComma(results, [&](Type type) { printType(type); }); |
1652 | 0 | os << ')'; |
1653 | 0 | } |
1654 | 0 | return; |
1655 | 0 | } |
1656 | 0 | case StandardTypes::Vector: { |
1657 | 0 | auto v = type.cast<VectorType>(); |
1658 | 0 | os << "vector<"; |
1659 | 0 | for (auto dim : v.getShape()) |
1660 | 0 | os << dim << 'x'; |
1661 | 0 | os << v.getElementType() << '>'; |
1662 | 0 | return; |
1663 | 0 | } |
1664 | 0 | case StandardTypes::RankedTensor: { |
1665 | 0 | auto v = type.cast<RankedTensorType>(); |
1666 | 0 | os << "tensor<"; |
1667 | 0 | for (auto dim : v.getShape()) { |
1668 | 0 | if (dim < 0) |
1669 | 0 | os << '?'; |
1670 | 0 | else |
1671 | 0 | os << dim; |
1672 | 0 | os << 'x'; |
1673 | 0 | } |
1674 | 0 | os << v.getElementType() << '>'; |
1675 | 0 | return; |
1676 | 0 | } |
1677 | 0 | case StandardTypes::UnrankedTensor: { |
1678 | 0 | auto v = type.cast<UnrankedTensorType>(); |
1679 | 0 | os << "tensor<*x"; |
1680 | 0 | printType(v.getElementType()); |
1681 | 0 | os << '>'; |
1682 | 0 | return; |
1683 | 0 | } |
1684 | 0 | case StandardTypes::MemRef: { |
1685 | 0 | auto v = type.cast<MemRefType>(); |
1686 | 0 | os << "memref<"; |
1687 | 0 | for (auto dim : v.getShape()) { |
1688 | 0 | if (dim < 0) |
1689 | 0 | os << '?'; |
1690 | 0 | else |
1691 | 0 | os << dim; |
1692 | 0 | os << 'x'; |
1693 | 0 | } |
1694 | 0 | printType(v.getElementType()); |
1695 | 0 | for (auto map : v.getAffineMaps()) { |
1696 | 0 | os << ", "; |
1697 | 0 | printAttribute(AffineMapAttr::get(map)); |
1698 | 0 | } |
1699 | 0 | // Only print the memory space if it is the non-default one. |
1700 | 0 | if (v.getMemorySpace()) |
1701 | 0 | os << ", " << v.getMemorySpace(); |
1702 | 0 | os << '>'; |
1703 | 0 | return; |
1704 | 0 | } |
1705 | 0 | case StandardTypes::UnrankedMemRef: { |
1706 | 0 | auto v = type.cast<UnrankedMemRefType>(); |
1707 | 0 | os << "memref<*x"; |
1708 | 0 | printType(v.getElementType()); |
1709 | 0 | os << '>'; |
1710 | 0 | return; |
1711 | 0 | } |
1712 | 0 | case StandardTypes::Complex: |
1713 | 0 | os << "complex<"; |
1714 | 0 | printType(type.cast<ComplexType>().getElementType()); |
1715 | 0 | os << '>'; |
1716 | 0 | return; |
1717 | 0 | case StandardTypes::Tuple: { |
1718 | 0 | auto tuple = type.cast<TupleType>(); |
1719 | 0 | os << "tuple<"; |
1720 | 0 | interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); }); |
1721 | 0 | os << '>'; |
1722 | 0 | return; |
1723 | 0 | } |
1724 | 0 | case StandardTypes::None: |
1725 | 0 | os << "none"; |
1726 | 0 | return; |
1727 | 0 | } |
1728 | 0 | } |
1729 | | |
1730 | | void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
1731 | | ArrayRef<StringRef> elidedAttrs, |
1732 | 0 | bool withKeyword) { |
1733 | 0 | // If there are no attributes, then there is nothing to be done. |
1734 | 0 | if (attrs.empty()) |
1735 | 0 | return; |
1736 | 0 | |
1737 | 0 | // Filter out any attributes that shouldn't be included. |
1738 | 0 | SmallVector<NamedAttribute, 8> filteredAttrs( |
1739 | 0 | llvm::make_filter_range(attrs, [&](NamedAttribute attr) { |
1740 | 0 | return !llvm::is_contained(elidedAttrs, attr.first.strref()); |
1741 | 0 | })); |
1742 | 0 |
|
1743 | 0 | // If there are no attributes left to print after filtering, then we're done. |
1744 | 0 | if (filteredAttrs.empty()) |
1745 | 0 | return; |
1746 | 0 | |
1747 | 0 | // Print the 'attributes' keyword if necessary. |
1748 | 0 | if (withKeyword) |
1749 | 0 | os << " attributes"; |
1750 | 0 |
|
1751 | 0 | // Otherwise, print them all out in braces. |
1752 | 0 | os << " {"; |
1753 | 0 | interleaveComma(filteredAttrs, |
1754 | 0 | [&](NamedAttribute attr) { printNamedAttribute(attr); }); |
1755 | 0 | os << '}'; |
1756 | 0 | } |
1757 | | |
1758 | 0 | void ModulePrinter::printNamedAttribute(NamedAttribute attr) { |
1759 | 0 | if (isBareIdentifier(attr.first)) { |
1760 | 0 | os << attr.first; |
1761 | 0 | } else { |
1762 | 0 | os << '"'; |
1763 | 0 | printEscapedString(attr.first.strref(), os); |
1764 | 0 | os << '"'; |
1765 | 0 | } |
1766 | 0 |
|
1767 | 0 | // Pretty printing elides the attribute value for unit attributes. |
1768 | 0 | if (attr.second.isa<UnitAttr>()) |
1769 | 0 | return; |
1770 | 0 | |
1771 | 0 | os << " = "; |
1772 | 0 | printAttribute(attr.second); |
1773 | 0 | } |
1774 | | |
1775 | | //===----------------------------------------------------------------------===// |
1776 | | // CustomDialectAsmPrinter |
1777 | | //===----------------------------------------------------------------------===// |
1778 | | |
1779 | | namespace { |
1780 | | /// This class provides the main specialization of the DialectAsmPrinter that is |
1781 | | /// used to provide support for print attributes and types. This hooks allows |
1782 | | /// for dialects to hook into the main ModulePrinter. |
1783 | | struct CustomDialectAsmPrinter : public DialectAsmPrinter { |
1784 | | public: |
1785 | 0 | CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {} |
1786 | 0 | ~CustomDialectAsmPrinter() override {} |
1787 | | |
1788 | 0 | raw_ostream &getStream() const override { return printer.getStream(); } |
1789 | | |
1790 | | /// Print the given attribute to the stream. |
1791 | 0 | void printAttribute(Attribute attr) override { printer.printAttribute(attr); } |
1792 | | |
1793 | | /// Print the given floating point value in a stablized form. |
1794 | 0 | void printFloat(const APFloat &value) override { |
1795 | 0 | printFloatValue(value, getStream()); |
1796 | 0 | } |
1797 | | |
1798 | | /// Print the given type to the stream. |
1799 | 0 | void printType(Type type) override { printer.printType(type); } |
1800 | | |
1801 | | /// The main module printer. |
1802 | | ModulePrinter &printer; |
1803 | | }; |
1804 | | } // end anonymous namespace |
1805 | | |
1806 | 0 | void ModulePrinter::printDialectAttribute(Attribute attr) { |
1807 | 0 | auto &dialect = attr.getDialect(); |
1808 | 0 |
|
1809 | 0 | // Ask the dialect to serialize the attribute to a string. |
1810 | 0 | std::string attrName; |
1811 | 0 | { |
1812 | 0 | llvm::raw_string_ostream attrNameStr(attrName); |
1813 | 0 | ModulePrinter subPrinter(attrNameStr, printerFlags, state); |
1814 | 0 | CustomDialectAsmPrinter printer(subPrinter); |
1815 | 0 | dialect.printAttribute(attr, printer); |
1816 | 0 | } |
1817 | 0 | printDialectSymbol(os, "#", dialect.getNamespace(), attrName); |
1818 | 0 | } |
1819 | | |
1820 | 0 | void ModulePrinter::printDialectType(Type type) { |
1821 | 0 | auto &dialect = type.getDialect(); |
1822 | 0 |
|
1823 | 0 | // Ask the dialect to serialize the type to a string. |
1824 | 0 | std::string typeName; |
1825 | 0 | { |
1826 | 0 | llvm::raw_string_ostream typeNameStr(typeName); |
1827 | 0 | ModulePrinter subPrinter(typeNameStr, printerFlags, state); |
1828 | 0 | CustomDialectAsmPrinter printer(subPrinter); |
1829 | 0 | dialect.printType(type, printer); |
1830 | 0 | } |
1831 | 0 | printDialectSymbol(os, "!", dialect.getNamespace(), typeName); |
1832 | 0 | } |
1833 | | |
1834 | | //===----------------------------------------------------------------------===// |
1835 | | // Affine expressions and maps |
1836 | | //===----------------------------------------------------------------------===// |
1837 | | |
1838 | | void ModulePrinter::printAffineExpr( |
1839 | 0 | AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { |
1840 | 0 | printAffineExprInternal(expr, BindingStrength::Weak, printValueName); |
1841 | 0 | } |
1842 | | |
1843 | | void ModulePrinter::printAffineExprInternal( |
1844 | | AffineExpr expr, BindingStrength enclosingTightness, |
1845 | 0 | function_ref<void(unsigned, bool)> printValueName) { |
1846 | 0 | const char *binopSpelling = nullptr; |
1847 | 0 | switch (expr.getKind()) { |
1848 | 0 | case AffineExprKind::SymbolId: { |
1849 | 0 | unsigned pos = expr.cast<AffineSymbolExpr>().getPosition(); |
1850 | 0 | if (printValueName) |
1851 | 0 | printValueName(pos, /*isSymbol=*/true); |
1852 | 0 | else |
1853 | 0 | os << 's' << pos; |
1854 | 0 | return; |
1855 | 0 | } |
1856 | 0 | case AffineExprKind::DimId: { |
1857 | 0 | unsigned pos = expr.cast<AffineDimExpr>().getPosition(); |
1858 | 0 | if (printValueName) |
1859 | 0 | printValueName(pos, /*isSymbol=*/false); |
1860 | 0 | else |
1861 | 0 | os << 'd' << pos; |
1862 | 0 | return; |
1863 | 0 | } |
1864 | 0 | case AffineExprKind::Constant: |
1865 | 0 | os << expr.cast<AffineConstantExpr>().getValue(); |
1866 | 0 | return; |
1867 | 0 | case AffineExprKind::Add: |
1868 | 0 | binopSpelling = " + "; |
1869 | 0 | break; |
1870 | 0 | case AffineExprKind::Mul: |
1871 | 0 | binopSpelling = " * "; |
1872 | 0 | break; |
1873 | 0 | case AffineExprKind::FloorDiv: |
1874 | 0 | binopSpelling = " floordiv "; |
1875 | 0 | break; |
1876 | 0 | case AffineExprKind::CeilDiv: |
1877 | 0 | binopSpelling = " ceildiv "; |
1878 | 0 | break; |
1879 | 0 | case AffineExprKind::Mod: |
1880 | 0 | binopSpelling = " mod "; |
1881 | 0 | break; |
1882 | 0 | } |
1883 | 0 | |
1884 | 0 | auto binOp = expr.cast<AffineBinaryOpExpr>(); |
1885 | 0 | AffineExpr lhsExpr = binOp.getLHS(); |
1886 | 0 | AffineExpr rhsExpr = binOp.getRHS(); |
1887 | 0 |
|
1888 | 0 | // Handle tightly binding binary operators. |
1889 | 0 | if (binOp.getKind() != AffineExprKind::Add) { |
1890 | 0 | if (enclosingTightness == BindingStrength::Strong) |
1891 | 0 | os << '('; |
1892 | 0 |
|
1893 | 0 | // Pretty print multiplication with -1. |
1894 | 0 | auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>(); |
1895 | 0 | if (rhsConst && binOp.getKind() == AffineExprKind::Mul && |
1896 | 0 | rhsConst.getValue() == -1) { |
1897 | 0 | os << "-"; |
1898 | 0 | printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); |
1899 | 0 | if (enclosingTightness == BindingStrength::Strong) |
1900 | 0 | os << ')'; |
1901 | 0 | return; |
1902 | 0 | } |
1903 | 0 |
|
1904 | 0 | printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); |
1905 | 0 |
|
1906 | 0 | os << binopSpelling; |
1907 | 0 | printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); |
1908 | 0 |
|
1909 | 0 | if (enclosingTightness == BindingStrength::Strong) |
1910 | 0 | os << ')'; |
1911 | 0 | return; |
1912 | 0 | } |
1913 | 0 |
|
1914 | 0 | // Print out special "pretty" forms for add. |
1915 | 0 | if (enclosingTightness == BindingStrength::Strong) |
1916 | 0 | os << '('; |
1917 | 0 |
|
1918 | 0 | // Pretty print addition to a product that has a negative operand as a |
1919 | 0 | // subtraction. |
1920 | 0 | if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) { |
1921 | 0 | if (rhs.getKind() == AffineExprKind::Mul) { |
1922 | 0 | AffineExpr rrhsExpr = rhs.getRHS(); |
1923 | 0 | if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) { |
1924 | 0 | if (rrhs.getValue() == -1) { |
1925 | 0 | printAffineExprInternal(lhsExpr, BindingStrength::Weak, |
1926 | 0 | printValueName); |
1927 | 0 | os << " - "; |
1928 | 0 | if (rhs.getLHS().getKind() == AffineExprKind::Add) { |
1929 | 0 | printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, |
1930 | 0 | printValueName); |
1931 | 0 | } else { |
1932 | 0 | printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, |
1933 | 0 | printValueName); |
1934 | 0 | } |
1935 | 0 |
|
1936 | 0 | if (enclosingTightness == BindingStrength::Strong) |
1937 | 0 | os << ')'; |
1938 | 0 | return; |
1939 | 0 | } |
1940 | 0 |
|
1941 | 0 | if (rrhs.getValue() < -1) { |
1942 | 0 | printAffineExprInternal(lhsExpr, BindingStrength::Weak, |
1943 | 0 | printValueName); |
1944 | 0 | os << " - "; |
1945 | 0 | printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, |
1946 | 0 | printValueName); |
1947 | 0 | os << " * " << -rrhs.getValue(); |
1948 | 0 | if (enclosingTightness == BindingStrength::Strong) |
1949 | 0 | os << ')'; |
1950 | 0 | return; |
1951 | 0 | } |
1952 | 0 | } |
1953 | 0 | } |
1954 | 0 | } |
1955 | 0 |
|
1956 | 0 | // Pretty print addition to a negative number as a subtraction. |
1957 | 0 | if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) { |
1958 | 0 | if (rhsConst.getValue() < 0) { |
1959 | 0 | printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); |
1960 | 0 | os << " - " << -rhsConst.getValue(); |
1961 | 0 | if (enclosingTightness == BindingStrength::Strong) |
1962 | 0 | os << ')'; |
1963 | 0 | return; |
1964 | 0 | } |
1965 | 0 | } |
1966 | 0 |
|
1967 | 0 | printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); |
1968 | 0 |
|
1969 | 0 | os << " + "; |
1970 | 0 | printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); |
1971 | 0 |
|
1972 | 0 | if (enclosingTightness == BindingStrength::Strong) |
1973 | 0 | os << ')'; |
1974 | 0 | } |
1975 | | |
1976 | 0 | void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) { |
1977 | 0 | printAffineExprInternal(expr, BindingStrength::Weak); |
1978 | 0 | isEq ? os << " == 0" : os << " >= 0"; |
1979 | 0 | } |
1980 | | |
1981 | 0 | void ModulePrinter::printAffineMap(AffineMap map) { |
1982 | 0 | // Dimension identifiers. |
1983 | 0 | os << '('; |
1984 | 0 | for (int i = 0; i < (int)map.getNumDims() - 1; ++i) |
1985 | 0 | os << 'd' << i << ", "; |
1986 | 0 | if (map.getNumDims() >= 1) |
1987 | 0 | os << 'd' << map.getNumDims() - 1; |
1988 | 0 | os << ')'; |
1989 | 0 |
|
1990 | 0 | // Symbolic identifiers. |
1991 | 0 | if (map.getNumSymbols() != 0) { |
1992 | 0 | os << '['; |
1993 | 0 | for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) |
1994 | 0 | os << 's' << i << ", "; |
1995 | 0 | if (map.getNumSymbols() >= 1) |
1996 | 0 | os << 's' << map.getNumSymbols() - 1; |
1997 | 0 | os << ']'; |
1998 | 0 | } |
1999 | 0 |
|
2000 | 0 | // Result affine expressions. |
2001 | 0 | os << " -> ("; |
2002 | 0 | interleaveComma(map.getResults(), |
2003 | 0 | [&](AffineExpr expr) { printAffineExpr(expr); }); |
2004 | 0 | os << ')'; |
2005 | 0 | } |
2006 | | |
2007 | 0 | void ModulePrinter::printIntegerSet(IntegerSet set) { |
2008 | 0 | // Dimension identifiers. |
2009 | 0 | os << '('; |
2010 | 0 | for (unsigned i = 1; i < set.getNumDims(); ++i) |
2011 | 0 | os << 'd' << i - 1 << ", "; |
2012 | 0 | if (set.getNumDims() >= 1) |
2013 | 0 | os << 'd' << set.getNumDims() - 1; |
2014 | 0 | os << ')'; |
2015 | 0 |
|
2016 | 0 | // Symbolic identifiers. |
2017 | 0 | if (set.getNumSymbols() != 0) { |
2018 | 0 | os << '['; |
2019 | 0 | for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) |
2020 | 0 | os << 's' << i << ", "; |
2021 | 0 | if (set.getNumSymbols() >= 1) |
2022 | 0 | os << 's' << set.getNumSymbols() - 1; |
2023 | 0 | os << ']'; |
2024 | 0 | } |
2025 | 0 |
|
2026 | 0 | // Print constraints. |
2027 | 0 | os << " : ("; |
2028 | 0 | int numConstraints = set.getNumConstraints(); |
2029 | 0 | for (int i = 1; i < numConstraints; ++i) { |
2030 | 0 | printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); |
2031 | 0 | os << ", "; |
2032 | 0 | } |
2033 | 0 | if (numConstraints >= 1) |
2034 | 0 | printAffineConstraint(set.getConstraint(numConstraints - 1), |
2035 | 0 | set.isEq(numConstraints - 1)); |
2036 | 0 | os << ')'; |
2037 | 0 | } |
2038 | | |
2039 | | //===----------------------------------------------------------------------===// |
2040 | | // OperationPrinter |
2041 | | //===----------------------------------------------------------------------===// |
2042 | | |
2043 | | namespace { |
2044 | | /// This class contains the logic for printing operations, regions, and blocks. |
2045 | | class OperationPrinter : public ModulePrinter, private OpAsmPrinter { |
2046 | | public: |
2047 | | explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags, |
2048 | | AsmStateImpl &state) |
2049 | 0 | : ModulePrinter(os, flags, &state) {} |
2050 | | |
2051 | | /// Print the given top-level module. |
2052 | | void print(ModuleOp op); |
2053 | | /// Print the given operation with its indent and location. |
2054 | | void print(Operation *op); |
2055 | | /// Print the bare location, not including indentation/location/etc. |
2056 | | void printOperation(Operation *op); |
2057 | | /// Print the given operation in the generic form. |
2058 | | void printGenericOp(Operation *op) override; |
2059 | | |
2060 | | /// Print the name of the given block. |
2061 | | void printBlockName(Block *block); |
2062 | | |
2063 | | /// Print the given block. If 'printBlockArgs' is false, the arguments of the |
2064 | | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
2065 | | /// operation of the block is not printed. |
2066 | | void print(Block *block, bool printBlockArgs = true, |
2067 | | bool printBlockTerminator = true); |
2068 | | |
2069 | | /// Print the ID of the given value, optionally with its result number. |
2070 | | void printValueID(Value value, bool printResultNo = true, |
2071 | | raw_ostream *streamOverride = nullptr) const; |
2072 | | |
2073 | | //===--------------------------------------------------------------------===// |
2074 | | // OpAsmPrinter methods |
2075 | | //===--------------------------------------------------------------------===// |
2076 | | |
2077 | | /// Return the current stream of the printer. |
2078 | 0 | raw_ostream &getStream() const override { return os; } |
2079 | | |
2080 | | /// Print the given type. |
2081 | 0 | void printType(Type type) override { ModulePrinter::printType(type); } |
2082 | | |
2083 | | /// Print the given attribute. |
2084 | 0 | void printAttribute(Attribute attr) override { |
2085 | 0 | ModulePrinter::printAttribute(attr); |
2086 | 0 | } |
2087 | | |
2088 | | /// Print the given attribute without its type. The corresponding parser must |
2089 | | /// provide a valid type for the attribute. |
2090 | 0 | void printAttributeWithoutType(Attribute attr) override { |
2091 | 0 | ModulePrinter::printAttribute(attr, AttrTypeElision::Must); |
2092 | 0 | } |
2093 | | |
2094 | | /// Print the ID for the given value. |
2095 | 0 | void printOperand(Value value) override { printValueID(value); } |
2096 | 0 | void printOperand(Value value, raw_ostream &os) override { |
2097 | 0 | printValueID(value, /*printResultNo=*/true, &os); |
2098 | 0 | } |
2099 | | |
2100 | | /// Print an optional attribute dictionary with a given set of elided values. |
2101 | | void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
2102 | 0 | ArrayRef<StringRef> elidedAttrs = {}) override { |
2103 | 0 | ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); |
2104 | 0 | } |
2105 | | void printOptionalAttrDictWithKeyword( |
2106 | | ArrayRef<NamedAttribute> attrs, |
2107 | 0 | ArrayRef<StringRef> elidedAttrs = {}) override { |
2108 | 0 | ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs, |
2109 | 0 | /*withKeyword=*/true); |
2110 | 0 | } |
2111 | | |
2112 | | /// Print the given successor. |
2113 | | void printSuccessor(Block *successor) override; |
2114 | | |
2115 | | /// Print an operation successor with the operands used for the block |
2116 | | /// arguments. |
2117 | | void printSuccessorAndUseList(Block *successor, |
2118 | | ValueRange succOperands) override; |
2119 | | |
2120 | | /// Print the given region. |
2121 | | void printRegion(Region ®ion, bool printEntryBlockArgs, |
2122 | | bool printBlockTerminators) override; |
2123 | | |
2124 | | /// Renumber the arguments for the specified region to the same names as the |
2125 | | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
2126 | | /// operations. If any entry in namesToUse is null, the corresponding |
2127 | | /// argument name is left alone. |
2128 | 0 | void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { |
2129 | 0 | state->getSSANameState().shadowRegionArgs(region, namesToUse); |
2130 | 0 | } |
2131 | | |
2132 | | /// Print the given affine map with the symbol and dimension operands printed |
2133 | | /// inline with the map. |
2134 | | void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
2135 | | ValueRange operands) override; |
2136 | | |
2137 | | /// Print the given string as a symbol reference. |
2138 | 0 | void printSymbolName(StringRef symbolRef) override { |
2139 | 0 | ::printSymbolReference(symbolRef, os); |
2140 | 0 | } |
2141 | | |
2142 | | private: |
2143 | | /// The number of spaces used for indenting nested operations. |
2144 | | const static unsigned indentWidth = 2; |
2145 | | |
2146 | | // This is the current indentation level for nested structures. |
2147 | | unsigned currentIndent = 0; |
2148 | | }; |
2149 | | } // end anonymous namespace |
2150 | | |
2151 | 0 | void OperationPrinter::print(ModuleOp op) { |
2152 | 0 | // Output the aliases at the top level. |
2153 | 0 | state->getAliasState().printAttributeAliases(os, newLine); |
2154 | 0 | state->getAliasState().printTypeAliases(os, newLine); |
2155 | 0 |
|
2156 | 0 | // Print the module. |
2157 | 0 | print(op.getOperation()); |
2158 | 0 | } |
2159 | | |
2160 | 0 | void OperationPrinter::print(Operation *op) { |
2161 | 0 | // Track the location of this operation. |
2162 | 0 | state->registerOperationLocation(op, newLine.curLine, currentIndent); |
2163 | 0 |
|
2164 | 0 | os.indent(currentIndent); |
2165 | 0 | printOperation(op); |
2166 | 0 | printTrailingLocation(op->getLoc()); |
2167 | 0 | } |
2168 | | |
2169 | 0 | void OperationPrinter::printOperation(Operation *op) { |
2170 | 0 | if (size_t numResults = op->getNumResults()) { |
2171 | 0 | auto printResultGroup = [&](size_t resultNo, size_t resultCount) { |
2172 | 0 | printValueID(op->getResult(resultNo), /*printResultNo=*/false); |
2173 | 0 | if (resultCount > 1) |
2174 | 0 | os << ':' << resultCount; |
2175 | 0 | }; |
2176 | 0 |
|
2177 | 0 | // Check to see if this operation has multiple result groups. |
2178 | 0 | ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op); |
2179 | 0 | if (!resultGroups.empty()) { |
2180 | 0 | // Interleave the groups excluding the last one, this one will be handled |
2181 | 0 | // separately. |
2182 | 0 | interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { |
2183 | 0 | printResultGroup(resultGroups[i], |
2184 | 0 | resultGroups[i + 1] - resultGroups[i]); |
2185 | 0 | }); |
2186 | 0 | os << ", "; |
2187 | 0 | printResultGroup(resultGroups.back(), numResults - resultGroups.back()); |
2188 | 0 |
|
2189 | 0 | } else { |
2190 | 0 | printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); |
2191 | 0 | } |
2192 | 0 |
|
2193 | 0 | os << " = "; |
2194 | 0 | } |
2195 | 0 |
|
2196 | 0 | // If requested, always print the generic form. |
2197 | 0 | if (!printerFlags.shouldPrintGenericOpForm()) { |
2198 | 0 | // Check to see if this is a known operation. If so, use the registered |
2199 | 0 | // custom printer hook. |
2200 | 0 | if (auto *opInfo = op->getAbstractOperation()) { |
2201 | 0 | opInfo->printAssembly(op, *this); |
2202 | 0 | return; |
2203 | 0 | } |
2204 | 0 | } |
2205 | 0 | |
2206 | 0 | // Otherwise print with the generic assembly form. |
2207 | 0 | printGenericOp(op); |
2208 | 0 | } |
2209 | | |
2210 | 0 | void OperationPrinter::printGenericOp(Operation *op) { |
2211 | 0 | os << '"'; |
2212 | 0 | printEscapedString(op->getName().getStringRef(), os); |
2213 | 0 | os << "\"("; |
2214 | 0 | interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); |
2215 | 0 | os << ')'; |
2216 | 0 |
|
2217 | 0 | // For terminators, print the list of successors and their operands. |
2218 | 0 | if (op->getNumSuccessors() != 0) { |
2219 | 0 | os << '['; |
2220 | 0 | interleaveComma(op->getSuccessors(), |
2221 | 0 | [&](Block *successor) { printBlockName(successor); }); |
2222 | 0 | os << ']'; |
2223 | 0 | } |
2224 | 0 |
|
2225 | 0 | // Print regions. |
2226 | 0 | if (op->getNumRegions() != 0) { |
2227 | 0 | os << " ("; |
2228 | 0 | interleaveComma(op->getRegions(), [&](Region ®ion) { |
2229 | 0 | printRegion(region, /*printEntryBlockArgs=*/true, |
2230 | 0 | /*printBlockTerminators=*/true); |
2231 | 0 | }); |
2232 | 0 | os << ')'; |
2233 | 0 | } |
2234 | 0 |
|
2235 | 0 | auto attrs = op->getAttrs(); |
2236 | 0 | printOptionalAttrDict(attrs); |
2237 | 0 |
|
2238 | 0 | // Print the type signature of the operation. |
2239 | 0 | os << " : "; |
2240 | 0 | printFunctionalType(op); |
2241 | 0 | } |
2242 | | |
2243 | 0 | void OperationPrinter::printBlockName(Block *block) { |
2244 | 0 | auto id = state->getSSANameState().getBlockID(block); |
2245 | 0 | if (id != SSANameState::NameSentinel) |
2246 | 0 | os << "^bb" << id; |
2247 | 0 | else |
2248 | 0 | os << "^INVALIDBLOCK"; |
2249 | 0 | } |
2250 | | |
2251 | | void OperationPrinter::print(Block *block, bool printBlockArgs, |
2252 | 0 | bool printBlockTerminator) { |
2253 | 0 | // Print the block label and argument list if requested. |
2254 | 0 | if (printBlockArgs) { |
2255 | 0 | os.indent(currentIndent); |
2256 | 0 | printBlockName(block); |
2257 | 0 |
|
2258 | 0 | // Print the argument list if non-empty. |
2259 | 0 | if (!block->args_empty()) { |
2260 | 0 | os << '('; |
2261 | 0 | interleaveComma(block->getArguments(), [&](BlockArgument arg) { |
2262 | 0 | printValueID(arg); |
2263 | 0 | os << ": "; |
2264 | 0 | printType(arg.getType()); |
2265 | 0 | }); |
2266 | 0 | os << ')'; |
2267 | 0 | } |
2268 | 0 | os << ':'; |
2269 | 0 |
|
2270 | 0 | // Print out some context information about the predecessors of this block. |
2271 | 0 | if (!block->getParent()) { |
2272 | 0 | os << " // block is not in a region!"; |
2273 | 0 | } else if (block->hasNoPredecessors()) { |
2274 | 0 | os << " // no predecessors"; |
2275 | 0 | } else if (auto *pred = block->getSinglePredecessor()) { |
2276 | 0 | os << " // pred: "; |
2277 | 0 | printBlockName(pred); |
2278 | 0 | } else { |
2279 | 0 | // We want to print the predecessors in increasing numeric order, not in |
2280 | 0 | // whatever order the use-list is in, so gather and sort them. |
2281 | 0 | SmallVector<std::pair<unsigned, Block *>, 4> predIDs; |
2282 | 0 | for (auto *pred : block->getPredecessors()) |
2283 | 0 | predIDs.push_back({state->getSSANameState().getBlockID(pred), pred}); |
2284 | 0 | llvm::array_pod_sort(predIDs.begin(), predIDs.end()); |
2285 | 0 |
|
2286 | 0 | os << " // " << predIDs.size() << " preds: "; |
2287 | 0 |
|
2288 | 0 | interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) { |
2289 | 0 | printBlockName(pred.second); |
2290 | 0 | }); |
2291 | 0 | } |
2292 | 0 | os << newLine; |
2293 | 0 | } |
2294 | 0 |
|
2295 | 0 | currentIndent += indentWidth; |
2296 | 0 | auto range = llvm::make_range( |
2297 | 0 | block->getOperations().begin(), |
2298 | 0 | std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1)); |
2299 | 0 | for (auto &op : range) { |
2300 | 0 | print(&op); |
2301 | 0 | os << newLine; |
2302 | 0 | } |
2303 | 0 | currentIndent -= indentWidth; |
2304 | 0 | } |
2305 | | |
2306 | | void OperationPrinter::printValueID(Value value, bool printResultNo, |
2307 | 0 | raw_ostream *streamOverride) const { |
2308 | 0 | state->getSSANameState().printValueID(value, printResultNo, |
2309 | 0 | streamOverride ? *streamOverride : os); |
2310 | 0 | } |
2311 | | |
2312 | 0 | void OperationPrinter::printSuccessor(Block *successor) { |
2313 | 0 | printBlockName(successor); |
2314 | 0 | } |
2315 | | |
2316 | | void OperationPrinter::printSuccessorAndUseList(Block *successor, |
2317 | 0 | ValueRange succOperands) { |
2318 | 0 | printBlockName(successor); |
2319 | 0 | if (succOperands.empty()) |
2320 | 0 | return; |
2321 | 0 | |
2322 | 0 | os << '('; |
2323 | 0 | interleaveComma(succOperands, |
2324 | 0 | [this](Value operand) { printValueID(operand); }); |
2325 | 0 | os << " : "; |
2326 | 0 | interleaveComma(succOperands, |
2327 | 0 | [this](Value operand) { printType(operand.getType()); }); |
2328 | 0 | os << ')'; |
2329 | 0 | } |
2330 | | |
2331 | | void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, |
2332 | 0 | bool printBlockTerminators) { |
2333 | 0 | os << " {" << newLine; |
2334 | 0 | if (!region.empty()) { |
2335 | 0 | auto *entryBlock = ®ion.front(); |
2336 | 0 | print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0, |
2337 | 0 | printBlockTerminators); |
2338 | 0 | for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) |
2339 | 0 | print(&b); |
2340 | 0 | } |
2341 | 0 | os.indent(currentIndent) << "}"; |
2342 | 0 | } |
2343 | | |
2344 | | void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
2345 | 0 | ValueRange operands) { |
2346 | 0 | AffineMap map = mapAttr.getValue(); |
2347 | 0 | unsigned numDims = map.getNumDims(); |
2348 | 0 | auto printValueName = [&](unsigned pos, bool isSymbol) { |
2349 | 0 | unsigned index = isSymbol ? numDims + pos : pos; |
2350 | 0 | assert(index < operands.size()); |
2351 | 0 | if (isSymbol) |
2352 | 0 | os << "symbol("; |
2353 | 0 | printValueID(operands[index]); |
2354 | 0 | if (isSymbol) |
2355 | 0 | os << ')'; |
2356 | 0 | }; |
2357 | 0 |
|
2358 | 0 | interleaveComma(map.getResults(), [&](AffineExpr expr) { |
2359 | 0 | printAffineExpr(expr, printValueName); |
2360 | 0 | }); |
2361 | 0 | } |
2362 | | |
2363 | | //===----------------------------------------------------------------------===// |
2364 | | // print and dump methods |
2365 | | //===----------------------------------------------------------------------===// |
2366 | | |
2367 | 0 | void Attribute::print(raw_ostream &os) const { |
2368 | 0 | ModulePrinter(os).printAttribute(*this); |
2369 | 0 | } |
2370 | | |
2371 | 0 | void Attribute::dump() const { |
2372 | 0 | print(llvm::errs()); |
2373 | 0 | llvm::errs() << "\n"; |
2374 | 0 | } |
2375 | | |
2376 | 0 | void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); } |
2377 | | |
2378 | 0 | void Type::dump() { print(llvm::errs()); } |
2379 | | |
2380 | 0 | void AffineMap::dump() const { |
2381 | 0 | print(llvm::errs()); |
2382 | 0 | llvm::errs() << "\n"; |
2383 | 0 | } |
2384 | | |
2385 | 0 | void IntegerSet::dump() const { |
2386 | 0 | print(llvm::errs()); |
2387 | 0 | llvm::errs() << "\n"; |
2388 | 0 | } |
2389 | | |
2390 | 0 | void AffineExpr::print(raw_ostream &os) const { |
2391 | 0 | if (!expr) { |
2392 | 0 | os << "<<NULL AFFINE EXPR>>"; |
2393 | 0 | return; |
2394 | 0 | } |
2395 | 0 | ModulePrinter(os).printAffineExpr(*this); |
2396 | 0 | } |
2397 | | |
2398 | 0 | void AffineExpr::dump() const { |
2399 | 0 | print(llvm::errs()); |
2400 | 0 | llvm::errs() << "\n"; |
2401 | 0 | } |
2402 | | |
2403 | 0 | void AffineMap::print(raw_ostream &os) const { |
2404 | 0 | if (!map) { |
2405 | 0 | os << "<<NULL AFFINE MAP>>"; |
2406 | 0 | return; |
2407 | 0 | } |
2408 | 0 | ModulePrinter(os).printAffineMap(*this); |
2409 | 0 | } |
2410 | | |
2411 | 0 | void IntegerSet::print(raw_ostream &os) const { |
2412 | 0 | ModulePrinter(os).printIntegerSet(*this); |
2413 | 0 | } |
2414 | | |
2415 | 0 | void Value::print(raw_ostream &os) { |
2416 | 0 | if (auto *op = getDefiningOp()) |
2417 | 0 | return op->print(os); |
2418 | 0 | // TODO: Improve this. |
2419 | 0 | assert(isa<BlockArgument>()); |
2420 | 0 | os << "<block argument>\n"; |
2421 | 0 | } |
2422 | 0 | void Value::print(raw_ostream &os, AsmState &state) { |
2423 | 0 | if (auto *op = getDefiningOp()) |
2424 | 0 | return op->print(os, state); |
2425 | 0 | |
2426 | 0 | // TODO: Improve this. |
2427 | 0 | assert(isa<BlockArgument>()); |
2428 | 0 | os << "<block argument>\n"; |
2429 | 0 | } |
2430 | | |
2431 | 0 | void Value::dump() { |
2432 | 0 | print(llvm::errs()); |
2433 | 0 | llvm::errs() << "\n"; |
2434 | 0 | } |
2435 | | |
2436 | 0 | void Value::printAsOperand(raw_ostream &os, AsmState &state) { |
2437 | 0 | // TODO(riverriddle) This doesn't necessarily capture all potential cases. |
2438 | 0 | // Currently, region arguments can be shadowed when printing the main |
2439 | 0 | // operation. If the IR hasn't been printed, this will produce the old SSA |
2440 | 0 | // name and not the shadowed name. |
2441 | 0 | state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, |
2442 | 0 | os); |
2443 | 0 | } |
2444 | | |
2445 | 0 | void Operation::print(raw_ostream &os, OpPrintingFlags flags) { |
2446 | 0 | // Find the operation to number from based upon the provided flags. |
2447 | 0 | Operation *printedOp = this; |
2448 | 0 | bool shouldUseLocalScope = flags.shouldUseLocalScope(); |
2449 | 0 | do { |
2450 | 0 | // If we are printing local scope, stop at the first operation that is |
2451 | 0 | // isolated from above. |
2452 | 0 | if (shouldUseLocalScope && printedOp->isKnownIsolatedFromAbove()) |
2453 | 0 | break; |
2454 | 0 | |
2455 | 0 | // Otherwise, traverse up to the next parent. |
2456 | 0 | Operation *parentOp = printedOp->getParentOp(); |
2457 | 0 | if (!parentOp) |
2458 | 0 | break; |
2459 | 0 | printedOp = parentOp; |
2460 | 0 | } while (true); |
2461 | 0 |
|
2462 | 0 | AsmState state(printedOp); |
2463 | 0 | print(os, state, flags); |
2464 | 0 | } |
2465 | 0 | void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { |
2466 | 0 | OperationPrinter(os, flags, state.getImpl()).print(this); |
2467 | 0 | } |
2468 | | |
2469 | 0 | void Operation::dump() { |
2470 | 0 | print(llvm::errs(), OpPrintingFlags().useLocalScope()); |
2471 | 0 | llvm::errs() << "\n"; |
2472 | 0 | } |
2473 | | |
2474 | 0 | void Block::print(raw_ostream &os) { |
2475 | 0 | Operation *parentOp = getParentOp(); |
2476 | 0 | if (!parentOp) { |
2477 | 0 | os << "<<UNLINKED BLOCK>>\n"; |
2478 | 0 | return; |
2479 | 0 | } |
2480 | 0 | // Get the top-level op. |
2481 | 0 | while (auto *nextOp = parentOp->getParentOp()) |
2482 | 0 | parentOp = nextOp; |
2483 | 0 |
|
2484 | 0 | AsmState state(parentOp); |
2485 | 0 | print(os, state); |
2486 | 0 | } |
2487 | 0 | void Block::print(raw_ostream &os, AsmState &state) { |
2488 | 0 | OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this); |
2489 | 0 | } |
2490 | | |
2491 | 0 | void Block::dump() { print(llvm::errs()); } |
2492 | | |
2493 | | /// Print out the name of the block without printing its body. |
2494 | 0 | void Block::printAsOperand(raw_ostream &os, bool printType) { |
2495 | 0 | Operation *parentOp = getParentOp(); |
2496 | 0 | if (!parentOp) { |
2497 | 0 | os << "<<UNLINKED BLOCK>>\n"; |
2498 | 0 | return; |
2499 | 0 | } |
2500 | 0 | AsmState state(parentOp); |
2501 | 0 | printAsOperand(os, state); |
2502 | 0 | } |
2503 | 0 | void Block::printAsOperand(raw_ostream &os, AsmState &state) { |
2504 | 0 | OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl()); |
2505 | 0 | printer.printBlockName(this); |
2506 | 0 | } |
2507 | | |
2508 | 0 | void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) { |
2509 | 0 | AsmState state(*this); |
2510 | 0 |
|
2511 | 0 | // Don't populate aliases when printing at local scope. |
2512 | 0 | if (!flags.shouldUseLocalScope()) |
2513 | 0 | state.getImpl().initializeAliases(*this); |
2514 | 0 | print(os, state, flags); |
2515 | 0 | } |
2516 | 0 | void ModuleOp::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { |
2517 | 0 | OperationPrinter(os, flags, state.getImpl()).print(*this); |
2518 | 0 | } |
2519 | | |
2520 | 0 | void ModuleOp::dump() { print(llvm::errs()); } |