From 38900e915488f88c784c7d24b1f1857bf90e728a Mon Sep 17 00:00:00 2001 From: Ayushh Garg Date: Sat, 14 Mar 2026 01:25:28 +0530 Subject: [PATCH 1/4] Path traversal protection in unzip to temp file --- .../ai/ml/operations/_local_job_invoker.py | 43 ++++++- sdk/ml/azure-ai-ml/python | 0 .../unittests/test_local_job_invoker.py | 118 +++++++++++++++++- 3 files changed, 158 insertions(+), 3 deletions(-) create mode 100644 sdk/ml/azure-ai-ml/python diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py index b7436d307902..c5d8b02793ac 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py @@ -41,7 +41,15 @@ def unzip_to_temporary_file(job_definition: JobBaseData, zip_content: Any) -> Path: temp_dir = Path(tempfile.gettempdir(), AZUREML_RUNS_DIR, job_definition.name) temp_dir.mkdir(parents=True, exist_ok=True) + resolved_temp_dir = temp_dir.resolve() with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_ref: + for member in zip_ref.namelist(): + member_path = (resolved_temp_dir / member).resolve() + # Ensure the member extracts within temp_dir (allow temp_dir itself for directory entries) + if member_path != resolved_temp_dir and not str(member_path).startswith( + str(resolved_temp_dir) + os.sep + ): + raise ValueError(f"Zip archive contains a path traversal entry and cannot be extracted safely: {member}") zip_ref.extractall(temp_dir) return temp_dir @@ -166,6 +174,38 @@ def is_local_run(job_definition: JobBaseData) -> bool: local = job_definition.properties.services.get("Local", None) return local is not None and EXECUTION_SERVICE_URL_KEY in local.endpoint +def _safe_tar_extractall(tar: tarfile.TarFile, dest_dir: str) -> None: + """Extract tar archive members safely, preventing path traversal (TarSlip). + + On Python 3.12+, uses the built-in 'data' filter. On older versions, + manually validates each member to ensure no path traversal, symlinks, + or hard links that could write outside the destination directory. + + :param tar: An opened tarfile.TarFile object. + :type tar: tarfile.TarFile + :param dest_dir: The destination directory for extraction. + :type dest_dir: str + :raises ValueError: If a tar member would escape the destination directory + or contains a symlink/hard link. + """ + resolved_dest = os.path.realpath(dest_dir) + + # Python 3.12+ has built-in data_filter for safe extraction + if hasattr(tarfile, "data_filter"): + tar.extractall(resolved_dest, filter="data") + else: + for member in tar.getmembers(): + if member.issym() or member.islnk(): + raise ValueError( + f"Tar archive contains a symbolic or hard link and cannot be extracted safely: {member.name}" + ) + member_path = os.path.realpath(os.path.join(resolved_dest, member.name)) + if member_path != resolved_dest and not member_path.startswith(resolved_dest + os.sep): + raise ValueError( + f"Tar archive contains a path traversal entry and cannot be extracted safely: {member.name}" + ) + # All members validated; safe to extract + tar.extractall(resolved_dest) class CommonRuntimeHelper: COMMON_RUNTIME_BOOTSTRAPPER_INFO = "common_runtime_bootstrapper_info.json" @@ -266,8 +306,7 @@ def copy_bootstrapper_from_container(self, container: "docker.models.containers. for chunk in data_stream: f.write(chunk) with tarfile.open(tar_file, mode="r") as tar: - for file_name in tar.getnames(): - tar.extract(file_name, os.path.dirname(path_in_host)) + _safe_tar_extractall(tar, os.path.dirname(path_in_host)) os.remove(tar_file) except docker.errors.APIError as e: msg = f"Copying {path_in_container} from container has failed. Detailed message: {e}" diff --git a/sdk/ml/azure-ai-ml/python b/sdk/ml/azure-ai-ml/python new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py b/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py index 53f1ab6db677..79c90a33405a 100644 --- a/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py +++ b/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py @@ -1,14 +1,19 @@ +import io import os +import tarfile import tempfile from pathlib import Path +from unittest.mock import MagicMock import pytest from azure.ai.ml.operations._local_job_invoker import ( _get_creationflags_and_startupinfo_for_background_process, + _safe_tar_extractall, patch_invocation_script_serialization, + unzip_to_temporary_file, ) - +import zipfile @pytest.mark.unittest @pytest.mark.training_experiences_test @@ -61,3 +66,114 @@ def test_creation_flags(self): flags = _get_creationflags_and_startupinfo_for_background_process("linux") assert flags == {"stderr": -2, "stdin": -3, "stdout": -3} + +def _make_job_definition(name="test-run"): + job_def = MagicMock() + job_def.name = name + return job_def + +@pytest.mark.unittest +@pytest.mark.training_experiences_test +class TestUnzipPathTraversalPrevention: + """Tests for ZIP path traversal prevention in unzip_to_temporary_file.""" + + def test_normal_zip_extracts_successfully(self): + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n") + zf.writestr("azureml-setup/config.json", '{"key": "value"}') + zip_bytes = buf.getvalue() + + job_def = _make_job_definition("safe-run") + result = unzip_to_temporary_file(job_def, zip_bytes) + + assert result.exists() + assert (result / "azureml-setup" / "invocation.sh").exists() + assert (result / "azureml-setup" / "config.json").exists() + + def test_zip_with_path_traversal_is_rejected(self): + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n") + zf.writestr("../../etc/evil.sh", "#!/bin/bash\necho pwned\n") + zip_bytes = buf.getvalue() + + job_def = _make_job_definition("traversal-run") + with pytest.raises(ValueError, match="path traversal"): + unzip_to_temporary_file(job_def, zip_bytes) + + def test_zip_with_absolute_path_is_rejected(self): + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + if os.name == "nt": + zf.writestr("C:/Windows/Temp/evil.sh", "#!/bin/bash\necho pwned\n") + else: + zf.writestr("/tmp/evil.sh", "#!/bin/bash\necho pwned\n") + zip_bytes = buf.getvalue() + + job_def = _make_job_definition("absolute-path-run") + with pytest.raises(ValueError, match="path traversal"): + unzip_to_temporary_file(job_def, zip_bytes) + + +@pytest.mark.unittest +@pytest.mark.training_experiences_test +class TestSafeTarExtract: + """Tests for tar path traversal prevention in _safe_tar_extractall.""" + + def test_normal_tar_extracts_successfully(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + data = b"normal content" + info = tarfile.TarInfo(name="vm-bootstrapper") + info.size = len(data) + tar.addfile(info, io.BytesIO(data)) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + _safe_tar_extractall(tar, dest) + + assert os.path.exists(os.path.join(dest, "vm-bootstrapper")) + + def test_tar_with_path_traversal_is_rejected(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + data = b"evil content" + info = tarfile.TarInfo(name="../../evil_script.sh") + info.size = len(data) + tar.addfile(info, io.BytesIO(data)) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + with pytest.raises((ValueError, Exception)): + _safe_tar_extractall(tar, dest) + + def test_tar_with_symlink_is_rejected(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + info = tarfile.TarInfo(name="evil_link") + info.type = tarfile.SYMTYPE + info.linkname = "/etc/passwd" + tar.addfile(info) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + with pytest.raises((ValueError, Exception)): + _safe_tar_extractall(tar, dest) + + def test_tar_with_hardlink_is_rejected(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + info = tarfile.TarInfo(name="evil_hardlink") + info.type = tarfile.LNKTYPE + info.linkname = "/etc/shadow" + tar.addfile(info) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + with pytest.raises((ValueError, Exception)): + _safe_tar_extractall(tar, dest) \ No newline at end of file From 4dbd7b22215099031f4ea878aed23a20e89861ec Mon Sep 17 00:00:00 2001 From: Ayushh Garg Date: Mon, 16 Mar 2026 01:19:52 +0530 Subject: [PATCH 2/4] ghcp changes --- .../ai/ml/operations/_local_job_invoker.py | 17 ++++++++++++--- sdk/ml/azure-ai-ml/python | 0 .../unittests/test_local_job_invoker.py | 21 ++++++++++++------- 3 files changed, 27 insertions(+), 11 deletions(-) delete mode 100644 sdk/ml/azure-ai-ml/python diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py index c5d8b02793ac..a45bdafd051e 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py @@ -179,26 +179,37 @@ def _safe_tar_extractall(tar: tarfile.TarFile, dest_dir: str) -> None: On Python 3.12+, uses the built-in 'data' filter. On older versions, manually validates each member to ensure no path traversal, symlinks, - or hard links that could write outside the destination directory. + hard links, or other special entries that could write outside the + destination directory or create unsafe filesystem nodes. :param tar: An opened tarfile.TarFile object. :type tar: tarfile.TarFile :param dest_dir: The destination directory for extraction. :type dest_dir: str :raises ValueError: If a tar member would escape the destination directory - or contains a symlink/hard link. + or contains a symlink, hard link, or unsupported special entry type. """ resolved_dest = os.path.realpath(dest_dir) # Python 3.12+ has built-in data_filter for safe extraction if hasattr(tarfile, "data_filter"): - tar.extractall(resolved_dest, filter="data") + try: + tar.extractall(resolved_dest, filter="data") + except tarfile.TarError as exc: + raise ValueError(f"Failed to safely extract tar archive: {exc}") from exc else: for member in tar.getmembers(): + # Reject symbolic and hard links if member.issym() or member.islnk(): raise ValueError( f"Tar archive contains a symbolic or hard link and cannot be extracted safely: {member.name}" ) + # Reject any non-regular, non-directory entries (e.g., devices, FIFOs) + if not (member.isfile() or member.isdir()): + raise ValueError( + f"Tar archive contains an unsupported special entry type and cannot be extracted safely: " + f"{member.name}" + ) member_path = os.path.realpath(os.path.join(resolved_dest, member.name)) if member_path != resolved_dest and not member_path.startswith(resolved_dest + os.sep): raise ValueError( diff --git a/sdk/ml/azure-ai-ml/python b/sdk/ml/azure-ai-ml/python deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py b/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py index 79c90a33405a..0ae1f20e905f 100644 --- a/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py +++ b/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py @@ -1,7 +1,9 @@ import io import os +import shutil import tarfile import tempfile +import zipfile from pathlib import Path from unittest.mock import MagicMock @@ -13,7 +15,6 @@ patch_invocation_script_serialization, unzip_to_temporary_file, ) -import zipfile @pytest.mark.unittest @pytest.mark.training_experiences_test @@ -66,12 +67,12 @@ def test_creation_flags(self): flags = _get_creationflags_and_startupinfo_for_background_process("linux") assert flags == {"stderr": -2, "stdin": -3, "stdout": -3} - def _make_job_definition(name="test-run"): job_def = MagicMock() job_def.name = name return job_def + @pytest.mark.unittest @pytest.mark.training_experiences_test class TestUnzipPathTraversalPrevention: @@ -87,9 +88,13 @@ def test_normal_zip_extracts_successfully(self): job_def = _make_job_definition("safe-run") result = unzip_to_temporary_file(job_def, zip_bytes) - assert result.exists() - assert (result / "azureml-setup" / "invocation.sh").exists() - assert (result / "azureml-setup" / "config.json").exists() + try: + assert result.exists() + assert (result / "azureml-setup" / "invocation.sh").exists() + assert (result / "azureml-setup" / "config.json").exists() + finally: + if result.exists(): + shutil.rmtree(result, ignore_errors=True) def test_zip_with_path_traversal_is_rejected(self): buf = io.BytesIO() @@ -147,7 +152,7 @@ def test_tar_with_path_traversal_is_rejected(self): buf.seek(0) with tarfile.open(fileobj=buf, mode="r") as tar: - with pytest.raises((ValueError, Exception)): + with pytest.raises(ValueError): _safe_tar_extractall(tar, dest) def test_tar_with_symlink_is_rejected(self): @@ -161,7 +166,7 @@ def test_tar_with_symlink_is_rejected(self): buf.seek(0) with tarfile.open(fileobj=buf, mode="r") as tar: - with pytest.raises((ValueError, Exception)): + with pytest.raises(ValueError): _safe_tar_extractall(tar, dest) def test_tar_with_hardlink_is_rejected(self): @@ -175,5 +180,5 @@ def test_tar_with_hardlink_is_rejected(self): buf.seek(0) with tarfile.open(fileobj=buf, mode="r") as tar: - with pytest.raises((ValueError, Exception)): + with pytest.raises(ValueError): _safe_tar_extractall(tar, dest) \ No newline at end of file From 74f0132aacb771121cb0c899ff17322859dbb4fe Mon Sep 17 00:00:00 2001 From: Ayushh Garg Date: Mon, 16 Mar 2026 09:44:05 +0530 Subject: [PATCH 3/4] fix black formatting and pylint errors --- .../ai/ml/operations/_local_job_invoker.py | 97 ++++++++++++++----- .../unittests/test_local_job_invoker.py | 5 +- 2 files changed, 77 insertions(+), 25 deletions(-) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py index a45bdafd051e..497476198a25 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py @@ -49,7 +49,9 @@ def unzip_to_temporary_file(job_definition: JobBaseData, zip_content: Any) -> Pa if member_path != resolved_temp_dir and not str(member_path).startswith( str(resolved_temp_dir) + os.sep ): - raise ValueError(f"Zip archive contains a path traversal entry and cannot be extracted safely: {member}") + raise ValueError( + f"Zip archive contains a path traversal entry and cannot be extracted safely: {member}" + ) zip_ref.extractall(temp_dir) return temp_dir @@ -102,19 +104,25 @@ def patch_invocation_script_serialization(invocation_path: Path) -> None: if searchRes: patched_json = searchRes.group(2).replace('"', '\\"') patched_json = patched_json.replace("'", '"') - invocation_path.write_text(searchRes.group(1) + patched_json + searchRes.group(3)) + invocation_path.write_text( + searchRes.group(1) + patched_json + searchRes.group(3) + ) def invoke_command(project_temp_dir: Path) -> None: if os.name == "nt": - invocation_script = project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BAT_FILE + invocation_script = ( + project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BAT_FILE + ) # There is a bug in Execution service on the serialized json for snapshots. # This is a client-side patch until the service fixes it, at which point it should # be a no-op patch_invocation_script_serialization(invocation_script) invoked_command = ["cmd.exe", "/c", "{0}".format(invocation_script)] else: - invocation_script = project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BASH_FILE + invocation_script = ( + project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BASH_FILE + ) subprocess.check_output(["chmod", "+x", invocation_script]) invoked_command = ["/bin/bash", "-c", "{0}".format(invocation_script)] @@ -150,10 +158,12 @@ def get_execution_service_response( try: local = job_definition.properties.services.get("Local", None) - (url, encodedBody) = local.endpoint.split(EXECUTION_SERVICE_URL_KEY) + url, encodedBody = local.endpoint.split(EXECUTION_SERVICE_URL_KEY) body = urllib.parse.unquote_plus(encodedBody) body_dict: Dict = json.loads(body) - response = requests_pipeline.post(url, json=body_dict, headers={"Authorization": "Bearer " + token}) + response = requests_pipeline.post( + url, json=body_dict, headers={"Authorization": "Bearer " + token} + ) response.raise_for_status() return (response.content, body_dict.get("SnapshotId", None)) # type: ignore[return-value] except AzureError as err: @@ -174,6 +184,7 @@ def is_local_run(job_definition: JobBaseData) -> bool: local = job_definition.properties.services.get("Local", None) return local is not None and EXECUTION_SERVICE_URL_KEY in local.endpoint + def _safe_tar_extractall(tar: tarfile.TarFile, dest_dir: str) -> None: """Extract tar archive members safely, preventing path traversal (TarSlip). @@ -211,13 +222,16 @@ def _safe_tar_extractall(tar: tarfile.TarFile, dest_dir: str) -> None: f"{member.name}" ) member_path = os.path.realpath(os.path.join(resolved_dest, member.name)) - if member_path != resolved_dest and not member_path.startswith(resolved_dest + os.sep): + if member_path != resolved_dest and not member_path.startswith( + resolved_dest + os.sep + ): raise ValueError( f"Tar archive contains a path traversal entry and cannot be extracted safely: {member.name}" ) # All members validated; safe to extract tar.extractall(resolved_dest) + class CommonRuntimeHelper: COMMON_RUNTIME_BOOTSTRAPPER_INFO = "common_runtime_bootstrapper_info.json" COMMON_RUNTIME_JOB_SPEC = "common_runtime_job_spec.json" @@ -243,14 +257,18 @@ class CommonRuntimeHelper: "Unable to communicate with Docker daemon. Is Docker running/installed?\n " "For local submissions, we need to build a Docker container to run your job in.\n Detailed message: {}" ) - DOCKER_LOGIN_FAILURE_MSG = "Login to Docker registry '{}' failed. See error message: {}" + DOCKER_LOGIN_FAILURE_MSG = ( + "Login to Docker registry '{}' failed. See error message: {}" + ) BOOTSTRAP_BINARY_FAILURE_MSG = ( "Azure Common Runtime execution failed. See detailed message below for troubleshooting " "information or re-submit with flag --use-local-runtime to try running on your local runtime: {}" ) def __init__(self, job_name: str): - self.common_runtime_temp_folder = os.path.join(Path.home(), ".azureml-common-runtime", job_name) + self.common_runtime_temp_folder = os.path.join( + Path.home(), ".azureml-common-runtime", job_name + ) if os.path.exists(self.common_runtime_temp_folder): shutil.rmtree(self.common_runtime_temp_folder) Path(self.common_runtime_temp_folder).mkdir(parents=True) @@ -259,10 +277,14 @@ def __init__(self, job_name: str): CommonRuntimeHelper.VM_BOOTSTRAPPER_FILE_NAME, ) self.stdout = open( # pylint: disable=consider-using-with - os.path.join(self.common_runtime_temp_folder, "stdout"), "w+", encoding=DefaultOpenEncoding.WRITE + os.path.join(self.common_runtime_temp_folder, "stdout"), + "w+", + encoding=DefaultOpenEncoding.WRITE, ) self.stderr = open( # pylint: disable=consider-using-with - os.path.join(self.common_runtime_temp_folder, "stderr"), "w+", encoding=DefaultOpenEncoding.WRITE + os.path.join(self.common_runtime_temp_folder, "stderr"), + "w+", + encoding=DefaultOpenEncoding.WRITE, ) # Bug Item number: 2885723 @@ -294,9 +316,13 @@ def get_docker_client(self, registry: Dict) -> "docker.DockerClient": # type: i registry=registry.get("url"), ) except Exception as e: - raise RuntimeError(self.DOCKER_LOGIN_FAILURE_MSG.format(registry.get("url"), e)) from e + raise RuntimeError( + self.DOCKER_LOGIN_FAILURE_MSG.format(registry.get("url"), e) + ) from e else: - raise RuntimeError("Registry information is missing from bootstrapper configuration.") + raise RuntimeError( + "Registry information is missing from bootstrapper configuration." + ) return client @@ -323,7 +349,9 @@ def copy_bootstrapper_from_container(self, container: "docker.models.containers. msg = f"Copying {path_in_container} from container has failed. Detailed message: {e}" raise MlException(message=msg, no_personal_data_message=msg) from e - def get_common_runtime_info_from_response(self, response: Any) -> Tuple[Dict[str, str], str]: + def get_common_runtime_info_from_response( + self, response: Any + ) -> Tuple[Dict[str, str], str]: """Extract common-runtime info from Execution Service response. :param response: Content of zip file from Execution Service containing all the @@ -334,13 +362,22 @@ def get_common_runtime_info_from_response(self, response: Any) -> Tuple[Dict[str """ with zipfile.ZipFile(io.BytesIO(response)) as zip_ref: - bootstrapper_path = f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_BOOTSTRAPPER_INFO}" + bootstrapper_path = ( + f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_BOOTSTRAPPER_INFO}" + ) job_spec_path = f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_JOB_SPEC}" - if not all(file_path in zip_ref.namelist() for file_path in [bootstrapper_path, job_spec_path]): - raise RuntimeError(f"{bootstrapper_path}, {job_spec_path} are not in the execution service response.") + if not all( + file_path in zip_ref.namelist() + for file_path in [bootstrapper_path, job_spec_path] + ): + raise RuntimeError( + f"{bootstrapper_path}, {job_spec_path} are not in the execution service response." + ) with zip_ref.open(bootstrapper_path, "r") as bootstrapper_file: - bootstrapper_json = json.loads(base64.b64decode(bootstrapper_file.read())) + bootstrapper_json = json.loads( + base64.b64decode(bootstrapper_file.read()) + ) with zip_ref.open(job_spec_path, "r") as job_spec_file: job_spec = job_spec_file.read().decode("utf-8") @@ -362,9 +399,13 @@ def get_bootstrapper_binary(self, bootstrapper_info: Dict) -> None: tag = bootstrapper_info.get("tag") if repo_prefix: - bootstrapper_image = f"{repository}/{repo_prefix}/boot/vm-bootstrapper/binimage/linux:{tag}" + bootstrapper_image = ( + f"{repository}/{repo_prefix}/boot/vm-bootstrapper/binimage/linux:{tag}" + ) else: - bootstrapper_image = f"{repository}/boot/vm-bootstrapper/binimage/linux:{tag}" + bootstrapper_image = ( + f"{repository}/boot/vm-bootstrapper/binimage/linux:{tag}" + ) try: boot_img = docker_client.images.pull(bootstrapper_image) @@ -378,7 +419,9 @@ def get_bootstrapper_binary(self, bootstrapper_info: Dict) -> None: boot_container.stop() boot_container.remove() - def execute_bootstrapper(self, bootstrapper_binary: str, job_spec: str) -> subprocess.Popen: + def execute_bootstrapper( + self, bootstrapper_binary: str, job_spec: str + ) -> subprocess.Popen: """Runs vm-bootstrapper with the job specification passed to it. This will build the Docker container, create all necessary files and directories, and run the job locally. Command args are defined by Common Runtime team here: https://msdata.visualstudio.com/Vienna/_git/vienna?path=/src/azureml- job-runtime/common- @@ -422,7 +465,9 @@ def execute_bootstrapper(self, bootstrapper_binary: str, job_spec: str) -> subpr process.kill() raise RuntimeError(LOCAL_JOB_FAILURE_MSG.format(self.stderr.read())) - def check_bootstrapper_process_status(self, bootstrapper_process: subprocess.Popen) -> Optional[int]: + def check_bootstrapper_process_status( + self, bootstrapper_process: subprocess.Popen + ) -> Optional[int]: """Check if bootstrapper process status is non-zero. :param bootstrapper_process: bootstrapper process @@ -433,7 +478,9 @@ def check_bootstrapper_process_status(self, bootstrapper_process: subprocess.Pop return_code = bootstrapper_process.poll() if return_code: self.stderr.seek(0) - raise RuntimeError(self.BOOTSTRAP_BINARY_FAILURE_MSG.format(self.stderr.read())) + raise RuntimeError( + self.BOOTSTRAP_BINARY_FAILURE_MSG.format(self.stderr.read()) + ) return return_code @@ -458,7 +505,9 @@ def start_run_if_local( :rtype: str """ token = credential.get_token(ws_base_url + "/.default").token - (zip_content, snapshot_id) = get_execution_service_response(job_definition, token, requests_pipeline) + zip_content, snapshot_id = get_execution_service_response( + job_definition, token, requests_pipeline + ) try: temp_dir = unzip_to_temporary_file(job_definition, zip_content) diff --git a/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py b/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py index 0ae1f20e905f..cca25c7313c6 100644 --- a/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py +++ b/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py @@ -16,6 +16,7 @@ unzip_to_temporary_file, ) + @pytest.mark.unittest @pytest.mark.training_experiences_test class TestLocalJobInvoker: @@ -67,6 +68,8 @@ def test_creation_flags(self): flags = _get_creationflags_and_startupinfo_for_background_process("linux") assert flags == {"stderr": -2, "stdin": -3, "stdout": -3} + + def _make_job_definition(name="test-run"): job_def = MagicMock() job_def.name = name @@ -181,4 +184,4 @@ def test_tar_with_hardlink_is_rejected(self): with tarfile.open(fileobj=buf, mode="r") as tar: with pytest.raises(ValueError): - _safe_tar_extractall(tar, dest) \ No newline at end of file + _safe_tar_extractall(tar, dest) From 5828aa69e6b3d35c8df870f0708090ff91cb06f0 Mon Sep 17 00:00:00 2001 From: Ayushh Garg Date: Wed, 18 Mar 2026 13:24:13 +0530 Subject: [PATCH 4/4] tox formatting --- .../ai/ml/operations/_local_job_invoker.py | 85 +++++-------------- 1 file changed, 21 insertions(+), 64 deletions(-) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py index 497476198a25..f901d2295817 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py @@ -46,9 +46,7 @@ def unzip_to_temporary_file(job_definition: JobBaseData, zip_content: Any) -> Pa for member in zip_ref.namelist(): member_path = (resolved_temp_dir / member).resolve() # Ensure the member extracts within temp_dir (allow temp_dir itself for directory entries) - if member_path != resolved_temp_dir and not str(member_path).startswith( - str(resolved_temp_dir) + os.sep - ): + if member_path != resolved_temp_dir and not str(member_path).startswith(str(resolved_temp_dir) + os.sep): raise ValueError( f"Zip archive contains a path traversal entry and cannot be extracted safely: {member}" ) @@ -104,25 +102,19 @@ def patch_invocation_script_serialization(invocation_path: Path) -> None: if searchRes: patched_json = searchRes.group(2).replace('"', '\\"') patched_json = patched_json.replace("'", '"') - invocation_path.write_text( - searchRes.group(1) + patched_json + searchRes.group(3) - ) + invocation_path.write_text(searchRes.group(1) + patched_json + searchRes.group(3)) def invoke_command(project_temp_dir: Path) -> None: if os.name == "nt": - invocation_script = ( - project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BAT_FILE - ) + invocation_script = project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BAT_FILE # There is a bug in Execution service on the serialized json for snapshots. # This is a client-side patch until the service fixes it, at which point it should # be a no-op patch_invocation_script_serialization(invocation_script) invoked_command = ["cmd.exe", "/c", "{0}".format(invocation_script)] else: - invocation_script = ( - project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BASH_FILE - ) + invocation_script = project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BASH_FILE subprocess.check_output(["chmod", "+x", invocation_script]) invoked_command = ["/bin/bash", "-c", "{0}".format(invocation_script)] @@ -161,9 +153,7 @@ def get_execution_service_response( url, encodedBody = local.endpoint.split(EXECUTION_SERVICE_URL_KEY) body = urllib.parse.unquote_plus(encodedBody) body_dict: Dict = json.loads(body) - response = requests_pipeline.post( - url, json=body_dict, headers={"Authorization": "Bearer " + token} - ) + response = requests_pipeline.post(url, json=body_dict, headers={"Authorization": "Bearer " + token}) response.raise_for_status() return (response.content, body_dict.get("SnapshotId", None)) # type: ignore[return-value] except AzureError as err: @@ -222,9 +212,7 @@ def _safe_tar_extractall(tar: tarfile.TarFile, dest_dir: str) -> None: f"{member.name}" ) member_path = os.path.realpath(os.path.join(resolved_dest, member.name)) - if member_path != resolved_dest and not member_path.startswith( - resolved_dest + os.sep - ): + if member_path != resolved_dest and not member_path.startswith(resolved_dest + os.sep): raise ValueError( f"Tar archive contains a path traversal entry and cannot be extracted safely: {member.name}" ) @@ -257,18 +245,14 @@ class CommonRuntimeHelper: "Unable to communicate with Docker daemon. Is Docker running/installed?\n " "For local submissions, we need to build a Docker container to run your job in.\n Detailed message: {}" ) - DOCKER_LOGIN_FAILURE_MSG = ( - "Login to Docker registry '{}' failed. See error message: {}" - ) + DOCKER_LOGIN_FAILURE_MSG = "Login to Docker registry '{}' failed. See error message: {}" BOOTSTRAP_BINARY_FAILURE_MSG = ( "Azure Common Runtime execution failed. See detailed message below for troubleshooting " "information or re-submit with flag --use-local-runtime to try running on your local runtime: {}" ) def __init__(self, job_name: str): - self.common_runtime_temp_folder = os.path.join( - Path.home(), ".azureml-common-runtime", job_name - ) + self.common_runtime_temp_folder = os.path.join(Path.home(), ".azureml-common-runtime", job_name) if os.path.exists(self.common_runtime_temp_folder): shutil.rmtree(self.common_runtime_temp_folder) Path(self.common_runtime_temp_folder).mkdir(parents=True) @@ -316,13 +300,9 @@ def get_docker_client(self, registry: Dict) -> "docker.DockerClient": # type: i registry=registry.get("url"), ) except Exception as e: - raise RuntimeError( - self.DOCKER_LOGIN_FAILURE_MSG.format(registry.get("url"), e) - ) from e + raise RuntimeError(self.DOCKER_LOGIN_FAILURE_MSG.format(registry.get("url"), e)) from e else: - raise RuntimeError( - "Registry information is missing from bootstrapper configuration." - ) + raise RuntimeError("Registry information is missing from bootstrapper configuration.") return client @@ -349,9 +329,7 @@ def copy_bootstrapper_from_container(self, container: "docker.models.containers. msg = f"Copying {path_in_container} from container has failed. Detailed message: {e}" raise MlException(message=msg, no_personal_data_message=msg) from e - def get_common_runtime_info_from_response( - self, response: Any - ) -> Tuple[Dict[str, str], str]: + def get_common_runtime_info_from_response(self, response: Any) -> Tuple[Dict[str, str], str]: """Extract common-runtime info from Execution Service response. :param response: Content of zip file from Execution Service containing all the @@ -362,22 +340,13 @@ def get_common_runtime_info_from_response( """ with zipfile.ZipFile(io.BytesIO(response)) as zip_ref: - bootstrapper_path = ( - f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_BOOTSTRAPPER_INFO}" - ) + bootstrapper_path = f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_BOOTSTRAPPER_INFO}" job_spec_path = f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_JOB_SPEC}" - if not all( - file_path in zip_ref.namelist() - for file_path in [bootstrapper_path, job_spec_path] - ): - raise RuntimeError( - f"{bootstrapper_path}, {job_spec_path} are not in the execution service response." - ) + if not all(file_path in zip_ref.namelist() for file_path in [bootstrapper_path, job_spec_path]): + raise RuntimeError(f"{bootstrapper_path}, {job_spec_path} are not in the execution service response.") with zip_ref.open(bootstrapper_path, "r") as bootstrapper_file: - bootstrapper_json = json.loads( - base64.b64decode(bootstrapper_file.read()) - ) + bootstrapper_json = json.loads(base64.b64decode(bootstrapper_file.read())) with zip_ref.open(job_spec_path, "r") as job_spec_file: job_spec = job_spec_file.read().decode("utf-8") @@ -399,13 +368,9 @@ def get_bootstrapper_binary(self, bootstrapper_info: Dict) -> None: tag = bootstrapper_info.get("tag") if repo_prefix: - bootstrapper_image = ( - f"{repository}/{repo_prefix}/boot/vm-bootstrapper/binimage/linux:{tag}" - ) + bootstrapper_image = f"{repository}/{repo_prefix}/boot/vm-bootstrapper/binimage/linux:{tag}" else: - bootstrapper_image = ( - f"{repository}/boot/vm-bootstrapper/binimage/linux:{tag}" - ) + bootstrapper_image = f"{repository}/boot/vm-bootstrapper/binimage/linux:{tag}" try: boot_img = docker_client.images.pull(bootstrapper_image) @@ -419,9 +384,7 @@ def get_bootstrapper_binary(self, bootstrapper_info: Dict) -> None: boot_container.stop() boot_container.remove() - def execute_bootstrapper( - self, bootstrapper_binary: str, job_spec: str - ) -> subprocess.Popen: + def execute_bootstrapper(self, bootstrapper_binary: str, job_spec: str) -> subprocess.Popen: """Runs vm-bootstrapper with the job specification passed to it. This will build the Docker container, create all necessary files and directories, and run the job locally. Command args are defined by Common Runtime team here: https://msdata.visualstudio.com/Vienna/_git/vienna?path=/src/azureml- job-runtime/common- @@ -465,9 +428,7 @@ def execute_bootstrapper( process.kill() raise RuntimeError(LOCAL_JOB_FAILURE_MSG.format(self.stderr.read())) - def check_bootstrapper_process_status( - self, bootstrapper_process: subprocess.Popen - ) -> Optional[int]: + def check_bootstrapper_process_status(self, bootstrapper_process: subprocess.Popen) -> Optional[int]: """Check if bootstrapper process status is non-zero. :param bootstrapper_process: bootstrapper process @@ -478,9 +439,7 @@ def check_bootstrapper_process_status( return_code = bootstrapper_process.poll() if return_code: self.stderr.seek(0) - raise RuntimeError( - self.BOOTSTRAP_BINARY_FAILURE_MSG.format(self.stderr.read()) - ) + raise RuntimeError(self.BOOTSTRAP_BINARY_FAILURE_MSG.format(self.stderr.read())) return return_code @@ -505,9 +464,7 @@ def start_run_if_local( :rtype: str """ token = credential.get_token(ws_base_url + "/.default").token - zip_content, snapshot_id = get_execution_service_response( - job_definition, token, requests_pipeline - ) + zip_content, snapshot_id = get_execution_service_response(job_definition, token, requests_pipeline) try: temp_dir = unzip_to_temporary_file(job_definition, zip_content)