Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ grpo:
skip_reference_policy_logprobs_calculation: true
seq_logprob_error_threshold: null

async_grpo:
enabled: false # Set to true to enable async training mode
# Max age (in training steps) for trajectories used in training
max_trajectory_age_steps: 1
in_flight_weight_updates: false # Set to true to enable in-flight weight updates
recompute_kv_cache_after_weight_updates: false # Set to true to recompute kv cache after in-flight-weight-updates

loss_fn:
reference_policy_kl_penalty: 0
reference_policy_kl_type: "k3"
Expand Down
48 changes: 48 additions & 0 deletions examples/nemo_gym/run_grpo_nemo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,55 @@ def main() -> None:
logger=logger,
master_config=master_config,
)
# Check if async mode is enabled
elif "async_grpo" in config["grpo"] and config["grpo"]["async_grpo"]["enabled"]:
# Async GRPO does not support dynamic sampling, reward scaling, or reward shaping (DAPO features)
unsupported_features = [
"use_dynamic_sampling",
"reward_scaling",
"reward_shaping",
]

for feature in unsupported_features:
if feature not in config["grpo"]:
continue

if feature == "use_dynamic_sampling":
if config["grpo"][feature]:
raise NotImplementedError(
f"{feature} is not supported with async GRPO"
)
else:
if config["grpo"][feature]["enabled"]:
raise NotImplementedError(
f"{feature} is not supported with async GRPO"
)

from nemo_rl.algorithms.grpo import async_grpo_train

print("🚀 Running async GRPO training")

async_config = config["grpo"]["async_grpo"]
# Run async GRPO training
async_grpo_train(
policy=policy,
policy_generation=policy_generation,
dataloader=dataloader,
val_dataloader=val_dataloader,
tokenizer=tokenizer,
loss_fn=loss_fn,
task_to_env=task_to_env,
val_task_to_env=val_task_to_env,
logger=logger,
checkpointer=checkpointer,
grpo_save_state=grpo_state,
master_config=master_config,
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
)
else:
print("🚀 Running synchronous GRPO training")

