Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions skyrl-train/skyrl_train/entrypoints/main_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,14 @@ def get_cfg_as_str(dict_cfg: DictConfig) -> str:

def get_tokenizer(self, padding_side="left"):
"""Initializes a tokenizer for the given model."""
tokenizer = AutoTokenizer.from_pretrained(
self.cfg.trainer.policy.model.path,
trust_remote_code=True,
use_fast=not self.cfg.trainer.disable_fast_tokenizer,
)
from skyrl_train.utils.io import io

with io.local_read_dir(self.cfg.trainer.policy.model.path) as model_path:
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
use_fast=not self.cfg.trainer.disable_fast_tokenizer,
)
Comment on lines +130 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 local_read_dir raises FileNotFoundError for HuggingFace model names, breaking non-local/non-cloud paths

Wrapping model loading in local_read_dir breaks HuggingFace Hub model name support (e.g., Qwen/Qwen2.5-1.5B-Instruct). AutoTokenizer.from_pretrained natively resolves HF model names, but now local_read_dir raises FileNotFoundError before the HF method is ever called.

Root Cause and Impact across all affected call sites

local_read_dir at skyrl-train/skyrl_train/utils/io/io.py:196-198 checks exists(input_path) for non-cloud paths, which uses the local filesystem. A HuggingFace model ID like "Qwen/Qwen2.5-1.5B-Instruct" is neither a cloud path nor a local path, so is_cloud_path() returns False, then exists() returns False, and FileNotFoundError is raised.

Before this PR, AutoTokenizer.from_pretrained(self.cfg.trainer.policy.model.path) handled HF model names natively. Now the local_read_dir wrapper intercepts and fails first.

This same regression affects every new call site added by this PR:

  • skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py:103 (Policy), :269 (Critic), :346 (Ref) — AutoConfig.from_pretrained and HFModelWrapper both accept HF names natively
  • skyrl-train/skyrl_train/workers/megatron/megatron_worker.py:439 (Policy), :727 (Ref) — especially severe because the existing snapshot_download fallback (which explicitly handles HF model names) is now unreachable: local_read_dir raises before execution reaches it

Impact: Any user specifying a HuggingFace model ID in trainer.policy.model.path (a previously supported workflow) will get a crash at startup.

Prompt for agents
The fix needs to be applied at the local_read_dir function in skyrl-train/skyrl_train/utils/io/io.py (lines 195-198), not at each call site. The else branch currently raises FileNotFoundError for any path that doesn't exist on the local filesystem, but HuggingFace model names (e.g., "Qwen/Qwen2.5-1.5B-Instruct") are valid non-local, non-cloud paths that should be passed through unchanged. The simplest fix is to remove the existence check in the else branch of local_read_dir and just yield input_path directly for non-cloud paths. The downstream libraries (AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoBridge, etc.) all handle both local paths and HuggingFace model names already. Alternatively, if you want to keep validation for local paths, add a check that distinguishes local filesystem paths from HF model names (e.g., check if the path contains os.sep or starts with / or .).
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

tokenizer.padding_side = padding_side
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Expand Down
66 changes: 41 additions & 25 deletions skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,31 +247,47 @@ class SGLangInferenceEngine(InferenceEngineInterface):
"""SGLang inference engine that implements InferenceEngineInterface."""

def __init__(self, *args, bundle_indices: Optional[List[int]] = None, **kwargs):
setup_envvars_for_sglang(kwargs, bundle_indices)

# Store common attributes
self._tp_size = kwargs.get("tp_size", 1)
if self._tp_size > 1:
raise ValueError(
"As of now, we don't support tensor parallel inference engine with SGLang. "
"Please set `inference_engine_tensor_parallel_size` to 1."
)
self.tokenizer = kwargs.pop("tokenizer", None)

# Unused kwargs
_ = kwargs.pop("num_gpus", 1)

# Add custom weight loader
kwargs["custom_weight_loader"] = CUSTOM_WEIGHT_LOADER_PATH

# Always use token-in-token-out SGLang engine
# NOTE(Charlie): unlike vLLM, SGLang cannot do token-in-token-out and
# token-in-text-out in the same engine config.
kwargs["skip_tokenizer_init"] = True

# Create the SGLang engine (signal handler issue is now fixed by patching)
self.engine = Engine(**kwargs)
logger.info(f"Created SGLang engine with kwargs: {kwargs}")
from skyrl_train.utils.io import io

original_model_path = kwargs.get("model_path", "")

self._cloud_model_ctx = None
if io.is_cloud_path(original_model_path):
self._cloud_model_ctx = io.local_read_dir(original_model_path)
local_model_path = self._cloud_model_ctx.__enter__()
kwargs["model_path"] = local_model_path

try:
setup_envvars_for_sglang(kwargs, bundle_indices)

# Store common attributes
self._tp_size = kwargs.get("tp_size", 1)
if self._tp_size > 1:
raise ValueError(
"As of now, we don't support tensor parallel inference engine with SGLang. "
"Please set `inference_engine_tensor_parallel_size` to 1."
)
self.tokenizer = kwargs.pop("tokenizer", None)

# Unused kwargs
_ = kwargs.pop("num_gpus", 1)

# Add custom weight loader
kwargs["custom_weight_loader"] = CUSTOM_WEIGHT_LOADER_PATH

# Always use token-in-token-out SGLang engine
# NOTE(Charlie): unlike vLLM, SGLang cannot do token-in-token-out and
# token-in-text-out in the same engine config.
kwargs["skip_tokenizer_init"] = True

# Create the SGLang engine (signal handler issue is now fixed by patching)
self.engine = Engine(**kwargs)
logger.info(f"Created SGLang engine with kwargs: {kwargs}")
finally:
# Clean up temp directory now that model is loaded into GPU memory
if self._cloud_model_ctx is not None:
self._cloud_model_ctx.__exit__(None, None, None)
self._cloud_model_ctx = None

# Create weight loader for coordinating weight updates
self._weight_loader = SGLangWeightLoader(self.engine, self._tp_size)
Expand Down
52 changes: 35 additions & 17 deletions skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,41 @@ class BaseVLLMInferenceEngine(InferenceEngineInterface):
"""Base class containing shared logic between sync and async VLLM engines."""

def __init__(self, *args, bundle_indices: list = None, **kwargs):
setup_envvars_for_vllm(kwargs, bundle_indices)
vllm_v1_disable_multiproc = kwargs.pop("vllm_v1_disable_multiproc", False)
if vllm_v1_disable_multiproc or vllm.__version__ == "0.8.2":
# https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

# Store common attributes
self._tp_size = kwargs.get("tensor_parallel_size", 1)
self._pp_size = kwargs.get("pipeline_parallel_size", 1)
self._dp_size = kwargs.get("data_parallel_size", 1)
self._is_lora = kwargs.get("enable_lora", False)

# Let subclass create the appropriate engine
self.llm = self._create_engine(*args, **kwargs)

# Weight loader is created by subclass after engine initialization
self._weight_loader = None
from skyrl_train.utils.io import io

original_model_path = kwargs.get("model", "")

self._cloud_model_ctx = None
if io.is_cloud_path(original_model_path):
self._cloud_model_ctx = io.local_read_dir(original_model_path)
local_model_path = self._cloud_model_ctx.__enter__()
kwargs["model"] = local_model_path
if kwargs.get("served_model_name") is None:
kwargs["served_model_name"] = original_model_path

try:
setup_envvars_for_vllm(kwargs, bundle_indices)
vllm_v1_disable_multiproc = kwargs.pop("vllm_v1_disable_multiproc", False)
if vllm_v1_disable_multiproc or vllm.__version__ == "0.8.2":
# https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

# Store common attributes
self._tp_size = kwargs.get("tensor_parallel_size", 1)
self._pp_size = kwargs.get("pipeline_parallel_size", 1)
self._dp_size = kwargs.get("data_parallel_size", 1)
self._is_lora = kwargs.get("enable_lora", False)

# Let subclass create the appropriate engine
self.llm = self._create_engine(*args, **kwargs)

# Weight loader is created by subclass after engine initialization
self._weight_loader = None
finally:
# Clean up temp directory now that model is loaded into GPU memory
if self._cloud_model_ctx is not None:
self._cloud_model_ctx.__exit__(None, None, None)
self._cloud_model_ctx = None

def tp_size(self):
return self._tp_size
Expand Down
9 changes: 8 additions & 1 deletion skyrl-train/skyrl_train/utils/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,14 @@ def local_read_dir(input_path: str):
# Download everything from cloud path to temp_dir
download_directory(input_path, temp_dir)
logger.info(f"Downloaded directory contents from {input_path}")
yield temp_dir
# s3fs.get with recursive=True may nest files under a subdirectory
# named after the last path component. If temp_dir contains a single
# subdirectory and no files, yield that subdirectory instead.
entries = os.listdir(temp_dir)
if len(entries) == 1 and os.path.isdir(os.path.join(temp_dir, entries[0])):
yield os.path.join(temp_dir, entries[0])
else:
yield temp_dir
else:
# For local paths, use directly (but check it exists)
if not exists(input_path):
Expand Down
Loading