Skip to content

Support group_size=64 for affine quantized SDPA#2

Merged
CC-Yeh merged 1 commit intoCC-Yeh:quantized_sdpafrom
dogukanveziroglu:fix/sdpa-affine-group64-b
Apr 18, 2026
Merged

Support group_size=64 for affine quantized SDPA#2
CC-Yeh merged 1 commit intoCC-Yeh:quantized_sdpafrom
dogukanveziroglu:fix/sdpa-affine-group64-b

Conversation

@dogukanveziroglu
Copy link
Copy Markdown

@dogukanveziroglu dogukanveziroglu commented Apr 14, 2026

Hey! Really nice work on the ml-explore#3026 PR I hope its not buried. I was trying it out with mlx-lm and hit a small issue that blocks the default user flow, so here's a quick fix on top of yours.

Summary

Adds group_size=64 support to the affine path of
quantized_scaled_dot_product_attention. This matches the default produced
by mx.quantize(mode="affine") and the kv_group_size=64 default used by
mlx-lm for KV-cache quantization.

Motivation

The current PR only dispatches the affine kernel at group_size=32, but the
rest of the MLX / mlx-lm stack defaults to group_size=64 for affine
quantization:

  • mx.quantize(x, mode="affine") → defaults to group_size=64
  • python/mlx/nn/layers/quantized.py _defaults_for_mode("affine")(64, 4)
  • mlx-lm/generate.pykv_group_size: int = 64
  • mlx-lm/models/base.pyquantized_scaled_dot_product_attention(..., group_size: int = 64, bits: int = 8)

So the natural user flow — mx.quantize(k, mode="affine") followed by
quantized_scaled_dot_product_attention(..., mode="affine") — fails today
with ValueError: Affine mode supports group_size 32 but received 64.

Benchmarks

Decode-style, B=1, H=32, D=128, Lq=1, affine 4-bit, M4:

KV context fp16 fused gs=32 fused gs=64 fused (this PR) gs=64 speedup over gs=32
2K 10 µs 5.1 µs 5.0 µs noise
8K 30 µs 9.0 µs 9.0 µs 0%
32K 231 µs 50 µs 41 µs +22%
64K 437 µs 95 µs 82 µs +16%
128K 879 µs 176 µs 152 µs +16%

gs=64 is actually faster than gs=32 at long context because it has half
the scale/bias memory traffic (one pair per 64 elements vs one per 32). At
128K tokens × 32 heads this is 16 MB vs 32 MB of scale/bias data moved per
decode step.

For comparison against the ops-based fallback path (what users hit today
when passing gs=64):

KV context gs=64 fallback gs=64 fused (this PR) Fused speedup
32K 66 µs 41 µs 1.6×
64K 146 µs 82 µs 1.8×
128K 313 µs 152 µs 2.1×

Changes

  1. mlx/backend/metal/kernels/sdpa_vector.h: add three dispatch entries for
    Affine with group_size=64 at bits ∈ {4, 6, 8}.
  2. mlx/fast.cpp: relax the affine validation to accept group_size of
    either 32 or 64.

Cost

  • Binary: mlx.metallib grows 128,161,428 → 128,233,236 bytes (+72 KB,
    +0.056%
    ). libmlx.dylib unchanged.
  • Compile: isolated scaled_dot_product_attention.metal compile goes
    0.46s → 0.62s. In a full pip install -e . build this is within
    measurement noise (~125s either way).

Testing

  • All 10 existing test_quantized_sdpa* tests pass (54 subtests).
  • Verified numerical correctness vs fp16 reference: max_err < 0.001 for
    both gs=32 and gs=64 across bits ∈ {4, 6, 8}, D ∈ {64, 128, 256}.
  • Built and tested on M4, macOS 25.4, Python 3.12.

Notes

  • Depends on Fix kernel dispatch naming in quant_sdpa_vector_2pass #1 (kernel-name dispatch fix by @andershansson) for any affine
    SDPA call to actually load a kernel. With only the current PR head, all
    affine kernel dispatches fail with "Unable to load function
    quant_sdpa_vector_2pass_1_..._16" before this change can be exercised.
  • Only affine mode is expanded. mxfp4, mxfp8, nvfp4 group sizes are
    fixed by the format and remain unchanged.

@dogukanveziroglu dogukanveziroglu changed the title Fix/sdpa affine group64 b Support group_size=64 for affine quantized SDPA Apr 14, 2026
@dogukanveziroglu dogukanveziroglu changed the base branch from main to quantized_sdpa April 14, 2026 10:46
Add Affine dispatch entries for group_size=64 at bits={4,6,8} and
relax the validation in quantized_scaled_dot_product_attention.

This matches the default produced by mx.quantize(mode="affine") and
the kv_group_size=64 default used by mlx-lm, so users following the
MLX/mlx-lm conventions no longer hit an error when using fused
quantized attention.

Benchmarks (M4, B=1 H=32 D=128 Lq=1, affine 4-bit):
  Context  gs=32 fused   gs=64 fused   speedup
  32K      50 us         41 us         +22%
  64K      95 us         82 us         +16%
  128K     176 us        152 us        +16%

gs=64 is faster at long context because it has half the scale/bias
memory traffic.

Costs:
  mlx.metallib: 128,161,428 -> 128,233,236 bytes (+0.056%)
  libmlx.dylib: unchanged

Existing 10 test_quantized_sdpa* tests continue to pass (54 subtests).
@dogukanveziroglu dogukanveziroglu force-pushed the fix/sdpa-affine-group64-b branch from aa69849 to e1c923e Compare April 18, 2026 16:43
@dogukanveziroglu
Copy link
Copy Markdown
Author

fixed the merge conflicts @CC-Yeh

@CC-Yeh CC-Yeh merged commit bfed86b into CC-Yeh:quantized_sdpa Apr 18, 2026
@dogukanveziroglu dogukanveziroglu deleted the fix/sdpa-affine-group64-b branch April 19, 2026 19:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants