[Draft] support for azure blob storage#1172
[Draft] support for azure blob storage#1172chiragbhatt311 wants to merge 7 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request significantly enhances the system by adding support for Azure Blob Storage and introducing a new summarization example. A critical prompt injection vulnerability has been identified in the new summarization judge environment, which could allow the model to manipulate its own rewards. Furthermore, a critical issue exists with a dependency pointing to a personal fork that must be addressed, and several medium-severity suggestions have been provided to improve code quality, consistency, and reproducibility.
| # For now, just always use the current main branch, later it will be better to pin it to a released version. For development, you | ||
| # can set it to your own development branch. | ||
| skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" } | ||
| skyrl-train = { git = "https://github.com/chiragbhatt311/SkyRL.git", subdirectory = "skyrl-train", branch = "chirag/skyrl_train_multinode" } |
There was a problem hiding this comment.
The skyrl-train dependency points to a personal fork and branch. This is a significant risk for maintainability and stability, and it should be reverted to the official repository path before this pull request is merged.
| skyrl-train = { git = "https://github.com/chiragbhatt311/SkyRL.git", subdirectory = "skyrl-train", branch = "chirag/skyrl_train_multinode" } | |
| skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" } |
| def _build_grading_prompt(self, summary: str) -> str: | ||
| """Build the user prompt for the grading API.""" | ||
| return f"""## User Intent | ||
| {self._format_user_intent()} | ||
|
|
||
| ### Original Document ### | ||
| {self.original_document} | ||
|
|
||
| ### Summary ### | ||
| {summary}""" |
There was a problem hiding this comment.
The SummarizationJudgeEnv class is vulnerable to prompt injection because it directly concatenates the model-generated summary into the prompt sent to the LLM grader. An attacker (or the model itself during training) can inject instructions into the summary to manipulate the grader's output and obtain higher rewards. This is a form of reward hacking that can compromise the training process and lead to the model learning to exploit the grading mechanism rather than performing the intended task.
To remediate this, use robust delimiters and clear instructions in the system prompt to ensure the grader treats the summary as data only. For example, wrap the summary in XML-like tags and instruct the grader to only evaluate the content within those tags.
| train_dataset = dataset["train"].select(range(128)) | ||
| val_dataset = dataset["test"].select(range(64)) |
There was a problem hiding this comment.
The dataset is being truncated to a very small, hardcoded number of samples (128 for train, 64 for test). This is likely for debugging but significantly reduces the utility of the example script for other users. It would be better to either remove this truncation to use the full dataset by default, or make these sizes configurable via command-line arguments.
| train_dataset = dataset["train"].select(range(128)) | |
| val_dataset = dataset["test"].select(range(64)) | |
| train_dataset = dataset["train"] | |
| val_dataset = dataset["test"] | |
| @@ -0,0 +1,68 @@ | |||
| set -x | |||
|
|
|||
| set -x | |||
|
|
||
| # NOTE (sumanthrh): `micro_train_batch_size_per_gpu` and `micro_forward_batch_size_per_gpu` can be tuned | ||
|
|
||
| DATA_DIR="/mnt/workspace/datasets/gsm8k_with_reward" |
There was a problem hiding this comment.
The DATA_DIR is hardcoded to an absolute path (/mnt/workspace/...) which is not portable across different environments. Consider using a path relative to $HOME or allowing it to be overridden by an environment variable to make the script more reusable.
| DATA_DIR="/mnt/workspace/datasets/gsm8k_with_reward" | |
| DATA_DIR="${DATA_DIR:-$HOME/data/gsm8k_with_reward}" |
| print(f"Could not find score in response: {response_text[:200]}") | ||
| return 0.0 |
There was a problem hiding this comment.
Throughout this file, print() is used for logging errors and warnings (e.g., lines 210, 214, 220, 241). In a distributed application, it's better to use a structured logger like loguru (which seems to be used elsewhere in the project) for better log management and visibility. Please replace these print calls with logger calls (e.g., logger.warning(...)).
| from azure.identity import ManagedIdentityCredential | ||
| credential = ManagedIdentityCredential(client_id=client_id) if client_id else ManagedIdentityCredential() | ||
| return fsspec.filesystem(proto, account_name=account_name, credential=credential) | ||
| except Exception as e: | ||
| logger.warning(f"ManagedIdentityCredential failed ({e}), falling back to DefaultAzureCredential") | ||
| from azure.identity import DefaultAzureCredential | ||
| return fsspec.filesystem(proto, account_name=account_name, credential=DefaultAzureCredential()) |
There was a problem hiding this comment.
The azure.identity imports are inside this function. It's a best practice to place all imports at the top of the file for clarity, performance, and to avoid circular import issues. Please move these imports to the module level. If they are optional dependencies, you can wrap them in a try...except ImportError block.
| def _ensure_azure_default_client(): | ||
| """Register a default AzureBlobClient for cloudpathlib. | ||
|
|
||
| Fallback chain: | ||
| 1. Connection string (AZURE_STORAGE_CONNECTION_STRING) | ||
| 2. Managed identity (AZURE_CLIENT_ID for user-assigned, otherwise system-assigned) | ||
| 3. DefaultAzureCredential (CLI, env-vars, workload identity, etc.) | ||
|
|
||
| Must be called before cloudpathlib tries to parse az:// paths. | ||
| """ | ||
| try: | ||
| from cloudpathlib import AzureBlobClient | ||
|
|
||
| conn_str = os.environ.get("AZURE_STORAGE_CONNECTION_STRING") | ||
| if conn_str: | ||
| try: | ||
| AzureBlobClient(connection_string=conn_str).set_as_default_client() | ||
| logger.info("Azure auth: using connection string") | ||
| return | ||
| except Exception as e: | ||
| logger.warning(f"Connection string auth failed ({e}), falling back to managed identity") | ||
|
|
||
| # Need account_url for credential-based auth | ||
| account_name = os.environ.get("AZURE_STORAGE_ACCOUNT_NAME") | ||
| if not account_name: | ||
| logger.info("No AZURE_STORAGE_CONNECTION_STRING or AZURE_STORAGE_ACCOUNT_NAME set, skipping Azure client setup") | ||
| return | ||
| account_url = f"https://{account_name}.blob.core.windows.net" | ||
|
|
||
| # 2. Managed identity | ||
| client_id = os.environ.get("AZURE_CLIENT_ID") | ||
| try: | ||
| from azure.identity import ManagedIdentityCredential | ||
| credential = ManagedIdentityCredential(client_id=client_id) if client_id else ManagedIdentityCredential() | ||
| AzureBlobClient(account_url=account_url, credential=credential).set_as_default_client() | ||
| logger.info(f"Azure auth: using ManagedIdentityCredential (client_id={client_id or 'system-assigned'})") | ||
| return | ||
| except Exception as e: | ||
| logger.warning(f"ManagedIdentityCredential failed ({e}), falling back to DefaultAzureCredential") | ||
|
|
||
| from azure.identity import DefaultAzureCredential | ||
| credential = DefaultAzureCredential() | ||
| AzureBlobClient(account_url=account_url, credential=credential).set_as_default_client() | ||
| logger.info("Azure auth: using DefaultAzureCredential") | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"Failed to set up AzureBlobClient: {e}") |
There was a problem hiding this comment.
The azure.identity and cloudpathlib imports are inside the _ensure_azure_default_client function. For better code structure, performance, and to avoid potential import issues, these should be moved to the top of the file. Optional dependencies can be handled gracefully with a try...except ImportError block at the module level.
| from tx.tinker.config import _ensure_azure_default_client | ||
| _ensure_azure_default_client() |
| if not account_name and not conn_str: | ||
| raise ValueError( | ||
| "Either AZURE_STORAGE_CONNECTION_STRING or AZURE_STORAGE_ACCOUNT_NAME " | ||
| "must be set for Azure storage" | ||
| ) |
There was a problem hiding this comment.
🔴 Azure _get_filesystem silently falls through to unauthenticated filesystem when connection string fails and no account name is set
When AZURE_STORAGE_CONNECTION_STRING is set but the fsspec.filesystem() call with it fails, and AZURE_STORAGE_ACCOUNT_NAME is not set, the function silently falls through to return fsspec.filesystem(proto) (an unauthenticated filesystem) instead of raising an error.
Root Cause
The guard on line 45 is if not account_name and not conn_str. When conn_str is truthy (it's set, even though it failed), this condition evaluates to False, so the ValueError is never raised. Then if account_name: on line 51 is also False, so the entire Azure block is skipped. Execution falls through to line 61: return fsspec.filesystem(proto), which creates an unauthenticated Azure filesystem.
Trace:
conn_stris set → entersif conn_str:block →fsspec.filesystem(...)raises → warning loggedaccount_nameisNonenot account_name and not conn_str→True and False→False→ ValueError skippedif account_name:→False→ credential-based auth skippedreturn fsspec.filesystem(proto)→ unauthenticated filesystem returned
Impact: Azure storage operations will fail with confusing authentication errors downstream, or worse, silently access the wrong (public) data. The intended error message is never shown.
| if not account_name and not conn_str: | |
| raise ValueError( | |
| "Either AZURE_STORAGE_CONNECTION_STRING or AZURE_STORAGE_ACCOUNT_NAME " | |
| "must be set for Azure storage" | |
| ) | |
| if not account_name: | |
| raise ValueError( | |
| "AZURE_STORAGE_CONNECTION_STRING failed and AZURE_STORAGE_ACCOUNT_NAME " | |
| "is not set. Cannot authenticate to Azure storage." | |
| ) | |
Was this helpful? React with 👍 or 👎 to provide feedback.
| train_dataset = dataset["train"].select(range(128)) | ||
| val_dataset = dataset["test"].select(range(64)) |
There was a problem hiding this comment.
🔴 GSM8K dataset script hardcodes truncation to 128 train / 64 test samples
The GSM8K dataset preparation script now unconditionally truncates the training set to 128 examples and the validation set to 64 examples. The GSM8K training set has 7,473 examples and the test set has 1,319 examples, so this discards >98% of the data.
Root Cause
Lines 51-52 add .select(range(128)) and .select(range(64)) directly after loading the dataset splits. This appears to be a debugging leftover that was accidentally included in the PR.
Before:
train_dataset = dataset["train"]
val_dataset = dataset["test"]After:
train_dataset = dataset["train"].select(range(128))
val_dataset = dataset["test"].select(range(64))The existing --max_train_dataset_length flag on line 54-56 already provides a proper mechanism for limiting dataset size. The hardcoded truncation makes that flag ineffective for values above 128.
Impact: Anyone using this script to prepare GSM8K data will silently get a tiny dataset, leading to undertrained models and misleading evaluation results.
| train_dataset = dataset["train"].select(range(128)) | |
| val_dataset = dataset["test"].select(range(64)) | |
| train_dataset = dataset["train"] | |
| val_dataset = dataset["test"] | |
Was this helpful? React with 👍 or 👎 to provide feedback.
| skyrl-train = { git = "https://github.com/chiragbhatt311/SkyRL.git", subdirectory = "skyrl-train", branch = "chirag/skyrl_train_multinode" } | ||
| #skyrl-train = { path = "../skyrl-train"} |
There was a problem hiding this comment.
🔴 skyrl-train dependency points to personal fork instead of official repository
The skyrl-train UV source in skyrl-tx/pyproject.toml was changed from the official NovaSky-AI/SkyRL repository to a personal fork chiragbhatt311/SkyRL.git on a feature branch.
Root Cause
Line 162 was changed from:
skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" }
to:
skyrl-train = { git = "https://github.com/chiragbhatt311/SkyRL.git", subdirectory = "skyrl-train", branch = "chirag/skyrl_train_multinode" }
The commented-out local path on line 163 (#skyrl-train = { path = "../skyrl-train"}) confirms this is a development artifact.
Impact: If merged, all skyrl-tx builds would depend on a personal fork that could diverge from upstream, be deleted, or force-pushed, breaking the dependency chain for all users.
| skyrl-train = { git = "https://github.com/chiragbhatt311/SkyRL.git", subdirectory = "skyrl-train", branch = "chirag/skyrl_train_multinode" } | |
| #skyrl-train = { path = "../skyrl-train"} | |
| skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" } | |
Was this helpful? React with 👍 or 👎 to provide feedback.
| else: | ||
| logger.warning(f"{azure_var} not found in os.environ, skipping") |
There was a problem hiding this comment.
🟡 Azure env vars emit warnings unconditionally for all non-Azure users
The prepare_runtime_environment function unconditionally warns when Azure environment variables are missing, unlike the pattern used for other optional env vars (WANDB, MLFLOW, Harbor) which only log when the variable is present.
Root Cause
At skyrl-train/skyrl_train/utils/utils.py:662-668, the loop iterates over all three Azure env vars and emits logger.warning(...) for each one not found. This runs for every Ray initialization, regardless of whether Azure storage is being used.
Compare with the WANDB/MLFLOW pattern at lines 650-660 which only logs when the variable is found:
if os.environ.get("WANDB_API_KEY"):
logger.info("Exporting wandb api key to ray runtime env")
env_vars["WANDB_API_KEY"] = os.environ["WANDB_API_KEY"]vs the Azure pattern:
for azure_var in (...):
if os.environ.get(azure_var):
env_vars[azure_var] = os.environ[azure_var]
else:
logger.warning(f"{azure_var} not found in os.environ, skipping")Impact: Every non-Azure user will see 3 spurious warning messages on every training run, adding noise to logs.
| else: | |
| logger.warning(f"{azure_var} not found in os.environ, skipping") | |
| logger.debug(f"{azure_var} not found in os.environ, skipping") | |
Was this helpful? React with 👍 or 👎 to provide feedback.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Uh oh!
There was an error while loading. Please reload this page.