Skip to content

feat(deepseek_v4): G-loop FP8 blockscale for grouped output LoRA wo_a#677

Closed
zufayu wants to merge 1 commit intofeat/deepseek-v4-pr1-skeletonfrom
feat/deepseek-v4-wo-a-gloop-blockscale
Closed

feat(deepseek_v4): G-loop FP8 blockscale for grouped output LoRA wo_a#677
zufayu wants to merge 1 commit intofeat/deepseek-v4-pr1-skeletonfrom
feat/deepseek-v4-wo-a-gloop-blockscale

Conversation

@zufayu
Copy link
Copy Markdown

@zufayu zufayu commented May 1, 2026

Summary

Alternative to #676. Replaces the BF16 grouped einsum at the end of DeepseekV4Attention.forward with a G-loop calling gemm_a8w8_blockscale_preshuffle_impl per group — the same FP8 + per-128-block kernel that wo_b / wq_b / wkv already use.

Touches one file: atom/models/deepseek_v4.py (+77 / −34).

Why this kernel choice over #676's batched a8w8

V4 wo_a's TP=8 shape:

dim value
B (n_local_groups) 2
K (d_per_group) 4096
N (o_lora_rank) 1024
M (token batch) 4 (decode) ~ 1024 (prefill)

PR #676 uses batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant. That kernel grids over B with (B, M*N/tile); designed for MLA's B≈128, at B=2 it severely underutilizes 304 CU on MI355X.

This PR does G=2 standard 2D GEMMs per layer. Each call is (M, K=4096) @ (K=4096, N=1024), well within gemm_a8w8_blockscale_preshuffle's autotune sweet spot (same shape pattern as wo_b). Trade-off: +2 extra kernel launches per layer (act_quant + 2 GEMMs vs 1 fused BMM in #676).

