Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ Notes:
* dev: `ocotilloapi_dev`
* test: `ocotilloapi_test` (created by init SQL in `docker/db/init/01-create-test-db.sql`)
* The database listens on port `5432` both inside the container and on your host. Ensure `POSTGRES_PORT=5432` and `POSTGRES_DB=ocotilloapi_dev` in your `.env` to run local commands against the Docker dev DB (e.g., `uv run pytest`, `uv run python -m transfers.transfer`).
* To restore a local or GCS-backed SQL dump into your local target DB, run `source .venv/bin/activate && python -m cli.cli restore-local-db path/to/dump.sql` or `source .venv/bin/activate && python -m cli.cli restore-local-db gs://ocotillo/sql-exports/latest.sql.gz`.
* `SESSION_SECRET_KEY` only needs to be set in `.env` if you plan to use `/admin`; without it, the API and `/ogcapi` still boot, but `/admin` will be unavailable.

#### Staging Data
Expand Down
2 changes: 2 additions & 0 deletions cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ python -m cli.cli --help

## Common commands

- `python -m cli.cli restore-local-db path/to/dump.sql`
- `python -m cli.cli restore-local-db gs://ocotillo/sql-exports/latest.sql.gz`
- `python -m cli.cli transfer-results`
- `python -m cli.cli compare-duplicated-welldata`
- `python -m cli.cli alembic-upgrade-and-data`
Expand Down
30 changes: 30 additions & 0 deletions cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,36 @@ def associate_assets_command(
associate_assets(root_directory)


@cli.command("restore-local-db")
def restore_local_db(
source: str = typer.Argument(
...,
help="Local .sql/.sql.gz path or gs://bucket/path.sql[.gz] URI.",
),
db_name: str | None = typer.Option(
None,
"--db-name",
help="Override POSTGRES_DB for the restore target.",
),
theme: ThemeMode = typer.Option(
ThemeMode.auto, "--theme", help="Color theme: auto, light, dark."
),
):
from cli.db_restore import LocalDbRestoreError, restore_local_db_from_sql

try:
result = restore_local_db_from_sql(source, db_name=db_name)
except LocalDbRestoreError as exc:
typer.echo(str(exc), err=True)
raise typer.Exit(code=1) from exc

typer.echo(
"Restored "
f"{result.source} into {result.db_name} "
f"on {result.host}:{result.port} as {result.user}."
)


@cli.command("transfer-results")
def transfer_results(
summary_path: Path = typer.Option(
Expand Down
243 changes: 243 additions & 0 deletions cli/db_restore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import getpass
import gzip
import os
import re
import subprocess
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path

from db.engine import engine, session_ctx
from db.initialization import recreate_public_schema
from services.gcs_helper import get_storage_bucket

LOCAL_POSTGRES_HOSTS = {"localhost", "127.0.0.1", "::1", "db"}
ROLE_DEPENDENT_SQL_PATTERNS = (
re.compile(r"^\s*SET\s+ROLE\b", re.IGNORECASE),
re.compile(r"^\s*SET\s+SESSION\s+AUTHORIZATION\b", re.IGNORECASE),
re.compile(r"^\s*ALTER\s+.*\s+OWNER\s+TO\b", re.IGNORECASE),
re.compile(r"^\s*GRANT\b", re.IGNORECASE),
re.compile(r"^\s*REVOKE\b", re.IGNORECASE),
re.compile(r"^\s*ALTER\s+DEFAULT\s+PRIVILEGES\b", re.IGNORECASE),
)


class LocalDbRestoreError(RuntimeError):
"""Raised when a local database restore cannot be performed safely."""


@dataclass(frozen=True)
class LocalDbRestoreResult:
sql_file: Path
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LocalDbRestoreResult declares a required sql_file: Path, but restore_local_db_from_sql() never populates it. As written this will raise a TypeError at runtime when constructing the dataclass, breaking the CLI and tests. Either remove sql_file from the result type, or make it optional / define what it should represent (note the staged file is in a TemporaryDirectory and will not exist after the context exits).

Suggested change
sql_file: Path

Copilot uses AI. Check for mistakes.
source: str
host: str
port: str
user: str
db_name: str


def _is_gcs_uri(source: str) -> bool:
return source.startswith("gs://")


def _parse_gcs_uri(source: str) -> tuple[str, str]:
if not _is_gcs_uri(source):
raise LocalDbRestoreError(f"Expected gs:// URI, got {source!r}.")

path = source[5:]
bucket_name, _, blob_name = path.partition("/")
if not bucket_name or not blob_name:
raise LocalDbRestoreError(
f"Invalid GCS URI {source!r}; expected gs://bucket/path.sql[.gz]."
)
return bucket_name, blob_name


def _validate_restore_source_name(source_name: str) -> None:
if source_name.endswith(".sql") or source_name.endswith(".sql.gz"):
return

raise LocalDbRestoreError(
"restore-local-db requires a .sql or .sql.gz source; " f"got {source_name!r}."
)


def _decompress_gzip_file(source_path: Path, target_path: Path) -> None:
try:
with gzip.open(source_path, "rb") as compressed:
with open(target_path, "wb") as expanded:
while chunk := compressed.read(1024 * 1024):
expanded.write(chunk)
except (OSError, gzip.BadGzipFile) as exc:
raise LocalDbRestoreError(
f"Failed to decompress gzip source {source_path!r}: "
"file is not a valid gzip-compressed SQL dump or is corrupted."
) from exc


def _sanitize_sql_dump(source_path: Path, target_path: Path) -> None:
try:
with open(source_path, "r", encoding="utf-8") as infile:
with open(target_path, "w", encoding="utf-8") as outfile:
for line in infile:
if any(
pattern.search(line) for pattern in ROLE_DEPENDENT_SQL_PATTERNS
):
continue
outfile.write(line)
except UnicodeError as exc:
raise LocalDbRestoreError(
f"Failed to read SQL dump {source_path} as UTF-8. "
"Ensure the dump file is UTF-8 encoded and not truncated."
) from exc
except OSError as exc:
raise LocalDbRestoreError(
f"I/O error while processing SQL dump {source_path} -> {target_path}: {exc}"
) from exc


@contextmanager
def _stage_restore_source(source: str | Path):
source_text = str(source)
_validate_restore_source_name(source_text)

with tempfile.TemporaryDirectory(prefix="ocotillo-db-restore-") as temp_dir:
temp_dir_path = Path(temp_dir)
expanded_sql_path = temp_dir_path / "expanded.sql"
staged_sql_path = temp_dir_path / "restore.sql"

if _is_gcs_uri(source_text):
bucket_name, blob_name = _parse_gcs_uri(source_text)
bucket = get_storage_bucket(bucket=bucket_name)
blob = bucket.blob(blob_name)
downloaded_path = temp_dir_path / Path(blob_name).name
blob.download_to_filename(str(downloaded_path))

if source_text.endswith(".sql.gz"):
_decompress_gzip_file(downloaded_path, expanded_sql_path)
else:
expanded_sql_path = downloaded_path
_sanitize_sql_dump(expanded_sql_path, staged_sql_path)
yield staged_sql_path, source_text
return

source_path = Path(source_text)
if not source_path.exists():
raise LocalDbRestoreError(f"Restore source not found: {source_path}")
if not source_path.is_file():
raise LocalDbRestoreError(f"Restore source is not a file: {source_path}")

if source_text.endswith(".sql.gz"):
_decompress_gzip_file(source_path, expanded_sql_path)
else:
expanded_sql_path = source_path
_sanitize_sql_dump(expanded_sql_path, staged_sql_path)
yield staged_sql_path, str(source_path)


def _resolve_restore_target(
db_name: str | None = None,
) -> tuple[str, str, str, str, str]:
driver = (os.environ.get("DB_DRIVER") or "").strip().lower()
if driver == "cloudsql":
raise LocalDbRestoreError(
"restore-local-db only supports local PostgreSQL targets; "
"DB_DRIVER=cloudsql is not allowed."
)

host = (os.environ.get("POSTGRES_HOST") or "localhost").strip()
if not host:
host = "localhost"
if host not in LOCAL_POSTGRES_HOSTS:
raise LocalDbRestoreError(
"restore-local-db only supports local PostgreSQL hosts "
f"({', '.join(sorted(LOCAL_POSTGRES_HOSTS))}); got {host!r}."
)

port = (os.environ.get("POSTGRES_PORT") or "5432").strip()
if not port:
port = "5432"

user = (os.environ.get("POSTGRES_USER") or "").strip()
if not user:
user = getpass.getuser()

target_db = (db_name or os.environ.get("POSTGRES_DB") or "postgres").strip()
if not target_db:
raise LocalDbRestoreError("Target database name is empty.")

password = os.environ.get("POSTGRES_PASSWORD", "")
return host, port, user, target_db, password


def _reset_target_schema() -> None:
try:
engine.dispose()
with session_ctx() as session:
recreate_public_schema(session)
engine.dispose()
except Exception as exc:
raise LocalDbRestoreError(
f"Failed to reset the public schema before restore: {exc}"
) from exc


def restore_local_db_from_sql(
source_file: Path | str, *, db_name: str | None = None
) -> LocalDbRestoreResult:
host, port, user, target_db, password = _resolve_restore_target(db_name)
with _stage_restore_source(source_file) as (staged_sql_file, source_description):
try:
_reset_target_schema()
except LocalDbRestoreError:
raise
except Exception as exc:
raise LocalDbRestoreError(
f"Failed to reset the public schema before restore: {exc}"
) from exc
Comment on lines +191 to +198
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The try/except around _reset_target_schema() is redundant because _reset_target_schema() already catches all exceptions and re-raises LocalDbRestoreError. Consider removing this extra wrapper to avoid duplicated error formatting and keep the control flow simpler.

Suggested change
try:
_reset_target_schema()
except LocalDbRestoreError:
raise
except Exception as exc:
raise LocalDbRestoreError(
f"Failed to reset the public schema before restore: {exc}"
) from exc
_reset_target_schema()

Copilot uses AI. Check for mistakes.
command = [
"psql",
"-v",
"ON_ERROR_STOP=1",
"-h",
host,
"-p",
port,
"-U",
user,
"-d",
target_db,
"-f",
str(staged_sql_file),
]

env = os.environ.copy()
if password:
env["PGPASSWORD"] = password

try:
subprocess.run(
command,
check=True,
env=env,
capture_output=True,
text=True,
)
except FileNotFoundError as exc:
raise LocalDbRestoreError(
"psql is not installed or not available on PATH."
) from exc
except subprocess.CalledProcessError as exc:
detail = (exc.stderr or exc.stdout or "").strip() or str(exc)
raise LocalDbRestoreError(
f"Restore failed for database {target_db!r}: {detail}"
) from exc

return LocalDbRestoreResult(
source=source_description,
host=host,
port=port,
user=user,
db_name=target_db,
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LocalDbRestoreResult(...) construction is missing the required sql_file field from the dataclass definition, which will raise TypeError: __init__() missing 1 required positional argument at runtime. Pass a value for sql_file (and ensure it points to a stable path), or adjust the dataclass fields to match what the function can safely return.

Suggested change
db_name=target_db,
db_name=target_db,
sql_file=staged_sql_file,

Copilot uses AI. Check for mistakes.
)
Loading
Loading