diff --git a/internal/providers/cache/redis.go b/internal/providers/cache/redis.go index 99dc421..39d7588 100644 --- a/internal/providers/cache/redis.go +++ b/internal/providers/cache/redis.go @@ -23,6 +23,20 @@ import ( // of a nil-pointer panic. var errNilRedisClient = errors.New("redis admin client is not configured") +// randRead is the entropy source for ACL passwords. It defaults to +// crypto/rand.Read and is a package var only so a test can substitute a +// fault-injecting reader to exercise the (otherwise unreachable) RNG-failure +// branch in provisionLocal. Production behaviour is identical to calling +// crypto/rand.Read directly. +var randRead = rand.Read + +// storageBytesScanCap is the hard ceiling on keys inspected per StorageBytes +// call. It defaults to storageMaxKeys (200k). It is a package var only so a +// test can lower it to a handful and deterministically exercise the +// truncation branch without writing 200k keys. Production uses the 200k +// default — see the storageMaxKeys doc comment for the rationale. +var storageBytesScanCap = storageMaxKeys + // aclAllowlist is the safe command allowlist applied to every provisioned ACL // user on the shared Redis backend. It replaces "+@all" which would grant // dangerous cross-tenant commands such as FLUSHDB, MONITOR, and CONFIG SET. @@ -157,7 +171,7 @@ func (p *Provider) provisionLocal(ctx context.Context, token string) (*Credentia // Generate a random password for the ACL user. pwBytes := make([]byte, 16) - if _, err := rand.Read(pwBytes); err != nil { + if _, err := randRead(pwBytes); err != nil { return nil, fmt.Errorf("cache.provisionLocal: generate password: %w", err) } password := hex.EncodeToString(pwBytes) @@ -222,6 +236,24 @@ func (p *Provider) StorageBytes(ctx context.Context, token string) (int64, error if p.rdb == nil { return 0, fmt.Errorf("cache.StorageBytes: %w", errNilRedisClient) } + return storageBytes(ctx, p.rdb, token) +} + +// redisScanner is the minimal slice of *redis.Client that storageBytes needs. +// It exists only so a deterministic fake can drive the mid-scan vanished-key +// skip and the truncation ceiling without a live Redis race. *redis.Client +// satisfies it directly, so production wiring is unchanged. +type redisScanner interface { + Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd + MemoryUsage(ctx context.Context, key string, samples ...int) *redis.IntCmd +} + +// storageBytes is the seam-friendly implementation behind Provider.StorageBytes. +// It takes the scanner as an interface so a test fake can return a SCAN key that +// the subsequent MEMORY USAGE reports as missing (deterministic skip branch), +// and so storageBytesScanCap can be lowered to fire the truncation branch with a +// handful of keys. +func storageBytes(ctx context.Context, rdb redisScanner, token string) (int64, error) { prefix := token + ":*" var ( @@ -231,20 +263,20 @@ func (p *Provider) StorageBytes(ctx context.Context, token string) (int64, error ) for { - keys, nextCursor, err := p.rdb.Scan(ctx, cursor, prefix, storageScanBatch).Result() + keys, nextCursor, err := rdb.Scan(ctx, cursor, prefix, storageScanBatch).Result() if err != nil { return 0, fmt.Errorf("cache.StorageBytes scan: %w", err) } for _, key := range keys { - if totalKeys >= storageMaxKeys { + if totalKeys >= storageBytesScanCap { break } totalKeys++ // MEMORY USAGE returns bytes used by the key including metadata. // Err is non-nil if the key doesn't exist (just deleted). - mem, err := p.rdb.MemoryUsage(ctx, key).Result() + mem, err := rdb.MemoryUsage(ctx, key).Result() if err != nil { // Key was deleted between SCAN and MEMORY USAGE — skip it. if strings.Contains(err.Error(), "ERR") || err == redis.Nil { @@ -256,16 +288,16 @@ func (p *Provider) StorageBytes(ctx context.Context, token string) (int64, error } cursor = nextCursor - if cursor == 0 || totalKeys >= storageMaxKeys { + if cursor == 0 || totalKeys >= storageBytesScanCap { break } } - if totalKeys >= storageMaxKeys { + if totalKeys >= storageBytesScanCap { slog.Warn("cache.StorageBytes.truncated", "token", token, "keys_scanned", totalKeys, - "max_keys", storageMaxKeys, + "max_keys", storageBytesScanCap, "impact", "storage_bytes under-reported — tenant exceeds the per-call key ceiling", ) } diff --git a/internal/providers/cache/redis_unit_test.go b/internal/providers/cache/redis_unit_test.go new file mode 100644 index 0000000..702600e --- /dev/null +++ b/internal/providers/cache/redis_unit_test.go @@ -0,0 +1,210 @@ +package cache + +// redis_unit_test.go is an in-package (white-box) companion to redis_test.go. +// It drives the unexported helpers and the error/fallback branches that the +// black-box redis_test.go cannot reach: nil-client handling, the ACL→key- +// namespace fallback, the Upstash stub, StorageBytes edge cases, and the +// legacy username derivation. + +import ( + "context" + "strings" + "testing" + + "github.com/redis/go-redis/v9" +) + +func TestNew_Defaults(t *testing.T) { + p := New(nil, "", "") + if p.backend != "local" { + t.Fatalf("empty backend must default to local; got %q", p.backend) + } + if p.redisHost != "localhost" { + t.Fatalf("empty host must default to localhost; got %q", p.redisHost) + } + p2 := New(nil, "upstash", "h:6379") + if p2.backend != "upstash" || p2.redisHost != "h:6379" { + t.Fatalf("explicit values lost: %+v", p2) + } +} + +func TestACLUsernameDerivation(t *testing.T) { + // Full-token username (P1-E). + if got := aclUsername("abcdef0123456789ff"); got != "usr_abcdef0123456789ff" { + t.Fatalf("aclUsername = %q", got) + } + // Legacy username truncates to 8 chars. + if got := legacyACLUsername("abcdef0123456789"); got != "usr_abcdef01" { + t.Fatalf("legacyACLUsername(long) = %q", got) + } + // Short token: no truncation. + if got := legacyACLUsername("abc"); got != "usr_abc" { + t.Fatalf("legacyACLUsername(short) = %q", got) + } +} + +func TestProvision_NilClient(t *testing.T) { + p := New(nil, "local", "localhost") + _, err := p.Provision(context.Background(), "tok", "anonymous") + if err == nil || !strings.Contains(err.Error(), "not configured") { + t.Fatalf("nil client must surface a configured error; got %v", err) + } +} + +func TestStorageBytes_NilClient(t *testing.T) { + p := New(nil, "local", "localhost") + _, err := p.StorageBytes(context.Background(), "tok") + if err == nil || !strings.Contains(err.Error(), "not configured") { + t.Fatalf("nil client StorageBytes must error; got %v", err) + } +} + +func TestProvision_Upstash_Stub(t *testing.T) { + p := New(nil, "upstash", "localhost") + _, err := p.Provision(context.Background(), "tok", "anonymous") + if err == nil || !strings.Contains(err.Error(), "not yet implemented") { + t.Fatalf("upstash backend must return not-implemented; got %v", err) + } +} + +// deadClient returns a redis client pointed at a closed port so every command +// (including ACL SETUSER and SCAN) errors. Used to drive the ACL→key-namespace +// fallback and the SCAN error branch. +func deadClient() *redis.Client { + return redis.NewClient(&redis.Options{Addr: "127.0.0.1:1", MaxRetries: -1}) +} + +// TestProvisionLocal_ACLFallback covers the branch where ACL SETUSER fails and +// the provider falls back to shared-URL + key-namespace isolation. +func TestProvisionLocal_ACLFallback(t *testing.T) { + p := New(deadClient(), "local", "redishost") + creds, err := p.Provision(context.Background(), "fallback-token", "anonymous") + if err != nil { + t.Fatalf("fallback must not error: %v", err) + } + if creds.KeyPrefix != "fallback-token:" { + t.Fatalf("fallback must set KeyPrefix; got %q", creds.KeyPrefix) + } + if !strings.HasPrefix(creds.URL, "redis://redishost:6379/0") { + t.Fatalf("fallback URL must be the shared host; got %q", creds.URL) + } +} + +// TestStorageBytes_ScanError covers the SCAN error return. +func TestStorageBytes_ScanError(t *testing.T) { + p := New(deadClient(), "local", "localhost") + _, err := p.StorageBytes(context.Background(), "tok") + if err == nil || !strings.Contains(err.Error(), "scan") { + t.Fatalf("dead client SCAN must error; got %v", err) + } +} + +// fakeScanner is a deterministic redisScanner: it returns a fixed set of keys +// from SCAN, and reports a configurable subset as missing on MEMORY USAGE. This +// drives the mid-scan vanished-key skip branch with zero timing dependence. +type fakeScanner struct { + keys []string // returned by a single SCAN page (cursor → 0) + memBytes map[string]int64 // present keys → byte size + missing map[string]bool // keys that MEMORY USAGE reports gone + scanError error // when set, SCAN returns this error +} + +func (f *fakeScanner) Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd { + cmd := redis.NewScanCmd(ctx, nil) + if f.scanError != nil { + cmd.SetErr(f.scanError) + return cmd + } + // Single page: return all keys and cursor 0 so the loop terminates. + cmd.SetVal(f.keys, 0) + return cmd +} + +func (f *fakeScanner) MemoryUsage(ctx context.Context, key string, samples ...int) *redis.IntCmd { + cmd := redis.NewIntCmd(ctx) + if f.missing[key] { + // Simulate the key vanishing between SCAN and MEMORY USAGE. + cmd.SetErr(redis.Nil) + return cmd + } + cmd.SetVal(f.memBytes[key]) + return cmd +} + +// TestStorageBytes_MemoryUsageSkip covers the MEMORY USAGE error/skip branch +// deterministically: a fake scanner returns three keys via SCAN but reports the +// middle one as gone (redis.Nil) on MEMORY USAGE. The loop must skip it and sum +// only the two present keys. No timing, no live Redis, no flake. +func TestStorageBytes_MemoryUsageSkip(t *testing.T) { + f := &fakeScanner{ + keys: []string{"tok:a", "tok:gone", "tok:b"}, + memBytes: map[string]int64{ + "tok:a": 100, + "tok:b": 250, + }, + missing: map[string]bool{"tok:gone": true}, + } + got, err := storageBytes(context.Background(), f, "tok") + if err != nil { + t.Fatalf("storageBytes must tolerate a vanished key: %v", err) + } + if got != 350 { + t.Fatalf("want 350 (100+250, skipping the gone key); got %d", got) + } +} + +// TestStorageBytes_TruncationCeiling covers the truncation branch: when the key +// count reaches storageBytesScanCap, scanning stops and the under-report warning +// fires. We lower the cap to 2 (restored on cleanup) so a 3-key fake trips it +// with no 200k write. +func TestStorageBytes_TruncationCeiling(t *testing.T) { + orig := storageBytesScanCap + storageBytesScanCap = 2 + t.Cleanup(func() { storageBytesScanCap = orig }) + + f := &fakeScanner{ + keys: []string{"tok:a", "tok:b", "tok:c"}, + memBytes: map[string]int64{"tok:a": 10, "tok:b": 20, "tok:c": 40}, + } + got, err := storageBytes(context.Background(), f, "tok") + if err != nil { + t.Fatalf("storageBytes: %v", err) + } + // Only the first 2 keys are counted before the cap halts the scan. + if got != 30 { + t.Fatalf("want 30 (cap=2 → only first two keys); got %d", got) + } +} + +// TestStorageBytes_ScanError_Fake covers the SCAN error return via the fake, +// independent of a live dead-port client. +func TestStorageBytes_ScanError_Fake(t *testing.T) { + f := &fakeScanner{scanError: context.DeadlineExceeded} + _, err := storageBytes(context.Background(), f, "tok") + if err == nil || !strings.Contains(err.Error(), "scan") { + t.Fatalf("SCAN error must propagate; got %v", err) + } +} + +// TestProvisionLocal_RandReadFailure covers the crypto/rand failure branch in +// provisionLocal via the randRead seam. We can use a nil-but-non-nil client +// because the RNG failure returns before any Redis command is issued. +func TestProvisionLocal_RandReadFailure(t *testing.T) { + orig := randRead + randRead = func(b []byte) (int, error) { return 0, errBoomRand } + t.Cleanup(func() { randRead = orig }) + + // A live (or dead) non-nil client passes the nil-check; the RNG failure + // short-circuits before the client is touched. + p := New(deadClient(), "local", "localhost") + _, err := p.Provision(context.Background(), "tok", "anonymous") + if err == nil || !strings.Contains(err.Error(), "generate password") { + t.Fatalf("randRead failure must surface as generate-password error; got %v", err) + } +} + +var errBoomRand = errBoom("rand exhausted") + +type errBoom string + +func (e errBoom) Error() string { return string(e) } diff --git a/internal/providers/db/local.go b/internal/providers/db/local.go index b89ea3e..57de37a 100644 --- a/internal/providers/db/local.go +++ b/internal/providers/db/local.go @@ -12,6 +12,7 @@ import ( "math/big" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) const defaultCustomersURL = "postgres://instant_cust:instant_cust@postgres-customers:5432/instant_customers?sslmode=disable" @@ -19,6 +20,30 @@ const defaultCustomersURL = "postgres://instant_cust:instant_cust@postgres-custo // alphanumChars is the charset for generated passwords. const alphanumChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +// randInt is the entropy source for generatePassword. It defaults to +// crypto/rand.Int and is a package var only so a test can substitute a +// fault-injecting source to exercise the (otherwise unreachable) RNG-failure +// branch. Production behaviour is identical to calling crypto/rand.Int. +var randInt = rand.Int + +// pgConn is the minimal slice of *pgx.Conn that this backend exercises. Routing +// through an interface lets a deterministic fake force the otherwise-unreachable +// defensive branches: a Close that errors (defer-error log), and a REVOKE / GRANT +// / DROP USER exec that errors (non-fatal log). *pgx.Conn satisfies it directly, +// so production wiring is unchanged. +type pgConn interface { + Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row + Close(ctx context.Context) error +} + +// pgxConnect is the connection factory the backend uses. It defaults to a thin +// wrapper over pgx.Connect and is a package var only so a test can inject a fake +// pgConn. Production behaviour is identical to calling pgx.Connect. +var pgxConnect = func(ctx context.Context, connString string) (pgConn, error) { + return pgx.Connect(ctx, connString) +} + // LocalBackend provisions databases on the shared postgres-customers instance. type LocalBackend struct { customersURL string // admin connection URL @@ -37,7 +62,7 @@ func generatePassword(n int) (string, error) { buf := make([]byte, n) charsetLen := big.NewInt(int64(len(alphanumChars))) for i := range buf { - idx, err := rand.Int(rand.Reader, charsetLen) + idx, err := randInt(rand.Reader, charsetLen) if err != nil { return "", fmt.Errorf("generatePassword: %w", err) } @@ -72,7 +97,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier } // Connect as admin. - conn, err := pgx.Connect(ctx, b.customersURL) + conn, err := pgxConnect(ctx, b.customersURL) if err != nil { return nil, fmt.Errorf("db.local.Provision: connect: %w", err) } @@ -110,7 +135,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier // run as a superuser/admin, not the per-token user (which lacks // CREATE-on-pg_catalog privileges). newDBURL := b.buildDBURL(username, pass, dbName) - adminNewDB, err := pgx.Connect(ctx, b.buildAdminNewDBURL(dbName)) + adminNewDB, err := pgxConnect(ctx, b.buildAdminNewDBURL(dbName)) if err != nil { slog.Error("db.local.Provision: connect new db for schema grant (non-fatal)", "error", err) // If extensions were requested and we couldn't connect to the new @@ -158,7 +183,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier // StorageBytes returns the size of db_{token} in bytes using pg_database_size. func (b *LocalBackend) StorageBytes(ctx context.Context, token, providerResourceID string) (int64, error) { - conn, err := pgx.Connect(ctx, b.customersURL) + conn, err := pgxConnect(ctx, b.customersURL) if err != nil { return 0, fmt.Errorf("db.local.StorageBytes: connect: %w", err) } @@ -181,7 +206,7 @@ func (b *LocalBackend) Deprovision(ctx context.Context, token, providerResourceI dbName := "db_" + token username := "usr_" + token - conn, err := pgx.Connect(ctx, b.customersURL) + conn, err := pgxConnect(ctx, b.customersURL) if err != nil { return fmt.Errorf("db.local.Deprovision: connect: %w", err) } diff --git a/internal/providers/db/local_seam_test.go b/internal/providers/db/local_seam_test.go new file mode 100644 index 0000000..b04d53a --- /dev/null +++ b/internal/providers/db/local_seam_test.go @@ -0,0 +1,253 @@ +package db + +// local_seam_test.go drives the LocalBackend defensive branches that cannot be +// triggered deterministically against a real superuser Postgres connection: +// the crypto/rand failure, the conn.Close(ctx) defer-error logs, and the +// non-fatal REVOKE / GRANT / DROP USER exec-error logs. It does so via the +// package seams (randInt + pgxConnect) using an in-memory fake pgConn, so the +// tests need no live database and never flake. + +import ( + "context" + "errors" + "io" + "math/big" + "strings" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// fakePgConn is a deterministic pgConn. execErr decides, per-statement, whether +// Exec returns an error (matched by a case-insensitive substring of the SQL). +// closeErr is returned by Close. queryRowErr is surfaced by the returned Row's +// Scan. +type fakePgConn struct { + execErr map[string]error // SQL substring → error to return + closeErr error + queryRowErr error + queryRowVal int64 + closed int +} + +func (f *fakePgConn) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { + for sub, err := range f.execErr { + if strings.Contains(sql, sub) { + return pgconn.CommandTag{}, err + } + } + return pgconn.CommandTag{}, nil +} + +func (f *fakePgConn) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + return &fakeRow{err: f.queryRowErr, val: f.queryRowVal} +} + +func (f *fakePgConn) Close(ctx context.Context) error { + f.closed++ + return f.closeErr +} + +type fakeRow struct { + err error + val int64 +} + +func (r *fakeRow) Scan(dest ...any) error { + if r.err != nil { + return r.err + } + if len(dest) == 1 { + if p, ok := dest[0].(*int64); ok { + *p = r.val + } + } + return nil +} + +// withFakeConn installs a pgxConnect seam returning the given fakes in sequence +// (first call → conns[0], etc.) and restores the real factory on cleanup. A +// connErrs entry, when non-nil, makes that connect attempt fail instead. +func withFakeConn(t *testing.T, conns []*fakePgConn, connErrs []error) { + t.Helper() + orig := pgxConnect + var i int + pgxConnect = func(ctx context.Context, connString string) (pgConn, error) { + idx := i + i++ + if idx < len(connErrs) && connErrs[idx] != nil { + return nil, connErrs[idx] + } + if idx < len(conns) { + return conns[idx], nil + } + return &fakePgConn{}, nil + } + t.Cleanup(func() { pgxConnect = orig }) +} + +// TestGeneratePassword_RandFailure covers the crypto/rand error branch in +// generatePassword via the randInt seam. +func TestGeneratePassword_RandFailure(t *testing.T) { + orig := randInt + randInt = func(_ io.Reader, _ *big.Int) (*big.Int, error) { + return nil, errors.New("entropy depleted") + } + t.Cleanup(func() { randInt = orig }) + + _, err := generatePassword(16) + if err == nil || !strings.Contains(err.Error(), "generatePassword") { + t.Fatalf("randInt failure must surface; got %v", err) + } +} + +// TestProvision_RandFailure asserts the password failure propagates out of +// Provision before any DB connection is opened. +func TestProvision_RandFailure(t *testing.T) { + orig := randInt + randInt = func(_ io.Reader, _ *big.Int) (*big.Int, error) { + return nil, errors.New("entropy depleted") + } + t.Cleanup(func() { randInt = orig }) + + b := newLocalBackend("postgres://x:y@h:5432/d") + _, err := b.Provision(context.Background(), "tok", "anonymous") + if err == nil || !strings.Contains(err.Error(), "db.local.Provision") { + t.Fatalf("Provision must fail on RNG error; got %v", err) + } +} + +// TestProvision_NonFatalBranches covers, in one provision, the REVOKE CONNECT +// non-fatal log, the GRANT SCHEMA non-fatal log, and BOTH conn.Close defer-error +// logs (admin conn + new-db conn). The provision still succeeds — these branches +// are best-effort by design. +func TestProvision_NonFatalBranches(t *testing.T) { + adminConn := &fakePgConn{ + execErr: map[string]error{"REVOKE CONNECT": errors.New("revoke denied")}, + closeErr: errors.New("admin close failed"), + } + newDBConn := &fakePgConn{ + execErr: map[string]error{"GRANT ALL ON SCHEMA": errors.New("schema grant denied")}, + closeErr: errors.New("newdb close failed"), + } + withFakeConn(t, []*fakePgConn{adminConn, newDBConn}, nil) + + b := newLocalBackend("postgres://admin:pw@host:5432/instant_customers") + creds, err := b.Provision(context.Background(), "nonfatal", "anonymous") + if err != nil { + t.Fatalf("non-fatal branches must not fail provision: %v", err) + } + if creds.DatabaseName != "db_nonfatal" { + t.Fatalf("DatabaseName = %q", creds.DatabaseName) + } + if adminConn.closed == 0 || newDBConn.closed == 0 { + t.Fatal("both connections must be Closed (defer-error branches)") + } +} + +// TestProvision_NewDBConnectFails_WithExtensions covers the branch where the +// connect-to-new-db step fails AND extensions were requested, so the provision +// errors loudly instead of returning a non-vector DB. +func TestProvision_NewDBConnectFails_WithExtensions(t *testing.T) { + adminConn := &fakePgConn{} + // First connect (admin) succeeds; second connect (new DB) fails. + withFakeConn(t, []*fakePgConn{adminConn}, []error{nil, errors.New("new db unreachable")}) + + b := newLocalBackend("postgres://admin:pw@host:5432/instant_customers") + _, err := b.Provision(context.Background(), "extfail", "pro") + // No extensions requested → non-fatal, provision succeeds. + if err != nil { + t.Fatalf("no-extension new-db connect failure must be non-fatal: %v", err) + } +} + +// TestProvisionWithExtensions_NewDBConnectFails covers the loud-failure arm: +// extensions requested but the new-DB connect failed. +func TestProvisionWithExtensions_NewDBConnectFails(t *testing.T) { + adminConn := &fakePgConn{} + withFakeConn(t, []*fakePgConn{adminConn}, []error{nil, errors.New("new db unreachable")}) + + b := newLocalBackend("postgres://admin:pw@host:5432/instant_customers") + _, err := b.ProvisionWithExtensions(context.Background(), "extloud", "pro", []string{"vector"}) + if err == nil || !strings.Contains(err.Error(), "install extensions") { + t.Fatalf("requested extensions + new-db connect fail must error loudly; got %v", err) + } +} + +// TestProvision_CreateUserFails covers the CREATE USER error return via the fake. +func TestProvision_CreateUserFails(t *testing.T) { + adminConn := &fakePgConn{execErr: map[string]error{"CREATE USER": errors.New("user exists")}} + withFakeConn(t, []*fakePgConn{adminConn}, nil) + + b := newLocalBackend("postgres://admin:pw@host:5432/d") + _, err := b.Provision(context.Background(), "dupuser", "anonymous") + if err == nil || !strings.Contains(err.Error(), "CREATE USER") { + t.Fatalf("want CREATE USER error; got %v", err) + } +} + +// TestProvision_GrantDatabaseFails covers the fatal GRANT DATABASE error return. +func TestProvision_GrantDatabaseFails(t *testing.T) { + adminConn := &fakePgConn{execErr: map[string]error{"GRANT ALL PRIVILEGES ON DATABASE": errors.New("denied")}} + withFakeConn(t, []*fakePgConn{adminConn}, nil) + + b := newLocalBackend("postgres://admin:pw@host:5432/d") + _, err := b.Provision(context.Background(), "grantfail", "anonymous") + if err == nil || !strings.Contains(err.Error(), "GRANT DATABASE") { + t.Fatalf("want GRANT DATABASE error; got %v", err) + } +} + +// TestStorageBytes_CloseError covers the StorageBytes disconnect defer-error log +// while the query itself succeeds. +func TestStorageBytes_CloseError(t *testing.T) { + conn := &fakePgConn{closeErr: errors.New("close failed"), queryRowVal: 4096} + withFakeConn(t, []*fakePgConn{conn}, nil) + + b := newLocalBackend("postgres://admin:pw@host:5432/d") + size, err := b.StorageBytes(context.Background(), "tok", "") + if err != nil { + t.Fatalf("StorageBytes: %v", err) + } + if size != 4096 { + t.Fatalf("size = %d, want 4096", size) + } + if conn.closed == 0 { + t.Fatal("Close must have run (defer-error branch)") + } +} + +// TestDeprovision_NonFatalBranches covers the terminate-connections failure log, +// the DROP USER non-fatal log, and the Close defer-error log — all in one +// Deprovision that still succeeds. +func TestDeprovision_NonFatalBranches(t *testing.T) { + conn := &fakePgConn{ + execErr: map[string]error{ + "pg_terminate_backend": errors.New("terminate denied"), + "DROP USER": errors.New("drop user denied"), + }, + closeErr: errors.New("close failed"), + } + withFakeConn(t, []*fakePgConn{conn}, nil) + + b := newLocalBackend("postgres://admin:pw@host:5432/d") + if err := b.Deprovision(context.Background(), "tok", ""); err != nil { + t.Fatalf("Deprovision non-fatal branches must succeed: %v", err) + } + if conn.closed == 0 { + t.Fatal("Close must have run") + } +} + +// TestDeprovision_DropDatabaseFails covers the fatal DROP DATABASE error return. +func TestDeprovision_DropDatabaseFails(t *testing.T) { + conn := &fakePgConn{execErr: map[string]error{"DROP DATABASE": errors.New("not owner")}} + withFakeConn(t, []*fakePgConn{conn}, nil) + + b := newLocalBackend("postgres://admin:pw@host:5432/d") + if err := b.Deprovision(context.Background(), "tok", ""); err == nil || + !strings.Contains(err.Error(), "DROP DATABASE") { + t.Fatalf("want DROP DATABASE error; got %v", err) + } +} diff --git a/internal/providers/db/local_test.go b/internal/providers/db/local_test.go new file mode 100644 index 0000000..d68e2af --- /dev/null +++ b/internal/providers/db/local_test.go @@ -0,0 +1,475 @@ +package db + +// local_test.go drives the LocalBackend (CREATE DATABASE / USER on the shared +// Postgres pod) against a real Postgres instance. Set TEST_POSTGRES_CUSTOMERS_URL to a +// superuser DSN (e.g. postgres://instant_cust:instant_cust@127.0.0.1:55432/ +// instant_customers?sslmode=disable). The DSN MUST belong to a role with +// CREATEDB + CREATEROLE so the provisioning DDL succeeds; the docker container +// the coverage harness starts uses the POSTGRES_USER bootstrap superuser. + +import ( + "context" + "crypto/rand" + "os" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5" +) + +// testCustomersURL returns the admin DSN for the local backend, skipping the +// test if TEST_POSTGRES_CUSTOMERS_URL is unset. +func testCustomersURL(t *testing.T) string { + t.Helper() + dsn := os.Getenv("TEST_POSTGRES_CUSTOMERS_URL") + if dsn == "" { + t.Skip("TEST_POSTGRES_CUSTOMERS_URL not set — skipping local Postgres provisioning tests") + } + return dsn +} + +// cleanupDB drops the database and user a test provisioned, ignoring errors so +// a half-failed provision doesn't leave the test red on teardown. +func cleanupDB(t *testing.T, dsn, token string) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := pgx.Connect(ctx, dsn) + if err != nil { + t.Logf("cleanupDB: connect: %v", err) + return + } + defer conn.Close(ctx) + _, _ = conn.Exec(ctx, + "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname='db_"+token+"' AND pid<>pg_backend_pid()") + _, _ = conn.Exec(ctx, `DROP DATABASE IF EXISTS "db_`+token+`"`) + _, _ = conn.Exec(ctx, `DROP USER IF EXISTS "usr_`+token+`"`) +} + +// uniqueToken returns a Postgres-identifier-safe token unique per test AND per +// process run. The per-run seed (random hex chosen once at init) prevents +// collisions between two same-second test runs against a shared Postgres — +// without it, time.Now().Format("150405") repeats every day at the same second +// and could collide with an orphan db_ left by an earlier run. +func uniqueToken(prefix string) string { + return prefix + runSeed + randHex() +} + +// runSeed is a per-process lowercase-hex suffix making every token in this run +// disjoint from any other run's tokens. +var runSeed = func() string { + b := make([]byte, 4) + _, _ = rand.Read(b) + const hexd = "0123456789abcdef" + out := make([]byte, 8) + for i, x := range b { + out[2*i] = hexd[x>>4] + out[2*i+1] = hexd[x&0x0f] + } + return string(out) +}() + +var hexCounter int + +func randHex() string { + hexCounter++ + return string(rune('a'+(hexCounter%26))) + string(rune('a'+((hexCounter/26)%26))) +} + +func TestNewLocalBackend_DefaultURL(t *testing.T) { + b := newLocalBackend("") + if b.customersURL != defaultCustomersURL { + t.Fatalf("empty URL must fall back to default; got %q", b.customersURL) + } + b2 := newLocalBackend("postgres://x:y@h:5432/d") + if b2.customersURL != "postgres://x:y@h:5432/d" { + t.Fatalf("explicit URL must be preserved; got %q", b2.customersURL) + } +} + +func TestGeneratePassword(t *testing.T) { + p, err := generatePassword(16) + if err != nil { + t.Fatalf("generatePassword: %v", err) + } + if len(p) != 16 { + t.Fatalf("want length 16, got %d", len(p)) + } + // Every char must come from the alphanum charset. + for _, c := range p { + if !strings.ContainsRune(alphanumChars, c) { + t.Fatalf("password char %q not in charset", c) + } + } + // Zero length is valid and returns "". + z, err := generatePassword(0) + if err != nil || z != "" { + t.Fatalf("generatePassword(0) = (%q,%v); want (\"\",nil)", z, err) + } +} + +func TestExtractHostAndIndexOf(t *testing.T) { + cases := []struct{ in, want string }{ + {"postgres://u:p@host:5432/db", "host:5432"}, + {"postgres://u:p@host/db", "host"}, + {"postgres://host:5432/db", "host:5432"}, // no auth + {"host:5432", "host:5432"}, // no prefix, no slash, no @ + {"postgres://u:p@host:5432", "host:5432"}, // no trailing slash + } + for _, c := range cases { + if got := extractHost(c.in); got != c.want { + t.Errorf("extractHost(%q) = %q; want %q", c.in, got, c.want) + } + } + if indexOf("abc", 'b') != 1 { + t.Fatal("indexOf hit wrong") + } + if indexOf("abc", 'z') != -1 { + t.Fatal("indexOf miss should be -1") + } +} + +func TestBuildURLs(t *testing.T) { + b := newLocalBackend("postgres://admin:pw@pghost:5432/instant_customers?sslmode=disable") + got := b.buildDBURL("usr_x", "secret", "db_x") + want := "postgres://usr_x:secret@pghost:5432/db_x" + if got != want { + t.Fatalf("buildDBURL = %q; want %q", got, want) + } + // buildAdminNewDBURL replaces the trailing path component. + admin := b.buildAdminNewDBURL("db_x") + if !strings.HasSuffix(admin, "/db_x") { + t.Fatalf("buildAdminNewDBURL must end with /db_x; got %q", admin) + } + // No-slash URL falls back to appending /db_x. + b2 := newLocalBackend("postgresnohost") + if got := b2.buildAdminNewDBURL("db_y"); got != "postgresnohost/db_y" { + t.Fatalf("no-slash fallback = %q", got) + } +} + +// TestLocalBackend_ProvisionDeprovision_HappyPath exercises the full +// CREATE DATABASE / CREATE USER / GRANT / DROP lifecycle against real Postgres. +func TestLocalBackend_ProvisionDeprovision_HappyPath(t *testing.T) { + dsn := testCustomersURL(t) + token := uniqueToken("happy") + defer cleanupDB(t, dsn, token) + + b := newLocalBackend(dsn) + ctx := context.Background() + + creds, err := b.Provision(ctx, token, "anonymous") + if err != nil { + t.Fatalf("Provision: %v", err) + } + if creds.DatabaseName != "db_"+token { + t.Fatalf("DatabaseName = %q", creds.DatabaseName) + } + if creds.Username != "usr_"+token { + t.Fatalf("Username = %q", creds.Username) + } + if !strings.HasPrefix(creds.URL, "postgres://usr_"+token+":") { + t.Fatalf("URL = %q", creds.URL) + } + if creds.ProviderResourceID != "" { + t.Fatalf("local backend ProviderResourceID must be empty; got %q", creds.ProviderResourceID) + } + + // The provisioned user must actually be able to connect to its DB. + userConn, err := pgx.Connect(ctx, creds.URL+"?sslmode=disable") + if err != nil { + t.Fatalf("provisioned user cannot connect: %v", err) + } + var one int + if err := userConn.QueryRow(ctx, "SELECT 1").Scan(&one); err != nil || one != 1 { + t.Fatalf("provisioned user query failed: %v", err) + } + userConn.Close(ctx) + + // StorageBytes must report a positive size for the live database. + size, err := b.StorageBytes(ctx, token, "") + if err != nil { + t.Fatalf("StorageBytes: %v", err) + } + if size <= 0 { + t.Fatalf("StorageBytes = %d; want > 0", size) + } + + // Deprovision drops everything. + if err := b.Deprovision(ctx, token, ""); err != nil { + t.Fatalf("Deprovision: %v", err) + } + + // Deprovision is idempotent — second call (DROP ... IF EXISTS) is a no-op. + if err := b.Deprovision(ctx, token, ""); err != nil { + t.Fatalf("second Deprovision must be idempotent: %v", err) + } +} + +// TestLocalBackend_ProvisionWithExtensions installs pgvector if available; +// when the extension isn't present in the image, CREATE EXTENSION errors and +// Provision must surface it (the requested-extension-but-can't-install branch). +func TestLocalBackend_ProvisionWithExtensions(t *testing.T) { + dsn := testCustomersURL(t) + token := uniqueToken("ext") + defer cleanupDB(t, dsn, token) + + b := newLocalBackend(dsn) + ctx := context.Background() + + _, err := b.ProvisionWithExtensions(ctx, token, "pro", []string{"vector"}) + // Plain postgres:16-alpine has no pgvector, so CREATE EXTENSION fails and the + // whole provision errors. If the image *does* have it, err is nil. Either is + // a valid exercised branch; we just assert it doesn't panic and, on success, + // the DB is usable. + if err == nil { + // Vector installed — verify DB exists then clean up. + if size, sErr := b.StorageBytes(ctx, token, ""); sErr != nil || size <= 0 { + t.Fatalf("post-extension StorageBytes=%d err=%v", size, sErr) + } + } else if !strings.Contains(err.Error(), "CREATE EXTENSION") { + t.Fatalf("unexpected extension error: %v", err) + } +} + +// TestLocalBackend_ProvisionWithExtensions_Rejected covers the allowlist +// rejection branch before any DB connection is opened. +func TestLocalBackend_ProvisionWithExtensions_Rejected(t *testing.T) { + b := newLocalBackend("postgres://u:p@127.0.0.1:1/db?sslmode=disable") + _, err := b.ProvisionWithExtensions(context.Background(), "tok", "pro", []string{"postgis"}) + if err == nil || !strings.Contains(err.Error(), "allowlist") { + t.Fatalf("want allowlist rejection, got %v", err) + } +} + +// TestLocalBackend_DuplicateProvision covers the CREATE DATABASE error branch: +// provisioning the same token twice fails on the second CREATE DATABASE. +func TestLocalBackend_DuplicateProvision(t *testing.T) { + dsn := testCustomersURL(t) + token := uniqueToken("dup") + defer cleanupDB(t, dsn, token) + + b := newLocalBackend(dsn) + ctx := context.Background() + + if _, err := b.Provision(ctx, token, "anonymous"); err != nil { + t.Fatalf("first Provision: %v", err) + } + _, err := b.Provision(ctx, token, "anonymous") + if err == nil || !strings.Contains(err.Error(), "CREATE DATABASE") { + t.Fatalf("duplicate Provision must fail on CREATE DATABASE; got %v", err) + } +} + +// TestLocalBackend_ConnectFailure covers every connect-error branch by pointing +// the backend at a dead port. +func TestLocalBackend_ConnectFailure(t *testing.T) { + b := newLocalBackend("postgres://u:p@127.0.0.1:1/db?sslmode=disable&connect_timeout=1") + ctx := context.Background() + + if _, err := b.Provision(ctx, "tok", "anonymous"); err == nil { + t.Fatal("Provision against dead port must error") + } + if _, err := b.StorageBytes(ctx, "tok", ""); err == nil { + t.Fatal("StorageBytes against dead port must error") + } + if err := b.Deprovision(ctx, "tok", ""); err == nil { + t.Fatal("Deprovision against dead port must error") + } +} + +// TestLocalBackend_CreateUserConflict covers the CREATE USER error branch: +// CREATE DATABASE succeeds but the user already exists. +func TestLocalBackend_CreateUserConflict(t *testing.T) { + dsn := testCustomersURL(t) + token := uniqueToken("usrconf") + defer cleanupDB(t, dsn, token) + + ctx := context.Background() + conn, err := pgx.Connect(ctx, dsn) + if err != nil { + t.Fatalf("connect: %v", err) + } + // Pre-create the user so the provisioning CREATE USER collides. + if _, err := conn.Exec(ctx, `CREATE USER "usr_`+token+`" WITH PASSWORD 'x'`); err != nil { + t.Fatalf("seed user: %v", err) + } + conn.Close(ctx) + + b := newLocalBackend(dsn) + _, err = b.Provision(ctx, token, "anonymous") + if err == nil || !strings.Contains(err.Error(), "CREATE USER") { + t.Fatalf("want CREATE USER conflict; got %v", err) + } +} + +// TestLocalBackend_Deprovision_TerminatesConnections exercises the +// pg_terminate_backend path with a live connection open against the target DB +// at drop time. The reaper terminates it, then DROP DATABASE succeeds. +func TestLocalBackend_Deprovision_TerminatesConnections(t *testing.T) { + dsn := testCustomersURL(t) + token := uniqueToken("term") + defer cleanupDB(t, dsn, token) + + b := newLocalBackend(dsn) + ctx := context.Background() + + creds, err := b.Provision(ctx, token, "anonymous") + if err != nil { + t.Fatalf("Provision: %v", err) + } + // Hold an open connection against the provisioned database so the + // pg_terminate_backend statement has a real backend to terminate. + live, err := pgx.Connect(ctx, creds.URL+"?sslmode=disable") + if err != nil { + t.Fatalf("open live conn: %v", err) + } + defer live.Close(ctx) + + if err := b.Deprovision(ctx, token, ""); err != nil { + t.Fatalf("Deprovision with live conn: %v", err) + } +} + +// TestLocalBackend_Provision_LimitedRole_NonFatalGrants provisions under a +// CREATEDB+CREATEROLE non-superuser role. In PostgreSQL 15+ the `public` schema +// is owned by the bootstrap superuser, not the database creator, so the +// GRANT ALL ON SCHEMA public statement fails for a non-superuser owner — this +// exercises the non-fatal GRANT-SCHEMA log branch while the provision still +// succeeds (the grant is best-effort). The provisioned credentials are still +// returned and usable for connection. +func TestLocalBackend_Provision_LimitedRole_NonFatalGrants(t *testing.T) { + dsn := testCustomersURL(t) + limited := os.Getenv("TEST_CUSTOMERS_LIMITED_URL") + if limited == "" { + t.Skip("TEST_CUSTOMERS_LIMITED_URL not set — skipping limited-role provision test") + } + token := uniqueToken("limgrant") + defer cleanupDB(t, dsn, token) + + b := newLocalBackend(limited) + creds, err := b.Provision(context.Background(), token, "anonymous") + if err != nil { + t.Fatalf("Provision under limited role must still succeed (grants are best-effort): %v", err) + } + if creds.DatabaseName != "db_"+token { + t.Fatalf("DatabaseName = %q", creds.DatabaseName) + } +} + +// TestLocalBackend_Deprovision_DropFails exercises BOTH the +// pg_terminate_backend failure log (a non-privileged role cannot terminate +// another role's backend) AND the DROP DATABASE error return (a non-owner +// cannot drop the database). It provisions as the superuser, then deprovisions +// as a CREATEDB-but-non-superuser role while a superuser-owned connection is +// live against the DB. +// +// Requires TEST_CUSTOMERS_LIMITED_URL — a DSN for a role with CREATEDB + +// CREATEROLE but NOT superuser, that does NOT own the provisioned database. +func TestLocalBackend_Deprovision_DropFails(t *testing.T) { + dsn := testCustomersURL(t) + limited := os.Getenv("TEST_CUSTOMERS_LIMITED_URL") + if limited == "" { + t.Skip("TEST_CUSTOMERS_LIMITED_URL not set — skipping privilege-failure deprovision test") + } + token := uniqueToken("dropfail") + defer cleanupDB(t, dsn, token) + + ctx := context.Background() + // Provision as the superuser so the DB is owned by the superuser. + super := newLocalBackend(dsn) + creds, err := super.Provision(ctx, token, "anonymous") + if err != nil { + t.Fatalf("Provision: %v", err) + } + // Hold a live connection from the provisioned user so terminate has work. + live, err := pgx.Connect(ctx, creds.URL+"?sslmode=disable") + if err != nil { + t.Fatalf("open live conn: %v", err) + } + defer live.Close(ctx) + + // Deprovision as the limited role: pg_terminate_backend on another role's + // backend is denied (logged, non-fatal), then DROP DATABASE on a DB it does + // not own returns an error. + lb := newLocalBackend(limited) + if err := lb.Deprovision(ctx, token, ""); err == nil { + t.Fatal("limited-role Deprovision must fail on DROP DATABASE") + } +} + +// TestLocalBackend_Deprovision_DropUserFails covers the non-fatal DROP USER +// log branch: DROP DATABASE succeeds but DROP USER fails because the +// deprovisioning role does not have admin rights over the target role. +// Setup: the limited role owns the database (so it can drop it), but the user +// role is owned/created by the superuser (so the limited role cannot drop it). +func TestLocalBackend_Deprovision_DropUserFails(t *testing.T) { + dsn := testCustomersURL(t) + limited := os.Getenv("TEST_CUSTOMERS_LIMITED_URL") + if limited == "" { + t.Skip("TEST_CUSTOMERS_LIMITED_URL not set — skipping DROP USER failure test") + } + token := uniqueToken("dropusr") + dbName := "db_" + token + username := "usr_" + token + ctx := context.Background() + + // limited role creates + owns the database. + lconn, err := pgx.Connect(ctx, limited) + if err != nil { + t.Fatalf("limited connect: %v", err) + } + if _, err := lconn.Exec(ctx, `CREATE DATABASE "`+dbName+`"`); err != nil { + t.Fatalf("limited CREATE DATABASE: %v", err) + } + lconn.Close(ctx) + + // superuser creates the user (limited cannot drop a superuser-owned role). + sconn, err := pgx.Connect(ctx, dsn) + if err != nil { + t.Fatalf("super connect: %v", err) + } + _, _ = sconn.Exec(ctx, `CREATE USER "`+username+`" WITH PASSWORD 'x'`) + sconn.Close(ctx) + + defer cleanupDB(t, dsn, token) + + // Deprovision as limited: DROP DATABASE succeeds, DROP USER fails (logged). + lb := newLocalBackend(limited) + if err := lb.Deprovision(ctx, token, ""); err != nil { + t.Fatalf("Deprovision should succeed (DROP USER failure is non-fatal): %v", err) + } +} + +// TestLocalBackend_StorageBytes_CancelledCtx covers the StorageBytes +// disconnect-error defer by cancelling the context immediately after the query +// returns, so conn.Close(ctx) sees a cancelled context. +func TestLocalBackend_StorageBytes_CancelledCtx(t *testing.T) { + dsn := testCustomersURL(t) + token := uniqueToken("cancel") + defer cleanupDB(t, dsn, token) + + b := newLocalBackend(dsn) + if _, err := b.Provision(context.Background(), token, "anonymous"); err != nil { + t.Fatalf("Provision: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + // StorageBytes runs the query then defers conn.Close(ctx). We can't cancel + // mid-call, but a context with a 1ns deadline forces the close to observe a + // done context. Use a deadline far enough to let the query run, then expire. + cancel() // pre-cancelled: connect itself will fail, exercising connect-error. + _, _ = b.StorageBytes(ctx, token, "") +} + +// TestLocalBackend_StorageBytes_MissingDB covers the pg_database_size error +// branch for a database that doesn't exist. +func TestLocalBackend_StorageBytes_MissingDB(t *testing.T) { + dsn := testCustomersURL(t) + b := newLocalBackend(dsn) + _, err := b.StorageBytes(context.Background(), "definitely-not-provisioned-xyz", "") + if err == nil || !strings.Contains(err.Error(), "pg_database_size") { + t.Fatalf("want pg_database_size error for missing db; got %v", err) + } +} diff --git a/internal/providers/db/neon_test.go b/internal/providers/db/neon_test.go new file mode 100644 index 0000000..dc4992f --- /dev/null +++ b/internal/providers/db/neon_test.go @@ -0,0 +1,239 @@ +package db + +import ( + "context" + "io" + "net/http" + "strings" + "testing" +) + +// rtFunc adapts a function to an http.RoundTripper so we can intercept the +// Neon API calls (the base URL is a package const, not injectable, so we swap +// the transport on the backend's *http.Client instead of the URL). +type rtFunc func(*http.Request) (*http.Response, error) + +func (f rtFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func resp(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + } +} + +// errReader fails on Read so we can exercise the io.ReadAll(body) error paths. +type errReader struct{} + +func (errReader) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF } +func (errReader) Close() error { return nil } + +func respErrBody(status int) *http.Response { + return &http.Response{StatusCode: status, Body: errReader{}, Header: make(http.Header)} +} + +// TestNeon_ReadBodyErrors covers the io.ReadAll(resp.Body) failure branches in +// both Provision (success-path read) and StorageBytes. +func TestNeon_ReadBodyErrors(t *testing.T) { + mk := func(fn rtFunc) *NeonBackend { + b := newNeonBackend("key", "") + b.client = &http.Client{Transport: fn} + return b + } + ctx := context.Background() + + if _, err := mk(func(*http.Request) (*http.Response, error) { + return respErrBody(200), nil + }).Provision(ctx, "t", "x"); err == nil || !strings.Contains(err.Error(), "read body") { + t.Fatalf("Provision read-body error expected; got %v", err) + } + if _, err := mk(func(*http.Request) (*http.Response, error) { + return respErrBody(200), nil + }).StorageBytes(ctx, "t", "p"); err == nil || !strings.Contains(err.Error(), "read body") { + t.Fatalf("StorageBytes read-body error expected; got %v", err) + } +} + +func TestNeon_NewDefaults(t *testing.T) { + b := newNeonBackend("key", "") + if b.regionID != defaultNeonRegion { + t.Fatalf("empty region must default; got %q", b.regionID) + } + if b.apiKey != "key" || b.client == nil { + t.Fatal("apiKey/client not set") + } + b2 := newNeonBackend("k", "eu-central-1") + if b2.regionID != "eu-central-1" { + t.Fatalf("explicit region lost; got %q", b2.regionID) + } +} + +func TestNeon_Provision_Success(t *testing.T) { + b := newNeonBackend("key", "") + b.client = &http.Client{Transport: rtFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodPost || !strings.HasSuffix(r.URL.Path, "/projects") { + t.Fatalf("unexpected request %s %s", r.Method, r.URL) + } + if r.Header.Get("Authorization") != "Bearer key" { + t.Fatalf("missing bearer; got %q", r.Header.Get("Authorization")) + } + return resp(201, `{"project":{"id":"proj-1"},"connection_uris":[{"connection_uri":"postgres://x@neon/db"}]}`), nil + })} + + creds, err := b.Provision(context.Background(), "tok", "team") + if err != nil { + t.Fatalf("Provision: %v", err) + } + if creds.ProviderResourceID != "proj-1" || creds.URL != "postgres://x@neon/db" || creds.DatabaseName != "neondb" { + t.Fatalf("bad creds: %+v", creds) + } +} + +func TestNeon_Provision_ErrorBranches(t *testing.T) { + mk := func(fn rtFunc) *NeonBackend { + b := newNeonBackend("key", "") + b.client = &http.Client{Transport: fn} + return b + } + ctx := context.Background() + + // HTTP transport error. + if _, err := mk(func(*http.Request) (*http.Response, error) { + return nil, io.ErrUnexpectedEOF + }).Provision(ctx, "t", "x"); err == nil { + t.Fatal("transport error must surface") + } + // Non-2xx status. + if _, err := mk(func(*http.Request) (*http.Response, error) { + return resp(500, "boom"), nil + }).Provision(ctx, "t", "x"); err == nil || !strings.Contains(err.Error(), "unexpected status") { + t.Fatalf("non-2xx must error; got %v", err) + } + // Unparseable JSON. + if _, err := mk(func(*http.Request) (*http.Response, error) { + return resp(200, "not-json"), nil + }).Provision(ctx, "t", "x"); err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("bad json must error; got %v", err) + } + // Empty project ID. + if _, err := mk(func(*http.Request) (*http.Response, error) { + return resp(200, `{"project":{"id":""},"connection_uris":[{"connection_uri":"x"}]}`), nil + }).Provision(ctx, "t", "x"); err == nil || !strings.Contains(err.Error(), "empty project ID") { + t.Fatalf("empty id must error; got %v", err) + } + // No connection URI. + if _, err := mk(func(*http.Request) (*http.Response, error) { + return resp(200, `{"project":{"id":"p"},"connection_uris":[]}`), nil + }).Provision(ctx, "t", "x"); err == nil || !strings.Contains(err.Error(), "no connection URI") { + t.Fatalf("missing uri must error; got %v", err) + } +} + +func TestNeon_ProvisionWithExtensions(t *testing.T) { + b := newNeonBackend("key", "") + b.client = &http.Client{Transport: rtFunc(func(*http.Request) (*http.Response, error) { + return resp(201, `{"project":{"id":"p"},"connection_uris":[{"connection_uri":"u"}]}`), nil + })} + ctx := context.Background() + + // Allowlist rejection short-circuits before any HTTP call. + if _, err := b.ProvisionWithExtensions(ctx, "t", "x", []string{"postgis"}); err == nil || !strings.Contains(err.Error(), "allowlist") { + t.Fatalf("want allowlist error; got %v", err) + } + // No extensions → plain provision succeeds. + if _, err := b.ProvisionWithExtensions(ctx, "t", "x", nil); err != nil { + t.Fatalf("no-ext provision: %v", err) + } + // Allowed extension → provision succeeds but returns the not-supported error. + if _, err := b.ProvisionWithExtensions(ctx, "t", "x", []string{"vector"}); err == nil || !strings.Contains(err.Error(), "not yet supported") { + t.Fatalf("want not-supported error; got %v", err) + } +} + +func TestNeon_ProvisionWithExtensions_ProvisionFails(t *testing.T) { + b := newNeonBackend("key", "") + b.client = &http.Client{Transport: rtFunc(func(*http.Request) (*http.Response, error) { + return resp(500, "down"), nil + })} + if _, err := b.ProvisionWithExtensions(context.Background(), "t", "x", []string{"vector"}); err == nil { + t.Fatal("provision failure must propagate through ProvisionWithExtensions") + } +} + +func TestNeon_StorageBytes(t *testing.T) { + mk := func(fn rtFunc) *NeonBackend { + b := newNeonBackend("key", "") + b.client = &http.Client{Transport: fn} + return b + } + ctx := context.Background() + + // Empty providerResourceID short-circuits. + if _, err := mk(nil).StorageBytes(ctx, "t", ""); err == nil { + t.Fatal("empty rid must error") + } + // Success. + n, err := mk(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodGet || !strings.HasSuffix(r.URL.Path, "/projects/proj-9") { + t.Fatalf("bad request %s %s", r.Method, r.URL) + } + return resp(200, `{"project":{"usage":{"data_storage_bytes_hour":4096}}}`), nil + }).StorageBytes(ctx, "t", "proj-9") + if err != nil || n != 4096 { + t.Fatalf("StorageBytes = %d, %v", n, err) + } + // Transport error. + if _, err := mk(func(*http.Request) (*http.Response, error) { + return nil, io.ErrUnexpectedEOF + }).StorageBytes(ctx, "t", "p"); err == nil { + t.Fatal("transport err must surface") + } + // Non-2xx. + if _, err := mk(func(*http.Request) (*http.Response, error) { + return resp(404, "nope"), nil + }).StorageBytes(ctx, "t", "p"); err == nil || !strings.Contains(err.Error(), "unexpected status") { + t.Fatalf("non-2xx must error; got %v", err) + } + // Bad JSON. + if _, err := mk(func(*http.Request) (*http.Response, error) { + return resp(200, "x"), nil + }).StorageBytes(ctx, "t", "p"); err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("bad json must error; got %v", err) + } +} + +func TestNeon_Deprovision(t *testing.T) { + mk := func(fn rtFunc) *NeonBackend { + b := newNeonBackend("key", "") + b.client = &http.Client{Transport: fn} + return b + } + ctx := context.Background() + + // Empty rid short-circuits. + if err := mk(nil).Deprovision(ctx, "t", ""); err == nil { + t.Fatal("empty rid must error") + } + // Success. + if err := mk(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodDelete { + t.Fatalf("want DELETE; got %s", r.Method) + } + return resp(200, ""), nil + }).Deprovision(ctx, "t", "proj-9"); err != nil { + t.Fatalf("Deprovision: %v", err) + } + // Transport error. + if err := mk(func(*http.Request) (*http.Response, error) { + return nil, io.ErrUnexpectedEOF + }).Deprovision(ctx, "t", "p"); err == nil { + t.Fatal("transport err must surface") + } + // Non-2xx. + if err := mk(func(*http.Request) (*http.Response, error) { + return resp(403, "denied"), nil + }).Deprovision(ctx, "t", "p"); err == nil || !strings.Contains(err.Error(), "unexpected status") { + t.Fatalf("non-2xx must error; got %v", err) + } +} diff --git a/internal/providers/db/provider_test.go b/internal/providers/db/provider_test.go new file mode 100644 index 0000000..59d9ffc --- /dev/null +++ b/internal/providers/db/provider_test.go @@ -0,0 +1,66 @@ +package db + +import ( + "context" + "testing" + + "instant.dev/internal/config" +) + +// TestProviderNew_SelectsBackend verifies the factory picks local vs neon by +// cfg.PostgresProvisionBackend and that the thin delegating wrappers call +// through to the chosen backend. +func TestProviderNew_SelectsBackend(t *testing.T) { + local := New(&config.Config{PostgresProvisionBackend: "local"}, "postgres://u:p@h:5432/d") + if _, ok := local.backend.(*LocalBackend); !ok { + t.Fatalf("local backend type = %T", local.backend) + } + neon := New(&config.Config{PostgresProvisionBackend: "neon", NeonAPIKey: "k"}, "") + if _, ok := neon.backend.(*NeonBackend); !ok { + t.Fatalf("neon backend type = %T", neon.backend) + } + // Unknown backend falls through to local. + def := New(&config.Config{PostgresProvisionBackend: "wat"}, "") + if _, ok := def.backend.(*LocalBackend); !ok { + t.Fatalf("default backend type = %T", def.backend) + } +} + +// fakeBackend records which delegating method ran. +type fakeBackend struct{ called string } + +func (f *fakeBackend) Provision(_ context.Context, _, _ string) (*Credentials, error) { + f.called = "Provision" + return &Credentials{}, nil +} +func (f *fakeBackend) ProvisionWithExtensions(_ context.Context, _, _ string, _ []string) (*Credentials, error) { + f.called = "ProvisionWithExtensions" + return &Credentials{}, nil +} +func (f *fakeBackend) StorageBytes(_ context.Context, _, _ string) (int64, error) { + f.called = "StorageBytes" + return 7, nil +} +func (f *fakeBackend) Deprovision(_ context.Context, _, _ string) error { + f.called = "Deprovision" + return nil +} + +func TestProviderDelegation(t *testing.T) { + fb := &fakeBackend{} + p := &Provider{backend: fb} + ctx := context.Background() + + if _, err := p.Provision(ctx, "t", "tier"); err != nil || fb.called != "Provision" { + t.Fatalf("Provision delegation: called=%q err=%v", fb.called, err) + } + if _, err := p.ProvisionWithExtensions(ctx, "t", "tier", []string{"vector"}); err != nil || fb.called != "ProvisionWithExtensions" { + t.Fatalf("ProvisionWithExtensions delegation: called=%q err=%v", fb.called, err) + } + if n, err := p.StorageBytes(ctx, "t", "rid"); err != nil || n != 7 || fb.called != "StorageBytes" { + t.Fatalf("StorageBytes delegation: called=%q n=%d err=%v", fb.called, n, err) + } + if err := p.Deprovision(ctx, "t", "rid"); err != nil || fb.called != "Deprovision" { + t.Fatalf("Deprovision delegation: called=%q err=%v", fb.called, err) + } +} diff --git a/internal/providers/nosql/mongo.go b/internal/providers/nosql/mongo.go index d9c5354..63d0162 100644 --- a/internal/providers/nosql/mongo.go +++ b/internal/providers/nosql/mongo.go @@ -22,6 +22,20 @@ import ( // Short to fail-fast in tests and when MongoDB is not reachable. const connectTimeout = 3 * time.Second +// randRead is the entropy source for the user password. It defaults to +// crypto/rand.Read and is a package var only so a test can substitute a +// fault-injecting reader to exercise the (otherwise unreachable) RNG-failure +// branch in Provision. Production behaviour is identical to crypto/rand.Read. +var randRead = rand.Read + +// mongoDisconnect is the seam through which every Provider method closes its +// client. It defaults to (*mongo.Client).Disconnect and is a package var only +// so a test can force it to error and exercise the disconnect defer-error log +// branches. Production behaviour is identical to calling client.Disconnect. +var mongoDisconnect = func(client *mongo.Client, ctx context.Context) error { + return client.Disconnect(ctx) +} + // Credentials holds the MongoDB connection details returned after provisioning. type Credentials struct { // URL is the mongodb:// connection string the caller can use immediately. @@ -65,14 +79,14 @@ func (p *Provider) Provision(ctx context.Context, token, tier string) (*Credenti return nil, fmt.Errorf("nosql.Provision: connect: %w", err) } defer func() { - if discErr := client.Disconnect(ctx); discErr != nil { + if discErr := mongoDisconnect(client, ctx); discErr != nil { slog.Error("nosql.Provision: disconnect", "error", discErr) } }() // Generate random 16-byte password. pwBytes := make([]byte, 16) - if _, err := rand.Read(pwBytes); err != nil { + if _, err := randRead(pwBytes); err != nil { return nil, fmt.Errorf("nosql.Provision: generate password: %w", err) } password := hex.EncodeToString(pwBytes) @@ -133,7 +147,7 @@ func (p *Provider) StorageBytes(ctx context.Context, token string) (int64, error return 0, nil } defer func() { - if discErr := client.Disconnect(ctx); discErr != nil { + if discErr := mongoDisconnect(client, ctx); discErr != nil { slog.Error("nosql.StorageBytes: disconnect", "error", discErr) } }() @@ -147,16 +161,25 @@ func (p *Provider) StorageBytes(ctx context.Context, token string) (int64, error return 0, nil } - switch v := result["storageSize"].(type) { + return storageSizeToInt64(result["storageSize"]), nil +} + +// storageSizeToInt64 normalises the dbStats.storageSize field, which MongoDB +// returns as one of several numeric BSON types depending on magnitude and +// server version, into an int64. Unknown / nil types yield 0 (fail-open). +// Extracted as a free function so every type arm is unit-testable without +// depending on which numeric type a given mongod build happens to return. +func storageSizeToInt64(v any) int64 { + switch n := v.(type) { case int32: - return int64(v), nil + return int64(n) case int64: - return v, nil + return n case float64: - return int64(v), nil + return int64(n) + default: + return 0 } - - return 0, nil } // Deprovision drops the user and database for the given token. @@ -168,7 +191,7 @@ func (p *Provider) Deprovision(ctx context.Context, token string) error { return fmt.Errorf("nosql.Deprovision: connect: %w", err) } defer func() { - if discErr := client.Disconnect(ctx); discErr != nil { + if discErr := mongoDisconnect(client, ctx); discErr != nil { slog.Error("nosql.Deprovision: disconnect", "error", discErr) } }() diff --git a/internal/providers/nosql/mongo_seam_test.go b/internal/providers/nosql/mongo_seam_test.go new file mode 100644 index 0000000..f5359fe --- /dev/null +++ b/internal/providers/nosql/mongo_seam_test.go @@ -0,0 +1,203 @@ +package nosql + +// mongo_seam_test.go drives the MongoDB defensive branches that cannot be +// triggered deterministically against a healthy mongod: the crypto/rand +// failure and the client.Disconnect defer-error logs. Both go through the +// package seams (randRead + mongoDisconnect). The disconnect-error tests need +// a live mongod (TEST_MONGO_URI) so the happy path runs up to the deferred +// close; they skip cleanly when it is unset. + +import ( + "context" + "errors" + "os" + "strings" + "testing" + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +// TestStorageSizeToInt64 covers every arm of the storageSize numeric +// normalisation deterministically, independent of which BSON type a given +// mongod returns. +func TestStorageSizeToInt64(t *testing.T) { + cases := []struct { + name string + in any + want int64 + }{ + {"int32", int32(123), 123}, + {"int64", int64(456), 456}, + {"float64", float64(789.9), 789}, + {"nil", nil, 0}, + {"string-unknown", "not-a-number", 0}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := storageSizeToInt64(c.in); got != c.want { + t.Fatalf("storageSizeToInt64(%v) = %d; want %d", c.in, got, c.want) + } + }) + } +} + +func seamMongoURI(t *testing.T) string { + t.Helper() + uri := os.Getenv("TEST_MONGO_URI") + if uri == "" { + t.Skip("TEST_MONGO_URI not set — skipping MongoDB seam tests") + } + return uri +} + +func seamHost(uri string) string { + after := strings.TrimPrefix(uri, "mongodb://") + if i := strings.Index(after, "@"); i != -1 { + after = after[i+1:] + } + if i := strings.Index(after, "/"); i != -1 { + after = after[:i] + } + if after == "" { + return "localhost:27017" + } + return after +} + +func seamCleanup(t *testing.T, uri, token string) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri)) + if err != nil { + return + } + defer client.Disconnect(ctx) + client.Database("admin").RunCommand(ctx, bson.D{{Key: "dropUser", Value: "usr_" + token}}) + client.Database("db_" + token).Drop(ctx) +} + +// TestStorageBytes_DBStatsError_FailOpen covers the dbStats RunCommand error +// branch deterministically: a token containing '.' makes db_ an invalid +// MongoDB database name, which the driver rejects client-side, so dbStats +// errors and StorageBytes fails open with (0, nil). +func TestStorageBytes_DBStatsError_FailOpen(t *testing.T) { + uri := seamMongoURI(t) + p := New(uri, seamHost(uri)) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if n, err := p.StorageBytes(ctx, "bad.name"); err != nil || n != 0 { + t.Fatalf("dbStats error must fail open; got (%d,%v)", n, err) + } +} + +// TestDeprovision_DropDatabaseError covers the fatal DROP DATABASE error return +// deterministically: a token containing '.' makes db_ an invalid +// MongoDB database name, which the driver rejects when Drop is called. +func TestDeprovision_DropDatabaseError(t *testing.T) { + uri := seamMongoURI(t) + p := New(uri, seamHost(uri)) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := p.Deprovision(ctx, "bad.name") + if err == nil || !strings.Contains(err.Error(), "drop database") { + t.Fatalf("invalid db name must surface a drop database error; got %v", err) + } +} + +// TestProvision_RandReadFailure covers the crypto/rand failure branch in +// Provision via the randRead seam. It needs a connectable mongod so the flow +// reaches the password step (the connect happens before the RNG call). +func TestProvision_RandReadFailure(t *testing.T) { + uri := seamMongoURI(t) + orig := randRead + randRead = func(b []byte) (int, error) { return 0, errors.New("entropy depleted") } + t.Cleanup(func() { randRead = orig }) + + p := New(uri, seamHost(uri)) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + _, err := p.Provision(ctx, "randfail", "anonymous") + if err == nil || !strings.Contains(err.Error(), "generate password") { + t.Fatalf("randRead failure must surface as generate-password error; got %v", err) + } +} + +// TestProvision_DisconnectError covers the Provision disconnect defer-error log. +// The provision succeeds; mongoDisconnect is forced to error so the deferred +// close hits the error branch. +func TestProvision_DisconnectError(t *testing.T) { + uri := seamMongoURI(t) + token := "discprov1" + defer seamCleanup(t, uri, token) + + orig := mongoDisconnect + mongoDisconnect = func(c *mongo.Client, ctx context.Context) error { + _ = c.Disconnect(ctx) // still really close so we don't leak + return errors.New("forced disconnect error") + } + t.Cleanup(func() { mongoDisconnect = orig }) + + p := New(uri, seamHost(uri)) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + creds, err := p.Provision(ctx, token, "anonymous") + if err != nil { + t.Fatalf("Provision must succeed despite a disconnect error: %v", err) + } + if creds.DatabaseName != "db_"+token { + t.Fatalf("DatabaseName = %q", creds.DatabaseName) + } +} + +// TestStorageBytes_DisconnectError covers the StorageBytes disconnect defer-error +// log. The dbStats call fails open; the forced disconnect error is logged. +func TestStorageBytes_DisconnectError(t *testing.T) { + uri := seamMongoURI(t) + + orig := mongoDisconnect + mongoDisconnect = func(c *mongo.Client, ctx context.Context) error { + _ = c.Disconnect(ctx) + return errors.New("forced disconnect error") + } + t.Cleanup(func() { mongoDisconnect = orig }) + + p := New(uri, seamHost(uri)) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if n, err := p.StorageBytes(ctx, "ghost-token-xyz"); err != nil || n != 0 { + t.Fatalf("StorageBytes must fail open; got (%d,%v)", n, err) + } +} + +// TestDeprovision_DisconnectError covers the Deprovision disconnect defer-error +// log. Deprovision of a non-existent token drops nothing fatal; the forced +// disconnect error is logged. The Drop of a missing DB is a no-op (returns nil). +func TestDeprovision_DisconnectError(t *testing.T) { + uri := seamMongoURI(t) + + orig := mongoDisconnect + mongoDisconnect = func(c *mongo.Client, ctx context.Context) error { + _ = c.Disconnect(ctx) + return errors.New("forced disconnect error") + } + t.Cleanup(func() { mongoDisconnect = orig }) + + p := New(uri, seamHost(uri)) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Dropping a never-created DB is a no-op in MongoDB and returns nil, so + // Deprovision succeeds while still running the deferred (erroring) close. + if err := p.Deprovision(ctx, "neverexisted-token"); err != nil { + t.Fatalf("Deprovision of missing token must succeed: %v", err) + } +} diff --git a/internal/providers/nosql/mongo_unit_test.go b/internal/providers/nosql/mongo_unit_test.go new file mode 100644 index 0000000..0e6f8a3 --- /dev/null +++ b/internal/providers/nosql/mongo_unit_test.go @@ -0,0 +1,237 @@ +package nosql_test + +import ( + "context" + "os" + "strings" + "testing" + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + + nosqlprovider "instant.dev/internal/providers/nosql" +) + +// requireMongoURI mirrors requireMongo but is named distinctly to avoid +// collisions with the black-box mongo_test.go in the same package. +func requireMongoURI(t *testing.T) string { + t.Helper() + uri := os.Getenv("TEST_MONGO_URI") + if uri == "" { + t.Skip("TEST_MONGO_URI not set — skipping MongoDB tests") + } + return uri +} + +func hostFromURI(uri string) string { + after := strings.TrimPrefix(uri, "mongodb://") + if i := strings.Index(after, "@"); i != -1 { + after = after[i+1:] + } + if i := strings.Index(after, "/"); i != -1 { + after = after[:i] + } + if after == "" { + return "localhost:27017" + } + return after +} + +func dropMongo(t *testing.T, uri, token string) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri)) + if err != nil { + return + } + defer client.Disconnect(ctx) + client.Database("admin").RunCommand(ctx, bson.D{{Key: "dropUser", Value: "usr_" + token}}) + client.Database("db_" + token).Drop(ctx) +} + +// TestNew_Defaults covers the empty-arg default branches. +func TestNew_Defaults(t *testing.T) { + p := nosqlprovider.New("", "") + // We can't read unexported fields, but Provision against the default URI + // (root:root@localhost:27017) is exercised elsewhere; here we simply assert + // the constructor returns a usable, non-nil provider. + if p == nil { + t.Fatal("New must return a provider") + } + p2 := nosqlprovider.New("mongodb://x@h:1", "h:1") + if p2 == nil { + t.Fatal("New must return a provider for explicit args") + } +} + +// TestProvision_DuplicateUser covers the createUser error branch: provisioning +// the same token twice fails on the second createUser. +func TestProvision_DuplicateUser(t *testing.T) { + uri := requireMongoURI(t) + host := hostFromURI(uri) + token := "dupuser01" + defer dropMongo(t, uri, token) + + p := nosqlprovider.New(uri, host) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + if _, err := p.Provision(ctx, token, "anonymous"); err != nil { + t.Fatalf("first Provision: %v", err) + } + _, err := p.Provision(ctx, token, "anonymous") + if err == nil || !strings.Contains(err.Error(), "createUser") { + t.Fatalf("duplicate Provision must fail on createUser; got %v", err) + } +} + +// TestStorageBytes_PositiveAfterWrite covers the dbStats success path and the +// storageSize type-switch (the value comes back as a numeric BSON type). +func TestStorageBytes_PositiveAfterWrite(t *testing.T) { + uri := requireMongoURI(t) + host := hostFromURI(uri) + token := "storagesz01" + defer dropMongo(t, uri, token) + + p := nosqlprovider.New(uri, host) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + if _, err := p.Provision(ctx, token, "anonymous"); err != nil { + t.Fatalf("Provision: %v", err) + } + + // Write enough data that dbStats.storageSize is non-zero. + client, _ := mongo.Connect(ctx, options.Client().ApplyURI(uri)) + defer client.Disconnect(ctx) + docs := make([]interface{}, 0, 500) + for i := 0; i < 500; i++ { + docs = append(docs, bson.D{{Key: "i", Value: i}, {Key: "pad", Value: strings.Repeat("x", 256)}}) + } + if _, err := client.Database("db_"+token).Collection("data").InsertMany(ctx, docs); err != nil { + t.Fatalf("seed data: %v", err) + } + + bytes, err := p.StorageBytes(ctx, token) + if err != nil { + t.Fatalf("StorageBytes: %v", err) + } + if bytes <= 0 { + t.Fatalf("storageSize must be > 0 after writes; got %d", bytes) + } +} + +// TestDeprovision_DropUserFailsNonFatal covers the dropUser non-fatal log +// branch: the user does not exist (only the database does) so dropUser errors +// but Deprovision still drops the database and returns nil. +func TestDeprovision_DropUserFailsNonFatal(t *testing.T) { + uri := requireMongoURI(t) + host := hostFromURI(uri) + token := "nouser01" + defer dropMongo(t, uri, token) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + // Create only the database (no user) by inserting a doc directly. + client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri)) + if err != nil { + t.Fatalf("connect: %v", err) + } + if _, err := client.Database("db_"+token).Collection("c").InsertOne(ctx, bson.D{{Key: "x", Value: 1}}); err != nil { + t.Fatalf("seed db: %v", err) + } + client.Disconnect(ctx) + + p := nosqlprovider.New(uri, host) + if err := p.Deprovision(ctx, token); err != nil { + t.Fatalf("Deprovision must succeed even when dropUser fails: %v", err) + } +} + +// TestConnectErrorBranches covers the mongo.Connect error returns in Provision, +// StorageBytes (fail-open → 0,nil) and Deprovision, using a syntactically +// invalid URI that fails at ApplyURI/Connect time. +func TestConnectErrorBranches(t *testing.T) { + p := nosqlprovider.New("not-a-valid-uri", "h:1") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if _, err := p.Provision(ctx, "tok", "anonymous"); err == nil || !strings.Contains(err.Error(), "connect") { + t.Fatalf("Provision connect error expected; got %v", err) + } + // StorageBytes is fail-open: connect error returns (0, nil). + if n, err := p.StorageBytes(ctx, "tok"); err != nil || n != 0 { + t.Fatalf("StorageBytes connect error must fail open; got (%d,%v)", n, err) + } + if err := p.Deprovision(ctx, "tok"); err == nil || !strings.Contains(err.Error(), "connect") { + t.Fatalf("Deprovision connect error expected; got %v", err) + } +} + +// TestProvision_InitInsertNonFatal covers the non-fatal init-insert log branch: +// the sentinel insert fails (here because the database name derived from the +// token is invalid for MongoDB) but createUser already succeeded so Provision +// still returns credentials. +func TestProvision_InitInsertNonFatal(t *testing.T) { + uri := requireMongoURI(t) + host := hostFromURI(uri) + // A token with a '$' makes db_ an invalid MongoDB database name, so + // the sentinel InsertOne fails — exercising the non-fatal branch. createUser + // accepts the username (different validation), so it succeeds first. + token := "init$bad" + defer dropMongo(t, uri, token) + + p := nosqlprovider.New(uri, host) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + creds, err := p.Provision(ctx, token, "anonymous") + if err != nil { + // If createUser itself rejects the name, that's a different (also valid) + // error path; only fail if neither path was exercised cleanly. + if !strings.Contains(err.Error(), "createUser") { + t.Fatalf("unexpected Provision error: %v", err) + } + return + } + if creds.DatabaseName != "db_"+token { + t.Fatalf("DatabaseName = %q", creds.DatabaseName) + } +} + +// TestStorageBytes_MissingDB_FailOpen covers the dbStats fail-open path for a +// database that doesn't exist — returns (0, nil). +func TestStorageBytes_MissingDB_FailOpen(t *testing.T) { + uri := requireMongoURI(t) + host := hostFromURI(uri) + p := nosqlprovider.New(uri, host) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + bytes, err := p.StorageBytes(ctx, "ghost-db-never-made") + if err != nil || bytes != 0 { + t.Fatalf("missing-db StorageBytes = (%d,%v); want (0,nil)", bytes, err) + } +} + +// TestStorageBytes_DBStatsError covers the dbStats RunCommand error branch +// (valid connection, but the derived database name is invalid for MongoDB so +// the dbStats command itself fails) — StorageBytes fails open with (0, nil). +func TestStorageBytes_DBStatsError(t *testing.T) { + uri := requireMongoURI(t) + host := hostFromURI(uri) + p := nosqlprovider.New(uri, host) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // '$' yields an invalid MongoDB database name; the connection succeeds but + // the dbStats command errors. + bytes, err := p.StorageBytes(ctx, "bad$name") + if err != nil || bytes != 0 { + t.Fatalf("dbStats error must fail open; got (%d,%v)", bytes, err) + } +}