From b7a5cf608b04cfb23e36ce981383445410ffc616 Mon Sep 17 00:00:00 2001 From: dubin555 Date: Sat, 14 Mar 2026 10:12:54 +0000 Subject: [PATCH] fix: use Cluster instead of WorkerConfig for dynamic batching dp_size When `use_ref_model=False`, the `worker` variable was set to `self.pipeline_config.actor_train` (a WorkerConfig) instead of `self.actor_train` (a Cluster). WorkerConfig does not have `dp_size`, so `worker.dp_size` raises AttributeError when dynamic batching is enabled for reference log prob computation without a separate reference model. Change the else-branch to use `self.actor_train` (Cluster) which has the `dp_size` property. --- roll/pipeline/rlvr/rlvr_pipeline.py | 2 +- tests/test_ref_worker_type_consistency.py | 65 +++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 tests/test_ref_worker_type_consistency.py diff --git a/roll/pipeline/rlvr/rlvr_pipeline.py b/roll/pipeline/rlvr/rlvr_pipeline.py index cd5823abc..5a133e9ad 100644 --- a/roll/pipeline/rlvr/rlvr_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_pipeline.py @@ -541,7 +541,7 @@ def run(self): with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer: if self.pipeline_config.enable_reference: worker_config = self.pipeline_config.reference if self.use_ref_model else self.pipeline_config.actor_train - worker = self.reference if self.use_ref_model else self.pipeline_config.actor_train + worker = self.reference if self.use_ref_model else self.actor_train if worker_config.use_dynamic_batching_in_infer: batch, dynamic_batching_metrics = dynamic_batching_shard( batch, diff --git a/tests/test_ref_worker_type_consistency.py b/tests/test_ref_worker_type_consistency.py new file mode 100644 index 000000000..6abddf7f0 --- /dev/null +++ b/tests/test_ref_worker_type_consistency.py @@ -0,0 +1,65 @@ +"""Test that reference log prob computation uses Cluster (not WorkerConfig) for dp_size. + +Bug: In RLVRPipeline._train, when use_ref_model=False: + + worker_config = self.pipeline_config.reference if self.use_ref_model else self.pipeline_config.actor_train + worker = self.reference if self.use_ref_model else self.pipeline_config.actor_train # BUG + +The `worker` variable is set to `self.pipeline_config.actor_train` (a WorkerConfig), +but it should be `self.actor_train` (a Cluster). WorkerConfig has no `dp_size` attribute, +so `worker.dp_size` on line 548 raises AttributeError. + +Fix: Change `self.pipeline_config.actor_train` to `self.actor_train` on that line. +""" + +import ast +import inspect +import textwrap + + +def test_ref_worker_uses_cluster_not_config(): + """When use_ref_model=False, `worker` must be `self.actor_train` (Cluster), not `self.pipeline_config.actor_train` (WorkerConfig).""" + import roll.pipeline.rlvr.rlvr_pipeline as mod + + source = inspect.getsource(mod.RLVRPipeline) + + # The buggy pattern: `self.pipeline_config.actor_train` used where `self.actor_train` is needed + # The fix ensures `worker = ... else self.actor_train` (without pipeline_config prefix) + # + # We check: in the line that assigns `worker = ...`, the else-branch must NOT + # reference `self.pipeline_config.actor_train` + tree = ast.parse(textwrap.dedent(source)) + + found_worker_assign = False + for node in ast.walk(tree): + if not isinstance(node, ast.Assign): + continue + # Look for: worker = + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "worker": + if isinstance(node.value, ast.IfExp): + found_worker_assign = True + # Check the orelse (else branch) of the ternary + orelse = node.value.orelse + # It should be self.actor_train, NOT self.pipeline_config.actor_train + source_segment = ast.dump(orelse) + assert "pipeline_config" not in source_segment, ( + "Bug: `worker` assignment else-branch references " + "`self.pipeline_config.actor_train` (WorkerConfig) instead of " + "`self.actor_train` (Cluster). WorkerConfig has no `dp_size` property." + ) + + assert found_worker_assign, ( + "Could not find `worker = ... if ... else ...` ternary assignment in RLVRPipeline. " + "The code structure may have changed." + ) + + +def test_worker_config_has_no_dp_size(): + """WorkerConfig should NOT have dp_size - it's only on Cluster.""" + from roll.configs.worker_config import WorkerConfig + + assert not hasattr(WorkerConfig, "dp_size"), ( + "WorkerConfig should not have dp_size attribute; " + "dp_size is a property of Cluster, not WorkerConfig." + )