diff --git a/cli/db_restore.py b/cli/db_restore.py index c746a18e..5290452e 100644 --- a/cli/db_restore.py +++ b/cli/db_restore.py @@ -21,6 +21,12 @@ re.compile(r"^\s*REVOKE\b", re.IGNORECASE), re.compile(r"^\s*ALTER\s+DEFAULT\s+PRIVILEGES\b", re.IGNORECASE), ) +PSQL_META_COMMAND_PATTERNS = ( + # Newer pg_dump versions emit these psql-only commands for safer restores. + # Older local psql clients reject them, so drop them from staged restores. + re.compile(r"^\s*\\restrict\b", re.IGNORECASE), + re.compile(r"^\s*\\unrestrict\b", re.IGNORECASE), +) class LocalDbRestoreError(RuntimeError): @@ -81,9 +87,13 @@ def _sanitize_sql_dump(source_path: Path, target_path: Path) -> None: with open(source_path, "r", encoding="utf-8") as infile: with open(target_path, "w", encoding="utf-8") as outfile: for line in infile: - if any( + matches_role_sql = any( pattern.search(line) for pattern in ROLE_DEPENDENT_SQL_PATTERNS - ): + ) + matches_psql_meta = any( + pattern.search(line) for pattern in PSQL_META_COMMAND_PATTERNS + ) + if matches_role_sql or matches_psql_meta: continue outfile.write(line) except UnicodeError as exc: @@ -235,6 +245,7 @@ def restore_local_db_from_sql( ) from exc return LocalDbRestoreResult( + sql_file=staged_sql_file, source=source_description, host=host, port=port, diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 47f451ec..53d136d5 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -164,10 +164,12 @@ def fake_associate(source_directory): def test_restore_local_db_invokes_psql(monkeypatch, tmp_path): sql_file = tmp_path / "restore.sql" sql_file.write_text( + "\\restrict abc123\n" "SET ROLE ocotillo;\n" "ALTER TABLE public.sample OWNER TO ocotillo;\n" "GRANT ALL ON TABLE public.sample TO ocotillo;\n" "select 1;\n" + "\\unrestrict abc123\n" ) captured: dict[str, object] = {} call_order: list[str] = []