Skip to content

feat(deepseek_v4): FP8 batched BMM for grouped output LoRA wo_a#676

Open
zufayu wants to merge 35 commits intomainfrom
feat/deepseek-v4-wo-a-fp8-bmm
Open

feat(deepseek_v4): FP8 batched BMM for grouped output LoRA wo_a#676
zufayu wants to merge 35 commits intomainfrom
feat/deepseek-v4-wo-a-fp8-bmm

Conversation

@zufayu
Copy link
Copy Markdown

@zufayu zufayu commented May 1, 2026

Summary

Replace the BF16 grouped einsum at the end of DeepseekV4Attention.forward with aiter.ops.triton.gemm.batched.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant (same Triton kernel MLA's _v_up_proj_and_o_proj already uses).

Builds on PR #650 (DeepSeek V4 PR1 skeleton). Touches one file: atom/models/deepseek_v4.py.

What changed

  • __init__: read ATOM_V4_OA_USE_EINSUM env once into self._use_oa_einsum.
  • process_weights_after_loading:
    • Default (env=0, BMM path): dequant on-disk FP8+per-128-block-scale wo_a → bf16 → reshape (G_local, R, K)dynamic_per_batched_tensor_quant (inlined as _v4_dynamic_per_batched_tensor_quant) → store as self.W_OA (FP8 buffer) + self.W_OA_scale (scalar buffer). Replace self.wo_a.weight with 1-element placeholder. Adds order assertion + idempotent guard.
    • env=1 (einsum fallback): existing dequant path, no requant.
  • forward (the 4 lines):
    • Default: _aiter_triton_fp8_bmm(o, W_OA, W_OA_scale, group_size=128, transpose_bm_in=True, transpose_bm=True) — single fused launch with on-the-fly per-token-group X quant; output layout matches einsum, so wo_b reshape unchanged.
    • env=1: legacy torch.einsum("sgd,grd->sgr", o, wo_a).

Verification (PR #650 standard test config)

ATOM_USE_TRITON_MOE=1 AITER_LOG_LEVEL=WARNING \
python -m atom.examples.simple_inference \
  --model /path/to/DeepSeek-V4-Pro \
  --kv_cache_dtype fp8 -tp 8 \
  --max-num-seqs 4 --max-num-batched-tokens 1024 --max-model-len 1024 \
  --gpu-memory-utilization 0.85 --enforce-eager \
  --temperature 0.0 --max-tokens 128

Correctness (4-prompt batch on V4-Pro / 8x MI355X)

All 4 prompts produce coherent, semantically equivalent output to einsum baseline. Deterministic prompts match by content; open-ended prompts diverge in expected FP8 quant noise pattern (early edge-confidence token flips, then content develops independently — both sides remain coherent and factually correct):

Prompt BMM einsum Comment
1+2+3=? = **6** (16 tok eos) = 6 (15 tok eos) bold token diff (logit perturbation)
list primes < 100 identical identical high-confidence tokens stable
introduce yourself "...Chinese company DeepSeek..." "...DeepSeek (深度求索) company..." both correct intros
如何增肌10公斤 "几乎不可能实现的目标" "在生理上几乎不可能实现的目标" both correct advice

Perf (eager mode, ATOM_USE_TRITON_MOE=1, --max-tokens 128, V4-Pro on 8x MI355X)

Same docker session, all caches warm, apples-to-apples:

Req output einsum BMM Δ
Req 2 (1+2+3) 16 tok eos 0.514s 0.518s +0.8%
Req 1 (primes) 94 tok eos 0.455s 0.449s −1.3%
Req 0 (intro) 128 tok max 0.440s 0.433s −1.6%
Req 3 (增肌) 128 tok max 0.440s 0.433s −1.6%
TTFT (post first-call cache) 2.99s 3.31s +10% noise

TPOT: parity (−1.6 ~ +0.8%, all within run-to-run noise).

Honest kernel-level finding (see RFC comment below)

Single-op micro-bench reveals the FP8 BMM kernel itself is 1.5-2x slower than rocBLAS BF16 BMM (what torch.einsum lowers to) at V4 wo_a's shape (B=2, K=4096, N=1024). TPOT shows parity only because wo_a is < 0.5% of total V4 forward time. Tuning aiter's JSON config DB does not close the gap — the bottleneck is fundamental (existing batched kernels grid over B, designed for MLA's B≈128).

The RFC comment on this PR proposes a new aiter Triton kernel (batched_gemm_a8w8_smallB_blockscale) that would address this regime properly. This PR is a placeholder: it captures the production code change and validates correctness/memory; it should land alongside or after the new aiter kernel for a real perf win.

Memory

VRAM saved: ~2-3% per rank (88% vs 90-91% on V4-Pro 8x MI355X), measured during simple_inference run after warmup. Direct contribution: wo_a runtime storage drops from BF16 ~16MB/layer to FP8 ~8MB/layer + scalar scale (~256MB/rank total; remainder of measured delta likely from KV cache adaptive sizing).

Trade-off

Dimension Status
Correctness (PR #650 smoke prompts) Pass
TPOT (eager, V4 PR std config) Parity (−1.6 ~ +0.8%)
TTFT (steady-state) Parity
wo_a kernel time (single-op) Slower 1.5-2x vs einsum (see RFC)
Memory (per rank) −2~3% headroom
FP8 precision impact Per-block-scale → scalar-scale collapse; lm_eval not yet measured (test node was offline)

Toggle

ATOM_V4_OA_USE_EINSUM=1 falls back to legacy BF16 einsum path. Read once at __init__, zero per-forward overhead. Useful for A/B comparison and debug rollback.

Open question for reviewer

Holding this PR open while the RFC for a new aiter kernel (in the comment thread) is evaluated. Two paths forward:

  1. Merge as-is — accept memory-only win, kernel-level perf regression invisible at TPOT level. Useful as a staging step so V4 isn't blocked on aiter changes.
  2. Wait for new aiter kernel, then refactor this PRprocess_weights_after_loading becomes a no-op (LinearBase standard FP8 path), forward calls the new kernel directly. Larger eventual win (perf parity / improvement + memory + zero precision loss).

Defer to V4 PR #650 maintainer + aiter team.

File touched

atom/models/deepseek_v4.py (+145, −47)

valarLip and others added 28 commits April 24, 2026 16:13
…arity

Adds the foundational scaffolding for DeepSeek-V4-Pro support — a major
architecture shift from V3.2 with mHC residuals, hybrid CSA+HCA attention,
hash routing, and grouped output LoRA. PR1 ships the eager-mode model code
with torch fallback kernels, validated against the official inference
implementation at bit-exact parity (max_abs_diff = 0.0).

Scope (PR1 only):
- New atom/models/deepseek_v4.py: full Compressor / Indexer / Attention /
  Gate / Expert / MoE / Block / MTPBlock / ParallelHead / Transformer port
  (~1200 lines). Single-rank only; plain nn.Linear / nn.Embedding for now.
- New atom/model_ops/sparse_attn_v4.py: torch fallbacks for sparse_attn
  and hc_split_sinkhorn (Sinkhorn-Knopp projection on Birkhoff polytope).
- New atom/model_ops/quant_v4.py: torch fallbacks for FP8/FP4 inplace
  QAT round-trip and Walsh-Hadamard transform (replaces fast_hadamard_transform
  which doesn't build on ROCm).
- Register DeepseekV4ForCausalLM in support_model_arch_dict.

Out of scope (tracked for PR2-6):
- Real HF checkpoint loading (PR2 = FP4 e2m1 loader, PR3 = TP + KV cache).
- AITER sparse_attn kernel (PR4; spec at
  /app/logs_claude/aiter_v4_sparse_attn_spec.md, AITER team kicked off).
- MTP integration with EagleProposer (PR5).
- @support_torch_compile + CUDAGraph + openai_server (PR6).

Verification: /app/logs_claude/v4_pr1_verify.py monkey-patches the reference's
TileLang kernel imports with our torch fallbacks, copies the same dummy
state_dict into both models, and runs prefill + decode side-by-side. 259
tensors match exactly; max_abs_diff = 0.0 on logits.
DeepSeek-V4-Pro stores routed expert weights as packed FP4 e2m1 (int8 with
2 values per byte, low nibble first) plus per-block ue8m0 scale (block size
32 along input dim). This commit adds `dequant_fp4_e2m1(packed, scale)` in
atom/model_ops/quant_v4.py — a pure-torch unpacker that mirrors convert.py
exactly but produces BF16 directly instead of repacking into FP8.

Validated bit-exactly against an independent reference unpack on a real
22M-element expert tensor from the on-disk checkpoint. Also regression-
tested across 5 different shapes/positions (w1/w2/w3 in first/mid/last
layer + MTP). All produce values that lie exactly on the FP4 e2m1 grid.

Scope: this is the standalone dequant utility. Wiring it into the model
loader's safetensors pipeline + tying it to specific param names happens
in PR3 alongside TP-aware expert sharding.

Test: /app/logs_claude/v4_pr2_dequant_test.py
Result: max_abs_diff = 0.0 (bit-exact)
PR3a: replace nn.Linear / nn.Embedding with ATOM tensor-parallel-aware
classes for the BF16 projections in Attention, Indexer, and the model
embedding. Same `weight` parameter naming so dummy state_dicts continue
to load. At TP=1 ATOM's tgemm.mm produces bit-identical output to F.linear,
so PR1's reference parity (max_abs_diff = 0.0) still passes.

Layers refactored (8 total):
- DeepseekV4Model.embed:           nn.Embedding -> VocabParallelEmbedding
- DeepseekV4Attention.wq_a:        nn.Linear    -> ReplicatedLinear
- DeepseekV4Attention.wq_b:        nn.Linear    -> ColumnParallelLinear
- DeepseekV4Attention.wkv:         nn.Linear    -> ReplicatedLinear  (single shared MQA head)
- DeepseekV4Attention.wo_a:        nn.Linear    -> ColumnParallelLinear
- DeepseekV4Attention.wo_b:        nn.Linear    -> RowParallelLinear (with all-reduce)
- Indexer.wq_b:                    nn.Linear    -> ColumnParallelLinear
- Indexer.weights_proj:            nn.Linear    -> ColumnParallelLinear

Deferred to later PRs (intentional):
- Compressor.wkv/wgate (fp32) -> PR3c with quant_type wiring
- ParallelHead.weight (fp32 LM head) -> PR3c
- Expert.w{1,2,3} -> PR3b (FusedMoE wholesale rewrite)
- MoE.gate.weight (used as raw Parameter, not Linear class) -> kept

Verification: /app/logs_claude/v4_pr1_verify.py (now GPU mode with
init_dist_env) shows max_abs_diff = 0.0 for prefill + decode against
reference at TP=1.
… for real ckpt

PR3c delivers end-to-end real-checkpoint loading for DeepSeek-V4 attention
layers via ATOM's existing FP8/FP4 GEMM infrastructure.

What works after this commit (validated on real /data/DeepSeek-V4-Pro/):
- DeepseekV4ForCausalLM(atom_config) auto-builds a V4QuantConfig that maps
  routed-experts -> per_1x32 (FP4) and overrides wo_a / Compressor.wkv /
  Compressor.wgate / indexer.weights_proj -> bf16 (no quant). Everything
  else inherits the global FP8 (per_1x128) spec from the HF quantization_config.
- load_weights(weights) walks an iterable of (name, tensor) pairs and:
    * Remaps ATOM's `weight_scale` -> on-disk `scale` naming.
    * Special-cases wo_a: dequantizes FP8+scale -> BF16 on the fly so the
      grouped-LoRA einsum (which aiter doesn't support in FP8) works.
    * Dispatches to ATOM Linear's weight_loader for FP8 / FP4 / BF16 paths.
    * Skips params with shape mismatch (e.g. expert nn.Linear waiting for
      PR3b's FusedMoE refactor) without crashing.
- All 23 attention parameters (FP8 q/kv proj + FP4 indexer + BF16 wo_a + fp32
  compressor) load successfully on real layer-2 of the V4 checkpoint.

Threading changes:
- DeepseekV4Args gains `quant_config: Optional[Any] = None`.
- DeepseekV4Attention / Indexer / Compressor / Block / MTPBlock / DeepseekV4Model
  now accept `prefix: str = ""` and pass `quant_config + prefix` down to each
  ATOM Linear constructor so per-layer quant lookup works.

Backward compatibility:
- When `args.quant_config is None` (toy / dummy validation), V4QuantConfig
  retains its `QuantType.No` global — Linear layers stay BF16 and the PR1
  bit-exact reference parity test (max_abs_diff = 0.0) still passes.

Remaining gaps for end-to-end real-ckpt forward (tracked in design doc):
- PR3b: replace MoE/Expert with FusedMoE so 384 expert FP4 weights load.
- PR3d: refactor V4 attention.forward to accept 2D [num_tokens, dim] input
  (ATOM TP linears require 2D — current 3D path raises "GEMM not supported").
PR3d adapts V4 model to ATOM's scheduler convention: model.forward consumes
flat 2D `[num_tokens, dim]` tokens (single sequence implicit B=1), matching
how ATOM's ModelRunner / scheduler pass tokens. This unblocks ATOM Linear's
quantized GEMM kernels (which only accept 2D `[M, K]` input) and enables
end-to-end real-checkpoint forward.

What changed:
- DeepseekV4Attention.forward(x, start_pos): now accepts 2D [num_tokens, dim].
  Internally adds a B=1 dim only where needed (RoPE, sparse_attn). The
  grouped-LoRA einsum string changes from "bsgd,grd->bsgr" to "sgd,grd->sgr".
- Compressor.forward / Indexer.forward: accept 2D x; auto-unsqueeze to 3D
  internally for backward compatibility with the existing logic.
- Block.hc_pre / hc_post + ParallelHead.hc_head: refactored to be
  shape-agnostic in leading dims (use negative indexing on flatten / sum).
  Both 4D `[B, S, hc, D]` (legacy reference path) and 3D `[num_tokens, hc, D]`
  (ATOM path) work.
- ParallelHead.get_logits: 2D path takes last token via `x[-1:]`; 3D path
  preserves `x[:, -1]` for legacy [B, S, D] inputs.
- MTPBlock.forward: 2D-aware via `e.unsqueeze(-2)` for hc-dim broadcast.
- DeepseekV4Model.forward: auto-flattens 2D `[1, S]` input_ids to 1D `[S]`
  for the new convention; rejects B>1 (proper multi-sequence batching needs
  attn_metadata, deferred).

Validated:
- PR1 reference parity (toy 4-layer dummy weights at B=1 S=32):
  max_abs_diff = 0.0 — still bit-exact after the 2D refactor.
- PR3d end-to-end on REAL V4 weights:
  + Built DeepseekV4ForCausalLM (4 layers, real V4 dims, ~105B params)
  + load_weights() loaded 36 layer-2 params; 23/23 attn params nonzero
  + attn(x_2d=[16, 7168], start_pos=0) → output [16, 7168] bf16
  + No NaN/Inf; output range [-2.94, 3.08], abs mean 0.42 (sensible)
  + This is the first successful V4 attention forward on real weights via ATOM

Test scripts (under /app/logs_claude/):
- v4_pr1_verify.py — toy parity (now uses B=1 + ATOM 2D path)
- v4_pr3d_layer_e2e.py — real-weight 2D forward end-to-end
- v4_pr3c_layer0_test.py — per-Linear validation against real ckpt

Remaining for full model end-to-end:
- PR3b: MoE → FusedMoE so 384 expert FP4 weights load (currently shape-skipped)
- Multi-sequence support via attn_metadata (currently single-sequence implicit B=1)
PR3b enables ATOM's FusedMoE for V4's 384 routed experts so FP4 expert
weights can load via the existing aiter `gemm_a4w4_quant` kernel and
shard across TP/EP ranks. Also extends `select_experts` in moe.py to
support V4's `sqrtsoftplus` scoring with `e_score_correction_bias`.

Changes in atom/model_ops/moe.py:
- `FusedMoE.select_experts` now handles `scoring_func="sqrtsoftplus"`:
  routing_weights = sqrt(softplus(router_logits)) + topk + renormalize.
  Mirrors the V4 reference Gate.forward exactly for non-hash layers.

Changes in atom/models/deepseek_v4.py:
- Dual-path MoE: when `quant_config` is set AND ATOM's global atom_config
  is initialized, MoE uses ReplicatedLinear gate + FusedMoE experts +
  ATOM-Linear shared_experts. Otherwise falls back to the original manual
  per-expert nn.Linear path so PR1 toy validation stays bit-exact (the
  reference test runs without ATOM's ModelRunner setting the global config).
- Expert class accepts `quant_config + prefix`: when set, w1/w2/w3 become
  ColumnParallelLinear/RowParallelLinear (FP8 path); else nn.Linear (toy).
- DeepseekV4ForCausalLM.get_expert_mapping() returns the (param_name,
  weight_name, expert_id, shard_id) tuples mapping V4's `w1/w2/w3` ckpt
  names to FusedMoE's merged `w13_*`/`w2_*` params.
- load_weights() walks expert_mapping first to dispatch routed expert
  tensors via FusedMoE's per-expert weight_loader, then handles the rest:
    * ATOM `weight_scale` ↔ on-disk `scale` rename (existing)
    * ATOM `gate.e_score_correction_bias` ↔ on-disk `gate.bias` rename (NEW)
    * `wo_a` FP8 → BF16 dequant on load (existing)

Validated:
- PR1 toy parity: max_abs_diff = 0.0 (manual MoE path still bit-exact).
- PR3d e2e: real layer-2 attn + 2D forward still works.
- PR3b new: under stub atom_config, FusedMoE path activates correctly.
  Layer-3 (non-hash, real V4 dims): gate + e_score_correction_bias +
  shared_experts (6/6) loaded; FusedMoE expert mapping returns 1152
  entries (384 experts × {w1,w2,w3}).

Known limitations (deferred):
- Hash routing (layers 0/1/2): tid2eid table is loaded but routing logic
  still falls through to sqrtsoftplus path → INCORRECT for hash layers.
  Proper hash routing requires either a custom path through FusedMoE
  or a pre-computed (topk_weights, topk_ids) injection point.
- Multi-sequence batching via attn_metadata (currently single-sequence implicit B=1).

Test: /app/logs_claude/v4_pr3b_fusedmoe_test.py
… prefix

Bug: `make_v4_quant_config` matched `"ffn.experts." in layer_name` (with
trailing dot). FusedMoE.__init__ asks for the layer's quant_type with
prefix `layers.N.ffn.experts` (NO trailing dot — it's the parent module
of the per-expert weights, not a per-expert lookup). The check failed,
so FusedMoE inherited the global FP8 (per_1x128) spec and allocated
the routed expert weights as `float8_e4m3fn` instead of `float4_e2m1fn_x2`.

Symptom in PR3b validation output before the fix:
  FusedMoE experts: 3/5 nonzero  (loader couldn't dispatch FP4-shaped
  on-disk tensors into FP8-typed model params; shape mismatch silently
  skipped them)

After the fix:
  experts.w13_weight: (385, 6144, 3584) torch.float4_e2m1fn_x2 ✓
  experts.w13_weight_scale: (385, 6144, 224) torch.float8_e8m0fnu ✓
  experts.w2_weight:  (385, 7168, 1536) torch.float4_e2m1fn_x2 ✓
  experts.w2_weight_scale:  (385, 7168, 96) torch.float8_e8m0fnu ✓
  e_score_correction_bias: (384,) torch.float32 ✓

Match condition tightened to `".ffn.experts" in layer_name` so it
catches BOTH `layers.N.ffn.experts.M.w1` (per-expert Linear lookups)
AND `layers.N.ffn.experts` (FusedMoE parent module lookup).

Note: a separate aiter-side issue (HSA_STATUS_ERROR_EXCEPTION on FP4
expert weight_loader, traced to a `direct_copy_kernel` with grid size
exceeding HW limits) prevents end-to-end FP4 expert load testing on
this box. The dtype/shape correctness above is verified by inspecting
the constructed module's params directly.

Validated:
- PR1 toy parity: max_abs_diff = 0.0 (manual MoE fallback unaffected)
- PR3d real-attention forward: still works
PR3b's expert weight loader had three bugs that caused weights to load as
zero or be silently dropped:

1. **Expert mapping pattern mismatch**: `make_expert_params_mapping` returns
   `(param_part="experts.w13_", weight_part="experts.0.w1.", ...)` — substring
   substitution, not endswith. The old code built `f".experts.{e}.{suffix}"`
   which never matched. Switched to longest-prefix substring substitution
   matching the standard ATOM loader pattern.

2. **Scale dtype zero-fill**: copying `torch.float8_e8m0fnu` into a `uint8`
   destination via `copy_()` silently produces zeros (mismatched dtype, no
   reinterpret). FusedMoE allocates `w13_weight_scale` as uint8; force a
   `.view(torch.uint8)` on the e8m0 source before passing to the loader.

3. **Param suffix `_scale` vs `.weight_scale`**: after substring sub,
   `experts.0.w1.scale` becomes `experts.w13_scale`, but the FusedMoE param is
   `experts.w13_weight_scale`. Added `_scale` → `_weight_scale` post-fix.

Plus: gracefully slice on-disk gate.weight / gate.bias when the test caps
n_routed_experts below the checkpoint size (no-op in real serving).

Verified:
- v4_pr3b_fusedmoe_test: 32 params loaded, 5/5 expert + 6/6 shared nonzero
- v4_pr3d_layer_e2e: real attention forward still works
- v4_pr1_verify: bit-exact reference parity preserved (0.0 max diff)
…uting_function

V4 uses tid2eid hash lookup (instead of gate-logit topk) for routing in
layers where compress_ratio implies hash layer (first 3 layers in standard
config). Previously, MoE just declared tid2eid for weight loading but
inference fell through to sqrtsoftplus path → wrong routing for those layers.

This commit:

- Adds an early `custom_routing_function` branch to FusedMoE.select_experts
  (it was in the signature but never honored — the non-grouped path went
  straight to scoring_func dispatch). Now any non-None custom fn takes
  precedence and returns (topk_weights, topk_ids).

- Adds DeepseekV4MoE._hash_topk(): topk_ids = tid2eid[input_ids],
  topk_weights = sqrtsoftplus(router_logits) gathered + renormalized.
  Stashes input_ids on self before the experts() call so the closure can
  index tid2eid; clears immediately after.

- For hash layers: assigns experts.custom_routing_function = self._hash_topk
  in MoE.__init__ so FusedMoE picks it up via the moe_forward custom op
  → forward_impl_graph → quant_method.apply → select_experts plumbing.

Verified:
- PR3e (new): synthetic tid2eid → _hash_topk produces exact expected ids,
  renormalized weights match reference math (max_abs_diff = 0.0)
- PR3e: FusedMoE.select_experts honors custom_routing_function correctly
- PR1 toy parity: still 0.0 max diff (hash path is opt-in via is_hash_layer)
- PR3b FusedMoE load: 32 params, all nonzero (no regression)
- PR3d real attn forward: still works (non-hash layer)
… real ckpt

Three changes converging on the first working V4 layer forward:

1. **weights_mapping**: Add class-level rename dict so the standard ATOM
   loader (`atom.model_loader.loader.load_model`) can ingest V4 ckpt names
   without per-model loader.py changes. `.gate.bias` →
   `.gate.e_score_correction_bias`, `.scale` → `.weight_scale_inv`. Loader's
   built-in `weight_scale_inv` → `weight_scale` rename then completes the
   path. Real serving via ModelRunner now works for non-wo_a layers.

2. **process_weights_after_loading hook**: After my custom `model.load_weights`
   finishes copying tensors, walk all submodules and call
   `quant_method.process_weights_after_loading(layer)` (or
   `layer.process_weights_after_loading()` if no quant_method).

   Without this, FusedMoE's `shuffle_weights` step is skipped and the FP4
   ck_moe kernel reads stale weight layout — manifested as
   HSA_STATUS_ERROR_EXCEPTION mid-forward. Standard loader.py calls this for
   us; my custom loader had to replicate it.

3. **PR3f end-to-end test** (logs_claude/v4_pr3f_block_e2e.py):
   - Build 1 dense layer (compress_ratios=[0]) with 8 routed experts
   - Load real layer-3 weights (32 target params, 33/33 nonzero)
   - Build mHC residual `[8 tokens, hc_mult=4, dim=7168]`
   - Call Block.forward(x, start_pos=0, input_ids)
   - Output: shape preserved, range [-4.1, 4.6], abs mean 0.81, no NaN/Inf

This is the first end-to-end forward through V4's full layer:
attention (FP8 wq/wkv + BF16 wo grouped LoRA + indexer) + FusedMoE (FP4
experts via aiter ck_moe + sqrtsoftplus routing + bias correction +
shared expert) + mHC pre/post Sinkhorn projections.

Confirmed no regression on PR1/PR3b/PR3d/PR3e.
…kpts

ModelRunner uses atom.model_loader.loader.load_model() — not the model's
custom load_weights(). This commit closes that gap so real serving via
openai_server works end-to-end:

1. **Expand weights_mapping with prefix renames**: V4 ckpt has bare names
   (`embed.`, `layers.`, `norm.`, `head.`, `hc_head_`) but our params live
   under `self.model = ...`. Add prefix substitutions so the loader's
   `model.get_parameter(name)` lookup hits the right attribute path.

2. **Fix dtype-mismatch silent zero in FusedMoE._load_w13/_load_w2**:
   PyTorch's `tensor.copy_()` between mismatched float8/uint8 dtypes silently
   writes zeros. V4's per-1x32 weight scales are stored as `float8_e8m0fnu`
   on disk but FusedMoE allocates them as `uint8` (raw byte storage). Force
   a `.view(torch.uint8)` reinterpret on the source so the bytes round-trip
   correctly. This is a pre-existing bug that was masked because V2/V3 use
   `float32` scales — V4 is the first ATOM model to use e8m0/e4m3 scales.

Verified:
- PR3i (new): standard load_model() loads V4 layer-0 from full 805GB ckpt
  index — 43/43 model params nonzero (100%), 5GB selective load.
- PR3g (new): full Model.forward(input_ids) → logits on real ckpt.
  Output shape (1, 129280), range [-14.2, 15.4], std 3.05, no NaN/Inf.
- PR3h (new): hash layer (layers 0/1/2) Block.forward works on real
  layer-0 ckpt (tid2eid loaded, 773423/775680 nonzero entries, real
  per-token expert assignments diverge from default sqrtsoftplus path).
- All 5 prior tests (PR1/PR3b/PR3d/PR3e/PR3f) still pass — no regression.

Net result: V4 inference pipeline is now production-ready for real ckpt
loading + forward; remaining gap is multi-layer + multi-batch attn metadata
+ AITER sparse_attn (parallel work).
…hook

PR3i shipped "100% nonzero params" but never ran forward through the
standard-loader path. Verifying with PR3j (new) revealed wo_a values were
2768× too large — `torch.copy_(BF16_dst, FP8_src)` does an FP8→BF16 dtype
conversion but SKIPS the per-128-block scale multiplication. Result: raw
FP8 e4m3 max value (448.0) lands in the BF16 weight buffer instead of the
true ~0.04 attention-init magnitude.

Fix: stop forcing wo_a to no_spec/BF16 in V4QuantConfig. Let it allocate
as FP8 ColumnParallelLinear so the standard FP8 loader fills both
`wo_a.weight` (FP8) and `wo_a.weight_scale` (e8m0) correctly. Then
DeepseekV4Attention.process_weights_after_loading dequants in place,
replacing weight with BF16 + dropping the scale param. Forward continues
to use BF16 weight in the grouped LoRA einsum (aiter has no FP8 grouped
einsum).

Also removes the manual wo_a special-case from custom load_weights() —
both load paths (custom + standard) now converge through the same
process_weights_after_loading dequant.

Verified by PR3j parity test:
- Custom path wo_a: abs.mean=0.0214, abs.max=0.4062
- Standard path wo_a: abs.mean=0.0214, abs.max=0.4062 (BIT-EXACT)
- Standard-loader Model.forward → logits range [-17.9, 15.8], std 3.04
- Magnitude ratio: 1.00 (was 2768× before fix)
- All 9 tests pass — no regression.

This was a silent corruption that PR3i's "params nonzero" check missed.
The lesson: nonzero != correct. Always verify with forward.
Major changes enabling correct V4 inference (single-prompt verified with
512-token coherent output in both English and Chinese):

Model fixes:
- WeightsMapper prefix-anchored remapping (fixes 381 silently-skipped params)
- wo_a FP8→BF16 dequant with quant_type=No to prevent CK shuffle corruption
- Hash routing (first 3 layers) now applies route_scale=2.5
- shared_experts reduce_results=False + unified all_reduce in MoE.forward
- KV cache reset on start_pos=0 with score_state=-inf initialization
- TP-correct head/group counts for Attention and Indexer

MoE routing:
- Standard Silu activation (not Swiglu — aiter a16w4+Swiglu has 9× amplitude
  loss on gfx950). swiglu_limit clamping done in triton post-kernel.
- ATOM_USE_TRITON_MOE=1: triton matmul_ogs path with swiglu_limit clamp
- ATOM_V4_TORCH_MOE=1: per-expert torch fallback with FP4 dequant (slow)
- GFX950MXScaleLayout→CDNA4MXScaleLayout fix in fused_moe_triton.py

Loader improvements:
- WeightsMapper auto-read from model class attribute
- Post-load WARNING listing all unloaded params
- Shape-mismatch raises RuntimeError instead of silent skip

Config:
- deepseek_v4→deepseek_v3 registry mapping with V4 field re-injection
- Robust from_hf_config with getattr defaults

Known limitations:
- Single-sequence only (kv_cache[:1,...] hardcoded); batch>1 needs PR3
- Multi-request KV isolation pending scheduler integration
- TPOT ~213ms with --enforce-eager (no CUDAGraph)
…202)

Upstream ref (deepseek-ai/DeepSeek-V4-Pro@a1fd202) changed shared_experts
from no swiglu_limit to swiglu_limit=args.swiglu_limit, making it consistent
with routed experts.
…witch RoPE to aiter

- DeepseekV4ForCausalLM/Model/Block/MTPBlock/Attention/Compressor/Indexer
  now accept `positions: torch.Tensor` instead of `start_pos: int`; internal
  ring-buffer indexing still derives `start_pos = positions[0].item()` (full
  per-request KV slot management deferred to PR3).
- New `_V4RoPE` wraps aiter `rope_cached_positions_{,2c_}fwd_inplace`,
  driven by per-token positions. Cos/sin cache built via V4's exact YaRN math
  (`_precompute_freqs_cis`); kept symmetric to `_apply_rotary_emb` by working
  on the pre-sliced rope tail.
- `_build_cos_sin_cache` is lru-cached on (rope params, dtype, device) so the
  3 distinct rope param sets (HCA / CSA / Dense) share one GPU tensor across
  all 62 layers instead of 62 register_buffer copies (~16 GB OOM otherwise).
- Inverse RoPE on the attention output keeps `_apply_rotary_emb` (aiter has
  no inverse kernel); the complex freqs slice is rebuilt on demand from the
  cos/sin cache via `_V4RoPE.freqs_for_positions`.
- Verified: simple_inference single-prompt CN 256 tokens coherent.
Generalize the GDN per-request state decoupling (#602) into a complete
model-agnostic KV abstraction owned by the AttentionMetadataBuilder
hierarchy. ModelRunner is now blind to attention type — it walks modules
and dispatches; per-attention-type tensor layouts (MLA 576-dim packed,
GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2
indexer cache, GDN per-req mamba state) all live next to their
respective builder.

ModelRunner net: -526 LOC. The if/elif chains over use_mla /
is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes,
allocate_kv_cache, and the binding loop are all gone. Future stateful
attentions (DeepseekV4 ring buffer + compressor state) plug in by
subclassing AttentionMetadataBuilder without touching scheduler /
block_manager / ModelRunner.

New AttentionMetadataBuilder hooks (defaults are no-ops):
  - compute_per_req_cache_bytes() / slots_per_req()
      bytes/slot for the per-request state pool
  - allocate_per_req_cache(num_slots)
      dict of named per-request state tensors
  - compute_block_bytes()
      per-block bytes for the KV pool budget
  - allocate_kv_cache_tensors(num_kv_heads, num_draft_layers)
      dict of named primary KV cache tensors (kv_cache, kv_scale,
      index_cache, aligned_index_dim, _kv_layer_cache_store)
  - build_kv_cache_tensor(layer_id, module)
      vLLM-style KVCacheTensor for one module, or None if foreign type;
      owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache)

Builder overrides:
  - AiterAttentionMetadataBuilder: split-K/V MHA + MiMo-V2 per-module
  - AiterMLAMetadataBuilder: 576-dim MLA + V3.2 indexer
  - GDNAttentionMetadataBuilder: hybrid full-attn rows + GDN mamba slot
    pool; chains super() for MHA modules in hybrid models. Absorbs the
    formerly-runner-owned gated_delta_net_state_shape/dtypes helpers
    and the side-effect init of full_attention_interval / num_full_attn
    / num_gdn_attn_state.

Naming distinguishes group (per-request unit) from slot (raw tensor
index). One group occupies `slots_per_req()` contiguous slots in the
underlying tensor:
  Sequence.mamba_state_slot     -> .per_req_cache_group
  seq.mamba_enabled             -> .has_per_req_cache
  batch.mamba_state_slots       -> .per_req_cache_groups
  BlockManager.mamba_*          -> .per_req_cache_*  (free pool, accounting)
  config.mamba_equiv_per_req    -> .per_req_cache_equiv_blocks
  config.num_mamba_groups       -> .num_per_req_cache_groups
  ModelRunner.max_mamba_slots   -> .max_per_req_cache_slots  (tensor dim)

Removed (moved to builders):
  ModelRunner._compute_mamba_per_slot_bytes
  ModelRunner.gated_delta_net_state_shape / _dtypes

Sanity check: ModelRunner.__init__ now asserts that any builder
returning compute_per_req_cache_bytes() > 0 has its model_type
registered in InputOutputProcessor._per_req_cache_model_types(),
catching the silent-corruption misconfiguration where a stateful
attention is added but Sequence-construction never gets the
has_per_req_cache=True flag.

Verified:
  - tests/test_per_req_cache_decoupling.py: 24/24 pass
  - core suite (block_manager, sequence, scheduler, request,
    io_processor_fanout, prefix_cache_accuracy): 118/118 pass
  - Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion
    quality unchanged
  - Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, 64 concurrent):
      flexible-extract = 0.8757 +/- 0.0091  (baseline 0.8711 from #602)
      strict-match     = 0.8605 +/- 0.0095
V4 backend (DeepseekV4Backend + DeepseekV4AttentionMetadataBuilder)
plus migration of state-cache buffers to ATOM's per_req_cache pool:

  - pre2a: 6 Compressor state buffers (kv_state + score_state for
    CSA Main / CSA Indexer / HCA Main).
  - pre2c-A: SWA window per layer (paper §3.6.1 state cache, every
    layer has SWA branch in V4-Pro). Attention.kv_cache splits into
    Attention.swa_kv (per_req_cache) + Attention.kv_cache (compressed
    entries only, still register_buffer; pre2c-B will move under
    block_table).

Validated single-prompt 64-token Chinese generation (V4-Pro tp=8,
triton MoE, enforce-eager) — output indistinguishable from baseline.
Strict-paper §3.6.1 split: compressed entries (CSA Main, CSA Indexer,
HCA Main) move from per-layer register_buffer to block-table-indexed
pools owned by DeepseekV4AttentionMetadataBuilder.

  - block_size = lcm(m, m') = 128 original tokens, plumbed via Config
    override on model_type=deepseek_v4 detection.
  - Three classical pools:
      v4_csa_main_kv [num_blocks, n_csa, k1=32, head_dim=512]
      v4_csa_idx_kv  [num_blocks, n_csa, k1=32, idx_head_dim=128]
      v4_hca_main_kv [num_blocks, n_hca, k2=1, head_dim=512]
    Per-layer slice bound to Compressor.kv_cache / Indexer.kv_cache.
  - V4 model adds _v4_scatter_compressed / _v4_gather_compressed helpers
    and fetches block_table from forward_context. Compressor.forward
    scatters writes into block-table slots; Indexer.forward + decode
    sparse_attn input gather committed entries from blocks.
  - Indexer + 1-slot warmup fallback register_buffer pattern same as
    pre2a Compressor.kv_state.
  - Attention.kv_cache attribute removed entirely (compressed entries
    no longer co-located on the Attention module).

Validated single-prompt 64-token Chinese generation (V4-Pro tp=8)
unchanged from pre2c-A baseline.
V4 forward now handles ATOM ragged-batch input with per-seq slot +
block_table routing. Single-seq behavior unchanged; concurrent
batched multi-seq prefill + decode verified end-to-end on 4 prompts.

Changes:
  - Builder prepare_decode/prepare_prefill populate cu_seqlens_q,
    block_tables, and v4_slot_indices (new per-seq metadata attached
    to AttentionMetaData via dynamic attribute).
  - _v4_get_block_table replaced with _v4_get_seq_metadata returning
    (block_tables, slot_indices, cu_seqlens_q, num_seqs).
  - Compressor.forward + Indexer.forward signatures: add slot,
    block_table args. Per-slot indexing via [slot:slot+1, ...]
    replaces hardcoded [:1, ...] / [:bsz, ...].
  - Attention.forward: batched Linear projections + RoPE on full flat
    tensor; per-seq loop slices (cu_seqlens_q) and dispatches SWA write,
    Compressor scatter, Indexer + sparse_attn with each seq's slot +
    block_table. Per-seq state-cache reset on prefill (start_pos==0)
    only zeros that seq's slot — no cross-seq pollution.
  - ParallelHead.get_logits: pick last-token-per-seq via cu_seqlens_q
    (fixed long-standing single-seq assumption that always returned
    only x[-1] regardless of batch size).

Validated MAX_NUM_SEQS=4 concurrent batched inference: 4 prompts
processed in parallel produce independent coherent outputs.
Three independent bugs caused V4 to ramble on edge-confidence prompts
(e.g. "1+2+3=?" output garbled despite 3/4 batch=4 prompts looking OK).
Single-prompt output now matches reference byte-equal on the first 5
tokens and produces "The sum is: 1 + 2 + 3 = **6**." (was: "I'll happily
provide a step-by-step breakdown..." ramble).

Bug 1 (quant_v4.py) — act_quant_inplace ue8m0 path used `ceil(log2)`
(matched TileLang reference) but ref_full_generate.py and aiter both use
round-to-even via f32_to_e8m0/e8m0_to_f32. The 1-binade gap appeared as
~0.002 cos drift on KV path, accumulating across 60 layers.

Bug 2 (moe.py) — FusedMoE.select_experts sqrtsoftplus path renormalized
topk_weights but never applied `* routed_scaling_factor`. The hash routing
path (V4 layers 0-2) does this internally, hiding the bug for hash layers.
Reference Gate.forward (model.py:583) applies the multiply for every
non-softmax routing path. Without the scale, layer 3+ MoE outputs were off
by 1.5x, producing the visible cos jump from 1.0 (layer 0/2) to 0.98
(layer 3+).

Bug 3 (deepseek_v4.py) — DeepseekV4Args.from_hf_config did not read
scale_fmt; HF config.json doesn't carry the field, only inference/config.json
does. Default to "ue8m0" matching reference ModelArgs (inference/model.py:40)
so act_quant_inplace's ue8m0 path is actually exercised.

Also folds in previously-validated V4 cleanups that were sitting in the
working tree:
  - _RMSNorm → ATOM RMSNorm (mark_trace + torch.compile friendly)
  - Indexer wq_b/weights_proj: ColumnParallelLinear → ReplicatedLinear
    (matches sglang/upstream; avoids extra all_reduce on index_score)
  - Block.hc_post defaults to torch (aiter mhc_post drift, opt-in via
    V4_AITER_HC_POST=1; see notes/12)
  - _torch_moe_forward: ue8m0 round-trip on input to mirror reference
    Expert.forward (act_quant before fp4_gemm), gated by V4_USE_REF_QUANT=1

Diagnosis path: notes/14_debug_1plus2plus3.md → notes/19_full_fix_verified.md
… cleanup

New module atom/utils/debug_helper/ provides reusable primitives for forward
bisecting and batch-invariance investigation. All entry points are no-ops
when their controlling env var is unset, so they are safe to leave wired
into production paths (model_runner.py post-load).

Components
  - dump.py        install_block_forward_hooks (multi-class + multi-call),
                   maybe_dump_weights_and_exit, maybe_log_topk
  - compare.py     cos_max (DOUBLE precision — fixes fp32 cos > 1.0 bug),
                   slot_split, compare_slots, pick_prefill_call,
                   schema_diff, plus CLI subcommands:
                     slot-invariance / ref-vs-target / layer-bisect / schema
  - ref_patch.py   patch_method / patch_block_forward / patch_module_dump
                   context managers for instrumenting read-only references
  - 9 ATOM_FWD_DUMP_* / ATOM_WEIGHT_DUMP_* / ATOM_DEBUG_TOPK env vars
    registered in atom/utils/envs.py "Debug Dump" section

Wired into model_runner.py with a 3-line post-load call (no-op default).

V4 model cleanup
  - Convert all nn.Parameter() constructors in deepseek_v4.py to
    atom_parameter() so inference-vs-training grad behavior is controlled
    from a single place (ATOM_REQUIRES_GRAD env). 21 call sites.

Documentation
  - docs/environment_variables.md: new "Debug Dump" subsection documenting
    all 9 env vars + CLI usage.
  - .claude/skills/dump-bisect-debug.md (v3.0): full methodology rewrite
    in English with quick-start decision tree, phase-at-a-glance summary,
    "When to stop / accept divergence" guidance, V4 paper §3.3 batch
    invariance treatment as Phase 8. Includes Bug 11 isolation case study.
  - .claude/skills/atom-patterns.md: ATOM architecture index reference.

Verified by running CLI on existing E1 4xP3 dump:
    python -m atom.utils.debug_helper.compare slot-invariance \\
        --dir /app/logs_claude/deepseek_v4/dumps/bug11_e1
reproduces the layer-by-layer divergence table that informed Bug 11
isolation in notes/21_bug11_isolation.md.
Two fixes that surfaced from the same V4 load run:

1. atom/models/deepseek_v4.py — skip `gate.e_score_correction_bias`
   allocation for hash-routed layers (layer_id < n_hash_layers). V4 hash
   layers route via `tid2eid` lookup, not bias-corrected gate logits;
   the checkpoint has no `gate.bias` for those layers (only layers >= 3).
   Allocating it caused 3 spurious "param NOT loaded from checkpoint"
   warnings every load. Both call sites that read the attribute now use
   `getattr(self.gate, "e_score_correction_bias", None)` — moe.py already
   accepts None for `e_score_correction_bias`.

2. atom/model_loader/loader.py — add ckpt-side coverage check (the
   reverse direction of the existing atom-side check). Every
   `get_parameter() except AttributeError: continue/break` site now
   records `(orig_ckpt_name, rewritten_name)`; after the main loop the
   loader warns if any non-benign drops occurred. This catches the
   actionable bug class — `weights_mapping` / `WeightsMapper` rewrites
   the ckpt name to something the model has no slot for, silently
   throwing away real weight data — which the existing atom-side check
   misses entirely. Benign families (output_scale / kv_scale / inv_freq
   / weight_scale_2) are filtered so the warning is signal, not noise.

Verified on V4 load:
  - atom-side warning: 46/2519 -> 43/2516 (3 hash bias removed)
  - ckpt-side warning: 0 drops (mapping is clean for V4)
  - remaining 43 are all model.mtp.0.* (PR5 todo)
Per paper §3.6.1, the Compressor's per-request state cache holds
"uncompressed tail tokens + previous block as B-side overlap context"
(eq 11). Restructure ATOM's kv_state from a roll-on-decode two-segment
buffer into a single pos % STATE_SIZE ring buffer (STATE_SIZE = 2*ratio
for overlap CSA, ratio for HCA).

Kernel update_compressor_states (atom/model_ops/v4_kernels/state_writes.py):
- dst = pos % STATE_SIZE for every token; no segment switching, no roll
- Phase derived in-kernel from context_lens vs cu_seqlens_q; no IS_PREFILL
- Write mask: fresh prefill keeps [max(0, cutoff-ratio), seqlen) (B-side
  overlap + tail); decode/MTP writes every token

Compressor.forward:
- Drops decode-boundary roll (kv_state[:ratio] <- kv_state[ratio:])
- Reads A-side / B-side halves by block-id parity (comp_id % 2)

Metadata plumbing:
- V4 prepare_decode now populates var["context_lens"] + attaches to
  AttentionMetaData (parent prepare_prefill already did)
- Compressor / Indexer.forward accept required context_lens kwarg
- Wrapper has no positions-derived fallback for context_lens

Also bundles PR-A scaffolding:
- ATOM_V4_BACKEND env gate + per-layer bisect (envs.py, v4_backend_gate.py)
- CPU-mirror metadata (cu_seqlens_q_cpu, state_slot_mapping_cpu,
  start_pos_per_seq_cpu) to avoid per-seq .tolist()/.item() syncs
- v4_slot_indices -> state_slot_mapping rename (clearer vs paged-KV slot_mapping)
- swa_write Triton kernel integration (Phase 1a) under backend gate

Validates: 15/15 byte-equal kernel-vs-reference (prefill + decode + MTP);
simple_inference fast path TPOT 0.328-0.518s/tok matches pre-refactor
baseline (Apr 29 v4_simple_inference.log: 0.453s/tok).
Replace the BF16 grouped einsum at the end of DeepseekV4Attention.forward
with aiter's batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant
(same Triton kernel MLA's _v_up_proj_and_o_proj already uses).

process_weights_after_loading now requants wo_a from on-disk FP8+per-block
scale to FP8+scalar scale (via dequant -> reshape (G,R,K) -> per-batched-tensor
requant), stores W_OA / W_OA_scale buffers, and replaces wo_a.weight with a
1-element placeholder. wo_a runtime memory drops ~50% per layer (BF16 ~8MB ->
FP8 ~4MB).

Forward becomes a single fused launch: input X stays BF16 and is per-token
group quanted inside the kernel; output layout matches the previous einsum
(transpose_bm_in/transpose_bm) so wo_b is unchanged.

Fallback: ATOM_V4_OA_USE_EINSUM=1 keeps the legacy BF16 einsum path
(read once at __init__ for zero per-forward overhead). dummy/toy/BF16-only
loads go through the same BMM path (one-time FP8 quant noise; dummy doesn't
validate numerics).

Risk: per-block scale -> scalar scale is a one-time precision collapse;
must be lm_eval-gated before merge per /ci-pr-guide.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Standalone micro-bench that sweeps (BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M)
against the V4 wo_a shape (B=2, K=4096, N=1024, M variable). Compares each
config to the kernel's default _get_config baseline. Prints best-5 per M.

Used to evaluate whether the BMM kernel can be tuned for B=n_local_groups=2
(MLA-style autotune DB targets B=128). If a clearly better config exists for
prefill or decode M ranges, hardcode it at the wo_a call site.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sweeps (BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, num_warps, num_stages,
waves_per_eu) for V4 wo_a shape (B=2, K=4096, N=1024) across decode + prefill
M ranges. Outputs best-per-bucket configs in aiter's expected M_LEQ_* JSON
format. With --write, drops the JSON into:
  aiter/ops/triton/configs/gemm/gfx942-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=1024-K=4096.json

Next V4 forward will auto-pick via the kernel's _get_config(M, N, K) lookup.

Without --write, dry-run prints the proposed JSON for review.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
amd-zfyu and others added 6 commits May 1, 2026 12:08
Wraps the existing v4_wo_a_tune.py sweep in an end-to-end runner that:
1. Detects gfx arch dynamically (was hardcoded gfx942 before).
2. Measures BMM with default config (true baseline).
3. Measures einsum baseline (for cross-comparison).
4. Sweeps configs at V4 wo_a shape (B=2, K=4096, N=1024, M={1,4,8,16,32,64})
   and writes best-per-bucket JSON to:
     aiter/ops/triton/configs/gemm/{arch}-...-N=1024-K=4096.json
5. Re-runs BMM (aiter's _get_config auto-picks the new JSON).
6. Prints three-way TPOT comparison (Req 3 / 128 tok max).

Uses aiter's intended config-file mechanism — no kernel source patching,
no @triton.autotune injection. Reverts cleanly by deleting the JSON.

Run on gfx950 docker (with working aiter + triton_kernels):
  bash scripts/v4_wo_a_tune_and_test.sh

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previously v4_wo_a_tune.py wrote a tuned JSON containing only the M buckets
we explicitly sweep'd (--m-list defaults to {1,4,8,16,32,64}). aiter's
_get_gemm_config_cached() raises KeyError if M doesn't fall into any bucket
in the config file — and V4 warmup pushes M up to max-num-batched-tokens
(1024 in PR #650 recipe). Result: BMM run with the tuned JSON crashed
with "No matching configuration found for M=1024".

Fix: load aiter's existing default JSON (which covers all M buckets with
generic MLA-style configs) as the base, override only the buckets we
explicitly tuned. Small M uses our B=2-targeted configs; large M falls
back to defaults. No more KeyError, no need to sweep all M values.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previous quick sweep missed the optimal config space:
- BLOCK_SIZE_N capped at 256; for N=1024 try up to 512 (cuts grid by 2x).
- num_warps fixed to {4, 8}; add 2 (small M sometimes prefers fewer warps).
- waves_per_eu missing 4; matters for register-pressure-bound configs.
- triton.testing.do_bench warmup=10/rep=50 had ~3% noise — winners chosen
  by noise rather than real perf. Bump to warmup=25/rep=100 → ~0.5% noise.

Total candidate space: ~1440 configs (vs ~3000 before but better targeted).
Full sweep on 6 M values: ~30 min. Quick mode: ~6 min.

Sweep variance was the root cause of the previous tune giving zero
improvement on V4 wo_a (B=2, K=4096, N=1024).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Skips the sweep; runs a single side-by-side comparison of:
  - torch.einsum("sgd,grd->sgr", ...) on BF16 wo_a
  - fp8_bmm(..., transpose_bm_in=True, transpose_bm=True) on FP8 wo_a

For each M in --m-list, prints einsum μs vs BMM μs vs speedup ratio.

Useful to verify the kernel-level perf claim of PR #676 (FP8 BMM should be
~2x faster on the wo_a kernel itself thanks to half the W bandwidth) —
independent of any tuning, independent of the V4 forward where wo_a is
~0.5% of total TPOT and the kernel-level speedup is invisible.

Run: python scripts/v4_wo_a_tune.py --microbench
Time: ~5 min (mostly Triton JIT first-compile per shape).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Companion to v4_wo_a_tune.py --microbench (BMM vs einsum). Lets us compare
the two FP8 paths (BMM in PR #676 / G-loop in PR #677) head-to-head against
the einsum baseline on the same shape (B=2, K=4096, N=1024).

Useful because PR #676 BMM came in 1.5-2x SLOWER than einsum at the kernel
level (rocBLAS BMM is just better tuned for B=2 than the AITER Triton BMM).
Need to know if G-loop's gemm_a8w8_blockscale_preshuffle (used widely by
ATOM Linears) fares better.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@zufayu
Copy link
Copy Markdown
Author

zufayu commented May 1, 2026

RFC: New aiter Triton kernel for small-B FP8 batched GEMM

Motivation

DeepSeek V4's grouped-output-LoRA wo_a projection has shape (per TP=8 rank):

  • B = n_local_groups = 2 (the kernel batch dimension)
  • K = d_per_group = 4096
  • N = o_lora_rank = 1024
  • M = token batch (1-16 in decode, up to ~1024 in chunked prefill)

This shape is materially different from MLA's regime where the existing aiter
batched FP8 kernels were tuned (B = num_heads ≈ 128). At B=2 the existing
kernels lose to PyTorch's BF16 grouped einsum (which lowers to rocBLAS BMM).

Measured (V4-Pro, gfx950, single-op micro-bench)

M einsum (BF16) batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant (FP8) G-loop gemm_a8w8_blockscale_preshuffle (FP8)
1 15 μs 35 μs (0.43x) 116 μs (0.13x)
16 14 μs 35 μs (0.43x) 60 μs (0.24x)
64 19 μs 37 μs (0.51x) 62 μs (0.30x)
256 19 μs 36 μs (0.54x) 58 μs (0.33x)
1024 37 μs 39 μs (0.96x) 55 μs (0.67x)

(Reproducible via scripts/v4_wo_a_tune.py --microbench and
scripts/v4_wo_a_microbench_gloop.py on the PR #676 branch.)

The FP8 paths are 1.5x – 7x slower than the BF16 baseline at this shape,
even though FP8 should theoretically halve the W bandwidth requirement.

Why existing kernels lose

batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant

Grid topology is (B, M_tiles * N_tiles). At B=2, the first grid dimension
provides only 2-way parallelism. With BLOCK_SIZE_N=128 (the default for this
shape) and N=1024 → 8 N-tiles. Decode M=4 → 1 M-tile. Total work-groups: 16.
On 304 CU MI355X this is ~5% utilization.
Tuning BLOCK sizes can raise
work-group count but introduces atomic / launch overhead that cancels the gain
(verified empirically with both _get_config JSON tuning and @triton.autotune
sweep approaches).

G-loop gemm_a8w8_blockscale_preshuffle

Standard 2D GEMM per group, no batching. Each call does fill the chip well, but
runs G_local=2 times serially with G+1 kernel launches per layer (one act_quant

rocBLAS / hipBLAS BMM (what torch.einsum lowers to)

Wins decisively at B=2 small-M because:

  • Hand-tuned vendor kernel for small-batch BMM
  • Mature autotune database covering exactly this shape regime
  • Single launch, full-CU coverage even at B=2 via internal split-K / persistent
    scheduling

Proposed: batched_gemm_a8w8_smallB_blockscale

A new aiter Triton kernel optimized for the small B + medium-large K + medium N
regime, using on-disk per-128-block W scale directly (no requant, no
precision loss).

Key design choices

  1. Collapse B into the grid's M dimension. Instead of grid (B, M*N/tile),
    use a 1D grid (B*ceil(M/BM)*ceil(N/BN),) with internal program-id
    decomposition. Lets the launcher dispatch all B*M*N work-groups across CUs
    in one wave, avoiding the B=2 bottleneck.

  2. Mandatory split-K. K=4096 with M small is bandwidth-bound; split-K=4 or 8
    adds a 3rd grid dimension and a small reduction kernel, giving 4-8x more
    work-groups. This is the standard fix for "tall-skinny" GEMMs and is already
    in gemm_a8w8_blockscale.py (just not in any batched kernel).

  3. Native per-block W scale support (vs requant-to-scalar in
    batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant).
    Eliminates load-time precision collapse. Implementation: load w_scale[b, n_block, k_block] per inner-loop iteration; BLOCK_SIZE_K = 128 is already
    constrained by per-token-group quant on the X side.

  4. Preserve fused per-token-group X act-quant (same as today's
    _a_per_token_group_prequant_* kernel). On-the-fly quant of X removes the
    need for a separate act-quant kernel call from the caller.

Signature sketch

batched_gemm_a8w8_smallB_blockscale_a_per_token_group_prequant_w_per_block_scale(
    X,               # (B, M, K) BF16 — quantized to FP8 inline
    WQ,              # (B, N, K) FP8 — pre-quantized at load
    w_scale,         # (B, N//128, K//128) FP32 — per-128-block W scale
    group_size=128,  # X act-quant block size on K (= W block size on K)
    block_n=128,     # W block size on N
    transpose_bm=False,
    transpose_bm_in=False,
    splitK=4,        # NEW: required for small-B / small-M regimes
    dtype=torch.bfloat16,
)

Expected benefits

  • wo_a kernel parity or beat vs einsum — collapsing B into grid + split-K
    should restore CU utilization that the existing batched kernel sacrifices.
  • Eliminate one-time precision loss of feat(deepseek_v4): FP8 batched BMM for grouped output LoRA wo_a #676's per-block→scalar requant.
  • No load-time dequant + requant dance — process_weights_after_loading
    becomes a no-op (LinearBase's standard FP8 + shuffle handles wo_a like any
    other Linear).
  • Reusable beyond V4 — any future model with grouped output projections
    (small B, medium-large K) benefits.

Implementation effort

  • Triton kernel: ~150 LoC (start from
    batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py,
    swap scalar w_scale load for per-block load, add 1D grid + split-K +
    reduction kernel).
  • Wrapper + autotune config: ~50 LoC.
  • Unit tests: ~80 LoC (compare vs reference einsum).
  • Estimated 1-2 engineer-days for a working version, +1 day for autotune
    baseline.

Alternative we considered and rejected

  • @triton.autotune patch on existing kernel — sweeps configs, picks
    best per shape. Tried; even with full sweep + low-noise bench, no config
    beats the default by > 5%. The B=2 grid bottleneck is fundamental, not a
    config choice. Tuning improves wo_a kernel by < 10% (still 1.4x slower
    than einsum at small M).

  • split_K retrofit on existing kernel — could land in the same place
    as the new kernel's split-K. Worth exploring as an incremental option,
    but the existing kernel's (B, M*N/tile) grid still wastes the B
    dimension; cleanest fix is the new kernel.

  • Continue with PR feat(deepseek_v4): FP8 batched BMM for grouped output LoRA wo_a #676 as-is — kernel-level 2x slower than einsum;
    TPOT shows parity only because wo_a is < 0.5% of forward time. Memory
    savings of −2~3% per rank are the only real win and may not justify
    the added complexity.

Ask

  1. Aiter team: is there appetite to land this kernel? Happy to draft the
    PR if there's a path to merge.
  2. V4 PR feat(deepseek_v4): PR1 skeleton — end-to-end inference with triton MoE #650 maintainer: would this kernel close PR feat(deepseek_v4): FP8 batched BMM for grouped output LoRA wo_a #676 (BMM) decisively
    in favor of the new path, or does the current memory-only trade still
    warrant landing PR feat(deepseek_v4): FP8 batched BMM for grouped output LoRA wo_a #676 in the interim?

Reproduce

Bench scripts on PR #676 branch (feat/deepseek-v4-wo-a-fp8-bmm):

python scripts/v4_wo_a_tune.py --microbench           # einsum vs current FP8 BMM
python scripts/v4_wo_a_microbench_gloop.py            # einsum vs G-loop blockscale

These were research artifacts used to characterize the wo_a kernel choice
(documented in the RFC comment on this PR — see #676 review thread).
Findings:
- aiter Triton FP8 BMM is 1.5-2x slower than rocBLAS BF16 BMM at V4 wo_a
  shape (B=2, K=4096, N=1024).
- Config tuning via aiter's JSON DB doesn't close the gap (B=2 grid
  underutilization is fundamental, not a config choice).
- G-loop blockscale alternative (PR #677) is even slower (2-7x vs einsum)
  due to G+1 launches per layer overhead at small M.
- Memory savings (-2~3% per rank) is the only kernel-independent win.

The scripts are preserved in git history (commits 26a83b4..761c69e on
this branch) for anyone wanting to reproduce the measurements. They are
not appropriate for the merged production diff.

Production change (atom/models/deepseek_v4.py) unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@zufayu
Copy link
Copy Markdown
Author

zufayu commented May 1, 2026

Tracked aiter side at ROCm/aiter#3000 — proposing a new batched_gemm_a8w8_smallB_blockscale_a_per_token_group_prequant_w_per_block_scale kernel optimized for V4 wo_a's (B=2, K=4096, N=1024) shape. This PR will refactor (process_weights_after_loading → no-op, forward → call new kernel directly) once that kernel lands.

Base automatically changed from feat/deepseek-v4-pr1-skeleton to main May 6, 2026 16:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants