Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions egomimic/algo/pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,17 @@ def process_batch_for_training(self, batch):
# left in `_batch` by `annotation_collate`, plus per-sample proprio
# tensors from `_batch` for the optional State block.
prompts = self._build_prompts(_batch, embodiment_name, B)
if (
os.environ.get("PI_DEBUG_PROMPTS")
and getattr(self, "_debug_prompt_count", 0) < 3
):
logger.info(
"[PI_DEBUG_PROMPTS %d] %s prompt[0]: %r",
getattr(self, "_debug_prompt_count", 0),
embodiment_name,
prompts[0],
)
self._debug_prompt_count = getattr(self, "_debug_prompt_count", 0) + 1
processed_batch[embodiment_id]["sampled_prompt"] = prompts
processed_batch[embodiment_id].update(self._tokenize_prompts(prompts))
processed_batch[embodiment_id]["pad_mask"] = torch.ones(
Expand Down
41 changes: 41 additions & 0 deletions egomimic/eval/eval_pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@ def compute_metrics_and_viz(self, batch):
total_loss = None
n_loss_embodiments = 0

# Bimanual 12-D layout from `HumanBimanualCartesianEuler.from32`:
# [L_xyz(3), L_ypr(3), R_xyz(3), R_ypr(3)]. Split lets us tell a
# translation problem apart from a rotation-reconstruction artifact
# (6D-cols → matrix → YPR can blow up near gimbal lock / ±π wrap).
def _split_mse(pred_t, gt_t):
if pred_t.shape[-1] != 12:
return None, None
xyz_idx = [0, 1, 2, 6, 7, 8]
ypr_idx = [3, 4, 5, 9, 10, 11]
xyz = MeanSquaredError()(
pred_t[..., xyz_idx].cpu().contiguous(),
gt_t[..., xyz_idx].cpu().contiguous(),
)
ypr = MeanSquaredError()(
pred_t[..., ypr_idx].cpu().contiguous(),
gt_t[..., ypr_idx].cpu().contiguous(),
)
return xyz, ypr

for embodiment_id, _batch in batch.items():
_batch = algo.norm_stats.unnormalize(_batch, embodiment_id)
embodiment_name = get_embodiment(embodiment_id).lower()
Expand All @@ -49,6 +68,16 @@ def compute_metrics_and_viz(self, batch):
metrics[f"Valid/{pred_key}_final_mse_avg"] = mse(
preds[pred_key][:, -1].cpu(), _batch[ac_key][:, -1].cpu()
)
xyz_p, ypr_p = _split_mse(preds[pred_key], _batch[ac_key])
if xyz_p is not None:
metrics[f"Valid/{pred_key}_xyz_paired_mse_avg"] = xyz_p
metrics[f"Valid/{pred_key}_ypr_paired_mse_avg"] = ypr_p
xyz_f, ypr_f = _split_mse(
preds[pred_key][:, -1:], _batch[ac_key][:, -1:]
)
if xyz_f is not None:
metrics[f"Valid/{pred_key}_xyz_final_mse_avg"] = xyz_f
metrics[f"Valid/{pred_key}_ypr_final_mse_avg"] = ypr_f

transform_list = self.transform_lists.get(embodiment_name)
gt_batch_viz = _batch
Expand All @@ -74,6 +103,18 @@ def compute_metrics_and_viz(self, batch):
pred_batch_viz[ac_key][:, -1].cpu().contiguous(),
gt_batch_viz[ac_key][:, -1].cpu().contiguous(),
)
xyz_cp, ypr_cp = _split_mse(
pred_batch_viz[ac_key], gt_batch_viz[ac_key]
)
if xyz_cp is not None:
metrics[f"Valid/{pred_key}_cam_xyz_paired_mse_avg"] = xyz_cp
metrics[f"Valid/{pred_key}_cam_ypr_paired_mse_avg"] = ypr_cp
xyz_cf, ypr_cf = _split_mse(
pred_batch_viz[ac_key][:, -1:], gt_batch_viz[ac_key][:, -1:]
)
if xyz_cf is not None:
metrics[f"Valid/{pred_key}_cam_xyz_final_mse_avg"] = xyz_cf
metrics[f"Valid/{pred_key}_cam_ypr_final_mse_avg"] = ypr_cf

preds_for_viz = dict(preds)
preds_for_viz[pred_key] = pred_batch_viz[ac_key]
Expand Down
18 changes: 11 additions & 7 deletions egomimic/hydra_configs/data/mecka_pi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,29 @@ train_datasets:
folder_path: ${paths.dataset_dir}
key_map:
_target_: egomimic.rldb.embodiment.human.Mecka.get_keymap
mode: cartesian_pi
annotation_key: annotations
transform_list:
_target_: egomimic.rldb.embodiment.human.Mecka.get_transform_list
mode: cartesian
filters:
lab: "mecka"
task: "folding_clothes"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['lab'] == 'mecka' and row['task'] == 'fold_clothes'"
mode: train

valid_datasets:
mecka_bimanual:
_target_: ${train_datasets.mecka_bimanual._target_}
resolver: ${train_datasets.mecka_bimanual.resolver}
filters: ${train_datasets.mecka_bimanual.filters}
_target_: ${data.train_datasets.mecka_bimanual._target_}
resolver: ${data.train_datasets.mecka_bimanual.resolver}
filters: ${data.train_datasets.mecka_bimanual.filters}
mode: valid

