Hello, I'm working on a project where I would like access to the VLM capabilities of the backbone (it's a project comparing the original PaliGemma and the post-training model)
I was just wondering if the following recipe is reasonable and whether there are issues; essentially we get the hidden states from "sample_actions" with only prefix_emb, then use lm_head to do decoding.
Here are the details:
- Image embedding: We call paligemma.model.get_image_features(pixel_values), which runs SigLIP → multi_modal_projector. This returns raw projector output with no division by sqrt(hidden_size) (matching the patched get_image_features). Shape: [B, 256, 2048] per image.
- Text embedding: We call paligemma.model.language_model.embed_tokens(token_ids) and then multiply by sqrt(2048). This matches what PI0Pytorch.embed_prefix does: lang_emb = embed_language_tokens(tokens) * sqrt(lang_emb_dim).
- Combine: We tokenize the prompt using the HF PaliGemma processor (which inserts token ID 257152 for
placeholders). We embed all tokens via step 2, then overwrite the image-placeholder positions with the image features from step 1. The result is a single [B, prefix_len, 2048] tensor where image positions have raw projector output and text positions have embed * sqrt(2048).
- Prefill: We call paligemma.model.language_model.forward(inputs_embeds=prefix_embs, position_ids=..., use_cache=True) — this is the patched GemmaModel directly (no normalizer, 18 layers). We get back hidden states and a KV cache.
- Decode loop: We project the last hidden state through paligemma.lm_head (tied to embed_tokens, shape [2048, 257152]) to get next-token logits. We greedily pick the top token, embed it via step 2, and call GemmaModel.forward again with the KV cache. Repeat until EOS (token 1) or max tokens.
Thanks
Hello, I'm working on a project where I would like access to the VLM capabilities of the backbone (it's a project comparing the original PaliGemma and the post-training model)
I was just wondering if the following recipe is reasonable and whether there are issues; essentially we get the hidden states from "sample_actions" with only prefix_emb, then use lm_head to do decoding.
Here are the details:
Thanks