From 60044b8bf2ac0a2f21c7c1434f6cc6272f5774b4 Mon Sep 17 00:00:00 2001 From: dubin555 Date: Sat, 14 Mar 2026 10:11:49 +0000 Subject: [PATCH] fix: rename loop variable to avoid shadowing outer `data` in train_step In CriticWorker.train_step and diffusion ActorWorker.train_step, the for-loop used `data` as both the outer batch variable and the loop iteration variable. After the loop, `data.to("cpu")` only moved the last mini-batch to CPU instead of the full batch, causing GPU memory to not be properly released. Rename the loop variable from `data` to `mini_batch` so the outer `data` variable remains intact for proper cleanup after the loop. --- roll/pipeline/base_worker.py | 4 +- .../diffusion/reward_fl/actor_worker.py | 4 +- tests/test_train_step_variable_shadow.py | 81 +++++++++++++++++++ 3 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 tests/test_train_step_variable_shadow.py diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index 6f1ce7b17..d25ca4844 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -619,12 +619,12 @@ def train_step(self, data: DataProto): dataloader_kwargs={"shuffle": True}, ) - for batch_idx, data in tqdm( + for batch_idx, mini_batch in tqdm( enumerate(dataloader), desc=f"{self.worker_name} train global step {global_step}", total=data.batch.batch_size[0] * self.pipeline_config.ppo_epochs // backward_batch_size, ): - vf_metrics = self.strategy.train_step(batch=data, loss_func=self.loss_func) + vf_metrics = self.strategy.train_step(batch=mini_batch, loss_func=self.loss_func) append_to_dict(metrics, vf_metrics) data.to("cpu") diff --git a/roll/pipeline/diffusion/reward_fl/actor_worker.py b/roll/pipeline/diffusion/reward_fl/actor_worker.py index 085b0f338..8d7f3139a 100644 --- a/roll/pipeline/diffusion/reward_fl/actor_worker.py +++ b/roll/pipeline/diffusion/reward_fl/actor_worker.py @@ -31,12 +31,12 @@ def train_step(self, data: DataProto): dataloader_kwargs={"shuffle": False}, ) - for batch_idx, data in tqdm( + for batch_idx, mini_batch in tqdm( enumerate(dataloader), desc=f"{self.worker_name} train global step {global_step}", total=data.batch.batch_size[0] // backward_batch_size, ): - pg_metrics = self.strategy.train_step(batch=data, loss_func=self.loss_func) + pg_metrics = self.strategy.train_step(batch=mini_batch, loss_func=self.loss_func) append_to_dict(metrics, pg_metrics) metrics["actor/loss"] = np.mean(metrics["actor/loss"]) diff --git a/tests/test_train_step_variable_shadow.py b/tests/test_train_step_variable_shadow.py new file mode 100644 index 000000000..6598226c2 --- /dev/null +++ b/tests/test_train_step_variable_shadow.py @@ -0,0 +1,81 @@ +"""Test that loop variables in train_step don't shadow the outer `data` variable. + +Bug: In CriticWorker.train_step and diffusion ActorWorker.train_step, the for-loop +used `data` as both the outer batch variable and the loop iteration variable: + + for batch_idx, data in tqdm(enumerate(dataloader), ...): + ... + data.to("cpu") # Only releases last mini-batch, not the full batch + +This causes `data.to("cpu")` after the loop to move only the last mini-batch to CPU, +leaking the full batch on GPU memory. + +Fix: Rename loop variable to `mini_batch`. +""" + +import ast +import textwrap + + +def _get_for_loop_targets(source: str) -> list: + """Extract all for-loop target variable names from source code.""" + tree = ast.parse(textwrap.dedent(source)) + targets = [] + for node in ast.walk(tree): + if isinstance(node, ast.For): + # Handle `for batch_idx, var in ...` + target = node.target + if isinstance(target, ast.Tuple): + for elt in target.elts: + if isinstance(elt, ast.Name): + targets.append(elt.id) + elif isinstance(target, ast.Name): + targets.append(target.id) + return targets + + +class TestCriticWorkerTrainStepNoShadow: + """Verify CriticWorker.train_step loop variable doesn't shadow outer `data`.""" + + def test_base_worker_critic_train_step_no_data_shadow(self): + """The for-loop in CriticWorker.train_step must not use `data` as loop variable.""" + import roll.pipeline.base_worker as mod + import inspect + + source = inspect.getsource(mod.CriticWorker.train_step) + targets = _get_for_loop_targets(source) + assert "data" not in targets, ( + "CriticWorker.train_step for-loop uses `data` as loop variable, " + "which shadows the outer `data` and causes only the last mini-batch " + "to be moved to CPU instead of the full batch (GPU memory leak)." + ) + + def test_diffusion_actor_worker_train_step_no_data_shadow(self): + """The for-loop in diffusion ActorWorker.train_step must not use `data` as loop variable.""" + import roll.pipeline.diffusion.reward_fl.actor_worker as mod + import inspect + + source = inspect.getsource(mod.ActorWorker.train_step) + targets = _get_for_loop_targets(source) + assert "data" not in targets, ( + "diffusion ActorWorker.train_step for-loop uses `data` as loop variable, " + "which shadows the outer `data` and causes only the last mini-batch " + "to be moved to CPU instead of the full batch (GPU memory leak)." + ) + + def test_loop_variable_is_mini_batch(self): + """The loop variable should be renamed to `mini_batch`.""" + import roll.pipeline.base_worker as mod + import roll.pipeline.diffusion.reward_fl.actor_worker as diffusion_mod + import inspect + + for cls, name in [ + (mod.CriticWorker, "CriticWorker"), + (diffusion_mod.ActorWorker, "diffusion ActorWorker"), + ]: + source = inspect.getsource(cls.train_step) + targets = _get_for_loop_targets(source) + assert "mini_batch" in targets, ( + f"{name}.train_step for-loop should use `mini_batch` as loop variable, " + f"got targets: {targets}" + )