Tune hybrid_triton_w4a16 prefill kernel for gfx1151#879
Draft
mgehre-amd wants to merge 3 commits intogfx11from
Draft
Tune hybrid_triton_w4a16 prefill kernel for gfx1151#879mgehre-amd wants to merge 3 commits intogfx11from
mgehre-amd wants to merge 3 commits intogfx11from
Conversation
Add num_stages=1 to the Triton kernel launch and update the M>1024 tile heuristic from BM=128,BN=64,BK=64,w=8 to BM=64,BN=256,BK=64,w=8. The num_stages=1 parameter was missing (all other ROCm Triton kernels in vLLM use it) and enables the compiler to pipeline global memory loads more effectively. The wider N-tile (256 vs 64) improves L2 cache reuse and CU occupancy for prefill-sized GEMMs (M=1000-2000). Benchmarked on gfx1151 with Qwen2.5-VL-7B-Instruct-AWQ shapes (M=1606, K=3584, group_size=128, symmetric): 37888x3584: 19.0ms -> 17.5ms (+8%) 18944x3584: 9.6ms -> 8.7ms (+10%) 4608x3584: 2.4ms -> 2.1ms (+12%) 3584x3584: 1.8ms -> 1.6ms (+12%) Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Add UNROLL_K parameter that statically unrolls multiple BLOCK_K tiles per outer loop iteration using tl.static_range. This amortises loop overhead and gives the Triton compiler more instruction scheduling freedom across independent load/unpack/dequant/dot sequences. For M > 1024 (prefill), UNROLL_K=4 processes 4 groups per outer iteration. For smaller M (decode-adjacent sizes), UNROLL_K=1 preserves the current behavior to avoid register pressure. Benchmarked on gfx1151 with Qwen2.5-VL-7B-Instruct-AWQ shapes (M=1606, K=3584, group_size=128, BM=64, BN=256, BK=64, w=8): 37888x3584: 17.1ms -> 16.4ms (+4%) 4608x3584: 2.0ms -> 1.97ms (+3%) 3584x3584: 1.6ms -> 1.58ms (+2%) End-to-end TTFT: 1277ms -> 1243ms (2.7%) Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Add test_triton_w4a16_prefill_perf_regression that benchmarks the Triton W4A16 kernel on representative prefill shapes (M=1606) from Qwen2.5-7B and Qwen3-4B models and asserts TFLOPS stays within 5% of hardcoded reference values measured on gfx1151. Reference values: 24.5-26.0 TFLOPS depending on shape, measured with num_stages=1, UNROLL_K=4, BM=64/BN=256/BK=64/w=8. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
37eaa25 to
f0b6d24
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
num_stages=1to the Triton kernel launch and update the M>1024 tile heuristic fromBM=128,BN=64,BK=64,w=8toBM=64,BN=256,BK=64,w=8for better L2 reuseUNROLL_Kparameter that statically unrolls multiple BLOCK_K tiles per outer loop iteration usingtl.static_range, amortising loop overheadBenchmarks
Qwen2.5-VL-7B-Instruct-AWQ on gfx1151 (Strix Halo), M=1606, K=3584, group_size=128, symmetric.
Kernel-level (profiled,
--enforce-eager)End-to-end TTFT
Other ideas tested (not adopted)
OPTIMIZE_EPILOGUE=1: mixed results, interferes withnum_stages=1tl.trans(): strided loads 5-8% slower@triton.autotune: heuristic already matches optimal configTest plan
pytest tests/kernels/quantization/test_hybrid_w4a16_triton.py(20/20 pass)sweep_prefill_int4.py)vllm-bench.pywith Qwen2.5-VL-7B-Instruct-AWQ)