From 6e95e894214f629faf4521402f27fd0fc261c401 Mon Sep 17 00:00:00 2001 From: Joe Doss Date: Fri, 17 Apr 2026 16:15:38 -0500 Subject: [PATCH] Key cache by HMAC of mapping content, auto-reload on mtime change MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The cache was keyed by Podman's hex secret ID, which Podman regenerates every time setup runs its delete+create cycle. This caused the cache to be silently useless after every refresh timer fire — serve's in- memory dict had the old IDs, Podman was handing out new IDs, 100% cache miss rate until a manual serve restart. Observed on the test server: 1554 lookups across 30 minutes after a refresh fire, every single one fell through to the provider. When Infisical then went down, every container secret lookup failed with 502 — the exact outage the cache was built to prevent. Root-cause fix, one change: - Cache keys are HMAC-SHA256 of the mapping's canonical JSON bytes, not the Podman hex ID. Same mapping always yields the same key, no matter how often Podman churns the hex IDs. The HMAC key is random, per-host, stored inside the encrypted cache envelope — mapping hashes cannot be correlated across deployments. - Serve calls cache.maybe_reload() on the lookup hot path. A stat() per request (~1μs); actual reload only when setup has rewritten the file. Rotations propagate to serve without a restart. Cleanups enabled by the new design: - Drop _prune_stale_cache_entries from setup — no stale entries to prune when keys are content-derived. - Drop the id_map return from _register_secrets — cache doesn't need hex IDs any more. - Drop ExecStart=systemctl try-restart psi-secrets.service from the refresh wrapper — auto-reload handles it. Legacy v1 payloads (hex-ID keyed) are discarded on load; next save rewrites in v2 format with a freshly generated HMAC key. Container lookups during the one-time transition fall through to the provider exactly once. --- psi/cache.py | 127 +++++++++++++++++++++++++++++++----- psi/provider.py | 10 +++ psi/serve.py | 16 +++-- psi/setup.py | 73 +++++++-------------- psi/unitgen.py | 16 ++--- tests/test_cache.py | 97 +++++++++++++++++++++++++++ tests/test_serve_offline.py | 23 ++++--- tests/test_setup.py | 76 ++++----------------- tests/test_unitgen.py | 10 +-- 9 files changed, 288 insertions(+), 160 deletions(-) diff --git a/psi/cache.py b/psi/cache.py index 7d4ae60..905bd8b 100644 --- a/psi/cache.py +++ b/psi/cache.py @@ -10,6 +10,16 @@ - A :class:`threading.Lock` serializes concurrent writers; reads are lock-free under CPython's GIL. +Cache keys are HMAC-SHA256 of the secret's mapping JSON bytes, using a random +32-byte key generated once on first save and preserved across loads. Keying on +the mapping (provider coordinates) instead of Podman's hex secret ID means the +cache survives Podman's delete+create churn during setup runs — the same +mapping always produces the same cache key, regardless of which hex ID Podman +currently associates with it. + +Serve checks the file mtime on each lookup and reloads in place when setup has +written a fresh version. No forced restart required after rotations. + On-disk envelope: :: @@ -21,14 +31,22 @@ Inside the decrypted payload is a JSON object:: - {"version": 1, "written_at": , "entries": {"": ""}} + {"version": 2, "written_at": , + "hmac_key": "", + "entries": {"": ""}} + +Payload version 1 (legacy, keyed by Podman hex IDs) is accepted on load but +discarded — the next save writes version 2 with a freshly generated HMAC key. """ from __future__ import annotations import base64 import binascii +import hashlib +import hmac import json +import secrets import struct import threading import time @@ -48,7 +66,8 @@ VERSION = 0x01 _HEADER_FMT = ">4sBB" _HEADER_SIZE = struct.calcsize(_HEADER_FMT) -_PAYLOAD_VERSION = 1 +_PAYLOAD_VERSION = 2 +_HMAC_KEY_BYTES = 32 class CacheError(ProviderError): @@ -94,8 +113,10 @@ def __init__(self, path: Path, backend: CacheBackend) -> None: self._path = path self._backend = backend self._entries: dict[str, bytes] = {} + self._hmac_key: bytes = secrets.token_bytes(_HMAC_KEY_BYTES) self._lock = threading.Lock() self._loaded = False + self._mtime_ns: int = 0 @property def path(self) -> Path: @@ -107,12 +128,28 @@ def backend_tag(self) -> int: """The backend discriminator byte used in the envelope header.""" return self._backend.tag + def cache_key(self, mapping_bytes: bytes) -> str: + """Compute the stable cache key for a secret's mapping JSON. + + HMAC-SHA256 over the raw mapping bytes with the cache's HMAC key. Two + callers (setup writer, serve reader) that see the same mapping bytes + produce identical cache keys, so the cache survives Podman's hex ID + churn. The HMAC key is per-host (never leaves the cache file) so + mapping hashes cannot be correlated across deployments. + """ + return hmac.new(self._hmac_key, mapping_bytes, hashlib.sha256).hexdigest() + def load(self) -> None: """Decrypt the backing file into memory. If the file is missing, start with an empty cache — ``save()`` will create it on the next mutation. If the file exists but is malformed or was written by a different backend, raise :class:`CacheError`. + + Legacy v1 payloads (hex-ID keyed) are silently discarded — the next + save writes v2 with a freshly generated HMAC key. A legacy cache is + indistinguishable from no cache at all; container startups during the + transition fall through to the provider. """ if not self._path.exists(): logger.info("Cache file not found at {}, starting empty", self._path) @@ -121,6 +158,7 @@ def load(self) -> None: return raw = self._path.read_bytes() + self._mtime_ns = self._path.stat().st_mtime_ns if len(raw) < _HEADER_SIZE: msg = f"Cache file {self._path} is too short to contain a header" raise CacheError(msg) @@ -141,44 +179,67 @@ def load(self) -> None: raise CacheError(msg) plaintext = self._backend.decrypt(raw[_HEADER_SIZE:]) - self._entries = _parse_payload(plaintext) + parsed = _parse_payload(plaintext) + self._hmac_key = parsed.hmac_key + self._entries = parsed.entries self._loaded = True logger.info("Loaded {} cache entries from {}", len(self._entries), self._path) + def maybe_reload(self) -> bool: + """Reload from disk if the backing file's mtime has changed. + + Called from the lookup hot path so serve picks up fresh entries + written by setup without needing a process restart. A ``stat`` call + is ~1μs on modern kernels; actual reload happens only when setup + has finished a write. + + Returns: + True if a reload happened, False otherwise. + """ + try: + current_mtime = self._path.stat().st_mtime_ns + except FileNotFoundError: + return False + if current_mtime == self._mtime_ns: + return False + self.load() + return True + def save(self) -> None: """Encrypt the in-memory dict and atomically replace the backing file.""" with self._lock: - plaintext = _serialize_payload(self._entries) + plaintext = _serialize_payload(self._entries, self._hmac_key) payload = self._backend.encrypt(plaintext) header = struct.pack(_HEADER_FMT, MAGIC, VERSION, self._backend.tag) write_bytes_secure(self._path, header + payload) + self._mtime_ns = self._path.stat().st_mtime_ns - def get(self, secret_id: str) -> bytes | None: - """Return the cached value for ``secret_id``, or ``None`` on miss.""" - return self._entries.get(secret_id) + def get(self, key: str) -> bytes | None: + """Return the cached value for ``key``, or ``None`` on miss.""" + return self._entries.get(key) - def set(self, secret_id: str, value: bytes) -> None: - """Insert or update ``secret_id``. Does not persist; call :meth:`save`.""" + def set(self, key: str, value: bytes) -> None: + """Insert or update ``key``. Does not persist; call :meth:`save`.""" with self._lock: - self._entries[secret_id] = value + self._entries[key] = value def bulk_set(self, entries: dict[str, bytes]) -> None: """Replace all entries atomically. Does not persist; call :meth:`save`.""" with self._lock: self._entries = dict(entries) - def invalidate(self, secret_id: str) -> bool: - """Drop ``secret_id`` from the cache. Returns True if it was present.""" + def invalidate(self, key: str) -> bool: + """Drop ``key`` from the cache. Returns True if it was present.""" with self._lock: - return self._entries.pop(secret_id, None) is not None + return self._entries.pop(key, None) is not None def clear(self) -> None: """Drop all entries. Does not persist; call :meth:`save`.""" with self._lock: self._entries = {} - def __contains__(self, secret_id: object) -> bool: - return secret_id in self._entries + def __contains__(self, key: object) -> bool: + return key in self._entries def __len__(self) -> int: return len(self._entries) @@ -192,16 +253,27 @@ def close(self) -> None: self._backend.close() -def _serialize_payload(entries: dict[str, bytes]) -> bytes: +class _ParsedPayload: + """Deserialized cache payload: the HMAC key and the entry dict.""" + + __slots__ = ("entries", "hmac_key") + + def __init__(self, hmac_key: bytes, entries: dict[str, bytes]) -> None: + self.hmac_key = hmac_key + self.entries = entries + + +def _serialize_payload(entries: dict[str, bytes], hmac_key: bytes) -> bytes: payload = { "version": _PAYLOAD_VERSION, "written_at": int(time.time()), + "hmac_key": base64.b64encode(hmac_key).decode("ascii"), "entries": {k: base64.b64encode(v).decode("ascii") for k, v in entries.items()}, } return json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8") -def _parse_payload(plaintext: bytes) -> dict[str, bytes]: +def _parse_payload(plaintext: bytes) -> _ParsedPayload: try: payload = json.loads(plaintext.decode("utf-8")) except (UnicodeDecodeError, json.JSONDecodeError) as e: @@ -213,10 +285,29 @@ def _parse_payload(plaintext: bytes) -> dict[str, bytes]: raise CacheError(msg) version = payload.get("version") + if version == 1: + logger.info( + "Cache payload is legacy v1 (hex-ID keyed) — discarding. " + "Next setup run will repopulate with HMAC-keyed v2." + ) + return _ParsedPayload(secrets.token_bytes(_HMAC_KEY_BYTES), {}) if version != _PAYLOAD_VERSION: msg = f"Cache payload version {version} is unsupported (expected {_PAYLOAD_VERSION})" raise CacheError(msg) + raw_hmac_key = payload.get("hmac_key") + if not isinstance(raw_hmac_key, str): + msg = "Cache payload 'hmac_key' must be a base64 string" + raise CacheError(msg) + try: + hmac_key = base64.b64decode(raw_hmac_key, validate=True) + except (ValueError, binascii.Error) as e: + msg = f"Cache 'hmac_key' is not valid base64: {e}" + raise CacheError(msg) from e + if len(hmac_key) != _HMAC_KEY_BYTES: + msg = f"Cache 'hmac_key' must be {_HMAC_KEY_BYTES} bytes, got {len(hmac_key)}" + raise CacheError(msg) + raw_entries = payload.get("entries") if not isinstance(raw_entries, dict): msg = "Cache payload 'entries' must be a JSON object" @@ -232,4 +323,4 @@ def _parse_payload(plaintext: bytes) -> dict[str, bytes]: except (ValueError, binascii.Error) as e: msg = f"Cache entry {key!r} has invalid base64: {e}" raise CacheError(msg) from e - return entries + return _ParsedPayload(hmac_key, entries) diff --git a/psi/provider.py b/psi/provider.py index 81581aa..e2a4eba 100644 --- a/psi/provider.py +++ b/psi/provider.py @@ -48,6 +48,16 @@ def parse_mapping(raw: str) -> dict: return data +def mapping_cache_bytes(mapping_data: dict) -> bytes: + """Canonical byte representation of a parsed mapping for cache keying. + + Both the setup writer and the serve reader compute the cache key from + this canonical form, so any trailing whitespace or key-order differences + in the on-disk mapping file do not produce spurious cache misses. + """ + return json.dumps(mapping_data, separators=(",", ":"), sort_keys=True).encode("utf-8") + + def get_provider(name: str, settings: PsiSettings) -> SecretProvider: """Instantiate a provider by name from settings.""" from psi.providers import create_provider diff --git a/psi/serve.py b/psi/serve.py index b84e45a..a60889b 100644 --- a/psi/serve.py +++ b/psi/serve.py @@ -13,7 +13,12 @@ from psi.errors import PsiError from psi.files import write_bytes_secure -from psi.provider import close_all_providers, open_all_providers, parse_mapping +from psi.provider import ( + close_all_providers, + mapping_cache_bytes, + open_all_providers, + parse_mapping, +) from psi.secret import validate_secret_id from psi.token import resolve_socket_token @@ -225,8 +230,11 @@ def _handle_lookup(self, secret_id: str) -> None: provider=provider_name, ) + cache_key: str | None = None if cache is not None: - cached = cache.get(secret_id) + cache.maybe_reload() + cache_key = cache.cache_key(mapping_cache_bytes(mapping_data)) + cached = cache.get(cache_key) if cached is not None: self._respond(200, cached) audit.bind(outcome="success", source="cache").debug("lookup") @@ -243,8 +251,8 @@ def _handle_lookup(self, secret_id: str) -> None: self._respond_error(502, "internal_error", str(e)) return - if cache is not None: - cache.set(secret_id, value) + if cache is not None and cache_key is not None: + cache.set(cache_key, value) self._respond(200, value) audit.bind(outcome="success", source="provider").info("lookup") diff --git a/psi/setup.py b/psi/setup.py index c7f9ec3..f609ed2 100644 --- a/psi/setup.py +++ b/psi/setup.py @@ -44,7 +44,9 @@ def run_setup( settings.state_dir.mkdir(parents=True, exist_ok=True) cache = _open_setup_cache(settings) - cache_updates: dict[str, bytes] = {} + # Keyed by the canonical mapping bytes so the caller can compute the + # HMAC cache key once the cache object is available. + values_by_mapping: dict[bytes, bytes] = {} try: for workload_name, workload in settings.workloads.items(): @@ -54,18 +56,16 @@ def run_setup( logger.info("Workload: {}", workload_name) if workload.provider == "infisical": - _setup_infisical_workload(settings, workload_name, cache_updates) + _setup_infisical_workload(settings, workload_name, values_by_mapping) elif workload.provider == "nitrokeyhsm": logger.info("Nitrokey HSM workload — secrets created via 'psi nitrokeyhsm store'") else: logger.warning("Unknown provider '{}', skipping", workload.provider) - if cache is not None: - if cache_updates: - logger.info("Writing {} entries to secret cache", len(cache_updates)) - for key, value in cache_updates.items(): - cache.set(key, value) - _prune_stale_cache_entries(cache) + if cache is not None and values_by_mapping: + logger.info("Writing {} entries to secret cache", len(values_by_mapping)) + for mapping_bytes, value in values_by_mapping.items(): + cache.set(cache.cache_key(mapping_bytes), value) cache.save() finally: if cache is not None: @@ -76,27 +76,6 @@ def run_setup( logger.info("Setup complete.") -def _prune_stale_cache_entries(cache: Cache) -> None: - """Drop cache entries whose keys are not currently in Podman's secret store. - - Each time ``_register_secrets`` deletes and re-creates a Podman secret, - Podman assigns a new hex ID. The old ID's cache entry becomes orphaned — - valid ciphertext for a secret that no longer exists. Without pruning, the - cache grows unboundedly across setup runs. - """ - try: - active_ids = {s.get("ID", "") for s in _list_podman_shell_secrets()} - except httpx.HTTPError as e: - logger.warning("Cannot query Podman secrets for cache pruning: {}", e) - return - - stale = [k for k in cache.entry_ids() if k not in active_ids] - if stale: - logger.info("Pruning {} stale cache entries", len(stale)) - for key in stale: - cache.invalidate(key) - - def _open_setup_cache(settings: PsiSettings) -> Cache | None: """Open the cache for write during setup, or return None on any failure.""" if not settings.cache.enabled or settings.cache.backend is None: @@ -144,13 +123,13 @@ def _is_retryable(exc: Exception) -> bool: def _setup_infisical_workload( settings: PsiSettings, workload_name: str, - cache_updates: dict[str, bytes], + values_by_mapping: dict[bytes, bytes], ) -> None: """Run Infisical-specific setup for a workload with retry.""" last_exc: Exception | None = None for attempt in range(len(_RETRY_DELAYS) + 1): try: - _fetch_and_register_infisical(settings, workload_name, cache_updates) + _fetch_and_register_infisical(settings, workload_name, values_by_mapping) return except (httpx.ConnectError, httpx.HTTPStatusError, ProviderError) as e: cause = e.__cause__ if isinstance(e, ProviderError) else e @@ -174,15 +153,17 @@ def _setup_infisical_workload( def _fetch_and_register_infisical( settings: PsiSettings, workload_name: str, - cache_updates: dict[str, bytes], + values_by_mapping: dict[bytes, bytes], ) -> None: """Fetch secrets from Infisical and register with Podman. - Populates ``cache_updates`` with ``{podman_hex_id: value_bytes}`` so the - caller can flush the encrypted cache once all workloads are processed. - The cache must be keyed by Podman's hex secret ID because that is what - the serve lookup path receives in ``$SECRET_ID``. + Populates ``values_by_mapping`` with ``{canonical_mapping_bytes: value}``. + The caller computes the HMAC cache key from these bytes once the cache + is available. Keying by mapping content makes the cache survive Podman's + delete+create churn — the same mapping always produces the same cache + key, regardless of the hex ID Podman has assigned to it today. """ + from psi.provider import mapping_cache_bytes, parse_mapping from psi.providers.infisical import InfisicalProvider from psi.providers.infisical.models import InfisicalConfig, resolve_auth @@ -229,13 +210,12 @@ def _fetch_and_register_infisical( logger.info("Found {} secrets", len(secrets)) logger.info("Merged: {} unique secrets", len(merged)) - id_map = _register_secrets(settings, workload_name, merged) + _register_secrets(settings, workload_name, merged) _generate_drop_in(settings, workload_name, merged) for key, value in values.items(): - secret_id = id_map.get(key, "") - if secret_id: - cache_updates[secret_id] = value + mapping_bytes = mapping_cache_bytes(parse_mapping(merged[key])) + values_by_mapping[mapping_bytes] = value finally: provider.close() @@ -244,16 +224,15 @@ def _register_secrets( settings: PsiSettings, workload_name: str, secrets: dict[str, str], -) -> dict[str, str]: +) -> None: """Create namespaced Podman secrets with mapping data. - Returns: - Mapping of ``{secret_key: podman_hex_id}`` so the caller can - populate the encrypted cache with the correct lookup key. + The hex IDs Podman assigns during delete+create are no longer tracked — + cache keying is by mapping content via HMAC, not by Podman's volatile + hex IDs. """ transport = httpx.HTTPTransport(uds=_podman_socket_url()) base = f"http://localhost/{_PODMAN_API_VERSION}" - id_map: dict[str, str] = {} with httpx.Client(transport=transport, timeout=30.0) as client: for secret_name, mapping_json in secrets.items(): @@ -265,12 +244,8 @@ def _register_secrets( content=mapping_json.encode(), ) resp.raise_for_status() - secret_id = resp.json().get("ID", "") - if secret_id: - id_map[secret_name] = secret_id logger.info("Registered {} secrets with Podman", len(secrets)) - return id_map def _generate_drop_in( diff --git a/psi/unitgen.py b/psi/unitgen.py index 56f0958..a047e34 100644 --- a/psi/unitgen.py +++ b/psi/unitgen.py @@ -159,16 +159,11 @@ def generate_provider_refresh_service(provider: str) -> str: timestamp will only fire once. This wrapper is a plain oneshot with no ``RemainAfterExit``, so its - ``ActiveEnterTimestamp`` updates every run. The timer uses - ``OnUnitActiveSec`` against the wrapper and re-arms correctly. Each run: - - 1. ``systemctl restart psi-{provider}-setup.service`` — re-runs setup, - which re-registers secrets with fresh hex IDs and writes the updated - cache file to disk. - 2. ``systemctl try-restart psi-secrets.service`` — restarts serve so it - reloads the fresh cache from disk. Without this, serve's in-memory - cache keeps the old hex IDs after each refresh and every subsequent - lookup misses the cache until the next operator-triggered restart. + ``ActiveEnterTimestamp`` updates every run. Each run calls + ``systemctl restart`` on the setup unit, which DOES re-run the ExecStart + even when the setup unit was ``active (exited)``. Setup writes the fresh + cache file; ``psi serve`` auto-reloads via mtime watch on the next + lookup — no forced serve restart required. """ return ( "[Unit]\n" @@ -178,7 +173,6 @@ def generate_provider_refresh_service(provider: str) -> str: "[Service]\n" "Type=oneshot\n" f"ExecStart=/usr/bin/systemctl restart psi-{provider}-setup.service\n" - "ExecStart=/usr/bin/systemctl try-restart psi-secrets.service\n" ) diff --git a/tests/test_cache.py b/tests/test_cache.py index 062da0a..134d29d 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -196,3 +196,100 @@ class TestClose: def test_close_delegates_to_backend(self, cache: Cache, backend: FakeBackend) -> None: cache.close() assert backend.close_calls == 1 + + +class TestCacheKey: + def test_same_mapping_bytes_produce_same_key(self, cache: Cache) -> None: + mapping = b'{"provider":"infisical","project":"p","path":"/","key":"K"}' + assert cache.cache_key(mapping) == cache.cache_key(mapping) + + def test_different_mappings_produce_different_keys(self, cache: Cache) -> None: + a = b'{"provider":"infisical","project":"p","path":"/","key":"A"}' + b = b'{"provider":"infisical","project":"p","path":"/","key":"B"}' + assert cache.cache_key(a) != cache.cache_key(b) + + def test_key_is_stable_across_save_and_load(self, cache: Cache) -> None: + mapping = b'{"provider":"infisical","project":"p","path":"/","key":"K"}' + key_before = cache.cache_key(mapping) + cache.set(key_before, b"value") + cache.save() + + reloaded = Cache(cache.path, FakeBackend()) + reloaded.load() + assert reloaded.cache_key(mapping) == key_before + assert reloaded.get(key_before) == b"value" + + def test_hmac_key_is_per_host_random(self, tmp_path: Path) -> None: + """Two freshly-initialized caches produce different keys for same input. + + The HMAC key is random on init; only load() imports one from an + existing file. Cross-host correlation of cache contents is impossible. + """ + mapping = b'{"provider":"infisical","project":"p","path":"/","key":"K"}' + a = Cache(tmp_path / "a.enc", FakeBackend()) + b = Cache(tmp_path / "b.enc", FakeBackend()) + assert a.cache_key(mapping) != b.cache_key(mapping) + + +class TestMaybeReload: + def test_no_reload_when_mtime_unchanged(self, cache: Cache) -> None: + cache.set("k", b"v") + cache.save() + assert cache.maybe_reload() is False + + def test_reloads_when_file_is_rewritten_by_another_writer(self, cache: Cache) -> None: + """Setup writes; serve (running in another process) picks up changes.""" + # Initial state: serve sees "old" + cache.set("k", b"old") + cache.save() + + # Simulate setup's writer: new Cache instance, writes new value + other = Cache(cache.path, FakeBackend()) + other.load() + other.set("k", b"new") + # Force a distinct mtime even on fast filesystems + import os as _os + import time as _time + + stat_before = _os.stat(cache.path) + while True: + other.save() + if _os.stat(cache.path).st_mtime_ns != stat_before.st_mtime_ns: + break + _time.sleep(0.01) + + assert cache.maybe_reload() is True + assert cache.get("k") == b"new" + + def test_missing_file_returns_false_no_crash( + self, tmp_path: Path, backend: FakeBackend + ) -> None: + cache = Cache(tmp_path / "does-not-exist.enc", backend) + assert cache.maybe_reload() is False + + +class TestLegacyV1PayloadDiscarded: + def test_v1_payload_is_treated_as_empty_with_fresh_hmac_key( + self, cache: Cache, backend: FakeBackend + ) -> None: + """A v1 cache file (hex-ID keyed) is ignored on load.""" + import base64 + import json + import time + + legacy_payload = { + "version": 1, + "written_at": int(time.time()), + "entries": { + "abc123hex": base64.b64encode(b"stale").decode(), + }, + } + plaintext = json.dumps(legacy_payload, separators=(",", ":")).encode() + encrypted = backend.encrypt(plaintext) + header = struct.pack(">4sBB", MAGIC, VERSION, backend.tag) + cache.path.write_bytes(header + encrypted) + + fresh = Cache(cache.path, backend) + fresh.load() + assert len(fresh) == 0 + assert fresh.get("abc123hex") is None diff --git a/tests/test_serve_offline.py b/tests/test_serve_offline.py index 826bb84..5988cfd 100644 --- a/tests/test_serve_offline.py +++ b/tests/test_serve_offline.py @@ -83,10 +83,20 @@ def __init__(self, path: str, headers: dict[str, str], body: bytes = b"") -> Non return TestHandler +_DB_URL_MAPPING = '{"provider":"infisical","project":"p","path":"/","key":"DATABASE_URL"}' + + +def _cache_key_for(cache: Cache, mapping_json: str) -> str: + """Compute the HMAC cache key the same way serve does at lookup time.""" + from psi.provider import mapping_cache_bytes, parse_mapping + + return cache.cache_key(mapping_cache_bytes(parse_mapping(mapping_json))) + + @pytest.fixture def populated_cache(tmp_path: Path) -> Cache: cache = Cache(tmp_path / "cache.enc", _FakeBackend()) - cache.set("myapp--DATABASE_URL", b"postgres://prod/db") + cache.set(_cache_key_for(cache, _DB_URL_MAPPING), b"postgres://prod/db") cache.save() fresh = Cache(tmp_path / "cache.enc", _FakeBackend()) fresh.load() @@ -99,9 +109,7 @@ def test_cached_lookup_succeeds_when_provider_raises( tmp_path: Path, populated_cache: Cache, ) -> None: - (tmp_path / "myapp--DATABASE_URL").write_text( - '{"provider":"infisical","project":"p","path":"/","key":"DATABASE_URL"}' - ) + (tmp_path / "myapp--DATABASE_URL").write_text(_DB_URL_MAPPING) provider = MagicMock() provider.lookup.side_effect = ProviderError( "Cannot reach Infisical API", @@ -122,9 +130,8 @@ def test_cache_miss_falls_through_to_provider( tmp_path: Path, populated_cache: Cache, ) -> None: - (tmp_path / "myapp--NEW_KEY").write_text( - '{"provider":"infisical","project":"p","path":"/","key":"NEW_KEY"}' - ) + new_mapping = '{"provider":"infisical","project":"p","path":"/","key":"NEW_KEY"}' + (tmp_path / "myapp--NEW_KEY").write_text(new_mapping) provider = MagicMock() provider.lookup.return_value = b"fresh-value" @@ -135,7 +142,7 @@ def test_cache_miss_falls_through_to_provider( assert h._status == 200 assert h.wfile.getvalue() == b"fresh-value" provider.lookup.assert_called_once() - assert populated_cache.get("myapp--NEW_KEY") == b"fresh-value" + assert populated_cache.get(_cache_key_for(populated_cache, new_mapping)) == b"fresh-value" def test_provider_error_without_cache_entry_returns_502( self, diff --git a/tests/test_setup.py b/tests/test_setup.py index f8e3884..ee64d04 100644 --- a/tests/test_setup.py +++ b/tests/test_setup.py @@ -16,7 +16,6 @@ _RETRY_DELAYS, _generate_drop_in, _is_retryable, - _prune_stale_cache_entries, _register_secrets, _setup_infisical_workload, ) @@ -347,21 +346,16 @@ def mock_fetch(settings, workload_name, cache_updates): assert call_count == 1 -class TestRegisterSecretsIdMap: - def test_returns_secret_id_from_podman_api(self, tmp_path: Path) -> None: - """_register_secrets returns {key: hex_id} for cache keying.""" +class TestRegisterSecrets: + def test_calls_podman_api_with_delete_then_create_per_secret(self, tmp_path: Path) -> None: + """_register_secrets issues delete+create for each mapping.""" delete_resp = httpx.Response(204, request=httpx.Request("DELETE", "http://x")) create_resp = httpx.Response( 200, - json={"ID": "abc123hex"}, + json={"ID": "ignored"}, request=httpx.Request("POST", "http://x"), ) - def mock_request(method, url, **kwargs): - if method == "DELETE": - return delete_resp - return create_resp - settings = _make_settings(tmp_path) with patch("psi.setup.httpx.Client") as mock_client_cls: @@ -369,59 +363,11 @@ def mock_request(method, url, **kwargs): client.delete.return_value = delete_resp client.post.return_value = create_resp - id_map = _register_secrets(settings, "myapp", {"DB_URL": "{}"}) - - assert id_map == {"DB_URL": "abc123hex"} - - -class TestPruneStaleCacheEntries: - def test_drops_entries_not_in_active_podman_ids(self) -> None: - """Orphaned cache entries from prior setup runs are pruned.""" - from unittest.mock import MagicMock - - cache = MagicMock() - cache.entry_ids.return_value = ["active1", "stale-old", "active2", "stale-older"] - - podman_secrets = [ - {"ID": "active1", "Spec": {"Name": "x", "Driver": {"Name": "shell"}}}, - {"ID": "active2", "Spec": {"Name": "y", "Driver": {"Name": "shell"}}}, - ] - - with patch("psi.setup._list_podman_shell_secrets", return_value=podman_secrets): - _prune_stale_cache_entries(cache) - - invalidated = [call.args[0] for call in cache.invalidate.call_args_list] - assert sorted(invalidated) == ["stale-old", "stale-older"] - - def test_keeps_cache_intact_if_podman_api_unreachable(self) -> None: - """A Podman API failure should not drop any entries.""" - from unittest.mock import MagicMock + _register_secrets(settings, "myapp", {"DB_URL": "{}", "API_KEY": "{}"}) - cache = MagicMock() - cache.entry_ids.return_value = ["keep1", "keep2"] - - with patch( - "psi.setup._list_podman_shell_secrets", - side_effect=httpx.ConnectError("refused"), - ): - _prune_stale_cache_entries(cache) - - cache.invalidate.assert_not_called() - - def test_no_op_when_cache_is_already_clean(self) -> None: - """No invalidate calls when every entry is already in Podman.""" - from unittest.mock import MagicMock - - cache = MagicMock() - cache.entry_ids.return_value = ["abc", "def"] - - with patch( - "psi.setup._list_podman_shell_secrets", - return_value=[ - {"ID": "abc", "Spec": {"Name": "x", "Driver": {"Name": "shell"}}}, - {"ID": "def", "Spec": {"Name": "y", "Driver": {"Name": "shell"}}}, - ], - ): - _prune_stale_cache_entries(cache) - - cache.invalidate.assert_not_called() + assert client.delete.call_count == 2 + assert client.post.call_count == 2 + assert ( + "myapp--DB_URL" in client.delete.call_args_list[0].args[0] + or "myapp--DB_URL" in client.delete.call_args_list[1].args[0] + ) diff --git a/tests/test_unitgen.py b/tests/test_unitgen.py index 0017757..1f5e2a9 100644 --- a/tests/test_unitgen.py +++ b/tests/test_unitgen.py @@ -478,13 +478,13 @@ def test_orders_after_setup_unit(self) -> None: content = generate_provider_refresh_service("infisical") assert "After=psi-infisical-setup.service" in content - def test_restarts_psi_secrets_so_serve_reloads_the_fresh_cache(self) -> None: - """After setup writes a fresh cache with new hex IDs, psi-secrets must - restart to reload it — otherwise serve keeps the old IDs in memory and - every subsequent lookup misses the cache. + def test_does_not_force_restart_psi_secrets(self) -> None: + """Serve auto-reloads via cache file mtime watch. The refresh wrapper + must not restart psi-secrets — doing so caused a 30s window of + lookup failures after every refresh. """ content = generate_provider_refresh_service("infisical") - assert "ExecStart=/usr/bin/systemctl try-restart psi-secrets.service" in content + assert "psi-secrets.service" not in content class TestProviderRefreshTimer: