From d41657207dfc8d70d75720be9ef6b9de6ee39805 Mon Sep 17 00:00:00 2001 From: lizamd <161388580+lizamd@users.noreply.github.com> Date: Tue, 5 May 2026 00:02:50 +0000 Subject: [PATCH] [ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher The is_supported_dtype check in nvte_multi_tensor_gemm previously required A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32 case where the GEMM output is fp32 for gradient accumulation. This forced a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop), bypassing the CK grouped GEMM kernel entirely on ROCm. The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY (fp32, fp16, bf16). The wrapper check is the only thing that prevents it from being reached. Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B MoE training on MI355X (gfx950): fallback warning rate drops from ~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel itself rejects via Kernel::IsSupportedArgument). Throughput is essentially unchanged in this workload because hipblaslt's per-shape autotuning happens to be competitive with the hardcoded CK tile configs for these MoE shapes; the gain will materialize once the CK dispatcher gains more tile configs (or shape-aware tile selection by aggregate M). This is a CUDA path file; the same patch applies to the AMD path via hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions. Follow-ups (out of scope for this PR): - Add more CK tile configs (e.g. TileCfg_64x256x64, TileCfg_128x256x64) and shape-aware tile selection by aggregate M per call. Currently throughput is unchanged on this workload because the existing hipblaslt fallback is well-tuned and the 3 hardcoded CK tile configs (TileCfg_256x256x64, TileCfg_256x128x64, TileCfg_256x128x64_padding) don't fit MoE shapes (highly variable per-expert M) optimally. Real CK-grouped-GEMM perf wins will materialize once tile selection adapts to M. - Investigate the ~3% of GEMMs that hit Kernel::IsSupportedArgument rejection (likely small per-expert M values that fail tile-size constraints in the current TileCfg_256x* instantiations). --- tests/pytorch/test_numerics.py | 62 +++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 5 +- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 4a768377e..33847f5b6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -3078,6 +3078,68 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, + reason="Only enable CUTLASS / CK grouped gemm on Hopper or ROCm", +) +@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("layout", ["TN", "NT"]) +def test_grouped_gemm_fp32_output(input_dtype, layout): + """Verify grouped GEMM with fp16/bf16 inputs and fp32 output goes through + the CUTLASS / CK grouped GEMM path (not the per-expert fallback). Exercises + the dispatcher is_supported_dtype check for the common bf16/bf16/fp32 case + used during training with fp32 gradient accumulation.""" + if input_dtype == torch.bfloat16 and not is_bf16_available(): + pytest.skip("bf16 requires sm_80+") + torch.manual_seed(0) + z, m, k, n = 8, 1027, 128, 512 + + dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() + m_splits = (torch.tensor(dist + [m]) - torch.tensor([0] + dist)).tolist() + + if layout == "TN": + A = [torch.randn(n, k, dtype=input_dtype, device="cuda") for _ in range(z)] + B = list(torch.split(torch.randn(m, k, dtype=input_dtype, device="cuda"), m_splits)) + out = [torch.empty(m, n, dtype=torch.float32, device="cuda")] + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] + single_output = True + grad = False + else: # "NT" wgrad: weight gradient in fp32 + A = list(torch.split(torch.randn(m, k, dtype=input_dtype, device="cuda"), m_splits)) + B = list(torch.split(torch.randn(m, n, dtype=input_dtype, device="cuda"), m_splits)) + out = [torch.empty(n, k, dtype=torch.float32, device="cuda") for _ in range(z)] + out_ref = [o.clone() for o in out] + single_output = False + grad = True + + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + try: + for i in range(z): + general_gemm( + A[i], B[i], + out_dtype=torch.float32, + grad=grad, + layout=layout, + out=out_ref[i], + ) + if single_output: + out_ref = [torch.cat(out_ref)] + + general_grouped_gemm( + A, B, out, [None] * z, + out_dtype=torch.float32, + m_splits=m_splits, + grad=grad, + layout=layout, + single_output=single_output, + ) + + for o, o_ref in zip(out, out_ref): + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + finally: + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..0a19ef1e3 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1163,12 +1163,13 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto A_dt = inputA->data.dtype; auto B_dt = inputB->data.dtype; auto D_dt = OutputD->data.dtype; + // CK fp16 dispatcher accepts D in {fp32, fp16, bf16} when A==B is fp16/bf16. return ( (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) ) || ( - (A_dt == B_dt) && (A_dt == D_dt) && - (is_fp16_dtype(A_dt)) + (A_dt == B_dt) && is_fp16_dtype(A_dt) && + (is_fp16_dtype(D_dt) || D_dt == transformer_engine::DType::kFloat32) ); #else auto A_type = get_cuda_dtype(inputA->data.dtype);