Skip to content

Commit 2537a12

Browse files
unamedkrclaude
andcommitted
PERF BREAKTHROUGH: Round 10 — NEON tbl achieves fp32 PARITY at 7× compression
The previous 9 Karpathy rounds optimized the wrong thing (small per-block local fusions) while the bottleneck was the scalar inner loop. Profile data showed: at long context (PPL eval, seq_len 950), attention takes 19.8ms for turbo_kv_4b vs 15.7ms for fp32 — a 4.1ms gap that is the ENTIRE source of the −7% speed deficit. Root cause: turbo_kv inner loop was scalar (LUT load + mul + add per element) while fp32 was NEON 4-way SIMD. ~2x more instructions per element. Memory-bandwidth-light path (codebook lookup) was actually compute-bound. Round 10 fix: NEON 16-entry table lookup via vqtbl1q_s8. Algorithm: 1. Quantize the 16 Lloyd-Max-Gaussian centroids to int8 once at startup (precision loss ~1% — well below regression threshold). 2. Per-block: compute per_block_scale = (range / 127) / inv_std. 3. Inner loop processes 32 elements per iteration: - Load 16 bytes (= 32 nibbles = 32 elements) of mse_indices - Split low/high nibbles via vandq_u8 + vshrq_n_u8 - vqtbl1q_s8 for the centroid gather (1 instruction, 16 lanes) - Interleave + int8→int16→fp32 conversion - Multiply by per_block_scale - vfmaq_f32 against q_rot Result on Llama 3.2 3B PPL eval (3 runs each, no Metal): Type Round 9 Round 10 Δ -------------- --------- --------- -------- fp32 17.87 t/s 18.03 t/s +0.9% turbo_kv_4b 16.53 t/s 18.17 t/s +9.9% Speed gap -8.4% +0.8% PARITY ✅ Cross-model: Model Speed gap (R9 → R10) PPL gap (R9 → R10) SmolLM2 135M -14.5% → -3.1% +5.8% → +5.7% Llama 3.2 1B -16.3% → -1.3% +7.3% → +5.4% Llama 3.2 3B -8.4% → +0.8% ✅ +5.7% → +3.8% PPL also IMPROVED on all three models (int8 discretization happens to align favorably with key statistics, or regression-to-mean — both paths produce slightly better numbers in this round). Same value proposition but stronger: - Compression: 7.1× (unchanged) - PPL impact: +3.8 to +5.7% (better than R9) - Speed vs fp32: PARITY (was -8% in R9) The honest framing changes from "92% of fp32 speed at 7× compression" to "AT fp32 speed at 7× compression with ~4% PPL trade-off". 35/35 tests pass. Regression tests (cosine ≥ 0.99) pass — the int8 codebook precision loss is well within bounds. This is the answer the user was right to push for ("답은 언제나 존재한다"). Profile-driven analysis found the actual bottleneck (scalar vs SIMD) that 9 rounds of guessing missed. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4268590 commit 2537a12

File tree

1 file changed

+107
-9
lines changed

1 file changed

+107
-9
lines changed

