Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 109 additions & 18 deletions psi/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
- A :class:`threading.Lock` serializes concurrent writers; reads are lock-free
under CPython's GIL.

Cache keys are HMAC-SHA256 of the secret's mapping JSON bytes, using a random
32-byte key generated once on first save and preserved across loads. Keying on
the mapping (provider coordinates) instead of Podman's hex secret ID means the
cache survives Podman's delete+create churn during setup runs — the same
mapping always produces the same cache key, regardless of which hex ID Podman
currently associates with it.

Serve checks the file mtime on each lookup and reloads in place when setup has
written a fresh version. No forced restart required after rotations.

On-disk envelope:

::
Expand All @@ -21,14 +31,22 @@

Inside the decrypted payload is a JSON object::

{"version": 1, "written_at": <unix_ts>, "entries": {"<id>": "<base64_value>"}}
{"version": 2, "written_at": <unix_ts>,
"hmac_key": "<base64 32 bytes>",
"entries": {"<hmac-hex>": "<base64_value>"}}

Payload version 1 (legacy, keyed by Podman hex IDs) is accepted on load but
discarded — the next save writes version 2 with a freshly generated HMAC key.
"""

from __future__ import annotations

import base64
import binascii
import hashlib
import hmac
import json
import secrets
import struct
import threading
import time
Expand All @@ -48,7 +66,8 @@
VERSION = 0x01
_HEADER_FMT = ">4sBB"
_HEADER_SIZE = struct.calcsize(_HEADER_FMT)
_PAYLOAD_VERSION = 1
_PAYLOAD_VERSION = 2
_HMAC_KEY_BYTES = 32


class CacheError(ProviderError):
Expand Down Expand Up @@ -94,8 +113,10 @@ def __init__(self, path: Path, backend: CacheBackend) -> None:
self._path = path
self._backend = backend
self._entries: dict[str, bytes] = {}
self._hmac_key: bytes = secrets.token_bytes(_HMAC_KEY_BYTES)
self._lock = threading.Lock()
self._loaded = False
self._mtime_ns: int = 0

@property
def path(self) -> Path:
Expand All @@ -107,12 +128,28 @@ def backend_tag(self) -> int:
"""The backend discriminator byte used in the envelope header."""
return self._backend.tag

def cache_key(self, mapping_bytes: bytes) -> str:
"""Compute the stable cache key for a secret's mapping JSON.

HMAC-SHA256 over the raw mapping bytes with the cache's HMAC key. Two
callers (setup writer, serve reader) that see the same mapping bytes
produce identical cache keys, so the cache survives Podman's hex ID
churn. The HMAC key is per-host (never leaves the cache file) so
mapping hashes cannot be correlated across deployments.
"""
return hmac.new(self._hmac_key, mapping_bytes, hashlib.sha256).hexdigest()

def load(self) -> None:
"""Decrypt the backing file into memory.

If the file is missing, start with an empty cache — ``save()`` will
create it on the next mutation. If the file exists but is malformed or
was written by a different backend, raise :class:`CacheError`.

Legacy v1 payloads (hex-ID keyed) are silently discarded — the next
save writes v2 with a freshly generated HMAC key. A legacy cache is
indistinguishable from no cache at all; container startups during the
transition fall through to the provider.
"""
if not self._path.exists():
logger.info("Cache file not found at {}, starting empty", self._path)
Expand All @@ -121,6 +158,7 @@ def load(self) -> None:
return

raw = self._path.read_bytes()
self._mtime_ns = self._path.stat().st_mtime_ns
if len(raw) < _HEADER_SIZE:
msg = f"Cache file {self._path} is too short to contain a header"
raise CacheError(msg)
Expand All @@ -141,44 +179,67 @@ def load(self) -> None:
raise CacheError(msg)

plaintext = self._backend.decrypt(raw[_HEADER_SIZE:])
self._entries = _parse_payload(plaintext)
parsed = _parse_payload(plaintext)
self._hmac_key = parsed.hmac_key
self._entries = parsed.entries
self._loaded = True
logger.info("Loaded {} cache entries from {}", len(self._entries), self._path)

def maybe_reload(self) -> bool:
"""Reload from disk if the backing file's mtime has changed.

Called from the lookup hot path so serve picks up fresh entries
written by setup without needing a process restart. A ``stat`` call
is ~1μs on modern kernels; actual reload happens only when setup
has finished a write.

Returns:
True if a reload happened, False otherwise.
"""
try:
current_mtime = self._path.stat().st_mtime_ns
except FileNotFoundError:
return False
if current_mtime == self._mtime_ns:
return False
self.load()
return True

def save(self) -> None:
"""Encrypt the in-memory dict and atomically replace the backing file."""
with self._lock:
plaintext = _serialize_payload(self._entries)
plaintext = _serialize_payload(self._entries, self._hmac_key)
payload = self._backend.encrypt(plaintext)
header = struct.pack(_HEADER_FMT, MAGIC, VERSION, self._backend.tag)
write_bytes_secure(self._path, header + payload)
self._mtime_ns = self._path.stat().st_mtime_ns

def get(self, secret_id: str) -> bytes | None:
"""Return the cached value for ``secret_id``, or ``None`` on miss."""
return self._entries.get(secret_id)
def get(self, key: str) -> bytes | None:
"""Return the cached value for ``key``, or ``None`` on miss."""
return self._entries.get(key)

def set(self, secret_id: str, value: bytes) -> None:
"""Insert or update ``secret_id``. Does not persist; call :meth:`save`."""
def set(self, key: str, value: bytes) -> None:
"""Insert or update ``key``. Does not persist; call :meth:`save`."""
with self._lock:
self._entries[secret_id] = value
self._entries[key] = value

def bulk_set(self, entries: dict[str, bytes]) -> None:
"""Replace all entries atomically. Does not persist; call :meth:`save`."""
with self._lock:
self._entries = dict(entries)

def invalidate(self, secret_id: str) -> bool:
"""Drop ``secret_id`` from the cache. Returns True if it was present."""
def invalidate(self, key: str) -> bool:
"""Drop ``key`` from the cache. Returns True if it was present."""
with self._lock:
return self._entries.pop(secret_id, None) is not None
return self._entries.pop(key, None) is not None

def clear(self) -> None:
"""Drop all entries. Does not persist; call :meth:`save`."""
with self._lock:
self._entries = {}

def __contains__(self, secret_id: object) -> bool:
return secret_id in self._entries
def __contains__(self, key: object) -> bool:
return key in self._entries

def __len__(self) -> int:
return len(self._entries)
Expand All @@ -192,16 +253,27 @@ def close(self) -> None:
self._backend.close()


def _serialize_payload(entries: dict[str, bytes]) -> bytes:
class _ParsedPayload:
"""Deserialized cache payload: the HMAC key and the entry dict."""

__slots__ = ("entries", "hmac_key")

def __init__(self, hmac_key: bytes, entries: dict[str, bytes]) -> None:
self.hmac_key = hmac_key
self.entries = entries


def _serialize_payload(entries: dict[str, bytes], hmac_key: bytes) -> bytes:
payload = {
"version": _PAYLOAD_VERSION,
"written_at": int(time.time()),
"hmac_key": base64.b64encode(hmac_key).decode("ascii"),
"entries": {k: base64.b64encode(v).decode("ascii") for k, v in entries.items()},
}
return json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8")


def _parse_payload(plaintext: bytes) -> dict[str, bytes]:
def _parse_payload(plaintext: bytes) -> _ParsedPayload:
try:
payload = json.loads(plaintext.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as e:
Expand All @@ -213,10 +285,29 @@ def _parse_payload(plaintext: bytes) -> dict[str, bytes]:
raise CacheError(msg)

version = payload.get("version")
if version == 1:
logger.info(
"Cache payload is legacy v1 (hex-ID keyed) — discarding. "
"Next setup run will repopulate with HMAC-keyed v2."
)
return _ParsedPayload(secrets.token_bytes(_HMAC_KEY_BYTES), {})
if version != _PAYLOAD_VERSION:
msg = f"Cache payload version {version} is unsupported (expected {_PAYLOAD_VERSION})"
raise CacheError(msg)

raw_hmac_key = payload.get("hmac_key")
if not isinstance(raw_hmac_key, str):
msg = "Cache payload 'hmac_key' must be a base64 string"
raise CacheError(msg)
try:
hmac_key = base64.b64decode(raw_hmac_key, validate=True)
except (ValueError, binascii.Error) as e:
msg = f"Cache 'hmac_key' is not valid base64: {e}"
raise CacheError(msg) from e
if len(hmac_key) != _HMAC_KEY_BYTES:
msg = f"Cache 'hmac_key' must be {_HMAC_KEY_BYTES} bytes, got {len(hmac_key)}"
raise CacheError(msg)

raw_entries = payload.get("entries")
if not isinstance(raw_entries, dict):
msg = "Cache payload 'entries' must be a JSON object"
Expand All @@ -232,4 +323,4 @@ def _parse_payload(plaintext: bytes) -> dict[str, bytes]:
except (ValueError, binascii.Error) as e:
msg = f"Cache entry {key!r} has invalid base64: {e}"
raise CacheError(msg) from e
return entries
return _ParsedPayload(hmac_key, entries)
10 changes: 10 additions & 0 deletions psi/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ def parse_mapping(raw: str) -> dict:
return data


def mapping_cache_bytes(mapping_data: dict) -> bytes:
"""Canonical byte representation of a parsed mapping for cache keying.

Both the setup writer and the serve reader compute the cache key from
this canonical form, so any trailing whitespace or key-order differences
in the on-disk mapping file do not produce spurious cache misses.
"""
return json.dumps(mapping_data, separators=(",", ":"), sort_keys=True).encode("utf-8")


def get_provider(name: str, settings: PsiSettings) -> SecretProvider:
"""Instantiate a provider by name from settings."""
from psi.providers import create_provider
Expand Down
16 changes: 12 additions & 4 deletions psi/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

from psi.errors import PsiError
from psi.files import write_bytes_secure
from psi.provider import close_all_providers, open_all_providers, parse_mapping
from psi.provider import (
close_all_providers,
mapping_cache_bytes,
open_all_providers,
parse_mapping,
)
from psi.secret import validate_secret_id
from psi.token import resolve_socket_token

Expand Down Expand Up @@ -225,8 +230,11 @@ def _handle_lookup(self, secret_id: str) -> None:
provider=provider_name,
)

cache_key: str | None = None
if cache is not None:
cached = cache.get(secret_id)
cache.maybe_reload()
cache_key = cache.cache_key(mapping_cache_bytes(mapping_data))
cached = cache.get(cache_key)
if cached is not None:
self._respond(200, cached)
audit.bind(outcome="success", source="cache").debug("lookup")
Expand All @@ -243,8 +251,8 @@ def _handle_lookup(self, secret_id: str) -> None:
self._respond_error(502, "internal_error", str(e))
return

if cache is not None:
cache.set(secret_id, value)
if cache is not None and cache_key is not None:
cache.set(cache_key, value)

self._respond(200, value)
audit.bind(outcome="success", source="provider").info("lookup")
Expand Down
Loading
Loading