Support cloud model paths (s3://, gs://) for fresh training runs#1113
Support cloud model paths (s3://, gs://) for fresh training runs#1113bthecohen wants to merge 2 commits intoNovaSky-AI:mainfrom
Conversation
…ning runs Wrap model loading sites with io.local_read_dir() so users can set trainer.policy.model.path to a cloud path. For local paths this is a no-op; for cloud paths it downloads to a temp dir and cleans up after the model is loaded into GPU memory. Sites wrapped: get_tokenizer, FSDP init_model (policy/critic/ref), Megatron init_configs, vLLM engine init, SGLang engine creation, and VLLMServerActor init. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request adds support for using cloud storage paths for initial model loading, which is a valuable enhancement for flexible training workflows. However, this feature, when combined with the hardcoded trust_remote_code=True setting, introduces a critical Remote Code Execution (RCE) vulnerability if an attacker can control the model path configuration. It is strongly recommended to make trust_remote_code a configurable parameter that defaults to False, requiring explicit opt-in for potentially untrusted code. Furthermore, improvements are needed in sglang_engine.py and vllm_engine.py to prevent potential resource leaks due to manual context manager handling, and clarification is requested for a removed logic in megatron_worker.py to ensure no unintended side effects.
I am having trouble creating individual review comments. Click here to see my feedback.
skyrl-train/skyrl_train/entrypoints/main_base.py (133)
The application hardcodes trust_remote_code=True when initializing the tokenizer. Combined with the newly added support for cloud storage paths (S3, GCS), this allows for Remote Code Execution (RCE) if an attacker can influence the model path configuration. An attacker could point the application to a malicious remote model containing executable code in its configuration files, which would be downloaded and executed on the training worker. It is recommended to make trust_remote_code a configurable parameter that defaults to False.
skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py (119)
The application hardcodes trust_remote_code=True when loading model configurations. This is particularly dangerous as this PR enables loading models from arbitrary cloud storage paths (S3, GCS). If an attacker can control the model_path, they can trigger the execution of arbitrary code embedded in the remote model's configuration. This setting should be configurable and default to False to prevent unauthorized code execution.
skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py (282)
Hardcoded trust_remote_code=True allows for arbitrary code execution when loading models from the provided path. With the addition of cloud path support, this risk is extended to remote sources. This should be made configurable and default to False.
skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py (357)
Hardcoded trust_remote_code=True allows for arbitrary code execution when loading models from the provided path. With the addition of cloud path support, this risk is extended to remote sources. This should be made configurable and default to False.
skyrl-train/skyrl_train/workers/megatron/megatron_worker.py (431)
The init_model method now supports loading models from cloud storage paths via io.local_read_dir. However, the underlying init_configs method hardcodes trust_remote_code=True when loading the tokenizer and config. This allows a remote attacker who can influence the model_path to achieve Remote Code Execution (RCE) by providing a malicious model. You should avoid hardcoding trust_remote_code=True and instead make it a user-controlled option that defaults to False.
skyrl-train/skyrl_train/workers/megatron/megatron_worker.py (719)
The init_model method now supports loading models from cloud storage paths via io.local_read_dir. However, the underlying init_configs method hardcodes trust_remote_code=True when loading the tokenizer and config. This allows a remote attacker who can influence the model_path to achieve Remote Code Execution (RCE) by providing a malicious model. You should avoid hardcoding trust_remote_code=True and instead make it a user-controlled option that defaults to False.
skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py (254-258)
This manual handling of the context manager's __enter__ and __exit__ methods is not robust. If an exception occurs during the initialization of the Engine (line 283), the __exit__ method on line 288 will not be called, leading to a resource leak (the temporary directory will not be cleaned up).
To ensure cleanup happens reliably, you should wrap the code between __enter__ and __exit__ in a try...finally block.
Example:
self._cloud_model_ctx = None
if io.is_cloud_path(original_model_path):
self._cloud_model_ctx = io.local_read_dir(original_model_path)
kwargs["model_path"] = self._cloud_model_ctx.__enter__()
try:
# ... existing initialization code ...
self.engine = Engine(**kwargs)
# ...
finally:
if self._cloud_model_ctx is not None:
self._cloud_model_ctx.__exit__(None, None, None)
self._cloud_model_ctx = NoneSince this change would affect lines not in the diff, I'm leaving it as a comment, but this is an important fix for resource management.
skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py (81-87)
Similar to the sglang_engine.py, the manual use of __enter__ and __exit__ can lead to resource leaks. If self._create_engine() on line 102 raises an exception, the temporary directory for the cloud model will not be cleaned up.
Please wrap the engine initialization logic in a try...finally block to ensure self._cloud_model_ctx.__exit__() is always called after __enter__().
Example:
self._cloud_model_ctx = None
if io.is_cloud_path(original_model_path):
self._cloud_model_ctx = io.local_read_dir(original_model_path)
kwargs["model"] = self._cloud_model_ctx.__enter__()
# ...
try:
# ... existing initialization code ...
self.llm = self._create_engine(*args, **kwargs)
finally:
if self._cloud_model_ctx is not None:
self._cloud_model_ctx.__exit__(None, None, None)
self._cloud_model_ctx = Noneskyrl-train/skyrl_train/workers/megatron/megatron_worker.py (202-208)
This block of code, which overrides token IDs in the model configuration based on the tokenizer, has been removed. This change seems unrelated to supporting cloud model paths and is not mentioned in the pull request description.
Could you please clarify the reason for this removal? If this logic was necessary to ensure consistency between the tokenizer and the model config, removing it might introduce subtle bugs. If it's now redundant, a brief explanation would be helpful for future maintainers.
…mp dir cleanup - Restore accidentally removed update_model_config block in MegatronWorker.init_configs that propagates token IDs from tokenizer and applies user model_config_kwargs - Wrap vLLM and SGLang engine init in try/finally to ensure cloud model temp directories are cleaned up even if engine creation fails Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| with io.local_read_dir(self.cfg.trainer.policy.model.path) as model_path: | ||
| tokenizer = AutoTokenizer.from_pretrained( | ||
| model_path, | ||
| trust_remote_code=True, | ||
| use_fast=not self.cfg.trainer.disable_fast_tokenizer, | ||
| ) |
There was a problem hiding this comment.
🔴 local_read_dir raises FileNotFoundError for HuggingFace model names, breaking non-local/non-cloud paths
Wrapping model loading in local_read_dir breaks HuggingFace Hub model name support (e.g., Qwen/Qwen2.5-1.5B-Instruct). AutoTokenizer.from_pretrained natively resolves HF model names, but now local_read_dir raises FileNotFoundError before the HF method is ever called.
Root Cause and Impact across all affected call sites
local_read_dir at skyrl-train/skyrl_train/utils/io/io.py:196-198 checks exists(input_path) for non-cloud paths, which uses the local filesystem. A HuggingFace model ID like "Qwen/Qwen2.5-1.5B-Instruct" is neither a cloud path nor a local path, so is_cloud_path() returns False, then exists() returns False, and FileNotFoundError is raised.
Before this PR, AutoTokenizer.from_pretrained(self.cfg.trainer.policy.model.path) handled HF model names natively. Now the local_read_dir wrapper intercepts and fails first.
This same regression affects every new call site added by this PR:
skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py:103(Policy),:269(Critic),:346(Ref) —AutoConfig.from_pretrainedandHFModelWrapperboth accept HF names nativelyskyrl-train/skyrl_train/workers/megatron/megatron_worker.py:439(Policy),:727(Ref) — especially severe because the existingsnapshot_downloadfallback (which explicitly handles HF model names) is now unreachable:local_read_dirraises before execution reaches it
Impact: Any user specifying a HuggingFace model ID in trainer.policy.model.path (a previously supported workflow) will get a crash at startup.
Prompt for agents
The fix needs to be applied at the local_read_dir function in skyrl-train/skyrl_train/utils/io/io.py (lines 195-198), not at each call site. The else branch currently raises FileNotFoundError for any path that doesn't exist on the local filesystem, but HuggingFace model names (e.g., "Qwen/Qwen2.5-1.5B-Instruct") are valid non-local, non-cloud paths that should be passed through unchanged. The simplest fix is to remove the existence check in the else branch of local_read_dir and just yield input_path directly for non-cloud paths. The downstream libraries (AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoBridge, etc.) all handle both local paths and HuggingFace model names already. Alternatively, if you want to keep validation for local paths, add a check that distinguishes local filesystem paths from HF model names (e.g., check if the path contains os.sep or starts with / or .).
Was this helpful? React with 👍 or 👎 to provide feedback.
| with io.local_read_dir(model_path) as model_path: | ||
| # initialize the bridge and provider objects | ||
| self.init_configs( | ||
| model_path, | ||
| self.cfg.trainer.policy.megatron_config, | ||
| self.cfg.trainer.policy.megatron_config.model_config_kwargs, | ||
| self.cfg.trainer.policy.megatron_config.transformer_config_kwargs, | ||
| bf16=self.cfg.trainer.bf16, | ||
| flash_attn=self.cfg.trainer.flash_attn, | ||
| ) | ||
|
|
||
| # wrap with DDP for training | ||
| self.actor_module = self.make_megatron_module( | ||
| wrap_with_ddp=True, | ||
| ddp_config=self.cfg.trainer.policy.megatron_config.ddp_config, | ||
| lora_config=self.cfg.trainer.policy.model.lora if self._is_lora else None, | ||
| lora_type=self.cfg.trainer.policy.megatron_config.lora_config.lora_type, | ||
| bf16=self.cfg.trainer.bf16, | ||
| ) | ||
| # wrap with DDP for training | ||
| self.actor_module = self.make_megatron_module( | ||
| wrap_with_ddp=True, | ||
| ddp_config=self.cfg.trainer.policy.megatron_config.ddp_config, | ||
| lora_config=self.cfg.trainer.policy.model.lora if self._is_lora else None, | ||
| lora_type=self.cfg.trainer.policy.megatron_config.lora_config.lora_type, | ||
| bf16=self.cfg.trainer.bf16, | ||
| ) | ||
|
|
||
| if self._local_rank == 0 and not os.path.exists( | ||
| model_path | ||
| ): # if not local path, try downloading model weights from huggingface | ||
| snapshot_download(model_path) # will be no-op if already downloaded | ||
| torch.distributed.barrier() | ||
| if self._local_rank == 0 and not os.path.exists( | ||
| model_path | ||
| ): # if not local path, try downloading model weights from huggingface | ||
| snapshot_download(model_path) # will be no-op if already downloaded | ||
| torch.distributed.barrier() |
There was a problem hiding this comment.
🔴 Megatron snapshot_download fallback unreachable inside local_read_dir for HuggingFace model names
The Megatron policy worker's snapshot_download fallback at lines 459–462 is now inside the local_read_dir context. When model_path is a HuggingFace model name (not a cloud path and not a local path), local_read_dir raises FileNotFoundError at skyrl-train/skyrl_train/utils/io/io.py:197-198 before execution ever reaches the snapshot_download call.
Detailed Explanation
Before this PR, the code handled HF model names gracefully:
self.init_configs(model_path, ...) # AutoConfig/AutoBridge handle HF names
self.actor_module = self.make_megatron_module(...)
if self._local_rank == 0 and not os.path.exists(model_path):
snapshot_download(model_path) # downloads HF model for weight loading
torch.distributed.barrier()After this PR, wrapping in local_read_dir means FileNotFoundError is raised before init_configs is even called. The snapshot_download code at lines 459-462 is dead code for the HF model name case.
The same issue affects MegatronRefWorkerBase.init_model at skyrl-train/skyrl_train/workers/megatron/megatron_worker.py:727 where snapshot_download at lines 745-748 is similarly unreachable.
Impact: Megatron workers can no longer be initialized with HuggingFace model names — a previously working workflow.
Was this helpful? React with 👍 or 👎 to provide feedback.
Summary
Adds support for using cloud storage paths (S3, GCS) as the initial model path when starting fresh training runs. Previously, cloud paths were only supported when resuming from checkpoints.
This enables workflows like:
Approach
Reuses the existing
io.local_read_dir()context manager at each model loading site. For local paths, it's a no-op passthrough. For cloud paths, it downloads to a temp directory and yields the local path. Temp files are cleaned up after model weights are loaded into GPU memory.Changes
io.py: Fixlocal_read_dirto handle s3fs nesting files under a subdirectory when downloading recursivelymain_base.py: Wrapget_tokenizer()withlocal_read_dirfsdp_worker.py: Wrapinit_model()in Policy, Critic, and Ref workers withlocal_read_dirmegatron_worker.py: Wrapinit_model()in Policy and Ref workers withlocal_read_dir(covers bothinit_configsandmake_megatron_modulewhich both need model files)vllm_engine.py: Resolve cloud path inBaseVLLMInferenceEngine.__init__(), preserving original path asserved_model_namesglang_engine.py: Resolve cloud path inSGLangInferenceEngine.__init__()(same pattern as vLLM)Design decisions
served_model_namepreserved — inference engines use the original cloud path for API naming consistencyTesting
Tested on a 2-node (16x A100) Ray cluster with an S3-hosted Qwen2.5-1.5B-Instruct model on GSM8K:
Not yet tested / follow-up needed
__init__). Maintainers, please verify.get_http_inference_client()entrypoint was added after this branch's base commit, so cloud path support forVLLMServerActorshould be added as a follow-up once the HTTP inference codepath is the default.🤖 Generated with Claude Code