From 015ab43bef41b987e1e651c0d11a3bd818ff69f0 Mon Sep 17 00:00:00 2001 From: Trevor Strieber Date: Fri, 22 May 2026 18:44:29 -0700 Subject: [PATCH 1/2] cuda: shared-weight Q8 matmul kernel for small-N batched (MTP verifier) At small batch (n_tok = 2..4) the per-token batch_warp8 kernel re-reads each weight row N times. This commit adds matmul_q8_0_preq_batch_share_warp_kernel -- a templated warp kernel that reads each row of weights exactly ONCE per warp and computes N_TOK partial dot products against N_TOK token inputs, amortizing weight bandwidth N-fold. N_TOK = {2, 3, 4} are instantiated. Gating (cuda_matmul_q8_0_tensor_labeled): the share kernel only replaces the existing batch_warp8 fallback path; it does NOT replace cuBLAS Gemm on cached F32/F16 weights. Conditions to fire: n_tok in [2, 4] AND blocks <= 32 AND (same blocks cap as batch_warp8) no F32 cuBLAS cache hit for this weight AND no F16 cuBLAS cache hit for this weight When any cuBLAS cache is present the reference path stays in charge of that weight, preserving byte-equality with upstream/main on all q8 matmuls that hit cuBLAS (attn_output_a/b, ffn_*_shexp, attn_q_b, the 4096x{2048,1024,512} / 2048x4096 shapes). The share kernel reads the same weights, performs the same per-block FMA (wscale * xs * dot in the same operand order), and uses the same warp_sum_f32 reduction as batch_warp8, so it is bit-identical to batch_warp8 on the weights it serves. Disable with DS4_CUDA_NO_Q8_SHARE_BATCH=1. Also respects DS4_CUDA_NO_Q8_BATCH_WARP=1 (since the fallback this replaces is gated by the same flag). Per-layer impact (DS4_METAL_LAYER_STAGE_PROFILE at n_tokens=2): per-layer total ~2.10ms -> ~1.60ms (-24%) on stages where cuBLAS is not used. Bench (DGX Spark, ds4flash.gguf + MTP-Q4K, n=256, --mtp-draft 2, 3-run avg, "knight" prompt): K=2 batched verifier ~8-9 t/s with this change, up from ~8.0 baseline. Byte-equality ------------- Share-warp kernel is mathematically and numerically identical to the batch_warp8 kernel for the weights it serves. Plain decode and MTP K=2 batched-verifier STRICT outputs are byte-equal to upstream/main. LOC --- ds4_cuda.cu: +102/-7 (one new templated kernel, one dispatch branch, three template instantiations). --- ds4_cuda.cu | 102 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/ds4_cuda.cu b/ds4_cuda.cu index 4821b841..33da9b9d 100644 --- a/ds4_cuda.cu +++ b/ds4_cuda.cu @@ -2127,6 +2127,52 @@ __global__ static void matmul_q8_0_preq_batch_warp8_kernel( if (lane == 0) out[tok * out_dim + row] = acc; } +/* Shared-weight variant: each warp reads one row of weights once and + * computes N_TOK dot products against N_TOK different token inputs. + * Cuts weight-bandwidth N-fold vs the per-token kernel above. Used for + * small batches (MTP spec verify at N=2-4) where cuBLAS GEMM pads the + * tensor-core M tile (16 for f16) and wastes ~7/8 of the M-axis work. */ +template +__global__ static void matmul_q8_0_preq_batch_share_warp_kernel( + float *out, + const unsigned char *w, + const int8_t *xq, + const float *xscale, + uint64_t in_dim, + uint64_t out_dim, + uint64_t blocks, + int use_dp4a) { + const uint64_t row = (uint64_t)blockIdx.x * 8u + (threadIdx.x >> 5u); + const uint32_t lane = threadIdx.x & 31u; + if (row >= out_dim) return; + + const unsigned char *wr = w + row * blocks * 34; + float acc[N_TOK]; + #pragma unroll + for (int t = 0; t < N_TOK; t++) acc[t] = 0.0f; + + for (uint64_t b = lane; b < blocks; b += 32u) { + const uint64_t i0 = b * 32; + const uint64_t bn = in_dim - i0 < 32 ? in_dim - i0 : 32; + const __half *scale_h = (const __half *)(wr + b * 34); + const int8_t *qs = (const int8_t *)(wr + b * 34 + 2); + const float wscale = __half2float(*scale_h); + #pragma unroll + for (int t = 0; t < N_TOK; t++) { + const int8_t *xqb = xq + (uint64_t)t * blocks * 32 + b * 32; + const float xs = xscale[(uint64_t)t * blocks + b]; + int dot = dot_i8_block(qs, xqb, bn, use_dp4a); + acc[t] += wscale * xs * (float)dot; + } + } + #pragma unroll + for (int t = 0; t < N_TOK; t++) acc[t] = warp_sum_f32(acc[t]); + if (lane == 0) { + #pragma unroll + for (int t = 0; t < N_TOK; t++) out[(uint64_t)t * out_dim + row] = acc[t]; + } +} + __global__ static void dequant_q8_0_to_f16_kernel( __half *out, const unsigned char *w, @@ -5947,6 +5993,62 @@ static int cuda_matmul_q8_0_tensor_labeled(ds4_gpu_tensor *out, const void *mode out->bytes < n_tok * out_dim * sizeof(float)) return 0; const char *wptr = cuda_model_range_ptr(model_map, weight_offset, weight_bytes, "q8_0"); if (!wptr) return 0; + /* Small-batch shared-weight path: at n_tok = 2..4, the hand-rolled warp + * kernel that reads each weight row once and computes N dot products + * against N tokens replaces the per-token batch_warp8 kernel and is + * bit-identical to it (same blocks, same per-block FMA order, same warp + * reduction). Gate to the same conditions under which batch_warp8 + * would have been chosen: no F32/F16 cuBLAS cache hit and blocks <= 32. + * Otherwise fall through so cuBLAS Gemm (the existing reference path) + * stays in charge for that weight. Disable with + * DS4_CUDA_NO_Q8_SHARE_BATCH=1. */ + if (n_tok >= 2u && n_tok <= 4u && blocks <= 32u && + getenv("DS4_CUDA_NO_Q8_SHARE_BATCH") == NULL && + getenv("DS4_CUDA_NO_Q8_BATCH_WARP") == NULL && + (!g_cublas_ready || + (cuda_q8_f32_ptr(model_map, weight_offset, weight_bytes, in_dim, out_dim, label) == NULL && + cuda_q8_f16_ptr(model_map, weight_offset, weight_bytes, in_dim, out_dim, label) == NULL))) { + const uint64_t share_xq_bytes = n_tok * blocks * 32u; + const uint64_t share_scale_offset = (share_xq_bytes + 15u) & ~15ull; + const uint64_t share_tmp_bytes = share_scale_offset + n_tok * blocks * sizeof(float); + void *share_tmp = cuda_tmp_alloc(share_tmp_bytes, "q8_0 share prequant"); + if (share_tmp) { + int8_t *share_xq = (int8_t *)share_tmp; + float *share_xscale = (float *)((char *)share_tmp + share_scale_offset); + const int share_dp4a = cuda_q8_use_dp4a(); + dim3 share_qgrid((unsigned)blocks, (unsigned)n_tok, 1); + quantize_q8_0_f32_kernel<<>>(share_xq, share_xscale, + (const float *)x->ptr, + in_dim, blocks); + if (cuda_ok(cudaGetLastError(), "matmul_q8_0 share quantize launch")) { + const unsigned grid_x = ((unsigned)out_dim + 7u) / 8u; + bool launched = false; + if (n_tok == 2u) { + matmul_q8_0_preq_batch_share_warp_kernel<2><<>>( + (float *)out->ptr, + reinterpret_cast(wptr), + share_xq, share_xscale, in_dim, out_dim, blocks, share_dp4a); + launched = true; + } else if (n_tok == 3u) { + matmul_q8_0_preq_batch_share_warp_kernel<3><<>>( + (float *)out->ptr, + reinterpret_cast(wptr), + share_xq, share_xscale, in_dim, out_dim, blocks, share_dp4a); + launched = true; + } else if (n_tok == 4u) { + matmul_q8_0_preq_batch_share_warp_kernel<4><<>>( + (float *)out->ptr, + reinterpret_cast(wptr), + share_xq, share_xscale, in_dim, out_dim, blocks, share_dp4a); + launched = true; + } + if (launched && cuda_ok(cudaGetLastError(), "matmul_q8_0 share warp launch")) { + return 1; + } + } + } + /* Falls through to cuBLAS / fallback if anything above failed. */ + } if (g_cublas_ready && n_tok > 1) { const float *w_f32 = cuda_q8_f32_ptr(model_map, weight_offset, weight_bytes, in_dim, out_dim, label); if (w_f32) { From 65d8182bd24b110a7d52e7dbbd005c02c1d0f64f Mon Sep 17 00:00:00 2001 From: Trevor Strieber Date: Fri, 22 May 2026 09:54:13 -0700 Subject: [PATCH 2/2] cuda: fuse head_rms_norm + rope_tail on Q (decode + batched paths) Swaps the back-to-back ds4_gpu_head_rms_norm_tensor + ds4_gpu_rope_tail_tensor pair on the Q tensor for the existing ds4_gpu_head_rms_norm_rope_tail_tensor fused kernel. Mainline already defined the fused kernel in ds4_cuda.cu but only the standalone function - no callers were using it on the decode hot path. Sites updated: - decode q_path: single-token raw-SWA decode path - batched q_path: batched verifier / prefill path Savings per call: one DRAM round trip on the Q tensor and one kernel launch per layer. Numerical-parity fix vs sequential reference -------------------------------------------- The naive fused kernel reads tail[i] from memory, multiplies by the rms scale in-register, then immediately combines with the RoPE cos/sin into the rotated output. Under --use_fast_math nvcc was folding the scale multiply into the c/s multiply via FMA, which differs from the sequential reference (head_rms_norm writes scale*tail[i] back to fp32 memory, then rope_tail reads it back). This ULP-scale per-pair drift compounds across 30 layers and high pos0 values, flipping argmax decisions on long-context prompts. Fix: wrap the scale multiply in __fmul_rn (single-rounded fp32 multiply, hard barrier to FMA contraction). x0 = __fmul_rn(tail[i], scale) and x1 = __fmul_rn(tail[i+1], scale) reproduce the bit pattern that the sequential path's memory store-then-load produces, eliminating the long- context drift. Verification ------------ - Plain decode 64-token "knight" output: byte-equal to upstream/main. - MTP K=2 batched-verifier STRICT 64-token output: byte-equal to upstream/main. - ds4_test --all: tensor-equivalence summary matches PR2 base distribution (pre-existing intermittent long_memory_archive atomicAdd non-determinism in MoE prefill at n_tokens>=128 is unchanged; not introduced by this commit). Header: added ds4_gpu_head_rms_norm_rope_tail_tensor declaration to ds4_gpu.h (the wrapper existed in ds4_cuda.cu but was not declared). Bench (DGX Spark, ds4flash.gguf, --temp 0, 32 gen tokens): Plain decode: ~16.3 t/s, no regression vs PR2 base. --- ds4.c | 67 ++++++++++++++++++++++++++--------------------------- ds4_cuda.cu | 11 +++++++-- ds4_gpu.h | 22 ++++++++++++++++++ 3 files changed, 64 insertions(+), 36 deletions(-) diff --git a/ds4.c b/ds4.c index 69763014..cd67c297 100644 --- a/ds4.c +++ b/ds4.c @@ -9617,15 +9617,19 @@ static bool metal_graph_encode_decode_layer( if (ok) { metal_graph_debug_dump_tensor("Qraw", g->q, q_dim, il, pos); } - if (ok) ok = ds4_gpu_head_rms_norm_tensor(g->q, 1, DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_RMS_EPS) != 0; - if (ok) { - metal_graph_debug_dump_tensor("Qnorm", g->q, q_dim, il, pos); - } - if (ok) ok = ds4_gpu_rope_tail_tensor(g->q, 1, DS4_N_HEAD, DS4_N_HEAD_DIM, - DS4_N_ROT, pos, - compressed ? (uint32_t)DS4_ROPE_ORIG_CTX : 0, - false, freq_base, freq_scale, ext_factor, attn_factor, - DS4_ROPE_YARN_BETA_FAST, DS4_ROPE_YARN_BETA_SLOW) != 0; + /* Fused head-rms-norm + RoPE rotation on Q (mainline already has the + * batched variant of this fusion implicit in some paths; the standalone + * fused kernel ds4_gpu_head_rms_norm_rope_tail_tensor saves one DRAM + * round trip and one kernel launch per layer on the decode hot path). + * Mathematically equivalent to the prior two-kernel sequence; FMA + * reordering may produce ULP-scale differences. */ + if (ok) ok = ds4_gpu_head_rms_norm_rope_tail_tensor(g->q, 1, DS4_N_HEAD, DS4_N_HEAD_DIM, + DS4_N_ROT, pos, + compressed ? (uint32_t)DS4_ROPE_ORIG_CTX : 0, + false, freq_base, freq_scale, + ext_factor, attn_factor, + DS4_ROPE_YARN_BETA_FAST, DS4_ROPE_YARN_BETA_SLOW, + DS4_RMS_EPS) != 0; DS4_METAL_PROFILE_DECODE_STAGE("q_path"); if (ok) { metal_graph_debug_dump_tensor("Qcur", g->q, q_dim, il, pos); @@ -11666,35 +11670,30 @@ static bool metal_graph_encode_layer_attention_batch( (uint64_t)n_tokens * q_dim, il, pos0); } DS4_METAL_PROFILE_Q_STAGE("q_b"); - if (ok) ok = ds4_gpu_head_rms_norm_tensor(g->batch_q, - n_tokens, - DS4_N_HEAD, - DS4_N_HEAD_DIM, - DS4_RMS_EPS) != 0; - if (ok) { - metal_graph_debug_dump_tensor("Qnorm", g->batch_q, - (uint64_t)n_tokens * q_dim, il, pos0); - } - DS4_METAL_PROFILE_Q_STAGE("head_norm"); - if (ok) ok = ds4_gpu_rope_tail_tensor(g->batch_q, - n_tokens, - DS4_N_HEAD, - DS4_N_HEAD_DIM, - DS4_N_ROT, - pos0, - compressed ? (uint32_t)DS4_ROPE_ORIG_CTX : 0, - false, - freq_base, - freq_scale, - ext_factor, - attn_factor, - DS4_ROPE_YARN_BETA_FAST, - DS4_ROPE_YARN_BETA_SLOW) != 0; + /* Fused head-rms-norm + RoPE tail on Q (batched path). Replaces the + * head_rms_norm + rope_tail pair that ran sequentially; saves one DRAM + * round-trip and one launch per layer. ULP-scale FMA reordering may + * differ from the sequential pair. */ + if (ok) ok = ds4_gpu_head_rms_norm_rope_tail_tensor(g->batch_q, + n_tokens, + DS4_N_HEAD, + DS4_N_HEAD_DIM, + DS4_N_ROT, + pos0, + compressed ? (uint32_t)DS4_ROPE_ORIG_CTX : 0, + false, + freq_base, + freq_scale, + ext_factor, + attn_factor, + DS4_ROPE_YARN_BETA_FAST, + DS4_ROPE_YARN_BETA_SLOW, + DS4_RMS_EPS) != 0; if (ok) { metal_graph_debug_dump_tensor("Qcur", g->batch_q, (uint64_t)n_tokens * q_dim, il, pos0); } - DS4_METAL_PROFILE_Q_STAGE("rope"); + DS4_METAL_PROFILE_Q_STAGE("head_norm_rope"); DS4_METAL_PROFILE_ATTN_STAGE("q_path"); if (!qkv_rms_fused) { if (ok) ok = metal_graph_matmul_q8_0_named_tensor("attn_kv", diff --git a/ds4_cuda.cu b/ds4_cuda.cu index 33da9b9d..5035c33e 100644 --- a/ds4_cuda.cu +++ b/ds4_cuda.cu @@ -2414,8 +2414,15 @@ __global__ static void head_rms_norm_rope_tail_kernel( float s = sinf(theta) * mscale; if (inverse) s = -s; float *tail = xr + n_nope; - float x0 = tail[i] * scale; - float x1 = tail[i + 1] * scale; + /* Match the sequential (rms-then-rope) numerical path: that path + * stores scale*tail[i] back to fp32 memory before the RoPE rotation + * reads it. Use __fmul_rn to force a single-rounded fp32 multiply + * for the scale step, preventing the compiler from fusing scale*x + * into the c/s multiply via FMA. Without this barrier the long- + * context (high pos0 -> large theta) drift compounds across layers + * and flips argmax decisions on long_memory_archive. */ + float x0 = __fmul_rn(tail[i], scale); + float x1 = __fmul_rn(tail[i + 1], scale); tail[i] = x0 * c - x1 * s; tail[i + 1] = x0 * s + x1 * c; } diff --git a/ds4_gpu.h b/ds4_gpu.h index fa277716..223fad48 100644 --- a/ds4_gpu.h +++ b/ds4_gpu.h @@ -299,6 +299,28 @@ int ds4_gpu_rope_tail_tensor( float beta_fast, float beta_slow); +/* Fused per-head RMS norm + RoPE tail rotation on Q-style tensors. + * Mathematically equivalent to head_rms_norm_tensor + rope_tail_tensor + * applied back-to-back, but in a single kernel — saves one DRAM + * round-trip + one launch per call. ULP-scale FMA reordering may differ + * from the sequential pair. */ +int ds4_gpu_head_rms_norm_rope_tail_tensor( + ds4_gpu_tensor *x, + uint32_t n_tok, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot, + uint32_t pos0, + uint32_t n_ctx_orig, + bool inverse, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + float eps); + /* Release decode fused KV finalizer: after the standalone RoPE kernel, this * performs DS4's FP8 non-RoPE KV round trip and writes the F16-rounded raw * attention cache row in one dispatch. */