Support group_size=64 for affine quantized SDPA#2
Merged
CC-Yeh merged 1 commit intoCC-Yeh:quantized_sdpafrom Apr 18, 2026
Merged
Conversation
Add Affine dispatch entries for group_size=64 at bits={4,6,8} and
relax the validation in quantized_scaled_dot_product_attention.
This matches the default produced by mx.quantize(mode="affine") and
the kv_group_size=64 default used by mlx-lm, so users following the
MLX/mlx-lm conventions no longer hit an error when using fused
quantized attention.
Benchmarks (M4, B=1 H=32 D=128 Lq=1, affine 4-bit):
Context gs=32 fused gs=64 fused speedup
32K 50 us 41 us +22%
64K 95 us 82 us +16%
128K 176 us 152 us +16%
gs=64 is faster at long context because it has half the scale/bias
memory traffic.
Costs:
mlx.metallib: 128,161,428 -> 128,233,236 bytes (+0.056%)
libmlx.dylib: unchanged
Existing 10 test_quantized_sdpa* tests continue to pass (54 subtests).
aa69849 to
e1c923e
Compare
Author
|
fixed the merge conflicts @CC-Yeh |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Hey! Really nice work on the ml-explore#3026 PR I hope its not buried. I was trying it out with mlx-lm and hit a small issue that blocks the default user flow, so here's a quick fix on top of yours.
Summary
Adds
group_size=64support to the affine path ofquantized_scaled_dot_product_attention. This matches the default producedby
mx.quantize(mode="affine")and thekv_group_size=64default used bymlx-lmfor KV-cache quantization.Motivation
The current PR only dispatches the affine kernel at
group_size=32, but therest of the MLX / mlx-lm stack defaults to
group_size=64for affinequantization:
mx.quantize(x, mode="affine")→ defaults togroup_size=64python/mlx/nn/layers/quantized.py_defaults_for_mode("affine")→(64, 4)mlx-lm/generate.py→kv_group_size: int = 64mlx-lm/models/base.py→quantized_scaled_dot_product_attention(..., group_size: int = 64, bits: int = 8)So the natural user flow —
mx.quantize(k, mode="affine")followed byquantized_scaled_dot_product_attention(..., mode="affine")— fails todaywith
ValueError: Affine mode supports group_size 32 but received 64.Benchmarks
Decode-style,
B=1, H=32, D=128, Lq=1, affine 4-bit, M4:fp16fusedgs=32fusedgs=64fused (this PR)gs=64is actually faster thangs=32at long context because it has halfthe scale/bias memory traffic (one pair per 64 elements vs one per 32). At
128K tokens × 32 heads this is 16 MB vs 32 MB of scale/bias data moved per
decode step.
For comparison against the ops-based fallback path (what users hit today
when passing
gs=64):gs=64fallbackgs=64fused (this PR)Changes
mlx/backend/metal/kernels/sdpa_vector.h: add three dispatch entries forAffinewithgroup_size=64atbits ∈ {4, 6, 8}.mlx/fast.cpp: relax the affine validation to acceptgroup_sizeofeither 32 or 64.
Cost
mlx.metallibgrows 128,161,428 → 128,233,236 bytes (+72 KB,+0.056%).
libmlx.dylibunchanged.scaled_dot_product_attention.metalcompile goes0.46s → 0.62s. In a full
pip install -e .build this is withinmeasurement noise (~125s either way).
Testing
test_quantized_sdpa*tests pass (54 subtests).max_err < 0.001forboth
gs=32andgs=64acrossbits ∈ {4, 6, 8},D ∈ {64, 128, 256}.Notes
SDPA call to actually load a kernel. With only the current PR head, all
affine kernel dispatches fail with "Unable to load function
quant_sdpa_vector_2pass_1_..._16" before this change can be exercised.mxfp4,mxfp8,nvfp4group sizes arefixed by the format and remain unchanged.