Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ datasets/
**/datasets/
apikey.txt
slurm-*.out
*.out
slurmoutputs/
*.log
.inductor_cache/
Expand Down
91 changes: 81 additions & 10 deletions egomimic/eval/eval_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

from egomimic.eval.eval_video import EvalVideo
from egomimic.rldb.embodiment.embodiment import get_embodiment
from egomimic.utils.action_utils import (
PI05_CARTESIAN_ACTION_ENCODING_LEGACY,
PI05_CARTESIAN_ACTION_ENCODING_NORM_ROT_6D,
PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,6 +103,7 @@ def __init__(
self._layer_keys = {} # layer_name -> list[np.ndarray (B, S, D)]
self._row_hashes = [] # one entry per sample (replicated by S at write time)
self._row_embodiments = []
self._row_frames = [] # source frame index per sample (None-safe)
self._hook_handles = []
# per-step buffer; each layer is recorded only on its FIRST forward
# call within a capture window (prefix pass for PaliGemma, first
Expand Down Expand Up @@ -205,6 +211,22 @@ def _extract_hashes(_batch, batch_size, embodiment_name):
return [str(v) for v in val.cpu().tolist()]
return [str(val)] * batch_size

@staticmethod
def _extract_frame_indices(_batch, batch_size):
"""Per-sample SOURCE frame index, surfaced by ZarrDataset.__getitem__.
Returns None when the batch predates this field, so the caller can
fall back to the legacy per-run counter."""
val = _batch.get("frame_idx")
if val is None:
return None
if torch.is_tensor(val):
return [int(v) for v in val.cpu().tolist()]
if isinstance(val, np.ndarray):
return [int(v) for v in val.tolist()]
if isinstance(val, (list, tuple)):
return [int(v) for v in val]
return [int(val)] * batch_size

# ------------------------------------------------------------------
# Validation lifecycle
# ------------------------------------------------------------------
Expand All @@ -213,6 +235,7 @@ def on_validation_start(self):
self._layer_keys = {}
self._row_hashes = []
self._row_embodiments = []
self._row_frames = []
self._n_rows = 0
self._register_hooks()
if self.trainer.is_global_zero:
Expand Down Expand Up @@ -258,16 +281,46 @@ def compute_metrics_and_viz(self, batch):
)
self._capture_active = False

# Mirror PI's post-processing for metrics + viz.
# sample_actions returns a static CUDA-graph buffer that the
# next embodiment's forward overwrites — clone before use
# (mirrors PI.forward_eval).
pred_actions = pred_actions.clone()

# Mirror PI.forward_eval post-processing for metrics + viz,
# branching on the action encoding so the 6D-rotation models
# unpack the 20-D xyz+6D(+g) action (not the legacy 14-D ypr).
ref = _batch[ac_key]
B, T, D = ref.shape
converter = algo.action_registry.get(embodiment_id, ac_key)
pred_actions_orig = converter.from32(pred_actions)
pred = pred_actions_orig[:, :T, :D]
action_encoding = getattr(
algo, "action_encoding", PI05_CARTESIAN_ACTION_ENCODING_LEGACY
)

predictions = OrderedDict()
predictions[ac_key] = pred
unnorm_actions = algo.norm_stats.unnormalize(predictions, embodiment_id)
if action_encoding == PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D:
pred_actions_orig = converter.from32_raw_rotation(
pred_actions,
stats=algo._action_stats(embodiment_id, ac_key),
norm_mode=algo.norm_stats.norm_mode,
unnormalize_non_rotation=True,
)
unnorm_actions = {ac_key: pred_actions_orig[:, :T, :D]}
elif action_encoding == PI05_CARTESIAN_ACTION_ENCODING_NORM_ROT_6D:
pred_6d = converter.from32_norm_6d(pred_actions)
predictions[ac_key] = pred_6d[:, :T, :D]
unnorm_actions = algo.norm_stats.unnormalize(
predictions, embodiment_id
)
elif action_encoding == PI05_CARTESIAN_ACTION_ENCODING_LEGACY:
pred_actions_orig = converter.from32(pred_actions)
predictions[ac_key] = pred_actions_orig[:, :T, :D]
unnorm_actions = algo.norm_stats.unnormalize(
predictions, embodiment_id
)
else:
raise ValueError(
f"Unsupported PI0.5 action_encoding: {action_encoding!r}"
)
for k in unnorm_actions:
unnorm_preds[f"{embodiment_name}_{k}"] = unnorm_actions[k]

Expand All @@ -290,6 +343,10 @@ def compute_metrics_and_viz(self, batch):
hashes = self._extract_hashes(_batch, B, embodiment_name)
self._row_hashes.extend(hashes)
self._row_embodiments.extend([embodiment_name] * B)
frames = self._extract_frame_indices(_batch, B)
# -1 sentinel marks "no source index in batch"; if ANY sample
# is sentinel the writer falls back to the per-run counter.
self._row_frames.extend(frames if frames is not None else [-1] * B)
for layer_name, key_tensor in self._step_capture.items():
keys_bsd = key_tensor.to(torch.float32).cpu().numpy()
self._layer_keys.setdefault(layer_name, []).append(keys_bsd)
Expand Down Expand Up @@ -330,6 +387,7 @@ def on_validation_end(self):
keys = keys_bsd.reshape(N * S, D)
sample_hashes = self._row_hashes
sample_embs = self._row_embodiments
sample_frames = self._row_frames
if N != len(sample_hashes):
n = min(N, len(sample_hashes))
logger.warning(
Expand All @@ -342,14 +400,27 @@ def on_validation_end(self):
keys = keys_bsd[:n].reshape(n * S, D)
sample_hashes = sample_hashes[:n]
sample_embs = sample_embs[:n]
sample_frames = sample_frames[:n]
# Replicate per-sample metadata across the S tokens.
# frame_idx is the per-run sample index (0..N-1) — tokens of
# the same frame share it, so meanpool.py can group on
# (video_hash, frame_idx). token_idx is the position within
# the sequence, useful for token-type slicing later.
# frame_idx is the SOURCE frame index per sample (from
# ZarrDataset.__getitem__) so tokens of the same frame share it
# and the inspector can fetch the right image / annotation.
# Falls back to a per-run counter (0..N-1) only for legacy
# datasets that don't surface frame_idx (sentinel -1 present).
hashes = [h for h in sample_hashes for _ in range(S)]
embs = [e for e in sample_embs for _ in range(S)]
frame_idx = [i for i in range(len(sample_hashes)) for _ in range(S)]
use_source_frames = len(sample_frames) == len(sample_hashes) and all(
fi >= 0 for fi in sample_frames
)
if use_source_frames:
frame_idx = [fi for fi in sample_frames for _ in range(S)]
else:
logger.warning(
"%s: source frame_idx unavailable for some/all samples; "
"falling back to per-run sample counter for frame_idx.",
layer_name,
)
frame_idx = [i for i in range(len(sample_hashes)) for _ in range(S)]
token_idx = list(range(S)) * len(sample_hashes)

# Optional PCA: fit on raw `keys`, log explained variance, and
Expand Down
17 changes: 16 additions & 1 deletion egomimic/eval/latent_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@
"_shuffle_random",
"_shuffle_pairs",
"_shuffle_custom",
# Tokenizer/annotation knobs declared at the data root for OmegaConf
# interpolation. Tokenization now lives on the algo side, so
# MultiDataModuleWrapper no longer accepts these as constructor args.
"use_tokenizer",
"model_name",
"sampling_mode",
"annotation_key",
"default_prompt",
}
)

