From 797c073c0fe38438b4a223d87c7e468ce59ee865 Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Wed, 17 Dec 2025 16:35:45 +0000 Subject: [PATCH] [mlir] Add dialect hooks for registering custom type and attribute alias printers This patch introduces a mechanism for dialects to register custom alias printers for types and attributes via the `OpAsmDialectInterface`. This allows dialects to provide alternative printed representations for types and attributes based on their TypeID, including types/attributes from other dialects. The new `registerAttrAliasPrinter` and `registerTypeAliasPrinter` virtual methods accept callbacks that register printers for specific TypeIDs. When printing, these custom printers are invoked in registration order, and the first one to produce output is used. The precedence for alias resolution is: 1. Explicit type/attribute aliases returned by `getAlias` 2. Dialect-specific alias printers registered via the new hooks 3. Default type/attribute printers Signed-off-by: Fabian Mora --- mlir/include/mlir/IR/OpImplementation.h | 37 +++ mlir/lib/IR/AsmPrinter.cpp | 215 +++++++++++++++++- .../IR/print-attr-type-dialect-aliases.mlir | 14 ++ .../Dialect/Test/TestDialectInterfaces.cpp | 48 ++++ 4 files changed, 310 insertions(+), 4 deletions(-) create mode 100644 mlir/test/IR/print-attr-type-dialect-aliases.mlir diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index d70aa346eaa1f..a6d39c4d17a8d 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -17,6 +17,8 @@ #include "mlir/IR/DialectInterface.h" #include "mlir/IR/OpAsmSupport.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/SMLoc.h" #include @@ -169,6 +171,9 @@ class AsmPrinter { if (succeeded(printAlias(attrOrType))) return; + if (succeeded(printDialectAlias(attrOrType, /*printStripped=*/true))) + return; + raw_ostream &os = getStream(); uint64_t posPrior = os.tell(); attrOrType.print(*this); @@ -218,6 +223,14 @@ class AsmPrinter { /// be printed. virtual LogicalResult printAlias(Type type); + /// Print the alias for the given attribute, return failure if no alias could + /// be printed. + virtual LogicalResult printDialectAlias(Attribute attr, bool printStripped); + + /// Print the alias for the given type, return failure if no alias could + /// be printed. + virtual LogicalResult printDialectAlias(Type type, bool printStripped); + /// Print the given string as a keyword, or a quoted and escaped string if it /// has any special or non-printable characters in it. virtual void printKeywordOrString(StringRef keyword); @@ -1799,6 +1812,30 @@ class OpAsmDialectInterface return AliasResult::NoAlias; } + /// Hooks for registering alias printers for types and attributes. These + /// printers are invoked when printing types or attributes of the given + /// TypeID. Printers are invoked in the order they are registered, and the + /// first one to print an alias is used. + /// The precedence of these printers is as follow: + /// 1. The type and attribute aliases returned by `getAlias`. + /// 2. Dialect-specific alias printers registered here. + /// 3. The type and attribute printers. + /// The boolean argument to the printer indicates whether the stripped form + /// of the type or attribute is being printed. + /// NOTE: This mechanism caches the printed object, therefore the printer + /// must always produce the same output for the same input. + using AttributeAliasPrinter = + llvm::function_ref; + using InsertAttrAliasPrinter = + llvm::function_ref; + virtual void registerAttrAliasPrinter(InsertAttrAliasPrinter insertFn) const { + } + using TypeAliasPrinter = llvm::function_ref; + using InsertTypeAliasPrinter = + llvm::function_ref; + virtual void registerTypeAliasPrinter(InsertTypeAliasPrinter insertFn) const { + } + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 7d991cea6c468..86fa754b3bb54 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -32,13 +32,16 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/Endian.h" @@ -47,6 +50,7 @@ #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/Threading.h" #include "llvm/Support/raw_ostream.h" +#include #include #include @@ -414,6 +418,7 @@ class AsmPrinter::Impl { public: Impl(raw_ostream &os, AsmStateImpl &state); explicit Impl(Impl &other) : Impl(other.os, other.state) {} + explicit Impl(raw_ostream &os, Impl &other) : Impl(os, other.state) {} /// Returns the output stream of the printer. raw_ostream &getStream() { return os; } @@ -447,6 +452,10 @@ class AsmPrinter::Impl { /// be printed. LogicalResult printAlias(Attribute attr); + /// Print the dialect alias for the given attribute, return failure if no + /// alias could be printed. + LogicalResult printDialectAlias(Attribute attr, bool printStripped); + /// Print the given type or an alias. void printType(Type type); /// Print the given type. @@ -456,6 +465,10 @@ class AsmPrinter::Impl { /// be printed. LogicalResult printAlias(Type type); + /// Print the dialect alias for the given type, return failure if no alias + /// could be printed. + LogicalResult printDialectAlias(Type type, bool printStripped); + /// Print the given location to the stream. If `allowAlias` is true, this /// allows for the internal location to use an attribute alias. void printLocation(LocationAttr loc, bool allowAlias = false); @@ -812,6 +825,12 @@ class DummyAliasOperationPrinter : private OpAsmPrinter { initializer.visit(type); return success(); } + LogicalResult printDialectAlias(Attribute attr, bool printStripped) override { + return failure(); + } + LogicalResult printDialectAlias(Type type, bool printStripped) override { + return failure(); + } /// Consider the given location to be printed for an alias. void printOptionalLocationSpecifier(Location loc) override { @@ -991,6 +1010,11 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { return success(); } + LogicalResult printDialectAlias(Attribute, bool) override { + return failure(); + } + LogicalResult printDialectAlias(Type, bool) override { return failure(); } + /// Record the alias result of a child element. void recordAliasResult(std::pair aliasDepthAndIndex) { childIndices.push_back(aliasDepthAndIndex.second); @@ -1251,7 +1275,10 @@ namespace { /// This class manages the state for type and attribute aliases. class AliasState { public: - // Initialize the internal aliases. + /// Initialize the alias state for custom dialect aliases. + AliasState(DialectInterfaceCollection &interfaces); + + /// Initialize the internal aliases. void initialize(Operation *op, const OpPrintingFlags &printerFlags, DialectInterfaceCollection &interfaces); @@ -1275,20 +1302,99 @@ class AliasState { printAliases(p, newLine, /*isDeferred=*/true); } + /// Get an attribute alias if it exists. Returns the alias if found, + /// or a default constructed pair otherwise. + std::pair + getAttrAlias(AsmPrinter::Impl &p, Attribute attr, bool printStripped) { + return getAlias(p, attr.getTypeID(), attr.getAsOpaquePointer(), + printStripped); + } + + /// Get a type alias if it exists. Returns the alias if found, + /// or a default constructed pair otherwise. + std::pair + getTypeAlias(AsmPrinter::Impl &p, Type type, bool printStripped) { + return getAlias(p, type.getTypeID(), type.getAsOpaquePointer(), + printStripped); + } + private: + using TypeIDPrinter = + std::tuple>; + using PrinterIterator = SmallVectorImpl::iterator; + /// Print all of the referenced aliases that support the provided resolution /// behavior. void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, bool isDeferred); + /// Comparison function for TypeIDPrinter. + static bool comparePrinters(const TypeIDPrinter &lhs, + const TypeIDPrinter &rhs); + + /// Find custom printers for the given TypeID. + llvm::iterator_range findPrinters(TypeID typeID); + + /// Get an attribute or type alias if it exists. Returns the alias if found, + /// or a default constructed pair otherwise. + std::pair getAlias(AsmPrinter::Impl &p, + TypeID typeID, + const void *opaqueAttrType, + bool printStripped); + /// Mapping between attribute/type and alias. llvm::MapVector attrTypeToAlias; /// An allocator used for alias names. llvm::BumpPtrAllocator aliasAllocator; + + /// Mapping between attribute/type ID and custom printers for them. + SmallVector attrTypePrinters; + + /// Cache for custom printed attributes/types. + DenseMap, + std::pair> + printCache; }; } // namespace +AliasState::AliasState( + DialectInterfaceCollection &interfaces) { + // Collect all of the custom alias printers. + for (const OpAsmDialectInterface &interface : interfaces) { + auto insertAliasAttrFn = + [&](TypeID typeID, + OpAsmDialectInterface::AttributeAliasPrinter printer) { + if (!printer) + return; + attrTypePrinters.emplace_back( + typeID, interface.getDialect(), + [printer](const void *attr, AsmPrinter &p, bool printStripped) { + printer(Attribute::getFromOpaquePointer(attr), p, + printStripped); + }); + }; + auto insertAliasTypeFn = + [&](TypeID typeID, OpAsmDialectInterface::TypeAliasPrinter printer) { + if (!printer) + return; + attrTypePrinters.emplace_back( + typeID, interface.getDialect(), + [printer](const void *attr, AsmPrinter &p, bool printStripped) { + printer(Type::getFromOpaquePointer(attr), p, printStripped); + }); + }; + interface.registerAttrAliasPrinter(insertAliasAttrFn); + interface.registerTypeAliasPrinter(insertAliasTypeFn); + } + + // Sort the printers by TypeID for efficient lookup. + // Stable sort guarantees that the order of registration is preserved. + std::stable_sort(attrTypePrinters.begin(), attrTypePrinters.end(), + comparePrinters); +} + void AliasState::initialize( Operation *op, const OpPrintingFlags &printerFlags, DialectInterfaceCollection &interfaces) { @@ -1315,6 +1421,24 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { return success(); } +bool AliasState::comparePrinters(const TypeIDPrinter &lhs, + const TypeIDPrinter &rhs) { + return std::get<0>(lhs).getAsOpaquePointer() < + std::get<0>(rhs).getAsOpaquePointer(); +} + +llvm::iterator_range +AliasState::findPrinters(TypeID typeID) { + TypeIDPrinter key = std::make_tuple( + typeID, /*unused*/ nullptr, + /*unused*/ std::function()); + PrinterIterator lb = std::lower_bound( + attrTypePrinters.begin(), attrTypePrinters.end(), key, comparePrinters); + PrinterIterator ub = std::upper_bound( + attrTypePrinters.begin(), attrTypePrinters.end(), key, comparePrinters); + return llvm::make_range(lb, ub); +} + void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, bool isDeferred) { auto filterFn = [=](const auto &aliasIt) { @@ -1342,6 +1466,40 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, } } +std::pair +AliasState::getAlias(AsmPrinter::Impl &p, TypeID typeID, + const void *opaqueAttrType, bool printStripped) { + llvm::PointerIntPair key(opaqueAttrType, + printStripped); + // Check the cache first. + if (auto it = printCache.find(key); it != printCache.end()) + return it->second; + + // Try to get the alias using custom printers. + std::string buffer; + llvm::raw_string_ostream os(buffer); + AsmPrinter::Impl printImpl(os, p); + DialectAsmPrinter printer(printImpl); + for (const auto &printInfo : findPrinters(typeID)) { + // Invoke the printer. + std::get<2>(printInfo)(opaqueAttrType, printer, printStripped); + + // Trim any whitespace. + if (StringRef str = StringRef(buffer).trim(); str != buffer) + buffer = str.str(); + + // If we printed something, cache and return. + if (!buffer.empty()) { + StringRef alias = (printCache[key] = std::make_pair( + std::get<1>(printInfo), std::move(buffer))) + .second; + return std::make_pair(std::get<1>(printInfo), alias); + } + buffer.clear(); + } + return {nullptr, StringRef()}; +} + //===----------------------------------------------------------------------===// // SSANameState //===----------------------------------------------------------------------===// @@ -1948,11 +2106,13 @@ class AsmStateImpl { public: explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, AsmState::LocationMap *locationMap) - : interfaces(op->getContext()), nameState(op, printerFlags), - printerFlags(printerFlags), locationMap(locationMap) {} + : interfaces(op->getContext()), aliasState(interfaces), + nameState(op, printerFlags), printerFlags(printerFlags), + locationMap(locationMap) {} explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, AsmState::LocationMap *locationMap) - : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {} + : interfaces(ctx), aliasState(interfaces), printerFlags(printerFlags), + locationMap(locationMap) {} /// Initialize the alias state to enable the printing of aliases. void initializeAliases(Operation *op) { @@ -2377,6 +2537,38 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) { return state.getAliasState().getAlias(type, os); } +LogicalResult AsmPrinter::Impl::printDialectAlias(Attribute attr, + bool printStripped) { + // Check to see if there is a dialect alias for this attribute. + auto [aliasDialect, alias] = + state.getAliasState().getAttrAlias(*this, attr, printStripped); + if (aliasDialect && !alias.empty()) { + if (printStripped) { + os << alias; + return success(); + } + printDialectSymbol(os, "!", aliasDialect->getNamespace(), alias); + return success(); + } + return failure(); +} + +LogicalResult AsmPrinter::Impl::printDialectAlias(Type type, + bool printStripped) { + // Check to see if there is a dialect alias for this type. + auto [aliasDialect, alias] = + state.getAliasState().getTypeAlias(*this, type, printStripped); + if (aliasDialect && !alias.empty()) { + if (printStripped) { + os << alias; + return success(); + } + printDialectSymbol(os, "!", aliasDialect->getNamespace(), alias); + return success(); + } + return failure(); +} + void AsmPrinter::Impl::printAttribute(Attribute attr, AttrTypeElision typeElision) { if (!attr) { @@ -2387,6 +2579,8 @@ void AsmPrinter::Impl::printAttribute(Attribute attr, // Try to print an alias for this attribute. if (succeeded(printAlias(attr))) return; + if (succeeded(printDialectAlias(attr, /*printStripped=*/false))) + return; return printAttributeImpl(attr, typeElision); } void AsmPrinter::Impl::printAttributeImpl(Attribute attr, @@ -2715,6 +2909,8 @@ void AsmPrinter::Impl::printType(Type type) { // Try to print an alias for this type. if (succeeded(printAlias(type))) return; + if (succeeded(printDialectAlias(type, /*printStripped=*/false))) + return; return printTypeImpl(type); } @@ -2987,6 +3183,17 @@ LogicalResult AsmPrinter::printAlias(Type type) { return impl->printAlias(type); } +LogicalResult AsmPrinter::printDialectAlias(Attribute attr, + bool printStripped) { + assert(impl && "expected AsmPrinter::printDialectAlias to be overriden"); + return impl->printAlias(attr); +} + +LogicalResult AsmPrinter::printDialectAlias(Type type, bool printStripped) { + assert(impl && "expected AsmPrinter::printDialectAlias to be overriden"); + return impl->printAlias(type); +} + void AsmPrinter::printAttributeWithoutType(Attribute attr) { assert(impl && "expected AsmPrinter::printAttributeWithoutType to be overriden"); diff --git a/mlir/test/IR/print-attr-type-dialect-aliases.mlir b/mlir/test/IR/print-attr-type-dialect-aliases.mlir new file mode 100644 index 0000000000000..95adb49d0a59a --- /dev/null +++ b/mlir/test/IR/print-attr-type-dialect-aliases.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s | FileCheck %s + +// Check that attr and type aliases are properly printed. + +// CHECK: {types = [!test.tuple_i7_from_attr, !test.tuple_i6_from_attr, tuple>]} : () -> (!test.tuple_i7, !test.tuple_i6, tuple>) +"test.op"() {types = [ + tuple>, + tuple>, + tuple> + ]} : () -> ( + tuple>, + tuple>, + tuple> + ) diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp index 3d4aa23ebe78a..8061fadb1da95 100644 --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -8,6 +8,8 @@ #include "TestDialect.h" #include "TestOps.h" +#include "TestTypes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Interfaces/FoldInterfaces.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/InliningUtils.h" @@ -262,6 +264,52 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { return AliasResult::NoAlias; } + void registerAttrAliasPrinter(InsertAttrAliasPrinter insertFn) const final { + insertFn(TypeID::get(), + [](Attribute attr, AsmPrinter &printer, bool printStripped) { + auto tTy = dyn_cast(cast(attr).getValue()); + if (!tTy || tTy.size() != 1) + return; + if (auto iTy = dyn_cast(tTy.getType(0)); + !iTy || iTy.getWidth() != 7) + return; + printer.getStream() << "tuple_i7_from_attr"; + }); + insertFn(TypeID::get(), + [](Attribute attr, AsmPrinter &printer, bool printStripped) { + auto tTy = dyn_cast(cast(attr).getValue()); + if (!tTy || tTy.size() != 1) + return; + if (auto iTy = dyn_cast(tTy.getType(0)); + !iTy || iTy.getWidth() != 6) + return; + printer.getStream() << "tuple_i6_from_attr"; + }); + } + + void registerTypeAliasPrinter(InsertTypeAliasPrinter insertFn) const final { + insertFn(TypeID::get(), + [](Type type, AsmPrinter &printer, bool printStripped) { + auto tTy = dyn_cast(type); + if (!tTy || tTy.size() != 1) + return; + if (auto iTy = dyn_cast(tTy.getType(0)); + !iTy || iTy.getWidth() != 7) + return; + printer.getStream() << "tuple_i7"; + }); + insertFn(TypeID::get(), + [](Type type, AsmPrinter &printer, bool printStripped) { + auto tTy = dyn_cast(type); + if (!tTy || tTy.size() != 1) + return; + if (auto iTy = dyn_cast(tTy.getType(0)); + !iTy || iTy.getWidth() != 6) + return; + printer.getStream() << "tuple_i6"; + }); + } + //===------------------------------------------------------------------===// // Resources //===------------------------------------------------------------------===//