diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 0ebd7fdfe..c1bc43faa 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -31,7 +31,8 @@ list(APPEND test_cuda_sources test_multi_unpadding.cu test_causal_softmax.cu test_swap_first_dims.cu - ../test_common.cu) + ../test_common.cu) + if(USE_CUDA) list(APPEND test_cuda_sources test_cast_float8blockwise.cu @@ -39,7 +40,8 @@ if(USE_CUDA) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu - test_cast_mxfp4_transpose.cu) + test_cast_mxfp4_transpose.cu + test_ck_grouped_mxfp8.cu) endif() if(USE_CUDA) @@ -54,12 +56,14 @@ endif() # Find required packages find_package(OpenMP REQUIRED) + if(USE_CUDA) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) else() target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX hiprand) endif() + target_compile_options(test_operator PRIVATE -O2 -fopenmp) include(GoogleTest) diff --git a/tests/cpp/operator/test_ck_grouped_mxfp8.cu b/tests/cpp/operator/test_ck_grouped_mxfp8.cu new file mode 100644 index 000000000..7ea939320 --- /dev/null +++ b/tests/cpp/operator/test_ck_grouped_mxfp8.cu @@ -0,0 +1,625 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +// TE CK grouped MXFP8 validation. +// +// Compares three paths for grouped MXFP8 GEMM across NN/NT/TN transpose layouts: +// 1. TE nvte_multi_tensor_gemm grouped path (CK backend selected by env) +// 2. ck_tile::reference_mx_gemm host reference, using exact quantized operands/scales +// 3. TE HIP reference kernel adapted from test_cublaslt_gemm.cu compute_ref_kernel + +#ifndef CK_TILE_USE_OCP_FP8 +#define CK_TILE_USE_OCP_FP8 1 +#endif + +#include +#include +#include + +#include +#include +#include + +#include "../test_common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace transformer_engine; +using namespace test; + +using fp8 = fp8e4m3; +using bf8 = fp8e5m2; +using bf16_t = bf16; +using e8m0_t_te = fp8e8m0; + +namespace { + +enum class MXOperandDType { + FP8, + BF8, +}; + +struct DTypeConfig { + const char* name; + MXOperandDType a; + MXOperandDType b; +}; + +static DType te_dtype(MXOperandDType t) { + return t == MXOperandDType::FP8 ? DType::kFloat8E4M3 : DType::kFloat8E5M2; +} + +struct LayoutConfig { + const char* name; + bool transa; + bool transb; +}; + +struct CaseConfig { + size_t m_total; + size_t n; + size_t k; + int experts; + float scale; + int seed; + LayoutConfig layout; + DTypeConfig dtype; +}; + +static std::string case_name(const testing::TestParamInfo& info) { + const auto& c = info.param; + std::ostringstream os; + os << "M" << c.m_total << "_N" << c.n << "_K" << c.k + << "_E" << c.experts << "_" << c.layout.name << "_" << c.dtype.name; + return os.str(); +} + +static void set_env_defaults() { + setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1", 1); + setenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", "1", 0); + setenv("NVTE_ROCM_ENABLE_MXFP8", "1", 0); +} + +static float to_float(float x) { return x; } +static float to_float(const bf16_t& x) { return static_cast(x); } +static float to_float(const ck_tile::bfloat16_t& x) { return static_cast(x); } + +__device__ __host__ __forceinline__ float ref_gelu_unused(float x) { + float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template +__global__ void compute_ref_kernel( + const A_Type* __restrict__ a_data, + const B_Type* __restrict__ b_data, + float a_scale_inv_scalar, + float b_scale_inv_scalar, + const e8m0_t_te* __restrict__ a_scale_inv_mxfp8, + const e8m0_t_te* __restrict__ b_scale_inv_mxfp8, + size_t a_scale_ld, + size_t b_scale_ld, + bool a_scale_is_colwise, + bool b_scale_is_colwise, + const Bias_Type* __restrict__ bias_data, + float d_scale, + size_t m, size_t k, size_t n, + D_Type* __restrict__ d_data, + float* __restrict__ d_amax, + Gelu_Type* __restrict__ gelu_data, + bool transa, + bool transb, + bool is_fp8_output, + bool a_is_colwise, + bool b_is_colwise, + bool use_mxfp8) { + const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; + const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; + const bool in_range = (ii < m) && (jj < n); + + float val = 0.0f; + if (in_range) { + for (size_t kk = 0; kk < k; ++kk) { + size_t a_idx = 0; + size_t b_idx = 0; + + if (use_mxfp8) { + a_idx = transa ? (ii * k + kk) : (kk * m + ii); + b_idx = transb ? (kk * n + jj) : (jj * k + kk); + } else { + a_idx = a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + b_idx = b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); + } + + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; + + if (a_scale_inv_mxfp8) { + const size_t kc = kk / 32; + const size_t a_scale_idx = + a_scale_is_colwise ? (kc * a_scale_ld + ii) : (ii * a_scale_ld + kc); + const size_t b_scale_idx = + b_scale_is_colwise ? (kc * b_scale_ld + jj) : (jj * b_scale_ld + kc); + a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); + b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); + } + + const float a_val = static_cast(a_data[a_idx]); + const float b_val = static_cast(b_data[b_idx]); + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; + } + + if (bias_data) val += static_cast(bias_data[ii]); + if (gelu_data) { + gelu_data[ii + jj * m] = static_cast(val); + val = ref_gelu_unused(val); + } + + const float scaled = val * d_scale; + d_data[ii + jj * m] = static_cast(scaled); + } + + if (is_fp8_output && d_amax) { + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int nthreads = blockDim.x * blockDim.y; + extern __shared__ float s_amax[]; + s_amax[tid] = in_range ? fabsf(val) : 0.0f; + __syncthreads(); + for (int offset = nthreads / 2; offset > 0; offset /= 2) { + if (tid < offset) s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); + __syncthreads(); + } + if (tid == 0) atomicMax(d_amax, s_amax[0]); + } +} + +template +static void fill_randn_cpu(Tensor* t, float scale, int seed) { + std::mt19937 gen(seed); + std::normal_distribution dist(0.0f, scale); + const size_t n = product(t->rowwise_shape()); + T* ptr = t->rowwise_cpu_dptr(); + for (size_t i = 0; i < n; ++i) ptr[i] = static_cast(dist(gen)); + t->from_cpu(); +} + +static std::vector split_even(size_t m_total, int experts) { + NVTE_CHECK(experts > 0, "experts must be > 0"); + NVTE_CHECK(m_total % static_cast(experts) == 0, + "m_total must be divisible by experts"); + return std::vector(experts, m_total / static_cast(experts)); +} + +static std::vector a_shape_for_te(size_t n, size_t k, bool transa) { + // TE grouped GEMM computes output shape [M,N]. A contributes the N dimension. + // transa=true means physical A is [N,K]; transa=false means physical A is [K,N]. + return transa ? std::vector{n, k} : std::vector{k, n}; +} + +static std::vector b_shape_for_te(size_t m, size_t k, bool transb) { + // B contributes the M dimension. + // transb=false means physical B is [M,K]; transb=true means physical B is [K,M]. + return transb ? std::vector{k, m} : std::vector{m, k}; +} + +struct ErrorStats { + size_t count = 0; + double sum_abs = 0.0; + double sum_rel = 0.0; + double sum_ref_abs = 0.0; + double sum_got_abs = 0.0; + float max_abs = 0.0f; + float max_rel = 0.0f; + std::vector abs_errs; +}; + +static void add_err(ErrorStats& s, float got, float ref) { + const float abs_err = std::abs(got - ref); + const float rel_err = abs_err / std::max(std::abs(ref), 1.0e-12f); + s.count++; + s.sum_abs += abs_err; + s.sum_rel += rel_err; + s.sum_ref_abs += std::abs(ref); + s.sum_got_abs += std::abs(got); + s.max_abs = std::max(s.max_abs, abs_err); + s.max_rel = std::max(s.max_rel, rel_err); + s.abs_errs.push_back(abs_err); +} + + +static void expect_reference_match(const std::string& label, + const ErrorStats& stats, + float max_abs_limit, + float mean_abs_limit) { + EXPECT_LE(stats.max_abs, max_abs_limit) << label; + EXPECT_LE(stats.sum_abs / static_cast(std::max(stats.count, 1)), + static_cast(mean_abs_limit)) << label; +} + +static void run_te_grouped_mxfp8(const std::vector& a_mx, + const std::vector& b_mx, + std::vector* outputs, + Tensor* workspace, + bool transa, + bool transb, + int math_sm_count) { + const size_t groups = a_mx.size(); + std::vector A(groups), B(groups), D(groups), Bias(groups), PreGelu(groups); + std::vector empty_bias(groups), empty_pregelu(groups); + + for (size_t i = 0; i < groups; ++i) { + A[i] = const_cast(a_mx[i]).data(); + B[i] = const_cast(b_mx[i]).data(); + D[i] = (*outputs)[i].data(); + Bias[i] = empty_bias[i].data(); + PreGelu[i] = empty_pregelu[i].data(); + } + + std::vector Workspaces(1); + Workspaces[0] = workspace->data(); + + nvte_multi_tensor_gemm(A.data(), + B.data(), + D.data(), + Bias.data(), + PreGelu.data(), + groups, + transa, + transb, + false, // grad + Workspaces.data(), + false, // accumulate + false, // use_split_accumulator + math_sm_count, + 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); +} + +template +static void run_hip_ref_for_group(const Tensor& a_mx, + const Tensor& b_mx, + Tensor* ref_d_colmajor, + size_t m, + size_t k, + size_t n, + bool transa, + bool transb) { + // TE grouped GEMM output is op(B) [M,K] * op(A) [K,N] -> [M,N]. + // compute_ref_kernel convention is A_left [M,K] * B_right [K,N]. + // Therefore left operand is TE B and right operand is TE A. + const bool left_transa = !transb; + const bool right_transb = !transa; + + const bool left_use_colwise = !left_transa; // Same rule as test_cublaslt_gemm run_reference. + const bool right_use_colwise = right_transb; // Same rule as test_cublaslt_gemm run_reference. + + const auto left_s = left_use_colwise ? b_mx.columnwise_scale_inv_shape() + : b_mx.rowwise_scale_inv_shape(); + const auto right_s = right_use_colwise ? a_mx.columnwise_scale_inv_shape() + : a_mx.rowwise_scale_inv_shape(); + NVTE_CHECK(left_s.ndim == 2 && right_s.ndim == 2, "Expected 2D MXFP8 scale_inv tensors"); + const size_t left_scale_ld = left_s.data[1]; + const size_t right_scale_ld = right_s.data[1]; + + dim3 block(16, 16); + dim3 grid(static_cast((n + block.x - 1) / block.x), + static_cast((m + block.y - 1) / block.y)); + const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); + + compute_ref_kernel + <<>>( + static_cast(left_use_colwise ? b_mx.columnwise_dptr() : b_mx.rowwise_dptr()), + static_cast(right_use_colwise ? a_mx.columnwise_dptr() : a_mx.rowwise_dptr()), + 1.0f, + 1.0f, + static_cast(left_use_colwise ? b_mx.columnwise_scale_inv_dptr() + : b_mx.rowwise_scale_inv_dptr()), + static_cast(right_use_colwise ? a_mx.columnwise_scale_inv_dptr() + : a_mx.rowwise_scale_inv_dptr()), + left_scale_ld, + right_scale_ld, + left_use_colwise, + right_use_colwise, + nullptr, + 1.0f, + m, k, n, + static_cast(ref_d_colmajor->rowwise_dptr()), + nullptr, + nullptr, + left_transa, + right_transb, + false, + left_use_colwise, + right_use_colwise, + true); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); +} + +template +static ck_tile::HostTensor run_ck_tile_reference_for_group( + const Tensor& a_mx, + const Tensor& b_mx, + size_t m, + size_t k, + size_t n, + bool transa, + bool transb) { + using namespace ck_tile::literals; + using AType = CkAType; + using BType = CkBType; + using CType = ck_tile::bfloat16_t; + using ScaleType = ck_tile::e8m0_t; + + const size_t kscale = k / 32; + + const bool left_transa = !transb; + const bool right_transb = !transa; + const bool left_use_colwise = !left_transa; + const bool right_use_colwise = right_transb; + + ck_tile::HostTensor a_left( + left_transa ? ck_tile::HostTensorDescriptor({m, k}, {k, 1_uz}) + : ck_tile::HostTensorDescriptor({m, k}, {1_uz, m})); + ck_tile::HostTensor b_right( + right_transb ? ck_tile::HostTensorDescriptor({k, n}, {n, 1_uz}) + : ck_tile::HostTensorDescriptor({k, n}, {1_uz, k})); + ck_tile::HostTensor c_ref( + ck_tile::HostTensorDescriptor({m, n}, {n, 1_uz})); + + ck_tile::HostTensor a_scale_ref( + left_use_colwise ? ck_tile::HostTensorDescriptor({m, kscale}, {1_uz, m}) + : ck_tile::HostTensorDescriptor({m, kscale}, {kscale, 1_uz})); + ck_tile::HostTensor b_scale_ref( + right_use_colwise ? ck_tile::HostTensorDescriptor({kscale, n}, {n, 1_uz}) + : ck_tile::HostTensorDescriptor({kscale, n}, {1_uz, kscale})); + + c_ref.SetZero(); + + NVTE_CHECK_CUDA(cudaMemcpy(a_left.data(), + left_use_colwise ? b_mx.columnwise_dptr() : b_mx.rowwise_dptr(), + a_left.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(b_right.data(), + right_use_colwise ? a_mx.columnwise_dptr() : a_mx.rowwise_dptr(), + b_right.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(a_scale_ref.data(), + left_use_colwise ? b_mx.columnwise_scale_inv_dptr() + : b_mx.rowwise_scale_inv_dptr(), + a_scale_ref.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(b_scale_ref.data(), + right_use_colwise ? a_mx.columnwise_scale_inv_dptr() + : a_mx.rowwise_scale_inv_dptr(), + b_scale_ref.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + + ck_tile::reference_mx_gemm( + a_left, b_right, c_ref, a_scale_ref, b_scale_ref); + return c_ref; +} + +static ErrorStats compare_te_vs_hip(const Tensor& te_out_rowmajor, + const Tensor& hip_ref_colmajor, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); + const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(te[i * n + j]), to_float(hip[j * m + i])); + } + } + return stats; +} + +static ErrorStats compare_te_vs_ck(const Tensor& te_out_rowmajor, + const ck_tile::HostTensor& ck_ref, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(te[i * n + j]), to_float(ck_ref(i, j))); + } + } + return stats; +} + +static ErrorStats compare_ck_vs_hip(const ck_tile::HostTensor& ck_ref, + const Tensor& hip_ref_colmajor, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(ck_ref(i, j)), to_float(hip[j * m + i])); + } + } + return stats; +} + +template +static void run_case_typed(const CaseConfig& cfg) { + set_env_defaults(); + + ASSERT_EQ(cfg.k % 128, 0UL) << "K must be a multiple of 128 for MXFP8"; + ASSERT_EQ(cfg.m_total % static_cast(cfg.experts), 0UL); + + cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); +#ifdef __HIP_PLATFORM_AMD__ + const bool is_gfx1250 = (prop.major == 12 && prop.minor == 5); + + if (!is_gfx1250) { + GTEST_SKIP() << "This MXFP8 grouped GEMM test currently exercises the gfx1250-compatible CK pipeline only. GPU=" + << prop.name << " major=" << prop.major << " minor=" << prop.minor; + } +#endif + + const auto m_splits = split_even(cfg.m_total, cfg.experts); + + std::vector a_src; + std::vector b_src; + std::vector a_mx; + std::vector b_mx; + std::vector output_te; + std::vector output_hip_colmajor; + a_src.reserve(cfg.experts); + b_src.reserve(cfg.experts); + a_mx.reserve(cfg.experts); + b_mx.reserve(cfg.experts); + output_te.reserve(cfg.experts); + output_hip_colmajor.reserve(cfg.experts); + + for (int g = 0; g < cfg.experts; ++g) { + const size_t m = m_splits[g]; + const auto a_shape = a_shape_for_te(cfg.n, cfg.k, cfg.layout.transa); + const auto b_shape = b_shape_for_te(m, cfg.k, cfg.layout.transb); + + a_src.emplace_back("a_src", a_shape, DType::kBFloat16); + b_src.emplace_back("b_src", b_shape, DType::kBFloat16); + + fill_randn_cpu(&a_src.back(), cfg.scale, cfg.seed + 1009 * g + 17); + fill_randn_cpu(&b_src.back(), cfg.scale, cfg.seed + 1009 * g + 29); + + // Allocate both rowwise and columnwise MX views so the backend can canonicalize NN/NT/TN. + a_mx.emplace_back("a_mx", a_shape, te_dtype(cfg.dtype.a), + true, true, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + b_mx.emplace_back("b_mx", b_shape, te_dtype(cfg.dtype.b), + true, true, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + + nvte_quantize(a_src.back().data(), a_mx.back().data(), 0); + nvte_quantize(b_src.back().data(), b_mx.back().data(), 0); + + output_te.emplace_back("output_te", std::vector{m, cfg.n}, DType::kBFloat16); + output_hip_colmajor.emplace_back("output_hip_colmajor", std::vector{cfg.n, m}, DType::kBFloat16); + } + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + Tensor workspace("workspace", std::vector{67108864}, DType::kByte); + + run_te_grouped_mxfp8(a_mx, b_mx, &output_te, &workspace, + cfg.layout.transa, cfg.layout.transb, + prop.multiProcessorCount); + for (auto& out : output_te) out.to_cpu(); + + for (int g = 0; g < cfg.experts; ++g) { + run_hip_ref_for_group(a_mx[g], b_mx[g], &output_hip_colmajor[g], + m_splits[g], cfg.k, cfg.n, + cfg.layout.transa, cfg.layout.transb); + output_hip_colmajor[g].to_cpu(); + expect_reference_match("group " + std::to_string(g) + " TE_vs_HIP_REF", + compare_te_vs_hip(output_te[g], output_hip_colmajor[g], + m_splits[g], cfg.n), + 0.25f, + 0.03f); + } + + for (int g = 0; g < cfg.experts; ++g) { + auto ck_ref = run_ck_tile_reference_for_group(a_mx[g], b_mx[g], + m_splits[g], cfg.k, cfg.n, + cfg.layout.transa, cfg.layout.transb); + expect_reference_match("group " + std::to_string(g) + " TE_vs_CK_REF ", + compare_te_vs_ck(output_te[g], ck_ref, m_splits[g], cfg.n), + 0.25f, + 0.03f); + expect_reference_match("group " + std::to_string(g) + " CK_vs_HIP_REF", + compare_ck_vs_hip(ck_ref, output_hip_colmajor[g], + m_splits[g], cfg.n), + 0.25f, + 0.03f); + } +} + +static void run_case(const CaseConfig& cfg) { + if (cfg.dtype.a == MXOperandDType::FP8 && cfg.dtype.b == MXOperandDType::FP8) { + run_case_typed(cfg); + } else if (cfg.dtype.a == MXOperandDType::FP8 && cfg.dtype.b == MXOperandDType::BF8) { + run_case_typed(cfg); + } else if (cfg.dtype.a == MXOperandDType::BF8 && cfg.dtype.b == MXOperandDType::FP8) { + run_case_typed(cfg); + } else { + run_case_typed(cfg); + } +} + +} // namespace + +class GroupedMXFP8TestSuite : public ::testing::TestWithParam {}; + +TEST_P(GroupedMXFP8TestSuite, MatchesCKTileAndHIPReferences) { + run_case(GetParam()); +} + +static constexpr LayoutConfig kNN{"NN", false, false}; +static constexpr LayoutConfig kNT{"NT", false, true}; +static constexpr LayoutConfig kTN{"TN", true, false}; + +static constexpr DTypeConfig kFP8FP8{"FP8xFP8", MXOperandDType::FP8, MXOperandDType::FP8}; +static constexpr DTypeConfig kFP8BF8{"FP8xBF8", MXOperandDType::FP8, MXOperandDType::BF8}; +static constexpr DTypeConfig kBF8FP8{"BF8xFP8", MXOperandDType::BF8, MXOperandDType::FP8}; +static constexpr DTypeConfig kBF8BF8{"BF8xBF8", MXOperandDType::BF8, MXOperandDType::BF8}; + +static std::vector make_cases() { + const std::vector dtypes = {kFP8FP8, kFP8BF8, kBF8FP8, kBF8BF8}; + const std::vector base_cases = { + // Small sanity across NN/NT/TN. + CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kNN, kFP8FP8}, + CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kNT, kFP8FP8}, + CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kTN, kFP8FP8}, + // Earlier failure regime across NN/NT/TN. + CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kNN, kFP8FP8}, + CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kNT, kFP8FP8}, + CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kTN, kFP8FP8}, + // Llama-ish suspicious path across NN/NT/TN. + CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kNN, kFP8FP8}, + CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kNT, kFP8FP8}, + CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kTN, kFP8FP8}, + }; + + std::vector cases; + cases.reserve(base_cases.size() * dtypes.size()); + for (const auto& base : base_cases) { + for (const auto& dtype : dtypes) { + CaseConfig c = base; + c.dtype = dtype; + cases.push_back(c); + } + } + return cases; +} + +static const std::vector kCases = make_cases(); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + GroupedMXFP8TestSuite, + ::testing::ValuesIn(kCases), + case_name); diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index d9c7d1fb0..642999ef7 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2253,6 +2253,7 @@ def test_grouped_linear_accuracy_cutlass( delay_wgrad_compute, ): os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + os.environ["NVTE_ROCM_ENABLE_MXFP8"] = "1" test_grouped_linear_accuracy( dtype, num_gemms, @@ -2268,6 +2269,7 @@ def test_grouped_linear_accuracy_cutlass( use_cutlass=True, ) os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + os.environ.pop("NVTE_ROCM_ENABLE_MXFP8", None) @pytest.mark.parametrize("dtype", param_types, ids=str) diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp new file mode 100644 index 000000000..4d7323be3 --- /dev/null +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -0,0 +1,588 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include "../../common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +#include +#include +#include + +namespace transformer_engine { +namespace mx_grouped_gemm { + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>; + +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; +template <> struct TETypeToCKType { using type = float; }; + +struct GroupedGemmRunContext { + const NVTETensor* A = nullptr; + const NVTETensor* B = nullptr; + NVTETensor* D = nullptr; + + int group_num = 0; + bool transA = false; + bool transB = false; + + void* workspace = nullptr; + size_t workspace_bytes = 0; + hipStream_t stream = nullptr; + + bool use_a_colwise_data = false; + bool use_b_colwise_data = false; +}; + +// Treat TE tensors as generalized 2D matrices by flattening: +// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. +static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, + int64_t& d0, int64_t& d1) { + if (t.shape().size() < 2) { + return false; + } + d0 = static_cast(t.flat_first_dim()); + d1 = static_cast(t.flat_last_dim()); + return true; +} + +static constexpr ck_tile::index_t ScaleBlockSize = 32; + +enum struct MxGemmPipelineType +{ + CompTDMV1, + CompTDMV2 +}; + +template +struct MxGemmPipelineTypeSelector; +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompTDM; + using pipeline = ck_tile::GemmPipelineAgBgCrCompTDMV1; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV1"; } +}; + +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompTDM; + using pipeline = ck_tile::GemmPipelineAgBgCrCompTDMV2; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } +}; + +template +static inline bool has_sufficient_workspace(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_WARN("ck_tile_mx_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + ", available bytes=", ctx.workspace_bytes, ". Falling back."); + return false; + } + return true; +} + +struct GroupedGemKernelParam_Wmma +{ + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = false; + static const int kBlockPerCu = 1; + static const ck_tile::index_t M_Tile = 64; + static const ck_tile::index_t N_Tile = 64; + static const ck_tile::index_t K_Tile = 128; + static const ck_tile::index_t M_Warp = 2; + static const ck_tile::index_t N_Warp = 2; + static const ck_tile::index_t K_Warp = 1; + static const ck_tile::index_t M_Warp_Tile = 32; + static const ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 128; +}; + +template +__global__ void preshuffle_scale_gfx1250_kernel(const ScaleType* __restrict__ src, + ScaleType* __restrict__ dst, + int actual_rows, + int output_rows, + int KScale) +{ + static_assert(ScaleBlockSize == 32 && sizeof(ScaleType) == 1, + "gfx1250 scale preshuffle only supports 8-bit scale with ScaleBlockSize=32"); + constexpr int MPerXdlops = 16; + constexpr int KPerXdlops = 128; + constexpr int MNPack = 2; + constexpr int KPack = 1; + constexpr int MNStep = MPerXdlops; // 16 + constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 + const int K0 = KScale / (KPack * KStep); + const int linear = blockIdx.x * blockDim.x + threadIdx.x; + const int total = output_rows * KScale; + if(linear >= total) + return; + const int mn = linear / KScale; + const int k = linear % KScale; + const int iMNRepeat = mn / (MNStep * MNPack); + const int tempmn = mn % (MNStep * MNPack); + const int iKRepeat = k / (KStep * KPack); + const int tempk = k % (KStep * KPack); + const int outputIndex = + (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + + (iKRepeat * KStep * KPack) * (MNStep * MNPack) + + tempmn * (KStep * KPack) + + tempk; + ScaleType value{}; + if(mn < actual_rows) + { + if constexpr(KStride) + value = src[mn * KScale + k]; + else + value = src[k * actual_rows + mn]; + } + dst[outputIndex] = value; +} + +template +void preShuffleScaleBuffer_gfx1250(const ScaleType* src, + ScaleType* dst, + int actual_rows, + int output_rows, + int KScale, + hipStream_t stream) +{ + constexpr int KPerXdlops = 128; + constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 + if(KScale % KStep != 0) + { + NVTE_ERROR("preshuffle_scale_gfx1250: KScale must be a multiple of 4, " + "i.e. original K must be a multiple of 128 for ScaleBlockSize=32."); + } + const int total = output_rows * KScale; + constexpr int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + hipLaunchKernelGGL((preshuffle_scale_gfx1250_kernel), + dim3(grid_size), + dim3(block_size), + 0, + stream, + src, + dst, + actual_rows, + output_rows, + KScale); + NVTE_CHECK_CUDA(hipGetLastError()); +} + +template +bool invoke_mx_grouped_gemm(const std::vector& descs, const GroupedGemmRunContext& ctx, const ck_tile::stream_config& stream_cfg) +{ + // check hardware WMMA support for the warp tile + static constexpr bool has_wmma_support = + ck_tile::has_wmma_traits_v; + + NVTE_CHECK(has_wmma_support, + "ck_tile_mx_grouped_gemm: unsupported gfx125 WMMA traits for " + "AType/BType/AccType with warp tile shape ", + MXFP8GemmConfig::M_Warp_Tile, "x", + MXFP8GemmConfig::N_Warp_Tile, "x", + MXFP8GemmConfig::K_Warp_Tile); + + using CLayout = RowMajor; + constexpr bool preshuffle = false; + constexpr bool DoubleSmemBuffer = true; // TDM pipeline requires double smem buffer + constexpr bool TransposeC = + std::is_same_v && + MXFP8GemmConfig::M_Warp_Tile == MXFP8GemmConfig::N_Warp_Tile; + static constexpr bool StructuredSparsity = false; + static constexpr bool NumWaveGroup = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { + using ALayout = std::conditional_t; + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { + using BLayout = std::conditional_t; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using UniversalGemmProblem = + ck_tile::MxGemmPipelineProblem; + /* make pipeline selective */ + using GemmPipeline = + typename MxGemmPipelineTypeSelector::pipeline; + using GemmEpilogue = ck_tile::TdmEpilogue< + ck_tile::CShuffleEpilogueProblem,//DsDataType + float, + CType, + ck_tile::tuple<>,//DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + MXFP8GemmConfig::M_Warp, + MXFP8GemmConfig::N_Warp, + MXFP8GemmConfig::M_Warp_Tile, + MXFP8GemmConfig::N_Warp_Tile, + MXFP8GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC, + 1, /*kNumWaveGroups_*/ + false, /*FixedVectorSize_*/ + 1, /*VectorSizeC_*/ + false, /*TiledMMAPermuteN_*/ + 1, /*BlockedXDLN_PerWarp_*/ + DoubleSmemBuffer, /*DoubleSmemBuffer*/ + AType, /*AType_*/ + BType /*BType_*/>>; + using Kernel = ck_tile::MxGroupedGemmKernel; + + if (!has_sufficient_workspace(ctx)) { + return false; + } + + auto kargs = Kernel::MakeKargs(descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + NVTE_WARN("ck_tile_mx_grouped_gemm: CK_Tile kernel arguments not supported for this config. " + "Falling back."); + return false; + } + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + NVTE_CHECK_CUDA(hipMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + ctx.stream)); + ck_tile::ignore = ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + kargs.size())); + return true; + }); + }); + return false; +} + +bool ck_tile_mx_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate,//ignored for now + hipStream_t stream) { + if (group_num <= 0) { + return true; + } + + // Normalize input mats + // I.e., swap A and B, as well as transa and transb. + const NVTETensor* A_use = B; + const NVTETensor* B_use = A; + bool transA_use = transB; + bool transB_use = transA; + + // Note: for MXFP8, row-wise and col-wise data are scaled along different + // dims, with the mat interpreted in row-major. + // Use the operand transpose flags to select the correct view. + // Scale view needs to match data view. + const bool use_a_colwise_data = transA_use; + const bool use_b_colwise_data = !transB_use; + + Tensor* A0_te = convertNVTETensorCheck(A_use[0]); + Tensor* B0_te = convertNVTETensorCheck(B_use[0]); + + // Validate scale type / data type combination. + // Expected input data format: fp8/bf8 (e4m3/e5m2) + // Expected scale data format: e8m0 + const auto* D0 = convertNVTETensorCheck(D[0]); + + const auto& A0_data = use_a_colwise_data ? A0_te->columnwise_data : A0_te->data; + const auto& B0_data = use_b_colwise_data ? B0_te->columnwise_data : B0_te->data; + const auto& A0_scale = use_a_colwise_data ? A0_te->columnwise_scale_inv : A0_te->scale_inv; + const auto& B0_scale = use_b_colwise_data ? B0_te->columnwise_scale_inv : B0_te->scale_inv; + + NVTE_CHECK(A0_data.dptr != nullptr, + "ck_tile_mx_grouped_gemm: A[0] data is not initialized"); + NVTE_CHECK(B0_data.dptr != nullptr, + "ck_tile_mx_grouped_gemm: B[0] data is not initialized"); + NVTE_CHECK(A0_scale.dptr != nullptr, + "ck_tile_mx_grouped_gemm: A[0] scale_inv is not initialized"); + NVTE_CHECK(B0_scale.dptr != nullptr, + "ck_tile_mx_grouped_gemm: B[0] scale_inv is not initialized"); + + const auto a_scale_dtype = A0_scale.dtype; + const auto b_scale_dtype = B0_scale.dtype; + NVTE_CHECK(a_scale_dtype == DType::kFloat8E8M0, + "ck_tile_mx_grouped_gemm: A scale_inv dtype must be Float8E8M0, got ", + static_cast(a_scale_dtype)); + + NVTE_CHECK(b_scale_dtype == DType::kFloat8E8M0, + "ck_tile_mx_grouped_gemm: B scale_inv dtype must be Float8E8M0, got ", + static_cast(b_scale_dtype)); + + const auto a_dtype = A0_data.dtype; + const auto b_dtype = B0_data.dtype; + const auto d_dtype = D0->dtype(); + NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: A dtype must be FP8"); + NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: B dtype must be FP8"); + + using AScaleType = ck_tile::e8m0_t; + using BScaleType = ck_tile::e8m0_t; + + void* ws_ptr = nullptr; + size_t ws_bytes = 0; + if (workspace) { + auto* ws_te = convertNVTETensorCheck(*workspace); + ws_ptr = ws_te->data.dptr; + ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); + } + + GroupedGemmRunContext ctx = { + A_use, + B_use, + D, + group_num, + transA_use, + transB_use, + ws_ptr, + ws_bytes, + stream, + use_a_colwise_data, + use_b_colwise_data}; + + const ck_tile::stream_config s{ctx.stream}; + + std::vector descs; + descs.reserve(group_num); + + std::vector> a_scale_shuffled_bufs; + std::vector> b_scale_shuffled_bufs; + a_scale_shuffled_bufs.reserve(group_num); + b_scale_shuffled_bufs.reserve(group_num); + + for (int i = 0; i < group_num; i++) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = ctx.use_a_colwise_data ? A_te->columnwise_data : A_te->data; + const auto& b = ctx.use_b_colwise_data ? B_te->columnwise_data : B_te->data; + const auto& d = D_te->data; + const auto& a_scales = + ctx.use_a_colwise_data ? A_te->columnwise_scale_inv : A_te->scale_inv; + const auto& b_scales = + ctx.use_b_colwise_data ? B_te->columnwise_scale_inv : B_te->scale_inv; + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + + if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized A in group ", i); + } + + if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized B in group ", i); + } + + if (!get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized D in group ", i); + } + if (a.dptr == nullptr || b.dptr == nullptr || a_scales.dptr == nullptr || + b_scales.dptr == nullptr) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: effective A/B data or scale_inv is missing."); + } + if (a_scales.shape.size() != 2 || b_scales.shape.size() != 2) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected effective A/B scale_inv tensors to be rank-2."); + } + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + if (K % ScaleBlockSize != 0) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: K must be a multiple of ScaleBlockSize for MX GEMM", i); + } + const int KScale = static_cast(K / ScaleBlockSize); + if (Kb != K) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: K mismatch between A and B in group ", i, + ". op(A)=", M, "x", K, ", op(B)=", Kb, "x", N); + } + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: D shape mismatch in group ", i, + ". D=", Dd0, "x", Dd1, ", expected=", M, "x", N); + } + + const ck_tile::index_t stride_A = static_cast(Ad1); + const ck_tile::index_t stride_B = static_cast(Bd1); + const ck_tile::index_t stride_E = static_cast(Dd1); + + // Pre-shuffle scale buffers for the hardware. + const int a_scale_actual_rows = static_cast(M); + const int a_scale_output_rows = + ck_tile::integer_least_multiple( + static_cast(M), + static_cast(GroupedGemKernelParam_Wmma::M_Warp_Tile)); + const int b_scale_actual_rows = static_cast(N); + const int b_scale_output_rows = static_cast(N); + const size_t a_scale_shuffled_bytes = + static_cast(a_scale_output_rows) * + static_cast(KScale) * + sizeof(AScaleType); + const size_t b_scale_shuffled_bytes = + static_cast(b_scale_output_rows) * + static_cast(KScale) * + sizeof(BScaleType); + a_scale_shuffled_bufs.push_back( + std::make_unique(a_scale_shuffled_bytes)); + b_scale_shuffled_bufs.push_back( + std::make_unique(b_scale_shuffled_bytes)); + void* a_scale_shuffled_ptr = a_scale_shuffled_bufs.back()->GetDeviceBuffer(); + void* b_scale_shuffled_ptr = b_scale_shuffled_bufs.back()->GetDeviceBuffer(); + // CK expects canonical pre-shuffled scale buffers laid out as + // A: [M, KScale] and B: [N, KScale], independent of A/B data layouts. + // TE rowwise MXFP8 scale_inv is [rows, KScale] and can be read with + // KStride=true. TE columnwise_scale_inv is [KScale, rows] and must be + // read with KStride=false before writing CK's canonical shuffled layout. + if (ctx.use_a_colwise_data) { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(a_scales.dptr), + reinterpret_cast(a_scale_shuffled_ptr), + a_scale_actual_rows, + a_scale_output_rows, + KScale, + stream); + } else { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(a_scales.dptr), + reinterpret_cast(a_scale_shuffled_ptr), + a_scale_actual_rows, + a_scale_output_rows, + KScale, + stream); + } + + if (ctx.use_b_colwise_data) { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(b_scales.dptr), + reinterpret_cast(b_scale_shuffled_ptr), + b_scale_actual_rows, + b_scale_output_rows, + KScale, + stream); + } else { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(b_scales.dptr), + reinterpret_cast(b_scale_shuffled_ptr), + b_scale_actual_rows, + b_scale_output_rows, + KScale, + stream); + } + descs.emplace_back(mx_grouped_gemm_kargs( + a.dptr, + a_scale_shuffled_ptr, + b.dptr, + b_scale_shuffled_ptr, + {/*ds_ptr*/}, + d.dptr, + 1,//kbatch + M, + N, + K, + stride_A, + stride_B, + {/*stride_Ds*/}, + stride_E)); + } + // invoke gemm + bool ok = false; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { + using BType = typename TETypeToCKType::type; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + ok = invoke_mx_grouped_gemm(descs,ctx,s); + }); + }); + }); + return ok; +} + +} // namespace mx_grouped_gemm +} // namespace transformer_engine + +bool ck_tile_mx_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream) { + return transformer_engine::mx_grouped_gemm::ck_tile_mx_grouped_gemm( + A, B, D, group_num, transA, transB, workspace, accumulate, stream); +} diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp new file mode 100644 index 000000000..96d3cd11b --- /dev/null +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp @@ -0,0 +1,16 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +bool ck_tile_mx_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream); + diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..445a5ce0e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -33,7 +33,8 @@ #include "./cutlass_grouped_gemm.cuh" #else #include "ck_grouped_gemm/ck_grouped_gemm.h" -#endif +#include "ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp" + #ifndef __HIP_PLATFORM_AMD__ namespace { @@ -1123,7 +1124,11 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor #else // Currently only support cutlass group gemm on Hopper Arch if (!(is_hopper && use_cutlass)) { + if (!use_cutlass) { #endif + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } cublas_path(); return; } @@ -1155,21 +1160,29 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor }; #endif +#ifdef __HIP_PLATFORM_AMD__ + auto effective_dtype = [](const transformer_engine::Tensor *t) { + if (t->has_data()) { + return t->data.dtype; + } + if (t->has_columnwise_data()) { + return t->columnwise_data.dtype; + } + return t->data.dtype; + }; +#endif + auto is_supported_dtype = [&]() -> bool { auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); #ifdef __HIP_PLATFORM_AMD__ - auto A_dt = inputA->data.dtype; - auto B_dt = inputB->data.dtype; + auto A_dt = effective_dtype(inputA); + auto B_dt = effective_dtype(inputB); auto D_dt = OutputD->data.dtype; - return ( - (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) - ) || - ( - (A_dt == B_dt) && (A_dt == D_dt) && - (is_fp16_dtype(A_dt)) - ); + + return ((is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) || + ((A_dt == B_dt) && (A_dt == D_dt) && is_fp16_dtype(A_dt))); #else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); @@ -1192,11 +1205,22 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && #ifdef __HIP_PLATFORM_AMD__ true) { - if (!ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { - if (warn_fallback) { - NVTE_WARN("Fallback to cuBLAS grouped GEMM."); - } - cublas_path(); + const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode); + + bool handled_by_ck = false; + if (mxfp8_gemm) { + handled_by_ck = ck_tile_mx_grouped_gemm( + A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); + } else { + handled_by_ck = ck_tile_grouped_gemm( + A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); + } + + if (!handled_by_ck) { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } + cublas_path(); } #else all_groups_uniform_k128(B, transb)) {