diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index d70aa346eaa1f..a6d39c4d17a8d 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -17,6 +17,8 @@ #include "mlir/IR/DialectInterface.h" #include "mlir/IR/OpAsmSupport.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/SMLoc.h" #include @@ -169,6 +171,9 @@ class AsmPrinter { if (succeeded(printAlias(attrOrType))) return; + if (succeeded(printDialectAlias(attrOrType, /*printStripped=*/true))) + return; + raw_ostream &os = getStream(); uint64_t posPrior = os.tell(); attrOrType.print(*this); @@ -218,6 +223,14 @@ class AsmPrinter { /// be printed. virtual LogicalResult printAlias(Type type); + /// Print the alias for the given attribute, return failure if no alias could + /// be printed. + virtual LogicalResult printDialectAlias(Attribute attr, bool printStripped); + + /// Print the alias for the given type, return failure if no alias could + /// be printed. + virtual LogicalResult printDialectAlias(Type type, bool printStripped); + /// Print the given string as a keyword, or a quoted and escaped string if it /// has any special or non-printable characters in it. virtual void printKeywordOrString(StringRef keyword); @@ -1799,6 +1812,30 @@ class OpAsmDialectInterface return AliasResult::NoAlias; } + /// Hooks for registering alias printers for types and attributes. These + /// printers are invoked when printing types or attributes of the given + /// TypeID. Printers are invoked in the order they are registered, and the + /// first one to print an alias is used. + /// The precedence of these printers is as follow: + /// 1. The type and attribute aliases returned by `getAlias`. + /// 2. Dialect-specific alias printers registered here. + /// 3. The type and attribute printers. + /// The boolean argument to the printer indicates whether the stripped form + /// of the type or attribute is being printed. + /// NOTE: This mechanism caches the printed object, therefore the printer + /// must always produce the same output for the same input. + using AttributeAliasPrinter = + llvm::function_ref; + using InsertAttrAliasPrinter = + llvm::function_ref; + virtual void registerAttrAliasPrinter(InsertAttrAliasPrinter insertFn) const { + } + using TypeAliasPrinter = llvm::function_ref; + using InsertTypeAliasPrinter = + llvm::function_ref; + virtual void registerTypeAliasPrinter(InsertTypeAliasPrinter insertFn) const { + } + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 7d991cea6c468..86fa754b3bb54 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -32,13 +32,16 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/Endian.h" @@ -47,6 +50,7 @@ #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/Threading.h" #include "llvm/Support/raw_ostream.h" +#include #include #include @@ -414,6 +418,7 @@ class AsmPrinter::Impl { public: Impl(raw_ostream &os, AsmStateImpl &state); explicit Impl(Impl &other) : Impl(other.os, other.state) {} + explicit Impl(raw_ostream &os, Impl &other) : Impl(os, other.state) {} /// Returns the output stream of the printer. raw_ostream &getStream() { return os; } @@ -447,6 +452,10 @@ class AsmPrinter::Impl { /// be printed. LogicalResult printAlias(Attribute attr); + /// Print the dialect alias for the given attribute, return failure if no + /// alias could be printed. + LogicalResult printDialectAlias(Attribute attr, bool printStripped); + /// Print the given type or an alias. void printType(Type type); /// Print the given type. @@ -456,6 +465,10 @@ class AsmPrinter::Impl { /// be printed. LogicalResult printAlias(Type type); + /// Print the dialect alias for the given type, return failure if no alias + /// could be printed. + LogicalResult printDialectAlias(Type type, bool printStripped); + /// Print the given location to the stream. If `allowAlias` is true, this /// allows for the internal location to use an attribute alias. void printLocation(LocationAttr loc, bool allowAlias = false); @@ -812,6 +825,12 @@ class DummyAliasOperationPrinter : private OpAsmPrinter { initializer.visit(type); return success(); } + LogicalResult printDialectAlias(Attribute attr, bool printStripped) override { + return failure(); + } + LogicalResult printDialectAlias(Type type, bool printStripped) override { + return failure(); + } /// Consider the given location to be printed for an alias. void printOptionalLocationSpecifier(Location loc) override { @@ -991,6 +1010,11 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { return success(); } + LogicalResult printDialectAlias(Attribute, bool) override { + return failure(); + } + LogicalResult printDialectAlias(Type, bool) override { return failure(); } + /// Record the alias result of a child element. void recordAliasResult(std::pair aliasDepthAndIndex) { childIndices.push_back(aliasDepthAndIndex.second); @@ -1251,7 +1275,10 @@ namespace { /// This class manages the state for type and attribute aliases. class AliasState { public: - // Initialize the internal aliases. + /// Initialize the alias state for custom dialect aliases. + AliasState(DialectInterfaceCollection &interfaces); + + /// Initialize the internal aliases. void initialize(Operation *op, const OpPrintingFlags &printerFlags, DialectInterfaceCollection &interfaces); @@ -1275,20 +1302,99 @@ class AliasState { printAliases(p, newLine, /*isDeferred=*/true); } + /// Get an attribute alias if it exists. Returns the alias if found, + /// or a default constructed pair otherwise. + std::pair + getAttrAlias(AsmPrinter::Impl &p, Attribute attr, bool printStripped) { + return getAlias(p, attr.getTypeID(), attr.getAsOpaquePointer(), + printStripped); + } + + /// Get a type alias if it exists. Returns the alias if found, + /// or a default constructed pair otherwise. + std::pair + getTypeAlias(AsmPrinter::Impl &p, Type type, bool printStripped) { + return getAlias(p, type.getTypeID(), type.getAsOpaquePointer(), + printStripped); + } + private: + using TypeIDPrinter = + std::tuple>; + using PrinterIterator = SmallVectorImpl::iterator; + /// Print all of the referenced aliases that support the provided resolution /// behavior. void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, bool isDeferred); + /// Comparison function for TypeIDPrinter. + static bool comparePrinters(const TypeIDPrinter &lhs, + const TypeIDPrinter &rhs); + + /// Find custom printers for the given TypeID. + llvm::iterator_range findPrinters(TypeID typeID); + + /// Get an attribute or type alias if it exists. Returns the alias if found, + /// or a default constructed pair otherwise. + std::pair getAlias(AsmPrinter::Impl &p, + TypeID typeID, + const void *opaqueAttrType, + bool printStripped); + /// Mapping between attribute/type and alias. llvm::MapVector attrTypeToAlias; /// An allocator used for alias names. llvm::BumpPtrAllocator aliasAllocator; + + /// Mapping between attribute/type ID and custom printers for them. + SmallVector attrTypePrinters; + + /// Cache for custom printed attributes/types. + DenseMap, + std::pair> + printCache; }; } // namespace +AliasState::AliasState( + DialectInterfaceCollection &interfaces) { + // Collect all of the custom alias printers. + for (const OpAsmDialectInterface &interface : interfaces) { + auto insertAliasAttrFn = + [&](TypeID typeID, + OpAsmDialectInterface::AttributeAliasPrinter printer) { + if (!printer) + return; + attrTypePrinters.emplace_back( + typeID, interface.getDialect(), + [printer](const void *attr, AsmPrinter &p, bool printStripped) { + printer(Attribute::getFromOpaquePointer(attr), p, + printStripped); + }); + }; + auto insertAliasTypeFn = + [&](TypeID typeID, OpAsmDialectInterface::TypeAliasPrinter printer) { + if (!printer) + return; + attrTypePrinters.emplace_back( + typeID, interface.getDialect(), + [printer](const void *attr, AsmPrinter &p, bool printStripped) { + printer(Type::getFromOpaquePointer(attr), p, printStripped); + }); + }; + interface.registerAttrAliasPrinter(insertAliasAttrFn); + interface.registerTypeAliasPrinter(insertAliasTypeFn); + } + + // Sort the printers by TypeID for efficient lookup. + // Stable sort guarantees that the order of registration is preserved. + std::stable_sort(attrTypePrinters.begin(), attrTypePrinters.end(), + comparePrinters); +} + void AliasState::initialize( Operation *op, const OpPrintingFlags &printerFlags, DialectInterfaceCollection &interfaces) { @@ -1315,6 +1421,24 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { return success(); } +bool AliasState::comparePrinters(const TypeIDPrinter &lhs, + const TypeIDPrinter &rhs) { + return std::get<0>(lhs).getAsOpaquePointer() < + std::get<0>(rhs).getAsOpaquePointer(); +} + +llvm::iterator_range +AliasState::findPrinters(TypeID typeID) { + TypeIDPrinter key = std::make_tuple( + typeID, /*unused*/ nullptr, + /*unused*/ std::function()); + PrinterIterator lb = std::lower_bound( + attrTypePrinters.begin(), attrTypePrinters.end(), key, comparePrinters); + PrinterIterator ub = std::upper_bound( + attrTypePrinters.begin(), attrTypePrinters.end(), key, comparePrinters); + return llvm::make_range(lb, ub); +} + void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, bool isDeferred) { auto filterFn = [=](const auto &aliasIt) { @@ -1342,6 +1466,40 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, } } +std::pair +AliasState::getAlias(AsmPrinter::Impl &p, TypeID typeID, + const void *opaqueAttrType, bool printStripped) { + llvm::PointerIntPair key(opaqueAttrType, + printStripped); + // Check the cache first. + if (auto it = printCache.find(key); it != printCache.end()) + return it->second; + + // Try to get the alias using custom printers. + std::string buffer; + llvm::raw_string_ostream os(buffer); + AsmPrinter::Impl printImpl(os, p); + DialectAsmPrinter printer(printImpl); + for (const auto &printInfo : findPrinters(typeID)) { + // Invoke the printer. + std::get<2>(printInfo)(opaqueAttrType, printer, printStripped); + + // Trim any whitespace. + if (StringRef str = StringRef(buffer).trim(); str != buffer) + buffer = str.str(); + + // If we printed something, cache and return. + if (!buffer.empty()) { + StringRef alias = (printCache[key] = std::make_pair( + std::get<1>(printInfo), std::move(buffer))) + .second; + return std::make_pair(std::get<1>(printInfo), alias); + } + buffer.clear(); + } + return {nullptr, StringRef()}; +} + //===----------------------------------------------------------------------===// // SSANameState //===----------------------------------------------------------------------===// @@ -1948,11 +2106,13 @@ class AsmStateImpl { public: explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, AsmState::LocationMap *locationMap) - : interfaces(op->getContext()), nameState(op, printerFlags), - printerFlags(printerFlags), locationMap(locationMap) {} + : interfaces(op->getContext()), aliasState(interfaces), + nameState(op, printerFlags), printerFlags(printerFlags), + locationMap(locationMap) {} explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, AsmState::LocationMap *locationMap) - : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {} + : interfaces(ctx), aliasState(interfaces), printerFlags(printerFlags), + locationMap(locationMap) {} /// Initialize the alias state to enable the printing of aliases. void initializeAliases(Operation *op) { @@ -2377,6 +2537,38 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) { return state.getAliasState().getAlias(type, os); } +LogicalResult AsmPrinter::Impl::printDialectAlias(Attribute attr, + bool printStripped) { + // Check to see if there is a dialect alias for this attribute. + auto [aliasDialect, alias] = + state.getAliasState().getAttrAlias(*this, attr, printStripped); + if (aliasDialect && !alias.empty()) { + if (printStripped) { + os << alias; + return success(); + } + printDialectSymbol(os, "!", aliasDialect->getNamespace(), alias); + return success(); + } + return failure(); +} + +LogicalResult AsmPrinter::Impl::printDialectAlias(Type type, + bool printStripped) { + // Check to see if there is a dialect alias for this type. + auto [aliasDialect, alias] = + state.getAliasState().getTypeAlias(*this, type, printStripped); + if (aliasDialect && !alias.empty()) { + if (printStripped) { + os << alias; + return success(); + } + printDialectSymbol(os, "!", aliasDialect->getNamespace(), alias); + return success(); + } + return failure(); +} + void AsmPrinter::Impl::printAttribute(Attribute attr, AttrTypeElision typeElision) { if (!attr) { @@ -2387,6 +2579,8 @@ void AsmPrinter::Impl::printAttribute(Attribute attr, // Try to print an alias for this attribute. if (succeeded(printAlias(attr))) return; + if (succeeded(printDialectAlias(attr, /*printStripped=*/false))) + return; return printAttributeImpl(attr, typeElision); } void AsmPrinter::Impl::printAttributeImpl(Attribute attr, @@ -2715,6 +2909,8 @@ void AsmPrinter::Impl::printType(Type type) { // Try to print an alias for this type. if (succeeded(printAlias(type))) return; + if (succeeded(printDialectAlias(type, /*printStripped=*/false))) + return; return printTypeImpl(type); } @@ -2987,6 +3183,17 @@ LogicalResult AsmPrinter::printAlias(Type type) { return impl->printAlias(type); } +LogicalResult AsmPrinter::printDialectAlias(Attribute attr, + bool printStripped) { + assert(impl && "expected AsmPrinter::printDialectAlias to be overriden"); + return impl->printAlias(attr); +} + +LogicalResult AsmPrinter::printDialectAlias(Type type, bool printStripped) { + assert(impl && "expected AsmPrinter::printDialectAlias to be overriden"); + return impl->printAlias(type); +} + void AsmPrinter::printAttributeWithoutType(Attribute attr) { assert(impl && "expected AsmPrinter::printAttributeWithoutType to be overriden"); diff --git a/mlir/test/IR/print-attr-type-dialect-aliases.mlir b/mlir/test/IR/print-attr-type-dialect-aliases.mlir new file mode 100644 index 0000000000000..95adb49d0a59a --- /dev/null +++ b/mlir/test/IR/print-attr-type-dialect-aliases.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s | FileCheck %s + +// Check that attr and type aliases are properly printed. + +// CHECK: {types = [!test.tuple_i7_from_attr, !test.tuple_i6_from_attr, tuple>]} : () -> (!test.tuple_i7, !test.tuple_i6, tuple>) +"test.op"() {types = [ + tuple>, + tuple>, + tuple> + ]} : () -> ( + tuple>, + tuple>, + tuple> + ) diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp index 3d4aa23ebe78a..8061fadb1da95 100644 --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -8,6 +8,8 @@ #include "TestDialect.h" #include "TestOps.h" +#include "TestTypes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Interfaces/FoldInterfaces.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/InliningUtils.h" @@ -262,6 +264,52 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { return AliasResult::NoAlias; } + void registerAttrAliasPrinter(InsertAttrAliasPrinter insertFn) const final { + insertFn(TypeID::get(), + [](Attribute attr, AsmPrinter &printer, bool printStripped) { + auto tTy = dyn_cast(cast(attr).getValue()); + if (!tTy || tTy.size() != 1) + return; + if (auto iTy = dyn_cast(tTy.getType(0)); + !iTy || iTy.getWidth() != 7) + return; + printer.getStream() << "tuple_i7_from_attr"; + }); + insertFn(TypeID::get(), + [](Attribute attr, AsmPrinter &printer, bool printStripped) { + auto tTy = dyn_cast(cast(attr).getValue()); + if (!tTy || tTy.size() != 1) + return; + if (auto iTy = dyn_cast(tTy.getType(0)); + !iTy || iTy.getWidth() != 6) + return; + printer.getStream() << "tuple_i6_from_attr"; + }); + } + + void registerTypeAliasPrinter(InsertTypeAliasPrinter insertFn) const final { + insertFn(TypeID::get(), + [](Type type, AsmPrinter &printer, bool printStripped) { + auto tTy = dyn_cast(type); + if (!tTy || tTy.size() != 1) + return; + if (auto iTy = dyn_cast(tTy.getType(0)); + !iTy || iTy.getWidth() != 7) + return; + printer.getStream() << "tuple_i7"; + }); + insertFn(TypeID::get(), + [](Type type, AsmPrinter &printer, bool printStripped) { + auto tTy = dyn_cast(type); + if (!tTy || tTy.size() != 1) + return; + if (auto iTy = dyn_cast(tTy.getType(0)); + !iTy || iTy.getWidth() != 6) + return; + printer.getStream() << "tuple_i6"; + }); + } + //===------------------------------------------------------------------===// // Resources //===------------------------------------------------------------------===//