Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion egomimic/algo/pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def __init__(
state_num_bins: int = 256,
control_mode: dict[str, str] | None = None,
proprio_keys_for_prompt: list[str] | None = None,
# ---------------------------
# Short-term memory: splice a "Prev:" block of the recent EE-pose
# history (as coarse delta bins) into the prompt. `history_len` (K)
# must match the data-side `proprio_history` so proprio arrives as
# [B, K, D]. All off by default → behaviour unchanged.
# ---------------------------
prev_state_in_prompt: bool = False,
history_len: int = 1,
prev_state_num_bins: int = 64,
**kwargs,
):
self.nets = nn.ModuleDict()
Expand All @@ -94,6 +103,13 @@ def __init__(
)
self._state_bin_edges = np.linspace(-1.0, 1.0, state_num_bins + 1)[:-1]

# Short-term memory knobs (coarser bins than the current-state block
# since deltas are small and we want to spend fewer prompt tokens).
self.prev_state_in_prompt = prev_state_in_prompt
self.history_len = history_len
self.prev_state_num_bins = prev_state_num_bins
self._prev_bin_edges = np.linspace(-1.0, 1.0, prev_state_num_bins + 1)[:-1]

self.camera_transforms = camera_transforms
self.train_image_augs = train_image_augs
self.eval_image_augs = eval_image_augs
Expand Down Expand Up @@ -228,6 +244,56 @@ def _discretize_state_for_sample(self, _batch, sample_idx: int) -> str | None:
bins = np.digitize(state, bins=self._state_bin_edges) - 1
return " ".join(map(str, bins.tolist()))

# Cap on how many past frames the "Prev:" block encodes, to bound prompt
# token growth even when history_len is large.
_PREV_MAX_FRAMES = 2

def _discretize_prev_states_for_sample(self, _batch, sample_idx: int) -> list[str]:
"""Encode the recent EE-pose *history* for sample i as coarse delta bins,
one bin-string PER past frame.

Each proprio value is a ``[K, D]`` backward window (current frame last,
normalized like the "State:" block). We take up to ``_PREV_MAX_FRAMES``
of the most-recent previous frames; for each, encode its displacement
from the current frame (``prev - cur``, small magnitudes) concatenated
across proprio keys, clip to [-1, 1] and digitize into
``prev_state_num_bins`` (coarser) bins.

Returns a list ordered MOST-RECENT-FIRST: element 0 is one step back
(rendered ``Prev1:``), element 1 is two steps back (``Prev2:``), etc.
Empty list when no history window is present (K < 2), so no "Prev*"
block is emitted.
"""
# Per-key delta windows, each (P, D) ordered oldest..newest.
per_key = []
for k in self.proprio_keys_for_prompt:
if k not in _batch:
continue
v = _batch[k]
if isinstance(v, torch.Tensor):
v = v[sample_idx].detach().cpu().numpy()
else:
v = np.asarray(v)[sample_idx]
v = np.asarray(v, dtype=np.float32)
if v.ndim < 2 or v.shape[0] < 2:
continue # single frame -> no history for this key
cur = v[-1]
prev = v[:-1][-self._PREV_MAX_FRAMES :] # (P, D) oldest..newest
per_key.append(prev - cur[None, :])
if not per_key:
return []
n_prev = max(arr.shape[0] for arr in per_key)
out = []
# f = 1 -> newest past frame (arr[-1]); f = 2 -> arr[-2]; ...
for f in range(1, n_prev + 1):
parts = [arr[-f].reshape(-1) for arr in per_key if f <= arr.shape[0]]
if not parts:
continue
deltas = np.clip(np.concatenate(parts, axis=-1), -1.0, 1.0)
bins = np.digitize(deltas, bins=self._prev_bin_edges) - 1
out.append(" ".join(map(str, bins.tolist())))
return out

