Skip to content

rollout: normalize proprio before model to match training#472

Open
rl2aloha wants to merge 1 commit into
elmo/rollout-cfgfrom
elmo/rollout-proprio-norm-fix
Open

rollout: normalize proprio before model to match training#472
rl2aloha wants to merge 1 commit into
elmo/rollout-cfgfrom
elmo/rollout-proprio-norm-fix

Conversation

@rl2aloha

Copy link
Copy Markdown
Contributor

The algo-side prompt refactor (9af36e0) moved State-block discretization
from build_tokenized_collate into PI._discretize_state_for_sample. The
old path normalized proprio via data_schematic.normalize_data before
clipping to [-1,1] and binning; the new path skips that step. Raw xyz
in meters saturates the [-1,1] clip, producing meaningless bin indices
the model never saw during training (causing jerky rollouts).

Match MultiDataset.getitem's behavior at the rollout boundary:
normalize the post-transform single sample via the model's norm_stats
(itself a MultiDataset) before collating and handing to
process_batch_for_training. norm_stats are the ones baked into the
checkpoint, so train/inference State bins now agree.

Also adds a one-shot debug print of the assembled prompt on the first
inference step, so the State block can be eyeballed at runtime.

Co-Authored-By: Claude Opus 4.7 noreply@anthropic.com

rl2aloha commented May 24, 2026

Copy link
Copy Markdown
Contributor Author

The algo-side prompt refactor (9af36e0) moved State-block discretization
from build_tokenized_collate into PI._discretize_state_for_sample. The
old path normalized proprio via data_schematic.normalize_data before
clipping to [-1,1] and binning; the new path skips that step. Raw xyz
in meters saturates the [-1,1] clip, producing meaningless bin indices
the model never saw during training (causing jerky rollouts).

Match MultiDataset.__getitem__'s behavior at the rollout boundary:
normalize the post-transform single sample via the model's norm_stats
(itself a MultiDataset) before collating and handing to
process_batch_for_training. norm_stats are the ones baked into the
checkpoint, so train/inference State bins now agree.

Also adds a one-shot debug print of the assembled prompt on the first
inference step, so the State block can be eyeballed at runtime.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@github-actions

github-actions Bot commented Jun 2, 2026

Copy link
Copy Markdown

Claude Code Review

Summary

Fixes a rollout/training mismatch by normalizing the post-transform sample with the checkpoint's norm_stats before collation, restoring the State-block discretization invariant that the recent algo-side refactor broke.

Key concerns

  1. Double-normalization risk if process_batch_for_training also normalizes. The justification cites MultiDataset.__getitem__ as the normalization point, which is correct for training. But please confirm that PI.process_batch_for_training does not itself re-apply norm_stats.normalize — otherwise rollouts will normalize twice while training normalizes once, re-introducing a (subtler) train/inference skew. A quick grep of process_batch_for_training in egomimic/algo/pi.py would confirm.

  2. Action denormalization downstream. If proprio is now normalized on the way in, ensure the predicted actions returned by forward_eval are denormalized before being sent to the robot. The previous unnormalized-input path may have masked a missing denorm step; worth tracing preds to make sure actions_cartesian are in meters at the robot interface.

  3. Embodiment string ordering. self.embodiment_id is used for normalize() before the embodiment_name is computed from self.arm. Confirm these always agree (i.e., self.embodiment_id == "eva_bimanual" when self.arm == "both"). If embodiment_id is set elsewhere from config, a mismatch would silently normalize against the wrong stats.

Suggestions

  • Replace the one-shot print with logger.debug or gate behind self.debug. Unconditional stdout in rollout makes log diffing across experiments noisier, and "first inference step" prints tend to ossify into permanent noise.
  • Add a brief assertion or comment that self.policy.model.norm_stats is the checkpoint-baked MultiDataset, not the training-time one — this is the load-bearing assumption of the fix and worth pinning down in code.
  • A regression test would be valuable: feed a known proprio sample through both MultiDataset.__getitem__ → collate and through the rollout path, and assert the resulting State tokens are bit-identical. This is exactly the kind of silent drift that motivated the PR.
  • Consider a short comment in PI._discretize_state_for_sample noting that callers are responsible for normalization, so the next refactor doesn't re-break this.

Verdict

Comment — fix direction is right and the diagnosis is convincing, but please verify (1) no double-normalize in process_batch_for_training and (2) action denorm on the return path before merging. Also downgrade the debug print.


Reviewed by Claude · Review workflow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant