Skip to content

[kernels] wvSplitK: pick (YTILE=1, UNRL=4) for gfx11 K<=2048#925

Open
mgehre-amd wants to merge 1 commit intogfx11from
matthias.wvsplitk-gfx11-n1-k2048
Open

[kernels] wvSplitK: pick (YTILE=1, UNRL=4) for gfx11 K<=2048#925
mgehre-amd wants to merge 1 commit intogfx11from
matthias.wvsplitk-gfx11-n1-k2048

Conversation

@mgehre-amd
Copy link
Copy Markdown

@mgehre-amd mgehre-amd commented May 7, 2026

For the bf16/fp16 wvSplitK path, the generic dispatcher in WVSPLIT_TILE_CFG picks suboptimal tile/unroll on gfx11 when K_in <= 2048: a sweep over (YTILE, UNRL) for the Qwen3.5-35B-A3B lm_head shape (M=248320, K=2048, N=1..4) shows (YTILE=1, UNRL=4) is consistently best, beating the heuristic-picked configs by 5-14% across N=1..4 in both bf16 and fp16.

Add a gfx11-only specialization gated on K_in <= 2048 above the generic dispatch branches. The K constraint is intentionally narrow so it only triggers on the targeted regime; bench verification on the existing shape set (Qwen3-4B, Qwen2.5VL-7B qkv/o_proj/gate_up/down/lm_head, K in {2560, 3584, 4096, 9728, 18944}) confirms no shape regresses in both dtypes — at most +0.6% on a single dtype which is within run-to-run noise.

Changes:

  • csrc/rocm/skinny_gemms.cu: extend WVSPLIT_TILE_CFG with a gfx11 K<=2048 branch picking (YTILE=1, UNRL=4) for all N=1..4.
  • tests/kernels/quantization/bench_rocm_skinny_gemm_bf16.py: add the Qwen3.5-35B-A3B lm_head shape (M=248320, K=2048) and Qwen3.5-35B-A3B 1024 proj shape

Bench results (gfx1151 / Strix Halo, A/B vs origin/gfx11):
248320x2048 N=1 bf16: 4599 -> 4348 us (-5.4%)
248320x2048 N=2 bf16: 4977 -> 4290 us (-13.8%)
248320x2048 N=3 bf16: 4881 -> 4282 us (-12.3%)
248320x2048 N=4 bf16: 4806 -> 4291 us (-10.7%)
1024x2048 N=1 bf16: 40 -> 30 us (-25.7%, 96 -> 130 GiB/s)
1024x2048 N=2 bf16: 52 -> 31 us (-40.9%, 75 -> 128 GiB/s)
1024x2048 N=3 bf16: 52 -> 31 us (-40.0%, 75 -> 125 GiB/s)
1024x2048 N=4 bf16: 54 -> 31 us (-43.2%, 72 -> 127 GiB/s)
fp16 numbers within ~1% of bf16 across all the above.
Other shapes: within +/-1% in both dtypes (noise).

For the bf16/fp16 wvSplitK path, the generic dispatcher in
WVSPLIT_TILE_CFG picks suboptimal tile/unroll on gfx11 when
K_in <= 2048. A sweep over (YTILE, UNRL) on the Qwen3.5-35B-A3B
shapes (M=248320 lm_head and M=1024 attention proj) shows
(YTILE=1, UNRL=4) is consistently best across N=1..4 in both
bf16 and fp16, with speedups ranging from 1.05x to 1.78x.

Add a gfx11-only specialization gated on K_in <= 2048 above the
generic dispatch branches. The K constraint is intentionally
narrow so it only triggers on the targeted regime; bench
verification on the existing shape set (Qwen3-4B, Qwen2.5VL-7B
qkv/o_proj/gate_up/down/lm_head, K in {2560, 3584, 4096, 9728,
18944}) confirms no shape regresses in both dtypes — at most +0.6%
on a single dtype which is within run-to-run noise.

Changes:
- csrc/rocm/skinny_gemms.cu: extend WVSPLIT_TILE_CFG with a gfx11
  K<=2048 branch picking (YTILE=1, UNRL=4) for all N=1..4.
- tests/kernels/quantization/bench_rocm_skinny_gemm_bf16.py:
  add the Qwen3.5-35B-A3B lm_head shape (M=248320, K=2048) and
  the M=1024 attention proj shape.

Bench results (gfx1151 / Strix Halo, A/B vs origin/gfx11):
  248320x2048 N=1 bf16: 4599 -> 4348 us (-5.4%, 206 -> 218 GiB/s)
  248320x2048 N=2 bf16: 4977 -> 4290 us (-13.8%, 190 -> 221 GiB/s)
  248320x2048 N=3 bf16: 4881 -> 4282 us (-12.3%, 194 -> 221 GiB/s)
  248320x2048 N=4 bf16: 4806 -> 4291 us (-10.7%, 197 -> 221 GiB/s)
    1024x2048 N=1 bf16:   40 ->   30 us (-25.7%, 96  -> 130 GiB/s)
    1024x2048 N=2 bf16:   52 ->   31 us (-40.9%, 75  -> 128 GiB/s)
    1024x2048 N=3 bf16:   52 ->   31 us (-40.0%, 75  -> 125 GiB/s)
    1024x2048 N=4 bf16:   54 ->   31 us (-43.2%, 72  -> 127 GiB/s)
  fp16 numbers within ~1% of bf16 across all the above.
  All other shapes: within +/-1% in both dtypes (noise).

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd mgehre-amd force-pushed the matthias.wvsplitk-gfx11-n1-k2048 branch from 944e264 to 64a7708 Compare May 8, 2026 06:23
@mgehre-amd mgehre-amd requested a review from eble-amd May 8, 2026 06:24
Copy link
Copy Markdown

@eble-amd eble-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very minor concern is that the condition in the code seems more general than what the description says was measured.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants