Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/openpi/training/misc/roboarena_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import openpi.models.pi0_fast as pi0_fast
import openpi.models.tokenizer as _tokenizer
import openpi.policies.droid_policy as droid_policy
import openpi.training.weight_loaders as weight_loaders
import openpi.transforms as _transforms

ModelType: TypeAlias = _model.ModelType
Expand All @@ -26,6 +27,7 @@ def get_roboarena_configs():
TrainConfig(
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
name="paligemma_binning_droid",
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
Expand All @@ -46,6 +48,7 @@ def get_roboarena_configs():
TrainConfig(
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
name="paligemma_fast_droid",
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
Expand All @@ -61,6 +64,7 @@ def get_roboarena_configs():
TrainConfig(
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
name="paligemma_fast_specialist_droid",
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
Expand All @@ -81,6 +85,7 @@ def get_roboarena_configs():
TrainConfig(
# Trained from PaliGemma, using FSQ tokenizer.
name="paligemma_vq_droid",
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
Expand All @@ -101,6 +106,7 @@ def get_roboarena_configs():
TrainConfig(
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
name="paligemma_diffusion_droid",
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
model=pi0_config.Pi0Config(action_horizon=10, action_dim=8),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
Expand Down
Loading