-
Notifications
You must be signed in to change notification settings - Fork 398
[Feature] support async rl #1360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
efb3109 to
1601d51
Compare
5e3f135 to
aaa4860
Compare
| waiting_tasks = set() | ||
| dataflow_start_time = time.perf_counter() | ||
| task_completion_times = [] | ||
| with tqdm(total=self.target_batch_size, desc="rollout_controller for training samples") as pbar: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用 tqdm(miniters=10) (Minimum progress display update interval in iters)并在循环中使用 pbar.update(finished_samples) 来代替 manual pbar.fresh。最小化pbar在loop中的操作。
| data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx | ||
| ) | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice hierarchical code!
| collator="fake_collator", | ||
| pack_level="none", | ||
| expired_threshold = ( | ||
| min(remain_size, self.config.tail_batch_trigger_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use cast(int, xxx) instead
| self.finished_samples_count = await self.replay_buffer.get_completed_samples_count.remote() | ||
| waiting_tasks = pending_tasks | ||
|
|
||
| while len(waiting_tasks) + self.finished_samples_count < max(data_concurrency, self.target_batch_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
len(waiting_tasks) + self.finished_samples_count < data_concurrency + init_finished_samples_count
31b3535 to
953a613
Compare
f6fa0fd to
4bd4c4f
Compare
xtuner/v1/data_proto/rl_data.py
Outdated
| extra_info: Dict[str, Any] = Field(default_factory=dict) | ||
| state: RolloutState = RolloutState.INIT | ||
|
|
||
| def _update_by_append(self, other: "RLRolloutResponseItem") -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def _update_by_append(self, other: "RLRolloutResponseItem") -> None: | |
| def _update_by_append(self, other: Self) -> None: |
xtuner/v1/data_proto/rl_data.py
Outdated
| self.state = other.state | ||
| return | ||
|
|
||
| def update(self, other: "RLRolloutResponseItem") -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和上面一样修改
|
|
||
| if other_ids_copy is not None: | ||
| assert self.response_ids is not None, "response_ids must not be None when updating partial data." | ||
| self.response_ids.extend(other_ids_copy.copy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥要 copy 两次?
| tail_batch_trigger_size: Annotated[ | ||
| Optional[int], | ||
| Parameter( | ||
| help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set to 0 to disable. 这句描述不对
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个没有所谓的 enable说法吧,需要配合 tail_batch_candidate_steps 才生效
xtuner/v1/data_proto/rl_data.py
Outdated
| response_ids: Optional[List[int]] = None | ||
| logprobs: Optional[List[float]] = None | ||
| num_return_tokens: Optional[int] = None | ||
| versioned_response: List[str] = Field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
思考下未来多轮情况下,这个地方是否有改动?
| tail_batch_trigger_size: Annotated[ | ||
| Optional[int], | ||
| Parameter( | ||
| help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个没有所谓的 enable说法吧,需要配合 tail_batch_candidate_steps 才生效
| Parameter(help="Weights for different states in the replay buffer."), | ||
| ] = {} | ||
| # async rollout related configs, assigned from dataflow cfg | ||
| enable_partial_rollout: Annotated[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
既然这些参数不让用户设置,是自动赋值,是否有其他实现方法,强制不让用户误以为可以设置
| else: | ||
| self.dataloader_cfg = DataloaderConfig( | ||
| collator="fake_collator", | ||
| pack_level="none", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_worker 默认可以设置为 1 或者 2。考虑到多模态场景
| self._completed_actions[replay_meta.version].append(action_id) | ||
| self.logger.debug(f"Add sample with root_id: {root_id}, action_id: {action_id} to finished_actions.") | ||
| else: | ||
| assert False, f"Unsupported rollout state {state} for action_id {action_id} in ReplayBufferStorage." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise AssertionError(xxxx)
|
|
||
| for sample in group_samples: | ||
| assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!" | ||
| if "routed_experts" in sample.env.rollout.extra_info: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不能暴力删除,需要考虑内存泄露情况
| data = base64.b64decode(routed_experts) | ||
| routed_experts = ray.cloudpickle.loads(data) | ||
| else: | ||
| routed_experts = torch.tensor(routed_experts) # n,layer,expert |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sglang 是走的这个分支 。如果运行到这个分支那么 先 routed_experts = ray.put(routed_experts) 然后 await routed_experts 就太怪了。建议还是处理下
| cur_routed_experts = cur_routed_experts[exist_routed_experts.shape[0] :, :, :] | ||
| concat_routed_experts = np.concatenate((exist_routed_experts, cur_routed_experts), axis=0) | ||
| prompt_tokens = response["meta_info"].get("prompt_tokens", 0) | ||
| response_tokens = response["meta_info"].get("completion_tokens", 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以加个 assert,判断 concat_routed_experts 序列长度等于 prompt_tokens+response_tokens-1
| if not self.enable_partial_rollout: | ||
| # 清除上次的response_ids等env数据 | ||
| if "routed_experts" in sample.env.rollout.extra_info: | ||
| del sample.env.rollout.extra_info["routed_experts"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同理
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否有考虑在中断情况下,下一次发送请求时候发给同一个 server,从而复用 cache.
2397345 to
003ca72
Compare
003ca72 to
ab6bccc
Compare
This PR introduces asynchronous RL support to Xtuner, enabling partial rollouts and version-based sample management for more efficient training data generation.
1. Key Concepts:
2. Async logic:
staleness_threshold=0.0enable_partial_rollout=0tail_batch_candidate_steps=0staleness_threshold=0.2enable_partial_rollout=0tail_batch_candidate_steps=02. Responses not retained when paused rollout
3. Prioritize sampling data from the abort queue
staleness_threshold=0.2enable_partial_rollout=0tail_batch_candidate_steps=1tail_batch_trigger_size=02. Responses not retained when paused
3. Prioritize sampling data from the abort queue
4. Put it into the candidate pool when sample abort num reaches
tail_batch_candidate_steps+1staleness_threshold=0.2enable_partial_rollout=1tail_batch_candidate_steps=0tail_batch_trigger_size=02. Responses retained & concatenated when paused
3. Prioritize sampling data from the abort queue
staleness_threshold=0.2enable_partial_rollout=1tail_batch_candidate_steps=1tail_batch_trigger_size=02. Responses retained & concatenated when paused
3. Prioritize sampling data from the abort queue
4. Put it into the candidate pool when sample abort num reaches
tail_batch_candidate_steps+1. thetail_batch_candidate_stepsmeans off policy step3. BenchMark
4. Relative PR
sample_from_expired_storagein dataflow. Whensample_from_expired_storageis set to True, the dataflow will not oversend data and will return data only after all tasks of the current batch are completed.