diff --git a/src/benchflow/_snapshot.py b/src/benchflow/_snapshot.py index 1cabc33..a2b7699 100644 --- a/src/benchflow/_snapshot.py +++ b/src/benchflow/_snapshot.py @@ -13,6 +13,8 @@ """ import logging +import shlex +import re as _re from pathlib import PurePosixPath logger = logging.getLogger(__name__) @@ -26,16 +28,14 @@ async def snapshot(env, name: str, workspace: str = "/app") -> str: Returns a reference string suitable for restore() and for recording in trial metadata / rewards.jsonl. """ - import re - - if not re.match(r"^[a-zA-Z0-9_-]+$", name): + if not _re.match(r"^[a-zA-Z0-9_-]+$", name): raise ValueError( f"Snapshot name must be alphanumeric/dash/underscore, got: {name!r}" ) await env.exec(f"mkdir -p {_SNAP_DIR}") snap_path = f"{_SNAP_DIR}/{name}.tar.gz" result = await env.exec( - f"tar czf {snap_path} -C {workspace} .", + f"tar czf {shlex.quote(snap_path)} -C {shlex.quote(workspace)} .", timeout_sec=120, ) if result.return_code != 0: @@ -54,12 +54,17 @@ async def restore(env, ref: str, workspace: str = "/app") -> None: if len(parts) != 3 or parts[0] != "fs": raise ValueError(f"invalid snapshot ref: {ref}") snap_path = parts[2] - check = await env.exec(f"test -f {snap_path} && echo ok || echo missing") + # Validate snap_path: must be under _SNAP_DIR and a .tar.gz file + if not snap_path.startswith(_SNAP_DIR + "/") or not snap_path.endswith(".tar.gz"): + raise ValueError(f"invalid snapshot ref: path must be under {_SNAP_DIR}") + if ".." in snap_path.split("/"): + raise ValueError("invalid snapshot ref: path traversal not allowed") + check = await env.exec(f"test -f {shlex.quote(snap_path)} && echo ok || echo missing") if "missing" in (check.stdout or ""): raise FileNotFoundError(f"snapshot not found: {snap_path}") result = await env.exec( - f"rm -rf {workspace}/* {workspace}/.[!.]* 2>/dev/null; " - f"tar xzf {snap_path} -C {workspace}", + f"rm -rf {shlex.quote(workspace)}/* {shlex.quote(workspace)}/.[!.]* 2>/dev/null; " + f"tar xzf {shlex.quote(snap_path)} -C {shlex.quote(workspace)}", timeout_sec=120, ) if result.return_code != 0: