cuda: make share-warp Q8 kernel bit-equal at any block count (stacked on #7)#8
Closed
TrevorS wants to merge 1 commit into
Closed
cuda: make share-warp Q8 kernel bit-equal at any block count (stacked on #7)#8TrevorS wants to merge 1 commit into
TrevorS wants to merge 1 commit into
Conversation
PR3-K4 introduced a `blocks <= 32u` constraint on the share-warp Q8
matmul dispatch out of caution: the share-warp kernel was assumed to
diverge from the N=1 warp8 reference once the per-lane block-stride
loop iterated more than once. Closer inspection shows the divergence
concern is unfounded -- both kernels use the same `for (b = lane;
b < blocks; b += 32u)` stride, the same `warp_sum_f32` reduction, and
the same per-block expression `acc += __half2float(*scale_h) *
xscale[b] * (float)dot`. The only ulp-scale risk is FMA-contraction
divergence under nvcc's code-context heuristics: with the more complex
share-warp t-loop body, the compiler might pick a different fma
scheduling for the equivalent arithmetic.
Rewrite the share-warp inner update to use an explicit `__fmaf_rn`
against a precomputed `(wscale * xs)` factor. Under `--fmad=true`
(the default for -O3 / --use_fast_math), nvcc emits the same SASS
for the N=1 expression: `t = scale * xscale; acc = fma(t, dot, acc)`.
Locking share-warp to that form removes the FMA-contraction degree of
freedom and guarantees bit-equality with the N=1 reference for every
output row at any block count.
Drop the `blocks <= 32u` dispatcher constraint now that the kernel is
bit-equal at any block count. Add a long comment block on the kernel
documenting the per-property argument for bit-equality.
Performance impact: none observable on the current PR stack. The
combined-forward MTP flow routes through F16 batched matmuls at
n_tok=4, not through Q8 -- the share-warp small-N dispatch sees only
n_tok=1 (decode) and n_tok=6 (prefill chunk), neither in the 2..4
range it gates on. Confirmed via dispatcher tracing: 0 PR8-eligible
calls during a 32-token MTP-enabled generation. Recovering the
share-warp's intended bandwidth-amortization win for combined-forward
needs the analogous F16 small-N kernel, scoped for separate work.
Also drop the parallel `blocks <= 32u` constraint at the PR5 batch_warp8
dispatch (line 6158) so DS4_CUDA_STRICT_BATCHED callers route through
batch_warp8 at any block count, instead of falling through to
matmul_q8_0_preq_kernel (which uses a 256-thread shared-mem tree
reduction and is NOT bit-equal to N=1 warp8). This makes the
strict-batched gate's "bit-equal to N=1 plain decode" invariant hold
at any in_dim, which was a latent gap in PR5.
Tested:
- `make cuda-spark` clean
- `make cpu` clean
- `./ds4_test --all`: only the pre-existing `--logprob-vectors
short_code_completion` failure (same as upstream/main, PR5/6/7).
`metal-tensor-equivalence`: OK, `cases=5 capture_fail=0 logits_fail=0
greedy_fail=0 top1_mismatch=0`. Stability remains within the
intrinsic engine-non-determinism band observed on PR7 (small-sample
pass-rate variance 30%-80% across runs of n=10; the same test
flakes identically on PR7 baseline -- this is *not* a regression
introduced by PR8).
- Plain decode output byte-identical to PR7 (only timing/cache lines
differ in the diff).
- MTP default-mode decode output byte-identical to PR7.
- Bench (DGX Spark, ds4flash, n=256, "knight"):
plain: 16.11-16.13 t/s (PR7: 16.11-16.13)
mtp: 16.11-16.13 t/s (PR7: 16.11-16.13)
No measurable delta -- consistent with the dispatch-trace finding
that share-warp is dead code in the combined-forward flow.
Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf
bb27595 to
28ff7ce
Compare
fba7778 to
ff5d8a4
Compare
Owner
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR8: cuda: make share-warp Q8 kernel bit-equal at any block count (stacked on #7)
Summary
Rewrites
matmul_q8_0_preq_batch_share_warp_kernel<N_TOK>to be bit-equal to the N=1 referencematmul_q8_0_preq_warp8_kernelat any block count, then drops theblocks <= 32uconstraint that previously gated share-warp off at largein_dim.No observable perf delta on the current PR stack. Honest disclosure up front: the share-warp small-N dispatch is dead code in the combined-forward MTP flow (which routes batched matmuls through F16, not Q8). Confirmed by dispatcher tracing — zero share-warp-eligible calls during a 32-token MTP-enabled generation. Recovering the bandwidth-amortization perf win for combined-forward needs the analogous F16 small-N kernel, scoped for separate work.
This PR is infrastructure / future-proofing: removes a stale correctness-guard, locks bit-equality with explicit
__fmaf_rn, and tightens a latent PR5 gap (DS4_CUDA_STRICT_BATCHEDnow works at any block count instead of falling through to a non-bit-equal kernel aboveblocks=32).What this PR is and isn't
blocks <= 32uconstraint on the batch_warp8 dispatcher (line 6158) soDS4_CUDA_STRICT_BATCHEDstrictly routes through batch_warp8 at any block count instead of falling through tomatmul_q8_0_preq_kernel(which uses a 256-thread shared-mem tree reduction and is NOT bit-equal to N=1 warp8).Bit-equality argument
For both
matmul_q8_0_preq_warp8_kernel(N=1 reference) andmatmul_q8_0_preq_batch_share_warp_kernel<N_TOK>(this PR's rewritten kernel), for every output row at any block count:b in {lane, lane+32, lane+64, ...}dot_i8_block(qs, xqb, bn, use_dp4a)acc += __half2float(*scale_h) * xscale[b] * (float)dot(compiler emitsfma(scale*xscale, dot, acc)under--fmad=true)const float s = wscale * xs; acc[t] = __fmaf_rn(s, (float)dot, acc[t])— same SASS by constructionwarp_sum_f32(butterfly shuffle, 16→8→4→2→1)The previous gate was set out of caution that the compiler might choose a different FMA contraction for the share-warp's t-loop body. Locking the form with
__fmaf_rnremoves that degree of freedom.Bench (DGX Spark, ds4flash, n=256, "knight")
--mtp(combined-forward)DS4_MTP_STRICT=1 --mtp(canonical)Consistent with the dispatcher-trace finding that share-warp is dead code in combined-forward.
Dispatcher trace evidence
Instrumented
cuda_matmul_q8_0_tensor_labeledwith a per-calln_toklog. During 16 tokens of MTP generation:Combined-forward N=2 (K=1 + first_token = 2-row) and prefix-1/prefix-2 captures translate to n_tok=4 calls at the F16 dispatcher, not the Q8 one.
What pushes MTP > plain (out of scope, scoped for separate PR)
The actual lever is at the F16 dispatcher (line 6320). cuBLAS GemmEx pads M=4 → 16 (wasting 12/16 = 75% of M-axis work) and uses F16 accumulation which drifts from F32 N=1 reference. A
matmul_f16_share_warp_kernel<N_TOK>analogous to the Q8 one would:matmul_f16_kernelfor row 0 (same construction as this PR)That's the next PR to actually push MTP perf over plain.
Tested against
make clean && make cuda-spark— clean, no warningsmake cpu— clean./ds4_test --all— only pre-existing--logprob-vectors short_code_completionfailure (same asupstream/main, PR5/6/7).metal-tensor-equivalencepasses on this run withcapture_fail=0 logits_fail=0 greedy_fail=0 top1_mismatch=0.--mtp) — byte-identical to PR7--mtpdecode — byte-identical to PR7DS4_MTP_STRICT=1 --mtp— unchanged (gate above already routes canonical)DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.ggufDeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.ggufTest stability note
ds4_test --metal-tensor-equivalencehas intrinsic engine-level non-determinism on this hardware: 10-run pass-rate variance of 30%-80% on PR7 baseline with no PR8 changes (sample runs: 7/10, 8/10, 4/5, 3/10 on identical code). The same flakiness pattern reproduces on PR8. This is not introduced by PR8 — it appears to be intrinsic to the CUDA backend's batched-N path (likely atomic ops in routed-MoE or cuBLAS workspace scheduling). Separate diagnostic work, out of scope here.AGENT.md compliance
DS4_CUDA_NO_Q8_SHARE_BATCH=1opt-out preserved.Out of scope / follow-ups