Skip to content
68 changes: 62 additions & 6 deletions sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@
def unzip_to_temporary_file(job_definition: JobBaseData, zip_content: Any) -> Path:
temp_dir = Path(tempfile.gettempdir(), AZUREML_RUNS_DIR, job_definition.name)
temp_dir.mkdir(parents=True, exist_ok=True)
resolved_temp_dir = temp_dir.resolve()
with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_ref:
for member in zip_ref.namelist():
member_path = (resolved_temp_dir / member).resolve()
# Ensure the member extracts within temp_dir (allow temp_dir itself for directory entries)
if member_path != resolved_temp_dir and not str(member_path).startswith(str(resolved_temp_dir) + os.sep):
raise ValueError(
f"Zip archive contains a path traversal entry and cannot be extracted safely: {member}"
)
zip_ref.extractall(temp_dir)
return temp_dir

Expand Down Expand Up @@ -142,7 +150,7 @@ def get_execution_service_response(
try:
local = job_definition.properties.services.get("Local", None)

(url, encodedBody) = local.endpoint.split(EXECUTION_SERVICE_URL_KEY)
url, encodedBody = local.endpoint.split(EXECUTION_SERVICE_URL_KEY)
body = urllib.parse.unquote_plus(encodedBody)
body_dict: Dict = json.loads(body)
response = requests_pipeline.post(url, json=body_dict, headers={"Authorization": "Bearer " + token})
Expand All @@ -167,6 +175,51 @@ def is_local_run(job_definition: JobBaseData) -> bool:
return local is not None and EXECUTION_SERVICE_URL_KEY in local.endpoint


def _safe_tar_extractall(tar: tarfile.TarFile, dest_dir: str) -> None:
"""Extract tar archive members safely, preventing path traversal (TarSlip).

On Python 3.12+, uses the built-in 'data' filter. On older versions,
manually validates each member to ensure no path traversal, symlinks,
hard links, or other special entries that could write outside the
destination directory or create unsafe filesystem nodes.

:param tar: An opened tarfile.TarFile object.
:type tar: tarfile.TarFile
:param dest_dir: The destination directory for extraction.
:type dest_dir: str
:raises ValueError: If a tar member would escape the destination directory
or contains a symlink, hard link, or unsupported special entry type.
"""
resolved_dest = os.path.realpath(dest_dir)

# Python 3.12+ has built-in data_filter for safe extraction
if hasattr(tarfile, "data_filter"):
try:
tar.extractall(resolved_dest, filter="data")
except tarfile.TarError as exc:
raise ValueError(f"Failed to safely extract tar archive: {exc}") from exc
else:
for member in tar.getmembers():
# Reject symbolic and hard links
if member.issym() or member.islnk():
raise ValueError(
f"Tar archive contains a symbolic or hard link and cannot be extracted safely: {member.name}"
)
# Reject any non-regular, non-directory entries (e.g., devices, FIFOs)
if not (member.isfile() or member.isdir()):
raise ValueError(
f"Tar archive contains an unsupported special entry type and cannot be extracted safely: "
f"{member.name}"
)
member_path = os.path.realpath(os.path.join(resolved_dest, member.name))
if member_path != resolved_dest and not member_path.startswith(resolved_dest + os.sep):
raise ValueError(
f"Tar archive contains a path traversal entry and cannot be extracted safely: {member.name}"
)
# All members validated; safe to extract
tar.extractall(resolved_dest)


class CommonRuntimeHelper:
COMMON_RUNTIME_BOOTSTRAPPER_INFO = "common_runtime_bootstrapper_info.json"
COMMON_RUNTIME_JOB_SPEC = "common_runtime_job_spec.json"
Expand Down Expand Up @@ -208,10 +261,14 @@ def __init__(self, job_name: str):
CommonRuntimeHelper.VM_BOOTSTRAPPER_FILE_NAME,
)
self.stdout = open( # pylint: disable=consider-using-with
os.path.join(self.common_runtime_temp_folder, "stdout"), "w+", encoding=DefaultOpenEncoding.WRITE
os.path.join(self.common_runtime_temp_folder, "stdout"),
"w+",
encoding=DefaultOpenEncoding.WRITE,
)
self.stderr = open( # pylint: disable=consider-using-with
os.path.join(self.common_runtime_temp_folder, "stderr"), "w+", encoding=DefaultOpenEncoding.WRITE
os.path.join(self.common_runtime_temp_folder, "stderr"),
"w+",
encoding=DefaultOpenEncoding.WRITE,
)

