Skip to content

Pass norm stats to collate for proprio tong kenization#400

Open
ElmoPA wants to merge 1 commit into
mainfrom
elmo/norm-collate
Open

Pass norm stats to collate for proprio tong kenization#400
ElmoPA wants to merge 1 commit into
mainfrom
elmo/norm-collate

Conversation

@ElmoPA

@ElmoPA ElmoPA commented May 8, 2026

Copy link
Copy Markdown
Contributor

No description provided.

@github-actions

github-actions Bot commented Jun 2, 2026

Copy link
Copy Markdown

Claude Code Review

Summary

Threads data_schematic into build_tokenized_collate so proprio tokenization uses properly normalized values (via normalize_data) instead of assuming upstream normalization. This fixes a latent bug where raw proprio was being clipped to [-1,1] and binned.

Key concerns

  1. MultiDataModuleWrapper signature — The diff updates the hydra.utils.instantiate call to pass data_schematic=data_schematic, but doesn't show the corresponding change in MultiDataModuleWrapper.__init__ or the wiring through to build_tokenized_collate. If the wrapper doesn't accept/forward data_schematic, this will fail at instantiate time. Please confirm that change is in this PR.

  2. Silent skip of unregistered keys could mask config errors. When zarr_key_to_keyname returns None, the key is silently dropped. If all proprio_keys are unregistered for a given embodiment, raw is empty and the function returns None — no prompt state, no error. Given the hard-fail philosophy elsewhere in this function, consider warning or erroring when a key in proprio_keys is configured but unregistered for an active embodiment. At minimum, a one-time warning per (embodiment, key) would help catch config drift.

  3. Order preservation relies on dict insertion order. The comment says "Iterate in raw insertion order (which mirrors proprio_keys ordering)" — this is correct in Py3.7+, but if normalize_data ever returns a re-ordered dict (or if a downstream refactor changes that), the bin layout in the prompt will silently shift and break trained checkpoints. Safer to iterate proprio_keys directly and look up the translated keyname:

    for k in proprio_keys:
        keyname = keyname_to_zarr_inverse.get(k)
        if keyname is None or keyname not in normed:
            continue
        v = normed[keyname]...
  4. Norm stats checkpoint compatibility — Any existing pi0.5 checkpoints trained with proprio=True under the old "assumes upstream normalization" path will have learned bins on a different value distribution. This silently invalidates prior runs that used this code path. Worth a callout in the PR description / release notes so nobody resumes a stale checkpoint.

Suggestions

  • Show / verify the MultiDataModuleWrapper.__init__ change in this PR.
  • Iterate proprio_keys (not raw.keys()) for the concat to make ordering robust to dict semantics.
  • Add at least one unit test: build a collate with a stub data_schematic, feed a sample with known stats, assert bins match a hand-computed expectation. Also test the three failure modes (data_schematic=None, norm_stats=None, missing embodiment).
  • Consider logging (once) when a proprio_keys entry is dropped because zarr_key_to_keyname returned None.
  • Add a PR description — this is a semantically meaningful change for any pi0.5 experiment in flight.

Verdict: Request Changes

Mainly to (a) confirm the MultiDataModuleWrapper signature change is included and (b) harden the ordering and add a minimal test, given this affects token alignment in a way that's hard to spot post-hoc.


Reviewed by Claude · Review workflow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant