diff --git a/psi/cache.py b/psi/cache.py index 7d4ae60..905bd8b 100644 --- a/psi/cache.py +++ b/psi/cache.py @@ -10,6 +10,16 @@ - A :class:`threading.Lock` serializes concurrent writers; reads are lock-free under CPython's GIL. +Cache keys are HMAC-SHA256 of the secret's mapping JSON bytes, using a random +32-byte key generated once on first save and preserved across loads. Keying on +the mapping (provider coordinates) instead of Podman's hex secret ID means the +cache survives Podman's delete+create churn during setup runs — the same +mapping always produces the same cache key, regardless of which hex ID Podman +currently associates with it. + +Serve checks the file mtime on each lookup and reloads in place when setup has +written a fresh version. No forced restart required after rotations. + On-disk envelope: :: @@ -21,14 +31,22 @@ Inside the decrypted payload is a JSON object:: - {"version": 1, "written_at": , "entries": {"": ""}} + {"version": 2, "written_at": , + "hmac_key": "", + "entries": {"": ""}} + +Payload version 1 (legacy, keyed by Podman hex IDs) is accepted on load but +discarded — the next save writes version 2 with a freshly generated HMAC key. """ from __future__ import annotations import base64 import binascii +import hashlib +import hmac import json +import secrets import struct import threading import time @@ -48,7 +66,8 @@ VERSION = 0x01 _HEADER_FMT = ">4sBB" _HEADER_SIZE = struct.calcsize(_HEADER_FMT) -_PAYLOAD_VERSION = 1 +_PAYLOAD_VERSION = 2 +_HMAC_KEY_BYTES = 32 class CacheError(ProviderError): @@ -94,8 +113,10 @@ def __init__(self, path: Path, backend: CacheBackend) -> None: self._path = path self._backend = backend self._entries: dict[str, bytes] = {} + self._hmac_key: bytes = secrets.token_bytes(_HMAC_KEY_BYTES) self._lock = threading.Lock() self._loaded = False + self._mtime_ns: int = 0 @property def path(self) -> Path: @@ -107,12 +128,28 @@ def backend_tag(self) -> int: """The backend discriminator byte used in the envelope header.""" return self._backend.tag + def cache_key(self, mapping_bytes: bytes) -> str: + """Compute the stable cache key for a secret's mapping JSON. + + HMAC-SHA256 over the raw mapping bytes with the cache's HMAC key. Two + callers (setup writer, serve reader) that see the same mapping bytes + produce identical cache keys, so the cache survives Podman's hex ID + churn. The HMAC key is per-host (never leaves the cache file) so + mapping hashes cannot be correlated across deployments. + """ + return hmac.new(self._hmac_key, mapping_bytes, hashlib.sha256).hexdigest() + def load(self) -> None: """Decrypt the backing file into memory. If the file is missing, start with an empty cache — ``save()`` will create it on the next mutation. If the file exists but is malformed or was written by a different backend, raise :class:`CacheError`. + + Legacy v1 payloads (hex-ID keyed) are silently discarded — the next + save writes v2 with a freshly generated HMAC key. A legacy cache is + indistinguishable from no cache at all; container startups during the + transition fall through to the provider. """ if not self._path.exists(): logger.info("Cache file not found at {}, starting empty", self._path) @@ -121,6 +158,7 @@ def load(self) -> None: return raw = self._path.read_bytes() + self._mtime_ns = self._path.stat().st_mtime_ns if len(raw) < _HEADER_SIZE: msg = f"Cache file {self._path} is too short to contain a header" raise CacheError(msg) @@ -141,44 +179,67 @@ def load(self) -> None: raise CacheError(msg) plaintext = self._backend.decrypt(raw[_HEADER_SIZE:]) - self._entries = _parse_payload(plaintext) + parsed = _parse_payload(plaintext) + self._hmac_key = parsed.hmac_key + self._entries = parsed.entries self._loaded = True logger.info("Loaded {} cache entries from {}", len(self._entries), self._path) + def maybe_reload(self) -> bool: + """Reload from disk if the backing file's mtime has changed. + + Called from the lookup hot path so serve picks up fresh entries + written by setup without needing a process restart. A ``stat`` call + is ~1μs on modern kernels; actual reload happens only when setup + has finished a write. + + Returns: + True if a reload happened, False otherwise. + """ + try: + current_mtime = self._path.stat().st_mtime_ns + except FileNotFoundError: + return False + if current_mtime == self._mtime_ns: + return False + self.load() + return True + def save(self) -> None: """Encrypt the in-memory dict and atomically replace the backing file.""" with self._lock: - plaintext = _serialize_payload(self._entries) + plaintext = _serialize_payload(self._entries, self._hmac_key) payload = self._backend.encrypt(plaintext) header = struct.pack(_HEADER_FMT, MAGIC, VERSION, self._backend.tag) write_bytes_secure(self._path, header + payload) + self._mtime_ns = self._path.stat().st_mtime_ns - def get(self, secret_id: str) -> bytes | None: - """Return the cached value for ``secret_id``, or ``None`` on miss.""" - return self._entries.get(secret_id) + def get(self, key: str) -> bytes | None: + """Return the cached value for ``key``, or ``None`` on miss.""" + return self._entries.get(key) - def set(self, secret_id: str, value: bytes) -> None: - """Insert or update ``secret_id``. Does not persist; call :meth:`save`.""" + def set(self, key: str, value: bytes) -> None: + """Insert or update ``key``. Does not persist; call :meth:`save`.""" with self._lock: - self._entries[secret_id] = value + self._entries[key] = value def bulk_set(self, entries: dict[str, bytes]) -> None: """Replace all entries atomically. Does not persist; call :meth:`save`.""" with self._lock: self._entries = dict(entries) - def invalidate(self, secret_id: str) -> bool: - """Drop ``secret_id`` from the cache. Returns True if it was present.""" + def invalidate(self, key: str) -> bool: + """Drop ``key`` from the cache. Returns True if it was present.""" with self._lock: - return self._entries.pop(secret_id, None) is not None + return self._entries.pop(key, None) is not None def clear(self) -> None: """Drop all entries. Does not persist; call :meth:`save`.""" with self._lock: self._entries = {} - def __contains__(self, secret_id: object) -> bool: - return secret_id in self._entries + def __contains__(self, key: object) -> bool: + return key in self._entries def __len__(self) -> int: return len(self._entries) @@ -192,16 +253,27 @@ def close(self) -> None: self._backend.close() -def _serialize_payload(entries: dict[str, bytes]) -> bytes: +class _ParsedPayload: + """Deserialized cache payload: the HMAC key and the entry dict.""" + + __slots__ = ("entries", "hmac_key") + + def __init__(self, hmac_key: bytes, entries: dict[str, bytes]) -> None: + self.hmac_key = hmac_key + self.entries = entries + + +def _serialize_payload(entries: dict[str, bytes], hmac_key: bytes) -> bytes: payload = { "version": _PAYLOAD_VERSION, "written_at": int(time.time()), + "hmac_key": base64.b64encode(hmac_key).decode("ascii"), "entries": {k: base64.b64encode(v).decode("ascii") for k, v in entries.items()}, } return json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8") -def _parse_payload(plaintext: bytes) -> dict[str, bytes]: +def _parse_payload(plaintext: bytes) -> _ParsedPayload: try: payload = json.loads(plaintext.decode("utf-8")) except (UnicodeDecodeError, json.JSONDecodeError) as e: @@ -213,10 +285,29 @@ def _parse_payload(plaintext: bytes) -> dict[str, bytes]: raise CacheError(msg) version = payload.get("version") + if version == 1: + logger.info( + "Cache payload is legacy v1 (hex-ID keyed) — discarding. " + "Next setup run will repopulate with HMAC-keyed v2." + ) + return _ParsedPayload(secrets.token_bytes(_HMAC_KEY_BYTES), {}) if version != _PAYLOAD_VERSION: msg = f"Cache payload version {version} is unsupported (expected {_PAYLOAD_VERSION})" raise CacheError(msg) + raw_hmac_key = payload.get("hmac_key") + if not isinstance(raw_hmac_key, str): + msg = "Cache payload 'hmac_key' must be a base64 string" + raise CacheError(msg) + try: + hmac_key = base64.b64decode(raw_hmac_key, validate=True) + except (ValueError, binascii.Error) as e: + msg = f"Cache 'hmac_key' is not valid base64: {e}" + raise CacheError(msg) from e + if len(hmac_key) != _HMAC_KEY_BYTES: + msg = f"Cache 'hmac_key' must be {_HMAC_KEY_BYTES} bytes, got {len(hmac_key)}" + raise CacheError(msg) + raw_entries = payload.get("entries") if not isinstance(raw_entries, dict): msg = "Cache payload 'entries' must be a JSON object" @@ -232,4 +323,4 @@ def _parse_payload(plaintext: bytes) -> dict[str, bytes]: except (ValueError, binascii.Error) as e: msg = f"Cache entry {key!r} has invalid base64: {e}" raise CacheError(msg) from e - return entries + return _ParsedPayload(hmac_key, entries) diff --git a/psi/provider.py b/psi/provider.py index 81581aa..e2a4eba 100644 --- a/psi/provider.py +++ b/psi/provider.py @@ -48,6 +48,16 @@ def parse_mapping(raw: str) -> dict: return data +def mapping_cache_bytes(mapping_data: dict) -> bytes: + """Canonical byte representation of a parsed mapping for cache keying. + + Both the setup writer and the serve reader compute the cache key from + this canonical form, so any trailing whitespace or key-order differences + in the on-disk mapping file do not produce spurious cache misses. + """ + return json.dumps(mapping_data, separators=(",", ":"), sort_keys=True).encode("utf-8") + + def get_provider(name: str, settings: PsiSettings) -> SecretProvider: """Instantiate a provider by name from settings.""" from psi.providers import create_provider diff --git a/psi/serve.py b/psi/serve.py index b84e45a..a60889b 100644 --- a/psi/serve.py +++ b/psi/serve.py @@ -13,7 +13,12 @@ from psi.errors import PsiError from psi.files import write_bytes_secure -from psi.provider import close_all_providers, open_all_providers, parse_mapping +from psi.provider import ( + close_all_providers, + mapping_cache_bytes, + open_all_providers, + parse_mapping, +) from psi.secret import validate_secret_id from psi.token import resolve_socket_token @@ -225,8 +230,11 @@ def _handle_lookup(self, secret_id: str) -> None: provider=provider_name, ) + cache_key: str | None = None if cache is not None: - cached = cache.get(secret_id) + cache.maybe_reload() + cache_key = cache.cache_key(mapping_cache_bytes(mapping_data)) + cached = cache.get(cache_key) if cached is not None: self._respond(200, cached) audit.bind(outcome="success", source="cache").debug("lookup") @@ -243,8 +251,8 @@ def _handle_lookup(self, secret_id: str) -> None: self._respond_error(502, "internal_error", str(e)) return - if cache is not None: - cache.set(secret_id, value) + if cache is not None and cache_key is not None: + cache.set(cache_key, value) self._respond(200, value) audit.bind(outcome="success", source="provider").info("lookup") diff --git a/psi/setup.py b/psi/setup.py index c7f9ec3..f609ed2 100644 --- a/psi/setup.py +++ b/psi/setup.py @@ -44,7 +44,9 @@ def run_setup( settings.state_dir.mkdir(parents=True, exist_ok=True) cache = _open_setup_cache(settings) - cache_updates: dict[str, bytes] = {} + # Keyed by the canonical mapping bytes so the caller can compute the + # HMAC cache key once the cache object is available. + values_by_mapping: dict[bytes, bytes] = {} try: for workload_name, workload in settings.workloads.items(): @@ -54,18 +56,16 @@ def run_setup( logger.info("Workload: {}", workload_name) if workload.provider == "infisical": - _setup_infisical_workload(settings, workload_name, cache_updates) + _setup_infisical_workload(settings, workload_name, values_by_mapping) elif workload.provider == "nitrokeyhsm": logger.info("Nitrokey HSM workload — secrets created via 'psi nitrokeyhsm store'") else: logger.warning("Unknown provider '{}', skipping", workload.provider) - if cache is not None: - if cache_updates: - logger.info("Writing {} entries to secret cache", len(cache_updates)) - for key, value in cache_updates.items(): - cache.set(key, value) - _prune_stale_cache_entries(cache) + if cache is not None and values_by_mapping: + logger.info("Writing {} entries to secret cache", len(values_by_mapping)) + for mapping_bytes, value in values_by_mapping.items(): + cache.set(cache.cache_key(mapping_bytes), value) cache.save() finally: if cache is not None: @@ -76,27 +76,6 @@ def run_setup( logger.info("Setup complete.") -def _prune_stale_cache_entries(cache: Cache) -> None: - """Drop cache entries whose keys are not currently in Podman's secret store. - - Each time ``_register_secrets`` deletes and re-creates a Podman secret, - Podman assigns a new hex ID. The old ID's cache entry becomes orphaned — - valid ciphertext for a secret that no longer exists. Without pruning, the - cache grows unboundedly across setup runs. - """ - try: - active_ids = {s.get("ID", "") for s in _list_podman_shell_secrets()} - except httpx.HTTPError as e: - logger.warning("Cannot query Podman secrets for cache pruning: {}", e) - return - - stale = [k for k in cache.entry_ids() if k not in active_ids] - if stale: - logger.info("Pruning {} stale cache entries", len(stale)) - for key in stale: - cache.invalidate(key) - - def _open_setup_cache(settings: PsiSettings) -> Cache | None: """Open the cache for write during setup, or return None on any failure.""" if not settings.cache.enabled or settings.cache.backend is None: @@ -144,13 +123,13 @@ def _is_retryable(exc: Exception) -> bool: def _setup_infisical_workload( settings: PsiSettings, workload_name: str, - cache_updates: dict[str, bytes], + values_by_mapping: dict[bytes, bytes], ) -> None: """Run Infisical-specific setup for a workload with retry.""" last_exc: Exception | None = None for attempt in range(len(_RETRY_DELAYS) + 1): try: - _fetch_and_register_infisical(settings, workload_name, cache_updates) + _fetch_and_register_infisical(settings, workload_name, values_by_mapping) return except (httpx.ConnectError, httpx.HTTPStatusError, ProviderError) as e: cause = e.__cause__ if isinstance(e, ProviderError) else e @@ -174,15 +153,17 @@ def _setup_infisical_workload( def _fetch_and_register_infisical( settings: PsiSettings, workload_name: str, - cache_updates: dict[str, bytes], + values_by_mapping: dict[bytes, bytes], ) -> None: """Fetch secrets from Infisical and register with Podman. - Populates ``cache_updates`` with ``{podman_hex_id: value_bytes}`` so the - caller can flush the encrypted cache once all workloads are processed. - The cache must be keyed by Podman's hex secret ID because that is what - the serve lookup path receives in ``$SECRET_ID``. + Populates ``values_by_mapping`` with ``{canonical_mapping_bytes: value}``. + The caller computes the HMAC cache key from these bytes once the cache + is available. Keying by mapping content makes the cache survive Podman's + delete+create churn — the same mapping always produces the same cache + key, regardless of the hex ID Podman has assigned to it today. """ + from psi.provider import mapping_cache_bytes, parse_mapping from psi.providers.infisical import InfisicalProvider from psi.providers.infisical.models import InfisicalConfig, resolve_auth @@ -229,13 +210,12 @@ def _fetch_and_register_infisical( logger.info("Found {} secrets", len(secrets)) logger.info("Merged: {} unique secrets", len(merged)) - id_map = _register_secrets(settings, workload_name, merged) + _register_secrets(settings, workload_name, merged) _generate_drop_in(settings, workload_name, merged) for key, value in values.items(): - secret_id = id_map.get(key, "") - if secret_id: - cache_updates[secret_id] = value + mapping_bytes = mapping_cache_bytes(parse_mapping(merged[key])) + values_by_mapping[mapping_bytes] = value finally: provider.close() @@ -244,16 +224,15 @@ def _register_secrets( settings: PsiSettings, workload_name: str, secrets: dict[str, str], -) -> dict[str, str]: +) -> None: """Create namespaced Podman secrets with mapping data. - Returns: - Mapping of ``{secret_key: podman_hex_id}`` so the caller can - populate the encrypted cache with the correct lookup key. + The hex IDs Podman assigns during delete+create are no longer tracked — + cache keying is by mapping content via HMAC, not by Podman's volatile + hex IDs. """ transport = httpx.HTTPTransport(uds=_podman_socket_url()) base = f"http://localhost/{_PODMAN_API_VERSION}" - id_map: dict[str, str] = {} with httpx.Client(transport=transport, timeout=30.0) as client: for secret_name, mapping_json in secrets.items(): @@ -265,12 +244,8 @@ def _register_secrets( content=mapping_json.encode(), ) resp.raise_for_status() - secret_id = resp.json().get("ID", "") - if secret_id: - id_map[secret_name] = secret_id logger.info("Registered {} secrets with Podman", len(secrets)) - return id_map def _generate_drop_in( diff --git a/psi/unitgen.py b/psi/unitgen.py index 56f0958..a047e34 100644 --- a/psi/unitgen.py +++ b/psi/unitgen.py @@ -159,16 +159,11 @@ def generate_provider_refresh_service(provider: str) -> str: timestamp will only fire once. This wrapper is a plain oneshot with no ``RemainAfterExit``, so its - ``ActiveEnterTimestamp`` updates every run. The timer uses - ``OnUnitActiveSec`` against the wrapper and re-arms correctly. Each run: - - 1. ``systemctl restart psi-{provider}-setup.service`` — re-runs setup, - which re-registers secrets with fresh hex IDs and writes the updated - cache file to disk. - 2. ``systemctl try-restart psi-secrets.service`` — restarts serve so it - reloads the fresh cache from disk. Without this, serve's in-memory - cache keeps the old hex IDs after each refresh and every subsequent - lookup misses the cache until the next operator-triggered restart. + ``ActiveEnterTimestamp`` updates every run. Each run calls + ``systemctl restart`` on the setup unit, which DOES re-run the ExecStart + even when the setup unit was ``active (exited)``. Setup writes the fresh + cache file; ``psi serve`` auto-reloads via mtime watch on the next + lookup — no forced serve restart required. """ return ( "[Unit]\n" @@ -178,7 +173,6 @@ def generate_provider_refresh_service(provider: str) -> str: "[Service]\n" "Type=oneshot\n" f"ExecStart=/usr/bin/systemctl restart psi-{provider}-setup.service\n" - "ExecStart=/usr/bin/systemctl try-restart psi-secrets.service\n" ) diff --git a/tests/test_cache.py b/tests/test_cache.py index 062da0a..134d29d 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -196,3 +196,100 @@ class TestClose: def test_close_delegates_to_backend(self, cache: Cache, backend: FakeBackend) -> None: cache.close() assert backend.close_calls == 1 + + +class TestCacheKey: + def test_same_mapping_bytes_produce_same_key(self, cache: Cache) -> None: + mapping = b'{"provider":"infisical","project":"p","path":"/","key":"K"}' + assert cache.cache_key(mapping) == cache.cache_key(mapping) + + def test_different_mappings_produce_different_keys(self, cache: Cache) -> None: + a = b'{"provider":"infisical","project":"p","path":"/","key":"A"}' + b = b'{"provider":"infisical","project":"p","path":"/","key":"B"}' + assert cache.cache_key(a) != cache.cache_key(b) + + def test_key_is_stable_across_save_and_load(self, cache: Cache) -> None: + mapping = b'{"provider":"infisical","project":"p","path":"/","key":"K"}' + key_before = cache.cache_key(mapping) + cache.set(key_before, b"value") + cache.save() + + reloaded = Cache(cache.path, FakeBackend()) + reloaded.load() + assert reloaded.cache_key(mapping) == key_before + assert reloaded.get(key_before) == b"value" + + def test_hmac_key_is_per_host_random(self, tmp_path: Path) -> None: + """Two freshly-initialized caches produce different keys for same input. + + The HMAC key is random on init; only load() imports one from an + existing file. Cross-host correlation of cache contents is impossible. + """ + mapping = b'{"provider":"infisical","project":"p","path":"/","key":"K"}' + a = Cache(tmp_path / "a.enc", FakeBackend()) + b = Cache(tmp_path / "b.enc", FakeBackend()) + assert a.cache_key(mapping) != b.cache_key(mapping) + + +class TestMaybeReload: + def test_no_reload_when_mtime_unchanged(self, cache: Cache) -> None: + cache.set("k", b"v") + cache.save() + assert cache.maybe_reload() is False + + def test_reloads_when_file_is_rewritten_by_another_writer(self, cache: Cache) -> None: + """Setup writes; serve (running in another process) picks up changes.""" + # Initial state: serve sees "old" + cache.set("k", b"old") + cache.save() + + # Simulate setup's writer: new Cache instance, writes new value + other = Cache(cache.path, FakeBackend()) + other.load() + other.set("k", b"new") + # Force a distinct mtime even on fast filesystems + import os as _os + import time as _time + + stat_before = _os.stat(cache.path) + while True: + other.save() + if _os.stat(cache.path).st_mtime_ns != stat_before.st_mtime_ns: + break + _time.sleep(0.01) + + assert cache.maybe_reload() is True + assert cache.get("k") == b"new" + + def test_missing_file_returns_false_no_crash( + self, tmp_path: Path, backend: FakeBackend + ) -> None: + cache = Cache(tmp_path / "does-not-exist.enc", backend) + assert cache.maybe_reload() is False + + +class TestLegacyV1PayloadDiscarded: + def test_v1_payload_is_treated_as_empty_with_fresh_hmac_key( + self, cache: Cache, backend: FakeBackend + ) -> None: + """A v1 cache file (hex-ID keyed) is ignored on load.""" + import base64 + import json + import time + + legacy_payload = { + "version": 1, + "written_at": int(time.time()), + "entries": { + "abc123hex": base64.b64encode(b"stale").decode(), + }, + } + plaintext = json.dumps(legacy_payload, separators=(",", ":")).encode() + encrypted = backend.encrypt(plaintext) + header = struct.pack(">4sBB", MAGIC, VERSION, backend.tag) + cache.path.write_bytes(header + encrypted) + + fresh = Cache(cache.path, backend) + fresh.load() + assert len(fresh) == 0 + assert fresh.get("abc123hex") is None diff --git a/tests/test_serve_offline.py b/tests/test_serve_offline.py index 826bb84..5988cfd 100644 --- a/tests/test_serve_offline.py +++ b/tests/test_serve_offline.py @@ -83,10 +83,20 @@ def __init__(self, path: str, headers: dict[str, str], body: bytes = b"") -> Non return TestHandler +_DB_URL_MAPPING = '{"provider":"infisical","project":"p","path":"/","key":"DATABASE_URL"}' + + +def _cache_key_for(cache: Cache, mapping_json: str) -> str: + """Compute the HMAC cache key the same way serve does at lookup time.""" + from psi.provider import mapping_cache_bytes, parse_mapping + + return cache.cache_key(mapping_cache_bytes(parse_mapping(mapping_json))) + + @pytest.fixture def populated_cache(tmp_path: Path) -> Cache: cache = Cache(tmp_path / "cache.enc", _FakeBackend()) - cache.set("myapp--DATABASE_URL", b"postgres://prod/db") + cache.set(_cache_key_for(cache, _DB_URL_MAPPING), b"postgres://prod/db") cache.save() fresh = Cache(tmp_path / "cache.enc", _FakeBackend()) fresh.load() @@ -99,9 +109,7 @@ def test_cached_lookup_succeeds_when_provider_raises( tmp_path: Path, populated_cache: Cache, ) -> None: - (tmp_path / "myapp--DATABASE_URL").write_text( - '{"provider":"infisical","project":"p","path":"/","key":"DATABASE_URL"}' - ) + (tmp_path / "myapp--DATABASE_URL").write_text(_DB_URL_MAPPING) provider = MagicMock() provider.lookup.side_effect = ProviderError( "Cannot reach Infisical API", @@ -122,9 +130,8 @@ def test_cache_miss_falls_through_to_provider( tmp_path: Path, populated_cache: Cache, ) -> None: - (tmp_path / "myapp--NEW_KEY").write_text( - '{"provider":"infisical","project":"p","path":"/","key":"NEW_KEY"}' - ) + new_mapping = '{"provider":"infisical","project":"p","path":"/","key":"NEW_KEY"}' + (tmp_path / "myapp--NEW_KEY").write_text(new_mapping) provider = MagicMock() provider.lookup.return_value = b"fresh-value" @@ -135,7 +142,7 @@ def test_cache_miss_falls_through_to_provider( assert h._status == 200 assert h.wfile.getvalue() == b"fresh-value" provider.lookup.assert_called_once() - assert populated_cache.get("myapp--NEW_KEY") == b"fresh-value" + assert populated_cache.get(_cache_key_for(populated_cache, new_mapping)) == b"fresh-value" def test_provider_error_without_cache_entry_returns_502( self, diff --git a/tests/test_setup.py b/tests/test_setup.py index f8e3884..ee64d04 100644 --- a/tests/test_setup.py +++ b/tests/test_setup.py @@ -16,7 +16,6 @@ _RETRY_DELAYS, _generate_drop_in, _is_retryable, - _prune_stale_cache_entries, _register_secrets, _setup_infisical_workload, ) @@ -347,21 +346,16 @@ def mock_fetch(settings, workload_name, cache_updates): assert call_count == 1 -class TestRegisterSecretsIdMap: - def test_returns_secret_id_from_podman_api(self, tmp_path: Path) -> None: - """_register_secrets returns {key: hex_id} for cache keying.""" +class TestRegisterSecrets: + def test_calls_podman_api_with_delete_then_create_per_secret(self, tmp_path: Path) -> None: + """_register_secrets issues delete+create for each mapping.""" delete_resp = httpx.Response(204, request=httpx.Request("DELETE", "http://x")) create_resp = httpx.Response( 200, - json={"ID": "abc123hex"}, + json={"ID": "ignored"}, request=httpx.Request("POST", "http://x"), ) - def mock_request(method, url, **kwargs): - if method == "DELETE": - return delete_resp - return create_resp - settings = _make_settings(tmp_path) with patch("psi.setup.httpx.Client") as mock_client_cls: @@ -369,59 +363,11 @@ def mock_request(method, url, **kwargs): client.delete.return_value = delete_resp client.post.return_value = create_resp - id_map = _register_secrets(settings, "myapp", {"DB_URL": "{}"}) - - assert id_map == {"DB_URL": "abc123hex"} - - -class TestPruneStaleCacheEntries: - def test_drops_entries_not_in_active_podman_ids(self) -> None: - """Orphaned cache entries from prior setup runs are pruned.""" - from unittest.mock import MagicMock - - cache = MagicMock() - cache.entry_ids.return_value = ["active1", "stale-old", "active2", "stale-older"] - - podman_secrets = [ - {"ID": "active1", "Spec": {"Name": "x", "Driver": {"Name": "shell"}}}, - {"ID": "active2", "Spec": {"Name": "y", "Driver": {"Name": "shell"}}}, - ] - - with patch("psi.setup._list_podman_shell_secrets", return_value=podman_secrets): - _prune_stale_cache_entries(cache) - - invalidated = [call.args[0] for call in cache.invalidate.call_args_list] - assert sorted(invalidated) == ["stale-old", "stale-older"] - - def test_keeps_cache_intact_if_podman_api_unreachable(self) -> None: - """A Podman API failure should not drop any entries.""" - from unittest.mock import MagicMock + _register_secrets(settings, "myapp", {"DB_URL": "{}", "API_KEY": "{}"}) - cache = MagicMock() - cache.entry_ids.return_value = ["keep1", "keep2"] - - with patch( - "psi.setup._list_podman_shell_secrets", - side_effect=httpx.ConnectError("refused"), - ): - _prune_stale_cache_entries(cache) - - cache.invalidate.assert_not_called() - - def test_no_op_when_cache_is_already_clean(self) -> None: - """No invalidate calls when every entry is already in Podman.""" - from unittest.mock import MagicMock - - cache = MagicMock() - cache.entry_ids.return_value = ["abc", "def"] - - with patch( - "psi.setup._list_podman_shell_secrets", - return_value=[ - {"ID": "abc", "Spec": {"Name": "x", "Driver": {"Name": "shell"}}}, - {"ID": "def", "Spec": {"Name": "y", "Driver": {"Name": "shell"}}}, - ], - ): - _prune_stale_cache_entries(cache) - - cache.invalidate.assert_not_called() + assert client.delete.call_count == 2 + assert client.post.call_count == 2 + assert ( + "myapp--DB_URL" in client.delete.call_args_list[0].args[0] + or "myapp--DB_URL" in client.delete.call_args_list[1].args[0] + ) diff --git a/tests/test_unitgen.py b/tests/test_unitgen.py index 0017757..1f5e2a9 100644 --- a/tests/test_unitgen.py +++ b/tests/test_unitgen.py @@ -478,13 +478,13 @@ def test_orders_after_setup_unit(self) -> None: content = generate_provider_refresh_service("infisical") assert "After=psi-infisical-setup.service" in content - def test_restarts_psi_secrets_so_serve_reloads_the_fresh_cache(self) -> None: - """After setup writes a fresh cache with new hex IDs, psi-secrets must - restart to reload it — otherwise serve keeps the old IDs in memory and - every subsequent lookup misses the cache. + def test_does_not_force_restart_psi_secrets(self) -> None: + """Serve auto-reloads via cache file mtime watch. The refresh wrapper + must not restart psi-secrets — doing so caused a 30s window of + lookup failures after every refresh. """ content = generate_provider_refresh_service("infisical") - assert "ExecStart=/usr/bin/systemctl try-restart psi-secrets.service" in content + assert "psi-secrets.service" not in content class TestProviderRefreshTimer: