Skip to content

Implement gemma 4 inference support#388

Open
Madfyre wants to merge 11 commits intotrymirai:mainfrom
Madfyre:implement-gemma-4-inference-support
Open

Implement gemma 4 inference support#388
Madfyre wants to merge 11 commits intotrymirai:mainfrom
Madfyre:implement-gemma-4-inference-support

Conversation

@Madfyre
Copy link
Copy Markdown

@Madfyre Madfyre commented May 5, 2026

This PR adds the UZU-side support needed to run Gemma 4 models exported by the new Lalamo format.

It covers both parts of the Lalamo Gemma 4 change:

  • supports per-layer RoPE configs and derives the runtime global/local RoPE buffers from layer configs
  • adds the Gemma 4 model extension path, including per-layer embeddings, layer PLE blocks, shared KV wiring, post-layer scalar handling, and the required CPU/Metal kernels

Small additions:

Added --greedy to the CLI run command (for debugging).
Added --tokens-limit to the CLI run command (for debugging).
Added Metal RoPE boundary handling for position >= max_sequence_length.
Added a RoPE regression test for out-of-range token positions.
Added CPU reference kernels for Gemma 4 model extension ops.
Added Metal kernels for Gemma 4 model extension ops.
Added kernel parity tests for the model extension kernels.
Added support for per-layer RoPE configs in decoder/language model config parsing.
Added support for Gemma 4 PLE model/layer config fields.
Added helper plumbing for Gemma 4 auxiliary buffers.
Added shared KV source layer handling.
Added post-layer scalar handling.
Added soft-cap handling in the Gemma 4 path.
Added support for deriving runtime layer metadata from layer_configs.

Notes

Tracer validation is close but not fully green yet!
The latest run failed at layer_results.7.activation_trace.pre_mixer_norm with 1029 violations over a limit of 922. Keeping this as a draft until that accuracy gap is resolved.
I spent a significant amount of time investigating this accuracy gap, but it is not resolved yet(

There is also still an intermittent decoding/stopping issue: in regular inference the model can produce the correct answer and then continue with extra garbage tokens instead of stopping cleanly. This PR keeps that as a known remaining issue while the Gemma 4 execution path is being validated.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f7effb54c0

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +275 to +277
let source_suffix_start = source_state
.shared_kv_suffix_source_start()
.expect("Windowed shared KV source layer must have suffix source");
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Read shared KV suffix from actual windowed write offset

When a shared-KV source layer is windowed, this code always copies suffix rows starting at window_length, but the source attention path can write suffix tokens at a compact offset (ring_length + projection_step) via windowed_suffix_write_start when the ring has free contiguous space. In that common decode state (ring_offset == 0, non-prefill, sampling_start == 0), the target layer will copy stale K/V rows instead of the just-written suffix, producing incorrect attention results for shared-KV models. The copy start index should be derived from the same write-placement logic used by the source layer.

Useful? React with 👍 / 👎.

@uuuvn
Copy link
Copy Markdown
Contributor

uuuvn commented May 5, 2026

Can you split supporting the new config format into a separate pr?

Add Gemma 4 model extension execution, per-layer RoPE handling, on-the-fly RoPE kernels, shared-KV/windowed cache fixes, and inference output parsing support.

Note: this is a working checkpoint; the code needs follow-up refactor and cleanup before final review.
Copy link
Copy Markdown
Author

Madfyre commented May 8, 2026

Gemma 4 inference support.

This adds the UZU-side runtime support needed for Gemma 4 text inference:

  • Added Gemma 4 attention shape handling for heterogeneous layers.
  • Added shared-KV wiring for layers that reuse another layer's K/V cache.
  • Added/updated CPU and Metal support for the Gemma 4-specific model-extension kernels.
  • Added output parser support for Gemma 4 thought/final channel formatting, so reasoning text is not mixed into the visible final answer.
  • Added tests/bench coverage around the new RoPE and model-extension paths.
  • Removed the decoding/stopping issue.

After the initial implementation, this adds a small performance/cleanup pass:

  • Shared-KV target layers now read from the source layer K/V cache directly instead of physically copying K/V rows into their own target cache each pass.
  • Shared-KV target layers skip their own KV-cache update, matching the reference behavior more closely.
  • inverse_frequencies are now precomputed once per Rope block into a tiny buffer and passed into the CPU/Metal kernels. This removes repeated base/scaling/log/exp inverse-frequency calculation from the inner RoPE kernel loop.

Finally, it works great!

Gemma 4 E2B benchmark, 841 input tokens, 5 runs:
Shared-KV no-copy + RoPE inverse-frequency precompute:

  • TTFT is around 0.420-0.423 s
  • Prompt throughput is around 1990-2005 t/s
  • Generate throughput is around 24.2-24.5 t/s

@CC-Yeh
Copy link
Copy Markdown
Contributor

CC-Yeh commented May 9, 2026

Gemma 4 E2B benchmark, 841 input tokens, 5 runs:
Shared-KV no-copy + RoPE inverse-frequency precompute:

  • TTFT is around 0.420-0.423 s
  • Prompt throughput is around 1990-2005 t/s
  • Generate throughput is around 24.2-24.5 t/s

How is it comparing to MLX?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants