Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions scripts/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import openpi.shared.normalize as _normalize
import openpi.training.config as _config
import openpi.training.data_loader as _data

import openpi.models_pytorch.lora_pytorch as lora_utils

def init_logging():
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
Expand Down Expand Up @@ -146,7 +146,7 @@ def get_model_parameters(model):
)


def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
def save_checkpoint(model, optimizer, global_step, config, is_main, data_config, lora_enabled=False):
"""Save a checkpoint with model state, optimizer state, and metadata."""
if not is_main:
return
Expand Down Expand Up @@ -416,6 +416,23 @@ def train_loop(config: _config.TrainConfig):
enable_gradient_checkpointing = False
logging.info("Gradient checkpointing is not supported for this model")

# Check if LoRA is enabled
lora_enabled = hasattr(config, "lora_config") and config.lora_config is not None and config.lora_config.enabled

# Load weights from weight_loader if specified (for fine-tuning)
if config.pytorch_weight_path is not None:
logging.info(f"Loading weights from: {config.pytorch_weight_path}")

model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
safetensors.torch.load_model(model, model_path)
logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")

# Apply LoRA AFTER loading pretrained weights
if lora_enabled:
logging.info("Applying LoRA adapters to model...")
frozen_count, trainable_count = lora_utils.apply_lora_to_pi0_pytorch(model, config.lora_config)
logging.info(f"LoRA applied: {trainable_count:,} trainable params, {frozen_count:,} frozen params")

# Log initial memory usage after model creation
if is_main and torch.cuda.is_available():
log_memory_usage(device, 0, "after_model_creation")
Expand All @@ -438,25 +455,22 @@ def train_loop(config: _config.TrainConfig):
static_graph=world_size >= 8, # Enable for 8+ GPUs
)

# Load weights from weight_loader if specified (for fine-tuning)
if config.pytorch_weight_path is not None:
logging.info(f"Loading weights from: {config.pytorch_weight_path}")

model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
safetensors.torch.load_model(
(model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
)
logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")

# Optimizer + learning rate schedule from config
warmup_steps = config.lr_schedule.warmup_steps
peak_lr = config.lr_schedule.peak_lr
decay_steps = config.lr_schedule.decay_steps
end_lr = config.lr_schedule.decay_lr

# Get trainable parameters (respects LoRA freezing)
if lora_enabled:
trainable_params = [p for p in model.parameters() if p.requires_grad]
logging.info(f"Optimizing {len(trainable_params)} parameter groups (LoRA mode)")
else:
trainable_params = model.parameters()

# Create optimizer with config parameters
optim = torch.optim.AdamW(
model.parameters(),
trainable_params,
lr=peak_lr,
betas=(config.optimizer.b1, config.optimizer.b2),
eps=config.optimizer.eps,
Expand Down
14 changes: 14 additions & 0 deletions src/openpi/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from openpi.models_pytorch import pi0_pytorch
from openpi.models_pytorch import lora_pytorch
from openpi.shared import image_tools
import openpi.shared.array_typing as at

Expand Down Expand Up @@ -243,6 +244,19 @@ def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseM
def load_pytorch(self, train_config, weight_path: str):
logger.info(f"train_config: {train_config}")
model = pi0_pytorch.PI0Pytorch(config=train_config.model)

# Check if train_config has LoRA config
has_lora_config = (
hasattr(train_config, "lora_config")
and train_config.lora_config is not None
and train_config.lora_config.enabled
)

if has_lora_config:
# Use the config's LoRA settings
lora_pytorch.apply_lora_to_pi0_pytorch(model, train_config.lora_config)
logger.info("LoRA layers applied successfully")

safetensors.torch.load_model(model, weight_path)
return model

Expand Down
Loading