From cfe54d3e6d6c2b588da2133a7f1aea22f94b6a43 Mon Sep 17 00:00:00 2001 From: Ryan Co Date: Mon, 8 Jun 2026 12:44:12 -0400 Subject: [PATCH] short state history memory --- egomimic/algo/pi.py | 94 +++++++- egomimic/algo/test_pi.py | 215 ++++++++---------- egomimic/hydra_configs/model/pi0.5_base.yaml | 12 + egomimic/models/preprocess_pi_obs.py | 13 +- egomimic/rldb/embodiment/embodiment.py | 17 +- egomimic/rldb/embodiment/human.py | 78 ++++++- egomimic/rldb/zarr/action_chunk_transforms.py | 21 +- egomimic/rldb/zarr/test_history_window.py | 87 +++++++ egomimic/rldb/zarr/zarr_dataset_multi.py | 33 ++- 9 files changed, 424 insertions(+), 146 deletions(-) create mode 100644 egomimic/rldb/zarr/test_history_window.py diff --git a/egomimic/algo/pi.py b/egomimic/algo/pi.py index 0415ee478..7340781c9 100644 --- a/egomimic/algo/pi.py +++ b/egomimic/algo/pi.py @@ -70,6 +70,15 @@ def __init__( state_num_bins: int = 256, control_mode: dict[str, str] | None = None, proprio_keys_for_prompt: list[str] | None = None, + # --------------------------- + # Short-term memory: splice a "Prev:" block of the recent EE-pose + # history (as coarse delta bins) into the prompt. `history_len` (K) + # must match the data-side `proprio_history` so proprio arrives as + # [B, K, D]. All off by default → behaviour unchanged. + # --------------------------- + prev_state_in_prompt: bool = False, + history_len: int = 1, + prev_state_num_bins: int = 64, **kwargs, ): self.nets = nn.ModuleDict() @@ -94,6 +103,13 @@ def __init__( ) self._state_bin_edges = np.linspace(-1.0, 1.0, state_num_bins + 1)[:-1] + # Short-term memory knobs (coarser bins than the current-state block + # since deltas are small and we want to spend fewer prompt tokens). + self.prev_state_in_prompt = prev_state_in_prompt + self.history_len = history_len + self.prev_state_num_bins = prev_state_num_bins + self._prev_bin_edges = np.linspace(-1.0, 1.0, prev_state_num_bins + 1)[:-1] + self.camera_transforms = camera_transforms self.train_image_augs = train_image_augs self.eval_image_augs = eval_image_augs @@ -228,6 +244,56 @@ def _discretize_state_for_sample(self, _batch, sample_idx: int) -> str | None: bins = np.digitize(state, bins=self._state_bin_edges) - 1 return " ".join(map(str, bins.tolist())) + # Cap on how many past frames the "Prev:" block encodes, to bound prompt + # token growth even when history_len is large. + _PREV_MAX_FRAMES = 2 + + def _discretize_prev_states_for_sample(self, _batch, sample_idx: int) -> list[str]: + """Encode the recent EE-pose *history* for sample i as coarse delta bins, + one bin-string PER past frame. + + Each proprio value is a ``[K, D]`` backward window (current frame last, + normalized like the "State:" block). We take up to ``_PREV_MAX_FRAMES`` + of the most-recent previous frames; for each, encode its displacement + from the current frame (``prev - cur``, small magnitudes) concatenated + across proprio keys, clip to [-1, 1] and digitize into + ``prev_state_num_bins`` (coarser) bins. + + Returns a list ordered MOST-RECENT-FIRST: element 0 is one step back + (rendered ``Prev1:``), element 1 is two steps back (``Prev2:``), etc. + Empty list when no history window is present (K < 2), so no "Prev*" + block is emitted. + """ + # Per-key delta windows, each (P, D) ordered oldest..newest. + per_key = [] + for k in self.proprio_keys_for_prompt: + if k not in _batch: + continue + v = _batch[k] + if isinstance(v, torch.Tensor): + v = v[sample_idx].detach().cpu().numpy() + else: + v = np.asarray(v)[sample_idx] + v = np.asarray(v, dtype=np.float32) + if v.ndim < 2 or v.shape[0] < 2: + continue # single frame -> no history for this key + cur = v[-1] + prev = v[:-1][-self._PREV_MAX_FRAMES :] # (P, D) oldest..newest + per_key.append(prev - cur[None, :]) + if not per_key: + return [] + n_prev = max(arr.shape[0] for arr in per_key) + out = [] + # f = 1 -> newest past frame (arr[-1]); f = 2 -> arr[-2]; ... + for f in range(1, n_prev + 1): + parts = [arr[-f].reshape(-1) for arr in per_key if f <= arr.shape[0]] + if not parts: + continue + deltas = np.clip(np.concatenate(parts, axis=-1), -1.0, 1.0) + bins = np.digitize(deltas, bins=self._prev_bin_edges) - 1 + out.append(" ".join(map(str, bins.tolist()))) + return out + def _build_prompts( self, _batch, embodiment_name: str, batch_size: int ) -> list[str]: @@ -250,8 +316,12 @@ def _build_prompts( else: # "first" prompts.append(sample[0]) + prev_active = self.prev_state_in_prompt and self.history_len > 1 any_block_active = ( - self.proprio_in_prompt or self.embodiment_label or bool(self.control_mode) + self.proprio_in_prompt + or self.embodiment_label + or bool(self.control_mode) + or prev_active ) if not any_block_active: return prompts @@ -264,6 +334,13 @@ def _build_prompts( blocks.append(f"Embodiment: {emb_name}") if self.control_mode: blocks.append(f"Control mode: {self._control_mode_for(emb_name)}") + # History first: Prev1 (one step back) .. PrevN (oldest), then the + # current State block. + if prev_active: + for j, prev_str in enumerate( + self._discretize_prev_states_for_sample(_batch, i), start=1 + ): + blocks.append(f"Prev{j}: {prev_str}") if self.proprio_in_prompt: state_str = self._discretize_state_for_sample(_batch, i) if state_str is not None: @@ -282,6 +359,21 @@ def _tokenize_prompts(self, prompts: list[str]) -> dict: return_tensors="pt", ) attention_mask = enc["attention_mask"].bool() + # Cheap truncation guard: a row with no padding filled the whole window + # and was likely truncated, dropping the trailing "Action:" anchor. + if ( + self.tokenizer_max_length is not None + and attention_mask.all(dim=1).any() + and not getattr(self, "_prompt_trunc_warned", False) + ): + logger.warning( + "Some prompts fill the entire tokenizer window " + "(max_length=%s) and may be truncated past the 'Action:' " + "anchor. Raise tokenizer_max_length or lower history_len / " + "prev_state_num_bins / state_num_bins.", + self.tokenizer_max_length, + ) + self._prompt_trunc_warned = True token_loss_mask = attention_mask.clone() token_loss_mask[:, -1] = False return { diff --git a/egomimic/algo/test_pi.py b/egomimic/algo/test_pi.py index f8c16c9da..3255d19fa 100644 --- a/egomimic/algo/test_pi.py +++ b/egomimic/algo/test_pi.py @@ -1,148 +1,125 @@ -from types import SimpleNamespace - -import pytest +import numpy as np import torch -import egomimic.algo.pi as pi_module from egomimic.algo.pi import PI -from egomimic.rldb.embodiment.embodiment import get_embodiment_id - - -class _StubNormStats: - def __init__(self, viz_img_keys): - self._viz_img_keys = viz_img_keys - def viz_img_key(self): - return self._viz_img_keys +# NOTE: the former ``test_visualize_preds_*`` tests were removed: visualization +# was moved off the ``PI`` algo into the external evaluators (commit +# "Refactor Eval: Metrics are computed in external eval class"), so the +# ``PI.visualize_preds`` method and ``pi.draw_actions`` import they exercised no +# longer exist. -def _make_transform(name): - return SimpleNamespace( - extrinsics={"name": f"{name}_extrinsics"}, - intrinsics={"name": f"{name}_intrinsics"}, - ) +# --------------------------------------------------------------------------- +# Short-term memory: "Prev:" prompt block built from recent EE-pose history. +# These exercise prompt assembly only (no tokenizer / model / data). +# --------------------------------------------------------------------------- -def _make_pi(camera_transforms, domains): +def _make_prompt_pi(**overrides): + """Bare PI with just the attributes the prompt-assembly methods read.""" pi = object.__new__(PI) - pi.domains = domains - pi.camera_transforms = camera_transforms - pi.is_6dof = False - pi.ac_keys = {get_embodiment_id(domain): "actions_cartesian" for domain in domains} - pi.norm_stats = _StubNormStats( - {get_embodiment_id(domain): "front_img_1" for domain in domains} - ) + pi.annotation_key = None + pi.default_prompt = "do task" + pi.sampling_mode = "first" + pi.proprio_in_prompt = True + pi.embodiment_label = False + pi.control_mode = None + pi.prev_state_in_prompt = False + pi.history_len = 1 + pi.proprio_keys_for_prompt = ["observations.state.ee_pose"] + pi.state_num_bins = 256 + pi.prev_state_num_bins = 64 + pi._state_bin_edges = np.linspace(-1.0, 1.0, 256 + 1)[:-1] + pi._prev_bin_edges = np.linspace(-1.0, 1.0, 64 + 1)[:-1] + for key, val in overrides.items(): + setattr(pi, key, val) return pi -def _make_batch(embodiment_name): - embodiment_id = get_embodiment_id(embodiment_name) +def _hist_batch(prev=0.4, cur=0.5, dim=4): + # [B=1, K=2, D]: row 0 is the previous frame, row 1 (last) is the current. return { - "embodiment": torch.tensor([embodiment_id]), - "front_img_1": torch.zeros(1, 3, 4, 4), - "actions_cartesian": torch.zeros(1, 2, 6), + "observations.state.ee_pose": torch.tensor( + [[[prev] * dim, [cur] * dim]], dtype=torch.float32 + ) } -def _make_predictions(embodiment_name): - return {f"{embodiment_name}_actions_cartesian": torch.ones(1, 2, 6)} - - -def test_visualize_preds_supports_single_transform_object(monkeypatch): - shared_transform = _make_transform("shared") - pi = _make_pi(shared_transform, ["aria_bimanual"]) +def test_prev_block_absent_when_disabled(): + pi = _make_prompt_pi(prev_state_in_prompt=False, history_len=2) + out = pi._build_prompts(_hist_batch(), "eva_bimanual", 1) + assert len(out) == 1 + assert "Prev1:" not in out[0] + assert "State:" in out[0] + assert out[0].endswith(";\nAction: ") - draw_calls = [] - def fake_draw_actions( - im, ac_type, color, actions, extrinsics, intrinsics, arm="both", **kwargs - ): - draw_calls.append((extrinsics, intrinsics)) - return im +def test_prev_block_present_and_ordered_when_enabled(): + pi = _make_prompt_pi(prev_state_in_prompt=True, history_len=2) + out = pi._build_prompts(_hist_batch(), "eva_bimanual", 1)[0] + assert "State:" in out and "Prev1:" in out + assert out.index("Prev1:") < out.index("State:") # recent motion then current + assert out.endswith(";\nAction: ") - monkeypatch.setattr(pi_module, "draw_actions", fake_draw_actions) - ims = pi.visualize_preds( - _make_predictions("aria_bimanual"), _make_batch("aria_bimanual") - ) +def test_prev_delta_bins_zero_motion_is_center_bin(): + pi = _make_prompt_pi(prev_state_in_prompt=True, history_len=2) + # K=2 window -> one past frame -> a single-element list. + s = pi._discretize_prev_states_for_sample(_hist_batch(prev=0.3, cur=0.3, dim=3), 0) + center = int(np.digitize(0.0, np.linspace(-1.0, 1.0, 65)[:-1]) - 1) + assert s == [" ".join([str(center)] * 3)] - assert ims.shape == (1, 4, 4, 3) - assert len(draw_calls) == 2 - assert all( - extrinsics is shared_transform.extrinsics for extrinsics, _ in draw_calls - ) - assert all( - intrinsics is shared_transform.intrinsics for _, intrinsics in draw_calls - ) +def test_prev_delta_bins_known_value(): + pi = _make_prompt_pi(prev_state_in_prompt=True, history_len=2) + s = pi._discretize_prev_states_for_sample(_hist_batch(prev=0.4, cur=0.5, dim=2), 0) + edges = np.linspace(-1.0, 1.0, 65)[:-1] + expected = int(np.digitize(0.4 - 0.5, edges) - 1) # delta = prev - cur + assert s == [" ".join([str(expected)] * 2)] -def test_visualize_preds_raises_clear_error_for_missing_embodiment(): - pi = _make_pi( - {"aria_bimanual": _make_transform("aria")}, - ["aria_bimanual", "eva_bimanual"], - ) - with pytest.raises(KeyError) as exc_info: - pi.visualize_preds( - _make_predictions("eva_bimanual"), _make_batch("eva_bimanual") +def test_prev_returns_empty_without_history(): + pi = _make_prompt_pi(prev_state_in_prompt=True, history_len=2) + # Single-frame proprio [B, D] -> per-sample [D] -> no history window. + batch = { + "observations.state.ee_pose": torch.tensor( + [[0.5, 0.5, 0.5]], dtype=torch.float32 ) - - assert "Missing camera transform for embodiment 'eva_bimanual'" in str( - exc_info.value - ) - assert "aria_bimanual" in str(exc_info.value) + } + assert pi._discretize_prev_states_for_sample(batch, 0) == [] + assert "Prev1:" not in pi._build_prompts(batch, "eva_bimanual", 1)[0] -def test_visualize_preds_rejects_invalid_camera_transform_shape(): - pi = _make_pi({"aria_bimanual": {"extrinsics": {}}}, ["aria_bimanual"]) +def test_prev_caps_number_of_past_frames(): + pi = _make_prompt_pi(prev_state_in_prompt=True, history_len=5) + # K=5 window -> 4 past frames, but only _PREV_MAX_FRAMES are encoded, each + # as its own Prev bin-string of width `dim`. + dim = 2 + window = torch.arange(5 * dim, dtype=torch.float32).reshape(1, 5, dim) / 100.0 + s = pi._discretize_prev_states_for_sample({"observations.state.ee_pose": window}, 0) + assert len(s) == PI._PREV_MAX_FRAMES + assert all(len(frame.split()) == dim for frame in s) - with pytest.raises(TypeError) as exc_info: - pi.visualize_preds( - _make_predictions("aria_bimanual"), _make_batch("aria_bimanual") - ) - assert "camera_transforms must be a CameraTransforms instance or a mapping" in str( - exc_info.value - ) - - -def test_visualize_preds_uses_embodiment_specific_camera_transform(monkeypatch): - aria_transform = _make_transform("aria") - eva_transform = _make_transform("eva") - pi = _make_pi( - {"aria_bimanual": aria_transform, "eva_bimanual": eva_transform}, - ["aria_bimanual", "eva_bimanual"], - ) - - draw_calls = [] - - def fake_draw_actions( - im, ac_type, color, actions, extrinsics, intrinsics, arm="both", **kwargs - ): - draw_calls.append( - { - "ac_type": ac_type, - "color": color, - "extrinsics": extrinsics, - "intrinsics": intrinsics, - "arm": arm, - "shape": tuple(actions.shape), - } - ) - return im - - monkeypatch.setattr(pi_module, "draw_actions", fake_draw_actions) - - ims = pi.visualize_preds( - _make_predictions("aria_bimanual"), _make_batch("aria_bimanual") - ) - - assert ims.shape == (1, 4, 4, 3) - assert len(draw_calls) == 2 - assert all(call["extrinsics"] is aria_transform.extrinsics for call in draw_calls) - assert all(call["intrinsics"] is aria_transform.intrinsics for call in draw_calls) - assert all(call["arm"] == "both" for call in draw_calls) - assert all(call["shape"] == (2, 6) for call in draw_calls) - assert all( - call["extrinsics"] is not eva_transform.extrinsics for call in draw_calls - ) +def test_tokenize_truncation_warns_once(): + pi = object.__new__(PI) + pi.tokenizer_max_length = 4 + pi._prompt_trunc_warned = False + + def fake_tokenizer(prompts, padding, truncation, max_length, return_tensors): + n = len(prompts) + # No padding column -> the window is full -> truncation suspected. + return { + "input_ids": torch.ones(n, 4, dtype=torch.long), + "attention_mask": torch.ones(n, 4, dtype=torch.long), + } + + pi.tokenizer = fake_tokenizer + out = pi._tokenize_prompts(["a"]) + assert pi._prompt_trunc_warned is True + assert out["tokenized_prompt"].shape == (1, 4) + # token_loss_mask masks out the final (anchor) position. + assert out["token_loss_mask"][0, -1].item() is False + pi._tokenize_prompts(["b"]) # stays warned, no crash + assert pi._prompt_trunc_warned is True diff --git a/egomimic/hydra_configs/model/pi0.5_base.yaml b/egomimic/hydra_configs/model/pi0.5_base.yaml index b3fa8f362..f320c0710 100644 --- a/egomimic/hydra_configs/model/pi0.5_base.yaml +++ b/egomimic/hydra_configs/model/pi0.5_base.yaml @@ -39,6 +39,18 @@ robomimic_model: state_num_bins: 256 control_mode: null + # Short-term memory (recent EE-pose history spliced into the prompt as a + # coarse "Prev:" delta block). OFF by default → behaviour unchanged. To + # enable, set prev_state_in_prompt=true, history_len=K (K>1), and match the + # data side so proprio arrives as [B,K,D], e.g. + # +data.train_datasets.eva_bimanual.resolver.key_map.proprio_history=K + # (the get_keymap _target_ forwards proprio_history to every proprio key). + # Mind the token budget (tokenizer_max_length=128): keep K small and bins + # coarse, or bump tokenizer_max_length (<= model.max_token_len=180). + prev_state_in_prompt: false + history_len: 1 + prev_state_num_bins: 64 + train_image_augs: _target_: torchvision.transforms.Compose transforms: diff --git a/egomimic/models/preprocess_pi_obs.py b/egomimic/models/preprocess_pi_obs.py index 8ba299f2b..61ebb09b1 100644 --- a/egomimic/models/preprocess_pi_obs.py +++ b/egomimic/models/preprocess_pi_obs.py @@ -43,11 +43,20 @@ def _mask_from_batch(B: int, device) -> torch.Tensor: def _concat_proprio( batch: dict, proprio_keys: list[str], device: torch.device ) -> torch.Tensor: - """Concat all proprio tensors along last dim → [B, D] (D can be 0).""" + """Concat all proprio tensors along last dim → [B, D] (D can be 0). + + With short-term memory enabled a proprio tensor is [B, K, D] (a backward + window of poses). The continuous ``state`` fed to the model stays the + current frame only, [B, D], so the model-facing observation is unchanged + (the history is consumed separately when building the prompt). + """ parts = [] for k in proprio_keys: if k in batch: - parts.append(batch[k].to(device)) + v = batch[k].to(device) + if v.ndim == 3: # [B, K, D] history window -> current frame [B, D] + v = v[:, -1, :] + parts.append(v) if not parts: # If no proprio, infer B from any tensor in batch (best-effort), else 0 for v in batch.values(): diff --git a/egomimic/rldb/embodiment/embodiment.py b/egomimic/rldb/embodiment/embodiment.py index cb7bbf4f9..512c36bd6 100644 --- a/egomimic/rldb/embodiment/embodiment.py +++ b/egomimic/rldb/embodiment/embodiment.py @@ -136,9 +136,24 @@ def viz( ) @classmethod - def get_keymap(cls, keymap_mode: str, norm_mode: bool = False, annotation_key=None): + def get_keymap( + cls, + keymap_mode: str, + norm_mode: bool = False, + annotation_key=None, + proprio_history: int = 1, + ): """Returns a dictionary mapping from the raw keys in the dataset to the canonical keys used by the model.""" key_map = cls._get_keymap(keymap_mode) + # Short-term memory: when proprio_history>1, read a backward window of + # that many frames for every proprio key. The per-episode reader stacks + # them as a leading [K, ...] axis (current frame last), so the existing + # pose transforms + ConcatKeys emit observations.state.* of shape + # [K, D]. proprio_history=1 (default) is a byte-identical no-op. + if proprio_history and proprio_history > 1: + for entry in key_map.values(): + if entry.get("key_type") == "proprio_keys": + entry["history"] = proprio_history if annotation_key is not None and not norm_mode: key_map[annotation_key] = { "key_type": "annotation_keys", diff --git a/egomimic/rldb/embodiment/human.py b/egomimic/rldb/embodiment/human.py index 957b43535..e4e9f9108 100644 --- a/egomimic/rldb/embodiment/human.py +++ b/egomimic/rldb/embodiment/human.py @@ -31,11 +31,26 @@ class Human(Embodiment): # MANO 21-keypoint topology: 0=wrist, 1-4 thumb, 5-8 index, 9-12 middle, 13-16 ring, 17-20 pinky. # Subclasses with non-MANO conventions (e.g. Aria) override these. FINGER_EDGES = [ - (0, 1), (1, 2), (2, 3), (3, 4), # thumb - (0, 5), (5, 6), (6, 7), (7, 8), # index - (0, 9), (9, 10), (10, 11), (11, 12), # middle - (0, 13), (13, 14), (14, 15), (15, 16), # ring - (0, 17), (17, 18), (18, 19), (19, 20), # pinky + (0, 1), + (1, 2), + (2, 3), + (3, 4), # thumb + (0, 5), + (5, 6), + (6, 7), + (7, 8), # index + (0, 9), + (9, 10), + (10, 11), + (11, 12), # middle + (0, 13), + (13, 14), + (14, 15), + (15, 16), # ring + (0, 17), + (17, 18), + (18, 19), + (19, 20), # pinky ] FINGER_COLORS = { "thumb": (255, 100, 100), @@ -111,7 +126,9 @@ def get_transform_list( ], ) -> list[Transform]: if mode == "cartesian": - return _build_human_cartesian_bimanual_transform_list(stride=cls.ACTION_STRIDE) + return _build_human_cartesian_bimanual_transform_list( + stride=cls.ACTION_STRIDE + ) if mode == "cartesian_padded": return _build_human_cartesian_bimanual_transform_list( stride=cls.ACTION_STRIDE @@ -121,7 +138,9 @@ def get_transform_list( stride=cls.ACTION_STRIDE, rot_repr="6d" ) if mode == "cartesian_wristframe_ypr": - return _build_human_cartesian_eef_frame_transform_list(stride=cls.ACTION_STRIDE) + return _build_human_cartesian_eef_frame_transform_list( + stride=cls.ACTION_STRIDE + ) if mode == "cartesian_wristframe_6d": return _build_human_cartesian_eef_frame_transform_list( stride=cls.ACTION_STRIDE, rot_repr="6d" @@ -150,11 +169,25 @@ class Aria(Human): ACTION_STRIDE = 3 # Aria's 21-keypoint layout is NOT MANO: 0-4 are fingertips, 5 is the palm root. FINGER_EDGES = [ - (5, 6), (6, 7), (7, 0), # thumb - (5, 8), (8, 9), (9, 10), (10, 1), # index - (5, 11), (11, 12), (12, 13), (13, 2), # middle - (5, 14), (14, 15), (15, 16), (16, 3), # ring - (5, 17), (17, 18), (18, 19), (19, 4), # pinky + (5, 6), + (6, 7), + (7, 0), # thumb + (5, 8), + (8, 9), + (9, 10), + (10, 1), # index + (5, 11), + (11, 12), + (12, 13), + (13, 2), # middle + (5, 14), + (14, 15), + (15, 16), + (16, 3), # ring + (5, 17), + (17, 18), + (18, 19), + (19, 4), # pinky ] FINGER_EDGE_RANGES = [ ("thumb", 0, 3), @@ -363,6 +396,7 @@ def get_keymap( annotations: bool = False, norm_mode: bool = False, annotation_key: str | None = None, + proprio_history: int = 1, ): if mode is None: mode = keymap_mode or "cartesian" @@ -465,6 +499,26 @@ def get_keymap( for key in list(key_map): if key_map[key].get("key_type") in ("camera_keys", "annotation_keys"): del key_map[key] + # Short-term memory: when proprio_history>1, read a backward window of + # that many frames for the EE/keypoint observation poses (current frame + # last). proprio_history=1 is a byte-identical no-op. + # + # `obs_head_pose` is EXCLUDED: it is the coordinate REFERENCE frame + # (target_world for the head-frame transforms), not a model observation. + # It must stay single-frame so each windowed EE pose is re-expressed in + # the *current* head frame; giving it a window makes target_world (K, 7) + # and breaks the (B, 7)-only frame math. The reader's `horizon` branch + # also takes precedence, so keypoints-mode keys with a horizon are + # unaffected. + history_exclude = {"obs_head_pose"} + if proprio_history and proprio_history > 1: + for name, entry in key_map.items(): + if ( + entry.get("key_type") == "proprio_keys" + and name not in history_exclude + and entry.get("zarr_key") not in history_exclude + ): + entry["history"] = proprio_history return key_map diff --git a/egomimic/rldb/zarr/action_chunk_transforms.py b/egomimic/rldb/zarr/action_chunk_transforms.py index 0e7aeafb0..75fb054f3 100644 --- a/egomimic/rldb/zarr/action_chunk_transforms.py +++ b/egomimic/rldb/zarr/action_chunk_transforms.py @@ -358,15 +358,20 @@ def __init__( def transform(self, batch: dict) -> dict: pose_world = np.asarray(batch[self.pose_world]) + # A single pose is 1D (D,); a short-term-memory history window is 2D + # (K, D). Re-express either as a (chunk, D) array for the chunk + # transform, then squeeze back only in the single-pose case so the + # default behaviour is byte-identical. + single = pose_world.ndim == 1 + chunk_world = pose_world[None, :] if single else pose_world transformed = self._chunk_transform.transform( { self.target_world: batch[self.target_world], - self.pose_world: pose_world[None, :], + self.pose_world: chunk_world, } ) - batch[self.transformed_key_name] = np.asarray( - transformed[self.transformed_key_name] - )[0] + out = np.asarray(transformed[self.transformed_key_name]) + batch[self.transformed_key_name] = out[0] if single else out return batch @@ -638,12 +643,8 @@ def transform(self, batch: dict) -> dict: ) pad_shape = (*arr.shape[:-1], 1) pad = np.zeros(pad_shape, dtype=arr.dtype) - padded = np.concatenate( - (arr[..., :6], pad, arr[..., 6:], pad), axis=-1 - ) - batch[self.action_key] = ( - torch.from_numpy(padded) if is_tensor else padded - ) + padded = np.concatenate((arr[..., :6], pad, arr[..., 6:], pad), axis=-1) + batch[self.action_key] = torch.from_numpy(padded) if is_tensor else padded return batch diff --git a/egomimic/rldb/zarr/test_history_window.py b/egomimic/rldb/zarr/test_history_window.py new file mode 100644 index 000000000..4e63ee5c9 --- /dev/null +++ b/egomimic/rldb/zarr/test_history_window.py @@ -0,0 +1,87 @@ +"""Unit tests for the short-term-memory backward-window helpers. + +CPU-only and store-free: they exercise ``ZarrDataset._pad_sequences_left`` +(front-padding for a history window read near the start of an episode) and the +backward-window index math used in ``__getitem__``. No zarr store, model, or +data download is required. +""" + +import numpy as np + +from egomimic.rldb.zarr.action_chunk_transforms import PoseCoordinateFrameTransform +from egomimic.rldb.zarr.zarr_dataset_multi import ZarrDataset + + +def _bare_dataset(): + # _pad_sequences_left uses no instance state, so a bare instance is enough. + return object.__new__(ZarrDataset) + + +def test_left_pad_repeats_first_frame(): + ds = _bare_dataset() + arr = np.array([[1.0, 2.0], [3.0, 4.0]]) # (2, 2); window short of history=4 + out = ds._pad_sequences_left({"k": arr.copy()}, history=4)["k"] + assert out.shape == (4, 2) + # Front rows are copies of the original FIRST frame; the tail is untouched, + # so the current frame stays at row -1. + assert np.array_equal(out[0], arr[0]) + assert np.array_equal(out[1], arr[0]) + assert np.array_equal(out[2:], arr) + assert np.array_equal(out[-1], arr[-1]) + + +def test_left_pad_noop_when_window_full(): + ds = _bare_dataset() + arr = np.arange(6).reshape(3, 2).astype(float) + out = ds._pad_sequences_left({"k": arr.copy()}, history=3)["k"] + assert np.array_equal(out, arr) + + +def test_left_pad_none_history_is_noop(): + ds = _bare_dataset() + arr = np.arange(4).reshape(2, 2).astype(float) + out = ds._pad_sequences_left({"k": arr.copy()}, history=None)["k"] + assert np.array_equal(out, arr) + + +def _backward_window(idx, history): + """Replicates the inline read-interval math in ``__getitem__``.""" + start_idx = max(0, idx - history + 1) + return start_idx, idx + 1 + + +def test_backward_window_full_in_interior(): + # Away from the episot start, the window is exactly `history` long and the + # last index is the current frame. + start, end = _backward_window(idx=10, history=4) + assert (start, end) == (7, 11) + assert end - start == 4 + assert end - 1 == 10 # current frame is the last row + + +def test_backward_window_clamps_at_episode_start(): + start, end = _backward_window(idx=1, history=4) + assert (start, end) == (0, 2) # only 2 real frames -> left-pad fills the rest + assert end - 1 == 1 + + +def test_pose_transform_history_matches_single_pose_rowwise(): + # The short-term-memory read feeds a [K, 7] window through + # PoseCoordinateFrameTransform. Each row must be transformed exactly as if + # it had been the lone single pose, so the default (single) path is + # preserved and the history rows are consistent with it. + target = np.array([0.05, -0.1, 0.2, 1.0, 0.0, 0.0, 0.0]) # xyz + quat(wxyz) + pose_a = np.array([0.1, 0.2, 0.3, 1.0, 0.0, 0.0, 0.0]) + pose_b = np.array([-0.2, 0.0, 0.4, 0.0, 1.0, 0.0, 0.0]) # 180deg about x + + t = PoseCoordinateFrameTransform( + target_world="tw", pose_world="pw", transformed_key_name="out", mode="xyzwxyz" + ) + + out_single = t.transform({"tw": target.copy(), "pw": pose_b.copy()})["out"] + assert out_single.shape == (7,) + + window = np.stack([pose_a, pose_b], axis=0) # current frame (pose_b) is last + out_hist = t.transform({"tw": target.copy(), "pw": window.copy()})["out"] + assert out_hist.shape == (2, 7) + assert np.allclose(out_hist[-1], out_single, atol=1e-6) diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index 37e6ec0bb..ab707ccc1 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -1713,6 +1713,27 @@ def _pad_sequences(self, data, horizon: int | None) -> dict: return data + def _pad_sequences_left(self, data, history: int | None) -> dict: + """Front-pad a backward (history) window by repeating the FIRST frame. + + Mirrors ``_pad_sequences`` but for short-term memory reads near the + start of an episode (``idx < history-1``), where the window is shorter + than ``history``. Padding the front keeps the current frame at row -1. + """ + if history is None: + return data + + for k in data: + if isinstance(data[k], np.ndarray): + seq_len = data[k].shape[0] + if seq_len < history: + pad_len = history - seq_len + first_frame = data[k][:1] + padding = np.repeat(first_frame, pad_len, axis=0) + data[k] = np.concatenate([padding, data[k]], axis=0) + + return data + def __getitem__( self, idx: int, @@ -1752,6 +1773,10 @@ def _next(reason: str, key: str = "") -> int: zarr_key = self.key_map[k]["zarr_key"] key_type = self.key_map[k].get("key_type", None) horizon = self.key_map[k].get("horizon", None) + # Short-term memory: read a backward window of `history` frames + # ending at (and inclusive of) the current frame. Mutually + # exclusive with `horizon` (forward action chunk). + history = self.key_map[k].get("history", None) if key_type == "annotation_keys": data[k] = self._annotation_text_for_frame(idx) @@ -1760,11 +1785,15 @@ def _next(reason: str, key: str = "") -> int: if horizon is not None: end_idx = self._chunk_end_idx(idx, horizon, key_type) read_interval = (idx, end_idx) + elif history is not None: + start_idx = max(0, idx - history + 1) + read_interval = (start_idx, idx + 1) # last row == current frame else: read_interval = (idx, None) read_dict = {zarr_key: read_interval} raw_data = self.episode_reader.read(read_dict) self._pad_sequences(raw_data, horizon) # should be able to pad images + self._pad_sequences_left(raw_data, history) # front-pad at ep start data[k] = raw_data[zarr_key] if zarr_key in self._image_keys: @@ -1794,7 +1823,9 @@ def _next(reason: str, key: str = "") -> int: data["embodiment"] = get_embodiment_id(self.embodiment) ep_name = Path(self.episode_path).name - data["episode_hash"] = ep_name[:-5] if ep_name.endswith(".zarr") else ep_name + data["episode_hash"] = ( + ep_name[:-5] if ep_name.endswith(".zarr") else ep_name + ) data["task"] = self.task or self.metadata.get("task_name", "unknown") _ = origin # preserved for symmetry with prior API return data