Expand All @@ -72,6 +80,7 @@ def build_dataset(
embodiment: str,
resolver,
hashes: list[str] | None = None,
exclude_hashes: list[str] | None = None,
frames_per_episode: int | None = 128,
stride: int | None = None,
valid_ratio: float = 0.05,
Expand All @@ -85,11 +94,17 @@ def build_dataset(
Unknown `mode` values raise. Extra kwargs are NOT accepted (no silent
swallowing) — typos in yaml fail loudly.
"""
excl = tuple(str(h) for h in (exclude_hashes or []))
if mode == "random":
lam = (
"lambda row: row['task'] == "
f"{task!r} and row['robot_name'] == {embodiment!r}"
f"{task!r} and row.get('embodiment') == {embodiment!r}"
)
# Drop known-corrupt episodes (e.g. zero-norm quaternions that crash the
# quat->ypr transform; the per-sample retry can't escape a fully-bad
# episode since its fallback pool is that episode's own frames).
if excl:
lam += f" and row['episode_hash'] not in {excl!r}"
filters = DatasetFilter(filter_lambdas=[lam])
logger.info("[build_dataset] %s | random mode | filter=%s", embodiment, lam)
return MultiDataset._from_resolver(
Expand Down
57 changes: 43 additions & 14 deletions egomimic/rldb/embodiment/human.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,26 @@ class Human(Embodiment):
# MANO 21-keypoint topology: 0=wrist, 1-4 thumb, 5-8 index, 9-12 middle, 13-16 ring, 17-20 pinky.
# Subclasses with non-MANO conventions (e.g. Aria) override these.
FINGER_EDGES = [
(0, 1), (1, 2), (2, 3), (3, 4), # thumb
(0, 5), (5, 6), (6, 7), (7, 8), # index
(0, 9), (9, 10), (10, 11), (11, 12), # middle
(0, 13), (13, 14), (14, 15), (15, 16), # ring
(0, 17), (17, 18), (18, 19), (19, 20), # pinky
(0, 1),
(1, 2),
(2, 3),
(3, 4), # thumb
(0, 5),
(5, 6),
(6, 7),
(7, 8), # index
(0, 9),
(9, 10),
(10, 11),
(11, 12), # middle
(0, 13),
(13, 14),
(14, 15),
(15, 16), # ring
(0, 17),
(17, 18),
(18, 19),
(19, 20), # pinky
]
FINGER_COLORS = {
"thumb": (255, 100, 100),
Expand Down Expand Up @@ -137,7 +152,7 @@ def get_transform_list(
stride=cls.ACTION_STRIDE
) + [CartesianYPRToRot6D(action_key="actions_cartesian")]
elif mode == "keypoints_headframe_ypr":
return _build_aria_keypoints_bimanual_transform_list(
return _build_human_keypoints_bimanual_transform_list(
stride=cls.ACTION_STRIDE, is_quat=False
)
if mode == "keypoints_headframe_quat":
Expand All @@ -160,11 +175,25 @@ class Aria(Human):
ACTION_STRIDE = 3
# Aria's 21-keypoint layout is NOT MANO: 0-4 are fingertips, 5 is the palm root.
FINGER_EDGES = [
(5, 6), (6, 7), (7, 0), # thumb
(5, 8), (8, 9), (9, 10), (10, 1), # index
(5, 11), (11, 12), (12, 13), (13, 2), # middle
(5, 14), (14, 15), (15, 16), (16, 3), # ring
(5, 17), (17, 18), (18, 19), (19, 4), # pinky
(5, 6),
(6, 7),
(7, 0), # thumb
(5, 8),
(8, 9),
(9, 10),
(10, 1), # index
(5, 11),
(11, 12),
(12, 13),
(13, 2), # middle
(5, 14),
(14, 15),
(15, 16),
(16, 3), # ring
(5, 17),
(17, 18),
(18, 19),
(19, 4), # pinky
]
FINGER_EDGE_RANGES = [
("thumb", 0, 3),
Expand Down Expand Up @@ -937,7 +966,7 @@ def _build_human_keypoints_bimanual_transform_list(
return transform_list


def _build_human_cartesian_revert_eef_frame_transform_list(
def _build_aria_cartesian_revert_eef_frame_transform_list(
*,
action_key: str = "actions_cartesian",
obs_key: str = "observations.state.ee_pose",
Expand All @@ -951,7 +980,7 @@ def _build_human_cartesian_revert_eef_frame_transform_list(
) -> list[Transform]:
"""Revert wrist-frame ARIA cartesian actions back to head (camera) frame.

Inverse of ``_build_human_cartesian_eef_frame_transform_list`` for viz: the
Inverse of ``_build_aria_cartesian_eef_frame_transform_list`` for viz: the
action chunks live in each side's wrist frame, the proprio ee-poses live in
headframe (= Aria camera frame). Re-composes ``target_headframe @ chunk_wristframe``
so action chunks are back in headframe / camera frame.
Expand Down Expand Up @@ -1146,7 +1175,7 @@ def _build_aria_cartesian_eef_frame_transform_list(
return transform_list


def _build_human_cartesian_bimanual_transform_list(
def _build_aria_cartesian_bimanual_transform_list(
*,
target_world: str = "obs_head_pose",
target_world_ypr: str = "obs_head_pose_ypr",
Expand Down
5 changes: 5 additions & 0 deletions egomimic/rldb/zarr/zarr_dataset_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,11 @@ def _next(reason: str, key: str = "") -> int:
data["episode_hash"] = (
ep_name[:-5] if ep_name.endswith(".zarr") else ep_name
)
# The actual source frame served (after any JPEG-decode fallback,
# `idx` is the frame finally read). Surfaced parallel to
# episode_hash so latent capture can label rows with the true
# frame index instead of a per-run counter.
data["frame_idx"] = int(idx)
_ = origin # preserved for symmetry with prior API
return data

Expand Down
Loading