-
Notifications
You must be signed in to change notification settings - Fork 263
Add configurable MoE runtime flags to MegatronConfig #1213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,70 @@ | ||||||||||||||||||
| """Tests for MoE config fields on MegatronConfig dataclass.""" | ||||||||||||||||||
|
|
||||||||||||||||||
| from skyrl.train.config.config import MegatronConfig, build_nested_dataclass | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class TestMegatronConfigMoEFields: | ||||||||||||||||||
| """Verify the 5 new MoE config fields exist with correct types and defaults.""" | ||||||||||||||||||
|
|
||||||||||||||||||
| def test_moe_fields_exist(self): | ||||||||||||||||||
| cfg = MegatronConfig() | ||||||||||||||||||
| assert hasattr(cfg, "moe_token_dispatcher_type") | ||||||||||||||||||
| assert hasattr(cfg, "moe_router_load_balancing_type") | ||||||||||||||||||
| assert hasattr(cfg, "moe_grouped_gemm") | ||||||||||||||||||
| assert hasattr(cfg, "moe_router_score_function") | ||||||||||||||||||
| assert hasattr(cfg, "moe_router_enable_expert_bias") | ||||||||||||||||||
|
|
||||||||||||||||||
| def test_moe_field_defaults(self): | ||||||||||||||||||
| cfg = MegatronConfig() | ||||||||||||||||||
| assert cfg.moe_token_dispatcher_type == "alltoall" | ||||||||||||||||||
| assert cfg.moe_router_load_balancing_type == "none" | ||||||||||||||||||
| assert cfg.moe_grouped_gemm is False | ||||||||||||||||||
| assert cfg.moe_router_score_function is None | ||||||||||||||||||
| assert cfg.moe_router_enable_expert_bias is None | ||||||||||||||||||
|
|
||||||||||||||||||
| def test_moe_fields_override(self): | ||||||||||||||||||
| cfg = MegatronConfig( | ||||||||||||||||||
| moe_token_dispatcher_type="allgather", | ||||||||||||||||||
| moe_router_load_balancing_type="aux_loss", | ||||||||||||||||||
| moe_grouped_gemm=True, | ||||||||||||||||||
| moe_router_score_function="sigmoid", | ||||||||||||||||||
| moe_router_enable_expert_bias=True, | ||||||||||||||||||
| ) | ||||||||||||||||||
| assert cfg.moe_token_dispatcher_type == "allgather" | ||||||||||||||||||
| assert cfg.moe_router_load_balancing_type == "aux_loss" | ||||||||||||||||||
| assert cfg.moe_grouped_gemm is True | ||||||||||||||||||
| assert cfg.moe_router_score_function == "sigmoid" | ||||||||||||||||||
| assert cfg.moe_router_enable_expert_bias is True | ||||||||||||||||||
|
|
||||||||||||||||||
| def test_moe_config_from_dict(self): | ||||||||||||||||||
| """MoE fields should survive dict -> dataclass round-trip.""" | ||||||||||||||||||
| d = { | ||||||||||||||||||
| "moe_token_dispatcher_type": "alltoall", | ||||||||||||||||||
| "moe_router_load_balancing_type": "none", | ||||||||||||||||||
| "moe_grouped_gemm": True, | ||||||||||||||||||
| "moe_router_score_function": "sigmoid", | ||||||||||||||||||
| "moe_router_enable_expert_bias": True, | ||||||||||||||||||
| } | ||||||||||||||||||
| cfg = build_nested_dataclass(MegatronConfig, d) | ||||||||||||||||||
| assert cfg.moe_token_dispatcher_type == "alltoall" | ||||||||||||||||||
| assert cfg.moe_router_load_balancing_type == "none" | ||||||||||||||||||
| assert cfg.moe_grouped_gemm is True | ||||||||||||||||||
| assert cfg.moe_router_score_function == "sigmoid" | ||||||||||||||||||
| assert cfg.moe_router_enable_expert_bias is True | ||||||||||||||||||
|
Comment on lines
+51
to
+53
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is not exhaustive. It's good practice to assert all fields that are being set from the dictionary to ensure the
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| def test_backward_compatible_defaults(self): | ||||||||||||||||||
| """Default values must match the old hardcoded values for backward compat.""" | ||||||||||||||||||
| cfg = MegatronConfig() | ||||||||||||||||||
| assert cfg.moe_token_dispatcher_type == "alltoall" | ||||||||||||||||||
| assert cfg.moe_router_load_balancing_type == "none" | ||||||||||||||||||
|
|
||||||||||||||||||
| def test_parallelism_fields_unchanged(self): | ||||||||||||||||||
| """Existing parallelism fields should still work.""" | ||||||||||||||||||
| cfg = MegatronConfig( | ||||||||||||||||||
| tensor_model_parallel_size=4, | ||||||||||||||||||
| pipeline_model_parallel_size=2, | ||||||||||||||||||
| expert_model_parallel_size=8, | ||||||||||||||||||
| ) | ||||||||||||||||||
| assert cfg.tensor_model_parallel_size == 4 | ||||||||||||||||||
| assert cfg.pipeline_model_parallel_size == 2 | ||||||||||||||||||
| assert cfg.expert_model_parallel_size == 8 | ||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve maintainability and reduce repetition, you can loop over the optional MoE configuration fields. This makes it easier to add more optional fields in the future.