rollout: normalize proprio before model to match training#472
Conversation
|
Warning This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
This stack of pull requests is managed by Graphite. Learn more about stacking. |
The algo-side prompt refactor (9af36e0) moved State-block discretization from build_tokenized_collate into PI._discretize_state_for_sample. The old path normalized proprio via data_schematic.normalize_data before clipping to [-1,1] and binning; the new path skips that step. Raw xyz in meters saturates the [-1,1] clip, producing meaningless bin indices the model never saw during training (causing jerky rollouts). Match MultiDataset.__getitem__'s behavior at the rollout boundary: normalize the post-transform single sample via the model's norm_stats (itself a MultiDataset) before collating and handing to process_batch_for_training. norm_stats are the ones baked into the checkpoint, so train/inference State bins now agree. Also adds a one-shot debug print of the assembled prompt on the first inference step, so the State block can be eyeballed at runtime. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
8085868 to
486eb81
Compare
Claude Code ReviewSummaryFixes a rollout/training mismatch by normalizing the post-transform sample with the checkpoint's Key concerns
Suggestions
VerdictComment — fix direction is right and the diagnosis is convincing, but please verify (1) no double-normalize in Reviewed by Claude · Review workflow |

The algo-side prompt refactor (9af36e0) moved State-block discretization
from build_tokenized_collate into PI._discretize_state_for_sample. The
old path normalized proprio via data_schematic.normalize_data before
clipping to [-1,1] and binning; the new path skips that step. Raw xyz
in meters saturates the [-1,1] clip, producing meaningless bin indices
the model never saw during training (causing jerky rollouts).
Match MultiDataset.getitem's behavior at the rollout boundary:
normalize the post-transform single sample via the model's norm_stats
(itself a MultiDataset) before collating and handing to
process_batch_for_training. norm_stats are the ones baked into the
checkpoint, so train/inference State bins now agree.
Also adds a one-shot debug print of the assembled prompt on the first
inference step, so the State block can be eyeballed at runtime.
Co-Authored-By: Claude Opus 4.7 noreply@anthropic.com