# Run standard GRPO training
grpo_train(
policy,
policy_generation,
Expand Down
41 changes: 32 additions & 9 deletions nemo_rl/algorithms/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,17 +642,40 @@ def _run_prompt_group_worker(
prompt_idx: int,
) -> None:
try:
# Import here to avoid circular dependency
from nemo_rl.algorithms.grpo import _should_use_nemo_gym
from nemo_rl.experience.rollouts import run_async_nemo_gym_rollout

# Run rollout for this prompt group
# Async engine supports concurrent generation; avoid locking
final_batch, rollout_metrics = run_async_multi_turn_rollout(
policy_generation=self.policy_generation,
input_batch=repeated_batch,
tokenizer=self.tokenizer,
task_to_env=self.task_to_env,
max_seq_len=self.master_config["policy"]["max_total_sequence_length"],
max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"],
greedy=False,
)
# Check if we should use nemo_gym (similar to synchronous GRPO)
if _should_use_nemo_gym(self.master_config):
generation_config = self.master_config["policy"]["generation"]
env_cfg = self.master_config.get("env") or {}
nemo_gym_rollout_result = run_async_nemo_gym_rollout(
policy_generation=self.policy_generation,
input_batch=repeated_batch,
tokenizer=self.tokenizer,
task_to_env=self.task_to_env,
max_seq_len=None,
generation_config=generation_config,
max_rollout_turns=None,
greedy=False,
)
final_batch = nemo_gym_rollout_result.final_batch
rollout_metrics = nemo_gym_rollout_result.rollout_metrics
else:
final_batch, rollout_metrics = run_async_multi_turn_rollout(
policy_generation=self.policy_generation,
input_batch=repeated_batch,
tokenizer=self.tokenizer,
task_to_env=self.task_to_env,
max_seq_len=self.master_config["policy"][
"max_total_sequence_length"
],
max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"],
greedy=False,
)

# Move to CPU and push to buffer (avoid blocking on GC/push)
final_batch_cpu = final_batch.to("cpu")
Expand Down
17 changes: 15 additions & 2 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2975,11 +2975,24 @@ def async_grpo_train(
checkpointer.finalize_checkpoint(checkpoint_path)
policy.offload_after_refit()

log_data = {"content": flat_messages_content}
# Logging
# Log training data (match sync GRPO logging payload for parity)
log_data = {}
if "agent_ref" in repeated_batch:
log_data["agent_ref"] = repeated_batch["agent_ref"]
log_data["content"] = flat_messages_content
log_data["rewards"] = rewards.tolist()
if master_config["grpo"]["use_dynamic_sampling"]:
# In dynamic sampling, `rewards` corresponds to filtered rewards
log_data["filtered_rewards"] = rewards.tolist()
log_data["rewards"] = repeated_batch["total_reward"].tolist()
log_data["input_lengths"] = input_lengths.tolist()
log_data["token_ids"] = train_data["input_ids"].tolist()
log_data["token_loss_mask"] = train_data["token_mask"].tolist()
log_data["sample_loss_mask"] = train_data["sample_mask"].tolist()
log_data["advantages"] = train_data["advantages"].tolist()
log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist()
log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist()
log_data["input_lengths"] = input_lengths.tolist()
logger.log_batched_dict_as_jsonl(
log_data, f"train_data_step{step + 1}.jsonl"
)
Expand Down
14 changes: 14 additions & 0 deletions nemo_rl/environments/nemo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,20 @@ def _postprocess_nemo_gym_to_nemo_rl_result(
)
output_item_dict.pop("generation_log_probs")

if not nemo_rl_message_log:
input_messages = nemo_gym_result["responses_create_params"]["input"]
prompt_token_ids = tokenizer.apply_chat_template(
input_messages, tokenize=True
)
raise ValueError(
f"NeMo Gym returned a result with no generation data. "
f"This typically means the prompt for the first turn already exceeds the vLLM max_model_len, "
f"so vLLM rejected the request before any tokens could be generated.\n"
f" Prompt length: {len(prompt_token_ids)} tokens.\n"
f" → Fix: increase `policy.max_total_sequence_length` and `policy.generation.vllm_cfg.max_model_len` "
f"to a value larger than {len(prompt_token_ids)}."
)

return {
"message_log": nemo_rl_message_log,
"input_message_log": nemo_rl_message_log[:1],
Expand Down
58 changes: 44 additions & 14 deletions nemo_rl/models/generation/vllm/vllm_worker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,20 +369,31 @@ async def _preprocess_chat(
messages_for_replace_prefix_tokens = deepcopy(messages)

# res is conversation, [request_prompt], [engine_prompt]
res = await super()._preprocess_chat(
request,
tokenizer,
messages,
chat_template,
chat_template_content_format,
add_generation_prompt,
continue_final_message,
tool_dicts,
documents,
chat_template_kwargs,
tool_parser,
add_special_tokens,
)
try:
res = await super()._preprocess_chat(
request,
tokenizer,
messages,
chat_template,
chat_template_content_format,
add_generation_prompt,
continue_final_message,
tool_dicts,
documents,
chat_template_kwargs,
tool_parser,
add_special_tokens,
)
except ValueError as e:
if "maximum context length" in str(e):
import logging

# Print a clean one-liner warning that max model length has been exceeded
# The exception is still raised, but later filtered out by the MaxContextLengthFilter
logging.getLogger(__name__).warning(
"Prompt exceeds max_model_len: %s", e
)
raise

if request.required_prefix_token_ids is None:
return res
Expand Down Expand Up @@ -572,6 +583,24 @@ def filter(self, record: LogRecord) -> bool:

vllm_async_llm_logger.addFilter(CleanLoggingFilter())

from logging import getLogger as _getLogger

_getLogger("vllm.entrypoints.openai.protocol").addFilter(CleanLoggingFilter())

# Suppress the noisy vLLM traceback when a prompt exceeds max_model_len.
# This is expected during multi-turn rollouts; we log a clean one-line
# warning from _preprocess_chat instead.
class MaxContextLengthFilter(LoggingFilter):
def filter(self, record: LogRecord) -> bool:
if record.exc_info and record.exc_info[1]:
if "maximum context length" in str(record.exc_info[1]):
return False
return True

_getLogger("vllm.entrypoints.openai.serving_chat").addFilter(
MaxContextLengthFilter()
)

return app

def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]":
Expand Down Expand Up @@ -602,6 +631,7 @@ def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]":
app,
host="0.0.0.0",
port=free_port,
timeout_keep_alive=120, # Keep connections alive longer (default is 5s), fix for this error: Hit an exception while making a request (try 1): <class 'aiohttp.client_exceptions.ClientOSError'>: [Errno 104] Connection reset by peer
)
server = uvicorn.Server(config=config)

Expand Down
1 change: 1 addition & 0 deletions tests/functional/L1_Functional_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ time uv run --no-sync bash ./tests/functional/eval.sh
time uv run --no-sync bash ./tests/functional/eval_async.sh
time uv run --no-sync bash ./tests/functional/grpo.sh
time uv run --no-sync bash ./tests/functional/grpo_async.sh
time uv run --no-sync bash ./tests/functional/grpo_async_gym.sh
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh
Expand Down
87 changes: 87 additions & 0 deletions tests/functional/grpo_async_gym.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
# Mark the current repo as safe, since wandb fetches metadata about the repo
git config --global --add safe.directory $PROJECT_ROOT

set -eou pipefail

EXP_NAME=$(basename $0 .sh)
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
LOG_DIR=$EXP_DIR/logs
JSON_METRICS=$EXP_DIR/metrics.json
RUN_LOG=$EXP_DIR/run.log
CHECKPOINT_DIR=$EXP_DIR/checkpoints
DATA_DIR=$EXP_DIR/data
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}

