Skip to content

Tune hybrid_triton_w4a16 prefill kernel for gfx1151#879

Draft
mgehre-amd wants to merge 3 commits intogfx11from
matthias.tune-hybrid-w4a16-prefill
Draft

Tune hybrid_triton_w4a16 prefill kernel for gfx1151#879
mgehre-amd wants to merge 3 commits intogfx11from
matthias.tune-hybrid-w4a16-prefill

Conversation

@mgehre-amd
Copy link
Copy Markdown

Summary

  • Add num_stages=1 to the Triton kernel launch and update the M>1024 tile heuristic from BM=128,BN=64,BK=64,w=8 to BM=64,BN=256,BK=64,w=8 for better L2 reuse
  • Add UNROLL_K parameter that statically unrolls multiple BLOCK_K tiles per outer loop iteration using tl.static_range, amortising loop overhead

Benchmarks

Qwen2.5-VL-7B-Instruct-AWQ on gfx1151 (Strix Halo), M=1606, K=3584, group_size=128, symmetric.

Kernel-level (profiled, --enforce-eager)

Kernel shape Before (ms) After (ms) Speedup
1606x37888x3584 (qkv+gate_up) 18.57 (23.5 TFLOPS) 17.34 (25.2 TFLOPS) +7%
1606x3584x18944 (down_proj) 10.37 (21.0 TFLOPS) 9.49 (23.0 TFLOPS) +9%
1606x4608x3584 (o_proj) 2.17 (24.5 TFLOPS) 2.21 (24.0 TFLOPS) ~0%
1606x3584x3584 (q_proj) 1.69 (24.4 TFLOPS) 1.66 (24.9 TFLOPS) +2%

End-to-end TTFT

TTFT (median)
Before 1287 ms
After 1243 ms
Improvement 44 ms (3.4%)

Other ideas tested (not adopted)

  • OPTIMIZE_EPILOGUE=1: mixed results, interferes with num_stages=1
  • Eliminate tl.trans(): strided loads 5-8% slower
  • Alternative unpacking (reshape-broadcast vs interleave): 3% slower
  • @triton.autotune: heuristic already matches optimal config
  • Split-K: no shape-specific lag to justify complexity

Test plan

  • pytest tests/kernels/quantization/test_hybrid_w4a16_triton.py (20/20 pass)
  • Kernel-level benchmarks (sweep_prefill_int4.py)
  • End-to-end TTFT (vllm-bench.py with Qwen2.5-VL-7B-Instruct-AWQ)

Add num_stages=1 to the Triton kernel launch and update the M>1024
tile heuristic from BM=128,BN=64,BK=64,w=8 to BM=64,BN=256,BK=64,w=8.

The num_stages=1 parameter was missing (all other ROCm Triton kernels
in vLLM use it) and enables the compiler to pipeline global memory
loads more effectively.

The wider N-tile (256 vs 64) improves L2 cache reuse and CU occupancy
for prefill-sized GEMMs (M=1000-2000).

Benchmarked on gfx1151 with Qwen2.5-VL-7B-Instruct-AWQ shapes
(M=1606, K=3584, group_size=128, symmetric):
  37888x3584: 19.0ms -> 17.5ms (+8%)
  18944x3584: 9.6ms -> 8.7ms  (+10%)
   4608x3584: 2.4ms -> 2.1ms  (+12%)
   3584x3584: 1.8ms -> 1.6ms  (+12%)

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Add UNROLL_K parameter that statically unrolls multiple BLOCK_K tiles
per outer loop iteration using tl.static_range. This amortises loop
overhead and gives the Triton compiler more instruction scheduling
freedom across independent load/unpack/dequant/dot sequences.

For M > 1024 (prefill), UNROLL_K=4 processes 4 groups per outer
iteration. For smaller M (decode-adjacent sizes), UNROLL_K=1
preserves the current behavior to avoid register pressure.

Benchmarked on gfx1151 with Qwen2.5-VL-7B-Instruct-AWQ shapes
(M=1606, K=3584, group_size=128, BM=64, BN=256, BK=64, w=8):
  37888x3584: 17.1ms -> 16.4ms (+4%)
   4608x3584: 2.0ms -> 1.97ms (+3%)
   3584x3584: 1.6ms -> 1.58ms (+2%)
End-to-end TTFT: 1277ms -> 1243ms (2.7%)

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Add test_triton_w4a16_prefill_perf_regression that benchmarks the
Triton W4A16 kernel on representative prefill shapes (M=1606) from
Qwen2.5-7B and Qwen3-4B models and asserts TFLOPS stays within 5%
of hardcoded reference values measured on gfx1151.

Reference values: 24.5-26.0 TFLOPS depending on shape, measured
with num_stages=1, UNROLL_K=4, BM=64/BN=256/BK=64/w=8.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd mgehre-amd force-pushed the matthias.tune-hybrid-w4a16-prefill branch from 37eaa25 to f0b6d24 Compare April 15, 2026 12:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant