Skip to content

Remove redundant config_name arg from JAX->Pytorch model conversion script#915

Open
tahsinkose wants to merge 1 commit intoPhysical-Intelligence:mainfrom
tahsinkose:fix/remove-config-name-from-jax-pytorch-conversion
Open

Remove redundant config_name arg from JAX->Pytorch model conversion script#915
tahsinkose wants to merge 1 commit intoPhysical-Intelligence:mainfrom
tahsinkose:fix/remove-config-name-from-jax-pytorch-conversion

Conversation

@tahsinkose
Copy link
Copy Markdown

Fixes #781.

Background

convert_jax_model_to_pytorch.py previously required a --config_name argument (e.g. pi05_base, pi05_droid) to look up a TrainConfig and extract a Pi0Config. This was unnecessary because all the information the conversion actually needs is either inferable from the checkpoint contents or irrelevant to the weight conversion entirely.


Approach

Each field of Pi0Config consumed by the script was audited:

1. pi05 (bool)

Determines the model architecture: which projection keys exist (time_mlp_in/time_mlp_out vs state_proj/action_time_mlp_in/action_time_mlp_out) and whether expert normalization layers use adaptive Dense layers or standard RMSNorm scale.

Resolution: reliably detectable from checkpoint contents — "time_mlp_in" is present in the projection params if and only if the checkpoint is pi05. No path-name heuristics needed.

2. action_dim

Used in PI0Pytorch.__init__ to size the projection layers:

  self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
  self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)                                                                                                                          
  self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)

Must be correct for load_state_dict to succeed.

Resolution: inferred from the checkpoint — action_in_proj/kernel.shape[0] gives action_dim directly.

3. action_horizon

Only referenced in forward/sample methods (embed_suffix, forward, sample_actions, denoise_step) for attention mask
construction and output slicing. Never used in __init__, so it has no effect on layer sizes or the weight conversion. Notably it also varies across checkpoints (pi05_droid: 15, pi05_libero: 10, default: 50), so using a hardcoded default would silently be wrong — but since it does not affect the conversion at all, it is simply not needed here. The config.json written alongside the converted weights is a human-readable reference only — nothing in the codebase reads it back. So omitting action_horizon from it has no functional impact on inference or finetuning; callers must construct PI0Pytorch(config) with the correct config themselves regardless.
#### 4. paligemma_variant / action_expert_variant

Already hardcoded elsewhere in the script — the PaliGemmaConfig inline class for the vision/language side and
openpi.models.gemma.get_config("gemma_300m") for the expert. Not needed by the config.

Changes

  • Removed config_name parameter from main and convert_pi0_checkpoint
  • Auto-detect pi05 from checkpoint contents: "time_mlp_in" in initial_params["projection_params"]
  • Infer action_dim from checkpoint: action_in_proj_kernel.shape[0]

…cript by inferring required fields from the checkpoint.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

How to convert pi05 base model to torch

1 participant