From f0bd0eb068bd75cb0d00d0e67adb282bd850d95a Mon Sep 17 00:00:00 2001 From: ElmoPA Date: Tue, 2 Jun 2026 21:40:12 -0400 Subject: [PATCH] inspector: cKDTree-based KNN for fast neighbor lookup on click --- .gitignore | 1 + egomimic/eval/eval_latent.py | 91 +- egomimic/eval/latent_dataset.py | 17 +- egomimic/rldb/embodiment/human.py | 57 +- egomimic/rldb/zarr/zarr_dataset_multi.py | 5 + .../data_visualization/align_embodiment.py | 359 ++++++++ .../data_visualization/inspector_lib/app.py | 123 ++- .../inspector_lib/caches.py | 171 +++- .../inspector_lib/language.py | 5 +- .../inspector_lib/pair_rank.py | 173 ++++ .../data_visualization/inspector_lib/views.py | 803 ++++++++++++++++-- .../data_visualization/latent_inspector.py | 36 + 12 files changed, 1725 insertions(+), 116 deletions(-) create mode 100644 egomimic/scripts/data_visualization/align_embodiment.py create mode 100644 egomimic/scripts/data_visualization/inspector_lib/pair_rank.py diff --git a/.gitignore b/.gitignore index 4d617a414..7b121f1eb 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,7 @@ datasets/ **/datasets/ apikey.txt slurm-*.out +*.out slurmoutputs/ *.log .inductor_cache/ diff --git a/egomimic/eval/eval_latent.py b/egomimic/eval/eval_latent.py index 2d663b3fc..7fb173e86 100644 --- a/egomimic/eval/eval_latent.py +++ b/egomimic/eval/eval_latent.py @@ -13,6 +13,11 @@ from egomimic.eval.eval_video import EvalVideo from egomimic.rldb.embodiment.embodiment import get_embodiment +from egomimic.utils.action_utils import ( + PI05_CARTESIAN_ACTION_ENCODING_LEGACY, + PI05_CARTESIAN_ACTION_ENCODING_NORM_ROT_6D, + PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D, +) logger = logging.getLogger(__name__) @@ -98,6 +103,7 @@ def __init__( self._layer_keys = {} # layer_name -> list[np.ndarray (B, S, D)] self._row_hashes = [] # one entry per sample (replicated by S at write time) self._row_embodiments = [] + self._row_frames = [] # source frame index per sample (None-safe) self._hook_handles = [] # per-step buffer; each layer is recorded only on its FIRST forward # call within a capture window (prefix pass for PaliGemma, first @@ -205,6 +211,22 @@ def _extract_hashes(_batch, batch_size, embodiment_name): return [str(v) for v in val.cpu().tolist()] return [str(val)] * batch_size + @staticmethod + def _extract_frame_indices(_batch, batch_size): + """Per-sample SOURCE frame index, surfaced by ZarrDataset.__getitem__. + Returns None when the batch predates this field, so the caller can + fall back to the legacy per-run counter.""" + val = _batch.get("frame_idx") + if val is None: + return None + if torch.is_tensor(val): + return [int(v) for v in val.cpu().tolist()] + if isinstance(val, np.ndarray): + return [int(v) for v in val.tolist()] + if isinstance(val, (list, tuple)): + return [int(v) for v in val] + return [int(val)] * batch_size + # ------------------------------------------------------------------ # Validation lifecycle # ------------------------------------------------------------------ @@ -213,6 +235,7 @@ def on_validation_start(self): self._layer_keys = {} self._row_hashes = [] self._row_embodiments = [] + self._row_frames = [] self._n_rows = 0 self._register_hooks() if self.trainer.is_global_zero: @@ -258,16 +281,46 @@ def compute_metrics_and_viz(self, batch): ) self._capture_active = False - # Mirror PI's post-processing for metrics + viz. + # sample_actions returns a static CUDA-graph buffer that the + # next embodiment's forward overwrites — clone before use + # (mirrors PI.forward_eval). + pred_actions = pred_actions.clone() + + # Mirror PI.forward_eval post-processing for metrics + viz, + # branching on the action encoding so the 6D-rotation models + # unpack the 20-D xyz+6D(+g) action (not the legacy 14-D ypr). ref = _batch[ac_key] B, T, D = ref.shape converter = algo.action_registry.get(embodiment_id, ac_key) - pred_actions_orig = converter.from32(pred_actions) - pred = pred_actions_orig[:, :T, :D] + action_encoding = getattr( + algo, "action_encoding", PI05_CARTESIAN_ACTION_ENCODING_LEGACY + ) predictions = OrderedDict() - predictions[ac_key] = pred - unnorm_actions = algo.norm_stats.unnormalize(predictions, embodiment_id) + if action_encoding == PI05_CARTESIAN_ACTION_ENCODING_RAW_ROT_6D: + pred_actions_orig = converter.from32_raw_rotation( + pred_actions, + stats=algo._action_stats(embodiment_id, ac_key), + norm_mode=algo.norm_stats.norm_mode, + unnormalize_non_rotation=True, + ) + unnorm_actions = {ac_key: pred_actions_orig[:, :T, :D]} + elif action_encoding == PI05_CARTESIAN_ACTION_ENCODING_NORM_ROT_6D: + pred_6d = converter.from32_norm_6d(pred_actions) + predictions[ac_key] = pred_6d[:, :T, :D] + unnorm_actions = algo.norm_stats.unnormalize( + predictions, embodiment_id + ) + elif action_encoding == PI05_CARTESIAN_ACTION_ENCODING_LEGACY: + pred_actions_orig = converter.from32(pred_actions) + predictions[ac_key] = pred_actions_orig[:, :T, :D] + unnorm_actions = algo.norm_stats.unnormalize( + predictions, embodiment_id + ) + else: + raise ValueError( + f"Unsupported PI0.5 action_encoding: {action_encoding!r}" + ) for k in unnorm_actions: unnorm_preds[f"{embodiment_name}_{k}"] = unnorm_actions[k] @@ -290,6 +343,10 @@ def compute_metrics_and_viz(self, batch): hashes = self._extract_hashes(_batch, B, embodiment_name) self._row_hashes.extend(hashes) self._row_embodiments.extend([embodiment_name] * B) + frames = self._extract_frame_indices(_batch, B) + # -1 sentinel marks "no source index in batch"; if ANY sample + # is sentinel the writer falls back to the per-run counter. + self._row_frames.extend(frames if frames is not None else [-1] * B) for layer_name, key_tensor in self._step_capture.items(): keys_bsd = key_tensor.to(torch.float32).cpu().numpy() self._layer_keys.setdefault(layer_name, []).append(keys_bsd) @@ -330,6 +387,7 @@ def on_validation_end(self): keys = keys_bsd.reshape(N * S, D) sample_hashes = self._row_hashes sample_embs = self._row_embodiments + sample_frames = self._row_frames if N != len(sample_hashes): n = min(N, len(sample_hashes)) logger.warning( @@ -342,14 +400,27 @@ def on_validation_end(self): keys = keys_bsd[:n].reshape(n * S, D) sample_hashes = sample_hashes[:n] sample_embs = sample_embs[:n] + sample_frames = sample_frames[:n] # Replicate per-sample metadata across the S tokens. - # frame_idx is the per-run sample index (0..N-1) — tokens of - # the same frame share it, so meanpool.py can group on - # (video_hash, frame_idx). token_idx is the position within - # the sequence, useful for token-type slicing later. + # frame_idx is the SOURCE frame index per sample (from + # ZarrDataset.__getitem__) so tokens of the same frame share it + # and the inspector can fetch the right image / annotation. + # Falls back to a per-run counter (0..N-1) only for legacy + # datasets that don't surface frame_idx (sentinel -1 present). hashes = [h for h in sample_hashes for _ in range(S)] embs = [e for e in sample_embs for _ in range(S)] - frame_idx = [i for i in range(len(sample_hashes)) for _ in range(S)] + use_source_frames = len(sample_frames) == len(sample_hashes) and all( + fi >= 0 for fi in sample_frames + ) + if use_source_frames: + frame_idx = [fi for fi in sample_frames for _ in range(S)] + else: + logger.warning( + "%s: source frame_idx unavailable for some/all samples; " + "falling back to per-run sample counter for frame_idx.", + layer_name, + ) + frame_idx = [i for i in range(len(sample_hashes)) for _ in range(S)] token_idx = list(range(S)) * len(sample_hashes) # Optional PCA: fit on raw `keys`, log explained variance, and diff --git a/egomimic/eval/latent_dataset.py b/egomimic/eval/latent_dataset.py index 161bbd09e..3e8787392 100644 --- a/egomimic/eval/latent_dataset.py +++ b/egomimic/eval/latent_dataset.py @@ -61,6 +61,14 @@ "_shuffle_random", "_shuffle_pairs", "_shuffle_custom", + # Tokenizer/annotation knobs declared at the data root for OmegaConf + # interpolation. Tokenization now lives on the algo side, so + # MultiDataModuleWrapper no longer accepts these as constructor args. + "use_tokenizer", + "model_name", + "sampling_mode", + "annotation_key", + "default_prompt", } ) @@ -72,6 +80,7 @@ def build_dataset( embodiment: str, resolver, hashes: list[str] | None = None, + exclude_hashes: list[str] | None = None, frames_per_episode: int | None = 128, stride: int | None = None, valid_ratio: float = 0.05, @@ -85,11 +94,17 @@ def build_dataset( Unknown `mode` values raise. Extra kwargs are NOT accepted (no silent swallowing) — typos in yaml fail loudly. """ + excl = tuple(str(h) for h in (exclude_hashes or [])) if mode == "random": lam = ( "lambda row: row['task'] == " - f"{task!r} and row['robot_name'] == {embodiment!r}" + f"{task!r} and row.get('embodiment') == {embodiment!r}" ) + # Drop known-corrupt episodes (e.g. zero-norm quaternions that crash the + # quat->ypr transform; the per-sample retry can't escape a fully-bad + # episode since its fallback pool is that episode's own frames). + if excl: + lam += f" and row['episode_hash'] not in {excl!r}" filters = DatasetFilter(filter_lambdas=[lam]) logger.info("[build_dataset] %s | random mode | filter=%s", embodiment, lam) return MultiDataset._from_resolver( diff --git a/egomimic/rldb/embodiment/human.py b/egomimic/rldb/embodiment/human.py index fe7e95a9d..7827ca7a1 100644 --- a/egomimic/rldb/embodiment/human.py +++ b/egomimic/rldb/embodiment/human.py @@ -31,11 +31,26 @@ class Human(Embodiment): # MANO 21-keypoint topology: 0=wrist, 1-4 thumb, 5-8 index, 9-12 middle, 13-16 ring, 17-20 pinky. # Subclasses with non-MANO conventions (e.g. Aria) override these. FINGER_EDGES = [ - (0, 1), (1, 2), (2, 3), (3, 4), # thumb - (0, 5), (5, 6), (6, 7), (7, 8), # index - (0, 9), (9, 10), (10, 11), (11, 12), # middle - (0, 13), (13, 14), (14, 15), (15, 16), # ring - (0, 17), (17, 18), (18, 19), (19, 20), # pinky + (0, 1), + (1, 2), + (2, 3), + (3, 4), # thumb + (0, 5), + (5, 6), + (6, 7), + (7, 8), # index + (0, 9), + (9, 10), + (10, 11), + (11, 12), # middle + (0, 13), + (13, 14), + (14, 15), + (15, 16), # ring + (0, 17), + (17, 18), + (18, 19), + (19, 20), # pinky ] FINGER_COLORS = { "thumb": (255, 100, 100), @@ -137,7 +152,7 @@ def get_transform_list( stride=cls.ACTION_STRIDE ) + [CartesianYPRToRot6D(action_key="actions_cartesian")] elif mode == "keypoints_headframe_ypr": - return _build_aria_keypoints_bimanual_transform_list( + return _build_human_keypoints_bimanual_transform_list( stride=cls.ACTION_STRIDE, is_quat=False ) if mode == "keypoints_headframe_quat": @@ -160,11 +175,25 @@ class Aria(Human): ACTION_STRIDE = 3 # Aria's 21-keypoint layout is NOT MANO: 0-4 are fingertips, 5 is the palm root. FINGER_EDGES = [ - (5, 6), (6, 7), (7, 0), # thumb - (5, 8), (8, 9), (9, 10), (10, 1), # index - (5, 11), (11, 12), (12, 13), (13, 2), # middle - (5, 14), (14, 15), (15, 16), (16, 3), # ring - (5, 17), (17, 18), (18, 19), (19, 4), # pinky + (5, 6), + (6, 7), + (7, 0), # thumb + (5, 8), + (8, 9), + (9, 10), + (10, 1), # index + (5, 11), + (11, 12), + (12, 13), + (13, 2), # middle + (5, 14), + (14, 15), + (15, 16), + (16, 3), # ring + (5, 17), + (17, 18), + (18, 19), + (19, 4), # pinky ] FINGER_EDGE_RANGES = [ ("thumb", 0, 3), @@ -937,7 +966,7 @@ def _build_human_keypoints_bimanual_transform_list( return transform_list -def _build_human_cartesian_revert_eef_frame_transform_list( +def _build_aria_cartesian_revert_eef_frame_transform_list( *, action_key: str = "actions_cartesian", obs_key: str = "observations.state.ee_pose", @@ -951,7 +980,7 @@ def _build_human_cartesian_revert_eef_frame_transform_list( ) -> list[Transform]: """Revert wrist-frame ARIA cartesian actions back to head (camera) frame. - Inverse of ``_build_human_cartesian_eef_frame_transform_list`` for viz: the + Inverse of ``_build_aria_cartesian_eef_frame_transform_list`` for viz: the action chunks live in each side's wrist frame, the proprio ee-poses live in headframe (= Aria camera frame). Re-composes ``target_headframe @ chunk_wristframe`` so action chunks are back in headframe / camera frame. @@ -1146,7 +1175,7 @@ def _build_aria_cartesian_eef_frame_transform_list( return transform_list -def _build_human_cartesian_bimanual_transform_list( +def _build_aria_cartesian_bimanual_transform_list( *, target_world: str = "obs_head_pose", target_world_ypr: str = "obs_head_pose_ypr", diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index 7c941a553..7db8843ff 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -1730,6 +1730,11 @@ def _next(reason: str, key: str = "") -> int: data["episode_hash"] = ( ep_name[:-5] if ep_name.endswith(".zarr") else ep_name ) + # The actual source frame served (after any JPEG-decode fallback, + # `idx` is the frame finally read). Surfaced parallel to + # episode_hash so latent capture can label rows with the true + # frame index instead of a per-run counter. + data["frame_idx"] = int(idx) _ = origin # preserved for symmetry with prior API return data diff --git a/egomimic/scripts/data_visualization/align_embodiment.py b/egomimic/scripts/data_visualization/align_embodiment.py new file mode 100644 index 000000000..4b58580c5 --- /dev/null +++ b/egomimic/scripts/data_visualization/align_embodiment.py @@ -0,0 +1,359 @@ +"""Embodiment-alignment sweep over captured latents (action-expert layers). + +Hypothesis under test: in the raw / PCA-50 latent space, a few directions +encode *embodiment* (aria vs eva) rather than task semantics. If we strip those +directions, the human and robot manifolds superimpose and ordinary KNN returns +cross-embodiment neighbors organically — no explicit cross-embodiment sampling. + +For each action-expert layer (`expert_layer_NN`) this script: + 1. loads `_keys.pt` (N, D) + the row-aligned `.csv` + (video_hash, embodiment, frame_idx, token_idx), + 2. attaches a frame-level *semantic* label = the language annotation spanning + that (hash, frame) — pulled from the zarr the same way the inspector does, + 3. applies several embodiment-removal transforms, + 4. scores each transform on two axes: + - embodiment removal : logreg 5-fold CV acc (want -> chance), + - semantic preservation: annotation kNN 5-fold CV acc (want stays high), + - cross-emb kNN rate : mean fraction of each point's k-NN that are the + OTHER embodiment (want -> population mix), + 5. optionally writes an inspector-loadable UMAP CSV per (layer, method). + +Methods + none standardized features, no removal (baseline) + mean_center subtract each embodiment's own mean (rank-1 domain offset) + zscore_emb per-embodiment mean+std standardization (1st+2nd order offset) + drop_pc:k PCA-50, drop the top-k embodiment-discriminative PCs (Fisher rank) + +Run (emimic venv, on a GPU node so cuML UMAP is available): + python egomimic/scripts/data_visualization/align_embodiment.py \ + --latent-dir logs/objgen6d_latent/.../latents/epoch_0 \ + --zarr-root /storage/.../datasets/pick_place_cotrain \ + --out-dir logs/objgen6d_latent/.../latents/epoch_0/aligned \ + --umap # also emit UMAP CSV/PNG per (layer, method) +""" + +from __future__ import annotations + +import argparse +import csv +import glob +import json +import os +from collections import Counter + +import numpy as np + +# sklearn is numpy-2 compatible in recent versions; all linear algebra here is +# small (subsampled), so CPU sklearn is fine. cuML is used only for UMAP. +from sklearn.decomposition import PCA +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import cross_val_score +from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors +from sklearn.preprocessing import StandardScaler + +from egomimic.scripts.data_visualization.inspector_lib.language import ( + annotation_intervals, + lang_for_frame, +) + +DROP_KS = (1, 2, 3, 5, 10) # top-k embodiment PCs to drop for drop_pc sweep +PCA_DIM = 50 +KNN_K = 10 +CV = 5 +RNG = np.random.RandomState(0) + + +# ---------------------------------------------------------------------------- +# IO +# ---------------------------------------------------------------------------- +def list_expert_layers(latent_dir): + csvs = sorted(glob.glob(os.path.join(latent_dir, "expert_layer_*.csv"))) + # drop any _img/_lang/_combined/_tsne derivatives — bare expert layers only + return [ + c + for c in csvs + if os.path.basename(c)[: -len(".csv")].replace("expert_layer_", "").isdigit() + ] + + +def load_layer(csv_path): + import torch + + keys_path = csv_path[: -len(".csv")] + "_keys.pt" + if not os.path.isfile(keys_path): + return None + X = torch.load(keys_path, map_location="cpu").to(torch.float32).numpy() + hashes, embs, frames = [], [], [] + with open(csv_path, newline="") as f: + r = csv.DictReader(f) + for row in r: + hashes.append(row["video_hash"]) + embs.append(row["embodiment"]) + frames.append(int(row.get("frame_idx", -1))) + n = min(len(hashes), X.shape[0]) + return X[:n], np.array(hashes[:n]), np.array(embs[:n]), np.array(frames[:n]) + + +def annotation_labels(zarr_root, hashes, frames): + """Frame-level annotation string per row (cached per hash).""" + cache = {} + out = [] + for h, fr in zip(hashes, frames): + if h not in cache: + try: + cache[h] = annotation_intervals(zarr_root, h) + except Exception: + cache[h] = tuple() + out.append(lang_for_frame(cache[h], int(fr)) or "") + return np.array(out) + + +def balanced_subsample(embs, max_rows): + idx_by = {e: np.where(embs == e)[0] for e in np.unique(embs)} + per = max(1, max_rows // max(1, len(idx_by))) + keep = [] + for e, ix in idx_by.items(): + if len(ix) > per: + ix = RNG.choice(ix, per, replace=False) + keep.append(ix) + keep = np.concatenate(keep) + RNG.shuffle(keep) + return keep + + +# ---------------------------------------------------------------------------- +# Removal methods -> return cleaned feature matrix F (N, d) +# ---------------------------------------------------------------------------- +def method_none(Xs, y): + return Xs + + +def method_mean_center(Xs, y): + F = Xs.copy() + for e in np.unique(y): + m = Xs[y == e].mean(0) + F[y == e] -= m + return F + + +def method_zscore_emb(Xs, y): + F = Xs.copy() + for e in np.unique(y): + sl = y == e + m, s = Xs[sl].mean(0), Xs[sl].std(0) + 1e-6 + F[sl] = (Xs[sl] - m) / s + return F + + +def method_drop_pc(Xs, y, k, pca_scores=None, fisher_order=None): + """PCA-50 then zero the top-k embodiment-discriminative PCs (by Fisher).""" + order = fisher_order + Z = pca_scores + keep = order[k:] # drop the k most embodiment-discriminative PCs + return Z[:, keep] + + +# ---------------------------------------------------------------------------- +# Metrics +# ---------------------------------------------------------------------------- +def emb_decode_acc(F, y_emb): + return float( + cross_val_score( + LogisticRegression(max_iter=2000), F, y_emb, cv=CV, n_jobs=-1 + ).mean() + ) + + +def cross_emb_knn_rate(F, y_emb, k=KNN_K): + nn = NearestNeighbors(n_neighbors=k + 1).fit(F) + _, idx = nn.kneighbors(F) + idx = idx[:, 1:] # drop self + other = (y_emb[idx] != y_emb[:, None]).mean() + return float(other) + + +def semantic_knn_acc(F, y_sem, k=KNN_K): + # keep classes with enough support for CV + cnt = Counter(y_sem) + good = {c for c, n in cnt.items() if n >= max(CV, k + 1) and c != ""} + if len(good) < 2: + return float("nan"), 0, 0 + m = np.array([s in good for s in y_sem]) + Fg, yg = F[m], y_sem[m] + shared_note = len(good) + acc = float( + cross_val_score( + KNeighborsClassifier(n_neighbors=k), Fg, yg, cv=CV, n_jobs=-1 + ).mean() + ) + return acc, int(m.sum()), shared_note + + +def shared_annotation_classes(y_sem, y_emb): + by = {e: set(y_sem[y_emb == e]) for e in np.unique(y_emb)} + sets = [s - {""} for s in by.values()] + return len(set.intersection(*sets)) if sets else 0 + + +# ---------------------------------------------------------------------------- +# UMAP (cuML) viz output for the inspector +# ---------------------------------------------------------------------------- +def write_umap_csv(out_csv, F, hashes, embs, frames): + try: + from cuml.manifold import UMAP as cuUMAP + + xyz = cuUMAP(n_components=3, random_state=0).fit_transform(F) + xyz = np.asarray(xyz) + except Exception as e: # noqa: BLE001 + print(f" [umap] skipped ({type(e).__name__}: {e})") + return + os.makedirs(os.path.dirname(out_csv), exist_ok=True) + with open(out_csv, "w", newline="") as f: + w = csv.writer(f) + w.writerow( + [ + "video_hash", + "embodiment", + "frame_idx", + "token_idx", + "umap_x", + "umap_y", + "umap_z", + ] + ) + for i in range(F.shape[0]): + w.writerow( + [ + hashes[i], + embs[i], + int(frames[i]), + 0, + float(xyz[i, 0]), + float(xyz[i, 1]), + float(xyz[i, 2]), + ] + ) + + +# ---------------------------------------------------------------------------- +def run_layer(csv_path, zarr_root, out_dir, max_rows, do_umap): + name = os.path.basename(csv_path)[: -len(".csv")] + loaded = load_layer(csv_path) + if loaded is None: + print(f"[{name}] no keys.pt — skip") + return [] + X, hashes, embs, frames = loaded + sub = balanced_subsample(embs, max_rows) + X, hashes, embs, frames = X[sub], hashes[sub], embs[sub], frames[sub] + y_sem = annotation_labels(zarr_root, hashes, frames) + + Xs = StandardScaler().fit_transform(X) + pca = PCA(n_components=min(PCA_DIM, Xs.shape[1])).fit(Xs) + Z = pca.transform(Xs) + # Fisher ratio of each PC between the two embodiments -> drop order + classes = np.unique(embs) + fisher = np.zeros(Z.shape[1]) + if len(classes) == 2: + a, b = (embs == classes[0]), (embs == classes[1]) + num = (Z[a].mean(0) - Z[b].mean(0)) ** 2 + den = Z[a].var(0) + Z[b].var(0) + 1e-9 + fisher = num / den + fisher_order = np.argsort(-fisher) # most embodiment-discriminative first + + chance = max(Counter(embs).values()) / len(embs) + pop_other = ( + float((embs != classes[0]).mean()) if len(classes) == 2 else float("nan") + ) + shared = shared_annotation_classes(y_sem, embs) + + methods = { + "none": method_none(Xs, embs), + "mean_center": method_mean_center(Xs, embs), + "zscore_emb": method_zscore_emb(Xs, embs), + } + for k in DROP_KS: + methods[f"drop_pc:{k}"] = method_drop_pc( + Xs, embs, k, pca_scores=Z, fisher_order=fisher_order + ) + + rows = [] + print( + f"\n[{name}] N={len(embs)} emb-chance={chance:.2f} " + f"pop_other={pop_other:.2f} annot_classes(shared 2-emb)={shared} " + f"top-PC-Fisher={fisher[fisher_order[0]]:.3f}" + ) + print(f" {'method':<14}{'emb_acc':>9}{'crossKNN':>10}{'sem_acc':>9}{'sem_n':>7}") + for mname, F in methods.items(): + ea = emb_decode_acc(F, embs) + ck = cross_emb_knn_rate(F, embs) + sa, sn, _ = semantic_knn_acc(F, y_sem) + print( + f" {mname:<14}{ea:>9.3f}{ck:>10.3f}" + f"{(sa if sa==sa else float('nan')):>9.3f}{sn:>7d}" + ) + rows.append( + dict( + layer=name, + method=mname, + n=len(embs), + emb_chance=round(chance, 4), + emb_acc=round(ea, 4), + cross_emb_knn=round(ck, 4), + pop_other_frac=round(pop_other, 4), + sem_knn_acc=(round(sa, 4) if sa == sa else None), + sem_n=sn, + shared_annot_classes=shared, + ) + ) + if do_umap and mname in ("none", "mean_center", "drop_pc:3", "drop_pc:10"): + write_umap_csv( + os.path.join(out_dir, name, mname.replace(":", "_"), f"{name}.csv"), + F, + hashes, + embs, + frames, + ) + return rows + + +def main(): + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--latent-dir", required=True) + ap.add_argument("--zarr-root", required=True) + ap.add_argument("--out-dir", default=None) + ap.add_argument("--max-rows", type=int, default=20000) + ap.add_argument( + "--layers", + default=None, + help="comma list of layer indices, e.g. 0,8,17 (default: all)", + ) + ap.add_argument("--umap", action="store_true") + args = ap.parse_args() + + out_dir = args.out_dir or os.path.join(args.latent_dir, "aligned") + os.makedirs(out_dir, exist_ok=True) + layers = list_expert_layers(args.latent_dir) + if args.layers: + want = {f"{int(i):02d}" for i in args.layers.split(",")} + layers = [c for c in layers if os.path.basename(c).split("_")[-1][:2] in want] + if not layers: + raise SystemExit(f"No expert_layer_*.csv in {args.latent_dir}") + print(f"Sweeping {len(layers)} action-expert layers -> {out_dir}") + + all_rows = [] + for c in layers: + all_rows += run_layer(c, args.zarr_root, out_dir, args.max_rows, args.umap) + + with open(os.path.join(out_dir, "alignment_metrics.json"), "w") as f: + json.dump(all_rows, f, indent=2) + # also a flat CSV for quick eyeballing + if all_rows: + with open(os.path.join(out_dir, "alignment_metrics.csv"), "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=list(all_rows[0].keys())) + w.writeheader() + w.writerows(all_rows) + print(f"\nWrote {out_dir}/alignment_metrics.{{json,csv}}") + + +if __name__ == "__main__": + main() diff --git a/egomimic/scripts/data_visualization/inspector_lib/app.py b/egomimic/scripts/data_visualization/inspector_lib/app.py index ad8c52687..efae4c539 100644 --- a/egomimic/scripts/data_visualization/inspector_lib/app.py +++ b/egomimic/scripts/data_visualization/inspector_lib/app.py @@ -304,6 +304,83 @@ def build_app( "marginTop": "6px", }, ), + html.Div( + "Neighbor pairs", + style={**LABEL_STYLE, "marginTop": "12px"}, + ), + dcc.RadioItems( + id="browser_knn_pairs", + options=[ + { + "label": " Cross-embodiment", + "value": "cross", + }, + { + "label": " Same embodiment", + "value": "same", + }, + ], + value="cross", + labelStyle={ + "display": "block", + "fontSize": "13px", + "marginBottom": "4px", + "cursor": "pointer", + }, + ), + html.Div( + "Same-embodiment skips the precomputed " + "KNN (cross-only) and excludes the " + "clicked frame's own action (its " + "annotation interval; whole recording " + "if unannotated).", + style={ + "fontSize": "10px", + "color": MUTED, + "marginTop": "6px", + }, + ), + ], + ), + # PCA component removal: drop the top-K PCA-50 dims + # most discriminative between aria and eva + # embodiments (Fisher score). Affects BOTH views: + # scatter recomputes UMAP live on the filtered + # features (click Apply), and the KNN lists + # (scatter pane + browser) measure distances in + # the reduced space. + html.Div( + style=CARD_STYLE, + children=[ + html.Div( + "PCA removal (aria vs eva Fisher)", + style=LABEL_STYLE, + ), + dcc.Slider( + id="pca_drop_k", + min=0, + max=50, + step=5, + value=0, + marks={i: str(i) for i in range(0, 51, 10)}, + tooltip={ + "placement": "bottom", + "always_visible": False, + }, + ), + html.Div( + "Drops the top-K of 50 PCA dims ranked " + "by aria↔eva Fisher score. Scatter: " + "live UMAP on the filtered features " + "(click Apply; first run fits PCA + " + "UMAP — slow). KNN lists use the " + "reduced space in both views.", + style={ + "fontSize": "10px", + "color": MUTED, + "marginTop": "6px", + }, + ), ], ), # Scatter-only controls: hidden when view_mode='browser'. @@ -463,6 +540,35 @@ def build_app( "cursor": "pointer", }, ), + html.Div( + "Browse order", + style={**LABEL_STYLE, "marginTop": "12px"}, + ), + dcc.Checklist( + id="browser_shuffle", + options=[ + { + "label": " Shuffle frames", + "value": "on", + } + ], + value=[], + labelStyle={ + "fontSize": "13px", + "cursor": "pointer", + }, + ), + html.Div( + "Fixed-seed permutation of the whole " + "list — paging via Load more stays " + "stable. Off = round-robin across " + "recordings, frames in temporal order.", + style={ + "fontSize": "10px", + "color": MUTED, + "marginTop": "4px", + }, + ), html.Div( "Exclude annotation substring (case-insensitive):", style={ @@ -904,10 +1010,21 @@ def build_app( dash.Input("layer", "value"), dash.Input("view_mode", "value"), dash.Input("browser_knn_space", "value"), + dash.Input("pca_drop_k", "value"), + dash.Input("browser_knn_pairs", "value"), dash.Input("browser_nav_stack", "data"), dash.Input("nav_stack", "data"), ) - def _state_to_url(run, layer, view_mode, knn_space, browser_nav, scatter_nav): + def _state_to_url( + run, + layer, + view_mode, + knn_space, + pca_drop_k, + knn_pairs, + browser_nav, + scatter_nav, + ): params: dict[str, str] = {} if run: params["run"] = run @@ -917,6 +1034,10 @@ def _state_to_url(run, layer, view_mode, knn_space, browser_nav, scatter_nav): params["view"] = view_mode if knn_space: params["knn"] = knn_space + if pca_drop_k: + params["pcadrop"] = str(int(pca_drop_k)) + if knn_pairs: + params["pairs"] = knn_pairs # The most recent clicked frame is the tail of whichever nav stack # belongs to the active view. active_nav = (browser_nav if view_mode == "browser" else scatter_nav) or [] diff --git a/egomimic/scripts/data_visualization/inspector_lib/caches.py b/egomimic/scripts/data_visualization/inspector_lib/caches.py index ecefb0aa1..46cd0c19a 100644 --- a/egomimic/scripts/data_visualization/inspector_lib/caches.py +++ b/egomimic/scripts/data_visualization/inspector_lib/caches.py @@ -36,15 +36,27 @@ def __init__( keys_max: int = 2, pca_max: int = 4, knn_max: int = 8, + drop_max: int = 2, ): self.full_max = full_max self.keys_max = keys_max self.pca_max = pca_max self.knn_max = knn_max + self.drop_max = drop_max self.full_cache: "OrderedDict[tuple[str, str], dict]" = OrderedDict() self.keys_cache: "OrderedDict[tuple[str, str], np.ndarray]" = OrderedDict() - self.pca_cache: "OrderedDict[tuple[str, str, int], np.ndarray]" = OrderedDict() + # (run, layer, n_components) -> {'feats', 'components', 'mean', 'evr'} + self.pca_cache: "OrderedDict[tuple[str, str, int], dict]" = OrderedDict() self.knn_cache: "OrderedDict[tuple[str, str], dict | None]" = OrderedDict() + # (run, layer, n_components) -> {'scores', 'order', 'group_a', 'group_b'} + self.fisher_cache: "OrderedDict[tuple[str, str, int], dict | None]" = ( + OrderedDict() + ) + # (run, layer, n_components, drop_k) -> feats with dropped columns. + # These are full-N float32 copies — keep the cap tight. + self.drop_cache: "OrderedDict[tuple[str, str, int, int], np.ndarray]" = ( + OrderedDict() + ) def layers_for(self, run_path: str) -> list[str]: paths = list_layer_csvs(run_path) @@ -319,10 +331,11 @@ def load_knn(self, run_path: str, layer: str): ) return self.knn_cache[key] - def pca_features(self, run_path: str, layer: str, n_components: int = 50): - """Fit PCA(n_components) on the layer's raw keys and return the - transformed (N, n_components) array. Cached. Returns None if raw - keys aren't available.""" + def _pca_entry(self, run_path: str, layer: str, n_components: int = 50): + """Fit PCA(n_components) on the layer's raw keys. Returns a cached + dict with 'feats' (N, k) float32, 'components' (k, D) float32, + 'mean' (D,) float32, 'evr' (k,) float32 — or None if raw keys + aren't available / the fit fails.""" cache_key = (run_path, layer, n_components) if cache_key in self.pca_cache: self.pca_cache.move_to_end(cache_key) @@ -342,15 +355,20 @@ def pca_features(self, run_path: str, layer: str, n_components: int = 50): "PCA failed for %s | %s: %s", os.path.basename(run_path), layer, e ) return None - feats = feats.astype(np.float32) + entry = { + "feats": feats.astype(np.float32), + "components": np.asarray(pca.components_, dtype=np.float32), + "mean": np.asarray(pca.mean_, dtype=np.float32), + "evr": np.asarray(pca.explained_variance_ratio_, dtype=np.float32), + } logger.info( "PCA features for %s | %s shape=%s (var explained=%.3f)", os.path.basename(run_path), layer, - feats.shape, - float(getattr(pca, "explained_variance_ratio_", np.zeros(1)).sum()), + entry["feats"].shape, + float(entry["evr"].sum()), ) - self.pca_cache[cache_key] = feats + self.pca_cache[cache_key] = entry while len(self.pca_cache) > self.pca_max: evicted_key, _ = self.pca_cache.popitem(last=False) logger.info( @@ -359,4 +377,137 @@ def pca_features(self, run_path: str, layer: str, n_components: int = 50): evicted_key[1], evicted_key[2], ) - return self.pca_cache[cache_key] + return entry + + def pca_features(self, run_path: str, layer: str, n_components: int = 50): + """Fit PCA(n_components) on the layer's raw keys and return the + transformed (N, n_components) array. Cached. Returns None if raw + keys aren't available.""" + entry = self._pca_entry(run_path, layer, n_components) + return None if entry is None else entry["feats"] + + def fisher_ranking(self, run_path: str, layer: str, n_components: int = 50): + """Rank PCA dims by Fisher score between the aria-* and eva-* + embodiment groups: F_d = (mu_a - mu_b)^2 / (var_a + var_b). Returns + a cached dict with 'scores' (k,) float32, 'order' (k,) int64 + (descending score), 'group_a'/'group_b' labels — or None if PCA + features are unavailable, misaligned with the CSV rows, or fewer + than two embodiment groups exist.""" + cache_key = (run_path, layer, n_components) + if cache_key in self.fisher_cache: + self.fisher_cache.move_to_end(cache_key) + return self.fisher_cache[cache_key] + + result = None + entry = self._pca_entry(run_path, layer, n_components) + if entry is not None: + feats = entry["feats"] + embs = np.asarray(self.load(run_path, layer)["embs"]) + if feats.shape[0] != len(embs): + logger.warning( + "fisher_ranking: PCA rows (%d) mismatch CSV rows (%d) for %s | %s", + feats.shape[0], + len(embs), + os.path.basename(run_path), + layer, + ) + else: + # Embodiment names look like 'aria_bimanual' / 'eva_right_arm'. + # Prefer the aria-vs-eva split; fall back to the two largest + # embodiment groups when those substrings are absent. + lower = np.array([str(e).lower() for e in embs]) + mask_a = np.char.find(lower.astype(str), "aria") >= 0 + mask_b = np.char.find(lower.astype(str), "eva") >= 0 + group_a, group_b = "aria*", "eva*" + if not (mask_a.any() and mask_b.any()): + uniq, counts = np.unique(lower, return_counts=True) + if len(uniq) >= 2: + top2 = uniq[np.argsort(counts)[::-1][:2]] + group_a, group_b = str(top2[0]), str(top2[1]) + mask_a = lower == top2[0] + mask_b = lower == top2[1] + else: + mask_a = mask_b = None + if mask_a is not None and mask_a.any() and mask_b.any(): + fa = feats[mask_a] + fb = feats[mask_b] + mu_diff = fa.mean(axis=0) - fb.mean(axis=0) + denom = fa.var(axis=0) + fb.var(axis=0) + 1e-12 + scores = (mu_diff * mu_diff / denom).astype(np.float32) + order = np.argsort(scores)[::-1].copy() + result = { + "scores": scores, + "order": order, + "group_a": group_a, + "group_b": group_b, + } + logger.info( + "Fisher ranking for %s | %s (%s n=%d vs %s n=%d): " + "top-5 dims=%s scores=%s", + os.path.basename(run_path), + layer, + group_a, + int(mask_a.sum()), + group_b, + int(mask_b.sum()), + order[:5].tolist(), + [f"{v:.3g}" for v in scores[order[:5]]], + ) + else: + logger.warning( + "fisher_ranking: <2 embodiment groups for %s | %s", + os.path.basename(run_path), + layer, + ) + + self.fisher_cache[cache_key] = result + while len(self.fisher_cache) > self.pca_max: + self.fisher_cache.popitem(last=False) + return result + + def pca_features_dropped( + self, run_path: str, layer: str, n_components: int = 50, drop_k: int = 0 + ): + """PCA features with the top-`drop_k` Fisher-ranked dims removed + (columns deleted, so distances live in the remaining subspace). + Cached (full-N copies — tight LRU). Returns None when PCA features + or the Fisher ranking are unavailable.""" + if drop_k <= 0: + return self.pca_features(run_path, layer, n_components) + cache_key = (run_path, layer, n_components, int(drop_k)) + if cache_key in self.drop_cache: + self.drop_cache.move_to_end(cache_key) + return self.drop_cache[cache_key] + feats = self.pca_features(run_path, layer, n_components) + ranking = self.fisher_ranking(run_path, layer, n_components) + if feats is None or ranking is None: + return None + drop = ranking["order"][: min(int(drop_k), feats.shape[1])] + out = np.delete(feats, drop, axis=1) + logger.info( + "PCA-dropped features for %s | %s: removed top-%d Fisher dims -> shape=%s", + os.path.basename(run_path), + layer, + len(drop), + out.shape, + ) + self.drop_cache[cache_key] = out + while len(self.drop_cache) > self.drop_max: + self.drop_cache.popitem(last=False) + return out + + def fisher_dropped_directions( + self, run_path: str, layer: str, n_components: int = 50, drop_k: int = 0 + ): + """The (drop_k, D) raw-space directions of the top-`drop_k` + Fisher-ranked PCA components — used to project embodiment- + discriminative directions out of raw-key distances. Returns None + when unavailable.""" + if drop_k <= 0: + return None + entry = self._pca_entry(run_path, layer, n_components) + ranking = self.fisher_ranking(run_path, layer, n_components) + if entry is None or ranking is None: + return None + drop = ranking["order"][: min(int(drop_k), entry["components"].shape[0])] + return entry["components"][drop] diff --git a/egomimic/scripts/data_visualization/inspector_lib/language.py b/egomimic/scripts/data_visualization/inspector_lib/language.py index f1b439c60..a9ce32466 100644 --- a/egomimic/scripts/data_visualization/inspector_lib/language.py +++ b/egomimic/scripts/data_visualization/inspector_lib/language.py @@ -16,7 +16,10 @@ logger = logging.getLogger(__name__) -@lru_cache(maxsize=512) +# Sized to hold every episode of a large run at once (full-cotrain latent +# runs span >1000 recordings) — eviction thrash here means re-opening zarrs +# over NFS on every click. +@lru_cache(maxsize=4096) def annotation_intervals( zarr_root: str, video_hash: str ) -> tuple[tuple[int, int, str, str], ...]: diff --git a/egomimic/scripts/data_visualization/inspector_lib/pair_rank.py b/egomimic/scripts/data_visualization/inspector_lib/pair_rank.py new file mode 100644 index 000000000..afbef1537 --- /dev/null +++ b/egomimic/scripts/data_visualization/inspector_lib/pair_rank.py @@ -0,0 +1,173 @@ +"""Perfect-pair retrieval report: average rank ("place") of the paired action. + +For each layer CSV in a run dir, sample N rows; for every sampled row, rank all +opposite-embodiment ACTION instances (recording + annotation interval, the same +identity the KNN dedupe uses) by their min row distance to the query, and find +the place of the query's true paired action (1 = nearest). Reports avg/median +place and top-1 rate per layer. + +Distances are computed in one of the coordinate spaces baked into the CSVs +(umap / pca_umap / tsne2d) — the same spaces the scatter view plots — so the +report matches what the eye sees in the viewer without the multi-GB keys.pt +reads a raw-space sweep over every layer would need. +""" + +from __future__ import annotations + +import glob +import logging +import os +import time + +import numpy as np +import pandas as pd + +from .views import _find_pair_action, _row_action_ids + +logger = logging.getLogger(__name__) + +_SPACE_COLS = { + "umap": ["umap_x", "umap_y", "umap_z"], + "pca_umap": ["pca_umap_x", "pca_umap_y", "pca_umap_z"], + "tsne2d": ["tsne2d_x", "tsne2d_y"], +} + +_META_COLS = ["video_hash", "frame_idx", "token_idx", "embodiment"] + + +def _layer_places(zarr_root, df, space_cols, n_samples, seed): + """Per-layer core: returns (places list, n_actions_opposite_avg).""" + data = { + "hashes": df["video_hash"].to_numpy(str), + "frame_idx": df["frame_idx"].to_numpy(np.int64), + "embs": df["embodiment"].to_numpy(str), + } + coords = df[space_cols].to_numpy(np.float32) + ids = _row_action_ids(zarr_root, data) + embs = data["embs"] + + # action id -> (representative row, embodiment) + _, first_pos = np.unique(ids, return_index=True) + rep_row = {int(ids[i]): int(i) for i in first_pos} + + # action id -> paired opposite-embodiment action id (None when no twin) + pair_of: dict[int, int | None] = {} + for aid, ridx in rep_row.items(): + res = _find_pair_action( + zarr_root, + data, + data["hashes"][ridx], + int(data["frame_idx"][ridx]), + str(embs[ridx]), + ) + if res is None: + pair_of[aid] = None + continue + twin_h, s2, e2, _prompt = res + twin_rows = np.where( + (data["hashes"] == twin_h) + & (data["frame_idx"] >= s2) + & (data["frame_idx"] < e2) + )[0] + pair_of[aid] = int(ids[twin_rows[0]]) if twin_rows.size else None + + paired_aid_per_row = np.array( + [pair_of.get(int(a)) if pair_of.get(int(a)) is not None else -1 for a in ids], + dtype=np.int64, + ) + eligible = np.where(paired_aid_per_row >= 0)[0] + if eligible.size == 0: + return [], 0.0 + + rng = np.random.default_rng(seed) + take = min(n_samples, eligible.size) + qidx = rng.choice(eligible, size=take, replace=False) + + places: list[int] = [] + n_opp_actions_seen = [] + for emb in np.unique(embs[qidx]): + q_e = qidx[embs[qidx] == emb] + opp_rows = np.where(embs != emb)[0] + if opp_rows.size == 0: + continue + # Sort opposite rows by action id so per-action mins reduce with + # one reduceat per query instead of a python loop over actions. + order = np.argsort(ids[opp_rows], kind="stable") + opp_sorted = opp_rows[order] + opp_ids_sorted = ids[opp_sorted] + opp_action_ids, starts = np.unique(opp_ids_sorted, return_index=True) + opp_coords = coords[opp_sorted] + n_opp_actions_seen.append(len(opp_action_ids)) + + paired_pos = np.searchsorted(opp_action_ids, paired_aid_per_row[q_e]) + # Guard: paired action must exist on the opposite side of this layer. + valid = (paired_pos < len(opp_action_ids)) & ( + opp_action_ids[np.minimum(paired_pos, len(opp_action_ids) - 1)] + == paired_aid_per_row[q_e] + ) + q_e, paired_pos = q_e[valid], paired_pos[valid] + + # ||q-x||² = ||q||² + ||x||² − 2 q·x, chunked — the broadcasted + # (chunk, M, 3) diff tensor peaked at ~2 GB on the 2.7M-row + # paligemma layers and got the report OOM-killed on login nodes. + opp_sq = np.einsum("nd,nd->n", opp_coords, opp_coords) + CHUNK = 32 + for c0 in range(0, len(q_e), CHUNK): + qc = q_e[c0 : c0 + CHUNK] + pc = paired_pos[c0 : c0 + CHUNK] + qcoords = coords[qc] + q_sq = np.einsum("qd,qd->q", qcoords, qcoords) + d2 = q_sq[:, None] + opp_sq[None, :] - 2.0 * (qcoords @ opp_coords.T) + mins = np.minimum.reduceat(d2, starts, axis=1) # (c, n_opp_actions) + paired_min = mins[np.arange(len(qc)), pc] + place = 1 + (mins < paired_min[:, None]).sum(axis=1) + places.extend(place.tolist()) + return places, (float(np.mean(n_opp_actions_seen)) if n_opp_actions_seen else 0.0) + + +def pair_rank_report( + run_dir: str, + zarr_root: str, + n_samples: int = 1000, + space: str = "umap", + seed: int = 0, +) -> None: + """Print the per-layer average place of the perfect pair.""" + if space not in _SPACE_COLS: + raise ValueError(f"space must be one of {sorted(_SPACE_COLS)}, got {space!r}") + space_cols = _SPACE_COLS[space] + csvs = sorted(glob.glob(os.path.join(run_dir, "*.csv"))) + if not csvs: + raise SystemExit(f"No layer CSVs in {run_dir}") + + print( + f"\nPerfect-pair retrieval report — avg place of the paired action " + f"(1 = nearest) among opposite-embodiment actions\n" + f"run: {run_dir}\nspace: {space} | samples/layer: {n_samples} | seed: {seed}\n" + ) + header = f"{'layer':38s} {'n_eval':>6s} {'avg':>7s} {'median':>7s} {'top1%':>6s} {'#actions':>8s}" + print(header) + print("-" * len(header)) + for path in csvs: + layer = os.path.basename(path)[:-4] + t0 = time.time() + try: + df = pd.read_csv(path, usecols=_META_COLS + space_cols) + except ValueError: + print(f"{layer:38s} (missing {'/'.join(space_cols)} — skipped)") + continue + places, n_opp = _layer_places(zarr_root, df, space_cols, n_samples, seed) + if not places: + print(f"{layer:38s} (no paired actions found — skipped)") + continue + arr = np.asarray(places) + print( + f"{layer:38s} {len(arr):>6d} {arr.mean():>7.2f} " + f"{np.median(arr):>7.1f} {100.0 * (arr == 1).mean():>5.1f}% " + f"{n_opp:>8.0f} [{time.time() - t0:.1f}s]" + ) + print( + "\nplace = rank of the true paired action when all opposite-embodiment " + "actions are sorted by min row distance to the sampled row.\n" + "Chance level = (#actions + 1) / 2." + ) diff --git a/egomimic/scripts/data_visualization/inspector_lib/views.py b/egomimic/scripts/data_visualization/inspector_lib/views.py index 9ea930cdd..5fa07d005 100644 --- a/egomimic/scripts/data_visualization/inspector_lib/views.py +++ b/egomimic/scripts/data_visualization/inspector_lib/views.py @@ -9,6 +9,7 @@ import logging import os +from collections import OrderedDict import numpy as np @@ -18,6 +19,7 @@ from .language import ( all_lang_concat_lower, annotation_intervals, + interval_for_frame, load_language_prompt, ) @@ -88,24 +90,195 @@ def _get_langs_per_row(data: dict, zarr_root: str) -> np.ndarray: def _filter_data_by_lang(data: dict, excludes: list[str], zarr_root: str): - """Return a shallow-copy of `data` with all per-row arrays masked - down to rows whose language doesn't contain any of `excludes`. - If `excludes` is empty, returns `data` unchanged.""" + """Return (filtered_data, n_dropped, keep_mask). `filtered_data` is a + shallow-copy of `data` with all per-row arrays masked down to rows + whose language doesn't contain any of `excludes`; `keep_mask` is the + boolean row mask (None when nothing was filtered) so callers can map + filtered row indices back to original CSV rows. If `excludes` is + empty, returns `data` unchanged.""" if not excludes: - return data, 0 + return data, 0, None langs = _get_langs_per_row(data, zarr_root) keep = np.ones(len(langs), dtype=bool) for sub in excludes: keep &= np.array([sub not in (text or "") for text in langs]) if keep.all(): - return data, 0 + return data, 0, None out = {} for k, v in data.items(): if isinstance(v, (np.ndarray, LazyStringArray)) and len(v) == len(keep): out[k] = v[keep] else: out[k] = v - return out, int((~keep).sum()) + return out, int((~keep).sum()), keep + + +def _action_key(zarr_root: str, video_hash: str, frame_idx: int): + """Identity of the action instance a row belongs to: the recording plus + the first annotation interval covering the frame. Rows with no covering + interval (or unreadable annotations) collapse into one whole-recording + action — same fallback as the same-embodiment action exclusion.""" + for s, e, *_ in annotation_intervals(zarr_root, str(video_hash)): + if s <= frame_idx < e: + return (str(video_hash), s, e) + return (str(video_hash), -1, -1) + + +# (cache_key, n_rows) -> (N,) int64 action ids. Small: a few layers' worth of +# int arrays. Keyed per (run, layer) so LayerStore cache turnover can't serve +# stale rows. +_ACTION_IDS_CACHE: "OrderedDict[tuple, np.ndarray]" = OrderedDict() +_ACTION_IDS_CACHE_MAX = 8 + + +def _row_action_ids(zarr_root: str, data: dict, cache_key: tuple | None = None): + """(N,) int array assigning every row its action-instance id (recording + + first covering annotation interval; whole recording when uncovered). + + Built vectorized per unique recording with the intervals prefetched in + parallel — per-row python/zarr lookups made click latency scale with both + row count and episode count on the network FS.""" + n = len(np.asarray(data["frame_idx"])) + key = (cache_key, n) if cache_key is not None else None + if key is not None and key in _ACTION_IDS_CACHE: + _ACTION_IDS_CACHE.move_to_end(key) + return _ACTION_IDS_CACHE[key] + + hashes = data["hashes"] + frames = np.asarray(data["frame_idx"]) + uniq = sorted({str(x) for x in hashes}) + + # Warm the annotation_intervals lru in parallel (NFS-bound, like + # _populate_browser_list does). + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=min(16, max(1, len(uniq)))) as ex: + list(ex.map(lambda h: annotation_intervals(zarr_root, h), uniq)) + + ids = np.full(n, -1, dtype=np.int64) + next_id = 0 + for h in uniq: + hmask = np.asarray(hashes == h) + if not hmask.any(): + continue + assigned = np.zeros(n, dtype=bool) + spans = [] + for s, e, *_ in annotation_intervals(zarr_root, h): + if (s, e) not in spans: + spans.append((s, e)) + # Process in list order so overlapping spans resolve to the first + # covering interval — same semantics as _action_key. + for s, e in spans: + sel = hmask & ~assigned & (frames >= s) & (frames < e) + if sel.any(): + ids[sel] = next_id + assigned |= sel + next_id += 1 + rest = hmask & ~assigned + if rest.any(): + ids[rest] = next_id # whole-recording fallback action + next_id += 1 + + if key is not None: + _ACTION_IDS_CACHE[key] = ids + while len(_ACTION_IDS_CACHE) > _ACTION_IDS_CACHE_MAX: + _ACTION_IDS_CACHE.popitem(last=False) + return ids + + +def _dedupe_neighbors_by_action( + zarr_root: str, data: dict, candidates, k: int, cache_key: tuple | None = None +): + """Keep only the nearest row per action instance, up to k rows. + `candidates` is an iterable of (distance, row_idx) pairs already sorted + by ascending distance; consumed lazily, so callers can pass a generator + over a full argsort without materializing it.""" + ids = _row_action_ids(zarr_root, data, cache_key=cache_key) + seen: set = set() + out: list[tuple[float, int]] = [] + for dist, ridx in candidates: + a = int(ids[ridx]) + if a in seen: + continue + seen.add(a) + out.append((dist, ridx)) + if len(out) >= k: + break + return out + + +def _find_pair_action( + zarr_root: str, data: dict, video_hash: str, frame_idx: int, src_emb: str +): + """Locate the matching action in the paired opposite-embodiment episode. + + Perfect-pair episodes carry (near-)identical annotation sets, so the twin + is the opposite-embodiment episode sharing the most (start, end, text) + annotation entries with the clicked episode. The matching action is the + twin's interval equal to the clicked frame's covering interval, with a + text-only fallback (two aria pair episodes share canonical texts but not + every rephrasing). Returns (twin_hash, start, end) or None when the frame + is unannotated or no opposite-embodiment episode shares any entry.""" + intervals = annotation_intervals(zarr_root, str(video_hash)) + cover = [(s, e) for s, e, _, _ in intervals if s <= frame_idx < e] + if not cover: + return None + s0, e0 = cover[0] + own = {(s, e, low) for s, e, low, _ in intervals} + embs = np.asarray(data["embs"]) + opp_hashes = sorted({str(x) for x in data["hashes"][embs != src_emb]}) + best_hash, best_iv, best_score = None, None, 0 + for h2 in opp_hashes: + iv2 = annotation_intervals(zarr_root, h2) + score = sum((s, e, low) in own for s, e, low, _ in iv2) + if score > best_score: + best_hash, best_iv, best_score = h2, iv2, score + if best_hash is None: + return None + # The matching action is the twin's interval with the same boundaries + # (pair members share segmentation); fall back to max frame-overlap. + spans = sorted({(s, e) for s, e, _, _ in best_iv}) + if (s0, e0) in spans: + s2, e2 = s0, e0 + else: + s2, e2 = max(spans, key=lambda sp: max(0, min(sp[1], e0) - max(sp[0], s0))) + if min(e2, e0) - max(s2, s0) <= 0: + return None + # Select the twin's prompt with the SAME rule the clicked-frame display + # uses (interval_for_frame → last paraphrase covering the frame). With + # byte-identical pair annotations this makes the paired prompt equal the + # clicked prompt; when they differ it stays an apples-to-apples compare. + m = interval_for_frame(best_iv, s2) + prompt = m[2] if m is not None else "" + return best_hash, s2, e2, prompt + + +def _pair_action_rows(data: dict, twin_hash: str, start: int, end: int): + """Row indices of the twin episode's frames inside the action interval.""" + frames = np.asarray(data["frame_idx"]) + return np.where((data["hashes"] == twin_hash) & (frames >= start) & (frames < end))[ + 0 + ] + + +def _pair_distance_text( + twin_hash: str, + start: int, + end: int, + prompt: str, + dists_rows, + rows, + frames, + space_label: str, +) -> str: + """Render the paired-action distance block shown under the clicked frame.""" + j = int(np.argmin(dists_rows)) + return ( + f"\n\n-- paired action ({space_label}) --\n" + f"{twin_hash} seg [{start},{end})\n" + f"prompt: {prompt}\n" + f"min dist {float(dists_rows[j]):.4f} @ frame {int(frames[int(rows[j])])}" + ) class ScatterView: @@ -126,6 +299,65 @@ def __init__( self.lang_key = lang_key self.image_key = image_key self.default_sample = default_sample + self._knn_trees = {} # (run_path, layer, reduction, target_emb) -> (cKDTree, target_indices) + # (run, layer, drop_k, rows-md5) -> (rows, umap_xyz) for the live + # PCA-removal recompute. Small: only sampled points are embedded. + self._live_umap_cache = {} + + def _live_umap_coords(self, run_path: str, layer: str, drop_k: int, rows): + """UMAP-3D embedding of the Fisher-filtered PCA features for the + given original-CSV `rows` (the sampled points). Cached by content, + so re-clicking Apply with unchanged controls is instant. Returns + (coords (len(rows), 3), None) or (None, error_message).""" + import hashlib + import time + + cache_key = ( + run_path, + layer, + int(drop_k), + hashlib.md5(np.ascontiguousarray(rows).tobytes()).hexdigest(), + ) + cached = self._live_umap_cache.get(cache_key) + if cached is not None: + return cached, None + + feats = self.store.pca_features_dropped( + run_path, layer, n_components=50, drop_k=drop_k + ) + if feats is None: + return None, ( + f"PCA removal needs raw keys + both embodiments — " + f"{layer}_keys.pt missing or Fisher ranking unavailable " + f"(re-run eval with +force_reeval=true to write raw keys)" + ) + if feats.shape[1] == 0: + return None, "all PCA dims removed — lower the slider" + if rows.max() >= feats.shape[0]: + return None, ( + f"PCA features rows ({feats.shape[0]}) mismatch CSV rows — " + f"stale {layer}_keys.pt?" + ) + try: + import umap + except ImportError: + return None, "umap-learn not installed — pip install umap-learn" + t0 = time.perf_counter() + X = feats[rows] + coords = umap.UMAP(n_components=3, random_state=42).fit_transform(X) + coords = np.asarray(coords, dtype=np.float32) + logger.info( + "live UMAP for %s | %s drop_k=%d on %d pts in %.1fs", + os.path.basename(run_path), + layer, + drop_k, + len(rows), + time.perf_counter() - t0, + ) + self._live_umap_cache[cache_key] = coords + while len(self._live_umap_cache) > 8: + self._live_umap_cache.pop(next(iter(self._live_umap_cache))) + return coords, None def build_figure( self, @@ -137,6 +369,7 @@ def build_figure( remove_outliers: bool = False, outlier_thresh: float = 3.0, excludes: list[str] | None = None, + pca_drop_k: int = 0, ): import time @@ -163,22 +396,27 @@ def _phase(name: str, t0: float): data = self.store.load(run_path, layer) _phase("store.load", t0) + pca_drop_k = int(pca_drop_k or 0) + # Trigger the selected reduction's lazy materialization NOW (before # filter/sample). Lazy fields are None placeholders in the dict; # without this, the filtered/sampled copy propagates None and the - # scatter goes blank. - t0 = time.perf_counter() - reduction_key = _REDUCTION_TO_COORDS.get(reduction, ("umap_xyz", "?"))[0] - _ = data[reduction_key] - _phase(f"materialize[{reduction_key}]", t0) + # scatter goes blank. Skipped when PCA removal is active — coords + # then come from a live UMAP recompute, not the CSV. + if pca_drop_k <= 0: + t0 = time.perf_counter() + reduction_key = _REDUCTION_TO_COORDS.get(reduction, ("umap_xyz", "?"))[0] + _ = data[reduction_key] + _phase(f"materialize[{reduction_key}]", t0) # Apply lang filter (e.g. drop "home" frames) BEFORE sampling so the # filter affects what shows up on the plot, not just whether the # `sample` points include filtered ones. t0 = time.perf_counter() n_lang_dropped = 0 + lang_keep = None if excludes: - data, n_lang_dropped = _filter_data_by_lang( + data, n_lang_dropped, lang_keep = _filter_data_by_lang( data, list(excludes), self.zarr_root ) _phase("lang_filter", t0) @@ -194,6 +432,13 @@ def _phase(name: str, t0: float): _phase("sample", t0) sub_groups = groups[keep] n_outliers_removed = 0 + # Original-CSV row indices of the sampled points — needed to slice + # the full-N PCA features when PCA removal is active. + orig_rows = ( + np.flatnonzero(lang_keep)[keep] + if lang_keep is not None + else np.asarray(keep) + ) # All reductions are read directly from precomputed coords produced # by eval_latent — no client-side recompute. If the user picks a @@ -217,7 +462,20 @@ def _phase(name: str, t0: float): "tsne3d": ("tsne3d_xyz", "t-SNE 3D", 3, "evaluator.compute_tsne_3d=true"), "pca": ("pca_xyz", "PCA", 3, "evaluator.compute_pca=true"), } - if reduction not in spec: + if pca_drop_k > 0: + # PCA removal active: ignore the precomputed reductions and embed + # the sampled points' Fisher-filtered PCA features with a live + # UMAP fit (cached per (run, layer, k, sample)). + t0 = time.perf_counter() + coords, live_err = self._live_umap_coords( + run_path, layer, pca_drop_k, orig_rows + ) + _phase("live_umap", t0) + if coords is None: + missing_msg = live_err + else: + axis_prefix, d = "umap", 3 + elif reduction not in spec: missing_msg = f"unknown reduction: {reduction!r}" else: key, label, dims, flag = spec[reduction] @@ -340,6 +598,11 @@ def _phase(name: str, t0: float): fig.update_layout( title=( f"{layer} ({axis_prefix} {d}D, n={coords.shape[0]}" + + ( + f", live UMAP on PCA-50 −top{pca_drop_k} Fisher dims" + if pca_drop_k > 0 + else "" + ) + ( f", {n_outliers_removed} outliers hidden" if n_outliers_removed > 0 @@ -397,32 +660,122 @@ def preset_button(self, label, value): }, ) + def _get_knn_tree(self, data, run_path, layer, reduction, target_emb, coords=None): + """Get or build a cKDTree for the target embodiment's coords.""" + from scipy.spatial import cKDTree + + cache_key = (run_path, layer, reduction, target_emb) + if cache_key in self._knn_trees: + return self._knn_trees[cache_key] + + if coords is None: + coord_key = _REDUCTION_TO_COORDS.get(reduction, ("umap_xyz", "?"))[0] + coords = data.get(coord_key) + if coords is None: + return None + + embs = data["embs"] + target_mask = embs == target_emb + if not target_mask.any(): + self._knn_trees[cache_key] = None + return None + + target_idx = np.where(target_mask)[0] + tree = cKDTree(coords[target_idx]) + entry = (tree, target_idx) + self._knn_trees[cache_key] = entry + if len(self._knn_trees) > 16: + oldest = next(iter(self._knn_trees)) + del self._knn_trees[oldest] + return entry + def knn_other_embodiment( - self, data, src_idx, src_emb, reduction, k=10, coords_override=None + self, + data, + src_idx, + src_emb, + reduction, + k=10, + coords_override=None, + run_path=None, + layer=None, ): """Return list of dicts for the K closest opposite-embodiment rows. - If `coords_override` is given, distances are computed on those coords - (use for raw 256d KNN). Otherwise the reduction's CSV columns are - used. Each dict has hash, frame, token, embodiment, distance. - None on missing data, [] when no opposite-embodiment rows exist.""" + Uses a cKDTree for O(log N) queries instead of brute-force.""" if coords_override is not None: coords = coords_override else: coord_key = _REDUCTION_TO_COORDS.get(reduction, ("umap_xyz", "?"))[0] coords = data.get(coord_key) if coords is None: - return None # missing column / file + return None + embs = data["embs"] - target_mask = embs != src_emb - if not target_mask.any(): + unique_embs = set(np.unique(embs)) + unique_embs.discard(src_emb) + if not unique_embs: return [] - target_idx = np.where(target_mask)[0] - diff = coords[target_idx] - coords[src_idx] - dists = np.linalg.norm(diff, axis=1) - order = np.argsort(dists)[:k] + + all_neighbors = [] + for target_emb in unique_embs: + if run_path and layer: + entry = self._get_knn_tree( + data, run_path, layer, reduction, target_emb, coords + ) + else: + entry = None + + if entry is not None: + tree, target_idx = entry + query_pt = coords[src_idx].reshape(1, -1) + n_target = tree.data.shape[0] + # At most one row per action instance: the raw top-k is + # dominated by temporally-adjacent rows of a single + # opposite-embodiment motion, so over-query and dedupe, + # expanding until k distinct actions are found or the + # embodiment is exhausted. + query_k = min(max(64, k * 16), n_target) + while True: + dists_arr, idx_arr = tree.query(query_pt, k=query_k) + cand = ( + (float(d), int(target_idx[int(i)])) + for d, i in zip( + np.atleast_1d(np.squeeze(dists_arr)), + np.atleast_1d(np.squeeze(idx_arr)), + ) + ) + kept = _dedupe_neighbors_by_action( + self.zarr_root, + data, + cand, + k, + cache_key=(run_path, layer), + ) + if len(kept) >= k or query_k >= n_target: + break + query_k = min(query_k * 8, n_target) + all_neighbors.extend(kept) + else: + target_mask = embs == target_emb + target_idx = np.where(target_mask)[0] + diff = coords[target_idx] - coords[src_idx] + dists = np.linalg.norm(diff, axis=1) + order = np.argsort(dists) + cand = ((float(dists[int(o)]), int(target_idx[int(o)])) for o in order) + all_neighbors.extend( + _dedupe_neighbors_by_action( + self.zarr_root, data, cand, k, cache_key=(run_path, layer) + ) + ) + + all_neighbors.sort(key=lambda x: x[0]) + # Re-dedupe across embodiments (no-op when there is only one + # opposite embodiment) and truncate to k. + all_neighbors = _dedupe_neighbors_by_action( + self.zarr_root, data, all_neighbors, k, cache_key=(run_path, layer) + ) out = [] - for rank, o in enumerate(order, start=1): - ridx = int(target_idx[o]) + for rank, (dist, ridx) in enumerate(all_neighbors, start=1): out.append( { "rank": rank, @@ -430,7 +783,7 @@ def knn_other_embodiment( "frame_idx": int(data["frame_idx"][ridx]), "token_idx": int(data["token_idx"][ridx]), "embodiment": str(data["embs"][ridx]), - "distance": float(dists[o]), + "distance": dist, } ) return out @@ -598,10 +951,21 @@ def knn_buttons(self, neighbors): return rows def build_inspect_payload( - self, run_path, layer, video_hash, frame_idx, token_idx, emb, reduction + self, + run_path, + layer, + video_hash, + frame_idx, + token_idx, + emb, + reduction, + pca_drop_k: int = 0, ): """Compute everything needed to populate the right pane for one - clicked frame: meta text, image URI, lang text, KNN buttons, label.""" + clicked frame: meta text, image URI, lang text, KNN buttons, label. + When `pca_drop_k` > 0 the KNN is computed in the Fisher-filtered + PCA-50 space (matching the live-UMAP scatter) instead of the + precomputed reduction coords.""" from dash import html meta = ( @@ -633,8 +997,12 @@ def build_inspect_payload( "Paths attempted:\n " + "\n ".join(lang_tried[:25]) ) - red_label = _REDUCTION_TO_COORDS.get(reduction, ("?", reduction))[1] - knn_label = f"10 closest opposite-embodiment frames ({red_label})" + pca_drop_k = int(pca_drop_k or 0) + if pca_drop_k > 0: + red_label = f"PCA-50 −top{pca_drop_k} Fisher dims" + else: + red_label = _REDUCTION_TO_COORDS.get(reduction, ("?", reduction))[1] + knn_label = f"10 closest opposite-embodiment frames ({red_label}, 1 per action)" knn_buttons = [html.Div("(no run/layer selected)", style={"color": "#94a3b8"})] if run_path and layer: try: @@ -655,10 +1023,64 @@ def build_inspect_payload( else: src_idx = int(where[0]) src_emb = str(data["embs"][src_idx]) + coords_override = None + knn_reduction = reduction + if pca_drop_k > 0: + # Same Fisher-filtered space the live-UMAP scatter is + # built from. The tag keeps the cKDTree cache keyed + # per drop level. + coords_override = self.store.pca_features_dropped( + run_path, layer, n_components=50, drop_k=pca_drop_k + ) + knn_reduction = f"pca50_drop{pca_drop_k}" + if coords_override is None or coords_override.shape[0] != len( + data["hashes"] + ): + raise RuntimeError( + "Fisher-filtered PCA features unavailable or " + "misaligned — re-run eval with " + "+force_reeval=true to write raw keys" + ) neighbors = self.knn_other_embodiment( - data, src_idx, src_emb, reduction, k=10 + data, + src_idx, + src_emb, + knn_reduction, + k=10, + coords_override=coords_override, + run_path=run_path, + layer=layer, ) knn_buttons = self.knn_buttons(neighbors) + # Paired-action distance, in the same space as the KNN + # list, appended below the clicked frame's language. + if coords_override is not None: + pair_coords = coords_override + else: + coord_key = _REDUCTION_TO_COORDS.get( + knn_reduction, ("umap_xyz", "?") + )[0] + pair_coords = data.get(coord_key) + pair = _find_pair_action( + self.zarr_root, data, str(video_hash), int(frame_idx), src_emb + ) + if pair is not None and pair_coords is not None: + twin, s2, e2, pair_prompt = pair + rows = _pair_action_rows(data, twin, s2, e2) + if rows.size: + d = np.linalg.norm( + pair_coords[rows] - pair_coords[src_idx], axis=1 + ) + lang_display = lang_display + _pair_distance_text( + twin, + s2, + e2, + pair_prompt, + d, + rows, + np.asarray(data["frame_idx"]), + red_label, + ) except Exception as e: knn_buttons = [ html.Div( @@ -741,6 +1163,7 @@ def _on_run_change(run_path): State("outlier_thresh", "value"), State("browser_hide_home", "value"), State("browser_lang_exclude", "value"), + State("pca_drop_k", "value"), ) def update_figure( _n_clicks, @@ -752,6 +1175,7 @@ def update_figure( outlier_thresh_val, hide_home_val, lang_exclude_val, + pca_drop_k, ): # Only re-render when the user clicks Apply. Pre-1st-click also fires # once (n_clicks=0) so the initial figure renders on page load. @@ -773,6 +1197,7 @@ def update_figure( remove_outliers=remove_out, outlier_thresh=thresh, excludes=excludes, + pca_drop_k=int(pca_drop_k or 0), ) # Unified click handler — fires on: @@ -802,9 +1227,17 @@ def update_figure( State("run", "value"), State("layer", "value"), State("reduction", "value"), + State("pca_drop_k", "value"), ) def on_inspect( - clickData, _knn_clicks, _back_clicks, nav_stack, run_path, layer, reduction + clickData, + _knn_clicks, + _back_clicks, + nav_stack, + run_path, + layer, + reduction, + pca_drop_k, ): nav_stack = list(nav_stack or []) ctx = dash.callback_context @@ -828,6 +1261,7 @@ def on_inspect( int(prev["token"]), prev["emb"], reduction, + pca_drop_k=int(pca_drop_k or 0), ) back_disabled = len(nav_stack) <= 1 return (*payload, nav_stack, back_disabled) @@ -908,6 +1342,7 @@ def on_inspect( int(nav_stack[-1]["token"]), nav_stack[-1]["emb"], reduction, + pca_drop_k=int(pca_drop_k or 0), ), nav_stack, len(nav_stack) <= 1, @@ -926,6 +1361,7 @@ def on_inspect( new_pt["token"], new_pt["emb"], reduction, + pca_drop_k=int(pca_drop_k or 0), ) back_disabled = len(nav_stack) <= 1 return (*payload, nav_stack, back_disabled) @@ -1096,10 +1532,25 @@ def browser_knn_buttons(self, neighbors): ) return rows - def render_browser_detail(self, run_path, layer, h, f, tok, knn_space="raw"): + def render_browser_detail( + self, + run_path, + layer, + h, + f, + tok, + knn_space="raw", + pca_drop_k=0, + pair_mode="cross", + ): """Build (meta, img_src, lang_display, knn_label, knn_buttons) for one clicked/popped (hash, frame, token). `knn_space` is 'raw' (full-D - keys) or 'pca' (50-d PCA features fitted on the layer's keys).""" + keys) or 'pca' (50-d PCA features fitted on the layer's keys). + `pca_drop_k` > 0 removes the top-k PCA dims ranked by aria-vs-eva + Fisher score from the KNN space: dropped columns in 'pca' mode, + projected-out component directions in 'raw' mode. `pair_mode` is + 'cross' (neighbors from OTHER embodiments — default) or 'same' + (neighbors from the clicked frame's own embodiment, self excluded).""" import time from dash import html @@ -1110,13 +1561,17 @@ def render_browser_detail(self, run_path, layer, h, f, tok, knn_space="raw"): def _phase(name: str, t0: float): _t_phases[name] = time.perf_counter() - t0 + pca_drop_k = int(pca_drop_k or 0) + pair_mode = pair_mode or "cross" logger.info( - "render_browser_detail ENTER layer=%s h=%s f=%s tok=%s knn_space=%s", + "render_browser_detail ENTER layer=%s h=%s f=%s tok=%s knn_space=%s drop_k=%d pairs=%s", layer, h, f, tok, knn_space, + pca_drop_k, + pair_mode, ) t0 = time.perf_counter() @@ -1170,8 +1625,15 @@ def _phase(name: str, t0: float): # raw-key space and dumped it as `_knn.pt`. When present # AND aligned with the current CSV, we skip the full-D distance # scan entirely. PCA mode still computes on demand because the - # n_components knob is inspector-side. - if knn_space != "pca" and src_idx is not None: + # n_components knob is inspector-side. PCA removal and + # same-embodiment pairs also bypass this path — the precomputed + # neighbors were built cross-embodiment in the full raw space. + if ( + knn_space != "pca" + and pca_drop_k <= 0 + and pair_mode != "same" + and src_idx is not None + ): t_knn = time.perf_counter() knn_pre = self.store.load_knn(run_path, layer) _phase("load_knn", t_knn) @@ -1201,11 +1663,61 @@ def _phase(name: str, t0: float): style={"color": "#94a3b8"}, ) ] - else: + phase_str = " ".join( + f"{n}={dt*1000:.0f}ms" for n, dt in _t_phases.items() + ) + logger.info( + "render_browser_detail (precomputed-KNN path) %s in %.2fs | %s", + layer, + time.perf_counter() - _t_overall, + phase_str, + ) + return meta, img_src, lang_display, knn_label, knn_buttons + # One row per action instance. The sidecar stores only the + # K (=8) globally-nearest rows, which usually collapse to a + # couple of distinct actions — only take this fast path when + # it can still fill the list; otherwise fall through to the + # full computed scan below. + cand = ( + (float(d), int(o)) + for d, o in zip(pre_dist.tolist(), pre_idx.tolist()) + ) + deduped = _dedupe_neighbors_by_action( + self.zarr_root, data, cand, 10, cache_key=(run_path, layer) + ) + if len(deduped) >= 10: + # Paired-action distance (raw space — same space the + # precomputed neighbors live in). Row-only gather from + # the mmap-backed keys keeps this path per-click cheap. + pair = _find_pair_action(self.zarr_root, data, h, f, src_emb) + if pair is not None: + twin, s2, e2, pair_prompt = pair + pair_rows = _pair_action_rows(data, twin, s2, e2) + keys_arr = ( + self.store.load_keys(run_path, layer) + if pair_rows.size + else None + ) + if keys_arr is not None and keys_arr.shape[0] == len( + data["hashes"] + ): + d = np.linalg.norm( + np.asarray(keys_arr[pair_rows], dtype=np.float32) + - np.asarray(keys_arr[src_idx], dtype=np.float32), + axis=1, + ) + lang_display = lang_display + _pair_distance_text( + twin, + s2, + e2, + pair_prompt, + d, + pair_rows, + np.asarray(data["frame_idx"]), + "raw", + ) neighbors = [] - for rank, (o, dist) in enumerate( - zip(pre_idx.tolist(), pre_dist.tolist()), start=1 - ): + for rank, (dist, o) in enumerate(deduped, start=1): neighbors.append( { "rank": rank, @@ -1213,33 +1725,44 @@ def _phase(name: str, t0: float): "frame_idx": int(data["frame_idx"][o]), "token_idx": int(data["token_idx"][o]), "embodiment": str(data["embs"][o]), - "distance": float(dist), + "distance": dist, } ) knn_label = ( f"{len(neighbors)} closest CROSS-embodiment in " - f"{space_label} (source='{src_emb}')" + f"{space_label} (source='{src_emb}', 1 per action)" ) knn_buttons = self.browser_knn_buttons(neighbors) - phase_str = " ".join( - f"{n}={dt*1000:.0f}ms" for n, dt in _t_phases.items() - ) - logger.info( - "render_browser_detail (precomputed-KNN path) %s in %.2fs | %s", - layer, - time.perf_counter() - _t_overall, - phase_str, - ) - return meta, img_src, lang_display, knn_label, knn_buttons + phase_str = " ".join( + f"{n}={dt*1000:.0f}ms" for n, dt in _t_phases.items() + ) + logger.info( + "render_browser_detail (precomputed-KNN path) %s in %.2fs | %s", + layer, + time.perf_counter() - _t_overall, + phase_str, + ) + return meta, img_src, lang_display, knn_label, knn_buttons + # `drop_dirs` is only set in raw mode with removal active: the + # (drop_k, D) component directions whose contribution gets + # subtracted from the raw distances below. + drop_dirs = None if knn_space == "pca": - feats = self.store.pca_features(run_path, layer, n_components=50) - space_label = "PCA-50" + feats = self.store.pca_features_dropped( + run_path, layer, n_components=50, drop_k=pca_drop_k + ) + space_label = ( + f"PCA-50 −top{pca_drop_k} Fisher dims" if pca_drop_k > 0 else "PCA-50" + ) missing_msg = ( f"({layer}_keys.pt missing — PCA features can't be " f"fitted; re-run eval with +force_reeval=true to " f"write raw keys)" ) + if pca_drop_k > 0 and feats is not None and feats.shape[1] == 0: + feats = None + missing_msg = "(all PCA dims removed — lower the slider)" else: feats = self.store.load_keys(run_path, layer) space_label = f"raw D={feats.shape[1]}" if feats is not None else "raw keys" @@ -1248,6 +1771,18 @@ def _phase(name: str, t0: float): f"+force_reeval=true to write raw keys " f"(KNN is bundled inside the same file))" ) + if pca_drop_k > 0 and feats is not None: + drop_dirs = self.store.fisher_dropped_directions( + run_path, layer, n_components=50, drop_k=pca_drop_k + ) + if drop_dirs is None: + feats = None + missing_msg = ( + "(Fisher ranking unavailable — needs PCA on raw keys " + "and both aria/eva embodiments in this layer)" + ) + else: + space_label += f" −top{pca_drop_k} Fisher PCA dims" if feats is None: knn_label = f"10 closest in {space_label}" @@ -1271,29 +1806,103 @@ def _phase(name: str, t0: float): ] else: diff = feats - feats[src_idx] - dists = np.linalg.norm(diff, axis=1) - # Cross-embodiment only: mask self AND every row whose - # embodiment matches the clicked frame's. Without this the - # neighbors are dominated by frames from the same recording - # type, which defeats the point of the visualization. - same_or_self = data["embs"] == src_emb - dists[same_or_self] = np.inf - n_candidates = int((~same_or_self).sum()) + if drop_dirs is not None: + # Raw-space removal: subtract the dropped components' + # contribution from the squared distances instead of + # materializing a filtered copy of the (N, D) keys. + d2 = np.einsum("ij,ij->i", diff, diff) + proj = diff @ drop_dirs.T + d2 -= np.einsum("ij,ij->i", proj, proj) + dists = np.sqrt(np.clip(d2, 0.0, None)) + else: + dists = np.linalg.norm(diff, axis=1) + # Paired-action distance, appended below the clicked frame's + # language. Computed before the pair-mode masking so the twin's + # (opposite-embodiment) rows still hold finite distances. + pair = _find_pair_action(self.zarr_root, data, h, f, src_emb) + if pair is not None: + twin, s2, e2, pair_prompt = pair + pair_rows = _pair_action_rows(data, twin, s2, e2) + if pair_rows.size: + lang_display = lang_display + _pair_distance_text( + twin, + s2, + e2, + pair_prompt, + dists[pair_rows], + pair_rows, + np.asarray(data["frame_idx"]), + space_label, + ) + # Candidate masking by pair mode. Cross (default): mask self AND + # every row whose embodiment matches the clicked frame's — + # without this the neighbors are dominated by frames from the + # same recording type, which defeats the point of the + # visualization. Same: keep only the clicked frame's embodiment, + # masking the clicked row itself. + if pair_mode == "same": + mask_out = np.asarray(data["embs"] != src_emb) + mask_out[src_idx] = True + # Exclude the clicked frame's own ACTION: rows from the same + # recording whose frame falls inside any annotation interval + # covering the clicked frame. Otherwise the list is just the + # temporally adjacent frames of the same motion. When no + # interval covers the frame (or annotations are unreadable), + # fall back to excluding the whole recording. + same_hash = np.asarray(data["hashes"] == h) + intervals = annotation_intervals(self.zarr_root, h) + spans = [(s, e) for s, e, *_ in intervals if s <= f < e] + if spans: + frames = np.asarray(data["frame_idx"]) + in_action = np.zeros(len(frames), dtype=bool) + for s, e in spans: + in_action |= (frames >= s) & (frames < e) + mask_out |= same_hash & in_action + else: + mask_out |= same_hash + pair_label = "SAME" + pair_suffix = ", same action excluded" + empty_msg = ( + f"(no rows from embodiment '{src_emb}' outside the " + f"clicked action in this layer)" + ) + else: + mask_out = np.asarray(data["embs"] == src_emb) + pair_label = "CROSS" + pair_suffix = ", 1 per action" + empty_msg = ( + f"(no rows from a different embodiment than '{src_emb}' " + f"in this layer)" + ) + dists[mask_out] = np.inf + n_candidates = int((~mask_out).sum()) if n_candidates == 0: knn_label = f"10 closest in {space_label}" knn_buttons = [ html.Div( - f"(no rows from a different embodiment than '{src_emb}' " - f"in this layer)", + empty_msg, style={"color": "#94a3b8"}, ) ] else: - k = min(10, n_candidates) - order = np.argsort(dists)[:k] + order = np.argsort(dists) + if pair_mode == "same": + k = min(10, n_candidates) + picked = [(float(dists[int(o)]), int(o)) for o in order[:k]] + else: + # Cross mode: at most one row per action instance — + # otherwise the list is 10 temporally-adjacent frames + # of a single opposite-embodiment motion. argsort puts + # the masked (inf) rows last, so capping the walk at + # n_candidates only visits valid rows. + cand = ( + (float(dists[int(o)]), int(o)) for o in order[:n_candidates] + ) + picked = _dedupe_neighbors_by_action( + self.zarr_root, data, cand, 10, cache_key=(run_path, layer) + ) neighbors = [] - for rank, o in enumerate(order, start=1): - o = int(o) + for rank, (dist, o) in enumerate(picked, start=1): neighbors.append( { "rank": rank, @@ -1301,12 +1910,12 @@ def _phase(name: str, t0: float): "frame_idx": int(data["frame_idx"][o]), "token_idx": int(data["token_idx"][o]), "embodiment": str(data["embs"][o]), - "distance": float(dists[o]), + "distance": dist, } ) knn_label = ( - f"10 closest CROSS-embodiment in {space_label} " - f"(source='{src_emb}')" + f"10 closest {pair_label}-embodiment in {space_label} " + f"(source='{src_emb}'{pair_suffix})" ) knn_buttons = self.browser_knn_buttons(neighbors) @@ -1379,6 +1988,7 @@ def _toggle_view( Input("view_mode", "value"), Input("browser_hide_home", "value"), Input("browser_lang_exclude", "value"), + Input("browser_shuffle", "value"), prevent_initial_call=True, ) def _reset_visible_count(*_): @@ -1393,9 +2003,16 @@ def _reset_visible_count(*_): Input("browser_visible_count", "data"), Input("browser_hide_home", "value"), Input("browser_lang_exclude", "value"), + Input("browser_shuffle", "value"), ) def _populate_browser_list( - run_path, layer, mode, visible_count, hide_home_val, lang_exclude_val + run_path, + layer, + mode, + visible_count, + hide_home_val, + lang_exclude_val, + shuffle_val, ): import time as _time @@ -1494,6 +2111,14 @@ def _populate_browser_list( len(items), _time.perf_counter() - _t_pop_start, ) + # Optional shuffle of the browse order. Fixed seed: the full list + # is rebuilt on every callback (incl. Load-more re-fires), so a + # stable permutation keeps already-rendered cards in place while + # paging. Applied before the lang-filter slicing below so the + # filtered window matches what gets rendered. + if shuffle_val and "on" in shuffle_val and items: + rng = np.random.default_rng(12345) + items = [items[int(i)] for i in rng.permutation(len(items))] # Apply lang filter, if any. KEY OPTIMIZATION: only filter the # `visible_count + FILTER_BUFFER` items we're about to render — # NOT the whole 249-card list. Each filter hit costs a zarr open @@ -1757,6 +2382,8 @@ def _populate_browser_list( ), Input("browser_back", "n_clicks"), Input("browser_knn_space", "value"), + Input("pca_drop_k", "value"), + Input("browser_knn_pairs", "value"), State("browser_nav_stack", "data"), State("run", "value"), State("layer", "value"), @@ -1767,6 +2394,8 @@ def _on_browser_click( _knn_clicks, _back_clicks, knn_space, + pca_drop_k, + knn_pairs, nav_stack, run_path, layer, @@ -1778,18 +2407,23 @@ def _on_browser_click( return (dash.no_update,) * 7 triggered = ctx.triggered[0]["prop_id"] logger.info( - "_on_browser_click trigger=%s knn_space=%s run=%s layer=%s nav_depth=%d", + "_on_browser_click trigger=%s knn_space=%s drop_k=%s pairs=%s run=%s layer=%s nav_depth=%d", triggered, knn_space, + pca_drop_k, + knn_pairs, os.path.basename(run_path) if run_path else None, layer, len(nav_stack), ) - # ---- Branch 0: KNN-space radio toggled ---- + # ---- Branch 0: KNN-space radio, PCA-removal slider, or pair-mode + # toggle changed ---- # Re-render the current selection (if any) using the new space. # No nav_stack mutation. - if triggered.startswith("browser_knn_space"): + if triggered.startswith( + ("browser_knn_space", "pca_drop_k", "browser_knn_pairs") + ): if not nav_stack: return (dash.no_update,) * 7 cur = nav_stack[-1] @@ -1800,6 +2434,8 @@ def _on_browser_click( int(cur["frame"]), int(cur["token"]), knn_space=knn_space, + pca_drop_k=pca_drop_k, + pair_mode=knn_pairs, ) return (*payload, nav_stack, len(nav_stack) <= 1) @@ -1825,6 +2461,8 @@ def _on_browser_click( int(prev["frame"]), int(prev["token"]), knn_space=knn_space, + pca_drop_k=pca_drop_k, + pair_mode=knn_pairs, ) return (*payload, nav_stack, len(nav_stack) <= 1) @@ -1863,7 +2501,14 @@ def _on_browser_click( nav_stack.append(new_pt) payload = self.render_browser_detail( - run_path, layer, h, f, tok, knn_space=knn_space + run_path, + layer, + h, + f, + tok, + knn_space=knn_space, + pca_drop_k=pca_drop_k, + pair_mode=knn_pairs, ) return (*payload, nav_stack, len(nav_stack) <= 1) diff --git a/egomimic/scripts/data_visualization/latent_inspector.py b/egomimic/scripts/data_visualization/latent_inspector.py index 153b83815..ce3005b67 100644 --- a/egomimic/scripts/data_visualization/latent_inspector.py +++ b/egomimic/scripts/data_visualization/latent_inspector.py @@ -94,6 +94,28 @@ def main(): p.add_argument( "--host", default="0.0.0.0", help="Set to 127.0.0.1 to bind localhost only." ) + p.add_argument( + "--pair-rank", + type=int, + default=0, + metavar="N", + help="Report mode (no server): for each layer, sample N rows and " + "print the average place of the perfect-pair action among " + "opposite-embodiment actions (1 = nearest). E.g. --pair-rank 1000.", + ) + p.add_argument( + "--pair-rank-space", + default="umap", + choices=["umap", "pca_umap", "tsne2d"], + help="Coordinate space for the --pair-rank distances (CSV-baked " + "reductions; same spaces the scatter view plots).", + ) + p.add_argument( + "--pair-rank-seed", + type=int, + default=0, + help="Sampling seed for --pair-rank.", + ) args = p.parse_args() if args.root: @@ -115,6 +137,20 @@ def main(): ) ] + if args.pair_rank > 0: + from inspector_lib.pair_rank import pair_rank_report + + for disp, run_dir in runs: + print(f"\n===== run: {disp} =====") + pair_rank_report( + run_dir, + args.zarr_root, + n_samples=args.pair_rank, + space=args.pair_rank_space, + seed=args.pair_rank_seed, + ) + return + app = build_app( runs=runs, zarr_root=args.zarr_root,