[ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher#573
[ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher#573lizamd wants to merge 1 commit into
Conversation
764cb65 to
ff19241
Compare
|
@matthiasdiener @aris134 Could you review this PR? |
wangye805
left a comment
There was a problem hiding this comment.
Can you edit an existing test or add a new test showing that with your change, bf16/fp16 input and fp32 outputs are going through the ck flow correctly now? Also paste some benchmarking data to this ticket for future reference
| // CK FP16/BF16 grouped GEMM dispatcher (ck_tile_grouped_gemm_fp16_dispatch) | ||
| // already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY | ||
| // (fp32, fp16, bf16). The previous check required A==B==D, which incorrectly | ||
| // rejected the common bf16/bf16/fp32 case (training with fp32 gradient | ||
| // accumulation), forcing a fallback to the per-expert hipblaslt loop. | ||
| // Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}. |
There was a problem hiding this comment.
I think this explanation may be better suited for the PR description rather than an inline code comment.
aris134
left a comment
There was a problem hiding this comment.
Agreed that the CK dispatch logic supports bf16/f32 combination. I would remove the detailed history comment about the previous fallback behavior which is better suited to the PR itself.
|
AMD General
+ @kang, ***@***.***>
Get Outlook for iOS<https://aka.ms/o0ukef>
________________________________
From: Aristotle ***@***.***>
Sent: Wednesday, May 6, 2026 7:57:26 AM
To: ROCm/TransformerEngine ***@***.***>
Cc: Li, Liz ***@***.***>; Author ***@***.***>
Subject: Re: [ROCm/TransformerEngine] [ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher (PR #573)
Caution: This message originated from an External Source. Use proper caution when opening attachments, clicking links, or responding.
@aris134 commented on this pull request.
________________________________
In transformer_engine/common/gemm/cublaslt_gemm.cu<#573 (comment)>:
+ // CK FP16/BF16 grouped GEMM dispatcher (ck_tile_grouped_gemm_fp16_dispatch)
+ // already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
+ // (fp32, fp16, bf16). The previous check required A==B==D, which incorrectly
+ // rejected the common bf16/bf16/fp32 case (training with fp32 gradient
+ // accumulation), forcing a fallback to the per-expert hipblaslt loop.
+ // Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}.
I think this explanation may be better suited for the PR description rather than an inline code comment.
—
Reply to this email directly, view it on GitHub<#573 (review)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/BGPJQJCXTJJJVEQPIRWHEST4ZNHFNAVCNFSM6AAAAACYQ2XVE2VHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHM2DEMZXGA4TCMBVGU>.
Triage notifications on the go with GitHub Mobile for iOS<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675> or Android<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
The is_supported_dtype check in nvte_multi_tensor_gemm previously required
A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32
case where the GEMM output is fp32 for gradient accumulation. This forced
a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop),
bypassing the CK grouped GEMM kernel entirely on ROCm.
The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already
supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(fp32, fp16, bf16). The wrapper check is the only thing that prevents it
from being reached.
Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which
matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B
MoE training on MI355X (gfx950): fallback warning rate drops from
~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel
itself rejects via Kernel::IsSupportedArgument). Throughput is essentially
unchanged in this workload because hipblaslt's per-shape autotuning
happens to be competitive with the hardcoded CK tile configs for these
MoE shapes; the gain will materialize once the CK dispatcher gains more
tile configs (or shape-aware tile selection by aggregate M).
This is a CUDA path file; the same patch applies to the AMD path via
hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on
NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions.
Follow-ups (out of scope for this PR):
- Add more CK tile configs (e.g. TileCfg_64x256x64, TileCfg_128x256x64)
and shape-aware tile selection by aggregate M per call. Currently
throughput is unchanged on this workload because the existing hipblaslt
fallback is well-tuned and the 3 hardcoded CK tile configs
(TileCfg_256x256x64, TileCfg_256x128x64, TileCfg_256x128x64_padding)
don't fit MoE shapes (highly variable per-expert M) optimally. Real
CK-grouped-GEMM perf wins will materialize once tile selection adapts
to M.
- Investigate the ~3% of GEMMs that hit Kernel::IsSupportedArgument
rejection (likely small per-expert M values that fail tile-size
constraints in the current TileCfg_256x* instantiations).
ff19241 to
d416572
Compare
|
@wangye805 @aris134 could you check the new commit? |
| ) | ||
| @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16]) | ||
| @pytest.mark.parametrize("layout", ["TN", "NT"]) | ||
| def test_grouped_gemm_fp32_output(input_dtype, layout): |
There was a problem hiding this comment.
Can it be done by adding configs/parameters to test_grouped_gemm?
The is_supported_dtype check in nvte_multi_tensor_gemm previously required A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32 case where the GEMM output is fp32 for gradient accumulation. This forced a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop), bypassing the CK grouped GEMM kernel entirely on ROCm.
The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY (fp32, fp16, bf16). The wrapper check is the only thing that prevents it from being reached.
Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B MoE training on MI355X (gfx950): fallback warning rate drops from ~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel itself rejects via Kernel::IsSupportedArgument). Throughput is essentially unchanged in this workload because hipblaslt's per-shape autotuning happens to be competitive with the hardcoded CK tile configs for these MoE shapes; the gain will materialize once the CK dispatcher gains more tile configs (or shape-aware tile selection by aggregate M).
This is a CUDA path file; the same patch applies to the AMD path via hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: