Skip to content

Add GLM-4.7-Flash (30B MoE) training example#1215

Draft
tyler-griggs wants to merge 2 commits intomainfrom
tgriggs/glm47-example-config
Draft

Add GLM-4.7-Flash (30B MoE) training example#1215
tyler-griggs wants to merge 2 commits intomainfrom
tgriggs/glm47-example-config

Conversation

@tyler-griggs
Copy link
Member

@tyler-griggs tyler-griggs commented Feb 25, 2026

Summary

Adds a complete GRPO training example for GLM-4.7-Flash (zai-org/GLM-4.7-Flash), a DeepSeek-V3 architecture clone with MLA + MoE (64 routed experts, 4 active per token, ~3B active parameters).

Files

  • skyrl/train/config/examples/glm4_7_30b_moe_grpo.yaml — Hydra config
  • skyrl-train/examples/megatron/run_megatron_grpo_glm4_7_30b.sh — Launch script

Configuration

  • Parallelism: TP=2, EP=8 on 2 nodes x 8 H100 GPUs (16 GPUs total)
  • MoE flags: sigmoid routing, expert bias, grouped GEMM
  • Attention: flash_attn: false (MLA not compatible with TE flash attention)
  • Algorithm: GRPO with KL loss

Dependencies

Requires the following PRs to be merged first:


Open with Devin

Add GRPO training example for zai-org/GLM-4.7-Flash, a DeepSeek-V3
architecture clone with MLA + MoE (64 routed experts, 4 active per
token, ~3B active parameters).

Default config targets 1x8 H100 GPUs (TP=2, EP=4, ~7.5GB model weights
per GPU in bf16). Comments note how to scale to 16 GPUs (EP=8).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@tyler-griggs tyler-griggs force-pushed the tgriggs/glm47-example-config branch from 4c971eb to f43e651 Compare February 25, 2026 19:19
@tyler-griggs tyler-griggs marked this pull request as ready for review February 26, 2026 00:04
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a new training example for GLM-4.7-Flash. The changes look good overall, but there are a couple of issues to address. First, the launch script duplicates a lot of configuration from the YAML file, which is a maintainability concern. I've suggested refactoring the script to load the YAML file directly. Second, the generator configuration in the YAML file is inconsistent with the documented 8-GPU setup, as it's configured for 16 GPUs. I've suggested correcting these values.

Comment on lines +43 to +107
uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path=$MODEL_NAME \
trainer.placement.colocate_all=true \
trainer.strategy=megatron \
trainer.placement.policy_num_nodes=$NUM_NODES \
trainer.placement.ref_num_nodes=$NUM_NODES \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.num_inference_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine_tensor_parallel_size=$INFERENCE_ENGINE_TP \
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.ref.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.ref.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.ref.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.ref.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.ref.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.policy.megatron_config.moe_token_dispatcher_type=$MOE_TOKEN_DISPATCHER \
trainer.policy.megatron_config.moe_router_load_balancing_type=$MOE_ROUTER_LB \
trainer.policy.megatron_config.moe_grouped_gemm=$MOE_GROUPED_GEMM \
trainer.policy.megatron_config.moe_router_score_function=$MOE_ROUTER_SCORE_FN \
trainer.policy.megatron_config.moe_router_enable_expert_bias=$MOE_ROUTER_EXPERT_BIAS \
trainer.ref.megatron_config.moe_token_dispatcher_type=$MOE_TOKEN_DISPATCHER \
trainer.ref.megatron_config.moe_router_load_balancing_type=$MOE_ROUTER_LB \
trainer.ref.megatron_config.moe_grouped_gemm=$MOE_GROUPED_GEMM \
trainer.ref.megatron_config.moe_router_score_function=$MOE_ROUTER_SCORE_FN \
trainer.ref.megatron_config.moe_router_enable_expert_bias=$MOE_ROUTER_EXPERT_BIAS \
trainer.use_sample_packing=true \
trainer.flash_attn=$FLASH_ATTN \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=false \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=128 \
trainer.policy_mini_batch_size=64 \
trainer.micro_forward_batch_size_per_gpu=4 \
trainer.micro_train_batch_size_per_gpu=4 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.policy.optimizer_config.weight_decay=0.1 \
trainer.policy.optimizer_config.max_grad_norm=1.0 \
trainer.algorithm.use_kl_loss=true \
generator.backend=$INFERENCE_BACKEND \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.6 \
trainer.logger="$LOGGER" \
trainer.project_name="glm4_7_30b_grpo" \
trainer.run_name="glm4_7_30b_a3b_grpo_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/glm4_7_30b_a3b_grpo_megatron" \
$@
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The uv run command duplicates a large amount of configuration that is also present in the glm4_7_30b_moe_grpo.yaml file. This makes maintenance difficult, as changes need to be synchronized in two places. The script should be simplified to load the YAML configuration file using Hydra's command-line arguments and only override parameters that are specific to the execution environment.

This would make the script much shorter and more maintainable. For example, you could use --config-name to specify the example config. You might need to also use --config-dir or --config-path if the example config is not in Hydra's default search path.

A refactored command could look like this:

# Assuming the config path is correctly set up for Hydra to find the example config.
CONFIG_NAME="examples/glm4_7_30b_moe_grpo"

uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
  --config-name $CONFIG_NAME \
  data.train_data="['$DATA_DIR/train.parquet']" \
  data.val_data="['$DATA_DIR/validation.parquet']" \
  trainer.placement.policy_num_nodes=$NUM_NODES \
  trainer.placement.ref_num_nodes=$NUM_NODES \
  trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
  trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
  trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
  trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
  generator.num_inference_engines=$NUM_INFERENCE_ENGINES \
  generator.inference_engine_tensor_parallel_size=$INFERENCE_ENGINE_TP \
  trainer.logger="$LOGGER" \
  trainer.run_name="glm4_7_30b_a3b_grpo_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}" \
  trainer.ckpt_path="$HOME/ckpts/glm4_7_30b_a3b_grpo_megatron" \
  $@

Most of the parameters from line 46 to 106 could be removed from the script as they would be loaded from the YAML file. I've kept a few as an example of what might be useful to keep as overrides.

Comment on lines 114 to 115
num_inference_engines: 2
inference_engine_tensor_parallel_size: 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The generator configuration seems to be incorrect for the intended hardware setup. The comments state this configuration is for an 8-GPU setup, but the generator is configured with num_inference_engines: 2 and inference_engine_tensor_parallel_size: 8, which would require 16 GPUs (2 * 8). This contradicts the comment and the setup in the corresponding launch script run_megatron_grpo_glm4_7_30b.sh, which uses more reasonable values for an 8-GPU node (NUM_INFERENCE_ENGINES=1, INFERENCE_ENGINE_TP=4). Please update these values to be consistent with an 8-GPU setup.

  num_inference_engines: 1
  inference_engine_tensor_parallel_size: 4

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 potential issues.

View 3 additional findings in Devin Review.

Open in Devin Review

Comment on lines 114 to 115
num_inference_engines: 2
inference_engine_tensor_parallel_size: 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 YAML inference engine GPU count mismatches policy GPU count, causing assertion failure

The YAML config sets colocate_all: true with policy_num_nodes: 1 and policy_num_gpus_per_node: 8 (8 policy GPUs), but num_inference_engines: 2 and inference_engine_tensor_parallel_size: 8 (16 rollout GPUs). This will trigger the assertion at skyrl/train/utils/utils.py:363 which requires num_policy_gpus == num_rollout_gpus when colocate_all is true.

Root Cause

The config appears to have been copied from a 2-node setup (like the Qwen3-30B or Moonlight examples which use 2 nodes × 8 GPUs = 16 GPUs with num_inference_engines=2, inference_engine_tensor_parallel_size=8) without adjusting the inference engine settings for the 1-node/8-GPU target.

The validation at skyrl/train/utils/utils.py:354-366 computes:

  • num_policy_gpus = 1 * 8 = 8
  • num_rollout_gpus = 2 * 8 * 1 * 1 = 16

Since 8 != 16, the assertion fails and training cannot start.

Impact: The YAML config is unusable as-is — it will crash immediately at config validation.

Suggested change
num_inference_engines: 2
inference_engine_tensor_parallel_size: 8
num_inference_engines: 1
inference_engine_tensor_parallel_size: 8
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines 30 to 31
NUM_INFERENCE_ENGINES=1
INFERENCE_ENGINE_TP=4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Shell script inference engine GPU count mismatches policy GPU count, causing assertion failure

The shell script sets colocate_all=true with NUM_NODES=1 and NUM_GPUS=8 (8 policy GPUs), but NUM_INFERENCE_ENGINES=1 and INFERENCE_ENGINE_TP=4 (4 rollout GPUs). This will trigger the assertion at skyrl/train/utils/utils.py:363 which requires num_policy_gpus == num_rollout_gpus when colocate_all is true.

Root Cause

The validation at skyrl/train/utils/utils.py:354-366 computes:

  • num_policy_gpus = 1 * 8 = 8
  • num_rollout_gpus = 1 * 4 * 1 * 1 = 4

Since 8 != 4, the assertion fails. Comparing with working examples like run_megatron_moonlight.sh and run_megatron_qwen3-30b-a3b.sh which both use NUM_INFERENCE_ENGINES=2, INFERENCE_ENGINE_TP=8 for 16 GPUs (matching their NUM_NODES=2 * NUM_GPUS=8 = 16), the correct values for 1 node × 8 GPUs should be either NUM_INFERENCE_ENGINES=1, INFERENCE_ENGINE_TP=8 or NUM_INFERENCE_ENGINES=2, INFERENCE_ENGINE_TP=4.

Impact: The shell script will crash immediately at config validation and cannot be used to launch training.

Suggested change
NUM_INFERENCE_ENGINES=1
INFERENCE_ENGINE_TP=4
NUM_INFERENCE_ENGINES=1
INFERENCE_ENGINE_TP=8
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@tyler-griggs tyler-griggs marked this pull request as draft February 26, 2026 00:46
The colocate_all assertion requires num_policy_gpus == num_rollout_gpus.
- YAML: num_inference_engines was 2 * TP=8 = 16, but policy GPUs = 8. Fixed to 1 * 8 = 8.
- Shell: INFERENCE_ENGINE_TP was 4, giving 1 * 4 = 4 rollout GPUs vs 8 policy. Fixed to 8.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant