Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,18 @@ def init_configs(
provider.attention_backend = "flash" if flash_attn else "fused"
provider.variable_seq_lengths = True
provider.masked_softmax_fusion = True
provider.moe_token_dispatcher_type = "alltoall"
provider.moe_router_load_balancing_type = "none"

# Apply explicit MoE config fields to the provider.
# These replace the previously hardcoded values and can be further
# overridden by transformer_config_kwargs if needed.
provider.moe_token_dispatcher_type = megatron_config.moe_token_dispatcher_type
provider.moe_router_load_balancing_type = megatron_config.moe_router_load_balancing_type
provider.moe_grouped_gemm = megatron_config.moe_grouped_gemm
if megatron_config.moe_router_score_function is not None:
provider.moe_router_score_function = megatron_config.moe_router_score_function
if megatron_config.moe_router_enable_expert_bias is not None:
provider.moe_router_enable_expert_bias = megatron_config.moe_router_enable_expert_bias
Comment on lines +244 to +247
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability and reduce repetition, you can loop over the optional MoE configuration fields. This makes it easier to add more optional fields in the future.

Suggested change
if megatron_config.moe_router_score_function is not None:
provider.moe_router_score_function = megatron_config.moe_router_score_function
if megatron_config.moe_router_enable_expert_bias is not None:
provider.moe_router_enable_expert_bias = megatron_config.moe_router_enable_expert_bias
for field_name in ("moe_router_score_function", "moe_router_enable_expert_bias"):
value = getattr(megatron_config, field_name)
if value is not None:
setattr(provider, field_name, value)


# Apply any additional transformer config kwargs (can override the above).
for k, v in transformer_config_kwargs.items():
setattr(provider, k, v)
provider.finalize()
Expand Down
6 changes: 6 additions & 0 deletions skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ class MegatronConfig(BaseConfig):
context_parallel_size: int = 1
expert_model_parallel_size: int = 1
expert_tensor_parallel_size: Optional[int] = None
# MoE runtime configuration flags
moe_token_dispatcher_type: str = "alltoall"
moe_router_load_balancing_type: str = "none"
moe_grouped_gemm: bool = False
moe_router_score_function: Optional[str] = None
moe_router_enable_expert_bias: Optional[bool] = None
ddp_config: MegatronDDPConfig = field(default_factory=MegatronDDPConfig)
torch_profiler_config: MegatronTorchProfilerConfig = field(default_factory=MegatronTorchProfilerConfig)
lora_config: MegatronLoraConfig = field(default_factory=MegatronLoraConfig)
Expand Down
9 changes: 9 additions & 0 deletions skyrl/train/config/megatron_config/policy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ context_parallel_size: 1
expert_model_parallel_size: 1
expert_tensor_parallel_size: null

# MoE runtime configuration flags.
# These are set on the Megatron provider before finalization.
# For additional MoE flags not listed here, use transformer_config_kwargs.
moe_token_dispatcher_type: "alltoall"
moe_router_load_balancing_type: "none"
moe_grouped_gemm: false
moe_router_score_function: null
moe_router_enable_expert_bias: null

# pass-through config to Megatron's `DistributedDataParallelConfig` object
# https://github.com/NVIDIA/Megatron-LM/blob/core_r0.13.0/megatron/core/distributed/distributed_data_parallel_config.py#L8
ddp_config:
Expand Down
6 changes: 6 additions & 0 deletions skyrl/train/config/megatron_config/ref.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ context_parallel_size: 1
expert_model_parallel_size: 1
expert_tensor_parallel_size: 1

# MoE runtime configuration flags (should match policy for model consistency)
moe_token_dispatcher_type: "alltoall"
moe_router_load_balancing_type: "none"
moe_grouped_gemm: false
moe_router_score_function: null
moe_router_enable_expert_bias: null

model_config_kwargs: {}

Expand Down
70 changes: 70 additions & 0 deletions tests/backends/skyrl_train/distributed/test_moe_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Tests for MoE config fields on MegatronConfig dataclass."""

from skyrl.train.config.config import MegatronConfig, build_nested_dataclass


class TestMegatronConfigMoEFields:
"""Verify the 5 new MoE config fields exist with correct types and defaults."""

def test_moe_fields_exist(self):
cfg = MegatronConfig()
assert hasattr(cfg, "moe_token_dispatcher_type")
assert hasattr(cfg, "moe_router_load_balancing_type")
assert hasattr(cfg, "moe_grouped_gemm")
assert hasattr(cfg, "moe_router_score_function")
assert hasattr(cfg, "moe_router_enable_expert_bias")

def test_moe_field_defaults(self):
cfg = MegatronConfig()
assert cfg.moe_token_dispatcher_type == "alltoall"
assert cfg.moe_router_load_balancing_type == "none"
assert cfg.moe_grouped_gemm is False
assert cfg.moe_router_score_function is None
assert cfg.moe_router_enable_expert_bias is None

def test_moe_fields_override(self):
cfg = MegatronConfig(
moe_token_dispatcher_type="allgather",
moe_router_load_balancing_type="aux_loss",
moe_grouped_gemm=True,
moe_router_score_function="sigmoid",
moe_router_enable_expert_bias=True,
)
assert cfg.moe_token_dispatcher_type == "allgather"
assert cfg.moe_router_load_balancing_type == "aux_loss"
assert cfg.moe_grouped_gemm is True
assert cfg.moe_router_score_function == "sigmoid"
assert cfg.moe_router_enable_expert_bias is True

def test_moe_config_from_dict(self):
"""MoE fields should survive dict -> dataclass round-trip."""
d = {
"moe_token_dispatcher_type": "alltoall",
"moe_router_load_balancing_type": "none",
"moe_grouped_gemm": True,
"moe_router_score_function": "sigmoid",
"moe_router_enable_expert_bias": True,
}
cfg = build_nested_dataclass(MegatronConfig, d)
assert cfg.moe_token_dispatcher_type == "alltoall"
assert cfg.moe_router_load_balancing_type == "none"
assert cfg.moe_grouped_gemm is True
assert cfg.moe_router_score_function == "sigmoid"
assert cfg.moe_router_enable_expert_bias is True
Comment on lines +51 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test is not exhaustive. It's good practice to assert all fields that are being set from the dictionary to ensure the build_nested_dataclass function works as expected for all new fields. The assertions for moe_token_dispatcher_type and moe_router_load_balancing_type are missing.

Suggested change
assert cfg.moe_grouped_gemm is True
assert cfg.moe_router_score_function == "sigmoid"
assert cfg.moe_router_enable_expert_bias is True
assert cfg.moe_token_dispatcher_type == "alltoall"
assert cfg.moe_router_load_balancing_type == "none"
assert cfg.moe_grouped_gemm is True
assert cfg.moe_router_score_function == "sigmoid"
assert cfg.moe_router_enable_expert_bias is True


def test_backward_compatible_defaults(self):
"""Default values must match the old hardcoded values for backward compat."""
cfg = MegatronConfig()
assert cfg.moe_token_dispatcher_type == "alltoall"
assert cfg.moe_router_load_balancing_type == "none"

def test_parallelism_fields_unchanged(self):
"""Existing parallelism fields should still work."""
cfg = MegatronConfig(
tensor_model_parallel_size=4,
pipeline_model_parallel_size=2,
expert_model_parallel_size=8,
)
assert cfg.tensor_model_parallel_size == 4
assert cfg.pipeline_model_parallel_size == 2
assert cfg.expert_model_parallel_size == 8