diff --git a/skyrl-train/skyrl_train/entrypoints/main_base.py b/skyrl-train/skyrl_train/entrypoints/main_base.py index ba535ac8f2..2299e88c9e 100644 --- a/skyrl-train/skyrl_train/entrypoints/main_base.py +++ b/skyrl-train/skyrl_train/entrypoints/main_base.py @@ -125,11 +125,14 @@ def get_cfg_as_str(dict_cfg: DictConfig) -> str: def get_tokenizer(self, padding_side="left"): """Initializes a tokenizer for the given model.""" - tokenizer = AutoTokenizer.from_pretrained( - self.cfg.trainer.policy.model.path, - trust_remote_code=True, - use_fast=not self.cfg.trainer.disable_fast_tokenizer, - ) + from skyrl_train.utils.io import io + + with io.local_read_dir(self.cfg.trainer.policy.model.path) as model_path: + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + use_fast=not self.cfg.trainer.disable_fast_tokenizer, + ) tokenizer.padding_side = padding_side if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py b/skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py index f20c8ecc19..6c63f019ba 100644 --- a/skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py @@ -247,31 +247,47 @@ class SGLangInferenceEngine(InferenceEngineInterface): """SGLang inference engine that implements InferenceEngineInterface.""" def __init__(self, *args, bundle_indices: Optional[List[int]] = None, **kwargs): - setup_envvars_for_sglang(kwargs, bundle_indices) - - # Store common attributes - self._tp_size = kwargs.get("tp_size", 1) - if self._tp_size > 1: - raise ValueError( - "As of now, we don't support tensor parallel inference engine with SGLang. " - "Please set `inference_engine_tensor_parallel_size` to 1." - ) - self.tokenizer = kwargs.pop("tokenizer", None) - - # Unused kwargs - _ = kwargs.pop("num_gpus", 1) - - # Add custom weight loader - kwargs["custom_weight_loader"] = CUSTOM_WEIGHT_LOADER_PATH - - # Always use token-in-token-out SGLang engine - # NOTE(Charlie): unlike vLLM, SGLang cannot do token-in-token-out and - # token-in-text-out in the same engine config. - kwargs["skip_tokenizer_init"] = True - - # Create the SGLang engine (signal handler issue is now fixed by patching) - self.engine = Engine(**kwargs) - logger.info(f"Created SGLang engine with kwargs: {kwargs}") + from skyrl_train.utils.io import io + + original_model_path = kwargs.get("model_path", "") + + self._cloud_model_ctx = None + if io.is_cloud_path(original_model_path): + self._cloud_model_ctx = io.local_read_dir(original_model_path) + local_model_path = self._cloud_model_ctx.__enter__() + kwargs["model_path"] = local_model_path + + try: + setup_envvars_for_sglang(kwargs, bundle_indices) + + # Store common attributes + self._tp_size = kwargs.get("tp_size", 1) + if self._tp_size > 1: + raise ValueError( + "As of now, we don't support tensor parallel inference engine with SGLang. " + "Please set `inference_engine_tensor_parallel_size` to 1." + ) + self.tokenizer = kwargs.pop("tokenizer", None) + + # Unused kwargs + _ = kwargs.pop("num_gpus", 1) + + # Add custom weight loader + kwargs["custom_weight_loader"] = CUSTOM_WEIGHT_LOADER_PATH + + # Always use token-in-token-out SGLang engine + # NOTE(Charlie): unlike vLLM, SGLang cannot do token-in-token-out and + # token-in-text-out in the same engine config. + kwargs["skip_tokenizer_init"] = True + + # Create the SGLang engine (signal handler issue is now fixed by patching) + self.engine = Engine(**kwargs) + logger.info(f"Created SGLang engine with kwargs: {kwargs}") + finally: + # Clean up temp directory now that model is loaded into GPU memory + if self._cloud_model_ctx is not None: + self._cloud_model_ctx.__exit__(None, None, None) + self._cloud_model_ctx = None # Create weight loader for coordinating weight updates self._weight_loader = SGLangWeightLoader(self.engine, self._tp_size) diff --git a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py index 14474eaa8a..ff46aff8f5 100644 --- a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -74,23 +74,41 @@ class BaseVLLMInferenceEngine(InferenceEngineInterface): """Base class containing shared logic between sync and async VLLM engines.""" def __init__(self, *args, bundle_indices: list = None, **kwargs): - setup_envvars_for_vllm(kwargs, bundle_indices) - vllm_v1_disable_multiproc = kwargs.pop("vllm_v1_disable_multiproc", False) - if vllm_v1_disable_multiproc or vllm.__version__ == "0.8.2": - # https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11 - os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - - # Store common attributes - self._tp_size = kwargs.get("tensor_parallel_size", 1) - self._pp_size = kwargs.get("pipeline_parallel_size", 1) - self._dp_size = kwargs.get("data_parallel_size", 1) - self._is_lora = kwargs.get("enable_lora", False) - - # Let subclass create the appropriate engine - self.llm = self._create_engine(*args, **kwargs) - - # Weight loader is created by subclass after engine initialization - self._weight_loader = None + from skyrl_train.utils.io import io + + original_model_path = kwargs.get("model", "") + + self._cloud_model_ctx = None + if io.is_cloud_path(original_model_path): + self._cloud_model_ctx = io.local_read_dir(original_model_path) + local_model_path = self._cloud_model_ctx.__enter__() + kwargs["model"] = local_model_path + if kwargs.get("served_model_name") is None: + kwargs["served_model_name"] = original_model_path + + try: + setup_envvars_for_vllm(kwargs, bundle_indices) + vllm_v1_disable_multiproc = kwargs.pop("vllm_v1_disable_multiproc", False) + if vllm_v1_disable_multiproc or vllm.__version__ == "0.8.2": + # https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11 + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + + # Store common attributes + self._tp_size = kwargs.get("tensor_parallel_size", 1) + self._pp_size = kwargs.get("pipeline_parallel_size", 1) + self._dp_size = kwargs.get("data_parallel_size", 1) + self._is_lora = kwargs.get("enable_lora", False) + + # Let subclass create the appropriate engine + self.llm = self._create_engine(*args, **kwargs) + + # Weight loader is created by subclass after engine initialization + self._weight_loader = None + finally: + # Clean up temp directory now that model is loaded into GPU memory + if self._cloud_model_ctx is not None: + self._cloud_model_ctx.__exit__(None, None, None) + self._cloud_model_ctx = None def tp_size(self): return self._tp_size diff --git a/skyrl-train/skyrl_train/utils/io/io.py b/skyrl-train/skyrl_train/utils/io/io.py index 62877735ce..1f27662b27 100644 --- a/skyrl-train/skyrl_train/utils/io/io.py +++ b/skyrl-train/skyrl_train/utils/io/io.py @@ -184,7 +184,14 @@ def local_read_dir(input_path: str): # Download everything from cloud path to temp_dir download_directory(input_path, temp_dir) logger.info(f"Downloaded directory contents from {input_path}") - yield temp_dir + # s3fs.get with recursive=True may nest files under a subdirectory + # named after the last path component. If temp_dir contains a single + # subdirectory and no files, yield that subdirectory instead. + entries = os.listdir(temp_dir) + if len(entries) == 1 and os.path.isdir(os.path.join(temp_dir, entries[0])): + yield os.path.join(temp_dir, entries[0]) + else: + yield temp_dir else: # For local paths, use directly (but check it exists) if not exists(input_path): diff --git a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py index a3ade9879f..ed2357bf4e 100644 --- a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py @@ -98,76 +98,79 @@ def backload_to_gpu(self, non_blocking=True, backload_optimizer=True, backload_m self.strategy.backload_to_gpu(self.model, self.optimizer, non_blocking, backload_optimizer, backload_model) def init_model(self, model_path, num_training_steps: int = None): - assert self.cfg.trainer.strategy in ("fsdp", "fsdp2") - strategy = FSDPStrategy( - fsdp_config=self.cfg.trainer.policy.fsdp_config, - optimizer_config=self.cfg.trainer.policy.optimizer_config, - model_config=self.cfg.trainer.policy.model, - fsdp_strategy=self.cfg.trainer.strategy, - seed=self.cfg.trainer.seed, - micro_train_batch_size_per_gpu=self.cfg.trainer.micro_train_batch_size_per_gpu, - num_training_steps=num_training_steps, - ) - strategy.setup_distributed() - self.strategy = strategy + from skyrl_train.utils.io import io + + with io.local_read_dir(model_path) as model_path: + assert self.cfg.trainer.strategy in ("fsdp", "fsdp2") + strategy = FSDPStrategy( + fsdp_config=self.cfg.trainer.policy.fsdp_config, + optimizer_config=self.cfg.trainer.policy.optimizer_config, + model_config=self.cfg.trainer.policy.model, + fsdp_strategy=self.cfg.trainer.strategy, + seed=self.cfg.trainer.seed, + micro_train_batch_size_per_gpu=self.cfg.trainer.micro_train_batch_size_per_gpu, + num_training_steps=num_training_steps, + ) + strategy.setup_distributed() + self.strategy = strategy - self._is_lora = self.cfg.trainer.policy.model.lora.rank > 0 + self._is_lora = self.cfg.trainer.policy.model.lora.rank > 0 - model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - init_context = get_init_weight_context_manager( - use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh - ) - with init_context(): - - wrapped_model = HFModelWrapper( - model_path, - use_flash_attention_2=self.cfg.trainer.flash_attn, - # NOTE (sumanthrh): Model initialization should always be in fp32 - # during training - bf16=False, - lora_rank=self.cfg.trainer.policy.model.lora.rank, - lora_alpha=self.cfg.trainer.policy.model.lora.alpha, - lora_dropout=self.cfg.trainer.policy.model.lora.dropout, - lora_init_method=self.cfg.trainer.policy.model.lora.init_method, - target_modules=self.cfg.trainer.policy.model.lora.target_modules, - exclude_modules=self.cfg.trainer.policy.model.lora.exclude_modules, - sequence_parallel_size=self.cfg.trainer.policy.sequence_parallel_size, - use_sample_packing=self.cfg.trainer.use_sample_packing, - use_torch_compile=self.cfg.trainer.policy.use_torch_compile, - rope_scaling=get_rope_scaling_config(self.cfg.trainer), - rope_theta=get_rope_theta_config(self.cfg.trainer), - model_config_kwargs=self.cfg.trainer.policy.model_config_kwargs, + model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + init_context = get_init_weight_context_manager( + use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh ) - # in-place patch - self._seq_parallel_monkey_patch(model=wrapped_model.model) - - if self.cfg.trainer.gradient_checkpointing: - wrapped_model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs={ - "use_reentrant": self.cfg.trainer.gradient_checkpointing_use_reentrant - } + with init_context(): + + wrapped_model = HFModelWrapper( + model_path, + use_flash_attention_2=self.cfg.trainer.flash_attn, + # NOTE (sumanthrh): Model initialization should always be in fp32 + # during training + bf16=False, + lora_rank=self.cfg.trainer.policy.model.lora.rank, + lora_alpha=self.cfg.trainer.policy.model.lora.alpha, + lora_dropout=self.cfg.trainer.policy.model.lora.dropout, + lora_init_method=self.cfg.trainer.policy.model.lora.init_method, + target_modules=self.cfg.trainer.policy.model.lora.target_modules, + exclude_modules=self.cfg.trainer.policy.model.lora.exclude_modules, + sequence_parallel_size=self.cfg.trainer.policy.sequence_parallel_size, + use_sample_packing=self.cfg.trainer.use_sample_packing, + use_torch_compile=self.cfg.trainer.policy.use_torch_compile, + rope_scaling=get_rope_scaling_config(self.cfg.trainer), + rope_theta=get_rope_theta_config(self.cfg.trainer), + model_config_kwargs=self.cfg.trainer.policy.model_config_kwargs, ) - - self.model, self.optimizer, self.scheduler = strategy.prepare( - (wrapped_model, None, None), - ) - assert ( - self.optimizer is not None and self.scheduler is not None - ), "FSDP preparation should create optimizer and scheduler" - - # Initialize weight extractor - # TODO(haochen): Now module grouping (in order to support FlashRL) is only enabled for the CUDA IPC - # transfer strategy, we can enable it for other strategies as well. - from skyrl_train.weight_sync import CudaIpcTransferStrategy - - group_by_module = self._transfer_strategy_cls is CudaIpcTransferStrategy - self.weight_extractor = FSDPWeightExtractor( - self.model.model, - group_by_module=group_by_module, - batch_size_threshold_gb=( - self.cfg.generator.weight_transfer_threshold_cuda_ipc_GB if group_by_module else 0.0 - ), - ) + # in-place patch + self._seq_parallel_monkey_patch(model=wrapped_model.model) + + if self.cfg.trainer.gradient_checkpointing: + wrapped_model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={ + "use_reentrant": self.cfg.trainer.gradient_checkpointing_use_reentrant + } + ) + + self.model, self.optimizer, self.scheduler = strategy.prepare( + (wrapped_model, None, None), + ) + assert ( + self.optimizer is not None and self.scheduler is not None + ), "FSDP preparation should create optimizer and scheduler" + + # Initialize weight extractor + # TODO(haochen): Now module grouping (in order to support FlashRL) is only enabled for the CUDA IPC + # transfer strategy, we can enable it for other strategies as well. + from skyrl_train.weight_sync import CudaIpcTransferStrategy + + group_by_module = self._transfer_strategy_cls is CudaIpcTransferStrategy + self.weight_extractor = FSDPWeightExtractor( + self.model.model, + group_by_module=group_by_module, + batch_size_threshold_gb=( + self.cfg.generator.weight_transfer_threshold_cuda_ipc_GB if group_by_module else 0.0 + ), + ) async def _save_lora_adapters_and_sync(self, peft_model, lora_sync_path, inference_engine_client): """Collect LoRA parameters, save and call inference engine to load.""" @@ -261,55 +264,58 @@ def backload_to_gpu(self, non_blocking=True, backload_optimizer=True, backload_m self.strategy.backload_to_gpu(self.model, self.optimizer, non_blocking, backload_optimizer, backload_model) def init_model(self, model_path, num_training_steps: int = None): - assert self.cfg.trainer.strategy in ("fsdp", "fsdp2") - strategy = FSDPStrategy( - fsdp_config=self.cfg.trainer.critic.fsdp_config, - optimizer_config=self.cfg.trainer.critic.optimizer_config, - fsdp_strategy=self.cfg.trainer.strategy, - seed=self.cfg.trainer.seed, - micro_train_batch_size_per_gpu=self.cfg.trainer.micro_train_batch_size_per_gpu, - num_training_steps=num_training_steps, - ) - strategy.setup_distributed() - self.strategy = strategy - - model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - init_context = get_init_weight_context_manager( - use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh - ) - with init_context(): - critic = get_llm_for_sequence_regression( - model_path, - "critic", - use_flash_attention_2=self.cfg.trainer.flash_attn, - # NOTE (sumanthrh): Model initialization should always be in fp32 - # during training - bf16=False, - lora_rank=self.cfg.trainer.critic.model.lora.rank, - lora_alpha=self.cfg.trainer.critic.model.lora.alpha, - lora_dropout=self.cfg.trainer.critic.model.lora.dropout, - target_modules=self.cfg.trainer.critic.model.lora.target_modules, - exclude_modules=self.cfg.trainer.critic.model.lora.exclude_modules, - value_head_prefix=self.cfg.trainer.algorithm.value_head_prefix, - init_value_head=self.cfg.trainer.policy.model.path == self.cfg.trainer.critic.model.path, - sequence_parallel_size=self.cfg.trainer.critic.sequence_parallel_size, - use_sample_packing=self.cfg.trainer.use_sample_packing, - model_config_kwargs=self.cfg.trainer.critic.model_config_kwargs, + from skyrl_train.utils.io import io + + with io.local_read_dir(model_path) as model_path: + assert self.cfg.trainer.strategy in ("fsdp", "fsdp2") + strategy = FSDPStrategy( + fsdp_config=self.cfg.trainer.critic.fsdp_config, + optimizer_config=self.cfg.trainer.critic.optimizer_config, + fsdp_strategy=self.cfg.trainer.strategy, + seed=self.cfg.trainer.seed, + micro_train_batch_size_per_gpu=self.cfg.trainer.micro_train_batch_size_per_gpu, + num_training_steps=num_training_steps, ) - self._seq_parallel_monkey_patch(model=critic, use_parent_class=True) + strategy.setup_distributed() + self.strategy = strategy - if self.cfg.trainer.gradient_checkpointing: - critic.gradient_checkpointing_enable( - gradient_checkpointing_kwargs={ - "use_reentrant": self.cfg.trainer.gradient_checkpointing_use_reentrant - } + model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + init_context = get_init_weight_context_manager( + use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh + ) + with init_context(): + critic = get_llm_for_sequence_regression( + model_path, + "critic", + use_flash_attention_2=self.cfg.trainer.flash_attn, + # NOTE (sumanthrh): Model initialization should always be in fp32 + # during training + bf16=False, + lora_rank=self.cfg.trainer.critic.model.lora.rank, + lora_alpha=self.cfg.trainer.critic.model.lora.alpha, + lora_dropout=self.cfg.trainer.critic.model.lora.dropout, + target_modules=self.cfg.trainer.critic.model.lora.target_modules, + exclude_modules=self.cfg.trainer.critic.model.lora.exclude_modules, + value_head_prefix=self.cfg.trainer.algorithm.value_head_prefix, + init_value_head=self.cfg.trainer.policy.model.path == self.cfg.trainer.critic.model.path, + sequence_parallel_size=self.cfg.trainer.critic.sequence_parallel_size, + use_sample_packing=self.cfg.trainer.use_sample_packing, + model_config_kwargs=self.cfg.trainer.critic.model_config_kwargs, ) - - # prepare models/optimizers... - self.model, self.optimizer, self.scheduler = strategy.prepare( - (critic, None, None), - ) - assert self.optimizer is not None + self._seq_parallel_monkey_patch(model=critic, use_parent_class=True) + + if self.cfg.trainer.gradient_checkpointing: + critic.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={ + "use_reentrant": self.cfg.trainer.gradient_checkpointing_use_reentrant + } + ) + + # prepare models/optimizers... + self.model, self.optimizer, self.scheduler = strategy.prepare( + (critic, None, None), + ) + assert self.optimizer is not None def forward( self, @@ -335,36 +341,39 @@ def backload_to_gpu(self, non_blocking=True, **kwargs): self.strategy.backload_to_gpu(self.model, None, non_blocking) def init_model(self, model_path): - assert self.cfg.trainer.strategy in ("fsdp", "fsdp2") - strategy = FSDPStrategy( - fsdp_config=self.cfg.trainer.ref.fsdp_config, - fsdp_strategy=self.cfg.trainer.strategy, - seed=self.cfg.trainer.seed, - micro_train_batch_size_per_gpu=self.cfg.trainer.micro_train_batch_size_per_gpu, - ) - strategy.setup_distributed() - self.strategy = strategy - - model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - init_context = get_init_weight_context_manager( - use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh - ) + from skyrl_train.utils.io import io + + with io.local_read_dir(model_path) as model_path: + assert self.cfg.trainer.strategy in ("fsdp", "fsdp2") + strategy = FSDPStrategy( + fsdp_config=self.cfg.trainer.ref.fsdp_config, + fsdp_strategy=self.cfg.trainer.strategy, + seed=self.cfg.trainer.seed, + micro_train_batch_size_per_gpu=self.cfg.trainer.micro_train_batch_size_per_gpu, + ) + strategy.setup_distributed() + self.strategy = strategy - with init_context(): - wrapped_model = HFModelWrapper( - model_path, - use_flash_attention_2=self.cfg.trainer.flash_attn, - bf16=self.cfg.trainer.bf16, - sequence_parallel_size=self.cfg.trainer.ref.sequence_parallel_size, - use_sample_packing=self.cfg.trainer.use_sample_packing, - rope_scaling=get_rope_scaling_config(self.cfg.trainer), - rope_theta=get_rope_theta_config(self.cfg.trainer), - model_config_kwargs=self.cfg.trainer.ref.model_config_kwargs, + model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + init_context = get_init_weight_context_manager( + use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh ) - self._seq_parallel_monkey_patch(model=wrapped_model.model) - self.model = strategy.prepare(wrapped_model) - self.model.eval() + with init_context(): + wrapped_model = HFModelWrapper( + model_path, + use_flash_attention_2=self.cfg.trainer.flash_attn, + bf16=self.cfg.trainer.bf16, + sequence_parallel_size=self.cfg.trainer.ref.sequence_parallel_size, + use_sample_packing=self.cfg.trainer.use_sample_packing, + rope_scaling=get_rope_scaling_config(self.cfg.trainer), + rope_theta=get_rope_theta_config(self.cfg.trainer), + model_config_kwargs=self.cfg.trainer.ref.model_config_kwargs, + ) + self._seq_parallel_monkey_patch(model=wrapped_model.model) + + self.model = strategy.prepare(wrapped_model) + self.model.eval() def forward( self, diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 60c591d343..a41a950ba8 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -434,30 +434,33 @@ def init_model(self, model_path, num_training_steps: int = 1e9): """ Initialize the model, optimizer, and scheduler for the policy worker. """ - # initialize the bridge and provider objects - self.init_configs( - model_path, - self.cfg.trainer.policy.megatron_config, - self.cfg.trainer.policy.megatron_config.model_config_kwargs, - self.cfg.trainer.policy.megatron_config.transformer_config_kwargs, - bf16=self.cfg.trainer.bf16, - flash_attn=self.cfg.trainer.flash_attn, - ) + from skyrl_train.utils.io import io + + with io.local_read_dir(model_path) as model_path: + # initialize the bridge and provider objects + self.init_configs( + model_path, + self.cfg.trainer.policy.megatron_config, + self.cfg.trainer.policy.megatron_config.model_config_kwargs, + self.cfg.trainer.policy.megatron_config.transformer_config_kwargs, + bf16=self.cfg.trainer.bf16, + flash_attn=self.cfg.trainer.flash_attn, + ) - # wrap with DDP for training - self.actor_module = self.make_megatron_module( - wrap_with_ddp=True, - ddp_config=self.cfg.trainer.policy.megatron_config.ddp_config, - lora_config=self.cfg.trainer.policy.model.lora if self._is_lora else None, - lora_type=self.cfg.trainer.policy.megatron_config.lora_config.lora_type, - bf16=self.cfg.trainer.bf16, - ) + # wrap with DDP for training + self.actor_module = self.make_megatron_module( + wrap_with_ddp=True, + ddp_config=self.cfg.trainer.policy.megatron_config.ddp_config, + lora_config=self.cfg.trainer.policy.model.lora if self._is_lora else None, + lora_type=self.cfg.trainer.policy.megatron_config.lora_config.lora_type, + bf16=self.cfg.trainer.bf16, + ) - if self._local_rank == 0 and not os.path.exists( - model_path - ): # if not local path, try downloading model weights from huggingface - snapshot_download(model_path) # will be no-op if already downloaded - torch.distributed.barrier() + if self._local_rank == 0 and not os.path.exists( + model_path + ): # if not local path, try downloading model weights from huggingface + snapshot_download(model_path) # will be no-op if already downloaded + torch.distributed.barrier() if self._rank == 0: print_model_size(self.actor_module[0]) @@ -719,28 +722,31 @@ def init_model(self, model_path, num_training_steps: int = 1e9): """ Initialize the model for the ref worker. """ - # initialize the bridge and provider objects - self.init_configs( - model_path, - self.cfg.trainer.ref.megatron_config, - self.cfg.trainer.ref.megatron_config.model_config_kwargs, - self.cfg.trainer.ref.megatron_config.transformer_config_kwargs, - bf16=self.cfg.trainer.bf16, - flash_attn=self.cfg.trainer.flash_attn, - ) + from skyrl_train.utils.io import io + + with io.local_read_dir(model_path) as model_path: + # initialize the bridge and provider objects + self.init_configs( + model_path, + self.cfg.trainer.ref.megatron_config, + self.cfg.trainer.ref.megatron_config.model_config_kwargs, + self.cfg.trainer.ref.megatron_config.transformer_config_kwargs, + bf16=self.cfg.trainer.bf16, + flash_attn=self.cfg.trainer.flash_attn, + ) - self.actor_module = self.make_megatron_module( - wrap_with_ddp=False, - ddp_config=None, - bf16=self.cfg.trainer.bf16, - ) + self.actor_module = self.make_megatron_module( + wrap_with_ddp=False, + ddp_config=None, + bf16=self.cfg.trainer.bf16, + ) - # download model weights from huggingface (need to be done for ref worker as well, else errors when colocate_all=False) - if self._local_rank == 0 and not os.path.exists( - model_path - ): # if not local path, try downloading model weights from huggingface - snapshot_download(model_path) # will be no-op if already downloaded - torch.distributed.barrier() + # download model weights from huggingface (need to be done for ref worker as well, else errors when colocate_all=False) + if self._local_rank == 0 and not os.path.exists( + model_path + ): # if not local path, try downloading model weights from huggingface + snapshot_download(model_path) # will be no-op if already downloaded + torch.distributed.barrier() # load weights if self._rank == 0: