Skip to content

cuda: make share-warp Q8 kernel bit-equal at any block count (stacked on #7)#8

Closed
TrevorS wants to merge 1 commit into
mtp-beats-plain-kernels-v6from
mtp-beats-plain-kernels-v7
Closed

cuda: make share-warp Q8 kernel bit-equal at any block count (stacked on #7)#8
TrevorS wants to merge 1 commit into
mtp-beats-plain-kernels-v6from
mtp-beats-plain-kernels-v7

Conversation

@TrevorS
Copy link
Copy Markdown
Owner

@TrevorS TrevorS commented May 24, 2026

PR8: cuda: make share-warp Q8 kernel bit-equal at any block count (stacked on #7)

Summary

Rewrites matmul_q8_0_preq_batch_share_warp_kernel<N_TOK> to be bit-equal to the N=1 reference matmul_q8_0_preq_warp8_kernel at any block count, then drops the blocks <= 32u constraint that previously gated share-warp off at large in_dim.

No observable perf delta on the current PR stack. Honest disclosure up front: the share-warp small-N dispatch is dead code in the combined-forward MTP flow (which routes batched matmuls through F16, not Q8). Confirmed by dispatcher tracing — zero share-warp-eligible calls during a 32-token MTP-enabled generation. Recovering the bandwidth-amortization perf win for combined-forward needs the analogous F16 small-N kernel, scoped for separate work.

This PR is infrastructure / future-proofing: removes a stale correctness-guard, locks bit-equality with explicit __fmaf_rn, and tightens a latent PR5 gap (DS4_CUDA_STRICT_BATCHED now works at any block count instead of falling through to a non-bit-equal kernel above blocks=32).

What this PR is and isn't

  • Is: a code-correctness improvement — share-warp's bit-equality with N=1 warp8 is now proven by construction (explicit fma lock-step), not just bit-equal-by-luck at small blocks. Future callers can rely on it at any block count.
  • Is: a latent PR5 fix — drops a parallel blocks <= 32u constraint on the batch_warp8 dispatcher (line 6158) so DS4_CUDA_STRICT_BATCHED strictly routes through batch_warp8 at any block count instead of falling through to matmul_q8_0_preq_kernel (which uses a 256-thread shared-mem tree reduction and is NOT bit-equal to N=1 warp8).
  • Is NOT: a perf win. The share-warp Q8 dispatch sees only n_tok={1, 6} in combined-forward — never n_tok=2..4. Both PR7 and PR8 bench identically (16.11-16.13 t/s plain and MTP).
  • Is NOT: the path to "MTP > plain". Today MTP ≈ plain at parity on Spark. Pushing MTP measurably above plain requires F16 small-N work (where combined-forward actually runs at n_tok=4), not Q8.

Bit-equality argument

For both matmul_q8_0_preq_warp8_kernel (N=1 reference) and matmul_q8_0_preq_batch_share_warp_kernel<N_TOK> (this PR's rewritten kernel), for every output row at any block count:

Property N=1 warp8 share-warp (PR8)
Per-lane block sequence b in {lane, lane+32, lane+64, ...} identical
Per-block dot dot_i8_block(qs, xqb, bn, use_dp4a) identical (same helper, same inputs for token 0)
Per-block update acc += __half2float(*scale_h) * xscale[b] * (float)dot (compiler emits fma(scale*xscale, dot, acc) under --fmad=true) explicit: const float s = wscale * xs; acc[t] = __fmaf_rn(s, (float)dot, acc[t]) — same SASS by construction
Warp reduction warp_sum_f32 (butterfly shuffle, 16→8→4→2→1) identical

The previous gate was set out of caution that the compiler might choose a different FMA contraction for the share-warp's t-loop body. Locking the form with __fmaf_rn removes that degree of freedom.

Bench (DGX Spark, ds4flash, n=256, "knight")

Mode PR7 PR8 Δ
Plain decode 16.11-16.13 16.11-16.13 none
Default --mtp (combined-forward) 16.11-16.13 16.11-16.13 none
DS4_MTP_STRICT=1 --mtp (canonical) unchanged unchanged none

Consistent with the dispatcher-trace finding that share-warp is dead code in combined-forward.

Dispatcher trace evidence

Instrumented cuda_matmul_q8_0_tensor_labeled with a per-call n_tok log. During 16 tokens of MTP generation:

1936 PR8-f16: n_tok=1            (decode F16)
 253 PR8-f16: n_tok=6            (prefill chunk F16)
  84 PR8-f16: n_tok=4            (combined-forward F16!)
 661 PR8-q8:  n_tok=1            (decode Q8)
 301 PR8-q8:  n_tok=6            (prefill chunk Q8)
   0 PR8-q8:  n_tok=2..4         (NEVER -- share-warp gate scope)

Combined-forward N=2 (K=1 + first_token = 2-row) and prefix-1/prefix-2 captures translate to n_tok=4 calls at the F16 dispatcher, not the Q8 one.

What pushes MTP > plain (out of scope, scoped for separate PR)

The actual lever is at the F16 dispatcher (line 6320). cuBLAS GemmEx pads M=4 → 16 (wasting 12/16 = 75% of M-axis work) and uses F16 accumulation which drifts from F32 N=1 reference. A matmul_f16_share_warp_kernel<N_TOK> analogous to the Q8 one would:

  • Amortize F16 weight reads N-fold (one row per warp, N tokens computed together)
  • Avoid the M-tile padding waste
  • Be bit-equal to N=1 matmul_f16_kernel for row 0 (same construction as this PR)

That's the next PR to actually push MTP perf over plain.

Tested against

  • make clean && make cuda-spark — clean, no warnings
  • make cpu — clean
  • ./ds4_test --all — only pre-existing --logprob-vectors short_code_completion failure (same as upstream/main, PR5/6/7). metal-tensor-equivalence passes on this run with capture_fail=0 logits_fail=0 greedy_fail=0 top1_mismatch=0.
  • Plain decode (no --mtp) — byte-identical to PR7
  • Default --mtp decode — byte-identical to PR7
  • DS4_MTP_STRICT=1 --mtp — unchanged (gate above already routes canonical)
  • 3-run bench stability: 16.11 / 16.13 / 16.11 t/s
  • Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
  • Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
  • MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf

Test stability note

ds4_test --metal-tensor-equivalence has intrinsic engine-level non-determinism on this hardware: 10-run pass-rate variance of 30%-80% on PR7 baseline with no PR8 changes (sample runs: 7/10, 8/10, 4/5, 3/10 on identical code). The same flakiness pattern reproduces on PR8. This is not introduced by PR8 — it appears to be intrinsic to the CUDA backend's batched-N path (likely atomic ops in routed-MoE or cuBLAS workspace scheduling). Separate diagnostic work, out of scope here.

AGENT.md compliance

  • "Preserve correctness before speed" — kernel is now bit-equal to N=1 by construction at any block count (was bit-equal-by-luck at small blocks). Strict-mode unchanged.
  • "Do not add permanent semantic variants behind flags" — no new flag. Pre-existing DS4_CUDA_NO_Q8_SHARE_BATCH=1 opt-out preserved.
  • "Diagnostic switches are fine when they validate the one release path" — env knobs unchanged.

Out of scope / follow-ups

  • F16 small-N share-warp kernel (where combined-forward actually has a perf lever)
  • Engine-level test non-determinism investigation (atomics / cuBLAS workspace)
  • Captured-graph spec decode subsystem

PR3-K4 introduced a `blocks <= 32u` constraint on the share-warp Q8
matmul dispatch out of caution: the share-warp kernel was assumed to
diverge from the N=1 warp8 reference once the per-lane block-stride
loop iterated more than once.  Closer inspection shows the divergence
concern is unfounded -- both kernels use the same `for (b = lane;
b < blocks; b += 32u)` stride, the same `warp_sum_f32` reduction, and
the same per-block expression `acc += __half2float(*scale_h) *
xscale[b] * (float)dot`.  The only ulp-scale risk is FMA-contraction
divergence under nvcc's code-context heuristics: with the more complex
share-warp t-loop body, the compiler might pick a different fma
scheduling for the equivalent arithmetic.

Rewrite the share-warp inner update to use an explicit `__fmaf_rn`
against a precomputed `(wscale * xs)` factor.  Under `--fmad=true`
(the default for -O3 / --use_fast_math), nvcc emits the same SASS
for the N=1 expression: `t = scale * xscale; acc = fma(t, dot, acc)`.
Locking share-warp to that form removes the FMA-contraction degree of
freedom and guarantees bit-equality with the N=1 reference for every
output row at any block count.

Drop the `blocks <= 32u` dispatcher constraint now that the kernel is
bit-equal at any block count.  Add a long comment block on the kernel
documenting the per-property argument for bit-equality.

Performance impact: none observable on the current PR stack.  The
combined-forward MTP flow routes through F16 batched matmuls at
n_tok=4, not through Q8 -- the share-warp small-N dispatch sees only
n_tok=1 (decode) and n_tok=6 (prefill chunk), neither in the 2..4
range it gates on.  Confirmed via dispatcher tracing: 0 PR8-eligible
calls during a 32-token MTP-enabled generation.  Recovering the
share-warp's intended bandwidth-amortization win for combined-forward
needs the analogous F16 small-N kernel, scoped for separate work.

Also drop the parallel `blocks <= 32u` constraint at the PR5 batch_warp8
dispatch (line 6158) so DS4_CUDA_STRICT_BATCHED callers route through
batch_warp8 at any block count, instead of falling through to
matmul_q8_0_preq_kernel (which uses a 256-thread shared-mem tree
reduction and is NOT bit-equal to N=1 warp8).  This makes the
strict-batched gate's "bit-equal to N=1 plain decode" invariant hold
at any in_dim, which was a latent gap in PR5.

Tested:
  - `make cuda-spark` clean
  - `make cpu` clean
  - `./ds4_test --all`: only the pre-existing `--logprob-vectors
    short_code_completion` failure (same as upstream/main, PR5/6/7).
    `metal-tensor-equivalence`: OK, `cases=5 capture_fail=0 logits_fail=0
    greedy_fail=0 top1_mismatch=0`.  Stability remains within the
    intrinsic engine-non-determinism band observed on PR7 (small-sample
    pass-rate variance 30%-80% across runs of n=10; the same test
    flakes identically on PR7 baseline -- this is *not* a regression
    introduced by PR8).
  - Plain decode output byte-identical to PR7 (only timing/cache lines
    differ in the diff).
  - MTP default-mode decode output byte-identical to PR7.
  - Bench (DGX Spark, ds4flash, n=256, "knight"):
      plain:    16.11-16.13 t/s  (PR7: 16.11-16.13)
      mtp:      16.11-16.13 t/s  (PR7: 16.11-16.13)
    No measurable delta -- consistent with the dispatch-trace finding
    that share-warp is dead code in the combined-forward flow.

Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels-v6 branch from bb27595 to 28ff7ce Compare May 24, 2026 17:14
@TrevorS TrevorS force-pushed the mtp-beats-plain-kernels-v7 branch from fba7778 to ff5d8a4 Compare May 24, 2026 17:14
@TrevorS
Copy link
Copy Markdown
Owner Author

TrevorS commented May 24, 2026

Superseded by the reframed 2-PR stack (#11 + #12), which tells the same Spark/GB10 + MTP combined-forward story more concisely, rebased on current upstream/main, with the exploratory paths dropped.

@TrevorS TrevorS closed this May 24, 2026
@TrevorS TrevorS deleted the mtp-beats-plain-kernels-v7 branch May 24, 2026 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant