diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h index 1776a209d0bf1..80ea1e3407058 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h @@ -9,6 +9,8 @@ #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H #define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/LogicalResult.h" @@ -91,6 +93,12 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns); void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options); +enum class LayoutKind { Lane, InstData, Subgroup }; +LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, + LayoutKind layoutKind, bool printOnly = false); + +LogicalResult resolveLayoutConflicts(Operation *target); + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 7fc75e7294ea3..9c35b14be0bd5 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -37,6 +38,7 @@ #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" +#include "mlir/Support/WalkResult.h" namespace mlir { namespace xegpu { @@ -53,8 +55,6 @@ using namespace mlir::dataflow; namespace { -enum class LayoutKind { Lane, InstData, Subgroup }; - //===----------------------------------------------------------------------===// // LayoutInfo //===----------------------------------------------------------------------===// @@ -380,7 +380,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, class LayoutInfoPropagation : public SparseBackwardDataFlowAnalysis { private: - LayoutKind layoutKind; + xegpu::LayoutKind layoutKind; void visitDpasOp(xegpu::DpasOp dpas, ArrayRef operands, ArrayRef results); @@ -436,7 +436,7 @@ class LayoutInfoPropagation public: LayoutInfoPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable, - LayoutKind layoutKind) + xegpu::LayoutKind layoutKind) : SparseBackwardDataFlowAnalysis(solver, symbolTable), layoutKind(layoutKind) {} using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; @@ -526,12 +526,12 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind( if (anchorLayout == nullptr) { return false; } - if (layoutKind == LayoutKind::InstData) { + if (layoutKind == xegpu::LayoutKind::InstData) { return !(anchorLayout.getEffectiveInstDataAsInt().empty()); - } else if (layoutKind == LayoutKind::Lane) { + } else if (layoutKind == xegpu::LayoutKind::Lane) { return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() || anchorLayout.getEffectiveLaneDataAsInt().empty()); - } else if (layoutKind == LayoutKind::Subgroup) { + } else if (layoutKind == xegpu::LayoutKind::Subgroup) { return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() || anchorLayout.getEffectiveSgDataAsInt().empty()); } @@ -579,7 +579,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp( instData = {instHeight, instWidth}; } - if (layoutKind == LayoutKind::InstData) + if (layoutKind == xegpu::LayoutKind::InstData) prefetchLayout = LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData)); else @@ -748,7 +748,7 @@ void LayoutInfoPropagation::visitDpasOp( SmallVector instDataA = {maxALen, subgroupSize}; SmallVector instDataB = {subgroupSize, maxBLen}; - if (layoutKind == LayoutKind::InstData) { + if (layoutKind == xegpu::LayoutKind::InstData) { dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA)); dpasBLayout = @@ -762,7 +762,7 @@ void LayoutInfoPropagation::visitDpasOp( if (operands.size() > 2) { VectorType cTy = dpas.getAccType(); - if (layoutKind == LayoutKind::InstData) { + if (layoutKind == xegpu::LayoutKind::InstData) { const unsigned dataCLen = bTy.getShape().back(); auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType()); @@ -832,7 +832,7 @@ void LayoutInfoPropagation::visitStoreNdOp( instData = {instHeight, instWidth}; } - if (layoutKind == LayoutKind::InstData) + if (layoutKind == xegpu::LayoutKind::InstData) storeLayout = LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData)); else @@ -992,7 +992,7 @@ void LayoutInfoPropagation::visitLoadGatherOp( instData.push_back(chunkSize); } - if (layoutKind == LayoutKind::InstData) + if (layoutKind == xegpu::LayoutKind::InstData) loadLayout = LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData)); else @@ -1055,7 +1055,7 @@ void LayoutInfoPropagation::visitStoreScatterOp( auto uArch = getUArch(getChipStr(storeScatter).value_or("")); const int subgroupSize = uArch->getSubgroupSize(); - if (layoutKind == LayoutKind::InstData) { + if (layoutKind == xegpu::LayoutKind::InstData) { SmallVector instData{subgroupSize}; if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1) @@ -1106,7 +1106,8 @@ class RunLayoutInfoPropagation { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation) - RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) { + RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind) + : target(op) { SymbolTableCollection symbolTable; loadBaselineAnalyses(solver); solver.load(symbolTable, layoutKind); @@ -1180,6 +1181,77 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) { printFunctionResult(funcOp); } +namespace { + +//===----------------------------------------------------------------------===// +// ResolveLayoutConflicts +//===----------------------------------------------------------------------===// +struct ResolveLayoutConflicts { + ResolveLayoutConflicts(Operation *parentOp) + : parentOp(parentOp), builder(parentOp->getContext()) {} + LogicalResult run(); + +private: + Operation *parentOp; + OpBuilder builder; + LogicalResult resolveLoadNdOp(xegpu::LoadNdOp loadNdOp); +}; + +} // namespace + +LogicalResult ResolveLayoutConflicts::run() { + auto r = parentOp->walk([&](Operation *op) -> WalkResult { + TypeSwitch(op).Case([&](xegpu::LoadNdOp loadNdOp) { + return failed(resolveLoadNdOp(loadNdOp)) ? WalkResult::interrupt() + : WalkResult::advance(); + }); + // TODO: Add other layout conflict resolution methods as needed. + return WalkResult::advance(); + }); + + return r.wasInterrupted() ? failure() : success(); +} + +/// LoadNd has a conflict if the tensor descriptor layout is different from the +/// load's anchor layout. +LogicalResult +ResolveLayoutConflicts::resolveLoadNdOp(xegpu::LoadNdOp loadNdOp) { + Attribute anchorLayout = loadNdOp.getLayoutAttr(); + Attribute tdescLayout = loadNdOp.getTensorDescType().getLayout(); + + if (anchorLayout && tdescLayout && anchorLayout != tdescLayout) { + // Try to get the defining CreateNdDescOp of the tensor descriptor. + auto conflictingCreateNdOp = + loadNdOp.getTensorDesc().getDefiningOp(); + if (!conflictingCreateNdOp) { + DBGS() << "Unable to find defining CreateNdDescOp for tensor descriptor: " + << loadNdOp.getTensorDesc() << "\n"; + return failure(); + } + // Duplicate the CreateNdDescOp with the expected layout. + builder.setInsertionPointAfter(conflictingCreateNdOp); + xegpu::TensorDescType tdescType = loadNdOp.getTensorDescType(); + auto expectedLayout = anchorLayout; + auto newTensorDescType = xegpu::TensorDescType::get( + conflictingCreateNdOp.getContext(), tdescType.getShape(), + tdescType.getElementType(), tdescType.getEncoding(), expectedLayout); + xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create( + builder, loadNdOp.getLoc(), newTensorDescType, + conflictingCreateNdOp->getOperands(), + conflictingCreateNdOp->getAttrs()); + // Replace only the conflicting uses of the createNdOp that can be + // resolved using the new layout. + conflictingCreateNdOp->replaceUsesWithIf( + ArrayRef(newOp.getResult()), [&](OpOperand &opnd) { + auto userLoadNdOp = dyn_cast(opnd.getOwner()); + if (!userLoadNdOp) + return false; + return userLoadNdOp.getLayoutAttr() == expectedLayout; + }); + } + return success(); +} + using GetLayoutFnTy = function_ref; /// Update an operation with the layout of its results. If the result type is /// a vector type, a temporary layout attribute is added to the operation. If @@ -1348,26 +1420,14 @@ struct XeGPUPropagateLayoutPass final } // namespace -void XeGPUPropagateLayoutPass::runOnOperation() { - LayoutKind layoutKind; - if (this->layoutKind == "lane") { - layoutKind = LayoutKind::Lane; - } else if (this->layoutKind == "inst") { - layoutKind = LayoutKind::InstData; - } else if (this->layoutKind == "subgroup") { - layoutKind = LayoutKind::Subgroup; - } else { - getOperation()->emitError("Unsupported layout kind option: " + - this->layoutKind); - signalPassFailure(); - return; - } - RunLayoutInfoPropagation analysis(getOperation(), layoutKind); +LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target, + LayoutKind layoutKind, bool printOnly) { + RunLayoutInfoPropagation analysis(target, layoutKind); // Print the analysis result and exit. (for debugging purposes) if (printOnly) { auto &os = llvm::outs(); analysis.printAnalysisResult(os); - return; + return success(); } // Helper to convert LayoutInfo to xegpu::LayoutAttr. auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr { @@ -1381,8 +1441,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() { return cast(layoutAttr); }; - mlir::OpBuilder builder(&getContext()); - Operation *op = getOperation(); + Operation *op = target; auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult { for (mlir::Operation &op : llvm::reverse(block->getOperations())) { LogicalResult r = success(); @@ -1407,7 +1466,39 @@ void XeGPUPropagateLayoutPass::runOnOperation() { } return WalkResult::advance(); }); - if (walkResult.wasInterrupted()) { + if (walkResult.wasInterrupted()) + return failure(); + + return success(); +} + +LogicalResult xegpu::resolveLayoutConflicts(Operation *target) { + ResolveLayoutConflicts resolver(target); + return resolver.run(); +} + +void XeGPUPropagateLayoutPass::runOnOperation() { + xegpu::LayoutKind layoutKind; + if (this->layoutKind == "lane") { + layoutKind = xegpu::LayoutKind::Lane; + } else if (this->layoutKind == "inst") { + layoutKind = xegpu::LayoutKind::InstData; + } else if (this->layoutKind == "subgroup") { + layoutKind = xegpu::LayoutKind::Subgroup; + } else { + getOperation()->emitError("Unsupported layout kind option: " + + this->layoutKind); + signalPassFailure(); + return; + } + OpBuilder builder(&getContext()); + if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind, + this->printOnly))) { + signalPassFailure(); + return; + } + // Resolve layout conflicts if any. + if (failed(xegpu::resolveLayoutConflicts(getOperation()))) { signalPassFailure(); return; } diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir index 5f70831f45e97..5e095fe0df89e 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=inst" -split-input-file %s | FileCheck %s +// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=inst" -split-input-file %s | FileCheck %s // CHECK-LABEL: func.func @load_store_no_array_len( diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir index 092a4cf442782..7675c44be1c61 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=subgroup" -split-input-file %s | FileCheck %s +// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=subgroup" -split-input-file %s | FileCheck %s gpu.module @test { // CHECK-LABEL: store_nd diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir index b88d8e1a78a26..3e7f3d5156d62 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=lane" -split-input-file %s | FileCheck %s +// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=lane" -split-input-file %s | FileCheck %s gpu.module @test { // CHECK-LABEL: func.func @dpas_f16( @@ -32,7 +32,7 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me gpu.module @test { // CHECK-LABEL: func.func @dpas_i8( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) { -// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout, layout_b = #xegpu.layout, layout_cd = #xegpu.layout} +// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout, layout_b = #xegpu.layout, layout_cd = #xegpu.layout} func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) { %c0 = arith.constant 0 : index @@ -109,7 +109,7 @@ gpu.module @test { // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> -> // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout> -// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout}> +// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout}> // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<16xi1> -> vector<16x16xf16> func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) { %c0 = arith.constant 0 : index @@ -240,7 +240,7 @@ gpu.module @test { // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> -// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] +// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] // CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> // CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout} : vector<16xf16> // CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] @@ -697,4 +697,4 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16 xegpu.store_nd %6, %arg0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16> return } -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir new file mode 100644 index 0000000000000..dd3f3c8bdc29e --- /dev/null +++ b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt --test-xegpu-resolve-layout-conflicts -split-input-file %s | FileCheck %s + +#load_lo = #xegpu.layout +#prefetch_lo = #xegpu.layout +gpu.module @test { + +// CHECK-LABEL: func.func @load_nd_with_conflicting_tensor_desc +// CHECK: %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<64x64xf16> +// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout> +// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<64x64xf16> +// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout> +// CHECK-NEXT: %{{.*}} = xegpu.load_nd %[[T1]][%{{.*}}, %{{.*}}] <{layout = #xegpu.layout}> : +// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout> -> vector<16x16xf16> +func.func @load_nd_with_conflicting_tensor_desc(%arg0: memref<64x64xf16>) -> vector<16x16xf16> { + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16> + -> !xegpu.tensor_desc<16x16xf16, #prefetch_lo> + %1 = xegpu.load_nd %0 [%c0, %c0] {layout = #load_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo> + -> vector<16x16xf16> + xegpu.prefetch_nd %0 [%c0, %c0] {layout = #prefetch_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo> + return %1 : vector<16x16xf16> +} +} diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 1a1520dfa975d..c8a6a6d7b8eb8 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -277,6 +277,80 @@ struct TestXeGPUMoveFuncBodyToWarpOp } }; +struct TestXeGPUPropagateLayouts + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUPropagateLayouts) + + StringRef getArgument() const final { return "test-xegpu-propagate-layouts"; } + + StringRef getDescription() const final { + return "Test the implementation of XeGPU propagate layouts."; + } + + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + TestXeGPUPropagateLayouts() = default; + TestXeGPUPropagateLayouts(const TestXeGPUPropagateLayouts &pass) + : PassWrapper(pass) {} + + Option layoutKind{ + *this, "layout-kind", + llvm::cl::desc("Propagate `subgroup` / `inst` / `lane` level of xegpu " + "layouts."), + llvm::cl::init("lane")}; + + void runOnOperation() override { + OpBuilder builder(getOperation()); + LayoutKind kind; + if (layoutKind == "subgroup") + kind = LayoutKind::Subgroup; + else if (layoutKind == "inst") + kind = LayoutKind::InstData; + else if (layoutKind == "lane") + kind = LayoutKind::Lane; + else { + signalPassFailure(); + return; + } + if (failed(xegpu::propagateLayouts(builder, getOperation(), kind))) { + signalPassFailure(); + } + } +}; + +struct TestXeGPUResolveLayoutConflicts + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUResolveLayoutConflicts) + + StringRef getArgument() const final { + return "test-xegpu-resolve-layout-conflicts"; + } + + StringRef getDescription() const final { + return "Test the implementation of XeGPU layout conflict resolution."; + } + + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + TestXeGPUResolveLayoutConflicts() = default; + TestXeGPUResolveLayoutConflicts(const TestXeGPUResolveLayoutConflicts &pass) = + default; + + void runOnOperation() override { + if (failed(xegpu::resolveLayoutConflicts(getOperation()))) { + signalPassFailure(); + } + } +}; + struct TestXeGPULayoutInterface : public PassWrapper> { @@ -342,6 +416,8 @@ void registerTestXeGPULowerings() { PassRegistration(); PassRegistration(); PassRegistration(); + PassRegistration(); + PassRegistration(); } } // namespace test } // namespace mlir