diff --git a/egomimic/algo/pi.py b/egomimic/algo/pi.py index abc7ede62..626f75057 100644 --- a/egomimic/algo/pi.py +++ b/egomimic/algo/pi.py @@ -332,6 +332,17 @@ def process_batch_for_training(self, batch): # left in `_batch` by `annotation_collate`, plus per-sample proprio # tensors from `_batch` for the optional State block. prompts = self._build_prompts(_batch, embodiment_name, B) + if ( + os.environ.get("PI_DEBUG_PROMPTS") + and getattr(self, "_debug_prompt_count", 0) < 3 + ): + logger.info( + "[PI_DEBUG_PROMPTS %d] %s prompt[0]: %r", + getattr(self, "_debug_prompt_count", 0), + embodiment_name, + prompts[0], + ) + self._debug_prompt_count = getattr(self, "_debug_prompt_count", 0) + 1 processed_batch[embodiment_id]["sampled_prompt"] = prompts processed_batch[embodiment_id].update(self._tokenize_prompts(prompts)) processed_batch[embodiment_id]["pad_mask"] = torch.ones( diff --git a/egomimic/eval/eval_pi.py b/egomimic/eval/eval_pi.py index 720619770..5e6b90faa 100644 --- a/egomimic/eval/eval_pi.py +++ b/egomimic/eval/eval_pi.py @@ -27,6 +27,25 @@ def compute_metrics_and_viz(self, batch): total_loss = None n_loss_embodiments = 0 + # Bimanual 12-D layout from `HumanBimanualCartesianEuler.from32`: + # [L_xyz(3), L_ypr(3), R_xyz(3), R_ypr(3)]. Split lets us tell a + # translation problem apart from a rotation-reconstruction artifact + # (6D-cols → matrix → YPR can blow up near gimbal lock / ±π wrap). + def _split_mse(pred_t, gt_t): + if pred_t.shape[-1] != 12: + return None, None + xyz_idx = [0, 1, 2, 6, 7, 8] + ypr_idx = [3, 4, 5, 9, 10, 11] + xyz = MeanSquaredError()( + pred_t[..., xyz_idx].cpu().contiguous(), + gt_t[..., xyz_idx].cpu().contiguous(), + ) + ypr = MeanSquaredError()( + pred_t[..., ypr_idx].cpu().contiguous(), + gt_t[..., ypr_idx].cpu().contiguous(), + ) + return xyz, ypr + for embodiment_id, _batch in batch.items(): _batch = algo.norm_stats.unnormalize(_batch, embodiment_id) embodiment_name = get_embodiment(embodiment_id).lower() @@ -49,6 +68,16 @@ def compute_metrics_and_viz(self, batch): metrics[f"Valid/{pred_key}_final_mse_avg"] = mse( preds[pred_key][:, -1].cpu(), _batch[ac_key][:, -1].cpu() ) + xyz_p, ypr_p = _split_mse(preds[pred_key], _batch[ac_key]) + if xyz_p is not None: + metrics[f"Valid/{pred_key}_xyz_paired_mse_avg"] = xyz_p + metrics[f"Valid/{pred_key}_ypr_paired_mse_avg"] = ypr_p + xyz_f, ypr_f = _split_mse( + preds[pred_key][:, -1:], _batch[ac_key][:, -1:] + ) + if xyz_f is not None: + metrics[f"Valid/{pred_key}_xyz_final_mse_avg"] = xyz_f + metrics[f"Valid/{pred_key}_ypr_final_mse_avg"] = ypr_f transform_list = self.transform_lists.get(embodiment_name) gt_batch_viz = _batch @@ -74,6 +103,18 @@ def compute_metrics_and_viz(self, batch): pred_batch_viz[ac_key][:, -1].cpu().contiguous(), gt_batch_viz[ac_key][:, -1].cpu().contiguous(), ) + xyz_cp, ypr_cp = _split_mse( + pred_batch_viz[ac_key], gt_batch_viz[ac_key] + ) + if xyz_cp is not None: + metrics[f"Valid/{pred_key}_cam_xyz_paired_mse_avg"] = xyz_cp + metrics[f"Valid/{pred_key}_cam_ypr_paired_mse_avg"] = ypr_cp + xyz_cf, ypr_cf = _split_mse( + pred_batch_viz[ac_key][:, -1:], gt_batch_viz[ac_key][:, -1:] + ) + if xyz_cf is not None: + metrics[f"Valid/{pred_key}_cam_xyz_final_mse_avg"] = xyz_cf + metrics[f"Valid/{pred_key}_cam_ypr_final_mse_avg"] = ypr_cf preds_for_viz = dict(preds) preds_for_viz[pred_key] = pred_batch_viz[ac_key] diff --git a/egomimic/hydra_configs/data/mecka_pi.yaml b/egomimic/hydra_configs/data/mecka_pi.yaml index 8c7bfec04..a01c9adea 100644 --- a/egomimic/hydra_configs/data/mecka_pi.yaml +++ b/egomimic/hydra_configs/data/mecka_pi.yaml @@ -8,25 +8,29 @@ train_datasets: folder_path: ${paths.dataset_dir} key_map: _target_: egomimic.rldb.embodiment.human.Mecka.get_keymap + mode: cartesian_pi + annotation_key: annotations transform_list: _target_: egomimic.rldb.embodiment.human.Mecka.get_transform_list + mode: cartesian filters: - lab: "mecka" - task: "folding_clothes" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['lab'] == 'mecka' and row['task'] == 'fold_clothes'" mode: train valid_datasets: mecka_bimanual: - _target_: ${train_datasets.mecka_bimanual._target_} - resolver: ${train_datasets.mecka_bimanual.resolver} - filters: ${train_datasets.mecka_bimanual.filters} + _target_: ${data.train_datasets.mecka_bimanual._target_} + resolver: ${data.train_datasets.mecka_bimanual.resolver} + filters: ${data.train_datasets.mecka_bimanual.filters} mode: valid train_dataloader_params: mecka_bimanual: - batch_size: 32 + batch_size: 64 num_workers: 10 valid_dataloader_params: mecka_bimanual: - batch_size: 32 + batch_size: 64 num_workers: 10 diff --git a/egomimic/hydra_configs/data/mecka_scale_cotrain_pi.yaml b/egomimic/hydra_configs/data/mecka_scale_cotrain_pi.yaml deleted file mode 100644 index 616165a7c..000000000 --- a/egomimic/hydra_configs/data/mecka_scale_cotrain_pi.yaml +++ /dev/null @@ -1,56 +0,0 @@ -_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper - -train_datasets: - mecka_bimanual: - _target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver - resolver: - _target_: egomimic.rldb.zarr.zarr_dataset_multi.S3EpisodeResolver - folder_path: ${paths.dataset_dir} - key_map: - _target_: egomimic.rldb.embodiment.human.Mecka.get_keymap - transform_list: - _target_: egomimic.rldb.embodiment.human.Mecka.get_transform_list - filters: - lab: "mecka" - task: "fold_clothes" - mode: train - scale_bimanual: - _target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver - resolver: - _target_: egomimic.rldb.zarr.zarr_dataset_multi.S3EpisodeResolver - folder_path: ${paths.dataset_dir} - key_map: - _target_: egomimic.rldb.embodiment.human.Scale.get_keymap - transform_list: - _target_: egomimic.rldb.embodiment.human.Scale.get_transform_list - filters: - lab: "scale" - task: "flagship_fold_clothes" - mode: train - -valid_datasets: - mecka_bimanual: - _target_: ${train_datasets.mecka_bimanual._target_} - resolver: ${train_datasets.mecka_bimanual.resolver} - filters: ${train_datasets.mecka_bimanual.filters} - mode: valid - scale_bimanual: - _target_: ${train_datasets.scale_bimanual._target_} - resolver: ${train_datasets.scale_bimanual.resolver} - filters: ${train_datasets.scale_bimanual.filters} - mode: valid - -train_dataloader_params: - mecka_bimanual: - batch_size: 32 - num_workers: 10 - scale_bimanual: - batch_size: 32 - num_workers: 10 -valid_dataloader_params: - mecka_bimanual: - batch_size: 32 - num_workers: 10 - scale_bimanual: - batch_size: 32 - num_workers: 10 diff --git a/egomimic/hydra_configs/evaluator/viz/pi_cartesian.yaml b/egomimic/hydra_configs/evaluator/viz/pi_cartesian.yaml index 0b9e9267f..4b2d5a72f 100644 --- a/egomimic/hydra_configs/evaluator/viz/pi_cartesian.yaml +++ b/egomimic/hydra_configs/evaluator/viz/pi_cartesian.yaml @@ -15,4 +15,10 @@ scale_bimanual: _partial_: true image_key: base_0_rgb action_key: actions_cartesian + mode: traj +mecka_bimanual: + _target_: egomimic.rldb.embodiment.human.Mecka.viz_gt_preds + _partial_: true + image_key: base_0_rgb + action_key: actions_cartesian mode: traj \ No newline at end of file diff --git a/egomimic/hydra_configs/evaluator/viz/pi_cartesian_lang.yaml b/egomimic/hydra_configs/evaluator/viz/pi_cartesian_lang.yaml index 3b7859dc6..911ecc3db 100644 --- a/egomimic/hydra_configs/evaluator/viz/pi_cartesian_lang.yaml +++ b/egomimic/hydra_configs/evaluator/viz/pi_cartesian_lang.yaml @@ -8,3 +8,5 @@ aria_bimanual: annotation_key: sampled_prompt scale_bimanual: annotation_key: sampled_prompt +mecka_bimanual: + annotation_key: sampled_prompt diff --git a/egomimic/hydra_configs/model/pi0.5_base.yaml b/egomimic/hydra_configs/model/pi0.5_base.yaml index 1711ecc75..912cd9313 100644 --- a/egomimic/hydra_configs/model/pi0.5_base.yaml +++ b/egomimic/hydra_configs/model/pi0.5_base.yaml @@ -33,7 +33,7 @@ robomimic_model: tokenizer_max_length: 128 sampling_mode: "random" annotation_key: "annotations" - default_prompt: "This is a bad action." + default_prompt: "" proprio_in_prompt: true embodiment_label: true state_num_bins: 256 @@ -58,15 +58,15 @@ optimizer: _target_: torch.optim.AdamW _partial_: true lr: 3e-5 - betas: [0.9, 0.999] + betas: [0.9, 0.95] eps: 1e-8 - weight_decay: 0.0 + weight_decay: 1e-10 scheduler: _target_: transformers.get_cosine_schedule_with_warmup _partial_: true - num_warmup_steps: 2000 + num_warmup_steps: 1000 num_training_steps: 60000 num_cycles: 0.5 -scheduler_interval: step \ No newline at end of file +scheduler_interval: step diff --git a/egomimic/hydra_configs/paths/default.yaml b/egomimic/hydra_configs/paths/default.yaml index 2ba5595fc..c3f2fdd6b 100644 --- a/egomimic/hydra_configs/paths/default.yaml +++ b/egomimic/hydra_configs/paths/default.yaml @@ -7,7 +7,7 @@ # data_dir: ${paths.root_dir}/data/ # default folder containing zarr episode datasets -dataset_dir: /coc/flash7/scratch/egoverseS3ZarrDataset +dataset_dir: /storage/project/r-dxu345-0/shared/egoverseS3ZarrDatasets # path to logging directory log_dir: ${paths.root_dir}/logs/ diff --git a/egomimic/hydra_configs/train_zarr_cartesian.yaml b/egomimic/hydra_configs/train_zarr_cartesian.yaml index 31fe7bd20..1c85c859f 100644 --- a/egomimic/hydra_configs/train_zarr_cartesian.yaml +++ b/egomimic/hydra_configs/train_zarr_cartesian.yaml @@ -31,8 +31,9 @@ seed: 42 # Normalization and norm-stat cache settings (override per-run as needed). norm_stats: norm_mode: quantile - sample_frac: 1.0 - num_workers: 4 + sample_frac: 0.2 + num_workers: 6 save_cache_dir: ${hydra:runtime.output_dir} precomputed_norm_path: null + # precomputed_norm_path: /storage/project/r-dxu345-0/rco3/EgoVerse/logs/norm_stats reject_outliers: true diff --git a/egomimic/rldb/embodiment/human.py b/egomimic/rldb/embodiment/human.py index a8f6f1ed1..e97de31e2 100644 --- a/egomimic/rldb/embodiment/human.py +++ b/egomimic/rldb/embodiment/human.py @@ -325,13 +325,38 @@ class Mecka(Human): VIZ_INTRINSICS_KEY = "mecka" ACTION_STRIDE = 1 + @classmethod + def get_transform_list( + cls, + mode: Literal["cartesian",] = "cartesian", + chunk_length: int = 100, + ) -> list[Transform]: + if mode == "cartesian": + return _build_aria_cartesian_bimanual_transform_list( + stride=cls.ACTION_STRIDE, + chunk_length=chunk_length, + ) + @classmethod def get_keymap( - cls, mode: Literal["cartesian", "keypoints"], annotations: bool = False + cls, + mode: Literal["cartesian", "cartesian_pi", "keypoints"] | None = None, + keymap_mode: Literal["cartesian", "cartesian_pi", "keypoints"] | None = None, + annotations: bool = False, + norm_mode: bool = False, + annotation_key: str | None = None, ): - if mode == "cartesian": + if mode is None: + mode = keymap_mode or "cartesian" + elif keymap_mode is not None and keymap_mode != mode: + raise ValueError( + f"Conflicting Mecka keymap modes: mode='{mode}', keymap_mode='{keymap_mode}'." + ) + + if mode in ("cartesian", "cartesian_pi"): + front_key = "base_0_rgb" if mode == "cartesian_pi" else cls.VIZ_IMAGE_KEY key_map = { - cls.VIZ_IMAGE_KEY: { + front_key: { "key_type": "camera_keys", "zarr_key": "images.front_1", }, @@ -407,13 +432,21 @@ def get_keymap( } else: raise ValueError( - f"Unsupported mode '{mode}'. Expected one of: 'cartesian', 'keypoints'." + f"Unsupported mode '{mode}'. Expected one of: 'cartesian', 'cartesian_pi', 'keypoints'." ) - if annotations: - key_map["annotations"] = { + + requested_annotation_key = annotation_key or ( + "annotations" if annotations else None + ) + if requested_annotation_key is not None: + key_map[requested_annotation_key] = { "key_type": "annotation_keys", - "zarr_key": "annotations", + "zarr_key": requested_annotation_key, } + if norm_mode: + for key in list(key_map): + if key_map[key].get("key_type") in ("camera_keys", "annotation_keys"): + del key_map[key] return key_map diff --git a/egomimic/rldb/zarr/test_multi_retry.py b/egomimic/rldb/zarr/test_multi_retry.py new file mode 100644 index 000000000..c9001b014 --- /dev/null +++ b/egomimic/rldb/zarr/test_multi_retry.py @@ -0,0 +1,105 @@ +"""Tests for MultiDataset's per-sample retry on bounds/exception failures. + +Verifies that retries draw from the *whole* MultiDataset, not just the +failing leaf, so a single bad episode can't trap the loader inside itself. +""" + +from __future__ import annotations + +import random + +import pytest + +from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset + + +class _DummyLeaf: + """Minimal stand-in for a ZarrDataset leaf. + + `bad_local_idxs` is the set of local indices that should raise; all + other indices return a trivial sample dict. We do not set `embodiment`, + so MultiDataset._check_bounds short-circuits to None and returns the + sample as-is — the only retry path exercised here is the + `except Exception` branch in __getitem__. + """ + + def __init__(self, name: str, length: int, bad_local_idxs: set[int] | None = None): + self.name = name + self._length = length + self._bad = bad_local_idxs or set() + + def __len__(self) -> int: + return self._length + + def __getitem__(self, local_idx: int) -> dict: + if local_idx in self._bad: + raise ValueError(f"intentional failure at {self.name}[{local_idx}]") + return {"leaf": self.name, "local_idx": local_idx} + + +def _make_mds(datasets: dict) -> MultiDataset: + return MultiDataset(datasets=datasets, mode="total") + + +def test_retry_escapes_fully_bad_leaf(): + """A leaf where every local index fails must be escapable via retry.""" + bad_len = 4 + good_len = 6 + bad = _DummyLeaf("bad_ep", bad_len, bad_local_idxs=set(range(bad_len))) + good = _DummyLeaf("good_ep", good_len) + + mds = _make_mds({"bad_ep": bad, "good_ep": good}) + # Sanity: index_map covers both leaves. + assert len(mds) == bad_len + good_len + + # Hit an index that lives in the bad leaf. With per-leaf retry this would + # exhaust attempts inside "bad_ep" and raise RuntimeError. With global + # retry it should land somewhere in "good_ep" instead. + bad_global_idx = mds._global_indices_by_dataset["bad_ep"][0] + + random.seed(0) + sample = mds[bad_global_idx] + assert sample["leaf"] == "good_ep" + + +def test_retry_pool_is_global(monkeypatch): + """The candidate pool passed to random.choice should be the whole + index_map (minus the failing index), not just the failing leaf's + indices.""" + bad = _DummyLeaf("bad_ep", 3, bad_local_idxs={0, 1, 2}) + good = _DummyLeaf("good_ep", 5) + mds = _make_mds({"bad_ep": bad, "good_ep": good}) + + captured: list[list[int]] = [] + + real_choice = random.choice + + def spy_choice(seq): + # Materialize whatever was passed so we can inspect it. + captured.append(list(seq)) + return real_choice(captured[-1]) + + monkeypatch.setattr(random, "choice", spy_choice) + + bad_global_idx = mds._global_indices_by_dataset["bad_ep"][0] + mds[bad_global_idx] + + assert captured, "retry never sampled — bad leaf wasn't exercised" + # First retry's pool should include indices from the good leaf. + good_globals = set(mds._global_indices_by_dataset["good_ep"]) + first_pool = set(captured[0]) + assert first_pool & good_globals, ( + f"first retry pool {first_pool} did not include any good-leaf indices " + f"{good_globals} — retry is still scoped to the failing leaf" + ) + + +def test_retry_exhausts_when_everything_is_bad(): + """If every leaf is fully bad, the bounded while-loop must raise rather + than spin forever.""" + a = _DummyLeaf("a", 2, bad_local_idxs={0, 1}) + b = _DummyLeaf("b", 2, bad_local_idxs={0, 1}) + mds = _make_mds({"a": a, "b": b}) + + with pytest.raises(RuntimeError, match="Entire MultiDataset bad"): + mds[0] diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index 069eb3a70..2e24ade88 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -912,7 +912,6 @@ def __getitem__(self, idx, _attempts: int | None = None): except Exception as e: next_idx, attempts = self._next_after_failure( idx, - dataset_name, attempts, reason=f"Sample failed ({type(e).__name__}: {e}) at " f"{dataset_name}[{local_idx}]", @@ -929,7 +928,6 @@ def __getitem__(self, idx, _attempts: int | None = None): if violation is not None: next_idx, attempts = self._next_after_failure( idx, - dataset_name, attempts, reason=violation, ) @@ -942,17 +940,19 @@ def __getitem__(self, idx, _attempts: int | None = None): return data def _next_after_failure( - self, idx: int, dataset_name: str, attempts: int | None, *, reason: str + self, idx: int, attempts: int | None, *, reason: str ) -> tuple[int, int]: - global_candidates = self._global_indices_by_dataset[dataset_name] + # Resample across the whole MultiDataset, not just the failing leaf — + # otherwise a systematically bad leaf (e.g. an episode where every + # frame's actions_cartesian violates bounds) chews through all its + # own frames before raising, instead of escaping to a healthy leaf. + global_candidates = range(len(self.index_map)) next_idx, attempts = get_fallback_idx( idx=idx, candidates=global_candidates, _attempts=attempts, max_attempts=len(global_candidates), - exhausted_error=( - f"Entire dataset bad (no valid indices): dataset={dataset_name}" - ), + exhausted_error=("Entire MultiDataset bad (no valid indices)"), ) next_dataset_name, next_local_idx = self.index_map[next_idx] logger.warning(