Skip to content

Fix kernel dispatch naming in quant_sdpa_vector_2pass#1

Open
andershansson wants to merge 20 commits intoCC-Yeh:quantized_sdpafrom
programlabbet:fix-quant-sdpa-vector-2pass-kname
Open

Fix kernel dispatch naming in quant_sdpa_vector_2pass#1
andershansson wants to merge 20 commits intoCC-Yeh:quantized_sdpafrom
programlabbet:fix-quant-sdpa-vector-2pass-kname

Conversation

@andershansson
Copy link
Copy Markdown

Summary

Fixes the kernel-dispatch naming bug @Thump604 reported in #3026 ([metal::Device] Unable to load function quant_sdpa_vector_2pass_1_float_128_16). One-line change in mlx/backend/metal/scaled_dot_product_attention.cpp, plus the already-existing test_quantized_sdpa_affine tests all start passing.

The bug

In quant_sdpa_vector_2pass (around line 660) the dispatch builds the kernel name from v.shape(-1):

kname += std::to_string(v.shape(-1));

For a quantized v, v.shape(-1) returns the packed uint32 dimension, not the logical value head_dim. The kernels in scaled_dot_product_attention.metal are instantiated with the logical dims (64_64, 128_128, 256_256), so the dispatch looks up names that don't exist:

bits logical head_dim packed dim dispatch looks up
4 128 16 quant_sdpa_vector_2pass_1_float_128_16
6 128 24 quant_sdpa_vector_2pass_1_float_128_24
8 128 32 quant_sdpa_vector_2pass_1_float_128_32

(The second pass at line 776 is fine — it uses out.shape(-1) which is the unquantized output shape.)

The fix

The function already has int bits as a parameter, so the logical dim is recoverable from the packed dim:

-  kname += std::to_string(v.shape(-1));
+  // v is stored as packed uint32; the kernel is instantiated with the logical
+  // value head_dim. Recover it: logical_dim = packed_dim * 32 / bits.
+  kname += std::to_string(v.shape(-1) * 32 / bits);

The formula works cleanly for all currently-supported affine bit widths:

bits packed * 32 / bits
4 16 * 32 / 4 = 128
6 24 * 32 / 6 = 128
8 32 * 32 / 8 = 128

Verification

Built from source on M4 Max, macOS 26, Python 3.12, with the downloaded Metal Toolchain 17E188. Against commit 6291e80.

Before the fix (test_quantized_sdpa_affine in python/tests/test_quantized.py):

python/tests/test_quantized.py::TestQuantized::test_quantized_sdpa_affine
  SUBFAILED(bits=4) RuntimeError: Unable to load function quant_sdpa_vector_2pass_1_float_128_16
  SUBFAILED(bits=6) RuntimeError: Unable to load function quant_sdpa_vector_2pass_1_float_128_24
  SUBFAILED(bits=8) RuntimeError: Unable to load function quant_sdpa_vector_2pass_1_float_128_32

After the fix:

python/tests/test_quantized.py::TestQuantized::test_quantized_sdpa_affine
  SUBPASSED(bits=4)
  SUBPASSED(bits=6)
  SUBPASSED(bits=8)
  PASSED

===================== 1 passed, 3 subtests passed in 0.18s =====================

All 10 quantized SDPA test methods (test_quantized_sdpa, test_quantized_sdpa_affine, test_quantized_sdpa_masked, test_quantized_sdpa_affine_masked, test_quantized_sdpa_sinks, test_quantized_sdpa_masked_with_sinks, test_quantized_sdpa_affine_masked_with_sinks, test_quantized_sdpa_causal, test_quantized_sdpa_affine_causal, test_quantized_sdpa_causal_with_array_mask_error) now pass — 54 subtests total.

Also independently verified numerical correctness against a dequantize-then-fp16-SDPA reference path on realistic decode shapes (B=1, n_h=32, n_kv=8, head_dim ∈ {64, 128, 256}, bits ∈ {4, 8}, context lengths 2048 and 8192) — cosine similarity 1.000000 across every configuration.

Context

We're maintaining turboquant-mlx, a TurboQuant-style KV cache compression library built on top of MLX, and we're tracking ml-explore#3026 closely. Happy to help move this PR forward with additional M4 Max benchmarks or cross-architecture correctness tests (Qwen2.5 with biased k_proj, Phi-3.5 with head_dim=96, gpt-oss alternating sliding/full attention) once the main PR is unblocked.

Test plan

  • python -m pytest python/tests/test_quantized.py -k "sdpa"all 10 tests, 54 subtests pass on M4 Max
  • Numerical correctness vs fp16 reference across head_dim ∈ {64, 128, 256}, bits ∈ {4, 8}
  • Verified bug reproduces on the pre-fix branch (all three bit-width subtests fail)
  • Verified fix is surgical — only 2 lines changed, single function touched

CC-Yeh and others added 20 commits February 7, 2026 21:15
The dispatch in quant_sdpa_vector_2pass builds the kernel name using
v.shape(-1), but for quantized inputs this returns the packed uint32
dimension (16 for 4-bit packing head_dim=128, 24 for 6-bit, 32 for 8-bit)
rather than the logical value head_dim that the kernels are instantiated
with.

The kernels in scaled_dot_product_attention.metal are instantiated with
logical dims (64_64, 128_128, 256_256), so dispatch fails with:

    [metal::Device] Unable to load function
    quant_sdpa_vector_2pass_1_float_128_16
    Function quant_sdpa_vector_2pass_1_float_128_16 was not found
    in the library

The function already has 'int bits' as a parameter, so the logical dim
is recoverable: logical_dim = packed_dim * 32 / bits. This formula works
cleanly for all currently-supported affine bit widths (4, 6, 8):

    bits=4: packed=16 -> logical=16*32/4=128
    bits=6: packed=24 -> logical=24*32/6=128
    bits=8: packed=32 -> logical=32*32/8=128

Verified by running the existing test_quantized_sdpa_affine test cases
which were failing on all three bit widths before the fix and now pass.
All 10 quantized SDPA test methods (54 subtests) in test_quantized.py
pass after this one-line fix.

Reproduction and verification performed on M4 Max, macOS 26,
Python 3.12, built from source against commit 6291e80 with the
downloaded Metal Toolchain 17E188.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants