Skip to content

[Draft] support for azure blob storage#1172

Open
chiragbhatt311 wants to merge 7 commits intoNovaSky-AI:mainfrom
chiragbhatt311:chirag/skyrl_train_multinode
Open

[Draft] support for azure blob storage#1172
chiragbhatt311 wants to merge 7 commits intoNovaSky-AI:mainfrom
chiragbhatt311:chirag/skyrl_train_multinode

Conversation

@chiragbhatt311
Copy link
Contributor

@chiragbhatt311 chiragbhatt311 commented Feb 18, 2026

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly enhances the system by adding support for Azure Blob Storage and introducing a new summarization example. A critical prompt injection vulnerability has been identified in the new summarization judge environment, which could allow the model to manipulate its own rewards. Furthermore, a critical issue exists with a dependency pointing to a personal fork that must be addressed, and several medium-severity suggestions have been provided to improve code quality, consistency, and reproducibility.

# For now, just always use the current main branch, later it will be better to pin it to a released version. For development, you
# can set it to your own development branch.
skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" }
skyrl-train = { git = "https://github.com/chiragbhatt311/SkyRL.git", subdirectory = "skyrl-train", branch = "chirag/skyrl_train_multinode" }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The skyrl-train dependency points to a personal fork and branch. This is a significant risk for maintainability and stability, and it should be reverted to the official repository path before this pull request is merged.

Suggested change
skyrl-train = { git = "https://github.com/chiragbhatt311/SkyRL.git", subdirectory = "skyrl-train", branch = "chirag/skyrl_train_multinode" }
skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" }

Comment on lines +173 to +182
def _build_grading_prompt(self, summary: str) -> str:
"""Build the user prompt for the grading API."""
return f"""## User Intent
{self._format_user_intent()}

### Original Document ###
{self.original_document}

### Summary ###
{summary}"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The SummarizationJudgeEnv class is vulnerable to prompt injection because it directly concatenates the model-generated summary into the prompt sent to the LLM grader. An attacker (or the model itself during training) can inject instructions into the summary to manipulate the grader's output and obtain higher rewards. This is a form of reward hacking that can compromise the training process and lead to the model learning to exploit the grading mechanism rather than performing the intended task.

To remediate this, use robust delimiters and clear instructions in the system prompt to ensure the grader treats the summary as data only. For example, wrap the summary in XML-like tags and instruct the grader to only evaluate the content within those tags.

Comment on lines +51 to +52
train_dataset = dataset["train"].select(range(128))
val_dataset = dataset["test"].select(range(64))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dataset is being truncated to a very small, hardcoded number of samples (128 for train, 64 for test). This is likely for debugging but significantly reduces the utility of the example script for other users. It would be better to either remove this truncation to use the full dataset by default, or make these sizes configurable via command-line arguments.

Suggested change
train_dataset = dataset["train"].select(range(128))
val_dataset = dataset["test"].select(range(64))
train_dataset = dataset["train"]
val_dataset = dataset["test"]

@@ -0,0 +1,68 @@
set -x

set -x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This set -x is redundant as it's already set on line 1. Please remove it to keep the script clean.


# NOTE (sumanthrh): `micro_train_batch_size_per_gpu` and `micro_forward_batch_size_per_gpu` can be tuned

DATA_DIR="/mnt/workspace/datasets/gsm8k_with_reward"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The DATA_DIR is hardcoded to an absolute path (/mnt/workspace/...) which is not portable across different environments. Consider using a path relative to $HOME or allowing it to be overridden by an environment variable to make the script more reusable.

Suggested change
DATA_DIR="/mnt/workspace/datasets/gsm8k_with_reward"
DATA_DIR="${DATA_DIR:-$HOME/data/gsm8k_with_reward}"

Comment on lines +210 to +211
print(f"Could not find score in response: {response_text[:200]}")
return 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Throughout this file, print() is used for logging errors and warnings (e.g., lines 210, 214, 220, 241). In a distributed application, it's better to use a structured logger like loguru (which seems to be used elsewhere in the project) for better log management and visibility. Please replace these print calls with logger calls (e.g., logger.warning(...)).

Comment on lines +54 to +60
from azure.identity import ManagedIdentityCredential
credential = ManagedIdentityCredential(client_id=client_id) if client_id else ManagedIdentityCredential()
return fsspec.filesystem(proto, account_name=account_name, credential=credential)
except Exception as e:
logger.warning(f"ManagedIdentityCredential failed ({e}), falling back to DefaultAzureCredential")
from azure.identity import DefaultAzureCredential
return fsspec.filesystem(proto, account_name=account_name, credential=DefaultAzureCredential())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The azure.identity imports are inside this function. It's a best practice to place all imports at the top of the file for clarity, performance, and to avoid circular import issues. Please move these imports to the module level. If they are optional dependencies, you can wrap them in a try...except ImportError block.

Comment on lines +15 to +61
def _ensure_azure_default_client():
"""Register a default AzureBlobClient for cloudpathlib.

Fallback chain:
1. Connection string (AZURE_STORAGE_CONNECTION_STRING)
2. Managed identity (AZURE_CLIENT_ID for user-assigned, otherwise system-assigned)
3. DefaultAzureCredential (CLI, env-vars, workload identity, etc.)

Must be called before cloudpathlib tries to parse az:// paths.
"""
try:
from cloudpathlib import AzureBlobClient

conn_str = os.environ.get("AZURE_STORAGE_CONNECTION_STRING")
if conn_str:
try:
AzureBlobClient(connection_string=conn_str).set_as_default_client()
logger.info("Azure auth: using connection string")
return
except Exception as e:
logger.warning(f"Connection string auth failed ({e}), falling back to managed identity")

# Need account_url for credential-based auth
account_name = os.environ.get("AZURE_STORAGE_ACCOUNT_NAME")
if not account_name:
logger.info("No AZURE_STORAGE_CONNECTION_STRING or AZURE_STORAGE_ACCOUNT_NAME set, skipping Azure client setup")
return
account_url = f"https://{account_name}.blob.core.windows.net"

# 2. Managed identity
client_id = os.environ.get("AZURE_CLIENT_ID")
try:
from azure.identity import ManagedIdentityCredential
credential = ManagedIdentityCredential(client_id=client_id) if client_id else ManagedIdentityCredential()
AzureBlobClient(account_url=account_url, credential=credential).set_as_default_client()
logger.info(f"Azure auth: using ManagedIdentityCredential (client_id={client_id or 'system-assigned'})")
return
except Exception as e:
logger.warning(f"ManagedIdentityCredential failed ({e}), falling back to DefaultAzureCredential")

from azure.identity import DefaultAzureCredential
credential = DefaultAzureCredential()
AzureBlobClient(account_url=account_url, credential=credential).set_as_default_client()
logger.info("Azure auth: using DefaultAzureCredential")

except Exception as e:
logger.error(f"Failed to set up AzureBlobClient: {e}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The azure.identity and cloudpathlib imports are inside the _ensure_azure_default_client function. For better code structure, performance, and to avoid potential import issues, these should be moved to the top of the file. Optional dependencies can be handled gracefully with a try...except ImportError block at the module level.

Comment on lines +718 to +719
from tx.tinker.config import _ensure_azure_default_client
_ensure_azure_default_client()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This local import should be moved to the top of the file to follow standard Python conventions. This improves readability and makes dependencies clear.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 4 potential issues.

View 6 additional findings in Devin Review.

Open in Devin Review

Comment on lines +45 to +49
if not account_name and not conn_str:
raise ValueError(
"Either AZURE_STORAGE_CONNECTION_STRING or AZURE_STORAGE_ACCOUNT_NAME "
"must be set for Azure storage"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Azure _get_filesystem silently falls through to unauthenticated filesystem when connection string fails and no account name is set

When AZURE_STORAGE_CONNECTION_STRING is set but the fsspec.filesystem() call with it fails, and AZURE_STORAGE_ACCOUNT_NAME is not set, the function silently falls through to return fsspec.filesystem(proto) (an unauthenticated filesystem) instead of raising an error.

Root Cause

The guard on line 45 is if not account_name and not conn_str. When conn_str is truthy (it's set, even though it failed), this condition evaluates to False, so the ValueError is never raised. Then if account_name: on line 51 is also False, so the entire Azure block is skipped. Execution falls through to line 61: return fsspec.filesystem(proto), which creates an unauthenticated Azure filesystem.

Trace:

  1. conn_str is set → enters if conn_str: block → fsspec.filesystem(...) raises → warning logged
  2. account_name is None
  3. not account_name and not conn_strTrue and FalseFalseValueError skipped
  4. if account_name:Falsecredential-based auth skipped
  5. return fsspec.filesystem(proto)unauthenticated filesystem returned

Impact: Azure storage operations will fail with confusing authentication errors downstream, or worse, silently access the wrong (public) data. The intended error message is never shown.

Suggested change
if not account_name and not conn_str:
raise ValueError(
"Either AZURE_STORAGE_CONNECTION_STRING or AZURE_STORAGE_ACCOUNT_NAME "
"must be set for Azure storage"
)
if not account_name:
raise ValueError(
"AZURE_STORAGE_CONNECTION_STRING failed and AZURE_STORAGE_ACCOUNT_NAME "
"is not set. Cannot authenticate to Azure storage."
)
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +51 to +52
train_dataset = dataset["train"].select(range(128))
val_dataset = dataset["test"].select(range(64))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 GSM8K dataset script hardcodes truncation to 128 train / 64 test samples

The GSM8K dataset preparation script now unconditionally truncates the training set to 128 examples and the validation set to 64 examples. The GSM8K training set has 7,473 examples and the test set has 1,319 examples, so this discards >98% of the data.

Root Cause

Lines 51-52 add .select(range(128)) and .select(range(64)) directly after loading the dataset splits. This appears to be a debugging leftover that was accidentally included in the PR.

Before:

train_dataset = dataset["train"]
val_dataset = dataset["test"]

After:

train_dataset = dataset["train"].select(range(128))
val_dataset = dataset["test"].select(range(64))

The existing --max_train_dataset_length flag on line 54-56 already provides a proper mechanism for limiting dataset size. The hardcoded truncation makes that flag ineffective for values above 128.

Impact: Anyone using this script to prepare GSM8K data will silently get a tiny dataset, leading to undertrained models and misleading evaluation results.

Suggested change
train_dataset = dataset["train"].select(range(128))
val_dataset = dataset["test"].select(range(64))
train_dataset = dataset["train"]
val_dataset = dataset["test"]
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +162 to +163
skyrl-train = { git = "https://github.com/chiragbhatt311/SkyRL.git", subdirectory = "skyrl-train", branch = "chirag/skyrl_train_multinode" }
#skyrl-train = { path = "../skyrl-train"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 skyrl-train dependency points to personal fork instead of official repository

The skyrl-train UV source in skyrl-tx/pyproject.toml was changed from the official NovaSky-AI/SkyRL repository to a personal fork chiragbhatt311/SkyRL.git on a feature branch.

Root Cause

Line 162 was changed from:

skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" }

to:

skyrl-train = { git = "https://github.com/chiragbhatt311/SkyRL.git", subdirectory = "skyrl-train", branch = "chirag/skyrl_train_multinode" }

The commented-out local path on line 163 (#skyrl-train = { path = "../skyrl-train"}) confirms this is a development artifact.

Impact: If merged, all skyrl-tx builds would depend on a personal fork that could diverge from upstream, be deleted, or force-pushed, breaking the dependency chain for all users.

Suggested change
skyrl-train = { git = "https://github.com/chiragbhatt311/SkyRL.git", subdirectory = "skyrl-train", branch = "chirag/skyrl_train_multinode" }
#skyrl-train = { path = "../skyrl-train"}
skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", subdirectory = "skyrl-train" }
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +667 to +668
else:
logger.warning(f"{azure_var} not found in os.environ, skipping")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Azure env vars emit warnings unconditionally for all non-Azure users

The prepare_runtime_environment function unconditionally warns when Azure environment variables are missing, unlike the pattern used for other optional env vars (WANDB, MLFLOW, Harbor) which only log when the variable is present.

Root Cause

At skyrl-train/skyrl_train/utils/utils.py:662-668, the loop iterates over all three Azure env vars and emits logger.warning(...) for each one not found. This runs for every Ray initialization, regardless of whether Azure storage is being used.

Compare with the WANDB/MLFLOW pattern at lines 650-660 which only logs when the variable is found:

if os.environ.get("WANDB_API_KEY"):
    logger.info("Exporting wandb api key to ray runtime env")
    env_vars["WANDB_API_KEY"] = os.environ["WANDB_API_KEY"]

vs the Azure pattern:

for azure_var in (...):
    if os.environ.get(azure_var):
        env_vars[azure_var] = os.environ[azure_var]
    else:
        logger.warning(f"{azure_var} not found in os.environ, skipping")

Impact: Every non-Azure user will see 3 spurious warning messages on every training run, adding noise to logs.

Suggested change
else:
logger.warning(f"{azure_var} not found in os.environ, skipping")
logger.debug(f"{azure_var} not found in os.environ, skipping")
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@chiragbhatt311 chiragbhatt311 changed the title support for azure blob storage [Draft] support for azure blob storage Feb 18, 2026
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant