Skip to content

Commit f5530ff

Browse files
committed
feat: add restore-local-db command for restoring local databases from SQL dumps
1 parent ccc295f commit f5530ff

6 files changed

Lines changed: 458 additions & 2 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ Notes:
214214
* dev: `ocotilloapi_dev`
215215
* test: `ocotilloapi_test` (created by init SQL in `docker/db/init/01-create-test-db.sql`)
216216
* The database listens on port `5432` both inside the container and on your host. Ensure `POSTGRES_PORT=5432` and `POSTGRES_DB=ocotilloapi_dev` in your `.env` to run local commands against the Docker dev DB (e.g., `uv run pytest`, `uv run python -m transfers.transfer`).
217+
* To restore a local or GCS-backed SQL dump into your local target DB, run `source .venv/bin/activate && python -m cli.cli restore-local-db path/to/dump.sql` or `source .venv/bin/activate && python -m cli.cli restore-local-db gs://ocotillo/sql-exports/latest.sql.gz`.
217218
* `SESSION_SECRET_KEY` only needs to be set in `.env` if you plan to use `/admin`; without it, the API and `/ogcapi` still boot, but `/admin` will be unavailable.
218219

219220
#### Staging Data

cli/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ python -m cli.cli --help
1515

1616
## Common commands
1717

18+
- `python -m cli.cli restore-local-db path/to/dump.sql`
19+
- `python -m cli.cli restore-local-db gs://ocotillo/sql-exports/latest.sql.gz`
1820
- `python -m cli.cli transfer-results`
1921
- `python -m cli.cli compare-duplicated-welldata`
2022
- `python -m cli.cli alembic-upgrade-and-data`

cli/cli.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,36 @@ def associate_assets_command(
134134
associate_assets(root_directory)
135135

136136

137+
@cli.command("restore-local-db")
138+
def restore_local_db(
139+
source: str = typer.Argument(
140+
...,
141+
help="Local .sql/.sql.gz path or gs://bucket/path.sql[.gz] URI.",
142+
),
143+
db_name: str | None = typer.Option(
144+
None,
145+
"--db-name",
146+
help="Override POSTGRES_DB for the restore target.",
147+
),
148+
theme: ThemeMode = typer.Option(
149+
ThemeMode.auto, "--theme", help="Color theme: auto, light, dark."
150+
),
151+
):
152+
from cli.db_restore import LocalDbRestoreError, restore_local_db_from_sql
153+
154+
try:
155+
result = restore_local_db_from_sql(source, db_name=db_name)
156+
except LocalDbRestoreError as exc:
157+
typer.echo(str(exc), err=True)
158+
raise typer.Exit(code=1) from exc
159+
160+
typer.echo(
161+
"Restored "
162+
f"{result.source} into {result.db_name} "
163+
f"on {result.host}:{result.port} as {result.user}."
164+
)
165+
166+
137167
@cli.command("transfer-results")
138168
def transfer_results(
139169
summary_path: Path = typer.Option(

cli/db_restore.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import getpass
2+
import gzip
3+
import os
4+
import re
5+
import subprocess
6+
import tempfile
7+
from contextlib import contextmanager
8+
from dataclasses import dataclass
9+
from pathlib import Path
10+
11+
from db.engine import engine, session_ctx
12+
from db.initialization import recreate_public_schema
13+
from services.gcs_helper import get_storage_bucket
14+
15+
LOCAL_POSTGRES_HOSTS = {"localhost", "127.0.0.1", "::1", "db"}
16+
ROLE_DEPENDENT_SQL_PATTERNS = (
17+
re.compile(r"^\s*SET\s+ROLE\b", re.IGNORECASE),
18+
re.compile(r"^\s*SET\s+SESSION\s+AUTHORIZATION\b", re.IGNORECASE),
19+
re.compile(r"^\s*ALTER\s+.*\s+OWNER\s+TO\b", re.IGNORECASE),
20+
re.compile(r"^\s*GRANT\b", re.IGNORECASE),
21+
re.compile(r"^\s*REVOKE\b", re.IGNORECASE),
22+
re.compile(r"^\s*ALTER\s+DEFAULT\s+PRIVILEGES\b", re.IGNORECASE),
23+
)
24+
25+
26+
class LocalDbRestoreError(RuntimeError):
27+
"""Raised when a local database restore cannot be performed safely."""
28+
29+
30+
@dataclass(frozen=True)
31+
class LocalDbRestoreResult:
32+
sql_file: Path
33+
source: str
34+
host: str
35+
port: str
36+
user: str
37+
db_name: str
38+
39+
40+
def _is_gcs_uri(source: str) -> bool:
41+
return source.startswith("gs://")
42+
43+
44+
def _parse_gcs_uri(source: str) -> tuple[str, str]:
45+
if not _is_gcs_uri(source):
46+
raise LocalDbRestoreError(f"Expected gs:// URI, got {source!r}.")
47+
48+
path = source[5:]
49+
bucket_name, _, blob_name = path.partition("/")
50+
if not bucket_name or not blob_name:
51+
raise LocalDbRestoreError(
52+
f"Invalid GCS URI {source!r}; expected gs://bucket/path.sql[.gz]."
53+
)
54+
return bucket_name, blob_name
55+
56+
57+
def _validate_restore_source_name(source_name: str) -> None:
58+
if source_name.endswith(".sql") or source_name.endswith(".sql.gz"):
59+
return
60+
61+
raise LocalDbRestoreError(
62+
"restore-local-db requires a .sql or .sql.gz source; " f"got {source_name!r}."
63+
)
64+
65+
66+
def _decompress_gzip_file(source_path: Path, target_path: Path) -> None:
67+
with gzip.open(source_path, "rb") as compressed:
68+
with open(target_path, "wb") as expanded:
69+
while chunk := compressed.read(1024 * 1024):
70+
expanded.write(chunk)
71+
72+
73+
def _sanitize_sql_dump(source_path: Path, target_path: Path) -> None:
74+
with open(source_path, "r", encoding="utf-8") as infile:
75+
with open(target_path, "w", encoding="utf-8") as outfile:
76+
for line in infile:
77+
if any(pattern.search(line) for pattern in ROLE_DEPENDENT_SQL_PATTERNS):
78+
continue
79+
outfile.write(line)
80+
81+
82+
@contextmanager
83+
def _stage_restore_source(source: str | Path):
84+
source_text = str(source)
85+
_validate_restore_source_name(source_text)
86+
87+
with tempfile.TemporaryDirectory(prefix="ocotillo-db-restore-") as temp_dir:
88+
temp_dir_path = Path(temp_dir)
89+
expanded_sql_path = temp_dir_path / "expanded.sql"
90+
staged_sql_path = temp_dir_path / "restore.sql"
91+
92+
if _is_gcs_uri(source_text):
93+
bucket_name, blob_name = _parse_gcs_uri(source_text)
94+
bucket = get_storage_bucket(bucket=bucket_name)
95+
blob = bucket.blob(blob_name)
96+
downloaded_path = temp_dir_path / Path(blob_name).name
97+
blob.download_to_filename(str(downloaded_path))
98+
99+
if source_text.endswith(".sql.gz"):
100+
_decompress_gzip_file(downloaded_path, expanded_sql_path)
101+
else:
102+
expanded_sql_path = downloaded_path
103+
_sanitize_sql_dump(expanded_sql_path, staged_sql_path)
104+
yield staged_sql_path, source_text
105+
return
106+
107+
source_path = Path(source_text)
108+
if not source_path.exists():
109+
raise LocalDbRestoreError(f"Restore source not found: {source_path}")
110+
if not source_path.is_file():
111+
raise LocalDbRestoreError(f"Restore source is not a file: {source_path}")
112+
113+
if source_text.endswith(".sql.gz"):
114+
_decompress_gzip_file(source_path, expanded_sql_path)
115+
else:
116+
expanded_sql_path = source_path
117+
_sanitize_sql_dump(expanded_sql_path, staged_sql_path)
118+
yield staged_sql_path, str(source_path)
119+
120+
121+
def _resolve_restore_target(
122+
db_name: str | None = None,
123+
) -> tuple[str, str, str, str, str]:
124+
driver = (os.environ.get("DB_DRIVER") or "").strip().lower()
125+
if driver == "cloudsql":
126+
raise LocalDbRestoreError(
127+
"restore-local-db only supports local PostgreSQL targets; "
128+
"DB_DRIVER=cloudsql is not allowed."
129+
)
130+
131+
host = (os.environ.get("POSTGRES_HOST") or "localhost").strip()
132+
if not host:
133+
host = "localhost"
134+
if host not in LOCAL_POSTGRES_HOSTS:
135+
raise LocalDbRestoreError(
136+
"restore-local-db only supports local PostgreSQL hosts "
137+
f"({', '.join(sorted(LOCAL_POSTGRES_HOSTS))}); got {host!r}."
138+
)
139+
140+
port = (os.environ.get("POSTGRES_PORT") or "5432").strip()
141+
if not port:
142+
port = "5432"
143+
144+
user = (os.environ.get("POSTGRES_USER") or "").strip()
145+
if not user:
146+
user = getpass.getuser()
147+
148+
target_db = (db_name or os.environ.get("POSTGRES_DB") or "postgres").strip()
149+
if not target_db:
150+
raise LocalDbRestoreError("Target database name is empty.")
151+
152+
password = os.environ.get("POSTGRES_PASSWORD", "")
153+
return host, port, user, target_db, password
154+
155+
156+
def _reset_target_schema() -> None:
157+
try:
158+
engine.dispose()
159+
with session_ctx() as session:
160+
recreate_public_schema(session)
161+
engine.dispose()
162+
except Exception as exc:
163+
raise LocalDbRestoreError(
164+
f"Failed to reset the public schema before restore: {exc}"
165+
) from exc
166+
167+
168+
def restore_local_db_from_sql(
169+
source_file: Path | str, *, db_name: str | None = None
170+
) -> LocalDbRestoreResult:
171+
host, port, user, target_db, password = _resolve_restore_target(db_name)
172+
with _stage_restore_source(source_file) as (staged_sql_file, source_description):
173+
try:
174+
_reset_target_schema()
175+
except LocalDbRestoreError:
176+
raise
177+
except Exception as exc:
178+
raise LocalDbRestoreError(
179+
f"Failed to reset the public schema before restore: {exc}"
180+
) from exc
181+
command = [
182+
"psql",
183+
"-v",
184+
"ON_ERROR_STOP=1",
185+
"-h",
186+
host,
187+
"-p",
188+
port,
189+
"-U",
190+
user,
191+
"-d",
192+
target_db,
193+
"-f",
194+
str(staged_sql_file),
195+
]
196+
197+
env = os.environ.copy()
198+
if password:
199+
env["PGPASSWORD"] = password
200+
201+
try:
202+
subprocess.run(
203+
command,
204+
check=True,
205+
env=env,
206+
capture_output=True,
207+
text=True,
208+
)
209+
except FileNotFoundError as exc:
210+
raise LocalDbRestoreError(
211+
"psql is not installed or not available on PATH."
212+
) from exc
213+
except subprocess.CalledProcessError as exc:
214+
detail = (exc.stderr or exc.stdout or "").strip() or str(exc)
215+
raise LocalDbRestoreError(
216+
f"Restore failed for database {target_db!r}: {detail}"
217+
) from exc
218+
219+
return LocalDbRestoreResult(
220+
sql_file=staged_sql_file,
221+
source=source_description,
222+
host=host,
223+
port=port,
224+
user=user,
225+
db_name=target_db,
226+
)

0 commit comments

Comments
 (0)