Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3078,6 +3078,68 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION,
reason="Only enable CUTLASS / CK grouped gemm on Hopper or ROCm",
)
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("layout", ["TN", "NT"])
def test_grouped_gemm_fp32_output(input_dtype, layout):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be done by adding configs/parameters to test_grouped_gemm?

"""Verify grouped GEMM with fp16/bf16 inputs and fp32 output goes through
the CUTLASS / CK grouped GEMM path (not the per-expert fallback). Exercises
the dispatcher is_supported_dtype check for the common bf16/bf16/fp32 case
used during training with fp32 gradient accumulation."""
if input_dtype == torch.bfloat16 and not is_bf16_available():
pytest.skip("bf16 requires sm_80+")
torch.manual_seed(0)
z, m, k, n = 8, 1027, 128, 512

dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist()
m_splits = (torch.tensor(dist + [m]) - torch.tensor([0] + dist)).tolist()

if layout == "TN":
A = [torch.randn(n, k, dtype=input_dtype, device="cuda") for _ in range(z)]
B = list(torch.split(torch.randn(m, k, dtype=input_dtype, device="cuda"), m_splits))
out = [torch.empty(m, n, dtype=torch.float32, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
single_output = True
grad = False
else: # "NT" wgrad: weight gradient in fp32
A = list(torch.split(torch.randn(m, k, dtype=input_dtype, device="cuda"), m_splits))
B = list(torch.split(torch.randn(m, n, dtype=input_dtype, device="cuda"), m_splits))
out = [torch.empty(n, k, dtype=torch.float32, device="cuda") for _ in range(z)]
out_ref = [o.clone() for o in out]
single_output = False
grad = True

os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
try:
for i in range(z):
general_gemm(
A[i], B[i],
out_dtype=torch.float32,
grad=grad,
layout=layout,
out=out_ref[i],
)
if single_output:
out_ref = [torch.cat(out_ref)]

general_grouped_gemm(
A, B, out, [None] * z,
out_dtype=torch.float32,
m_splits=m_splits,
grad=grad,
layout=layout,
single_output=single_output,
)

for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)
finally:
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
Expand Down
5 changes: 3 additions & 2 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1163,12 +1163,13 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
auto A_dt = inputA->data.dtype;
auto B_dt = inputB->data.dtype;
auto D_dt = OutputD->data.dtype;
// CK fp16 dispatcher accepts D in {fp32, fp16, bf16} when A==B is fp16/bf16.
return (
(is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt))
) ||
(
(A_dt == B_dt) && (A_dt == D_dt) &&
(is_fp16_dtype(A_dt))
(A_dt == B_dt) && is_fp16_dtype(A_dt) &&
(is_fp16_dtype(D_dt) || D_dt == transformer_engine::DType::kFloat32)
);
#else
auto A_type = get_cuda_dtype(inputA->data.dtype);
Expand Down
Loading