diff --git a/scripts/convert_to_robot.py b/scripts/convert_to_robot.py index fb5651d..31ed0fa 100644 --- a/scripts/convert_to_robot.py +++ b/scripts/convert_to_robot.py @@ -17,8 +17,11 @@ import pickle import shutil import sys +import tempfile from pathlib import Path +import numpy as np + sys.path.insert(0, str(Path(__file__).parent.parent)) from video2robot.robot import RobotRetargeter @@ -27,6 +30,35 @@ from video2robot.utils import get_next_project_dir, emit_progress +def prepare_smplx_betas_for_gmr(smplx_path: Path, tmp_dir: Path, target_num_betas: int = 16) -> Path: + """Create a temporary SMPL-X file with betas padded/truncated for GMR.""" + with np.load(smplx_path, allow_pickle=True) as smplx_data: + data = {key: smplx_data[key] for key in smplx_data.files} + + betas = np.asarray(data["betas"], dtype=np.float32).reshape(-1) + if betas.shape[0] == target_num_betas: + return smplx_path + + if betas.shape[0] < target_num_betas: + adjusted_betas = np.pad(betas, (0, target_num_betas - betas.shape[0])) + else: + adjusted_betas = betas[:target_num_betas] + + data["betas"] = adjusted_betas.astype(np.float32) + + tmp = tempfile.NamedTemporaryFile( + prefix=f"{smplx_path.stem}_", + suffix=".npz", + dir=tmp_dir, + delete=False, + ) + tmp_path = Path(tmp.name) + tmp.close() + np.savez(tmp_path, **data) + print(f"[SMPL-X] Adjusted betas for GMR: {betas.shape[0]} → {target_num_betas}") + return tmp_path + + def main(): parser = argparse.ArgumentParser(description="Convert SMPL-X to robot motion") @@ -116,6 +148,8 @@ def main(): retargeter = RobotRetargeter(robot_type=args.robot) motion_paths: dict[int, Path] = {} twist_paths: dict[int, Path] = {} + temp_smplx_dir = tempfile.TemporaryDirectory(prefix="video2robot_gmr_") + temp_smplx_path = Path(temp_smplx_dir.name) for track in selected_tracks: smplx_path = track.smplx_path @@ -128,8 +162,9 @@ def main(): if track.track_id: print(f"[Robot] Track ID: {track.track_id}") + gmr_smplx_path = prepare_smplx_betas_for_gmr(smplx_path, temp_smplx_path) retargeter.retarget( - smplx_path=smplx_path, + smplx_path=gmr_smplx_path, output_path=output_path, target_fps=args.fps, visualize=args.visualize,