train_dataloader_params:
mecka_bimanual:
batch_size: 32
batch_size: 64
num_workers: 10
valid_dataloader_params:
mecka_bimanual:
batch_size: 32
batch_size: 64
num_workers: 10
56 changes: 0 additions & 56 deletions egomimic/hydra_configs/data/mecka_scale_cotrain_pi.yaml

This file was deleted.

6 changes: 6 additions & 0 deletions egomimic/hydra_configs/evaluator/viz/pi_cartesian.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,10 @@ scale_bimanual:
_partial_: true
image_key: base_0_rgb
action_key: actions_cartesian
mode: traj
mecka_bimanual:
_target_: egomimic.rldb.embodiment.human.Mecka.viz_gt_preds
_partial_: true
image_key: base_0_rgb
action_key: actions_cartesian
mode: traj
2 changes: 2 additions & 0 deletions egomimic/hydra_configs/evaluator/viz/pi_cartesian_lang.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ aria_bimanual:
annotation_key: sampled_prompt
scale_bimanual:
annotation_key: sampled_prompt
mecka_bimanual:
annotation_key: sampled_prompt
10 changes: 5 additions & 5 deletions egomimic/hydra_configs/model/pi0.5_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ robomimic_model:
tokenizer_max_length: 128
sampling_mode: "random"
annotation_key: "annotations"
default_prompt: "This is a bad action."
default_prompt: ""
proprio_in_prompt: true
embodiment_label: true
state_num_bins: 256
Expand All @@ -58,15 +58,15 @@ optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 3e-5
betas: [0.9, 0.999]
betas: [0.9, 0.95]
eps: 1e-8
weight_decay: 0.0
weight_decay: 1e-10

scheduler:
_target_: transformers.get_cosine_schedule_with_warmup
_partial_: true
num_warmup_steps: 2000
num_warmup_steps: 1000
num_training_steps: 60000
num_cycles: 0.5

scheduler_interval: step
scheduler_interval: step
2 changes: 1 addition & 1 deletion egomimic/hydra_configs/paths/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# data_dir: ${paths.root_dir}/data/

# default folder containing zarr episode datasets
dataset_dir: /coc/flash7/scratch/egoverseS3ZarrDataset
dataset_dir: /storage/project/r-dxu345-0/shared/egoverseS3ZarrDatasets

# path to logging directory
log_dir: ${paths.root_dir}/logs/
Expand Down
5 changes: 3 additions & 2 deletions egomimic/hydra_configs/train_zarr_cartesian.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ seed: 42
# Normalization and norm-stat cache settings (override per-run as needed).
norm_stats:
norm_mode: quantile
sample_frac: 1.0
num_workers: 4
sample_frac: 0.2
num_workers: 6
save_cache_dir: ${hydra:runtime.output_dir}
precomputed_norm_path: null
# precomputed_norm_path: /storage/project/r-dxu345-0/rco3/EgoVerse/logs/norm_stats
reject_outliers: true
47 changes: 40 additions & 7 deletions egomimic/rldb/embodiment/human.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,38 @@ class Mecka(Human):
VIZ_INTRINSICS_KEY = "mecka"
ACTION_STRIDE = 1

@classmethod
def get_transform_list(
cls,
mode: Literal["cartesian",] = "cartesian",
chunk_length: int = 100,
) -> list[Transform]:
if mode == "cartesian":
return _build_aria_cartesian_bimanual_transform_list(
stride=cls.ACTION_STRIDE,
chunk_length=chunk_length,
)

@classmethod
def get_keymap(
cls, mode: Literal["cartesian", "keypoints"], annotations: bool = False
cls,
mode: Literal["cartesian", "cartesian_pi", "keypoints"] | None = None,
keymap_mode: Literal["cartesian", "cartesian_pi", "keypoints"] | None = None,
annotations: bool = False,
norm_mode: bool = False,
annotation_key: str | None = None,
):
if mode == "cartesian":
if mode is None:
mode = keymap_mode or "cartesian"
elif keymap_mode is not None and keymap_mode != mode:
raise ValueError(
f"Conflicting Mecka keymap modes: mode='{mode}', keymap_mode='{keymap_mode}'."
)

if mode in ("cartesian", "cartesian_pi"):
front_key = "base_0_rgb" if mode == "cartesian_pi" else cls.VIZ_IMAGE_KEY
key_map = {
cls.VIZ_IMAGE_KEY: {
front_key: {
"key_type": "camera_keys",
"zarr_key": "images.front_1",
},
Expand Down Expand Up @@ -407,13 +432,21 @@ def get_keymap(
}
else:
raise ValueError(
f"Unsupported mode '{mode}'. Expected one of: 'cartesian', 'keypoints'."
f"Unsupported mode '{mode}'. Expected one of: 'cartesian', 'cartesian_pi', 'keypoints'."
)
if annotations:
key_map["annotations"] = {

requested_annotation_key = annotation_key or (
"annotations" if annotations else None
)
if requested_annotation_key is not None:
key_map[requested_annotation_key] = {
"key_type": "annotation_keys",
"zarr_key": "annotations",
"zarr_key": requested_annotation_key,
}
if norm_mode:
for key in list(key_map):
if key_map[key].get("key_type") in ("camera_keys", "annotation_keys"):
del key_map[key]
return key_map


Expand Down
Loading
Loading