perf: cascaded N=3 MTP + share-warp Q8 — GB10 decode past 20 t/s (stacked on #14)#15
Open
TrevorS wants to merge 2 commits into
Open
perf: cascaded N=3 MTP + share-warp Q8 — GB10 decode past 20 t/s (stacked on #14)#15TrevorS wants to merge 2 commits into
TrevorS wants to merge 2 commits into
Conversation
matmul_q8_0_preq_batch_share_warp_kernel called dot_i8_block once per token, which re-read the row's Q8 weight quants N_TOK times -- the one thing the share-warp kernel exists to avoid. Hoist the 8 int32 weight words into registers once per block and dp4a them against each token's activations. Bit-identical to the previous form and to the N=1 warp8 reference: the dp4a accumulation order (word 0..7) is unchanged and the per-block dot is an exact int32, so only the (already-shared) weight load is removed from the inner token loop. Verified byte-identical --mtp output. Bench (DGX Spark / GB10, ds4flash, knight n=256, --mtp, 5-run mean): before: 18.50 t/s after: 18.75 t/s (+0.25, +1.4%) Kernel share of MTP decode GPU time: 27.8% -> 27.1% (88.6us -> 86.2us avg). The kernel stays HBM-bandwidth-bound on the 34-byte Q8 block stride; this removes the redundant-reload overhead on top of that.
Verify [first_token, drafts[0], drafts[1]] in a single batched main-model pass instead of [first_token, drafts[0]], lifting accepted tokens/iter on GB10 and pushing decode past 20 t/s on the standard knight prompt. drafts[1] is chained off drafts[0] by threading the MTP block's own output HC (mtp_state_hc -> mtp_next_hc) rather than the iter-start combined_prev_hc. With the HC threaded this way the verifier accepts drafts[1] at roughly the canonical chained rate (~0.58 given drafts[0] accepted), so the extra row pays for itself. Partial-accept rollback gains a second target. The batched verify advances the compressor frontier through all three rows; a commit==1 accept (drafts[0] kept, drafts[1] rejected) must rewind to the post-drafts[0] boundary. This adds a prefix2 snapshot — a mirror of the prefix1 buffers captured at the row-1 boundary (t==1) — and a spec_frontier_commit_prefix2 restore. The prefix1/prefix2 commit bodies are factored into one spec_frontier_commit_from helper. Also fixes a latent bug in metal_graph_verify_suffix_tops that the N=3 window is the first caller to reach: the top_rows>=2 branch passed (n_tokens=1, top_k=top_rows) to ds4_gpu_indexer_topk_tensor, computing the top-N of row 0 only instead of one argmax per row, and routed through the generic single-thread indexer_topk_kernel (~17 ms per call over the 128k vocab). Replaced with one fast tree-reduce argmax per row via ds4_gpu_argmax_tensor on row views. Both issues were dormant while the window was strictly two tokens (top_rows==1, argmax path). Correctness: the verifier defines the exact greedy stream, so the N=3 output is byte-identical to the two-token window on every prompt tested, including ~30k-context long prompts where the compressor frontier is fully active and the prefix2 rewind fires for real. Strict mode (DS4_MTP_STRICT / --quality) is unchanged — it never enters combined-forward and stays byte-exact to plain decode. DS4_MTP_NO_CASCADE=1 forces the two-token window byte-for-byte. Bench (DGX Spark / GB10, ds4flash, --mtp, temp 0, --nothink): knight n=256: 18.85 -> 20.54 t/s (5-run mean), 1.789 -> 2.237 tok/iter transformer prompt: 19.39 -> 21.66 t/s long-context ~30k: 17.45 -> 22.19 t/s Tested: make cuda-spark clean; ds4_test --long-context, --tool-call-quality, --server OK; N=3 == DS4_MTP_NO_CASCADE and strict == plain byte-identical at short, second-prompt, and ~30k-context.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked on #14 (
gb10-mtp-combined-forward).Summary
Two CUDA-only, weight-preserving decode-perf changes on top of combined-forward:
[first_token, drafts[0], drafts[1]]in one batched main-model pass instead of[first_token, drafts[0]], lifting accepted tokens/iter.Together they push greedy
--mtpdecode past 20 t/s on the standard short prompt and to ~22 t/s at long context.Speed — DGX Spark / GB10 (ds4flash,
--mtp, temp 0,--nothink)knight, n=256knightis a 5-run mean; tokens/iter 1.79 → 2.24.What's in the PR
1. share-warp Q8 weight hoist
matmul_q8_0_preq_batch_share_warp_kernelwas calling the per-token dot, re-reading the row's Q8 weight quants N times — the one thing the share-warp kernel exists to avoid. Hoist the 8 int32 weight words into registers once per block and dp4a them against each token. Bit-identical: exact int32 dot, unchanged accumulation order.2. cascaded N=3 combined-forward
drafts[1]is chained offdrafts[0]by threading the MTP block's own output HC (mtp_state_hc → mtp_next_hc), not the iter-startcombined_prev_hc, so the verifier accepts it ~0.58 of the time givendrafts[0]accepted — enough to pay for the extra row.commit==1case (drafts[0]kept,drafts[1]rejected). prefix1/prefix2 share one restore helper.top_rows ≥ 2readback took the top-N of row 0 only and routed through the generic single-threadindexer_topk(~17 ms/call over the 128k vocab). Replaced with one fast tree-reduce argmax per row.DS4_MTP_NO_CASCADE=1forces the prior two-token window (kill switch).Correctness
The verifier defines the exact greedy stream, so the N=3 output is byte-identical to the two-token window on every prompt tested, including ~30k-context prompts where the compressor frontier is fully active and the
prefix2rewind fires for real. Strict mode (DS4_MTP_STRICT/--quality) is unchanged — it never enters combined-forward and stays byte-exact to plain decode.Tested
make clean && make cuda-spark— clean;make cpu— clean./ds4_test --long-context,--tool-call-quality,--server— OKDS4_MTP_NO_CASCADEand strict ≡ plain, byte-identical at short / second-prompt / ~30k-contextds4_testchecks fail identically on stockupstream/main— not introduced here, and outside the spec-decode path (both run plain/batch decode, not--mtp):--logprob-vectors short_code_completion— markdown code-fence language-tag case near-tie (```cvs```C); the generated code is byte-identical.--metal-tensor-equivalence— MoE routed-expert down-projection floatatomicAddatn_tokens >= 128(use_atomic_down) drifts at ulp scale across runs and occasionally flips a greedy argmax at long context;DS4_CUDA_MOE_NO_ATOMIC_DOWN=1makes it bit-exact.Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model:
DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.ggufMTP:
DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.ggufAGENT.md compliance
DS4_MTP_NO_CASCADEis a diagnostic kill switch that reproduces the prior two-token release path, not a new permanent mode.