Skip to content

Fix Megatron backend correctness: grad_scale_func, PP seed variation, weight sync pause#1212

Draft
tyler-griggs wants to merge 2 commits intomainfrom
tgriggs/megatron-correctness-fixes
Draft

Fix Megatron backend correctness: grad_scale_func, PP seed variation, weight sync pause#1212
tyler-griggs wants to merge 2 commits intomainfrom
tgriggs/megatron-correctness-fixes

Conversation

@tyler-griggs
Copy link
Member

@tyler-griggs tyler-griggs commented Feb 25, 2026

Set grad_scale_func on Megatron model config

  • File: megatron_model_wrapper.py
  • Sets config.grad_scale_func = optimizer.scale_loss so Megatron's pipeline schedule can scale the loss before backward
  • Latent bug: no-op with bf16 default (scale=1.0), but needed for fp16 dynamic loss scaling, MoE auxiliary loss scaling, and explicit loss_scale configs
  • Guarded by if actor_optimizer is not None (ref model has no optimizer)

Vary random seed by PP rank

  • File: megatron_strategy.py
  • Adds seed = seed + 100 * mpu.get_pipeline_model_parallel_rank() at the top of set_seed()
  • Ensures different PP stages get different dropout masks and stochastic noise
  • No-op when PP=1 (rank=0, offset=0)

Add pause/resume for non-colocated weight sync

  • File: worker_dispatch.py
  • Adds pause_generation() before and resume_generation() after weight broadcast in the non-colocated path
  • Prevents inference engines from seeing partially-updated weights during NCCL transfer
  • Colocated path (sleep/wake_up) is unchanged

Open with Devin

…ibuted training

Three correctness fixes identified through gap analysis:

- Set config.grad_scale_func = optimizer.scale_loss on the Megatron model
  config (C1). Latent bug: no-op with bf16 default but needed for fp16
  dynamic loss scaling, MoE auxiliary loss scaling, and explicit loss_scale.

- Vary random seed by PP rank: seed + 100 * pp_rank (C4). Ensures different
  pipeline stages get different dropout masks and stochastic noise, matching
  Megatron standard practice. No-op when PP=1.

- Add pause_generation/resume_generation for non-colocated weight sync (C5).
  Prevents inference engines from seeing partially-updated weights during
  NCCL broadcast. The colocated path (sleep/wake_up) is unchanged.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@tyler-griggs tyler-griggs force-pushed the tgriggs/megatron-correctness-fixes branch from a124b74 to 79d2ace Compare February 25, 2026 18:50
@tyler-griggs tyler-griggs marked this pull request as ready for review February 25, 2026 19:02
Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 potential issue.

View 5 additional findings in Devin Review.

Open in Devin Review

Comment on lines 397 to 406
else:
# Non-colocated: pause generation to prevent in-flight requests from
# reading partially-updated weights during the NCCL broadcast.
await self._inference_engine_client.pause_generation()
self.broadcast_to_inference_engines(self._inference_engine_client)
self.finish_weight_sync()
if self.colocate_all:
await self._inference_engine_client.wake_up(tags=["kv_cache"])
else:
await self._inference_engine_client.resume_generation()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Missing try/finally around pause_generation/resume_generation leaves inference engine permanently paused on failure

In the non-colocated path of save_weights_for_sampler, if broadcast_to_inference_engines() or finish_weight_sync() raises an exception after pause_generation() has been called, resume_generation() is never invoked.

Root Cause and Impact

pause_generation() at skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:614 sets self.generation_paused_event, which blocks all new and in-flight inference requests. If the broadcast or finish step fails (e.g. NCCL timeout, network error), the event is never cleared because resume_generation() at line 406 is skipped.

This leaves the inference engine permanently paused. Worse, any subsequent call to pause_generation() will raise RuntimeError("Generation is already paused, cannot pause again.") (line 612-613 of inference_engine_client.py), making recovery impossible without a full restart.

The fix is to wrap lines 401-406 in a try/finally block so that resume_generation() is always called in the non-colocated path, even if an exception occurs during broadcast or finish.

Suggested change
else:
# Non-colocated: pause generation to prevent in-flight requests from
# reading partially-updated weights during the NCCL broadcast.
await self._inference_engine_client.pause_generation()
self.broadcast_to_inference_engines(self._inference_engine_client)
self.finish_weight_sync()
if self.colocate_all:
await self._inference_engine_client.wake_up(tags=["kv_cache"])
else:
await self._inference_engine_client.resume_generation()
else:
# Non-colocated: pause generation to prevent in-flight requests from
# reading partially-updated weights during the NCCL broadcast.
await self._inference_engine_client.pause_generation()
try:
self.broadcast_to_inference_engines(self._inference_engine_client)
self.finish_weight_sync()
finally:
if self.colocate_all:
await self._inference_engine_client.wake_up(tags=["kv_cache"])
else:
await self._inference_engine_client.resume_generation()
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces three important correctness fixes for the Megatron backend, addressing gradient scaling, random seed variation across pipeline parallel ranks, and safe weight synchronization in non-colocated setups. The changes are well-implemented and accompanied by a new suite of tests to ensure their correctness. My review includes a high-severity suggestion to improve the robustness of the weight synchronization logic using a try...finally block to prevent the system from getting stuck in a paused state. I've also included a couple of medium-severity suggestions to improve code maintainability by removing a magic number and refactoring a test for clarity. Overall, this is a solid contribution that enhances the stability of the distributed training backend.

Comment on lines 395 to 406
if self.colocate_all:
await self._inference_engine_client.wake_up(tags=["weights"])
else:
# Non-colocated: pause generation to prevent in-flight requests from
# reading partially-updated weights during the NCCL broadcast.
await self._inference_engine_client.pause_generation()
self.broadcast_to_inference_engines(self._inference_engine_client)
self.finish_weight_sync()
if self.colocate_all:
await self._inference_engine_client.wake_up(tags=["kv_cache"])
else:
await self._inference_engine_client.resume_generation()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation does not guarantee that resume_generation() will be called if an exception occurs during broadcast_to_inference_engines() or finish_weight_sync(). This could leave the inference engines in a permanently paused state. It's safer to wrap the broadcast logic in a try...finally block to ensure that resume_generation() is always executed.

Suggested change
if self.colocate_all:
await self._inference_engine_client.wake_up(tags=["weights"])
else:
# Non-colocated: pause generation to prevent in-flight requests from
# reading partially-updated weights during the NCCL broadcast.
await self._inference_engine_client.pause_generation()
self.broadcast_to_inference_engines(self._inference_engine_client)
self.finish_weight_sync()
if self.colocate_all:
await self._inference_engine_client.wake_up(tags=["kv_cache"])
else:
await self._inference_engine_client.resume_generation()
try:
if self.colocate_all:
await self._inference_engine_client.wake_up(tags=["weights"])
else:
# Non-colocated: pause generation to prevent in-flight requests from
# reading partially-updated weights during the NCCL broadcast.
await self._inference_engine_client.pause_generation()
self.broadcast_to_inference_engines(self._inference_engine_client)
finally:
self.finish_weight_sync()
if self.colocate_all:
await self._inference_engine_client.wake_up(tags=["kv_cache"])
else:
await self._inference_engine_client.resume_generation()

# Vary seed by pipeline parallel rank so that different PP stages get
# different dropout masks and stochastic noise (matches Megatron standard
# practice).
seed = seed + 100 * mpu.get_pipeline_model_parallel_rank()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic number 100 is used here. It's better to define it as a module-level constant with a descriptive name, e.g., PIPELINE_PARALLEL_SEED_OFFSET = 100. This improves code clarity and makes it easier to change this value in the future if needed.

Comment on lines 91 to 115
def test_seed_offset_by_pp_rank(self):
from skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy import (
MegatronStrategy,
)
from skyrl.train.config.config import MegatronConfig

strategy = MegatronStrategy(megatron_config=MegatronConfig(), seed=42)

with patch("skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy.mpu") as mock_mpu:
seeds_seen = []

mock_mpu.get_pipeline_model_parallel_rank.return_value = 0
with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)):
strategy.set_seed(42)
assert seeds_seen[-1] == 42

mock_mpu.get_pipeline_model_parallel_rank.return_value = 1
with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)):
strategy.set_seed(42)
assert seeds_seen[-1] == 142

mock_mpu.get_pipeline_model_parallel_rank.return_value = 3
with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)):
strategy.set_seed(42)
assert seeds_seen[-1] == 342
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test can be made more concise and easier to extend by using pytest.mark.parametrize. This would eliminate the repeated blocks of code for each pipeline parallel rank being tested.

Suggested change
def test_seed_offset_by_pp_rank(self):
from skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy import (
MegatronStrategy,
)
from skyrl.train.config.config import MegatronConfig
strategy = MegatronStrategy(megatron_config=MegatronConfig(), seed=42)
with patch("skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy.mpu") as mock_mpu:
seeds_seen = []
mock_mpu.get_pipeline_model_parallel_rank.return_value = 0
with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)):
strategy.set_seed(42)
assert seeds_seen[-1] == 42
mock_mpu.get_pipeline_model_parallel_rank.return_value = 1
with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)):
strategy.set_seed(42)
assert seeds_seen[-1] == 142
mock_mpu.get_pipeline_model_parallel_rank.return_value = 3
with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)):
strategy.set_seed(42)
assert seeds_seen[-1] == 342
@pytest.mark.parametrize("pp_rank, expected_seed", [
(0, 42),
(1, 142),
(3, 342),
])
def test_seed_offset_by_pp_rank(self, pp_rank, expected_seed):
from skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy import (
MegatronStrategy,
)
from skyrl.train.config.config import MegatronConfig
strategy = MegatronStrategy(megatron_config=MegatronConfig(), seed=42)
with patch("skyrl.backends.skyrl_train.distributed.megatron.megatron_strategy.mpu") as mock_mpu:
seeds_seen = []
mock_mpu.get_pipeline_model_parallel_rank.return_value = pp_rank
with patch("random.seed", side_effect=lambda s: seeds_seen.append(s)):
strategy.set_seed(42)
assert seeds_seen[-1] == expected_seed

@tyler-griggs tyler-griggs marked this pull request as draft February 26, 2026 00:03
…stant

- Wrap non-colocated broadcast in try/finally so resume_generation is
  always called, even if broadcast_to_inference_engines raises. Prevents
  inference engines from being permanently paused on failure.
- Extract seed offset magic number 100 to _PP_SEED_OFFSET constant.
- Parametrize seed variation test with pytest.mark.parametrize.
- Add test for resume-on-failure behavior.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant