Fix Megatron backend correctness: grad_scale_func, PP seed variation, weight sync pause#1212
Fix Megatron backend correctness: grad_scale_func, PP seed variation, weight sync pause#1212tyler-griggs wants to merge 2 commits intomainfrom
Conversation
…ibuted training Three correctness fixes identified through gap analysis: - Set config.grad_scale_func = optimizer.scale_loss on the Megatron model config (C1). Latent bug: no-op with bf16 default but needed for fp16 dynamic loss scaling, MoE auxiliary loss scaling, and explicit loss_scale. - Vary random seed by PP rank: seed + 100 * pp_rank (C4). Ensures different pipeline stages get different dropout masks and stochastic noise, matching Megatron standard practice. No-op when PP=1. - Add pause_generation/resume_generation for non-colocated weight sync (C5). Prevents inference engines from seeing partially-updated weights during NCCL broadcast. The colocated path (sleep/wake_up) is unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
a124b74 to
79d2ace
Compare
| else: | ||
| # Non-colocated: pause generation to prevent in-flight requests from | ||
| # reading partially-updated weights during the NCCL broadcast. | ||
| await self._inference_engine_client.pause_generation() | ||
| self.broadcast_to_inference_engines(self._inference_engine_client) | ||
| self.finish_weight_sync() | ||
| if self.colocate_all: | ||
| await self._inference_engine_client.wake_up(tags=["kv_cache"]) | ||
| else: | ||
| await self._inference_engine_client.resume_generation() |
There was a problem hiding this comment.
🔴 Missing try/finally around pause_generation/resume_generation leaves inference engine permanently paused on failure
In the non-colocated path of save_weights_for_sampler, if broadcast_to_inference_engines() or finish_weight_sync() raises an exception after pause_generation() has been called, resume_generation() is never invoked.
Root Cause and Impact
pause_generation() at skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:614 sets self.generation_paused_event, which blocks all new and in-flight inference requests. If the broadcast or finish step fails (e.g. NCCL timeout, network error), the event is never cleared because resume_generation() at line 406 is skipped.
This leaves the inference engine permanently paused. Worse, any subsequent call to pause_generation() will raise RuntimeError("Generation is already paused, cannot pause again.") (line 612-613 of inference_engine_client.py), making recovery impossible without a full restart.
The fix is to wrap lines 401-406 in a try/finally block so that resume_generation() is always called in the non-colocated path, even if an exception occurs during broadcast or finish.
| else: | |
| # Non-colocated: pause generation to prevent in-flight requests from | |
| # reading partially-updated weights during the NCCL broadcast. | |
| await self._inference_engine_client.pause_generation() | |
| self.broadcast_to_inference_engines(self._inference_engine_client) | |
| self.finish_weight_sync() | |
| if self.colocate_all: | |
| await self._inference_engine_client.wake_up(tags=["kv_cache"]) | |
| else: | |
| await self._inference_engine_client.resume_generation() | |
| else: | |
| # Non-colocated: pause generation to prevent in-flight requests from | |
| # reading partially-updated weights during the NCCL broadcast. | |
| await self._inference_engine_client.pause_generation() | |
| try: | |
| self.broadcast_to_inference_engines(self._inference_engine_client) | |
| self.finish_weight_sync() | |
| finally: | |
| if self.colocate_all: | |
| await self._inference_engine_client.wake_up(tags=["kv_cache"]) | |
| else: | |
| await self._inference_engine_client.resume_generation() |
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
Code Review
This pull request introduces three important correctness fixes for the Megatron backend, addressing gradient scaling, random seed variation across pipeline parallel ranks, and safe weight synchronization in non-colocated setups. The changes are well-implemented and accompanied by a new suite of tests to ensure their correctness. My review includes a high-severity suggestion to improve the robustness of the weight synchronization logic using a try...finally block to prevent the system from getting stuck in a paused state. I've also included a couple of medium-severity suggestions to improve code maintainability by removing a magic number and refactoring a test for clarity. Overall, this is a solid contribution that enhances the stability of the distributed training backend.
| if self.colocate_all: | ||
| await self._inference_engine_client.wake_up(tags=["weights"]) | ||
| else: | ||
| # Non-colocated: pause generation to prevent in-flight requests from | ||
| # reading partially-updated weights during the NCCL broadcast. | ||
| await self._inference_engine_client.pause_generation() | ||
| self.broadcast_to_inference_engines(self._inference_engine_client) | ||
| self.finish_weight_sync() | ||
| if self.colocate_all: | ||
| await self._inference_engine_client.wake_up(tags=["kv_cache"]) | ||
| else: | ||
| await self._inference_engine_client.resume_generation() |
There was a problem hiding this comment.
The current implementation does not guarantee that resume_generation() will be called if an exception occurs during broadcast_to_inference_engines() or finish_weight_sync(). This could leave the inference engines in a permanently paused state. It's safer to wrap the broadcast logic in a try...finally block to ensure that resume_generation() is always executed.
| if self.colocate_all: | |
| await self._inference_engine_client.wake_up(tags=["weights"]) | |
| else: | |
| # Non-colocated: pause generation to prevent in-flight requests from | |
| # reading partially-updated weights during the NCCL broadcast. | |
| await self._inference_engine_client.pause_generation() | |
| self.broadcast_to_inference_engines(self._inference_engine_client) | |
| self.finish_weight_sync() | |
| if self.colocate_all: | |
| await self._inference_engine_client.wake_up(tags=["kv_cache"]) | |
| else: | |
| await self._inference_engine_client.resume_generation() | |
| try: | |
| if self.colocate_all: | |
| await self._inference_engine_client.wake_up(tags=["weights"]) | |
| else: | |
| # Non-colocated: pause generation to prevent in-flight requests from | |
| # reading partially-updated weights during the NCCL broadcast. | |
| await self._inference_engine_client.pause_generation() | |
| self.broadcast_to_inference_engines(self._inference_engine_client) | |
| finally: | |
| self.finish_weight_sync() | |
| if self.colocate_all: | |
| await self._inference_engine_client.wake_up(tags=["kv_cache"]) | |
| else: | |
| await self._inference_engine_client.resume_generation() |
| # Vary seed by pipeline parallel rank so that different PP stages get | ||
| # different dropout masks and stochastic noise (matches Megatron standard | ||
| # practice). | ||
| seed = seed + 100 * mpu.get_pipeline_model_parallel_rank() |
| def test_seed_offset_by_pp_rank(self): | ||
| from skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy import ( | ||
| MegatronStrategy, | ||
| ) | ||
| from skyrl.train.config.config import MegatronConfig | ||
|
|
||
| strategy = MegatronStrategy(megatron_config=MegatronConfig(), seed=42) | ||
|
|
||
| with patch("skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy.mpu") as mock_mpu: | ||
| seeds_seen = [] | ||
|
|
||
| mock_mpu.get_pipeline_model_parallel_rank.return_value = 0 | ||
| with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)): | ||
| strategy.set_seed(42) | ||
| assert seeds_seen[-1] == 42 | ||
|
|
||
| mock_mpu.get_pipeline_model_parallel_rank.return_value = 1 | ||
| with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)): | ||
| strategy.set_seed(42) | ||
| assert seeds_seen[-1] == 142 | ||
|
|
||
| mock_mpu.get_pipeline_model_parallel_rank.return_value = 3 | ||
| with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)): | ||
| strategy.set_seed(42) | ||
| assert seeds_seen[-1] == 342 |
There was a problem hiding this comment.
This test can be made more concise and easier to extend by using pytest.mark.parametrize. This would eliminate the repeated blocks of code for each pipeline parallel rank being tested.
| def test_seed_offset_by_pp_rank(self): | |
| from skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy import ( | |
| MegatronStrategy, | |
| ) | |
| from skyrl.train.config.config import MegatronConfig | |
| strategy = MegatronStrategy(megatron_config=MegatronConfig(), seed=42) | |
| with patch("skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy.mpu") as mock_mpu: | |
| seeds_seen = [] | |
| mock_mpu.get_pipeline_model_parallel_rank.return_value = 0 | |
| with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)): | |
| strategy.set_seed(42) | |
| assert seeds_seen[-1] == 42 | |
| mock_mpu.get_pipeline_model_parallel_rank.return_value = 1 | |
| with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)): | |
| strategy.set_seed(42) | |
| assert seeds_seen[-1] == 142 | |
| mock_mpu.get_pipeline_model_parallel_rank.return_value = 3 | |
| with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)): | |
| strategy.set_seed(42) | |
| assert seeds_seen[-1] == 342 | |
| @pytest.mark.parametrize("pp_rank, expected_seed", [ | |
| (0, 42), | |
| (1, 142), | |
| (3, 342), | |
| ]) | |
| def test_seed_offset_by_pp_rank(self, pp_rank, expected_seed): | |
| from skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy import ( | |
| MegatronStrategy, | |
| ) | |
| from skyrl.train.config.config import MegatronConfig | |
| strategy = MegatronStrategy(megatron_config=MegatronConfig(), seed=42) | |
| with patch("skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy.mpu") as mock_mpu: | |
| seeds_seen = [] | |
| mock_mpu.get_pipeline_model_parallel_rank.return_value = pp_rank | |
| with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)): | |
| strategy.set_seed(42) | |
| assert seeds_seen[-1] == expected_seed |
…stant - Wrap non-colocated broadcast in try/finally so resume_generation is always called, even if broadcast_to_inference_engines raises. Prevents inference engines from being permanently paused on failure. - Extract seed offset magic number 100 to _PP_SEED_OFFSET constant. - Parametrize seed variation test with pytest.mark.parametrize. - Add test for resume-on-failure behavior. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Set
grad_scale_funcon Megatron model configmegatron_model_wrapper.pyconfig.grad_scale_func = optimizer.scale_lossso Megatron's pipeline schedule can scale the loss before backwardif actor_optimizer is not None(ref model has no optimizer)Vary random seed by PP rank
megatron_strategy.pyseed = seed + 100 * mpu.get_pipeline_model_parallel_rank()at the top ofset_seed()Add pause/resume for non-colocated weight sync
worker_dispatch.pypause_generation()before andresume_generation()after weight broadcast in the non-colocated path