diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py index b7436d307902..f901d2295817 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_job_invoker.py @@ -41,7 +41,15 @@ def unzip_to_temporary_file(job_definition: JobBaseData, zip_content: Any) -> Path: temp_dir = Path(tempfile.gettempdir(), AZUREML_RUNS_DIR, job_definition.name) temp_dir.mkdir(parents=True, exist_ok=True) + resolved_temp_dir = temp_dir.resolve() with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_ref: + for member in zip_ref.namelist(): + member_path = (resolved_temp_dir / member).resolve() + # Ensure the member extracts within temp_dir (allow temp_dir itself for directory entries) + if member_path != resolved_temp_dir and not str(member_path).startswith(str(resolved_temp_dir) + os.sep): + raise ValueError( + f"Zip archive contains a path traversal entry and cannot be extracted safely: {member}" + ) zip_ref.extractall(temp_dir) return temp_dir @@ -142,7 +150,7 @@ def get_execution_service_response( try: local = job_definition.properties.services.get("Local", None) - (url, encodedBody) = local.endpoint.split(EXECUTION_SERVICE_URL_KEY) + url, encodedBody = local.endpoint.split(EXECUTION_SERVICE_URL_KEY) body = urllib.parse.unquote_plus(encodedBody) body_dict: Dict = json.loads(body) response = requests_pipeline.post(url, json=body_dict, headers={"Authorization": "Bearer " + token}) @@ -167,6 +175,51 @@ def is_local_run(job_definition: JobBaseData) -> bool: return local is not None and EXECUTION_SERVICE_URL_KEY in local.endpoint +def _safe_tar_extractall(tar: tarfile.TarFile, dest_dir: str) -> None: + """Extract tar archive members safely, preventing path traversal (TarSlip). + + On Python 3.12+, uses the built-in 'data' filter. On older versions, + manually validates each member to ensure no path traversal, symlinks, + hard links, or other special entries that could write outside the + destination directory or create unsafe filesystem nodes. + + :param tar: An opened tarfile.TarFile object. + :type tar: tarfile.TarFile + :param dest_dir: The destination directory for extraction. + :type dest_dir: str + :raises ValueError: If a tar member would escape the destination directory + or contains a symlink, hard link, or unsupported special entry type. + """ + resolved_dest = os.path.realpath(dest_dir) + + # Python 3.12+ has built-in data_filter for safe extraction + if hasattr(tarfile, "data_filter"): + try: + tar.extractall(resolved_dest, filter="data") + except tarfile.TarError as exc: + raise ValueError(f"Failed to safely extract tar archive: {exc}") from exc + else: + for member in tar.getmembers(): + # Reject symbolic and hard links + if member.issym() or member.islnk(): + raise ValueError( + f"Tar archive contains a symbolic or hard link and cannot be extracted safely: {member.name}" + ) + # Reject any non-regular, non-directory entries (e.g., devices, FIFOs) + if not (member.isfile() or member.isdir()): + raise ValueError( + f"Tar archive contains an unsupported special entry type and cannot be extracted safely: " + f"{member.name}" + ) + member_path = os.path.realpath(os.path.join(resolved_dest, member.name)) + if member_path != resolved_dest and not member_path.startswith(resolved_dest + os.sep): + raise ValueError( + f"Tar archive contains a path traversal entry and cannot be extracted safely: {member.name}" + ) + # All members validated; safe to extract + tar.extractall(resolved_dest) + + class CommonRuntimeHelper: COMMON_RUNTIME_BOOTSTRAPPER_INFO = "common_runtime_bootstrapper_info.json" COMMON_RUNTIME_JOB_SPEC = "common_runtime_job_spec.json" @@ -208,10 +261,14 @@ def __init__(self, job_name: str): CommonRuntimeHelper.VM_BOOTSTRAPPER_FILE_NAME, ) self.stdout = open( # pylint: disable=consider-using-with - os.path.join(self.common_runtime_temp_folder, "stdout"), "w+", encoding=DefaultOpenEncoding.WRITE + os.path.join(self.common_runtime_temp_folder, "stdout"), + "w+", + encoding=DefaultOpenEncoding.WRITE, ) self.stderr = open( # pylint: disable=consider-using-with - os.path.join(self.common_runtime_temp_folder, "stderr"), "w+", encoding=DefaultOpenEncoding.WRITE + os.path.join(self.common_runtime_temp_folder, "stderr"), + "w+", + encoding=DefaultOpenEncoding.WRITE, ) # Bug Item number: 2885723 @@ -266,8 +323,7 @@ def copy_bootstrapper_from_container(self, container: "docker.models.containers. for chunk in data_stream: f.write(chunk) with tarfile.open(tar_file, mode="r") as tar: - for file_name in tar.getnames(): - tar.extract(file_name, os.path.dirname(path_in_host)) + _safe_tar_extractall(tar, os.path.dirname(path_in_host)) os.remove(tar_file) except docker.errors.APIError as e: msg = f"Copying {path_in_container} from container has failed. Detailed message: {e}" @@ -408,7 +464,7 @@ def start_run_if_local( :rtype: str """ token = credential.get_token(ws_base_url + "/.default").token - (zip_content, snapshot_id) = get_execution_service_response(job_definition, token, requests_pipeline) + zip_content, snapshot_id = get_execution_service_response(job_definition, token, requests_pipeline) try: temp_dir = unzip_to_temporary_file(job_definition, zip_content) diff --git a/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py b/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py index 53f1ab6db677..cca25c7313c6 100644 --- a/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py +++ b/sdk/ml/azure-ai-ml/tests/job_common/unittests/test_local_job_invoker.py @@ -1,12 +1,19 @@ +import io import os +import shutil +import tarfile import tempfile +import zipfile from pathlib import Path +from unittest.mock import MagicMock import pytest from azure.ai.ml.operations._local_job_invoker import ( _get_creationflags_and_startupinfo_for_background_process, + _safe_tar_extractall, patch_invocation_script_serialization, + unzip_to_temporary_file, ) @@ -61,3 +68,120 @@ def test_creation_flags(self): flags = _get_creationflags_and_startupinfo_for_background_process("linux") assert flags == {"stderr": -2, "stdin": -3, "stdout": -3} + + +def _make_job_definition(name="test-run"): + job_def = MagicMock() + job_def.name = name + return job_def + + +@pytest.mark.unittest +@pytest.mark.training_experiences_test +class TestUnzipPathTraversalPrevention: + """Tests for ZIP path traversal prevention in unzip_to_temporary_file.""" + + def test_normal_zip_extracts_successfully(self): + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n") + zf.writestr("azureml-setup/config.json", '{"key": "value"}') + zip_bytes = buf.getvalue() + + job_def = _make_job_definition("safe-run") + result = unzip_to_temporary_file(job_def, zip_bytes) + + try: + assert result.exists() + assert (result / "azureml-setup" / "invocation.sh").exists() + assert (result / "azureml-setup" / "config.json").exists() + finally: + if result.exists(): + shutil.rmtree(result, ignore_errors=True) + + def test_zip_with_path_traversal_is_rejected(self): + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("azureml-setup/invocation.sh", "#!/bin/bash\necho hello\n") + zf.writestr("../../etc/evil.sh", "#!/bin/bash\necho pwned\n") + zip_bytes = buf.getvalue() + + job_def = _make_job_definition("traversal-run") + with pytest.raises(ValueError, match="path traversal"): + unzip_to_temporary_file(job_def, zip_bytes) + + def test_zip_with_absolute_path_is_rejected(self): + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + if os.name == "nt": + zf.writestr("C:/Windows/Temp/evil.sh", "#!/bin/bash\necho pwned\n") + else: + zf.writestr("/tmp/evil.sh", "#!/bin/bash\necho pwned\n") + zip_bytes = buf.getvalue() + + job_def = _make_job_definition("absolute-path-run") + with pytest.raises(ValueError, match="path traversal"): + unzip_to_temporary_file(job_def, zip_bytes) + + +@pytest.mark.unittest +@pytest.mark.training_experiences_test +class TestSafeTarExtract: + """Tests for tar path traversal prevention in _safe_tar_extractall.""" + + def test_normal_tar_extracts_successfully(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + data = b"normal content" + info = tarfile.TarInfo(name="vm-bootstrapper") + info.size = len(data) + tar.addfile(info, io.BytesIO(data)) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + _safe_tar_extractall(tar, dest) + + assert os.path.exists(os.path.join(dest, "vm-bootstrapper")) + + def test_tar_with_path_traversal_is_rejected(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + data = b"evil content" + info = tarfile.TarInfo(name="../../evil_script.sh") + info.size = len(data) + tar.addfile(info, io.BytesIO(data)) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + with pytest.raises(ValueError): + _safe_tar_extractall(tar, dest) + + def test_tar_with_symlink_is_rejected(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + info = tarfile.TarInfo(name="evil_link") + info.type = tarfile.SYMTYPE + info.linkname = "/etc/passwd" + tar.addfile(info) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + with pytest.raises(ValueError): + _safe_tar_extractall(tar, dest) + + def test_tar_with_hardlink_is_rejected(self): + with tempfile.TemporaryDirectory() as dest: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + info = tarfile.TarInfo(name="evil_hardlink") + info.type = tarfile.LNKTYPE + info.linkname = "/etc/shadow" + tar.addfile(info) + buf.seek(0) + + with tarfile.open(fileobj=buf, mode="r") as tar: + with pytest.raises(ValueError): + _safe_tar_extractall(tar, dest)