diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 9ee7a44477..e531df74b9 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -235,9 +235,18 @@ def init_configs( provider.attention_backend = "flash" if flash_attn else "fused" provider.variable_seq_lengths = True provider.masked_softmax_fusion = True - provider.moe_token_dispatcher_type = "alltoall" - provider.moe_router_load_balancing_type = "none" - + # Apply explicit MoE config fields to the provider. + # These replace the previously hardcoded values and can be further + # overridden by transformer_config_kwargs if needed. + provider.moe_token_dispatcher_type = megatron_config.moe_token_dispatcher_type + provider.moe_router_load_balancing_type = megatron_config.moe_router_load_balancing_type + provider.moe_grouped_gemm = megatron_config.moe_grouped_gemm + if megatron_config.moe_router_score_function is not None: + provider.moe_router_score_function = megatron_config.moe_router_score_function + if megatron_config.moe_router_enable_expert_bias is not None: + provider.moe_router_enable_expert_bias = megatron_config.moe_router_enable_expert_bias + + # Apply any additional transformer config kwargs (can override the above). for k, v in transformer_config_kwargs.items(): setattr(provider, k, v) provider.finalize() diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 5e15cc2af6..bf6e1e49fb 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -145,6 +145,12 @@ class MegatronConfig(BaseConfig): context_parallel_size: int = 1 expert_model_parallel_size: int = 1 expert_tensor_parallel_size: Optional[int] = None + # MoE runtime configuration flags + moe_token_dispatcher_type: str = "alltoall" + moe_router_load_balancing_type: str = "none" + moe_grouped_gemm: bool = False + moe_router_score_function: Optional[str] = None + moe_router_enable_expert_bias: Optional[bool] = None ddp_config: MegatronDDPConfig = field(default_factory=MegatronDDPConfig) torch_profiler_config: MegatronTorchProfilerConfig = field(default_factory=MegatronTorchProfilerConfig) lora_config: MegatronLoraConfig = field(default_factory=MegatronLoraConfig) diff --git a/skyrl/train/config/megatron_config/policy.yaml b/skyrl/train/config/megatron_config/policy.yaml index 3ba0dc8f95..2ce2bb010c 100644 --- a/skyrl/train/config/megatron_config/policy.yaml +++ b/skyrl/train/config/megatron_config/policy.yaml @@ -5,6 +5,15 @@ context_parallel_size: 1 expert_model_parallel_size: 1 expert_tensor_parallel_size: null +# MoE runtime configuration flags. +# These are set on the Megatron provider before finalization. +# For additional MoE flags not listed here, use transformer_config_kwargs. +moe_token_dispatcher_type: "alltoall" +moe_router_load_balancing_type: "none" +moe_grouped_gemm: false +moe_router_score_function: null +moe_router_enable_expert_bias: null + # pass-through config to Megatron's `DistributedDataParallelConfig` object # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.13.0/megatron/core/distributed/distributed_data_parallel_config.py#L8 ddp_config: diff --git a/skyrl/train/config/megatron_config/ref.yaml b/skyrl/train/config/megatron_config/ref.yaml index dd5447a6bb..493986776e 100644 --- a/skyrl/train/config/megatron_config/ref.yaml +++ b/skyrl/train/config/megatron_config/ref.yaml @@ -5,6 +5,12 @@ context_parallel_size: 1 expert_model_parallel_size: 1 expert_tensor_parallel_size: 1 +# MoE runtime configuration flags (should match policy for model consistency) +moe_token_dispatcher_type: "alltoall" +moe_router_load_balancing_type: "none" +moe_grouped_gemm: false +moe_router_score_function: null +moe_router_enable_expert_bias: null model_config_kwargs: {} diff --git a/tests/backends/skyrl_train/distributed/test_moe_config.py b/tests/backends/skyrl_train/distributed/test_moe_config.py new file mode 100644 index 0000000000..3f3c8d6c57 --- /dev/null +++ b/tests/backends/skyrl_train/distributed/test_moe_config.py @@ -0,0 +1,70 @@ +"""Tests for MoE config fields on MegatronConfig dataclass.""" + +from skyrl.train.config.config import MegatronConfig, build_nested_dataclass + + +class TestMegatronConfigMoEFields: + """Verify the 5 new MoE config fields exist with correct types and defaults.""" + + def test_moe_fields_exist(self): + cfg = MegatronConfig() + assert hasattr(cfg, "moe_token_dispatcher_type") + assert hasattr(cfg, "moe_router_load_balancing_type") + assert hasattr(cfg, "moe_grouped_gemm") + assert hasattr(cfg, "moe_router_score_function") + assert hasattr(cfg, "moe_router_enable_expert_bias") + + def test_moe_field_defaults(self): + cfg = MegatronConfig() + assert cfg.moe_token_dispatcher_type == "alltoall" + assert cfg.moe_router_load_balancing_type == "none" + assert cfg.moe_grouped_gemm is False + assert cfg.moe_router_score_function is None + assert cfg.moe_router_enable_expert_bias is None + + def test_moe_fields_override(self): + cfg = MegatronConfig( + moe_token_dispatcher_type="allgather", + moe_router_load_balancing_type="aux_loss", + moe_grouped_gemm=True, + moe_router_score_function="sigmoid", + moe_router_enable_expert_bias=True, + ) + assert cfg.moe_token_dispatcher_type == "allgather" + assert cfg.moe_router_load_balancing_type == "aux_loss" + assert cfg.moe_grouped_gemm is True + assert cfg.moe_router_score_function == "sigmoid" + assert cfg.moe_router_enable_expert_bias is True + + def test_moe_config_from_dict(self): + """MoE fields should survive dict -> dataclass round-trip.""" + d = { + "moe_token_dispatcher_type": "alltoall", + "moe_router_load_balancing_type": "none", + "moe_grouped_gemm": True, + "moe_router_score_function": "sigmoid", + "moe_router_enable_expert_bias": True, + } + cfg = build_nested_dataclass(MegatronConfig, d) + assert cfg.moe_token_dispatcher_type == "alltoall" + assert cfg.moe_router_load_balancing_type == "none" + assert cfg.moe_grouped_gemm is True + assert cfg.moe_router_score_function == "sigmoid" + assert cfg.moe_router_enable_expert_bias is True + + def test_backward_compatible_defaults(self): + """Default values must match the old hardcoded values for backward compat.""" + cfg = MegatronConfig() + assert cfg.moe_token_dispatcher_type == "alltoall" + assert cfg.moe_router_load_balancing_type == "none" + + def test_parallelism_fields_unchanged(self): + """Existing parallelism fields should still work.""" + cfg = MegatronConfig( + tensor_model_parallel_size=4, + pipeline_model_parallel_size=2, + expert_model_parallel_size=8, + ) + assert cfg.tensor_model_parallel_size == 4 + assert cfg.pipeline_model_parallel_size == 2 + assert cfg.expert_model_parallel_size == 8