def _build_prompts(
self, _batch, embodiment_name: str, batch_size: int
) -> list[str]:
Expand All @@ -250,8 +316,12 @@ def _build_prompts(
else: # "first"
prompts.append(sample[0])

prev_active = self.prev_state_in_prompt and self.history_len > 1
any_block_active = (
self.proprio_in_prompt or self.embodiment_label or bool(self.control_mode)
self.proprio_in_prompt
or self.embodiment_label
or bool(self.control_mode)
or prev_active
)
if not any_block_active:
return prompts
Expand All @@ -264,6 +334,13 @@ def _build_prompts(
blocks.append(f"Embodiment: {emb_name}")
if self.control_mode:
blocks.append(f"Control mode: {self._control_mode_for(emb_name)}")
# History first: Prev1 (one step back) .. PrevN (oldest), then the
# current State block.
if prev_active:
for j, prev_str in enumerate(
self._discretize_prev_states_for_sample(_batch, i), start=1
):
blocks.append(f"Prev{j}: {prev_str}")
if self.proprio_in_prompt:
state_str = self._discretize_state_for_sample(_batch, i)
if state_str is not None:
Expand All @@ -282,6 +359,21 @@ def _tokenize_prompts(self, prompts: list[str]) -> dict:
return_tensors="pt",
)
attention_mask = enc["attention_mask"].bool()
# Cheap truncation guard: a row with no padding filled the whole window
# and was likely truncated, dropping the trailing "Action:" anchor.
if (
self.tokenizer_max_length is not None
and attention_mask.all(dim=1).any()
and not getattr(self, "_prompt_trunc_warned", False)
):
logger.warning(
"Some prompts fill the entire tokenizer window "
"(max_length=%s) and may be truncated past the 'Action:' "
"anchor. Raise tokenizer_max_length or lower history_len / "
"prev_state_num_bins / state_num_bins.",
self.tokenizer_max_length,
)
self._prompt_trunc_warned = True
token_loss_mask = attention_mask.clone()
token_loss_mask[:, -1] = False
return {
Expand Down
215 changes: 96 additions & 119 deletions egomimic/algo/test_pi.py
Original file line number Diff line number Diff line change
@@ -1,148 +1,125 @@
from types import SimpleNamespace

import pytest
import numpy as np
import torch

import egomimic.algo.pi as pi_module
from egomimic.algo.pi import PI
from egomimic.rldb.embodiment.embodiment import get_embodiment_id


class _StubNormStats:
def __init__(self, viz_img_keys):
self._viz_img_keys = viz_img_keys

def viz_img_key(self):
return self._viz_img_keys
# NOTE: the former ``test_visualize_preds_*`` tests were removed: visualization
# was moved off the ``PI`` algo into the external evaluators (commit
# "Refactor Eval: Metrics are computed in external eval class"), so the
# ``PI.visualize_preds`` method and ``pi.draw_actions`` import they exercised no
# longer exist.


def _make_transform(name):
return SimpleNamespace(
extrinsics={"name": f"{name}_extrinsics"},
intrinsics={"name": f"{name}_intrinsics"},
)
# ---------------------------------------------------------------------------
# Short-term memory: "Prev:" prompt block built from recent EE-pose history.
# These exercise prompt assembly only (no tokenizer / model / data).
# ---------------------------------------------------------------------------


def _make_pi(camera_transforms, domains):
def _make_prompt_pi(**overrides):
"""Bare PI with just the attributes the prompt-assembly methods read."""
pi = object.__new__(PI)
pi.domains = domains
pi.camera_transforms = camera_transforms
pi.is_6dof = False
pi.ac_keys = {get_embodiment_id(domain): "actions_cartesian" for domain in domains}
pi.norm_stats = _StubNormStats(
{get_embodiment_id(domain): "front_img_1" for domain in domains}
)
pi.annotation_key = None
pi.default_prompt = "do task"
pi.sampling_mode = "first"
pi.proprio_in_prompt = True
pi.embodiment_label = False
pi.control_mode = None
pi.prev_state_in_prompt = False
pi.history_len = 1
pi.proprio_keys_for_prompt = ["observations.state.ee_pose"]
pi.state_num_bins = 256
pi.prev_state_num_bins = 64
pi._state_bin_edges = np.linspace(-1.0, 1.0, 256 + 1)[:-1]
pi._prev_bin_edges = np.linspace(-1.0, 1.0, 64 + 1)[:-1]
for key, val in overrides.items():
setattr(pi, key, val)
return pi


def _make_batch(embodiment_name):
embodiment_id = get_embodiment_id(embodiment_name)
def _hist_batch(prev=0.4, cur=0.5, dim=4):
# [B=1, K=2, D]: row 0 is the previous frame, row 1 (last) is the current.
return {
"embodiment": torch.tensor([embodiment_id]),
"front_img_1": torch.zeros(1, 3, 4, 4),
"actions_cartesian": torch.zeros(1, 2, 6),
"observations.state.ee_pose": torch.tensor(
[[[prev] * dim, [cur] * dim]], dtype=torch.float32
)
}


def _make_predictions(embodiment_name):
return {f"{embodiment_name}_actions_cartesian": torch.ones(1, 2, 6)}


def test_visualize_preds_supports_single_transform_object(monkeypatch):
shared_transform = _make_transform("shared")
pi = _make_pi(shared_transform, ["aria_bimanual"])
def test_prev_block_absent_when_disabled():
pi = _make_prompt_pi(prev_state_in_prompt=False, history_len=2)
out = pi._build_prompts(_hist_batch(), "eva_bimanual", 1)
assert len(out) == 1
assert "Prev1:" not in out[0]
assert "State:" in out[0]
assert out[0].endswith(";\nAction: ")

draw_calls = []

def fake_draw_actions(
im, ac_type, color, actions, extrinsics, intrinsics, arm="both", **kwargs
):
draw_calls.append((extrinsics, intrinsics))
return im
def test_prev_block_present_and_ordered_when_enabled():
pi = _make_prompt_pi(prev_state_in_prompt=True, history_len=2)
out = pi._build_prompts(_hist_batch(), "eva_bimanual", 1)[0]
assert "State:" in out and "Prev1:" in out
assert out.index("Prev1:") < out.index("State:") # recent motion then current
assert out.endswith(";\nAction: ")

monkeypatch.setattr(pi_module, "draw_actions", fake_draw_actions)

ims = pi.visualize_preds(
_make_predictions("aria_bimanual"), _make_batch("aria_bimanual")
)
def test_prev_delta_bins_zero_motion_is_center_bin():
pi = _make_prompt_pi(prev_state_in_prompt=True, history_len=2)
# K=2 window -> one past frame -> a single-element list.
s = pi._discretize_prev_states_for_sample(_hist_batch(prev=0.3, cur=0.3, dim=3), 0)
center = int(np.digitize(0.0, np.linspace(-1.0, 1.0, 65)[:-1]) - 1)
assert s == [" ".join([str(center)] * 3)]

assert ims.shape == (1, 4, 4, 3)
assert len(draw_calls) == 2
assert all(
extrinsics is shared_transform.extrinsics for extrinsics, _ in draw_calls
)
assert all(
intrinsics is shared_transform.intrinsics for _, intrinsics in draw_calls
)

def test_prev_delta_bins_known_value():
pi = _make_prompt_pi(prev_state_in_prompt=True, history_len=2)
s = pi._discretize_prev_states_for_sample(_hist_batch(prev=0.4, cur=0.5, dim=2), 0)
edges = np.linspace(-1.0, 1.0, 65)[:-1]
expected = int(np.digitize(0.4 - 0.5, edges) - 1) # delta = prev - cur
assert s == [" ".join([str(expected)] * 2)]

def test_visualize_preds_raises_clear_error_for_missing_embodiment():
pi = _make_pi(
{"aria_bimanual": _make_transform("aria")},
["aria_bimanual", "eva_bimanual"],
)

with pytest.raises(KeyError) as exc_info:
pi.visualize_preds(
_make_predictions("eva_bimanual"), _make_batch("eva_bimanual")
def test_prev_returns_empty_without_history():
pi = _make_prompt_pi(prev_state_in_prompt=True, history_len=2)
# Single-frame proprio [B, D] -> per-sample [D] -> no history window.
batch = {
"observations.state.ee_pose": torch.tensor(
[[0.5, 0.5, 0.5]], dtype=torch.float32
)

assert "Missing camera transform for embodiment 'eva_bimanual'" in str(
exc_info.value
)
assert "aria_bimanual" in str(exc_info.value)
}
assert pi._discretize_prev_states_for_sample(batch, 0) == []
assert "Prev1:" not in pi._build_prompts(batch, "eva_bimanual", 1)[0]


def test_visualize_preds_rejects_invalid_camera_transform_shape():
pi = _make_pi({"aria_bimanual": {"extrinsics": {}}}, ["aria_bimanual"])
def test_prev_caps_number_of_past_frames():
pi = _make_prompt_pi(prev_state_in_prompt=True, history_len=5)
# K=5 window -> 4 past frames, but only _PREV_MAX_FRAMES are encoded, each
# as its own Prev<j> bin-string of width `dim`.
dim = 2
window = torch.arange(5 * dim, dtype=torch.float32).reshape(1, 5, dim) / 100.0
s = pi._discretize_prev_states_for_sample({"observations.state.ee_pose": window}, 0)
assert len(s) == PI._PREV_MAX_FRAMES
assert all(len(frame.split()) == dim for frame in s)

with pytest.raises(TypeError) as exc_info:
pi.visualize_preds(
_make_predictions("aria_bimanual"), _make_batch("aria_bimanual")
)

assert "camera_transforms must be a CameraTransforms instance or a mapping" in str(
exc_info.value
)


def test_visualize_preds_uses_embodiment_specific_camera_transform(monkeypatch):
aria_transform = _make_transform("aria")
eva_transform = _make_transform("eva")
pi = _make_pi(
{"aria_bimanual": aria_transform, "eva_bimanual": eva_transform},
["aria_bimanual", "eva_bimanual"],
)

draw_calls = []

def fake_draw_actions(
im, ac_type, color, actions, extrinsics, intrinsics, arm="both", **kwargs
):
draw_calls.append(
{
"ac_type": ac_type,
"color": color,
"extrinsics": extrinsics,
"intrinsics": intrinsics,
"arm": arm,
"shape": tuple(actions.shape),
}
)
return im

monkeypatch.setattr(pi_module, "draw_actions", fake_draw_actions)

ims = pi.visualize_preds(
_make_predictions("aria_bimanual"), _make_batch("aria_bimanual")
)

assert ims.shape == (1, 4, 4, 3)
assert len(draw_calls) == 2
assert all(call["extrinsics"] is aria_transform.extrinsics for call in draw_calls)
assert all(call["intrinsics"] is aria_transform.intrinsics for call in draw_calls)
assert all(call["arm"] == "both" for call in draw_calls)
assert all(call["shape"] == (2, 6) for call in draw_calls)
assert all(
call["extrinsics"] is not eva_transform.extrinsics for call in draw_calls
)
def test_tokenize_truncation_warns_once():
pi = object.__new__(PI)
pi.tokenizer_max_length = 4
pi._prompt_trunc_warned = False

def fake_tokenizer(prompts, padding, truncation, max_length, return_tensors):
n = len(prompts)
# No padding column -> the window is full -> truncation suspected.
return {
"input_ids": torch.ones(n, 4, dtype=torch.long),
"attention_mask": torch.ones(n, 4, dtype=torch.long),
}

pi.tokenizer = fake_tokenizer
out = pi._tokenize_prompts(["a"])
assert pi._prompt_trunc_warned is True
assert out["tokenized_prompt"].shape == (1, 4)
# token_loss_mask masks out the final (anchor) position.
assert out["token_loss_mask"][0, -1].item() is False
pi._tokenize_prompts(["b"]) # stays warned, no crash
assert pi._prompt_trunc_warned is True
Loading