[kernels] wvSplitK: pick (YTILE=1, UNRL=4) for gfx11 K<=2048#925
Open
mgehre-amd wants to merge 1 commit intogfx11from
Open
[kernels] wvSplitK: pick (YTILE=1, UNRL=4) for gfx11 K<=2048#925mgehre-amd wants to merge 1 commit intogfx11from
mgehre-amd wants to merge 1 commit intogfx11from
Conversation
For the bf16/fp16 wvSplitK path, the generic dispatcher in
WVSPLIT_TILE_CFG picks suboptimal tile/unroll on gfx11 when
K_in <= 2048. A sweep over (YTILE, UNRL) on the Qwen3.5-35B-A3B
shapes (M=248320 lm_head and M=1024 attention proj) shows
(YTILE=1, UNRL=4) is consistently best across N=1..4 in both
bf16 and fp16, with speedups ranging from 1.05x to 1.78x.
Add a gfx11-only specialization gated on K_in <= 2048 above the
generic dispatch branches. The K constraint is intentionally
narrow so it only triggers on the targeted regime; bench
verification on the existing shape set (Qwen3-4B, Qwen2.5VL-7B
qkv/o_proj/gate_up/down/lm_head, K in {2560, 3584, 4096, 9728,
18944}) confirms no shape regresses in both dtypes — at most +0.6%
on a single dtype which is within run-to-run noise.
Changes:
- csrc/rocm/skinny_gemms.cu: extend WVSPLIT_TILE_CFG with a gfx11
K<=2048 branch picking (YTILE=1, UNRL=4) for all N=1..4.
- tests/kernels/quantization/bench_rocm_skinny_gemm_bf16.py:
add the Qwen3.5-35B-A3B lm_head shape (M=248320, K=2048) and
the M=1024 attention proj shape.
Bench results (gfx1151 / Strix Halo, A/B vs origin/gfx11):
248320x2048 N=1 bf16: 4599 -> 4348 us (-5.4%, 206 -> 218 GiB/s)
248320x2048 N=2 bf16: 4977 -> 4290 us (-13.8%, 190 -> 221 GiB/s)
248320x2048 N=3 bf16: 4881 -> 4282 us (-12.3%, 194 -> 221 GiB/s)
248320x2048 N=4 bf16: 4806 -> 4291 us (-10.7%, 197 -> 221 GiB/s)
1024x2048 N=1 bf16: 40 -> 30 us (-25.7%, 96 -> 130 GiB/s)
1024x2048 N=2 bf16: 52 -> 31 us (-40.9%, 75 -> 128 GiB/s)
1024x2048 N=3 bf16: 52 -> 31 us (-40.0%, 75 -> 125 GiB/s)
1024x2048 N=4 bf16: 54 -> 31 us (-43.2%, 72 -> 127 GiB/s)
fp16 numbers within ~1% of bf16 across all the above.
All other shapes: within +/-1% in both dtypes (noise).
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
944e264 to
64a7708
Compare
eble-amd
approved these changes
May 8, 2026
eble-amd
left a comment
There was a problem hiding this comment.
A very minor concern is that the condition in the code seems more general than what the description says was measured.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
For the bf16/fp16 wvSplitK path, the generic dispatcher in WVSPLIT_TILE_CFG picks suboptimal tile/unroll on gfx11 when K_in <= 2048: a sweep over (YTILE, UNRL) for the Qwen3.5-35B-A3B lm_head shape (M=248320, K=2048, N=1..4) shows (YTILE=1, UNRL=4) is consistently best, beating the heuristic-picked configs by 5-14% across N=1..4 in both bf16 and fp16.
Add a gfx11-only specialization gated on K_in <= 2048 above the generic dispatch branches. The K constraint is intentionally narrow so it only triggers on the targeted regime; bench verification on the existing shape set (Qwen3-4B, Qwen2.5VL-7B qkv/o_proj/gate_up/down/lm_head, K in {2560, 3584, 4096, 9728, 18944}) confirms no shape regresses in both dtypes — at most +0.6% on a single dtype which is within run-to-run noise.
Changes:
Bench results (gfx1151 / Strix Halo, A/B vs origin/gfx11):
248320x2048 N=1 bf16: 4599 -> 4348 us (-5.4%)
248320x2048 N=2 bf16: 4977 -> 4290 us (-13.8%)
248320x2048 N=3 bf16: 4881 -> 4282 us (-12.3%)
248320x2048 N=4 bf16: 4806 -> 4291 us (-10.7%)
1024x2048 N=1 bf16: 40 -> 30 us (-25.7%, 96 -> 130 GiB/s)
1024x2048 N=2 bf16: 52 -> 31 us (-40.9%, 75 -> 128 GiB/s)
1024x2048 N=3 bf16: 52 -> 31 us (-40.0%, 75 -> 125 GiB/s)
1024x2048 N=4 bf16: 54 -> 31 us (-43.2%, 72 -> 127 GiB/s)
fp16 numbers within ~1% of bf16 across all the above.
Other shapes: within +/-1% in both dtypes (noise).