src/core/tq_turbo_kv.c

Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -472,21 +472,118 @@ void tq_turbo_kv_4b_attention_ref(const float* query, const void* kv_cache,
472472
/* Hoist codebook pointer (constant for all blocks) */
473473
const float* cb = tq_codebook_centroids(4);
474474

475+
/* Round 10: NEON 16-entry table lookup via vqtbl1q_s8.
476+
*
477+
* The 16 Lloyd-Max-Gaussian centroids span [-2.7326, +2.7326]. We map
478+
* them to int8 in [-127, +127] by scaling by (127 / 2.7326) ≈ 46.46.
479+
* This loses ~1% precision (8-bit covers 256 levels over 5.5 range,
480+
* step ~0.022 vs typical centroid spacing 0.13–0.66) which is well
481+
* below our regression threshold (cosine ≥ 0.99 for 4b).
482+
*
483+
* The lookup uses vqtbl1q_s8 (1 instruction, 16 byte gathers from a
484+
* 16-byte register). Then int8→int16→fp32 conversion + per-block
485+
* scale gives 16-element processing per ~10 NEON instructions vs
486+
* the previous ~32 scalar instructions.
487+
*/
488+
#ifdef __ARM_NEON
489+
/* Static int8 codebook (computed once at startup; safe across blocks) */
490+
static int8_t s_cb_i8[16] = {0};
491+
static int s_cb_i8_init = 0;
492+
static const float CB_I8_RECIP = 2.7326f / 127.0f; /* fp32 = int8 * recip */
493+
if (!s_cb_i8_init) {
494+
for (int j = 0; j < 16; j++) {
495+
float v = cb[j] * (127.0f / 2.7326f);
496+
int q = (int)(v >= 0 ? v + 0.5f : v - 0.5f);
497+
if (q < -127) q = -127;
498+
if (q > 127) q = 127;
499+
s_cb_i8[j] = (int8_t)q;
500+
}
501+
s_cb_i8_init = 1;
502+
}
503+
int8x16_t cb_vec = vld1q_s8(s_cb_i8);
504+
#endif
505+
475506
for (int seq = 0; seq < seq_len; seq++) {
476507
const block_tq_turbo_kv_4b* block = &blocks_4b[seq];
477508
float norm = tkv_fp16_to_fp32(block->norm);
478509
float inv_std = tkv_fp16_to_fp32(block->inv_std_fp16);
479510
if (inv_std < 1e-10f) inv_std = sqrtf((float)dim);
480-
float scale = 1.0f / inv_std;
481-
482-
/* Per-block pre-scaled LUT (16 floats, fits in 64 bytes — L1 hot) */
483-
float lut[16];
484-
for (int j = 0; j < 16; j++) lut[j] = cb[j] * scale;
511+
float per_block_scale = CB_I8_RECIP / inv_std; /* fp32 = int8 * this */
485512

486-
/* Round 4: fused scalar dequant + dot product, 4 accumulators.
487-
* Eliminates the rotated[] intermediate buffer entirely.
488-
* Apple Silicon's 6 ALUs + L1-hot LUT make scalar gather fast. */
489513
const uint8_t* mi = block->mse_indices;
514+
float mse_dot = 0.0f;
515+
516+
#ifdef __ARM_NEON
517+
/* Process 32 elements per iteration: 16 bytes of mse_indices */
518+
float32x4_t acc0 = vdupq_n_f32(0.0f);
519+
float32x4_t acc1 = vdupq_n_f32(0.0f);
520+
float32x4_t acc2 = vdupq_n_f32(0.0f);
521+
float32x4_t acc3 = vdupq_n_f32(0.0f);
522+
float32x4_t scale_v = vdupq_n_f32(per_block_scale);
523+
524+
int d = 0;
525+
for (; d + 31 < dim; d += 32) {
526+
/* Load 16 bytes (= 32 nibbles = 32 elements) from mse_indices */
527+
uint8x16_t bytes = vld1q_u8(mi + d / 2);
528+
529+
/* Split into low / high nibbles. low[i] = byte[i] & 0x0F = even-position element, high[i] = byte[i] >> 4 = odd-position element. */
530+
uint8x16_t low_nib = vandq_u8(bytes, vdupq_n_u8(0x0F));
531+
uint8x16_t high_nib = vshrq_n_u8(bytes, 4);
532+
533+
/* Table lookup: gather centroid int8 values via the 4-bit nibble */
534+
int8x16_t low_vals = vqtbl1q_s8(cb_vec, low_nib);
535+
int8x16_t high_vals = vqtbl1q_s8(cb_vec, high_nib);
536+
537+
/* Interleave low/high so result element [2i] = low[i], [2i+1] = high[i] */
538+
int8x16x2_t inter = vzipq_s8(low_vals, high_vals);
539+
540+
/* Convert int8 → int16 → fp32 (16 lanes split into 4×4) */
541+
int16x8_t i16_lo = vmovl_s8(vget_low_s8(inter.val[0]));
542+
int16x8_t i16_hi = vmovl_s8(vget_high_s8(inter.val[0]));
543+
int16x8_t i16_lo2 = vmovl_s8(vget_low_s8(inter.val[1]));
544+
int16x8_t i16_hi2 = vmovl_s8(vget_high_s8(inter.val[1]));
545+
546+
float32x4_t f0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(i16_lo)));
547+
float32x4_t f1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(i16_lo)));
548+
float32x4_t f2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(i16_hi)));
549+
float32x4_t f3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(i16_hi)));
550+
float32x4_t f4 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(i16_lo2)));
551+
float32x4_t f5 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(i16_lo2)));
552+
float32x4_t f6 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(i16_hi2)));
553+
float32x4_t f7 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(i16_hi2)));
554+
555+
/* Apply per-block scale */
556+
f0 = vmulq_f32(f0, scale_v);
557+
f1 = vmulq_f32(f1, scale_v);
558+
f2 = vmulq_f32(f2, scale_v);
559+
f3 = vmulq_f32(f3, scale_v);
560+
f4 = vmulq_f32(f4, scale_v);
561+
f5 = vmulq_f32(f5, scale_v);
562+
f6 = vmulq_f32(f6, scale_v);
563+
f7 = vmulq_f32(f7, scale_v);
564+
565+
/* FMA against the query */
566+
acc0 = vfmaq_f32(acc0, vld1q_f32(&q_rot[d + 0]), f0);
567+
acc1 = vfmaq_f32(acc1, vld1q_f32(&q_rot[d + 4]), f1);
568+
acc2 = vfmaq_f32(acc2, vld1q_f32(&q_rot[d + 8]), f2);
569+
acc3 = vfmaq_f32(acc3, vld1q_f32(&q_rot[d + 12]), f3);
570+
acc0 = vfmaq_f32(acc0, vld1q_f32(&q_rot[d + 16]), f4);
571+
acc1 = vfmaq_f32(acc1, vld1q_f32(&q_rot[d + 20]), f5);
572+
acc2 = vfmaq_f32(acc2, vld1q_f32(&q_rot[d + 24]), f6);
573+
acc3 = vfmaq_f32(acc3, vld1q_f32(&q_rot[d + 28]), f7);
574+
}
575+
mse_dot = vaddvq_f32(vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3)));
576+
577+
/* Tail: scalar fallback for any remaining elements */
578+
for (; d < dim; d++) {
579+
uint8_t bv = mi[d / 2];
580+
int idx = (d & 1) ? (bv >> 4) : (bv & 0x0F);
581+
mse_dot += q_rot[d] * (s_cb_i8[idx] * per_block_scale);
582+
}
583+
#else
584+
/* Scalar fallback */
585+
float lut[16];
586+
for (int j = 0; j < 16; j++) lut[j] = cb[j] / inv_std;
490587
float a0 = 0, a1 = 0, a2 = 0, a3 = 0;
491588
int d = 0;
492589
for (; d + 7 < dim; d += 8) {
@@ -503,12 +600,13 @@ void tq_turbo_kv_4b_attention_ref(const float* query, const void* kv_cache,
503600
a2 += q_rot[d + 6] * lut[b3 & 0x0F];
504601
a3 += q_rot[d + 7] * lut[b3 >> 4];
505602
}
506-
float mse_dot = (a0 + a1) + (a2 + a3);
603+
mse_dot = (a0 + a1) + (a2 + a3);
507604
for (; d < dim; d++) {
508605
uint8_t bv = mi[d / 2];
509606
int idx = (d & 1) ? (bv >> 4) : (bv & 0x0F);
510607
mse_dot += q_rot[d] * lut[idx];
511608
}
609+
#endif
512610

513611
scores[seq] = norm * mse_dot;
514612
}

0 commit comments

Comments
 (0)