feat(deepseek_v4): G-loop FP8 blockscale for grouped output LoRA wo_a#677
Closed
zufayu wants to merge 1 commit intofeat/deepseek-v4-pr1-skeletonfrom
Closed
feat(deepseek_v4): G-loop FP8 blockscale for grouped output LoRA wo_a#677zufayu wants to merge 1 commit intofeat/deepseek-v4-pr1-skeletonfrom
zufayu wants to merge 1 commit intofeat/deepseek-v4-pr1-skeletonfrom
Conversation
Alternative to PR #676 (FP8 batched BMM). Replaces the BF16 grouped einsum in DeepseekV4Attention.forward with a G-loop calling gemm_a8w8_blockscale_preshuffle_impl per group — the same FP8 + per-128-block kernel wo_b/wq_b/wkv already use. Why this approach over the BMM kernel: - wo_a's V4 shape is B=n_local_groups=2 (with tp=8), K=4096, N=1024. - The MLA-style batched_gemm_a8w8_per_token_group_prequant kernel grids over B, designed for B≈128. At B=2 it severely underutilizes 304 CU. - This G-loop dispatches G=2 standard 2D GEMMs of shape (M, N=1024, K=4096) — N=1024 is well within the kernel's autotune sweet spot, full CU coverage. - Uses on-disk per-128-block W scale directly: zero precision loss (vs PR #676's per-block→scalar collapse). - process_weights_after_loading becomes a no-op for the default path — LinearBase's standard FP8 + shuffle handles wo_a like any other Linear. Trade-off vs PR #676: - +2 extra kernel launches per layer in eager mode (act_quant + 2 GEMMs vs 1 fused BMM); amortized to ~0 in cudagraph mode once V4's .item() syncs are removed. - Each kernel runs faster due to better grid utilization. - No precision loss from scale-format conversion. Fallback: ATOM_V4_OA_USE_EINSUM=1 keeps the legacy BF16 einsum path. Touches one file: atom/models/deepseek_v4.py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Author
|
Closing in favor of #676. After re-measuring with all caches warm in the same docker session (apples-to-apples), both implementations achieve TPOT parity with einsum baseline. Memory savings are identical (~2-3% per rank, both use FP8 wo_a storage). Decision rationale: prefer #676 because:
This PR (#677) remains a valid alternative if reviewers prefer:
If accuracy testing later reveals an issue with #676's per-block→scalar quant, this branch is the natural fallback. Branch retained at |
zufayu
pushed a commit
that referenced
this pull request
May 1, 2026
Companion to v4_wo_a_tune.py --microbench (BMM vs einsum). Lets us compare the two FP8 paths (BMM in PR #676 / G-loop in PR #677) head-to-head against the einsum baseline on the same shape (B=2, K=4096, N=1024). Useful because PR #676 BMM came in 1.5-2x SLOWER than einsum at the kernel level (rocBLAS BMM is just better tuned for B=2 than the AITER Triton BMM). Need to know if G-loop's gemm_a8w8_blockscale_preshuffle (used widely by ATOM Linears) fares better. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
zufayu
pushed a commit
that referenced
this pull request
May 1, 2026
These were research artifacts used to characterize the wo_a kernel choice (documented in the RFC comment on this PR — see #676 review thread). Findings: - aiter Triton FP8 BMM is 1.5-2x slower than rocBLAS BF16 BMM at V4 wo_a shape (B=2, K=4096, N=1024). - Config tuning via aiter's JSON DB doesn't close the gap (B=2 grid underutilization is fundamental, not a config choice). - G-loop blockscale alternative (PR #677) is even slower (2-7x vs einsum) due to G+1 launches per layer overhead at small M. - Memory savings (-2~3% per rank) is the only kernel-independent win. The scripts are preserved in git history (commits 26a83b4..761c69e on this branch) for anyone wanting to reproduce the measurements. They are not appropriate for the merged production diff. Production change (atom/models/deepseek_v4.py) unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Alternative to #676. Replaces the BF16 grouped einsum at the end of
DeepseekV4Attention.forwardwith a G-loop callinggemm_a8w8_blockscale_preshuffle_implper group — the same FP8 + per-128-block kernel that wo_b / wq_b / wkv already use.Touches one file:
atom/models/deepseek_v4.py(+77 / −34).Why this kernel choice over #676's batched a8w8
V4 wo_a's TP=8 shape:
PR #676 uses
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant. That kernel grids over B with(B, M*N/tile); designed for MLA's B≈128, at B=2 it severely underutilizes 304 CU on MI355X.This PR does G=2 standard 2D GEMMs per layer. Each call is
(M, K=4096) @ (K=4096, N=1024), well withingemm_a8w8_blockscale_preshuffle's autotune sweet spot (same shape pattern as wo_b). Trade-off: +2 extra kernel launches per layer (act_quant + 2 GEMMs vs 1 fused BMM in #676).Additional benefit: zero precision loss. The on-disk per-128-block W scale is consumed directly; no requant to scalar (which #676 required for the BMM kernel's
(1,)scale interface).Verification (PR #650 standard test config, V4-Pro on 8x MI355X)
Correctness
All 4 prompts produce coherent, semantically equivalent output to einsum baseline. Same FP8-quant-noise pattern as #676 (deterministic prompts char-equal, open-ended prompts diverge in early edge-confidence tokens then develop independently).
Perf — three-way comparison
vs PR #676 specifically
process_weights_after_loadingcomplexityThis PR is strictly better than #676 on perf, TTFT, precision, and code simplicity, at the cost of +2 kernel launches per layer in eager mode (cudagraph mode would amortize that to ~zero).
vs einsum baseline
Both this PR and #676 are slightly slower than einsum in eager mode. Root cause is fundamental: V4 wo_a's
(M small, K=4096, N=1024)shape is well-suited to rocBLAS BMM (whattorch.einsumlowers to) and FP8 weight bandwidth savings don't outweigh activation quant + kernel-launch overhead at this size.The pitch is memory (~2-3% per rank), not latency. Eager perf hit is small (2-5%) and likely vanishes in CUDAGraph mode (out of scope per PR #650's "Known limitations").
Toggle
ATOM_V4_OA_USE_EINSUM=1falls back to BF16 einsum. Read once at__init__, zero per-forward overhead.Open question for reviewer
Same as #676 — memory-vs-eager-latency trade. Defer to V4 PR #650 maintainer judgment. This PR vs #676: G-loop is strictly better on every measurable dimension except kernel launch count. Recommend taking this one over #676 if either is to be merged.
Out of scope (vs follow-up PRs)
.item()syncs cleanup.lm_evalGSM8K accuracy gating (test node was offline).gemm_a8w8_blockscale_preshuffleautotune for(M small, N=1024, K=4096)shape — current configs may close remaining 2-5% perf gap.File touched
atom/models/deepseek_v4.py(+77, −34)