diff --git a/egomimic/robot/annotations/chips/basic.txt b/egomimic/robot/annotations/chips/basic.txt new file mode 100644 index 000000000..d027450dc --- /dev/null +++ b/egomimic/robot/annotations/chips/basic.txt @@ -0,0 +1 @@ +Using the left hand, pick the doritos from the table. \ No newline at end of file diff --git a/egomimic/robot/annotations/chips/place_pink_plate.txt b/egomimic/robot/annotations/chips/place_pink_plate.txt new file mode 100644 index 000000000..10714f331 --- /dev/null +++ b/egomimic/robot/annotations/chips/place_pink_plate.txt @@ -0,0 +1 @@ +With left hand, place the doritos on the pink plate. \ No newline at end of file diff --git a/egomimic/robot/annotations/croissant/basic.txt b/egomimic/robot/annotations/croissant/basic.txt new file mode 100644 index 000000000..33d426954 --- /dev/null +++ b/egomimic/robot/annotations/croissant/basic.txt @@ -0,0 +1 @@ +Using the left hand, pick the croissant from the table. \ No newline at end of file diff --git a/egomimic/robot/annotations/croissant/place_bowl.txt b/egomimic/robot/annotations/croissant/place_bowl.txt new file mode 100644 index 000000000..b0e02fb76 --- /dev/null +++ b/egomimic/robot/annotations/croissant/place_bowl.txt @@ -0,0 +1 @@ +With left hand, place the croissant in the bowl. \ No newline at end of file diff --git a/egomimic/robot/annotations/croissant/place_pink_cup.txt b/egomimic/robot/annotations/croissant/place_pink_cup.txt new file mode 100644 index 000000000..c321dbaa0 --- /dev/null +++ b/egomimic/robot/annotations/croissant/place_pink_cup.txt @@ -0,0 +1 @@ +With left hand, place the croissant on the left side of pink mug. \ No newline at end of file diff --git a/egomimic/robot/annotations/cup/basic.txt b/egomimic/robot/annotations/cup/basic.txt new file mode 100644 index 000000000..b62a4bb9e --- /dev/null +++ b/egomimic/robot/annotations/cup/basic.txt @@ -0,0 +1 @@ +With left hand, grab the pink mug from the table using the handle. \ No newline at end of file diff --git a/egomimic/robot/annotations/cup/place_brown_toy.txt b/egomimic/robot/annotations/cup/place_brown_toy.txt new file mode 100644 index 000000000..1f6188d1e --- /dev/null +++ b/egomimic/robot/annotations/cup/place_brown_toy.txt @@ -0,0 +1 @@ +With left hand, place the pink mug to the right of the brown toy. \ No newline at end of file diff --git a/egomimic/robot/annotations/cup/place_pink_plate.txt b/egomimic/robot/annotations/cup/place_pink_plate.txt new file mode 100644 index 000000000..9acf59832 --- /dev/null +++ b/egomimic/robot/annotations/cup/place_pink_plate.txt @@ -0,0 +1 @@ +With left hand, place the pink mug on top of the pink plate. \ No newline at end of file diff --git a/egomimic/robot/annotations/screwdriver/basic.txt b/egomimic/robot/annotations/screwdriver/basic.txt new file mode 100644 index 000000000..38e43c1a7 --- /dev/null +++ b/egomimic/robot/annotations/screwdriver/basic.txt @@ -0,0 +1 @@ +Using the left hand, pick the orange screw driver from the table. \ No newline at end of file diff --git a/egomimic/robot/annotations/screwdriver/place_bowl.txt b/egomimic/robot/annotations/screwdriver/place_bowl.txt new file mode 100644 index 000000000..f0dcc7bf1 --- /dev/null +++ b/egomimic/robot/annotations/screwdriver/place_bowl.txt @@ -0,0 +1 @@ +With left hand, place the screw driver in the bowl. \ No newline at end of file diff --git a/egomimic/robot/annotations/screwdriver/place_pink_plate.txt b/egomimic/robot/annotations/screwdriver/place_pink_plate.txt new file mode 100644 index 000000000..b6cad7fc7 --- /dev/null +++ b/egomimic/robot/annotations/screwdriver/place_pink_plate.txt @@ -0,0 +1 @@ +With left hand, place the screw driver on top of the pink plate. \ No newline at end of file diff --git a/egomimic/robot/annotations/toy/pick_basic.txt b/egomimic/robot/annotations/toy/pick_basic.txt new file mode 100644 index 000000000..b05a8be51 --- /dev/null +++ b/egomimic/robot/annotations/toy/pick_basic.txt @@ -0,0 +1 @@ +With left hand, pick up the brown toy from the table. \ No newline at end of file diff --git a/egomimic/robot/annotations/toy/pick_basic2.txt b/egomimic/robot/annotations/toy/pick_basic2.txt new file mode 100644 index 000000000..2f19d3d32 --- /dev/null +++ b/egomimic/robot/annotations/toy/pick_basic2.txt @@ -0,0 +1 @@ +With left hand, grab the brown toy from the table \ No newline at end of file diff --git a/egomimic/robot/annotations/toy/place_basic.txt b/egomimic/robot/annotations/toy/place_basic.txt new file mode 100644 index 000000000..04f9c48b9 --- /dev/null +++ b/egomimic/robot/annotations/toy/place_basic.txt @@ -0,0 +1 @@ +With the left hand, position the brown toy in front of the green bowl. \ No newline at end of file diff --git a/egomimic/robot/annotations/toy/place_bowl.txt b/egomimic/robot/annotations/toy/place_bowl.txt new file mode 100644 index 000000000..d075d88f0 --- /dev/null +++ b/egomimic/robot/annotations/toy/place_bowl.txt @@ -0,0 +1 @@ +With left hand, place the brown toy in the green bowl. \ No newline at end of file diff --git a/egomimic/robot/annotations/toy/place_croissant.txt b/egomimic/robot/annotations/toy/place_croissant.txt new file mode 100644 index 000000000..33a733a81 --- /dev/null +++ b/egomimic/robot/annotations/toy/place_croissant.txt @@ -0,0 +1 @@ +With left hand, place the brown toy on the left side of the croissant. \ No newline at end of file diff --git a/egomimic/robot/annotations/toy/place_pink_plate.txt b/egomimic/robot/annotations/toy/place_pink_plate.txt new file mode 100644 index 000000000..049134f86 --- /dev/null +++ b/egomimic/robot/annotations/toy/place_pink_plate.txt @@ -0,0 +1 @@ +With the left hand, place the brown toy on the pink plate. \ No newline at end of file diff --git a/egomimic/robot/rollout.py b/egomimic/robot/rollout.py index a7cbf97e7..ea3b5157b 100644 --- a/egomimic/robot/rollout.py +++ b/egomimic/robot/rollout.py @@ -10,15 +10,13 @@ import h5py import numpy as np import torch -from torch.utils.data import default_collate from robot_utils import RateLoop from scipy.spatial.transform import Rotation as R from egomimic.models.denoising_policy import DenoisingPolicy from egomimic.pl_utils.pl_model import ModelWrapper -from egomimic.pl_utils.pl_data_utils import build_tokenized_collate +from egomimic.pl_utils.pl_data_utils import annotation_collate from egomimic.rldb.embodiment.embodiment import get_embodiment -from egomimic.utils.hydra_utils import find_run_snapshot_path, load_run_snapshot from egomimic.rldb.embodiment.eva import Eva from egomimic.robot.eva.eva_kinematics import EvaMinkKinematicsSolver from egomimic.utils.egomimicUtils import ( @@ -268,15 +266,18 @@ def __init__( self.debug = debug self.transform_list = Eva.get_transform_list(mode="cartesian_wristframe_ypr") self.annotation = None - self._tokenizer = None - self.collate_fn = default_collate + # PI's process_batch_for_training now owns prompt assembly + tokenization + # (see PI._build_prompts / _tokenize_prompts). The collate's only job is + # to stack tensors and preserve the raw per-sample `annotations` list as + # list-of-lists, which is what _build_prompts iterates over. + self.collate_fn = annotation_collate if annotation_path is not None: if not os.path.isfile(annotation_path): print(f"[rollout] WARNING: annotation file not found: {annotation_path} (continuing without annotation)") else: with open(annotation_path, "r") as f: self.annotation = f.read().strip() - self.collate_fn = self._build_collate_from_checkpoint_cfg(self.annotation) + self._apply_annotation_to_algo() LOCAL_WEIGHT_PATH = "/home/robot/robot_ws/egomimic/algo/pi_checkpoints/pi05_base_pytorch" @@ -318,49 +319,33 @@ def _patch_checkpoint_paths(cls, ckpt_path): print(f"[rollout] Patched checkpoint saved to {patched_path}") return patched_path, cfg - def _build_collate_from_checkpoint_cfg(self, default_prompt): - """Build a tokenized collate_fn using the build_tokenized_collate flags - saved in the checkpoint's hydra config (under the top-level ``data:`` - block). This mirrors the pi prompt formatting the model was trained on - (Task / Embodiment / Control mode / State blocks).""" - data_cfg = (self._ckpt_cfg or {}).get("data", {}) or {} - # TODO: remove (debug — verify build_tokenized_collate flags from ckpt) - print( - f"[rollout][debug] ckpt_cfg top-level keys: " - f"{list((self._ckpt_cfg or {}).keys())}" - ) - print( - f"[rollout][debug] data_cfg flags: " - f"proprio={data_cfg.get('proprio')}, " - f"embodiment_label={data_cfg.get('embodiment_label')}, " - f"control_mode={data_cfg.get('control_mode')}, " - f"state_num_bins={data_cfg.get('state_num_bins')}, " - f"model_name={data_cfg.get('model_name')}" - ) - return build_tokenized_collate( - max_length=128, - model_name=data_cfg.get("model_name", "google/paligemma-3b-mix-224"), - sampling_mode="first", - annotation_key="annotations", - default_prompt=default_prompt, - proprio_keys=data_cfg.get("proprio_keys"), - state_num_bins=data_cfg.get("state_num_bins", 256), - proprio=bool(data_cfg.get("proprio", False)), - embodiment_label=bool(data_cfg.get("embodiment_label", False)), - control_mode=data_cfg.get("control_mode"), - ) + def _apply_annotation_to_algo(self): + """Wire the rollout-time annotation into the PI algo. + + The algo was loaded with its trained-in ``annotation_key`` / + ``sampling_mode`` / ``default_prompt``. Override them so the prompt + the user supplies via --annotation-path is what actually gets + tokenized: + - ``annotation_key="annotations"`` matches the key we stuff into + each per-step sample in ``process_obs_for_transform_list``. + - ``sampling_mode="first"`` makes inference deterministic — there's + only ever one annotation per rollout, but if a list ever shows up + we always pick the same element. + - ``default_prompt=self.annotation`` is the fallback path for + edge cases (e.g. the annotations key gets dropped). + """ + model = getattr(self.policy, "model", None) + if model is None: + return + if hasattr(model, "annotation_key"): + model.annotation_key = "annotations" + if hasattr(model, "sampling_mode"): + model.sampling_mode = "first" + if self.annotation is not None and hasattr(model, "default_prompt"): + model.default_prompt = self.annotation def _load_policy(self): patched_path, _ = self._patch_checkpoint_paths(self.policy_path) - # The .ckpt only stores the model subtree (see trainHydra._build_model_config_tree), - # so load the full hydra run-snapshot (with the data: block) from disk. - snapshot_path = find_run_snapshot_path(self.policy_path) - if snapshot_path is None: - print(f"[rollout] WARNING: no .hydra/config.yaml found near {self.policy_path}") - self._ckpt_cfg = None - else: - print(f"[rollout] Loaded hydra config from {snapshot_path}") - self._ckpt_cfg = load_run_snapshot(self.policy_path) policy = ModelWrapper.load_from_checkpoint( patched_path, weights_only=False, map_location="cpu" ) @@ -418,7 +403,6 @@ def rollout_step(self, i, obs): for transform in self.transform_list: transform_list_batch = transform.transform(transform_list_batch) transform_list_batch = self.collate_fn([transform_list_batch]) - print(f"[rollout][debug] sampled_prompt: {transform_list_batch.get('sampled_prompt')}") # TODO: remove if self.arm == "both": embodiment_name = "eva_bimanual" elif self.arm == "right": @@ -547,9 +531,8 @@ def process_obs_for_transform_list(self, obs): left_cmd_ee_pose = torch.from_numpy(left_xyzwxyz).view(1, 7).repeat(45, 1) data["left.cmd_ee_pose"] = left_cmd_ee_pose - # `embodiment` is consumed by build_tokenized_collate's _embodiment_name, - # which calls int() on it to look up the embodiment name; it must be - # the integer id, not the string. The string lives in metadata.robot_name. + # `embodiment` must be the integer id (the string lives in + # metadata.robot_name). Downstream lookups call int() on it. if self.arm == "both": data["embodiment"] = self.embodiment_id data["metadata.robot_name"] = "eva_bimanual" @@ -566,12 +549,12 @@ def process_obs_for_transform_list(self, obs): return data def load_annotation(self, annotation_path): - """Load a new annotation file, building the tokenized collate only if needed. + """Load a new annotation file. - The annotation text flows through data["annotations"] at each inference - step, so updating self.annotation is sufficient when the tokenized - collate already exists. We only build it when the collate is still the - plain default_collate (i.e. no annotation was provided at init time). + The annotation text flows through ``data["annotations"]`` at each + inference step and is consumed by ``PI.process_batch_for_training``, + so updating ``self.annotation`` is sufficient. We also re-apply it to + the algo so ``default_prompt`` stays in sync as a fallback. Returns True on success, False if the file could not be loaded. """ @@ -580,8 +563,7 @@ def load_annotation(self, annotation_path): return False with open(annotation_path, "r") as f: self.annotation = f.read().strip() - if self.collate_fn is default_collate: - self.collate_fn = self._build_collate_from_checkpoint_cfg(self.annotation) + self._apply_annotation_to_algo() print(f"[rollout] Loaded new annotation from {annotation_path}: '{self.annotation}'") return True