# Bug Item number: 2885723
Expand Down Expand Up @@ -266,8 +323,7 @@ def copy_bootstrapper_from_container(self, container: "docker.models.containers.
for chunk in data_stream:
f.write(chunk)
with tarfile.open(tar_file, mode="r") as tar:
for file_name in tar.getnames():
tar.extract(file_name, os.path.dirname(path_in_host))
_safe_tar_extractall(tar, os.path.dirname(path_in_host))
os.remove(tar_file)
except docker.errors.APIError as e:
msg = f"Copying {path_in_container} from container has failed. Detailed message: {e}"
Expand Down Expand Up @@ -408,7 +464,7 @@ def start_run_if_local(
:rtype: str
"""
token = credential.get_token(ws_base_url + "/.default").token
(zip_content, snapshot_id) = get_execution_service_response(job_definition, token, requests_pipeline)
zip_content, snapshot_id = get_execution_service_response(job_definition, token, requests_pipeline)

try:
temp_dir = unzip_to_temporary_file(job_definition, zip_content)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import io
import os
import shutil
import tarfile
import tempfile
import zipfile
from pathlib import Path
from unittest.mock import MagicMock

import pytest

from azure.ai.ml.operations._local_job_invoker import (
_get_creationflags_and_startupinfo_for_background_process,
_safe_tar_extractall,
patch_invocation_script_serialization,
unzip_to_temporary_file,
)


Expand Down Expand Up @@ -61,3 +68,120 @@ def test_creation_flags(self):
flags = _get_creationflags_and_startupinfo_for_background_process("linux")

assert flags == {"stderr": -2, "stdin": -3, "stdout": -3}


def _make_job_definition(name="test-run"):
job_def = MagicMock()
job_def.name = name
return job_def


@pytest.mark.unittest
@pytest.mark.training_experiences_test
class TestUnzipPathTraversalPrevention:
"""Tests for ZIP path traversal prevention in unzip_to_temporary_file."""

def test_normal_zip_extracts_successfully(self):
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n")
zf.writestr("azureml-setup/config.json", '{"key": "value"}')
zip_bytes = buf.getvalue()

job_def = _make_job_definition("safe-run")
result = unzip_to_temporary_file(job_def, zip_bytes)

try:
assert result.exists()
assert (result / "azureml-setup" / "invocation.sh").exists()
assert (result / "azureml-setup" / "config.json").exists()
finally:
if result.exists():
shutil.rmtree(result, ignore_errors=True)

def test_zip_with_path_traversal_is_rejected(self):
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n")
zf.writestr("../../etc/evil.sh", "#!/bin/bash\necho pwned\n")
zip_bytes = buf.getvalue()

job_def = _make_job_definition("traversal-run")
with pytest.raises(ValueError, match="path traversal"):
unzip_to_temporary_file(job_def, zip_bytes)

def test_zip_with_absolute_path_is_rejected(self):
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
if os.name == "nt":
zf.writestr("C:/Windows/Temp/evil.sh", "#!/bin/bash\necho pwned\n")
else:
zf.writestr("/tmp/evil.sh", "#!/bin/bash\necho pwned\n")
zip_bytes = buf.getvalue()

job_def = _make_job_definition("absolute-path-run")
with pytest.raises(ValueError, match="path traversal"):
unzip_to_temporary_file(job_def, zip_bytes)


@pytest.mark.unittest
@pytest.mark.training_experiences_test
class TestSafeTarExtract:
"""Tests for tar path traversal prevention in _safe_tar_extractall."""

def test_normal_tar_extracts_successfully(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
data = b"normal content"
info = tarfile.TarInfo(name="vm-bootstrapper")
info.size = len(data)
tar.addfile(info, io.BytesIO(data))
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
_safe_tar_extractall(tar, dest)

assert os.path.exists(os.path.join(dest, "vm-bootstrapper"))

def test_tar_with_path_traversal_is_rejected(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
data = b"evil content"
info = tarfile.TarInfo(name="../../evil_script.sh")
info.size = len(data)
tar.addfile(info, io.BytesIO(data))
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
with pytest.raises(ValueError):
_safe_tar_extractall(tar, dest)

def test_tar_with_symlink_is_rejected(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
info = tarfile.TarInfo(name="evil_link")
info.type = tarfile.SYMTYPE
info.linkname = "/etc/passwd"
tar.addfile(info)
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
with pytest.raises(ValueError):
_safe_tar_extractall(tar, dest)

def test_tar_with_hardlink_is_rejected(self):
with tempfile.TemporaryDirectory() as dest:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tar:
info = tarfile.TarInfo(name="evil_hardlink")
info.type = tarfile.LNKTYPE
info.linkname = "/etc/shadow"
tar.addfile(info)
buf.seek(0)

with tarfile.open(fileobj=buf, mode="r") as tar:
with pytest.raises(ValueError):
_safe_tar_extractall(tar, dest)