Skip to content

Commit 3d17413

Browse files
authored
Merge pull request #597 from DataIntegrationGroup/restore-local-db
feat: add restore-local-db command for restoring local databases from SQL dumps
2 parents ccc295f + b4be5f1 commit 3d17413

5 files changed

Lines changed: 470 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ Notes:
214214
* dev: `ocotilloapi_dev`
215215
* test: `ocotilloapi_test` (created by init SQL in `docker/db/init/01-create-test-db.sql`)
216216
* The database listens on port `5432` both inside the container and on your host. Ensure `POSTGRES_PORT=5432` and `POSTGRES_DB=ocotilloapi_dev` in your `.env` to run local commands against the Docker dev DB (e.g., `uv run pytest`, `uv run python -m transfers.transfer`).
217+
* To restore a local or GCS-backed SQL dump into your local target DB, run `source .venv/bin/activate && python -m cli.cli restore-local-db path/to/dump.sql` or `source .venv/bin/activate && python -m cli.cli restore-local-db gs://ocotillo/sql-exports/latest.sql.gz`.
217218
* `SESSION_SECRET_KEY` only needs to be set in `.env` if you plan to use `/admin`; without it, the API and `/ogcapi` still boot, but `/admin` will be unavailable.
218219

219220
#### Staging Data

cli/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ python -m cli.cli --help
1515

1616
## Common commands
1717

18+
- `python -m cli.cli restore-local-db path/to/dump.sql`
19+
- `python -m cli.cli restore-local-db gs://ocotillo/sql-exports/latest.sql.gz`
1820
- `python -m cli.cli transfer-results`
1921
- `python -m cli.cli compare-duplicated-welldata`
2022
- `python -m cli.cli alembic-upgrade-and-data`

