Implement gemma 4 inference support#388
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f7effb54c0
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| let source_suffix_start = source_state | ||
| .shared_kv_suffix_source_start() | ||
| .expect("Windowed shared KV source layer must have suffix source"); |
There was a problem hiding this comment.
Read shared KV suffix from actual windowed write offset
When a shared-KV source layer is windowed, this code always copies suffix rows starting at window_length, but the source attention path can write suffix tokens at a compact offset (ring_length + projection_step) via windowed_suffix_write_start when the ring has free contiguous space. In that common decode state (ring_offset == 0, non-prefill, sampling_start == 0), the target layer will copy stale K/V rows instead of the just-written suffix, producing incorrect attention results for shared-KV models. The copy start index should be derived from the same write-placement logic used by the source layer.
Useful? React with 👍 / 👎.
|
Can you split supporting the new config format into a separate pr? |
Add Gemma 4 model extension execution, per-layer RoPE handling, on-the-fly RoPE kernels, shared-KV/windowed cache fixes, and inference output parsing support. Note: this is a working checkpoint; the code needs follow-up refactor and cleanup before final review.
|
Gemma 4 inference support. This adds the UZU-side runtime support needed for Gemma 4 text inference:
After the initial implementation, this adds a small performance/cleanup pass:
Finally, it works great! Gemma 4 E2B benchmark, 841 input tokens, 5 runs:
|
How is it comparing to MLX? |
This PR adds the UZU-side support needed to run Gemma 4 models exported by the new Lalamo format.
It covers both parts of the Lalamo Gemma 4 change:
Small additions:
Added --greedy to the CLI run command (for debugging).
Added --tokens-limit to the CLI run command (for debugging).
Added Metal RoPE boundary handling for position >= max_sequence_length.
Added a RoPE regression test for out-of-range token positions.
Added CPU reference kernels for Gemma 4 model extension ops.
Added Metal kernels for Gemma 4 model extension ops.
Added kernel parity tests for the model extension kernels.
Added support for per-layer RoPE configs in decoder/language model config parsing.
Added support for Gemma 4 PLE model/layer config fields.
Added helper plumbing for Gemma 4 auxiliary buffers.
Added shared KV source layer handling.
Added post-layer scalar handling.
Added soft-cap handling in the Gemma 4 path.
Added support for deriving runtime layer metadata from layer_configs.
Notes
Tracer validation is close but not fully green yet!
The latest run failed at layer_results.7.activation_trace.pre_mixer_norm with 1029 violations over a limit of 922. Keeping this as a draft until that accuracy gap is resolved.
I spent a significant amount of time investigating this accuracy gap, but it is not resolved yet(
There is also still an intermittent decoding/stopping issue: in regular inference the model can produce the correct answer and then continue with extra garbage tokens instead of stopping cleanly. This PR keeps that as a known remaining issue while the Gemma 4 execution path is being validated.