Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions internal/providers/cache/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ import (
// of a nil-pointer panic.
var errNilRedisClient = errors.New("redis admin client is not configured")

// randRead is the entropy source for ACL passwords. It defaults to
// crypto/rand.Read and is a package var only so a test can substitute a
// fault-injecting reader to exercise the (otherwise unreachable) RNG-failure
// branch in provisionLocal. Production behaviour is identical to calling
// crypto/rand.Read directly.
var randRead = rand.Read

// storageBytesScanCap is the hard ceiling on keys inspected per StorageBytes
// call. It defaults to storageMaxKeys (200k). It is a package var only so a
// test can lower it to a handful and deterministically exercise the
// truncation branch without writing 200k keys. Production uses the 200k
// default — see the storageMaxKeys doc comment for the rationale.
var storageBytesScanCap = storageMaxKeys

// aclAllowlist is the safe command allowlist applied to every provisioned ACL
// user on the shared Redis backend. It replaces "+@all" which would grant
// dangerous cross-tenant commands such as FLUSHDB, MONITOR, and CONFIG SET.
Expand Down Expand Up @@ -157,7 +171,7 @@ func (p *Provider) provisionLocal(ctx context.Context, token string) (*Credentia

// Generate a random password for the ACL user.
pwBytes := make([]byte, 16)
if _, err := rand.Read(pwBytes); err != nil {
if _, err := randRead(pwBytes); err != nil {
return nil, fmt.Errorf("cache.provisionLocal: generate password: %w", err)
}
password := hex.EncodeToString(pwBytes)
Expand Down Expand Up @@ -222,6 +236,24 @@ func (p *Provider) StorageBytes(ctx context.Context, token string) (int64, error
if p.rdb == nil {
return 0, fmt.Errorf("cache.StorageBytes: %w", errNilRedisClient)
}
return storageBytes(ctx, p.rdb, token)
}

// redisScanner is the minimal slice of *redis.Client that storageBytes needs.
// It exists only so a deterministic fake can drive the mid-scan vanished-key
// skip and the truncation ceiling without a live Redis race. *redis.Client
// satisfies it directly, so production wiring is unchanged.
type redisScanner interface {
Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd
MemoryUsage(ctx context.Context, key string, samples ...int) *redis.IntCmd
}

// storageBytes is the seam-friendly implementation behind Provider.StorageBytes.
// It takes the scanner as an interface so a test fake can return a SCAN key that
// the subsequent MEMORY USAGE reports as missing (deterministic skip branch),
// and so storageBytesScanCap can be lowered to fire the truncation branch with a
// handful of keys.
func storageBytes(ctx context.Context, rdb redisScanner, token string) (int64, error) {
prefix := token + ":*"

var (
Expand All @@ -231,20 +263,20 @@ func (p *Provider) StorageBytes(ctx context.Context, token string) (int64, error
)

for {
keys, nextCursor, err := p.rdb.Scan(ctx, cursor, prefix, storageScanBatch).Result()
keys, nextCursor, err := rdb.Scan(ctx, cursor, prefix, storageScanBatch).Result()
if err != nil {
return 0, fmt.Errorf("cache.StorageBytes scan: %w", err)
}

for _, key := range keys {
if totalKeys >= storageMaxKeys {
if totalKeys >= storageBytesScanCap {
break
}
totalKeys++

// MEMORY USAGE returns bytes used by the key including metadata.
// Err is non-nil if the key doesn't exist (just deleted).
mem, err := p.rdb.MemoryUsage(ctx, key).Result()
mem, err := rdb.MemoryUsage(ctx, key).Result()
if err != nil {
// Key was deleted between SCAN and MEMORY USAGE — skip it.
if strings.Contains(err.Error(), "ERR") || err == redis.Nil {
Expand All @@ -256,16 +288,16 @@ func (p *Provider) StorageBytes(ctx context.Context, token string) (int64, error
}

cursor = nextCursor
if cursor == 0 || totalKeys >= storageMaxKeys {
if cursor == 0 || totalKeys >= storageBytesScanCap {
break
}
}

if totalKeys >= storageMaxKeys {
if totalKeys >= storageBytesScanCap {
slog.Warn("cache.StorageBytes.truncated",
"token", token,
"keys_scanned", totalKeys,
"max_keys", storageMaxKeys,
"max_keys", storageBytesScanCap,
"impact", "storage_bytes under-reported — tenant exceeds the per-call key ceiling",
)
}
Expand Down
210 changes: 210 additions & 0 deletions internal/providers/cache/redis_unit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
package cache

// redis_unit_test.go is an in-package (white-box) companion to redis_test.go.
// It drives the unexported helpers and the error/fallback branches that the
// black-box redis_test.go cannot reach: nil-client handling, the ACL→key-
// namespace fallback, the Upstash stub, StorageBytes edge cases, and the
// legacy username derivation.

import (
"context"
"strings"
"testing"

"github.com/redis/go-redis/v9"
)

func TestNew_Defaults(t *testing.T) {
p := New(nil, "", "")
if p.backend != "local" {
t.Fatalf("empty backend must default to local; got %q", p.backend)
}
if p.redisHost != "localhost" {
t.Fatalf("empty host must default to localhost; got %q", p.redisHost)
}
p2 := New(nil, "upstash", "h:6379")
if p2.backend != "upstash" || p2.redisHost != "h:6379" {
t.Fatalf("explicit values lost: %+v", p2)
}
}

func TestACLUsernameDerivation(t *testing.T) {
// Full-token username (P1-E).
if got := aclUsername("abcdef0123456789ff"); got != "usr_abcdef0123456789ff" {
t.Fatalf("aclUsername = %q", got)
}
// Legacy username truncates to 8 chars.
if got := legacyACLUsername("abcdef0123456789"); got != "usr_abcdef01" {
t.Fatalf("legacyACLUsername(long) = %q", got)
}
// Short token: no truncation.
if got := legacyACLUsername("abc"); got != "usr_abc" {
t.Fatalf("legacyACLUsername(short) = %q", got)
}
}

func TestProvision_NilClient(t *testing.T) {
p := New(nil, "local", "localhost")
_, err := p.Provision(context.Background(), "tok", "anonymous")
if err == nil || !strings.Contains(err.Error(), "not configured") {
t.Fatalf("nil client must surface a configured error; got %v", err)
}
}

func TestStorageBytes_NilClient(t *testing.T) {
p := New(nil, "local", "localhost")
_, err := p.StorageBytes(context.Background(), "tok")
if err == nil || !strings.Contains(err.Error(), "not configured") {
t.Fatalf("nil client StorageBytes must error; got %v", err)
}
}

func TestProvision_Upstash_Stub(t *testing.T) {
p := New(nil, "upstash", "localhost")
_, err := p.Provision(context.Background(), "tok", "anonymous")
if err == nil || !strings.Contains(err.Error(), "not yet implemented") {
t.Fatalf("upstash backend must return not-implemented; got %v", err)
}
}

// deadClient returns a redis client pointed at a closed port so every command
// (including ACL SETUSER and SCAN) errors. Used to drive the ACL→key-namespace
// fallback and the SCAN error branch.
func deadClient() *redis.Client {
return redis.NewClient(&redis.Options{Addr: "127.0.0.1:1", MaxRetries: -1})
}

// TestProvisionLocal_ACLFallback covers the branch where ACL SETUSER fails and
// the provider falls back to shared-URL + key-namespace isolation.
func TestProvisionLocal_ACLFallback(t *testing.T) {
p := New(deadClient(), "local", "redishost")
creds, err := p.Provision(context.Background(), "fallback-token", "anonymous")
if err != nil {
t.Fatalf("fallback must not error: %v", err)
}
if creds.KeyPrefix != "fallback-token:" {
t.Fatalf("fallback must set KeyPrefix; got %q", creds.KeyPrefix)
}
if !strings.HasPrefix(creds.URL, "redis://redishost:6379/0") {
t.Fatalf("fallback URL must be the shared host; got %q", creds.URL)
}
}

// TestStorageBytes_ScanError covers the SCAN error return.
func TestStorageBytes_ScanError(t *testing.T) {
p := New(deadClient(), "local", "localhost")
_, err := p.StorageBytes(context.Background(), "tok")
if err == nil || !strings.Contains(err.Error(), "scan") {
t.Fatalf("dead client SCAN must error; got %v", err)
}
}

// fakeScanner is a deterministic redisScanner: it returns a fixed set of keys
// from SCAN, and reports a configurable subset as missing on MEMORY USAGE. This
// drives the mid-scan vanished-key skip branch with zero timing dependence.
type fakeScanner struct {
keys []string // returned by a single SCAN page (cursor → 0)
memBytes map[string]int64 // present keys → byte size
missing map[string]bool // keys that MEMORY USAGE reports gone
scanError error // when set, SCAN returns this error
}

func (f *fakeScanner) Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd {
cmd := redis.NewScanCmd(ctx, nil)
if f.scanError != nil {
cmd.SetErr(f.scanError)
return cmd
}
// Single page: return all keys and cursor 0 so the loop terminates.
cmd.SetVal(f.keys, 0)
return cmd
}

func (f *fakeScanner) MemoryUsage(ctx context.Context, key string, samples ...int) *redis.IntCmd {
cmd := redis.NewIntCmd(ctx)
if f.missing[key] {
// Simulate the key vanishing between SCAN and MEMORY USAGE.
cmd.SetErr(redis.Nil)
return cmd
}
cmd.SetVal(f.memBytes[key])
return cmd
}

// TestStorageBytes_MemoryUsageSkip covers the MEMORY USAGE error/skip branch
// deterministically: a fake scanner returns three keys via SCAN but reports the
// middle one as gone (redis.Nil) on MEMORY USAGE. The loop must skip it and sum
// only the two present keys. No timing, no live Redis, no flake.
func TestStorageBytes_MemoryUsageSkip(t *testing.T) {
f := &fakeScanner{
keys: []string{"tok:a", "tok:gone", "tok:b"},
memBytes: map[string]int64{
"tok:a": 100,
"tok:b": 250,
},
missing: map[string]bool{"tok:gone": true},
}
got, err := storageBytes(context.Background(), f, "tok")
if err != nil {
t.Fatalf("storageBytes must tolerate a vanished key: %v", err)
}
if got != 350 {
t.Fatalf("want 350 (100+250, skipping the gone key); got %d", got)
}
}

// TestStorageBytes_TruncationCeiling covers the truncation branch: when the key
// count reaches storageBytesScanCap, scanning stops and the under-report warning
// fires. We lower the cap to 2 (restored on cleanup) so a 3-key fake trips it
// with no 200k write.
func TestStorageBytes_TruncationCeiling(t *testing.T) {
orig := storageBytesScanCap
storageBytesScanCap = 2
t.Cleanup(func() { storageBytesScanCap = orig })

f := &fakeScanner{
keys: []string{"tok:a", "tok:b", "tok:c"},
memBytes: map[string]int64{"tok:a": 10, "tok:b": 20, "tok:c": 40},
}
got, err := storageBytes(context.Background(), f, "tok")
if err != nil {
t.Fatalf("storageBytes: %v", err)
}
// Only the first 2 keys are counted before the cap halts the scan.
if got != 30 {
t.Fatalf("want 30 (cap=2 → only first two keys); got %d", got)
}
}

// TestStorageBytes_ScanError_Fake covers the SCAN error return via the fake,
// independent of a live dead-port client.
func TestStorageBytes_ScanError_Fake(t *testing.T) {
f := &fakeScanner{scanError: context.DeadlineExceeded}
_, err := storageBytes(context.Background(), f, "tok")
if err == nil || !strings.Contains(err.Error(), "scan") {
t.Fatalf("SCAN error must propagate; got %v", err)
}
}

// TestProvisionLocal_RandReadFailure covers the crypto/rand failure branch in
// provisionLocal via the randRead seam. We can use a nil-but-non-nil client
// because the RNG failure returns before any Redis command is issued.
func TestProvisionLocal_RandReadFailure(t *testing.T) {
orig := randRead
randRead = func(b []byte) (int, error) { return 0, errBoomRand }
t.Cleanup(func() { randRead = orig })

// A live (or dead) non-nil client passes the nil-check; the RNG failure
// short-circuits before the client is touched.
p := New(deadClient(), "local", "localhost")
_, err := p.Provision(context.Background(), "tok", "anonymous")
if err == nil || !strings.Contains(err.Error(), "generate password") {
t.Fatalf("randRead failure must surface as generate-password error; got %v", err)
}
}

var errBoomRand = errBoom("rand exhausted")

type errBoom string

func (e errBoom) Error() string { return string(e) }
35 changes: 30 additions & 5 deletions internal/providers/db/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,38 @@ import (
"math/big"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)

const defaultCustomersURL = "postgres://instant_cust:instant_cust@postgres-customers:5432/instant_customers?sslmode=disable"

// alphanumChars is the charset for generated passwords.
const alphanumChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"

// randInt is the entropy source for generatePassword. It defaults to
// crypto/rand.Int and is a package var only so a test can substitute a
// fault-injecting source to exercise the (otherwise unreachable) RNG-failure
// branch. Production behaviour is identical to calling crypto/rand.Int.
var randInt = rand.Int

// pgConn is the minimal slice of *pgx.Conn that this backend exercises. Routing
// through an interface lets a deterministic fake force the otherwise-unreachable
// defensive branches: a Close that errors (defer-error log), and a REVOKE / GRANT
// / DROP USER exec that errors (non-fatal log). *pgx.Conn satisfies it directly,
// so production wiring is unchanged.
type pgConn interface {
Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
Close(ctx context.Context) error
}

// pgxConnect is the connection factory the backend uses. It defaults to a thin
// wrapper over pgx.Connect and is a package var only so a test can inject a fake
// pgConn. Production behaviour is identical to calling pgx.Connect.
var pgxConnect = func(ctx context.Context, connString string) (pgConn, error) {
return pgx.Connect(ctx, connString)
}

// LocalBackend provisions databases on the shared postgres-customers instance.
type LocalBackend struct {
customersURL string // admin connection URL
Expand All @@ -37,7 +62,7 @@ func generatePassword(n int) (string, error) {
buf := make([]byte, n)
charsetLen := big.NewInt(int64(len(alphanumChars)))
for i := range buf {
idx, err := rand.Int(rand.Reader, charsetLen)
idx, err := randInt(rand.Reader, charsetLen)
if err != nil {
return "", fmt.Errorf("generatePassword: %w", err)
}
Expand Down Expand Up @@ -72,7 +97,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier
}

// Connect as admin.
conn, err := pgx.Connect(ctx, b.customersURL)
conn, err := pgxConnect(ctx, b.customersURL)
if err != nil {
return nil, fmt.Errorf("db.local.Provision: connect: %w", err)
}
Expand Down Expand Up @@ -110,7 +135,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier
// run as a superuser/admin, not the per-token user (which lacks
// CREATE-on-pg_catalog privileges).
newDBURL := b.buildDBURL(username, pass, dbName)
adminNewDB, err := pgx.Connect(ctx, b.buildAdminNewDBURL(dbName))
adminNewDB, err := pgxConnect(ctx, b.buildAdminNewDBURL(dbName))
if err != nil {
slog.Error("db.local.Provision: connect new db for schema grant (non-fatal)", "error", err)
// If extensions were requested and we couldn't connect to the new
Expand Down Expand Up @@ -158,7 +183,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier

// StorageBytes returns the size of db_{token} in bytes using pg_database_size.
func (b *LocalBackend) StorageBytes(ctx context.Context, token, providerResourceID string) (int64, error) {
conn, err := pgx.Connect(ctx, b.customersURL)
conn, err := pgxConnect(ctx, b.customersURL)
if err != nil {
return 0, fmt.Errorf("db.local.StorageBytes: connect: %w", err)
}
Expand All @@ -181,7 +206,7 @@ func (b *LocalBackend) Deprovision(ctx context.Context, token, providerResourceI
dbName := "db_" + token
username := "usr_" + token

conn, err := pgx.Connect(ctx, b.customersURL)
conn, err := pgxConnect(ctx, b.customersURL)
if err != nil {
return fmt.Errorf("db.local.Deprovision: connect: %w", err)
}
Expand Down
Loading
Loading