diff --git a/README.md b/README.md index 44ec7bcd..d9a42a32 100644 --- a/README.md +++ b/README.md @@ -214,6 +214,7 @@ Notes: * dev: `ocotilloapi_dev` * test: `ocotilloapi_test` (created by init SQL in `docker/db/init/01-create-test-db.sql`) * The database listens on port `5432` both inside the container and on your host. Ensure `POSTGRES_PORT=5432` and `POSTGRES_DB=ocotilloapi_dev` in your `.env` to run local commands against the Docker dev DB (e.g., `uv run pytest`, `uv run python -m transfers.transfer`). +* To restore a local or GCS-backed SQL dump into your local target DB, run `source .venv/bin/activate && python -m cli.cli restore-local-db path/to/dump.sql` or `source .venv/bin/activate && python -m cli.cli restore-local-db gs://ocotillo/sql-exports/latest.sql.gz`. * `SESSION_SECRET_KEY` only needs to be set in `.env` if you plan to use `/admin`; without it, the API and `/ogcapi` still boot, but `/admin` will be unavailable. #### Staging Data diff --git a/cli/README.md b/cli/README.md index 42d557c8..2433081c 100644 --- a/cli/README.md +++ b/cli/README.md @@ -15,6 +15,8 @@ python -m cli.cli --help ## Common commands +- `python -m cli.cli restore-local-db path/to/dump.sql` +- `python -m cli.cli restore-local-db gs://ocotillo/sql-exports/latest.sql.gz` - `python -m cli.cli transfer-results` - `python -m cli.cli compare-duplicated-welldata` - `python -m cli.cli alembic-upgrade-and-data` diff --git a/cli/cli.py b/cli/cli.py index 134c3538..44e9d02f 100644 --- a/cli/cli.py +++ b/cli/cli.py @@ -134,6 +134,36 @@ def associate_assets_command( associate_assets(root_directory) +@cli.command("restore-local-db") +def restore_local_db( + source: str = typer.Argument( + ..., + help="Local .sql/.sql.gz path or gs://bucket/path.sql[.gz] URI.", + ), + db_name: str | None = typer.Option( + None, + "--db-name", + help="Override POSTGRES_DB for the restore target.", + ), + theme: ThemeMode = typer.Option( + ThemeMode.auto, "--theme", help="Color theme: auto, light, dark." + ), +): + from cli.db_restore import LocalDbRestoreError, restore_local_db_from_sql + + try: + result = restore_local_db_from_sql(source, db_name=db_name) + except LocalDbRestoreError as exc: + typer.echo(str(exc), err=True) + raise typer.Exit(code=1) from exc + + typer.echo( + "Restored " + f"{result.source} into {result.db_name} " + f"on {result.host}:{result.port} as {result.user}." + ) + + @cli.command("transfer-results") def transfer_results( summary_path: Path = typer.Option( diff --git a/cli/db_restore.py b/cli/db_restore.py new file mode 100644 index 00000000..c746a18e --- /dev/null +++ b/cli/db_restore.py @@ -0,0 +1,243 @@ +import getpass +import gzip +import os +import re +import subprocess +import tempfile +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path + +from db.engine import engine, session_ctx +from db.initialization import recreate_public_schema +from services.gcs_helper import get_storage_bucket + +LOCAL_POSTGRES_HOSTS = {"localhost", "127.0.0.1", "::1", "db"} +ROLE_DEPENDENT_SQL_PATTERNS = ( + re.compile(r"^\s*SET\s+ROLE\b", re.IGNORECASE), + re.compile(r"^\s*SET\s+SESSION\s+AUTHORIZATION\b", re.IGNORECASE), + re.compile(r"^\s*ALTER\s+.*\s+OWNER\s+TO\b", re.IGNORECASE), + re.compile(r"^\s*GRANT\b", re.IGNORECASE), + re.compile(r"^\s*REVOKE\b", re.IGNORECASE), + re.compile(r"^\s*ALTER\s+DEFAULT\s+PRIVILEGES\b", re.IGNORECASE), +) + + +class LocalDbRestoreError(RuntimeError): + """Raised when a local database restore cannot be performed safely.""" + + +@dataclass(frozen=True) +class LocalDbRestoreResult: + sql_file: Path + source: str + host: str + port: str + user: str + db_name: str + + +def _is_gcs_uri(source: str) -> bool: + return source.startswith("gs://") + + +def _parse_gcs_uri(source: str) -> tuple[str, str]: + if not _is_gcs_uri(source): + raise LocalDbRestoreError(f"Expected gs:// URI, got {source!r}.") + + path = source[5:] + bucket_name, _, blob_name = path.partition("/") + if not bucket_name or not blob_name: + raise LocalDbRestoreError( + f"Invalid GCS URI {source!r}; expected gs://bucket/path.sql[.gz]." + ) + return bucket_name, blob_name + + +def _validate_restore_source_name(source_name: str) -> None: + if source_name.endswith(".sql") or source_name.endswith(".sql.gz"): + return + + raise LocalDbRestoreError( + "restore-local-db requires a .sql or .sql.gz source; " f"got {source_name!r}." + ) + + +def _decompress_gzip_file(source_path: Path, target_path: Path) -> None: + try: + with gzip.open(source_path, "rb") as compressed: + with open(target_path, "wb") as expanded: + while chunk := compressed.read(1024 * 1024): + expanded.write(chunk) + except (OSError, gzip.BadGzipFile) as exc: + raise LocalDbRestoreError( + f"Failed to decompress gzip source {source_path!r}: " + "file is not a valid gzip-compressed SQL dump or is corrupted." + ) from exc + + +def _sanitize_sql_dump(source_path: Path, target_path: Path) -> None: + try: + with open(source_path, "r", encoding="utf-8") as infile: + with open(target_path, "w", encoding="utf-8") as outfile: + for line in infile: + if any( + pattern.search(line) for pattern in ROLE_DEPENDENT_SQL_PATTERNS + ): + continue + outfile.write(line) + except UnicodeError as exc: + raise LocalDbRestoreError( + f"Failed to read SQL dump {source_path} as UTF-8. " + "Ensure the dump file is UTF-8 encoded and not truncated." + ) from exc + except OSError as exc: + raise LocalDbRestoreError( + f"I/O error while processing SQL dump {source_path} -> {target_path}: {exc}" + ) from exc + + +@contextmanager +def _stage_restore_source(source: str | Path): + source_text = str(source) + _validate_restore_source_name(source_text) + + with tempfile.TemporaryDirectory(prefix="ocotillo-db-restore-") as temp_dir: + temp_dir_path = Path(temp_dir) + expanded_sql_path = temp_dir_path / "expanded.sql" + staged_sql_path = temp_dir_path / "restore.sql" + + if _is_gcs_uri(source_text): + bucket_name, blob_name = _parse_gcs_uri(source_text) + bucket = get_storage_bucket(bucket=bucket_name) + blob = bucket.blob(blob_name) + downloaded_path = temp_dir_path / Path(blob_name).name + blob.download_to_filename(str(downloaded_path)) + + if source_text.endswith(".sql.gz"): + _decompress_gzip_file(downloaded_path, expanded_sql_path) + else: + expanded_sql_path = downloaded_path + _sanitize_sql_dump(expanded_sql_path, staged_sql_path) + yield staged_sql_path, source_text + return + + source_path = Path(source_text) + if not source_path.exists(): + raise LocalDbRestoreError(f"Restore source not found: {source_path}") + if not source_path.is_file(): + raise LocalDbRestoreError(f"Restore source is not a file: {source_path}") + + if source_text.endswith(".sql.gz"): + _decompress_gzip_file(source_path, expanded_sql_path) + else: + expanded_sql_path = source_path + _sanitize_sql_dump(expanded_sql_path, staged_sql_path) + yield staged_sql_path, str(source_path) + + +def _resolve_restore_target( + db_name: str | None = None, +) -> tuple[str, str, str, str, str]: + driver = (os.environ.get("DB_DRIVER") or "").strip().lower() + if driver == "cloudsql": + raise LocalDbRestoreError( + "restore-local-db only supports local PostgreSQL targets; " + "DB_DRIVER=cloudsql is not allowed." + ) + + host = (os.environ.get("POSTGRES_HOST") or "localhost").strip() + if not host: + host = "localhost" + if host not in LOCAL_POSTGRES_HOSTS: + raise LocalDbRestoreError( + "restore-local-db only supports local PostgreSQL hosts " + f"({', '.join(sorted(LOCAL_POSTGRES_HOSTS))}); got {host!r}." + ) + + port = (os.environ.get("POSTGRES_PORT") or "5432").strip() + if not port: + port = "5432" + + user = (os.environ.get("POSTGRES_USER") or "").strip() + if not user: + user = getpass.getuser() + + target_db = (db_name or os.environ.get("POSTGRES_DB") or "postgres").strip() + if not target_db: + raise LocalDbRestoreError("Target database name is empty.") + + password = os.environ.get("POSTGRES_PASSWORD", "") + return host, port, user, target_db, password + + +def _reset_target_schema() -> None: + try: + engine.dispose() + with session_ctx() as session: + recreate_public_schema(session) + engine.dispose() + except Exception as exc: + raise LocalDbRestoreError( + f"Failed to reset the public schema before restore: {exc}" + ) from exc + + +def restore_local_db_from_sql( + source_file: Path | str, *, db_name: str | None = None +) -> LocalDbRestoreResult: + host, port, user, target_db, password = _resolve_restore_target(db_name) + with _stage_restore_source(source_file) as (staged_sql_file, source_description): + try: + _reset_target_schema() + except LocalDbRestoreError: + raise + except Exception as exc: + raise LocalDbRestoreError( + f"Failed to reset the public schema before restore: {exc}" + ) from exc + command = [ + "psql", + "-v", + "ON_ERROR_STOP=1", + "-h", + host, + "-p", + port, + "-U", + user, + "-d", + target_db, + "-f", + str(staged_sql_file), + ] + + env = os.environ.copy() + if password: + env["PGPASSWORD"] = password + + try: + subprocess.run( + command, + check=True, + env=env, + capture_output=True, + text=True, + ) + except FileNotFoundError as exc: + raise LocalDbRestoreError( + "psql is not installed or not available on PATH." + ) from exc + except subprocess.CalledProcessError as exc: + detail = (exc.stderr or exc.stdout or "").strip() or str(exc) + raise LocalDbRestoreError( + f"Restore failed for database {target_db!r}: {detail}" + ) from exc + + return LocalDbRestoreResult( + source=source_description, + host=host, + port=port, + user=user, + db_name=target_db, + ) diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 499be641..47f451ec 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -15,9 +15,11 @@ # =============================================================================== from __future__ import annotations +import gzip import textwrap import uuid from pathlib import Path +from subprocess import CalledProcessError from types import SimpleNamespace from sqlalchemy import select @@ -159,6 +161,198 @@ def fake_associate(source_directory): assert captured["path"] == asset_dir +def test_restore_local_db_invokes_psql(monkeypatch, tmp_path): + sql_file = tmp_path / "restore.sql" + sql_file.write_text( + "SET ROLE ocotillo;\n" + "ALTER TABLE public.sample OWNER TO ocotillo;\n" + "GRANT ALL ON TABLE public.sample TO ocotillo;\n" + "select 1;\n" + ) + captured: dict[str, object] = {} + call_order: list[str] = [] + + def fake_reset(): + call_order.append("reset") + + def fake_run(command, check, env, capture_output, text): + call_order.append("psql") + captured["command"] = command + captured["check"] = check + captured["env"] = env + captured["capture_output"] = capture_output + captured["text"] = text + captured["restored_sql"] = Path(command[-1]).read_text() + return SimpleNamespace(returncode=0) + + monkeypatch.setattr("cli.db_restore._reset_target_schema", fake_reset) + monkeypatch.setattr("cli.db_restore.subprocess.run", fake_run) + monkeypatch.setenv("POSTGRES_HOST", "localhost") + monkeypatch.setenv("POSTGRES_PORT", "5432") + monkeypatch.setenv("POSTGRES_USER", "nm_user") + monkeypatch.setenv("POSTGRES_PASSWORD", "secret") + monkeypatch.setenv("POSTGRES_DB", "ocotilloapi_dev") + + runner = CliRunner() + result = runner.invoke(cli, ["restore-local-db", str(sql_file)]) + + assert result.exit_code == 0, result.output + assert captured["command"][:-1] == [ + "psql", + "-v", + "ON_ERROR_STOP=1", + "-h", + "localhost", + "-p", + "5432", + "-U", + "nm_user", + "-d", + "ocotilloapi_dev", + "-f", + ] + assert captured["command"][-1].endswith("/restore.sql") + assert captured["check"] is True + assert captured["capture_output"] is True + assert captured["text"] is True + assert captured["env"]["PGPASSWORD"] == "secret" + assert captured["restored_sql"] == "select 1;\n" + assert call_order == ["reset", "psql"] + assert "Restored" in result.output + assert "ocotilloapi_dev" in result.output + + +def test_restore_local_db_rejects_non_sql_files(tmp_path): + source_file = tmp_path / "restore.dump" + source_file.write_text("not sql") + + runner = CliRunner() + result = runner.invoke(cli, ["restore-local-db", str(source_file)]) + + assert result.exit_code == 1 + assert "requires a .sql or .sql.gz source" in result.output + + +def test_restore_local_db_rejects_remote_host(monkeypatch, tmp_path): + sql_file = tmp_path / "restore.sql" + sql_file.write_text("select 1;\n") + called = {"value": False} + + def fake_run(*args, **kwargs): + called["value"] = True + raise AssertionError("subprocess.run should not be called for remote hosts") + + monkeypatch.setattr("cli.db_restore.subprocess.run", fake_run) + monkeypatch.setenv("POSTGRES_HOST", "db.example.com") + + runner = CliRunner() + result = runner.invoke(cli, ["restore-local-db", str(sql_file)]) + + assert result.exit_code == 1 + assert "only supports local PostgreSQL hosts" in result.output + assert called["value"] is False + + +def test_restore_local_db_reports_psql_failures(monkeypatch, tmp_path): + sql_file = tmp_path / "restore.sql" + sql_file.write_text("select 1;\n") + + def fake_run(command, check, env, capture_output, text): + raise CalledProcessError( + 1, + command, + stderr='psql: role "missing" does not exist', + ) + + monkeypatch.setattr("cli.db_restore._reset_target_schema", lambda: None) + monkeypatch.setattr("cli.db_restore.subprocess.run", fake_run) + monkeypatch.setenv("POSTGRES_HOST", "localhost") + monkeypatch.setenv("POSTGRES_DB", "ocotilloapi_dev") + + runner = CliRunner() + result = runner.invoke(cli, ["restore-local-db", str(sql_file)]) + + assert result.exit_code == 1 + assert "Restore failed for database 'ocotilloapi_dev'" in result.output + assert 'role "missing" does not exist' in result.output + + +def test_restore_local_db_downloads_and_restores_gcs_gzip(monkeypatch, tmp_path): + source_uri = "gs://ocotillo/sql-exports/latest.sql.gz" + sql_text = ( + "SET SESSION AUTHORIZATION 'ocotillo';\n" + "REVOKE ALL ON SCHEMA public FROM ocotillo;\n" + "select 42;\n" + ) + gz_payload = gzip.compress(sql_text.encode("utf-8")) + captured: dict[str, object] = {} + + class FakeBlob: + def download_to_filename(self, filename): + Path(filename).write_bytes(gz_payload) + + class FakeBucket: + def __init__(self): + self.requested_blob_name = None + + def blob(self, blob_name): + self.requested_blob_name = blob_name + captured["blob_name"] = blob_name + return FakeBlob() + + fake_bucket = FakeBucket() + + def fake_get_storage_bucket(client=None, bucket=None): + captured["bucket_name"] = bucket + return fake_bucket + + def fake_run(command, check, env, capture_output, text): + captured["command"] = command + captured["restored_sql"] = Path(command[-1]).read_text() + return SimpleNamespace(returncode=0) + + monkeypatch.setattr("cli.db_restore._reset_target_schema", lambda: None) + monkeypatch.setattr("cli.db_restore.get_storage_bucket", fake_get_storage_bucket) + monkeypatch.setattr("cli.db_restore.subprocess.run", fake_run) + monkeypatch.setenv("POSTGRES_HOST", "localhost") + monkeypatch.setenv("POSTGRES_DB", "ocotilloapi_dev") + + runner = CliRunner() + result = runner.invoke(cli, ["restore-local-db", source_uri]) + + assert result.exit_code == 0, result.output + assert captured["bucket_name"] == "ocotillo" + assert captured["blob_name"] == "sql-exports/latest.sql.gz" + assert captured["restored_sql"] == "select 42;\n" + assert captured["command"][-2:] == ["-f", captured["command"][-1]] + assert source_uri in result.output + + +def test_restore_local_db_reports_schema_reset_failures(monkeypatch, tmp_path): + sql_file = tmp_path / "restore.sql" + sql_file.write_text("select 1;\n") + called = {"psql": False} + + def fake_reset(): + raise RuntimeError("permission denied to drop schema public") + + def fake_run(*args, **kwargs): + called["psql"] = True + raise AssertionError("psql should not be called when schema reset fails") + + monkeypatch.setattr("cli.db_restore._reset_target_schema", fake_reset) + monkeypatch.setattr("cli.db_restore.subprocess.run", fake_run) + monkeypatch.setenv("POSTGRES_HOST", "localhost") + monkeypatch.setenv("POSTGRES_DB", "ocotilloapi_dev") + + runner = CliRunner() + result = runner.invoke(cli, ["restore-local-db", str(sql_file)]) + + assert result.exit_code == 1 + assert "permission denied to drop schema public" in result.output + assert called["psql"] is False + + def test_well_inventory_csv_command_calls_service(monkeypatch, tmp_path): inventory_file = tmp_path / "inventory.csv" inventory_file.write_text("header\nvalue\n")