Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/benchflow/_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"""

import logging
import shlex
import re as _re
from pathlib import PurePosixPath

logger = logging.getLogger(__name__)
Expand All @@ -26,16 +28,14 @@ async def snapshot(env, name: str, workspace: str = "/app") -> str:
Returns a reference string suitable for restore() and for recording
in trial metadata / rewards.jsonl.
"""
import re

if not re.match(r"^[a-zA-Z0-9_-]+$", name):
if not _re.match(r"^[a-zA-Z0-9_-]+$", name):
raise ValueError(
f"Snapshot name must be alphanumeric/dash/underscore, got: {name!r}"
)
await env.exec(f"mkdir -p {_SNAP_DIR}")
snap_path = f"{_SNAP_DIR}/{name}.tar.gz"
result = await env.exec(
f"tar czf {snap_path} -C {workspace} .",
f"tar czf {shlex.quote(snap_path)} -C {shlex.quote(workspace)} .",
timeout_sec=120,
)
if result.return_code != 0:
Expand All @@ -54,12 +54,17 @@ async def restore(env, ref: str, workspace: str = "/app") -> None:
if len(parts) != 3 or parts[0] != "fs":
raise ValueError(f"invalid snapshot ref: {ref}")
snap_path = parts[2]
check = await env.exec(f"test -f {snap_path} && echo ok || echo missing")
# Validate snap_path: must be under _SNAP_DIR and a .tar.gz file
if not snap_path.startswith(_SNAP_DIR + "/") or not snap_path.endswith(".tar.gz"):
raise ValueError(f"invalid snapshot ref: path must be under {_SNAP_DIR}")
if ".." in snap_path.split("/"):
raise ValueError("invalid snapshot ref: path traversal not allowed")
check = await env.exec(f"test -f {shlex.quote(snap_path)} && echo ok || echo missing")
if "missing" in (check.stdout or ""):
raise FileNotFoundError(f"snapshot not found: {snap_path}")
result = await env.exec(
f"rm -rf {workspace}/* {workspace}/.[!.]* 2>/dev/null; "
f"tar xzf {snap_path} -C {workspace}",
f"rm -rf {shlex.quote(workspace)}/* {shlex.quote(workspace)}/.[!.]* 2>/dev/null; "
f"tar xzf {shlex.quote(snap_path)} -C {shlex.quote(workspace)}",
timeout_sec=120,
)
if result.return_code != 0:
Expand Down