rm -rf $EXP_DIR $LOG_DIR
mkdir -p $EXP_DIR $LOG_DIR $CHECKPOINT_DIR $DATA_DIR

cd $PROJECT_ROOT

# Follow nemo-gym instructions here to get this data:
# https://docs.nvidia.com/nemo/gym/0.1.0/tutorials/nemo-rl-grpo/setup.html#training-nemo-rl-grpo-setup
cd 3rdparty/Gym-workspace/Gym
config_paths="responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,\
resources_servers/workplace_assistant/configs/workplace_assistant.yaml"

uv run ng_prepare_data "+config_paths=[${config_paths}]" \
+output_dirpath=data/workplace_assistant \
+mode=train_preparation \
+should_download=true \
+data_source=huggingface
cd -

# This trimming of the workplace assistant dataset is necessary b/c with all the tools the first prompt is >4000 tokens
# which will cause vllm to return nothing on the first prompt and crash RL. Since we want to keep this test short to
# smoke test, we trim all but the first tool
TRAIN_PATH=$DATA_DIR/workplace_assistant_train.jsonl
VALIDATION_PATH=$DATA_DIR/workplace_assistant_validation.jsonl
jq -c '.responses_create_params.tools |= (.[0:1])' 3rdparty/Gym-workspace/Gym/data/workplace_assistant/train.jsonl > $TRAIN_PATH
jq -c '.responses_create_params.tools |= (.[0:1])' 3rdparty/Gym-workspace/Gym/data/workplace_assistant/validation.jsonl > $VALIDATION_PATH

uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
$PROJECT_ROOT/examples/nemo_gym/run_grpo_nemo_gym.py \
--config $PROJECT_ROOT/examples/nemo_gym/grpo_qwen3_30ba3b_instruct.yaml \
policy.model_name=Qwen/Qwen3-0.6B \
policy.dtensor_cfg.enabled=false \
policy.megatron_cfg.enabled=true \
policy.megatron_cfg.tensor_model_parallel_size=1 \
policy.megatron_cfg.pipeline_model_parallel_size=1 \
policy.megatron_cfg.expert_model_parallel_size=1 \
policy.megatron_cfg.context_parallel_size=1 \
policy.megatron_cfg.sequence_parallel=false \
policy.generation.vllm_cfg.tensor_parallel_size=1 \
policy.generation.vllm_cfg.async_engine=true \
policy.max_total_sequence_length=512 \
policy.generation.colocated.enabled=false \
policy.generation.colocated.resources.num_nodes=1 \
policy.generation.colocated.resources.gpus_per_node=1 \
grpo.num_prompts_per_step=4 \
grpo.num_generations_per_prompt=2 \
grpo.max_num_steps=10 \
grpo.async_grpo.enabled=true \
grpo.async_grpo.max_trajectory_age_steps=1 \
grpo.async_grpo.in_flight_weight_updates=true \
policy.train_global_batch_size=4 \
policy.train_micro_batch_size=1 \
cluster.gpus_per_node=2 \
loss_fn.use_importance_sampling_correction=true \
logger.tensorboard_enabled=true \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=false \
logger.monitor_gpus=true \
checkpointing.enabled=false \
checkpointing.checkpoint_dir=$CHECKPOINT_DIR \
data.train.data_path=$TRAIN_PATH \
data.validation.data_path=$VALIDATION_PATH \
$@ \
2>&1 | tee $RUN_LOG

uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# Observed to be between 0.8-1.3
uv run tests/check_metrics.py $JSON_METRICS \
'median(data["train/gen_kl_error"]) < 1.3'
Loading