From 7ee058add08ae0d65996a394dac170a817081270 Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Wed, 25 Feb 2026 18:28:11 +0000 Subject: [PATCH 1/2] Add configurable MoE runtime flags to MegatronConfig Expose 5 MoE runtime config flags as first-class MegatronConfig fields: - moe_token_dispatcher_type (replaces hardcoded "alltoall") - moe_router_load_balancing_type (replaces hardcoded "none") - moe_grouped_gemm (enables fused grouped GEMM for MoE) - moe_router_score_function (e.g. "sigmoid" for GLM/DeepSeek-V3) - moe_router_enable_expert_bias (learned bias for load balancing) Most architecture flags (num_experts, topk, qkv_bias, rotary, etc.) are auto-detected by AutoBridge from the HF config.json. These 5 flags control runtime behavior that is NOT in the HF config. Advanced/model-specific flags can still be passed through the existing transformer_config_kwargs dict, which is applied after these fields. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../workers/megatron/megatron_worker.py | 15 +++- skyrl/train/config/config.py | 6 ++ .../train/config/megatron_config/policy.yaml | 9 +++ skyrl/train/config/megatron_config/ref.yaml | 6 ++ .../distributed/test_moe_config.py | 68 +++++++++++++++++++ 5 files changed, 101 insertions(+), 3 deletions(-) create mode 100644 tests/backends/skyrl_train/distributed/test_moe_config.py diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 9ee7a44477..e531df74b9 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -235,9 +235,18 @@ def init_configs( provider.attention_backend = "flash" if flash_attn else "fused" provider.variable_seq_lengths = True provider.masked_softmax_fusion = True - provider.moe_token_dispatcher_type = "alltoall" - provider.moe_router_load_balancing_type = "none" - + # Apply explicit MoE config fields to the provider. + # These replace the previously hardcoded values and can be further + # overridden by transformer_config_kwargs if needed. + provider.moe_token_dispatcher_type = megatron_config.moe_token_dispatcher_type + provider.moe_router_load_balancing_type = megatron_config.moe_router_load_balancing_type + provider.moe_grouped_gemm = megatron_config.moe_grouped_gemm + if megatron_config.moe_router_score_function is not None: + provider.moe_router_score_function = megatron_config.moe_router_score_function + if megatron_config.moe_router_enable_expert_bias is not None: + provider.moe_router_enable_expert_bias = megatron_config.moe_router_enable_expert_bias + + # Apply any additional transformer config kwargs (can override the above). for k, v in transformer_config_kwargs.items(): setattr(provider, k, v) provider.finalize() diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 5e15cc2af6..bf6e1e49fb 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -145,6 +145,12 @@ class MegatronConfig(BaseConfig): context_parallel_size: int = 1 expert_model_parallel_size: int = 1 expert_tensor_parallel_size: Optional[int] = None + # MoE runtime configuration flags + moe_token_dispatcher_type: str = "alltoall" + moe_router_load_balancing_type: str = "none" + moe_grouped_gemm: bool = False + moe_router_score_function: Optional[str] = None + moe_router_enable_expert_bias: Optional[bool] = None ddp_config: MegatronDDPConfig = field(default_factory=MegatronDDPConfig) torch_profiler_config: MegatronTorchProfilerConfig = field(default_factory=MegatronTorchProfilerConfig) lora_config: MegatronLoraConfig = field(default_factory=MegatronLoraConfig) diff --git a/skyrl/train/config/megatron_config/policy.yaml b/skyrl/train/config/megatron_config/policy.yaml index 3ba0dc8f95..2ce2bb010c 100644 --- a/skyrl/train/config/megatron_config/policy.yaml +++ b/skyrl/train/config/megatron_config/policy.yaml @@ -5,6 +5,15 @@ context_parallel_size: 1 expert_model_parallel_size: 1 expert_tensor_parallel_size: null +# MoE runtime configuration flags. +# These are set on the Megatron provider before finalization. +# For additional MoE flags not listed here, use transformer_config_kwargs. +moe_token_dispatcher_type: "alltoall" +moe_router_load_balancing_type: "none" +moe_grouped_gemm: false +moe_router_score_function: null +moe_router_enable_expert_bias: null + # pass-through config to Megatron's `DistributedDataParallelConfig` object # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.13.0/megatron/core/distributed/distributed_data_parallel_config.py#L8 ddp_config: diff --git a/skyrl/train/config/megatron_config/ref.yaml b/skyrl/train/config/megatron_config/ref.yaml index dd5447a6bb..493986776e 100644 --- a/skyrl/train/config/megatron_config/ref.yaml +++ b/skyrl/train/config/megatron_config/ref.yaml @@ -5,6 +5,12 @@ context_parallel_size: 1 expert_model_parallel_size: 1 expert_tensor_parallel_size: 1 +# MoE runtime configuration flags (should match policy for model consistency) +moe_token_dispatcher_type: "alltoall" +moe_router_load_balancing_type: "none" +moe_grouped_gemm: false +moe_router_score_function: null +moe_router_enable_expert_bias: null model_config_kwargs: {} diff --git a/tests/backends/skyrl_train/distributed/test_moe_config.py b/tests/backends/skyrl_train/distributed/test_moe_config.py new file mode 100644 index 0000000000..12315224c1 --- /dev/null +++ b/tests/backends/skyrl_train/distributed/test_moe_config.py @@ -0,0 +1,68 @@ +"""Tests for MoE config fields on MegatronConfig dataclass.""" + +from skyrl.train.config.config import MegatronConfig, build_nested_dataclass + + +class TestMegatronConfigMoEFields: + """Verify the 5 new MoE config fields exist with correct types and defaults.""" + + def test_moe_fields_exist(self): + cfg = MegatronConfig() + assert hasattr(cfg, "moe_token_dispatcher_type") + assert hasattr(cfg, "moe_router_load_balancing_type") + assert hasattr(cfg, "moe_grouped_gemm") + assert hasattr(cfg, "moe_router_score_function") + assert hasattr(cfg, "moe_router_enable_expert_bias") + + def test_moe_field_defaults(self): + cfg = MegatronConfig() + assert cfg.moe_token_dispatcher_type == "alltoall" + assert cfg.moe_router_load_balancing_type == "none" + assert cfg.moe_grouped_gemm is False + assert cfg.moe_router_score_function is None + assert cfg.moe_router_enable_expert_bias is None + + def test_moe_fields_override(self): + cfg = MegatronConfig( + moe_token_dispatcher_type="allgather", + moe_router_load_balancing_type="aux_loss", + moe_grouped_gemm=True, + moe_router_score_function="sigmoid", + moe_router_enable_expert_bias=True, + ) + assert cfg.moe_token_dispatcher_type == "allgather" + assert cfg.moe_router_load_balancing_type == "aux_loss" + assert cfg.moe_grouped_gemm is True + assert cfg.moe_router_score_function == "sigmoid" + assert cfg.moe_router_enable_expert_bias is True + + def test_moe_config_from_dict(self): + """MoE fields should survive dict -> dataclass round-trip.""" + d = { + "moe_token_dispatcher_type": "alltoall", + "moe_router_load_balancing_type": "none", + "moe_grouped_gemm": True, + "moe_router_score_function": "sigmoid", + "moe_router_enable_expert_bias": True, + } + cfg = build_nested_dataclass(MegatronConfig, d) + assert cfg.moe_grouped_gemm is True + assert cfg.moe_router_score_function == "sigmoid" + assert cfg.moe_router_enable_expert_bias is True + + def test_backward_compatible_defaults(self): + """Default values must match the old hardcoded values for backward compat.""" + cfg = MegatronConfig() + assert cfg.moe_token_dispatcher_type == "alltoall" + assert cfg.moe_router_load_balancing_type == "none" + + def test_parallelism_fields_unchanged(self): + """Existing parallelism fields should still work.""" + cfg = MegatronConfig( + tensor_model_parallel_size=4, + pipeline_model_parallel_size=2, + expert_model_parallel_size=8, + ) + assert cfg.tensor_model_parallel_size == 4 + assert cfg.pipeline_model_parallel_size == 2 + assert cfg.expert_model_parallel_size == 8 From f97ab3a16435362af664574e9ec21135b7b84a22 Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Thu, 26 Feb 2026 00:56:25 +0000 Subject: [PATCH 2/2] Add missing assertions for moe_token_dispatcher_type and moe_router_load_balancing_type in test_moe_config_from_dict Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/backends/skyrl_train/distributed/test_moe_config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/backends/skyrl_train/distributed/test_moe_config.py b/tests/backends/skyrl_train/distributed/test_moe_config.py index 12315224c1..3f3c8d6c57 100644 --- a/tests/backends/skyrl_train/distributed/test_moe_config.py +++ b/tests/backends/skyrl_train/distributed/test_moe_config.py @@ -46,6 +46,8 @@ def test_moe_config_from_dict(self): "moe_router_enable_expert_bias": True, } cfg = build_nested_dataclass(MegatronConfig, d) + assert cfg.moe_token_dispatcher_type == "alltoall" + assert cfg.moe_router_load_balancing_type == "none" assert cfg.moe_grouped_gemm is True assert cfg.moe_router_score_function == "sigmoid" assert cfg.moe_router_enable_expert_bias is True