diff --git a/examples/qwen2.5-0.5B-agentic/agentic_val_sokoban_lora.yaml b/examples/qwen2.5-0.5B-agentic/agentic_val_sokoban_lora.yaml index 5dd3fe377..48f668add 100644 --- a/examples/qwen2.5-0.5B-agentic/agentic_val_sokoban_lora.yaml +++ b/examples/qwen2.5-0.5B-agentic/agentic_val_sokoban_lora.yaml @@ -29,26 +29,27 @@ system_envs: # - roll # - baseline -track_with: tensorboard +track_with: stdout # Disable TensorBoard in smoke run to bypass SummaryWriter type constraints. tracker_kwargs: - log_dir: /data/oss_bucket_0/yali/llm/tensorboard/roll_exp/agentic_sokoban + log_dir: ./output/tensorboard/agentic_sokoban_lora_smoke # Use local path so smoke test does not depend on external mount. checkpoint_config: type: file_system - output_dir: /data/cpfs_0/rl_examples/models/${exp_name} + output_dir: /tmp/roll_output/agentic_sokoban_lora_smoke # Keep checkpoint path local for a portable smoke run. -num_gpus_per_node: 8 +num_gpus_per_node: 4 # Fit smoke test to a 4-GPU node. -max_steps: 1024 +max_steps: 3 # Minimal training smoke: one training step is enough to verify end-to-end path. save_steps: 10000 logging_steps: 1 -eval_steps: 10 +eval_steps: 0 # Disable eval loop for faster smoke validation. resume_from_checkpoint: false +async_generation_ratio: 1 # Required by partial_gpu_mode validation in agentic_pipeline. -rollout_batch_size: 1024 -val_batch_size: 1024 -sequence_length: 8192 +rollout_batch_size: 4 # Keep rollout tiny to reduce runtime/memory. +val_batch_size: 4 +sequence_length: 2048 # Reduce memory pressure while preserving normal train path. advantage_clip: 0.2 ppo_epochs: 1 @@ -75,9 +76,9 @@ actor_train: training_args: learning_rate: 2.0e-5 weight_decay: 0 - per_device_train_batch_size: 2 - gradient_accumulation_steps: 64 - warmup_steps: 10 + per_device_train_batch_size: 1 # Minimal micro-batch for smoke stability. + gradient_accumulation_steps: 2 + warmup_steps: 1 lr_scheduler_type: cosine data_args: template: qwen2_5 @@ -91,7 +92,7 @@ actor_train: expert_model_parallel_size: 1 use_distributed_optimizer: true recompute_granularity: full - device_mapping: list(range(0,8)) + device_mapping: list(range(0,2)) # Constrain actor_train to 4 GPUs for this smoke profile. infer_batch_size: 2 actor_infer: @@ -102,11 +103,11 @@ actor_infer: lora_rank: 32 lora_alpha: 32 generating_args: - max_new_tokens: 128 # single-turn response length - top_p: 0.99 - top_k: 100 + max_new_tokens: 64 # Shorter generation keeps smoke test fast. + top_p: 1 + top_k: 3 num_beams: 1 - temperature: 0.99 + temperature: 0.0 num_return_sequences: 1 data_args: template: qwen2_5 @@ -116,7 +117,7 @@ actor_infer: gpu_memory_utilization: 0.8 block_size: 16 load_format: auto - device_mapping: list(range(0,8)) + device_mapping: list(range(0,4)) # Constrain actor_infer to same 4-GPU pool. reference: model_args: @@ -129,7 +130,7 @@ reference: strategy_args: strategy_name: hf_infer strategy_config: ~ - device_mapping: list(range(0,8)) + device_mapping: list(range(0,2)) # Keep reference mapping consistent with 4-GPU smoke topology. infer_batch_size: 2 reward_normalization: @@ -138,19 +139,19 @@ reward_normalization: train_env_manager: format_penalty: -0.15 # sokoban env penalty_for_step=-0.1 - max_env_num_per_worker: 16 - num_env_groups: 128 + max_env_num_per_worker: 4 # Smaller env fanout for quick smoke startup. + num_env_groups: 2 # under the same group, the env config and env seed are ensured to be equal - group_size: 8 + group_size: 2 tags: [SimpleSokoban] - num_groups_partition: [128] # If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + num_groups_partition: [2] # Match reduced group count for smoke. val_env_manager: - max_env_num_per_worker: 32 - num_env_groups: 1024 + max_env_num_per_worker: 4 # Keep validation manager light even though eval is disabled. + num_env_groups: 4 group_size: 1 # should be set to 1 because val temperature is set to 0 and same prompt leads to same output tags: [SimpleSokoban, LargerSokoban, SokobanDifferentGridVocab, FrozenLake] - num_groups_partition: [256, 256, 256, 256] # TODO: If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + num_groups_partition: [1, 1, 1, 1] # Minimal partitioning for smoke. # Here, you can override variables defined in the imported envs. max_tokens_per_step: 128 in custom_env.SimpleSokoban, here replaced by 64 diff --git a/examples/qwen2.5-0.5B-agentic/agentic_val_sokoban_mulit_lora_partial_overlap.yaml b/examples/qwen2.5-0.5B-agentic/agentic_val_sokoban_mulit_lora_partial_overlap.yaml new file mode 100644 index 000000000..292528d8f --- /dev/null +++ b/examples/qwen2.5-0.5B-agentic/agentic_val_sokoban_mulit_lora_partial_overlap.yaml @@ -0,0 +1,206 @@ +defaults: + - ../config/traj_envs@_here_ + - ../config/deepspeed_zero@_here_ + - ../config/deepspeed_zero2@_here_ + - ../config/deepspeed_zero3@_here_ + - ../config/deepspeed_zero3_cpuoffload@_here_ + +hydra: + run: + dir: . + output_subdir: null + +pipeline_cls: roll.pipeline.agentic.agentic_multi_lora_pipeline.AgenticMultiLoraPipeline + + + +exp_name: "agent_train_sokoban_multi_lora1" +seed: 42 +logging_dir: ./output/lora_pipeline1/logs +output_dir: ./output/lora_pipeline1 +render_save_dir: /tmp/roll_output/lora_pipeline1/render + +# track_with: wandb +# tracker_kwargs: +# entity: "khd6t7hdhn-university-of-pennsylvania" +# project: "rlix" +# api_key: "${oc.env:WANDB_API_KEY}" + + +system_envs: + USE_MODELSCOPE: "0" + NCCL_SHM_DISABLE: "1" + RAY_PROFILING: "1" + RAY_DEDUP_LOGS: "0" + RAY_TMPDIR: "${oc.env:RAY_TMPDIR,/tmp}" + OMP_NUM_THREADS: "1" + MKL_NUM_THREADS: "1" + OPENBLAS_NUM_THREADS: "1" + RAY_grpc_server_thread_pool_size: "4" + TORCHINDUCTOR_COMPILE_THREADS: "1" + TORCHINDUCTOR_MAX_AUTOTUNE: "0" + # Container lacks SYS_PTRACE capability; disable vLLM custom all-reduce IPC and use NCCL fallback + VLLM_DISABLE_CUSTOM_ALL_REDUCE: "1" + +checkpoint_config: + type: file_system + output_dir: /tmp/roll_output/multi_lora2/checkpoints + +num_gpus_per_node: 2 +model_download_type: HUGGINGFACE_HUB +offload_nccl: true +max_steps: 3 +model_update_buffer_size_mb: 100 # Limit broadcast bucket to 100 MB to avoid OOM with co-located infer workers +model_update_transport: cpu_serialize # CPU byte serialization; avoids pidfd_getfd error in restricted containers +verify_model_after_sync: true +save_steps: 10000 +logging_steps: 1 +eval_steps: 20 +resume_from_checkpoint: false + +async_generation_ratio: 1 + +rollout_batch_size: 4 +val_batch_size: 4 +sequence_length: 1024 # Reduced from 2048: Sokoban max_new_tokens=64 needs ~500 tokens max, halves peak activation memory +max_actions_per_traj: 5 + +advantage_clip: 0.2 +ppo_epochs: 1 +adv_estimator: "grpo" +init_kl_coef: 0.0 +whiten_advantages: true +entropy_loss_coef: 0 +max_grad_norm: 1.0 + +pretrain: Qwen/Qwen2.5-0.5B-Instruct +reward_pretrain: Qwen/Qwen2.5-0.5B-Instruct + +actor_train: + offload_nccl: ${offload_nccl} + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + adapters: + Sokoban1: + lora_target: all-linear + lora_rank: 8 + lora_alpha: 8 + Sokoban2: + lora_target: all-linear + lora_rank: 8 + lora_alpha: 8 + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 2 + warmup_steps: 1 + lr_scheduler_type: cosine + data_args: + template: qwen2_5 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: false + is_lora_optimizer_isolated: true + recompute_granularity: full + sequence_parallel: true + overlap_grad_reduce: false # Isolated LoRA mode requires overlap_grad_reduce disabled to avoid grad-sync hang. + # Note: use_sequence_packing is NOT enabled here — sequence packing mixes sequences from different + # LoRA adapters into one microbatch, violating the adapter-homogeneity constraint in inner_forward_step. + # Note: use_dynamic_batching_in_train is also NOT enabled — incompatible with is_lora_optimizer_isolated=true. + device_mapping: "[0, ]" + infer_batch_size: 1 + +actor_infer: + offload_nccl: ${offload_nccl} + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + adapters: + Sokoban1: + lora_target: q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj + lora_rank: 8 + lora_alpha: 8 + Sokoban2: + lora_target: q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj + lora_rank: 8 + lora_alpha: 8 + generating_args: + max_new_tokens: 64 + top_p: 1 + top_k: 3 + num_beams: 1 + temperature: 0.0 + num_return_sequences: 1 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: vllm + strategy_config: + VLLM_USE_V1: 1 + gpu_memory_utilization: 0.8 # Raise cache budget so vLLM has non-zero KV blocks during two-worker startup. + block_size: 16 + load_format: auto + tensor_parallel_size: 1 + max_num_batched_tokens: 1024 # Match reduced sequence_length=1024 + max_num_seqs: 2 + enforce_eager: true + sleep_level: 1 + device_mapping: "[0, 1, ]" + +reference: + offload_nccl: ${offload_nccl} + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2_5 + strategy_args: + strategy_name: megatron_infer + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + # Dynamic batching on reference (megatron_infer): trims padding per-microbatch to actual token length + # (rounded to sequence_length_round_in_infer). Reduces peak memory during log_prob computation. + use_dynamic_batching_in_infer: true + max_tokens_per_microbatch_in_infer: 1024 # Match reduced sequence_length=1024 + sequence_length_round_in_infer: 8 + device_mapping: "[0, ]" + infer_batch_size: 1 + +reward_normalization: + grouping: traj_group_id + method: mean_std + +train_env_manager: + format_penalty: -0.15 + max_env_num_per_worker: 4 + num_env_groups: 2 + group_size: 2 + tags: [Sokoban1, Sokoban2] + num_groups_partition: [1, 1] + +val_env_manager: + max_env_num_per_worker: 4 + num_env_groups: 2 + group_size: 2 + tags: [Sokoban1, Sokoban2] + num_groups_partition: [1, 1] + +max_tokens_per_step: 64 + +custom_envs: + Sokoban1: + ${custom_env.SimpleSokoban} + Sokoban2: + ${custom_env.SimpleSokoban} diff --git a/examples/start_agentic_pipeline.py b/examples/start_agentic_pipeline.py index 1b10c685f..4a9b1dab5 100644 --- a/examples/start_agentic_pipeline.py +++ b/examples/start_agentic_pipeline.py @@ -34,6 +34,7 @@ def main(): pipeline = pipeline_cls(pipeline_config=ppo_config) pipeline.run() + print("Pipeline finished.") if __name__ == "__main__": diff --git a/mcore_adapter/src/mcore_adapter/adapters/lora_layer.py b/mcore_adapter/src/mcore_adapter/adapters/lora_layer.py index f201a1446..9611c4ceb 100644 --- a/mcore_adapter/src/mcore_adapter/adapters/lora_layer.py +++ b/mcore_adapter/src/mcore_adapter/adapters/lora_layer.py @@ -254,6 +254,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): if self.sequence_parallel and self.base_layer.parallel_mode == "row": lora_result = scatter_to_sequence_parallel_region(lora_result) + # Cast per-adapter result before accumulating; each adapter may compute in its own weight dtype. + if lora_result.dtype != previous_dtype: + lora_result = lora_result.to(previous_dtype) result = result + lora_result result = result.to(previous_dtype) @@ -413,6 +416,8 @@ def _create_lora_layers(self, r, lora_bias, **kwargs): in_features = self.in_features * self.tp_size if self.is_grouped: + if not isinstance(TEGroupedLinear, type): + raise RuntimeError("Grouped LoRA layers require Transformer Engine grouped linear support.") r = r // self.config.moe_router_topk lora_a = TERowParallelGroupedLinear( num_gemms=self.base_layer.num_gemms, @@ -457,6 +462,8 @@ def _create_lora_layers(self, r, lora_bias, **kwargs): out_features = self.out_features * self.tp_size if self.is_grouped: + if not isinstance(TEGroupedLinear, type): + raise RuntimeError("Grouped LoRA layers require Transformer Engine grouped linear support.") r = r // self.config.moe_router_topk lora_a = TEGroupedLinear( num_gemms=self.base_layer.num_gemms, @@ -518,6 +525,13 @@ def dispatch_megatron( elif isinstance(target_base_layer, (TELinear, TEGroupedLinear)): # default to column parallel linear for non-parallel linear layers new_module = LoraColumnParallelLinear(base_layer=target, adapter_name=adapter_name, **kwargs) + else: + # Fail fast: non-TE layers are not supported for LoRA. This prevents silent skip + # where peft would leave the module unchanged (no LoRA applied) with no error. + raise RuntimeError( + f"LoRA on {type(target_base_layer).__name__} is not supported. " + "Use transformer_impl=transformer_engine." + ) return new_module diff --git a/mcore_adapter/src/mcore_adapter/initialize.py b/mcore_adapter/src/mcore_adapter/initialize.py index fa8f70457..d397f37a3 100644 --- a/mcore_adapter/src/mcore_adapter/initialize.py +++ b/mcore_adapter/src/mcore_adapter/initialize.py @@ -53,6 +53,9 @@ def _initialize_distributed(args: "TrainingArguments"): rank=int(os.getenv("RANK", "0")), world_size=int(os.getenv("WORLD_SIZE", "1")), timeout=args.ddp_timeout_delta, + # Explicitly bind NCCL to this GPU from the start; avoids ambiguous + # device selection when multiple GPUs are visible to the process. + device_id=torch.device(args.device), ) # Set the tensor model-parallel, pipeline model-parallel, and # data-parallel communicators. diff --git a/mcore_adapter/src/mcore_adapter/models/model_utils.py b/mcore_adapter/src/mcore_adapter/models/model_utils.py index 3fcb7ebbb..681d83f0c 100644 --- a/mcore_adapter/src/mcore_adapter/models/model_utils.py +++ b/mcore_adapter/src/mcore_adapter/models/model_utils.py @@ -112,7 +112,8 @@ def forward(self, hidden_states): class _McaLoraLogitsHelper(torch.autograd.Function): @staticmethod def forward(ctx, logits: "torch.Tensor"): - return logits + # Return a fresh tensor so downstream inplace ops do not invalidate this custom backward. + return logits.clone() @staticmethod def backward(ctx, grad_output: "torch.Tensor"): diff --git a/requirements_common.txt b/requirements_common.txt index 5af345be3..28a599d81 100644 --- a/requirements_common.txt +++ b/requirements_common.txt @@ -1,4 +1,4 @@ -ray[default,cgraph]==2.48.0 # vllm required ray[default,cgraph]>=2.48.0 +ray[default,cgraph] # >=2.48.0 tao: let the pip figure out for vllm # vllm required ray[default,cgraph]>=2.48.0 numpy<2.0a0,>=1.25 tensordict sympy diff --git a/requirements_torch260_vllm.txt b/requirements_torch260_vllm.txt index 8a6d2d93b..546bf7683 100644 --- a/requirements_torch260_vllm.txt +++ b/requirements_torch260_vllm.txt @@ -4,8 +4,11 @@ torch==2.6.0.* torchvision==0.21.0.* torchaudio==2.6.0.* -flash-attn - -transformer-engine[pytorch]==2.2.0 +https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl +transformers==4.51.1 +tensorboard # needed for metrics tracking by default +# transformer-engine must be installed AFTER torch (pip install transformer-engine[pytorch] --no-build-isolation). +# TE availability is validated at runtime in base_config.py; install separately if needed. +# transformer-engine[pytorch]==2.2.0 deepspeed==0.16.4 vllm==0.8.4 diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index bd84be971..c37ce252c 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -177,6 +177,30 @@ class BaseConfig(ScheduleConfig): default=1024, metadata={"help": "Buffer size in MB for model update operations (e.g., 1024 = 1GB)."} ) + model_update_transport: str = field( + default="cuda_ipc", + metadata={ + "help": ( + "Transport for colocated model weight transfer (Megatron+vLLM only). " + "'cuda_ipc': default GPU transport via CUDA IPC; may require additional " + "container privileges in restricted environments. " + "'cpu_serialize': CPU byte serialization fallback via standard pickle, " + "avoids 'pidfd_getfd: Operation not permitted' in restricted containers. " + "Payload size equals model_update_buffer_size_mb per rank; rank 0 holds " + "num_gpus_per_worker copies during gather." + ) + }, + ) + verify_model_after_sync: bool = field( + default=False, + metadata={ + "help": ( + "When True, verify weight integrity after every model update sync " + "by comparing sender-side stats against receiver-side live weights. " + "Raises RuntimeError on mismatch." + ) + }, + ) num_nodes: int = field( default=1, metadata={"help": "Number of nodes available for distributed training."} @@ -291,6 +315,18 @@ def __post_init__(self): # Only validate for Megatron strategies if 'megatron' in strategy_name.lower(): + # Fail fast when Transformer Engine is required by Megatron config but unavailable. + strategy_config = self.actor_train.strategy_args.strategy_config or {} + transformer_impl = strategy_config.get("transformer_impl", "transformer_engine") + if transformer_impl == "transformer_engine": + from megatron.core.models.gpt.gpt_layer_specs import HAVE_TE + if not HAVE_TE: + raise RuntimeError( + "Transformer Engine is requested by actor_train Megatron config " + "(transformer_impl=transformer_engine) but not available. " + "Install transformer-engine or set " + "actor_train.strategy_args.strategy_config.transformer_impl=local." + ) try: validate_megatron_batch_size( batch_size=self.rollout_batch_size, diff --git a/roll/configs/model_args.py b/roll/configs/model_args.py index c9b8b8446..06c456408 100644 --- a/roll/configs/model_args.py +++ b/roll/configs/model_args.py @@ -3,6 +3,8 @@ import torch +from roll.utils.lora_routing import normalize_domain + # Inspired by: https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llamafactory/hparams/finetuning_args.py @dataclass @@ -11,6 +13,13 @@ class LoraArguments: Arguments pertaining to the LoRA training. """ + #todo(tao) rename as lora_name systematically + # Unique identifier for this adapter, used as routing key in multi-LoRA dispatch. + # Names are normalized via normalize_domain() to lowercase slugs (e.g., "Math/v2" -> "math_v2"). + adapter_name: str = field( + default="default", + metadata={"help": "The name of the adapter to be injected."}, + ) additional_target: Optional[str] = field( default=None, metadata={ @@ -61,6 +70,13 @@ class ModelArguments(LoraArguments): "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." }, ) + # Multi-LoRA support: maps normalized adapter names to their LoraArguments configs. + # Single-LoRA configs using legacy top-level lora_rank/lora_target are auto-converted + # to adapters={"default": LoraArguments(...)} in __post_init__. + adapters: Optional[Dict[str, LoraArguments]] = field( + default=None, + metadata={"help": "List of LoRA adapter configurations."}, + ) adapter_name_or_path: Optional[str] = field( default=None, metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}, @@ -119,19 +135,96 @@ class ModelArguments(LoraArguments): default=1, metadata={"help": "The group size for Ulysses attention."}, ) + # Maps raw adapter names (as written in YAML) to their normalized slugs. + # Used for reverse lookups when routing tags come from external sources that + # may use the original non-normalized spelling. + adapter_name_map: dict[str, str] = field(default_factory=dict, init=False) + + @property + def _is_single_lora(self) -> bool: + """True when using legacy top-level lora fields (no explicit adapters dict). + + Internal only: meaningful before __post_init__ canonicalizes single-LoRA + into an adapters dict. After init, use is_multi_lora to distinguish. + """ + return self.adapters is None and self.lora_rank is not None and self.lora_target is not None + + @property + def is_multi_lora(self) -> bool: + """True when the config carries multiple named LoRA adapters.""" + return self.adapters is not None and len(self.adapters) > 1 + + @staticmethod + def _split_arg(arg): + """Split a comma-separated string into a list of stripped items.""" + if isinstance(arg, str): + return [item.strip() for item in arg.split(",")] + return arg + + def _normalize_adapters(self) -> None: + """Normalize adapter names to lowercase slugs and apply per-adapter defaults.""" + if self.adapters is None: + return + + normalized_adapters: dict[str, LoraArguments] = {} + raw_to_final: dict[str, str] = {} + seen_bases: set[str] = set() + for raw_adapter_name, adapter_config in self.adapters.items(): + base = normalize_domain(raw_adapter_name) + # Fail fast on normalization collisions to keep tag->adapter mapping deterministic. + if base in seen_bases: + raise RuntimeError( + f"Adapter name collision: '{raw_adapter_name}' normalizes to '{base}' " + "which conflicts with an earlier adapter. Use distinct adapter names." + ) + seen_bases.add(base) + adapter_config.adapter_name = base + if adapter_config.lora_alpha is None or adapter_config.lora_alpha <= 0: + adapter_config.lora_alpha = adapter_config.lora_rank * 2 + # Skip splitting when lora_target looks like a regex (contains regex special chars). + if adapter_config.lora_target is not None and not any( + c in adapter_config.lora_target for c in ["*", "$", "|", "(", "^", "[", "+", "?", "\\"] + ): + adapter_config.lora_target = self._split_arg(adapter_config.lora_target) + adapter_config.additional_target = self._split_arg(adapter_config.additional_target) + normalized_adapters[base] = adapter_config + raw_to_final[str(raw_adapter_name)] = base + self.adapters = normalized_adapters + self.adapter_name_map = raw_to_final def __post_init__(self): - def split_arg(arg): - if isinstance(arg, str): - return [item.strip() for item in arg.split(",")] - return arg - - self.lora_alpha = self.lora_alpha or self.lora_rank * 2 - if self.lora_target is not None and not any(c in self.lora_target for c in ["*", "$", "|", "("]): - # split when lora_target is not regex expression - self.lora_target = split_arg(self.lora_target) - self.freeze_module_prefix: Optional[List[str]] = split_arg(self.freeze_module_prefix) - self.additional_target: Optional[List[str]] = split_arg(self.additional_target) + # --- LoRA mode dispatch --- + # Multi-LoRA: adapters dict is set explicitly in config. + # Only normalize the per-adapter configs; top-level lora_rank/lora_alpha are ignored. + # Empty adapters dict ({}) is treated as config error — fail fast to catch typos. + if self.adapters is not None: + if len(self.adapters) == 0: + raise ValueError("adapters dict is empty; remove it or add at least one adapter.") + self._normalize_adapters() + + # Single-LoRA: top-level lora_rank + lora_target set, no adapters dict. + # Canonicalize into a single-entry adapters dict for uniform downstream access. + elif self._is_single_lora: + self.lora_alpha = self.lora_alpha or self.lora_rank * 2 + self.adapters = { + "default": LoraArguments( + adapter_name="default", + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + lora_dropout=self.lora_dropout, + lora_target=self.lora_target, + ) + } + self._normalize_adapters() + + # No-LoRA: neither adapters nor lora_target set. Nothing to do. + + # --- Fields that apply regardless of LoRA mode --- + # Skip splitting when lora_target looks like a regex (contains regex special chars). + if self.lora_target is not None and not any(c in self.lora_target for c in ["*", "$", "|", "(", "^", "[", "+", "?", "\\"]): + self.lora_target = self._split_arg(self.lora_target) + self.freeze_module_prefix: Optional[List[str]] = self._split_arg(self.freeze_module_prefix) + self.additional_target: Optional[List[str]] = self._split_arg(self.additional_target) dtype_mapping = { "fp32": torch.float32, diff --git a/roll/configs/worker_config.py b/roll/configs/worker_config.py index 5ceb721a5..47196e026 100644 --- a/roll/configs/worker_config.py +++ b/roll/configs/worker_config.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +import ast from typing import Dict, List, Literal, Optional, Union from roll.configs import DataArguments, GeneratingArguments, ModelArguments @@ -240,7 +241,18 @@ def __post_init__(self): ) if self.device_mapping is not None: - self.device_mapping = eval(self.device_mapping) + if isinstance(self.device_mapping, str): + try: + self.device_mapping = ast.literal_eval(self.device_mapping) + except (ValueError, SyntaxError): + # Backward compatibility: many configs use "list(range(...))". + # RISK: __builtins__={} reduces but does not fully sandbox eval; + # acceptable only because input is a local config file, not user input. + self.device_mapping = eval( + self.device_mapping, + {"__builtins__": {}}, + {"list": list, "range": range}, + ) assert ( len(self.device_mapping) % self.num_gpus_per_worker == 0 ), f"len(device_mapping)={len(self.device_mapping)} must be divisible by num_gpus_per_worker={self.num_gpus_per_worker}." @@ -287,4 +299,3 @@ def is_actor_infer_overlapping_with_any_cluster(actor_infer: WorkerConfig, actor return True return False - diff --git a/roll/distributed/executor/cluster.py b/roll/distributed/executor/cluster.py index 15920ee8f..07b843072 100644 --- a/roll/distributed/executor/cluster.py +++ b/roll/distributed/executor/cluster.py @@ -20,7 +20,7 @@ dispatch_one_to_all, ) from roll.platforms import current_platform -from roll.utils.constants import RAY_NAMESPACE +from roll.utils.constants import RAY_NAMESPACE, rlix_env_vars from roll.distributed.scheduler.resource_manager import ResourceManager from roll.utils.import_utils import safe_import_class from roll.utils.logging import get_logger @@ -132,6 +132,7 @@ def _create_workers(self): "CLUSTER_NAME": self.cluster_name, "WORKER_NAME": worker_name, } + env_vars.update(rlix_env_vars()) if rank != 0: env_vars["MASTER_ADDR"] = self.master_addr diff --git a/roll/distributed/executor/model_update_group.py b/roll/distributed/executor/model_update_group.py index 3ea8effd4..1dbde1913 100644 --- a/roll/distributed/executor/model_update_group.py +++ b/roll/distributed/executor/model_update_group.py @@ -4,6 +4,52 @@ from roll.distributed.executor.cluster import Cluster from roll.distributed.scheduler.protocol import DataProto from roll.utils.functionals import reduce_metrics_list +from roll.utils.logging import get_logger + +logger = get_logger() + + +def _aggregate_sender_stats(stats_list: list[dict]) -> dict: + """Aggregate weight stats across PP stages (sum-of-sums, max-of-maxes, min-of-mins). + + Each entry in stats_list comes from one PP-stage reporter. For colocated path + only one worker reports, so no aggregation is needed. For separated path, one + reporter per PP stage — their stats must be combined. + """ + result: dict = {} + + # Aggregate base stats. + base_entries = [entry["base"] for entry in stats_list if "base" in entry] + if base_entries: + result["base"] = { + "sum": sum(entry["sum"] for entry in base_entries), + "max": max(entry["max"] for entry in base_entries), + "min": min(entry["min"] for entry in base_entries), + } + + # Aggregate per-adapter LoRA stats. + all_adapter_names: set[str] = set() + for entry in stats_list: + if "lora" in entry: + all_adapter_names.update(entry["lora"].keys()) + if all_adapter_names: + lora_result: dict = {} + for adapter_name in sorted(all_adapter_names): + adapter_entries = [ + entry["lora"][adapter_name] + for entry in stats_list + if "lora" in entry and adapter_name in entry["lora"] + ] + if adapter_entries: + lora_result[adapter_name] = { + "sum": sum(entry["sum"] for entry in adapter_entries), + "max": max(entry["max"] for entry in adapter_entries), + "min": min(entry["min"] for entry in adapter_entries), + } + if lora_result: + result["lora"] = lora_result + + return result class ModelUpdateGroup: @@ -28,14 +74,45 @@ def __init__(self, src_cluster: Cluster, tgt_cluster: Cluster, pipeline_config: ] ) - def model_update(self, step=None): + def model_update(self, step=None, adapters_to_update: set[str] | None = None): if step % self.frequency != 0: return {} + kwargs = {"model_update_name": self.model_update_name} + if adapters_to_update is not None: + kwargs["adapters_to_update"] = sorted(adapters_to_update) + dataprotos: list[DataProto] = ray.get( [ - train_worker.start_model_update.remote(model_update_name=self.model_update_name) + train_worker.start_model_update.remote(**kwargs) for train_worker in self.src_cluster.workers ] ) + + # Post-sync verification gated by config flag (disabled by default). + if not self.pipeline_config.verify_model_after_sync: + return reduce_metrics_list([dataproto.meta_info["metrics"] for dataproto in dataprotos]) + + # Extract weight_stats separately before reduce_metrics_list (which would + # corrupt nested dicts via np.mean). Only non-empty stats from canonical + # reporter workers are included. + sender_stats_list = [ + dataproto.meta_info["weight_stats"] + for dataproto in dataprotos + if dataproto.meta_info.get("weight_stats") + ] + if sender_stats_list: + aggregated_stats = _aggregate_sender_stats(sender_stats_list) + if aggregated_stats: + # Fire verify_model on all target infer workers and wait (fail-fast). + verify_refs = [ + infer_worker.verify_model.remote(expected_stats=aggregated_stats) + for infer_worker in self.tgt_cluster.workers + ] + ray.get(verify_refs) + logger.info( + "[ModelUpdateGroup] verify_model ok tgt_workers=%d stats_keys=%s", + len(verify_refs), sorted(aggregated_stats.keys()), + ) + return reduce_metrics_list([dataproto.meta_info["metrics"] for dataproto in dataprotos]) diff --git a/roll/distributed/executor/worker.py b/roll/distributed/executor/worker.py index fd16a7bf7..372416b95 100644 --- a/roll/distributed/executor/worker.py +++ b/roll/distributed/executor/worker.py @@ -1,9 +1,12 @@ import logging import os import socket +import asyncio +import inspect +import threading from concurrent import futures from dataclasses import dataclass -from typing import Dict, Optional, List +from typing import Any, Dict, Optional, List import ray @@ -12,7 +15,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.distributed.scheduler.storage import SharedStorage from roll.utils.checkpoint_manager import download_model -from roll.utils.constants import RAY_NAMESPACE, STORAGE_NAME +from roll.utils.constants import GLOBAL_STORAGE_NAMESPACE, RAY_NAMESPACE, STORAGE_NAME from roll.utils.context_managers import state_offload_manger from roll.utils.logging import get_logger from roll.utils.network_utils import collect_free_port, get_node_ip @@ -42,6 +45,22 @@ def is_pipeline_last_stage(self): class Worker: + """ + Base worker class for distributed training and inference. + + A Worker wraps a strategy (e.g., FSDP, Megatron, vLLM) and provides a unified interface + for model loading, state management, and distributed communication setup. + + Workers are created by Cluster and run as Ray actors. Each worker has: + - A unique rank within its cluster + - A strategy that implements framework-specific logic + - Access to shared storage for cross-worker coordination + + Multi-pipeline support: + - Workers can belong to different pipelines in the same Ray cluster + - Pipeline isolation is achieved via pipeline-scoped rendezvous keys and port claims + - PIPELINE_ID environment variable identifies the pipeline (None for single-pipeline mode) + """ def __init__(self, worker_config: WorkerConfig): if worker_config.offload_nccl: @@ -53,8 +72,11 @@ def __init__(self, worker_config: WorkerConfig): self.rank = int(os.environ.get("RANK", 0)) self.world_size = int(os.environ.get("WORLD_SIZE", 1)) self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + # Pipeline identifier for multi-pipeline isolation. None in single-pipeline mode. + self.pipeline_id = os.environ.get("PIPELINE_ID") or None + # Use GLOBAL_STORAGE_NAMESPACE for cross-pipeline visibility (shared across all pipelines). self.shared_storage = SharedStorage.options( - name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE + name=STORAGE_NAME, get_if_exists=True, namespace=GLOBAL_STORAGE_NAMESPACE ).remote() if self.rank == 0: @@ -65,9 +87,13 @@ def __init__(self, worker_config: WorkerConfig): self.master_addr = os.environ["MASTER_ADDR"] self.master_port = int(os.environ["MASTER_PORT"]) - self.shared_storage.put.remote( - self.cluster_name, {"MASTER_ADDR": self.master_addr, "MASTER_PORT": self.master_port} - ) + # Guard against misconfiguration: ROLL_RAY_NAMESPACE implies multi-pipeline mode. + if self.pipeline_id is None and os.environ.get("ROLL_RAY_NAMESPACE"): + raise RuntimeError("PIPELINE_ID must be set when ROLL_RAY_NAMESPACE is set (multi-pipeline mode)") + # Pipeline-scoped rendezvous key for MASTER_ADDR/PORT discovery by other workers. + # Format: "{pipeline_id}:{cluster_name}" for multi-pipeline, "{cluster_name}" for single-pipeline. + rendezvous_key = f"{self.pipeline_id}:{self.cluster_name}" if self.pipeline_id else self.cluster_name + self.shared_storage.put.remote(rendezvous_key, {"MASTER_ADDR": self.master_addr, "MASTER_PORT": self.master_port}) # NOTE: 自定义Worker时根据需要配置rank_info self.rank_info = RankInfo( world_size=self.world_size, @@ -95,17 +121,33 @@ def get_node_ip(): @staticmethod def get_free_port(): + """ + Allocate a unique free port for distributed communication setup. + + Uses atomic try_put on SharedStorage to claim ports across all workers in the cluster. + In multi-pipeline mode, ports are claimed with pipeline_id as the value, enabling + per-pipeline port isolation while still detecting conflicts across pipelines. + + Returns: + A unique port number available for use. + + Raises: + RuntimeError: If no unique port can be allocated within MAX_PORT_RETRY_COUNT attempts. + """ shared_storage = SharedStorage.options( - name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE + name=STORAGE_NAME, get_if_exists=True, namespace=GLOBAL_STORAGE_NAMESPACE ).remote() + pipeline_id = os.environ.get("PIPELINE_ID") or None master_addr = Worker.get_node_ip() max_retry_count = int(os.environ.get("MAX_PORT_RETRY_COUNT", 1000)) retry_count = 0 master_port = collect_free_port() while retry_count < max_retry_count: master_addr_port_key = f"MASTER_ADDR_PORT:{master_addr}:{master_port}" - if ray.get(shared_storage.get.remote(master_addr_port_key)) is None: - ray.get(shared_storage.put.remote(master_addr_port_key, True)) + # try_put returns True if key was successfully claimed (didn't exist before). + # Value is pipeline_id for multi-pipeline, True for single-pipeline mode. + claimed = ray.get(shared_storage.try_put.remote(master_addr_port_key, pipeline_id if pipeline_id else True)) + if claimed: break master_port = collect_free_port() retry_count += 1 @@ -166,20 +208,27 @@ def initialize(self, pipeline_config, *args, **kwargs): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_states(self, *args, **kwargs): if getattr(self, "strategy", None) is not None: - self.strategy.load_states() + self.strategy.load_states(*args, **kwargs) else: self.logger.warning("worker has not strategy") @register(dispatch_mode=Dispatch.ONE_TO_ALL) def process_weights_after_loading(self): + """ + Process weights after model loading (e.g., weight slicing, dtype conversion). + + Uses _maybe_await so sync and async strategy implementations are both handled consistently, + matching the pattern used by setup_collective_group and destroy_collective_group. + Any exception from the async path is re-raised loudly by _maybe_await. + """ if getattr(self, "strategy", None) is not None: - self.strategy.process_weights_after_loading() + self._maybe_await(self.strategy.process_weights_after_loading()) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def offload_states(self, *args, **kwargs): if getattr(self, "strategy", None) is not None: - self.strategy.offload_states() + self.strategy.offload_states(*args, **kwargs) else: self.logger.warning("worker has not strategy") @@ -193,11 +242,84 @@ def setup_model_update(self, *args, **kwargs): self.strategy.setup_model_update(*args, **kwargs) def setup_collective_group(self, *args, **kwargs): + """ + Set up a distributed collective communication group (e.g., NCCL process group). + + Delegates to strategy.setup_collective_group(), which may be sync or async. + Uses _maybe_await to handle both cases transparently. + """ + if getattr(self, "strategy", None) is not None: + self._maybe_await(self.strategy.setup_collective_group(*args, **kwargs)) + else: + self.logger.warning("worker has not strategy") + + def destroy_collective_group(self, group_name: str, model_update_name: str | None = None) -> None: + """ + Destroy a collective communication group. + + Args: + group_name: Process group name to destroy. + model_update_name: Optional identifier for the model update session (used for bookkeeping cleanup). + """ if getattr(self, "strategy", None) is not None: - self.strategy.setup_collective_group(*args, **kwargs) + self._maybe_await(self.strategy.destroy_collective_group(group_name, model_update_name)) else: self.logger.warning("worker has not strategy") + @staticmethod + def _maybe_await(result: Any) -> Any: + """ + Execute a result that may be sync or async, returning the resolved value. + + This helper allows Worker methods to call strategy methods that may be either + synchronous or asynchronous without knowing the implementation at call site. + + Handles three scenarios: + 1. Non-awaitable result: Return directly (sync path) + 2. Awaitable with no running event loop: Use run_until_complete + 3. Awaitable with running event loop: Spawn a thread with asyncio.run + + Args: + result: Either a direct value or an awaitable (coroutine/Future). + + Returns: + The resolved value from the awaitable, or the original result if sync. + + Raises: + Any exception raised by the awaitable. + """ + if not inspect.isawaitable(result): + return result + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No event loop exists; create one and run the coroutine. + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if not loop.is_running(): + # Event loop exists but not running; use it directly. + return loop.run_until_complete(result) + + # Event loop is running (we're inside an async context). + # Spawn a daemon thread with a fresh loop to avoid blocking the current loop. + out: Dict[str, Any] = {} + err: Dict[str, BaseException] = {} + + def runner(): + try: + out["value"] = asyncio.run(result) + except BaseException as exc: + err["exc"] = exc + + t = threading.Thread(target=runner, daemon=True) + t.start() + t.join() + if err: + raise err["exc"] + return out.get("value") + def setup_p2p_collective_group(self, *args, **kwargs): if getattr(self, "strategy", None) is not None: self.strategy.setup_p2p_collective_group(*args, **kwargs) @@ -206,6 +328,7 @@ def setup_p2p_collective_group(self, *args, **kwargs): def start_model_update(self, *args, **kwargs): metrics = {} + weight_stats: dict = {} if getattr(self, "strategy", None) is not None: with state_offload_manger( strategy=self.strategy, @@ -213,13 +336,16 @@ def start_model_update(self, *args, **kwargs): metric_infix=f"{self.cluster_name}/model_update", load_kwargs={"include": [OffloadStateType.model_params]}, ): - exec_metrics: Dict = self.strategy.model_update(*args, **kwargs) + exec_result: Dict = self.strategy.model_update(*args, **kwargs) + # Separate weight_stats from timing metrics before key-prefixing. + # weight_stats is a nested dict that would be corrupted by reduce_metrics_list. + weight_stats = exec_result.pop("weight_stats", {}) metric_prefix = f"time/{self.cluster_name}/model_update" - metrics.update({f"{metric_prefix}/{k}": v for k, v in exec_metrics.items()}) + metrics.update({f"{metric_prefix}/{k}": v for k, v in exec_result.items()}) else: self.logger.warning("worker has not strategy") - output = DataProto(meta_info={"metrics": metrics}) + output = DataProto(meta_info={"metrics": metrics, "weight_stats": weight_stats}) return output def model_update_set_read_done_handle(self, *args, **kwargs): @@ -234,12 +360,137 @@ def update_parameter_in_bucket(self, *args, **kwargs): else: self.logger.warning("worker has not strategy") + def build_latest_bucket_cache( + self, checkpoint_version: int, adapter_name: str | None = None + ) -> None: + """ + Build a sender-side CPU bucket cache for selective parameter sync. + + In RLix's selective sync protocol, the sender (actor_train) builds bucket caches + containing the latest parameter state. These caches are then transferred to receiver + workers (actor_infer) during model update, avoiding redundant checkpoint loading. + + Args: + checkpoint_version: Unique version identifier for the cached weights snapshot. + adapter_name: Optional LoRA adapter name; None means base model. + + Raises: + RuntimeError: If strategy does not implement _build_latest_bucket_cache. + """ + if getattr(self, "strategy", None) is None: + raise RuntimeError("worker has no strategy") + fn = getattr(self.strategy, "_build_latest_bucket_cache", None) + if not callable(fn): + raise RuntimeError(f"{type(self.strategy).__name__} does not support build_latest_bucket_cache") + fn(checkpoint_version=int(checkpoint_version), adapter_name=adapter_name) + + def promote_active_checkpoint(self, checkpoint_version: int) -> None: + """ + Promote a checkpoint version as the active one for subsequent operations. + + After building bucket caches, this marks which version should be used for + the next selective sync. Enables atomic version switching without race conditions. + + Args: + checkpoint_version: Unique version identifier to promote. + + Raises: + RuntimeError: If strategy does not implement promote_active_checkpoint. + """ + if getattr(self, "strategy", None) is None: + raise RuntimeError("worker has no strategy") + promote = getattr(self.strategy, "promote_active_checkpoint", None) + if not callable(promote): + raise RuntimeError(f"{type(self.strategy).__name__} does not support promote_active_checkpoint") + promote(checkpoint_version=int(checkpoint_version)) + + def promote_active_adapter_checkpoint( + self, adapter_name: str, checkpoint_version: int + ) -> None: + """ + Promote a per-adapter checkpoint version as active (multi-LoRA support). + + Similar to promote_active_checkpoint but scoped to a specific LoRA adapter, + allowing independent version management per adapter. + + Args: + adapter_name: Name of the LoRA adapter. + checkpoint_version: Unique version identifier to promote. + + Raises: + RuntimeError: If strategy does not implement promote_active_adapter_checkpoint. + """ + if getattr(self, "strategy", None) is None: + raise RuntimeError("worker has no strategy") + fn = getattr(self.strategy, "promote_active_adapter_checkpoint", None) + if not callable(fn): + raise RuntimeError(f"{type(self.strategy).__name__} does not support promote_active_adapter_checkpoint") + fn(str(adapter_name), int(checkpoint_version)) + + def selective_sync_active_cache( + self, + *, + sync_id: str, + tgt_dp_ranks, + tgt_workers, + tgt_device_mapping, + tgt_num_gpus_per_worker: int, + comm_plan=None, + adapters_to_sync: list[str] | None = None, + ) -> None: + """ + Perform selective parameter synchronization from sender to receiver workers. + + This is the core RLix operation that transfers parameter buckets from actor_train + (sender) to actor_infer (receiver) workers. Uses pre-built bucket caches and + optional communication plans to minimize transfer overhead. + + Args: + sync_id: Unique identifier for this sync operation (for logging/tracing). + tgt_dp_ranks: Target data parallel ranks to sync to. + tgt_workers: Target worker actor handles. + tgt_device_mapping: Device mapping for target workers. + tgt_num_gpus_per_worker: Number of GPUs per target worker. + comm_plan: Optional pre-computed communication plan for optimized transfers. + adapters_to_sync: Optional list of LoRA adapters to sync (multi-LoRA mode). + + Raises: + RuntimeError: If strategy does not implement selective_sync_active_cache. + """ + if getattr(self, "strategy", None) is None: + raise RuntimeError("worker has no strategy") + fn = getattr(self.strategy, "selective_sync_active_cache", None) + if not callable(fn): + raise RuntimeError(f"{type(self.strategy).__name__} does not support selective_sync_active_cache") + self.logger.info( + "[rlix][selective_sync] worker_call_enter " + f"sync_id={sync_id} tgt_dp_ranks={list(tgt_dp_ranks)} " + f"tgt_num_gpus_per_worker={tgt_num_gpus_per_worker}" + ) + result = fn( + tgt_dp_ranks=tgt_dp_ranks, + tgt_workers=tgt_workers, + tgt_device_mapping=tgt_device_mapping, + tgt_num_gpus_per_worker=int(tgt_num_gpus_per_worker), + comm_plan=comm_plan, + adapters_to_sync=adapters_to_sync, + ) + self.logger.info(f"[rlix][selective_sync] worker_call_exit sync_id={sync_id}") + return result + def add_lora(self, *args, **kwargs): if getattr(self, "strategy", None) is not None: self.strategy.add_lora(*args, **kwargs) else: self.logger.warning("worker has not strategy") + def verify_model(self, *args, **kwargs): + """Delegate post-sync weight verification to the strategy layer.""" + if getattr(self, "strategy", None) is not None: + self.strategy.verify_model(*args, **kwargs) + else: + self.logger.warning("worker has not strategy") + @register(dispatch_mode=Dispatch.DP_MP_COMPUTE) def get_metrics(self, metric_names: Optional[List[str]] = None) -> DataProto: """ diff --git a/roll/distributed/scheduler/async_generate_scheduler.py b/roll/distributed/scheduler/async_generate_scheduler.py index 9a854f857..3ff2cc6f6 100644 --- a/roll/distributed/scheduler/async_generate_scheduler.py +++ b/roll/distributed/scheduler/async_generate_scheduler.py @@ -2,6 +2,7 @@ import enum import itertools import math +import os import queue import random import threading @@ -22,7 +23,7 @@ from roll.distributed.scheduler.generate_scheduler import GlobalCounter from roll.distributed.scheduler.protocol import DataProto from roll.models.model_providers import default_tokenizer_provider -from roll.utils.constants import RAY_NAMESPACE +from roll.utils.constants import RAY_NAMESPACE, rlix_env_vars from roll.utils.functionals import ( GenerateRequestType, concatenate_input_and_output, @@ -402,10 +403,15 @@ def set_scheduler( logger.info("use_additional_prompts is False, disable query and response filtering.") self.cluster_max_running_requests = self.pipeline_config.max_running_requests * self.actor_cluster.dp_size + pipeline_id = os.environ.get("PIPELINE_ID") or None + if pipeline_id is None and os.environ.get("ROLL_RAY_NAMESPACE"): + raise RuntimeError("PIPELINE_ID must be set when ROLL_RAY_NAMESPACE is set (multi-pipeline mode)") + counter_name = f"{pipeline_id}_DynamicSchedulerRequestCounter" if pipeline_id else "DynamicSchedulerRequestCounter" self.request_counter = GlobalCounter.options( - name="DynamicSchedulerRequestCounter", + name=counter_name, get_if_exists=True, namespace=RAY_NAMESPACE, + runtime_env={"env_vars": rlix_env_vars()}, ).remote() def reset_status(self): diff --git a/roll/distributed/scheduler/generate_scheduler.py b/roll/distributed/scheduler/generate_scheduler.py index 9c17b7984..4c32567c7 100644 --- a/roll/distributed/scheduler/generate_scheduler.py +++ b/roll/distributed/scheduler/generate_scheduler.py @@ -5,6 +5,7 @@ import math import uuid import time +import os from collections import defaultdict, deque from dataclasses import dataclass, fields from itertools import cycle @@ -34,6 +35,7 @@ from roll.utils.metrics.metrics_manager import DurationTracker from roll.utils.import_utils import safe_import_class from roll.utils.logging import get_logger +from roll.utils.constants import DO_TIME_SHARING logger = get_logger() @@ -603,6 +605,24 @@ def next_request_id(self): return request_id +@ray.remote +class GlobalCounter: + """Monotonically increasing counter as a Ray actor. + + Used to assign unique global IDs across distributed workers without coordination cost. + get_value() returns the current counter and increments it atomically (single-actor + execution guarantees no races). + """ + + def __init__(self): + self._value = 0 + + def get_value(self) -> int: + value = int(self._value) + self._value = value + 1 + return value + + @ray.remote class GenerateScheduler(Scheduler): def __init__(self, pipeline_config=None): @@ -1303,19 +1323,43 @@ def __init__(self, infer_cluster, pipeline_config, resource_manager): # Active DP ranks for request routing self.active_dp_ranks: Set[int] = set(range(self.infer_cluster.world_size)) # All ranks initially active self.routing_lock = asyncio.Lock() # Protect routing updates + self._op_lock = asyncio.Lock() # Serializes scheduling ops (shrink/expand) with adapter sync - async def generate_one_request(self, data: DataProto): - await self._check_suspend() + def get_active_dp_ranks(self) -> Set[int]: + """Return a copy of the current active DP ranks set. - src_rank = data.meta_info["src_rank"] - # Atomic routing assignment under lock to prevent TOCTOU race with shrink/expand - async with self.routing_lock: - # Least-loaded dispatch - if src_rank not in self.src_rank2_dp_rank: - dp_rank = self._get_least_active_dp_rank() - self.src_rank2_dp_rank[src_rank] = dp_rank + Used for state verification after initialization shrink operations. + """ + return set(self.active_dp_ranks) + + async def generate_one_request(self, data: DataProto): + src_rank = data.meta_info.get("src_rank") + global_step = data.meta_info.get("global_step") + t0 = time.time() + # NOTE: do not block while holding routing_lock. Re-check suspend after acquiring lock + # to avoid TOCTOU with shrink-to-zero and concurrent shrink/expand. + while True: + await self._check_suspend() + if src_rank is None: + src_rank = data.meta_info["src_rank"] + # Atomic routing assignment under lock to prevent TOCTOU race with shrink/expand + async with self.routing_lock: + if self.need_suspend: + continue + if not self.active_dp_ranks: + raise RuntimeError("No active DP ranks and not suspended") + + dp_rank = self.src_rank2_dp_rank.get(src_rank) + if dp_rank is not None and dp_rank not in self.active_dp_ranks: + self.src_rank2_dp_rank.pop(src_rank, None) + dp_rank = None + + # Least-loaded dispatch + if dp_rank is None: + dp_rank = self._get_least_active_dp_rank() + self.src_rank2_dp_rank[src_rank] = dp_rank + break - dp_rank = self.src_rank2_dp_rank[src_rank] request_id = f"{self.request_id}_{self.request_counter}" self.request_counter += 1 data.meta_info["request_id"] = request_id @@ -1330,6 +1374,7 @@ async def generate_one_request(self, data: DataProto): self.running_requests[dp_rank].remove(request_id) self.empty_notifier.set() # Cleanup tracking (on both success and abort paths) + self.request_id_2_dp_rank.pop(request_id, None) self.request_id_2_src_rank.pop(request_id, None) assert response_data is not None @@ -1502,13 +1547,23 @@ async def _rebalance_on_shrink(self, shrink_dp_ranks: List[int]) -> Dict[str, in Raises: RuntimeError: If shrink operation fails + + Side Effects: + - Sets need_suspend=True and clears suspend_notifier if shrinking to zero + active ranks (blocks future generate_one_request() until expansion). + - On exception: rolls back active_dp_ranks and need_suspend, re-sets + suspend_notifier to unblock waiters. + - See FIXME (G02-RULE-26.2) in this method for known locking constraints on + abort RPCs under routing_lock. """ keep_ranks = list(self.active_dp_ranks - set(shrink_dp_ranks)) - if not keep_ranks: - raise ValueError("Cannot shrink to zero active ranks") - old_active_ranks = self.active_dp_ranks.copy() + old_need_suspend = self.need_suspend self.active_dp_ranks = set(keep_ranks) + if not self.active_dp_ranks: + # Shrink-to-zero: block future generate_one_request() calls until expansion. + self.suspend_notifier.clear() + self.need_suspend = True try: total_aborted = 0 @@ -1527,6 +1582,10 @@ async def _rebalance_on_shrink(self, shrink_dp_ranks: List[int]) -> Dict[str, in + # FIXME(G02-RULE-26.2): abort RPCs and this drain loop run while routing_lock is held, + # blocking generate_one_request for the full drain duration. Fix: split into + # _shrink_routing_state (sync, under lock) + drain outside lock + brief re-lock to clear mappings. + # Acceptable for now because aborts are infrequent and expected to complete quickly. await asyncio.gather(*abort_futures) while True: @@ -1534,7 +1593,7 @@ async def _rebalance_on_shrink(self, shrink_dp_ranks: List[int]) -> Dict[str, in if remain == 0: break logger.info(f"Shrink: waiting for {len(shrink_dp_ranks)} workers {remain=} to finish abort") - await asyncio.sleep(3) + await asyncio.sleep(0.5) # Clear ALL mappings pointing to shrinking workers (not just in-flight) shrink_dp_ranks_set = set(shrink_dp_ranks) @@ -1553,6 +1612,9 @@ async def _rebalance_on_shrink(self, shrink_dp_ranks: List[int]) -> Dict[str, in except Exception as e: self.active_dp_ranks = old_active_ranks + self.need_suspend = old_need_suspend + if not self.need_suspend: + self.suspend_notifier.set() raise RuntimeError(f"Shrink failed: {e}") from e async def rebalance_on_expand(self, expand_dp_ranks: List[int]) -> Dict[str, int]: @@ -1597,18 +1659,16 @@ async def _rebalance_on_expand(self, expand_dp_ranks: List[int]) -> Dict[str, in Algorithm: Round-robin selection across old workers 1. Calculate proportional src_ranks to abort: src_ranks_to_keep = ceil(total * old_count / new_count) 2. Group existing src_ranks by dp_rank (only old workers) - 3. Round-robin iterate over old workers using cycle() + 3. Round-robin iterate over old workers using while loop with empty-streak guard 4. Select one src_rank at a time until remaining_to_abort reaches 0 5. Abort ALL requests from selected src_ranks 6. Clear src_rank mappings for reallocation to new workers Implementation Notes: - - Uses cycle() for infinite round-robin iteration over old workers - - Check at line 1146 (if not dp_rank in old_active_dp_ranks) is redundant - since dp_rank_to_src_ranks already contains only old workers, but kept as defensive guard - - Loop terminates when remaining_to_abort <= 0 or all worker lists are exhausted - - If all workers exhausted before reaching target, loop may cycle indefinitely - (no explicit check for empty state, but pop(0) will eventually empty all lists) + - Round-robin uses a while loop with empty_streak detection (not cycle()) to + terminate cleanly when all worker lists are exhausted before the abort target + - Calls self.resume() automatically when expanding from zero active ranks + (was_empty check), unblocking suspended generate_one_request() callers Args: expand_dp_ranks: DP ranks to add to active set (already validated) @@ -1628,9 +1688,12 @@ async def _rebalance_on_expand(self, expand_dp_ranks: List[int]) -> Dict[str, in # Calculate counts before updating active_dp_ranks old_dp_count = len(self.active_dp_ranks) old_active_dp_ranks = self.active_dp_ranks.copy() + was_empty = old_dp_count == 0 self.active_dp_ranks.update(expand_dp_ranks) new_dp_count = len(self.active_dp_ranks) + if was_empty and new_dp_count > 0: + self.resume() total_src_ranks = len(self.src_rank2_dp_rank) if total_src_ranks == 0: @@ -1652,20 +1715,26 @@ async def _rebalance_on_expand(self, expand_dp_ranks: List[int]) -> Dict[str, in # Round-robin selection: iterate over old workers and select one src_rank at a time # todo optimization:(yangpeng) take uneven dp load into consideration and do dynamic load balancing, not just RR + available_to_abort = sum(len(v) for v in dp_rank_to_src_ranks.values()) + if available_to_abort <= 0: + logger.info("Expand: no rebalancing possible (no src_ranks on old workers)") + return {"aborted": 0, "remapped": 0} + remaining_to_abort = min(src_ranks_to_abort, available_to_abort) selected_src_ranks = [] - remaining_to_abort = src_ranks_to_abort - for dp_rank in cycle(dp_rank_to_src_ranks.keys()): - if not dp_rank in old_active_dp_ranks: - continue - - if remaining_to_abort <= 0: - break - + dp_ranks_rr = list(dp_rank_to_src_ranks.keys()) + empty_streak = 0 + idx = 0 + while remaining_to_abort > 0: + dp_rank = dp_ranks_rr[idx % len(dp_ranks_rr)] + idx += 1 src_ranks_on_worker = dp_rank_to_src_ranks.get(dp_rank, []) if not src_ranks_on_worker: + empty_streak += 1 + if empty_streak >= len(dp_ranks_rr): + break continue + empty_streak = 0 selected_src_ranks.append(src_ranks_on_worker.pop(0)) - remaining_to_abort -= 1 # Remove from mapping and group by dp_rank for abort @@ -1771,25 +1840,68 @@ def _validate_calculated_ranks(self, ranks: List[int], mode: str) -> None: raise ValueError(f"[{mode}] DP rank {dp_rank} out of range [0, {self.infer_cluster.world_size})") # AST: State consistency + if mode not in ("shrink", "expand"): + raise ValueError(f"Invalid mode: {mode}") - for dp_rank in ranks: - if dp_rank not in self.active_dp_ranks: - raise ValueError(f"DP rank {dp_rank} not active {mode=}") + if mode == "shrink": + for dp_rank in ranks: + if dp_rank not in self.active_dp_ranks: + raise ValueError(f"[shrink] DP rank {dp_rank} not active") + else: + for dp_rank in ranks: + if dp_rank in self.active_dp_ranks: + raise ValueError(f"[expand] DP rank {dp_rank} already active") - async def shrink_workers(self, target_gpus: List[int]) -> Dict[str, Any]: + def _validate_dp_ranks_input(self, dp_ranks: List[int], *, mode: str) -> List[int]: + """Validate and normalize a dp_ranks list input. + + Checks: non-empty list[int], each value in [0, world_size), no duplicates. + Returns a normalized list of plain ints (coerces numpy ints etc.). + + Args: + dp_ranks: Candidate DP ranks to validate. + mode: Label used in error messages ("shrink" or "expand"). + + Returns: + Normalized list[int] with duplicates rejected. + + Raises: + ValueError: If list is empty, values out of range, or contains duplicates. + TypeError: If any element is not an int. + """ + if not isinstance(dp_ranks, list) or not dp_ranks: + raise ValueError(f"{mode}: dp_ranks must be a non-empty list[int]") + out: List[int] = [] + for x in dp_ranks: + if not isinstance(x, int): + raise TypeError(f"{mode}: dp_ranks must be list[int], got element {type(x).__name__}") + if not (0 <= x < self.infer_cluster.world_size): + raise ValueError(f"{mode}: dp_rank {x} out of range [0, {self.infer_cluster.world_size})") + out.append(int(x)) + if len(out) != len(set(out)): + raise ValueError(f"{mode}: dp_ranks contains duplicates") + return out + + async def shrink_workers(self, dp_ranks: List[int], skip_offload: bool = False) -> Dict[str, Any]: """Complete atomic shrink operation: validate → rebalance → offload → update routing. Orchestrates the full worker shrink process: - 1. Validates target_gpus input - 2. Calculates DP ranks to offload based on GPU overlap - 3. Validates calculated ranks against active state + 1. Validates dp_ranks input (type, range, duplicates) + 2. If skip_offload=True: filters to only currently-active ranks (idempotent no-op + if all ranks already inactive) + 3. If skip_offload=False: validates ranks are active (strict check) 4. Atomically (under routing_lock): - - Rebalances routing (aborts requests on shrinking workers) - - Offloads model states from shrinking workers - 5. Returns metrics for monitoring + - Rebalances routing: aborts in-flight requests on shrinking workers and drains + their queues (abort RPCs and drain also run under routing_lock — see FIXME + comment in _rebalance_on_shrink for G02-RULE-26.2) + 5. If skip_offload=False: offloads model states from shrinking workers to CPU + 6. Returns metrics for monitoring Args: - target_gpus: GPU IDs to free (e.g., [4, 5, 6, 7] to free second half of 8 GPUs) + dp_ranks: DP ranks to deactivate/offload. + skip_offload: If True, skip physical model offload and treat already-inactive + ranks as a no-op. Use when another coupled scheduler will handle the offload, + or during init-time shrink where ranks are not yet loaded. Returns: Metrics dict containing: @@ -1799,61 +1911,83 @@ async def shrink_workers(self, target_gpus: List[int]) -> Dict[str, Any]: - "offload_ranks": List of DP ranks that were offloaded Raises: - ValueError: If target_gpus invalid (empty, duplicates) or - calculated ranks invalid (not active, out of range) + ValueError: If dp_ranks invalid (empty, duplicates, out of range) or + ranks not active (when skip_offload=False) RuntimeError: If rebalance or offload operations fail Example: - # Shrink to free GPUs [4, 5, 6, 7] (second half of 8-GPU setup) - result = await scheduler.shrink_workers([4, 5, 6, 7]) + # Full shrink with offload + result = await scheduler.shrink_workers([2, 3]) # Returns: {"aborted": 10, "remapped": 5, "shrink_duration_ms": 2340.5, "offload_ranks": [2, 3]} + # Routing-only shrink (another scheduler handles offload) + result = await scheduler.shrink_workers([2, 3], skip_offload=True) + Side Effects: - Updates active_dp_ranks (removes offload_ranks) - Aborts in-flight requests on shrinking workers - Clears src_rank mappings for remapped environments - - Offloads model states from shrinking workers to CPU + - Offloads model states from shrinking workers to CPU (unless skip_offload=True) + - Serialized under _op_lock (prevents concurrent shrink/expand) + - If skip_offload=True and ranks already inactive: returns zero-metrics immediately """ - start_time = time.time() - - # VAL: VAL_NON_EMPTY, VAL_NO_DUPLICATES - self._validate_target_gpus(target_gpus, mode="shrink") - # Calculate DP ranks to offload - target_gpus = set(target_gpus) - offload_ranks = [dp for dp in range(self.infer_cluster.world_size) - if set(self._get_gpus_for_dp_rank(dp)).intersection(target_gpus)] - - # VAL: VAL_NON_EMPTY, state consistency check - self._validate_calculated_ranks(offload_ranks, mode="shrink") - - # Atomic operation under routing_lock - async with self.routing_lock: - # Rebalance (abort + update active_dp_ranks) - result = await self.rebalance_on_shrink(offload_ranks) - # release the lock before blocking offload so that active dp rank can work immediately - # Offload states from target workers - offload_refs = self.infer_cluster.offload_states_partial(offload_ranks, blocking=False) - await asyncio.gather(*[asyncio.wrap_future(ref.future()) for ref in offload_refs]) - - return {**result, "shrink_duration_ms": (time.time() - start_time) * 1000, - "offload_ranks": offload_ranks} - - async def expand_workers(self, target_gpus: List[int], skip_load: bool = False) -> Dict[str, Any]: + async with self._op_lock: + start_time = time.time() + offload_ranks = self._validate_dp_ranks_input(dp_ranks, mode="shrink") + + # In coupled scheduler flows (e.g., init paths), shrink can be called multiple times + # with skip_offload=True. Treat already-inactive ranks as no-op for idempotence. + if bool(skip_offload): + active_set = set(self.active_dp_ranks) + offload_ranks = [r for r in offload_ranks if r in active_set] + if not offload_ranks: + return { + "aborted": 0, + "remapped": 0, + "shrink_duration_ms": (time.time() - start_time) * 1000, + "offload_ranks": [], + } + else: + # VAL: VAL_NON_EMPTY, state consistency check (strict for normal shrink calls) + self._validate_calculated_ranks(offload_ranks, mode="shrink") + + # Atomic routing update under routing_lock + async with self.routing_lock: + # Rebalance (abort + update active_dp_ranks) + result = await self.rebalance_on_shrink(offload_ranks) + + if not bool(skip_offload): + # Offload states from target workers + # Use explicit keyword args so Ray signature binding stays stable across wrapped actor methods. + offload_refs = self.infer_cluster.offload_states_partial( + target_dp_ranks=offload_ranks, blocking=False + ) + await asyncio.gather(*[asyncio.wrap_future(ref.future()) for ref in offload_refs]) + + return { + **result, + "shrink_duration_ms": (time.time() - start_time) * 1000, + "offload_ranks": offload_ranks, + } + + async def expand_workers(self, dp_ranks: List[int], skip_load: bool = False) -> Dict[str, Any]: """Complete atomic expand operation: validate → load → rebalance → update routing. Orchestrates the full worker expand process: - 1. Validates target_gpus input - 2. Calculates DP ranks to restore based on GPU overlap - 3. Validates calculated ranks against active state (skip if skip_load=True) + 1. Validates dp_ranks input (type, range, duplicates) + 2. If skip_load=True: filters to only currently-inactive ranks (no-op if all + already active). Skips model loading; only updates routing state. + 3. If skip_load=False: validates ranks are inactive (strict check) 4. Atomically (under routing_lock): - Loads model states on expanding workers (skip if skip_load=True) - Rebalances routing (proportionally redistributes requests) 5. Returns metrics for monitoring Args: - target_gpus: GPU IDs to restore (e.g., [4, 5, 6, 7] to restore second half of 8 GPUs) - skip_load: If True, skip model loading and validation (use when model_update already loaded states). - This only updates active_dp_ranks to restore routing state without re-loading models. + dp_ranks: DP ranks to restore to active set. + skip_load: If True, skip model loading (use when model_update already synced + weights). In DO_TIME_SHARING mode (when skip_load=False), triggers selective + model weight sync via ModelUpdateService before loading vLLM states. Returns: Metrics dict containing: @@ -1863,47 +1997,92 @@ async def expand_workers(self, target_gpus: List[int], skip_load: bool = False) - "load_ranks": List of DP ranks that were restored Raises: - ValueError: If target_gpus invalid (empty, duplicates) or - calculated ranks invalid (already active, out of range) + ValueError: If dp_ranks invalid (empty, duplicates, out of range) or + ranks already active (when skip_load=False) RuntimeError: If load or rebalance operations fail Example: - # Expand to restore GPUs [4, 5, 6, 7] (second half of 8-GPU setup) - result = await scheduler.expand_workers([4, 5, 6, 7]) + # Full expand with load + result = await scheduler.expand_workers([2, 3]) # Returns: {"aborted": 3, "remapped": 3, "expand_duration_ms": 1850.2, "load_ranks": [2, 3]} - # After model_update already loaded states to all GPUs, just restore routing: - result = await scheduler.expand_workers([4, 5, 6, 7], skip_load=True) + # After model_update already loaded states, just restore routing: + result = await scheduler.expand_workers([2, 3], skip_load=True) Side Effects: - Updates active_dp_ranks (adds load_ranks) - Loads model states from CPU to expanding workers (unless skip_load=True) - Aborts some requests from old workers for proportional rebalancing - Clears src_rank mappings for rebalanced environments (will route to new workers) + - Serialized under _op_lock (prevents concurrent shrink/expand) + - If skip_load=True and ranks already active: returns zero-metrics immediately + - In DO_TIME_SHARING mode: syncs selected worker weights via ModelUpdateService + before loading vLLM states (avoids holding KV cache during weight sync) """ - start_time = time.time() - - # VAL: VAL_NON_EMPTY, VAL_NO_DUPLICATES - self._validate_target_gpus(target_gpus, mode="expand") + async with self._op_lock: + start_time = time.time() + load_ranks = self._validate_dp_ranks_input(dp_ranks, mode="expand") + + # Mirror shrink_workers(skip_offload=True): filter to only inactive ranks so already-active + # ranks are a no-op. Rebalancing still runs to update active_dp_ranks and trigger resume(). + if skip_load: + inactive_ranks = set(range(self.infer_cluster.world_size)) - self.active_dp_ranks + load_ranks = [r for r in load_ranks if r in inactive_ranks] + if not load_ranks: + return { + "aborted": 0, + "remapped": 0, + "expand_duration_ms": (time.time() - start_time) * 1000, + "load_ranks": [], + } - # Calculate DP ranks to restore - target_gpus = set(target_gpus) - load_ranks = [dp for dp in range(self.infer_cluster.world_size) - if set(self._get_gpus_for_dp_rank(dp)).issubset(target_gpus)] - - # VAL: VAL_NON_EMPTY, state consistency check - # Skip validation when skip_load=True because ranks may already be "active" in cluster - # (model states loaded by model_update) but not tracked in active_dp_ranks yet - if not skip_load: - self._validate_calculated_ranks(load_ranks, mode="expand") - load_refs = self.infer_cluster.load_states_partial(load_ranks, blocking=False) - await asyncio.gather(*[asyncio.wrap_future(ref.future()) for ref in load_refs]) + else: + self._validate_calculated_ranks(load_ranks, mode="expand") + # In RLix mode, delay vLLM KV cache init until after selective model update completes. + # This avoids holding large KV allocations during weight sync (which needs extra headroom). + if DO_TIME_SHARING and load_ranks: + pipeline_id = os.environ.get("PIPELINE_ID") or None + ray_namespace = os.environ.get("ROLL_RAY_NAMESPACE") or None + if not pipeline_id: + raise RuntimeError("DO_TIME_SHARING mode requires PIPELINE_ID to be set") + if not ray_namespace: + raise RuntimeError("DO_TIME_SHARING mode requires ROLL_RAY_NAMESPACE to be set") + try: + model_update_service = ray.get_actor( + f"{pipeline_id}_model_update_service", + namespace=ray_namespace, + ) + except Exception as e: + raise RuntimeError( + f"Failed to resolve ModelUpdateService for pipeline_id={pipeline_id!r} " + f"(expected name={pipeline_id}_model_update_service in namespace={ray_namespace!r})" + ) from e + ref = model_update_service.sync_selected_workers.remote(load_ranks) + await asyncio.wrap_future(ref.future()) + # vLLM may require post-load processing after weights are updated (e.g., FP8 hooks). + process_refs = self.infer_cluster.process_weights_after_loading(blocking=False) + await asyncio.gather(*[asyncio.wrap_future(ref.future()) for ref in process_refs]) + # Now that weights are synced, initialize full infer states (incl. KV cache) for rollout. + # Use explicit keyword args so Ray signature binding stays stable across wrapped actor methods. + load_refs = self.infer_cluster.load_states_partial( + target_dp_ranks=load_ranks, blocking=False + ) + await asyncio.gather(*[asyncio.wrap_future(ref.future()) for ref in load_refs]) + else: + # Use explicit keyword args so Ray signature binding stays stable across wrapped actor methods. + load_refs = self.infer_cluster.load_states_partial( + target_dp_ranks=load_ranks, blocking=False + ) + await asyncio.gather(*[asyncio.wrap_future(ref.future()) for ref in load_refs]) - # Atomic operation under routing_lock - async with self.routing_lock: + # Atomic operation under routing_lock + async with self.routing_lock: + # Rebalance (update active_dp_ranks + conditional abort) + result = await self.rebalance_on_expand(load_ranks) - # Rebalance (update active_dp_ranks + conditional abort) - result = await self.rebalance_on_expand(load_ranks) + return { + **result, + "expand_duration_ms": (time.time() - start_time) * 1000, + "load_ranks": load_ranks, + } - return {**result, "expand_duration_ms": (time.time() - start_time) * 1000, - "load_ranks": load_ranks} diff --git a/roll/distributed/scheduler/initialize.py b/roll/distributed/scheduler/initialize.py index 877e4ef18..7384f5e5a 100644 --- a/roll/distributed/scheduler/initialize.py +++ b/roll/distributed/scheduler/initialize.py @@ -17,7 +17,7 @@ wait_for_nodes, ) from roll.distributed.scheduler.log_monitor import LogMonitorListener -from roll.utils.constants import RAY_NAMESPACE +from roll.utils.constants import DO_TIME_SHARING, RAY_NAMESPACE from roll.utils.logging import get_logger from roll.platforms import current_platform @@ -46,6 +46,14 @@ def start_ray_cluster(): logger.info(f"Starting ray cluster: {cmd}") ret = subprocess.run(cmd, shell=True, capture_output=True) if ret.returncode != 0: + # Fallback for Ray CLI bug: some Ray builds crash with "is not a valid Sentinel" error + # in Click's deepcopy. When this occurs on rank 0 (head node), return False to signal + # the caller to use ray.init(address=None) for in-process local cluster startup instead. + # This only works for single-node runs; multi-node distributed runs will still fail. + stderr_text = ret.stderr.decode("utf-8", errors="ignore") + if rank == 0 and "is not a valid Sentinel" in stderr_text: + logger.warning("Ray CLI failed with Sentinel bug; falling back to in-process ray.init startup") + return False logger.error(f"Failed to start ray cluster: {cmd}") logger.error(f"ret.stdout: {ret.stdout}") logger.error(f"ret.stderr: {ret.stderr}") @@ -54,6 +62,24 @@ def start_ray_cluster(): def init(): + if DO_TIME_SHARING: + # Time-sharing mode: RLix scheduler manages the Ray cluster lifecycle. We only connect to + # the existing cluster via address="auto" and skip node startup, log monitoring, and + # atexit shutdown handlers. This allows multiple ROLL pipelines to share a single cluster. + runtime_env = { + "env_vars": current_platform.get_custom_env_vars(), + } + if not ray.is_initialized(): + ray.init( + address="auto", + namespace=RAY_NAMESPACE, + ignore_reinit_error=True, + log_to_driver=True, + runtime_env=runtime_env, + ) + logger.info("ROLL init: time-sharing mode enabled; leaving Ray cluster lifecycle to RLix scheduler") + return + rank = get_driver_rank() world_size = get_driver_world_size() master_addr = get_driver_master_addr() diff --git a/roll/distributed/scheduler/log_monitor.py b/roll/distributed/scheduler/log_monitor.py index ff64646b5..8500f0e5d 100644 --- a/roll/distributed/scheduler/log_monitor.py +++ b/roll/distributed/scheduler/log_monitor.py @@ -26,7 +26,7 @@ from ray._private.worker import print_to_stdstream, logger as monitor_logger, print_worker_logs from roll.distributed.scheduler.driver_utils import get_driver_rank, wait_for_nodes, get_driver_world_size -from roll.utils.constants import RAY_NAMESPACE +from roll.utils.constants import DO_TIME_SHARING, RAY_NAMESPACE from roll.utils.logging import get_logger logger = get_logger() @@ -218,6 +218,18 @@ def wait_for_grace_stop(self): time.sleep(0.1) def stop(self): + if DO_TIME_SHARING: + # Time-sharing mode: cluster is shared across multiple pipelines managed by RLix scheduler. + # Only clean up local resources (file handles, thread) - skip cluster teardown to avoid + # disrupting other co-tenant pipelines. + logger.info("LogMonitorListener.stop: time-sharing mode - cleaning up local resources only") + StdPublisher.close_file_handlers() + time.sleep(0.2) + try: + self.log_monitor_thread.join(2) + except Exception: + pass + return StdPublisher.close_file_handlers() time.sleep(5) self.log_monitor_thread.join(2) @@ -235,6 +247,12 @@ def stop(self): subprocess.run(cmd, shell=True, capture_output=True) def start(self): + if DO_TIME_SHARING: + # Time-sharing mode: skip log monitoring setup since the Ray cluster is shared and + # managed by RLix scheduler. ExceptionMonitor actors would collide across pipelines, + # and atexit shutdown handlers would incorrectly tear down the shared cluster. + logger.info("LogMonitorListener.start: time-sharing mode - skipping log monitoring setup") + return atexit.register(self.stop) if self.rank == 0: diff --git a/roll/distributed/scheduler/resource_manager.py b/roll/distributed/scheduler/resource_manager.py index ac9810f41..91ff41580 100644 --- a/roll/distributed/scheduler/resource_manager.py +++ b/roll/distributed/scheduler/resource_manager.py @@ -7,6 +7,7 @@ from roll.platforms import current_platform from roll.utils.ray_utils import get_visible_gpus, get_node_rank +from roll.utils.rlix_compat import ROLL_RESOURCE_MANAGER_ACTOR_NAME, RLIX_NAMESPACE class ResourceManager: @@ -15,14 +16,27 @@ def __init__(self, num_gpus_per_node, num_nodes): The ResourceManager centrally manages the required GPU/CPU resources, facilitating Ray to deploy Actors on specified GPU devices. """ - available_resources = ray.available_resources() - available_gpu = available_resources.get(current_platform.ray_device_key, 0) + # NOTE: Some environments can expose Ray "GPU" resources even when `torch.cuda.is_available()` + # is False inside the driver/worker process (e.g., CUDA init failure or restricted device access). + # For GPU placement groups we must use Ray's built-in "GPU" resource key to ensure bundles are schedulable. + ray_device_key = current_platform.ray_device_key + device_control_env_var = getattr(current_platform, "device_control_env_var", None) + if int(num_gpus_per_node or 0) > 0: + ray_device_key = "GPU" + if not device_control_env_var: + device_control_env_var = "CUDA_VISIBLE_DEVICES" + + # Use cluster_resources (total capacity) rather than available_resources + # (currently free) because fractional GPU allocations like num_gpus=0.01 + # on coordinator actors temporarily reduce available_gpu below the integer total. + cluster_resources = ray.cluster_resources() + available_gpu = cluster_resources.get(ray_device_key, 0) nodes_maybe_used = [] ray_nodes = ray.nodes() for node in ray_nodes: resource = node["Resources"] - node_gpu_num = int(resource.get(current_platform.ray_device_key, 0)) + node_gpu_num = int(resource.get(ray_device_key, 0)) if node_gpu_num >= num_gpus_per_node: nodes_maybe_used.append(node) nodes_maybe_used = sorted(nodes_maybe_used, key=lambda n: n["Resources"]["CPU"]) @@ -43,19 +57,18 @@ def __init__(self, num_gpus_per_node, num_nodes): for i in range(self.num_nodes): node = nodes_maybe_used[i] node_cpu = int(node["Resources"]["CPU"]) - bundles.append({current_platform.ray_device_key: self.gpu_per_node, "CPU": max(node_cpu / 2, 1)}) + bundles.append({ray_device_key: self.gpu_per_node, "CPU": max(node_cpu / 2, 1)}) - self.placement_groups = [ray.util.placement_group([bundle]) for bundle in bundles] + self.placement_groups = [ + ray.util.placement_group([bundle]) + for i, bundle in enumerate(bundles) + ] ray.get([pg.ready() for pg in self.placement_groups]) gpu_ranks = ray.get([ get_visible_gpus.options( placement_group=pg, - **( - {"num_gpus": self.gpu_per_node} - if current_platform.ray_device_key == "GPU" - else {"resources": {current_platform.ray_device_key: self.gpu_per_node}} - ) - ).remote(current_platform.device_control_env_var) + **({"num_gpus": self.gpu_per_node} if self.gpu_per_node > 0 else {}) + ).remote(device_control_env_var) for pg in self.placement_groups ]) print(f"gpu ranks: {gpu_ranks}") @@ -75,13 +88,28 @@ def __init__(self, num_gpus_per_node, num_nodes): node = nodes_maybe_used[0] node_cpu = int(node["Resources"]["CPU"]) bundles = [{"CPU": node_cpu}] * self.num_nodes - self.placement_groups = [ray.util.placement_group([bundle]) for bundle in bundles] + self.placement_groups = [ + ray.util.placement_group([bundle]) + for i, bundle in enumerate(bundles) + ] ray.get([pg.ready() for pg in self.placement_groups]) self.node_ranks = [0] self.node2pg: Dict[int, PlacementGroup] = {} for node_rank, placement_group in zip(self.node_ranks, self.placement_groups): self.node2pg[node_rank] = placement_group + def get_state(self) -> dict: + """Return serializable state for proxy construction.""" + return { + "num_nodes": self.num_nodes, + "gpu_per_node": self.gpu_per_node, + "num_gpus": self.num_gpus, + "node_ranks": list(self.node_ranks), + "gpu_ranks": list(getattr(self, "gpu_ranks", [])), + "node2pg": dict(self.node2pg), + "placement_groups": list(self.placement_groups), + } + def nodes_placement_group(self, node_rank) -> PlacementGroup: """ mesh table是 m×n,获取第node_rank nodel上gpu_rank的PlacementGroup,用于把ray.Actor部署到指定的GPU上 @@ -148,3 +176,79 @@ def allocate_placement_group(self, world_size, device_mapping: List[int] = None) assert len(allocated_pg) == world_size return allocated_pg + + +# --------------------------------------------------------------------------- +# Singleton actor + proxy for RLix control-plane mode +# --------------------------------------------------------------------------- + +# Use imported constants from rlix.protocol.types for consistency +_ROLL_RM_ACTOR_NAME = ROLL_RESOURCE_MANAGER_ACTOR_NAME +_ROLL_RM_NAMESPACE = RLIX_NAMESPACE + + +@ray.remote(num_cpus=0, max_restarts=0, max_task_retries=0) +class _RollResourceManagerActor(ResourceManager): + """Cluster-wide singleton Ray actor wrapping ResourceManager for RLix control-plane mode.""" + pass + + +class RollResourceManagerProxy(ResourceManager): + """Synchronous drop-in for ResourceManager backed by a shared Ray actor. + + Used in RLix control-plane mode so all concurrent pipelines share a single + ResourceManager actor (and its placement groups) rather than each pipeline + creating its own, which would exhaust cluster GPU resources. + + State (placement groups, node/gpu topology) is fetched from the actor once + in __init__ and cached locally. This makes allocate_placement_group() safe + to call from within async Ray actors — no blocking ray.get() RPCs at call time. + + destroy_placement_group() is a no-op: the singleton actor owns the PGs and + they are cleaned up when the orchestrator tears down the actor. + """ + + def __init__(self, num_gpus_per_node: int) -> None: + # Get or lazily create the cluster-wide singleton ResourceManager actor. + # All concurrent pipelines share one actor so placement groups are allocated once. + try: + actor = ray.get_actor(_ROLL_RM_ACTOR_NAME, namespace=_ROLL_RM_NAMESPACE) + except ValueError: + try: + actor = ( + _RollResourceManagerActor.options( + name=_ROLL_RM_ACTOR_NAME, + namespace=_ROLL_RM_NAMESPACE, + get_if_exists=True, + max_restarts=0, + max_task_retries=0, + ).remote(num_gpus_per_node=num_gpus_per_node, num_nodes=None) + ) + except Exception: + actor = ray.get_actor(_ROLL_RM_ACTOR_NAME, namespace=_ROLL_RM_NAMESPACE) + + self._actor = actor + state = ray.get(self._actor.get_state.remote()) + self.num_nodes = state["num_nodes"] + self.gpu_per_node = state["gpu_per_node"] + self.num_gpus = state["num_gpus"] + self.node_ranks = state["node_ranks"] + self.gpu_ranks = state["gpu_ranks"] + self.node2pg = state["node2pg"] + self.placement_groups = state["placement_groups"] + + # Fail fast if this pipeline was configured with a different num_gpus_per_node + # than the singleton was created with. Silent mismatch causes wrong placement sizing. + assert self.gpu_per_node == num_gpus_per_node, ( + f"num_gpus_per_node mismatch: singleton actor has {self.gpu_per_node}, " + f"caller expected {num_gpus_per_node}. All pipelines must use the same value." + ) + + # nodes_placement_group and allocate_placement_group are inherited from ResourceManager. + # State is cached locally in __init__, so these methods work without blocking RPCs. + + def destroy_placement_group(self): + raise NotImplementedError( + "RollResourceManagerProxy is a read-only proxy to a shared singleton ResourceManager actor. " + "Placement groups are owned by the singleton actor; teardown is handled by the orchestrator." + ) diff --git a/roll/distributed/scheduler/rollout_scheduler.py b/roll/distributed/scheduler/rollout_scheduler.py index 6ce801c31..4c7d354c4 100644 --- a/roll/distributed/scheduler/rollout_scheduler.py +++ b/roll/distributed/scheduler/rollout_scheduler.py @@ -1,4 +1,6 @@ import asyncio +import math +import os import random import time from dataclasses import dataclass, field @@ -7,6 +9,7 @@ import ray from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from ray._private import profiling +from ray.runtime_env import RuntimeEnv from tqdm import tqdm from roll.distributed.executor.cluster import Cluster @@ -17,7 +20,9 @@ from roll.pipeline.agentic.agentic_config import EnvManagerConfig from roll.utils.functionals import append_to_dict from roll.utils.import_utils import safe_import_class +from roll.utils.constants import DO_TIME_SHARING, RAY_NAMESPACE, rlix_env_vars from roll.utils.logging import get_logger +from roll.utils.rlix_compat import COORDINATOR_ACTOR_NAME_PREFIX, get_pipeline_namespace, ProgressReport logger = get_logger() @@ -202,11 +207,17 @@ def check_and_log_hung_envs(self): @dataclass class GroupData: + """Holds state for a single episode (group of rollouts). + + Created when an episode starts, tracks all rollouts submitted by env workers + until the group_size is reached. Used for progress tracking and FIFO ordering. + """ group_id: int episode_id: int create_step: int + created_at: float # Timestamp for FIFO priority; used to detect oldest unfinished episode rollouts: List[DataProto] = field(default_factory=list) - running_rollouts: int = 0 + running_rollouts: int = 0 # Count of envs currently processing this episode class GroupQueue: def __init__( @@ -256,7 +267,11 @@ def shutdown(self): def advance_group(self, create_step): assert not self.quit self.groups[self.next_episode_id] = GroupData( - group_id=self.group_id, episode_id=self.next_episode_id, create_step=create_step) + group_id=self.group_id, + episode_id=self.next_episode_id, + create_step=create_step, + created_at=time.time(), + ) self.next_episode_id += 1 def _advance_step(self, create_step): @@ -317,9 +332,28 @@ async def get_episode_id(self, env_id: Optional[int] = None) -> Optional[int]: await self.progress.wait() return None - def put(self, episode_id, start_step, rollout): - if episode_id not in self.groups: # ignore rollouts from outdated episode - return + def put(self, episode_id, start_step, rollout) -> Dict[str, Any]: + """Submit a rollout from an env worker to this episode's group. + + Args: + episode_id: Episode this rollout belongs to + start_step: Training step when this rollout was generated + rollout: Rollout data (None indicates env is exiting) + + Returns: + Dict with keys: + - "status": One of "ignored", "exit", "filtered", "completed", "partial" + - "non_null_added": 1 if rollout is not None, else 0 + + Status meanings: + - "ignored": Episode was already removed (outdated) + - "exit": All rollouts in group are None (env shutdown) + - "filtered": Group rejected by GroupFilter (e.g., reward threshold) + - "completed": Group reached group_size and passed filter + - "partial": Group not yet complete + """ + if episode_id not in self.groups: # ignore rollouts from outdated episode + return {"status": "ignored", "non_null_added": 0} group = self.groups[episode_id] assert start_step >= group.create_step, f"{start_step=} {group.create_step=}" group.rollouts.append(rollout) @@ -327,6 +361,7 @@ def put(self, episode_id, start_step, rollout): if all(rollout is None for rollout in group.rollouts): logger.info(f"GroupQueue: group {self.group_id} exit") self.complete.set() + return {"status": "exit", "non_null_added": 0} elif self.group_filter.filter(group_id=self.group_id, episode_id=episode_id, group=group.rollouts): logger.info(f"filter rollout group {group.group_id} episode {group.episode_id}") self.group_filter_count += 1 @@ -334,9 +369,12 @@ def put(self, episode_id, start_step, rollout): if self.env_monitor: self.env_monitor.cleanup_episode(self.group_id, episode_id) self.advance_group(create_step=self.current_step) + return {"status": "filtered", "non_null_added": 0 if rollout is None else 1} else: self.complete.set() self.progress_bar.update(self.group_size) + return {"status": "completed", "non_null_added": 0 if rollout is None else 1} + return {"status": "partial", "non_null_added": 0 if rollout is None else 1} async def get(self) -> GroupData: while True: @@ -355,14 +393,42 @@ async def get(self) -> GroupData: @ray.remote class GroupQueueManager: + """Central coordinator for collecting rollouts from environment workers. + + Manages per-group GroupQueues, tracks progress for RLix scheduler integration, + and provides batch retrieval for RolloutScheduler.get_batch(). + + In time-sharing mode (DO_TIME_SHARING=True), reports progress to the + RlixCoordinator at 2% bucket granularity for GPU scheduling decisions. + """ + def __init__(self, config, env_manager_config: EnvManagerConfig, mode): self.mode = mode + self.config = config self.env_manager_config = env_manager_config self.group_size = self.env_manager_config.group_size self.progress_bar = tqdm(desc=f"{self.mode} rollout progress(total trajectory)", mininterval=self.env_manager_config.max_traj_per_env) self.pending_gets = set() self.rollout_complete = {} + self.pipeline_id = os.environ.get("PIPELINE_ID") or None + # Both train and val modes report to coordinator so gap_ratio accounts for all infer capacity. + self._rlix_enabled = bool(DO_TIME_SHARING) + self.adapter_id = self.env_manager_config.tags[0] if getattr(self.env_manager_config, "tags", None) else None # lora adapter + self._rlix_coordinator = None + if self._rlix_enabled: + if not self.pipeline_id: + raise RuntimeError("DO_TIME_SHARING mode requires PIPELINE_ID to be set") + coordinator_name = f"{COORDINATOR_ACTOR_NAME_PREFIX}{self.pipeline_id}" + coordinator_namespace = get_pipeline_namespace(self.pipeline_id) + try: + self._rlix_coordinator = ray.get_actor(coordinator_name, namespace=coordinator_namespace) + except Exception as exc: + raise RuntimeError( + f"Failed to resolve coordinator {coordinator_name!r} in namespace {coordinator_namespace!r}. " + "GroupQueueManager expects the coordinator actor to exist before startup." + ) from exc + group_filter_cls = safe_import_class(env_manager_config.group_filter_cls) assert group_filter_cls self.group_filter = group_filter_cls(config, env_manager_config, mode) @@ -370,9 +436,11 @@ def __init__(self, config, env_manager_config: EnvManagerConfig, mode): if self.mode == "train": self.async_generation_ratio = config.async_generation_ratio self.max_traj_per_env = env_manager_config.max_traj_per_env if config.rollout_batch_size > 0 else None + self.rollout_batch_size = int(config.rollout_batch_size) else: self.async_generation_ratio = 0 self.max_traj_per_env = env_manager_config.max_traj_per_env if config.val_batch_size > 0 else None + self.rollout_batch_size = int(config.val_batch_size) # Initialize env activity monitor first (before creating GroupQueues) self.group_queue: Dict[int, GroupQueue] = {} @@ -405,6 +473,167 @@ def __init__(self, config, env_manager_config: EnvManagerConfig, mode): self.total = 0 self.waiting = 0 + # === RLix Progress Tracking === + # Tracks rollout collection progress for time-sharing scheduler decisions. + # Progress is reported at 2% bucket granularity to minimize coordinator overhead. + self._progress_last_bucket: Optional[int] = None # Last emitted 2% bucket (0-50); emits only on change + self._progress_new_batch = False # True when new batch starts; forces immediate emit + self._progress_total_required_estimated = self._estimate_total_required() # Target: batch_size * num_return_sequences + self._progress_collected_estimated = 0 # Trajectories collected so far (clamped to total_required) + self._progress_episode_non_null: Dict[Tuple[int, int], int] = {} # Per-episode count for rollback on filter/reject + self._progress_active = False # True only between begin_progress_batch/end_progress_batch + + def _resolve_num_return_sequences(self) -> int: + # RLix progress should be expressed in "trajectory units" that match the rollout batch contract. + # + # In ROLL's request scheduler, the effective number of finished samples required for a "batch" is + # `batch_size * num_return_sequences` (not scaled by async_generation_ratio). + raw = None + generating_args = getattr(self.env_manager_config, "generating_args", None) + if generating_args is not None: + raw = getattr(generating_args, "num_return_sequences", None) + if raw is None: + actor_infer = getattr(self.config, "actor_infer", None) + generating_args = getattr(actor_infer, "generating_args", None) if actor_infer is not None else None + if generating_args is not None: + raw = getattr(generating_args, "num_return_sequences", None) + + n = 1 if raw is None else int(raw) + if n <= 0: + raise RuntimeError(f"Invalid num_return_sequences={raw!r}; expected > 0") + return n + + def _estimate_total_required(self) -> int: + """Calculate total trajectories required per training step. + + Formula: rollout_batch_size * num_return_sequences + + Note: Does NOT depend on async_generation_ratio. Async controls + overlap/pausing behavior, not the required sample count for a batch. + + Returns: + 0 if max_traj_per_env is None (unbounded mode), else positive int. + """ + if self.max_traj_per_env is None: + return 0 + # Denominator for progress is the per-step rollout batch target (trajectory units). + # It must not depend on async_generation_ratio (async controls overlap/pausing, not the required sample count). + num_return_sequences = self._resolve_num_return_sequences() + return int(self.rollout_batch_size) * int(num_return_sequences) + + def _mark_new_batch(self) -> None: + self._progress_total_required_estimated = self._estimate_total_required() + self._progress_new_batch = True + + def _reset_progress_for_new_batch(self, current_train_step: Optional[int]) -> None: + """Reset progress tracking and emit report for a new batch cycle. + + Called at the start of each training step (advance_step) or when clearing + state after suspension/rollback (clear). Only tracks progress in bounded mode. + + Args: + current_train_step: The training step number, or None if unknown. + """ + if self.max_traj_per_env is not None: + self._progress_collected_estimated = 0 + self._progress_episode_non_null.clear() + self._mark_new_batch() + self._maybe_emit_progress(current_train_step=current_train_step) + + def _compute_progress(self) -> Tuple[int, int, int, Optional[float]]: + """Compute current progress state for this scheduler stream. + + Returns: + Tuple of (total_required, collected, remaining, oldest_unfinished_ts): + - total_required: Target trajectories for this step + - collected: Trajectories collected (clamped to total_required) + - remaining: max(total_required - collected, 0) + - oldest_unfinished_ts: Creation time of oldest incomplete episode, + used for FIFO priority and hang detection; None if no incomplete episodes. + """ + if self.max_traj_per_env is None: + # Unbounded mode: do not report progress in Phase 3. + return 0, 0, 0, None + + total_required = self._progress_total_required_estimated + collected = min(self._progress_collected_estimated, total_required) + + oldest_ts: Optional[float] = min( + (group.created_at + for gq in self.group_queue.values() + for group in gq.groups.values() + if len(group.rollouts) < self.group_size), + default=None, + ) + + remaining = max(total_required - collected, 0) + return total_required, collected, remaining, oldest_ts + + def _maybe_emit_progress(self, *, current_train_step: Optional[int]) -> None: + """Emit progress report to coordinator if conditions are met. + + Suppressed when _progress_active is False (before begin or after end), + preventing stale emissions from late put() calls after batch deactivation. + + Emits when: + - 2% bucket changed (bucket != self._progress_last_bucket), OR + - Batch complete (remaining == 0), OR + - New batch started (self._progress_new_batch == True) + + Uses fire-and-forget RPC to avoid blocking rollout collection. + Coordinator aggregates reports from all streams (train, val, LoRAs) + and forwards a single aggregated report to the central scheduler. + + Args: + current_train_step: Current training step (for metrics), or None if unknown. + """ + if not self._progress_active: + return + if not self._rlix_enabled: + return + if self.max_traj_per_env is None: + return + if self._rlix_coordinator is None: + raise RuntimeError("RLIX progress enabled but coordinator handle is missing") + if not self.pipeline_id: + raise RuntimeError("RLIX progress enabled but PIPELINE_ID is missing") + + total_required, collected, remaining, oldest_ts = self._compute_progress() + if total_required <= 0: + return + + percent_completed = float(collected) / float(max(total_required, 1)) + # 2% buckets (0..50). Bucket 0 means 0% completed, bucket 50 means 100% completed. + bucket = math.floor(percent_completed * 50) + + should_emit = ( + bucket != self._progress_last_bucket + or remaining == 0 + or self._progress_new_batch + ) + if not should_emit: + return + + emitted_for_new_batch = self._progress_new_batch + self._progress_last_bucket = bucket + self._progress_new_batch = False + + report = ProgressReport( + pipeline_id=str(self.pipeline_id), + step_target_trajectories=int(total_required), + fifo_timestamp=time.time(), + metrics={ + "mode": self.mode, + "collected": int(self._progress_collected_estimated), + "bucket": int(bucket), + "new_batch": bool(emitted_for_new_batch), + "current_train_step": current_train_step, + "adapter_id": self.adapter_id, + }, + ) + # Fire-and-forget to coordinator; coordinator aggregates and forwards to rlix scheduler. + self._rlix_coordinator.report_progress_from_scheduler.remote(report) + def collect_metrics(self): group_filter_count = 0 for group_queue in self.group_queue.values(): @@ -413,6 +642,13 @@ def collect_metrics(self): return {"scheduler/group_filter_count": group_filter_count} def clear(self): + """Reset scheduler state for a new training step or after suspension. + + Cancels pending batch retrieval tasks and clears all group queue state. + Called when rolling back to a checkpoint or when starting fresh after a + suspend operation. Progress deactivation is handled separately by + end_progress_batch() via the RolloutScheduler lifecycle. + """ self.rollout_complete = {} for get_task in self.pending_gets: get_task.cancel() @@ -421,9 +657,46 @@ def clear(self): group_queue.clear() def advance_step(self, step): + """Advance to a new training step. + + Propagates step advancement to all group queues (creates/expires async + groups, wakes waiters). Does NOT reset progress; that is now handled + by begin_progress_batch() at the start of each get_batch() request. + + Args: + step: The new training step number, or None if step is unknown. + """ for group_queue in self.group_queue.values(): group_queue.advance_step(step) + def begin_progress_batch(self, current_train_step: Optional[int]) -> None: + """Activate progress tracking for a new batch collection cycle. + + Called by RolloutScheduler.get_batch() to mark the start of active demand. + Resets counters, sets _progress_active, and emits a new_batch report to + the coordinator so the scheduler allocates GPU resources for this stream. + + Args: + current_train_step: The training step number, or None if unknown. + """ + self._progress_active = True + self._reset_progress_for_new_batch(current_train_step=current_train_step) + + def end_progress_batch(self) -> None: + """Deactivate progress tracking for the completed batch. + + Called by RolloutScheduler.get_batch() after batch collection finishes + (success, empty, or exception). Sets _progress_active = False to suppress + late put()-driven emissions, then tells the coordinator to remove this + stream from aggregation so stale demand does not distort scheduling. + """ + self._progress_active = False + if self._rlix_coordinator is not None: + self._rlix_coordinator.clear_progress_stream.remote( + mode=self.mode, + adapter_id=self.adapter_id, + ) + async def get_episode_id(self, group_id, env_id=None): """ Get the next episode ID for an environment. @@ -470,9 +743,31 @@ def put(self, group_id, episode_id, start_step, rollout: DataProto, env_id=None) self.env_monitor.record_activity(group_id, env_id, episode_id, rollout) self.waiting += 1 - self.group_queue[group_id].put(episode_id, start_step, rollout) + put_result = self.group_queue[group_id].put(episode_id, start_step, rollout) + + # === Progress Tracking (bounded mode only) === + # Track collected trajectories for RLix scheduler progress reports. + # Must handle rollback when groups are filtered/rejected. + if self.max_traj_per_env is not None: + status = str(put_result.get("status", "")) + non_null_added = int(put_result.get("non_null_added", 0)) + episode_key = (group_id, episode_id) + + # Increment progress for each valid trajectory added to the episode + if non_null_added: + self._progress_episode_non_null[episode_key] = self._progress_episode_non_null.get(episode_key, 0) + 1 + self._progress_collected_estimated += non_null_added + + # Rollback progress if episode was rejected (filtered) or env exited + if status in {"filtered", "exit"}: + rolled_back = self._progress_episode_non_null.pop(episode_key, 0) + self._progress_collected_estimated = max(self._progress_collected_estimated - rolled_back, 0) + # Episode completed: stop tracking but keep the count (already added to progress) + elif status == "completed": + self._progress_episode_non_null.pop(episode_key, None) self.waiting -= 1 self.total += 1 + self._maybe_emit_progress(current_train_step=int(start_step) if start_step is not None else None) async def get_batch(self, batch_size, current_step) -> List[DataProto]: """ @@ -509,7 +804,16 @@ async def wait_a_episode(): done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) while done and (batch_size < 0 or len(ret) < batch_size): d = done.pop() - group = await d + try: + group = await d + except asyncio.CancelledError: + # Best-effort cleanup: cancellation is expected during clear()/shutdown(). + continue + except Exception as e: + # Fail-fast: clean up any outstanding tasks and surface the error. + for t in pending: + t.cancel() + raise RuntimeError(f"GroupQueue.get() task failed (group_id={d.get_name()!r})") from e group_rollout = group.rollouts self.total -= len(group_rollout) @@ -540,7 +844,18 @@ async def wait_a_episode(): return ret class RolloutScheduler(RolloutMockMixin): - """ + """Orchestrates rollout generation for a single mode (train or val). + + Coordinates three main components: + 1. GroupQueueManager: Collects rollouts from env workers + 2. RequestScheduler: Routes GPU inference requests + 3. Cluster (EnvManager): Runs environment rollout loops + + In time-sharing mode, integrates with RLix scheduler for GPU allocation: + - Reports progress to RlixCoordinator for scheduling decisions + - Supports shrink/expand for GPU reassignment between pipelines + - Uses async __init__ to allow concurrent pipeline startup + Usage: # User should control load_states/offload_states in pipeline by themselves. actor_infer @@ -555,48 +870,101 @@ class RolloutScheduler(RolloutMockMixin): rollout() ray.get(train_rollout_scheduler.shutdown.remote()) """ - def __init__(self, config, env_manager_config: EnvManagerConfig, resource_manager, infer_cluster, mode, collator=None): + async def __init__(self, config, env_manager_config: EnvManagerConfig, resource_manager, infer_cluster, mode, collator=None): + """Initialize the rollout scheduler actor. + + NOTE: This is an async __init__. Avoid blocking calls here. + Env worker initialization is deferred to get_batch() to allow + concurrent pipeline startup in multi-pipeline scenarios. + + Args: + config: Pipeline configuration (e.g., PPOConfig) + env_manager_config: Environment manager configuration + resource_manager: Ray placement group resource manager + infer_cluster: Inference cluster for GPU routing + mode: "train" or "val" + collator: Optional data collator for preprocessing + """ + # NOTE: This actor is async. Avoid blocking calls in __init__ and log each construction phase + # to pinpoint startup stalls (e.g., placement group allocation, child actor creation). + logger.info(f"[RolloutScheduler] __init__ enter mode={mode}") self.config = config self.env_manager_config = env_manager_config self.resource_manager = resource_manager self.infer_cluster = infer_cluster self.mode = mode + self.pipeline_id = os.environ.get("PIPELINE_ID") or None + if self.pipeline_id is None and os.environ.get("ROLL_RAY_NAMESPACE"): + raise RuntimeError("PIPELINE_ID must be set when ROLL_RAY_NAMESPACE is set (multi-pipeline mode)") env_num = self.env_manager_config.world_size * self.env_manager_config.max_env_num_per_worker + runtime_env = RuntimeEnv(env_vars=rlix_env_vars()) + self.env_output_queue = GroupQueueManager.options( - name=f"GroupQueueManager-{mode}", + name=( + # Include env-manager name so multiple train schedulers (one per tag) do not collide on actor name. + f"{self.pipeline_id}_group_queue_manager_{self.env_manager_config.name}_{mode}" + if self.pipeline_id + else f"GroupQueueManager-{self.env_manager_config.name}-{mode}" + ), + namespace=RAY_NAMESPACE, scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=ray.get_runtime_context().get_node_id(), soft=False), - max_concurrency = env_num + 1 # reserve extra one for get_batch + runtime_env=runtime_env, + max_concurrency=env_num + 1, # reserve extra one for get_batch ).remote( self.config, self.env_manager_config, mode ) + # Include env-manager name so multiple train schedulers (one per tag) do not collide on actor name. self.generate_scheduler = RequestScheduler.options( - name=f"RequestScheduler-{self.env_manager_config.name}-{mode}", + name=( + f"{self.pipeline_id}_request_scheduler_{self.env_manager_config.name}_{mode}" + if self.pipeline_id + else f"RequestScheduler-{self.env_manager_config.name}-{mode}" + ), + namespace=RAY_NAMESPACE, scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=ray.get_runtime_context().get_node_id(), soft=False, ), - max_concurrency = env_num + 1 # reserve extra one for suspend/resume + runtime_env=runtime_env, + max_concurrency=env_num + 1, # reserve extra one for suspend/resume ).remote(infer_cluster=self.infer_cluster, pipeline_config=config, resource_manager=self.resource_manager) - - self.es_manager: Any = Cluster( - name=self.env_manager_config.name, - worker_cls=self.env_manager_config.worker_cls, - resource_manager=self.resource_manager, - worker_config=self.env_manager_config, + logger.info(f"[RolloutScheduler] created RequestScheduler mode={self.mode}") + + logger.info(f"[RolloutScheduler] creating env Cluster mode={self.mode} name={self.env_manager_config.name}") + # Cluster.__init__ calls ray.get() for topology resolution, which blocks the event loop. + # Run it in a thread executor to avoid freezing this async actor's constructor. + loop = asyncio.get_event_loop() + self.es_manager: Any = await loop.run_in_executor( + None, + lambda: Cluster( + name=self.env_manager_config.name, + worker_cls=self.env_manager_config.worker_cls, + resource_manager=self.resource_manager, + worker_config=self.env_manager_config, + ), ) - self.es_manager.initialize( + logger.info(f"[RolloutScheduler] created env Cluster mode={self.mode} name={self.env_manager_config.name}") + # Do not block with ray.get() inside this async actor's constructor. + # We kick off initialization on env workers and await it in get_batch(). + self._es_initialize_refs = self.es_manager.initialize( pipeline_config=self.config, generate_scheduler=self.generate_scheduler, output_queue=self.env_output_queue, collator=collator, mode=self.mode, + blocking=False, + ) + self._es_initialized = False + logger.info( + f"[RolloutScheduler] submitted env initialize mode={self.mode} " + f"num_refs={len(self._es_initialize_refs) if self._es_initialize_refs else 0}" ) self.rollout_task = None @@ -606,20 +974,35 @@ def __init__(self, config, env_manager_config: EnvManagerConfig, resource_manage # Initialize rollout mock mechanism from mixin self._init_rollout_mock() + logger.info(f"[RolloutScheduler] __init__ exit mode={self.mode}") - async def shutdown(self): + async def shutdown(self, timeout: float = 10.0): if self.rollout_task is None: return - await asyncio.gather(*self.es_manager.stop(blocking=False)) - await self.env_output_queue.shutdown.remote() - await self.generate_scheduler.abort_request.remote() - await self.rollout_task + + async def _do_shutdown(): + await asyncio.gather(*self.es_manager.stop(blocking=False)) + await self.env_output_queue.shutdown.remote() + await self.generate_scheduler.abort_request.remote() + await self.rollout_task + + try: + await asyncio.wait_for(_do_shutdown(), timeout=timeout) + except asyncio.TimeoutError: + logger.warning(f"shutdown timed out after {timeout}s, force-skipping") + if self.rollout_task is not None and not self.rollout_task.done(): + self.rollout_task.cancel() self.rollout_task = None async def suspend(self): await self.generate_scheduler.suspend.remote() + async def resume(self): + # Delegate resume so partial-GPU pipeline can re-enable request dispatch after expand. + await self.generate_scheduler.resume.remote() + async def _run_rollout_loop(self, seed): + logger.info(f"[RolloutScheduler] start _run_rollout_loop seed={seed} mode={self.mode}") await asyncio.gather(*self.es_manager.run_rollout_loop(seed, blocking=False)) async def _get_batch(self, batch_size, global_step): @@ -627,59 +1010,99 @@ async def _get_batch(self, batch_size, global_step): async def get_batch(self, data: DataProto, batch_size): global_step = data.meta_info["global_step"] + logger.info(f"[RolloutScheduler] get_batch enter mode={self.mode} global_step={global_step} batch_size={batch_size}") # MOCK MODE: Load pre-recorded data, skip rollout (from mixin) if self._should_load_mock(global_step): return await self._load_mock_batch(global_step) + if not self._es_initialized: + logger.info(f"[RolloutScheduler] awaiting env worker initialize mode={self.mode}") + init_refs = self._es_initialize_refs or [] + await asyncio.gather(*init_refs) + self._es_initialized = True + logger.info(f"[RolloutScheduler] env worker initialize done mode={self.mode}") + # start env manager if self.rollout_task is None: seed = random.randint(0, 1000000) if self.mode == "train" else self.config.seed self.rollout_task = asyncio.create_task(self._run_rollout_loop(seed)) + logger.info(f"[RolloutScheduler] update_step start mode={self.mode} global_step={global_step}") await asyncio.gather(*self.es_manager.update_step(global_step, blocking=False)) - await self.env_output_queue.advance_step.remote(global_step) - await self.generate_scheduler.resume.remote() - - get_task = asyncio.create_task(self._get_batch(batch_size, global_step)) - await asyncio.wait({get_task, self.rollout_task}, return_when=asyncio.FIRST_COMPLETED) - if self.rollout_task.done() and self.rollout_task.exception() is not None: - await self.rollout_task - data_batch = await get_task - if batch_size <= 0: - await self.rollout_task - self.rollout_task = None - await self.env_output_queue.clear.remote() - - if len(data_batch) == 0: - return None + logger.info(f"[RolloutScheduler] update_step done mode={self.mode} global_step={global_step}") - metrics = {} - get_batch_return_start_time = None - for d_item in data_batch: - get_batch_return_start_time = d_item.meta_info.pop("get_batch_return_start_time", None) - append_to_dict(metrics, d_item.meta_info["metrics"]) - if get_batch_return_start_time is not None: - metrics["time/get_batch_cost_gqm"] = time.time() - get_batch_return_start_time - metrics.update(await self.env_output_queue.collect_metrics.remote()) - batch = DataProto.concat(data_batch) - batch.meta_info["metrics"] = metrics - batch.meta_info["get_batch_return_start_time"] = time.time() - - # DUMP MODE: Save merged batch (from mixin) - await self._maybe_dump_batch(batch, global_step) - - return batch + logger.info(f"[RolloutScheduler] advance_step start mode={self.mode} global_step={global_step}") + await self.env_output_queue.advance_step.remote(global_step) + logger.info(f"[RolloutScheduler] advance_step done mode={self.mode} global_step={global_step}") + # In time-sharing mode (DO_TIME_SHARING=True), external RLix coordinator controls suspend/resume. + # In standalone mode, this scheduler must resume generate_scheduler before each batch. + if not DO_TIME_SHARING: + await self.generate_scheduler.resume.remote() + + # Activate progress tracking for this batch. + await self.env_output_queue.begin_progress_batch.remote( + int(global_step) if global_step is not None else None + ) - async def shrink_sampler(self, target_gpus: List[int]) -> Dict[str, Any]: - """Thin wrapper: Delegate shrink operation to RequestScheduler. + try: + get_task = asyncio.create_task(self._get_batch(batch_size, global_step)) + await asyncio.wait({get_task, self.rollout_task}, return_when=asyncio.FIRST_COMPLETED) + if self.rollout_task.done() and self.rollout_task.exception() is not None: + await self.rollout_task + data_batch = await get_task + logger.info( + f"[RolloutScheduler] env_output_queue.get_batch returned mode={self.mode} " + f"global_step={global_step} items={len(data_batch) if data_batch else 0}" + ) + if batch_size <= 0: + await self.rollout_task + self.rollout_task = None + await self.env_output_queue.clear.remote() + + if len(data_batch) == 0: + return None + + metrics = {} + get_batch_return_start_time = None + for d_item in data_batch: + get_batch_return_start_time = d_item.meta_info.pop("get_batch_return_start_time", None) + append_to_dict(metrics, d_item.meta_info["metrics"]) + if get_batch_return_start_time is not None: + metrics["time/get_batch_cost_gqm"] = time.time() - get_batch_return_start_time + metrics.update(await self.env_output_queue.collect_metrics.remote()) + batch = DataProto.concat(data_batch) + batch.meta_info["metrics"] = metrics + batch.meta_info["get_batch_return_start_time"] = time.time() + + # DUMP MODE: Save merged batch (from mixin) + await self._maybe_dump_batch(batch, global_step) + + return batch + finally: + # Deactivate progress: suppress late emissions and clear coordinator/scheduler state. + # Awaited (not fire-and-forget) to serialize with the next begin_progress_batch call + # and prevent lifecycle races on the GQM actor (which has max_concurrency > 1). + await self.env_output_queue.end_progress_batch.remote() + + async def shrink_sampler(self, dp_ranks: List[int], skip_offload: bool = False) -> Dict[str, Any]: + """Offload model weights from specified DP ranks and deactivate them for routing. + + Called by RLix scheduler during GPU reassignment. After shrink, + the specified DP ranks will not receive inference requests. + + This is a thin wrapper delegating to RequestScheduler.shrink_workers() + for atomic execution under routing_lock. v4.6 ARCHITECTURAL CHANGE: RolloutScheduler no longer performs validation, calculation, or state management. All worker lifecycle operations are now owned by RequestScheduler for atomic execution under routing_lock. Args: - target_gpus: GPU IDs to free (e.g., [4,5] for actor_train or [6,7] for critic) + dp_ranks: DP ranks to offload (e.g., [0, 1] for first two ranks) + skip_offload: If True, skip physical weight offload. Use when + another coupled scheduler (e.g., critic) already offloaded + the same GPUs, avoiding redundant offload operations. Returns: Dict with metrics from RequestScheduler.shrink_workers(): @@ -705,24 +1128,30 @@ async def shrink_sampler(self, target_gpus: List[int]) -> Dict[str, Any]: start_time = time.time() # Delegate complete shrink operation to RequestScheduler (atomic under routing_lock) - result = await self.generate_scheduler.shrink_workers.remote(target_gpus) + result = await self.generate_scheduler.shrink_workers.remote(dp_ranks, skip_offload=bool(skip_offload)) # Add timing from RolloutScheduler perspective result["rollout_scheduler_duration_ms"] = (time.time() - start_time) * 1000 return result - async def expand_sampler(self, target_gpus: List[int], skip_load: bool = False) -> Dict[str, Any]: - """Thin wrapper: Delegate expand operation to RequestScheduler. + async def expand_sampler(self, dp_ranks: List[int], skip_load: bool = False) -> Dict[str, Any]: + """Load model weights to specified DP ranks and activate them for routing. + + Called by RLix scheduler during GPU reassignment. After expand, + the specified DP ranks will receive inference requests again. + + This is a thin wrapper delegating to RequestScheduler.expand_workers() + for atomic execution under routing_lock. v4.6 ARCHITECTURAL CHANGE: RolloutScheduler no longer performs validation, calculation, or state management. All worker lifecycle operations are now owned by RequestScheduler for atomic execution under routing_lock. Args: - target_gpus: GPU IDs to restore (e.g., [4,5] for actor_train or [6,7] for critic) - skip_load: If True, skip model loading (use when model_update already loaded states). - This only updates active_dp_ranks to restore routing state. + dp_ranks: DP ranks to load/activate (e.g., [0, 1] for first two ranks) + skip_load: If True, skip model loading (use when model_update already + loaded states). This only updates active_dp_ranks to restore routing. Returns: Dict with metrics from RequestScheduler.expand_workers(): @@ -751,7 +1180,7 @@ async def expand_sampler(self, target_gpus: List[int], skip_load: bool = False) start_time = time.time() # Delegate complete expand operation to RequestScheduler (atomic under routing_lock) - result = await self.generate_scheduler.expand_workers.remote(target_gpus, skip_load) + result = await self.generate_scheduler.expand_workers.remote(dp_ranks, skip_load) # Add timing from RolloutScheduler perspective result["rollout_scheduler_duration_ms"] = (time.time() - start_time) * 1000 diff --git a/roll/distributed/scheduler/storage.py b/roll/distributed/scheduler/storage.py index da4c9e1d5..f2716fe1e 100644 --- a/roll/distributed/scheduler/storage.py +++ b/roll/distributed/scheduler/storage.py @@ -1,3 +1,21 @@ +"""Cluster-wide key-value store for multi-pipeline coordination. + +This module provides SharedStorage, a Ray actor that persists across pipeline lifecycles +and enables coordination between ephemeral pipelines in time-sharing mode. + +Usage: +- Port allocation claims: MASTER_ADDR_PORT:: -> pipeline_id +- Checkpoint path caching: model_path: -> local_path + +# TODO(rlix): Delegate SharedStorage to rlix orchestrator similar to ResourceManager pattern. +# Currently rlix directly calls delete_port_claims/delete_prefix on this actor. +# Consider creating SharedStorageProxy and moving actor lifecycle management to rlix, +# with actor name/namespace defined in rlix.protocol.types (like ROLL_RESOURCE_MANAGER_ACTOR_NAME). +# This would: +# - Remove hardcoded "SHARED_STORAGE_ACTOR" / "global_storage_namespace" from rlix/orchestrator.py +# - Allow rlix to control storage lifecycle independent of ROLL +# - Enable per-tenant storage isolation if needed +""" import ray from roll.utils.logging import get_logger @@ -7,17 +25,127 @@ @ray.remote class SharedStorage: + """Cluster-wide key-value store shared across all ROLL pipelines. + + In time-sharing mode, this actor is shared across all ROLL pipelines and used for: + - Port allocation claims (MASTER_ADDR_PORT:* keys) to prevent port collisions + - Model checkpoint path caching to avoid redundant downloads + + The actor persists for the lifetime of the Ray cluster, enabling coordination + between ephemeral pipelines that come and go. + """ def __init__(self): self._storage = {} def put(self, key, data): + """Store data under the given key, overwriting any existing value. + + Args: + key: The storage key (typically a string identifier). + data: The value to store (will be placed in Ray object store). + + Note: + This overwrites existing values silently. Use try_put() for atomic put-if-absent. + """ ref = ray.put(data) self._storage[key] = ref + def try_put(self, key, data) -> bool: + """Atomically store data only if the key does not already exist. + + This is used for lock-free resource claiming (e.g., port allocation). + Multiple pipelines can race to claim a port; only one succeeds. + + Args: + key: The storage key to claim. + data: The value to store (typically pipeline_id for port claims). + + Returns: + True if the key was successfully claimed (did not exist). + False if the key already exists (claim failed). + """ + if key in self._storage: + return False + ref = ray.put(data) + self._storage[key] = ref + return True + def get(self, key): + """Retrieve data stored under the given key. + + Args: + key: The storage key to look up. + + Returns: + The stored value, or None if the key does not exist. + """ ref = self._storage.get(key) if ref is None: logger.warning(f"{key} is not found in storage") return None return ray.get(ref) + + def delete(self, key) -> None: + """Remove a single key from storage. + + Args: + key: The storage key to delete. + + Note: + Silently ignores keys that don't exist. + """ + self._storage.pop(key, None) + + def delete_prefix(self, prefix: str) -> int: + """Delete all keys that start with the given prefix. + + Used for bulk cleanup when a pipeline terminates, removing all its + associated entries from shared storage. + + Args: + prefix: The key prefix to match (e.g., "pipeline_123:"). + + Returns: + The number of keys deleted. + + Raises: + ValueError: If prefix is not a string. + """ + if not isinstance(prefix, str): + raise ValueError(f"prefix must be str, got {type(prefix).__name__}") + keys = [k for k in self._storage.keys() if isinstance(k, str) and k.startswith(prefix)] + for k in keys: + self._storage.pop(k, None) + return len(keys) + + def delete_port_claims(self, pipeline_id: str) -> int: + """Release all port claims owned by a specific pipeline. + + When a pipeline terminates, this removes all MASTER_ADDR_PORT:* entries + that were claimed by that pipeline, allowing the ports to be reused. + + Args: + pipeline_id: The ID of the pipeline whose port claims should be released. + + Returns: + The number of port claims deleted. + + Raises: + ValueError: If pipeline_id is not a non-empty string. + """ + if not isinstance(pipeline_id, str) or pipeline_id == "": + raise ValueError("pipeline_id must be non-empty str") + deleted = 0 + for key in list(self._storage.keys()): + if not isinstance(key, str) or not key.startswith("MASTER_ADDR_PORT:"): + continue + ref = self._storage.get(key) + if ref is None: + continue + value = ray.get(ref) + if value != pipeline_id: + continue + self._storage.pop(key, None) + deleted += 1 + return deleted diff --git a/roll/distributed/strategy/deepspeed_strategy.py b/roll/distributed/strategy/deepspeed_strategy.py index f81fb19da..646ae0a67 100644 --- a/roll/distributed/strategy/deepspeed_strategy.py +++ b/roll/distributed/strategy/deepspeed_strategy.py @@ -552,6 +552,14 @@ def load_checkpoint(self, load_dir, tag="checkpoint", **kwargs): self.model.load_checkpoint(load_dir, tag=tag, **kwargs) def collect_lora_params(self): + # Both ZeRO-3 and non-ZeRO-3 paths export only the default adapter; + # fail fast before any silent single-adapter export in multi-LoRA configs. + adapters = self.worker_config.model_args.adapters + if adapters is not None and len(adapters) > 1: + raise RuntimeError( + "DeepSpeed LoRA collection does not support multi-LoRA. " + f"Configured adapters: {sorted(adapters.keys())}" + ) peft_model = self.unwrap_model() if not self.ds_config.is_zero3(): lora_state_dict = get_peft_model_state_dict(peft_model) @@ -570,7 +578,8 @@ def collect_lora_params(self): def setup_model_update(self, infer_cluster, model_update_name: str): assert model_update_name not in self.weight_updaters - is_lora = self.worker_config.model_args.lora_target is not None + # Use adapters (not lora_target) so explicit multi-LoRA configs are recognized. + is_lora = self.worker_config.model_args.adapters is not None self.weight_updaters[model_update_name] = DeepSpeedWeightUpdater( pipeline_config=self.worker.pipeline_config, infer_cluster=infer_cluster, diff --git a/roll/distributed/strategy/fsdp2_strategy.py b/roll/distributed/strategy/fsdp2_strategy.py index 389ff9cb2..97cace9f5 100644 --- a/roll/distributed/strategy/fsdp2_strategy.py +++ b/roll/distributed/strategy/fsdp2_strategy.py @@ -636,7 +636,8 @@ def _prepare_fsdp2_model( finally: clear_fsdp2_init_context() - self.is_lora = self.worker_config.model_args.lora_target is not None + # Use adapters (not lora_target) so explicit multi-LoRA configs are recognized. + self.is_lora = self.worker_config.model_args.adapters is not None return model, torch_dtype, cp_size @@ -1254,7 +1255,8 @@ def train_step( def setup_model_update(self, infer_cluster, model_update_name: str): assert model_update_name not in self.weight_updaters - is_lora = self.worker_config.model_args.lora_target is not None + # Use adapters (not lora_target) so explicit multi-LoRA configs are recognized. + is_lora = self.worker_config.model_args.adapters is not None self.weight_updaters[model_update_name] = FSDP2WeightUpdater( pipeline_config=self.worker.pipeline_config, infer_cluster=infer_cluster, diff --git a/roll/distributed/strategy/megatron_strategy.py b/roll/distributed/strategy/megatron_strategy.py index 68fa26b37..7cd48b62a 100644 --- a/roll/distributed/strategy/megatron_strategy.py +++ b/roll/distributed/strategy/megatron_strategy.py @@ -1,10 +1,14 @@ +import io import math import os +import pickle import random +import threading from collections import defaultdict from contextlib import nullcontext +from dataclasses import asdict, dataclass from functools import partial -from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple import numpy as np import ray @@ -46,7 +50,11 @@ from roll.distributed.strategy.strategy import InferenceStrategy, TrainStrategy from roll.models.model_providers import default_processor_provider, default_tokenizer_provider from roll.platforms import current_platform -from roll.third_party.megatron.model_update import MegatronWeightUpdater +from roll.third_party.megatron.model_update import ( + MegatronWeightUpdater, + gather_all_hf_weights, + gather_weights_meta_cross_pp, +) from roll.third_party.megatron.offload_states_patch import ( MegatronOffloadStateType, bind_megatron_offload_states_func, @@ -56,6 +64,7 @@ from roll.third_party.megatron.optimizer import get_megatron_optimizer from roll.third_party.megatron.tensor_parallel import vocab_parallel_entropy from roll.utils.constants import ( + DO_TIME_SHARING, DIST_OPTIMIZER_DIR, IGNORE_INDEX, OPTIMIZER_NAME, @@ -63,10 +72,20 @@ SCHEDULER_NAME, ) from roll.utils.context_managers import disable_gradients +from roll.utils.cuda_ipc_utils import MultiprocessingSerializer from roll.utils.dynamic_batching import make_micro_batch_iter_for_dynamic_batching from roll.utils.functionals import append_to_dict, reduce_metrics, adjust_sequence_length +from roll.utils.collective import collective from roll.utils.logging import get_logger +from roll.utils.lora_routing import resolve_microbatch_lora_name +from roll.utils.network_utils import collect_free_port, get_node_ip from roll.utils.offload_states import OffloadStateType +from roll.utils.send_recv_utils import ( + _bucket_named_tensors, + compute_weight_stats, + monkey_patch_torch_reductions, + named_tensors_from_bucket, +) from roll.utils.sequence_packing import make_micro_batch_iter_for_sequence_packing, restore_results_order @@ -76,6 +95,31 @@ logger = get_logger() +def _safe_dist_barrier(subgroup=None): + """Synchronize ranks at a barrier, handling two common failure modes. + + Safe to call even when ``dist`` is not initialized (single-process or + workers that skipped dist init) — the barrier becomes a no-op in that case. + + For NCCL backend, passes ``device_ids`` explicitly to avoid a hang that + occurs when no default CUDA device is set (NCCL requires an explicit device + for the barrier collective; see PyTorch issue fixed after v2.9.1). + + Args: + subgroup: Optional process-group subset (e.g. TP group, PP group). + When None, synchronizes all ranks in the global default group. + """ + if not dist.is_available() or not dist.is_initialized(): + return + kwargs = {} + if dist.get_backend() == "nccl" and current_platform.is_available(): + kwargs["device_ids"] = [current_platform.current_device()] + if subgroup is None: + dist.barrier(**kwargs) + else: + dist.barrier(group=subgroup, **kwargs) + + class MegatronInferStrategy(InferenceStrategy): strategy_name = "megatron_infer" @@ -89,6 +133,13 @@ def __init__(self, worker: Worker): # maybe put max_grad_norm into training_args as transformers do, rather # than in pipeline_config (PPOConfig) config_dict.update({"max_grad_norm": self.worker.pipeline_config.max_grad_norm}) + # Filter out strategy_config keys (e.g., is_lora_optimizer_isolated) that are not + # valid TrainingArguments fields — otherwise TrainingArguments(**config_dict) raises TypeError. + supported_keys = set(TrainingArguments.__dataclass_fields__.keys()) + dropped_keys = [k for k in config_dict if k not in supported_keys] + if dropped_keys: + logger.warning(f"Ignore non-TrainingArguments keys: {dropped_keys}") + config_dict = {k: v for k, v in config_dict.items() if k in supported_keys} logger.info(f"training_args: {config_dict}") self.megatron_train_args = TrainingArguments(**config_dict) self.model = None @@ -112,6 +163,8 @@ def initialize(self, model_provider): self.forward_backward_func = get_forward_backward_func() self.seq_length = self.worker.pipeline_config.sequence_length + # True when PEFT LoRA adapters are configured; gates adapter-routing code paths. + self.is_lora = self.worker_config.model_args.adapters is not None self.worker.rank_info.dp_rank = mpu.get_data_parallel_rank(with_context_parallel=False) self.worker.rank_info.dp_size = mpu.get_data_parallel_world_size(with_context_parallel=False) @@ -127,7 +180,7 @@ def initialize(self, model_provider): logger.info("Set variable_seq_lengths to True when use dynamic batching and pipeline parallel.") logger.info(f"{self.model.get_models()}") - dist.barrier() + _safe_dist_barrier() def get_data_input(self, batch: DataProto): def broadcast_obj(obj, group): @@ -229,6 +282,7 @@ def forward_step( return results def _get_feature_on_this_cp_rank(self, feature: torch.Tensor, feature_name: str = "input_ids") -> torch.Tensor: + """Slice a feature tensor for this Context Parallel rank.""" return self.models_unwrapped[0].get_batch_on_this_cp_rank({feature_name: feature}, dim3_keys=[])[feature_name] def _get_unpad_seqlen(self, attention_mask: torch.Tensor, pad_to_multiple_of: int = 256) -> int: @@ -400,7 +454,19 @@ def _unpack_sequences(self, output_tensor, cu_seqlens_padded): yield local_chunk def inner_forward_step(self, loss_func, data_iterator: Iterator[DataProto], model): + """Single micro-batch forward step called by Megatron's forward_backward_func. + + Multi-LoRA: ``set_adapter`` is called per microbatch because different + microbatches may target different LoRA adapters. + + """ data = next(data_iterator) + # Multi-LoRA: activate the correct adapter for this microbatch before forward. + if self.is_lora: + routing = resolve_microbatch_lora_name(data.non_tensor_batch) + for m in self.models_unwrapped: + m.set_adapter(routing.lora_name) + # get_data_input broadcasts batch.batch to all PP/TP/CP ranks, so tensors are always available. input_ids = data.batch["input_ids"] attention_mask = data.batch["attention_mask"] labels = data.batch["labels"] if "labels" in data.batch else None # labels is only used for sft @@ -418,6 +484,9 @@ def inner_forward_step(self, loss_func, data_iterator: Iterator[DataProto], mode attention_mask = self._get_feature_on_this_cp_rank(attention_mask, "attention_mask") if labels is not None: labels = self._get_feature_on_this_cp_rank(labels, "labels") + # Megatron TransformerEngine expects bool attention_mask; some pipelines produce int tensors. + if attention_mask is not None and attention_mask.dtype != torch.bool and not torch.is_floating_point(attention_mask): + attention_mask = attention_mask.bool() position_ids = None # attention_mask: SelfAttention defalt to te DotProductAttention with # AttnMaskType.causal in which attention_mask would not be used, pass @@ -444,7 +513,7 @@ def inner_forward_step(self, loss_func, data_iterator: Iterator[DataProto], mode multi_modal_data[key].append(sample_mm_inputs[key]) for key in multi_modal_data.keys(): assert key not in forward_args - # DataProto.to('cuda') in upper frame not work for non_tensor_batch + # DataProto.to('cuda') in upper frame not work for non_tensor_batch. forward_args[key] = torch.concat(multi_modal_data[key], dim=0).to(input_ids.device) forward_args.update({"force_vit_image": True}) @@ -955,6 +1024,17 @@ def op_compute_language_loss(self, losses: torch.Tensor, labels: torch.Tensor, b return loss, metrics + +@dataclass +class SplitBatchResult: + """Result of splitting a batch into microbatches for training.""" + + microbatches: List[DataProto] + num_microbatches: int + # 1 for dynamic batching / sequence packing; per_device_train_batch_size otherwise. + micro_batch_size: int + + class MegatronTrainStrategy(MegatronInferStrategy, TrainStrategy): strategy_name = "megatron_train" @@ -965,6 +1045,37 @@ def __init__(self, worker: Worker): self.processor = None self._validate_access_integrity = True + # ---------- Versioned Bucket Cache for Selective Sync (Time-Sharing) ---------- + # Design: after each train_step, weights are gathered across PP ranks into CPU + # buckets and stored in a versioned cache. Only one rank (pp0/dp0/tp0/cp0, the + # "cache owner") stores the buckets; other ranks participate in the PP collective + # but discard results. When selective_sync_active_cache is called, the cache + # owner replays the "active" version's buckets to inference workers via CUDA IPC + # (colocated) or NCCL broadcast (remote), avoiding a full model_update cycle. + # + # _latest_cached: version just built (may not yet be promoted) + # _active_cached: version promoted for the next selective sync + # GC policy: keep latest + active; evict everything else. + self._cache_lock = threading.Lock() + self._cache_map: Dict[int, List[Any]] = {} + self._latest_cached: Optional[int] = None + self._active_cached: Optional[int] = None + # weights_meta is computed per-adapter inside _build_latest_bucket_cache() + # so that metadata names match the adapter-specific state dict keys. + # Single global cache owner: pp0/dp0/tp0/cp0 only; set during initialize(). + self._is_cache_owner: bool = False + + # Sender stats for post-sync verification, keyed by cache version. + self._cache_stats: Dict[int, dict] = {} + # Per-adapter versioned cache (multi-LoRA selective sync): same design as base + # cache but keyed by adapter name, so each adapter's LoRA weights can be synced + # independently at different versions. + self._adapter_cache_map: Dict[str, Dict[int, List[Any]]] = {} + self._latest_adapter_cached: Dict[str, Optional[int]] = {} + self._active_adapter_cached: Dict[str, Optional[int]] = {} + # Per-adapter sender stats keyed by (adapter_name, cache_key). + self._adapter_cache_stats: Dict[tuple, dict] = {} + def initialize(self, model_provider): self.seq_length = self.worker.pipeline_config.sequence_length self.weight_updaters: dict[str, MegatronWeightUpdater] = {} @@ -980,6 +1091,122 @@ def initialize(self, model_provider): ) self.forward_backward_func = get_forward_backward_func() self.model.config.finalize_model_grads_func = finalize_model_grads + + # Capture unwrapped models before DDP replaces self.model.models. + self.models_unwrapped = self.model.get_models() + + # LoRA detection: check both explicit adapter configs and the legacy lora_target field. + self.is_lora = (self.worker_config.model_args.adapters is not None) or \ + (getattr(self.worker_config.model_args, "lora_target", None) is not None) + # Multi-adapter discriminator: True only for RLix multi-adapter LoRA configs. + # Legacy single-LoRA (lora_target only, no adapters dict) uses train_step + shared optimizer. + self.has_multi_adapter = self.worker_config.model_args.adapters is not None and len(self.worker_config.model_args.adapters) > 1 + + # --- Config validation: reject incompatible configs before DDP wrapping --- + + # Read boolean flag; defaults to False when absent. + self.is_lora_optimizer_isolated: bool = bool( + self.worker_config.strategy_args.strategy_config.get("is_lora_optimizer_isolated", False) + if self.worker_config.strategy_args and self.worker_config.strategy_args.strategy_config + else False + ) + # Multi-adapter requires isolated optimizers — one per adapter. + if self.has_multi_adapter and not self.is_lora_optimizer_isolated: + raise ValueError( + "model_args.adapters is configured but is_lora_optimizer_isolated is not set. " + "Set strategy_config.is_lora_optimizer_isolated=true." + ) + + if self.is_lora_optimizer_isolated: + if self.megatron_train_args.use_distributed_optimizer: + raise ValueError( + "Isolated multi-adapter LoRA requires use_distributed_optimizer=False. " + "Distributed optimizer shards state across ranks, which conflicts " + "with per-adapter optimizer isolation." + ) + if self.megatron_train_args.overlap_grad_reduce: + raise ValueError( + "Isolated multi-adapter LoRA requires overlap_grad_reduce=False. " + "With overlap_grad_reduce=True, idle adapters' DDP backward hooks " + "never fire during another adapter's sequential pass, causing a " + "hang in finish_grad_sync()." + ) + if getattr(self.worker_config.model_args, "model_type", None) == "trl": + raise ValueError( + "Isolated multi-adapter LoRA does not support TRL value-head models " + "(model_type='trl'). Disable value head." + ) + + # --- Model-structure validation: needs instantiated model, not DDP --- + # When is_lora_optimizer_isolated=True, each adapter has its own optimizer. + # This requires every trainable parameter to belong to exactly one adapter. + # Shared trainable parameters (e.g., a value head not scoped to any adapter) + # would receive gradient updates from multiple optimizers, corrupting state. + # + # Example of VALID param names (adapter-scoped): + # "layers.0.self_attn.q_proj.lora_A.adapter_A.weight" + # "layers.0.self_attn.q_proj.lora_B.adapter_B.weight" + # + # Example of INVALID shared trainable (would cause error): + # "v_head.weight" # not scoped to any adapter → shared across optimizers + if self.is_lora_optimizer_isolated: + adapter_names = list(self.worker_config.model_args.adapters.keys()) + if not adapter_names: + raise ValueError( + "Multi-adapter LoRA requires at least one adapter in model_args.adapters" + ) + + # Activate all adapters so their LoRA params are marked trainable for inspection. + for model in self.models_unwrapped: + base_model = getattr(model, "base_model", None) + if base_model is not None and hasattr(base_model, "set_adapter"): + base_model.set_adapter(adapter_names) + + # Aggregate params from all chunks with a chunk-index prefix so names are unique. + # Virtual-pipeline chunks each hold different layers; the same local name (e.g. + # "layers.0.weight") can appear in multiple chunks, so the prefix is required. + name_to_param: Dict[str, torch.nn.Parameter] = {} + for chunk_idx, chunk_model in enumerate(self.models_unwrapped): + for param_name, param in chunk_model.named_parameters(): + name_to_param[f"chunk{chunk_idx}.{param_name}"] = param + + original_requires_grad: Dict[str, bool] = { + n: bool(p.requires_grad) for n, p in name_to_param.items() + } + + # Build adapter markers for name-matching. Example: {adapter_A: ".adapter_A.", ...} + markers = {adapter_name: f".{adapter_name}." for adapter_name in adapter_names} + + # Find shared trainables: params that are trainable but not scoped to any adapter. + # A param is adapter-scoped if its name contains one of the markers (e.g., ".adapter_A.") + shared_trainables: List[str] = [] + for name, param in name_to_param.items(): + if not original_requires_grad[name]: + # Skip frozen params — they don't participate in optimizer updates. + continue + if not any(marker in name for marker in markers.values()): + # Trainable but not adapter-scoped → shared across all adapters. + shared_trainables.append(name) + + if shared_trainables: + preview = ", ".join(repr(n) for n in shared_trainables[:10]) + likely_value_head = any( + ("v_head" in n or "value_head" in n) for n in shared_trainables + ) + hint = ( + " This looks like a value head / TRL wrapper. Set model_type: ~ to disable." + if likely_value_head + else "" + ) + raise ValueError( + "Multi-adapter LoRA requires all trainable parameters to be " + f"adapter-scoped (name must include one of: {sorted(markers.values())}). " + f"Found shared trainables (first 10): {preview}. " + "Freeze these parameters to use per-adapter optimizer mode." + + hint + ) + + # --- DDP wrapping: all config and model-structure checks passed --- ddp_config = DistributedDataParallelConfig( grad_reduce_in_fp32=self.megatron_train_args.accumulate_allreduce_grads_in_fp32, overlap_grad_reduce=self.megatron_train_args.overlap_grad_reduce, @@ -996,9 +1223,8 @@ def initialize(self, model_provider): # model chunks is overlapped with compute anyway. disable_bucketing=(model_index > 0), ) - for model_index, m in enumerate(self.model.get_models()) + for model_index, m in enumerate(self.models_unwrapped) ] - self.models_unwrapped = self.model.get_models() self.model.models = self.models_wrapped params_dtype = ( @@ -1006,6 +1232,7 @@ def initialize(self, model_provider): if self.megatron_train_args.fp16 else torch.bfloat16 if self.megatron_train_args.bf16 else torch.float32 ) + optimizer_config = OptimizerConfig( optimizer=self.megatron_train_args.optimizer, lr=self.megatron_train_args.learning_rate, @@ -1020,11 +1247,107 @@ def initialize(self, model_provider): use_distributed_optimizer=self.megatron_train_args.use_distributed_optimizer, clip_grad=self.megatron_train_args.max_grad_norm, ) - self.optimizer: MegatronOptimizer = get_megatron_optimizer(optimizer_config, self.models_wrapped) - logger.info(f"megatron optimizer: {self.optimizer}") + self.adapter_optimizers: Dict[str, MegatronOptimizer] | None = None + self.adapter_schedulers: Dict[str, Any] | None = None - bind_megatron_offload_states_func(optimizer=self.optimizer) + if not self.has_multi_adapter: + # Non-LoRA or legacy single-LoRA: single optimizer (upstream v0.2.0 path). + self.optimizer: MegatronOptimizer = get_megatron_optimizer(optimizer_config, self.models_wrapped) + logger.info(f"megatron optimizer: {self.optimizer}") + bind_megatron_offload_states_func(optimizer=self.optimizer) + else: + # ---- Isolated mode: one optimizer + scheduler per adapter ---- + # adapter_names, name_to_param, original_requires_grad, markers already + # computed during model-structure validation above. + + def _apply_trainability_mask_for_adapter(active_adapter: str) -> None: + """Freeze all params except this adapter's LoRA weights. + + Used before ``get_megatron_optimizer`` so the optimizer only captures + parameters that belong to ``active_adapter``. The trainability mask + is restored after all per-adapter optimizers are constructed. + """ + marker = markers[active_adapter] + for n, p in name_to_param.items(): + p.requires_grad_(bool(original_requires_grad[n] and (marker in n))) + + self.adapter_optimizers = {} + self.adapter_schedulers = {} + param_id_to_name = {id(p): n for n, p in name_to_param.items()} + seen_param_ids: Set[int] = set() + for adapter_name in adapter_names: + # Activate the current adapter on every chunk so PEFT routes forward + # correctly; chunk 0 alone is not sufficient for virtual-pipeline models. + for chunk_model in self.models_unwrapped: + chunk_model.set_adapter(adapter_name) + _apply_trainability_mask_for_adapter(adapter_name) + adapter_opt = get_megatron_optimizer(optimizer_config, self.models_wrapped) + # bind_megatron_offload_states_func is deferred to the ChainedOptimizer + # call below (line ~1306), which recursively binds all sub-optimizers. + + # Assert optimizer param ownership is isolated to this adapter. + marker = markers[adapter_name] + for group in getattr(adapter_opt, "param_groups", []): + for param in group.get("params", []): + pid = id(param) + pname = param_id_to_name.get(pid) + if pname is None: + # Megatron optimizers may create FP32 "main params" (new Parameter + # objects) for FP16/BF16 model params. Those parameters are not + # present in model.named_parameters(), so we cannot verify their + # adapter ownership here. + continue + if marker not in pname: + raise RuntimeError( + f"Per-adapter optimizer for {adapter_name!r} captured " + f"non-scoped param {pname!r}" + ) + if pid in seen_param_ids: + raise RuntimeError( + f"Parameter {pname!r} appears in multiple per-adapter optimizers; " + "expected disjoint param sets" + ) + seen_param_ids.add(pid) + + self.adapter_optimizers[adapter_name] = adapter_opt + self.adapter_schedulers[adapter_name] = get_megatron_lr_scheduler( + self.megatron_train_args, + self.megatron_train_args.max_steps, + optimizer=adapter_opt, + ) + + # Restore original trainability. + for n, p in name_to_param.items(): + p.requires_grad_(original_requires_grad[n]) + + # ChainedOptimizer wraps all per-adapter optimizers so that generic + # offload/reload/state_dict calls (which expect a single self.optimizer) + # fan out to every adapter optimizer transparently. + # Tradeoff: all-or-nothing handling means all adapters are reloaded/offloaded together, + # even when train_step_lora() only trains one adapter at a time. + # fixme(tao) HACK can we do lora granular swap of optimizer? + # Each sub-optimizer already has reload_states/offload_states bound by + # bind_megatron_offload_states_func, so adapter_optimizers[adapter_name].reload_states() + # would work mechanically — but train_step_lora still calls self.load_states()/ + # self.offload_states() which go through ChainedOptimizer and swap all adapters. + from megatron.core.optimizer import ChainedOptimizer + self.optimizer = ChainedOptimizer(list(self.adapter_optimizers.values())) + bind_megatron_offload_states_func(optimizer=self.optimizer) + + # Initialize per-adapter RNG states for sequential training (plan item 15). + # Each adapter starts from the current global RNG state; they diverge as training progresses. + # Includes Megatron TP CUDA RNG tracker for deterministic TP-parallel dropout per adapter. + self.adapter_rng_states: Dict[str, Dict[str, Any]] = { + name: { + "cpu": torch.get_rng_state(), + "cuda": torch.cuda.get_rng_state(), + "python": random.getstate(), + "numpy": np.random.get_state(), + "rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(), + } + for name in adapter_names + } self.worker.rank_info.dp_rank = mpu.get_data_parallel_rank() self.worker.rank_info.dp_size = mpu.get_data_parallel_world_size() @@ -1035,6 +1358,14 @@ def initialize(self, model_provider): self.worker.rank_info.cp_size = mpu.get_context_parallel_world_size() self.worker.rank_info.cp_rank = mpu.get_context_parallel_rank() + # Single global cache owner: the unique rank with all parallel dimensions at 0. + self._is_cache_owner = ( + mpu.get_pipeline_model_parallel_rank() == 0 + and mpu.get_data_parallel_rank() == 0 + and mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_context_parallel_rank() == 0 + ) + logger.info(f"max steps pipeline {self.worker_config.training_args.max_steps}") self.worker_config.training_args.max_steps = ( self.worker_config.training_args.max_steps // self.worker.rank_info.dp_size @@ -1042,6 +1373,18 @@ def initialize(self, model_provider): self.megatron_train_args.max_steps = self.worker_config.training_args.max_steps logger.info(f"max steps worker train {self.worker_config.training_args.max_steps}") + # Per-adapter schedulers must use DP-adjusted max_steps. They were initially + # created before dp_size was known, so rebuild here with the final step budget. + if self.has_multi_adapter and self.adapter_optimizers: + self.adapter_schedulers = { + adapter_name: get_megatron_lr_scheduler( + self.megatron_train_args, + self.megatron_train_args.max_steps, + optimizer=adapter_opt, + ) + for adapter_name, adapter_opt in self.adapter_optimizers.items() + } + self.scheduler = get_megatron_lr_scheduler( self.megatron_train_args, self.megatron_train_args.max_steps, optimizer=self.optimizer ) @@ -1073,53 +1416,35 @@ def initialize(self, model_provider): logger.info("Set variable_seq_lengths to True when use dynamic batching and pipeline parallel.") logger.info(f"{self.model.get_models()}") - dist.barrier() + _safe_dist_barrier() def train_step(self, batch: DataProto, loss_func: Callable): self.model.train() global_step = batch.meta_info.get("global_step", 0) - is_offload_optimizer_states_in_train_step = batch.meta_info.get("is_offload_optimizer_states_in_train_step", True) - batch.meta_info['batch_num_tokens'] = self._get_batch_num_tokens(batch, dp_group=mpu.get_data_parallel_group()) - batch.meta_info['global_valid_samples'] = self._get_global_valid_samples(batch, dp_group=mpu.get_data_parallel_group()) + is_offload_optimizer_states_in_train_step = batch.meta_info.get( + "is_offload_optimizer_states_in_train_step", True + ) - if self.worker_config.use_dynamic_batching_in_train: - micro_batches_list = list(make_micro_batch_iter_for_dynamic_batching(batch)) - num_microbatches = batch.meta_info["num_micro_batchs"] - mini_batch_size = 1 - elif self.use_sequence_packing: - vp_size = self.worker_config.strategy_args.strategy_config['virtual_pipeline_model_parallel_size']\ - if 'virtual_pipeline_model_parallel_size' in self.worker_config.strategy_args.strategy_config else 1 - micro_batches_list = list(make_micro_batch_iter_for_sequence_packing(batch, tp_size=self.worker.rank_info.tp_size, - cp_size=self.worker.rank_info.cp_size, - vp_size=vp_size, is_train=True, - dp_group=mpu.get_data_parallel_group(with_context_parallel=True), - micro_batch_size=self.worker_config.training_args.per_device_train_batch_size, - config=self.worker_config.sequence_packing_args)) - num_microbatches = micro_batches_list[0].meta_info["num_micro_batchs"] - mini_batch_size = 1 - else: - mini_batch_size = self.worker_config.training_args.per_device_train_batch_size - num_microbatches = batch.batch.batch_size[0] // self.worker_config.training_args.per_device_train_batch_size - assert ( - num_microbatches == self.megatron_train_args.gradient_accumulation_steps - ), f"num_microbatches={num_microbatches} gradient_accumulation_steps={self.megatron_train_args.gradient_accumulation_steps}" - micro_batches_list = batch.chunk(chunks=num_microbatches) + # Shared: populate batch-level metadata. + self._ensure_train_batch_meta(batch) - for micro_batch in micro_batches_list: - micro_batch.meta_info['loss_scale'] = num_microbatches * mpu.get_data_parallel_world_size() - micro_batch.meta_info['micro_batch_size'] = micro_batch.batch.batch_size[0] + # Shared: split batch into microbatches. + split = self._split_batch_to_microbatches(batch) - data_iterator = [iter(micro_batches_list) for _ in range(len(self.model))] + # Shared: stamp loss_scale, micro_batch_size, batch_num_tokens, global_valid_samples. + self._annotate_microbatches_for_train( + split.microbatches, split.num_microbatches, batch.meta_info + ) - metrics_tensors: List[Dict[str, "torch.Tensor"]] = self.forward_backward_func( - forward_step_func=partial(self.inner_forward_step, loss_func), - data_iterator=data_iterator, - model=self.model.get_models(), - num_microbatches=num_microbatches, + # Shared: forward/backward passes. + # train_step always uses self.seq_length, even for sequence packing (current RLIX behavior). + metrics = self._run_forward_backward( + microbatches=split.microbatches, + loss_func=loss_func, + num_microbatches=split.num_microbatches, + micro_batch_size=split.micro_batch_size, seq_length=self.seq_length, - micro_batch_size=mini_batch_size, - forward_only=False, ) # 只有step的时候需要load optimizer states @@ -1134,23 +1459,179 @@ def train_step(self, batch: DataProto, loss_func: Callable): else: raise NotImplementedError("megatron optimizer step failed!") + # Shared: zero grad buffers and optimizer state, then clear stale bucket caches. + self._zero_grad() + self._clear_bucket_caches() + + metrics.update({self.worker_config.name + "/" + "grad_norm": grad_norm}) + self._collect_auxiliary_loss_metrics(metrics) + + # Time-sharing: build a versioned bucket cache of the current weights. + # Promotion is NOT done here — the RLix pipeline calls + # promote_active_checkpoint explicitly after train_step to control which + # version is broadcast via selective_sync_active_cache. + if DO_TIME_SHARING: + checkpoint_version = int(batch.meta_info["checkpoint_version"]) + self._build_latest_bucket_cache(checkpoint_version=checkpoint_version) + return metrics + + + # ------------------------------------------------------------------ + # Shared helpers extracted from train_step (Changes 2-6) + # ------------------------------------------------------------------ + def _zero_grad(self) -> None: + """Zero Megatron DDP grad buffers and optimizer grad state.""" for model in self.model: model.zero_grad_buffer() - # Offload/reload does not update cached_param_buffer_shard_list/cached_grad_buffer_shard_list, - # resulting using old params in `start_param_sync`, which leads to wrong results. So we clear the cache. - for bucket_group in model.bucket_groups + model.expert_parallel_bucket_groups: - if hasattr(bucket_group, "cached_param_buffer_shard_list"): - bucket_group.cached_param_buffer_shard_list = [None] * len(bucket_group.buckets) - if hasattr(bucket_group, "cached_grad_buffer_shard_list"): - bucket_group.cached_grad_buffer_shard_list = [None] * len(bucket_group.buckets) self.optimizer.zero_grad() - metrics = {} + def _ensure_train_batch_meta(self, batch: DataProto) -> None: + """Populate batch_num_tokens and global_valid_samples on batch.meta_info. + + Uses direct assignment matching train_step baseline. + DataProto.chunk()/make_iterator() share the same meta_info dict reference + across microbatches, so setdefault would preserve stale values from a + previous mini-batch iteration. + """ + if batch.meta_info is None: + batch.meta_info = {} + batch.meta_info['batch_num_tokens'] = self._get_batch_num_tokens( + batch, dp_group=mpu.get_data_parallel_group() + ) + batch.meta_info['global_valid_samples'] = self._get_global_valid_samples( + batch, dp_group=mpu.get_data_parallel_group() + ) + + def _split_batch_to_microbatches( + self, + batch: DataProto, + ) -> SplitBatchResult: + """Split a DataProto batch into microbatches for training. + + Three splitting strategies, selected by worker config: + - Dynamic batching: variable-length microbatches via make_micro_batch_iter_for_dynamic_batching. + - Sequence packing: load-balanced packed partitions via make_micro_batch_iter_for_sequence_packing. + - Standard: equal-size chunks by per_device_train_batch_size, with + num_microbatches == gradient_accumulation_steps assertion. + """ + if self.worker_config.use_dynamic_batching_in_train: + # Fail fast if upstream caller did not run dynamic_batching_shard() to prepare + # required batch metadata. See dynamic_batching.py:118. + if not batch.meta_info or "micro_batch_indices" not in batch.meta_info: + raise RuntimeError( + "use_dynamic_batching_in_train requires batch metadata from " + "dynamic_batching_shard(). Ensure the pipeline calls " + "dynamic_batching_shard() before train_step/train_step_lora." + ) + microbatches = list(make_micro_batch_iter_for_dynamic_batching(batch)) + num_microbatches = batch.meta_info["num_micro_batchs"] + return SplitBatchResult( + microbatches=microbatches, + num_microbatches=num_microbatches, + micro_batch_size=1, + ) + + if self.use_sequence_packing: + vp_size = self.worker_config.strategy_args.strategy_config.get( + "virtual_pipeline_model_parallel_size", 1 + ) + microbatches = list( + make_micro_batch_iter_for_sequence_packing( + batch, + tp_size=self.worker.rank_info.tp_size, + cp_size=self.worker.rank_info.cp_size, + vp_size=vp_size, + is_train=True, + dp_group=mpu.get_data_parallel_group(with_context_parallel=True), + micro_batch_size=self.worker_config.training_args.per_device_train_batch_size, + config=self.worker_config.sequence_packing_args, + ) + ) + num_microbatches = microbatches[0].meta_info["num_micro_batchs"] + return SplitBatchResult( + microbatches=microbatches, + num_microbatches=num_microbatches, + micro_batch_size=1, + ) + + # Standard path: equal chunks by per_device_train_batch_size. + per_device_batch_size = self.worker_config.training_args.per_device_train_batch_size + total_batch_size = batch.batch.batch_size[0] + num_microbatches = total_batch_size // per_device_batch_size + assert num_microbatches == self.megatron_train_args.gradient_accumulation_steps, ( + f"num_microbatches={num_microbatches} " + f"gradient_accumulation_steps={self.megatron_train_args.gradient_accumulation_steps}" + ) + microbatches = batch.chunk(chunks=num_microbatches) + return SplitBatchResult( + microbatches=microbatches, + num_microbatches=num_microbatches, + micro_batch_size=per_device_batch_size, + ) + + def _annotate_microbatches_for_train( + self, + microbatches: List[DataProto], + num_microbatches: int, + batch_meta: Dict[str, Any], + ) -> None: + """Stamp loss_scale, micro_batch_size, and batch-level metadata on each microbatch. + + loss_scale = num_microbatches * dp_world_size. This is the standard train_step + convention — inner_forward_step multiplies loss by this value to normalize + gradient accumulation across microbatches and data parallel ranks. + """ + for micro_batch in microbatches: + if micro_batch.meta_info is None: + micro_batch.meta_info = {} + # Direct assignment for loss_scale and micro_batch_size, matching train_step + # baseline. These are always set fresh by the training step. + micro_batch.meta_info['loss_scale'] = num_microbatches * mpu.get_data_parallel_world_size() + micro_batch.meta_info['micro_batch_size'] = micro_batch.batch.batch_size[0] + # setdefault for batch-level metadata that may already be populated. + micro_batch.meta_info.setdefault("batch_num_tokens", batch_meta.get("batch_num_tokens")) + micro_batch.meta_info.setdefault("global_valid_samples", batch_meta.get("global_valid_samples")) + + def _run_forward_backward( + self, + microbatches: List[DataProto], + loss_func: Callable, + num_microbatches: int, + micro_batch_size: int, + seq_length: int, + ) -> Dict[str, Any]: + """Run forward/backward passes on explicit microbatch list. Does NOT step optimizer. + + Builds the data_iterator from the provided microbatch list and calls + forward_backward_func. Does not re-split — the microbatch list is used as-is, + preserving packed partition boundaries for sequence packing. + + Loss scaling is handled by _annotate_microbatches_for_train which stamps + loss_scale = num_microbatches * dp_world_size on each microbatch. The + inner_forward_step loss_wrapper applies this scale. + """ + data_iterator = [iter(microbatches) for _ in range(len(self.model))] + metrics_tensors: List[Dict[str, "torch.Tensor"]] = self.forward_backward_func( + forward_step_func=partial(self.inner_forward_step, loss_func), + data_iterator=data_iterator, + model=self.model.get_models(), + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + ) + + metrics: Dict[str, Any] = {} for mini_metrics in metrics_tensors: append_to_dict(metrics, mini_metrics) + return metrics - metrics.update({self.worker_config.name + "/" + "grad_norm": grad_norm}) + def _collect_auxiliary_loss_metrics(self, metrics: Dict[str, Any]) -> None: + """Collect MoE and MTP auxiliary loss metrics after a training step. + Called by both train_step and train_step_lora to ensure auxiliary losses + are always reported regardless of training path. + """ if self.model.config.num_moe_experts is not None and self.model.config.num_moe_experts > 1: reduce_aux_losses_tracker_across_ranks() tracker = get_moe_layer_wise_logging_tracker() @@ -1163,48 +1644,826 @@ def train_step(self, batch: DataProto, loss_func: Callable): metrics.update(moe_losses) if self.model.config.mtp_num_layers is not None and self.model.config.mtp_num_layers > 0: - mtp_total_loss_dict = {} + mtp_total_loss_dict: Dict[str, Any] = {} MTPLossLoggingHelper.reduce_loss_in_tracker() tracker = MTPLossLoggingHelper.tracker if "values" in tracker: loss_scale = 1 / self.megatron_train_args.gradient_accumulation_steps mtp_losses = tracker["values"] * loss_scale mtp_num_layers = mtp_losses.shape[0] - for i in range(mtp_num_layers): - name = self.worker_config.name + "/" + f"mtp_{i+1} loss" - mtp_total_loss_dict[name] = mtp_losses[i].item() + for layer_idx in range(mtp_num_layers): + name = self.worker_config.name + "/" + f"mtp_{layer_idx+1} loss" + mtp_total_loss_dict[name] = mtp_losses[layer_idx].item() MTPLossLoggingHelper.clean_loss_in_tracker() metrics.update(mtp_total_loss_dict) + + def _clear_bucket_caches(self) -> None: + """Clear cached param/grad buffer shard lists after optimizer step. + + Offload/reload does not update these caches, so stale params in + start_param_sync would lead to wrong results. + """ + for model in self.model: + for bucket_group in model.bucket_groups + model.expert_parallel_bucket_groups: + if hasattr(bucket_group, "cached_param_buffer_shard_list"): + bucket_group.cached_param_buffer_shard_list = [None] * len(bucket_group.buckets) + if hasattr(bucket_group, "cached_grad_buffer_shard_list"): + bucket_group.cached_grad_buffer_shard_list = [None] * len(bucket_group.buckets) + + def train_step_lora(self, batch: DataProto, loss_func: Callable) -> dict: + """Single-adapter-per-call LoRA training step. + + Callers guarantee exactly one adapter per call. The adapter's per-adapter + optimizer and scheduler are stepped independently. + """ + self.model.train() + + if not self.is_lora_optimizer_isolated: + raise RuntimeError( + "train_step_lora requires model_args.adapters. " + "Legacy (lora_target only) should use train_step." + ) + + if self.adapter_optimizers is None or self.adapter_schedulers is None: + raise RuntimeError( + "train_step_lora requires adapter_optimizers/adapter_schedulers " + "to be initialized" + ) + + # Shared: populate batch-level metadata. + self._ensure_train_batch_meta(batch) + + # Shared: split batch into microbatches (same contract as train_step). + split = self._split_batch_to_microbatches(batch) + microbatches = split.microbatches + + # Shared: stamp loss_scale, micro_batch_size, batch_num_tokens, global_valid_samples. + # loss_scale = num_microbatches * dp_world_size, matching train_step semantics. + self._annotate_microbatches_for_train( + microbatches, split.num_microbatches, batch.meta_info + ) + + # LoRA-specific: resolve adapter name from non_tensor_batch. + # resolve_microbatch_lora_name validates homogeneity within each microbatch. + # All callers set lora_name via non_tensor_batch (pipeline routing). + adapter_name = resolve_microbatch_lora_name(microbatches[0].non_tensor_batch).lora_name + # Validate all microbatches target the same adapter (single-adapter-per-call contract). + for mb_idx, mb in enumerate(microbatches[1:], start=1): + mb_adapter = resolve_microbatch_lora_name(mb.non_tensor_batch).lora_name + if mb_adapter != adapter_name: + raise ValueError( + f"train_step_lora expects single adapter per call, but microbatch[{mb_idx}] " + f"has adapter={mb_adapter!r}, expected {adapter_name!r}" + ) + + is_offload_optimizer_states_in_train_step = bool( + batch.meta_info.get("is_offload_optimizer_states_in_train_step", True) + ) + + opt = self.adapter_optimizers.get(adapter_name) + sch = self.adapter_schedulers.get(adapter_name) + if opt is None or sch is None: + raise RuntimeError(f"Missing optimizer/scheduler for adapter {adapter_name!r}") + + # LoRA-specific: restore adapter RNG state (including TP CUDA RNG tracker for dropout). + self.load_states(include=[OffloadStateType.optimizer_states]) + rng = self.adapter_rng_states[adapter_name] + torch.set_rng_state(rng["cpu"]) + torch.cuda.set_rng_state(rng["cuda"]) + random.setstate(rng["python"]) + np.random.set_state(rng["numpy"]) + tensor_parallel.get_cuda_rng_tracker().set_states(rng["rng_tracker_states"]) + + # Shared: forward/backward passes (same call signature as train_step). + metrics = self._run_forward_backward( + microbatches=microbatches, + loss_func=loss_func, + num_microbatches=split.num_microbatches, + micro_batch_size=split.micro_batch_size, + seq_length=self.seq_length, + ) + + # LoRA-specific: save adapter RNG state (including TP CUDA RNG tracker for dropout). + self.adapter_rng_states[adapter_name] = { + "cpu": torch.get_rng_state(), + "cuda": torch.cuda.get_rng_state(), + "python": random.getstate(), + "numpy": np.random.get_state(), + "rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(), + } + + # LoRA-specific: per-adapter optimizer step. + update_successful, grad_norm, _ = opt.step() + if update_successful: + sch.step() + else: + raise NotImplementedError("megatron optimizer step failed!") + + # Shared: zero grad buffers and optimizer state, then clear stale bucket caches. + self._zero_grad() + self._clear_bucket_caches() + + # Time-sharing: build per-adapter bucket cache while GPU weights are still resident. + # Must run before offload_states moves weights to CPU. + # Promotion is NOT done here — the pipeline calls promote methods explicitly. + if DO_TIME_SHARING: + checkpoint_version = int(batch.meta_info["checkpoint_version"]) + self._build_latest_bucket_cache( + checkpoint_version=checkpoint_version, + adapter_name=adapter_name, + ) + + metrics.update({ + f"{self.worker_config.name}/{adapter_name}/grad_norm": grad_norm, + }) + self._collect_auxiliary_loss_metrics(metrics) + + if is_offload_optimizer_states_in_train_step: + self.offload_states(include=[OffloadStateType.optimizer_states], non_blocking=True) + + # Restore all adapters active (PEFT sometimes expects list of active adapters). + active_adapters = list(self.worker_config.model_args.adapters.keys()) + for model in self.models_unwrapped: + model.base_model.set_adapter(active_adapters) + return metrics - def model_update(self, model_update_name: str): - return self.weight_updaters[model_update_name].model_update() + def model_update(self, model_update_name: str, adapters_to_update: list[str] | None = None): + # Forward optional adapter subset to weight updater for multi-LoRA selective sync. + return self.weight_updaters[model_update_name].model_update(adapters_to_update=adapters_to_update) + + + def get_lora_tensors(self, adapter_name: str) -> Dict[str, torch.Tensor]: + """Return a CPU copy of all LoRA parameter tensors for *adapter_name*. + + Reads parameters from models_unwrapped[0] (TP/DP ranks share identical + LoRA weights, so rank 0 is sufficient). + + Note: used only by integration tests for weight inspection and snapshot + comparison. Not called in any production pipeline. + """ + if not self.is_lora: + raise RuntimeError( + "get_lora_tensors called but LoRA is not enabled for this strategy." + ) + marker = f".{adapter_name}." + tensors: Dict[str, torch.Tensor] = {} + for name, param in self.models_unwrapped[0].named_parameters(): + if "lora_" not in name: + continue + if marker not in name: + continue + tensors[name] = param.detach().cpu().clone() + if not tensors: + raise RuntimeError( + f"No LoRA tensors found for adapter {adapter_name!r}; check adapter naming." + ) + return tensors + + def set_lora_tensors( + self, *, adapter_name: str, tensors: Dict[str, torch.Tensor] + ) -> int: + """Overwrite the LoRA parameters for *adapter_name* with *tensors* (in-place). + + Also refreshes the optimizer's FP32 main-param copies via + ``optimizer.reload_model_params()`` so the next step starts from the + updated weights, not stale copies. + + Note: used only by integration tests to reset adapter weights to a known + state before a reference run. Not called in any production pipeline. + """ + if not self.is_lora: + raise RuntimeError( + "set_lora_tensors called but LoRA is not enabled for this strategy." + ) + marker = f".{adapter_name}." + name_to_param = dict(self.models_unwrapped[0].named_parameters()) + copied = 0 + for name, value in tensors.items(): + if "lora_" not in name: + continue + if marker not in name: + continue + if name not in name_to_param: + raise KeyError( + f"Unknown LoRA param name {name!r} when setting adapter {adapter_name!r}" + ) + param = name_to_param[name] + src = value.detach() + if src.device != param.device or src.dtype != param.dtype: + src = src.to(device=param.device, dtype=param.dtype) + param.data.copy_(src) + copied += 1 + copied_total = copied + if dist.is_initialized(): + copied_total_tensor = torch.tensor([copied], dtype=torch.int64, device=current_platform.current_device()) + dist.all_reduce(copied_total_tensor, op=dist.ReduceOp.SUM) + copied_total = int(copied_total_tensor.item()) + if copied_total == 0: + raise RuntimeError( + f"No LoRA tensors applied for adapter {adapter_name!r}; " + "check naming and tensor keys." + ) + + # Sync BF16 model params → FP32 main params. + # Megatron's mixed-precision optimizers keep a separate FP32 "main params" copy of + # BF16/FP16 model weights and use it as the authoritative source in optimizer.step(). + # We just mutated the BF16 side directly (bypassing the optimizer), so push those + # changes into FP32 now — otherwise the next step() would overwrite our writes. + self.optimizer.reload_model_params() + return copied + + def copy_lora_params(self, *, src_adapter: str, dst_adapter: str) -> int: + """Copy LoRA parameters in-place from *src_adapter* to *dst_adapter*. + + Matches source parameter names to destination names by substituting the + adapter marker (``..`` → ``..``) and raises + ``KeyError`` if the expected destination parameter does not exist. + + Note: used only by integration tests to synchronize all adapters to the + same initial weights. Not called in any production pipeline. + """ + if not self.is_lora: + raise RuntimeError( + "copy_lora_params called but LoRA is not enabled for this strategy." + ) + src_marker = f".{src_adapter}." + dst_marker = f".{dst_adapter}." + name_to_param = dict(self.models_unwrapped[0].named_parameters()) + copied = 0 + for name, param in name_to_param.items(): + if "lora_" not in name: + continue + if src_marker not in name: + continue + dst_name = name.replace(src_marker, dst_marker) + if dst_name not in name_to_param: + raise KeyError( + f"Expected destination param {dst_name!r} for source {name!r}" + ) + name_to_param[dst_name].data.copy_(param.data) + copied += 1 + if copied == 0: + raise RuntimeError( + "No LoRA parameters copied; check adapter naming and parameter patterns." + ) + + # Sync BF16 model params → FP32 main params (same reason as set_lora_tensors). + self.optimizer.reload_model_params() + return copied + + def _build_latest_bucket_cache( + self, *, checkpoint_version: int, adapter_name: Optional[str] = None + ) -> None: + """Gather current model weights across PP ranks and store as CPU buckets. + + All PP ranks must participate in ``gather_all_hf_weights`` (it uses PP + collectives internally). Only the cache owner (pp0/dp0/tp0/cp0) stores + the resulting buckets; non-owners drain the generator to keep the + collective moving but discard results. + + When ``adapter_name`` is given, only that adapter's LoRA weights are + cached (stored in ``_adapter_cache_map``); otherwise base weights are + cached in ``_cache_map``. + """ + buffer_size = int(self.worker.pipeline_config.model_update_buffer_size_mb) * 1024 * 1024 + cache_key = int(checkpoint_version) + + with self._cache_lock: + # Compute weights_meta with the actual adapter_name so metadata names match + # the state dict keys used by gather_all_hf_weights (base vs LoRA names). + weights_meta = gather_weights_meta_cross_pp(self.models_unwrapped, adapter_name=adapter_name) + + # All PP ranks must participate in gather_all_hf_weights (PP collective). + # Only the cache owner stores results; non-owners drain and discard each batch. + cached_buckets: List[Any] = [] + # Accumulate sender stats from globally-gathered weights for verification. + # Gated by config flag to skip stats computation when verification is disabled. + compute_stats = self.worker.pipeline_config.verify_model_after_sync + running_sum = 0.0 + running_max = float("-inf") + running_min = float("inf") + batch_count = 0 + for hf_named_weights in gather_all_hf_weights( + self.models_unwrapped, + buffer_size=buffer_size, + weights_meta=weights_meta, + adapter_name=adapter_name, + ): + if not self._is_cache_owner: + # Non-owner must consume the generator element to keep the PP collective moving, + # but does not store anything. + continue + # Compute sender stats on GPU tensors before CPU copy (GPU reductions are + # ~20-40x faster than CPU for large models). + if compute_stats: + batch_stats = compute_weight_stats(hf_named_weights) + if batch_stats: + running_sum += batch_stats["sum"] + running_max = max(running_max, batch_stats["max"]) + running_min = min(running_min, batch_stats["min"]) + batch_count += 1 + # Cache as raw CPU tensors. GPU staging + serialization happens at transport + # time because IPC handles are ephemeral (tied to specific GPU allocations). + cpu_named_weights = [ + (str(name), weight.detach().to("cpu").contiguous()) + for name, weight in hf_named_weights + ] + + bucket, tensors_meta = _bucket_named_tensors(cpu_named_weights) # CPU int8 + cached_buckets.append((tensors_meta, bucket)) + + if not self._is_cache_owner: + return + + # Store sender stats alongside cached buckets for later verification. + sender_stats = {} + if batch_count > 0: + sender_stats = {"sum": running_sum, "max": running_max, "min": running_min} + + if adapter_name is not None: + self._adapter_cache_map.setdefault(adapter_name, {})[cache_key] = cached_buckets + self._latest_adapter_cached[adapter_name] = cache_key + # Store per-adapter stats keyed by (adapter_name, cache_key). + self._adapter_cache_stats[(adapter_name, cache_key)] = sender_stats + else: + self._cache_map[cache_key] = cached_buckets + self._latest_cached = cache_key + self._cache_stats[cache_key] = sender_stats + + def promote_active_checkpoint(self, checkpoint_version: int) -> None: + """Mark a cached version as the "active" snapshot for selective sync. + + The distinction between "latest" and "active" allows a new cache to be + built concurrently while selective_sync_active_cache reads the previous + active version. After promotion, all versions except latest and active + are garbage-collected. + """ + if not DO_TIME_SHARING: + raise RuntimeError("promote_active_checkpoint is only supported under RLix control plane") + # Non-owners hold no cache, so there is nothing to promote. + if not self._is_cache_owner: + return + + cache_key = int(checkpoint_version) + with self._cache_lock: + if cache_key not in self._cache_map: + raise RuntimeError(f"promote_active_checkpoint missing cache_key={cache_key}") + self._active_cached = cache_key + + keep: Set[int] = set() + if self._latest_cached is not None: + keep.add(self._latest_cached) + keep.add(self._active_cached) + + for key in list(self._cache_map.keys()): + if key not in keep: + del self._cache_map[key] + + def promote_active_adapter_checkpoint( + self, adapter_name: str, checkpoint_version: int + ) -> None: + """Same as ``promote_active_checkpoint`` but for a single adapter's LoRA cache.""" + if not DO_TIME_SHARING: + raise RuntimeError("promote_active_adapter_checkpoint is only supported under RLix control plane") + # Non-owners hold no cache, so there is nothing to promote. + if not self._is_cache_owner: + return + cache_key = int(checkpoint_version) + with self._cache_lock: + if cache_key not in self._adapter_cache_map.get(adapter_name, {}): + raise RuntimeError( + f"promote_active_adapter_checkpoint missing cache for adapter={adapter_name!r} key={cache_key}" + ) + self._active_adapter_cached[adapter_name] = cache_key + keep: Set[int] = set() + if self._latest_adapter_cached.get(adapter_name) is not None: + keep.add(self._latest_adapter_cached[adapter_name]) + keep.add(self._active_adapter_cached[adapter_name]) + for key in list(self._adapter_cache_map[adapter_name].keys()): + if key not in keep: + del self._adapter_cache_map[adapter_name][key] + + def selective_sync_active_cache( + self, + *, + tgt_dp_ranks: List[int], + tgt_workers, + tgt_device_mapping: List[int], + tgt_num_gpus_per_worker: int, + comm_plan: Optional[dict] = None, + adapters_to_sync: Optional[List[str]] = None, + ) -> None: + """Replay the active bucket cache to inference workers (time-sharing). + + Transport flow (executed only by the single cache-owner rank): + 1. **Cache lookup**: read the promoted "active" version from ``_cache_map`` + (base weights) and optionally ``_adapter_cache_map`` (per-adapter LoRA). + 2. **Decode comm_plan**: the ``ModelUpdateService`` builds a per-rank plan + specifying IPC targets (colocated workers) and NCCL broadcast targets. + 3. **Transport**: for each cached bucket, stage to GPU once, then: + - IPC path: serialize the GPU tensor via CUDA IPC and push to colocated workers. + - Broadcast path: NCCL broadcast to remote workers. + 4. **LoRA registration**: after adapter weights are transported, call + ``add_lora`` on each target worker to register the adapter with its PEFT config. + 5. **Group teardown**: destroy the temporary NCCL broadcast group. + + Non-owner ranks return immediately; ``ray.get(sync_refs)`` in + ``ModelUpdateService`` provides the cross-worker sync barrier. + """ + if not DO_TIME_SHARING: + raise RuntimeError("selective_sync_active_cache is only supported under RLix control plane") + + tgt_dp_ranks = sorted(set(int(r) for r in tgt_dp_ranks)) + if not tgt_dp_ranks: + raise ValueError("tgt_dp_ranks must be non-empty") + if not tgt_device_mapping: + raise ValueError("tgt_device_mapping must be non-empty") + if not isinstance(tgt_num_gpus_per_worker, int) or int(tgt_num_gpus_per_worker) <= 0: + raise ValueError("tgt_num_gpus_per_worker must be positive int") + if len(tgt_device_mapping) % int(tgt_num_gpus_per_worker) != 0: + raise RuntimeError("tgt_device_mapping length must be divisible by tgt_num_gpus_per_worker") + + world_rank = int(self.worker.rank) + logger.info(f"[rlix][selective_sync_active_cache] enter world_rank={world_rank} is_cache_owner={self._is_cache_owner}") + + # Non-owners have no cache and do no transport. + # ray.get(sync_refs) in ModelUpdateService provides the sync barrier for all train workers. + if not self._is_cache_owner: + return None + + # Owner acquires lock for the entire replay (cache lookup + all transport + group teardown). + # This prevents concurrent promote_active_checkpoint or _build_latest_bucket_cache from + # racing with in-flight transport. + logger.info("[rlix][selective_sync_active_cache] acquiring _cache_lock") + with self._cache_lock: + logger.info("[rlix][selective_sync_active_cache] _cache_lock acquired") + # --- Cache lookup --- + adapter_names_to_register: List[str] = [] + base_cached_buckets: List[Any] = [] + adapter_cached_buckets: Dict[str, List[Any]] = {} + + if adapters_to_sync is not None: + # Sync specified adapters using their active versions. + missing = [a for a in adapters_to_sync if self._active_adapter_cached.get(a) is None] + if missing: + raise RuntimeError(f"selective_sync_active_cache: no active version for adapters {missing}") + adapter_names_to_register = list(dict.fromkeys(str(a) for a in adapters_to_sync)) + if self._active_cached is None: + raise RuntimeError( + "selective_sync_active_cache(is_lora): active base cache is unset; " + "call promote_active_checkpoint first" + ) + if self._active_cached not in self._cache_map: + raise RuntimeError(f"selective_sync_active_cache: base active cache missing key={self._active_cached}") + base_cached_buckets = list(self._cache_map[self._active_cached]) + for a in adapters_to_sync: + key = self._active_adapter_cached[a] + adapter_cached_buckets[a] = list(self._adapter_cache_map[a][key]) + elif self.is_lora: + # adapters_to_sync=None + LoRA mode: sync ALL active adapters (expand path). + active_entries = {a: k for a, k in self._active_adapter_cached.items() if k is not None} + if not active_entries: + raise RuntimeError( + "selective_sync_active_cache(is_lora, adapters_to_sync=None): no active adapter caches promoted yet" + ) + adapter_names_to_register = list(sorted(active_entries.keys())) + if self._active_cached is None: + raise RuntimeError( + "selective_sync_active_cache(is_lora): active base cache is unset; " + "call promote_active_checkpoint first" + ) + if self._active_cached not in self._cache_map: + raise RuntimeError(f"selective_sync_active_cache: base active cache missing key={self._active_cached}") + base_cached_buckets = list(self._cache_map[self._active_cached]) + for a, key in active_entries.items(): + adapter_cached_buckets[a] = list(self._adapter_cache_map[a][key]) + else: + # Full fine-tune path. + if self._active_cached is None: + raise RuntimeError( + "selective_sync_active_cache requires an active promoted cache (active_cached is unset)" + ) + if self._active_cached not in self._cache_map: + raise RuntimeError(f"active_cached={self._active_cached} missing from cache_map") + base_cached_buckets = list(self._cache_map[self._active_cached]) + + # --- Decode comm_plan for the single owner --- + # comm_plan is always non-None for the owner (ModelUpdateService guarantees this). + if comm_plan is None: + raise RuntimeError( + "selective_sync_active_cache: comm_plan must be non-None for the cache owner. " + "ModelUpdateService must always build a comm_plan keyed by the owner's src_rank." + ) + if world_rank not in comm_plan: + raise RuntimeError( + "selective_sync_active_cache comm_plan missing owner rank. " + f"owner_rank={world_rank} keys={sorted(int(k) for k in comm_plan.keys())}" + ) + comm_plan_args = comm_plan[world_rank] + group_name: Optional[str] = str(comm_plan_args["group_name"]) + ipc_targets: List[Dict[str, Any]] = comm_plan_args.get("ipc_targets", []) + broadcast_local_ranks_by_dp_rank: Dict[int, List[int]] = comm_plan_args.get( + "broadcast_local_ranks_by_dp_rank", {} + ) + planned_broadcast_ranks = sorted({int(td["rank"]) for td in comm_plan_args.get("tgt_devices", [])}) + broadcast_workers = [tgt_workers[r] for r in planned_broadcast_ranks] + logger.info( + f"[rlix][selective_sync_active_cache] comm_plan decoded: " + f"group_name={group_name} ipc_targets={len(ipc_targets)} " + f"broadcast_ranks={planned_broadcast_ranks} " + f"base_buckets={len(base_cached_buckets)} is_lora={self.is_lora}" + ) + + def _transport_bucket_sequence( + bucket_sequence: List[Any], + *, + is_lora_stage: bool, + phase_tag: str, + adapter_label: Optional[str] = None, + ) -> None: + """Transport one bucket sequence (base or adapter) to all target workers. + + For each bucket: stage CPU->GPU once, then fan out via IPC to + colocated workers and NCCL broadcast to remote workers. GPU staging + buffer is freed after each bucket to limit peak VRAM. + + When model_update_transport="cpu_serialize", the IPC path serializes the + CPU bucket directly with standard pickle (avoiding CUDA IPC). GPU + staging is skipped when there are no broadcast workers. + """ + transport = self.worker.pipeline_config.model_update_transport + for bucket_idx, (tensors_meta, cpu_bucket) in enumerate(bucket_sequence): + logger.info(f"[rlix][transport] bucket={bucket_idx}/{len(bucket_sequence)} phase={phase_tag} transport={transport}") + + # GPU staging is needed for NCCL broadcast or CUDA IPC serialization. + # With cpu_serialize and no broadcast workers, skip GPU staging entirely. + need_gpu_staging = bool(broadcast_workers) or transport == "cuda_ipc" + gpu_bucket = None + if need_gpu_staging: + gpu_bucket = cpu_bucket.to(current_platform.device_type).contiguous() + logger.info(f"[rlix][transport] bucket={bucket_idx} staged_to_gpu") + + # Transport workflow (IPC + NCCL overlap): + # 1. Fire async: IPC sends to colocated workers (same node) + # 2. Fire async: NCCL broadcasts to remote workers (cross-node, GPU-to-GPU) + # 3. Barrier: wait on all IPC + NCCL to finish + # 4. Free gpu_bucket — safe because all consumers have copied the data + # IPC and NCCL run concurrently to hide transfer latency. + + # Step 1: IPC path — serialize bucket once, then fan out to all colocated workers. + # Payload is identical for every IPC target, so serialize before the loop. + ipc_payload: Optional[bytes] = None + if ipc_targets: + if transport == "cpu_serialize": + # CPU serialization: torch.save for ~1.6x speedup over pickle.dumps + # on large tensors. Avoids CUDA IPC in restricted containers. + buf = io.BytesIO() + torch.save( + {"bucket": cpu_bucket.contiguous(), "tensors_meta": tensors_meta}, buf + ) + ipc_payload = buf.getvalue() + elif transport == "cuda_ipc": + # CUDA IPC: serialize GPU tensor via ForkingPickler. + # Ensure pickle uses GPU UUIDs instead of raw device indices, + # so the receiver resolves the correct local device even when + # CUDA_VISIBLE_DEVICES orderings differ between processes. + monkey_patch_torch_reductions() + ipc_payload = MultiprocessingSerializer.serialize( + {"bucket": gpu_bucket, "tensors_meta": tensors_meta} + ) + else: + raise ValueError( + f"Unsupported model_update_transport: {transport!r}. " + f"Expected 'cuda_ipc' or 'cpu_serialize'." + ) + + ipc_refs: List[ray.ObjectRef] = [] + for ipc_entry in ipc_targets: + tgt_dp_rank = int(ipc_entry["dp_rank"]) + ipc_local_ranks: List[int] = [int(r) for r in ipc_entry["local_ranks"]] + # Build a list long enough to cover all TP ranks (worker indexes by self.rank). + payload_list = [ipc_payload] * tgt_num_gpus_per_worker + ipc_refs.append( + tgt_workers[tgt_dp_rank].update_parameter_in_bucket.remote( + payload_list, + is_lora=is_lora_stage, + ipc_local_ranks=ipc_local_ranks, + model_update_transport=transport, + ) + ) + + # Step 2: NCCL path — broadcast to remote (non-colocated) workers. + # Requires gpu_bucket; only entered when broadcast_workers is non-empty + # (which guarantees gpu_bucket was staged above). + nccl_handles: List[Any] = [] + recv_refs: List[ray.ObjectRef] = [] + named_params: List[Any] = [] + if broadcast_workers and gpu_bucket is not None: + named_params = list(named_tensors_from_bucket(bucket=gpu_bucket, tensors_meta=tensors_meta)) + names = [n for n, _ in named_params] + dtypes = [t.dtype for _, t in named_params] + shapes = [t.shape for _, t in named_params] + + recv_refs = [ + worker.broadcast_parameter.remote( + group_name=group_name, + names=names, + dtypes=dtypes, + shapes=shapes, + is_lora=is_lora_stage, + broadcast_local_ranks=broadcast_local_ranks_by_dp_rank.get( + int(planned_broadcast_ranks[worker_idx]) + ), + ) + for worker_idx, worker in enumerate(broadcast_workers) + ] + + for _, weight in named_params: + nccl_handles.append( + collective.broadcast( + tensor=weight, + src_rank=0, + group_name=group_name, + async_op=True, + ) + ) + + # Step 3+4: barrier — wait for all transfers, then free GPU memory. + logger.info(f"[rlix][transport] bucket={bucket_idx} waiting nccl_handles={len(nccl_handles)} ipc_refs={len(ipc_refs)} recv_refs={len(recv_refs)}") + for nccl_handle in nccl_handles: + nccl_handle.wait() + logger.info(f"[rlix][transport] bucket={bucket_idx} nccl_done, waiting ray.get") + ray.get(ipc_refs + recv_refs) + logger.info(f"[rlix][transport] bucket={bucket_idx} all_done") + del nccl_handles, named_params + if gpu_bucket is not None: + del gpu_bucket + current_platform.empty_cache() + + # --- Transport: base buckets first, then per-adapter buckets --- + _transport_bucket_sequence(base_cached_buckets, is_lora_stage=False, phase_tag="base") + + if self.is_lora and adapter_names_to_register: + peft_configs = getattr(self.models_unwrapped[0], "peft_config", None) or {} + missing_cfg = [a for a in adapter_names_to_register if a not in peft_configs] + if missing_cfg: + raise RuntimeError( + f"selective_sync_active_cache: missing peft_config for adapters {missing_cfg}" + ) + for adapter_label in adapter_names_to_register: + buckets = adapter_cached_buckets.get(adapter_label, []) + if not buckets: + raise RuntimeError( + f"selective_sync_active_cache: no cached buckets for adapter={adapter_label!r}; " + "promote_active_adapter_checkpoint must be called before sync" + ) + _transport_bucket_sequence( + buckets, + is_lora_stage=True, + phase_tag="adapter", + adapter_label=adapter_label, + ) + # Compute the union of IPC and broadcast local ranks for this adapter's add_lora call. + # Collect all unique target actors across both paths. + adapter_target_actor_dp_ranks: Set[int] = set() + ipc_local_ranks_by_dp: Dict[int, List[int]] = { + int(entry["dp_rank"]): [int(r) for r in entry["local_ranks"]] + for entry in ipc_targets + } + for entry in ipc_targets: + adapter_target_actor_dp_ranks.add(int(entry["dp_rank"])) + for dp_rank in planned_broadcast_ranks: + adapter_target_actor_dp_ranks.add(int(dp_rank)) + + for dp_rank in sorted(adapter_target_actor_dp_ranks): + ipc_lr = ipc_local_ranks_by_dp.get(dp_rank, []) + broadcast_lr = broadcast_local_ranks_by_dp_rank.get(dp_rank, []) + lora_local_ranks = sorted(set(ipc_lr) | set(broadcast_lr)) or None + ray.get( + tgt_workers[dp_rank].add_lora.remote( + adapter_name=adapter_label, + peft_config=asdict(peft_configs[adapter_label]), + lora_local_ranks=lora_local_ranks, + ) + ) + + # --- Teardown broadcast group once after all replay completes --- + if broadcast_workers: + logger.info(f"[rlix][selective_sync_active_cache] teardown: destroying sender group {group_name}") + collective.destroy_collective_group(group_name) + logger.info(f"[rlix][selective_sync_active_cache] teardown: sender destroyed, destroying receiver groups") + ray.get([w.destroy_collective_group.remote(group_name) for w in broadcast_workers]) + logger.info(f"[rlix][selective_sync_active_cache] teardown: all groups destroyed") + + # Collect sender stats from cached versions for post-sync verification. + weight_stats: dict = {} + if base_cached_buckets: + base_key = self._active_cached + base_stats = self._cache_stats.get(base_key, {}) + if base_stats: + weight_stats["base"] = base_stats + if adapter_cached_buckets: + lora_stats: dict = {} + for adapter_label in adapter_names_to_register: + adapter_key = self._active_adapter_cached.get(adapter_label) + adapter_stats = self._adapter_cache_stats.get((adapter_label, adapter_key), {}) + if adapter_stats: + lora_stats[adapter_label] = adapter_stats + if lora_stats: + weight_stats["lora"] = lora_stats + + # Lock released. No dist.barrier() here: ray.get(sync_refs) in ModelUpdateService + # waits for all train workers to complete before the next sync is allowed. + return {"weight_stats": weight_stats} if weight_stats else None + + def _translate_offload_include( + self, include: Optional[List[OffloadStateType]] + ) -> Tuple[bool, List[MegatronOffloadStateType]]: + """Derive request intent from caller's include arg. + + Returns: + wants_model_params: whether model_params reload/offload is requested. + translated: Megatron-internal state types corresponding to requested states. + When include is None (all states), returns all three types explicitly. + """ + if include is None: + return True, [ + MegatronOffloadStateType.model_params, + MegatronOffloadStateType.other_params, + MegatronOffloadStateType.optimizer_states, + ] + translated: List[MegatronOffloadStateType] = [] + if OffloadStateType.model_params in include: + translated.append(MegatronOffloadStateType.model_params) + if OffloadStateType.other_params in include: + translated.append(MegatronOffloadStateType.other_params) + if OffloadStateType.optimizer_states in include: + translated.append(MegatronOffloadStateType.optimizer_states) + return OffloadStateType.model_params in include, translated def load_states(self, include=None, non_blocking=False): - if include is not None: - include_states = [] - if OffloadStateType.model_params in include: - reload_megatron_no_grad_module(model_chunks=self.model.get_models()) - include_states.append(MegatronOffloadStateType.model_params) - if OffloadStateType.other_params in include: - include_states.append(MegatronOffloadStateType.other_params) - if OffloadStateType.optimizer_states in include: - include_states.append(MegatronOffloadStateType.optimizer_states) - include = include_states - self.optimizer.reload_states(include=include, non_blocking=non_blocking) + """Reload optimizer and model states back to GPU. + + Behavior by caller context: + - isolated + include=None: no-grad swap runs, optimizer gets explicit full list + - non-isolated + include=None: no no-grad swap, optimizer gets raw None + - either + explicit include with model_params: no-grad swap runs, optimizer gets translated list + - either + explicit include without model_params: no no-grad swap + - isolated + include=[]: no no-grad swap, optimizer call skipped + - non-isolated + include=[]: no no-grad swap, optimizer called with [] + """ + wants_model_params, translated_include = self._translate_offload_include(include) + + # Manual no-grad reload needed when: + # - Isolated optimizer: always (optimizer doesn't manage frozen base params) + # - Explicit include with model_params: optimizer gets a filtered list, + # so it won't reload no-grad params on its own + # Skipped only for non-isolated + include=None where the optimizer handles all. + if wants_model_params and (self.is_lora_optimizer_isolated or include is not None): + reload_megatron_no_grad_module(model_chunks=self.model.get_models()) + + # Isolated path: always pass explicit translated list (never raw None). + # Non-isolated path: preserve raw None so optimizer uses its default "all" handling. + if self.is_lora_optimizer_isolated: + if translated_include: + self.optimizer.reload_states(include=translated_include, non_blocking=non_blocking) + else: + optimizer_include = None if include is None else translated_include + self.optimizer.reload_states(include=optimizer_include, non_blocking=non_blocking) def offload_states(self, include=None, non_blocking=False, pin_memory=True): - if include is not None: - include_states = [] - if OffloadStateType.model_params in include: - offload_megatron_no_grad_module(model_chunks=self.model.get_models(), pin_memory=pin_memory) - include_states.append(MegatronOffloadStateType.model_params) - if OffloadStateType.other_params in include: - include_states.append(MegatronOffloadStateType.other_params) - if OffloadStateType.optimizer_states in include: - include_states.append(MegatronOffloadStateType.optimizer_states) - include = include_states - self.optimizer.offload_states(include=include, non_blocking=non_blocking, pin_memory=pin_memory) + """Offload optimizer and model states from GPU. + + Behavior by caller context: + - isolated + include=None: no-grad swap runs, optimizer gets explicit full list + - non-isolated + include=None: no no-grad swap, optimizer gets raw None + - either + explicit include with model_params: no-grad swap runs, optimizer gets translated list + - either + explicit include without model_params: no no-grad swap + - isolated + include=[]: no no-grad swap, optimizer call skipped + - non-isolated + include=[]: no no-grad swap, optimizer called with [] + - rotary cache clear + CUDA cache clear always runs + """ + wants_model_params, translated_include = self._translate_offload_include(include) + + # Same manual no-grad condition as load_states. + if wants_model_params and (self.is_lora_optimizer_isolated or include is not None): + offload_megatron_no_grad_module( + model_chunks=self.model.get_models(), pin_memory=pin_memory, + ) + + if self.is_lora_optimizer_isolated: + if translated_include: + self.optimizer.offload_states( + include=translated_include, non_blocking=non_blocking, pin_memory=pin_memory, + ) + else: + optimizer_include = None if include is None else translated_include + self.optimizer.offload_states( + include=optimizer_include, non_blocking=non_blocking, pin_memory=pin_memory, + ) + + # Unconditional cleanup after offload (both paths, matches current behavior). RotaryEmbedding.forward.cache_clear() current_platform.empty_cache() @@ -1261,16 +2520,34 @@ def save_checkpoint(self, save_dir, global_step, ckpt_id, tag="checkpoint", loca validate_access_integrity=self._validate_access_integrity, ) self._validate_access_integrity = False - elif not dist.is_initialized() or mpu.get_data_modulo_expert_parallel_rank() == 0: + # Compatibility: older Megatron builds do not expose get_data_modulo_expert_parallel_rank(). + # Save optimizer when single-process (no dist) OR when data-parallel rank is 0. + elif (not dist.is_initialized()) or ( + ( + mpu.get_data_modulo_expert_parallel_rank() + if hasattr(mpu, "get_data_modulo_expert_parallel_rank") + else mpu.get_data_parallel_rank(with_context_parallel=False) + ) + == 0 + ): torch.save(self.optimizer.state_dict(), os.path.join(checkpoint_dir, OPTIMIZER_NAME)) logger.info(f"Saving optimizer state to {os.path.join(checkpoint_dir, OPTIMIZER_NAME)}") if dist.is_initialized(): - dist.barrier() + _safe_dist_barrier() - # save lr_scheduler + # save lr_scheduler — isolated mode saves a dict with {"mode": "isolated", + # "schedulers": {adapter_name: state_dict, ...}} so load_checkpoint can restore each + # adapter's LR schedule independently. if dist.get_rank() == 0: - torch.save(self.scheduler.state_dict(), os.path.join(save_dir, SCHEDULER_NAME)) + if self.adapter_schedulers is not None: + scheduler_state = { + "mode": "isolated", + "schedulers": {k: v.state_dict() for k, v in self.adapter_schedulers.items()}, + } + else: + scheduler_state = self.scheduler.state_dict() + torch.save(scheduler_state, os.path.join(save_dir, SCHEDULER_NAME)) # save rng state rng_states = { @@ -1280,6 +2557,9 @@ def save_checkpoint(self, save_dir, global_step, ckpt_id, tag="checkpoint", loca "cuda_rng_state": current_platform.get_rng_state(), "rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(), } + # Per-adapter RNG states enable deterministic per-adapter dropout across checkpoint restarts. + if getattr(self, "adapter_rng_states", None) is not None: + rng_states["adapter_rng_states"] = self.adapter_rng_states rgn_path = os.path.join(save_dir, RNG_STATE_DIR, f"rng_state_{dist.get_rank()}.pth") os.makedirs(os.path.dirname(rgn_path), exist_ok=True) torch.save(rng_states, rgn_path) @@ -1327,7 +2607,23 @@ def load_checkpoint(self, load_dir, tag="checkpoint", **kwargs): self.optimizer.load_state_dict(state_dict) # load lr_scheduler - self.scheduler.load_state_dict(torch.load(os.path.join(load_dir, SCHEDULER_NAME))) + scheduler_state = torch.load(os.path.join(load_dir, SCHEDULER_NAME), weights_only=False) + if isinstance(scheduler_state, dict) and scheduler_state.get("mode") == "isolated": + if self.adapter_schedulers is None: + raise RuntimeError( + "Checkpoint contains shared-mode LoRA scheduler state which is no longer supported. " + "Only per-adapter LoRA checkpoints can be resumed." + ) + for adapter_name, state in scheduler_state["schedulers"].items(): + if adapter_name not in self.adapter_schedulers: + raise RuntimeError( + f"Checkpoint contains scheduler state for adapter {adapter_name!r} " + "but this adapter is not registered in the current strategy." + ) + self.adapter_schedulers[adapter_name].load_state_dict(state) + logger.info(f"Loaded per-adapter scheduler states for: {sorted(scheduler_state['schedulers'].keys())}") + else: + self.scheduler.load_state_dict(scheduler_state) # load model state dict state_dict = load_state_dict_from_checkpoint(load_dir) @@ -1349,6 +2645,9 @@ def load_checkpoint(self, load_dir, tag="checkpoint", **kwargs): if not checkpoint_rng_state["rng_tracker_states"]: raise KeyError tensor_parallel.get_cuda_rng_tracker().set_states(checkpoint_rng_state["rng_tracker_states"]) + if "adapter_rng_states" in checkpoint_rng_state and getattr(self, "adapter_rng_states", None) is not None: + self.adapter_rng_states.update(checkpoint_rng_state["adapter_rng_states"]) + logger.info(f"Loaded adapter RNG states for: {sorted(checkpoint_rng_state['adapter_rng_states'].keys())}") else: logger.info(f"not load rng state, not found file: {rng_file}") diff --git a/roll/distributed/strategy/sglang_strategy.py b/roll/distributed/strategy/sglang_strategy.py index 475e0eb70..757a888fa 100644 --- a/roll/distributed/strategy/sglang_strategy.py +++ b/roll/distributed/strategy/sglang_strategy.py @@ -23,6 +23,8 @@ InitWeightsUpdateGroupReqInput, UpdateWeightsFromTensorReqInput, ) + +from roll.utils.constants import DO_TIME_SHARING from roll.utils.functionals import concatenate_input_and_output from roll.utils.logging import get_logger from roll.utils.network_utils import collect_free_port @@ -109,7 +111,7 @@ async def initialize(self, model_provider): "trust_remote_code": True, "tp_size": tp_size, "log_level": sglang_config.get("log_level", "info"), - "port": 30000 + dp_rank * 500, + "port": self.worker.get_free_port(), # 'disable_cuda_graph': True, "disable_custom_all_reduce": sglang_config.get("disable_custom_all_reduce", True), 'nnodes': nnodes, @@ -118,7 +120,7 @@ async def initialize(self, model_provider): ) if nnodes > 1: - sglang_config['dist_init_addr'] = f'{ray.util.get_node_ip_address()}:{collect_free_port()}' + sglang_config['dist_init_addr'] = f'{ray.util.get_node_ip_address()}:{self.worker.get_free_port()}' logger.info(f"[sglang][sglang_config]: {sglang_config}") @@ -378,7 +380,7 @@ async def load_states(self, *args, **kwargs): async def offload_states(self, include=None, non_blocking=False): if include is None or OffloadStateType.model_params in include: - if self.worker.pipeline_config.is_actor_infer_colocated and self.is_model_in_gpu: + if (self.worker.pipeline_config.is_actor_infer_colocated or DO_TIME_SHARING) and self.is_model_in_gpu: await self.model.tokenizer_manager.release_memory_occupation(ReleaseMemoryOccupationReqInput(), None) logger.info("self.model.release_memory_occupation exec ....") # always release all diff --git a/roll/distributed/strategy/strategy.py b/roll/distributed/strategy/strategy.py index 410af36ae..862e57330 100644 --- a/roll/distributed/strategy/strategy.py +++ b/roll/distributed/strategy/strategy.py @@ -90,7 +90,13 @@ def setup_model_update(self, *args, **kwargs): raise NotImplementedError def _setup_collective_group_impl( - self, model_update_name, comm_plan, backend, mode + self, + model_update_name, + comm_plan, + backend, + mode, + *, + timeout_s: Optional[float] = None, ): """ mode: @@ -124,7 +130,7 @@ def _setup_collective_group_impl( collective.init_collective_group( world_size, rank, backend=backend, group_name=group_name, - master_addr=master_addr, master_port=master_port + master_addr=master_addr, master_port=master_port, timeout_s=timeout_s ) collective.allreduce(torch.zeros(1).to(current_platform.device_type), group_name=group_name) @@ -140,11 +146,36 @@ def _setup_collective_group_impl( ) logger.info(f"warmup setup_collective_group: {group_name} rank: {rank} world_size: {world_size}") - def setup_collective_group(self, model_update_name, comm_plan, backend=None, mode="receiver"): + def setup_collective_group( + self, + model_update_name, + comm_plan, + backend=None, + mode="receiver", + *, + timeout_s: Optional[float] = None, + ): """ 单卡infer strategy可直接复用,多卡infer strategy需要自行管理 """ - self._setup_collective_group_impl(model_update_name, comm_plan, backend, mode=mode) + self._setup_collective_group_impl(model_update_name, comm_plan, backend, mode=mode, timeout_s=timeout_s) + + def destroy_collective_group(self, group_name: str, model_update_name: str | None = None) -> None: + # Destroy a single collective group and optionally clean up bookkeeping. + collective.destroy_collective_group(group_name) + + # Clean up bookkeeping if model_update_name is provided. + # Structure: model_update_comm_plan[model_update_name][src_pp_rank] = {group_name: ..., ...} + # Multiple src_pp_rank entries may exist under one model_update_name, each with its own group_name. + # We must iterate to remove only entries matching this group_name, then prune model_update_name if empty. + if model_update_name is not None: + plan = getattr(self, "model_update_comm_plan", None) + if isinstance(plan, dict) and model_update_name in plan: + for src_pp_rank in list(plan[model_update_name].keys()): + if plan[model_update_name][src_pp_rank].get("group_name") == group_name: + plan[model_update_name].pop(src_pp_rank, None) + if not plan[model_update_name]: + plan.pop(model_update_name, None) # offload/load 相关接口 def load_states(self, *args, **kwargs): @@ -439,8 +470,16 @@ def __init__(self, worker: "Worker"): self.scheduler = None self.checkpoint_manager = CheckpointManager(checkpoint_config=self.worker_config.checkpoint_config) - def setup_collective_group(self, model_update_name, comm_plan, backend=None, mode="sender"): - self._setup_collective_group_impl(model_update_name, comm_plan, backend, mode=mode) + def setup_collective_group( + self, + model_update_name, + comm_plan, + backend=None, + mode="sender", + *, + timeout_s: Optional[float] = None, + ): + self._setup_collective_group_impl(model_update_name, comm_plan, backend, mode=mode, timeout_s=timeout_s) def setup_p2p_collective_group(self, model_update_name, comm_plan, backend="nccl"): diff --git a/roll/distributed/strategy/vllm_strategy.py b/roll/distributed/strategy/vllm_strategy.py index 0e446efe5..801b87e86 100644 --- a/roll/distributed/strategy/vllm_strategy.py +++ b/roll/distributed/strategy/vllm_strategy.py @@ -19,8 +19,10 @@ from roll.distributed.scheduler.protocol import DataProto, list_of_dict_to_dict_of_list from roll.distributed.strategy.strategy import InferenceStrategy from roll.third_party.vllm import create_async_llm +from roll.utils.constants import DO_TIME_SHARING from roll.utils.functionals import concatenate_input_and_output, reduce_metrics from roll.utils.logging import get_logger +from roll.utils.lora_routing import ensure_lora_name_in_batch, get_lora_name_array, resolve_microbatch_lora_name from roll.utils.offload_states import OffloadStateType from roll.platforms import current_platform @@ -28,6 +30,43 @@ logger = get_logger() +def _normalize_lora_int_ids_loaded(value) -> list[int]: + """Normalize LoRA adapter integer IDs returned by vLLM's list_loras RPC. + + vLLM's ``list_loras`` API has inconsistent return formats across versions and + distributed configurations: + - Single GPU: returns ``[id1, id2, ...]`` (flat list of ints) + - Multi-GPU/Tensor Parallel: may return ``[[id1, id2], [id1, id2], ...]`` + where each sub-list corresponds to a different rank's view + - Empty state: returns ``[]`` or ``[[]]`` + + This helper flattens nested structures, deduplicates across ranks, and returns + a sorted list of unique integer adapter IDs for consistent downstream handling. + + Args: + value: The raw return value from ``await model.list_loras()``. May be + a flat list of ints, a nested list of lists, or an empty list. + + Returns: + A sorted list of unique integer LoRA adapter IDs. Returns an empty list + for invalid or empty inputs. + """ + if not isinstance(value, list) or not value: + return [] + # Handle nested [[id,...], ...] format from multi-rank responses + if isinstance(value[0], list): + flat: list[int] = [] + for sub in value: + if not isinstance(sub, list): + continue + for item in sub: + if isinstance(item, int): + flat.append(item) + return sorted(set(flat)) + # Handle flat [id, ...] format from single-rank responses + return [item for item in value if isinstance(item, int)] + + class VllmStrategy(InferenceStrategy): strategy_name = "vllm" @@ -85,15 +124,58 @@ async def initialize(self, model_provider): } ) - self.is_lora = self.worker_config.model_args.lora_target is not None + # ===================================================================== + # Multi-LoRA Configuration + # ===================================================================== + # Detection: LoRA mode is active when adapters dict is configured and non-empty. + # This replaces the legacy lora_target field check. + # Note: We check both `is not None` and `len() > 0` because: + # - adapters=None → LoRA disabled + # - adapters={} (empty dict) → invalid config, would crash on max_lora_rank + adapters = self.worker_config.model_args.adapters + self.is_lora = adapters is not None and len(adapters) > 0 if self.is_lora: + # ----------------------------------------------------------------- + # vLLM V1 Multi-LoRA Support: + # ----------------------------------------------------------------- + # vLLM V1 supports multi-LoRA with prefix caching and chunked prefill: + # - Block hashes include LoRA adapter name via _gen_lora_extra_hash_keys() + # - Each request's lora_request.lora_name is part of the cache key + # - See: vllm/v1/core/kv_cache_utils.py:generate_block_hash_extra_keys() + # ----------------------------------------------------------------- + + # max_loras: Maximum number of LoRA adapters that can be resident in GPU + # memory simultaneously. Set to at least configured adapters + 1 for + # dynamic loading headroom. + max_loras_cfg = int(vllm_config.get("max_loras", 0) or 0) lora_kwargs = { "enable_lora": True, - "max_loras": 1, - "max_lora_rank": self.worker_config.model_args.lora_rank, + "max_loras": max(max_loras_cfg, len(adapters) + 1), + "max_lora_rank": max(a.lora_rank for a in adapters.values()), } vllm_config.update(lora_kwargs) - vllm_config["load_format"] = "auto" # enables vLLM to load the base model for add_lora + # LoRA mode requires real base model weights for adapter weight initialization. + # "dummy" load_format only works for weight broadcasting from trainer. + vllm_config["load_format"] = "auto" + + # Guard: LoRA mode is incompatible with dummy load_format (used for weight broadcasting). + # Users must either set load_format='auto' or disable LoRA. + if self.is_lora and vllm_config.get("load_format") == "dummy": + raise RuntimeError( + "vLLM LoRA mode requires real base model weights; got load_format='dummy'. " + "Set vllm strategy_config.load_format='auto' or disable LoRA." + ) + + # Guard: Multi-LoRA routing requires vLLM V1 engine for adapter-id RPC APIs. + # The V0 engine does not expose the per-request adapter selection APIs needed + # for routing different samples to different LoRA adapters in a single batch. + if self.is_lora: + vllm_use_v1 = int(os.environ.get("VLLM_USE_V1", "1")) + if vllm_use_v1 != 1: + raise RuntimeError( + "LoRA mode in ROLL requires VLLM_USE_V1=1. " + "Non-v1 engine path does not expose adapter-id APIs required by multi-LoRA routing." + ) logger.info(f"vllm_config: {vllm_config}") assert not dist.is_initialized() @@ -149,7 +231,26 @@ def _should_use_beam_search(self, generation_config) -> bool: return generation_config.get("num_beams", 1) > 1 or generation_config.get("use_beam_search", False) async def _generate_standard(self, batch: DataProto, generation_config: Dict) -> torch.Tensor: - """Standard generate method for non-beam search cases.""" + """Standard generate method for non-beam search cases with multi-LoRA routing. + + This method handles both single-LoRA and multi-LoRA scenarios: + - Single-LoRA: All samples use the same adapter (auto-filled if not specified) + - Multi-LoRA: Each sample specifies its adapter via ``lora_name`` in non_tensor_batch + + The multi-LoRA routing flow: + 1. Extract per-sample adapter names from ``batch.non_tensor_batch["lora_name"]`` + 2. Resolve each adapter name to its vLLM-assigned integer ID + 3. Construct a ``LoRARequest`` per sample + 4. Pass the per-sample requests to vLLM's generate API + + Args: + batch: Input batch containing ``batch`` (tensor data) and ``non_tensor_batch`` + (metadata including optional ``lora_name`` array). + generation_config: Generation parameters (temperature, top_p, etc.). + + Returns: + Output tensor of shape ``(bs * num_return_sequences, input_len + max_response_len)``. + """ sampling_params = create_sampling_params_for_vllm(gen_kwargs=generation_config) input_ids = batch.batch["input_ids"] # (bs, prompt_length) @@ -162,14 +263,70 @@ async def _generate_standard(self, batch: DataProto, generation_config: Dict) -> for prompt in gather_unpadded_input_ids(input_ids=input_ids, attention_mask=attention_mask) ] - lora_request = None + # ===================================================================== + # Multi-LoRA Per-Sample Routing + # ===================================================================== + # In multi-LoRA mode, each sample in the batch may use a different adapter. + # The adapter assignment is determined by: + # 1. Explicit per-sample ``lora_name`` in non_tensor_batch (producer sets this) + # 2. Single-adapter fallback: if only one adapter is configured, all samples + # use it automatically (ensures backward compatibility) + # + # The routing validation ensures: + # - ``lora_name`` array length matches batch size + # - All referenced adapters are registered and loaded in vLLM + # ===================================================================== if self.is_lora: - lora_int_ids = list(await self.model.list_loras()) - if len(lora_int_ids) > 0: - lora_int_id = lora_int_ids[0] - lora_request = LoRARequest(lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="dummy_lora_path") + ensure_lora_name_in_batch( + batch.non_tensor_batch, + adapters=self.worker_config.model_args.adapters, + batch_size=batch.batch["input_ids"].size(0), + ) + + lora_requests: list[LoRARequest | None] | None = None + if self.is_lora: + # Step 1: Extract per-sample adapter names + lora_names = get_lora_name_array(batch.non_tensor_batch) + + # Step 2: Validate adapter count matches prompt count + if len(lora_names) != len(prompts): + logger.error("LoRA routing mismatch: len(lora_names)=%s len(prompts)=%s", len(lora_names), len(prompts)) + raise RuntimeError( + f"vLLM routing requires len(lora_name)==len(prompts), got {len(lora_names)} vs {len(prompts)}" + ) - async def _generate(prompt): + # Step 3: Build adapter name -> integer ID mapping + adapters = [str(d) for d in lora_names.tolist()] + lora_request_path = self.worker_config.model_args.model_name_or_path + lora_int_ids_loaded = _normalize_lora_int_ids_loaded(await self.model.list_loras()) + adapter_to_int_id: dict[str, int] = {} + for adapter in sorted(set(adapters)): + # Validate adapter is configured + if adapter not in self.worker_config.model_args.adapters: + raise RuntimeError(f"Unknown LoRA adapter requested by lora_name={adapter!r}") + # Get vLLM-assigned integer ID + lora_int_id = await self.get_lora_id(adapter) + if lora_int_id is None: + raise RuntimeError(f"Missing LoRA adapter in vLLM engine: {adapter!r}") + # Verify adapter is loaded (visible in list_loras) + if lora_int_id not in lora_int_ids_loaded: + raise RuntimeError( + f"LoRA adapter id not loaded in vLLM engine: adapter={adapter!r} lora_int_id={lora_int_id}" + ) + adapter_to_int_id[adapter] = lora_int_id + + # Step 4: Construct per-sample LoRARequest objects + # vLLM uses these to route each request to the correct adapter's weights. + lora_requests = [ + LoRARequest( + lora_name=adapter, + lora_int_id=adapter_to_int_id[adapter], + lora_path=lora_request_path, + ) + for adapter in adapters + ] + + async def _generate(prompt, lora_request: LoRARequest | None): request_id = random_uuid() result_generator = self.model.generate( prompt=prompt, @@ -182,7 +339,13 @@ async def _generate(prompt): output = result return output - vllm_outputs = await asyncio.gather(*[_generate(prompt) for prompt in prompts]) + # Execute all generations in parallel, each with its LoRARequest (or None for non-LoRA mode) + if lora_requests is None: + vllm_outputs = await asyncio.gather(*[_generate(prompt, None) for prompt in prompts]) + else: + vllm_outputs = await asyncio.gather( + *[_generate(prompt, lora_request) for prompt, lora_request in zip(prompts, lora_requests, strict=True)] + ) # (bs * num_return_sequences, max_response_len) output_ids = gather_outputs_to_pad_tensor( @@ -260,6 +423,37 @@ async def _beam_search(prompt): return output async def generate_request(self, data: DataProto): + """Generate for a single streaming request with LoRA adapter routing. + + Unlike ``_generate_standard`` which handles batch inference with per-sample + LoRA routing, this method handles single-request streaming generation where + each request uses exactly one LoRA adapter. + + The LoRA routing flow for single requests: + 1. Resolve the adapter name from ``non_tensor_batch`` (single value, not array) + 2. Look up the vLLM-assigned integer ID for the adapter + 3. Verify the adapter is loaded + 4. Construct and pass a single ``LoRARequest`` to vLLM + + Routing metadata is recorded in ``data.meta_info`` for observability: + - ``routed_lora_name``: The resolved adapter name + - ``routed_lora_int_id``: The vLLM integer ID for the adapter + + Args: + data: Input data proto containing: + - ``batch``: Tensor data (input_ids, attention_mask) + - ``non_tensor_batch``: Metadata including ``lora_name`` + - ``meta_info``: Request ID, generation config, etc. + + Returns: + DataProto with output tokens, finish reasons, and logprobs in meta_info. + + Raises: + RuntimeError: If LoRA routing fails (adapter not found, not loaded, etc.) + """ + # Keep meta_info writable for routing diagnostics; some callers may pass None. + if data.meta_info is None: + data.meta_info = {} collect_unfinished = data.meta_info.get("collect_unfinished", False) input_ids = data.batch["input_ids"] attention_mask = data.batch["attention_mask"] @@ -283,12 +477,51 @@ async def generate_request(self, data: DataProto): prompt_token_ids = gather_unpadded_input_ids(input_ids=input_ids, attention_mask=attention_mask) assert len(prompt_token_ids) == 1 prompt = TokensPrompt(prompt_token_ids=prompt_token_ids[0]) + + # ===================================================================== + # Single-Request LoRA Routing + # ===================================================================== + # For streaming requests, each request uses exactly one LoRA adapter. + # The adapter name is resolved from non_tensor_batch (single value, not array). + # This differs from _generate_standard where we handle per-sample routing + # within a batch. + # ===================================================================== + if self.is_lora: + ensure_lora_name_in_batch( + data.non_tensor_batch, + adapters=self.worker_config.model_args.adapters, + batch_size=data.batch["input_ids"].size(0), + ) + lora_request = None if self.is_lora: - lora_int_ids = list(await self.model.list_loras()) - if len(lora_int_ids) > 0: - lora_int_id = lora_int_ids[0] - lora_request = LoRARequest(lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="dummy_lora_path") + # Step 1: Resolve the adapter name for this single request + routing = resolve_microbatch_lora_name(data.non_tensor_batch) + + # Step 2: Get vLLM-assigned integer ID for the adapter + lora_name = routing.lora_name + lora_int_id = await self.get_lora_id(lora_name) + if lora_int_id is None: + raise RuntimeError(f"Missing LoRA adapter in vLLM engine: {lora_name!r}") + + # Record routing decision for observability + data.meta_info["routed_lora_name"] = lora_name + data.meta_info["routed_lora_int_id"] = int(lora_int_id) + + # Step 3: Verify adapter is loaded (handle race condition after add_lora) + lora_int_ids_loaded = _normalize_lora_int_ids_loaded(await self.model.list_loras()) + if lora_int_id not in lora_int_ids_loaded: + # Fail fast if adapter not visible - add_lora should have waited + raise RuntimeError( + f"LoRA adapter id not loaded: adapter={lora_name!r} lora_int_id={lora_int_id} loaded={lora_int_ids_loaded[:16]!r}" + ) + + # Step 4: Construct LoRARequest for vLLM + lora_request = LoRARequest( + lora_name=lora_name, + lora_int_id=lora_int_id, + lora_path=self.worker_config.model_args.model_name_or_path, + ) result_generator = self.model.generate( prompt=prompt, @@ -334,38 +567,311 @@ async def abort_requests(self, request_ids): # offload/reload 接口 async def load_states(self, *args, **kwargs): - await self.model.reset_prefix_cache() + # Ensure KV/block manager exists before reset_prefix_cache. Calling reset on an + # uninitialized engine state can block indefinitely. + logger.info("[vllm_strategy][load_states] enter is_model_in_gpu=%s", self.is_model_in_gpu) if not self.is_model_in_gpu: + logger.info("[vllm_strategy][load_states] calling model.load_states()") await self.model.load_states() self.is_model_in_gpu = True + logger.info("[vllm_strategy][load_states] model.load_states() done") + logger.info("[vllm_strategy][load_states] calling reset_prefix_cache()") + await self.model.reset_prefix_cache() + logger.info("[vllm_strategy][load_states] reset_prefix_cache() done") async def offload_states(self, include=None, non_blocking=False): await self.model.reset_prefix_cache() if include is None or OffloadStateType.model_params in include: - if self.is_model_in_gpu and self.worker.pipeline_config.is_actor_infer_colocated: + if self.is_model_in_gpu and (self.worker.pipeline_config.is_actor_infer_colocated or DO_TIME_SHARING): await self.model.offload_states(self.sleep_level) self.is_model_in_gpu = False gc.collect() current_platform.empty_cache() def process_weights_after_loading(self,*args, **kwargs): - self.model.process_weights_after_loading() + # CustomAsyncLLM.process_weights_after_loading is async; return the awaitable so caller can await. + return self.model.process_weights_after_loading() + + # ===================================================================== + # Collective Communication Group Management + # ===================================================================== + # These methods manage process groups for distributed weight synchronization + # between trainer (FSDP2) and inference workers. Two call styles are supported: + # + # 1. Dynamic comm_plan style (modern, selective model-update): + # Used for fine-grained control over which ranks participate in each group. + # setup_collective_group(comm_plan=..., backend=?, timeout_s=?) + # + # 2. Legacy/persistent broadcast group style: + # Used for traditional all-rank broadcast communication patterns. + # setup_collective_group(master_address=..., master_port=..., rank_offset=..., + # world_size=..., group_name=..., backend=?, timeout_s=?) + # ===================================================================== + + async def setup_collective_group(self, *args, **kwargs) -> None: + """Create a collective communication group for distributed operations. + + This method supports two calling conventions for different use cases: + + **Style 1: Dynamic comm_plan (recommended for multi-LoRA)** + Uses a communication plan that specifies which ranks participate. + This enables selective model updates where only relevant workers + receive weight broadcasts for specific adapters. + + Required kwargs: + comm_plan: Communication plan specifying participant ranks. + + Optional kwargs: + backend: Communication backend (defaults to platform default). + timeout_s: Timeout for group creation in seconds. + + **Style 2: Legacy broadcast group** + Creates a persistent process group for traditional all-rank broadcasts. + Used when all workers need to participate in weight synchronization. + + Required kwargs: + master_address: Address of the rank 0 process. + master_port: Port for communication. + rank_offset: Offset to apply to local ranks. + world_size: Total number of ranks in the group. + group_name: Unique identifier for this process group. + + Optional kwargs: + backend: Communication backend (defaults to platform default). + timeout_s: Timeout for group creation in seconds. + + Raises: + TypeError: If neither style's required arguments are provided. + """ + # Style 1: Dynamic comm_plan based group setup + if "comm_plan" in kwargs: + backend = kwargs.get("backend", None) + timeout_s = kwargs.get("timeout_s", None) + comm_plan = kwargs["comm_plan"] + backend = backend if backend is not None else current_platform.communication_backend + await self.model.setup_collective_group( + comm_plan=comm_plan, backend=backend, rank_in_cluster=self.worker.rank, timeout_s=timeout_s + ) + return + + # Style 2: Legacy/persistent broadcast group + required = {"master_address", "master_port", "rank_offset", "world_size", "group_name"} + if required.issubset(kwargs.keys()): + backend = kwargs.get("backend", None) + timeout_s = kwargs.get("timeout_s", None) + backend = backend if backend is not None else current_platform.communication_backend + logger.info(f"setup_collective_group group_name={kwargs['group_name']!r}") + await self.model.setup_collective_group( + kwargs["master_address"], + kwargs["master_port"], + kwargs["rank_offset"], + kwargs["world_size"], + kwargs["group_name"], + backend, + timeout_s=timeout_s, + ) + return - # 参数同步相关接口 - async def setup_collective_group(self, master_address, master_port, rank_offset, world_size, group_name, backend=None): - logger.info(f"setup_collective_group {group_name=}") - backend = backend if backend is not None else current_platform.communication_backend - await self.model.setup_collective_group(master_address, master_port, rank_offset, world_size, group_name, backend) + raise TypeError( + "VllmStrategy.setup_collective_group expects either " + "(comm_plan=..., backend=?, timeout_s=?) " + "or (master_address=..., master_port=..., rank_offset=..., world_size=..., group_name=..., backend=?, timeout_s=?)." + ) + + async def broadcast_parameter(self, names, dtypes, shapes, group_name, is_lora=False, *, broadcast_local_ranks=None): + await self.model.broadcast_parameter(names, dtypes, shapes, group_name, is_lora, broadcast_local_ranks=broadcast_local_ranks) - async def broadcast_parameter(self, names, dtypes, shapes, group_name, is_lora=False): - await self.model.broadcast_parameter(names, dtypes, shapes, group_name, is_lora) + async def update_parameter_in_bucket( + self, serialized_named_tensors, is_lora=False, *, ipc_local_ranks=None, model_update_transport="cuda_ipc" + ): + await self.model.update_parameter_in_bucket( + serialized_named_tensors, is_lora, + ipc_local_ranks=ipc_local_ranks, model_update_transport=model_update_transport, + ) - async def update_parameter_in_bucket(self, serialized_named_tensors, is_lora=False): - await self.model.update_parameter_in_bucket(serialized_named_tensors, is_lora) + async def destroy_collective_group(self, group_name: str, model_update_name: str | None = None) -> None: + """Destroy a previously created collective communication group. - async def add_lora(self, peft_config): - peft_config["target_modules"] = set(self.worker_config.model_args.lora_target) - await self.model.add_lora(peft_config) + Args: + group_name: The name of the process group to destroy. + model_update_name: Unused in vLLM strategy (kept for API compatibility + with other strategies that track model_update_comm_plan state). + """ + # vLLM has no model_update_comm_plan bookkeeping; model_update_name is unused. + del model_update_name + await self.model.destroy_collective_group(group_name) + + async def add_lora( + self, + adapter_name: str = "default", + peft_config: dict = None, + *, + lora_local_ranks=None, + wake_after_add: bool = True, + ): + """Register a LoRA adapter with the vLLM inference engine. + + This method handles the full lifecycle of LoRA adapter registration: + 1. Validates the adapter name against the configured adapters dict + 2. Calls vLLM's add_lora RPC with the PEFT configuration + 3. Tracks readiness via wake_after_add without follow-up visibility RPCs + + The method is designed for multi-LoRA scenarios where different samples + in a batch may need different adapters. Each adapter must be registered + before it can be used in inference via LoRARequest routing. + + This method always re-registers the adapter to ensure updated LoRA weights + from the latest training step are applied. The caller must evict stale + registrations via offload_states() before the next model_update, because + LoRA GPU tensors are discarded for both sleep_level=1 and sleep_level=2. + + Args: + adapter_name: Name of the adapter to register. Must match a key in + ``worker_config.model_args.adapters``. Defaults to "default" for + backward compatibility with single-LoRA callers. + peft_config: PEFT configuration dict containing LoRA parameters. + Required. The ``target_modules`` field is overwritten from the + configured adapter spec to ensure consistency. + wake_after_add: Whether this adapter registration should fully wake + the vLLM engine (weights + KV cache). For multi-adapter updates, + callers set this only on the last adapter. + Raises: + RuntimeError: If: + - ``peft_config`` is None + - ``adapter_name`` is not in the configured adapters + - ``adapter_name="default"`` in multi-LoRA mode (FSDP2 limitation) + + Note: + - This method intentionally avoids immediate post-registration visibility + RPC checks (``get_lora_id``/``list_loras``) to avoid reentrancy stalls. + Readiness is tracked via ``wake_after_add``: non-final adapters keep + KV cache asleep, while the final adapter marks the model ready. + - For multi-LoRA with FSDP2 trainer, use explicit adapter names instead + of the "default" placeholder to avoid ambiguity. + """ + # Backward-compatible: single-LoRA callers may pass only peft_config and rely on adapter_name default. + if peft_config is None: + raise RuntimeError("add_lora: peft_config must not be None") + adapters = self.worker_config.model_args.adapters or {} + if adapter_name not in adapters: + raise RuntimeError( + f"add_lora: unknown adapter_name={adapter_name!r}. " + f"Valid adapters: {sorted(adapters.keys())}" + ) + # Guard: FSDP2 model_update path does not support multi-LoRA weight broadcasting. + # Using "default" name in multi-LoRA config would cause ambiguity. + if adapter_name == "default" and len(adapters) > 1: + raise RuntimeError( + "add_lora called with adapter_name='default' in multi-LoRA mode. " + "FSDP2 model_update path does not support multi-LoRA. " + f"Configured adapters: {list(adapters.keys())}" + ) + # Keep target_modules JSON-serializable and deterministic for worker-side hashing. + peft_config["target_modules"] = sorted(adapters[adapter_name].lora_target) + # Blocking RPC: does not return until custom_add_lora on the worker completes. + # Inside custom_add_lora the sequence is: + # 1. reload_model() → wake_up(["weights"]) only (no KV cache wake-up) + # 2. vLLM.add_lora() → LoRA tensors loaded to GPU, adapter registered in vLLM Python cache + # 3. register(name, id) → _lora_names updated only after vLLM confirms success + await self.model.add_lora( + adapter_name, + peft_config, + lora_local_ranks=lora_local_ranks, + wake_after_add=wake_after_add, + ) + # No follow-up visibility RPCs here (get_lora_id/list_loras) to avoid + # reentrancy hazards. Trust worker-level add_lora success and track GPU + # readiness based on whether this call performed the final wake-up. + self.is_model_in_gpu = wake_after_add + logger.info( + "[vllm_strategy][add_lora] registered adapter=%s (worker-level ok; is_model_in_gpu=%s)", + adapter_name, self.is_model_in_gpu, + ) + + async def verify_model(self, expected_stats: dict) -> None: + """Verify post-sync weights match sender stats, with TP aggregation for base model. + + Dispatches custom_verify_model to all TP ranks via collective_rpc_async. + Base stats: aggregated across TP ranks (sum-of-sums, max-of-maxes, min-of-mins) + because post-ingestion weights are TP-sharded. + LoRA stats: identical across TP ranks (broadcast sends same data to each rank), + so first rank's result is used directly. + """ + from roll.utils.send_recv_utils import verify_weight_stats + + per_rank_results = await self.model.verify_model(expected_stats=expected_stats) + + # Normalize collective_rpc_async return format (same shape variations as get_lora_id). + if not isinstance(per_rank_results, list): + per_rank_results = [per_rank_results] + # Flatten nested [[result], ...] format from some vLLM versions. + if len(per_rank_results) == 1 and isinstance(per_rank_results[0], list): + per_rank_results = per_rank_results[0] + + # Base model: TP-aggregate then compare against sender stats. + if "base" in expected_stats: + agg_sum = sum(rank_result["base"]["sum"] for rank_result in per_rank_results) + agg_max = max(rank_result["base"]["max"] for rank_result in per_rank_results) + agg_min = min(rank_result["base"]["min"] for rank_result in per_rank_results) + aggregated_base = {"sum": agg_sum, "max": agg_max, "min": agg_min} + verify_weight_stats(aggregated_base, expected_stats["base"], label="base") + + # LoRA: all TP ranks have identical raw tensors; take first rank's result. + if "lora" in expected_stats: + first_rank_lora = per_rank_results[0].get("lora", {}) + for adapter_name, expected_adapter_stats in expected_stats["lora"].items(): + actual_adapter_stats = first_rank_lora.get(adapter_name) + if actual_adapter_stats is None: + raise RuntimeError( + f"verify_model: adapter {adapter_name!r} missing from rank 0 result; " + f"available={sorted(first_rank_lora.keys())}" + ) + verify_weight_stats( + actual_adapter_stats, expected_adapter_stats, label=f"lora/{adapter_name}" + ) + + logger.info("[vllm_strategy][verify_model] ok tp_ranks=%d", len(per_rank_results)) + + async def get_lora_id(self, adapter_name: str) -> int | None: + """Get the integer ID assigned by vLLM for a named LoRA adapter. + + vLLM assigns unique integer IDs to each registered LoRA adapter. These IDs + are required for constructing ``LoRARequest`` objects during inference. + + Note: + vLLM's ``get_lora_id`` RPC may return various formats depending on the + distributed configuration: + - Single rank: returns ``int`` directly + - Multi-rank via collective_rpc: returns ``[int]`` or ``[[int]]`` + This method normalizes all formats to a single ``int | None``. + + Args: + adapter_name: The name of the LoRA adapter to query. + + Returns: + The integer ID if the adapter is registered, or ``None`` if not found. + + Raises: + RuntimeError: If different ranks report different IDs for the same + adapter name, indicating a registration inconsistency. + """ + lora_id = await self.model.get_lora_id(adapter_name) + # Handle vLLM collective_rpc return format variations: + # - Single rank: int + # - Multi-rank: [int, int, ...] (one per rank) or [[int], ...] + if isinstance(lora_id, list): + if not lora_id: + return None + # Handle nested [[id], ...] format + if len(lora_id) == 1 and isinstance(lora_id[0], list): + inner = lora_id[0] + return inner[0] if inner else None + # Handle [id, id, ...] format - verify consistency across ranks + first = lora_id[0] + if all(x == first for x in lora_id): + return first + raise RuntimeError(f"Inconsistent LoRA id across ranks for adapter {adapter_name!r}: {lora_id!r}") + return lora_id async def _collect_metrics_snapshot(self): """Collect metrics snapshots periodically in a background thread.""" diff --git a/roll/models/model_providers.py b/roll/models/model_providers.py index 32ebc2b3f..cbb26b80c 100644 --- a/roll/models/model_providers.py +++ b/roll/models/model_providers.py @@ -149,40 +149,135 @@ def freeze_model(model, model_args: "ModelArguments"): param.requires_grad_(False) +def _resolve_lora_target_modules(model: "torch.nn.Module", lora_target: Any) -> Any: + """Resolve magic targets like 'all-linear' into explicit module-name lists. + + Accepts lora_target as a comma-separated string, a list of module names, or a regex + string (detected by a heuristic check for common regex metacharacters). Magic tokens + 'all-linear', 'all-embedding', and 'all-router' are expanded by introspecting the + model's named modules via mcore_adapter helpers. + + Returns either a list[str] of concrete module names or a regex string (passthrough). + + Note: unlike upstream's get_target_modules which mutates model_args.lora_target in-place + via .remove(), this function creates new lists to avoid corrupting the config when called + multiple times on the same model_args (e.g. actor model + MCA model). + """ + + def _split_targets(target: Any) -> Any: + if target is None: + return [] + if isinstance(target, str): + # Treat as regex when it looks like one; otherwise split on commas. + if any(c in target for c in ["*", "$", "|", "(", "^", "[", "+", "?", "\\"]): + return target + return [item.strip() for item in target.split(",") if item.strip()] + return list(target) + + target_modules = _split_targets(lora_target) + if isinstance(target_modules, str): + return target_modules + + if "all-linear" in target_modules: + target_modules = [m for m in target_modules if m != "all-linear"] + target_modules += find_all_linear_modules(model) + if "all-embedding" in target_modules: + target_modules = [m for m in target_modules if m != "all-embedding"] + target_modules += find_all_embedding_modules(model) + if "all-router" in target_modules: + target_modules = [m for m in target_modules if m != "all-router"] + target_modules += find_all_router_modules(model) + return target_modules + + + + # Inspired by: https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llamafactory/model/adapter.py -def setup_lora_training( - config, model, model_args: "ModelArguments", is_trainable: Optional[bool] = False, is_mca: Optional[bool] = False +def setup_lora_training_from_adapters( + model, + adapters: dict, + is_trainable: Optional[bool] = False, + is_mca: Optional[bool] = False, ): + """Wrap model with one or more named LoRA adapters from the adapters dict. + + Each entry in adapters maps adapter_name -> adapter_args with per-adapter + lora_rank, lora_target, lora_alpha, lora_dropout, and additional_target. + Target modules are resolved against the pre-LoRA base model to avoid + LoRA-on-LoRA when adding multiple adapters. + + Post-setup active-adapter policy: + - is_mca=True (Megatron): all adapters activated so Megatron allocates grad + buffers for every adapter before wrapping. Per-step routing switches later. + - is_mca=False (HF/FSDP2/DeepSpeed): only the first adapter is activated, + matching upstream single-adapter semantics. + + When is_trainable is False, only enables input require grads without adding LoRA. + """ model.enable_input_require_grads() - if is_trainable: - - def get_target_modules(model: "torch.nn.Module", model_args: "ModelArguments"): - target_modules = model_args.lora_target - if "all-linear" in model_args.lora_target: - target_modules.remove("all-linear") - target_modules += find_all_linear_modules(model) - if "all-embedding" in model_args.lora_target: - target_modules.remove("all-embedding") - target_modules += find_all_embedding_modules(model) - if "all-router" in model_args.lora_target: - target_modules.remove("all-router") - target_modules += find_all_router_modules(model) - return target_modules - - target_modules = get_target_modules(model, model_args) - lora_config = { - "r": model_args.lora_rank, + if not is_trainable: + return model + + base_model = model + target_modules_map: dict[str, Any] = {} + for adapter_name, adapter_args in adapters.items(): + lora_target = getattr(adapter_args, "lora_target", None) + if lora_target is None: + raise ValueError(f"adapter '{adapter_name}' has no lora_target — cannot create LoRA config.") + # Resolve module targets against the *pre-LoRA* model to avoid LoRA-on-LoRA + # when adding multiple adapters. + target_modules_map[adapter_name] = _resolve_lora_target_modules(base_model, lora_target) + + # First adapter uses get_peft_model() to wrap base model into PeftModel; + # subsequent adapters use add_adapter() on the already-wrapped PeftModel. + peft_model = None + for adapter_name, adapter_args in adapters.items(): + target_modules = target_modules_map[adapter_name] + # adapter_args is always a LoraArguments dataclass pre-normalized by _normalize_adapters; + # direct attribute access so misuse fails fast instead of silently using defaults. + lora_config: dict = { + "r": adapter_args.lora_rank, "target_modules": target_modules, - "lora_alpha": model_args.lora_alpha, - "lora_dropout": model_args.lora_dropout, - "modules_to_save": model_args.additional_target, + "lora_alpha": adapter_args.lora_alpha, + "lora_dropout": adapter_args.lora_dropout, + "modules_to_save": adapter_args.additional_target, } if not is_mca: lora_config.update({"task_type": TaskType.CAUSAL_LM}) - model = get_peft_model( - model, LoraConfig(**lora_config), autocast_adapter_dtype=model_args.autocast_adapter_dtype - ) - return model + + peft_config = LoraConfig(**lora_config) + if peft_model is None: + peft_model = get_peft_model( + base_model, + peft_config, + adapter_name=adapter_name, + autocast_adapter_dtype=adapter_args.autocast_adapter_dtype, + ) + else: + peft_model.add_adapter(adapter_name, peft_config) + # PEFT only autocasts adapter dtype for the *initial* adapter created via get_peft_model. + # For additional adapters added via add_adapter(), we must apply the same casting logic + # to match single-adapter training semantics (critical for Phase-0 step equivalence). + base = getattr(peft_model, "base_model", None) + if base is not None and hasattr(base, "_cast_adapter_dtype"): + base._cast_adapter_dtype( + adapter_name=adapter_name, + autocast_adapter_dtype=adapter_args.autocast_adapter_dtype, + ) + + if peft_model is None: + raise ValueError("adapters is empty but setup_lora_training_from_adapters was called.") + + # Megatron (is_mca): activate ALL adapters so PEFT marks every adapter trainable before + # Megatron wraps the model (grad buffers / main_grad allocated for every adapter). + # Per-step routing will call set_adapter(single_name) at runtime. + # Non-Megatron: activate only the first adapter to match upstream single-adapter semantics. + if is_mca: + peft_model.base_model.set_adapter(list(adapters.keys())) + else: + peft_model.base_model.set_adapter(next(iter(adapters))) + return peft_model + def load_model( @@ -258,10 +353,11 @@ def load_model( model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": use_reentrant}) - if model_args.lora_target is None: + # adapters is always populated by __post_init__ when LoRA is enabled. + if model_args.adapters is None: freeze_model(model, model_args) else: - model = setup_lora_training(config, model, model_args, is_trainable) + model = setup_lora_training_from_adapters(model, model_args.adapters, is_trainable) if add_valuehead: from trl import AutoModelForCausalLMWithValueHead @@ -469,20 +565,28 @@ def default_actor_model_provider( if model_args.moe_aux_loss_coef is not None and training_args.moe_aux_loss_coeff is None: training_args.moe_aux_loss_coeff = model_args.moe_aux_loss_coef model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) + # adapters is always populated by __post_init__ when LoRA is enabled. + lora_enabled = model_args.adapters is not None if is_trainable: model.train() for param in model.parameters(): - param.requires_grad = True + # LoRA fine-tuning keeps the base model frozen. + param.requires_grad = not lora_enabled else: model.eval() for param in model.parameters(): param.requires_grad = False - if model_args.lora_target is None: + if not lora_enabled: freeze_model(model, model_args) else: apply_megatron_lora() set_linear_is_expert(model[0]) - model.models[0] = setup_lora_training(model[0].config, model[0], model_args, is_trainable, is_mca=True) + model.models[0] = setup_lora_training_from_adapters( + model[0], + model_args.adapters, + is_trainable, + is_mca=True, + ) patch_model(model, config, use_mcore=True) else: # hf diff --git a/roll/pipeline/agentic/agentic_config.py b/roll/pipeline/agentic/agentic_config.py index def3246cd..e0c55015f 100644 --- a/roll/pipeline/agentic/agentic_config.py +++ b/roll/pipeline/agentic/agentic_config.py @@ -288,7 +288,8 @@ def __post_init__(self): elif self.train_env_manager.max_traj_per_env < 0: self.train_env_manager.max_traj_per_env = traj_per_env logger.info(f"train_env_manager.max_traj_per_env: {self.train_env_manager.max_traj_per_env}") - assert self.train_env_manager.max_traj_per_env >= traj_per_env, f"max_traj_per_env must be >= {traj_per_env}" + # Ensure max_traj_per_env meets the minimum floor for the given batch and env count. + self.ensure_min_traj_per_env(self.train_env_manager, self.rollout_batch_size) # Validate rollout_batch_size is compatible with group_size # The scheduler collects trajectories in complete groups to maintain variance reduction properties @@ -322,7 +323,8 @@ def __post_init__(self): if self.val_env_manager.max_traj_per_env < 0: self.val_env_manager.max_traj_per_env = traj_per_env logger.info(f"val_env_manager.max_traj_per_env: {self.val_env_manager.max_traj_per_env}") - assert self.val_env_manager.max_traj_per_env >= traj_per_env, f"max_traj_per_env must be >= {traj_per_env}" + # Ensure max_traj_per_env meets the minimum floor for the given batch and env count. + self.ensure_min_traj_per_env(self.val_env_manager, self.val_batch_size) if ( hasattr(self, "actor_infer") @@ -385,5 +387,16 @@ def make_env_configs(self, env_manager_config: EnvManagerConfig): f"[ENV CONFIG] tag: {tag}, group_id: {group_id}, group_seeds: {group_seeds[group_id]}, env_id: {env_id}" ) done_groups += n_group - assert done_groups == env_manager_config.num_env_groups + assert done_groups == env_manager_config.num_env_groups, f"{done_groups=} is not { env_manager_config.num_env_groups=}" env_manager_config.env_configs = env_configs + + def ensure_min_traj_per_env(self, env_manager_config: EnvManagerConfig, batch_size: int) -> None: + """Ensure max_traj_per_env is sufficient for the given env count and batch size.""" + env_count = env_manager_config.num_env_groups * env_manager_config.group_size + min_traj_per_env = (batch_size + env_count - 1) // env_count + if env_manager_config.max_traj_per_env < min_traj_per_env: + logger.warning( + "Overriding max_traj_per_env: %d -> %d (batch_size=%d, env_count=%d)", + env_manager_config.max_traj_per_env, min_traj_per_env, batch_size, env_count, + ) + env_manager_config.max_traj_per_env = min_traj_per_env diff --git a/roll/pipeline/agentic/agentic_multi_lora_pipeline.py b/roll/pipeline/agentic/agentic_multi_lora_pipeline.py new file mode 100644 index 000000000..5ce40f2e3 --- /dev/null +++ b/roll/pipeline/agentic/agentic_multi_lora_pipeline.py @@ -0,0 +1,1007 @@ +import threading +import time + +from dataclasses import replace +from typing import Any, Dict, List, Optional + +import numpy as np +import ray +import torch +from codetiming import Timer +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +from ray.util.timer import _Timer + +from roll.distributed.executor.cluster import Cluster +from roll.distributed.scheduler.protocol import DataProto +from roll.distributed.scheduler.rollout_scheduler import RolloutScheduler +from roll.models.model_providers import default_tokenizer_provider +from roll.pipeline.agentic.agentic_config import AgenticConfig, EnvManagerConfig +from roll.datasets.global_dataset import GlobalDatasetManager +from roll.pipeline.agentic.agentic_pipeline import ( + compute_rollout_traj_metrics, + compute_train_data_metrics, + get_episode_scores, +) +from roll.utils.constants import RAY_NAMESPACE +from roll.pipeline.agentic.utils import ( + agentic_compute_advantage, + compute_discounted_returns, + compute_response_level_rewards, + dump_rollout_trajectories, + get_agentic_response_level_mask, +) +from roll.pipeline.base_pipeline import BasePipeline +from roll.utils.dynamic_batching import dynamic_batching_shard +from roll.utils.functionals import ( + RunningMoments, + agg_loss, + batch_balance, + compute_token_reward, + masked_mean, + reduce_metrics, +) +from roll.utils.kl_controller import get_kl_controller +from roll.utils.logging import get_logger + +from roll.utils.lora_routing import normalize_domain +from roll.utils.train_infer_corrections import apply_train_infer_correction_to_batch + + +logger = get_logger() + + +def is_lora_training(pipeline_config: AgenticConfig) -> bool: + return pipeline_config.actor_train.model_args.adapters is not None + + +class AgenticMultiLoraPipeline(BasePipeline): + """ + Async multi-LoRA Agentic pipeline: + - multiple env tags sampled concurrently + - each batch routes via non_tensor_batch["lora_name"] + - per-adapter optimizer stepping via actor_train.train_step_lora([...]) + """ + + def __init__(self, pipeline_config: AgenticConfig): + super().__init__(pipeline_config) + self.pipeline_config: AgenticConfig + + self.pipeline_config.set_max_steps(max_steps=self.pipeline_config.max_steps) + if not is_lora_training(self.pipeline_config): + raise RuntimeError( + "AgenticMultiLoraPipeline requires LoRA adapters (actor_train.model_args.adapters). " + "For full fine-tune (FFT), use AgenticPipeline. " + "FFT reference requires a separate frozen reference model/cluster (not disable_adapter)." + ) + + actor_infer_strategy = getattr(self.pipeline_config.actor_infer, "strategy_args", None) + if actor_infer_strategy is not None and getattr(actor_infer_strategy, "strategy_name", None) == "vllm": + strategy_config = actor_infer_strategy.strategy_config or {} + sleep_level = int(strategy_config.get("sleep_level", 1)) + if sleep_level != 1: + raise RuntimeError( + "AgenticMultiLoraPipeline requires vLLM sleep_level=1. " + "Level 1 offloads weights to CPU (restorable); Level 2 discards weights entirely. " + "Multi-LoRA needs restorable offload for train/infer weight sync cycles." + ) + + # For multi-LoRA training, reference is the same backbone with LoRA disabled. + # Use actor_train.disable_adapter() to compute ref_log_probs; do not create a separate reference cluster. + self.use_ref_model = False + + # TODO: support GAE with per-LoRA critics: frozen backbone + per-LoRA adapters + per-LoRA value heads. + # Critic setup per LoRA task: + # - Value head: fully tuned linear layer (hidden_state → scalar value) + # - Backbone: frozen weights + LoRA adapters (only adapters updated to save memory) + if self.pipeline_config.adv_estimator == "gae": + raise NotImplementedError( + "AgenticMultiLoraPipeline does not support adv_estimator='gae'. " + "A single shared critic cannot produce accurate advantages across different LoRA tasks. " + "Requires per-LoRA critic adapters and per-LoRA value heads on a shared backbone " + "(not yet implemented). Use 'grpo' or 'reinforce_plus_plus' instead." + ) + + self.kl_ctrl = get_kl_controller( + init_kl_coef=self.pipeline_config.init_kl_coef, + target_kl=self.pipeline_config.target_kl, + kl_horizon=self.pipeline_config.kl_horizon, + ) + + self.actor_train: Any = Cluster( + name=self.pipeline_config.actor_train.name, + worker_cls=self.pipeline_config.actor_train.worker_cls, + resource_manager=self.resource_manager, + worker_config=self.pipeline_config.actor_train, + ) + + self.actor_infer: Any = Cluster( + name=self.pipeline_config.actor_infer.name, + worker_cls=self.pipeline_config.actor_infer.worker_cls, + resource_manager=self.resource_manager, + worker_config=self.pipeline_config.actor_infer, + ) + download_clusters = [self.actor_train, self.actor_infer] + + # INIT PHASE: Download models and tokenizer + self.download_models(*download_clusters) + self.tokenizer = default_tokenizer_provider(model_args=self.pipeline_config.actor_train.model_args) + + # INIT PHASE: Initialize clusters + self.actor_train.initialize(pipeline_config=self.pipeline_config, blocking=True) + + self.actor_infer.initialize(pipeline_config=self.pipeline_config, blocking=True) + + # INIT PHASE: Model update pairing (train -> infer) + self.set_model_update_pair( + src_cluster=self.actor_train, + tgt_cluster=self.actor_infer, + frequency=self.pipeline_config.actor_train.model_update_frequency, + ) + + self.set_checkpoint_clusters(self.actor_train) + + self.running = RunningMoments() + + self.partial_gpu_mode: bool = False + if hasattr(self.pipeline_config, "partial_gpu_mode") and self.pipeline_config.partial_gpu_mode: + self._validate_partial_gpu_config() + self.partial_gpu_mode = True + + # Per-tag rollout schedulers (shared actor_infer). + self.rollout_schedulers: dict[str, Any] = {} + base_env: EnvManagerConfig = self.pipeline_config.train_env_manager + for tag, n_group in zip(base_env.tags, base_env.num_groups_partition): + # Shallow-copy the base config so per-tag mutations don't affect other tags. + env_cfg = replace(base_env) + # Narrow the config to this single tag's env subset (one tag, one partition). + env_cfg.tags = [tag] + env_cfg.num_groups_partition = [n_group] + env_cfg.num_env_groups = n_group + env_cfg.name = f"train_env_{tag}" + # Recompute derived fields (world_size, max_env_num_per_worker, etc.) for the reduced env count. + env_cfg.__post_init__() + # Ensure per-tag max_traj_per_env is sufficient after narrowing to this tag's env subset. + self.pipeline_config.ensure_min_traj_per_env(env_cfg, self.pipeline_config.rollout_batch_size) + # Rebuild env_configs so worker_rank → env_id mapping reflects only this tag's envs. + self.pipeline_config.make_env_configs(env_cfg) + self.rollout_schedulers[tag] = ray.remote(RolloutScheduler).options( + name=f"RolloutScheduler-train-{tag}", + namespace=RAY_NAMESPACE, + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False, + ), + ).remote( + config=self.pipeline_config, + env_manager_config=env_cfg, + resource_manager=self.resource_manager, + infer_cluster=self.actor_infer, + mode="train", + ) + + # Per-tag val rollout schedulers (mirrors train schedulers for per-adapter eval). + val_env: EnvManagerConfig = self.pipeline_config.val_env_manager + val_tags = list(val_env.tags) if getattr(val_env, "tags", None) else [] + # Val tags must match train tags exactly for correct per-adapter eval. + assert val_tags == list(base_env.tags), ( + f"val_env_manager.tags must match train_env_manager.tags: " + f"val={val_tags} train={list(base_env.tags)}" + ) + num_tags = len(val_tags) + + # Validate val partition: no fallback, require explicit valid config. + val_num_groups_partition = list(getattr(val_env, "num_groups_partition", []) or []) + assert len(val_num_groups_partition) == num_tags, ( + f"val_env_manager.num_groups_partition length ({len(val_num_groups_partition)}) " + f"must match num_tags ({num_tags})" + ) + assert all(n_group > 0 for n_group in val_num_groups_partition), ( + f"val_env_manager.num_groups_partition entries must all be > 0: {val_num_groups_partition}" + ) + assert sum(val_num_groups_partition) == val_env.num_env_groups, ( + f"sum(val_env_manager.num_groups_partition) = {sum(val_num_groups_partition)} " + f"must equal val_env_manager.num_env_groups = {val_env.num_env_groups}" + ) + + # Per-tag val_batch_size: equal split, validated per-tag. + assert self.pipeline_config.val_batch_size % num_tags == 0, ( + f"val_batch_size ({self.pipeline_config.val_batch_size}) must be divisible by " + f"num_tags ({num_tags})" + ) + val_batch_size_per_tag = self.pipeline_config.val_batch_size // num_tags + self._val_batch_size_per_tag: dict[str, int] = {} + for tag, val_n_group in zip(val_tags, val_num_groups_partition): + tag_val_env_num = val_n_group * val_env.group_size + assert val_batch_size_per_tag % tag_val_env_num == 0, ( + f"per-tag val_batch_size ({val_batch_size_per_tag}) must be divisible by " + f"tag {tag!r} val_env_num ({tag_val_env_num} = {val_n_group} * {val_env.group_size})" + ) + self._val_batch_size_per_tag[tag] = val_batch_size_per_tag + + self.val_rollout_schedulers: dict[str, Any] = {} + for tag, val_n_group in zip(val_tags, val_num_groups_partition): + val_env_cfg = replace(val_env) + val_env_cfg.tags = [tag] + val_env_cfg.num_groups_partition = [val_n_group] + val_env_cfg.num_env_groups = val_n_group + val_env_cfg.name = f"val_env_{tag}" + val_env_cfg.__post_init__() + # Ensure per-tag max_traj_per_env is sufficient for the proportional val batch. + self.pipeline_config.ensure_min_traj_per_env(val_env_cfg, self._val_batch_size_per_tag[tag]) + self.pipeline_config.make_env_configs(val_env_cfg) + self.val_rollout_schedulers[tag] = ray.remote(RolloutScheduler).options( + name=f"RolloutScheduler-val-{tag}", + namespace=RAY_NAMESPACE, + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False, + ), + ).remote( + config=self.pipeline_config, + env_manager_config=val_env_cfg, + resource_manager=self.resource_manager, + infer_cluster=self.actor_infer, + mode="val", + ) + + self.val_dataset_manager: Any = GlobalDatasetManager.options( + name="val_dataset_manager", + get_if_exists=True, + namespace=RAY_NAMESPACE, + ).remote() + + # Serialize concurrent shrink/expand calls from partial-GPU mode. + self._infer_resize_lock = threading.Lock() + + # Initial model update to register/load adapters on inference before first rollout. + self._initial_model_update() + self._create_lora_trackers() + + def _create_lora_trackers(self) -> None: + """Create one metrics tracker per LoRA adapter for independent per-adapter tracking.""" + from roll.utils.tracking import create_lora_tracker + + adapters = self.pipeline_config.actor_train.model_args.adapters or {} + if not adapters: + return + adapter_names = sorted(adapters.keys()) + tracker_name = self.pipeline_config.track_with + + self.lora_trackers: dict[str, Any] = {} + for name in adapter_names: + self.lora_trackers[name] = create_lora_tracker( + tracker_name=tracker_name, + lora_name=name, + config=self.pipeline_config.to_dict(), + **self.pipeline_config.tracker_kwargs, + ) + logger.info("Created per-LoRA trackers for adapters: %s", adapter_names) + + def _initial_model_update(self) -> None: + # Full offload: discard model weights, KV cache, and all LoRA tensors before initial sync. + self.actor_infer.offload_states() + adapters = set(self.pipeline_config.actor_train.model_args.adapters.keys()) if self.pipeline_config.actor_train.model_args.adapters else None + _ = self.model_update_lora_subset(global_step=0, adapters_to_update=adapters) + self.actor_infer.load_states() + + def adjust_batch(self, data: DataProto, mode: str = "copy") -> DataProto: + # TODO: extract adjust_batch into a standalone utility function instead of + # calling an unbound method from a sibling class (fragile, bypasses inheritance). + from roll.pipeline.agentic.agentic_pipeline import AgenticPipeline + + return AgenticPipeline.adjust_batch(self, data=data, mode=mode) # type: ignore[misc] + + + + def val(self, lora_name: str, global_step: int) -> dict: + """Validate a single adapter by running only its matching tag's val scheduler.""" + metrics: dict = {} + ray.get(self.val_dataset_manager.reset.remote()) + + for tag, val_scheduler in self.val_rollout_schedulers.items(): + # Only validate the tag that maps to the given adapter. + if normalize_domain(tag) != lora_name: + continue + metrics.update(self._val_tag(tag, val_scheduler, global_step)) + + logger.info(f"val lora={lora_name} metrics: {metrics}") + return metrics + + def _val_tag(self, tag: str, val_scheduler: Any, global_step: int) -> dict: + """Run validation for a single tag and return prefixed metrics.""" + metrics: dict = {} + batch = DataProto(meta_info={"is_offload_states": False, "global_step": global_step}) + eval_batch = ray.get(val_scheduler.get_batch.remote(batch, self._val_batch_size_per_tag[tag])) + + if "get_batch_return_start_time" in eval_batch.meta_info: + metrics[f"time/get_batch_cost_val/{tag}"] = ( + time.time() - eval_batch.meta_info.pop("get_batch_return_start_time") + ) + + dump_rollout_trajectories(self.pipeline_config.rollout_dump_dir, global_step, eval_batch) + eval_metrics = reduce_metrics(eval_batch.meta_info.get("metrics", {})) + eval_score = get_episode_scores(eval_batch) + eval_metrics["score/mean"] = torch.mean(eval_score).detach().item() + eval_metrics["score/max"] = torch.max(eval_score).detach().item() + eval_metrics["score/min"] = torch.min(eval_score).detach().item() + + metrics.update({f"val/{tag}/{k}": v for k, v in eval_metrics.items()}) + return metrics + + def _validate_partial_gpu_config(self) -> bool: + train_devices = set(self.actor_train.worker_config.device_mapping) + infer_devices = set(self.actor_infer.worker_config.device_mapping) + + if not train_devices or not infer_devices: + raise ValueError( + "device_mapping cannot be empty: " + f"train={list(train_devices)}, infer={list(infer_devices)}" + ) + + if train_devices.isdisjoint(infer_devices): + raise RuntimeError( + "AgenticMultiLoraPipeline does not support disjoint actor_train/actor_infer device_mapping. " + "Use partial overlap (actor_train ⊂ actor_infer) so inference can continue on remaining GPUs while " + "training runs." + ) + + if train_devices.issubset(infer_devices) and len(train_devices) < len(infer_devices): + logger.info("Detected Configuration Model B: Subset device_mapping, partial_gpu_mode=True") + infer_dp_size = self.actor_infer.worker_config.world_size + assert infer_dp_size >= 2, ( + f"partial_gpu_mode requires actor_infer.dp_size >= 2, got {infer_dp_size}" + ) + async_ratio = self.pipeline_config.async_generation_ratio + assert async_ratio >= 0, f"async_generation_ratio must be >= 0, got {async_ratio}" + + infer_strategy_config = self.actor_infer.worker_config.strategy_args.strategy_config + tp_size = infer_strategy_config.get("tensor_parallel_size", 1) + pp_size = infer_strategy_config.get("pipeline_parallel_size", 1) + assert tp_size >= 1 and pp_size >= 1, f"tp_size and pp_size must be >= 1: tp={tp_size}, pp={pp_size}" + + expected_gpu_count = tp_size * pp_size * infer_dp_size + actual_gpu_count = len(infer_devices) + assert expected_gpu_count == actual_gpu_count, ( + "Parallelism configuration mismatch: " + f"tp_size * pp_size * dp_size = {tp_size} * {pp_size} * {infer_dp_size} = {expected_gpu_count}, " + f"but device_mapping has {actual_gpu_count} GPUs" + ) + + gpus_per_dp_rank = tp_size * pp_size + freed_gpus = train_devices + self._validate_minimum_active_ranks(infer_dp_size, infer_devices, list(freed_gpus), gpus_per_dp_rank) + # Store TP/PP-aware attributes for GPU→dp_rank translation in shrink/expand. + self._infer_gpus_per_dp_rank = gpus_per_dp_rank + self._infer_device_mapping = list(self.actor_infer.worker_config.device_mapping) + logger.info(f"Partial GPU mode validated: infer_dp_size={infer_dp_size}, freed_gpus={sorted(freed_gpus)}") + return True + + if len(train_devices) == len(infer_devices): + raise RuntimeError( + "AgenticMultiLoraPipeline does not support actor_train/actor_infer colocating mode " + "(train device_mapping == infer device_mapping). Use partial overlap (actor_train ⊂ actor_infer)." + ) + + raise RuntimeError( + "Unsupported device_mapping relationship for AgenticMultiLoraPipeline. " + f"train={sorted(train_devices)} infer={sorted(infer_devices)}" + ) + + def _validate_minimum_active_ranks( + self, + infer_dp_size: int, + infer_devices: set, + freed_gpu_list: list, + gpus_per_dp_rank: int, + ) -> None: + # TODO: extract _validate_minimum_active_ranks into a shared utility instead of + # calling an unbound method from a sibling class (fragile, bypasses inheritance). + from roll.pipeline.agentic.agentic_pipeline import AgenticPipeline + + AgenticPipeline._validate_minimum_active_ranks( + self, infer_dp_size, infer_devices, freed_gpu_list, gpus_per_dp_rank + ) + + def _prepare_batch(self, batch: DataProto, metrics: dict) -> DataProto: + """Transform raw rollout data into a training-ready batch for the actor update. + + Multi-LoRA pipelines do NOT use a critic (GAE is explicitly unsupported because a + single shared critic cannot produce accurate values across different LoRA tasks). + Instead, critic-free estimators like GRPO, Reinforce++, or GIGPO are used. + + Processing pipeline (in order): + 1. Discounted returns — collapse multi-step rewards into per-token returns. + 2. Batch adjustment — filter/transform samples (e.g. drop low-quality trajectories). + 3. Reference log-probs — dynamic_batching_shard (if enabled), disable LoRA adapter, + batch_balance, then compute log-probs under the frozen base model for KL penalty. + Matches agentic_pipeline.py:404-436. + 4. Old log-probs — compute log-probs under the *current* policy (LoRA enabled) to form + the importance-sampling ratio (π_new / π_old) used by the clipped surrogate objective. + If old-logprob recompute is disabled, zeros are used (ratio = 1, i.e. on-policy). + 5. Response-level mask — build segment/token masks that select which parts of the + response are included in the loss. + 6. Response-level rewards — normalize and reshape rewards per response segment. + 7. Token-level rewards — apply KL penalty and per-token reward shaping. + 8. Advantage estimation — critic-free estimator (GRPO / Reinforce++ / GIGPO) + over the shaped rewards. + 9. Train-infer correction — optionally down-weight stale samples whose old log-probs + diverge too far from the current policy (importance-weight clipping). + + Args: + batch: Raw rollout output containing token ids, attention masks, rewards, + and meta_info produced by the environment / rollout workers. + metrics: Mutable dict; timing and scalar metrics are added in-place. + + Returns: + The same ``batch`` object, enriched with fields required by the actor + training step: ``old_log_probs``, ``ref_log_probs``, ``advantages``, + ``returns``, token-level rewards, and response-level masks. + """ + # Step 1: collapse multi-step rewards into discounted returns. + batch = compute_discounted_returns(batch, self.pipeline_config.adv_estimator, self.pipeline_config.step_reward_gamma) + + # Step 2: filter/transform samples (e.g. drop low-quality trajectories). + batch = self.adjust_batch(batch, mode=self.pipeline_config.batch_adjust_mode) + metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) + + # Step 3: reference log-probs — run the base model (LoRA disabled) to get the + # KL-divergence anchor used in the training loss (matching agentic_pipeline.py:404-436). + with Timer(name="cal_ref_log_probs", logger=None) as cal_timer: + if self.pipeline_config.enable_reference: + # Dynamic batching for ref path (same guard as old log-prob path below). + if self.pipeline_config.actor_train.use_dynamic_batching_in_infer: + batch, dynamic_batching_metrics = dynamic_batching_shard( + batch, + self.actor_train.dp_size, + self.pipeline_config.actor_train.max_tokens_per_microbatch_in_infer, + self.pipeline_config.actor_train.sequence_length_round_in_infer, + self.pipeline_config.actor_train.strategy_args.strategy_config.get("pipeline_model_parallel_size", 1), + self.pipeline_config.actor_train.strategy_args.strategy_config.get("virtual_pipeline_model_parallel_size", None), + "reference/compute_log_probs", + ) + metrics.update(dynamic_batching_metrics) + # For multi-LoRA, reference logprobs are computed by disabling the LoRA adapter on the actor. + batch.meta_info["disable_adapter"] = True + batch.meta_info["is_offload_states"] = False + batch_balance(batch, dp_size=self.actor_train.dp_size, minibatch_size=len(batch)) + ref_log_probs: DataProto = self.actor_train.compute_log_probs(batch, blocking=True) + batch.meta_info.pop("disable_adapter", None) + # Use rename + union to preserve all fields from the ref DataProto (matching agentic_pipeline.py:431-432). + ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") + batch = batch.union(ref_log_probs) + avg_ref_log_prob = masked_mean(batch.batch["ref_log_probs"], batch.batch["response_mask"][:, 1:]) + metrics.update(reduce_metrics(ref_log_probs.meta_info.pop("metrics", {}))) + metrics.update({"critic/ref_log_prob/mean": avg_ref_log_prob.item()}) + metrics["time/step_ref_log_probs_values_reward"] = cal_timer.last + + # Re-balance after ref log-prob compute may have changed padding. + batch_balance(batch, dp_size=self.actor_train.dp_size, minibatch_size=len(batch)) + + # Step 4: old log-probs — compute π_old(a|s) under the current LoRA-enabled policy. + # These form the denominator of the importance-sampling ratio π_new / π_old. + with Timer(name="cal_old_log_probs_values", logger=None) as cal_old_logpb_timer: + batch.meta_info["is_offload_states"] = False + if self.pipeline_config.enable_old_logprobs_recompute: + if self.pipeline_config.actor_train.use_dynamic_batching_in_infer: + batch, dynamic_batching_metrics = dynamic_batching_shard( + batch, + self.actor_train.dp_size, + self.pipeline_config.actor_train.max_tokens_per_microbatch_in_infer, + self.pipeline_config.actor_train.sequence_length_round_in_infer, + self.pipeline_config.actor_train.strategy_args.strategy_config.get("pipeline_model_parallel_size", 1), + self.pipeline_config.actor_train.strategy_args.strategy_config.get("virtual_pipeline_model_parallel_size", None), + "actor_train/compute_log_probs", + ) + metrics.update(dynamic_batching_metrics) + old_log_probs: DataProto = self.actor_train.compute_log_probs(batch, blocking=True) + batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] + avg_old_log_prob = masked_mean(batch.batch["old_log_probs"], batch.batch["response_mask"][:, 1:]) + metrics.update({"critic/old_log_prob/mean": avg_old_log_prob.item()}) + metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {}))) + agg_entropy = agg_loss( + loss_mat=old_log_probs.batch["entropy"], + loss_mask=batch.batch["response_mask"][:, 1:], + loss_agg_mode="token-mean", + ) + metrics.update({"critic/entropy/mean": agg_entropy.item()}) + else: + batch.batch["old_log_probs"] = torch.zeros_like(batch.batch["attention_mask"][:, 1:]) + + # Reference logprobs (if reference disabled, mock with old_log_probs) + if not self.pipeline_config.enable_reference: + batch.batch["ref_log_probs"] = batch.batch["old_log_probs"].clone() + avg_ref_log_prob = masked_mean(batch.batch["ref_log_probs"], batch.batch["response_mask"][:, 1:]) + metrics.update({"critic/ref_log_prob/mean": avg_ref_log_prob.item()}) + metrics["time/step_old_log_probs_values"] = cal_old_logpb_timer.last + + # Step 5: build response-level masks that select which tokens/segments contribute to the loss. + with Timer(name="cal_response_level_mask", logger=None) as timer: + batch, mask_metrics = get_agentic_response_level_mask(batch, self.pipeline_config) + metrics.update(mask_metrics) + metrics["time/step_cal_response_level_mask"] = timer.last + + # Step 6: normalize and reshape rewards per response segment. + with Timer(name="cal_response_norm_rewards", logger=None) as timer: + batch, reward_metrics = compute_response_level_rewards(batch=batch, pipeline_config=self.pipeline_config) + metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) + metrics.update(reward_metrics) + metrics["time/step_cal_norm_rewards"] = timer.last + + # Step 7: apply KL penalty and per-token reward shaping via the adaptive KL controller. + with Timer(name="cal_token_reward", logger=None) as timer: + batch, token_level_metrics = compute_token_reward(batch, self.pipeline_config, self.kl_ctrl) + metrics.update(token_level_metrics) + metrics["time/step_cal_token_reward"] = timer.last + + # Step 8: compute advantages using critic-free estimator (GRPO / Reinforce++ / GIGPO). + with Timer(name="compute_advantage", logger=None) as timer: + batch = agentic_compute_advantage( + data=batch, + gamma=self.pipeline_config.gamma, + lambd=self.pipeline_config.lambd, + adv_estimator=self.pipeline_config.adv_estimator, + advantage_clip=self.pipeline_config.advantage_clip, + whiten_advantages=self.pipeline_config.whiten_advantages, + whiten_rewards=self.pipeline_config.whiten_rewards, + ) + metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) + metrics["time/step_adv"] = timer.last + # Step 9: importance-weight correction — down-weight stale samples whose old + # log-probs diverge too far from the current policy (only when old logprobs are recomputed). + if self.pipeline_config.enable_old_logprobs_recompute: + batch, corr_metrics = apply_train_infer_correction_to_batch( + self.pipeline_config, + batch, + update_mask_keys=batch.meta_info["loss_mask_keys"], + ) + metrics.update(corr_metrics) + return batch + + def _shrink_workers(self, *, dp_ranks_to_remove: List[int]) -> Dict[str, Any]: + """Shrink inference off training GPUs across all per-tag schedulers. + + 2-phase pattern (mirrors agentic_pipeline._shrink_workers): + - Phase 1: all schedulers except last do routing-only shrink (skip_offload=True). + - Phase 2: last scheduler does routing + physical offload (skip_offload=False). + """ + if not isinstance(dp_ranks_to_remove, list) or not dp_ranks_to_remove: + raise ValueError("dp_ranks_to_remove must be a non-empty list[int]") + with self._infer_resize_lock: + all_schedulers = list(self.rollout_schedulers.values()) + list(self.val_rollout_schedulers.values()) + # Phase 1: routing-only shrink on all except last. + if len(all_schedulers) > 1: + phase1_metrics = ray.get( + [sched.shrink_sampler.remote(dp_ranks_to_remove, skip_offload=True) for sched in all_schedulers[:-1]] + ) + else: + phase1_metrics = [] + # Phase 2: last scheduler does routing + physical offload. + phase2_metrics = ray.get(all_schedulers[-1].shrink_sampler.remote(dp_ranks_to_remove, skip_offload=False)) + shrink_metrics_list = phase1_metrics + [phase2_metrics] + result: Dict[str, Any] = {} + for idx, shrink_metrics in enumerate(shrink_metrics_list): + result.update({f"shrink/{idx}/{k}": v for k, v in shrink_metrics.items()}) + return result + + def _expand_workers(self, *, dp_ranks_to_add: List[int], train_skip_load: bool) -> Dict[str, Any]: + """Expand inference back to training GPUs across all per-tag schedulers. + + Sequential pattern (mirrors agentic_pipeline._expand_workers): + - First scheduler does physical load (skip_load determined by train_skip_load). + - Rest do routing-only expand (skip_load=True). + """ + if not isinstance(dp_ranks_to_add, list) or not dp_ranks_to_add: + raise ValueError("dp_ranks_to_add must be a non-empty list[int]") + with self._infer_resize_lock: + all_schedulers = list(self.rollout_schedulers.values()) + list(self.val_rollout_schedulers.values()) + # First scheduler loads model states (skip if model_update already loaded them). + first_metrics = ray.get(all_schedulers[0].expand_sampler.remote(dp_ranks_to_add, skip_load=bool(train_skip_load))) + # Rest do routing-only expand. + rest_metrics = ray.get( + [sched.expand_sampler.remote(dp_ranks_to_add, skip_load=True) for sched in all_schedulers[1:]] + ) + expand_metrics_list = [first_metrics] + rest_metrics + result: Dict[str, Any] = {} + for idx, expand_metrics in enumerate(expand_metrics_list): + result.update({f"expand/{idx}/{k}": v for k, v in expand_metrics.items()}) + return result + + @torch.no_grad() + def run(self): + if not is_lora_training(self.pipeline_config): + raise RuntimeError("AgenticMultiLoraPipeline requires actor_train.model_args.adapters to be configured.") + + success = False + try: + max_steps_per_lora = int(self.pipeline_config.max_steps) + adapters = list(self.pipeline_config.actor_train.model_args.adapters.keys()) + lora_step: dict[str, int] = {name: 0 for name in adapters} + global_tick = 0 + # Adapter keys in model_args.adapters are canonical lowercase (normalized in __post_init__). + tag_to_adapter = {tag: normalize_domain(tag) for tag in self.rollout_schedulers.keys()} + + # Resume per-lora state from checkpoint if available. + if "lora_step_by_adapter" in self.state.kv: + saved_mapping = self.state.kv["tag_to_adapter"] + if saved_mapping != tag_to_adapter: + raise RuntimeError( + f"Checkpoint tag_to_adapter mismatch: saved={saved_mapping} current={tag_to_adapter}" + ) + lora_step = dict(self.state.kv["lora_step_by_adapter"]) + global_tick = int(self.state.kv["global_tick"]) + logger.info(f"Resumed from checkpoint: global_tick={global_tick} lora_step={lora_step}") + + unknown = sorted({a for a in tag_to_adapter.values() if a not in lora_step}) + if unknown: + raise RuntimeError( + f"Train env tags must map to configured LoRA adapters. Unknown adapters from tags: {unknown}. " + f"adapters={sorted(lora_step.keys())} tag_to_adapter={tag_to_adapter}" + ) + + # Calculate tokens-per-second system throughput + tps_timer = _Timer(window_size=5) + # Monotonic clock origin for all relative timestamps in this pipeline run. + pipeline_start_mono = time.monotonic() + + # Kick off one in-flight get_batch per tag. + in_flight: dict[str, ray.ObjectRef] = {} + pending_by_tag: dict[str, DataProto] = {} + # Per-tag monotonic timestamp of when get_batch.remote() was issued, + # used to measure rollout latency (submission → ray.wait ready). + submitted_at_mono: dict[str, float] = {} + tags = list(self.rollout_schedulers.keys()) + for tag in tags: + adapter = tag_to_adapter[tag] + if lora_step.get(adapter, 0) >= max_steps_per_lora: + continue + # Use per-adapter step for rollout-facing operations (not global_tick). + data = DataProto(meta_info={"global_step": lora_step.get(adapter, 0)}) + in_flight[tag] = self.rollout_schedulers[tag].get_batch.remote( + data, self.pipeline_config.rollout_batch_size + ) + submitted_at_mono[tag] = time.monotonic() + + # Monotonic timestamp when the current ray.wait polling started (None when not waiting). + # Used to measure wall-clock time spent blocked in ray.wait per tick. + wait_ready_since_mono: float | None = None + # Single-adapter first-ready tick: each tick processes one ready tag batch. + last_get_batch_done_ts_by_adapter: dict[str, float] = {} + last_train_step_done_ts_by_adapter: dict[str, float] = {} + last_train_step_done_ts_global: float | None = None + + while any(lora_step[name] < max_steps_per_lora for name in adapters): + active_tags = [tag for tag in tags if lora_step.get(tag_to_adapter[tag], 0) < max_steps_per_lora] + active_tags_in_flight = [tag for tag in active_tags if tag in in_flight] + active_refs = [in_flight[tag] for tag in active_tags_in_flight] + assert len(active_refs) > 0 + + if wait_ready_since_mono is None: + wait_ready_since_mono = time.monotonic() + + # ray.wait with no timeout blocks until num_returns refs are ready. + ready, _ = ray.wait(active_refs, num_returns=1) + + ready_now_mono = time.monotonic() + + tick_wait_ready_batch_s = ready_now_mono - wait_ready_since_mono + wait_ready_since_mono = None + + # Single-adapter tick: consume exactly one ready batch per train_step_lora call. + ready_ref = ready[0] + ready_tag = next((t for t, r in in_flight.items() if r == ready_ref), None) + if ready_tag is None: + raise RuntimeError("ray.wait returned a ref that is not tracked in in_flight") + + batch = ray.get(ready_ref) + if batch is None: + raise RuntimeError(f"get_batch returned None for tag={ready_tag!r}") + # Derive sample UUIDs from traj_id (same as agentic_pipeline.py:338-339). + sample_uuids = [f"{traj_id}_{idx}" for idx, traj_id in enumerate(batch.non_tensor_batch['traj_id'])] + batch.non_tensor_batch['sample_uuid'] = np.array(sample_uuids, dtype=object) + # Align with AgenticPipeline timing metrics: + # - time/get_batch_cost_train from rollout scheduler's internal marker (if present) + # - time/step_rollout approximated later as (wait + preprocess) per adapter + batch.meta_info.setdefault("metrics", {}) + batch.meta_info["metrics"]["time/ray_wait_ready_batch_s"] = tick_wait_ready_batch_s + adapter_name = tag_to_adapter.get(ready_tag, ready_tag) + get_batch_done_ts = time.monotonic() - pipeline_start_mono + batch.meta_info["metrics"]["time/get_batch_done_ts"] = get_batch_done_ts + issue_mono = submitted_at_mono.get(ready_tag) + if issue_mono is None: + raise RuntimeError(f"Missing submitted_at timestamp for ready tag={ready_tag!r}") + issue_ts = issue_mono - pipeline_start_mono + batch.meta_info["metrics"]["time/get_batch_issue_ts"] = issue_ts + batch.meta_info["metrics"]["time/get_batch_latency_s"] = get_batch_done_ts - issue_ts + prev_done_ts = last_get_batch_done_ts_by_adapter.get(adapter_name) + batch.meta_info["metrics"]["time/get_batch_done_interval_s"] = ( + 0.0 if prev_done_ts is None else get_batch_done_ts - prev_done_ts + ) + last_get_batch_done_ts_by_adapter[adapter_name] = get_batch_done_ts + if "get_batch_return_start_time" in batch.meta_info: + batch.meta_info["metrics"]["time/get_batch_cost_train"] = ( + time.time() - batch.meta_info.pop("get_batch_return_start_time") + ) + pending_by_tag[ready_tag] = batch + in_flight.pop(ready_tag, None) + start_mono = submitted_at_mono.pop(ready_tag, None) + if start_mono is None: + raise RuntimeError(f"Missing submitted_at timestamp for popped tag={ready_tag!r}") + wait_s = time.monotonic() - start_mono + batch.meta_info["metrics"]["time/get_batch_wait_s"] = wait_s + logger.info(f"get_batch done tag={ready_tag!r} global_tick={global_tick} elapsed_s={wait_s:.3f}") + + # Greedy tick: once any tag has a ready batch, proceed to train. In partial-GPU mode, `shrink_sampler` + # relies on RequestScheduler to abort/remap + update routing safely for any in-flight requests. + + tick_metrics: dict = {} + shrink_duration_s: Optional[float] = None + with Timer(name="pipeline_tick_total", logger=None) as tick_timer: + with tps_timer: + # Partial GPU: shrink inference off training GPUs before training. + if self.partial_gpu_mode: + with Timer(name="exec_shrink", logger=None) as shrink_timer: + target_gpus: list[int] = [] + if hasattr(self.actor_train.worker_config, "device_mapping") and self.actor_train.worker_config.device_mapping: + target_gpus.extend(self.actor_train.worker_config.device_mapping) + if target_gpus: + dp_ranks = self._target_gpus_to_dp_ranks_to_remove( + target_gpus=target_gpus, + ) + tick_metrics.update(self._shrink_workers(dp_ranks_to_remove=dp_ranks)) + shrink_duration_s = float(shrink_timer.last) + + # Collect actor inference metrics once per tick + actor_infer_metrics = self.actor_infer.get_metrics() + actor_infer_reduced = {} + if "metrics" in actor_infer_metrics.meta_info: + actor_infer_reduced = reduce_metrics(actor_infer_metrics.meta_info.pop("metrics", {})) + + # Exactly one batch is ready per tick (ray.wait returns 1, + # cleared before next iteration). + if len(pending_by_tag) != 1: + raise RuntimeError( + f"Expected exactly 1 pending batch per tick, got {len(pending_by_tag)}: " + f"{sorted(pending_by_tag.keys())}" + ) + (ready_tag_for_tick, ready_batch_for_tick), = pending_by_tag.items() + + dirty_adapters: set[str] = set() + lora_metrics: dict[str, dict] = {} + + adapter_for_tag = tag_to_adapter[ready_tag_for_tick] + adapter_metrics = lora_metrics.setdefault(adapter_for_tag, {}) + if actor_infer_reduced: + adapter_metrics.update(actor_infer_reduced) + tick_wait_ready_batch_s = float( + ready_batch_for_tick.meta_info.get("metrics", {}).get("time/ray_wait_ready_batch_s", 0.0) or 0.0 + ) + tick_metrics["time/ray_wait_ready_batch_s"] = tick_wait_ready_batch_s + adapter_metrics["time/ray_wait_ready_batch_s"] = tick_wait_ready_batch_s + + wait_s = float(ready_batch_for_tick.meta_info.get("metrics", {}).get("time/get_batch_wait_s", 0.0) or 0.0) + # Use per-adapter step for rollout-facing metadata (not global_tick). + ready_batch_for_tick.meta_info["global_step"] = lora_step[adapter_for_tag] + ready_batch_for_tick.meta_info["_broadcast_non_tensor_batch"] = True + # Keep strategy token-count accounting contract identical to agentic_pipeline. + ready_batch_for_tick.meta_info["loss_mask_keys"] = ["response_mask"] + with Timer(name="rollout", logger=None) as rollout_timer: + adapter_metrics.update(reduce_metrics(ready_batch_for_tick.meta_info.pop("metrics", {}))) + adapter_metrics.update(compute_rollout_traj_metrics(ready_batch_for_tick)) + dump_rollout_trajectories(self.pipeline_config.rollout_dump_dir, lora_step[adapter_for_tag], ready_batch_for_tick) + adapter_metrics["time/step_rollout"] = rollout_timer.last + wait_s + + prepared_batch = self._prepare_batch(ready_batch_for_tick, adapter_metrics) + + # Extract the single adapter name from the prepared batch. + lora_names = prepared_batch.non_tensor_batch["lora_name"] + unique = list(dict.fromkeys(lora_names.tolist())) + if len(unique) != 1: + raise RuntimeError(f"Expected homogeneous lora_name per prepared batch, got {unique}") + adapter_name = str(unique[0]) + # Fail fast on adapter mismatch: adapter_for_tag is the canonical + # step source for rollout dump, model_update, and checkpoint. + if adapter_name != adapter_for_tag: + raise RuntimeError( + f"Adapter mismatch: tag={ready_tag_for_tick!r} expected adapter={adapter_for_tag!r} " + f"but prepared batch contains adapter={adapter_name!r}" + ) + dirty_adapters.add(adapter_name) + + # Per-adapter data metrics inline (single batch, no deferred concat needed). + with Timer(name="compute_data_metrics", logger=None) as data_metrics_timer: + adapter_metrics.update(compute_train_data_metrics(batch=prepared_batch)) + adapter_metrics["time/step_compute_data_metrics"] = data_metrics_timer.last + + # Balance batch for training (production pattern: agentic_pipeline.py:534-537). + batch_balance_metrics = batch_balance( + batch=prepared_batch, + dp_size=self.actor_train.dp_size, + minibatch_size=self.actor_train.dp_size + * self.pipeline_config.actor_train.training_args.per_device_train_batch_size + * self.pipeline_config.actor_train.training_args.gradient_accumulation_steps, + logging_prefix="global_seqlen/actor_train", + ) + tick_metrics.update(batch_balance_metrics) + adapter_metrics.update(batch_balance_metrics) + + # Dynamic batching: shard prepared_batch before train_step_lora + # (same pattern as agentic_pipeline.py train_step path). + if self.pipeline_config.actor_train.use_dynamic_batching_in_train: + prepared_batch, dynamic_batching_metrics = dynamic_batching_shard( + prepared_batch, + self.actor_train.dp_size, + self.pipeline_config.actor_train.max_tokens_per_microbatch_in_train, + self.pipeline_config.actor_train.sequence_length_round_in_train, + self.pipeline_config.actor_train.strategy_args.strategy_config.get( + "pipeline_model_parallel_size", 1 + ), + self.pipeline_config.actor_train.strategy_args.strategy_config.get( + "virtual_pipeline_model_parallel_size", None + ), + "actor_train/train_step_lora", + ) + adapter_metrics.update(dynamic_batching_metrics) + + # Train single adapter. + with Timer(name="train_timer", logger=None) as train_timer: + train_refs: list[ray.ObjectRef] = self.actor_train.train_step_lora(prepared_batch, blocking=False) + train_metrics = DataProto.materialize_concat(data_refs=train_refs) + reduced_train_metrics = reduce_metrics(train_metrics.meta_info.pop("metrics", {})) + tick_metrics.update(reduced_train_metrics) + tps_timer.push_units_processed(n=torch.sum(prepared_batch.batch["attention_mask"]).detach().item()) + train_step_s = float(train_timer.last) + train_step_done_ts = time.monotonic() - pipeline_start_mono + tick_metrics["time/train_step_done_ts"] = train_step_done_ts + tick_metrics["time/train_step_done_interval_s"] = ( + 0.0 + if last_train_step_done_ts_global is None + else train_step_done_ts - last_train_step_done_ts_global + ) + last_train_step_done_ts_global = train_step_done_ts + tick_metrics["system/tps"] = tps_timer.mean_throughput + for name in dirty_adapters: + adapter_metrics = lora_metrics.setdefault(name, {}) + adapter_metrics["time/step_train"] = train_step_s + adapter_metrics["time/step_train_step_lora"] = train_step_s + adapter_metrics["time/train_step_done_ts"] = train_step_done_ts + prev_train_done_ts = last_train_step_done_ts_by_adapter.get(name) + lora_train_step_interval_s = ( + 0.0 if prev_train_done_ts is None else train_step_done_ts - prev_train_done_ts + ) + adapter_metrics["time/train_step_done_interval_s"] = lora_train_step_interval_s + last_train_step_done_ts_by_adapter[name] = train_step_done_ts + for k, v in reduced_train_metrics.items(): + if f"/{name}/" in k: + adapter_metrics[k] = v + else: + adapter_metrics.setdefault(k, v) + + # Update step counters. + lora_step[adapter_for_tag] += 1 + global_tick += 1 + + tick_metrics["system/global_tick"] = global_tick + for name, step in lora_step.items(): + tick_metrics[f"system/lora_step/{name}"] = step + # Cumulative sample count (pattern from agentic_pipeline.py:569). + tick_metrics["system/samples"] = global_tick * self.pipeline_config.rollout_batch_size + for name in dirty_adapters: + adapter_metrics = lora_metrics.setdefault(name, {}) + adapter_metrics["system/global_tick"] = global_tick + adapter_metrics["system/lora_step"] = lora_step[adapter_for_tag] + + # Model update boundary: suspend rollouts only for model_update. + # TODO: fine-granular rollout interruption — currently we abort ALL loras' rollouts + # and force-sync ALL adapters. Better approach: only abort/interrupt requests for the + # just-trained adapter (dirty_adapters), leave other loras' in-flight rollouts running, + # and sync only the updated adapter weights instead of all_adapters. + with Timer(name="model_update", logger=None) as model_update_timer: + ray.get([sched.suspend.remote() for sched in self.rollout_schedulers.values()]) + + if self.pipeline_config.async_pipeline: + # Full offload: stop generation server, discard KV cache + all LoRA tensors. + self.actor_infer.offload_states() + # Full offload destroys all LoRA tensors on infer side — must re-sync every adapter. + # Train-side weights are preserved in pinned CPU memory across offload cycles. + all_adapters = set(self.pipeline_config.actor_train.model_args.adapters.keys()) if self.pipeline_config.actor_train.model_args.adapters else None + model_update_metrics = self.model_update_lora_subset(global_tick, adapters_to_update=all_adapters) + tick_metrics.update(model_update_metrics) + for name in dirty_adapters: + lora_metrics.setdefault(name, {}).update(model_update_metrics) + self.actor_infer.load_states() + # Partial GPU: expand routing state after model_update reloads to all GPUs. + if self.partial_gpu_mode and global_tick > 0: + target_gpus = [] + + target_gpus.extend(self.actor_train.worker_config.device_mapping) + + + # but the lost rank is silent — only the alignment warning in the callee signals it. + dp_ranks_to_add = self._target_gpus_to_dp_ranks_to_add(target_gpus=target_gpus) + expand_result = self._expand_workers(dp_ranks_to_add=dp_ranks_to_add, + train_skip_load=True) + + + tick_metrics.update(expand_result) + for name in dirty_adapters: + lora_metrics.setdefault(name, {}).update(expand_result) + + model_update_s = float(model_update_timer.last) + tick_metrics["time/step_model_update"] = model_update_s + for name in dirty_adapters: + lora_metrics.setdefault(name, {})["time/step_model_update"] = model_update_s + + # Per-adapter validation: run after model_update + expand so inference + # weights are current and schedulers are resumed. + if self.pipeline_config.eval_steps > 0: + for name in dirty_adapters: + if lora_step[name] % self.pipeline_config.eval_steps == 0: + with Timer(name="val", logger=None) as val_timer: + val_metrics = self.val(lora_name=name, global_step=lora_step[name]) + val_metrics["time/step_val"] = val_timer.last + lora_metrics.setdefault(name, {}).update(val_metrics) + + tick_total_s = float(tick_timer.last) + for name in dirty_adapters: + lora_metrics.setdefault(name, {})["time/step_total"] = tick_total_s + if shrink_duration_s is not None: + lora_metrics.setdefault(name, {})["time/step_shrink"] = shrink_duration_s + + if self.pipeline_config.logging_steps > 0 and global_tick % self.pipeline_config.logging_steps == 0: + logger.info(f"tick={global_tick} lora_step={lora_step}") + logger.info(tick_metrics) + + # Per-LoRA metrics to per-LoRA trackers (independent step counters). + if hasattr(self, "lora_trackers"): + for name in sorted(dirty_adapters): + per_lora_metrics = dict(lora_metrics.get(name, {})) + per_lora_metrics["system/lora_name"] = name + self.lora_trackers[name].log(values=per_lora_metrics, step=lora_step[name]) + # Global tick metrics to pipeline-level tracker (shared step counter). + self.tracker.log(values=tick_metrics, step=global_tick) + + # Persist per-lora state for checkpoint resume. + all_done = all(lora_step[name] >= max_steps_per_lora for name in adapters) + self.state.kv["lora_step_by_adapter"] = dict(lora_step) + self.state.kv["global_tick"] = global_tick + self.state.kv["tag_to_adapter"] = dict(tag_to_adapter) + self.state.step = global_tick + # Minimal log_history entry for do_checkpoint (reads log_history[-1] for system/step). + # Do not persist full tick_metrics: base resume replay lacks lora_name context. + self.state.log_history.append({"system/step": global_tick}) + self.do_checkpoint(global_step=global_tick, is_last_step=all_done) + + pending_by_tag.clear() + for tag in tags: + adapter = tag_to_adapter[tag] + if lora_step.get(adapter, 0) >= max_steps_per_lora: + in_flight.pop(tag, None) + continue + if tag in in_flight: + # Keep the existing in-flight request; do not clobber it. + continue + # Use post-increment lora_step for next tick's rollout. + data = DataProto(meta_info={"global_step": lora_step[adapter]}) + in_flight[tag] = self.rollout_schedulers[tag].get_batch.remote( + data, self.pipeline_config.rollout_batch_size + ) + submitted_at_mono[tag] = time.monotonic() + + success = True + finally: + try: + ray.get( + [sched.shutdown.remote() for sched in self.rollout_schedulers.values()] + + [sched.shutdown.remote() for sched in self.val_rollout_schedulers.values()] + ) + except Exception: + logger.exception("Failed to shutdown rollout schedulers") + try: + if hasattr(self, "lora_trackers"): + for lora_tracker in self.lora_trackers.values(): + lora_tracker.finish() + self.tracker.finish() + except Exception: + logger.exception("tracker.finish failed") + if success: + logger.info("pipeline complete!") diff --git a/roll/pipeline/agentic/agentic_pipeline.py b/roll/pipeline/agentic/agentic_pipeline.py index d3e812cd7..6fd1b8f7a 100644 --- a/roll/pipeline/agentic/agentic_pipeline.py +++ b/roll/pipeline/agentic/agentic_pipeline.py @@ -2,7 +2,8 @@ import os.path import time from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, List +import threading +from typing import Any, Dict, List, Optional import numpy as np import ray @@ -26,7 +27,7 @@ get_agentic_response_level_mask, ) from roll.pipeline.base_pipeline import BasePipeline -from roll.utils.constants import RAY_NAMESPACE +from roll.utils.constants import DO_TIME_SHARING, RAY_NAMESPACE from roll.utils.dynamic_batching import dynamic_batching_shard from roll.utils.functionals import ( RunningMoments, @@ -39,12 +40,13 @@ from roll.utils.train_infer_corrections import apply_train_infer_correction_to_batch from roll.utils.kl_controller import get_kl_controller from roll.utils.logging import get_logger -from roll.utils.offload_states import OffloadStateType + logger = get_logger() + def is_lora_training(pipeline_config: AgenticConfig) -> bool: return pipeline_config.actor_train.model_args.lora_target is not None @@ -54,10 +56,23 @@ def __init__(self, pipeline_config: AgenticConfig): self.pipeline_config: AgenticConfig self.pipeline_config.set_max_steps(max_steps=self.pipeline_config.max_steps) + + # AgenticPipeline supports at most one adapter (the auto-converted "default"). + # Multi-adapter training requires AgenticMultiLoraPipeline for per-adapter step counting. + if self.pipeline_config.actor_train.model_args.is_multi_lora: + adapter_names = sorted(self.pipeline_config.actor_train.model_args.adapters.keys()) + raise RuntimeError( + f"AgenticPipeline supports at most 1 LoRA adapter, got {len(adapter_names)}: {adapter_names}. " + "For multi-adapter training, set pipeline_cls to " + "roll.pipeline.agentic.agentic_multi_lora_pipeline.AgenticMultiLoraPipeline." + ) + self.use_ref_model = self.pipeline_config.enable_reference and (not is_lora_training(self.pipeline_config)) # Derived configuration for partial GPU mode (auto-detected from device_mapping) self.partial_gpu_mode: bool = False + # Agentic pipeline does not support time-sharing mode. + assert not DO_TIME_SHARING, "AgenticPipeline must not be instantiated with DO_TIME_SHARING=True" self.kl_ctrl = get_kl_controller( init_kl_coef=self.pipeline_config.init_kl_coef, @@ -79,6 +94,7 @@ def __init__(self, pipeline_config: AgenticConfig): resource_manager=self.resource_manager, worker_config=self.pipeline_config.actor_infer, ) + download_clusters = [self.actor_train, self.actor_infer] if self.use_ref_model: @@ -125,6 +141,7 @@ def __init__(self, pipeline_config: AgenticConfig): name=f"RewardScheduler-{self.pipeline_config.reward.name}", get_if_exists=True, namespace=RAY_NAMESPACE, + scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=ray.get_runtime_context().get_node_id(), soft=False, @@ -139,6 +156,7 @@ def __init__(self, pipeline_config: AgenticConfig): # INIT PHASE: Create RolloutSchedulers self.train_rollout_scheduler = ray.remote(RolloutScheduler).options( name="RolloutScheduler-train", + namespace=RAY_NAMESPACE, scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=ray.get_runtime_context().get_node_id(), soft=False)).remote( @@ -151,6 +169,7 @@ def __init__(self, pipeline_config: AgenticConfig): self.val_rollout_scheduler = ray.remote(RolloutScheduler).options( name="RolloutScheduler-val", + namespace=RAY_NAMESPACE, scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=ray.get_runtime_context().get_node_id(), soft=False)).remote( @@ -162,7 +181,20 @@ def __init__(self, pipeline_config: AgenticConfig): ) self.val_dataset_manager = GlobalDatasetManager.options(name=f"val_dataset_manager", get_if_exists=True, - namespace=RAY_NAMESPACE).remote() + namespace=RAY_NAMESPACE, + + ).remote() + + # Per-pipeline infer resize serialization boundary (ENG-123). + infer_strategy_config = self.actor_infer.worker_config.strategy_args.strategy_config + tp_size = int(infer_strategy_config.get("tensor_parallel_size", 1)) + pp_size = int(infer_strategy_config.get("pipeline_parallel_size", 1)) + self._infer_gpus_per_dp_rank = tp_size * pp_size + self._infer_device_mapping = list(getattr(self.pipeline_config.actor_infer, "device_mapping", None) or []) + if not self._infer_device_mapping: + raise RuntimeError("actor_infer.device_mapping must be set") + self._infer_resize_lock = threading.Lock() + # INIT PHASE: Initialize Clusters refs: List[ray.ObjectRef] = [] refs.extend(self.actor_train.initialize(pipeline_config=self.pipeline_config, blocking=False)) @@ -193,12 +225,51 @@ def __init__(self, pipeline_config: AgenticConfig): self.running = RunningMoments() + # TODO: sync LoRA adapters to actor_infer before first rollout (see AgenticMultiLoraPipeline._initial_model_update). + # Validate partial GPU mode configuration and set self.partial_gpu_mode if self.pipeline_config.partial_gpu_mode: self.partial_gpu_mode = self._validate_partial_gpu_config() else: self.partial_gpu_mode = False + + def _shrink_workers(self, *, dp_ranks_to_remove: List[int]) -> Dict[str, Any]: + """Pipeline-local shrink helper (ENG-123). + + Serializes with a per-pipeline lock and performs: + - val routing-only shrink first (skip_offload=True) + - train shrink second (skip_offload=False; real offload) + """ + if not isinstance(dp_ranks_to_remove, list) or not dp_ranks_to_remove: + raise ValueError("dp_ranks_to_remove must be a non-empty list[int]") + with self._infer_resize_lock: + val_metrics = ray.get( + self.val_rollout_scheduler.shrink_sampler.remote(dp_ranks_to_remove, skip_offload=True) + ) + train_metrics = ray.get( + self.train_rollout_scheduler.shrink_sampler.remote(dp_ranks_to_remove, skip_offload=False) + ) + # Val scheduler call is routing-only side effect; discard its metrics to match upstream return shape. + return dict(train_metrics or {}) + + def _expand_workers(self, *, dp_ranks_to_add: List[int], train_skip_load: bool) -> Dict[str, Any]: + """Pipeline-local expand helper (ENG-123). + + Serializes with a per-pipeline lock and performs: + - train expand first (skip_load=train_skip_load) + - val routing-only expand second (skip_load=True) + """ + if not isinstance(dp_ranks_to_add, list) or not dp_ranks_to_add: + raise ValueError("dp_ranks_to_add must be a non-empty list[int]") + with self._infer_resize_lock: + train_metrics = ray.get( + self.train_rollout_scheduler.expand_sampler.remote(dp_ranks_to_add, skip_load=bool(train_skip_load)) + ) + # Val scheduler call is routing-only side effect; discard its metrics to match upstream return shape. + ray.get(self.val_rollout_scheduler.expand_sampler.remote(dp_ranks_to_add, skip_load=True)) + return dict(train_metrics or {}) + @torch.no_grad() def run(self): # Calculate tokens-per-second system throughput @@ -223,9 +294,9 @@ def run(self): # Suspend rollout scheduler to pause request processing ray.get(self.train_rollout_scheduler.suspend.remote()) - # Stop generation server if using async mode (will restart after model update) + # Full offload: stop generation server, discard KV cache + LoRA (will restart after model update). if self.pipeline_config.async_pipeline: - self.actor_infer.offload_states(include=OffloadStateType.other_params) + self.actor_infer.offload_states() # PHASE 3: Model Update with Timer(name="model_update", logger=None) as model_update_timer: @@ -252,12 +323,21 @@ def run(self): if hasattr(self.critic.worker_config, 'device_mapping') and self.critic.worker_config.device_mapping: target_gpus.extend(self.critic.worker_config.device_mapping) - if target_gpus: - expand_metrics = ray.get( - self.train_rollout_scheduler.expand_sampler.remote(target_gpus, skip_load=True) - ) - logger.info(f"Expand routing state (skip_load): {expand_metrics}") - metrics.update({"expand/" + k: v for k, v in expand_metrics.items()}) + # Routing restore after model_update: model_update loaded states to all GPUs in actor_infer's + # device_mapping. Expand should restore routing without re-loading. + # + # Expand selection rule (current requirement): add a dp-rank iff its full dp-slice (the + # gpus_per_dp_rank-sized slice in device_mapping) is fully contained in the available GPU set. + # + # Limitation: if a dp-rank slice straddles the trainer/infer GPU boundary, it will be + # removed during shrink (intersection semantics) but never re-added during expand (subset + # semantics), effectively reducing rollout parallelism. The callee pair + # (_target_gpus_to_dp_ranks_to_remove / _target_gpus_to_dp_ranks_to_add) handles this safely + # but the lost rank is silent — only the alignment warning in the callee signals it. + dp_ranks_to_add = self._target_gpus_to_dp_ranks_to_add(target_gpus=target_gpus) + expand_metrics = self._expand_workers(dp_ranks_to_add=dp_ranks_to_add, train_skip_load=True) + logger.info(f"Expand routing state: {expand_metrics}") + metrics.update({"expand/" + k: v for k, v in expand_metrics.items()}) batch: DataProto = DataProto() batch.meta_info = {"global_step": global_step} @@ -317,7 +397,7 @@ def run(self): # During training: actor_train uses freed GPUs [0,1] # Next iteration: model_update reloads actor_infer to all GPUs [0,1,2,3] elif self.partial_gpu_mode: - with Timer(name="cal_ref_log_probs", logger=None) as shrink_timer: + with Timer(name="exec_shrink", logger=None) as shrink_timer: target_gpus = [] # Collect actor_train GPUs if hasattr(self.actor_train.worker_config, 'device_mapping') and self.actor_train.worker_config.device_mapping: @@ -328,7 +408,8 @@ def run(self): target_gpus.extend(self.critic.worker_config.device_mapping) assert target_gpus, "cannot be empty" - shrink_metrics = ray.get(self.train_rollout_scheduler.shrink_sampler.remote(target_gpus)) + dp_ranks_to_remove = self._target_gpus_to_dp_ranks_to_remove(target_gpus=list(target_gpus)) + shrink_metrics = self._shrink_workers(dp_ranks_to_remove=dp_ranks_to_remove) logger.info(f"Shrink sampler: {shrink_metrics}") metrics.update({"shrink/" + k: v for k, v in shrink_metrics.items()}) metrics["time/step_shrink"] = shrink_timer.last diff --git a/roll/pipeline/agentic/env/deepeyes/env.py b/roll/pipeline/agentic/env/deepeyes/env.py index 8b3b31cfc..9300d075f 100644 --- a/roll/pipeline/agentic/env/deepeyes/env.py +++ b/roll/pipeline/agentic/env/deepeyes/env.py @@ -20,7 +20,7 @@ from roll.pipeline.rlvr.rlvr_config import RewardConfig from roll.pipeline.agentic.llm_proxy.proxy_utils import generate_by_proxy from roll.utils.checkpoint_manager import file_lock_context -from roll.utils.constants import RAY_NAMESPACE, EpisodeStopReason +from roll.utils.constants import RAY_NAMESPACE, EpisodeStopReason, rlix_env_vars from roll.utils.random_utils import all_seed from roll.utils.logging import get_logger @@ -207,7 +207,10 @@ def __init__( # Convert train/val mode to sample/traversal for GlobalDataset global_dataset_mode = "sample" if self.mode == "train" else "traversal" self.dataset = DeepEyesDataset.options( - name=f"{self.mode}_deepeyes", get_if_exists=True, namespace=RAY_NAMESPACE + name=f"{self.mode}_deepeyes", + get_if_exists=True, + namespace=RAY_NAMESPACE, + runtime_env={"env_vars": rlix_env_vars()}, ).remote( dataset_name=data_args.file_name, split="train", @@ -218,7 +221,10 @@ def __init__( idx=idx, ) self.dataset_manager = GlobalDatasetManager.options( - name=f"{self.mode}_dataset_manager", get_if_exists=True, namespace=RAY_NAMESPACE + name=f"{self.mode}_dataset_manager", + get_if_exists=True, + namespace=RAY_NAMESPACE, + runtime_env={"env_vars": rlix_env_vars()}, ).remote() ray.get(self.dataset_manager.register.remote(dataset_name="deepeyes", dataset_ref=self.dataset)) diff --git a/roll/pipeline/agentic/env/gem/math_env.py b/roll/pipeline/agentic/env/gem/math_env.py index cec5cc513..cecc59fda 100644 --- a/roll/pipeline/agentic/env/gem/math_env.py +++ b/roll/pipeline/agentic/env/gem/math_env.py @@ -11,7 +11,7 @@ import ray from roll.datasets.global_dataset import GlobalDataset, GlobalDatasetManager -from roll.utils.constants import RAY_NAMESPACE +from roll.utils.constants import RAY_NAMESPACE, rlix_env_vars logger = logging.getLogger(__name__) @@ -38,12 +38,16 @@ def __init__( global_dataset_mode = "sample" if self.mode == "train" else "traversal" self.dataset = GlobalDataset.options(name=f"{self.mode}_{dataset_name}", get_if_exists=True, - namespace=RAY_NAMESPACE).remote(dataset_name=dataset_name, + namespace=RAY_NAMESPACE, + runtime_env={"env_vars": rlix_env_vars()}, + ).remote(dataset_name=dataset_name, split=split, mode=global_dataset_mode) self.dataset_manager = GlobalDatasetManager.options(name=f"{self.mode}_dataset_manager", get_if_exists=True, - namespace=RAY_NAMESPACE).remote() + namespace=RAY_NAMESPACE, + runtime_env={"env_vars": rlix_env_vars()}, + ).remote() ray.get(self.dataset_manager.register.remote(dataset_name=dataset_name, dataset_ref=self.dataset)) self.idx = 0 self.epoch = 0 @@ -93,4 +97,4 @@ def step( "metrics": metrics, "metrics_agg_mode": metrics_agg_mode } - return TERMINAL_STATE, reward, True, True, info \ No newline at end of file + return TERMINAL_STATE, reward, True, True, info diff --git a/roll/pipeline/agentic/env_manager/agent_native_env_manager.py b/roll/pipeline/agentic/env_manager/agent_native_env_manager.py index 6d11285a6..fe292920d 100644 --- a/roll/pipeline/agentic/env_manager/agent_native_env_manager.py +++ b/roll/pipeline/agentic/env_manager/agent_native_env_manager.py @@ -1,5 +1,6 @@ import copy import json +import os import time from datetime import datetime from typing import List, Union, Dict, Optional @@ -15,11 +16,12 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.agentic.env_manager.token_mask_utils import convert_list_content_str from roll.pipeline.agentic.env_manager.traj_env_manager import TrajEnvManager -from roll.utils.constants import GenerateStopReason, EpisodeStopReason +from roll.utils.constants import DO_TIME_SHARING, GenerateStopReason, EpisodeStopReason from roll.utils.functionals import pad_to_length, aggregate_metrics from roll.utils.hash_utils import compute_object_hash + class AgentNativeStepEnvManager(TrajEnvManager): """ Used for native like format. @@ -73,6 +75,8 @@ def run_rollout_loop(self, data: DataProto): self.stop_reason = EpisodeStopReason.MAX_LENGTH elif stop_reason == GenerateStopReason.ABORT: self.stop_reason = EpisodeStopReason.ABORT + + self.rollout_cache.attempt += 1 self.log_stats["current_step"].append(self.current_step) self.log_stats["generate_time"].append(round(generate_timer.last)) @@ -139,6 +143,7 @@ def step(self, llm_output: DataProto): observation, reward, terminated, truncated, info = self.env.step(action=response) self.rollout_cache.step += 1 + self.rollout_cache.attempt = 0 # terminated 完全由swe|tb env决定 self.rollout_cache.terminated = terminated @@ -216,6 +221,10 @@ def format_messages(self, rollout_cache: RolloutCache) -> DataProto: "attention_mask": attention_mask, "position_ids": position_ids, }, batch_size=input_ids.shape[0]) + # Inject lora_name for inference routing; single-adapter uses sole key, multi-adapter validates normalized tag. + lora_name = self._resolve_lora_name(self.pipeline_config.actor_infer.model_args.adapters) + if lora_name is not None: + lm_input.non_tensor_batch["lora_name"] = np.array([lora_name], dtype=object) current_cache["prompt_ids"] = prompt_ids current_cache['state_hash'] = compute_object_hash(messages) @@ -238,6 +247,8 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): samples: List[DataProto] = [] step_rewards = [i['reward'] for i in self.rollout_cache.history] episode_score = sum(step_rewards) + # Resolve lora_name once per rollout; adapter map and tag are rollout-constant. + _lora_name = self._resolve_lora_name(self.pipeline_config.actor_train.model_args.adapters) # Initialize lists for step length statistics step_prompt_length_list = [] @@ -302,6 +313,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "env_ids": np.array([self.rollout_cache.env_id], dtype=object), "group_ids": np.array([self.rollout_cache.group_id], dtype=object), "tags": np.array([self.rollout_cache.tag], dtype=object), + **({"lora_name": np.array([_lora_name], dtype=object)} if _lora_name is not None else {}), "step_scores": np.array([history["reward"]], dtype=object), # step-level reward, return by env "episode_scores": np.array([episode_score], dtype=object), "state_hash": np.array([history['state_hash']], dtype=object), @@ -436,7 +448,8 @@ def create_placeholder_rollout(self, episode_id): """ self.logger.info(f"[PLACEHOLDER_ROLLOUT] failure_mode: {self.failure_mode}") - seq_len = length=self.pipeline_config.sequence_length + # Keep placeholder length aligned with training sequence_length to preserve tensor contracts. + seq_len = self.pipeline_config.sequence_length input_ids = torch.full((1, seq_len), self.tokenizer.pad_token_id, dtype=torch.long) attention_mask = torch.zeros((1, seq_len), dtype=torch.long) position_ids = torch.zeros((1, seq_len), dtype=torch.long) @@ -458,10 +471,16 @@ def create_placeholder_rollout(self, episode_id): infer_logprobs = torch.zeros((1, seq_len - 1), dtype=torch.float) lm_input.batch["infer_logprobs"] = infer_logprobs + # Inject lora_name only when adapters are configured; uses env_config tag since rollout_cache may not exist. + _placeholder_lora_name = self._resolve_lora_name( + self.pipeline_config.actor_train.model_args.adapters, tag=self.env_config['tag'] + ) + _placeholder_lora = {"lora_name": np.array([_placeholder_lora_name], dtype=object)} if _placeholder_lora_name is not None else {} lm_input.non_tensor_batch = { "env_ids": np.array([self.env_config['env_id']], dtype=object), "group_ids": np.array([self.env_config['group_id']], dtype=object), "tags": np.array([self.env_config['tag']], dtype=object), + **_placeholder_lora, "step_scores": np.array([0], dtype=object), "episode_scores": np.array([0], dtype=object), "state_hash": np.array([''], dtype=object), @@ -518,4 +537,4 @@ def filter(self, group_id: int, episode_id: int, group: list[DataProto]): return False self.global_filter_stats["filtered"] += 1 - return True \ No newline at end of file + return True diff --git a/roll/pipeline/agentic/env_manager/base_env_manager.py b/roll/pipeline/agentic/env_manager/base_env_manager.py index 643a08e30..b24f069a1 100644 --- a/roll/pipeline/agentic/env_manager/base_env_manager.py +++ b/roll/pipeline/agentic/env_manager/base_env_manager.py @@ -18,6 +18,7 @@ class RolloutCache: truncated: bool = False terminated: bool = False step: int = 0 + attempt: int = 0 class BaseEnvManager: @@ -61,4 +62,4 @@ def update_step(self, global_step): self.current_step = global_step def stop(self): - self.running = False \ No newline at end of file + self.running = False diff --git a/roll/pipeline/agentic/env_manager/step_concat_env_manager.py b/roll/pipeline/agentic/env_manager/step_concat_env_manager.py index 23438f935..66773c60a 100644 --- a/roll/pipeline/agentic/env_manager/step_concat_env_manager.py +++ b/roll/pipeline/agentic/env_manager/step_concat_env_manager.py @@ -1,3 +1,4 @@ +import numpy as np import torch from tensordict import TensorDict @@ -6,6 +7,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.agentic.env_manager.step_env_manager import StepEnvManager from roll.utils.hash_utils import compute_object_hash + from roll.utils.str_utils import contains_renderable_field @@ -44,6 +46,10 @@ def format_messages(self, rollout_cache: RolloutCache) -> DataProto: "attention_mask": attention_mask, "position_ids": position_ids, }, batch_size=input_ids.shape[0]) + # Inject lora_name for inference routing; single-adapter uses sole key, multi-adapter validates normalized tag. + lora_name = self._resolve_lora_name(self.pipeline_config.actor_infer.model_args.adapters) + if lora_name is not None: + lm_input.non_tensor_batch["lora_name"] = np.array([lora_name], dtype=object) current_cache["prompt_ids"] = prompt_ids current_cache['state_hash'] = compute_object_hash(current_observation) current_cache['messages'] = messages diff --git a/roll/pipeline/agentic/env_manager/step_env_manager.py b/roll/pipeline/agentic/env_manager/step_env_manager.py index 987ff97f1..f00ddcdf7 100644 --- a/roll/pipeline/agentic/env_manager/step_env_manager.py +++ b/roll/pipeline/agentic/env_manager/step_env_manager.py @@ -10,6 +10,7 @@ from roll.pipeline.agentic.env_manager.traj_env_manager import TrajEnvManager from roll.utils.functionals import pad_to_length, aggregate_metrics from roll.utils.hash_utils import compute_object_hash + from roll.utils.str_utils import contains_renderable_field @@ -59,6 +60,10 @@ def format_messages(self, rollout_cache: RolloutCache) -> DataProto: "attention_mask": attention_mask, "position_ids": position_ids, }, batch_size=input_ids.shape[0]) + # Inject lora_name for inference routing; single-adapter uses sole key, multi-adapter validates normalized tag. + lora_name = self._resolve_lora_name(self.pipeline_config.actor_infer.model_args.adapters) + if lora_name is not None: + lm_input.non_tensor_batch["lora_name"] = np.array([lora_name], dtype=object) current_cache["prompt_ids"] = prompt_ids current_cache['state_hash'] = compute_object_hash(current_observation) current_cache['messages'] = messages @@ -73,6 +78,8 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): samples: List[DataProto] = [] episode_score = sum([i['reward'] for i in self.rollout_cache.history]) + # Resolve lora_name once per rollout; adapter map and tag are rollout-constant. + _lora_name = self._resolve_lora_name(self.pipeline_config.actor_train.model_args.adapters) for step, history in enumerate(rollout_cache.history): token_ids = history["prompt_ids"] + history["response_ids"] response_masks = [0] * len(history["prompt_ids"]) + [1] * len(history["response_ids"]) @@ -100,6 +107,17 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): response_mask = pad_to_length(response_mask, length=self.pipeline_config.sequence_length, pad_value=0) prompt_mask = pad_to_length(prompt_mask, length=self.pipeline_config.sequence_length, pad_value=0) score_tensor = pad_to_length(score_tensor, length=self.pipeline_config.sequence_length, pad_value=0) + non_tensor_batch = { + "episode_scores": np.array([episode_score], dtype=object), + "step_scores": np.array([history["reward"]], dtype=object), # step-level reward, return by env + "tags": np.array([self.rollout_cache.tag], dtype=object), + "env_ids": np.array([self.rollout_cache.env_id], dtype=object), + "group_ids": np.array([self.rollout_cache.group_id], dtype=object), + "state_hash": np.array([history['state_hash']], dtype=object), + "step": np.array([step], dtype=object), + } + if _lora_name is not None: + non_tensor_batch["lora_name"] = np.array([_lora_name], dtype=object) lm_input = DataProto( batch=TensorDict( { @@ -111,15 +129,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "scores": score_tensor, }, batch_size=input_ids.shape[0]), - non_tensor_batch={ - "episode_scores": np.array([episode_score], dtype=object), - "step_scores": np.array([history["reward"]], dtype=object), # step-level reward, return by env - "tags": np.array([self.rollout_cache.tag], dtype=object), - "env_ids": np.array([self.rollout_cache.env_id], dtype=object), - "group_ids": np.array([self.rollout_cache.group_id], dtype=object), - "state_hash": np.array([history['state_hash']], dtype=object), - "step": np.array([step], dtype=object), - } + non_tensor_batch=non_tensor_batch ) if len(infer_logprobs): infer_logprobs = torch.tensor(infer_logprobs, dtype=torch.float).unsqueeze(0) @@ -138,4 +148,4 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): env_metric = {f"env/{rollout_cache.tag}/{k}": v for k, v in env_metric.items()} env_metric["env/response_length"] = response_length batch.meta_info = {"metrics": env_metric} - return batch \ No newline at end of file + return batch diff --git a/roll/pipeline/agentic/env_manager/traj_env_manager.py b/roll/pipeline/agentic/env_manager/traj_env_manager.py index 2a1f23ee3..b88c709d8 100644 --- a/roll/pipeline/agentic/env_manager/traj_env_manager.py +++ b/roll/pipeline/agentic/env_manager/traj_env_manager.py @@ -21,9 +21,10 @@ from roll.distributed.scheduler.generate_scheduler import RequestScheduler from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.agentic.agentic_config import EnvManagerConfig, AgenticConfig -from roll.utils.constants import GenerateStopReason +from roll.utils.constants import DO_TIME_SHARING, GenerateStopReason from roll.utils.functionals import pad_to_length, aggregate_metrics from roll.utils.logging import get_logger +from roll.utils.lora_routing import normalize_domain from roll.utils.str_utils import contains_renderable_field @@ -120,6 +121,10 @@ def run_rollout_loop(self, data: DataProto): with Timer(name="step", logger=None) as step_timer: if stop_reason == GenerateStopReason.FINISH: rollout_cache: RolloutCache = self.step(lm_output) + elif stop_reason == GenerateStopReason.ABORT: + # Retry the same turn (same step) after abort. This is used to survive + # shrink/rebalance aborts. Each retry increments attempt so request_ids remain unique. + self.rollout_cache.attempt += 1 log_stats["step_time"].append(step_timer.last) if self.running and (rollout_cache.terminated or stop_reason == GenerateStopReason.MAX_LENGTH): @@ -173,6 +178,7 @@ def step(self, llm_output: DataProto): suffix = info.pop("suffix", None) self.rollout_cache.step += 1 + self.rollout_cache.attempt = 0 self.rollout_cache.terminated = terminated self.rollout_cache.truncated = truncated if self.rollout_cache.step >= self.env_config.max_steps: @@ -232,6 +238,26 @@ def make_decision(self, rollout_cache: RolloutCache): lm_output.meta_info["stop_reason"] = GenerateStopReason.FINISH return lm_output + def _resolve_lora_name(self, adapters: dict | None, tag: str | None = None) -> str | None: + """Resolve LoRA adapter name from configured adapters and env tag. + + Returns the resolved adapter name, or None if no adapters configured. + Defaults to self.rollout_cache.tag if tag is not provided. + """ + if adapters is None: + return None + if len(adapters) == 1: + return next(iter(adapters.keys())) + resolved_tag = tag if tag is not None else self.rollout_cache.tag + normalized = normalize_domain(resolved_tag) + valid_adapters = set(adapters.keys()) + if normalized not in valid_adapters: + raise RuntimeError( + f"Env tag {resolved_tag!r} normalizes to {normalized!r} " + f"which is not in configured adapters: {sorted(valid_adapters)}" + ) + return normalized + def format_messages(self, history: RolloutCache) -> DataProto: content = self.rollout_cache.history[-1] @@ -278,6 +304,10 @@ def format_messages(self, history: RolloutCache) -> DataProto: "attention_mask": attention_mask, "position_ids": position_ids, }, batch_size=input_ids.shape[0]) + # Inject lora_name for inference routing; single-adapter uses sole key, multi-adapter validates normalized tag. + lora_name = self._resolve_lora_name(self.pipeline_config.actor_infer.model_args.adapters) + if lora_name is not None: + lm_input.non_tensor_batch["lora_name"] = np.array([lora_name], dtype=object) content["prompt_ids"] = prompt_ids content["messages"] = messages return lm_input @@ -355,13 +385,18 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): infer_logprobs = pad_to_length(infer_logprobs, length=self.pipeline_config.sequence_length, pad_value=0) lm_input.batch["infer_logprobs"] = infer_logprobs[:, 1:] - lm_input.non_tensor_batch.update({ + non_tensor_update = { "env_ids": np.array([self.rollout_cache.env_id], dtype=object), "group_ids": np.array([self.rollout_cache.group_id], dtype=object), "tags": np.array([self.rollout_cache.tag], dtype=object), "step_scores": np.array([scores], dtype=object), "episode_scores": np.array([episode_score], dtype=object), - }) + } + # Inject lora_name only when adapters are configured; non-LoRA paths never read it. + lora_name = self._resolve_lora_name(self.pipeline_config.actor_train.model_args.adapters) + if lora_name is not None: + non_tensor_update["lora_name"] = np.array([lora_name], dtype=object) + lm_input.non_tensor_batch.update(non_tensor_update) metrics_agg_mode = self.rollout_cache.history[-1].get('metrics_agg_mode', {}) history_metrics = [item.get("metrics", {}) for item in self.rollout_cache.history] @@ -371,4 +406,4 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): env_metric = {f"env/{rollout_cache.tag}/{k}": v for k, v in env_metric.items()} env_metric["env/response_length"] = response_length lm_input.meta_info = {"metrics": env_metric} - return lm_input \ No newline at end of file + return lm_input diff --git a/roll/pipeline/agentic/env_manager/vl_traj_env_manager.py b/roll/pipeline/agentic/env_manager/vl_traj_env_manager.py index f109ca752..4c9b892a9 100644 --- a/roll/pipeline/agentic/env_manager/vl_traj_env_manager.py +++ b/roll/pipeline/agentic/env_manager/vl_traj_env_manager.py @@ -22,12 +22,13 @@ token_ids_to_assistant_mask from roll.pipeline.agentic.env_manager.traj_env_manager import TrajEnvManager from roll.pipeline.agentic.llm_proxy import BaseLLMProxy, create_llm_proxy -from roll.utils.constants import EpisodeStopReason, GenerateStopReason, RAY_NAMESPACE +from roll.utils.constants import DO_TIME_SHARING, EpisodeStopReason, GenerateStopReason, RAY_NAMESPACE from roll.utils.env_action_limiter import get_global_limiter from roll.utils.functionals import pad_to_length, aggregate_metrics from roll.utils.logging import get_logger + class VLTrajEnvManager(TrajEnvManager): def __init__(self, worker_config: EnvManagerConfig, @@ -184,6 +185,7 @@ def run_rollout_loop(self, data: DataProto): self.stop_reason = EpisodeStopReason.MAX_LENGTH elif generation_stop_reason == GenerateStopReason.ABORT: self.stop_reason = EpisodeStopReason.ABORT + self.rollout_cache.attempt += 1 log_stats["current_step"].append(self.current_step) log_stats["generate_time"].append(generate_timer.last) @@ -223,6 +225,7 @@ def step(self, llm_output: DataProto): suffix = info.pop("suffix", None) self.rollout_cache.step += 1 + self.rollout_cache.attempt = 0 self.rollout_cache.terminated = terminated self.rollout_cache.truncated = truncated if self.rollout_cache.step >= self.env_config.max_steps: @@ -403,6 +406,10 @@ def replace_placeholder(text): f"extra_suffix_length={history.history[-1]['extra_suffix_length']}" ) + # Inject lora_name for inference routing; single-adapter uses sole key, multi-adapter validates normalized tag. + lora_name = self._resolve_lora_name(self.pipeline_config.actor_infer.model_args.adapters) + if lora_name is not None: + lm_input.non_tensor_batch["lora_name"] = np.array([lora_name], dtype=object) return lm_input, messages def formulate_rollouts(self, rollout_cache: RolloutCache): @@ -474,14 +481,19 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "prompt_mask": prompt_mask, "scores": score_tensor, }) - lm_input.non_tensor_batch.update({ + non_tensor_update = { "env_ids": np.array([self.rollout_cache.env_id], dtype=object), "group_ids": np.array([self.rollout_cache.group_id], dtype=object), "messages_list": np.array([messages], dtype=object), "tags": np.array([self.rollout_cache.tag], dtype=object), "step_scores": np.array([scores], dtype=object), "episode_scores": np.array([episode_score], dtype=object), - }) + } + # Inject lora_name only when adapters are configured; non-LoRA paths never read it. + lora_name = self._resolve_lora_name(self.pipeline_config.actor_train.model_args.adapters) + if lora_name is not None: + non_tensor_update["lora_name"] = np.array([lora_name], dtype=object) + lm_input.non_tensor_batch.update(non_tensor_update) metrics_agg_mode = self.rollout_cache.history[-1].get('metrics_agg_mode', {}) history_metrics = [item.get("metrics", {}) for item in self.rollout_cache.history] diff --git a/roll/pipeline/agentic/environment_worker.py b/roll/pipeline/agentic/environment_worker.py index bd4ede7b6..c0a969b2d 100644 --- a/roll/pipeline/agentic/environment_worker.py +++ b/roll/pipeline/agentic/environment_worker.py @@ -98,12 +98,12 @@ async def run_rollout_loop(self, seed): os.environ["WORKER_NAME"] = f"EnvironmentWorker_{self.rank}" loop = asyncio.get_event_loop() - pool = ThreadPoolExecutor(max_workers=len(self.env_managers)) + pool = ThreadPoolExecutor(max_workers= max(len(self.env_managers), 1)) def run_with_profiler(env_manager, data_proto): with local_profiler(): return env_manager.run_rollout_loop(data_proto) - + def run_without_profiler(env_manager, data_proto): return env_manager.run_rollout_loop(data_proto) diff --git a/roll/pipeline/base_pipeline.py b/roll/pipeline/base_pipeline.py index 5c4d67e78..2492b21de 100644 --- a/roll/pipeline/base_pipeline.py +++ b/roll/pipeline/base_pipeline.py @@ -15,6 +15,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.distributed.scheduler.resource_manager import ResourceManager from roll.utils.checkpoint_manager import CheckpointManager, download_model +from roll.utils.constants import DO_TIME_SHARING from roll.utils.functionals import reduce_metrics from roll.utils.logging import get_logger from roll.utils.tracking import create_tracker @@ -30,9 +31,15 @@ class BasePipeline: def __init__(self, pipeline_config): set_seed(seed=pipeline_config.seed) self.pipeline_config = pipeline_config - self.resource_manager = ResourceManager( - num_nodes=self.pipeline_config.num_nodes, num_gpus_per_node=self.pipeline_config.num_gpus_per_node - ) + if DO_TIME_SHARING: + from roll.distributed.scheduler.resource_manager import RollResourceManagerProxy + self.resource_manager = RollResourceManagerProxy( + num_gpus_per_node=self.pipeline_config.num_gpus_per_node + ) + else: + self.resource_manager = ResourceManager( + num_nodes=self.pipeline_config.num_nodes, num_gpus_per_node=self.pipeline_config.num_gpus_per_node + ) self.state = WorkerState() self.checkpoint_manager = CheckpointManager(checkpoint_config=self.pipeline_config.checkpoint_config) self.tracker = create_tracker( @@ -51,11 +58,22 @@ def __init__(self, pipeline_config): load_dir = os.path.join(self.resume_from_checkpoint, "pipeline") self.state = WorkerState.load_from_json(load_dir=load_dir, tag="pipeline") - def resume_metrics(): - for metrics in self.state.log_history: - self.tracker.log(values=metrics, step=metrics["system/step"]) - - self.resume_futures.append(self.executor.submit(resume_metrics)) + # Skip log_history replay for multi-LoRA checkpoints: per-LoRA tracker logs + # are not reconstructable from minimal log_history entries. + saved_tag_to_adapter = self.state.kv.get("tag_to_adapter") + is_multi_lora_resume = ( + isinstance(saved_tag_to_adapter, dict) and len(set(saved_tag_to_adapter.values())) > 1 + ) + if is_multi_lora_resume: + logger.warning( + "Resuming from a multi-LoRA checkpoint. Skipping log_history replay: " + "per-LoRA tracker logs are not reconstructable from minimal log_history entries." + ) + else: + def resume_metrics(): + for metrics in self.state.log_history: + self.tracker.log(values=metrics, step=metrics["system/step"]) + self.resume_futures.append(self.executor.submit(resume_metrics)) def run(self): pass @@ -75,19 +93,34 @@ def model_update(self, global_step): model_update_group.tgt_cluster.process_weights_after_loading() return metrics - def do_checkpoint(self, global_step, is_last_step=None): + def model_update_lora_subset(self, global_step: int, *, adapters_to_update: set[str] | None = None) -> dict: + """Adapter-subset model update helper for multi-LoRA pipelines.""" + metrics: dict = {} + for model_update_group in self.model_update_groups: + metrics.update(model_update_group.model_update(step=global_step, adapters_to_update=adapters_to_update)) + model_update_group.tgt_cluster.process_weights_after_loading() + return metrics + + def do_checkpoint(self, global_step, is_last_step=None, offload_after_checkpoint: bool = False): if is_last_step is None: is_last_step = global_step == self.pipeline_config.max_steps - 1 metrics = self.state.log_history[-1] metrics["system/step"] = global_step if global_step > 0 and ( - global_step % self.pipeline_config.save_steps == 0 or global_step == self.pipeline_config.max_steps - 1 + global_step % self.pipeline_config.save_steps == 0 + or global_step == self.pipeline_config.max_steps - 1 + or is_last_step ): ckpt_metrics_refss = [] for cluster in self.checkpoint_clusters: ckpt_metrics_refss.append( - cluster.do_checkpoint(global_step=global_step, is_last_step=is_last_step, blocking=False) + cluster.do_checkpoint( + global_step=global_step, + is_last_step=is_last_step, + offload_after_checkpoint=offload_after_checkpoint, + blocking=False, + ) ) for ckpt_metrics_refs in ckpt_metrics_refss: @@ -154,6 +187,80 @@ def _cleanup_old_checkpoints(self): except Exception as e: logger.warning(f"Failed to delete checkpoint {ckpt_dir}: {e}") + # -- Partial-GPU helpers: translate GPU IDs to DP ranks for shrink/expand -- + # Subclasses must set _infer_gpus_per_dp_rank and _infer_device_mapping during __init__. + + _infer_gpus_per_dp_rank: int = 0 + _infer_device_mapping: List[int] = [] + + def _target_gpus_to_dp_ranks_to_remove(self, *, target_gpus: List[int]) -> List[int]: + """Translate target GPU IDs to DP ranks for shrink (intersection semantics). + + A DP rank is included if ANY of its GPUs overlap with target_gpus. + This is used for shrink operations where we want to offload any rank + that touches the training GPU set. + """ + if not isinstance(target_gpus, list) or not target_gpus: + raise ValueError("target_gpus must be a non-empty list[int]") + gpus_per_dp_rank = int(self._infer_gpus_per_dp_rank) + device_mapping = list(self._infer_device_mapping) + if len(device_mapping) % gpus_per_dp_rank != 0: + raise RuntimeError("device_mapping length must be divisible by gpus_per_dp_rank") + target = set(int(gpu_id) for gpu_id in target_gpus) + min_gpu = min(target) + max_gpu = max(target) + if min_gpu % gpus_per_dp_rank != 0 or (max_gpu + 1) % gpus_per_dp_rank != 0: + logger.warning( + f"Target GPU range [{min_gpu}, {max_gpu}] not aligned with DP granularity " + f"({gpus_per_dp_rank}). DP rank boundary violation detected " + f"for target GPUs {sorted(target)}. " + f"Rollout DP ranks may not cleanly map to training GPUs." + ) + max_dp = len(device_mapping) // gpus_per_dp_rank + out: List[int] = [] + for dp_rank in range(max_dp): + start = dp_rank * gpus_per_dp_rank + dp_gpus = set(int(gpu_id) for gpu_id in device_mapping[start : start + gpus_per_dp_rank]) + if dp_gpus.intersection(target): + out.append(dp_rank) + if not out: + raise RuntimeError("No dp ranks matched target_gpus for shrink") + return out + + def _target_gpus_to_dp_ranks_to_add(self, *, target_gpus: List[int]) -> List[int]: + """Translate target GPU IDs to DP ranks for expand (subset semantics). + + A DP rank is included only if ALL its GPUs are in target_gpus. + This is used for expand operations where we only want to activate ranks + whose full GPU slice is available. + """ + if not isinstance(target_gpus, list) or not target_gpus: + raise ValueError("target_gpus must be a non-empty list[int]") + gpus_per_dp_rank = int(self._infer_gpus_per_dp_rank) + device_mapping = list(self._infer_device_mapping) + if len(device_mapping) % gpus_per_dp_rank != 0: + raise RuntimeError("device_mapping length must be divisible by gpus_per_dp_rank") + target = set(int(gpu_id) for gpu_id in target_gpus) + min_gpu = min(target) + max_gpu = max(target) + if min_gpu % gpus_per_dp_rank != 0 or (max_gpu + 1) % gpus_per_dp_rank != 0: + logger.warning( + f"Target GPU range [{min_gpu}, {max_gpu}] not aligned with DP granularity " + f"({gpus_per_dp_rank}). DP rank boundary violation detected " + f"for target GPUs {sorted(target)}. " + f"Rollout DP ranks may not cleanly map to training GPUs." + ) + max_dp = len(device_mapping) // gpus_per_dp_rank + out: List[int] = [] + for dp_rank in range(max_dp): + start = dp_rank * gpus_per_dp_rank + dp_gpus = set(int(gpu_id) for gpu_id in device_mapping[start : start + gpus_per_dp_rank]) + if dp_gpus and dp_gpus.issubset(target): + out.append(dp_rank) + if not out: + raise RuntimeError("No dp ranks matched target_gpus for expand") + return out + def download_models(self, *clusters: Cluster): node2pg: Dict[str, PlacementGroup] = {} node2model_names: Dict[str, set[str]] = defaultdict(set) diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index c0ae33d06..95f6d920c 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -23,9 +23,11 @@ ) from roll.platforms import current_platform from roll.utils.checkpoint_manager import download_model +from roll.utils.constants import DO_TIME_SHARING from roll.utils.context_managers import state_offload_manger, log_gpu_memory_usage from roll.utils.dynamic_batching import make_mini_batch_iter_for_dynamic_batching from roll.utils.functionals import agg_loss, append_to_dict, compute_approx_kl, masked_mean, postprocess_generate, reduce_metrics +from roll.utils.lora_routing import ensure_lora_name_in_batch from roll.utils.offload_nccl import reload_process_groups from roll.utils.offload_states import OffloadStateType @@ -73,6 +75,9 @@ def train_step(self, data: DataProto): is_offload_states=is_offload_states, load_kwargs={"include": [OffloadStateType.model_params, OffloadStateType.other_params]}, ): + # TODO: to(device) before get_data_input() is legacy order — broadcast_object_list + # serializes GPU tensors back to CPU. Prefer get_data_input() first (as in train_step_lora) + # to broadcast while data is still on CPU, then move to GPU once. data = data.to(current_platform.device_type) data = self.strategy.get_data_input(data) per_device_train_batch_size = self.worker_config.training_args.per_device_train_batch_size @@ -94,22 +99,126 @@ def train_step(self, data: DataProto): dataloader_kwargs={"shuffle": True}, ) - for batch_idx, backward_batch in tqdm(enumerate(dataloader), - desc=f"{self.worker_name} train global step {global_step}", - total=data.batch.batch_size[0] * self.pipeline_config.ppo_epochs // backward_batch_size): + # Count actual iterations instead of static formula — dynamic batching can change the count. + actual_backward_steps = 0 + for backward_batch in tqdm(dataloader, + desc=f"{self.worker_name} train global step {global_step}"): pg_metrics = self.strategy.train_step(batch=backward_batch, loss_func=self.loss_func) if self.worker_config.use_dynamic_batching_in_train or self.worker_config.use_sequence_packing: pg_metrics = reduce_metrics(pg_metrics) append_to_dict(metrics, pg_metrics) + actual_backward_steps += 1 metrics["actor/lr"] = self.strategy.scheduler.get_last_lr()[0] - metrics["actor/backward_steps"] = data.batch.batch_size[0] * self.pipeline_config.ppo_epochs // backward_batch_size + metrics["actor/backward_steps"] = actual_backward_steps data.to("cpu") self._logprobs_cache.clear() output = DataProto(meta_info={"metrics": metrics}) return output + @register(dispatch_mode=Dispatch.DP_MP_DISPATCH_FIRST) + def train_step_lora(self, data: DataProto): + """Single-adapter-per-call LoRA training step. + + Routes microbatches via ``non_tensor_batch["lora_name"]`` to + ``MegatronTrainStrategy.train_step_lora`` with per-adapter optimizer. + """ + global_step = data.meta_info.get("global_step", 0) + is_offload_states = data.meta_info.get("is_offload_states", True) + metrics = {} + self.logger.info(f"{self.worker_name} lora train global step {global_step}") + + # Fail fast before loading GPU states — caller must broadcast non_tensor_batch for LoRA routing. + if not (data.meta_info or {}).get("_broadcast_non_tensor_batch"): + raise RuntimeError( + "train_step_lora requires caller to set meta_info['_broadcast_non_tensor_batch'] = True" + ) + + # Keep train_step_lora state lifecycle consistent with train_step: + # reload model params before forward, then offload afterwards. + with state_offload_manger( + strategy=self.strategy, + metrics=metrics, + metric_infix=f"{self.cluster_name}/train_step_lora", + is_offload_states=is_offload_states, + load_kwargs={"include": [OffloadStateType.model_params, OffloadStateType.other_params]}, + ): + # Keep a stable denominator for train-step summary metrics, aligned with train_step. + per_device_train_batch_size = self.worker_config.training_args.per_device_train_batch_size + backward_batch_size = ( + per_device_train_batch_size * self.worker_config.training_args.gradient_accumulation_steps + ) + # DP_MP_DISPATCH_FIRST sends batch=None to non-source TP/CP ranks; only normalize routing + # keys on the source shard before strategy broadcast reconstructs full DataProto everywhere. + if data.batch is not None: + batch_size = data.batch.batch_size[0] + ensure_lora_name_in_batch( + data.non_tensor_batch, + adapters=self.worker_config.model_args.adapters, + batch_size=batch_size, + ) + # Broadcast non_tensor_batch then move tensors to GPU. + data = self.strategy.get_data_input(data) + data = data.to(current_platform.device_type) + + # PPO epoch loop — mirrors train_step() so LoRA training runs the same number of + # optimizer steps as full-model training when ppo_epochs > 1. + if self.worker_config.use_dynamic_batching_in_train: + dataloader = make_mini_batch_iter_for_dynamic_batching( + data=data, + epochs=self.pipeline_config.ppo_epochs, + ga_steps=self.worker_config.training_args.gradient_accumulation_steps, + ) + else: + dataloader = data.make_iterator( + mini_batch_size=backward_batch_size, + epochs=self.pipeline_config.ppo_epochs, + seed=self.pipeline_config.seed, + dataloader_kwargs={"shuffle": True}, + ) + + # Count actual iterations instead of static formula — dynamic batching can change the count. + actual_backward_steps = 0 + for backward_batch in dataloader: + lora_metrics = self.strategy.train_step_lora(backward_batch, loss_func=self.loss_func) + if self.worker_config.use_dynamic_batching_in_train or self.worker_config.use_sequence_packing: + lora_metrics = reduce_metrics(lora_metrics) + # Use append_to_dict to match train_step accumulation pattern (consistent with reducers). + append_to_dict(metrics, lora_metrics) + actual_backward_steps += 1 + + # Mirror train_step summary metrics so dashboards remain comparable in multi-LoRA mode. + # For per-adapter optimizer mode, avoid using the top-level scheduler LR because it can + # diverge from actual adapter schedulers; prefer active-adapter LR(s). + if "actor/lr" not in metrics: + if getattr(self.strategy, "adapter_schedulers", None): + active_adapters = [] + lora_arr = data.non_tensor_batch.get("lora_name", None) if data.non_tensor_batch else None + if lora_arr is not None: + active_adapters = list(dict.fromkeys(str(name) for name in lora_arr.tolist())) + lr_values = [] + for adapter_name in active_adapters: + sch = self.strategy.adapter_schedulers.get(adapter_name, None) + if sch is None: + continue + lr = sch.get_last_lr()[0] + metrics[f"{self.worker_config.name}/{adapter_name}/lr"] = lr + lr_values.append(float(lr)) + if lr_values: + metrics["actor/lr"] = sum(lr_values) / len(lr_values) + elif hasattr(self.strategy, "scheduler") and self.strategy.scheduler is not None: + metrics["actor/lr"] = self.strategy.scheduler.get_last_lr()[0] + # Use actual loop count instead of static formula to handle dynamic batching correctly. + metrics["actor/backward_steps"] = actual_backward_steps + data.to("cpu") + + # Keep cache lifecycle consistent with train_step to avoid stale logprob cache accumulation. + self._logprobs_cache.clear() + # Match train_step return style: no .to("cpu") since meta_info holds scalar metrics only. + output = DataProto(meta_info={"metrics": metrics}) + return output + @register(dispatch_mode=Dispatch.DP_MP_DISPATCH_FIRST) def compute_log_probs(self, data: DataProto): """ @@ -307,7 +416,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): return total_loss, pg_metrics @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def do_checkpoint(self, global_step, is_last_step=None): + def do_checkpoint(self, global_step, is_last_step=None, offload_after_checkpoint: bool = False): if self.worker_config.offload_nccl: reload_process_groups() with Timer("do_checkpoint") as total_timer: @@ -321,6 +430,10 @@ def do_checkpoint(self, global_step, is_last_step=None): exec_metrics: Dict = self.strategy.save_checkpoint( save_dir, global_step, ckpt_id, is_last_step=is_last_step ) + # Offload all states (model + optimizer) reloaded during save_checkpoint so GPU + # memory is fully released before the caller returns the GPU to the scheduler. + if offload_after_checkpoint: + self.strategy.offload_states() metrics = { f"time/{self.cluster_name}/do_checkpoint/total": total_timer.last, @@ -367,6 +480,16 @@ async def load_states(self, *args, **kwargs): async def offload_states(self, *args, **kwargs): await self.strategy.offload_states(*args, **kwargs) + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + async def process_weights_after_loading(self, *args, **kwargs): + # Keep this path fully async for InferWorker. The base Worker implementation + # uses _maybe_await with a helper thread when an event loop is already running, + # which can deadlock vLLM collective RPC coroutines waiting on the actor loop. + # Running and awaiting here ensures the coroutine stays on the actor event loop. + strategy_result = self.strategy.process_weights_after_loading(*args, **kwargs) + if inspect.isawaitable(strategy_result): + await strategy_result + @register(dispatch_mode=Dispatch.ONE_TO_ALL) async def load_states_partial(self, target_dp_ranks: List[int]): """Load states for workers whose dp_rank is in target_dp_ranks.""" @@ -381,16 +504,13 @@ async def load_states_partial(self, target_dp_ranks: List[int]): assert getattr(self, "strategy", None) is not None, "worker has no strategy to load" if self.rank_info.dp_rank in target_dp_ranks: - # AST: AST_PRECONDITION(is_model_in_gpu is False) - verify strategy offloaded before load is_loaded = self._get_strategy_load_state() - - assert is_loaded is False, ( - f"Pre-condition: strategy must be offloaded before load_states_partial, " - f"got Worker {self.rank} (DP {self.rank_info.dp_rank}) is_model_in_gpu={is_loaded}" - ) - - await self.strategy.load_states() - self.logger.info(f"Worker {self.rank} (DP {self.rank_info.dp_rank}) loaded states") + if not is_loaded: + # NOT Already loaded — vllm_strategy.add_lora() set is_model_in_gpu=True because + # custom_add_lora calls load_states() on the worker before this point. + # Nothing to do; skip the no-op collective RPC. + await self.strategy.load_states() + self.logger.info(f"Worker {self.rank} (DP {self.rank_info.dp_rank}) loaded states") else: self.logger.debug(f"Worker {self.rank} (DP {self.rank_info.dp_rank}) skipped load") @@ -462,9 +582,20 @@ async def start_model_update(self, *args, **kwargs): async def update_parameter_in_bucket(self, *args, **kwargs): await self.strategy.update_parameter_in_bucket(*args, **kwargs) + async def destroy_collective_group(self, group_name: str, model_update_name: str | None = None) -> None: + # Must be async to match InferWorker's async actor dispatch pattern. + # Without this override, Worker.destroy_collective_group (sync) runs in a threadpool + # where _maybe_await tries to drive collective_rpc_async on a fresh event loop, + # deadlocking the engine's ZMQ transport which is bound to the actor's main loop. + await self.strategy.destroy_collective_group(group_name, model_update_name) + async def add_lora(self, *args, **kwargs): await self.strategy.add_lora(*args, **kwargs) + async def verify_model(self, *args, **kwargs): + """Async override — InferWorker runs on an async event loop.""" + await self.strategy.verify_model(*args, **kwargs) + @register(dispatch_mode=Dispatch.DP_MP_COMPUTE) async def generate(self, data: DataProto): """ diff --git a/roll/pipeline/distill/distill_pipeline.py b/roll/pipeline/distill/distill_pipeline.py index 0b4dc6d8a..7d59408d0 100644 --- a/roll/pipeline/distill/distill_pipeline.py +++ b/roll/pipeline/distill/distill_pipeline.py @@ -285,7 +285,8 @@ def run(self): batch: DataProto = DataProto.from_single_dict(batch_dict) batch.meta_info = {"global_step": global_step, "is_offload_states": False, "is_offload_optimizer_states_in_train_step": False, - 'loss_mask_keys': ['labels_for_loss']} + 'loss_mask_keys': ['labels_for_loss'], + "_broadcast_non_tensor_batch": True} # Reorder data for DP rank load balancing batch_balance_metrics = batch_balance(batch, dp_size=self.student.dp_size, minibatch_size=self.batch_size) metrics_mgr.add_metrics(batch_balance_metrics) @@ -338,7 +339,8 @@ def val(self): val_loss_list = [] for batch_dict in self.val_dataloader: batch: DataProto = DataProto.from_single_dict(batch_dict) - batch.meta_info = {"is_offload_optimizer_states_in_train_step": False} + batch.meta_info = {"is_offload_optimizer_states_in_train_step": False, + "_broadcast_non_tensor_batch": True} val_metrics_refs = self.student.val_step(batch, blocking=False) val_metrics = DataProto.materialize_concat(data_refs=val_metrics_refs) val_metrics = val_metrics.meta_info.pop("metrics", {}) diff --git a/roll/pipeline/distill/distill_vlm_pipeline.py b/roll/pipeline/distill/distill_vlm_pipeline.py index 40672161a..927e84ff5 100644 --- a/roll/pipeline/distill/distill_vlm_pipeline.py +++ b/roll/pipeline/distill/distill_vlm_pipeline.py @@ -262,8 +262,10 @@ def run(self): metrics_mgr.clear_metrics() batch: DataProto = DataProto.from_single_dict(batch_dict) + # VLM batches carry multimodal data in non_tensor_batch; broadcast to all PP/TP/CP ranks. batch.meta_info = {"global_step": global_step, "is_offload_states": False, - "is_offload_optimizer_states_in_train_step": False, "loss_mask_keys": ["labels_for_loss"]} + "is_offload_optimizer_states_in_train_step": False, "loss_mask_keys": ["labels_for_loss"], + "_broadcast_non_tensor_batch": True} batch_offset = self.logits_transfer_group.apply_offset_by_dp(batch) with Timer(name="step_train", logger=None) as step_train_timer: with Timer(name="teacher_forward", logger=None) as teacher_timer: diff --git a/roll/pipeline/rlvr/rlvr_pipeline.py b/roll/pipeline/rlvr/rlvr_pipeline.py index 73d1a32ce..0d5157a20 100644 --- a/roll/pipeline/rlvr/rlvr_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_pipeline.py @@ -41,7 +41,7 @@ from roll.utils.kl_controller import get_kl_controller from roll.utils.logging import get_logger from roll.utils.metrics.metrics_manager import MetricsManager -from roll.utils.offload_states import OffloadStateType + logger = get_logger() @@ -435,7 +435,8 @@ def run(self): with Timer(name="step_stop_server", logger=None) as step_stop_server_timer: if self.pipeline_config.async_pipeline: ray.get([scheduler.pause_sampling.remote() for scheduler in self.generate_schedulers.values()]) - self.actor_infer.offload_states(include=OffloadStateType.other_params) + # Full offload: stop generation server, discard KV cache + LoRA (will restart after model update). + self.actor_infer.offload_states() metrics_mgr.add_metric("time/step_stop_server", step_stop_server_timer.last) with Timer(name="step_model_update", logger=None) as step_model_update_timer: diff --git a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py index 19a462801..5abc7d213 100644 --- a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py @@ -44,7 +44,7 @@ from roll.utils.logging import get_logger from roll.utils.metrics.metrics_manager import MetricsManager from roll.utils.packages import is_transformers_version_greater_than -from roll.utils.offload_states import OffloadStateType + logger = get_logger() @@ -473,7 +473,8 @@ def run(self): with Timer(name="step_stop_server", logger=None) as step_stop_server_timer: if self.pipeline_config.async_pipeline: ray.get([scheduler.pause_sampling.remote() for scheduler in self.generate_schedulers.values()]) - self.actor_infer.offload_states(include=OffloadStateType.other_params) + # Full offload: stop generation server, discard KV cache + LoRA (will restart after model update). + self.actor_infer.offload_states() metrics_mgr.add_metric("time/step_stop_server", step_stop_server_timer.last) with Timer(name="step_model_update", logger=None) as step_model_update_timer: diff --git a/roll/pipeline/sft/sft_pipeline.py b/roll/pipeline/sft/sft_pipeline.py index 1d30bf5f6..6dbf3304b 100644 --- a/roll/pipeline/sft/sft_pipeline.py +++ b/roll/pipeline/sft/sft_pipeline.py @@ -208,7 +208,8 @@ def run(self): with Timer(name="step_train", logger=None) as step_train_timer: batch: DataProto = DataProto.from_single_dict(batch_dict) batch.meta_info = {"global_step": global_step, "is_offload_optimizer_states_in_train_step": False, - "loss_mask_keys": ["labels"]} + "loss_mask_keys": ["labels"], + "_broadcast_non_tensor_batch": True} # Reorder data for DP rank load balancing batch_balance_metrics = batch_balance(batch, dp_size=self.sft_train.dp_size, minibatch_size=self.global_train_batch_size) diff --git a/roll/pipeline/sft/sft_worker.py b/roll/pipeline/sft/sft_worker.py index d76866b96..712514977 100644 --- a/roll/pipeline/sft/sft_worker.py +++ b/roll/pipeline/sft/sft_worker.py @@ -11,6 +11,7 @@ from roll.distributed.strategy.factory import create_strategy from roll.distributed.strategy.strategy import InferenceStrategy, TrainStrategy from roll.utils.functionals import reduce_metrics +from roll.utils.lora_routing import ensure_lora_name_in_batch from roll.models.model_providers import default_actor_model_provider from roll.platforms import current_platform @@ -30,8 +31,11 @@ def initialize(self, pipeline_config): @register(Dispatch.DP_MP_DISPATCH_FIRST, clear_cache=False) def train_step(self, data: DataProto): - data = data.to(current_platform.device_type) + # Caller must provide meta_info; guard against None for get_data_input. + if data.meta_info is None: + data.meta_info = {} data = self.strategy.get_data_input(data) + data = data.to(current_platform.device_type) metrics = self.strategy.train_step(batch=data, loss_func=self.loss_func) @@ -39,10 +43,36 @@ def train_step(self, data: DataProto): return output @register(Dispatch.DP_MP_DISPATCH_FIRST, clear_cache=False) - def val_step(self, data: DataProto): + def train_step_lora(self, data: DataProto): + """Single-adapter-per-call LoRA training step. + + Routes to ``MegatronTrainStrategy.train_step_lora`` which dispatches + the per-adapter optimizer.step() for the adapter identified by + ``non_tensor_batch["lora_name"]``. + """ + # Caller must set _broadcast_non_tensor_batch=True so LoRA routing keys reach all ranks. + if not (data.meta_info or {}).get("_broadcast_non_tensor_batch"): + raise RuntimeError( + "train_step_lora requires caller to set meta_info['_broadcast_non_tensor_batch'] = True" + ) + data = self.strategy.get_data_input(data) + # Validate/fill lora_name after broadcast — all ranks now have non_tensor_batch. + batch_size = data.batch.batch_size[0] if data.batch is not None else None + ensure_lora_name_in_batch( + data.non_tensor_batch, + adapters=self.worker_config.model_args.adapters, + batch_size=batch_size, + ) data = data.to(current_platform.device_type) + metrics = self.strategy.train_step_lora(data, loss_func=self.loss_func) + output = DataProto(meta_info={"metrics": metrics}).to("cpu") + return output + + @register(Dispatch.DP_MP_DISPATCH_FIRST, clear_cache=False) + def val_step(self, data: DataProto): data.meta_info["micro_batch_size"] = self.worker_config.infer_batch_size data = self.strategy.get_data_input(data) + data = data.to(current_platform.device_type) metrics = self.strategy.forward_step(batch=data, forward_func=self.loss_func) if metrics is None: metrics = {} @@ -66,6 +96,37 @@ def do_checkpoint(self, global_step, is_last_step=False): output = DataProto(meta_info={"metrics": metrics}) return output + # ------------------------------------------------------------------ + # Per-adapter LoRA weight management (Phase-1 multi-LoRA port) + # ------------------------------------------------------------------ + + @register(Dispatch.ONE_TO_ALL) + def get_lora_tensors(self, adapter_name: str) -> Dict[str, torch.Tensor]: + """Return a CPU copy of all LoRA parameter tensors for *adapter_name*. + + Dispatched to all workers (ONE_TO_ALL); callers typically use ``result[0]`` + (rank-0 copy) since all DP/TP ranks hold identical LoRA weights. + + Note: used only by integration tests. Not called in any production pipeline. + """ + return self.strategy.get_lora_tensors(adapter_name) + + @register(Dispatch.ONE_TO_ALL) + def set_lora_tensors(self, adapter_name: str, tensors: Dict[str, torch.Tensor]) -> int: + """Overwrite LoRA parameters for *adapter_name* in-place on all workers. + + Note: used only by integration tests. Not called in any production pipeline. + """ + return self.strategy.set_lora_tensors(adapter_name=adapter_name, tensors=tensors) + + @register(Dispatch.ONE_TO_ALL) + def copy_lora_params(self, src_adapter: str, dst_adapter: str) -> int: + """Copy LoRA parameters from *src_adapter* to *dst_adapter* on all workers. + + Note: used only by integration tests. Not called in any production pipeline. + """ + return self.strategy.copy_lora_params(src_adapter=src_adapter, dst_adapter=dst_adapter) + def loss_func(self, data: DataProto, output_tensor: torch.Tensor): labels = data.batch["labels"] batch_num_tokens = data.meta_info['batch_num_tokens']['labels'] diff --git a/roll/platforms/cpu.py b/roll/platforms/cpu.py index 3149938d3..e5a7083d0 100644 --- a/roll/platforms/cpu.py +++ b/roll/platforms/cpu.py @@ -1,4 +1,7 @@ +import os + from .platform import Platform +from ..utils.constants import DO_TIME_SHARING from ..utils.logging import get_logger @@ -6,12 +9,29 @@ class CpuPlatform(Platform): + """Platform for nodes without GPU/NPU accelerators (e.g., scheduler/coordinator nodes). + + In RLix mode (time-sharing), CPU actors spawn GPU workers on other nodes and need to + configure GPU visibility for those child workers. The device_control_env_var and + ray_experimental_noset attributes are only applied in this mode via update_env_vars_for_visible_devices. + + In standalone mode, CPU actors don't spawn GPU workers, so these attributes are unused. + """ device_name: str = "CPU" device_type: str = "cpu" dispatch_key: str = "CPU" ray_device_key: str = "CPU" communication_backend: str = "gloo" + # GPU visibility attributes: only needed in RLix mode where CPU actors spawn GPU workers. + # In standalone mode, these are None and update_env_vars_for_visible_devices() early-exits. + if DO_TIME_SHARING: + device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + ray_experimental_noset: str = "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES" + else: + device_control_env_var: str = None + ray_experimental_noset: str = None + @classmethod def clear_cublas_workspaces(cls) -> None: return diff --git a/roll/platforms/platform.py b/roll/platforms/platform.py index d9050dd31..ad1fe123c 100644 --- a/roll/platforms/platform.py +++ b/roll/platforms/platform.py @@ -48,18 +48,18 @@ class Platform: # Examples: "GPU", "NPU" ray_device_key: str - # platform-agnostic way to specify the device control environment variable, - # .e.g. CUDA_VISIBLE_DEVICES for CUDA. + # Platform-specific device visibility environment variable. # hint: search for "get_visible_accelerator_ids_env_var" in # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa # Examples: "CUDA_VISIBLE_DEVICES", "ASCEND_RT_VISIBLE_DEVICES" - device_control_env_var: str + # Set to None for platforms that don't control GPU visibility (e.g., CpuPlatform in standalone mode). + device_control_env_var: str | None - # Optional Ray experimental config - # Some accelerators require specific flags in Ray start parameters; - # leave blank if not needed + # Ray experimental flag to prevent auto-setting device visibility. + # Required when spawning workers with specific GPU assignments via placement groups. # Example: "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES" - ray_experimental_noset: str + # Set to None for platforms that don't control GPU visibility. + ray_experimental_noset: str | None # Communication backend for distributed training # Examples: "nccl", "hccl" @@ -132,7 +132,11 @@ def update_env_vars_for_visible_devices(cls, env_vars: dict, gpu_ranks: list) -> Behavior: - Sets the platform-specific visibility environment variable. - Sets the corresponding Ray experimental flag if needed. + - Skips if platform doesn't support device visibility (None attributes). """ + # Skip if platform doesn't support device visibility control (e.g., CpuPlatform in standalone mode). + if cls.device_control_env_var is None or cls.ray_experimental_noset is None: + return visible_devices_env_vars = { cls.device_control_env_var: ",".join(map(str, gpu_ranks)), cls.ray_experimental_noset: "1", diff --git a/roll/third_party/fsdp2/model_update.py b/roll/third_party/fsdp2/model_update.py index f575ef82d..3b665ff2b 100644 --- a/roll/third_party/fsdp2/model_update.py +++ b/roll/third_party/fsdp2/model_update.py @@ -47,6 +47,15 @@ class FSDP2WeightUpdater: def __init__( self, pipeline_config: PPOConfig, infer_cluster, worker_config, model_update_name: str, model, is_lora ): + # gather_fsdp2_weights and _add_lora_to_infer_workers both export only the default adapter; + # fail fast before any weight gather or broadcast in multi-LoRA configs. + if is_lora: + adapters = getattr(getattr(worker_config, "model_args", None), "adapters", None) + if adapters is not None and len(adapters) > 1: + raise RuntimeError( + "FSDP2 model_update does not support multi-LoRA. " + f"Configured adapters: {sorted(adapters.keys())}" + ) self.pipeline_config = pipeline_config self.worker_config = worker_config self.model_update_name = model_update_name diff --git a/roll/third_party/megatron/model_update.py b/roll/third_party/megatron/model_update.py index 970ae888d..558d86619 100644 --- a/roll/third_party/megatron/model_update.py +++ b/roll/third_party/megatron/model_update.py @@ -1,6 +1,6 @@ import time from dataclasses import asdict -from typing import Optional +from typing import Generator, Optional import ray import torch @@ -19,7 +19,7 @@ from roll.utils.constants import RAY_NAMESPACE from roll.utils.logging import get_logger from roll.utils.network_utils import collect_free_port, get_node_ip -from roll.utils.send_recv_utils import serialize_named_weights +from roll.utils.send_recv_utils import compute_weight_stats, serialize_named_weights if is_peft_available(): @@ -142,7 +142,7 @@ def _process_and_yield_weights(weights_info, group=None, ep_group=None): for mcore_name, weight in weights_info: weight_size = weight.numel() * weight.element_size() * group_size if buffer_size is not None and waiting_weights_size + weight_size > buffer_size: - yield gather_and_convert_weights(waiting_weights, model_converter, group, ep_group) + yield gather_and_convert_weights(waiting_weights, model_converter, group, ep_group, **kwargs) waiting_weights, waiting_weights_size = [], 0 waiting_weights.append((mcore_name, weight)) waiting_weights_size += weight_size @@ -158,10 +158,47 @@ def _process_and_yield_weights(weights_info, group=None, ep_group=None): yield from _process_and_yield_weights(other_weights_with_info, mpu.get_tensor_model_parallel_group()) -def _iter_vp_stage_named_weights(models: list[McaGPTModel], model_converter: ModelConverter): +def _iter_vp_stage_named_weights( + models: list[McaGPTModel], model_converter: ModelConverter, adapter_name: str | None = None +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Yield (mcore_name, tensor) pairs across all virtual pipeline stages. + + adapter_name=None: export base model weights only (LoRA keys filtered out, + `.base_layer.` prefix stripped to restore canonical mcore names). + adapter_name="": export only that adapter's LoRA delta weights. + Non-PeftModel: export full state dict as-is. + + NOTE: Semantic difference from upstream PEFT — upstream's + ``get_peft_model_state_dict(adapter_name=None)`` exports the *default* + adapter's LoRA state, whereas here ``adapter_name=None`` means "export + base weights only, with LoRA keys stripped". Callers must pass an + explicit adapter name to get LoRA delta weights. + """ for vp_stage, model in enumerate(models): if is_peft_available() and isinstance(model, PeftModel): - mcore_state_dict = get_peft_model_state_dict(model, model.state_dict_for_save_checkpoint()) + # adapter_name=None: export base weights only, stripping PEFT wrapper + # naming (`.base_layer.`) so ModelConverter sees canonical mcore names. + if adapter_name is None: + full_state_dict = model.state_dict_for_save_checkpoint() + mcore_state_dict = {} + for name, weight in full_state_dict.items(): + if ".lora_" in name or ".lora_embedding_" in name or ".lora_magnitude_vector" in name: + continue + # LoRA wrappers expose base tensors as "...base_layer."; converter expects + # the original tensor names without the wrapper hop. + normalized_name = name.replace(".base_layer.", ".") + # Fail fast if stripping ".base_layer." produces a key that already exists, + # which would silently overwrite an unrelated weight. + if normalized_name in mcore_state_dict: + raise ValueError( + f"base_layer name collision: '{name}' normalizes to '{normalized_name}' " + f"which already exists in state dict" + ) + mcore_state_dict[normalized_name] = weight + else: + mcore_state_dict = get_peft_model_state_dict( + model, model.state_dict_for_save_checkpoint(), adapter_name=adapter_name + ) else: mcore_state_dict = model.state_dict_for_save_checkpoint() for mcore_name, weight in sorted(mcore_state_dict.items()): @@ -171,19 +208,31 @@ def _iter_vp_stage_named_weights(models: list[McaGPTModel], model_converter: Mod yield mcore_name, weight -def gather_pp_stage_hf_weights(models: list[McaGPTModel], buffer_size, **kwargs): +def gather_pp_stage_hf_weights( + models: list[McaGPTModel], buffer_size: int, adapter_name: str | None = None, **kwargs +): # gather tp&ep weights, not including pipeline parallel if not mpu.model_parallel_is_initialized(): raise RuntimeError("Model parallelism must be initialized before save as hf inflight.") + # Resolve lora_rank from peft_config if not already provided by caller + # (gather_all_hf_weights already resolves rank and passes it in **kwargs). + if "lora_rank" not in kwargs: + lora_rank = _resolve_lora_rank(models[0], adapter_name) + if lora_rank is not None: + kwargs["lora_rank"] = lora_rank + model_config = models[0].config model_converter = ModelConverter(model_config, to_hf=True, efficient_mode=True) yield from _gather_hf_weights( - model_converter, list(_iter_vp_stage_named_weights(models, model_converter)), buffer_size, **kwargs + model_converter, + list(_iter_vp_stage_named_weights(models, model_converter, adapter_name=adapter_name)), + buffer_size, + **kwargs, ) -def gather_weights_meta_cross_pp(models: list[McaGPTModel]): +def gather_weights_meta_cross_pp(models: list[McaGPTModel], adapter_name: str | None = None): if not mpu.model_parallel_is_initialized(): raise RuntimeError("Model parallelism must be initialized before save as hf inflight.") model_config = models[0].config @@ -192,7 +241,7 @@ def gather_weights_meta_cross_pp(models: list[McaGPTModel]): pp_rank = mpu.get_pipeline_model_parallel_rank() model_converter = ModelConverter(model_config, to_hf=True, efficient_mode=True) named_weights_meta = [] - for mcore_name, weight in _iter_vp_stage_named_weights(models, model_converter): + for mcore_name, weight in _iter_vp_stage_named_weights(models, model_converter, adapter_name=adapter_name): weight_size = weight.numel() * weight.element_size() if model_converter.dist_converter.is_expert_parallel_weight(mcore_name): weight_size *= model_config.expert_model_parallel_size * model_config.expert_tensor_parallel_size @@ -222,19 +271,63 @@ def gather_weights_meta_cross_pp(models: list[McaGPTModel]): return expert_weights_meta + other_weights_meta -def gather_all_hf_weights(models: list[McaGPTModel], buffer_size: int, weights_meta: Optional[list[dict]]): - # weights_meta: list of dict, each dict is {"name": str, "shape": list, "dtype": str, "pp_stage": int, "size": int} +def _resolve_lora_rank(model: McaGPTModel, adapter_name: str | None) -> int | None: + """Resolve LoRA rank from model.peft_config, the single source of truth. + + Works for both PeftModel instances and project-specific wrappers that carry peft_config + without subclassing PeftModel. Fails fast when an explicit adapter_name is requested + but not found in the config. + """ + peft_configs = getattr(model, "peft_config", None) + if not isinstance(peft_configs, dict) or not peft_configs: + return None + + if adapter_name is None: + # Best-effort: use any adapter's rank for converter ops during base-weight export. + peft_cfg = next(iter(peft_configs.values())) + else: + peft_cfg = peft_configs.get(adapter_name) + if peft_cfg is None: + raise RuntimeError(f"Missing peft_config for adapter {adapter_name!r}") + + lora_rank = getattr(peft_cfg, "r", None) + return int(lora_rank) if lora_rank is not None else None + + +def gather_all_hf_weights( + models: list[McaGPTModel], buffer_size: int, weights_meta: Optional[list[dict]], adapter_name: str | None = None +): + """Gather weights across all parallelism dimensions (TP, EP, PP) and convert to HF naming. + + Yields batches of (hf_name, tensor) pairs, bounded by buffer_size. + + Without PP (pp_size <= 1): delegates to gather_pp_stage_hf_weights for TP/EP gather only. + With PP: broadcasts each weight from its owning PP stage to all ranks, then gathers TP/EP. + Weights are processed in buffered batches to limit peak memory. + + weights_meta: pre-computed list from gather_weights_meta_cross_pp(), each entry is + {"name": str, "shape": list, "dtype": str, "pp_stage": int, "size": int}. + Only needed when pp_size > 1; ignored otherwise. + """ if not mpu.model_parallel_is_initialized(): raise RuntimeError("Model parallelism must be initialized before save as hf inflight.") - kwargs = {} - if is_peft_available() and isinstance(models[0], PeftModel): - lora_rank = next(iter(models[0].peft_config.values())).r - kwargs = {"lora_rank": lora_rank} + kwargs: dict = {} + lora_rank = _resolve_lora_rank(models[0], adapter_name) + if lora_rank is not None: + kwargs["lora_rank"] = lora_rank + + if dist.is_initialized() and dist.get_rank() == 0: + logger.info( + "gather_all_hf_weights: adapter=%r lora_rank=%s model_cls=%s", + adapter_name, + lora_rank, + type(models[0]).__name__, + ) pp_size = models[0].config.pipeline_model_parallel_size if pp_size <= 1: - yield from gather_pp_stage_hf_weights(models, buffer_size, **kwargs) + yield from gather_pp_stage_hf_weights(models, buffer_size, adapter_name=adapter_name, **kwargs) return pp_rank = mpu.get_pipeline_model_parallel_rank() @@ -242,7 +335,8 @@ def gather_all_hf_weights(models: list[McaGPTModel], buffer_size: int, weights_m models[0].config, pipeline_model_parallel_rank=pp_rank, to_hf=True, efficient_mode=True ) cur_stage_state_dict = { - mcore_name: weight for mcore_name, weight in _iter_vp_stage_named_weights(models, model_converter) + mcore_name: weight + for mcore_name, weight in _iter_vp_stage_named_weights(models, model_converter, adapter_name=adapter_name) } def _gather_batch_params(named_weights_with_stage: list[tuple[str, torch.Tensor, int]]): @@ -292,6 +386,9 @@ def __init__( self._model_update_buffer_size = ( pipeline_config.model_update_buffer_size_mb * 1024 * 1024 ) # Convert MB to bytes + # Uses `adapters` (not legacy `lora_target`) because __post_init__ normalizes + # single-LoRA configs into the `adapters` dict, so this covers both cases. + self.is_lora = self.worker_config.model_args.adapters is not None self.infer_worker_config = infer_cluster.worker_config self.infer_cluster = infer_cluster self.is_colocated = is_actor_infer_overlapping_with_any_cluster( @@ -307,17 +404,16 @@ def __init__( # Separated mode attributes self.model_update_group_name = None self._model_update_locker = None - self._weights_meta = None if self.is_colocated: self._setup_colocated_model_update() else: self._setup_separated_model_update() - def model_update(self): + def model_update(self, adapters_to_update: list[str] | None = None): if self.is_colocated: - return self._colocated_model_update() - return self._separated_model_update() + return self._colocated_model_update(adapters_to_update=adapters_to_update) + return self._separated_model_update(adapters_to_update=adapters_to_update) def _setup_colocated_model_update(self): logger.info(f"RANK {dist.get_rank()} Setup colocated model update") @@ -355,7 +451,8 @@ def _setup_colocated_model_update(self): ) self._setup_broadcast_group() - self._weights_meta = gather_weights_meta_cross_pp(self.models_unwrapped) + # weights_meta is computed per-adapter inside _gather_and_distribute_weights() + # so that metadata names match the adapter-specific state dict keys. def _setup_separated_model_update(self): self._model_update_locker = Locker.options( @@ -416,7 +513,7 @@ def _broadcast_to_infer_workers(self, hf_named_weights) -> list[ray.ObjectRef]: names=[n for n, _ in hf_named_weights], dtypes=[w.dtype for _, w in hf_named_weights], shapes=[w.shape for _, w in hf_named_weights], - is_lora=self.worker_config.model_args.lora_target is not None, + is_lora=self.is_lora, ) for worker in self._broadcast_workers ] @@ -429,18 +526,118 @@ def _broadcast_to_infer_workers(self, hf_named_weights) -> list[ray.ObjectRef]: handle.wait() return refs - def _colocated_model_update(self): + def _colocated_model_update(self, *, adapters_to_update: list[str] | None = None): + """Transfer weights to colocated inference workers via CUDA IPC. + + LoRA mode: loops over each adapter, transfers only LoRA delta weights (not base model), + then calls add_lora() to register the adapter in the inference engine. + Base mode: transfers full model weights in a single pass. + + Returns dict with timing info and weight_stats for post-sync verification. + Only dist.get_rank()==0 reports stats (all workers have identical globally-gathered + weights via gather_all_hf_weights; picking one avoids duplication). + """ + co_infer_rank = dist.get_rank(self._infer_parallel_cpu_group) + # Only global rank 0 reports stats; all workers have identical gathered weights. + # Gated by config flag to skip stats computation when verification is disabled. + is_stats_reporter = dist.get_rank() == 0 and self.pipeline_config.verify_model_after_sync + weight_stats: dict = {} + + if self.is_lora: + peft_configs = self.models_unwrapped[0].peft_config + selected = set(adapters_to_update) if adapters_to_update is not None else None + # Materialize selected adapters once so registration order is explicit and + # consistent across colocated and broadcast worker paths. + adapter_items = [ + (adapter_name, peft_config) + for adapter_name, peft_config in peft_configs.items() + if selected is None or adapter_name in selected + ] + # Accumulate per-adapter sender stats for verification. + lora_stats: dict[str, dict[str, float]] = {} + for adapter_index, (adapter_name, peft_config) in enumerate(adapter_items): + wake_after_add = adapter_index == len(adapter_items) - 1 + batch_stats = self._gather_and_distribute_weights(adapter_name, compute_stats=is_stats_reporter) + if is_stats_reporter and batch_stats: + lora_stats[adapter_name] = batch_stats + # Register adapter on all infer workers (colocated + broadcast). + # BLOCKING: upstream was fire-and-forget which races with inference requests. + # Fix vs upstream: upstream only registered on _co_infer_worker. + # Pure colocated: add_lora on _co_infer_worker only. + # Partial overlap: add_lora on both _co_infer_worker and _broadcast_workers. + # Pure separated: handled by _separated_model_update() on _broadcast_workers. + add_lora_refs: list[ray.ObjectRef] = [] + if co_infer_rank == 0 and self._co_infer_worker is not None: + add_lora_refs.append( + self._co_infer_worker.add_lora.remote( + adapter_name=adapter_name, + peft_config=asdict(peft_config), + wake_after_add=wake_after_add, + ) + ) + if dist.get_rank() == 0 and self._broadcast_workers: + add_lora_refs.extend( + w.add_lora.remote( + adapter_name=adapter_name, + peft_config=asdict(peft_config), + wake_after_add=wake_after_add, + ) + for w in self._broadcast_workers + ) + if add_lora_refs: + ray.get(add_lora_refs) + if lora_stats: + weight_stats["lora"] = lora_stats + else: + batch_stats = self._gather_and_distribute_weights(None, compute_stats=is_stats_reporter) + if is_stats_reporter and batch_stats: + weight_stats["base"] = batch_stats + return {"weight_stats": weight_stats} + + def _gather_and_distribute_weights( + self, adapter_name: str | None = None, compute_stats: bool = False + ) -> dict[str, float]: + """Gather HF-format weights from this PP stage and push them to colocated inference workers. + + Converts Megatron-Core weights to HF naming in buffered batches, serializes each batch, + and distributes to inference workers via gather (colocated) or broadcast (remote). + + adapter_name: which adapter's weights to gather. None means base weights only + (LoRA keys stripped); a string means that adapter's LoRA delta weights. + compute_stats: if True, accumulate running sum/max/min across all batches for verification. + Returns aggregate stats dict if compute_stats=True, otherwise empty dict. + """ + # Running stats accumulators across all weight batches. + running_sum = 0.0 + running_max = float("-inf") + running_min = float("inf") + tensor_count = 0 refs = [] infer_parallel_size = dist.get_world_size(self._infer_parallel_cpu_group) co_infer_rank = dist.get_rank(self._infer_parallel_cpu_group) - if is_lora := (self.worker_config.model_args.lora_target is not None): - peft_config = self.models_unwrapped[0].peft_config.get("default", None) + # Compute weights_meta with the actual adapter_name so metadata names match + # the state dict keys used by gather_all_hf_weights (base vs LoRA names). + weights_meta = gather_weights_meta_cross_pp(self.models_unwrapped, adapter_name=adapter_name) for hf_named_weights in gather_all_hf_weights( - self.models_unwrapped, buffer_size=self._model_update_buffer_size, weights_meta=self._weights_meta + self.models_unwrapped, + buffer_size=self._model_update_buffer_size, + weights_meta=weights_meta, + adapter_name=adapter_name, ): + # Accumulate running stats across batches for post-sync verification. + if compute_stats: + batch_stats = compute_weight_stats(hf_named_weights) + if batch_stats: + running_sum += batch_stats["sum"] + running_max = max(running_max, batch_stats["max"]) + running_min = min(running_min, batch_stats["min"]) + tensor_count += 1 + if self._co_infer_worker is not None: serialized_tensors = serialize_named_weights( - hf_named_weights, infer_strategy=self.infer_worker_config.strategy_args.strategy_name + hf_named_weights, + infer_strategy=self.infer_worker_config.strategy_args.strategy_name, + model_update_transport=self.pipeline_config.model_update_transport, ) infer_parallel_tensors = [None] * infer_parallel_size if co_infer_rank == 0 else None dist.gather_object( @@ -452,32 +649,124 @@ def _colocated_model_update(self): refs = [] if co_infer_rank == 0 and self._co_infer_worker is not None: refs.append( - self._co_infer_worker.update_parameter_in_bucket.remote(infer_parallel_tensors, is_lora=is_lora) + self._co_infer_worker.update_parameter_in_bucket.remote( + infer_parallel_tensors, + is_lora=self.is_lora, + model_update_transport=self.pipeline_config.model_update_transport, + ) ) if self._broadcast_workers: refs.extend(self._broadcast_to_infer_workers(hf_named_weights)) if refs: ray.get(refs) - refs = [] - if is_lora and co_infer_rank == 0 and self._co_infer_worker is not None: - refs.append(self._co_infer_worker.add_lora.remote(peft_config=asdict(peft_config))) + if compute_stats and tensor_count > 0: + return {"sum": running_sum, "max": running_max, "min": running_min} return {} - def _separated_model_update(self): + def _separated_model_update(self, *, adapters_to_update: list[str] | None = None): + """Broadcast weights from train workers to remote (non-colocated) infer workers. + + Gathers HF-format weights from this PP stage in buffered batches and broadcasts + each batch under a distributed lock to avoid conflicts with inference requests. + + LoRA mode: iterates over each adapter, broadcasts LoRA adapter weights, then registers + the adapter on infer workers via add_lora(). + Base mode: broadcasts full model weights in a single pass. + + Returns dict with weight_stats for post-sync verification. + Only workers with _broadcast_workers report stats (dp==0, tp==0, one per PP stage). + """ if not mpu.get_expert_data_parallel_rank() == 0: return {} + # Only workers with _broadcast_workers are canonical reporters (dp==0, tp==0). + # Gated by config flag to skip stats computation when verification is disabled. + is_stats_reporter = bool(self._broadcast_workers) and self.pipeline_config.verify_model_after_sync + weight_stats: dict = {} + logger.info(f"start broadcast model update {self.model_update_name}") - for hf_named_weights in gather_pp_stage_hf_weights( - self.models_unwrapped, buffer_size=self._model_update_buffer_size - ): - if not self._broadcast_workers: - continue - while not ray.get(self._model_update_locker.acquire.remote()): - time.sleep(0.1) + if self.is_lora: + peft_configs = self.models_unwrapped[0].peft_config + selected = set(adapters_to_update) if adapters_to_update is not None else None + # Materialize selected adapters once so registration order is explicit and + # matches colocated mode behavior. + adapter_items = [ + (adapter_name, peft_config) + for adapter_name, peft_config in peft_configs.items() + if selected is None or adapter_name in selected + ] + lora_stats: dict[str, dict[str, float]] = {} + for adapter_index, (adapter_name, peft_config) in enumerate(adapter_items): + wake_after_add = adapter_index == len(adapter_items) - 1 + logger.info(f"model_update: broadcasting adapter={adapter_name!r}") + # Accumulate stats across all batches for this adapter. + running_sum = 0.0 + running_max = float("-inf") + running_min = float("inf") + batch_count = 0 + for hf_named_weights in gather_pp_stage_hf_weights( + self.models_unwrapped, + buffer_size=self._model_update_buffer_size, + adapter_name=adapter_name, + ): + if is_stats_reporter: + batch_stats = compute_weight_stats(hf_named_weights) + if batch_stats: + running_sum += batch_stats["sum"] + running_max = max(running_max, batch_stats["max"]) + running_min = min(running_min, batch_stats["min"]) + batch_count += 1 + if not self._broadcast_workers: + continue + self._broadcast_bucket_under_lock(hf_named_weights) + if is_stats_reporter and batch_count > 0: + lora_stats[adapter_name] = {"sum": running_sum, "max": running_max, "min": running_min} + # After broadcasting LoRA tensors, register the adapter on all infer workers. + if self._broadcast_workers: + logger.info(f"model_update: registering adapter={adapter_name!r} on infer workers") + ray.get( + [ + w.add_lora.remote( + adapter_name=adapter_name, + peft_config=asdict(peft_config), + wake_after_add=wake_after_add, + ) + for w in self._broadcast_workers + ] + ) + logger.info(f"model_update: adapter={adapter_name!r} registration complete") + if lora_stats: + weight_stats["lora"] = lora_stats + else: + running_sum = 0.0 + running_max = float("-inf") + running_min = float("inf") + batch_count = 0 + for hf_named_weights in gather_pp_stage_hf_weights( + self.models_unwrapped, buffer_size=self._model_update_buffer_size + ): + if is_stats_reporter: + batch_stats = compute_weight_stats(hf_named_weights) + if batch_stats: + running_sum += batch_stats["sum"] + running_max = max(running_max, batch_stats["max"]) + running_min = min(running_min, batch_stats["min"]) + batch_count += 1 + if not self._broadcast_workers: + continue + self._broadcast_bucket_under_lock(hf_named_weights) + if is_stats_reporter and batch_count > 0: + weight_stats["base"] = {"sum": running_sum, "max": running_max, "min": running_min} + return {"weight_stats": weight_stats} + + def _broadcast_bucket_under_lock(self, hf_named_weights: list[tuple[str, torch.Tensor]]) -> None: + """Acquire model_update lock, broadcast one bucket to infer workers, then release.""" + while not ray.get(self._model_update_locker.acquire.remote()): + time.sleep(0.1) + try: refs = self._broadcast_to_infer_workers(hf_named_weights) ray.get(refs) + finally: ray.get(self._model_update_locker.release.remote()) - return {} diff --git a/roll/third_party/vllm/async_llm.py b/roll/third_party/vllm/async_llm.py index 950a06ef5..ded9101a9 100644 --- a/roll/third_party/vllm/async_llm.py +++ b/roll/third_party/vllm/async_llm.py @@ -18,11 +18,34 @@ async def setup_collective_group(self, *args, **kwargs): async def broadcast_parameter(self, *args, **kwargs): await self.engine_core.collective_rpc_async(method="broadcast_parameter", args=args, kwargs=kwargs) - async def update_parameter_in_bucket(self, serialized_named_tensors, is_lora=False): - await self.engine_core.collective_rpc_async(method="update_parameter_in_bucket", args=(serialized_named_tensors, is_lora)) + async def update_parameter_in_bucket( + self, serialized_named_tensors, is_lora=False, *, ipc_local_ranks=None, model_update_transport="cuda_ipc" + ): + await self.engine_core.collective_rpc_async( + method="update_parameter_in_bucket", + args=(serialized_named_tensors, is_lora), + kwargs={"ipc_local_ranks": ipc_local_ranks, "model_update_transport": model_update_transport}, + ) + + async def destroy_collective_group(self, group_name: str): + await self.engine_core.collective_rpc_async(method="destroy_collective_group", args=(group_name,)) async def add_lora(self, *args, **kwargs): await self.engine_core.collective_rpc_async(method="custom_add_lora", args=args, kwargs=kwargs) + async def get_lora_id(self, *args, **kwargs): + # Keep adapter-id lookup on the same collective RPC path as add_lora. + return await self.engine_core.collective_rpc_async(method="custom_get_lora_id", args=args, kwargs=kwargs) + + async def list_loras(self) -> list[int]: + # Expose loaded adapter ids so strategy can verify routing readiness. + return await self.engine_core.collective_rpc_async(method="custom_list_loras") + + async def verify_model(self, expected_stats: dict | None = None) -> list: + """Dispatch custom_verify_model to all TP ranks and return per-rank stats.""" + return await self.engine_core.collective_rpc_async( + method="custom_verify_model", kwargs={"expected_stats": expected_stats} + ) + async def process_weights_after_loading(self): await self.engine_core.collective_rpc_async(method="process_weights_after_loading") diff --git a/roll/third_party/vllm/async_llm_engine.py b/roll/third_party/vllm/async_llm_engine.py index 25a7a025e..ee6ce63de 100644 --- a/roll/third_party/vllm/async_llm_engine.py +++ b/roll/third_party/vllm/async_llm_engine.py @@ -2,26 +2,29 @@ class CustomAsyncLLMEngine(AsyncLLMEngine): async def custom_init_worker(self): - self.engine.model_executor.collective_rpc(method="custom_init_worker") + await self.engine.model_executor.collective_rpc(method="custom_init_worker") async def load_states(self): - self.engine.model_executor.collective_rpc(method="load_states") + await self.engine.model_executor.collective_rpc(method="load_states") async def offload_states(self, level): - self.reset_prefix_cache() - self.engine.model_executor.collective_rpc(method="offload_states", args=(level,)) + await self.reset_prefix_cache() + await self.engine.model_executor.collective_rpc(method="offload_states", args=(level,)) async def setup_collective_group(self, *args, **kwargs): - self.engine.model_executor.collective_rpc(method="setup_collective_group", args=args, kwargs=kwargs) + await self.engine.model_executor.collective_rpc(method="setup_collective_group", args=args, kwargs=kwargs) async def broadcast_parameter(self, *args, **kwargs): - self.engine.model_executor.collective_rpc(method="broadcast_parameter", args=args, kwargs=kwargs) + await self.engine.model_executor.collective_rpc(method="broadcast_parameter", args=args, kwargs=kwargs) async def update_parameter_in_bucket(self, *args, **kwargs): - self.engine.model_executor.collective_rpc(method="update_parameter_in_bucket", args=args, kwargs=kwargs) + await self.engine.model_executor.collective_rpc(method="update_parameter_in_bucket", args=args, kwargs=kwargs) + + async def destroy_collective_group(self, group_name: str): + await self.engine.model_executor.collective_rpc(method="destroy_collective_group", args=(group_name,), kwargs={}) async def add_lora(self, *args, **kwargs): - self.engine.model_executor.collective_rpc(method="custom_add_lora", args=args, kwargs=kwargs) + await self.engine.model_executor.collective_rpc(method="custom_add_lora", args=args, kwargs=kwargs) async def process_weights_after_loading(self): await self.engine.model_executor.collective_rpc(method="process_weights_after_loading") diff --git a/roll/third_party/vllm/vllm_0_8_4/__init__.py b/roll/third_party/vllm/vllm_0_8_4/__init__.py index 633252a34..95d64cb79 100644 --- a/roll/third_party/vllm/vllm_0_8_4/__init__.py +++ b/roll/third_party/vllm/vllm_0_8_4/__init__.py @@ -13,6 +13,77 @@ from roll.third_party.vllm.async_llm import CustomAsyncLLM +# vLLM 0.8.4 compatibility: some builds call LoRALRUCache._LRUCache__update(), +# while others only provide LRUCache._LRUCache__touch(); add alias for subprocess engine path. +from vllm.lora.models import LoRALRUCache as _LoRALRUCache +from vllm.utils import LRUCache as _LRUCache +if not hasattr(_LoRALRUCache, "_LRUCache__update") and hasattr(_LRUCache, "_LRUCache__touch"): + setattr(_LoRALRUCache, "_LRUCache__update", _LRUCache._LRUCache__touch) + +# Patch vLLM v1 dummy profiling run to avoid indexing with a NumPy int64 array. +# Source: vllm/v1/worker/gpu_model_runner.py::GPUModelRunner._dummy_run (vllm==0.8.4) +# +# vllm==0.8.4 builds `logit_indices` as a NumPy array and uses it to index a torch.Tensor +# (`hidden_states[logit_indices]`). In some environments this raises: +# RuntimeError: Could not infer dtype of numpy.int64 +# Fix: convert indices to torch.LongTensor on the correct device before indexing (lines 80-81). +import vllm.v1.worker.gpu_model_runner as _v1_gpu_model_runner +import torch as _torch + +@_torch.inference_mode() +def _dummy_run_fixed(self, num_tokens: int) -> _torch.Tensor: + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = _v1_gpu_model_runner.np.array( + num_scheduled_tokens_list, dtype=_v1_gpu_model_runner.np.int32 + ) + + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): + model = self.model + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if _v1_gpu_model_runner.get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) + intermediate_tensors = _v1_gpu_model_runner.IntermediateTensors( + {k: v[:num_tokens] for k, v in self.intermediate_tensors.items()} + ) + + with _v1_gpu_model_runner.set_forward_context(None, self.vllm_config, num_tokens=num_tokens): + hidden_states = model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + logit_indices_np = _v1_gpu_model_runner.np.cumsum(num_scheduled_tokens) - 1 + logit_indices = _torch.as_tensor(logit_indices_np, device=hidden_states.device, dtype=_torch.long) + return hidden_states[logit_indices] + +_v1_gpu_model_runner.GPUModelRunner._dummy_run = _dummy_run_fixed + async def generate( self, prompt: PromptType, @@ -74,7 +145,8 @@ def abort_requests( OutputProcessor.abort_requests = abort_requests -# patch qwen3 fp8 +# Patch Qwen3 fp8 block quantization offset/size calculation. +# Source: vllm/model_executor/layers/linear.py::QKVParallelLinear.weight_loader_v2 (vllm==0.8.4) # https://github.com/vllm-project/vllm/issues/17327 # https://github.com/vllm-project/vllm/pull/17318 from vllm.model_executor.layers.linear import QKVParallelLinear diff --git a/roll/third_party/vllm/vllm_utils.py b/roll/third_party/vllm/vllm_utils.py index f8d65a86c..7c43666e2 100644 --- a/roll/third_party/vllm/vllm_utils.py +++ b/roll/third_party/vllm/vllm_utils.py @@ -54,6 +54,14 @@ class TensorLoRARequest(LoRARequest): def patch_vllm_lora_manager(): + # vLLM 0.8.4 compatibility: some builds call LoRALRUCache._LRUCache__update(), + # while the installed LRUCache implementation exposes _LRUCache__touch(). + # Provide an alias so LoRA adapter activation does not crash during engine profiling. + from vllm.lora.models import LoRALRUCache + from vllm.utils import LRUCache + if not hasattr(LoRALRUCache, "_LRUCache__update") and hasattr(LRUCache, "_LRUCache__touch"): + setattr(LoRALRUCache, "_LRUCache__update", LRUCache._LRUCache__touch) + def load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: """ based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors diff --git a/roll/third_party/vllm/worker.py b/roll/third_party/vllm/worker.py index ea82ceb40..6c7cba3e3 100644 --- a/roll/third_party/vllm/worker.py +++ b/roll/third_party/vllm/worker.py @@ -1,9 +1,11 @@ import gc import hashlib +import io import json +import pickle import time from collections import OrderedDict -from typing import Iterable, Tuple +from typing import Iterable, List, Optional, Tuple import torch import vllm @@ -13,45 +15,121 @@ from roll.third_party.vllm.vllm_utils import TensorLoRARequest, patch_vllm_lora_manager from roll.utils.collective import collective from roll.utils.cuda_ipc_utils import MultiprocessingSerializer +from roll.utils.functionals import get_dist_info_from_comm_plan from roll.utils.logging import get_logger -from roll.utils.send_recv_utils import monkey_patch_torch_reductions, named_tensors_from_bucket +from roll.utils.send_recv_utils import compute_weight_stats, monkey_patch_torch_reductions, named_tensors_from_bucket logger = get_logger() class TensorLoraManager: + """Manages LoRA adapter staging and confirmed registration for vLLM workers. + + Two concerns: + (a) Staging: collects incoming tensor weights (add_weight) before they are passed + to vLLM for ingestion via build_request. + (b) Tracking: maintains a confirmed adapter_name -> lora_int_id map (_lora_names) + so routing and readiness checks can look up integer adapter ids by name. + An entry in _lora_names means vLLM has confirmed the adapter is loaded on GPU; + it is never set speculatively. + """ + def __init__(self): self.lora_params = OrderedDict() self.add_lora_count = 0 + self._lora_names: dict[str, int] = {} # Track adapter_name -> lora_int_id for routing lookups. + # Preserve raw received tensors (HF-format) per adapter for post-sync verification. + # Populated in build_request() before lora_params is cleared; survives until next sync. + self._staged_weights: dict[str, OrderedDict] = {} + + def get_lora_id(self, adapter_name: str) -> int | None: + """Return the vLLM integer adapter id for adapter_name, or None if not yet registered. + + Returns None when the adapter has not been confirmed as loaded by vLLM on this worker. + Callers should treat None as "not ready" and retry or skip the operation. + """ + # Return None when adapter has not been registered on this worker yet. + return self._lora_names.get(adapter_name, None) + + def register(self, adapter_name: str, lora_int_id: int) -> None: + """Record a confirmed adapter registration. + + Must be called only after vLLM's add_lora succeeds. + Invariant: an entry in _lora_names means the adapter is actually loaded in vLLM on GPU. + Violation (calling before vLLM confirms) would cause routing to route to an unloaded adapter. + """ + # Called only after vLLM confirms the adapter is loaded successfully. + # Invariant: entry in _lora_names ↔ adapter successfully registered in vLLM. + self._lora_names[adapter_name] = lora_int_id def add_weight(self, name: str, weight: torch.Tensor): self.lora_params[name] = weight - def build_request(self, peft_config: dict) -> TensorLoRARequest: - """ - Generate a unique LoRA ID based on the PEFT configuration rather than - using a timestamp to assert all tp-ranks get the same LoRA ID. + def build_request(self, adapter_name: str, peft_config: dict) -> TensorLoRARequest: + """Build a TensorLoRARequest from staged weights and return it. + + Computes a stable lora_int_id from the adapter name + PEFT config so every + TP-rank worker produces the same integer id for the same adapter, regardless + of registration order. The old design used a call-order counter, which caused + different TP ranks to compute different ids when adapters were registered in + different orders — leading to NCCL group membership mismatches. + + Does NOT update _lora_names. Registration is intentionally deferred to + register(), which is called by custom_add_lora only after vLLM confirms success. + This keeps _lora_names as a strictly confirmed-state map. + + Consumes and resets self.lora_params after building the request. """ self.add_lora_count += 1 peft_config["add_lora_count"] = self.add_lora_count - peft_config_str = json.dumps(peft_config, sort_keys=True) + # Use a stable hash key (adapter + config only). Do NOT include call-order counters, + # otherwise different registration order across workers yields inconsistent adapter ids. + # Exclude add_lora_count from hash — it increments per call, producing different int_ids + # for the same adapter across sync cycles and causing vLLM LRU eviction mismatches. + peft_config_for_hash = {k: v for k, v in peft_config.items() if k != "add_lora_count"} + peft_config_for_hash["adapter_name"] = adapter_name + peft_config_str = json.dumps(peft_config_for_hash, sort_keys=True) hash_obj = hashlib.sha256(peft_config_str.encode("utf-8")) hex_dig = hash_obj.hexdigest() lora_int_id = int(hex_dig, 16) % 0x7FFFFFFF + # Do NOT set _lora_names here — registration is recorded by register() only after + # vLLM confirms the adapter loaded successfully in custom_add_lora. lora_request = TensorLoRARequest( - lora_name=f"{lora_int_id}", + lora_name=adapter_name, lora_int_id=lora_int_id, lora_path="dummy_lora_path", - peft_config=peft_config, + peft_config=peft_config_for_hash, lora_tensors=self.lora_params, ) - del self.lora_params + # Preserve raw received tensors for post-sync verification before clearing. + # These are the same HF-format tensors the sender produced, so stats comparison + # against sender stats is valid (same format, no vLLM transformation applied yet). + self._staged_weights[adapter_name] = self.lora_params + # Normal-path cleanup: transfer ownership of staged tensors to lora_request, then + # reset lora_params immediately. lora_request is a local in custom_add_lora; once + # vLLM's add_lora() copies the tensors into GPU memory and the function returns, + # lora_request goes out of scope and Python GC frees the staging buffers. + # No separate cleanup step is needed on the happy path. self.lora_params = OrderedDict() return lora_request class WorkerBase: + """Mixin that extends vLLM's WorkerExtensionCls with RLix-specific lifecycle methods. + + All methods use the "custom_" prefix to avoid name conflicts with vLLM's own worker + methods. WorkerV1 (and future V2) subclass this to inherit the full implementation; + they only override what differs between engine versions. + + Key responsibilities: + - LoRA adapter registration and lifecycle (custom_add_lora, custom_list_loras, + custom_get_lora_id). + - GPU memory lifecycle: reload_model, load_states, offload_states. + - Parameter broadcast and bucket update for model-weight synchronisation. + - NCCL collective group management for model updates. + """ + def custom_init_worker(self, *args, **kwargs): self.weight_loaded: bool = True self.kv_cache_loaded: bool = True @@ -59,21 +137,320 @@ def custom_init_worker(self, *args, **kwargs): self.buffer_cache = None self.tensor_lora_manager = TensorLoraManager() + # Use custom prefix because worker_extension_cls can not have conflicting method names with vllm worker. + def custom_add_lora( + self, + adapter_name: str, + peft_config: dict, + *, + lora_local_ranks: Optional[List[int]] = None, + wake_after_add: bool = True, + ) -> bool: + """Register a LoRA adapter with vLLM on this worker. + + Pre-condition: staged LoRA tensors have already been delivered via add_weight calls. + Post-condition: adapter is loaded in vLLM and tensor_lora_manager._lora_names[adapter_name] + is set only on success. + + Why conditional wake-up here: + LoRA tensors are allocated outside the cumem "weights" pool. If we only called + reload_model() (which wakes weights only), the KV cache would remain uninitialised. + A subsequent load_states_partial call that tries wake_up(["kv_cache"]) on a GPU + that is already near-full with model weights + LoRA tensors would OOM. + For multi-adapter updates: + - non-final adapters call reload_model() to keep broadcast memory low + - final adapter calls load_states() to initialize KV cache before rollout + We avoid follow-up strategy RPC verification after this call to prevent + reentrancy stalls. + + Registration is deferred to after vLLM confirms success so _lora_names only ever + holds adapters that are actually resident on GPU. + """ + # Partial-overlap support: skip registration on ranks not in the mask. + if lora_local_ranks is not None and self.rank not in lora_local_ranks: + return True # match existing True return convention for non-participating ranks + + # Build request with adapter name so routing can map name -> id consistently. + lora_request = self.tensor_lora_manager.build_request(adapter_name, peft_config) + lora_int_id = lora_request.lora_int_id + staged_count = len(lora_request.lora_tensors) if lora_request.lora_tensors else 0 + # Diagnostic: check if adapter is still in vLLM's Python registry. After offload_states() at + # either sleep level, the registry is cleared, so in_vllm_cache=True here means the adapter was + # registered without an intervening sleep (e.g. back-to-back add_lora calls). GPU tensors are valid here. + lora_manager = getattr(getattr(self, "model_runner", None), "lora_manager", None) + in_vllm_cache = ( + lora_int_id in lora_manager.list_adapters() + if lora_manager is not None and callable(getattr(lora_manager, "list_adapters", None)) + else None + ) + logger.info( + "[vllm][add_lora] enter adapter=%s int_id=%s staged_tensors=%s in_vllm_cache=%s weight_loaded=%s wake_after_add=%s", + adapter_name, lora_int_id, staged_count, in_vllm_cache, self.weight_loaded, wake_after_add, + ) + # Ensure weights are resident before add_lora. Final adapter also wakes KV cache. + if wake_after_add: + self.load_states() + else: + self.reload_model() + add_lora = getattr(getattr(self, "model_runner", None), "add_lora", None) + if not callable(add_lora): + raise NotImplementedError( + "vLLM worker does not expose model_runner.add_lora; " + "ensure the configured vLLM version supports runtime LoRA registration." + ) + try: + ok = add_lora(lora_request) + except Exception as exc: + logger.error( + "[vllm][add_lora] FAILED adapter=%s int_id=%s in_vllm_cache=%s exc=%s", + adapter_name, lora_int_id, in_vllm_cache, exc, + ) + raise + if ok is False: + logger.error( + "[vllm][add_lora] returned_False adapter=%s int_id=%s in_vllm_cache=%s", + adapter_name, lora_int_id, in_vllm_cache, + ) + raise RuntimeError(f"vLLM add_lora returned False for adapter={adapter_name!r}") + # vLLM confirmed success — record the registration now so _lora_names only ever + # contains adapters that are actually loaded in vLLM. + self.tensor_lora_manager.register(adapter_name, lora_request.lora_int_id) + logger.info( + "[vllm][add_lora] ok adapter=%s int_id=%s in_vllm_cache=%s", + adapter_name, lora_int_id, in_vllm_cache, + ) + return True + + def custom_list_loras(self) -> list[int]: + """Return the sorted list of vLLM integer adapter ids currently loaded on this worker. + + Queries the live vLLM LoRA manager directly rather than tensor_lora_manager._lora_names, + because _lora_names is a local Python map that is cleared on sleep(). Querying vLLM at + runtime detects evicted slots that the Python map might still show after partial failures. + + Normalises heterogeneous return types across vLLM versions: + - dict → keys are adapter ids + - list[int] → used directly + - list[str] → numeric strings cast to int; name strings resolved via _lora_names + - list[object with lora_int_id attr] → attribute extracted + + Returns an empty list when no LoRA manager is present (LoRA not enabled). + """ + # Query runtime vLLM LoRA state instead of tensor_lora_manager._lora_names. + # This allows strategy-side visibility checks to detect slots that were evicted from GPU state. + lora_manager = getattr(getattr(self, "model_runner", None), "lora_manager", None) + if lora_manager is None: + return [] + list_adapters = getattr(lora_manager, "list_adapters", None) + if not callable(list_adapters): + return [] + raw = list_adapters() + if isinstance(raw, dict): + raw = list(raw.keys()) + lora_ids = [] + for item in raw: + if isinstance(item, int): + lora_ids.append(item) + continue + # Some vLLM versions may return adapter names/ids as strings. + # Resolve names through local adapter_name->id map to keep readiness checks accurate. + if isinstance(item, str): + if item.isdigit(): + lora_ids.append(int(item)) + continue + mapped_id = self.tensor_lora_manager.get_lora_id(item) + if isinstance(mapped_id, int): + lora_ids.append(mapped_id) + continue + lora_int_id = getattr(item, "lora_int_id", None) + if isinstance(lora_int_id, int): + lora_ids.append(lora_int_id) + return sorted(set(lora_ids)) + + def custom_get_lora_id(self, adapter_name: str) -> int | None: + """Return the vLLM integer adapter id for adapter_name, or None if not yet registered. + + Provides a stable public API on the worker so strategy code does not need to reach into + tensor_lora_manager directly. Returns None when the adapter has not been confirmed loaded. + """ + # Strategy uses this to resolve adapter name into vLLM integer adapter id. + return self.tensor_lora_manager.get_lora_id(adapter_name) + + def custom_verify_model(self, expected_stats: dict) -> dict: + """Compute weight stats from this TP rank and return them for strategy-level aggregation. + + Base model: reads live GPU parameters from model_runner.model.named_parameters(). + End-to-end — these are the actual tensors used for inference. Stats are computed + in-place using .sum(dtype=float32) — no fp32 copy is allocated, only a scalar. + When LoRA modules are active, named_parameters() returns base weights only; LoRA + delta tensors are plain torch.Tensors (not nn.Parameters) stored in + lora_a_stacked/lora_b_stacked GPU buffers, so they do NOT appear in named_parameters(). + + LoRA: reads raw received tensors from tensor_lora_manager._staged_weights (transport+ + delivery verification — same HF-format as sender, before vLLM's _load_adapter + transformation). Identical across all TP ranks. + + Also performs a LoRA presence check: verifies every adapter in _lora_names exists in + vLLM's live lora_manager.list_adapters(). + + Returns per-rank stats dict for strategy-level TP aggregation (base) and comparison (LoRA). + """ + result: dict = {} + + # LoRA presence check: every registered adapter must be in vLLM's live manager. + # Direct attribute access — model_runner.lora_manager is always present on vLLM + # workers when LoRA is active (which is the only case where _lora_names is non-empty). + if self.tensor_lora_manager._lora_names: + live_ids = set(self.model_runner.lora_manager.list_adapters()) + for adapter_name, expected_id in self.tensor_lora_manager._lora_names.items(): + if expected_id not in live_ids: + raise RuntimeError( + f"verify_model: adapter {adapter_name!r} (int_id={expected_id}) " + f"not in vLLM live adapters {sorted(live_ids)}" + ) + + # Base model stats: live GPU parameters (TP-sharded per rank). + # remove_duplicate=True (default) so tied weights (e.g. embed_tokens/lm_head when + # tie_word_embeddings=True) are counted once, matching the sender's gather_all_hf_weights. + if "base" in expected_stats: + model = self.model_runner.model + base_stats = compute_weight_stats(model.named_parameters()) + result["base"] = base_stats + + # LoRA stats: raw received tensors (identical across TP ranks). + if "lora" in expected_stats: + result["lora"] = {} + for adapter_name in expected_stats["lora"]: + staged = self.tensor_lora_manager._staged_weights.get(adapter_name) + if staged is None: + raise RuntimeError( + f"verify_model: no staged weights for adapter {adapter_name!r}; " + f"available={sorted(self.tensor_lora_manager._staged_weights.keys())}" + ) + adapter_stats = compute_weight_stats(staged.items()) + result["lora"][adapter_name] = adapter_stats + + logger.info("[vllm][verify_model] rank=%s stats_keys=%s", self.rank, sorted(result.keys())) + return result + def reload_model(self): + """Allocate the GPU weight memory pool — does NOT update parameter values. + + Calls wake_up(["weights"]) to restore the CuMem "weights" pool back to GPU. + After this returns, weight tensors are addressable on GPU but their values are + whatever was there before sleep (restored from CPU at level=1, or re-initialized + at level=2). No new parameter values are written here. + + To write new trainer weights into the restored pool, call load_weights() next. + For a full wake-up (weights + KV cache), use load_states() instead. + + Idempotent: guarded by weight_loaded flag, so repeated calls are no-ops. + + The [debug][wake_up_done] log is a Stage 3 breadcrumb for memory profiling: + at this point no receive buffers exist yet (streaming approach), so + device_used = baseline + model_weights only. + """ if not self.weight_loaded: self.wake_up(["weights"]) self.weight_loaded = True + # [debug] Stage 3: model structure just allocated on GPU by wake_up. + # With the streaming approach in broadcast_parameter, no receive buffers exist yet, + # so device_used = baseline + model_weights only. + _free3, _total3 = torch.cuda.mem_get_info() + logger.info( + f"[debug][wake_up_done] " + f"device_used={(_total3 - _free3) / 1024**3:.3f}GB " + f"allocated={torch.cuda.memory_allocated() / 1024**3:.3f}GB" + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # before updating the parameters, we need to reinitialize the previously released model + """Overwrite in-GPU parameter values with the trainer's latest weights. + + This is the second step of model update, after reload_model() allocates the + weight memory pool. reload_model() makes weight tensors addressable on GPU; + load_weights() makes them correct by copying the trainer's new values in. + + Accepts a generator of (param_name, tensor) pairs so tensors arrive one at a + time (streaming), keeping peak GPU memory low during broadcast. + + LoRA alias patch: + When LoRA is active, vLLM wraps every target module at init time + (e.g. gate_up_proj → gate_up_proj.base_layer). AutoWeightsLoader then looks + for the original fused key (gate_up_proj) which no longer exists → KeyError. + Fix: temporarily monkey-patch named_parameters() on affected submodules to + also yield the unwrapped alias, then restore the original after load_weights. + """ + # Before updating parameters, reinitialize the previously released model. self.reload_model() if vllm.__version__ < "0.8.5": from roll.third_party.vllm.vllm_utils import patch_vllm_moe_model_weight_loader patch_vllm_moe_model_weight_loader(self.model_runner.model) - self.model_runner.model.load_weights(weights=weights) + # Root cause: vLLM's _create_lora_modules() permanently replaces all LoRA target modules + # with wrapper objects at LoRAModelManager init time (e.g. gate_up_proj becomes + # gate_up_proj.base_layer). AutoWeightsLoader skips the root module and directly calls + # each child's load_weights (e.g. Qwen2Model.load_weights). That child builds its own + # params_dict from self.named_parameters() and applies stacked_params_mapping + # (gate_proj -> gate_up_proj), producing a fused key that no longer exists -> KeyError. + # Fix: patch named_parameters() on every submodule that has its own load_weights + # (those are the ones AutoWeightsLoader calls directly). Each alias maps the unwrapped + # name to the same tensor as the base_layer counterpart. + model = self.model_runner.model + params_dict = dict(model.named_parameters(remove_duplicate=False)) + lora_active = any(".base_layer." in k for k in params_dict) + if not lora_active: + model.load_weights(weights=weights) + return + # Collect submodules (not root) that have their own load_weights — AutoWeightsLoader + # calls these directly. Build per-submodule aliases stripping ".base_layer.". + patches: dict = {} + for submod_name, submod in model.named_modules(): + if submod is model: + continue # AutoWeightsLoader skips root to avoid infinite recursion + if not callable(getattr(submod, "load_weights", None)): + continue + sub_params = dict(submod.named_parameters(remove_duplicate=False)) + if not any(".base_layer." in k for k in sub_params): + continue + # Build aliases stripping ".base_layer." — fail fast if both the + # wrapped key and its canonical form exist in the same submodule. + sub_aliases = {} + for param_name, param_value in sub_params.items(): + if ".base_layer." not in param_name: + continue + canonical = param_name.replace(".base_layer.", ".") + if canonical in sub_params: + raise ValueError( + f"base_layer alias collision: both '{param_name}' and '{canonical}' " + f"exist in submodule parameters" + ) + sub_aliases[canonical] = param_value + orig = submod.named_parameters + + # _make_aliased is a factory to avoid the classic Python late-binding closure bug. + # Without it, a plain lambda/def inside the loop would capture `orig` and `sub_aliases` + # by reference, so all patches would use the values from the last loop iteration. + def _make_aliased(orig_fn, aliased_dict): + def _aliased(*args, **kwargs): + yield from orig_fn(*args, **kwargs) + yield from aliased_dict.items() + return _aliased + + submod.named_parameters = _make_aliased(orig, sub_aliases) + patches[submod_name] = (submod, orig) + try: + model.load_weights(weights=weights) + finally: + for _, (submod, orig) in patches.items(): + submod.named_parameters = orig def load_states(self): + """Fully wake up this worker: model weights + KV cache. + + Idempotent: each sub-step is guarded by its own flag (weight_loaded, kv_cache_loaded). + Use this instead of reload_model() when LoRA adapters will be registered immediately + after, to avoid a later wake_up(["kv_cache"]) on a near-full GPU (OOM risk). + """ self.reload_model() if not self.kv_cache_loaded: self.wake_up(["kv_cache"]) @@ -86,10 +463,40 @@ def load_states(self): buffer.data.copy_(self.buffers[name].data) self.buffers = None - def offload_states(self, level): + def offload_states(self, level: int): + """Sleep this worker to free GPU memory, evicting LoRA state as part of the teardown. + + level=1: swap model weights to CPU, discard KV cache and LoRA tensors. + level=2: destroy everything (weights, KV cache, LoRA tensors). + + LoRA eviction rationale: + LoRA tensors use the default CuMem tag (not the "weights" tag), so sleep() at either + level discards their GPU memory. However, vLLM's Python-side LRUCacheWorkerLoRAManager + still holds entries pointing at the now-freed GPU memory. On the next add_lora call, + vLLM finds the adapter "in cache" and skips reloading, then accesses the freed memory → + CUDA error or silent corruption. + Fix: always evict stale vLLM adapter registrations here so the next add_lora always + takes the fresh-load path and applies the latest trained LoRA weights. + + Assert invariant: weight_loaded and kv_cache_loaded must be in sync — either both + True (fully awake) or both False (already offloaded). A mixed state indicates a bug. + """ assert (self.weight_loaded and self.kv_cache_loaded) or (not self.weight_loaded and not self.kv_cache_loaded) if not self.weight_loaded: + logger.info("[vllm][offload] already offloaded, skip (level=%s)", level) + # Safety-net cleanup: staged tensors survive only if staging happened but + # custom_add_lora was never called (e.g. error mid-cycle, aborted training step). + # On the normal path, build_request() already transferred ownership to a local + # lora_request that goes out of scope after add_lora() returns, freeing the + # tensors then. This block handles the abnormal path to prevent GPU leaks. + if getattr(self, "tensor_lora_manager", None) is not None and self.tensor_lora_manager.lora_params: + staged_count = len(self.tensor_lora_manager.lora_params) + self.tensor_lora_manager.lora_params = OrderedDict() + logger.info("[vllm][offload] cleared staged LoRA tensors while already-offloaded: count=%s", staged_count) return + # LoRA tensors use the default CuMem tag, not the "weights" tag, so sleep(level=1) discards them too. + _desc = "destroy weights+KV+LoRA" if level == 2 else "swap weights to CPU, discard KV+LoRA" + logger.info("[vllm][offload] sleep(level=%s) start: %s", level, _desc) if vllm.__version__ < "0.8.5" and level == 2: # https://github.com/vllm-project/vllm/issues/16564 model = self.model_runner.model @@ -99,44 +506,273 @@ def offload_states(self, level): self.kv_cache_loaded = False if hasattr(self, "recv_manager"): self.recv_manager.clear() + # Drop staged LoRA tensors so repeated selective-sync cycles do not accumulate GPU buffers. + if getattr(self, "tensor_lora_manager", None) is not None and self.tensor_lora_manager.lora_params: + staged_count = len(self.tensor_lora_manager.lora_params) + self.tensor_lora_manager.lora_params = OrderedDict() + logger.info("[vllm][offload] cleared staged LoRA tensors: count=%s", staged_count) + # LoRA tensors use the default CuMem tag, not the "weights" tag. + # sleep(level=1) therefore discards LoRA GPU memory just like level=2 does. + # vLLM's Python-side LoRA cache (LRUCacheWorkerLoRAManager) still holds entries pointing at + # now-freed GPU memory after either sleep level. On the next add_lora call, vLLM would take the + # else-branch (adapter "in cache") and skip reloading → using freed memory → CUDA error / crash. + # Fix: always evict stale vLLM adapter registrations after any sleep level, so the next add_lora + # always takes the fresh-load path and newly trained LoRA weights are applied every cycle. + if ( + getattr(self, "tensor_lora_manager", None) is not None + and self.tensor_lora_manager._lora_names + ): + lora_manager = getattr(getattr(self, "model_runner", None), "lora_manager", None) + remove_adapter = getattr(lora_manager, "remove_adapter", None) if lora_manager is not None else None + evicted = 0 + if callable(remove_adapter): + for int_id in self.tensor_lora_manager._lora_names.values(): + remove_adapter(int_id) + evicted += 1 + self.tensor_lora_manager._lora_names = {} + logger.info("[vllm][offload] cleared adapter id map and evicted vllm cache: count=%s", evicted) gc.collect() current_platform.empty_cache() + logger.info("[vllm][offload] sleep(level=%s) done: GPU memory %s", level, "fully freed" if level == 2 else "weights on CPU, KV+LoRA discarded") + + def setup_collective_group(self, *args, **kwargs): + """Initialise an NCCL collective group for model-weight broadcasting. + + Supports two call styles: + + 1. comm_plan style (RLix selective model-update): + Keyword args: comm_plan, backend, rank_in_cluster, timeout_s (optional). + Calls get_dist_info_from_comm_plan to resolve which NCCL group this worker + belongs to. If group_rank is None, this worker is not part of the update + group and the call returns immediately (skip — not an error). + Ends with a dummy allreduce barrier to verify NCCL connectivity before any + broadcast, catching misconfigured groups early. - def setup_collective_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): - group_rank = self.rank + rank_offset + 2. Legacy positional style (persistent broadcast group): + Positional args: master_address, master_port, rank_offset, world_size, + group_name, backend. Optional kwarg: timeout_s. + All workers are expected to participate. + + master_port is always cast to int to prevent type mismatch errors in collective init. + """ + # Dynamic comm_plan based group setup (selective model-update style). + if "comm_plan" in kwargs: + comm_plan = kwargs["comm_plan"] + backend = kwargs["backend"] + rank_in_cluster = int(kwargs["rank_in_cluster"]) + timeout_s = kwargs.get("timeout_s", None) + + group_rank, comm_plan_args = get_dist_info_from_comm_plan( + comm_plan, rank_in_cluster=rank_in_cluster, rank_in_worker=int(self.rank) + ) + if group_rank is None: + logger.info( + f"[rlix][vllm][collective] setup_skip " + f"rank_in_cluster={rank_in_cluster} rank_in_worker={int(self.rank)}" + ) + return + + group_name = comm_plan_args["group_name"] + master_address = comm_plan_args["master_addr"] + master_port = comm_plan_args["master_port"] + world_size = int(len(comm_plan_args["tgt_devices"]) + 1) + logger.info( + f"[rlix][vllm][collective] setup_enter group_name={group_name} " + f"rank={group_rank} world_size={world_size} master={master_address}:{master_port} " + f"timeout_s={timeout_s}" + ) + collective.init_collective_group( + world_size, + rank=int(group_rank), + backend=backend, + group_name=group_name, + master_addr=master_address, + master_port=master_port, + timeout_s=timeout_s, + ) + # Dummy allreduce barrier: verifies NCCL connectivity immediately after init. + # Detects misconfigured groups (wrong world_size, wrong ranks) before any real broadcast. + collective.allreduce(torch.zeros(1, device=current_platform.device_type), group_name=group_name) + logger.info( + f"[rlix][vllm][collective] setup_exit group_name={group_name} " + f"rank={group_rank} world_size={world_size}" + ) + return + + # Legacy / persistent broadcast group style. + if len(args) < 6: + raise TypeError( + "setup_collective_group expects either comm_plan kwargs or " + "(master_address, master_port, rank_offset, world_size, group_name, backend, timeout_s=?)." + ) + master_address, master_port, rank_offset, world_size, group_name, backend = args[:6] + timeout_s = kwargs.get("timeout_s", None) + group_rank = int(self.rank) + int(rank_offset) + logger.info( + f"[rlix][vllm][collective] setup_enter group_name={group_name} " + f"rank={group_rank} world_size={world_size} master={master_address}:{master_port} " + f"rank_offset={rank_offset} timeout_s={timeout_s}" + ) collective.init_collective_group( - world_size, + int(world_size), rank=group_rank, backend=backend, group_name=group_name, master_addr=master_address, - master_port=master_port, + master_port=int(master_port), + timeout_s=timeout_s, + ) + logger.info( + f"[rlix][vllm][collective] setup_exit group_name={group_name} " + f"rank={group_rank} world_size={world_size}" ) - logger.info(f"setup_collective_group: {group_name} rank: {group_rank} world_size: {world_size}") - def broadcast_parameter(self, names, dtypes, shapes, group_name, is_lora=False): - weights_and_handles = [] - for name, dtype, shape in zip(names, dtypes, shapes): - target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) - weight = torch.empty(shape, dtype=target_dtype, device=self.device) - handle = collective.broadcast(tensor=weight, src_rank=0, group_name=group_name, async_op=True) - weights_and_handles.append((name, weight, handle)) + def destroy_collective_group(self, group_name: str): + """Tear down an NCCL collective group and release its resources. - def weights_iter(): - for name, weight, handle in weights_and_handles: - handle.wait() - yield name, weight + Call after each model-update cycle completes to free NCCL communicator handles. + A new group will be created on the next setup_collective_group call. + + Guard: partial-overlap IPC local ranks never called setup_collective_group, so + collective.is_group_exist() returns False for them — skip destroy silently to + avoid a KeyError in collective.destroy_collective_group (collective.py:65). + """ + if not collective.is_group_exist(group_name): + logger.info( + f"[rlix][vllm][collective] destroy_skip_not_joined group_name={group_name} rank={self.rank}" + ) + return + logger.info(f"[rlix][vllm][collective] destroy_enter group_name={group_name}") + collective.destroy_collective_group(group_name) + logger.info(f"[rlix][vllm][collective] destroy_exit group_name={group_name}") + + def broadcast_parameter(self, names, dtypes, shapes, group_name, is_lora=False, *, broadcast_local_ranks=None): + """Receive broadcasted tensors from rank 0. Base weights are written to GPU immediately; + LoRA tensors are staged in tensor_lora_manager for later add_lora registration. + + is_lora=False (base model weights): + Overwrites the model's in-GPU weight tensors directly, one at a time via a streaming + generator. reload_model() is called first to ensure the weight memory pool exists, + then each tensor is received and written in-place before the next buffer is allocated. + Peak memory = model_weights + one_tensor_buffer. + + is_lora=True (LoRA adapter weights): + Does NOT write to the model. Received tensors are staged in tensor_lora_manager + and only applied to the vLLM engine later when custom_add_lora is called. + LoRA tensors are small so all receives are issued async in a batch to let NCCL + pipeline the transfers. + """ + # [debug] Stage 1: log GPU memory before any receive buffer is allocated. + # If another process still has model weights loaded, device_used will be much higher + # than the expected idle baseline (~3.5 GiB for 6 idle processes on this test config). + _free_bytes, _total_bytes = torch.cuda.mem_get_info() + _device_used_gb = (_total_bytes - _free_bytes) / 1024**3 + _alloc_gb = torch.cuda.memory_allocated() / 1024**3 + logger.info( + f"[rlix][vllm][broadcast] enter group_name={group_name} " + f"num_tensors={len(names)} is_lora={int(bool(is_lora))} " + f"[debug] device_used={_device_used_gb:.3f}GB allocated={_alloc_gb:.3f}GB " + f"device_total={_total_bytes / 1024**3:.3f}GB" + ) + + # Partial-overlap support: ranks not in the mask never joined the NCCL group; skip early. + if broadcast_local_ranks is not None and self.rank not in broadcast_local_ranks: + return if is_lora: - for name, weight in weights_iter(): + # LoRA tensors are small: keep async batch pattern so NCCL can pipeline transfers. + weights_and_handles = [] + for name, dtype, shape in zip(names, dtypes, shapes): + target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + weight = torch.empty(shape, dtype=target_dtype, device=self.device) + handle = collective.broadcast(tensor=weight, src_rank=0, group_name=group_name, async_op=True) + weights_and_handles.append((name, weight, handle)) + for name, weight, handle in weights_and_handles: + handle.wait() self.tensor_lora_manager.add_weight(name, weight) + logger.info(f"[rlix][vllm][broadcast] exit group_name={group_name} mode=lora") return - self.load_weights(weights=weights_iter()) - def update_parameter_in_bucket(self, serialized_named_tensors, is_lora=False): - monkey_patch_torch_reductions() - bucket_with_meta = MultiprocessingSerializer.deserialize(serialized_named_tensors[self.rank]) - named_params = named_tensors_from_bucket(**bucket_with_meta) + # Base weights: reload model FIRST, then stream one tensor at a time via a generator. + # Peak memory = model_weights + one_tensor_buffer (not model + ALL buffers simultaneously). + # Passing a generator to load_weights means LoRA patch logic runs ONCE for all tensors + # (O(1) named_modules scan), vs calling load_weights 290 times (O(290) scans). + self.reload_model() + + def _streaming_weights_gen(): + # One buffer at a time: allocate → blocking broadcast (wait for data) → yield → + # del _buf before the loop advances to the next tensor. This keeps peak memory at + # model_weights + one_buffer rather than model_weights + all_buffers. + for _name, _dtype, _shape in zip(names, dtypes, shapes): + _target_dtype = _dtype if isinstance(_dtype, torch.dtype) else getattr(torch, _dtype) + _buf = torch.empty(_shape, dtype=_target_dtype, device=self.device) + # Blocking broadcast: receive this tensor before allocating the next buffer. + collective.broadcast(tensor=_buf, src_rank=0, group_name=group_name, async_op=False) + yield _name, _buf + # Each parameter has a different shape (embedding, attention, MLP, bias, ...), + # so the buffer cannot be reused — a new torch.empty is required each iteration. + # del here ensures the old GPU block is returned to the CUDA caching allocator + # before the next torch.empty runs. Without it, both tensors would be alive + # simultaneously at the loop boundary → peak = 2 buffers instead of 1. + del _buf + + # load_weights calls reload_model() internally; no-op since weight_loaded=True after + # the reload_model() call above. + self.load_weights(weights=_streaming_weights_gen()) + + # [debug] Stage 4: all tensors loaded; peak (model + one_buffer) has already passed. + _free4, _total4 = torch.cuda.mem_get_info() + logger.info( + f"[debug][broadcast_load_done] group_name={group_name} " + f"device_used={(_total4 - _free4) / 1024**3:.3f}GB " + f"allocated={torch.cuda.memory_allocated() / 1024**3:.3f}GB" + ) + logger.info(f"[rlix][vllm][broadcast] exit group_name={group_name} mode=weights") + + def update_parameter_in_bucket( + self, serialized_named_tensors, is_lora=False, *, ipc_local_ranks=None, model_update_transport="cuda_ipc" + ): + """Deserialise a packed parameter bucket and apply it to the model or stage for LoRA. + + Counterpart to broadcast_parameter: same base/LoRA split, but tensors arrive + pre-packed in a serialized bucket (CUDA-IPC or CPU-bytes) instead of via NCCL broadcast. + + is_lora=False (base model weights): + Calls load_weights() to overwrite in-GPU parameter values with the unpacked tensors. + No explicit reload_model() here — load_weights() handles that internally. + + is_lora=True (LoRA adapter weights): + Stages each unpacked tensor in tensor_lora_manager.add_weight(), same as + broadcast_parameter's LoRA path. Applied to vLLM later via custom_add_lora. + + The bucket is serialised as {"bucket": , "tensors_meta": ...}. + cpu_serialize uses torch.save/torch.load format; cuda_ipc uses ForkingPickler with + cudaIpcGetMemHandle. The model_update_transport parameter selects the deserializer. + + named_params is materialised with list() because named_tensors_from_bucket returns a + generator and generators can only be consumed once. + """ + # Partial-overlap support: broadcast-only ranks receive weights via NCCL instead; + # returning early here prevents double-application of the same weights. + if ipc_local_ranks is not None and self.rank not in ipc_local_ranks: + return + raw = serialized_named_tensors[self.rank] + if model_update_transport == "cpu_serialize": + # torch.save format: PyTorch storage-aware deserialization. + # weights_only=True is safe — payload contains only {Tensor, dict, torch.dtype}. + bucket_with_meta = torch.load(io.BytesIO(raw), weights_only=True) + else: + # CUDA IPC format: pickle with patched GPU tensor reducers. + monkey_patch_torch_reductions() + bucket_with_meta = pickle.loads(raw) + bucket = bucket_with_meta["bucket"] + if not getattr(bucket, "is_cuda", False): + # Pinned DMA for CPU→GPU: ~8.5x faster than pageable .to() copy + # (319ms vs 2.7s at 3.4GB on PCIe 4.0). + bucket = bucket.contiguous().pin_memory() + bucket = bucket.to(device=self.device, non_blocking=True) + torch.cuda.current_stream().synchronize() + named_params = list(named_tensors_from_bucket(bucket=bucket, tensors_meta=bucket_with_meta["tensors_meta"])) if is_lora: for name, weight in named_params: self.tensor_lora_manager.add_weight(name, weight) @@ -157,13 +793,19 @@ def process_weights_after_loading(self): class WorkerV1(WorkerBase): + """vLLM V1 engine worker variant. + + The only V1-specific behaviour is calling patch_vllm_lora_manager() at init time. + That patch fixes vLLM's LRUCacheWorkerLoRAManager so evicted adapter entries are + properly removed from the Python-side cache, preventing stale-pointer CUDA errors on + the next add_lora call after a sleep cycle. + + All other logic (LoRA registration, weight broadcasting, collective group management, + offload/reload lifecycle) is inherited from WorkerBase. + """ + def custom_init_worker(self, *args, **kwargs): super().custom_init_worker(*args, **kwargs) patch_vllm_lora_manager() - # Use custom prefix because worker_extension_cls can not has - # conflicting method name with vllm worker. - def custom_add_lora(self, peft_config) -> bool: - lora_request = self.tensor_lora_manager.build_request(peft_config) - super().reload_model() - return self.model_runner.add_lora(lora_request) + # custom_add_lora is inherited from WorkerBase so all worker variants share adapter-name logic. diff --git a/roll/utils/checkpoint_manager.py b/roll/utils/checkpoint_manager.py index 92c6d499d..bf2f40fe5 100644 --- a/roll/utils/checkpoint_manager.py +++ b/roll/utils/checkpoint_manager.py @@ -12,7 +12,7 @@ from huggingface_hub import snapshot_download from roll.distributed.scheduler.storage import SharedStorage -from roll.utils.constants import STORAGE_NAME, RAY_NAMESPACE +from roll.utils.constants import GLOBAL_STORAGE_NAMESPACE, STORAGE_NAME from roll.utils.logging import get_logger from roll.utils.network_utils import get_node_ip from roll.utils.upload_utils import uploader_registry @@ -43,7 +43,7 @@ def wrapper(model_name_or_path: str, local_dir: Optional[str] = None): global shared_storage if shared_storage is None: shared_storage = SharedStorage.options( - name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE + name=STORAGE_NAME, get_if_exists=True, namespace=GLOBAL_STORAGE_NAMESPACE ).remote() cached_path = ray.get(shared_storage.get.remote(key=f"{node_ip}:{model_name_or_path}")) if cached_path is None or not os.path.exists(cached_path): diff --git a/roll/utils/collective/collective.py b/roll/utils/collective/collective.py index 78bcd5fcb..20e52e518 100644 --- a/roll/utils/collective/collective.py +++ b/roll/utils/collective/collective.py @@ -1,4 +1,5 @@ -from typing import Union, Optional +from datetime import timedelta +from typing import Optional, Union from torch._C._distributed_c10d import ReduceOp from torch.distributed import Backend @@ -15,23 +16,39 @@ class GroupManager: def __init__(self): """ - 基于torch ProcessGroup 实现 + ProcessGroup manager backed by torch.distributed. ref: https://github.com/ray-project/ray/blob/master/python/ray/util/collective/collective.py """ self._name_group_map = {} + # Reverse map: group object → name. Needed for backend inspection via get_group_backend(). self._group_name_map = {} - def create_collective_group(self, backend, world_size, rank, master_addr: str, master_port: int, group_name, global_ranks=None): - self._name_group_map[group_name] = init_custom_process_group( + def create_collective_group( + self, + backend, + world_size, + rank, + master_addr: str, + master_port: int, + group_name, + global_ranks=None, + timeout_s: Optional[float] = None, + ): + # Convert seconds to timedelta; None keeps the PyTorch default (1800s). + # Configurable timeout lets callers tune for slow cross-node initializations. + timeout = None if timeout_s is None else timedelta(seconds=float(timeout_s)) + group = init_custom_process_group( backend=backend, init_method=f"tcp://{master_addr}:{master_port}", + timeout=timeout, world_size=world_size, rank=rank, group_name=group_name, global_ranks=global_ranks ) - - return self._name_group_map[group_name] + self._name_group_map[group_name] = group + self._group_name_map[group] = group_name + return group def is_group_exist(self, group_name): return group_name in self._name_group_map @@ -39,18 +56,22 @@ def is_group_exist(self, group_name): def get_group_by_name(self, group_name): """Get the collective group handle by its name.""" if not self.is_group_exist(group_name): - logger.warning("The group '{}' is not initialized.".format(group_name)) - return None + # Fail fast: returning None here caused silent hangs in downstream collective ops. + raise KeyError("The group '{}' is not initialized.".format(group_name)) return self._name_group_map[group_name] def destroy_collective_group(self, group_name): """Group destructor.""" if not self.is_group_exist(group_name): - logger.warning("The group '{}' does not exist.".format(group_name)) - return + raise KeyError("The group '{}' does not exist.".format(group_name)) # release the collective group resource g = self._name_group_map[group_name] + try: + dist.destroy_process_group(g) + except Exception as e: + # Wrap with group name so callers can identify which group failed. + raise RuntimeError(f"Failed to destroy process group: group_name={group_name}") from e # clean up the dicts del self._group_name_map[g] del self._name_group_map[group_name] @@ -67,6 +88,9 @@ def init_collective_group( backend: Union[str, Backend] = current_platform.communication_backend, group_name: str = "default", global_ranks: Optional[list] = None, + # Per-group timeout (seconds). None uses PyTorch's default (1800s). + # Set explicitly for groups that span slow cross-node links. + timeout_s: Optional[float] = None, ): global _group_mgr if not group_name: @@ -78,7 +102,22 @@ def init_collective_group( assert world_size > 0 assert rank >= 0 assert rank < world_size - _group_mgr.create_collective_group(backend, world_size, rank, master_addr, master_port, group_name, global_ranks=global_ranks) + logger.info( + "[rlix][collective] init_enter " + f"group_name={group_name} backend={backend} rank={rank}/{world_size} master={master_addr}:{master_port} " + f"timeout_s={timeout_s}" + ) + _group_mgr.create_collective_group( + backend, + world_size, + rank, + master_addr, + master_port, + group_name, + global_ranks=global_ranks, + timeout_s=timeout_s, + ) + logger.info(f"[rlix][collective] init_exit group_name={group_name} rank={rank}/{world_size}") def allreduce(tensor, group_name: str = "default", op=ReduceOp.SUM): @@ -103,3 +142,21 @@ def broadcast_object_list(object_list, src=None, group_name="default", device=No assert (src is not None and group_src is None) or (src is None and group_src is not None),\ ("Either src or group_src must be set, but they cannot be set simultaneously.") dist.broadcast_object_list(object_list, src=src, group_src=group_src, group=_group_mgr.get_group_by_name(group_name)) + + +def is_group_exist(group_name: str) -> bool: + """Check if a collective group with the given name exists.""" + global _group_mgr + return _group_mgr.is_group_exist(group_name) + + +def destroy_collective_group(group_name: str) -> None: + global _group_mgr + _group_mgr.destroy_collective_group(group_name) + + +def get_group_backend(group_name: str): + # Expose backend lookup for callers that need to branch on CPU/GPU transport behavior. + global _group_mgr + group = _group_mgr.get_group_by_name(group_name) + return dist.get_backend(group) diff --git a/roll/utils/constants.py b/roll/utils/constants.py index 94e5fb875..697b9b498 100644 --- a/roll/utils/constants.py +++ b/roll/utils/constants.py @@ -1,8 +1,30 @@ import enum +import logging import os -RAY_NAMESPACE = "roll" +# Validate required env vars at import time so that misconfigured Ray workers +# crash immediately with a clear message rather than failing deep in actor init. +_RLIX_CONTROL_PLANE = os.environ.get("RLIX_CONTROL_PLANE", "") +if _RLIX_CONTROL_PLANE == "rlix": + try: + import rlix # noqa: F401 + except ImportError: + raise RuntimeError( + "RLIX_CONTROL_PLANE=rlix requires the 'rlix' package. " + "Either install rlix or set RLIX_CONTROL_PLANE to a different value." + ) +DO_TIME_SHARING = _RLIX_CONTROL_PLANE == "rlix" # True when running under RLix scheduler +if _RLIX_CONTROL_PLANE == "rlix": + ray_namespace = os.environ.get("ROLL_RAY_NAMESPACE") + if not ray_namespace: + raise RuntimeError("RLIX_CONTROL_PLANE=rlix requires ROLL_RAY_NAMESPACE to be set before importing roll.*") + pipeline_id = os.environ.get("PIPELINE_ID") + if not pipeline_id: + raise RuntimeError("RLIX_CONTROL_PLANE=rlix requires PIPELINE_ID to be set before importing roll.*") + +RAY_NAMESPACE = os.environ.get("ROLL_RAY_NAMESPACE", "roll") +GLOBAL_STORAGE_NAMESPACE = "global_storage_namespace" STORAGE_NAME = "SHARED_STORAGE_ACTOR" GENERATE_SCHEDULER_NAME = "GENERATE_SCHEDULER_ACTOR" REWARD_SCHEDULER_NAME = "REWARD_SCHEDULER_ACTOR" @@ -21,6 +43,55 @@ IGNORE_INDEX = -100 +def rlix_env_vars() -> dict[str, str]: + """Env vars for all per-pipeline Ray actor processes in RLix mode. + + Use this when creating child actors from within a pipeline actor; Ray does not reliably + inherit runtime_env env vars from parent actors. + + Only propagates vars that are explicitly set in the environment; no defaults are applied. + Thread/compile-limiting vars (OMP_NUM_THREADS, TORCH_COMPILE_DISABLE, etc.) are included + when set — configure them in the container or orchestrator environment. + """ + if not DO_TIME_SHARING: + return {} + # In RLix mode, roll.* import already validated these exist; keep them explicit here too. + pipeline_id = os.environ.get("PIPELINE_ID") + ray_namespace = os.environ.get("ROLL_RAY_NAMESPACE") + if not pipeline_id: + raise RuntimeError("DO_TIME_SHARING mode requires PIPELINE_ID to be set") + if not ray_namespace: + raise RuntimeError("DO_TIME_SHARING mode requires ROLL_RAY_NAMESPACE to be set") + + env_vars: dict[str, str] = { + "PIPELINE_ID": pipeline_id, + "ROLL_RAY_NAMESPACE": ray_namespace, + "RLIX_CONTROL_PLANE": "rlix", + } + + # Propagate PYTHONPATH if set (for imports when Ray workers start outside repo root). + if pythonpath := os.environ.get("PYTHONPATH"): + env_vars["PYTHONPATH"] = pythonpath + + # Propagate thread/compile-limiting vars only if explicitly set in the environment. + # These cap thread pools and disable TorchInductor subprocess spawning in control-plane actors. + for var in ( + "OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS", + "NUMEXPR_NUM_THREADS", "RAY_grpc_server_thread_pool_size", + "RAY_num_server_call_thread", "TORCH_COMPILE_DISABLE", + "TORCHINDUCTOR_COMPILE_THREADS", "TOKENIZERS_PARALLELISM", + ): + if value := os.environ.get(var): + env_vars[var] = value + + logging.getLogger(__name__).info( + "[rlix_env_vars] pid=%d propagating %d env vars", + os.getpid(), + len(env_vars), + ) + return env_vars + + class GenerateStopReason(enum.Enum): FINISH = enum.auto() ABORT = enum.auto() @@ -39,4 +110,4 @@ class EpisodeStopReason(enum.Enum): LLM_GENERATE_FAILED = "llm_generate_failed" UNKNOWN = "unknown" NO_SYSTEM_PROMPT = "no_system_prompt" - EVAL_GT = "eval_gt" \ No newline at end of file + EVAL_GT = "eval_gt" diff --git a/roll/utils/context_managers.py b/roll/utils/context_managers.py index 7fa574f17..832704cf0 100644 --- a/roll/utils/context_managers.py +++ b/roll/utils/context_managers.py @@ -201,7 +201,22 @@ def state_offload_manger(strategy, metrics: Dict, metric_infix: str, is_offload_ metrics[f"time/{metric_infix}/execute"] = execute_timer.last metrics[f"time/{metric_infix}/onload"] = onload_timer.last metrics[f"time/{metric_infix}/offload"] = offload_timer.last - del os.environ["roll_EXEC_FUNC_NAME"] + # Use pop(key, None) so cleanup is safe even when the yield body raises an exception. + os.environ.pop("roll_EXEC_FUNC_NAME", None) + + +@contextmanager +def state_offload_manager(strategy, metrics: Dict, metric_infix: str, is_offload_states=True, load_kwargs={}): + # Compatibility wrapper: upstream historically used a misspelled name. + # TODO(ENG-123): migrate callers to state_offload_manager(...) and remove alias in a future cleanup. + with state_offload_manger( + strategy, + metrics=metrics, + metric_infix=metric_infix, + is_offload_states=is_offload_states, + load_kwargs=load_kwargs, + ): + yield @contextmanager diff --git a/roll/utils/env_action_limiter.py b/roll/utils/env_action_limiter.py index e28d83ae9..6fb36beeb 100644 --- a/roll/utils/env_action_limiter.py +++ b/roll/utils/env_action_limiter.py @@ -1,8 +1,9 @@ import asyncio +import os import time from typing import Dict import ray -from roll.utils.constants import RAY_NAMESPACE +from roll.utils.constants import RAY_NAMESPACE, rlix_env_vars @ray.remote class GlobalLimiter: @@ -67,17 +68,25 @@ class LimiterClient: def __init__(self, tag: str = "default", max_concurrent_calls: int = 10): self.tag = tag + # Scope limiter actors per pipeline so concurrent pipelines don't share rate limits. + self.pipeline_id = os.environ.get("PIPELINE_ID") or "" self.limiter = None self.max_concurrent_calls = max_concurrent_calls self._initialize_limiter() - + def _initialize_limiter(self): """Initialize global rate limiter""" - limiter_name = f"GlobalLimiter_{self.tag}" + if self.pipeline_id: + # Prefix with pipeline_id so each pipeline gets its own Ray actor instance. + limiter_name = f"{self.pipeline_id}_GlobalLimiter_{self.tag}" + else: + limiter_name = f"GlobalLimiter_{self.tag}" self.limiter = GlobalLimiter.options( name=limiter_name, get_if_exists=True, namespace=RAY_NAMESPACE, + # Ray doesn't inherit runtime_env env vars from parent actors; propagate explicitly. + runtime_env={"env_vars": rlix_env_vars()}, ).remote(max_concurrent_calls=self.max_concurrent_calls) def acquire(self) -> str: @@ -117,9 +126,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): def get_global_limiter(tag: str = "default", max_concurrent_calls: int = 10) -> LimiterClient: """Get API rate limiter instance for specified tag""" global _global_limiters - if tag not in _global_limiters: - _global_limiters[tag] = LimiterClient(tag=tag, max_concurrent_calls=max_concurrent_calls) - return _global_limiters[tag] + pipeline_id = os.environ.get("PIPELINE_ID") or "" + # Use pipeline_id:tag as the cache key so each pipeline gets an isolated singleton. + key = f"{pipeline_id}:{tag}" if pipeline_id else tag + if key not in _global_limiters: + _global_limiters[key] = LimiterClient(tag=tag, max_concurrent_calls=max_concurrent_calls) + return _global_limiters[key] def clear_global_limiters(tag: str = None): """Clear limiter instances @@ -136,4 +148,4 @@ def clear_global_limiters(tag: str = None): def get_active_limiter_tags() -> list: """Get list of all active limiter tags""" global _global_limiters - return list(_global_limiters.keys()) \ No newline at end of file + return list(_global_limiters.keys()) diff --git a/roll/utils/functionals.py b/roll/utils/functionals.py index 0cee47c75..99632292c 100644 --- a/roll/utils/functionals.py +++ b/roll/utils/functionals.py @@ -926,10 +926,58 @@ def postprocess_generate( batch["prompt_id"] = prompt_id if logprobs is not None: batch["infer_logprobs"] = logprobs - return DataProto(batch=batch) + meta_info = dict(prompts.meta_info) if prompts.meta_info is not None else {} + + # Propagate non_tensor_batch (e.g., lora_name for multi-LoRA routing) from prompt to output. + # The prompt has N entries; the output has N*num_return_sequences entries. + # Values already at output size are copied as-is; values at prompt size are repeated once per return sequence. + non_tensor_batch = {} + if prompts.non_tensor_batch: + if prompts.batch is None: + raise RuntimeError("postprocess_generate: prompts.batch is None but non_tensor_batch is set; all callers must provide a tensor batch.") + input_batch_size = int(prompts.batch.batch_size[0]) + for key, val in prompts.non_tensor_batch.items(): + if val is None: + continue + if not isinstance(val, np.ndarray): + raise TypeError(f"non_tensor_batch[{key!r}] must be np.ndarray, got {type(val).__name__}") + if len(val) == output_batch_size: + non_tensor_batch[key] = val + elif len(val) == input_batch_size: + # np.repeat groups responses by prompt: repeat(["A","B"], K) → ["A","A","A","B","B","B"]. + # This matches the output tensor layout where rows i*K..(i+1)*K all belong to prompt i. + # np.tile would interleave instead: ["A","B","A","B"] — misaligning lora_name with output rows. + non_tensor_batch[key] = np.repeat(val, int(num_return_sequences)) + else: + raise ValueError( + f"non_tensor_batch[{key!r}] length mismatch: len={len(val)}, expected {input_batch_size} or {output_batch_size}" + ) + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch, meta_info=meta_info) def get_dist_info_from_comm_plan(comm_plan, rank_in_cluster, rank_in_worker): + """Find this worker's NCCL group rank and plan args from a comm_plan. + + comm_plan structure: + {src_rank: {"tgt_devices": [{"rank": cluster_rank, "device": {"rank": worker_rank}, ...}, ...], + "group_name": str, "master_addr": str, "master_port": int, ...}} + Each key is the sender's cluster rank. tgt_devices lists all receiver devices in + the NCCL group, ordered by insertion. The sender is always rank 0 in the group; + receivers are assigned ranks 1..N in tgt_devices order. + + Args: + comm_plan: the full plan dict as described above. + rank_in_cluster: this worker actor's cluster-level rank (Ray actor index). + rank_in_worker: this GPU's local rank within the worker (0 for single-GPU workers, + 0..TP_size-1 for tensor-parallel workers). + + Returns: + (group_rank, comm_plan_args) if this worker is a receiver in any group, + (None, None) if this worker does not appear in any group (skip setup). + + group_rank is 1-indexed: rank 1 is the first receiver, rank 2 the second, etc. + (The sender occupies rank 0 and calls setup_collective_group with rank=0 separately.) + """ for src_rank, comm_plan_args in comm_plan.items(): start_rank = 0 for tgt_device in comm_plan_args["tgt_devices"]: @@ -1286,4 +1334,3 @@ def calculate_workload(seq_len_list): metrics = {} metrics.update(global_balance_stats) return metrics - diff --git a/roll/utils/lora_routing.py b/roll/utils/lora_routing.py new file mode 100644 index 000000000..aef84b2e1 --- /dev/null +++ b/roll/utils/lora_routing.py @@ -0,0 +1,161 @@ +"""LoRA routing utilities for multi-LoRA microbatch dispatch. + +Routing contract +---------------- +Every batch that reaches a multi-LoRA worker must carry a per-sample adapter +name in ``non_tensor_batch["lora_name"]`` as an ``np.ndarray(dtype=object)``. + +- **Multi-adapter producers** (schedulers, env managers) must inject this key + and normalize adapter names via ``normalize_domain()`` before dispatch. +- **Single-adapter producers** may call ``ensure_lora_name_in_batch()`` to + auto-fill the array; it raises if the batch is multi-adapter and the key is + missing. +- **Workers** call ``resolve_microbatch_lora_name()`` to assert homogeneity + (all samples in the microbatch belong to the same adapter) and retrieve the + routing key. +""" +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any, Mapping + +import numpy as np + + +_INVALID_ADAPTER_CHARS = re.compile(r"[^a-z0-9._-]+") +_MULTI_UNDERSCORES = re.compile(r"_+") + + +def normalize_domain(domain: str) -> str: + """Canonicalize an adapter name to a lowercase slug. + + All routing lookups compare normalized names, so callers that use different + capitalizations or separators (e.g. "Math/v2", "math_v2") resolve to the + same key. Raises ``ValueError`` on an empty result so a bad name never + silently maps to the wrong adapter. + + Examples:: + + normalize_domain("Math/v2") -> "math_v2" + normalize_domain(" GPT ") -> "gpt" + normalize_domain("a--b__c") -> "a-b_c" # consecutive separators collapsed + """ + # Lowercase and strip surrounding whitespace first. + domain = domain.strip().lower() + # Replace any char outside [a-z0-9._-] with underscore. + domain = _INVALID_ADAPTER_CHARS.sub("_", domain) + # Collapse consecutive underscores and remove leading/trailing ones. + domain = _MULTI_UNDERSCORES.sub("_", domain).strip("_") + if not domain: + raise ValueError("normalize_domain() produced an empty adapter name") + return domain + + +@dataclass(frozen=True) +class LoraNameRouting: + """Resolved adapter name for one microbatch. + + Keeping both raw and normalized names allows callers to log the original + spelling for debugging while using the normalized form for registry lookups. + """ + + raw_lora_name: str # name as it appeared in non_tensor_batch["lora_name"] + lora_name: str # normalized slug used for adapter registry lookups + + +def _require_str(val: Any, *, where: str) -> str: + # numpy array items may come back as numpy.str_ rather than Python str; + # reject early so routing comparisons don't silently use the wrong type. + if not isinstance(val, str): + raise TypeError(f"Expected str for {where}, got {type(val)}") + return val + + +def get_lora_name_array(non_tensor_batch: Mapping[str, Any]) -> np.ndarray: + """Return the per-sample LoRA name array from ``non_tensor_batch["lora_name"]``. + + Raises ``RuntimeError`` if the key is absent, ``TypeError`` if the value is + not an ``np.ndarray(dtype=object)``. + """ + if "lora_name" not in non_tensor_batch: + raise RuntimeError( + 'Missing `non_tensor_batch["lora_name"]` (required for multi-LoRA routing). ' + f"Available keys={sorted(non_tensor_batch.keys())}" + ) + lora_name = non_tensor_batch["lora_name"] + if not isinstance(lora_name, np.ndarray) or lora_name.dtype != object: + raise TypeError( + f'Expected `non_tensor_batch["lora_name"]` to be np.ndarray(dtype=object), ' + f"got {type(lora_name)} dtype={getattr(lora_name, 'dtype', None)} " + f"shape={getattr(lora_name, 'shape', None)}" + ) + return lora_name + + +def ensure_lora_name_in_batch( + non_tensor_batch: dict, + *, + adapters: Mapping[str, Any] | None, + batch_size: int | None = None, +) -> None: + """Ensure ``non_tensor_batch["lora_name"]`` exists, enforcing single-vs-multi policy. + + - If the key already exists: no-op. + - If ``adapters`` is empty or None: no-op (non-LoRA path). + - If exactly one adapter is configured: auto-fill the array with that adapter's name. + ``batch_size`` is inferred from another batch key when not provided. + - If multiple adapters are configured: raise ``RuntimeError`` — producers must inject + ``lora_name`` explicitly; there is no safe default to choose. + """ + if "lora_name" in non_tensor_batch: + return + if not adapters: + return + if len(adapters) == 1: + only_key = next(iter(adapters.keys())) + # Keep this strict: infer shape or fail so callers fix producer contract early. + if batch_size is None: + if not non_tensor_batch: + raise RuntimeError( + "ensure_lora_name_in_batch: cannot auto-fill lora_name in single-adapter " + "mode with empty non_tensor_batch and no batch_size provided." + ) + batch_size = len(next(iter(non_tensor_batch.values()))) + non_tensor_batch["lora_name"] = np.full(batch_size, only_key, dtype=object) + return + raise RuntimeError( + "Missing non_tensor_batch['lora_name'] in multi-adapter mode. " + f"Configured adapters: {sorted(adapters.keys())}. " + "Producers must inject lora_name." + ) + + +def resolve_microbatch_lora_name(non_tensor_batch: Mapping[str, Any]) -> LoraNameRouting: + """Resolve the adapter name for a homogeneous microbatch. + + Asserts that every sample in the microbatch belongs to the same adapter; + raises ``RuntimeError`` if adapters are mixed or the name is not normalized. + + Workers call this immediately before dispatching to an adapter-specific + forward pass. Producers are responsible for splitting mixed batches before + this point. + """ + lora_arr = get_lora_name_array(non_tensor_batch) + if lora_arr.size == 0: + raise RuntimeError('Empty adapter name array in non_tensor_batch.') + raw_lora_names = [_require_str(d, where='adapter name item') for d in lora_arr.tolist()] + unique = sorted(set(raw_lora_names)) + if len(unique) != 1: + raise RuntimeError(f"Microbatch must be adapter-homogeneous; got adapter names={unique}") + raw_lora_name = unique[0] + normalized = normalize_domain(raw_lora_name) + # Names in the batch must already be normalized so registry lookups are exact. + # Producers (schedulers, env managers) must call normalize_domain() before + # writing lora_name into non_tensor_batch; catch violations here early. + if normalized != raw_lora_name: + raise RuntimeError( + f"Invalid adapter name={raw_lora_name!r}: expected normalized form {normalized!r}. " + "Adapter names must be lowercase alphanumeric with dots, hyphens, or underscores." + ) + return LoraNameRouting(raw_lora_name=raw_lora_name, lora_name=raw_lora_name) diff --git a/roll/utils/network_utils.py b/roll/utils/network_utils.py index a9719f6d5..9b6d3aa87 100644 --- a/roll/utils/network_utils.py +++ b/roll/utils/network_utils.py @@ -2,9 +2,16 @@ def get_node_ip(): - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.connect(("8.8.8.8", 80)) - return s.getsockname()[0] + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + finally: + s.close() + except PermissionError: + # Some sandboxed environments disallow socket creation; default to loopback. + return "127.0.0.1" def collect_free_port(): diff --git a/roll/utils/rlix_compat.py b/roll/utils/rlix_compat.py new file mode 100644 index 000000000..8cbbf2563 --- /dev/null +++ b/roll/utils/rlix_compat.py @@ -0,0 +1,21 @@ +"""Optional rlix compatibility layer.""" + +try: + from rlix.protocol.types import ( + RLIX_NAMESPACE, + COORDINATOR_ACTOR_NAME_PREFIX, + ROLL_RESOURCE_MANAGER_ACTOR_NAME, + get_pipeline_namespace, + ProgressReport, + ) + RLIX_AVAILABLE = True +except ImportError: + RLIX_AVAILABLE = False + RLIX_NAMESPACE: str = "rlix" + COORDINATOR_ACTOR_NAME_PREFIX: str = "rlix:coordinator:" + ROLL_RESOURCE_MANAGER_ACTOR_NAME: str = "rlix:roll_resource_manager" + + def get_pipeline_namespace(pipeline_id: str) -> str: + return f"pipeline_{pipeline_id}_NS" + + ProgressReport = None # type: ignore[misc,assignment] \ No newline at end of file diff --git a/roll/utils/send_recv_utils.py b/roll/utils/send_recv_utils.py index 6eab849f5..4dcd78ce9 100644 --- a/roll/utils/send_recv_utils.py +++ b/roll/utils/send_recv_utils.py @@ -1,4 +1,7 @@ -from typing import Dict +import io +import math +import pickle +from typing import Dict, Iterable import torch from torch.multiprocessing import reductions @@ -244,7 +247,67 @@ def named_tensors_from_bucket(bucket: "torch.Tensor", tensors_meta: list[dict]) return reconstructed -def serialize_named_weights(named_weights: list[tuple[str, torch.Tensor]], infer_strategy: str): +def compute_weight_stats(named_tensors: Iterable[tuple[str, torch.Tensor]]) -> dict[str, float]: + """Accumulate running sum/max/min across all tensors for verification. + + Uses dtype=torch.float32 in .sum() to avoid bf16/fp16 overflow without + creating a full fp32 copy (only a scalar is allocated). max/min do not + overflow so they use the original dtype. + Returns {"sum": float, "max": float, "min": float} or empty dict if no tensors. + """ + running_sum = 0.0 + running_max = float("-inf") + running_min = float("inf") + tensor_count = 0 + for _name, tensor in named_tensors: + # .sum(dtype=float32) accumulates in fp32 without allocating a full copy — only a scalar. + running_sum += tensor.detach().sum(dtype=torch.float32).item() + # max/min are safe in original dtype (no overflow risk for extrema). + running_max = max(running_max, tensor.detach().max().item()) + running_min = min(running_min, tensor.detach().min().item()) + tensor_count += 1 + if tensor_count == 0: + return {} + return {"sum": running_sum, "max": running_max, "min": running_min} + + +def verify_weight_stats( + actual: dict[str, float], + expected: dict[str, float], + label: str, + rel_tol: float = 1e-4, +) -> None: + """Compare actual vs expected weight stats; raise RuntimeError on mismatch.""" + for stat_key in ("sum", "max", "min"): + actual_val = actual.get(stat_key) + expected_val = expected.get(stat_key) + if actual_val is None or expected_val is None: + raise RuntimeError( + f"verify_weight_stats({label}): missing '{stat_key}' — " + f"actual_keys={sorted(actual.keys())} expected_keys={sorted(expected.keys())}" + ) + if not math.isclose(actual_val, expected_val, rel_tol=rel_tol): + raise RuntimeError( + f"verify_weight_stats({label}): {stat_key} mismatch — " + f"actual={actual_val} expected={expected_val} rel_tol={rel_tol}" + ) + + +def serialize_named_weights( + named_weights: list[tuple[str, torch.Tensor]], + infer_strategy: str, + model_update_transport: str = "cuda_ipc", +) -> bytes: + """Serialize named weight tensors into bytes for cross-process transfer. + + Args: + named_weights: list of (name, tensor) pairs to serialize. + infer_strategy: inference backend name ("sglang" or "vllm"). + model_update_transport: "cuda_ipc" (default) for CUDA IPC via ForkingPickler, + or "cpu_serialize" for CPU byte serialization via standard pickle. The + cpu_serialize fallback avoids pidfd_getfd errors in restricted containers. + """ + # sglang path — unchanged, always uses ForkingPickler + CUDA IPC. if infer_strategy == "sglang": from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket @@ -266,14 +329,34 @@ def serialize_named_weights(named_weights: list[tuple[str, torch.Tensor]], infer serialized_tensors = MultiprocessingSerializer.serialize(flattened_tensor_data) return serialized_tensors + # vLLM path — transport-dependent serialization. bucket, tensors_meta = _bucket_named_tensors(named_weights) - # PumpkinComment: - # FSDP2 will fail if using CPUOffload Policy without this check - if not getattr(bucket, "is_cuda", False): - bucket = bucket.to(current_platform.device_type).contiguous() - - monkey_patch_torch_reductions() - - serialized_tensors = MultiprocessingSerializer.serialize({"bucket": bucket, "tensors_meta": tensors_meta}) - return serialized_tensors + if model_update_transport == "cpu_serialize": + # CPU serialization fallback for restricted containers where CUDA IPC is + # unavailable. torch.save gives ~1.6x speedup over pickle on large tensors + # (storage-aware raw bytes format vs pickle per-object protocol overhead). + if getattr(bucket, "is_cuda", False): + # Pinned host buffer + non_blocking DMA: ~10x faster than pageable + # .cpu() for GPU→CPU copy (270ms vs 2.7s at 3.4GB on PCIe 4.0). + pinned_bucket = torch.empty_like(bucket, device="cpu").pin_memory() + pinned_bucket.copy_(bucket, non_blocking=True) + torch.cuda.current_stream().synchronize() + bucket = pinned_bucket + bucket = bucket.contiguous() + buf = io.BytesIO() + torch.save({"bucket": bucket, "tensors_meta": tensors_meta}, buf) + return buf.getvalue() + elif model_update_transport == "cuda_ipc": + # CUDA IPC path (default): tensor stays on GPU, serialized via ForkingPickler + # which uses cudaIpcGetMemHandle. Requires CAP_SYS_PTRACE on Linux 5.6+. + if not getattr(bucket, "is_cuda", False): + bucket = bucket.to(current_platform.device_type) + bucket = bucket.contiguous() + monkey_patch_torch_reductions() + return MultiprocessingSerializer.serialize({"bucket": bucket, "tensors_meta": tensors_meta}) + else: + raise ValueError( + f"Unsupported model_update_transport: {model_update_transport!r}. " + f"Expected 'cuda_ipc' or 'cpu_serialize'." + ) diff --git a/roll/utils/tracking.py b/roll/utils/tracking.py index 785eca881..d6a518ebf 100644 --- a/roll/utils/tracking.py +++ b/roll/utils/tracking.py @@ -1,4 +1,5 @@ import json +import os from functools import wraps from typing import Optional, Dict, Any @@ -158,6 +159,40 @@ def finish(self): pass +def create_lora_tracker( + tracker_name: str, + lora_name: str, + config: dict, + **base_kwargs: Any, +) -> BaseTracker: + """Create a tracker instance scoped to a single LoRA adapter. + + Shapes kwargs per-backend so each LoRA gets its own run/log directory: + - TensorBoard: appends lora_name to log_dir + - W&B: appends lora_name to run name + - Swanlab: appends lora_name to experiment_name + - Stdout: no change (stateless) + """ + from copy import deepcopy + kwargs = deepcopy(base_kwargs) + + if tracker_name == "tensorboard": + # Each LoRA adapter gets its own TensorBoard subdirectory. + if "log_dir" in kwargs: + kwargs["log_dir"] = os.path.join(kwargs["log_dir"], lora_name) + elif tracker_name == "wandb": + # Each LoRA adapter gets its own W&B run with a suffixed name. + base_name = kwargs.get("name") + kwargs["name"] = f"{base_name}/{lora_name}" if base_name else lora_name + elif tracker_name == "swanlab": + # Each LoRA adapter gets its own Swanlab experiment. + base_exp = kwargs.get("experiment_name") + kwargs["experiment_name"] = f"{base_exp}/{lora_name}" if base_exp else lora_name + # stdout: no per-lora shaping needed + + return create_tracker(tracker_name=tracker_name, config=config, **kwargs) + + def create_tracker(tracker_name: str, config: dict, **kwargs) -> BaseTracker: if not tracker_name: return BaseTracker() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_isolated_single_lora_step_equivalence.py b/tests/integration/test_isolated_single_lora_step_equivalence.py new file mode 100644 index 000000000..e7b1afd1c --- /dev/null +++ b/tests/integration/test_isolated_single_lora_step_equivalence.py @@ -0,0 +1,881 @@ +""" +Integration tests: isolated single-LoRA step equivalence (sequential clusters). + +Strategy +-------- +Run the two clusters **sequentially** on the *same* GPU set so GPU requirements are +halved compared to running them in parallel. + +Phase 1 — isolated cluster (multi-LoRA, ROLL ported strategy): + - Register all adapters under ``is_lora_optimizer_isolated=True``. + - For each adapter in turn, run ``train_step_lora`` for *n_steps* steps. + - Record the scalar loss returned at every step. + - Teardown. + +Phase 2 — reference clusters (upstream single-LoRA, standard megatron_train): + - For each adapter, create a **fresh** single-adapter cluster on the *same* GPUs. + - Restore the adapter's initial weights (saved before Phase 1). + - Run ``train_step`` for the same *n_steps* steps with the same token tensors. + - Record the scalar loss at every step. + - Teardown. + +Assertion +--------- + loss[adapter][step] from Phase 1 == loss[adapter][step] from Phase 2 + (``torch.testing.assert_close(rtol=1e-5, atol=1e-6)`` on every scalar). + +Test matrix +----------- +| TC | dp | tp | pp | Adapters | GPUs needed | +|----|----|----|----|----------|---------------| +| 1 | 1 | 1 | 1 | a, b | 1 (dp*tp*pp) | +| 2 | 2 | 1 | 1 | a, b, c | 2 (dp*tp*pp) | +| 3 | 1 | 2 | 1 | a, b, c | 2 (dp*tp*pp) | +| 4 | 2 | 2 | 1 | a, b, c | 4 (dp*tp*pp) | +| 5 | 1 | 1 | 2 | a, b, c | 2 (dp*tp*pp) | +| 6 | 1 | 2 | 2 | a, b, c | 4 (dp*tp*pp) | +| 7 | 2 | 1 | 2 | a, b, c | 4 (dp*tp*pp) | + +Determinism contract +-------------------- +For the two-phase sequential design to produce numerically identical losses, ALL +stochastic operations must be either eliminated or seeded identically across phases. +The test enforces this via four mechanisms: + +1. ``lora_dropout=0.0`` + LoRA adapter layers have no dropout, removing the primary source of + LoRA-specific stochasticity. + +2. ``model_config_kwargs={"attention_dropout": 0.0, "hidden_dropout": 0.0}`` + The frozen base model's dropout layers (attention and hidden) affect the + activations that flow back through LoRA parameters. Even though base weights + are frozen, non-zero dropout yields different activation patterns across phases + (because the global RNG state advances during Phase 1 adapter_a training and + is NOT at the same position when Phase 2 reference for adapter_b starts from a + fresh seed). Setting both to 0.0 via ``model_config_kwargs`` eliminates this + dependence on RNG state entirely. + NOTE: Qwen2.5-0.5B-Instruct already ships with attention_dropout=0.0, so this + is defensive rather than corrective for that model, but is required for safety. + +3. ``is_offload_optimizer_states_in_train_step=False`` in microbatch meta_info + Prevents asynchronous CPU↔GPU optimizer-state offload between steps, which + could introduce timing-dependent numerical differences. + +4. ``pipeline_config.seed=42`` (same for all clusters) + Megatron uses this seed to initialise its per-rank RNG tracker. Both clusters + are seeded identically so any remaining RNG-dependent operation (e.g., Megatron + TP dropout, weight init) starts from the same state. + +Phase 1 dependencies (must be ported into ROLL before tests pass): + - ``MegatronTrainStrategy.train_step_lora`` with ``is_lora_optimizer_isolated=True`` + - ``Worker.train_step_lora`` + - ``Worker.{get_lora_tensors, set_lora_tensors, copy_lora_params}`` +""" +import os +import random +import uuid +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pytest +import ray +import torch + +from roll.configs.model_args import LoraArguments, ModelArguments +from roll.configs.training_args import TrainingArguments +from roll.configs.worker_config import StrategyArguments, WorkerConfig +from roll.distributed.executor.cluster import Cluster +from roll.distributed.scheduler.protocol import DataProto +from roll.distributed.scheduler.resource_manager import ResourceManager +from roll.distributed.scheduler.storage import SharedStorage +from roll.utils.constants import RAY_NAMESPACE, STORAGE_NAME + +# Worker name shared between the two phases so loss key extraction is uniform. +_WORKER_NAME = "sft_train" + +# ---- Determinism: zero out ALL base-model dropout (see module docstring §2) ---- +# These kwargs are forwarded to the Hugging Face / Megatron model config so that +# attention softmax dropout and hidden-state FF dropout are disabled for every +# cluster in both phases. This ensures forward-pass activations are deterministic +# regardless of the global PyTorch RNG state. +_ZERO_DROPOUT_MODEL_CONFIG_KWARGS: dict = { + "attention_dropout": 0.0, + "hidden_dropout": 0.0, + +} +_LORA_TARGETS = "all-linear,all-router" + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +def _unique_cluster_name(prefix: str) -> str: + return f"{prefix}_{uuid.uuid4().hex}" + + +def _ensure_shared_storage() -> None: + try: + SharedStorage.options(name=STORAGE_NAME, get_if_exists=True, namespace=RAY_NAMESPACE).remote() + except Exception: + SharedStorage.options(name=STORAGE_NAME, namespace=RAY_NAMESPACE).remote() + + +def _ray_init() -> None: + if ray.is_initialized(): + ray.shutdown() + ray.init(namespace=RAY_NAMESPACE, ignore_reinit_error=True, log_to_driver=False) + _ensure_shared_storage() + + +def _seed_driver(seed: int = 42) -> None: + """Seed the driver-process RNG. + + Ray worker processes are seeded via ``pipeline_config.seed``; this seeds the + driver-side Python/NumPy/Torch state for any host-side random operations + (e.g. generating token sequences). Call before each cluster creation phase + so both phases start from the same host-side RNG position. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def _make_pipeline_config(*, seed: int = 42, sequence_length: int = 64) -> SimpleNamespace: + return SimpleNamespace( + seed=seed, + max_grad_norm=1.0, + sequence_length=sequence_length, + resume_from_checkpoint=False, + model_update_buffer_size_mb=256, + is_actor_infer_colocated=False, + ) + + +def _download_model(model_id: str) -> str: + """Download model from Hugging Face and return the local snapshot path.""" + from huggingface_hub import snapshot_download + + return snapshot_download(repo_id=model_id) + + +def _system_envs() -> dict: + root = Path(__file__).resolve().parents[2] + pythonpath = os.pathsep.join([str(root), str(root / "mcore_adapter" / "src")]) + return {"PYTHONPATH": pythonpath} + + +def _isolated_worker_config( + *, + adapter_names: list[str], + model_dir: str, + dp: int, + tp: int, + pp: int = 1, + gradient_accumulation_steps: int = 1, +) -> WorkerConfig: + """WorkerConfig for the isolated multi-LoRA cluster. + + Determinism: + - ``lora_dropout=0.0`` — no randomness in LoRA layers. + - ``model_config_kwargs`` — zeros attention & hidden dropout in the base model + so frozen base-model activations are deterministic regardless of RNG state. + """ + adapters = { + name: LoraArguments(lora_rank=8, lora_alpha=16, lora_dropout=0.0, lora_target=_LORA_TARGETS) + for name in adapter_names + } + return WorkerConfig( + name=_WORKER_NAME, + worker_cls="roll.pipeline.sft.sft_worker.SFTWorker", + model_args=ModelArguments( + model_name_or_path=model_dir, + dtype="bf16", + adapters=adapters, + model_config_kwargs=_ZERO_DROPOUT_MODEL_CONFIG_KWARGS, + ), + training_args=TrainingArguments( + max_steps=999, # effectively unlimited; we drive steps externally + per_device_train_batch_size=1, + gradient_accumulation_steps=gradient_accumulation_steps, + learning_rate=1e-4, + weight_decay=0.0, + ), + strategy_args=StrategyArguments( + strategy_name="megatron_train", + strategy_config={ + "tensor_model_parallel_size": tp, + "pipeline_model_parallel_size": pp, + "expert_model_parallel_size": 1, + "context_parallel_size": 1, + "overlap_p2p_comm": False, + "use_distributed_optimizer": False, # required by isolated prototype + "is_lora_optimizer_isolated": True, + }, + ), + device_mapping=f"list(range(0, {dp * tp * pp}))", + infer_batch_size=1, + system_envs=_system_envs(), + ) + + +def _reference_worker_config( + *, + adapter_name: str, + model_dir: str, + dp: int, + tp: int, + pp: int = 1, + gradient_accumulation_steps: int = 1, +) -> WorkerConfig: + """WorkerConfig for an upstream single-LoRA reference cluster. + + Uses the *same* GPU set as the isolated cluster (sequential execution). + + Determinism: applies the same ``model_config_kwargs`` and ``lora_dropout=0.0`` + as the isolated cluster so both phases are identically dropout-free. + """ + adapters = { + adapter_name: LoraArguments(lora_rank=8, lora_alpha=16, lora_dropout=0.0, lora_target=_LORA_TARGETS) + } + return WorkerConfig( + name=_WORKER_NAME, + worker_cls="roll.pipeline.sft.sft_worker.SFTWorker", + model_args=ModelArguments( + model_name_or_path=model_dir, + dtype="bf16", + adapters=adapters, + model_config_kwargs=_ZERO_DROPOUT_MODEL_CONFIG_KWARGS, + ), + training_args=TrainingArguments( + max_steps=999, + per_device_train_batch_size=1, + gradient_accumulation_steps=gradient_accumulation_steps, + learning_rate=1e-4, + weight_decay=0.0, + ), + strategy_args=StrategyArguments( + strategy_name="megatron_train", + strategy_config={ + "tensor_model_parallel_size": tp, + "pipeline_model_parallel_size": pp, + "expert_model_parallel_size": 1, + "context_parallel_size": 1, + "overlap_p2p_comm": False, + "use_distributed_optimizer": False, + }, + ), + device_mapping=f"list(range(0, {dp * tp * pp}))", + infer_batch_size=1, + system_envs=_system_envs(), + ) + + +def _make_microbatch(input_ids: torch.Tensor, adapter_name: str, global_step: int) -> DataProto: + """Build a single-row DataProto microbatch routed to *adapter_name*. + + Determinism: ``is_offload_optimizer_states_in_train_step=False`` disables the + async CPU↔GPU optimizer-state offload that happens between steps. In + ``isolated`` mode the optimizer states are always kept resident anyway, but + setting this on the reference cluster prevents any timing-dependent numerical + differences from asynchronous offload. + """ + attention_mask = torch.ones_like(input_ids) + labels = input_ids.clone() + mb = DataProto.from_single_dict( + {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} + ) + mb.non_tensor_batch["lora_name"] = np.array([adapter_name] * input_ids.shape[0], dtype=object) + mb.meta_info = { + "lora_name": adapter_name, + "global_step": global_step, + "_broadcast_non_tensor_batch": True, + # Disable async optimizer-state offload to remove a potential source of + # timing-dependent numerical non-determinism between the two phases. + "is_offload_optimizer_states_in_train_step": False, + "loss_mask_keys": ["labels"], + } + return mb + + +def _extract_loss(result: DataProto) -> float: + """Extract the scalar loss from a train_step / train_step_lora DataProto result. + + Checks both ``{worker_name}/loss`` (upstream convention) and + ``{worker_name}/loss@sum`` (ROLL convention). + """ + metrics: dict = result.meta_info.get("metrics", {}) if result.meta_info else {} + for key in (f"{_WORKER_NAME}/loss", f"{_WORKER_NAME}/loss@sum"): + if key in metrics: + val = metrics[key] + # val may be a tensor or a list of tensors (append_to_dict accumulates into lists) + if isinstance(val, (list, tuple)): + val = val[0] + if isinstance(val, torch.Tensor): + return float(val.mean().item()) + return float(val) + available = list(metrics.keys()) + raise KeyError( + f"Expected loss key '{_WORKER_NAME}/loss' (or '/loss@sum') in metrics but got: {available}. " + "Check that the SFTWorker's loss_func emits the expected key." + ) + + +def _shutdown(cluster: Cluster) -> None: + try: + cluster.execute_all_sync("shutdown") + except Exception: + pass + for worker in getattr(cluster, "workers", []): + try: + ray.kill(worker, no_restart=True) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Core test logic, shared by all 4 test cases +# --------------------------------------------------------------------------- + +def _run_equivalence_test( + *, + adapter_names: list[str], + dp: int, + tp: int, + pp: int = 1, + model_dir: str, + resource_manager: ResourceManager, + pipeline_config: SimpleNamespace, + n_steps: int = 3, + seed: int = 42, + phase1_order: str = "sequential", +) -> None: + """ + Phase 1: isolated multi-LoRA cluster + ---------------------------------------- + 1. Create cluster (all adapters, ``is_lora_optimizer_isolated=True``). + 2. Seed all adapters with identical initial weights (copy from first). + 3. Save those initial weights for Phase 2 reference clusters. + 4. Train all adapters for *n_steps* steps under one of two orderings + controlled by *phase1_order*: + + ``"sequential"`` — train every step of adapter A before touching adapter B: + + for adapter in adapter_names: + for step in range(n_steps): + train_step_lora(adapter, step) + + ``"interleaved"`` — round-robin across adapters, one step at a time: + + for step in range(n_steps): + for adapter in adapter_names: + train_step_lora(adapter, step) + + Both orderings must produce the *same* per-adapter per-step loss because + ``isolated`` mode isolates each adapter's optimizer state so that one + adapter's step does NOT affect any other adapter's weight or momentum. + 5. Teardown cluster. + + Phase 2: upstream single-LoRA reference clusters (sequential, same GPUs) + -------------------------------------------------------------------------- + 6. For each adapter: + a. Create a **fresh** single-adapter cluster on the *same* GPUs. + b. Restore this adapter's initial weights (saved in step 3). + c. Run ``train_step`` for *n_steps* steps with the same token tensors + and the same ``pipeline_config.seed``. + d. Collect per-step loss. + e. Teardown cluster. + + Assertion + --------- + For every (adapter, step) pair: + isolated_loss[adapter][step] == reference_loss[adapter][step]. + + Determinism + ----------- + - ``lora_dropout=0.0`` in both WorkerConfigs. + - ``model_config_kwargs`` forces ``attention_dropout=0.0`` and + ``hidden_dropout=0.0`` so frozen base-model activations are RNG-independent. + - ``is_offload_optimizer_states_in_train_step=False`` in every microbatch. + - Driver-side RNG is reset via ``_seed_driver(seed)`` before both phases. + - Both clusters use the same ``pipeline_config.seed`` (worker-side Megatron RNG). + """ + debug_trace = os.environ.get("RLIX_DEBUG_ISOLATED_LORA", "") not in ("", "0", "false", "False") + + # Fixed token sequences, one per step (different steps → different data, + # making the multi-step comparison more discriminating). + # These are generated with a deterministic formula so they don't depend on + # host-side RNG state (same tensors across phases). + # Replicate batch across dp-ranks so dispatch_dp_mp_dispatch_first can chunk + # the batch evenly (batch_size must be >= dp). Each dp rank receives an + # identical row so the per-rank loss equals the single-rank reference loss. + # Megatron PP with non-interleaved schedule needs >=2 microbatches in practice. + # Keep GA=1 for non-PP tests, and GA=2 for PP tests to avoid PP stalls. + ga_steps = 2 if pp > 1 else 1 + token_width = int(pipeline_config.sequence_length) if pp > 1 else 8 + step_input_ids: list[torch.Tensor] = [ + torch.tensor( + [[((step * 7 + i) % 29) + 1 for i in range(token_width)]] * (dp * ga_steps), + dtype=torch.long, + ) + for step in range(n_steps) + ] + + # ----------------------------------------------------------------------- + # Phase 1: isolated cluster + # Reset driver-side RNG so host-side tensor construction is reproducible. + # ----------------------------------------------------------------------- + _seed_driver(seed) + pa_cfg = _isolated_worker_config( + adapter_names=adapter_names, + model_dir=model_dir, + dp=dp, + tp=tp, + pp=pp, + gradient_accumulation_steps=ga_steps, + ) + pa_cluster = Cluster( + name=_unique_cluster_name("multi_lora_isolated"), + worker_cls=pa_cfg.worker_cls, + resource_manager=resource_manager, + worker_config=pa_cfg, + ) + pa_cluster.initialize(pipeline_config=pipeline_config, blocking=True) + + # Ensure all adapters start from identical weights (copy from first). + first = adapter_names[0] + for other in adapter_names[1:]: + pa_cluster.copy_lora_params(src_adapter=first, dst_adapter=other) + # For non-PP runs, normalize DP rank drift at init. + # PP runs shard LoRA tensors by stage, so rank-0 tensors cannot be broadcast + # to every rank. + if pp == 1: + for name in adapter_names: + pa_cluster.set_lora_tensors(name, pa_cluster.get_lora_tensors(name)[0]) + + init_weights: dict[str, dict[str, torch.Tensor]] | None = None + if pp == 1: + init_weights = { + name: pa_cluster.get_lora_tensors(name)[0] + for name in adapter_names + } + + # Train all adapters for n_steps steps under the requested ordering. + isolated_losses: dict[str, list[float]] = {name: [] for name in adapter_names} + isolated_lora_trace: dict[str, list[dict[str, torch.Tensor]]] = { + name: [] for name in adapter_names + } + + if phase1_order == "sequential": + # All steps for adapter A, then all steps for adapter B, ... + # Mirrors the simplest RLix scheduling policy. + for name in adapter_names: + for step in range(n_steps): + mb = _make_microbatch(step_input_ids[step], name, global_step=step) + result = pa_cluster.train_step_lora(mb) + isolated_losses[name].append(_extract_loss(result)) + if debug_trace: + isolated_lora_trace[name].append(pa_cluster.get_lora_tensors(name)[0]) + + elif phase1_order == "interleaved": + # Round-robin: one step per adapter per outer iteration. + # Verifies that interleaving does NOT corrupt any adapter's loss + # trajectory — the key correctness claim of isolated optimizer + # isolation. Each adapter has its own step counter so global_step + # is per-adapter, matching what the reference cluster sees. + adapter_step: dict[str, int] = {name: 0 for name in adapter_names} + for _outer in range(n_steps): + for name in adapter_names: + s = adapter_step[name] + mb = _make_microbatch(step_input_ids[s], name, global_step=s) + result = pa_cluster.train_step_lora(mb) + isolated_losses[name].append(_extract_loss(result)) + if debug_trace: + isolated_lora_trace[name].append(pa_cluster.get_lora_tensors(name)[0]) + adapter_step[name] += 1 + + else: + raise ValueError( + f"Unknown phase1_order={phase1_order!r}; expected 'sequential' or 'interleaved'" + ) + + _shutdown(pa_cluster) + + # ----------------------------------------------------------------------- + # Phase 2: upstream single-LoRA reference clusters (sequential, same GPUs) + # Reset driver-side RNG to the same state as before Phase 1 so any + # driver-side random ops are identical. + # ----------------------------------------------------------------------- + _seed_driver(seed) + reference_losses: dict[str, list[float]] = {} + reference_lora_trace: dict[str, list[dict[str, torch.Tensor]]] = { + name: [] for name in adapter_names + } + + for name in adapter_names: + ref_cfg = _reference_worker_config( + adapter_name=name, + model_dir=model_dir, + dp=dp, + tp=tp, + pp=pp, + gradient_accumulation_steps=ga_steps, + ) + ref_cluster = Cluster( + name=_unique_cluster_name(f"ref_{name}"), + worker_cls=ref_cfg.worker_cls, + resource_manager=resource_manager, + worker_config=ref_cfg, + ) + ref_cluster.initialize(pipeline_config=pipeline_config, blocking=True) + + # Restore initial weights from Phase 1 so both runs start identically. + # PP runs keep LoRA tensors sharded by stage; this helper applies one + # tensor dict to all ranks, so only restore in non-PP mode. + if init_weights is not None: + ref_cluster.set_lora_tensors(name, init_weights[name]) + + step_losses: list[float] = [] + for step in range(n_steps): + mb = _make_microbatch(step_input_ids[step], name, global_step=step) + result = ref_cluster.train_step(mb) + step_losses.append(_extract_loss(result)) + if debug_trace: + reference_lora_trace[name].append(ref_cluster.get_lora_tensors(name)[0]) + + _shutdown(ref_cluster) + reference_losses[name] = step_losses + + if debug_trace: + # Lightweight diff report to bisect divergence between isolated and reference runs. + for name in adapter_names: + if init_weights is None: + continue + init_tensors = init_weights[name] + for step in range(n_steps): + pa_tensors = isolated_lora_trace[name][step] + ref_tensors = reference_lora_trace[name][step] + max_diff = 0.0 + max_key = None + max_pa_delta = 0.0 + max_ref_delta = 0.0 + for k, pa_v in pa_tensors.items(): + ref_v = ref_tensors.get(k) + if ref_v is None: + raise KeyError(f"[debug] Missing tensor {k!r} in reference trace for {name!r}") + d = (pa_v.float() - ref_v.float()).abs().max().item() + if d > max_diff: + max_diff = d + max_key = k + init_v = init_tensors.get(k) + if init_v is None: + raise KeyError(f"[debug] Missing tensor {k!r} in init trace for {name!r}") + pa_d = (pa_v.float() - init_v.float()).abs().max().item() + ref_d = (ref_v.float() - init_v.float()).abs().max().item() + if pa_d > max_pa_delta: + max_pa_delta = pa_d + if ref_d > max_ref_delta: + max_ref_delta = ref_d + print(f"[debug] adapter={name} step={step} max_lora_param_abs_diff={max_diff:.6e} key={max_key}") + print(f"[debug] adapter={name} step={step} max_abs_delta_vs_init: isolated={max_pa_delta:.6e} reference={max_ref_delta:.6e}") + + # ----------------------------------------------------------------------- + # Assert: isolated loss == reference loss at every (adapter, step) + # ----------------------------------------------------------------------- + for name in adapter_names: + pa_losses = isolated_losses[name] + ref_losses = reference_losses[name] + assert len(pa_losses) == len(ref_losses) == n_steps, ( + f"[adapter={name}] Unexpected step count: pa={len(pa_losses)}, ref={len(ref_losses)}" + ) + for step, (pa_loss, ref_loss) in enumerate(zip(pa_losses, ref_losses)): + pa_t = torch.tensor(pa_loss) + ref_t = torch.tensor(ref_loss) + torch.testing.assert_close( + pa_t, + ref_t, + rtol=1e-5, + atol=1e-6, + msg=( + f"Loss mismatch at adapter={name!r} step={step} " + f"[dp={dp}, tp={tp}, pp={pp}]: " + f"isolated={pa_loss:.8f}, reference={ref_loss:.8f}" + ), + ) + + +# --------------------------------------------------------------------------- +# TC-1: dp=1, tp=1, adapters=[a, b] — needs 1 GPU +# --------------------------------------------------------------------------- + +@pytest.mark.skipif( + torch.cuda.device_count() < 1, + reason="TC-1 requires >= 1 CUDA device (dp=1, tp=1).", +) +def test_tc1_isolated_single_lora_step_dp1_tp1(): + """ + TC-1 dp=1, tp=1, adapters=[a, b], n_steps=3. + + Exercises both Phase-1 orderings against the same single-LoRA reference: + - ``sequential``: all steps for adapter_a, then all steps for adapter_b. + - ``interleaved``: step 0 → [a, b], step 1 → [a, b], step 2 → [a, b]. + + Both must produce losses matching the reference at every (adapter, step). + GPU budget: 1 (clusters run sequentially on the same GPU). + """ + model_id = "Qwen/Qwen2.5-0.5B-Instruct" + model_dir = _download_model(model_id) + + os.environ.setdefault("roll_RPC_TIMEOUT", "600") + _ray_init() + + dp, tp = 1, 1 + resource_manager = ResourceManager(num_nodes=1, num_gpus_per_node=torch.cuda.device_count()) + pipeline_config = _make_pipeline_config(seed=42, sequence_length=64) + + for order in ("sequential", "interleaved"): + _run_equivalence_test( + adapter_names=["adapter_a", "adapter_b"], + dp=dp, + tp=tp, + model_dir=model_dir, + resource_manager=resource_manager, + pipeline_config=pipeline_config, + n_steps=3, + phase1_order=order, + ) + + +# --------------------------------------------------------------------------- +# TC-2: dp=2, tp=1, adapters=[a, b, c] — needs 2 GPUs +# --------------------------------------------------------------------------- + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="TC-2 requires >= 2 CUDA devices (dp=2, tp=1).", +) +def test_tc2_isolated_single_lora_step_dp2_tp1(): + """ + TC-2 dp=2, tp=1, adapters=[a, b, c], n_steps=3. + + Exercises both Phase-1 orderings under data parallelism (dp=2). + GPU budget: 2 (clusters run sequentially). + """ + model_id = "Qwen/Qwen2.5-0.5B-Instruct" + model_dir = _download_model(model_id) + + os.environ.setdefault("roll_RPC_TIMEOUT", "600") + _ray_init() + + dp, tp = 2, 1 + resource_manager = ResourceManager(num_nodes=1, num_gpus_per_node=torch.cuda.device_count()) + pipeline_config = _make_pipeline_config(seed=42, sequence_length=64) + + for order in ("sequential", "interleaved"): + _run_equivalence_test( + adapter_names=["adapter_a", "adapter_b", "adapter_c"], + dp=dp, + tp=tp, + model_dir=model_dir, + resource_manager=resource_manager, + pipeline_config=pipeline_config, + n_steps=3, + phase1_order=order, + ) + + +# --------------------------------------------------------------------------- +# TC-3: dp=1, tp=2, adapters=[a, b, c] — needs 2 GPUs +# --------------------------------------------------------------------------- + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="TC-3 requires >= 2 CUDA devices (dp=1, tp=2).", +) +def test_tc3_isolated_single_lora_step_dp1_tp2(): + """ + TC-3 dp=1, tp=2, adapters=[a, b, c], n_steps=3. + + Exercises both Phase-1 orderings under tensor parallelism (tp=2). + GPU budget: 2 (clusters run sequentially). + """ + model_id = "Qwen/Qwen2.5-0.5B-Instruct" + model_dir = _download_model(model_id) + + os.environ.setdefault("roll_RPC_TIMEOUT", "600") + _ray_init() + + dp, tp = 1, 2 + resource_manager = ResourceManager(num_nodes=1, num_gpus_per_node=torch.cuda.device_count()) + pipeline_config = _make_pipeline_config(seed=42, sequence_length=64) + + for order in ("sequential", "interleaved"): + _run_equivalence_test( + adapter_names=["adapter_a", "adapter_b", "adapter_c"], + dp=dp, + tp=tp, + model_dir=model_dir, + resource_manager=resource_manager, + pipeline_config=pipeline_config, + n_steps=3, + phase1_order=order, + ) + + +# --------------------------------------------------------------------------- +# TC-4: dp=2, tp=2, adapters=[a, b, c] — needs 4 GPUs +# --------------------------------------------------------------------------- + +@pytest.mark.skipif( + torch.cuda.device_count() < 4, + reason="TC-4 requires >= 4 CUDA devices (dp=2, tp=2).", +) +def test_tc4_isolated_single_lora_step_dp2_tp2(): + """ + TC-4 dp=2, tp=2, adapters=[a, b, c], n_steps=3. + + Exercises both Phase-1 orderings under combined data + tensor parallelism. + GPU budget: 4 (clusters run sequentially). + """ + model_id = "Qwen/Qwen2.5-0.5B-Instruct" + model_dir = _download_model(model_id) + + os.environ.setdefault("roll_RPC_TIMEOUT", "600") + _ray_init() + + dp, tp = 2, 2 + resource_manager = ResourceManager(num_nodes=1, num_gpus_per_node=torch.cuda.device_count()) + pipeline_config = _make_pipeline_config(seed=42, sequence_length=64) + + for order in ("sequential", "interleaved"): + _run_equivalence_test( + adapter_names=["adapter_a", "adapter_b", "adapter_c"], + dp=dp, + tp=tp, + pp=1, + model_dir=model_dir, + resource_manager=resource_manager, + pipeline_config=pipeline_config, + n_steps=3, + phase1_order=order, + ) + + +# --------------------------------------------------------------------------- +# TC-5: dp=1, tp=1, pp=2, adapters=[a, b, c] — needs 2 GPUs +# --------------------------------------------------------------------------- + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="TC-5 requires >= 2 CUDA devices (dp=1, tp=1, pp=2).", +) +def test_tc5_isolated_single_lora_step_dp1_tp1_pp2(): + """ + TC-5 dp=1, tp=1, pp=2, adapters=[a, b, c], n_steps=1. + + Exercises both Phase-1 orderings under pipeline parallelism. + GPU budget: 2 (clusters run sequentially). + """ + model_id = "Qwen/Qwen2.5-0.5B-Instruct" + model_dir = _download_model(model_id) + + os.environ.setdefault("roll_RPC_TIMEOUT", "600") + _ray_init() + + dp, tp, pp = 1, 1, 2 + resource_manager = ResourceManager(num_nodes=1, num_gpus_per_node=torch.cuda.device_count()) + pipeline_config = _make_pipeline_config(seed=42, sequence_length=64) + + for order in ("sequential", "interleaved"): + _run_equivalence_test( + adapter_names=["adapter_a", "adapter_b", "adapter_c"], + dp=dp, + tp=tp, + pp=pp, + model_dir=model_dir, + resource_manager=resource_manager, + pipeline_config=pipeline_config, + n_steps=3, + phase1_order=order, + ) + + +# --------------------------------------------------------------------------- +# TC-6: dp=1, tp=2, pp=2, adapters=[a, b, c] — needs 4 GPUs +# --------------------------------------------------------------------------- + +@pytest.mark.skipif( + torch.cuda.device_count() < 4, + reason="TC-6 requires >= 4 CUDA devices (dp=1, tp=2, pp=2).", +) +def test_tc6_isolated_single_lora_step_dp1_tp2_pp2(): + """ + TC-6 dp=1, tp=2, pp=2, adapters=[a, b, c], n_steps=1. + + Exercises both Phase-1 orderings under combined tensor + pipeline parallelism. + GPU budget: 4 (clusters run sequentially). + """ + model_id = "Qwen/Qwen2.5-0.5B-Instruct" + model_dir = _download_model(model_id) + + os.environ.setdefault("roll_RPC_TIMEOUT", "600") + _ray_init() + + dp, tp, pp = 1, 2, 2 + resource_manager = ResourceManager(num_nodes=1, num_gpus_per_node=torch.cuda.device_count()) + pipeline_config = _make_pipeline_config(seed=42, sequence_length=64) + + for order in ("sequential", "interleaved"): + _run_equivalence_test( + adapter_names=["adapter_a", "adapter_b", "adapter_c"], + dp=dp, + tp=tp, + pp=pp, + model_dir=model_dir, + resource_manager=resource_manager, + pipeline_config=pipeline_config, + n_steps=3, + phase1_order=order, + ) + + +# --------------------------------------------------------------------------- +# TC-7: dp=2, tp=1, pp=2, adapters=[a, b, c] — needs 4 GPUs +# --------------------------------------------------------------------------- + +@pytest.mark.skipif( + torch.cuda.device_count() < 4, + reason="TC-7 requires >= 4 CUDA devices (dp=2, tp=1, pp=2).", +) +def test_tc7_isolated_single_lora_step_dp2_tp1_pp2(): + """ + TC-7 dp=2, tp=1, pp=2, adapters=[a, b, c], n_steps=1. + + Exercises both Phase-1 orderings under combined data + pipeline parallelism. + GPU budget: 4 (clusters run sequentially). + """ + model_id = "Qwen/Qwen2.5-0.5B-Instruct" + model_dir = _download_model(model_id) + + os.environ.setdefault("roll_RPC_TIMEOUT", "600") + _ray_init() + + dp, tp, pp = 2, 1, 2 + resource_manager = ResourceManager(num_nodes=1, num_gpus_per_node=torch.cuda.device_count()) + pipeline_config = _make_pipeline_config(seed=42, sequence_length=64) + + for order in ("sequential", "interleaved"): + _run_equivalence_test( + adapter_names=["adapter_a", "adapter_b", "adapter_c"], + dp=dp, + tp=tp, + pp=pp, + model_dir=model_dir, + resource_manager=resource_manager, + pipeline_config=pipeline_config, + n_steps=3, + phase1_order=order, + ) diff --git a/tests/utils/test_send_recv_utils.py b/tests/utils/test_send_recv_utils.py new file mode 100644 index 000000000..0172f050a --- /dev/null +++ b/tests/utils/test_send_recv_utils.py @@ -0,0 +1,621 @@ +"""Tests for cpu_serialize weight transfer in send_recv_utils. + +Covers bucket pack/unpack, full cpu_serialize serialize/deserialize round-trips, +and end-to-end scenarios with realistic model weights. All tests are CPU-only +since the cpu_serialize path is specifically designed for CPU serialization. + +The benchmark test uses Ray object store (ray.put/ray.get) to faithfully +measure cross-process transport cost, matching the production flow where +serialized bytes travel through Ray's object store (/dev/shm) between +training and inference workers. +""" + +import io +import time +from typing import Optional + +import pytest +import ray +import torch + +from roll.utils.send_recv_utils import ( + _bucket_named_tensors, + named_tensors_from_bucket, + serialize_named_weights, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_named_weights( + shapes: list[tuple[str, tuple[int, ...]]], + dtype: torch.dtype = torch.float32, +) -> list[tuple[str, torch.Tensor]]: + """Create deterministic named tensors via torch.arange for reproducibility.""" + named_weights: list[tuple[str, torch.Tensor]] = [] + for name, shape in shapes: + numel = 1 + for dim in shape: + numel *= dim + # arange in float32 then cast — avoids dtype issues with arange on bf16 + tensor = torch.arange(numel, dtype=torch.float32).reshape(shape).to(dtype) + named_weights.append((name, tensor)) + return named_weights + + +def _deserialize_cpu_serialize( + serialized: bytes, + target_device: str = "cpu", +) -> list[tuple[str, torch.Tensor]]: + """Mirror the vLLM worker deserialization path for cpu_serialize payloads. + + Steps: torch.load -> move bucket to device -> named_tensors_from_bucket. + This matches the logic in worker.py for cpu_serialize transport (torch.save format). + """ + payload = torch.load(io.BytesIO(serialized), weights_only=True) + bucket = payload["bucket"].to(target_device) + tensors_meta = payload["tensors_meta"] + return named_tensors_from_bucket(bucket, tensors_meta) + + +def _assert_named_weights_equal( + actual: list[tuple[str, torch.Tensor]], + expected: list[tuple[str, torch.Tensor]], + msg: Optional[str] = None, +) -> None: + """Assert exact match on name, shape, dtype, and values.""" + prefix = f"{msg}: " if msg else "" + assert len(actual) == len(expected), ( + f"{prefix}length mismatch: {len(actual)} vs {len(expected)}" + ) + for idx, ((actual_name, actual_tensor), (expected_name, expected_tensor)) in enumerate( + zip(actual, expected) + ): + assert actual_name == expected_name, ( + f"{prefix}name mismatch at index {idx}: {actual_name!r} vs {expected_name!r}" + ) + assert actual_tensor.shape == expected_tensor.shape, ( + f"{prefix}shape mismatch for {actual_name}: {actual_tensor.shape} vs {expected_tensor.shape}" + ) + assert actual_tensor.dtype == expected_tensor.dtype, ( + f"{prefix}dtype mismatch for {actual_name}: {actual_tensor.dtype} vs {expected_tensor.dtype}" + ) + assert torch.equal(actual_tensor, expected_tensor), ( + f"{prefix}value mismatch for {actual_name}" + ) + + +# --------------------------------------------------------------------------- +# TestBucketRoundTrip — unit tests for _bucket_named_tensors / named_tensors_from_bucket +# --------------------------------------------------------------------------- + + +class TestBucketRoundTrip: + """Unit tests for bucket pack (_bucket_named_tensors) and unpack (named_tensors_from_bucket).""" + + def test_single_tensor(self) -> None: + """One (4,3) float32 tensor survives bucket round-trip.""" + weights = _make_named_weights([("layer.weight", (4, 3))]) + bucket, meta = _bucket_named_tensors(weights) + reconstructed = named_tensors_from_bucket(bucket, meta) + _assert_named_weights_equal(reconstructed, weights) + + def test_multiple_tensors(self) -> None: + """Three tensors with different shapes survive bucket round-trip.""" + shapes = [ + ("layer0.weight", (4, 3)), + ("layer1.bias", (8,)), + ("layer2.weight", (2, 5, 3)), + ] + weights = _make_named_weights(shapes) + bucket, meta = _bucket_named_tensors(weights) + reconstructed = named_tensors_from_bucket(bucket, meta) + _assert_named_weights_equal(reconstructed, weights) + + def test_preserves_dtype_bfloat16(self) -> None: + """bfloat16 dtype is preserved through bucket round-trip.""" + weights = _make_named_weights([("bf16.weight", (4, 3))], dtype=torch.bfloat16) + bucket, meta = _bucket_named_tensors(weights) + reconstructed = named_tensors_from_bucket(bucket, meta) + _assert_named_weights_equal(reconstructed, weights) + + def test_preserves_dtype_float16(self) -> None: + """float16 dtype is preserved through bucket round-trip.""" + weights = _make_named_weights([("fp16.weight", (4, 3))], dtype=torch.float16) + bucket, meta = _bucket_named_tensors(weights) + reconstructed = named_tensors_from_bucket(bucket, meta) + _assert_named_weights_equal(reconstructed, weights) + + def test_empty_raises(self) -> None: + """Empty input raises ValueError.""" + with pytest.raises(ValueError, match="Cannot create empty tensor bucket"): + _bucket_named_tensors([]) + + def test_scalar_shaped_tensor(self) -> None: + """(1,) shaped tensor (edge case) survives bucket round-trip.""" + weights = _make_named_weights([("scalar.param", (1,))]) + bucket, meta = _bucket_named_tensors(weights) + reconstructed = named_tensors_from_bucket(bucket, meta) + _assert_named_weights_equal(reconstructed, weights) + + def test_large_tensor(self) -> None: + """(1024, 512) tensor survives bucket round-trip.""" + weights = _make_named_weights([("large.weight", (1024, 512))]) + bucket, meta = _bucket_named_tensors(weights) + reconstructed = named_tensors_from_bucket(bucket, meta) + _assert_named_weights_equal(reconstructed, weights) + + +# --------------------------------------------------------------------------- +# TestCpuSerializeSerialize — unit tests for serialize_named_weights with cpu_serialize +# --------------------------------------------------------------------------- + + +class TestCpuSerializeSerialize: + """Unit tests for full cpu_serialize serialize -> deserialize round-trip.""" + + def test_roundtrip_single_tensor(self) -> None: + """Single tensor survives cpu_serialize serialize/deserialize.""" + weights = _make_named_weights([("layer.weight", (4, 3))]) + serialized = serialize_named_weights(weights, infer_strategy="vllm", model_update_transport="cpu_serialize") + reconstructed = _deserialize_cpu_serialize(serialized) + _assert_named_weights_equal(reconstructed, weights) + + def test_roundtrip_multiple_tensors(self) -> None: + """Multiple tensors survive cpu_serialize serialize/deserialize.""" + shapes = [ + ("model.embed.weight", (16, 8)), + ("model.layer.weight", (8, 8)), + ("model.head.bias", (16,)), + ] + weights = _make_named_weights(shapes) + serialized = serialize_named_weights(weights, infer_strategy="vllm", model_update_transport="cpu_serialize") + reconstructed = _deserialize_cpu_serialize(serialized) + _assert_named_weights_equal(reconstructed, weights) + + def test_roundtrip_bfloat16(self) -> None: + """bfloat16 tensors survive cpu_serialize round-trip.""" + weights = _make_named_weights([("bf16.weight", (4, 3))], dtype=torch.bfloat16) + serialized = serialize_named_weights(weights, infer_strategy="vllm", model_update_transport="cpu_serialize") + reconstructed = _deserialize_cpu_serialize(serialized) + _assert_named_weights_equal(reconstructed, weights) + + def test_roundtrip_float16(self) -> None: + """float16 tensors survive cpu_serialize round-trip.""" + weights = _make_named_weights([("fp16.weight", (4, 3))], dtype=torch.float16) + serialized = serialize_named_weights(weights, infer_strategy="vllm", model_update_transport="cpu_serialize") + reconstructed = _deserialize_cpu_serialize(serialized) + _assert_named_weights_equal(reconstructed, weights) + + def test_roundtrip_large_multi_layer(self) -> None: + """4 layers with realistic shapes survive cpu_serialize round-trip.""" + shapes = [ + ("model.layers.0.self_attn.q_proj.weight", (512, 512)), + ("model.layers.0.self_attn.k_proj.weight", (128, 512)), + ("model.layers.0.mlp.gate_proj.weight", (1024, 512)), + ("model.layers.0.mlp.down_proj.weight", (512, 1024)), + ] + weights = _make_named_weights(shapes, dtype=torch.bfloat16) + serialized = serialize_named_weights(weights, infer_strategy="vllm", model_update_transport="cpu_serialize") + reconstructed = _deserialize_cpu_serialize(serialized) + _assert_named_weights_equal(reconstructed, weights) + + def test_payload_is_bytes(self) -> None: + """serialize_named_weights with cpu_serialize returns bytes.""" + weights = _make_named_weights([("layer.weight", (4, 3))]) + serialized = serialize_named_weights(weights, infer_strategy="vllm", model_update_transport="cpu_serialize") + assert isinstance(serialized, bytes) + + def test_payload_contains_cpu_tensor(self) -> None: + """Deserialized bucket tensor resides on CPU.""" + weights = _make_named_weights([("layer.weight", (4, 3))]) + serialized = serialize_named_weights(weights, infer_strategy="vllm", model_update_transport="cpu_serialize") + # cpu_serialize now uses torch.save format + payload = torch.load(io.BytesIO(serialized), weights_only=True) + bucket = payload["bucket"] + assert bucket.device == torch.device("cpu") + + def test_invalid_transport_raises(self) -> None: + """Unknown transport raises ValueError.""" + weights = _make_named_weights([("layer.weight", (4, 3))]) + with pytest.raises(ValueError, match="Unsupported model_update_transport"): + serialize_named_weights(weights, infer_strategy="vllm", model_update_transport="unknown_transport") + + +# --------------------------------------------------------------------------- +# TestCpuSerializeEndToEnd — realistic end-to-end scenarios +# --------------------------------------------------------------------------- + + +class TestCpuSerializeEndToEnd: + """End-to-end tests simulating realistic weight transfer scenarios.""" + + def test_multi_rank_independent_payloads(self) -> None: + """Serialize same weights N times (simulating N ranks), deserialize each independently.""" + shapes = [ + ("model.embed.weight", (32, 16)), + ("model.layer.weight", (16, 16)), + ("model.head.weight", (32, 16)), + ] + original_weights = _make_named_weights(shapes, dtype=torch.bfloat16) + num_ranks = 4 + + # Simulate per-rank serialization (each rank gets its own copy) + serialized_list = [ + serialize_named_weights(original_weights, infer_strategy="vllm", model_update_transport="cpu_serialize") + for _rank in range(num_ranks) + ] + + # Each rank deserializes independently and must match original + for rank_idx in range(num_ranks): + reconstructed = _deserialize_cpu_serialize(serialized_list[rank_idx]) + _assert_named_weights_equal(reconstructed, original_weights, msg=f"rank {rank_idx}") + + def test_batched_weight_updates(self) -> None: + """Split weights into batches, serialize/deserialize each, combine and verify.""" + all_shapes = [ + ("model.layers.0.weight", (64, 32)), + ("model.layers.1.weight", (64, 32)), + ("model.layers.2.weight", (64, 32)), + ("model.layers.3.weight", (64, 32)), + ] + all_weights = _make_named_weights(all_shapes, dtype=torch.bfloat16) + + # Split into two batches (simulating buffer-size-bounded transfer) + batch_size = 2 + batches = [all_weights[start:start + batch_size] for start in range(0, len(all_weights), batch_size)] + + # Serialize and deserialize each batch, collect results + combined: list[tuple[str, torch.Tensor]] = [] + for batch in batches: + serialized = serialize_named_weights(batch, infer_strategy="vllm", model_update_transport="cpu_serialize") + reconstructed = _deserialize_cpu_serialize(serialized) + combined.extend(reconstructed) + + # Combined results must match original + _assert_named_weights_equal(combined, all_weights) + + def test_deserialized_bucket_is_contiguous(self) -> None: + """Deserialized bucket tensor is contiguous in memory.""" + weights = _make_named_weights([("layer.weight", (32, 16))], dtype=torch.bfloat16) + serialized = serialize_named_weights(weights, infer_strategy="vllm", model_update_transport="cpu_serialize") + # cpu_serialize now uses torch.save format + payload = torch.load(io.BytesIO(serialized), weights_only=True) + bucket = payload["bucket"] + assert bucket.is_contiguous(), "Deserialized bucket must be contiguous" + + def test_lora_adapter_weights(self) -> None: + """LoRA-style weight names with low-rank shapes survive round-trip.""" + # Typical LoRA adapter naming and shapes (rank=8) + lora_rank = 8 + hidden_dim = 512 + shapes = [ + ("base_model.model.layers.0.self_attn.q_proj.lora_A.weight", (lora_rank, hidden_dim)), + ("base_model.model.layers.0.self_attn.q_proj.lora_B.weight", (hidden_dim, lora_rank)), + ("base_model.model.layers.0.self_attn.v_proj.lora_A.weight", (lora_rank, hidden_dim)), + ("base_model.model.layers.0.self_attn.v_proj.lora_B.weight", (hidden_dim, lora_rank)), + ] + weights = _make_named_weights(shapes, dtype=torch.bfloat16) + serialized = serialize_named_weights(weights, infer_strategy="vllm", model_update_transport="cpu_serialize") + reconstructed = _deserialize_cpu_serialize(serialized) + _assert_named_weights_equal(reconstructed, weights, msg="lora adapter") + + @pytest.mark.slow + def test_full_model_state_dict_roundtrip(self) -> None: + """Init Qwen/Qwen2.5-1.5B-Instruct with dummy weights, serialize all, verify exact match. + + Uses from_config (random init) to skip multi-GB download — correctness + test only needs matching shapes/dtypes, not pretrained values. + Skipped if transformers is not installed. + """ + transformers = pytest.importorskip("transformers") + + config = transformers.AutoConfig.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") + model = transformers.AutoModelForCausalLM.from_config(config).to(dtype=torch.bfloat16) + + # Convert state_dict to list of (name, tensor) pairs + original_weights = list(model.state_dict().items()) + assert len(original_weights) > 0, "Model state dict should not be empty" + + # Serialize via cpu_serialize and deserialize + serialized = serialize_named_weights( + original_weights, infer_strategy="vllm", model_update_transport="cpu_serialize" + ) + reconstructed = _deserialize_cpu_serialize(serialized) + + _assert_named_weights_equal(reconstructed, original_weights, msg="full model state dict") + + +# --------------------------------------------------------------------------- +# Benchmark — compare cuda_ipc vs cpu_serialize end-to-end with real model +# +# Uses Ray object store (ray.put/ray.get) for cross-process transport, +# matching the production flow: serialize_named_weights() → .remote() → +# Ray object store (/dev/shm) → pickle.loads() on inference worker. +# --------------------------------------------------------------------------- + + +@ray.remote +def _ray_cpu_serialize_deserialize(serialized: bytes) -> float: + """Deserialize cpu_serialize payload in a Ray worker — mirrors production worker.py:761-766. + + In production, serialized bytes travel through Ray's object store (/dev/shm) + before reaching the inference worker. Calling this via .remote() replicates + that exact transfer path, including the object store transport cost. + """ + import io + import time + + import torch + + from roll.utils.send_recv_utils import named_tensors_from_bucket + + deserialize_start = time.perf_counter() + # torch.load matches production worker.py deserialization (torch.save format) + payload = torch.load(io.BytesIO(serialized), weights_only=True) + bucket = payload["bucket"] + # Unpack tensors from bucket (same as production worker.py:766) + _ = list(named_tensors_from_bucket(bucket=bucket, tensors_meta=payload["tensors_meta"])) + return time.perf_counter() - deserialize_start + + +@ray.remote(num_gpus=1) +def _ray_cuda_ipc_deserialize(serialized: bytes) -> float: + """Deserialize cuda_ipc payload in a Ray worker — mirrors production worker.py:758-766. + + CUDA IPC handles are cross-process only, so deserialization must happen in a + separate process. Ray worker runs on a different process with its own GPU context. + """ + import time + + from roll.utils.cuda_ipc_utils import MultiprocessingSerializer + from roll.utils.send_recv_utils import monkey_patch_torch_reductions + + # Production path: monkey patch must be applied before deserialization (worker.py:760) + monkey_patch_torch_reductions() + + deserialize_start = time.perf_counter() + payload = MultiprocessingSerializer.deserialize(serialized) + # Force materialization of the GPU tensor + _ = payload["bucket"].shape + return time.perf_counter() - deserialize_start + + +@pytest.mark.slow +def test_benchmark_cuda_ipc_vs_cpu_serialize() -> None: + """Benchmark serialize + transport + deserialize for both transports. + + Both transports use Ray object store for cross-process transfer, matching the + production flow where .remote() implicitly puts the payload into Ray's object + store (/dev/shm shared memory). This captures the transport cost that dominates + cpu_serialize (~1.2 GB bytes blob) vs cuda_ipc (~few KB IPC handles). + + If CUDA is not available, only the cpu_serialize path is benchmarked and a warning is printed. + Requires transformers. + """ + transformers = pytest.importorskip("transformers") + + cuda_available = torch.cuda.is_available() + if not cuda_available: + print("\nWARNING: CUDA not available — only benchmarking cpu_serialize (skipping cuda_ipc)") + + # Use dummy weights (random init from config) to skip multi-GB download. + # Benchmark measures transport performance, not weight correctness. + benchmark_model_name = "Qwen/Qwen2.5-1.5B-Instruct" + config = transformers.AutoConfig.from_pretrained(benchmark_model_name) + model = transformers.AutoModelForCausalLM.from_config(config).to(dtype=torch.bfloat16) + + # Extract state dict and free model to reclaim memory for bucket allocation + weights_cpu = list(model.state_dict().items()) + del model + if cuda_available: + torch.cuda.empty_cache() + + # Ray must be initialized for cross-process transport via object store + num_gpus = torch.cuda.device_count() if cuda_available else 0 + ray.init(num_cpus=2, num_gpus=num_gpus, ignore_reinit_error=True) + try: + _run_benchmark(weights_cpu, cuda_available=cuda_available, model_name=benchmark_model_name) + finally: + ray.shutdown() + + +def _run_benchmark( + weights_cpu: list[tuple[str, torch.Tensor]], *, cuda_available: bool, model_name: str +) -> None: + """Run the actual benchmark after Ray is initialized. + + Separated from the test function to keep ray.init/shutdown in a clean try/finally. + When cuda_available=False, only the cpu_serialize path is benchmarked. + + Args: + weights_cpu: Pre-extracted state dict entries on CPU. Model must be deleted + before calling this to free GPU memory for bucket allocation. + cuda_available: Whether CUDA is available for cuda_ipc benchmarking. + model_name: Model identifier for benchmark output header. + """ + total_bytes = sum(tensor.numel() * tensor.element_size() for _, tensor in weights_cpu) + total_mb = total_bytes / (1024 * 1024) + + NUM_WARMUP_ROUNDS = 2 + NUM_BENCHMARK_ROUNDS = 5 + + # Move weights to GPU if available — production path starts with GPU tensors. + # Both cpu_serialize and cuda_ipc serialize from GPU weights in production. + weights_gpu: list[tuple[str, torch.Tensor]] | None = None + if cuda_available: + try: + weights_gpu = [(name, tensor.cuda()) for name, tensor in weights_cpu] + except torch.cuda.OutOfMemoryError: + print( + "\nWARNING: Not enough GPU memory to hold weights — " + "benchmarking cpu_serialize with CPU weights only (pinned memory path not exercised)" + ) + + # ------ cpu_serialize: serialize (pinned GPU→CPU + torch.save) + transport via Ray + deserialize ------ + # Production: weights on GPU → serialize_named_weights does pinned GPU→CPU + torch.save. + # Falls back to CPU weights if GPU memory is insufficient. + cpu_serialize_weights = weights_gpu if weights_gpu is not None else weights_cpu + + # Warmup: run full serialize + ray.put + remote deserialize cycle + for _warmup in range(NUM_WARMUP_ROUNDS): + serialized = serialize_named_weights( + cpu_serialize_weights, infer_strategy="vllm", model_update_transport="cpu_serialize" + ) + serialized_ref = ray.put(serialized) + ray.get(_ray_cpu_serialize_deserialize.remote(serialized_ref)) + + cpu_serialize_serialize_times: list[float] = [] + cpu_serialize_transport_times: list[float] = [] + cpu_serialize_deserialize_times: list[float] = [] + cpu_serialize_payload_bytes = 0 + for _round in range(NUM_BENCHMARK_ROUNDS): + if cuda_available: + torch.cuda.synchronize() + + # Serialize: pinned GPU→CPU copy + torch.save (or torch.save only if CPU weights) + start = time.perf_counter() + serialized = serialize_named_weights( + cpu_serialize_weights, infer_strategy="vllm", model_update_transport="cpu_serialize" + ) + cpu_serialize_serialize_times.append(time.perf_counter() - start) + cpu_serialize_payload_bytes = len(serialized) + + # Transport: put bytes into Ray object store (/dev/shm) + start = time.perf_counter() + serialized_ref = ray.put(serialized) + cpu_serialize_transport_times.append(time.perf_counter() - start) + + # Deserialize: Ray worker receives from object store + torch.load + deserialize_elapsed = ray.get(_ray_cpu_serialize_deserialize.remote(serialized_ref)) + cpu_serialize_deserialize_times.append(deserialize_elapsed) + + # ------ cuda_ipc: requires GPU weights (already moved above) ------ + cuda_ipc_serialize_times: list[float] = [] + cuda_ipc_transport_times: list[float] = [] + cuda_ipc_deserialize_times: list[float] = [] + cuda_ipc_payload_bytes = 0 + + if weights_gpu is not None: + from roll.utils.send_recv_utils import monkey_patch_torch_reductions + + monkey_patch_torch_reductions() + + # Warmup: serialize + transport + remote deserialize + for _warmup in range(NUM_WARMUP_ROUNDS): + serialized = serialize_named_weights( + weights_gpu, infer_strategy="vllm", model_update_transport="cuda_ipc" + ) + serialized_ref = ray.put(serialized) + ray.get(_ray_cuda_ipc_deserialize.remote(serialized_ref)) + + for _round in range(NUM_BENCHMARK_ROUNDS): + torch.cuda.synchronize() + + # Serialize: ForkingPickler with cudaIpcGetMemHandle + start = time.perf_counter() + serialized = serialize_named_weights( + weights_gpu, infer_strategy="vllm", model_update_transport="cuda_ipc" + ) + cuda_ipc_serialize_times.append(time.perf_counter() - start) + cuda_ipc_payload_bytes = len(serialized) + + # Transport: put serialized IPC handles into Ray object store + start = time.perf_counter() + serialized_ref = ray.put(serialized) + cuda_ipc_transport_times.append(time.perf_counter() - start) + + # Deserialize: Ray worker (with GPU) receives and reconstructs via IPC handles + deserialize_elapsed = ray.get(_ray_cuda_ipc_deserialize.remote(serialized_ref)) + cuda_ipc_deserialize_times.append(deserialize_elapsed) + + # ------ Print results ------ + _print_benchmark_results( + model_name=model_name, + num_weights=len(weights_cpu), + total_mb=total_mb, + num_benchmark_rounds=NUM_BENCHMARK_ROUNDS, + num_warmup_rounds=NUM_WARMUP_ROUNDS, + cpu_serialize_serialize_times=cpu_serialize_serialize_times, + cpu_serialize_transport_times=cpu_serialize_transport_times, + cpu_serialize_deserialize_times=cpu_serialize_deserialize_times, + cpu_serialize_payload_bytes=cpu_serialize_payload_bytes, + cuda_ipc_serialize_times=cuda_ipc_serialize_times, + cuda_ipc_transport_times=cuda_ipc_transport_times, + cuda_ipc_deserialize_times=cuda_ipc_deserialize_times, + cuda_ipc_payload_bytes=cuda_ipc_payload_bytes, + ) + + +def _median(values: list[float]) -> float: + """Return the median of a list of floats.""" + sorted_values = sorted(values) + mid = len(sorted_values) // 2 + return sorted_values[mid] + + +def _print_benchmark_results( + *, + model_name: str, + num_weights: int, + total_mb: float, + num_benchmark_rounds: int, + num_warmup_rounds: int, + cpu_serialize_serialize_times: list[float], + cpu_serialize_transport_times: list[float], + cpu_serialize_deserialize_times: list[float], + cpu_serialize_payload_bytes: int, + cuda_ipc_serialize_times: list[float], + cuda_ipc_transport_times: list[float], + cuda_ipc_deserialize_times: list[float], + cuda_ipc_payload_bytes: int, +) -> None: + """Print formatted benchmark comparison of cpu_serialize vs cuda_ipc. + + When cuda_ipc lists are empty (no CUDA available), only cpu_serialize results are printed. + """ + cpu_ser = _median(cpu_serialize_serialize_times) + cpu_trans = _median(cpu_serialize_transport_times) + cpu_de = _median(cpu_serialize_deserialize_times) + cpu_total = cpu_ser + cpu_trans + cpu_de + cpu_payload_mb = cpu_serialize_payload_bytes / (1024 * 1024) + + has_cuda_ipc = len(cuda_ipc_serialize_times) > 0 + + print(f"\n{'=' * 95}") + print(f"Benchmark: {model_name} ({total_mb:.1f} MB, {num_weights} tensors)") + print(f"Rounds: {num_benchmark_rounds} (median), Warmup: {num_warmup_rounds}") + print(f"Transport: Ray object store (ray.put/ray.get), matching production .remote() path") + print(f"{'-' * 95}") + print( + f"{'Transport':<12} {'Payload (MB)':>13} {'Serialize (ms)':>15} " + f"{'Transport (ms)':>15} {'Deserialize (ms)':>17} {'Total (ms)':>11}" + ) + print(f"{'-' * 95}") + print( + f"{'cpu_serialize':<12} {cpu_payload_mb:>13.1f} {cpu_ser * 1000:>15.2f} " + f"{cpu_trans * 1000:>15.2f} {cpu_de * 1000:>17.2f} {cpu_total * 1000:>11.2f}" + ) + + if has_cuda_ipc: + ipc_ser = _median(cuda_ipc_serialize_times) + ipc_trans = _median(cuda_ipc_transport_times) + ipc_de = _median(cuda_ipc_deserialize_times) + ipc_total = ipc_ser + ipc_trans + ipc_de + ipc_payload_mb = cuda_ipc_payload_bytes / (1024 * 1024) + + print( + f"{'cuda_ipc':<12} {ipc_payload_mb:>13.3f} {ipc_ser * 1000:>15.2f} " + f"{ipc_trans * 1000:>15.2f} {ipc_de * 1000:>17.2f} {ipc_total * 1000:>11.2f}" + ) + print(f"{'-' * 95}") + speedup = cpu_total / ipc_total if ipc_total > 0 else float("inf") + print(f"cuda_ipc speedup: {speedup:.2f}x") + else: + print(f"{'-' * 95}") + print("cuda_ipc: SKIPPED (CUDA not available)") + + print(f"{'=' * 95}")