diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 4a768377e..33847f5b6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -3078,6 +3078,68 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, + reason="Only enable CUTLASS / CK grouped gemm on Hopper or ROCm", +) +@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("layout", ["TN", "NT"]) +def test_grouped_gemm_fp32_output(input_dtype, layout): + """Verify grouped GEMM with fp16/bf16 inputs and fp32 output goes through + the CUTLASS / CK grouped GEMM path (not the per-expert fallback). Exercises + the dispatcher is_supported_dtype check for the common bf16/bf16/fp32 case + used during training with fp32 gradient accumulation.""" + if input_dtype == torch.bfloat16 and not is_bf16_available(): + pytest.skip("bf16 requires sm_80+") + torch.manual_seed(0) + z, m, k, n = 8, 1027, 128, 512 + + dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() + m_splits = (torch.tensor(dist + [m]) - torch.tensor([0] + dist)).tolist() + + if layout == "TN": + A = [torch.randn(n, k, dtype=input_dtype, device="cuda") for _ in range(z)] + B = list(torch.split(torch.randn(m, k, dtype=input_dtype, device="cuda"), m_splits)) + out = [torch.empty(m, n, dtype=torch.float32, device="cuda")] + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] + single_output = True + grad = False + else: # "NT" wgrad: weight gradient in fp32 + A = list(torch.split(torch.randn(m, k, dtype=input_dtype, device="cuda"), m_splits)) + B = list(torch.split(torch.randn(m, n, dtype=input_dtype, device="cuda"), m_splits)) + out = [torch.empty(n, k, dtype=torch.float32, device="cuda") for _ in range(z)] + out_ref = [o.clone() for o in out] + single_output = False + grad = True + + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + try: + for i in range(z): + general_gemm( + A[i], B[i], + out_dtype=torch.float32, + grad=grad, + layout=layout, + out=out_ref[i], + ) + if single_output: + out_ref = [torch.cat(out_ref)] + + general_grouped_gemm( + A, B, out, [None] * z, + out_dtype=torch.float32, + m_splits=m_splits, + grad=grad, + layout=layout, + single_output=single_output, + ) + + for o, o_ref in zip(out, out_ref): + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + finally: + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..0a19ef1e3 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1163,12 +1163,13 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto A_dt = inputA->data.dtype; auto B_dt = inputB->data.dtype; auto D_dt = OutputD->data.dtype; + // CK fp16 dispatcher accepts D in {fp32, fp16, bf16} when A==B is fp16/bf16. return ( (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) ) || ( - (A_dt == B_dt) && (A_dt == D_dt) && - (is_fp16_dtype(A_dt)) + (A_dt == B_dt) && is_fp16_dtype(A_dt) && + (is_fp16_dtype(D_dt) || D_dt == transformer_engine::DType::kFloat32) ); #else auto A_type = get_cuda_dtype(inputA->data.dtype);