Skip to content

Conversation

@YanhuiDua
Copy link
Collaborator

@YanhuiDua YanhuiDua commented Dec 15, 2025

This PR introduces asynchronous RL support to Xtuner, enabling partial rollouts and version-based sample management for more efficient training data generation.

1. Key Concepts:

  • staleness_threshold: The maximum allowed threshold of stale (expired) samples in a training batch.
  • enable_partial_rollout: Whether to enable partial rollout for asynchronous data generation.
  • tail_batch_candidate_steps: Number of rollout steps after which a sample becomes a candidate for the tail batch. Set to 0 to disable. 0 means no tail batch.
  • tail_batch_trigger_size: Number of candidate samples needed in the queue to trigger a tail batch operation. It will be set to global_batch_size when not provided by user or set to 0

2. Async logic:

Strategy Type Settings Core Features
Synchronous Strategy staleness_threshold=0.0
enable_partial_rollout=0
tail_batch_candidate_steps=0
1. No data oversending
Asynchronous 1 staleness_threshold=0.2
enable_partial_rollout=0
tail_batch_candidate_steps=0
1. 20% data oversending
2. Responses not retained when paused rollout
3. Prioritize sampling data from the abort queue
Asynchronous 2 staleness_threshold=0.2
enable_partial_rollout=0
tail_batch_candidate_steps=1
tail_batch_trigger_size=0
1. 20% data oversending
2. Responses not retained when paused
3. Prioritize sampling data from the abort queue
4. Put it into the candidate pool when sample abort num reaches tail_batch_candidate_steps+1
Asynchronous 3 staleness_threshold=0.2
enable_partial_rollout=1
tail_batch_candidate_steps=0
tail_batch_trigger_size=0
1. 20% data oversending
2. Responses retained & concatenated when paused
3. Prioritize sampling data from the abort queue
Asynchronous 4 staleness_threshold=0.2
enable_partial_rollout=1
tail_batch_candidate_steps=1
tail_batch_trigger_size=0
1. 20% data oversending
2. Responses retained & concatenated when paused
3. Prioritize sampling data from the abort queue
4. Put it into the candidate pool when sample abort num reaches tail_batch_candidate_steps+1. the tail_batch_candidate_steps means off policy step

3. BenchMark

4. Relative PR

  • Added async-related configuration parameters including partial_rollout, tail_batch_candidate_steps, tail_batch_trigger_size and staleness_threshold;
  • Refactored replay buffer storage to support versioned samples with bucketed tracking of completed, aborted, and expired states
  • Renamed Sampler to DatasetSampler and separated dataset sampling logic from replay buffer sampling
  • Apply sample_from_expired_storage in dataflow. When sample_from_expired_storage is set to True, the dataflow will not oversend data and will return data only after all tasks of the current batch are completed.
  • Add task time log info.
  • Added partial rollout functionality with versioned response tracking to accumulate tokens across multiple generation steps
  • Implemented automatic worker restart mechanism when all rollout workers become inactive
  • Fixed state handling for aborted rollouts and improved error logging
  • Add tensorboard for training and rollout metrics.
  • Refactored the training loop in fit() to conditionally execute rollout, training, and weight synchronization based on debug mode
  • Fix async running bugs

@YanhuiDua YanhuiDua force-pushed the support_async_rl_4 branch 2 times, most recently from 5e3f135 to aaa4860 Compare December 19, 2025 04:20
waiting_tasks = set()
dataflow_start_time = time.perf_counter()
task_completion_times = []
with tqdm(total=self.target_batch_size, desc="rollout_controller for training samples") as pbar:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用 tqdm(miniters=10) (Minimum progress display update interval in iters)并在循环中使用 pbar.update(finished_samples) 来代替 manual pbar.fresh。最小化pbar在loop中的操作。

data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx
)
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice hierarchical code!

collator="fake_collator",
pack_level="none",
expired_threshold = (
min(remain_size, self.config.tail_batch_trigger_size)
Copy link
Collaborator

@jayhenry jayhenry Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use cast(int, xxx) instead

self.finished_samples_count = await self.replay_buffer.get_completed_samples_count.remote()
waiting_tasks = pending_tasks

while len(waiting_tasks) + self.finished_samples_count < max(data_concurrency, self.target_batch_size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(waiting_tasks) + self.finished_samples_count < data_concurrency + init_finished_samples_count

extra_info: Dict[str, Any] = Field(default_factory=dict)
state: RolloutState = RolloutState.INIT

def _update_by_append(self, other: "RLRolloutResponseItem") -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _update_by_append(self, other: "RLRolloutResponseItem") -> None:
def _update_by_append(self, other: Self) -> None:

self.state = other.state
return

def update(self, other: "RLRolloutResponseItem") -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和上面一样修改


if other_ids_copy is not None:
assert self.response_ids is not None, "response_ids must not be None when updating partial data."
self.response_ids.extend(other_ids_copy.copy())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥要 copy 两次?

tail_batch_trigger_size: Annotated[
Optional[int],
Parameter(
help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set to 0 to disable. 这句描述不对

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个没有所谓的 enable说法吧,需要配合 tail_batch_candidate_steps 才生效

response_ids: Optional[List[int]] = None
logprobs: Optional[List[float]] = None
num_return_tokens: Optional[int] = None
versioned_response: List[str] = Field(default_factory=list)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

思考下未来多轮情况下,这个地方是否有改动?

tail_batch_trigger_size: Annotated[
Optional[int],
Parameter(
help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个没有所谓的 enable说法吧,需要配合 tail_batch_candidate_steps 才生效

Parameter(help="Weights for different states in the replay buffer."),
] = {}
# async rollout related configs, assigned from dataflow cfg
enable_partial_rollout: Annotated[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然这些参数不让用户设置,是自动赋值,是否有其他实现方法,强制不让用户误以为可以设置

else:
self.dataloader_cfg = DataloaderConfig(
collator="fake_collator",
pack_level="none",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_worker 默认可以设置为 1 或者 2。考虑到多模态场景

self._completed_actions[replay_meta.version].append(action_id)
self.logger.debug(f"Add sample with root_id: {root_id}, action_id: {action_id} to finished_actions.")
else:
assert False, f"Unsupported rollout state {state} for action_id {action_id} in ReplayBufferStorage."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise AssertionError(xxxx)


for sample in group_samples:
assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!"
if "routed_experts" in sample.env.rollout.extra_info:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不能暴力删除,需要考虑内存泄露情况

data = base64.b64decode(routed_experts)
routed_experts = ray.cloudpickle.loads(data)
else:
routed_experts = torch.tensor(routed_experts) # n,layer,expert
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sglang 是走的这个分支 。如果运行到这个分支那么 先 routed_experts = ray.put(routed_experts) 然后 await routed_experts 就太怪了。建议还是处理下

cur_routed_experts = cur_routed_experts[exist_routed_experts.shape[0] :, :, :]
concat_routed_experts = np.concatenate((exist_routed_experts, cur_routed_experts), axis=0)
prompt_tokens = response["meta_info"].get("prompt_tokens", 0)
response_tokens = response["meta_info"].get("completion_tokens", 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以加个 assert,判断 concat_routed_experts 序列长度等于 prompt_tokens+response_tokens-1

if not self.enable_partial_rollout:
# 清除上次的response_ids等env数据
if "routed_experts" in sample.env.rollout.extra_info:
del sample.env.rollout.extra_info["routed_experts"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否有考虑在中断情况下,下一次发送请求时候发给同一个 server,从而复用 cache.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants