Skip to content

perf: cascaded N=3 MTP + share-warp Q8 — GB10 decode past 20 t/s (stacked on #14)#15

Open
TrevorS wants to merge 2 commits into
gb10-mtp-combined-forwardfrom
gb10-decode-perf
Open

perf: cascaded N=3 MTP + share-warp Q8 — GB10 decode past 20 t/s (stacked on #14)#15
TrevorS wants to merge 2 commits into
gb10-mtp-combined-forwardfrom
gb10-decode-perf

Conversation

@TrevorS
Copy link
Copy Markdown
Owner

@TrevorS TrevorS commented May 25, 2026

Stacked on #14 (gb10-mtp-combined-forward).

Summary

Two CUDA-only, weight-preserving decode-perf changes on top of combined-forward:

  1. share-warp Q8 weight hoist — load each Q8_0 weight row's quants once per block and reuse across the N batched tokens, instead of re-reading them per token.
  2. cascaded N=3 combined-forward — verify [first_token, drafts[0], drafts[1]] in one batched main-model pass instead of [first_token, drafts[0]], lifting accepted tokens/iter.

Together they push greedy --mtp decode past 20 t/s on the standard short prompt and to ~22 t/s at long context.

Speed — DGX Spark / GB10 (ds4flash, --mtp, temp 0, --nothink)

prompt #14 (N=2) this PR Δ
knight, n=256 18.85 20.54 +1.69
chat prompt, n=200 19.39 21.66 +2.27
long-context ~30k 17.45 22.19 +4.74

knight is a 5-run mean; tokens/iter 1.79 → 2.24.

What's in the PR

1. share-warp Q8 weight hoist

matmul_q8_0_preq_batch_share_warp_kernel was calling the per-token dot, re-reading the row's Q8 weight quants N times — the one thing the share-warp kernel exists to avoid. Hoist the 8 int32 weight words into registers once per block and dp4a them against each token. Bit-identical: exact int32 dot, unchanged accumulation order.

2. cascaded N=3 combined-forward

  • drafts[1] is chained off drafts[0] by threading the MTP block's own output HC (mtp_state_hc → mtp_next_hc), not the iter-start combined_prev_hc, so the verifier accepts it ~0.58 of the time given drafts[0] accepted — enough to pay for the extra row.
  • Partial-accept rollback gains a prefix2 snapshot (compressor frontier at the row-1 boundary) for the commit==1 case (drafts[0] kept, drafts[1] rejected). prefix1/prefix2 share one restore helper.
  • Fixes a latent bug the N=3 window is the first caller to hit: the verifier's top_rows ≥ 2 readback took the top-N of row 0 only and routed through the generic single-thread indexer_topk (~17 ms/call over the 128k vocab). Replaced with one fast tree-reduce argmax per row.
  • DS4_MTP_NO_CASCADE=1 forces the prior two-token window (kill switch).

Correctness

The verifier defines the exact greedy stream, so the N=3 output is byte-identical to the two-token window on every prompt tested, including ~30k-context prompts where the compressor frontier is fully active and the prefix2 rewind fires for real. Strict mode (DS4_MTP_STRICT / --quality) is unchanged — it never enters combined-forward and stays byte-exact to plain decode.

Tested

  • make clean && make cuda-spark — clean; make cpu — clean
  • ./ds4_test --long-context, --tool-call-quality, --server — OK
  • N=3 ≡ DS4_MTP_NO_CASCADE and strict ≡ plain, byte-identical at short / second-prompt / ~30k-context
  • Two ds4_test checks fail identically on stock upstream/main — not introduced here, and outside the spec-decode path (both run plain/batch decode, not --mtp):
    • --logprob-vectors short_code_completion — markdown code-fence language-tag case near-tie (```c vs ```C); the generated code is byte-identical.
    • --metal-tensor-equivalence — MoE routed-expert down-projection float atomicAdd at n_tokens >= 128 (use_atomic_down) drifts at ulp scale across runs and occasionally flips a greedy argmax at long context; DS4_CUDA_MOE_NO_ATOMIC_DOWN=1 makes it bit-exact.

Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf

AGENT.md compliance

  • Preserve correctness before speed — output byte-identical to the prior decode path; strict mode preserves the byte-exact canonical path.
  • No permanent semantic variants behind flagsDS4_MTP_NO_CASCADE is a diagnostic kill switch that reproduces the prior two-token release path, not a new permanent mode.

TrevorS added 2 commits May 24, 2026 16:17
matmul_q8_0_preq_batch_share_warp_kernel called dot_i8_block once per
token, which re-read the row's Q8 weight quants N_TOK times -- the one
thing the share-warp kernel exists to avoid.  Hoist the 8 int32 weight
words into registers once per block and dp4a them against each token's
activations.

Bit-identical to the previous form and to the N=1 warp8 reference: the
dp4a accumulation order (word 0..7) is unchanged and the per-block dot
is an exact int32, so only the (already-shared) weight load is removed
from the inner token loop.  Verified byte-identical --mtp output.

Bench (DGX Spark / GB10, ds4flash, knight n=256, --mtp, 5-run mean):
  before: 18.50 t/s
  after:  18.75 t/s   (+0.25, +1.4%)

Kernel share of MTP decode GPU time: 27.8% -> 27.1% (88.6us -> 86.2us
avg).  The kernel stays HBM-bandwidth-bound on the 34-byte Q8 block
stride; this removes the redundant-reload overhead on top of that.
Verify [first_token, drafts[0], drafts[1]] in a single batched
main-model pass instead of [first_token, drafts[0]], lifting accepted
tokens/iter on GB10 and pushing decode past 20 t/s on the standard
knight prompt.

drafts[1] is chained off drafts[0] by threading the MTP block's own
output HC (mtp_state_hc -> mtp_next_hc) rather than the iter-start
combined_prev_hc.  With the HC threaded this way the verifier accepts
drafts[1] at roughly the canonical chained rate (~0.58 given drafts[0]
accepted), so the extra row pays for itself.

Partial-accept rollback gains a second target.  The batched verify
advances the compressor frontier through all three rows; a commit==1
accept (drafts[0] kept, drafts[1] rejected) must rewind to the
post-drafts[0] boundary.  This adds a prefix2 snapshot — a mirror of
the prefix1 buffers captured at the row-1 boundary (t==1) — and a
spec_frontier_commit_prefix2 restore.  The prefix1/prefix2 commit
bodies are factored into one spec_frontier_commit_from helper.

Also fixes a latent bug in metal_graph_verify_suffix_tops that the N=3
window is the first caller to reach: the top_rows>=2 branch passed
(n_tokens=1, top_k=top_rows) to ds4_gpu_indexer_topk_tensor, computing
the top-N of row 0 only instead of one argmax per row, and routed
through the generic single-thread indexer_topk_kernel (~17 ms per call
over the 128k vocab).  Replaced with one fast tree-reduce argmax per
row via ds4_gpu_argmax_tensor on row views.  Both issues were dormant
while the window was strictly two tokens (top_rows==1, argmax path).

Correctness: the verifier defines the exact greedy stream, so the N=3
output is byte-identical to the two-token window on every prompt
tested, including ~30k-context long prompts where the compressor
frontier is fully active and the prefix2 rewind fires for real.  Strict
mode (DS4_MTP_STRICT / --quality) is unchanged — it never enters
combined-forward and stays byte-exact to plain decode.
DS4_MTP_NO_CASCADE=1 forces the two-token window byte-for-byte.

Bench (DGX Spark / GB10, ds4flash, --mtp, temp 0, --nothink):
  knight n=256:        18.85 -> 20.54 t/s (5-run mean), 1.789 -> 2.237 tok/iter
  transformer prompt:  19.39 -> 21.66 t/s
  long-context ~30k:   17.45 -> 22.19 t/s

Tested: make cuda-spark clean; ds4_test --long-context,
--tool-call-quality, --server OK; N=3 == DS4_MTP_NO_CASCADE and strict
== plain byte-identical at short, second-prompt, and ~30k-context.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant