|
| 1 | +import getpass |
| 2 | +import gzip |
| 3 | +import os |
| 4 | +import re |
| 5 | +import subprocess |
| 6 | +import tempfile |
| 7 | +from contextlib import contextmanager |
| 8 | +from dataclasses import dataclass |
| 9 | +from pathlib import Path |
| 10 | + |
| 11 | +from db.engine import engine, session_ctx |
| 12 | +from db.initialization import recreate_public_schema |
| 13 | +from services.gcs_helper import get_storage_bucket |
| 14 | + |
| 15 | +LOCAL_POSTGRES_HOSTS = {"localhost", "127.0.0.1", "::1", "db"} |
| 16 | +ROLE_DEPENDENT_SQL_PATTERNS = ( |
| 17 | + re.compile(r"^\s*SET\s+ROLE\b", re.IGNORECASE), |
| 18 | + re.compile(r"^\s*SET\s+SESSION\s+AUTHORIZATION\b", re.IGNORECASE), |
| 19 | + re.compile(r"^\s*ALTER\s+.*\s+OWNER\s+TO\b", re.IGNORECASE), |
| 20 | + re.compile(r"^\s*GRANT\b", re.IGNORECASE), |
| 21 | + re.compile(r"^\s*REVOKE\b", re.IGNORECASE), |
| 22 | + re.compile(r"^\s*ALTER\s+DEFAULT\s+PRIVILEGES\b", re.IGNORECASE), |
| 23 | +) |
| 24 | + |
| 25 | + |
| 26 | +class LocalDbRestoreError(RuntimeError): |
| 27 | + """Raised when a local database restore cannot be performed safely.""" |
| 28 | + |
| 29 | + |
| 30 | +@dataclass(frozen=True) |
| 31 | +class LocalDbRestoreResult: |
| 32 | + sql_file: Path |
| 33 | + source: str |
| 34 | + host: str |
| 35 | + port: str |
| 36 | + user: str |
| 37 | + db_name: str |
| 38 | + |
| 39 | + |
| 40 | +def _is_gcs_uri(source: str) -> bool: |
| 41 | + return source.startswith("gs://") |
| 42 | + |
| 43 | + |
| 44 | +def _parse_gcs_uri(source: str) -> tuple[str, str]: |
| 45 | + if not _is_gcs_uri(source): |
| 46 | + raise LocalDbRestoreError(f"Expected gs:// URI, got {source!r}.") |
| 47 | + |
| 48 | + path = source[5:] |
| 49 | + bucket_name, _, blob_name = path.partition("/") |
| 50 | + if not bucket_name or not blob_name: |
| 51 | + raise LocalDbRestoreError( |
| 52 | + f"Invalid GCS URI {source!r}; expected gs://bucket/path.sql[.gz]." |
| 53 | + ) |
| 54 | + return bucket_name, blob_name |
| 55 | + |
| 56 | + |
| 57 | +def _validate_restore_source_name(source_name: str) -> None: |
| 58 | + if source_name.endswith(".sql") or source_name.endswith(".sql.gz"): |
| 59 | + return |
| 60 | + |
| 61 | + raise LocalDbRestoreError( |
| 62 | + "restore-local-db requires a .sql or .sql.gz source; " f"got {source_name!r}." |
| 63 | + ) |
| 64 | + |
| 65 | + |
| 66 | +def _decompress_gzip_file(source_path: Path, target_path: Path) -> None: |
| 67 | + with gzip.open(source_path, "rb") as compressed: |
| 68 | + with open(target_path, "wb") as expanded: |
| 69 | + while chunk := compressed.read(1024 * 1024): |
| 70 | + expanded.write(chunk) |
| 71 | + |
| 72 | + |
| 73 | +def _sanitize_sql_dump(source_path: Path, target_path: Path) -> None: |
| 74 | + with open(source_path, "r", encoding="utf-8") as infile: |
| 75 | + with open(target_path, "w", encoding="utf-8") as outfile: |
| 76 | + for line in infile: |
| 77 | + if any(pattern.search(line) for pattern in ROLE_DEPENDENT_SQL_PATTERNS): |
| 78 | + continue |
| 79 | + outfile.write(line) |
| 80 | + |
| 81 | + |
| 82 | +@contextmanager |
| 83 | +def _stage_restore_source(source: str | Path): |
| 84 | + source_text = str(source) |
| 85 | + _validate_restore_source_name(source_text) |
| 86 | + |
| 87 | + with tempfile.TemporaryDirectory(prefix="ocotillo-db-restore-") as temp_dir: |
| 88 | + temp_dir_path = Path(temp_dir) |
| 89 | + expanded_sql_path = temp_dir_path / "expanded.sql" |
| 90 | + staged_sql_path = temp_dir_path / "restore.sql" |
| 91 | + |
| 92 | + if _is_gcs_uri(source_text): |
| 93 | + bucket_name, blob_name = _parse_gcs_uri(source_text) |
| 94 | + bucket = get_storage_bucket(bucket=bucket_name) |
| 95 | + blob = bucket.blob(blob_name) |
| 96 | + downloaded_path = temp_dir_path / Path(blob_name).name |
| 97 | + blob.download_to_filename(str(downloaded_path)) |
| 98 | + |
| 99 | + if source_text.endswith(".sql.gz"): |
| 100 | + _decompress_gzip_file(downloaded_path, expanded_sql_path) |
| 101 | + else: |
| 102 | + expanded_sql_path = downloaded_path |
| 103 | + _sanitize_sql_dump(expanded_sql_path, staged_sql_path) |
| 104 | + yield staged_sql_path, source_text |
| 105 | + return |
| 106 | + |
| 107 | + source_path = Path(source_text) |
| 108 | + if not source_path.exists(): |
| 109 | + raise LocalDbRestoreError(f"Restore source not found: {source_path}") |
| 110 | + if not source_path.is_file(): |
| 111 | + raise LocalDbRestoreError(f"Restore source is not a file: {source_path}") |
| 112 | + |
| 113 | + if source_text.endswith(".sql.gz"): |
| 114 | + _decompress_gzip_file(source_path, expanded_sql_path) |
| 115 | + else: |
| 116 | + expanded_sql_path = source_path |
| 117 | + _sanitize_sql_dump(expanded_sql_path, staged_sql_path) |
| 118 | + yield staged_sql_path, str(source_path) |
| 119 | + |
| 120 | + |
| 121 | +def _resolve_restore_target( |
| 122 | + db_name: str | None = None, |
| 123 | +) -> tuple[str, str, str, str, str]: |
| 124 | + driver = (os.environ.get("DB_DRIVER") or "").strip().lower() |
| 125 | + if driver == "cloudsql": |
| 126 | + raise LocalDbRestoreError( |
| 127 | + "restore-local-db only supports local PostgreSQL targets; " |
| 128 | + "DB_DRIVER=cloudsql is not allowed." |
| 129 | + ) |
| 130 | + |
| 131 | + host = (os.environ.get("POSTGRES_HOST") or "localhost").strip() |
| 132 | + if not host: |
| 133 | + host = "localhost" |
| 134 | + if host not in LOCAL_POSTGRES_HOSTS: |
| 135 | + raise LocalDbRestoreError( |
| 136 | + "restore-local-db only supports local PostgreSQL hosts " |
| 137 | + f"({', '.join(sorted(LOCAL_POSTGRES_HOSTS))}); got {host!r}." |
| 138 | + ) |
| 139 | + |
| 140 | + port = (os.environ.get("POSTGRES_PORT") or "5432").strip() |
| 141 | + if not port: |
| 142 | + port = "5432" |
| 143 | + |
| 144 | + user = (os.environ.get("POSTGRES_USER") or "").strip() |
| 145 | + if not user: |
| 146 | + user = getpass.getuser() |
| 147 | + |
| 148 | + target_db = (db_name or os.environ.get("POSTGRES_DB") or "postgres").strip() |
| 149 | + if not target_db: |
| 150 | + raise LocalDbRestoreError("Target database name is empty.") |
| 151 | + |
| 152 | + password = os.environ.get("POSTGRES_PASSWORD", "") |
| 153 | + return host, port, user, target_db, password |
| 154 | + |
| 155 | + |
| 156 | +def _reset_target_schema() -> None: |
| 157 | + try: |
| 158 | + engine.dispose() |
| 159 | + with session_ctx() as session: |
| 160 | + recreate_public_schema(session) |
| 161 | + engine.dispose() |
| 162 | + except Exception as exc: |
| 163 | + raise LocalDbRestoreError( |
| 164 | + f"Failed to reset the public schema before restore: {exc}" |
| 165 | + ) from exc |
| 166 | + |
| 167 | + |
| 168 | +def restore_local_db_from_sql( |
| 169 | + source_file: Path | str, *, db_name: str | None = None |
| 170 | +) -> LocalDbRestoreResult: |
| 171 | + host, port, user, target_db, password = _resolve_restore_target(db_name) |
| 172 | + with _stage_restore_source(source_file) as (staged_sql_file, source_description): |
| 173 | + try: |
| 174 | + _reset_target_schema() |
| 175 | + except LocalDbRestoreError: |
| 176 | + raise |
| 177 | + except Exception as exc: |
| 178 | + raise LocalDbRestoreError( |
| 179 | + f"Failed to reset the public schema before restore: {exc}" |
| 180 | + ) from exc |
| 181 | + command = [ |
| 182 | + "psql", |
| 183 | + "-v", |
| 184 | + "ON_ERROR_STOP=1", |
| 185 | + "-h", |
| 186 | + host, |
| 187 | + "-p", |
| 188 | + port, |
| 189 | + "-U", |
| 190 | + user, |
| 191 | + "-d", |
| 192 | + target_db, |
| 193 | + "-f", |
| 194 | + str(staged_sql_file), |
| 195 | + ] |
| 196 | + |
| 197 | + env = os.environ.copy() |
| 198 | + if password: |
| 199 | + env["PGPASSWORD"] = password |
| 200 | + |
| 201 | + try: |
| 202 | + subprocess.run( |
| 203 | + command, |
| 204 | + check=True, |
| 205 | + env=env, |
| 206 | + capture_output=True, |
| 207 | + text=True, |
| 208 | + ) |
| 209 | + except FileNotFoundError as exc: |
| 210 | + raise LocalDbRestoreError( |
| 211 | + "psql is not installed or not available on PATH." |
| 212 | + ) from exc |
| 213 | + except subprocess.CalledProcessError as exc: |
| 214 | + detail = (exc.stderr or exc.stdout or "").strip() or str(exc) |
| 215 | + raise LocalDbRestoreError( |
| 216 | + f"Restore failed for database {target_db!r}: {detail}" |
| 217 | + ) from exc |
| 218 | + |
| 219 | + return LocalDbRestoreResult( |
| 220 | + sql_file=staged_sql_file, |
| 221 | + source=source_description, |
| 222 | + host=host, |
| 223 | + port=port, |
| 224 | + user=user, |
| 225 | + db_name=target_db, |
| 226 | + ) |
0 commit comments