From 1f707d71c48a60706c1f5f6f3110fe81ff6054e3 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 25 Apr 2026 17:16:42 +0000 Subject: [PATCH 01/13] initial commit for CK Tile MXFP8 integration for gfx1250 --- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 492 ++++++++++++++++++ 1 file changed, 492 insertions(+) create mode 100644 transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp new file mode 100644 index 000000000..c7746b563 --- /dev/null +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -0,0 +1,492 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include "../../common.h" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>; + +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::e8m0_t; }; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; +template <> struct TETypeToCKType { using type = float; }; + +struct GroupedGemmRunContext { + const NVTETensor* A = nullptr; + const NVTETensor* B = nullptr; + NVTETensor* D = nullptr; + + int group_num = 0; + bool transA = false; + bool transB = false; + + void* workspace = nullptr; + size_t workspace_bytes = 0; + hipStream_t stream = nullptr; + +}; + +static constexpr ck_tile::index_t ScaleBlockSize = 32; + +enum struct MxGemmPipelineType +{ + CompTDMV1, + CompTDMV2 +}; + +template +struct MxGemmPipelineTypeSelector; +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompTDM; + using pipeline = ck_tile::GemmPipelineAgBgCrCompTDMV1; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV1"; } +}; + +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompTDM; + using pipeline = ck_tile::GemmPipelineAgBgCrCompTDMV2; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } +}; + +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + return t.data; +} + +static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { + return t.scale_inv; +} + +template +static inline bool has_sufficient_workspace(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + ", available bytes=", ctx.workspace_bytes, ". Falling back."); + return false; + } + return true; +} + +struct GroupedGemKernelParam_Wmma +{ + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = false; + static const int kBlockPerCu = 1; + static const ck_tile::index_t M_Tile = 64; + static const ck_tile::index_t N_Tile = 64; + static const ck_tile::index_t K_Tile = 128; + static const ck_tile::index_t M_Warp = 2; + static const ck_tile::index_t N_Warp = 2; + static const ck_tile::index_t K_Warp = 1; + static const ck_tile::index_t M_Warp_Tile = 32; + static const ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 128; +}; + +template +__global__ void preshuffle_scale_gfx1250_kernel(const ScaleType* __restrict__ src, + ScaleType* __restrict__ dst, + int actual_rows, + int output_rows, + int KScale) +{ + static_assert(ScaleBlockSize == 32 && sizeof(ScaleType) == 1, + "gfx1250 scale preshuffle only supports 8-bit scale with ScaleBlockSize=32"); + constexpr int MPerXdlops = 16; + constexpr int KPerXdlops = 128; + constexpr int MNPack = 2; + constexpr int KPack = 1; + constexpr int MNStep = MPerXdlops; // 16 + constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 + const int K0 = KScale / (KPack * KStep); + const int linear = blockIdx.x * blockDim.x + threadIdx.x; + const int total = output_rows * KScale; + if(linear >= total) + return; + const int mn = linear / KScale; + const int k = linear % KScale; + const int iMNRepeat = mn / (MNStep * MNPack); + const int tempmn = mn % (MNStep * MNPack); + const int iKRepeat = k / (KStep * KPack); + const int tempk = k % (KStep * KPack); + const int outputIndex = + (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + + (iKRepeat * KStep * KPack) * (MNStep * MNPack) + + tempmn * (KStep * KPack) + + tempk; + ScaleType value{}; + if(mn < actual_rows) + { + if constexpr(KStride) + value = src[mn * KScale + k]; + else + value = src[k * actual_rows + mn]; + } + dst[outputIndex] = value; +} + +template +void preShuffleScaleBuffer_gfx1250(const ScaleType* src, + ScaleType* dst, + int actual_rows, + int output_rows, + int KScale, + hipStream_t stream) +{ + constexpr int KPerXdlops = 128; + constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 + if(KScale % KStep != 0) + { + NVTE_ERROR("preshuffle_scale_gfx1250: KScale must be a multiple of 4, " + "i.e. original K must be a multiple of 128 for ScaleBlockSize=32."); + } + const int total = output_rows * KScale; + constexpr int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + hipLaunchKernelGGL((preshuffle_scale_gfx1250_kernel), + dim3(grid_size), + dim3(block_size), + 0, + stream, + src, + dst, + actual_rows, + output_rows, + KScale); + NVTE_CHECK_CUDA(hipGetLastError()); +} + +template +bool invoke_mx_grouped_gemm(const std::vector& descs, const GroupedGemmRunContext& ctx, const ck_tile::stream_config& stream_cfg) +{ + // check hardware WMMA support for the warp tile + static constexpr bool has_wmma_support = + ck_tile::has_wmma_traits_v; + + NVTE_CHECK(has_wmma_support, + "ck_tile_mx_grouped_gemm: unsupported gfx125 WMMA traits for " + "AType/BType/AccType with warp tile shape ", + MXFP8GemmConfig::M_Warp_Tile, "x", + MXFP8GemmConfig::N_Warp_Tile, "x", + MXFP8GemmConfig::K_Warp_Tile); + + using CLayout = RowMajor; + constexpr bool preshuffle = false; + constexpr bool DoubleSmemBuffer = true; // TDM pipeline requires double smem buffer + constexpr bool TransposeC = + std::is_same_v && + MXFP8GemmConfig::M_Warp_Tile == MXFP8GemmConfig::N_Warp_Tile; + static constexpr bool StructuredSparsity = false; + static constexpr bool NumWaveGroup = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { + using ALayout = std::conditional_t; + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { + using BLayout = std::conditional_t; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using UniversalGemmProblem = + ck_tile::MxGemmPipelineProblem; + using PipelineType = MxGemmPipelineType::CompTDMV1; + /* make pipeline selective */ + using GemmPipeline = + typename MxGemmPipelineTypeSelector::pipeline; + using GemmEpilogue = ck_tile::TdmEpilogue< + ck_tile::CShuffleEpilogueProblem,//DsDataType + float, + CType, + ck_tile::tuple<>,//DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + MXFP8GemmConfig::M_Warp, + MXFP8GemmConfig::N_Warp, + MXFP8GemmConfig::M_Warp_Tile, + MXFP8GemmConfig::N_Warp_Tile, + MXFP8GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC, + 1, /*kNumWaveGroups_*/ + false, /*FixedVectorSize_*/ + 1, /*VectorSizeC_*/ + false, /*TiledMMAPermuteN_*/ + 1, /*BlockedXDLN_PerWarp_*/ + DoubleSmemBuffer, /*DoubleSmemBuffer*/ + AType, /*AType_*/ + BType /*BType_*/>>; + using Kernel = ck_tile::MxGroupedGemmKernel; + + if (!has_sufficient_workspace(ctx)) { + return false; + } + + auto kargs = Kernel::MakeKargs(descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + NVTE_WARN("ck_tile_mx_grouped_gemm: CK_Tile kernel arguments not supported for this config. " + "Falling back."); + return false; + } + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + NVTE_CHECK_CUDA(hipMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + ctx.stream)); + ck_tile::ignore = ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + kargs.size())); + return true; + }); + }); + return false; +} + +bool ck_tile_mx_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate,//ignored for now + hipStream_t stream) { + if (group_num <= 0) { + return true; + } + + // Normalize input mats + // I.e., swap A and B, as well as transa and transb. + const NVTETensor* A_use = B; + const NVTETensor* B_use = A; + bool transA_use = transB; + bool transB_use = transA; + + // Validate scale type / data type combination + // Expected input data format: fp8/bf8 (e4m3/e5m2) + // Expected scale data format: e8m0 + const auto* A0 = convertNVTETensorCheck(A_use[0]); + const auto* B0 = convertNVTETensorCheck(B_use[0]); + const auto* D0 = convertNVTETensorCheck(D[0]); + NVTE_CHECK(A0->scale_inv.has_data(), "ck_tile_mx_grouped_gemm: A[0] scale_inv is not initialized"); + NVTE_CHECK(B0->scale_inv.has_data(), "ck_tile_mx_grouped_gemm: B[0] scale_inv is not initialized"); + + const auto a_scale_dtype = A0->scale_inv.dtype; + const auto b_scale_dtype = B0->scale_inv.dtype; + NVTE_CHECK(a_scale_dtype == DType::kFloat8E8M0, + "ck_tile_mx_grouped_gemm: A scale_inv dtype must be Float8E8M0, got ", + static_cast(a_scale_dtype)); + + NVTE_CHECK(b_scale_dtype == DType::kFloat8E8M0, + "ck_tile_mx_grouped_gemm: B scale_inv dtype must be Float8E8M0, got ", + static_cast(b_scale_dtype)); + + const auto a_dtype = A0->dtype(); + const auto b_dtype = B0->dtype(); + const auto d_dtype = D0->dtype(); + NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: A dtype must be FP8"); + NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: B dtype must be FP8"); + + using AScaleType = typename TETypeToCKType::type; + using BScaleType = typename TETypeToCKType::type; + + void* ws_ptr = nullptr; + size_t ws_bytes = 0; + if (workspace) { + auto* ws_te = convertNVTETensorCheck(*workspace); + ws_ptr = ws_te->data.dptr; + ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); + } + + GroupedGemmRunContext ctx = { + A_use, + B_use, + D, + group_num, + transA_use, + transB_use, + ws_ptr, + ws_bytes, + stream}; + + const ck_tile::stream_config s{ctx.stream}; + + std::vector descs; + descs.reserve(group_num); + + std::vector> a_scale_shuffled_bufs; + std::vector> b_scale_shuffled_bufs; + a_scale_shuffled_bufs.reserve(group_num); + b_scale_shuffled_bufs.reserve(group_num); + + for (int i = 0; i < group_num; i++) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected all groups to be rank>=2."); + } + const auto& a_scales = scale_inv_view(*A_te); + const auto& b_scales = scale_inv_view(*B_te); + if (a_scales.shape.size() != 2 || b_scales.shape.size() != 2) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected A/B scale_inv tensors to be rank-2."); + } + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + if (K % ScaleBlockSize != 0) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: K must be a multiple of ScaleBlockSize for MX GEMM", i); + } + const int KScale = static_cast(K / ScaleBlockSize); + if (Kb != K) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: K mismatch between A and B in group ", i); + } + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: D shape mismatch in group ", i); + } + const ck_tile::index_t stride_A = static_cast(Ad1); + const ck_tile::index_t stride_B = static_cast(Bd1); + const ck_tile::index_t stride_E = static_cast(Dd1); + // Pre-shuffle scale buffers for the hardware + const int a_scale_actual_rows = static_cast(M); + const int a_scale_output_rows = + ck_tile::integer_least_multiple( + static_cast(M), + static_cast(GroupedGemKernelParam_Wmma::M_Warp_Tile)); + const int b_scale_actual_rows = static_cast(N); + const int b_scale_output_rows = static_cast(N); + const size_t a_scale_shuffled_bytes = + static_cast(a_scale_output_rows) * + static_cast(KScale) * + sizeof(AScaleType); + const size_t b_scale_shuffled_bytes = + static_cast(b_scale_output_rows) * + static_cast(KScale) * + sizeof(BScaleType); + a_scale_shuffled_bufs.push_back( + std::make_unique(a_scale_shuffled_bytes)); + b_scale_shuffled_bufs.push_back( + std::make_unique(b_scale_shuffled_bytes)); + void* a_scale_shuffled_ptr = a_scale_shuffled_bufs.back()->GetDeviceBuffer(); + void* b_scale_shuffled_ptr = b_scale_shuffled_bufs.back()->GetDeviceBuffer(); + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(a_scales.dptr), + reinterpret_cast(a_scale_shuffled_ptr), + a_scale_actual_rows, + a_scale_output_rows, + KScale, + stream); + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(b_scales.dptr), + reinterpret_cast(b_scale_shuffled_ptr), + b_scale_actual_rows, + b_scale_output_rows, + KScale, + stream); + descs.emplace_back(mx_grouped_gemm_kargs( + a.dptr, + a_scale_shuffled_ptr, + b.dptr, + b_scale_shuffled_ptr, + {/*ds_ptr*/}, + d.dptr, + 1,//kbatch + M, + N, + K, + stride_A, + stride_B, + {/*stride_Ds*/}, + stride_E)); + } + // invoke gemm + bool ok = false; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { + using BType = typename TETypeToCKType::type; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + ok = invoke_mx_grouped_gemm(descs,ctx,s); + }); + }); + }); + return ok; +} From e102f00a1e3c73386a9c79a0b990c52203a7274f Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 25 Apr 2026 17:38:06 +0000 Subject: [PATCH 02/13] ck mxfp8 gfx1250 integration builds successfully --- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index c7746b563..1ced0d61e 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -14,6 +14,9 @@ #include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +namespace transformer_engine { +namespace mx_grouped_gemm { + using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>; @@ -21,7 +24,6 @@ using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>; template struct TETypeToCKType; template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; -template <> struct TETypeToCKType { using type = ck_tile::e8m0_t; }; template <> struct TETypeToCKType { using type = ck_tile::half_t; }; template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; template <> struct TETypeToCKType { using type = float; }; @@ -41,6 +43,18 @@ struct GroupedGemmRunContext { }; +// Treat TE tensors as generalized 2D matrices by flattening: +// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. +static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, + int64_t& d0, int64_t& d1) { + if (t.shape().size() < 2) { + return false; + } + d0 = static_cast(t.flat_first_dim()); + d1 = static_cast(t.flat_last_dim()); + return true; +} + static constexpr ck_tile::index_t ScaleBlockSize = 32; enum struct MxGemmPipelineType @@ -247,10 +261,10 @@ bool invoke_mx_grouped_gemm(const std::vector& descs, con BType, AScaleType, BScaleType>; - using PipelineType = MxGemmPipelineType::CompTDMV1; /* make pipeline selective */ using GemmPipeline = - typename MxGemmPipelineTypeSelector::pipeline; + typename MxGemmPipelineTypeSelector::pipeline; using GemmEpilogue = ck_tile::TdmEpilogue< ck_tile::CShuffleEpilogueProblemscale_inv.has_data(), "ck_tile_mx_grouped_gemm: A[0] scale_inv is not initialized"); - NVTE_CHECK(B0->scale_inv.has_data(), "ck_tile_mx_grouped_gemm: B[0] scale_inv is not initialized"); + NVTE_CHECK(A0->scale_inv.dptr != nullptr, + "ck_tile_mx_grouped_gemm: A[0] scale_inv is not initialized"); + NVTE_CHECK(B0->scale_inv.dptr != nullptr, + "ck_tile_mx_grouped_gemm: B[0] scale_inv is not initialized"); const auto a_scale_dtype = A0->scale_inv.dtype; const auto b_scale_dtype = B0->scale_inv.dtype; @@ -352,8 +368,8 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: A dtype must be FP8"); NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: B dtype must be FP8"); - using AScaleType = typename TETypeToCKType::type; - using BScaleType = typename TETypeToCKType::type; + using AScaleType = ck_tile::e8m0_t; + using BScaleType = ck_tile::e8m0_t; void* ws_ptr = nullptr; size_t ws_bytes = 0; @@ -490,3 +506,6 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, }); return ok; } + +} // namespace mx_grouped_gemm +} // namespace transformer_engine \ No newline at end of file From 52a28875302c38dcfbc6ae128f4024be4d56a2e9 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 25 Apr 2026 18:11:08 +0000 Subject: [PATCH 03/13] add entrypoint to ck mx group gemm in caller --- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 15 ++++++++- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp | 16 ++++++++++ .../common/gemm/cublaslt_gemm.cu | 32 +++++++++++-------- 3 files changed, 49 insertions(+), 14 deletions(-) create mode 100644 transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index 1ced0d61e..b05b844f8 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -508,4 +508,17 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, } } // namespace mx_grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine + +bool ck_tile_mx_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream) { + return transformer_engine::mx_grouped_gemm::ck_tile_mx_grouped_gemm( + A, B, D, group_num, transA, transB, workspace, accumulate, stream); +} diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp new file mode 100644 index 000000000..96d3cd11b --- /dev/null +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp @@ -0,0 +1,16 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +bool ck_tile_mx_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream); + diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..1d030686b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -33,7 +33,8 @@ #include "./cutlass_grouped_gemm.cuh" #else #include "ck_grouped_gemm/ck_grouped_gemm.h" -#endif +#include "ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp" + #ifndef __HIP_PLATFORM_AMD__ namespace { @@ -1163,13 +1164,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto A_dt = inputA->data.dtype; auto B_dt = inputB->data.dtype; auto D_dt = OutputD->data.dtype; - return ( - (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) - ) || - ( - (A_dt == B_dt) && (A_dt == D_dt) && - (is_fp16_dtype(A_dt)) - ); + return (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)); #else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); @@ -1192,11 +1187,22 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && #ifdef __HIP_PLATFORM_AMD__ true) { - if (!ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { - if (warn_fallback) { - NVTE_WARN("Fallback to cuBLAS grouped GEMM."); - } - cublas_path(); + const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode); + + bool handled_by_ck = false; + if (mxfp8_gemm) { + handled_by_ck = ck_tile_mx_grouped_gemm( + A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); + } else { + handled_by_ck = ck_tile_grouped_gemm( + A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); + } + + if (!handled_by_ck) { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } + cublas_path(); } #else all_groups_uniform_k128(B, transb)) { From 80227775d296c6d140957110582b2a46d07de1dd Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 25 Apr 2026 18:25:46 +0000 Subject: [PATCH 04/13] temporary hacky change to test_numerics for bringup testing --- tests/pytorch/test_numerics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index d9c7d1fb0..642999ef7 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2253,6 +2253,7 @@ def test_grouped_linear_accuracy_cutlass( delay_wgrad_compute, ): os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + os.environ["NVTE_ROCM_ENABLE_MXFP8"] = "1" test_grouped_linear_accuracy( dtype, num_gemms, @@ -2268,6 +2269,7 @@ def test_grouped_linear_accuracy_cutlass( use_cutlass=True, ) os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + os.environ.pop("NVTE_ROCM_ENABLE_MXFP8", None) @pytest.mark.parametrize("dtype", param_types, ids=str) From bc6253d013e0dbc19ff22539cb426def9531521f Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 25 Apr 2026 19:34:12 +0000 Subject: [PATCH 05/13] add warning print to confirm we are in fallback --- transformer_engine/common/gemm/cublaslt_gemm.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 1d030686b..9a45b6e10 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1125,6 +1125,9 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor // Currently only support cutlass group gemm on Hopper Arch if (!(is_hopper && use_cutlass)) { #endif + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } cublas_path(); return; } From d26f52e9715d81bffa5cb6769e418ebb568445f0 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 2 May 2026 16:34:30 +0000 Subject: [PATCH 06/13] MXFP8 grouped fwd/bwd now reaches CK path and runs without fallback/crash; remaining issue is numerical validation vs BF16 sequential reference. --- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 168 +++++++++++++----- .../common/gemm/cublaslt_gemm.cu | 20 ++- 2 files changed, 140 insertions(+), 48 deletions(-) diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index b05b844f8..59ae7c0ff 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -41,6 +41,8 @@ struct GroupedGemmRunContext { size_t workspace_bytes = 0; hipStream_t stream = nullptr; + bool use_a_colwise_data = false; + bool use_b_colwise_data = false; }; // Treat TE tensors as generalized 2D matrices by flattening: @@ -55,6 +57,19 @@ static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, return true; } +// Columnwise storage is the physical transposed view used to rewrite a +// normalized GEMM into CK's preferred NT presentation. Interpret its +// 2D shape consistently with the FP8 grouped GEMM path. +static inline bool get_columnwise_storage_2d_dims(const transformer_engine::SimpleTensor& t, + int64_t& d0, int64_t& d1) { + if (t.shape.size() != 2) { + return false; + } + d0 = static_cast(t.shape[1]); + d1 = static_cast(t.shape[0]); + return true; +} + static constexpr ck_tile::index_t ScaleBlockSize = 32; enum struct MxGemmPipelineType @@ -81,19 +96,11 @@ struct MxGemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } }; -static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { - return t.data; -} - -static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { - return t.scale_inv; -} - template static inline bool has_sufficient_workspace(const GroupedGemmRunContext& ctx) { const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); if (!ctx.workspace || ctx.workspace_bytes < needed) { - NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + NVTE_WARN("ck_tile_mx_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, ", available bytes=", ctx.workspace_bytes, ". Falling back."); return false; } @@ -161,11 +168,11 @@ __global__ void preshuffle_scale_gfx1250_kernel(const ScaleType* __restrict__ sr template void preShuffleScaleBuffer_gfx1250(const ScaleType* src, - ScaleType* dst, - int actual_rows, - int output_rows, - int KScale, - hipStream_t stream) + ScaleType* dst, + int actual_rows, + int output_rows, + int KScale, + hipStream_t stream) { constexpr int KPerXdlops = 128; constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 @@ -334,39 +341,75 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, return true; } - // Normalize input mats + // Normalize input mats similar to the FP8 grouped path. // I.e., swap A and B, as well as transa and transb. const NVTETensor* A_use = B; const NVTETensor* B_use = A; bool transA_use = transB; bool transB_use = transA; - // Validate scale type / data type combination + bool use_a_colwise_data = false; + bool use_b_colwise_data = false; + + Tensor* A0_te = convertNVTETensorCheck(A_use[0]); + Tensor* B0_te = convertNVTETensorCheck(B_use[0]); + + // CK MX grouped GEMM is presented as normalized NT, matching the FP8 grouped path. + // Selecting columnwise_data rewrites the physical storage and effective dims used by CK + // while preserving the original math. + if (transA_use) { + if (!A0_te->has_columnwise_data() || A0_te->columnwise_scale_inv.dptr == nullptr) { + NVTE_WARN("ck_tile_mx_grouped_gemm: missing A columnwise MXFP8 view for NT rewrite; falling back."); + return false; + } + use_a_colwise_data = true; + transA_use = false; + } + + if (!transB_use) { + if (!B0_te->has_columnwise_data() || B0_te->columnwise_scale_inv.dptr == nullptr) { + NVTE_WARN("ck_tile_mx_grouped_gemm: missing B columnwise MXFP8 view for NT rewrite; falling back."); + return false; + } + use_b_colwise_data = true; + transB_use = true; + } + + // Validate scale type / data type combination using the effective storage + // selected by the NT canonicalization above. // Expected input data format: fp8/bf8 (e4m3/e5m2) // Expected scale data format: e8m0 - const auto* A0 = convertNVTETensorCheck(A_use[0]); - const auto* B0 = convertNVTETensorCheck(B_use[0]); const auto* D0 = convertNVTETensorCheck(D[0]); - NVTE_CHECK(A0->scale_inv.dptr != nullptr, - "ck_tile_mx_grouped_gemm: A[0] scale_inv is not initialized"); - NVTE_CHECK(B0->scale_inv.dptr != nullptr, - "ck_tile_mx_grouped_gemm: B[0] scale_inv is not initialized"); - const auto a_scale_dtype = A0->scale_inv.dtype; - const auto b_scale_dtype = B0->scale_inv.dtype; + const auto& A0_data = use_a_colwise_data ? A0_te->columnwise_data : A0_te->data; + const auto& B0_data = use_b_colwise_data ? B0_te->columnwise_data : B0_te->data; + const auto& A0_scale = use_a_colwise_data ? A0_te->columnwise_scale_inv : A0_te->scale_inv; + const auto& B0_scale = use_b_colwise_data ? B0_te->columnwise_scale_inv : B0_te->scale_inv; + + NVTE_CHECK(A0_data.dptr != nullptr, + "ck_tile_mx_grouped_gemm: effective A[0] data is not initialized"); + NVTE_CHECK(B0_data.dptr != nullptr, + "ck_tile_mx_grouped_gemm: effective B[0] data is not initialized"); + NVTE_CHECK(A0_scale.dptr != nullptr, + "ck_tile_mx_grouped_gemm: effective A[0] scale_inv is not initialized"); + NVTE_CHECK(B0_scale.dptr != nullptr, + "ck_tile_mx_grouped_gemm: effective B[0] scale_inv is not initialized"); + + const auto a_scale_dtype = A0_scale.dtype; + const auto b_scale_dtype = B0_scale.dtype; NVTE_CHECK(a_scale_dtype == DType::kFloat8E8M0, "ck_tile_mx_grouped_gemm: A scale_inv dtype must be Float8E8M0, got ", static_cast(a_scale_dtype)); - + NVTE_CHECK(b_scale_dtype == DType::kFloat8E8M0, "ck_tile_mx_grouped_gemm: B scale_inv dtype must be Float8E8M0, got ", static_cast(b_scale_dtype)); - - const auto a_dtype = A0->dtype(); - const auto b_dtype = B0->dtype(); + + const auto a_dtype = A0_data.dtype; + const auto b_dtype = B0_data.dtype; const auto d_dtype = D0->dtype(); - NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: A dtype must be FP8"); - NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: B dtype must be FP8"); + NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: effective A dtype must be FP8"); + NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: effective B dtype must be FP8"); using AScaleType = ck_tile::e8m0_t; using BScaleType = ck_tile::e8m0_t; @@ -378,7 +421,7 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, ws_ptr = ws_te->data.dptr; ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); } - + GroupedGemmRunContext ctx = { A_use, B_use, @@ -388,7 +431,9 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, transB_use, ws_ptr, ws_bytes, - stream}; + stream, + use_a_colwise_data, + use_b_colwise_data}; const ck_tile::stream_config s{ctx.stream}; @@ -407,20 +452,48 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, transformer_engine::convertNVTETensorCheck(ctx.B[i]); transformer_engine::Tensor* D_te = transformer_engine::convertNVTETensorCheck(ctx.D[i]); - const auto& a = data_view(*A_te); - const auto& b = data_view(*B_te); - const auto& d = data_view(*D_te); + + const auto& a = ctx.use_a_colwise_data ? A_te->columnwise_data : A_te->data; + const auto& b = ctx.use_b_colwise_data ? B_te->columnwise_data : B_te->data; + const auto& d = D_te->data; + const auto& a_scales = + ctx.use_a_colwise_data ? A_te->columnwise_scale_inv : A_te->scale_inv; + const auto& b_scales = + ctx.use_b_colwise_data ? B_te->columnwise_scale_inv : B_te->scale_inv; + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*B_te, Bd0, Bd1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected all groups to be rank>=2."); + + if (ctx.use_a_colwise_data) { + if (!get_columnwise_storage_2d_dims(A_te->columnwise_data, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected 2D columnwise_data for A in group ", i); + } + } else { + if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized A in group ", i); + } + } + + if (ctx.use_b_colwise_data) { + if (!get_columnwise_storage_2d_dims(B_te->columnwise_data, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected 2D columnwise_data for B in group ", i); + } + } else { + if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized B in group ", i); + } + } + + if (!get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized D in group ", i); + } + if (a.dptr == nullptr || b.dptr == nullptr || a_scales.dptr == nullptr || + b_scales.dptr == nullptr) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: effective A/B data or scale_inv is missing."); } - const auto& a_scales = scale_inv_view(*A_te); - const auto& b_scales = scale_inv_view(*B_te); if (a_scales.shape.size() != 2 || b_scales.shape.size() != 2) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected A/B scale_inv tensors to be rank-2."); + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected effective A/B scale_inv tensors to be rank-2."); } + const int64_t M = ctx.transA ? Ad1 : Ad0; const int64_t K = ctx.transA ? Ad0 : Ad1; const int64_t N = ctx.transB ? Bd0 : Bd1; @@ -430,15 +503,20 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, } const int KScale = static_cast(K / ScaleBlockSize); if (Kb != K) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: K mismatch between A and B in group ", i); + NVTE_ERROR("ck_tile_mx_grouped_gemm: K mismatch between A and B in group ", i, + ". op(A)=", M, "x", K, ", op(B)=", Kb, "x", N); } if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: D shape mismatch in group ", i); + NVTE_ERROR("ck_tile_mx_grouped_gemm: D shape mismatch in group ", i, + ". D=", Dd0, "x", Dd1, ", expected=", M, "x", N); } + const ck_tile::index_t stride_A = static_cast(Ad1); const ck_tile::index_t stride_B = static_cast(Bd1); const ck_tile::index_t stride_E = static_cast(Dd1); - // Pre-shuffle scale buffers for the hardware + + // Pre-shuffle scale buffers for the hardware. + // For the NT-normalized presentation, A scales are MxKScale and B scales are NxKScale. const int a_scale_actual_rows = static_cast(M); const int a_scale_output_rows = ck_tile::integer_least_multiple( diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9a45b6e10..b3863350e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1159,15 +1159,29 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor }; #endif +#ifdef __HIP_PLATFORM_AMD__ + auto effective_dtype = [](const transformer_engine::Tensor *t) { + if (t->has_data()) { + return t->data.dtype; + } + if (t->has_columnwise_data()) { + return t->columnwise_data.dtype; + } + return t->data.dtype; + }; +#endif + auto is_supported_dtype = [&]() -> bool { auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); #ifdef __HIP_PLATFORM_AMD__ - auto A_dt = inputA->data.dtype; - auto B_dt = inputB->data.dtype; + auto A_dt = effective_dtype(inputA); + auto B_dt = effective_dtype(inputB); auto D_dt = OutputD->data.dtype; - return (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)); + + return ((is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) || + ((A_dt == B_dt) && (A_dt == D_dt) && is_fp16_dtype(A_dt))); #else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); From e295e745518927a44b784816f75f25ae9c21a1bc Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 2 May 2026 20:09:45 +0000 Subject: [PATCH 07/13] add cpp test for ck tile group mxfp8 gemm forward --- tests/cpp/operator/CMakeLists.txt | 6 +- .../test_te_ck_grouped_mxfp8_forward_refs.cu | 554 ++++++++++++++++++ .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 20 +- 3 files changed, 578 insertions(+), 2 deletions(-) create mode 100644 tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 0ebd7fdfe..fa9f9a542 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -16,6 +16,7 @@ list(APPEND test_cuda_sources test_dequantize_mxfp8.cu test_dequantize_nvfp4.cu test_cast_nvfp4_transpose.cu + test_te_ck_grouped_mxfp8_forward_refs.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu @@ -31,7 +32,8 @@ list(APPEND test_cuda_sources test_multi_unpadding.cu test_causal_softmax.cu test_swap_first_dims.cu - ../test_common.cu) + ../test_common.cu) + if(USE_CUDA) list(APPEND test_cuda_sources test_cast_float8blockwise.cu @@ -54,12 +56,14 @@ endif() # Find required packages find_package(OpenMP REQUIRED) + if(USE_CUDA) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) else() target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX hiprand) endif() + target_compile_options(test_operator PRIVATE -O2 -fopenmp) include(GoogleTest) diff --git a/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu b/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu new file mode 100644 index 000000000..0872b1640 --- /dev/null +++ b/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu @@ -0,0 +1,554 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +// Forward-only TE CK grouped MXFP8 validation. +// +// Compares three paths for grouped MXFP8 forward GEMM: +// 1. TE nvte_multi_tensor_gemm grouped forward path (CK backend selected by env) +// 2. ck_tile::reference_mx_gemm host reference, using exact quantized operands/scales +// 3. TE HIP reference kernel adapted from test_cublaslt_gemm.cu compute_ref_kernel +// +// Intended drop-in location: +// TransformerEngine/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu + +#ifndef CK_TILE_USE_OCP_FP8 +#define CK_TILE_USE_OCP_FP8 1 +#endif + +#include +#include +#include + +#include +#include +#include + +#include "../test_common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace transformer_engine; +using namespace test; + +using fp8 = fp8e4m3; +using bf16_t = bf16; +using e8m0_t_te = fp8e8m0; + +namespace { + +struct CaseConfig { + size_t m_total; + size_t n; + size_t k; + int experts; + float scale; + int seed; + int ck_ref_groups; +}; + +static std::string case_name(const testing::TestParamInfo& info) { + const auto& c = info.param; + std::ostringstream os; + os << "M" << c.m_total << "_N" << c.n << "_K" << c.k + << "_E" << c.experts; + return os.str(); +} + +static void set_env_defaults() { + setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1", 1); + setenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", "1", 0); + setenv("NVTE_ROCM_ENABLE_MXFP8", "1", 0); +} + +static float to_float(float x) { return x; } +static float to_float(const bf16_t& x) { return static_cast(x); } +static float to_float(const ck_tile::bfloat16_t& x) { return static_cast(x); } + +__device__ __host__ __forceinline__ float ref_gelu_unused(float x) { + float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template +__global__ void compute_ref_kernel( + const A_Type* __restrict__ a_data, + const B_Type* __restrict__ b_data, + float a_scale_inv_scalar, + float b_scale_inv_scalar, + const e8m0_t_te* __restrict__ a_scale_inv_mxfp8, + const e8m0_t_te* __restrict__ b_scale_inv_mxfp8, + size_t a_scale_ld, + size_t b_scale_ld, + bool a_scale_is_colwise, + bool b_scale_is_colwise, + const Bias_Type* __restrict__ bias_data, + float d_scale, + size_t m, size_t k, size_t n, + D_Type* __restrict__ d_data, + float* __restrict__ d_amax, + Gelu_Type* __restrict__ gelu_data, + bool transa, + bool transb, + bool is_fp8_output, + bool a_is_colwise, + bool b_is_colwise, + bool use_mxfp8) { + const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; + const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; + const bool in_range = (ii < m) && (jj < n); + + float val = 0.0f; + if (in_range) { + for (size_t kk = 0; kk < k; ++kk) { + size_t a_idx = 0; + size_t b_idx = 0; + + if (use_mxfp8) { + a_idx = transa ? (ii * k + kk) : (kk * m + ii); + b_idx = transb ? (kk * n + jj) : (jj * k + kk); + } else { + a_idx = a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + b_idx = b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); + } + + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; + + if (a_scale_inv_mxfp8) { + const size_t kc = kk / 32; + const size_t a_scale_idx = + a_scale_is_colwise ? (kc * a_scale_ld + ii) : (ii * a_scale_ld + kc); + const size_t b_scale_idx = + b_scale_is_colwise ? (kc * b_scale_ld + jj) : (jj * b_scale_ld + kc); + a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); + b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); + } + + const float a_val = static_cast(a_data[a_idx]); + const float b_val = static_cast(b_data[b_idx]); + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; + } + + if (bias_data) val += static_cast(bias_data[ii]); + if (gelu_data) { + gelu_data[ii + jj * m] = static_cast(val); + val = ref_gelu_unused(val); + } + + const float scaled = val * d_scale; + d_data[ii + jj * m] = static_cast(scaled); + } + + if (is_fp8_output && d_amax) { + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int nthreads = blockDim.x * blockDim.y; + extern __shared__ float s_amax[]; + s_amax[tid] = in_range ? fabsf(val) : 0.0f; + __syncthreads(); + for (int offset = nthreads / 2; offset > 0; offset /= 2) { + if (tid < offset) s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); + __syncthreads(); + } + if (tid == 0) atomicMax(d_amax, s_amax[0]); + } +} + +template +static void fill_randn_cpu(Tensor* t, float scale, int seed) { + std::mt19937 gen(seed); + std::normal_distribution dist(0.0f, scale); + const size_t n = product(t->rowwise_shape()); + T* ptr = t->rowwise_cpu_dptr(); + for (size_t i = 0; i < n; ++i) ptr[i] = static_cast(dist(gen)); + t->from_cpu(); +} + +static std::vector split_even(size_t m_total, int experts) { + NVTE_CHECK(experts > 0, "experts must be > 0"); + NVTE_CHECK(m_total % static_cast(experts) == 0, + "m_total must be divisible by experts"); + return std::vector(experts, m_total / static_cast(experts)); +} + +struct ErrorStats { + size_t count = 0; + double sum_abs = 0.0; + double sum_rel = 0.0; + double sum_ref_abs = 0.0; + double sum_got_abs = 0.0; + float max_abs = 0.0f; + float max_rel = 0.0f; + std::vector abs_errs; +}; + +static void add_err(ErrorStats& s, float got, float ref) { + const float abs_err = std::abs(got - ref); + const float rel_err = abs_err / std::max(std::abs(ref), 1.0e-12f); + s.count++; + s.sum_abs += abs_err; + s.sum_rel += rel_err; + s.sum_ref_abs += std::abs(ref); + s.sum_got_abs += std::abs(got); + s.max_abs = std::max(s.max_abs, abs_err); + s.max_rel = std::max(s.max_rel, rel_err); + s.abs_errs.push_back(abs_err); +} + +static float quantile(std::vector& values, double q) { + if (values.empty()) return 0.0f; + const size_t pos = std::min(static_cast(q * (values.size() - 1)), values.size() - 1); + std::nth_element(values.begin(), values.begin() + pos, values.end()); + return values[pos]; +} + +static void print_stats(const std::string& label, ErrorStats s) { + std::vector v50 = s.abs_errs; + std::vector v90 = s.abs_errs; + std::vector v99 = s.abs_errs; + const double denom = static_cast(std::max(s.count, 1)); + std::cout << std::fixed << std::setprecision(6) + << label + << " count=" << s.count + << " max_abs=" << s.max_abs + << " mean_abs=" << (s.sum_abs / denom) + << " p50_abs=" << quantile(v50, 0.50) + << " p90_abs=" << quantile(v90, 0.90) + << " p99_abs=" << quantile(v99, 0.99) + << " max_rel=" << s.max_rel + << " mean_rel=" << (s.sum_rel / denom) + << " ref_abs_mean=" << (s.sum_ref_abs / denom) + << " got_abs_mean=" << (s.sum_got_abs / denom) + << std::endl; +} + +static void expect_reference_match(const std::string& label, + const ErrorStats& stats, + float max_abs_limit, + float mean_abs_limit) { + print_stats(label, stats); + EXPECT_LE(stats.max_abs, max_abs_limit) << label; + EXPECT_LE(stats.sum_abs / static_cast(std::max(stats.count, 1)), + static_cast(mean_abs_limit)) << label; +} + +static void run_te_grouped_mxfp8_forward(const std::vector& weights_mx, + const std::vector& inputs_mx, + std::vector* outputs, + Tensor* workspace, + int math_sm_count) { + const size_t groups = weights_mx.size(); + std::vector A(groups), B(groups), D(groups), Bias(groups), PreGelu(groups); + std::vector empty_bias(groups), empty_pregelu(groups); + + // Match GroupedLinear forward / te_general_grouped_gemm: + // A = weight [N,K], transa=true + // B = input [M,K], transb=false + // D = output [M,N] + for (size_t i = 0; i < groups; ++i) { + A[i] = const_cast(weights_mx[i]).data(); + B[i] = const_cast(inputs_mx[i]).data(); + D[i] = (*outputs)[i].data(); + Bias[i] = empty_bias[i].data(); + PreGelu[i] = empty_pregelu[i].data(); + } + + std::vector Workspaces(1); + Workspaces[0] = workspace->data(); + + nvte_multi_tensor_gemm(A.data(), + B.data(), + D.data(), + Bias.data(), + PreGelu.data(), + groups, + true, // transa: weight [N,K] -> op(A) [K,N] + false, // transb: input [M,K] -> op(B) [M,K] + false, // grad + Workspaces.data(), + false, // accumulate + false, // use_split_accumulator + math_sm_count, + 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); +} + +template +static void run_hip_ref_for_group(const Tensor& input_mx, + const Tensor& weight_mx, + Tensor* ref_d_colmajor, + size_t m, + size_t k, + size_t n) { + // compute_ref_kernel expects A=input [M,K], B=weight [N,K], transa=true, transb=false, + // and writes D as column-major MxN into rowwise storage shaped [N,M]. + const auto a_s = input_mx.rowwise_scale_inv_shape(); + const auto b_s = weight_mx.rowwise_scale_inv_shape(); + NVTE_CHECK(a_s.ndim == 2 && b_s.ndim == 2, "Expected 2D MXFP8 scale_inv tensors"); + const size_t a_scale_ld = a_s.data[1]; + const size_t b_scale_ld = b_s.data[1]; + + dim3 block(16, 16); + dim3 grid(static_cast((n + block.x - 1) / block.x), + static_cast((m + block.y - 1) / block.y)); + const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); + + compute_ref_kernel + <<>>( + static_cast(input_mx.rowwise_dptr()), + static_cast(weight_mx.rowwise_dptr()), + 1.0f, + 1.0f, + static_cast(input_mx.rowwise_scale_inv_dptr()), + static_cast(weight_mx.rowwise_scale_inv_dptr()), + a_scale_ld, + b_scale_ld, + false, // input scale rowwise [M,K/32] + false, // weight scale rowwise [N,K/32] + nullptr, + 1.0f, + m, k, n, + static_cast(ref_d_colmajor->rowwise_dptr()), + nullptr, + nullptr, + true, // transa for A=input in this reference-kernel convention + false, // transb for B=weight + false, + false, + false, + true); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); +} + +static ck_tile::HostTensor run_ck_tile_reference_for_group( + const Tensor& input_mx, + const Tensor& weight_mx, + size_t m, + size_t k, + size_t n) { + using namespace ck_tile::literals; + using AType = ck_tile::fp8_t; + using BType = ck_tile::fp8_t; + using CType = ck_tile::bfloat16_t; + using ScaleType = ck_tile::e8m0_t; + + const size_t kscale = k / 32; + + ck_tile::HostTensor a_host( + ck_tile::HostTensorDescriptor({m, k}, {k, 1_uz})); + ck_tile::HostTensor b_host( + ck_tile::HostTensorDescriptor({k, n}, {1_uz, k})); + ck_tile::HostTensor c_ref( + ck_tile::HostTensorDescriptor({m, n}, {n, 1_uz})); + ck_tile::HostTensor a_scale_ref( + ck_tile::HostTensorDescriptor({m, kscale}, {kscale, 1_uz})); + ck_tile::HostTensor b_scale_ref( + ck_tile::HostTensorDescriptor({kscale, n}, {1_uz, kscale})); + + c_ref.SetZero(); + + NVTE_CHECK_CUDA(cudaMemcpy(a_host.data(), + input_mx.rowwise_dptr(), + a_host.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(b_host.data(), + weight_mx.rowwise_dptr(), + b_host.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(a_scale_ref.data(), + input_mx.rowwise_scale_inv_dptr(), + a_scale_ref.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(b_scale_ref.data(), + weight_mx.rowwise_scale_inv_dptr(), + b_scale_ref.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + + ck_tile::reference_mx_gemm( + a_host, b_host, c_ref, a_scale_ref, b_scale_ref); + return c_ref; +} + +static ErrorStats compare_te_vs_hip(const Tensor& te_out_rowmajor, + const Tensor& hip_ref_colmajor, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); + const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(te[i * n + j]), to_float(hip[j * m + i])); + } + } + return stats; +} + +static ErrorStats compare_te_vs_ck(const Tensor& te_out_rowmajor, + const ck_tile::HostTensor& ck_ref, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(te[i * n + j]), to_float(ck_ref(i, j))); + } + } + return stats; +} + +static ErrorStats compare_ck_vs_hip(const ck_tile::HostTensor& ck_ref, + const Tensor& hip_ref_colmajor, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(ck_ref(i, j)), to_float(hip[j * m + i])); + } + } + return stats; +} + +static void run_case(const CaseConfig& cfg) { + set_env_defaults(); + + ASSERT_EQ(cfg.k % 128, 0UL) << "K must be a multiple of 128 for MXFP8"; + ASSERT_EQ(cfg.m_total % static_cast(cfg.experts), 0UL); + + cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); +#ifdef __HIP_PLATFORM_AMD__ + const bool is_gfx950_or_newer_cdna = (prop.major == 9 && prop.minor >= 5); + const bool is_gfx1250 = (prop.major == 12 && prop.minor == 5); + + if (!is_gfx950_or_newer_cdna && !is_gfx1250) { + GTEST_SKIP() << "MXFP8 requires gfx950+ or gfx1250 in this test. GPU=" << prop.name + << " major=" << prop.major << " minor=" << prop.minor; + } +#endif + + const auto m_splits = split_even(cfg.m_total, cfg.experts); + const size_t per_m = m_splits[0]; + const int groups_to_ck = std::min(cfg.ck_ref_groups, cfg.experts); + + std::cout << "\n=== TE CK grouped MXFP8 forward reference comparison ===\n" + << "M_total=" << cfg.m_total << " N=" << cfg.n << " K=" << cfg.k + << " experts=" << cfg.experts << " per_expert_M=" << per_m + << " scale=" << cfg.scale << " seed=" << cfg.seed << "\n" + << "NVTE_USE_CUTLASS_GROUPED_GEMM=" << std::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM") << "\n" + << "NVTE_ROCM_ENABLE_MXFP8=" << std::getenv("NVTE_ROCM_ENABLE_MXFP8") << "\n" + << "CK_TILE_USE_OCP_FP8=" << CK_TILE_USE_OCP_FP8 << "\n" + << "GPU=" << prop.name << " SM/CU count=" << prop.multiProcessorCount << "\n"; + + std::vector input_src; + std::vector weight_src; + std::vector input_mx; + std::vector weight_mx; + std::vector output_te; + std::vector output_hip_colmajor; + input_src.reserve(cfg.experts); + weight_src.reserve(cfg.experts); + input_mx.reserve(cfg.experts); + weight_mx.reserve(cfg.experts); + output_te.reserve(cfg.experts); + output_hip_colmajor.reserve(cfg.experts); + + for (int g = 0; g < cfg.experts; ++g) { + const size_t m = m_splits[g]; + input_src.emplace_back("input_src", std::vector{m, cfg.k}, DType::kBFloat16); + weight_src.emplace_back("weight_src", std::vector{cfg.n, cfg.k}, DType::kBFloat16); + + fill_randn_cpu(&input_src.back(), cfg.scale, cfg.seed + 1009 * g + 17); + fill_randn_cpu(&weight_src.back(), cfg.scale, cfg.seed + 1009 * g + 29); + + input_mx.emplace_back("input_mx", std::vector{m, cfg.k}, DType::kFloat8E4M3, + true, false, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + weight_mx.emplace_back("weight_mx", std::vector{cfg.n, cfg.k}, DType::kFloat8E4M3, + true, false, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + + nvte_quantize(input_src.back().data(), input_mx.back().data(), 0); + nvte_quantize(weight_src.back().data(), weight_mx.back().data(), 0); + + output_te.emplace_back("output_te", std::vector{m, cfg.n}, DType::kBFloat16); + output_hip_colmajor.emplace_back("output_hip_colmajor", std::vector{cfg.n, m}, DType::kBFloat16); + } + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + Tensor workspace("workspace", std::vector{67108864}, DType::kByte); + + run_te_grouped_mxfp8_forward(weight_mx, input_mx, &output_te, &workspace, + prop.multiProcessorCount); + for (auto& out : output_te) out.to_cpu(); + + for (int g = 0; g < cfg.experts; ++g) { + run_hip_ref_for_group(input_mx[g], weight_mx[g], &output_hip_colmajor[g], + m_splits[g], cfg.k, cfg.n); + output_hip_colmajor[g].to_cpu(); + expect_reference_match("group " + std::to_string(g) + " TE_vs_HIP_REF", + compare_te_vs_hip(output_te[g], output_hip_colmajor[g], + m_splits[g], cfg.n), + 0.25f, + 0.03f); + } + + for (int g = 0; g < groups_to_ck; ++g) { + auto ck_ref = run_ck_tile_reference_for_group(input_mx[g], weight_mx[g], + m_splits[g], cfg.k, cfg.n); + expect_reference_match("group " + std::to_string(g) + " TE_vs_CK_REF ", + compare_te_vs_ck(output_te[g], ck_ref, m_splits[g], cfg.n), + 0.25f, + 0.03f); + expect_reference_match("group " + std::to_string(g) + " CK_vs_HIP_REF", + compare_ck_vs_hip(ck_ref, output_hip_colmajor[g], + m_splits[g], cfg.n), + 0.25f, + 0.03f); + } +} + +} // namespace + +class GroupedMXFP8ForwardRefsTestSuite : public ::testing::TestWithParam {}; + +TEST_P(GroupedMXFP8ForwardRefsTestSuite, MatchesCKTileAndHIPReferences) { + run_case(GetParam()); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + GroupedMXFP8ForwardRefsTestSuite, + ::testing::Values( + // Small enough for quick CI-style sanity. + CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, 1}, + // Reproduces the earlier forward-only "failure" scale/shape regime, but + // validates against true MXFP8 references instead of BF16. + CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, 1}, + // Llama-ish suspicious path. CK reference only group 0 to keep runtime sane; + // HIP reference checks all groups. + CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, 1}), + case_name); diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index 59ae7c0ff..932f22ecc 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -6,14 +6,17 @@ #include #include "../../common.h" + #include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include +#include + namespace transformer_engine { namespace mx_grouped_gemm { @@ -511,6 +514,21 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, ". D=", Dd0, "x", Dd1, ", expected=", M, "x", N); } + if (i == 0) { + printf("[MX CK] transA=%d transB=%d use_a_col=%d use_b_col=%d " + "M=%ld N=%ld K=%ld Ad=[%ld,%ld] Bd=[%ld,%ld] " + "a_scale_shape=[%zu,%zu] b_scale_shape=[%zu,%zu]\n", + static_cast(ctx.transA), + static_cast(ctx.transB), + static_cast(ctx.use_a_colwise_data), + static_cast(ctx.use_b_colwise_data), + M, N, K, Ad0, Ad1, Bd0, Bd1, + a_scales.shape.size() > 0 ? a_scales.shape[0] : 0, + a_scales.shape.size() > 1 ? a_scales.shape[1] : 0, + b_scales.shape.size() > 0 ? b_scales.shape[0] : 0, + b_scales.shape.size() > 1 ? b_scales.shape[1] : 0); + } + const ck_tile::index_t stride_A = static_cast(Ad1); const ck_tile::index_t stride_B = static_cast(Bd1); const ck_tile::index_t stride_E = static_cast(Dd1); From 1784045d88d240a4cc3eff6aceec1a9fc88c51be Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sun, 3 May 2026 15:33:02 +0000 Subject: [PATCH 08/13] Fix MXFP8 grouped GEMM scale handling for NN/TN/NT --- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 121 ++++++++---------- 1 file changed, 53 insertions(+), 68 deletions(-) diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index 932f22ecc..4a5a4aaa9 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -15,6 +15,7 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include +#include #include namespace transformer_engine { @@ -60,19 +61,6 @@ static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, return true; } -// Columnwise storage is the physical transposed view used to rewrite a -// normalized GEMM into CK's preferred NT presentation. Interpret its -// 2D shape consistently with the FP8 grouped GEMM path. -static inline bool get_columnwise_storage_2d_dims(const transformer_engine::SimpleTensor& t, - int64_t& d0, int64_t& d1) { - if (t.shape.size() != 2) { - return false; - } - d0 = static_cast(t.shape[1]); - d1 = static_cast(t.shape[0]); - return true; -} - static constexpr ck_tile::index_t ScaleBlockSize = 32; enum struct MxGemmPipelineType @@ -351,35 +339,13 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, bool transA_use = transB; bool transB_use = transA; - bool use_a_colwise_data = false; - bool use_b_colwise_data = false; + const bool use_a_colwise_data = transA_use; + const bool use_b_colwise_data = !transB_use; Tensor* A0_te = convertNVTETensorCheck(A_use[0]); Tensor* B0_te = convertNVTETensorCheck(B_use[0]); - // CK MX grouped GEMM is presented as normalized NT, matching the FP8 grouped path. - // Selecting columnwise_data rewrites the physical storage and effective dims used by CK - // while preserving the original math. - if (transA_use) { - if (!A0_te->has_columnwise_data() || A0_te->columnwise_scale_inv.dptr == nullptr) { - NVTE_WARN("ck_tile_mx_grouped_gemm: missing A columnwise MXFP8 view for NT rewrite; falling back."); - return false; - } - use_a_colwise_data = true; - transA_use = false; - } - - if (!transB_use) { - if (!B0_te->has_columnwise_data() || B0_te->columnwise_scale_inv.dptr == nullptr) { - NVTE_WARN("ck_tile_mx_grouped_gemm: missing B columnwise MXFP8 view for NT rewrite; falling back."); - return false; - } - use_b_colwise_data = true; - transB_use = true; - } - - // Validate scale type / data type combination using the effective storage - // selected by the NT canonicalization above. + // Validate scale type / data type combination. // Expected input data format: fp8/bf8 (e4m3/e5m2) // Expected scale data format: e8m0 const auto* D0 = convertNVTETensorCheck(D[0]); @@ -466,24 +432,17 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (ctx.use_a_colwise_data) { - if (!get_columnwise_storage_2d_dims(A_te->columnwise_data, Ad0, Ad1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected 2D columnwise_data for A in group ", i); - } - } else { - if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized A in group ", i); - } + // MXFP8 columnwise_data is not a physical transpose. It has the same + // logical tensor shape as rowwise data, but is quantized with scales + // along the other dimension. Therefore dims/strides must always be + // derived from the TE tensor shape, not from columnwise_data.shape + // interpreted as a transposed storage view. + if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized A in group ", i); } - if (ctx.use_b_colwise_data) { - if (!get_columnwise_storage_2d_dims(B_te->columnwise_data, Bd0, Bd1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected 2D columnwise_data for B in group ", i); - } - } else { - if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized B in group ", i); - } + if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized B in group ", i); } if (!get_flat_2d_dims(*D_te, Dd0, Dd1)) { @@ -556,20 +515,46 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, std::make_unique(b_scale_shuffled_bytes)); void* a_scale_shuffled_ptr = a_scale_shuffled_bufs.back()->GetDeviceBuffer(); void* b_scale_shuffled_ptr = b_scale_shuffled_bufs.back()->GetDeviceBuffer(); - preShuffleScaleBuffer_gfx1250( - reinterpret_cast(a_scales.dptr), - reinterpret_cast(a_scale_shuffled_ptr), - a_scale_actual_rows, - a_scale_output_rows, - KScale, - stream); - preShuffleScaleBuffer_gfx1250( - reinterpret_cast(b_scales.dptr), - reinterpret_cast(b_scale_shuffled_ptr), - b_scale_actual_rows, - b_scale_output_rows, - KScale, - stream); + // CK expects canonical pre-shuffled scale buffers laid out as + // A: [M, KScale] and B: [N, KScale], independent of A/B data layouts. + // TE rowwise MXFP8 scale_inv is [rows, KScale] and can be read with + // KStride=true. TE columnwise_scale_inv is [KScale, rows] and must be + // read with KStride=false before writing CK's canonical shuffled layout. + if (ctx.use_a_colwise_data) { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(a_scales.dptr), + reinterpret_cast(a_scale_shuffled_ptr), + a_scale_actual_rows, + a_scale_output_rows, + KScale, + stream); + } else { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(a_scales.dptr), + reinterpret_cast(a_scale_shuffled_ptr), + a_scale_actual_rows, + a_scale_output_rows, + KScale, + stream); + } + + if (ctx.use_b_colwise_data) { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(b_scales.dptr), + reinterpret_cast(b_scale_shuffled_ptr), + b_scale_actual_rows, + b_scale_output_rows, + KScale, + stream); + } else { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(b_scales.dptr), + reinterpret_cast(b_scale_shuffled_ptr), + b_scale_actual_rows, + b_scale_output_rows, + KScale, + stream); + } descs.emplace_back(mx_grouped_gemm_kargs( a.dptr, a_scale_shuffled_ptr, From fe99bf30d1b99c701217ef37b223f09da14e2384 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sun, 3 May 2026 15:59:03 +0000 Subject: [PATCH 09/13] update ck mxfp8 group gemm gtest to exercise mixed dtypes --- tests/cpp/operator/CMakeLists.txt | 2 +- .../test_te_ck_grouped_mxfp8_forward_refs.cu | 554 ------------------ 2 files changed, 1 insertion(+), 555 deletions(-) delete mode 100644 tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index fa9f9a542..c81ab1e62 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -16,7 +16,7 @@ list(APPEND test_cuda_sources test_dequantize_mxfp8.cu test_dequantize_nvfp4.cu test_cast_nvfp4_transpose.cu - test_te_ck_grouped_mxfp8_forward_refs.cu + test_te_ck_grouped_mxfp8.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu diff --git a/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu b/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu deleted file mode 100644 index 0872b1640..000000000 --- a/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu +++ /dev/null @@ -1,554 +0,0 @@ -/************************************************************************* - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - -// Forward-only TE CK grouped MXFP8 validation. -// -// Compares three paths for grouped MXFP8 forward GEMM: -// 1. TE nvte_multi_tensor_gemm grouped forward path (CK backend selected by env) -// 2. ck_tile::reference_mx_gemm host reference, using exact quantized operands/scales -// 3. TE HIP reference kernel adapted from test_cublaslt_gemm.cu compute_ref_kernel -// -// Intended drop-in location: -// TransformerEngine/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu - -#ifndef CK_TILE_USE_OCP_FP8 -#define CK_TILE_USE_OCP_FP8 1 -#endif - -#include -#include -#include - -#include -#include -#include - -#include "../test_common.h" - -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace transformer_engine; -using namespace test; - -using fp8 = fp8e4m3; -using bf16_t = bf16; -using e8m0_t_te = fp8e8m0; - -namespace { - -struct CaseConfig { - size_t m_total; - size_t n; - size_t k; - int experts; - float scale; - int seed; - int ck_ref_groups; -}; - -static std::string case_name(const testing::TestParamInfo& info) { - const auto& c = info.param; - std::ostringstream os; - os << "M" << c.m_total << "_N" << c.n << "_K" << c.k - << "_E" << c.experts; - return os.str(); -} - -static void set_env_defaults() { - setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1", 1); - setenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", "1", 0); - setenv("NVTE_ROCM_ENABLE_MXFP8", "1", 0); -} - -static float to_float(float x) { return x; } -static float to_float(const bf16_t& x) { return static_cast(x); } -static float to_float(const ck_tile::bfloat16_t& x) { return static_cast(x); } - -__device__ __host__ __forceinline__ float ref_gelu_unused(float x) { - float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template -__global__ void compute_ref_kernel( - const A_Type* __restrict__ a_data, - const B_Type* __restrict__ b_data, - float a_scale_inv_scalar, - float b_scale_inv_scalar, - const e8m0_t_te* __restrict__ a_scale_inv_mxfp8, - const e8m0_t_te* __restrict__ b_scale_inv_mxfp8, - size_t a_scale_ld, - size_t b_scale_ld, - bool a_scale_is_colwise, - bool b_scale_is_colwise, - const Bias_Type* __restrict__ bias_data, - float d_scale, - size_t m, size_t k, size_t n, - D_Type* __restrict__ d_data, - float* __restrict__ d_amax, - Gelu_Type* __restrict__ gelu_data, - bool transa, - bool transb, - bool is_fp8_output, - bool a_is_colwise, - bool b_is_colwise, - bool use_mxfp8) { - const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; - const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; - const bool in_range = (ii < m) && (jj < n); - - float val = 0.0f; - if (in_range) { - for (size_t kk = 0; kk < k; ++kk) { - size_t a_idx = 0; - size_t b_idx = 0; - - if (use_mxfp8) { - a_idx = transa ? (ii * k + kk) : (kk * m + ii); - b_idx = transb ? (kk * n + jj) : (jj * k + kk); - } else { - a_idx = a_is_colwise ? (ii * k + kk) - : (transa ? (ii * k + kk) : (kk * m + ii)); - b_idx = b_is_colwise ? (jj * k + kk) - : (transb ? (kk * n + jj) : (jj * k + kk)); - } - - float a_scale_inv_val = a_scale_inv_scalar; - float b_scale_inv_val = b_scale_inv_scalar; - - if (a_scale_inv_mxfp8) { - const size_t kc = kk / 32; - const size_t a_scale_idx = - a_scale_is_colwise ? (kc * a_scale_ld + ii) : (ii * a_scale_ld + kc); - const size_t b_scale_idx = - b_scale_is_colwise ? (kc * b_scale_ld + jj) : (jj * b_scale_ld + kc); - a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); - b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); - } - - const float a_val = static_cast(a_data[a_idx]); - const float b_val = static_cast(b_data[b_idx]); - val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; - } - - if (bias_data) val += static_cast(bias_data[ii]); - if (gelu_data) { - gelu_data[ii + jj * m] = static_cast(val); - val = ref_gelu_unused(val); - } - - const float scaled = val * d_scale; - d_data[ii + jj * m] = static_cast(scaled); - } - - if (is_fp8_output && d_amax) { - const int tid = threadIdx.y * blockDim.x + threadIdx.x; - const int nthreads = blockDim.x * blockDim.y; - extern __shared__ float s_amax[]; - s_amax[tid] = in_range ? fabsf(val) : 0.0f; - __syncthreads(); - for (int offset = nthreads / 2; offset > 0; offset /= 2) { - if (tid < offset) s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); - __syncthreads(); - } - if (tid == 0) atomicMax(d_amax, s_amax[0]); - } -} - -template -static void fill_randn_cpu(Tensor* t, float scale, int seed) { - std::mt19937 gen(seed); - std::normal_distribution dist(0.0f, scale); - const size_t n = product(t->rowwise_shape()); - T* ptr = t->rowwise_cpu_dptr(); - for (size_t i = 0; i < n; ++i) ptr[i] = static_cast(dist(gen)); - t->from_cpu(); -} - -static std::vector split_even(size_t m_total, int experts) { - NVTE_CHECK(experts > 0, "experts must be > 0"); - NVTE_CHECK(m_total % static_cast(experts) == 0, - "m_total must be divisible by experts"); - return std::vector(experts, m_total / static_cast(experts)); -} - -struct ErrorStats { - size_t count = 0; - double sum_abs = 0.0; - double sum_rel = 0.0; - double sum_ref_abs = 0.0; - double sum_got_abs = 0.0; - float max_abs = 0.0f; - float max_rel = 0.0f; - std::vector abs_errs; -}; - -static void add_err(ErrorStats& s, float got, float ref) { - const float abs_err = std::abs(got - ref); - const float rel_err = abs_err / std::max(std::abs(ref), 1.0e-12f); - s.count++; - s.sum_abs += abs_err; - s.sum_rel += rel_err; - s.sum_ref_abs += std::abs(ref); - s.sum_got_abs += std::abs(got); - s.max_abs = std::max(s.max_abs, abs_err); - s.max_rel = std::max(s.max_rel, rel_err); - s.abs_errs.push_back(abs_err); -} - -static float quantile(std::vector& values, double q) { - if (values.empty()) return 0.0f; - const size_t pos = std::min(static_cast(q * (values.size() - 1)), values.size() - 1); - std::nth_element(values.begin(), values.begin() + pos, values.end()); - return values[pos]; -} - -static void print_stats(const std::string& label, ErrorStats s) { - std::vector v50 = s.abs_errs; - std::vector v90 = s.abs_errs; - std::vector v99 = s.abs_errs; - const double denom = static_cast(std::max(s.count, 1)); - std::cout << std::fixed << std::setprecision(6) - << label - << " count=" << s.count - << " max_abs=" << s.max_abs - << " mean_abs=" << (s.sum_abs / denom) - << " p50_abs=" << quantile(v50, 0.50) - << " p90_abs=" << quantile(v90, 0.90) - << " p99_abs=" << quantile(v99, 0.99) - << " max_rel=" << s.max_rel - << " mean_rel=" << (s.sum_rel / denom) - << " ref_abs_mean=" << (s.sum_ref_abs / denom) - << " got_abs_mean=" << (s.sum_got_abs / denom) - << std::endl; -} - -static void expect_reference_match(const std::string& label, - const ErrorStats& stats, - float max_abs_limit, - float mean_abs_limit) { - print_stats(label, stats); - EXPECT_LE(stats.max_abs, max_abs_limit) << label; - EXPECT_LE(stats.sum_abs / static_cast(std::max(stats.count, 1)), - static_cast(mean_abs_limit)) << label; -} - -static void run_te_grouped_mxfp8_forward(const std::vector& weights_mx, - const std::vector& inputs_mx, - std::vector* outputs, - Tensor* workspace, - int math_sm_count) { - const size_t groups = weights_mx.size(); - std::vector A(groups), B(groups), D(groups), Bias(groups), PreGelu(groups); - std::vector empty_bias(groups), empty_pregelu(groups); - - // Match GroupedLinear forward / te_general_grouped_gemm: - // A = weight [N,K], transa=true - // B = input [M,K], transb=false - // D = output [M,N] - for (size_t i = 0; i < groups; ++i) { - A[i] = const_cast(weights_mx[i]).data(); - B[i] = const_cast(inputs_mx[i]).data(); - D[i] = (*outputs)[i].data(); - Bias[i] = empty_bias[i].data(); - PreGelu[i] = empty_pregelu[i].data(); - } - - std::vector Workspaces(1); - Workspaces[0] = workspace->data(); - - nvte_multi_tensor_gemm(A.data(), - B.data(), - D.data(), - Bias.data(), - PreGelu.data(), - groups, - true, // transa: weight [N,K] -> op(A) [K,N] - false, // transb: input [M,K] -> op(B) [M,K] - false, // grad - Workspaces.data(), - false, // accumulate - false, // use_split_accumulator - math_sm_count, - 0); - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); -} - -template -static void run_hip_ref_for_group(const Tensor& input_mx, - const Tensor& weight_mx, - Tensor* ref_d_colmajor, - size_t m, - size_t k, - size_t n) { - // compute_ref_kernel expects A=input [M,K], B=weight [N,K], transa=true, transb=false, - // and writes D as column-major MxN into rowwise storage shaped [N,M]. - const auto a_s = input_mx.rowwise_scale_inv_shape(); - const auto b_s = weight_mx.rowwise_scale_inv_shape(); - NVTE_CHECK(a_s.ndim == 2 && b_s.ndim == 2, "Expected 2D MXFP8 scale_inv tensors"); - const size_t a_scale_ld = a_s.data[1]; - const size_t b_scale_ld = b_s.data[1]; - - dim3 block(16, 16); - dim3 grid(static_cast((n + block.x - 1) / block.x), - static_cast((m + block.y - 1) / block.y)); - const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); - - compute_ref_kernel - <<>>( - static_cast(input_mx.rowwise_dptr()), - static_cast(weight_mx.rowwise_dptr()), - 1.0f, - 1.0f, - static_cast(input_mx.rowwise_scale_inv_dptr()), - static_cast(weight_mx.rowwise_scale_inv_dptr()), - a_scale_ld, - b_scale_ld, - false, // input scale rowwise [M,K/32] - false, // weight scale rowwise [N,K/32] - nullptr, - 1.0f, - m, k, n, - static_cast(ref_d_colmajor->rowwise_dptr()), - nullptr, - nullptr, - true, // transa for A=input in this reference-kernel convention - false, // transb for B=weight - false, - false, - false, - true); - NVTE_CHECK_CUDA(cudaGetLastError()); - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); -} - -static ck_tile::HostTensor run_ck_tile_reference_for_group( - const Tensor& input_mx, - const Tensor& weight_mx, - size_t m, - size_t k, - size_t n) { - using namespace ck_tile::literals; - using AType = ck_tile::fp8_t; - using BType = ck_tile::fp8_t; - using CType = ck_tile::bfloat16_t; - using ScaleType = ck_tile::e8m0_t; - - const size_t kscale = k / 32; - - ck_tile::HostTensor a_host( - ck_tile::HostTensorDescriptor({m, k}, {k, 1_uz})); - ck_tile::HostTensor b_host( - ck_tile::HostTensorDescriptor({k, n}, {1_uz, k})); - ck_tile::HostTensor c_ref( - ck_tile::HostTensorDescriptor({m, n}, {n, 1_uz})); - ck_tile::HostTensor a_scale_ref( - ck_tile::HostTensorDescriptor({m, kscale}, {kscale, 1_uz})); - ck_tile::HostTensor b_scale_ref( - ck_tile::HostTensorDescriptor({kscale, n}, {1_uz, kscale})); - - c_ref.SetZero(); - - NVTE_CHECK_CUDA(cudaMemcpy(a_host.data(), - input_mx.rowwise_dptr(), - a_host.get_element_space_size_in_bytes(), - cudaMemcpyDeviceToHost)); - NVTE_CHECK_CUDA(cudaMemcpy(b_host.data(), - weight_mx.rowwise_dptr(), - b_host.get_element_space_size_in_bytes(), - cudaMemcpyDeviceToHost)); - NVTE_CHECK_CUDA(cudaMemcpy(a_scale_ref.data(), - input_mx.rowwise_scale_inv_dptr(), - a_scale_ref.get_element_space_size_in_bytes(), - cudaMemcpyDeviceToHost)); - NVTE_CHECK_CUDA(cudaMemcpy(b_scale_ref.data(), - weight_mx.rowwise_scale_inv_dptr(), - b_scale_ref.get_element_space_size_in_bytes(), - cudaMemcpyDeviceToHost)); - - ck_tile::reference_mx_gemm( - a_host, b_host, c_ref, a_scale_ref, b_scale_ref); - return c_ref; -} - -static ErrorStats compare_te_vs_hip(const Tensor& te_out_rowmajor, - const Tensor& hip_ref_colmajor, - size_t m, - size_t n) { - ErrorStats stats; - const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); - const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); - for (size_t i = 0; i < m; ++i) { - for (size_t j = 0; j < n; ++j) { - add_err(stats, to_float(te[i * n + j]), to_float(hip[j * m + i])); - } - } - return stats; -} - -static ErrorStats compare_te_vs_ck(const Tensor& te_out_rowmajor, - const ck_tile::HostTensor& ck_ref, - size_t m, - size_t n) { - ErrorStats stats; - const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); - for (size_t i = 0; i < m; ++i) { - for (size_t j = 0; j < n; ++j) { - add_err(stats, to_float(te[i * n + j]), to_float(ck_ref(i, j))); - } - } - return stats; -} - -static ErrorStats compare_ck_vs_hip(const ck_tile::HostTensor& ck_ref, - const Tensor& hip_ref_colmajor, - size_t m, - size_t n) { - ErrorStats stats; - const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); - for (size_t i = 0; i < m; ++i) { - for (size_t j = 0; j < n; ++j) { - add_err(stats, to_float(ck_ref(i, j)), to_float(hip[j * m + i])); - } - } - return stats; -} - -static void run_case(const CaseConfig& cfg) { - set_env_defaults(); - - ASSERT_EQ(cfg.k % 128, 0UL) << "K must be a multiple of 128 for MXFP8"; - ASSERT_EQ(cfg.m_total % static_cast(cfg.experts), 0UL); - - cudaDeviceProp prop; - NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); -#ifdef __HIP_PLATFORM_AMD__ - const bool is_gfx950_or_newer_cdna = (prop.major == 9 && prop.minor >= 5); - const bool is_gfx1250 = (prop.major == 12 && prop.minor == 5); - - if (!is_gfx950_or_newer_cdna && !is_gfx1250) { - GTEST_SKIP() << "MXFP8 requires gfx950+ or gfx1250 in this test. GPU=" << prop.name - << " major=" << prop.major << " minor=" << prop.minor; - } -#endif - - const auto m_splits = split_even(cfg.m_total, cfg.experts); - const size_t per_m = m_splits[0]; - const int groups_to_ck = std::min(cfg.ck_ref_groups, cfg.experts); - - std::cout << "\n=== TE CK grouped MXFP8 forward reference comparison ===\n" - << "M_total=" << cfg.m_total << " N=" << cfg.n << " K=" << cfg.k - << " experts=" << cfg.experts << " per_expert_M=" << per_m - << " scale=" << cfg.scale << " seed=" << cfg.seed << "\n" - << "NVTE_USE_CUTLASS_GROUPED_GEMM=" << std::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM") << "\n" - << "NVTE_ROCM_ENABLE_MXFP8=" << std::getenv("NVTE_ROCM_ENABLE_MXFP8") << "\n" - << "CK_TILE_USE_OCP_FP8=" << CK_TILE_USE_OCP_FP8 << "\n" - << "GPU=" << prop.name << " SM/CU count=" << prop.multiProcessorCount << "\n"; - - std::vector input_src; - std::vector weight_src; - std::vector input_mx; - std::vector weight_mx; - std::vector output_te; - std::vector output_hip_colmajor; - input_src.reserve(cfg.experts); - weight_src.reserve(cfg.experts); - input_mx.reserve(cfg.experts); - weight_mx.reserve(cfg.experts); - output_te.reserve(cfg.experts); - output_hip_colmajor.reserve(cfg.experts); - - for (int g = 0; g < cfg.experts; ++g) { - const size_t m = m_splits[g]; - input_src.emplace_back("input_src", std::vector{m, cfg.k}, DType::kBFloat16); - weight_src.emplace_back("weight_src", std::vector{cfg.n, cfg.k}, DType::kBFloat16); - - fill_randn_cpu(&input_src.back(), cfg.scale, cfg.seed + 1009 * g + 17); - fill_randn_cpu(&weight_src.back(), cfg.scale, cfg.seed + 1009 * g + 29); - - input_mx.emplace_back("input_mx", std::vector{m, cfg.k}, DType::kFloat8E4M3, - true, false, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); - weight_mx.emplace_back("weight_mx", std::vector{cfg.n, cfg.k}, DType::kFloat8E4M3, - true, false, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); - - nvte_quantize(input_src.back().data(), input_mx.back().data(), 0); - nvte_quantize(weight_src.back().data(), weight_mx.back().data(), 0); - - output_te.emplace_back("output_te", std::vector{m, cfg.n}, DType::kBFloat16); - output_hip_colmajor.emplace_back("output_hip_colmajor", std::vector{cfg.n, m}, DType::kBFloat16); - } - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - - Tensor workspace("workspace", std::vector{67108864}, DType::kByte); - - run_te_grouped_mxfp8_forward(weight_mx, input_mx, &output_te, &workspace, - prop.multiProcessorCount); - for (auto& out : output_te) out.to_cpu(); - - for (int g = 0; g < cfg.experts; ++g) { - run_hip_ref_for_group(input_mx[g], weight_mx[g], &output_hip_colmajor[g], - m_splits[g], cfg.k, cfg.n); - output_hip_colmajor[g].to_cpu(); - expect_reference_match("group " + std::to_string(g) + " TE_vs_HIP_REF", - compare_te_vs_hip(output_te[g], output_hip_colmajor[g], - m_splits[g], cfg.n), - 0.25f, - 0.03f); - } - - for (int g = 0; g < groups_to_ck; ++g) { - auto ck_ref = run_ck_tile_reference_for_group(input_mx[g], weight_mx[g], - m_splits[g], cfg.k, cfg.n); - expect_reference_match("group " + std::to_string(g) + " TE_vs_CK_REF ", - compare_te_vs_ck(output_te[g], ck_ref, m_splits[g], cfg.n), - 0.25f, - 0.03f); - expect_reference_match("group " + std::to_string(g) + " CK_vs_HIP_REF", - compare_ck_vs_hip(ck_ref, output_hip_colmajor[g], - m_splits[g], cfg.n), - 0.25f, - 0.03f); - } -} - -} // namespace - -class GroupedMXFP8ForwardRefsTestSuite : public ::testing::TestWithParam {}; - -TEST_P(GroupedMXFP8ForwardRefsTestSuite, MatchesCKTileAndHIPReferences) { - run_case(GetParam()); -} - -INSTANTIATE_TEST_SUITE_P( - OperatorTest, - GroupedMXFP8ForwardRefsTestSuite, - ::testing::Values( - // Small enough for quick CI-style sanity. - CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, 1}, - // Reproduces the earlier forward-only "failure" scale/shape regime, but - // validates against true MXFP8 references instead of BF16. - CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, 1}, - // Llama-ish suspicious path. CK reference only group 0 to keep runtime sane; - // HIP reference checks all groups. - CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, 1}), - case_name); From e7159c495b53b733a5a42a0892dd739ec5bfc89b Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sun, 3 May 2026 16:00:35 +0000 Subject: [PATCH 10/13] include renamed test file --- .../cpp/operator/test_te_ck_grouped_mxfp8.cu | 629 ++++++++++++++++++ 1 file changed, 629 insertions(+) create mode 100644 tests/cpp/operator/test_te_ck_grouped_mxfp8.cu diff --git a/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu b/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu new file mode 100644 index 000000000..1ad32557c --- /dev/null +++ b/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu @@ -0,0 +1,629 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +// TE CK grouped MXFP8 validation. +// +// Compares three paths for grouped MXFP8 GEMM across NN/NT/TN transpose layouts: +// 1. TE nvte_multi_tensor_gemm grouped path (CK backend selected by env) +// 2. ck_tile::reference_mx_gemm host reference, using exact quantized operands/scales +// 3. TE HIP reference kernel adapted from test_cublaslt_gemm.cu compute_ref_kernel +// +// Intended drop-in location: +// TransformerEngine/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu + +#ifndef CK_TILE_USE_OCP_FP8 +#define CK_TILE_USE_OCP_FP8 1 +#endif + +#include +#include +#include + +#include +#include +#include + +#include "../test_common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace transformer_engine; +using namespace test; + +using fp8 = fp8e4m3; +using bf8 = fp8e5m2; +using bf16_t = bf16; +using e8m0_t_te = fp8e8m0; + +namespace { + +enum class MXOperandDType { + FP8, + BF8, +}; + +struct DTypeConfig { + const char* name; + MXOperandDType a; + MXOperandDType b; +}; + +static DType te_dtype(MXOperandDType t) { + return t == MXOperandDType::FP8 ? DType::kFloat8E4M3 : DType::kFloat8E5M2; +} + +struct LayoutConfig { + const char* name; + bool transa; + bool transb; +}; + +struct CaseConfig { + size_t m_total; + size_t n; + size_t k; + int experts; + float scale; + int seed; + LayoutConfig layout; + DTypeConfig dtype; +}; + +static std::string case_name(const testing::TestParamInfo& info) { + const auto& c = info.param; + std::ostringstream os; + os << "M" << c.m_total << "_N" << c.n << "_K" << c.k + << "_E" << c.experts << "_" << c.layout.name << "_" << c.dtype.name; + return os.str(); +} + +static void set_env_defaults() { + setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1", 1); + setenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", "1", 0); + setenv("NVTE_ROCM_ENABLE_MXFP8", "1", 0); +} + +static float to_float(float x) { return x; } +static float to_float(const bf16_t& x) { return static_cast(x); } +static float to_float(const ck_tile::bfloat16_t& x) { return static_cast(x); } + +__device__ __host__ __forceinline__ float ref_gelu_unused(float x) { + float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template +__global__ void compute_ref_kernel( + const A_Type* __restrict__ a_data, + const B_Type* __restrict__ b_data, + float a_scale_inv_scalar, + float b_scale_inv_scalar, + const e8m0_t_te* __restrict__ a_scale_inv_mxfp8, + const e8m0_t_te* __restrict__ b_scale_inv_mxfp8, + size_t a_scale_ld, + size_t b_scale_ld, + bool a_scale_is_colwise, + bool b_scale_is_colwise, + const Bias_Type* __restrict__ bias_data, + float d_scale, + size_t m, size_t k, size_t n, + D_Type* __restrict__ d_data, + float* __restrict__ d_amax, + Gelu_Type* __restrict__ gelu_data, + bool transa, + bool transb, + bool is_fp8_output, + bool a_is_colwise, + bool b_is_colwise, + bool use_mxfp8) { + const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; + const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; + const bool in_range = (ii < m) && (jj < n); + + float val = 0.0f; + if (in_range) { + for (size_t kk = 0; kk < k; ++kk) { + size_t a_idx = 0; + size_t b_idx = 0; + + if (use_mxfp8) { + a_idx = transa ? (ii * k + kk) : (kk * m + ii); + b_idx = transb ? (kk * n + jj) : (jj * k + kk); + } else { + a_idx = a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + b_idx = b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); + } + + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; + + if (a_scale_inv_mxfp8) { + const size_t kc = kk / 32; + const size_t a_scale_idx = + a_scale_is_colwise ? (kc * a_scale_ld + ii) : (ii * a_scale_ld + kc); + const size_t b_scale_idx = + b_scale_is_colwise ? (kc * b_scale_ld + jj) : (jj * b_scale_ld + kc); + a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); + b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); + } + + const float a_val = static_cast(a_data[a_idx]); + const float b_val = static_cast(b_data[b_idx]); + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; + } + + if (bias_data) val += static_cast(bias_data[ii]); + if (gelu_data) { + gelu_data[ii + jj * m] = static_cast(val); + val = ref_gelu_unused(val); + } + + const float scaled = val * d_scale; + d_data[ii + jj * m] = static_cast(scaled); + } + + if (is_fp8_output && d_amax) { + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int nthreads = blockDim.x * blockDim.y; + extern __shared__ float s_amax[]; + s_amax[tid] = in_range ? fabsf(val) : 0.0f; + __syncthreads(); + for (int offset = nthreads / 2; offset > 0; offset /= 2) { + if (tid < offset) s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); + __syncthreads(); + } + if (tid == 0) atomicMax(d_amax, s_amax[0]); + } +} + +template +static void fill_randn_cpu(Tensor* t, float scale, int seed) { + std::mt19937 gen(seed); + std::normal_distribution dist(0.0f, scale); + const size_t n = product(t->rowwise_shape()); + T* ptr = t->rowwise_cpu_dptr(); + for (size_t i = 0; i < n; ++i) ptr[i] = static_cast(dist(gen)); + t->from_cpu(); +} + +static std::vector split_even(size_t m_total, int experts) { + NVTE_CHECK(experts > 0, "experts must be > 0"); + NVTE_CHECK(m_total % static_cast(experts) == 0, + "m_total must be divisible by experts"); + return std::vector(experts, m_total / static_cast(experts)); +} + +static std::vector a_shape_for_te(size_t n, size_t k, bool transa) { + // TE grouped GEMM computes output shape [M,N]. A contributes the N dimension. + // transa=true means physical A is [N,K]; transa=false means physical A is [K,N]. + return transa ? std::vector{n, k} : std::vector{k, n}; +} + +static std::vector b_shape_for_te(size_t m, size_t k, bool transb) { + // B contributes the M dimension. + // transb=false means physical B is [M,K]; transb=true means physical B is [K,M]. + return transb ? std::vector{k, m} : std::vector{m, k}; +} + +struct ErrorStats { + size_t count = 0; + double sum_abs = 0.0; + double sum_rel = 0.0; + double sum_ref_abs = 0.0; + double sum_got_abs = 0.0; + float max_abs = 0.0f; + float max_rel = 0.0f; + std::vector abs_errs; +}; + +static void add_err(ErrorStats& s, float got, float ref) { + const float abs_err = std::abs(got - ref); + const float rel_err = abs_err / std::max(std::abs(ref), 1.0e-12f); + s.count++; + s.sum_abs += abs_err; + s.sum_rel += rel_err; + s.sum_ref_abs += std::abs(ref); + s.sum_got_abs += std::abs(got); + s.max_abs = std::max(s.max_abs, abs_err); + s.max_rel = std::max(s.max_rel, rel_err); + s.abs_errs.push_back(abs_err); +} + + +static void expect_reference_match(const std::string& label, + const ErrorStats& stats, + float max_abs_limit, + float mean_abs_limit) { + EXPECT_LE(stats.max_abs, max_abs_limit) << label; + EXPECT_LE(stats.sum_abs / static_cast(std::max(stats.count, 1)), + static_cast(mean_abs_limit)) << label; +} + +static void run_te_grouped_mxfp8(const std::vector& a_mx, + const std::vector& b_mx, + std::vector* outputs, + Tensor* workspace, + bool transa, + bool transb, + int math_sm_count) { + const size_t groups = a_mx.size(); + std::vector A(groups), B(groups), D(groups), Bias(groups), PreGelu(groups); + std::vector empty_bias(groups), empty_pregelu(groups); + + for (size_t i = 0; i < groups; ++i) { + A[i] = const_cast(a_mx[i]).data(); + B[i] = const_cast(b_mx[i]).data(); + D[i] = (*outputs)[i].data(); + Bias[i] = empty_bias[i].data(); + PreGelu[i] = empty_pregelu[i].data(); + } + + std::vector Workspaces(1); + Workspaces[0] = workspace->data(); + + nvte_multi_tensor_gemm(A.data(), + B.data(), + D.data(), + Bias.data(), + PreGelu.data(), + groups, + transa, + transb, + false, // grad + Workspaces.data(), + false, // accumulate + false, // use_split_accumulator + math_sm_count, + 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); +} + +template +static void run_hip_ref_for_group(const Tensor& a_mx, + const Tensor& b_mx, + Tensor* ref_d_colmajor, + size_t m, + size_t k, + size_t n, + bool transa, + bool transb) { + // TE grouped GEMM output is op(B) [M,K] * op(A) [K,N] -> [M,N]. + // compute_ref_kernel convention is A_left [M,K] * B_right [K,N]. + // Therefore left operand is TE B and right operand is TE A. + const bool left_transa = !transb; + const bool right_transb = !transa; + + const bool left_use_colwise = !left_transa; // Same rule as test_cublaslt_gemm run_reference. + const bool right_use_colwise = right_transb; // Same rule as test_cublaslt_gemm run_reference. + + const auto left_s = left_use_colwise ? b_mx.columnwise_scale_inv_shape() + : b_mx.rowwise_scale_inv_shape(); + const auto right_s = right_use_colwise ? a_mx.columnwise_scale_inv_shape() + : a_mx.rowwise_scale_inv_shape(); + NVTE_CHECK(left_s.ndim == 2 && right_s.ndim == 2, "Expected 2D MXFP8 scale_inv tensors"); + const size_t left_scale_ld = left_s.data[1]; + const size_t right_scale_ld = right_s.data[1]; + + dim3 block(16, 16); + dim3 grid(static_cast((n + block.x - 1) / block.x), + static_cast((m + block.y - 1) / block.y)); + const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); + + compute_ref_kernel + <<>>( + static_cast(left_use_colwise ? b_mx.columnwise_dptr() : b_mx.rowwise_dptr()), + static_cast(right_use_colwise ? a_mx.columnwise_dptr() : a_mx.rowwise_dptr()), + 1.0f, + 1.0f, + static_cast(left_use_colwise ? b_mx.columnwise_scale_inv_dptr() + : b_mx.rowwise_scale_inv_dptr()), + static_cast(right_use_colwise ? a_mx.columnwise_scale_inv_dptr() + : a_mx.rowwise_scale_inv_dptr()), + left_scale_ld, + right_scale_ld, + left_use_colwise, + right_use_colwise, + nullptr, + 1.0f, + m, k, n, + static_cast(ref_d_colmajor->rowwise_dptr()), + nullptr, + nullptr, + left_transa, + right_transb, + false, + left_use_colwise, + right_use_colwise, + true); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); +} + +template +static ck_tile::HostTensor run_ck_tile_reference_for_group( + const Tensor& a_mx, + const Tensor& b_mx, + size_t m, + size_t k, + size_t n, + bool transa, + bool transb) { + using namespace ck_tile::literals; + using AType = CkAType; + using BType = CkBType; + using CType = ck_tile::bfloat16_t; + using ScaleType = ck_tile::e8m0_t; + + const size_t kscale = k / 32; + + const bool left_transa = !transb; + const bool right_transb = !transa; + const bool left_use_colwise = !left_transa; + const bool right_use_colwise = right_transb; + + ck_tile::HostTensor a_left( + left_transa ? ck_tile::HostTensorDescriptor({m, k}, {k, 1_uz}) + : ck_tile::HostTensorDescriptor({m, k}, {1_uz, m})); + ck_tile::HostTensor b_right( + right_transb ? ck_tile::HostTensorDescriptor({k, n}, {n, 1_uz}) + : ck_tile::HostTensorDescriptor({k, n}, {1_uz, k})); + ck_tile::HostTensor c_ref( + ck_tile::HostTensorDescriptor({m, n}, {n, 1_uz})); + + ck_tile::HostTensor a_scale_ref( + left_use_colwise ? ck_tile::HostTensorDescriptor({m, kscale}, {1_uz, m}) + : ck_tile::HostTensorDescriptor({m, kscale}, {kscale, 1_uz})); + ck_tile::HostTensor b_scale_ref( + right_use_colwise ? ck_tile::HostTensorDescriptor({kscale, n}, {n, 1_uz}) + : ck_tile::HostTensorDescriptor({kscale, n}, {1_uz, kscale})); + + c_ref.SetZero(); + + NVTE_CHECK_CUDA(cudaMemcpy(a_left.data(), + left_use_colwise ? b_mx.columnwise_dptr() : b_mx.rowwise_dptr(), + a_left.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(b_right.data(), + right_use_colwise ? a_mx.columnwise_dptr() : a_mx.rowwise_dptr(), + b_right.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(a_scale_ref.data(), + left_use_colwise ? b_mx.columnwise_scale_inv_dptr() + : b_mx.rowwise_scale_inv_dptr(), + a_scale_ref.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(b_scale_ref.data(), + right_use_colwise ? a_mx.columnwise_scale_inv_dptr() + : a_mx.rowwise_scale_inv_dptr(), + b_scale_ref.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + + ck_tile::reference_mx_gemm( + a_left, b_right, c_ref, a_scale_ref, b_scale_ref); + return c_ref; +} + +static ErrorStats compare_te_vs_hip(const Tensor& te_out_rowmajor, + const Tensor& hip_ref_colmajor, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); + const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(te[i * n + j]), to_float(hip[j * m + i])); + } + } + return stats; +} + +static ErrorStats compare_te_vs_ck(const Tensor& te_out_rowmajor, + const ck_tile::HostTensor& ck_ref, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(te[i * n + j]), to_float(ck_ref(i, j))); + } + } + return stats; +} + +static ErrorStats compare_ck_vs_hip(const ck_tile::HostTensor& ck_ref, + const Tensor& hip_ref_colmajor, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(ck_ref(i, j)), to_float(hip[j * m + i])); + } + } + return stats; +} + +template +static void run_case_typed(const CaseConfig& cfg) { + set_env_defaults(); + + ASSERT_EQ(cfg.k % 128, 0UL) << "K must be a multiple of 128 for MXFP8"; + ASSERT_EQ(cfg.m_total % static_cast(cfg.experts), 0UL); + + cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); +#ifdef __HIP_PLATFORM_AMD__ + const bool is_gfx950_or_newer_cdna = (prop.major == 9 && prop.minor >= 5); + const bool is_gfx1250 = (prop.major == 12 && prop.minor == 5); + + if (!is_gfx950_or_newer_cdna && !is_gfx1250) { + GTEST_SKIP() << "MXFP8 requires gfx950+ or gfx1250 in this test. GPU=" << prop.name + << " major=" << prop.major << " minor=" << prop.minor; + } +#endif + + const auto m_splits = split_even(cfg.m_total, cfg.experts); + + std::vector a_src; + std::vector b_src; + std::vector a_mx; + std::vector b_mx; + std::vector output_te; + std::vector output_hip_colmajor; + a_src.reserve(cfg.experts); + b_src.reserve(cfg.experts); + a_mx.reserve(cfg.experts); + b_mx.reserve(cfg.experts); + output_te.reserve(cfg.experts); + output_hip_colmajor.reserve(cfg.experts); + + for (int g = 0; g < cfg.experts; ++g) { + const size_t m = m_splits[g]; + const auto a_shape = a_shape_for_te(cfg.n, cfg.k, cfg.layout.transa); + const auto b_shape = b_shape_for_te(m, cfg.k, cfg.layout.transb); + + a_src.emplace_back("a_src", a_shape, DType::kBFloat16); + b_src.emplace_back("b_src", b_shape, DType::kBFloat16); + + fill_randn_cpu(&a_src.back(), cfg.scale, cfg.seed + 1009 * g + 17); + fill_randn_cpu(&b_src.back(), cfg.scale, cfg.seed + 1009 * g + 29); + + // Allocate both rowwise and columnwise MX views so the backend can canonicalize NN/NT/TN. + a_mx.emplace_back("a_mx", a_shape, te_dtype(cfg.dtype.a), + true, true, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + b_mx.emplace_back("b_mx", b_shape, te_dtype(cfg.dtype.b), + true, true, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + + nvte_quantize(a_src.back().data(), a_mx.back().data(), 0); + nvte_quantize(b_src.back().data(), b_mx.back().data(), 0); + + output_te.emplace_back("output_te", std::vector{m, cfg.n}, DType::kBFloat16); + output_hip_colmajor.emplace_back("output_hip_colmajor", std::vector{cfg.n, m}, DType::kBFloat16); + } + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + Tensor workspace("workspace", std::vector{67108864}, DType::kByte); + + run_te_grouped_mxfp8(a_mx, b_mx, &output_te, &workspace, + cfg.layout.transa, cfg.layout.transb, + prop.multiProcessorCount); + for (auto& out : output_te) out.to_cpu(); + + for (int g = 0; g < cfg.experts; ++g) { + run_hip_ref_for_group(a_mx[g], b_mx[g], &output_hip_colmajor[g], + m_splits[g], cfg.k, cfg.n, + cfg.layout.transa, cfg.layout.transb); + output_hip_colmajor[g].to_cpu(); + expect_reference_match("group " + std::to_string(g) + " TE_vs_HIP_REF", + compare_te_vs_hip(output_te[g], output_hip_colmajor[g], + m_splits[g], cfg.n), + 0.25f, + 0.03f); + } + + for (int g = 0; g < cfg.experts; ++g) { + auto ck_ref = run_ck_tile_reference_for_group(a_mx[g], b_mx[g], + m_splits[g], cfg.k, cfg.n, + cfg.layout.transa, cfg.layout.transb); + expect_reference_match("group " + std::to_string(g) + " TE_vs_CK_REF ", + compare_te_vs_ck(output_te[g], ck_ref, m_splits[g], cfg.n), + 0.25f, + 0.03f); + expect_reference_match("group " + std::to_string(g) + " CK_vs_HIP_REF", + compare_ck_vs_hip(ck_ref, output_hip_colmajor[g], + m_splits[g], cfg.n), + 0.25f, + 0.03f); + } +} + +static void run_case(const CaseConfig& cfg) { + if (cfg.dtype.a == MXOperandDType::FP8 && cfg.dtype.b == MXOperandDType::FP8) { + run_case_typed(cfg); + } else if (cfg.dtype.a == MXOperandDType::FP8 && cfg.dtype.b == MXOperandDType::BF8) { + run_case_typed(cfg); + } else if (cfg.dtype.a == MXOperandDType::BF8 && cfg.dtype.b == MXOperandDType::FP8) { + run_case_typed(cfg); + } else { + run_case_typed(cfg); + } +} + +} // namespace + +class GroupedMXFP8TestSuite : public ::testing::TestWithParam {}; + +TEST_P(GroupedMXFP8TestSuite, MatchesCKTileAndHIPReferences) { + run_case(GetParam()); +} + +static constexpr LayoutConfig kNN{"NN", false, false}; +static constexpr LayoutConfig kNT{"NT", false, true}; +static constexpr LayoutConfig kTN{"TN", true, false}; + +static constexpr DTypeConfig kFP8FP8{"FP8xFP8", MXOperandDType::FP8, MXOperandDType::FP8}; +static constexpr DTypeConfig kFP8BF8{"FP8xBF8", MXOperandDType::FP8, MXOperandDType::BF8}; +static constexpr DTypeConfig kBF8FP8{"BF8xFP8", MXOperandDType::BF8, MXOperandDType::FP8}; +static constexpr DTypeConfig kBF8BF8{"BF8xBF8", MXOperandDType::BF8, MXOperandDType::BF8}; + +static std::vector make_cases() { + const std::vector dtypes = {kFP8FP8, kFP8BF8, kBF8FP8, kBF8BF8}; + const std::vector base_cases = { + // Small sanity across NN/NT/TN. + CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kNN, kFP8FP8}, + CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kNT, kFP8FP8}, + CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kTN, kFP8FP8}, + // Earlier failure regime across NN/NT/TN. + CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kNN, kFP8FP8}, + CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kNT, kFP8FP8}, + CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kTN, kFP8FP8}, + // Llama-ish suspicious path across NN/NT/TN. + CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kNN, kFP8FP8}, + CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kNT, kFP8FP8}, + CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kTN, kFP8FP8}, + }; + + std::vector cases; + cases.reserve(base_cases.size() * dtypes.size()); + for (const auto& base : base_cases) { + for (const auto& dtype : dtypes) { + CaseConfig c = base; + c.dtype = dtype; + cases.push_back(c); + } + } + return cases; +} + +static const std::vector kCases = make_cases(); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + GroupedMXFP8TestSuite, + ::testing::ValuesIn(kCases), + case_name); From 972cea3035c8c45d68a8e86cd512a8bd8f7badc1 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sun, 3 May 2026 17:22:45 +0000 Subject: [PATCH 11/13] clean up code --- tests/cpp/operator/CMakeLists.txt | 2 +- ...uped_mxfp8.cu => test_ck_grouped_mxfp8.cu} | 10 ++--- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 43 ++++++------------- 3 files changed, 17 insertions(+), 38 deletions(-) rename tests/cpp/operator/{test_te_ck_grouped_mxfp8.cu => test_ck_grouped_mxfp8.cu} (98%) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index c81ab1e62..4f87a9091 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -16,7 +16,7 @@ list(APPEND test_cuda_sources test_dequantize_mxfp8.cu test_dequantize_nvfp4.cu test_cast_nvfp4_transpose.cu - test_te_ck_grouped_mxfp8.cu + test_ck_grouped_mxfp8.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu diff --git a/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu b/tests/cpp/operator/test_ck_grouped_mxfp8.cu similarity index 98% rename from tests/cpp/operator/test_te_ck_grouped_mxfp8.cu rename to tests/cpp/operator/test_ck_grouped_mxfp8.cu index 1ad32557c..7ea939320 100644 --- a/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu +++ b/tests/cpp/operator/test_ck_grouped_mxfp8.cu @@ -10,9 +10,6 @@ // 1. TE nvte_multi_tensor_gemm grouped path (CK backend selected by env) // 2. ck_tile::reference_mx_gemm host reference, using exact quantized operands/scales // 3. TE HIP reference kernel adapted from test_cublaslt_gemm.cu compute_ref_kernel -// -// Intended drop-in location: -// TransformerEngine/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu #ifndef CK_TILE_USE_OCP_FP8 #define CK_TILE_USE_OCP_FP8 1 @@ -478,12 +475,11 @@ static void run_case_typed(const CaseConfig& cfg) { cudaDeviceProp prop; NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); #ifdef __HIP_PLATFORM_AMD__ - const bool is_gfx950_or_newer_cdna = (prop.major == 9 && prop.minor >= 5); const bool is_gfx1250 = (prop.major == 12 && prop.minor == 5); - if (!is_gfx950_or_newer_cdna && !is_gfx1250) { - GTEST_SKIP() << "MXFP8 requires gfx950+ or gfx1250 in this test. GPU=" << prop.name - << " major=" << prop.major << " minor=" << prop.minor; + if (!is_gfx1250) { + GTEST_SKIP() << "This MXFP8 grouped GEMM test currently exercises the gfx1250-compatible CK pipeline only. GPU=" + << prop.name << " major=" << prop.major << " minor=" << prop.minor; } #endif diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index 4a5a4aaa9..4d7323be3 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -188,7 +188,7 @@ void preShuffleScaleBuffer_gfx1250(const ScaleType* src, NVTE_CHECK_CUDA(hipGetLastError()); } -template +template bool invoke_mx_grouped_gemm(const std::vector& descs, const GroupedGemmRunContext& ctx, const ck_tile::stream_config& stream_cfg) { // check hardware WMMA support for the warp tile @@ -261,7 +261,7 @@ bool invoke_mx_grouped_gemm(const std::vector& descs, con BScaleType>; /* make pipeline selective */ using GemmPipeline = - typename MxGemmPipelineTypeSelector::pipeline; using GemmEpilogue = ck_tile::TdmEpilogue< ck_tile::CShuffleEpilogueProblemcolumnwise_scale_inv : B0_te->scale_inv; NVTE_CHECK(A0_data.dptr != nullptr, - "ck_tile_mx_grouped_gemm: effective A[0] data is not initialized"); + "ck_tile_mx_grouped_gemm: A[0] data is not initialized"); NVTE_CHECK(B0_data.dptr != nullptr, - "ck_tile_mx_grouped_gemm: effective B[0] data is not initialized"); + "ck_tile_mx_grouped_gemm: B[0] data is not initialized"); NVTE_CHECK(A0_scale.dptr != nullptr, - "ck_tile_mx_grouped_gemm: effective A[0] scale_inv is not initialized"); + "ck_tile_mx_grouped_gemm: A[0] scale_inv is not initialized"); NVTE_CHECK(B0_scale.dptr != nullptr, - "ck_tile_mx_grouped_gemm: effective B[0] scale_inv is not initialized"); + "ck_tile_mx_grouped_gemm: B[0] scale_inv is not initialized"); const auto a_scale_dtype = A0_scale.dtype; const auto b_scale_dtype = B0_scale.dtype; @@ -377,8 +381,8 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, const auto a_dtype = A0_data.dtype; const auto b_dtype = B0_data.dtype; const auto d_dtype = D0->dtype(); - NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: effective A dtype must be FP8"); - NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: effective B dtype must be FP8"); + NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: A dtype must be FP8"); + NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: B dtype must be FP8"); using AScaleType = ck_tile::e8m0_t; using BScaleType = ck_tile::e8m0_t; @@ -432,11 +436,6 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - // MXFP8 columnwise_data is not a physical transpose. It has the same - // logical tensor shape as rowwise data, but is quantized with scales - // along the other dimension. Therefore dims/strides must always be - // derived from the TE tensor shape, not from columnwise_data.shape - // interpreted as a transposed storage view. if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized A in group ", i); } @@ -473,27 +472,11 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, ". D=", Dd0, "x", Dd1, ", expected=", M, "x", N); } - if (i == 0) { - printf("[MX CK] transA=%d transB=%d use_a_col=%d use_b_col=%d " - "M=%ld N=%ld K=%ld Ad=[%ld,%ld] Bd=[%ld,%ld] " - "a_scale_shape=[%zu,%zu] b_scale_shape=[%zu,%zu]\n", - static_cast(ctx.transA), - static_cast(ctx.transB), - static_cast(ctx.use_a_colwise_data), - static_cast(ctx.use_b_colwise_data), - M, N, K, Ad0, Ad1, Bd0, Bd1, - a_scales.shape.size() > 0 ? a_scales.shape[0] : 0, - a_scales.shape.size() > 1 ? a_scales.shape[1] : 0, - b_scales.shape.size() > 0 ? b_scales.shape[0] : 0, - b_scales.shape.size() > 1 ? b_scales.shape[1] : 0); - } - const ck_tile::index_t stride_A = static_cast(Ad1); const ck_tile::index_t stride_B = static_cast(Bd1); const ck_tile::index_t stride_E = static_cast(Dd1); // Pre-shuffle scale buffers for the hardware. - // For the NT-normalized presentation, A scales are MxKScale and B scales are NxKScale. const int a_scale_actual_rows = static_cast(M); const int a_scale_output_rows = ck_tile::integer_least_multiple( From c0fabff0c93569b39a9006f49dea35245f52b8df Mon Sep 17 00:00:00 2001 From: Aristotle <89488299+aris134@users.noreply.github.com> Date: Wed, 6 May 2026 16:40:13 -0400 Subject: [PATCH 12/13] Update cublaslt_gemm.cu --- transformer_engine/common/gemm/cublaslt_gemm.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index b3863350e..1aef1b0de 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1123,7 +1123,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (!use_cutlass || num_gemms == 1) { #else // Currently only support cutlass group gemm on Hopper Arch - if (!(is_hopper && use_cutlass)) { + // if (!(is_hopper && use_cutlass)) { + if (!use_cutlass) { #endif if (warn_fallback) { NVTE_WARN("Fallback to cuBLAS grouped GEMM."); From 3db2e5a40d72f9c17efc0e190e37392aeed94ee7 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 11 May 2026 09:10:14 -0400 Subject: [PATCH 13/13] address pr comments --- tests/cpp/operator/CMakeLists.txt | 4 ++-- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 4f87a9091..c1bc43faa 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -16,7 +16,6 @@ list(APPEND test_cuda_sources test_dequantize_mxfp8.cu test_dequantize_nvfp4.cu test_cast_nvfp4_transpose.cu - test_ck_grouped_mxfp8.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu @@ -41,7 +40,8 @@ if(USE_CUDA) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu - test_cast_mxfp4_transpose.cu) + test_cast_mxfp4_transpose.cu + test_ck_grouped_mxfp8.cu) endif() if(USE_CUDA) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 1aef1b0de..445a5ce0e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1123,7 +1123,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (!use_cutlass || num_gemms == 1) { #else // Currently only support cutlass group gemm on Hopper Arch - // if (!(is_hopper && use_cutlass)) { + if (!(is_hopper && use_cutlass)) { if (!use_cutlass) { #endif if (warn_fallback) {