Additional benefit: zero precision loss. The on-disk per-128-block W scale is consumed directly; no requant to scalar (which #676 required for the BMM kernel's (1,) scale interface).

Verification (PR #650 standard test config, V4-Pro on 8x MI355X)

ATOM_USE_TRITON_MOE=1 AITER_LOG_LEVEL=WARNING \
python -m atom.examples.simple_inference \
  --model /path/to/DeepSeek-V4-Pro \
  --kv_cache_dtype fp8 -tp 8 \
  --max-num-seqs 4 --max-num-batched-tokens 1024 --max-model-len 1024 \
  --gpu-memory-utilization 0.85 --enforce-eager \
  --temperature 0.0 --max-tokens 128

Correctness

All 4 prompts produce coherent, semantically equivalent output to einsum baseline. Same FP8-quant-noise pattern as #676 (deterministic prompts char-equal, open-ended prompts diverge in early edge-confidence tokens then develop independently).

Perf — three-way comparison

Req output einsum (baseline) BMM (#676) G-loop (this PR)
Req 2 (1+2+3) 15-16 tok eos 0.506s 0.523s (+3.4%) 0.516s (+2.0%)
Req 1 (primes) 94 tok eos 0.431s 0.448s (+3.9%) 0.440s (+2.1%)
Req 0 (intro) 128 tok max 0.406s 0.433s (+6.7%) 0.426s (+4.9%)
Req 3 (增肌) 128 tok max 0.406s 0.433s (+6.7%) 0.426s (+4.9%)
TTFT (first call) 1.38s 24.72s 4.36s
VRAM / rank baseline −2~3% −2~3% (same as #676)

vs PR #676 specifically

Dimension #676 (BMM) This PR (G-loop)
TPOT regression vs einsum +3.4 ~ +6.7% +2.0 ~ +4.9%
TTFT first-call cost +23s (BMM kernel JIT + autotune) +3s
Precision loss per-block → scalar W scale (one-time collapse) zero (per-block scale used directly)
process_weights_after_loading complexity dequant + reshape + requant + buffer + placeholder no-op (LinearBase standard FP8 + shuffle)
Memory savings Same Same

This PR is strictly better than #676 on perf, TTFT, precision, and code simplicity, at the cost of +2 kernel launches per layer in eager mode (cudagraph mode would amortize that to ~zero).

vs einsum baseline

Both this PR and #676 are slightly slower than einsum in eager mode. Root cause is fundamental: V4 wo_a's (M small, K=4096, N=1024) shape is well-suited to rocBLAS BMM (what torch.einsum lowers to) and FP8 weight bandwidth savings don't outweigh activation quant + kernel-launch overhead at this size.

The pitch is memory (~2-3% per rank), not latency. Eager perf hit is small (2-5%) and likely vanishes in CUDAGraph mode (out of scope per PR #650's "Known limitations").

Toggle

ATOM_V4_OA_USE_EINSUM=1 falls back to BF16 einsum. Read once at __init__, zero per-forward overhead.

Open question for reviewer

Same as #676 — memory-vs-eager-latency trade. Defer to V4 PR #650 maintainer judgment. This PR vs #676: G-loop is strictly better on every measurable dimension except kernel launch count. Recommend taking this one over #676 if either is to be merged.

Out of scope (vs follow-up PRs)

File touched

atom/models/deepseek_v4.py (+77, −34)

Alternative to PR #676 (FP8 batched BMM). Replaces the BF16 grouped einsum
in DeepseekV4Attention.forward with a G-loop calling
gemm_a8w8_blockscale_preshuffle_impl per group — the same FP8 + per-128-block
kernel wo_b/wq_b/wkv already use.

Why this approach over the BMM kernel:
- wo_a's V4 shape is B=n_local_groups=2 (with tp=8), K=4096, N=1024.
- The MLA-style batched_gemm_a8w8_per_token_group_prequant kernel grids over
  B, designed for B≈128. At B=2 it severely underutilizes 304 CU.
- This G-loop dispatches G=2 standard 2D GEMMs of shape (M, N=1024, K=4096)
  — N=1024 is well within the kernel's autotune sweet spot, full CU coverage.
- Uses on-disk per-128-block W scale directly: zero precision loss (vs PR
  #676's per-block→scalar collapse).
- process_weights_after_loading becomes a no-op for the default path —
  LinearBase's standard FP8 + shuffle handles wo_a like any other Linear.

Trade-off vs PR #676:
- +2 extra kernel launches per layer in eager mode (act_quant + 2 GEMMs vs
  1 fused BMM); amortized to ~0 in cudagraph mode once V4's .item() syncs
  are removed.
- Each kernel runs faster due to better grid utilization.
- No precision loss from scale-format conversion.

Fallback: ATOM_V4_OA_USE_EINSUM=1 keeps the legacy BF16 einsum path.

Touches one file: atom/models/deepseek_v4.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@zufayu
Copy link
Copy Markdown
Author

zufayu commented May 1, 2026

Closing in favor of #676.

After re-measuring with all caches warm in the same docker session (apples-to-apples), both implementations achieve TPOT parity with einsum baseline. Memory savings are identical (~2-3% per rank, both use FP8 wo_a storage).

Decision rationale: prefer #676 because:

  • Architectural alignment: feat(deepseek_v4): FP8 batched BMM for grouped output LoRA wo_a #676 uses the same Triton kernel (batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant) that MLA _v_up_proj_and_o_proj already uses on this codebase. Single shared kernel = single tuning surface, single optimization target.
  • Single launch per layer (vs G+1=3 launches in this PR). Will matter more in cudagraph mode where launch overhead vanishes for the BMM but G-loop still spends GPU dispatch cycles.
  • Slightly leaner code in process_weights_after_loading.

This PR (#677) remains a valid alternative if reviewers prefer:

If accuracy testing later reveals an issue with #676's per-block→scalar quant, this branch is the natural fallback. Branch retained at feat/deepseek-v4-wo-a-gloop-blockscale.

@zufayu zufayu closed this May 1, 2026
zufayu pushed a commit that referenced this pull request May 1, 2026
Companion to v4_wo_a_tune.py --microbench (BMM vs einsum). Lets us compare
the two FP8 paths (BMM in PR #676 / G-loop in PR #677) head-to-head against
the einsum baseline on the same shape (B=2, K=4096, N=1024).

Useful because PR #676 BMM came in 1.5-2x SLOWER than einsum at the kernel
level (rocBLAS BMM is just better tuned for B=2 than the AITER Triton BMM).
Need to know if G-loop's gemm_a8w8_blockscale_preshuffle (used widely by
ATOM Linears) fares better.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
zufayu pushed a commit that referenced this pull request May 1, 2026
These were research artifacts used to characterize the wo_a kernel choice
(documented in the RFC comment on this PR — see #676 review thread).
Findings:
- aiter Triton FP8 BMM is 1.5-2x slower than rocBLAS BF16 BMM at V4 wo_a
  shape (B=2, K=4096, N=1024).
- Config tuning via aiter's JSON DB doesn't close the gap (B=2 grid
  underutilization is fundamental, not a config choice).
- G-loop blockscale alternative (PR #677) is even slower (2-7x vs einsum)
  due to G+1 launches per layer overhead at small M.
- Memory savings (-2~3% per rank) is the only kernel-independent win.

The scripts are preserved in git history (commits 26a83b4..761c69e on
this branch) for anyone wanting to reproduce the measurements. They are
not appropriate for the merged production diff.

Production change (atom/models/deepseek_v4.py) unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants