feat(gemm): optimize blockwise FP8 GEMM Triton kernels (gfx950 / MI355X)#339
feat(gemm): optimize blockwise FP8 GEMM Triton kernels (gfx950 / MI355X)#339ChengYao-amd wants to merge 4 commits into
Conversation
411e894 to
08db45b
Compare
(gfx942 / MI300X) Performance ReportComparison: 1. Executive Summary
Highlights
2. Per-Model Summary (Mean Improvement)
The optimization benefits small- and mid-sized models more uniformly. Large-model shapes (405B / 72B) show a smaller but still consistently positive backward speedup. 3. Top 10 Backward Improvements
These top backward gains share the pattern of moderate N/K dimensions (N ≤ 6144, K ≤ 4096) combined with relatively large M. The backward pass computes 4. Top 10 Forward Improvements
The largest forward gains land on (a) small square shapes (≤4096) and (b) very wide N shapes (N ≥ 100k) of the 405B model, suggesting the optimized kernel handles both ends of the shape spectrum better than the baseline. 5. Forward RegressionsTen shapes show a small forward slowdown. All ten still report a healthy positive backward speedup.
Observations on the regression pattern:
6. Correctness CheckOne test case has a different correctness status between the two builds.
The optimized kernel resolves the previously failing accuracy check on the largest 405B shape. All remaining 83 cases keep 7. Conclusion
|
Optimize the Triton-backend blockwise FP8 GEMM path. All performance
gains land on the MI355X (gfx950) target; gfx942 falls back to the
shared kernel and is unaffected.
Forward + backward (geomean over 84 PASS shapes, vs current main):
forward : 1470.1 -> 1483.6 (+0.92%)
backward : 1022.7 -> 1198.0 (+17.13%)
combined-step: 1138.1 -> 1276.1 (+12.13%)
Backward distribution is heavily right-skewed: 56/84 shapes >= +1% on bwd,
13/84 shapes >= +5% on bwd, only 8/84 shapes <= -1% on bwd (the regressions
all sit in the very-large N >= 57344 wgrad sub-space and are documented in
pr_report.md).
Source-level changes (4 files):
* primus_turbo/triton/gemm/gemm_fp8_kernel.py
- Layout-specialised kernels: split the shared kernel into
`_blockwise_fp8_nt_kernel` (forward), `_blockwise_fp8_nn_kernel`
(dgrad) and a new `_blockwise_fp8_tn_kernel` (wgrad). The shared
kernel is kept only for grouped GEMM.
- EVEN_K fast path on NT/NN to avoid K-tail masking.
- Wgrad transposed-store epilogue: `tl.trans(acc.to(bf16))` + swapped
pointer addressing turns `buffer_store_short x 64` into
`(buffer|global)_store_dwordx4 x 4-8` on all cached binaries.
- Autotune space: add `BLOCK_N=64` candidates and `num_stages=3` for
fwd/dgrad, while wgrad keeps `num_stages in {1,2}` to avoid the
Triton 3.7 AMD-backend SIGABRT on the dual strided-K load pattern.
- NN autotune stacks `matrix_instr_nonkdim=16` candidates only on the
`BLOCK_M=256` path, where 32x32x64 MFMA overflows 16 AGPR/warp.
Picked by 49/69 NN keys (71%); responsible for the +5-6% bwd gain
on Llama-2-70B / Qwen2.5-72B large-K shapes.
- Llama-3.1-405B M=32768 N=106496 K=16384 correctness fix (main FAILs
this shape with dA=NaN; this branch passes 84/84).
* primus_turbo/triton/quantization/quant_blockwise.py
- Single-launch row+col fusion (`quant_fp8_blockwise_dual`) so the
forward path produces both row- and col-major scaled tensors in one
kernel instead of two; the col output is `save_for_backward`-ed and
consumed by wgrad without re-quantising.
* primus_turbo/pytorch/kernels/quantization/quantization_impl.py
- Add the dual-quant impl + custom-op registration alongside the
existing single-direction quantization.
* primus_turbo/pytorch/ops/gemm_fp8.py
- Wire the dual-quant fusion through the Float8 forward, and reuse
the cached col-major quant in the backward without any
id()-keyed wrapper caches (those were rolled back in 4394fe0
because their hit rate only existed under the benchmark's
100-iteration tensor reuse, not in real training).
Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: xiaobochen-amd <xiaobo.chen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
086cd1b to
bf3893f
Compare
bf3893f to
4b1402d
Compare
Optimize blockwise FP8 GEMM Triton kernels (gfx950 / MI355X)
PR branch:
dev/yaoc/optimize_fp8_blockwise_gemm_tritonBase branch:
main@06b8d3f(Add HYBRID FP8 format support for Triton backend in gemm and grouped_gemm)Commit on this branch:
opt: blockwise FP8 GEMM Triton kernels (gfx950 / MI355X)(4 files, +597 / −46)1. TL;DR
This PR optimizes the blockwise FP8 GEMM Triton path on gfx950 (MI355X).
Source-level changes touch four files only:
primus_turbo/triton/gemm/gemm_fp8_kernel.pyprimus_turbo/triton/quantization/quant_blockwise.pyprimus_turbo/pytorch/kernels/quantization/quantization_impl.pyprimus_turbo/pytorch/ops/gemm_fp8.pyAggregate impact on the 84-shape
bench_gemm_turbo.pysuite(
PRIMUS_TURBO_GEMM_BACKEND=TRITON,--dtype fp8 --granularity blockwise,geomean over the 83 shapes both branches pass correctness on):
main(gmean)¹
Combined Step TFLOPS = 6·M·N·K / (fwd_ms + bwd_ms) / 1e9.Additional wins:
mainFAILsTestID 47(Llama-3.1-405BM=32768 N=106496 K=16384, dA=NaN). This PR PASSes all 84 shapes.
all 83 PASS-on-both shapes are ≥ +1% on combined-step
(median +13.13%, p25 +10.95%, p75 +15.30%), 82 / 83 shapes are
≥ +5%. The backward delta is strictly positive on every shape
(median +18.48%, min +3.22%, max +25.76%). Forward is centred near
zero (median +0.63%, mean +0.28%) but 5 shapes carry a fwd
regression ≤ −5% — see §4.1 for why these still come out net
positive on combined-step.
CK / hipBLASLt / grouped-GEMM paths are not touched. gfx942 (MI300X /
MI325X) falls back to the unchanged shared kernel.
2. Scope
The PR contains exactly four files of source change:
primus_turbo/triton/gemm/gemm_fp8_kernel.pyprimus_turbo/triton/quantization/quant_blockwise.pyquant_fp8_blockwise_dual)primus_turbo/pytorch/kernels/quantization/quantization_impl.pyprimus_turbo/pytorch/ops/gemm_fp8.pyOut of scope (not in this PR):
agent/,agent/historical_experience/— agent-tooling assets areshipped in a separate PR.
gemm_kernel.py— only the FP8 blockwise kernel file is touched; theshared bf16/fp16 GEMM kernel is unchanged.
gemm_fp8_kernel.pyshared kernel_blockwise_fp8_autotune_kernelispreserved (still used by grouped-GEMM); only NT/NN/TN dispatch is
switched to the layout-specialised variants.
3. Optimization techniques (kept in this PR)
3.1 Layout-specialised kernels with
EVEN_Kfast pathReplace the single shared
_blockwise_fp8_autotune_kernel(still keptfor grouped-GEMM) with three dispatch-specialised variants:
_blockwise_fp8_nt_kernel— forward (A @ B.T)_blockwise_fp8_nn_kernel— dgrad (grad_out @ B)_blockwise_fp8_tn_kernel— wgrad (A.T @ grad_out), new in §3.4Each variant hard-codes its layout
constexpr(A_K_CONTIGUOUS,B_K_CONTIGUOUS,SCALE_2D_B) at the kernel level, eliminating theruntime per-call branch decisions and shrinking each binary.
EVEN_Kfast path on NT/NN: whenK % BLOCK_K == 0(the commontraining shape), the K-tail mask
tl.load(..., mask=...)is replacedwith an unmasked path, removing two
v_cmp_*+ onev_cndmask_*perK iteration.
3.2 Autotune space extension
The
mainautotune space is 32 configs(
BM ∈ {128,256},BN=128,BK=128,kpack ∈ {1,2},GM ∈ {4,8},CHUNK ∈ {32,64},num_stages ∈ {1,2},num_warps = 4 if BM=128 else 8).This PR extends it to 96 configs on gfx950 (forward / dgrad):
BLOCK_N ∈ {64,128}andnum_stages ∈ {1,2,3}are added.The wgrad path is restricted to
num_stages ∈ {1,2}to work around aTriton 3.7 AMD-backend assertion that fires only on the dual
strided-K load pattern combined with
num_stages=3. This is thesingle-line fix carried in the otherwise larger §3.1 change set;
maindoes not need it becausemain's autotune space does notcontain
num_stages=3in the first place.3.3 NN-only
matrix_instr_nonkdim=16candidates onBLOCK_M=256NN (dgrad) at
BM=256 BN=128 nw=8 ns=3overflows the 16 AGPR / warpbudget that gfx950's default
v_mfma_f32_32x32x64_f8f6f4requires forits accumulator (405
v_accvgpr_*ops, ~24% of the NN kernel ASM inthe dev-branch baseline). Stacking
matrix_instr_nonkdim=16candidatesonly on the BM=256 path lets Triton compile a
v_mfma_f32_16x16x128_f8f6f4variant that uses 4 AGPR / warp; theautotuner picks it on the shapes where the AGPR pressure dominates.
The
BM=128 nw=4path is left alone because it already only uses16 AGPR / warp, and stacking a 16x16 MFMA candidate there would just
inflate cold-autotune time without any binary winning.
Empirical hit rate:
nonkdim=16is selected by 49 / 69 NN cache keys(71%). Top wins land on Llama-2-70B
M=16384 N=8192 K=28672(bwd+5.86%) and Qwen2.5-72B
M=8192 N=8192 K=29568(bwd +5.96%).Cost: NN autotune candidate count 96 → 144 (+50% cold-bench time on a
fresh cache).
3.4 wgrad-specialised kernel with transposed-store epilogue
In the dev-branch baseline, wgrad (
_blockwise_tn,trans_c=True)shared the kernel with NT/NN and produced its
(N, M)output via theshared-kernel store path. With
stride_cm=1, stride_cn=N, the BNdimension is strided in physical memory while BM is contiguous, which
the compiler can only serialise into 64 ×
buffer_store_short(32low-half + 32 high-half) per tile — about 6-16× lower effective store
throughput than the NT/NN path's 8 ×
buffer_store_dwordx4.The new
_blockwise_fp8_tn_kernelrewrites the epilogue:The
tl.transswapsBNto the outer dim, swapping pointer indexingputs the contiguous-stride dim on the inner axis, and the store
naturally coalesces into
dwordx4.ASM-level verification on a fresh Triton cache:
96 / 96 wgrad binaries now use
(buffer|global)_store_dwordx4 × 4-8,fully eliminating the
buffer_store_short × 64pattern.The shared kernel is still used by grouped-GEMM (which uses
trans_c=Falseand already stores into a contiguous(M, N)layout, so it does not need the transposed-store path).
3.5 Single-launch row + col fused quant (
quant_fp8_blockwise_dual)The forward path needs both row-major (for
A @ B.T) and col-major(saved for
A.T @ grad_outin backward) FP8 blockwise quantisationsof the input activation. The dev-branch baseline launched two separate
quant kernels; this PR fuses them into a single
quant_fp8_blockwise_duallaunch:
tl.loadof the bf16 activation tile.The col-major FP8 output is
save_for_backward-ed and consumeddirectly by wgrad without re-quantising.
This fusion is implemented at kernel level rather than as a
wrapper-level cache (no
id()-keyed reuse of quant outputs acrosscalls), so the gain holds in real LLM training where activations
are fresh each iteration; it does not depend on the benchmark
harness's 100-iteration tensor-reuse pattern.
3.6 Llama-3.1-405B correctness fix
mainFAILsTestID 47(M=32768 N=106496 K=16384) withdA=NaNinthe dgrad NN path. This PR includes the targeted index-arithmetic fix
in
_blockwise_fp8_nn_kernel(5 lines). The shape now passes withdA=28.6 dB.4. Performance results
4.1 Aggregate (PASS-on-both, 83 cases)
mainPer-shape delta distribution over the same 83 shapes:
The combined-step delta is strictly positive on every shape that both
branches pass; the smallest combined gain is +3.29% (TestID 43,
Llama-3.1-405B M=16384 N=106496 K=16384) and the largest is +18.51%
(TestID 64, Qwen2.5-72B M=8192 N=8192 K=29568).
5 shapes carry a forward regression ≤ −5% — all of them have
(N, K) = (4096, 4096)or(N, K) = (4096, 11008):A further 5 shapes show a fwd regression in the −5% to −2.5% band
(TID 8, 14, 49, 2, 62). All 10 net-negative-fwd shapes remain net
positive on combined-step thanks to a +12-26% bwd uplift on the same
shapes.
4.2 Per-model (PASS-on-both)
mainnumbers come fromtmp/main/gemm_fp8_blockwise_triton_benchmark.csv.This PRnumbers come fromtmp/new-branch/gemm_fp8_blockwise_triton_benchmark.csv.maincmb gmean² Llama-3.1-405B excludes
TestID 47from the comparison because itFAILs on
main(PASS on this PR — see §3.6 and §4.4). The smaller+9.15% on this model is dominated by the two large-N wgrad shapes
(TID 43 +3.29%, TID 39 +7.72% combined) where the BN tile count
explodes to
N / 128 = 832; the other 9 Llama-3.1-405B shapes are at+7.7% to +12.7% combined and match the other models.
4.3 Best combined movers (vs
main)maincombThe largest gains span both small-K + mid-M (
K=4096, thequant-fusion + wgrad transposed-store §3.5/§3.4 dominate) and
large-K + mid-M (
K ∈ {28672, 29568}, the NN nonkdim=16 §3.3plus the wgrad transposed-store §3.4 stack). The forward delta on
these top-bwd shapes is in the −1% to +3% range, confirming the gains
live primarily on the backward path.
4.4 Smallest combined gains (vs
main)The "worst" rows below are all still positive — no shape that PASSes
on both branches regresses on combined-step:
maincombTwo distinct populations show up in this Bottom-5:
N ∈ {106496, 59136}) wgrad-boundshapes. The transposed-store epilogue (§3.4) brings less here
because the BN tile count blows up (
N / 128 = 462-832) and theper-tile
tl.trans(acc)exhausts register budget; see §5.1.(8192, 4096, 4096)square-ish shapes whosecombined gain is held back by a fwd regression of −6.85% to −8.29%
(cold-autotune drift, see §5.3). Their bwd is still +12-13%, so
combined remains net positive.
TestID 47(Llama-3.1-405B M=32768 N=106496 K=16384) FAILs onmainand does not appear in the apples-to-apples table. On this PR it
PASSes with combined = 1347.6 TFLOPS (fwd 1626.5 TFLOPS / 70.30 ms,
bwd 1241.1 TFLOPS / 184.27 ms).
5. Known limitations
5.1 Large-N wgrad shapes get smaller combined gain
Three shapes with
N ≥ 57344produce noticeably smaller combinedgain than the +12.71% gmean:
Root cause: on these shapes the BN tile count explodes to
N / 128 = 462-832and the per-tiletl.trans(acc)in the wgradtransposed-store epilogue (§3.4) exhausts register budget, spilling
the transpose to LDS-exchange. The shapes are still strictly net
positive vs
mainbecausemain's wgrad starts from a much lowerbaseline (Bwd TFLOPS in the 1100-1190 range, vs the post-PR 1200+).
Mitigation, deferred to a follow-up PR: add a
WGRAD_TRANS_STORE: tl.constexprswitch and let the autotuner pickbetween the transposed-store path and the legacy short-store path
per shape; expected to recover the regressed subspace without giving
up the +17.83% bwd gmean on the other 80 shapes.
5.2 Cold-autotune wall-time
This PR's autotune candidate count is larger than
main's:BLOCK_N=64andnum_stages=3)nonkdim=16stack onBM=256)End-to-end cold benchmark wall-time on the 84-shape suite goes from
~25 min (main) to ~120 min (this PR). For production deployment we
recommend caching the per-shape best config into a lookup table after
a one-time autotune sweep; this is not in this PR but is straight
to add as a follow-up.
5.3 Cold-autotune drift on near-tie configs
The 5 shapes with
Δ fwd ≤ −5%listed in §4.1 are all in the(*, 4096, 4096)/(4096, 4096, *)square-ish family. On theseshapes several autotune candidates are within ~1% of each other on
fwd time, and the cold-cache pick can flip between sweeps; this
shows up as a per-shape ±5-10% fwd drift between independent runs
even when the underlying kernel is unchanged. The drift does not
affect the PR-vs-main aggregate (every drifting shape is still net
positive on combined-step thanks to the +12-26% bwd uplift), but
pinning the best config per shape (§5.2) is the recommended fix in
production.
6. Reproducing the numbers
Environment used for the numbers in this report:
2.10.0a0+git449b1763.7.0+gitf4a3db9edev/yaoc/optimize_fp8_blockwise_gemm_tritonPRIMUS_TURBO_GEMM_BACKEND=TRITONblockwise, FP8 E4M3,block_size=128