cli/cli.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,36 @@ def associate_assets_command(
134134
associate_assets(root_directory)
135135

136136

137+
@cli.command("restore-local-db")
138+
def restore_local_db(
139+
source: str = typer.Argument(
140+
...,
141+
help="Local .sql/.sql.gz path or gs://bucket/path.sql[.gz] URI.",
142+
),
143+
db_name: str | None = typer.Option(
144+
None,
145+
"--db-name",
146+
help="Override POSTGRES_DB for the restore target.",
147+
),
148+
theme: ThemeMode = typer.Option(
149+
ThemeMode.auto, "--theme", help="Color theme: auto, light, dark."
150+
),
151+
):
152+
from cli.db_restore import LocalDbRestoreError, restore_local_db_from_sql
153+
154+
try:
155+
result = restore_local_db_from_sql(source, db_name=db_name)
156+
except LocalDbRestoreError as exc:
157+
typer.echo(str(exc), err=True)
158+
raise typer.Exit(code=1) from exc
159+
160+
typer.echo(
161+
"Restored "
162+
f"{result.source} into {result.db_name} "
163+
f"on {result.host}:{result.port} as {result.user}."
164+
)
165+
166+
137167
@cli.command("transfer-results")
138168
def transfer_results(
139169
summary_path: Path = typer.Option(

cli/db_restore.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import getpass
2+
import gzip
3+
import os
4+
import re
5+
import subprocess
6+
import tempfile
7+
from contextlib import contextmanager
8+
from dataclasses import dataclass
9+
from pathlib import Path
10+
11+
from db.engine import engine, session_ctx
12+
from db.initialization import recreate_public_schema
13+
from services.gcs_helper import get_storage_bucket
14+
15+
LOCAL_POSTGRES_HOSTS = {"localhost", "127.0.0.1", "::1", "db"}
16+
ROLE_DEPENDENT_SQL_PATTERNS = (
17+
re.compile(r"^\s*SET\s+ROLE\b", re.IGNORECASE),
18+
re.compile(r"^\s*SET\s+SESSION\s+AUTHORIZATION\b", re.IGNORECASE),
19+
re.compile(r"^\s*ALTER\s+.*\s+OWNER\s+TO\b", re.IGNORECASE),
20+
re.compile(r"^\s*GRANT\b", re.IGNORECASE),
21+
re.compile(r"^\s*REVOKE\b", re.IGNORECASE),
22+
re.compile(r"^\s*ALTER\s+DEFAULT\s+PRIVILEGES\b", re.IGNORECASE),
23+
)
24+
25+
26+
class LocalDbRestoreError(RuntimeError):
27+
"""Raised when a local database restore cannot be performed safely."""
28+
29+
30+
@dataclass(frozen=True)
31+
class LocalDbRestoreResult:
32+
sql_file: Path
33+
source: str
34+
host: str
35+
port: str
36+
user: str
37+
db_name: str
38+
39+
40+
def _is_gcs_uri(source: str) -> bool:
41+
return source.startswith("gs://")
42+
43+
44+
def _parse_gcs_uri(source: str) -> tuple[str, str]:
45+
if not _is_gcs_uri(source):
46+
raise LocalDbRestoreError(f"Expected gs:// URI, got {source!r}.")
47+
48+
path = source[5:]
49+
bucket_name, _, blob_name = path.partition("/")
50+
if not bucket_name or not blob_name:
51+
raise LocalDbRestoreError(
52+
f"Invalid GCS URI {source!r}; expected gs://bucket/path.sql[.gz]."
53+
)
54+
return bucket_name, blob_name
55+
56+
57+
def _validate_restore_source_name(source_name: str) -> None:
58+
if source_name.endswith(".sql") or source_name.endswith(".sql.gz"):
59+
return
60+
61+
raise LocalDbRestoreError(
62+
"restore-local-db requires a .sql or .sql.gz source; " f"got {source_name!r}."
63+
)
64+
65+
66+
def _decompress_gzip_file(source_path: Path, target_path: Path) -> None:
67+
try:
68+
with gzip.open(source_path, "rb") as compressed:
69+
with open(target_path, "wb") as expanded:
70+
while chunk := compressed.read(1024 * 1024):
71+
expanded.write(chunk)
72+
except (OSError, gzip.BadGzipFile) as exc:
73+
raise LocalDbRestoreError(
74+
f"Failed to decompress gzip source {source_path!r}: "
75+
"file is not a valid gzip-compressed SQL dump or is corrupted."
76+
) from exc
77+
78+
79+
def _sanitize_sql_dump(source_path: Path, target_path: Path) -> None:
80+
try:
81+
with open(source_path, "r", encoding="utf-8") as infile:
82+
with open(target_path, "w", encoding="utf-8") as outfile:
83+
for line in infile:
84+
if any(
85+
pattern.search(line) for pattern in ROLE_DEPENDENT_SQL_PATTERNS
86+
):
87+
continue
88+
outfile.write(line)
89+
except UnicodeError as exc:
90+
raise LocalDbRestoreError(
91+
f"Failed to read SQL dump {source_path} as UTF-8. "
92+
"Ensure the dump file is UTF-8 encoded and not truncated."
93+
) from exc
94+
except OSError as exc:
95+
raise LocalDbRestoreError(
96+
f"I/O error while processing SQL dump {source_path} -> {target_path}: {exc}"
97+
) from exc
98+
99+
100+
@contextmanager
101+
def _stage_restore_source(source: str | Path):
102+
source_text = str(source)
103+
_validate_restore_source_name(source_text)
104+
105+
with tempfile.TemporaryDirectory(prefix="ocotillo-db-restore-") as temp_dir:
106+
temp_dir_path = Path(temp_dir)
107+
expanded_sql_path = temp_dir_path / "expanded.sql"
108+
staged_sql_path = temp_dir_path / "restore.sql"
109+
110+
if _is_gcs_uri(source_text):
111+
bucket_name, blob_name = _parse_gcs_uri(source_text)
112+
bucket = get_storage_bucket(bucket=bucket_name)
113+
blob = bucket.blob(blob_name)
114+
downloaded_path = temp_dir_path / Path(blob_name).name
115+
blob.download_to_filename(str(downloaded_path))
116+
117+
if source_text.endswith(".sql.gz"):
118+
_decompress_gzip_file(downloaded_path, expanded_sql_path)
119+
else:
120+
expanded_sql_path = downloaded_path
121+
_sanitize_sql_dump(expanded_sql_path, staged_sql_path)
122+
yield staged_sql_path, source_text
123+
return
124+
125+
source_path = Path(source_text)
126+
if not source_path.exists():
127+
raise LocalDbRestoreError(f"Restore source not found: {source_path}")
128+
if not source_path.is_file():
129+
raise LocalDbRestoreError(f"Restore source is not a file: {source_path}")
130+
131+
if source_text.endswith(".sql.gz"):
132+
_decompress_gzip_file(source_path, expanded_sql_path)
133+
else:
134+
expanded_sql_path = source_path
135+
_sanitize_sql_dump(expanded_sql_path, staged_sql_path)
136+
yield staged_sql_path, str(source_path)
137+
138+
139+
def _resolve_restore_target(
140+
db_name: str | None = None,
141+
) -> tuple[str, str, str, str, str]:
142+
driver = (os.environ.get("DB_DRIVER") or "").strip().lower()
143+
if driver == "cloudsql":
144+
raise LocalDbRestoreError(
145+
"restore-local-db only supports local PostgreSQL targets; "
146+
"DB_DRIVER=cloudsql is not allowed."
147+
)
148+
149+
host = (os.environ.get("POSTGRES_HOST") or "localhost").strip()
150+
if not host:
151+
host = "localhost"
152+
if host not in LOCAL_POSTGRES_HOSTS:
153+
raise LocalDbRestoreError(
154+
"restore-local-db only supports local PostgreSQL hosts "
155+
f"({', '.join(sorted(LOCAL_POSTGRES_HOSTS))}); got {host!r}."
156+
)
157+
158+
port = (os.environ.get("POSTGRES_PORT") or "5432").strip()
159+
if not port:
160+
port = "5432"
161+
162+
user = (os.environ.get("POSTGRES_USER") or "").strip()
163+
if not user:
164+
user = getpass.getuser()
165+
166+
target_db = (db_name or os.environ.get("POSTGRES_DB") or "postgres").strip()
167+
if not target_db:
168+
raise LocalDbRestoreError("Target database name is empty.")
169+
170+
password = os.environ.get("POSTGRES_PASSWORD", "")
171+
return host, port, user, target_db, password
172+
173+
174+
def _reset_target_schema() -> None:
175+
try:
176+
engine.dispose()
177+
with session_ctx() as session:
178+
recreate_public_schema(session)
179+
engine.dispose()
180+
except Exception as exc:
181+
raise LocalDbRestoreError(
182+
f"Failed to reset the public schema before restore: {exc}"
183+
) from exc
184+
185+
186+
def restore_local_db_from_sql(
187+
source_file: Path | str, *, db_name: str | None = None
188+
) -> LocalDbRestoreResult:
189+
host, port, user, target_db, password = _resolve_restore_target(db_name)
190+
with _stage_restore_source(source_file) as (staged_sql_file, source_description):
191+
try:
192+
_reset_target_schema()
193+
except LocalDbRestoreError:
194+
raise
195+
except Exception as exc:
196+
raise LocalDbRestoreError(
197+
f"Failed to reset the public schema before restore: {exc}"
198+
) from exc
199+
command = [
200+
"psql",
201+
"-v",
202+
"ON_ERROR_STOP=1",
203+
"-h",
204+
host,
205+
"-p",
206+
port,
207+
"-U",
208+
user,
209+
"-d",
210+
target_db,
211+
"-f",
212+
str(staged_sql_file),
213+
]
214+
215+
env = os.environ.copy()
216+
if password:
217+
env["PGPASSWORD"] = password
218+
219+
try:
220+
subprocess.run(
221+
command,
222+
check=True,
223+
env=env,
224+
capture_output=True,
225+
text=True,
226+
)
227+
except FileNotFoundError as exc:
228+
raise LocalDbRestoreError(
229+
"psql is not installed or not available on PATH."
230+
) from exc
231+
except subprocess.CalledProcessError as exc:
232+
detail = (exc.stderr or exc.stdout or "").strip() or str(exc)
233+
raise LocalDbRestoreError(
234+
f"Restore failed for database {target_db!r}: {detail}"
235+
) from exc
236+
237+
return LocalDbRestoreResult(
238+
source=source_description,
239+
host=host,
240+
port=port,
241+
user=user,
242+
db_name=target_db,
243+
)

0 commit comments

Comments
 (0)