diff --git a/internal/backend/postgres/backend_seam_test.go b/internal/backend/postgres/backend_seam_test.go new file mode 100644 index 0000000..1975b6d --- /dev/null +++ b/internal/backend/postgres/backend_seam_test.go @@ -0,0 +1,308 @@ +package postgres + +// backend_seam_test.go — coverage for the NewBackend factory, the goredis +// helper aliases, k8sEnv/k8sEnvInt, and the cluster-router paths the existing +// tests don't reach (at-capacity Pick, refreshCounts/dbCount, pollLoop, +// ProviderResourceID). + +import ( + "context" + "os" + "testing" + "time" +) + +// osWriteFileBackend writes a kubeconfig fixture for the NewBackend k8s tests. +func osWriteFileBackend(path, content string) error { + return os.WriteFile(path, []byte(content), 0o600) +} + +func TestK8sEnv_Seam(t *testing.T) { + t.Setenv("K8S_TEST_KEY", "v") + if k8sEnv("K8S_TEST_KEY", "fb") != "v" { + t.Error("should return env value") + } + if k8sEnv("K8S_UNSET_KEY_XYZ", "fb") != "fb" { + t.Error("should return fallback") + } +} + +func TestK8sEnvInt_Seam(t *testing.T) { + t.Setenv("K8S_INT_KEY", "42") + if k8sEnvInt("K8S_INT_KEY", 7) != 42 { + t.Error("should parse env int") + } + t.Setenv("K8S_INT_BAD", "notanint") + if k8sEnvInt("K8S_INT_BAD", 7) != 7 { + t.Error("bad int should fall back") + } + if k8sEnvInt("K8S_INT_UNSET_XYZ", 9) != 9 { + t.Error("unset should fall back") + } +} + +func TestGoredisAliases_Seam(t *testing.T) { + if _, err := goredisParseURL("not-a-redis-url"); err == nil { + t.Error("expected parse error") + } + opt, err := goredisParseURL("redis://127.0.0.1:6379") + if err != nil { + t.Fatalf("parse: %v", err) + } + if c := goredisNewClient(opt); c == nil { + t.Error("nil client") + } +} + +func TestNewBackend_Neon_Seam(t *testing.T) { + b := NewBackend("neon", "", "", "apikey", "region") + if _, ok := b.(*NeonBackend); !ok { + t.Errorf("want *NeonBackend, got %T", b) + } +} + +func TestNewBackend_DefaultLocal(t *testing.T) { + b := NewBackend("", "postgres://u@h/db", "", "", "") + if _, ok := b.(*LocalBackend); !ok { + t.Errorf("want *LocalBackend, got %T", b) + } +} + +func TestNewBackend_ClusterURLs(t *testing.T) { + b := NewBackend("", "", "url0,url1, ,url2,", "", "") + lb, ok := b.(*LocalBackend) + if !ok { + t.Fatalf("want *LocalBackend, got %T", b) + } + // trailing empty + whitespace entries filtered → 3 clusters + if len(lb.router.adminURLs) != 3 { + t.Errorf("adminURLs = %v; want 3 filtered", lb.router.adminURLs) + } +} + +func TestNewBackend_ClusterURLs_AllEmpty_FallsBack(t *testing.T) { + b := NewBackend("", "cust", " , ,", "", "") + lb := b.(*LocalBackend) + if len(lb.router.adminURLs) != 1 { + t.Errorf("all-empty cluster list should fall back to single; got %v", lb.router.adminURLs) + } +} + +// k8s backend with no kubeconfig + no in-cluster config → newK8sBackend fails → +// NewBackend falls back to local. Covers the fallback branch in the factory. +func TestNewBackend_K8s_FallbackToLocal(t *testing.T) { + t.Setenv("K8S_KUBECONFIG", "/nonexistent/kubeconfig-path") + b := NewBackend("k8s", "cust-url", "", "", "") + if _, ok := b.(*LocalBackend); !ok { + t.Errorf("k8s init failure should fall back to *LocalBackend, got %T", b) + } +} + +// NewBackend("k8s") with a valid kubeconfig + a parseable REDIS_URL_FOR_ROUTES +// exercises the route-registry-enabled block (the only sub-95 gap in NewBackend). +func TestNewBackend_K8s_RouteRegistryEnabled(t *testing.T) { + dir := t.TempDir() + kc := dir + "/kubeconfig" + if err := osWriteFileBackend(kc, minimalKubeconfig); err != nil { + t.Fatalf("write kubeconfig: %v", err) + } + t.Setenv("K8S_KUBECONFIG", kc) + t.Setenv("REDIS_URL_FOR_ROUTES", "redis://127.0.0.1:6379") + b := NewBackend("k8s", "cust", "", "", "") + kb, ok := b.(*K8sBackend) + if !ok { + t.Fatalf("want *K8sBackend, got %T", b) + } + if kb.rdb == nil { + t.Error("route registry should be enabled when REDIS_URL_FOR_ROUTES parses") + } +} + +// NewBackend("k8s") with a valid kubeconfig but an UNPARSEABLE route Redis URL +// exercises the route-registry-disabled (warn) branch. +func TestNewBackend_K8s_RouteRegistryBadURL(t *testing.T) { + dir := t.TempDir() + kc := dir + "/kubeconfig" + if err := osWriteFileBackend(kc, minimalKubeconfig); err != nil { + t.Fatalf("write kubeconfig: %v", err) + } + t.Setenv("K8S_KUBECONFIG", kc) + t.Setenv("REDIS_URL_FOR_ROUTES", "::::not-a-redis-url") + b := NewBackend("k8s", "cust", "", "", "") + kb, ok := b.(*K8sBackend) + if !ok { + t.Fatalf("want *K8sBackend, got %T", b) + } + if kb.rdb != nil { + t.Error("route registry should stay disabled when the URL fails to parse") + } +} + +func TestNewDedicatedBackend_Seam(t *testing.T) { + b := NewDedicatedBackend("dsn", "") + if _, ok := b.(*DedicatedProvider); !ok { + t.Errorf("want *DedicatedProvider, got %T", b) + } +} + +// --- cluster_router uncovered paths --- + +func TestProviderResourceID(t *testing.T) { + r := newClusterRouter([]string{"u0", "u1"}, 0) + if r.ProviderResourceID(1) != "local:1" { + t.Errorf("got %q", r.ProviderResourceID(1)) + } +} + +func TestPick_AllAtCapacity_FallsBackToZero(t *testing.T) { + r := newClusterRouter([]string{"u0", "u1"}, 1) + // Saturate both clusters' counts so headroom <= 0 everywhere. + r.mu.Lock() + r.counts[0] = 5 + r.counts[1] = 5 + r.mu.Unlock() + idx, url, err := r.Pick() + if err != nil { + t.Fatalf("Pick: %v", err) + } + if idx != 0 || url != "u0" { + t.Errorf("at-capacity Pick should fall back to index 0; got %d/%q", idx, url) + } +} + +func TestPick_AllURLsEmpty_BestNegativeFallback(t *testing.T) { + // Non-empty slice but every URL blank → loop never sets best → best<0 path. + r := newClusterRouter([]string{"", ""}, 0) + idx, _, err := r.Pick() + if err != nil { + t.Fatalf("Pick: %v", err) + } + if idx != 0 { + t.Errorf("best<0 fallback should pick index 0; got %d", idx) + } +} + +func TestPick_NoClusters_Error(t *testing.T) { + r := newClusterRouter(nil, 0) + if _, _, err := r.Pick(); err == nil { + t.Error("expected no-clusters error") + } +} + +func TestRefreshCounts_ConnectFails_KeepsPrevious(t *testing.T) { + // Unreachable admin URL → dbCount errors → previous count retained. + r := newClusterRouter([]string{"postgres://x@127.0.0.1:1/none", ""}, 0) + r.mu.Lock() + r.counts[0] = 3 + r.mu.Unlock() + r.refreshCounts(context.Background()) + r.mu.RLock() + got := r.counts[0] + r.mu.RUnlock() + if got != 3 { + t.Errorf("count after failed poll = %d; want previous 3", got) + } +} + +func TestDbCount_ConnectError(t *testing.T) { + r := newClusterRouter([]string{"x"}, 0) + if _, err := r.dbCount(context.Background(), "postgres://x@127.0.0.1:1/none"); err == nil { + t.Error("expected connect error") + } +} + +func TestDbCount_Success_ViaSeam(t *testing.T) { + fc := &fakePGConn{scanInt64: 11} + withPGXConnect(t, fc, nil) + r := newClusterRouter([]string{"u0"}, 0) + n, err := r.dbCount(context.Background(), "u0") + if err != nil || n != 11 { + t.Errorf("dbCount = %d, %v", n, err) + } +} + +func TestDbCount_ScanError_ViaSeam(t *testing.T) { + fc := &fakePGConn{queryRowErr: errSeam} + withPGXConnect(t, fc, nil) + r := newClusterRouter([]string{"u0"}, 0) + if _, err := r.dbCount(context.Background(), "u0"); err == nil { + t.Error("expected scan error") + } +} + +// pollLoop ticker branch: shrink the poll interval so ticker.C fires and the +// periodic refreshCounts runs, then cancel. +func TestPollLoop_TickerFires(t *testing.T) { + fc := &fakePGConn{scanInt64: 1} + withPGXConnect(t, fc, nil) + r := newClusterRouter([]string{"u0"}, 0) + r.pollInterval = 5 * time.Millisecond // per-instance, no shared-global race + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { r.pollLoop(ctx); close(done) }() + time.Sleep(40 * time.Millisecond) // let several ticks fire + cancel() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("pollLoop did not return after cancel") + } +} + +// pollLoop with a non-positive pollInterval falls back to the default — covers +// the interval<=0 guard. We cancel immediately after start so the default-60s +// ticker never actually fires (we only need the guard line executed). +func TestPollLoop_ZeroIntervalFallsBackToDefault(t *testing.T) { + fc := &fakePGConn{scanInt64: 1} + withPGXConnect(t, fc, nil) + r := newClusterRouter([]string{"u0"}, 0) + r.pollInterval = 0 // → guard sets interval = defaultClusterPollInterval + ctx, cancel := context.WithCancel(context.Background()) + cancel() // pre-cancel: after the immediate refresh + ticker setup, return + done := make(chan struct{}) + go func() { r.pollLoop(ctx); close(done) }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("pollLoop did not return") + } +} + +// pollLoop ctx.Done() return path: call pollLoop directly with an +// already-cancelled context and a fresh (never-closed) done channel so the +// select can only fire on ctx.Done(). Deterministic — no race with done. +func TestPollLoop_CtxDoneReturns(t *testing.T) { + fc := &fakePGConn{scanInt64: 1} + withPGXConnect(t, fc, nil) + r := newClusterRouter([]string{"u0"}, 0) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // pre-cancel: after the immediate refresh, select hits ctx.Done() + done := make(chan struct{}) + go func() { r.pollLoop(ctx); close(done) }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("pollLoop did not return on ctx.Done()") + } +} + +// pollLoop done-channel return path: drive directly and signal exit via +// Shutdown (closes done), joining the goroutine so it can't leak. +func TestPollLoop_ShutdownReturns(t *testing.T) { + fc := &fakePGConn{scanInt64: 1} + withPGXConnect(t, fc, nil) + r := newClusterRouter([]string{"u0"}, 0) + done := make(chan struct{}) + go func() { r.pollLoop(context.Background()); close(done) }() + // Wait until the poller is up, then Shutdown (closes r.done) → return. + deadline := time.Now().Add(2 * time.Second) + for r.pollStarts.Load() == 0 && time.Now().Before(deadline) { + time.Sleep(time.Millisecond) + } + r.Shutdown() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("pollLoop did not return after Shutdown") + } +} diff --git a/internal/backend/postgres/cluster_router.go b/internal/backend/postgres/cluster_router.go index b1a3055..2efc480 100644 --- a/internal/backend/postgres/cluster_router.go +++ b/internal/backend/postgres/cluster_router.go @@ -20,8 +20,6 @@ import ( "sync" "sync/atomic" "time" - - "github.com/jackc/pgx/v5" ) // ClusterRouter picks the admin DSN of the least-loaded shared Postgres cluster. @@ -48,8 +46,23 @@ type ClusterRouter struct { // regression test to assert exactly one poller runs across N Start calls. pollStarts atomic.Int32 done chan struct{} + // pollWG tracks the poll goroutine so Shutdown can block until it has + // fully returned. Without this, Shutdown only signals done and returns — + // the goroutine may still be mid-pgxConnect, so the polling connection + // (and any seam it reads) outlives Shutdown. Joining here makes Shutdown a + // true barrier: no router goroutine touches pgxConnect after it returns. + pollWG sync.WaitGroup + + // pollInterval is the refresh cadence. Defaults to defaultClusterPollInterval; + // a test sets it on its own router instance (no shared global) to exercise the + // ticker branch quickly without a 60s wait. + pollInterval time.Duration } +// defaultClusterPollInterval is the production cadence at which pollLoop +// refreshes per-cluster database counts. +const defaultClusterPollInterval = 60 * time.Second + // newClusterRouter creates a ClusterRouter for the given admin DSNs. // maxPerCluster sets the database capacity cap. Pass 0 to use the default (400). func newClusterRouter(adminURLs []string, maxPerCluster int) *ClusterRouter { @@ -61,11 +74,12 @@ func newClusterRouter(adminURLs []string, maxPerCluster int) *ClusterRouter { caps[i] = maxPerCluster } return &ClusterRouter{ - adminURLs: adminURLs, - maxDBs: caps, - counts: make([]int, len(adminURLs)), - inflight: make([]int, len(adminURLs)), - done: make(chan struct{}), + adminURLs: adminURLs, + maxDBs: caps, + counts: make([]int, len(adminURLs)), + inflight: make([]int, len(adminURLs)), + done: make(chan struct{}), + pollInterval: defaultClusterPollInterval, } } @@ -75,17 +89,26 @@ func newClusterRouter(adminURLs []string, maxPerCluster int) *ClusterRouter { // spawn a second poller. Call Shutdown to stop. func (r *ClusterRouter) Start(ctx context.Context) { r.startOnce.Do(func() { - go r.pollLoop(ctx) + r.pollWG.Add(1) + go func() { + defer r.pollWG.Done() + r.pollLoop(ctx) + }() }) } -// Shutdown stops the background polling goroutine. +// Shutdown stops the background polling goroutine and blocks until it has +// returned. Joining (rather than only signalling done) guarantees no poll +// connection is in flight once Shutdown returns — important both in prod (clean +// teardown) and in tests (a leaked poller would otherwise call pgxConnect after +// a test restores the seam, racing later tests). func (r *ClusterRouter) Shutdown() { select { case <-r.done: default: close(r.done) } + r.pollWG.Wait() } // Pick returns the index and admin DSN of the cluster with the most available @@ -198,7 +221,11 @@ func (r *ClusterRouter) pollLoop(ctx context.Context) { // many times Start is called — the once-guard regression test asserts it. r.pollStarts.Add(1) - ticker := time.NewTicker(60 * time.Second) + interval := r.pollInterval + if interval <= 0 { + interval = defaultClusterPollInterval + } + ticker := time.NewTicker(interval) defer ticker.Stop() // Immediate first poll so counts are populated before the first provision. @@ -248,7 +275,7 @@ func (r *ClusterRouter) dbCount(ctx context.Context, adminURL string) (int, erro ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - conn, err := pgx.Connect(ctx, adminURL) + conn, err := pgxConnect(ctx, adminURL) if err != nil { return 0, fmt.Errorf("connect: %w", err) } diff --git a/internal/backend/postgres/cluster_router_startonce_test.go b/internal/backend/postgres/cluster_router_startonce_test.go index 4b43a5d..5257ff2 100644 --- a/internal/backend/postgres/cluster_router_startonce_test.go +++ b/internal/backend/postgres/cluster_router_startonce_test.go @@ -36,6 +36,9 @@ func TestStart_OnceGuard_SinglePoller(t *testing.T) { cancel() r := newClusterRouter([]string{"postgres://bogus@127.0.0.1:1/none"}, 0) + // Join the poller on exit so it cannot outlive this test and call the real + // pgxConnect after a later test installs a fake seam (Shutdown now Waits). + defer r.Shutdown() for i := 0; i < 5; i++ { r.Start(ctx) diff --git a/internal/backend/postgres/dedicated.go b/internal/backend/postgres/dedicated.go index 1ca02d8..e87b5d7 100644 --- a/internal/backend/postgres/dedicated.go +++ b/internal/backend/postgres/dedicated.go @@ -15,11 +15,8 @@ import ( "context" "encoding/json" "fmt" - "io" "log/slog" "net/http" - - "github.com/jackc/pgx/v5" ) const dedicatedNeonRegion = "aws-us-east-2" @@ -103,12 +100,12 @@ func (p *DedicatedProvider) provisionNeon(ctx context.Context, token, tier strin "region_id": dedicatedNeonRegion, }, } - bodyBytes, err := json.Marshal(body) + bodyBytes, err := jsonMarshal(body) if err != nil { return nil, fmt.Errorf("db.dedicated.provisionNeon: marshal: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, + req, err := httpNewRequestWithContext(ctx, http.MethodPost, p.neonBaseURL+"/projects", bytes.NewReader(bodyBytes)) if err != nil { return nil, fmt.Errorf("db.dedicated.provisionNeon: new request: %w", err) @@ -122,7 +119,7 @@ func (p *DedicatedProvider) provisionNeon(ctx context.Context, token, tier strin } defer resp.Body.Close() - respBytes, err := io.ReadAll(resp.Body) + respBytes, err := ioReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("db.dedicated.provisionNeon: read body: %w", err) } @@ -165,7 +162,7 @@ func (p *DedicatedProvider) neonStorageBytes(ctx context.Context, providerResour if providerResourceID == "" { return 0, fmt.Errorf("db.dedicated.neonStorageBytes: empty providerResourceID") } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, + req, err := httpNewRequestWithContext(ctx, http.MethodGet, p.neonBaseURL+"/projects/"+providerResourceID, nil) if err != nil { return 0, fmt.Errorf("db.dedicated.neonStorageBytes: new request: %w", err) @@ -178,7 +175,7 @@ func (p *DedicatedProvider) neonStorageBytes(ctx context.Context, providerResour } defer resp.Body.Close() - respBytes, err := io.ReadAll(resp.Body) + respBytes, err := ioReadAll(resp.Body) if err != nil { return 0, fmt.Errorf("db.dedicated.neonStorageBytes: read body: %w", err) } @@ -203,7 +200,7 @@ func (p *DedicatedProvider) deprovisionNeon(ctx context.Context, token, provider if providerResourceID == "" { return fmt.Errorf("db.dedicated.deprovisionNeon: empty providerResourceID") } - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, + req, err := httpNewRequestWithContext(ctx, http.MethodDelete, p.neonBaseURL+"/projects/"+providerResourceID, nil) if err != nil { return fmt.Errorf("db.dedicated.deprovisionNeon: new request: %w", err) @@ -217,7 +214,7 @@ func (p *DedicatedProvider) deprovisionNeon(ctx context.Context, token, provider defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(resp.Body) + body, _ := ioReadAll(resp.Body) return fmt.Errorf("db.dedicated.deprovisionNeon: status %d: %s", resp.StatusCode, string(body)) } slog.Info("db.dedicated.deprovisionNeon: deprovisioned", "token", token, "project_id", providerResourceID) @@ -246,7 +243,7 @@ func (p *DedicatedProvider) provisionLocal(ctx context.Context, token, tier stri } adminDSN := p.localAdminDSN() - conn, err := pgx.Connect(ctx, adminDSN) + conn, err := pgxConnect(ctx, adminDSN) if err != nil { return nil, fmt.Errorf("db.dedicated.provisionLocal: connect: %w", err) } @@ -267,7 +264,7 @@ func (p *DedicatedProvider) provisionLocal(ctx context.Context, token, tier stri } // Grant schema privileges on the new database. - adminNewDB, err := pgx.Connect(ctx, buildAdminNewDBURL(adminDSN, dbName)) + adminNewDB, err := pgxConnect(ctx, buildAdminNewDBURL(adminDSN, dbName)) if err != nil { slog.Error("db.dedicated.provisionLocal: connect new db for schema grant (non-fatal)", "error", err) } else { @@ -300,7 +297,7 @@ func (p *DedicatedProvider) provisionLocal(ctx context.Context, token, tier stri func (p *DedicatedProvider) localStorageBytes(ctx context.Context, token string) (int64, error) { adminDSN := p.localAdminDSN() - conn, err := pgx.Connect(ctx, adminDSN) + conn, err := pgxConnect(ctx, adminDSN) if err != nil { return 0, fmt.Errorf("db.dedicated.localStorageBytes: connect: %w", err) } @@ -323,7 +320,7 @@ func (p *DedicatedProvider) deprovisionLocal(ctx context.Context, token string) username := "dedicated_usr_" + token adminDSN := p.localAdminDSN() - conn, err := pgx.Connect(ctx, adminDSN) + conn, err := pgxConnect(ctx, adminDSN) if err != nil { return fmt.Errorf("db.dedicated.deprovisionLocal: connect: %w", err) } diff --git a/internal/backend/postgres/dedicated_seam_test.go b/internal/backend/postgres/dedicated_seam_test.go new file mode 100644 index 0000000..ed5b92e --- /dev/null +++ b/internal/backend/postgres/dedicated_seam_test.go @@ -0,0 +1,403 @@ +package postgres + +// dedicated_seam_test.go — coverage for dedicated.go: the Neon-API path +// (provision/storage/deprovision + every error wrap) and the local-admin path +// (driven via the pgConn seam). + +import ( + "context" + "encoding/json" + "io" + "math/big" + "net/http" + "net/http/httptest" + "testing" +) + +func dedicatedNeonProvider(t *testing.T, h http.HandlerFunc) *DedicatedProvider { + t.Helper() + srv := httptest.NewServer(h) + t.Cleanup(srv.Close) + return &DedicatedProvider{neonAPIKey: "key", neonBaseURL: srv.URL, httpClient: srv.Client()} +} + +func TestNewDedicatedProvider_Defaults(t *testing.T) { + p := NewDedicatedProvider("dsn", "key") + if p.adminDSN != "dsn" || p.neonAPIKey != "key" || p.neonBaseURL != neonAPIBase || p.httpClient == nil { + t.Errorf("provider = %+v", p) + } +} + +func TestDedicated_Provision_NeonSuccess(t *testing.T) { + p := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]any{ + "project": map[string]string{"id": "dp1"}, + "connection_uris": []map[string]string{{"connection_uri": "postgres://d"}}, + }) + }) + creds, err := p.Provision(context.Background(), "this-is-a-very-long-token-1234567890", "team", -1) + if err != nil { + t.Fatalf("Provision: %v", err) + } + if creds.ProviderResourceID != "dp1" { + t.Errorf("PRID = %q", creds.ProviderResourceID) + } +} + +func TestDedicated_ProvisionNeon_MarshalError(t *testing.T) { + p := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) {}) + orig := jsonMarshal + jsonMarshal = func(any) ([]byte, error) { return nil, errSeam } + t.Cleanup(func() { jsonMarshal = orig }) + if _, err := p.Provision(context.Background(), "t", "team", -1); err == nil { + t.Error("expected marshal error") + } +} + +func TestDedicated_ProvisionNeon_NewRequestError(t *testing.T) { + p := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) {}) + orig := httpNewRequestWithContext + httpNewRequestWithContext = func(context.Context, string, string, io.Reader) (*http.Request, error) { + return nil, errSeam + } + t.Cleanup(func() { httpNewRequestWithContext = orig }) + if _, err := p.Provision(context.Background(), "t", "team", -1); err == nil { + t.Error("expected new-request error") + } +} + +func TestDedicated_ProvisionNeon_HTTPError(t *testing.T) { + p := &DedicatedProvider{neonAPIKey: "k", neonBaseURL: "http://127.0.0.1:1", httpClient: &http.Client{}} + if _, err := p.Provision(context.Background(), "t", "team", -1); err == nil { + t.Error("expected http error") + } +} + +func TestDedicated_ProvisionNeon_ReadBodyError(t *testing.T) { + p := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte("{}")) + }) + orig := ioReadAll + ioReadAll = func(io.Reader) ([]byte, error) { return nil, errSeam } + t.Cleanup(func() { ioReadAll = orig }) + if _, err := p.Provision(context.Background(), "t", "team", -1); err == nil { + t.Error("expected read-body error") + } +} + +func TestDedicated_ProvisionNeon_Non2xx(t *testing.T) { + p := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + }) + if _, err := p.Provision(context.Background(), "t", "team", -1); err == nil { + t.Error("expected non-2xx error") + } +} + +func TestDedicated_ProvisionNeon_UnmarshalError(t *testing.T) { + p := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte("xx")) + }) + if _, err := p.Provision(context.Background(), "t", "team", -1); err == nil { + t.Error("expected unmarshal error") + } +} + +func TestDedicated_ProvisionNeon_EmptyProjectID(t *testing.T) { + p := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]any{"project": map[string]string{"id": ""}}) + }) + if _, err := p.Provision(context.Background(), "t", "team", -1); err == nil { + t.Error("expected empty-id error") + } +} + +func TestDedicated_ProvisionNeon_NoConnectionURI(t *testing.T) { + p := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]any{"project": map[string]string{"id": "x"}}) + }) + if _, err := p.Provision(context.Background(), "t", "team", -1); err == nil { + t.Error("expected no-uri error") + } +} + +func TestDedicated_StorageBytes_Neon(t *testing.T) { + p := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "project": map[string]any{"usage": map[string]any{"data_storage_bytes_hour": 7}}, + }) + }) + got, err := p.StorageBytes(context.Background(), "tok", "p1") + if err != nil || got != 7 { + t.Errorf("StorageBytes = %d, %v", got, err) + } +} + +func TestDedicated_neonStorageBytes_ErrorBranches(t *testing.T) { + // empty PRID + p := &DedicatedProvider{neonAPIKey: "k", httpClient: &http.Client{}, neonBaseURL: "http://x"} + if _, err := p.neonStorageBytes(context.Background(), ""); err == nil { + t.Error("expected empty-PRID error") + } + + // new-request error + orig := httpNewRequestWithContext + httpNewRequestWithContext = func(context.Context, string, string, io.Reader) (*http.Request, error) { + return nil, errSeam + } + if _, err := p.neonStorageBytes(context.Background(), "p1"); err == nil { + t.Error("expected new-request error") + } + httpNewRequestWithContext = orig + + // http error + p2 := &DedicatedProvider{neonAPIKey: "k", httpClient: &http.Client{}, neonBaseURL: "http://127.0.0.1:1"} + if _, err := p2.neonStorageBytes(context.Background(), "p1"); err == nil { + t.Error("expected http error") + } + + // non-2xx + p3 := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "no", http.StatusNotFound) + }) + if _, err := p3.neonStorageBytes(context.Background(), "p1"); err == nil { + t.Error("expected non-2xx error") + } + + // read-body error + p4 := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("{}")) }) + io2 := ioReadAll + ioReadAll = func(io.Reader) ([]byte, error) { return nil, errSeam } + if _, err := p4.neonStorageBytes(context.Background(), "p1"); err == nil { + t.Error("expected read-body error") + } + ioReadAll = io2 + + // unmarshal error + p5 := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("xx")) }) + if _, err := p5.neonStorageBytes(context.Background(), "p1"); err == nil { + t.Error("expected unmarshal error") + } +} + +func TestDedicated_Deprovision_Neon(t *testing.T) { + p := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + if err := p.Deprovision(context.Background(), "tok", "p1"); err != nil { + t.Fatalf("Deprovision: %v", err) + } +} + +func TestDedicated_deprovisionNeon_ErrorBranches(t *testing.T) { + p := &DedicatedProvider{neonAPIKey: "k", httpClient: &http.Client{}, neonBaseURL: "http://x"} + if err := p.deprovisionNeon(context.Background(), "tok", ""); err == nil { + t.Error("expected empty-PRID error") + } + + orig := httpNewRequestWithContext + httpNewRequestWithContext = func(context.Context, string, string, io.Reader) (*http.Request, error) { + return nil, errSeam + } + if err := p.deprovisionNeon(context.Background(), "tok", "p1"); err == nil { + t.Error("expected new-request error") + } + httpNewRequestWithContext = orig + + p2 := &DedicatedProvider{neonAPIKey: "k", httpClient: &http.Client{}, neonBaseURL: "http://127.0.0.1:1"} + if err := p2.deprovisionNeon(context.Background(), "tok", "p1"); err == nil { + t.Error("expected http error") + } + + p3 := dedicatedNeonProvider(t, func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "no", http.StatusBadRequest) + }) + if err := p3.deprovisionNeon(context.Background(), "tok", "p1"); err == nil { + t.Error("expected non-2xx error") + } +} + +func TestDedicated_Regrade_NoOp(t *testing.T) { + p := &DedicatedProvider{} + res, err := p.Regrade(context.Background(), "tok", "p1", 5) + if err != nil || res.Applied { + t.Errorf("Regrade = %+v, %v", res, err) + } +} + +// --- local-admin path (neonAPIKey == "") --- + +func TestDedicated_localAdminDSN_Fallback(t *testing.T) { + p := &DedicatedProvider{} + if p.localAdminDSN() != defaultCustomersURL { + t.Errorf("localAdminDSN = %q; want default", p.localAdminDSN()) + } + p2 := &DedicatedProvider{adminDSN: "postgres://a@h/x"} + if p2.localAdminDSN() != "postgres://a@h/x" { + t.Errorf("localAdminDSN = %q", p2.localAdminDSN()) + } +} + +func TestDedicated_ProvisionLocal_Success(t *testing.T) { + fc := &fakePGConn{} + withPGXConnect(t, fc, nil) + p := &DedicatedProvider{adminDSN: "postgres://a:b@h:5432/postgres"} + creds, err := p.Provision(context.Background(), "tok", "team", -1) + if err != nil { + t.Fatalf("Provision local: %v", err) + } + if creds.DatabaseName != "dedicated_db_tok" || creds.Username != "dedicated_usr_tok" { + t.Errorf("creds = %+v", creds) + } +} + +func TestDedicated_ProvisionLocal_GenPasswordError(t *testing.T) { + origRI := randInt + randInt = func(_ io.Reader, _ *big.Int) (*big.Int, error) { return nil, errSeam } + t.Cleanup(func() { randInt = origRI }) + p := &DedicatedProvider{adminDSN: "x"} + if _, err := p.Provision(context.Background(), "tok", "team", -1); err == nil { + t.Error("expected generatePassword error") + } +} + +func TestDedicated_ProvisionLocal_CloseErrors(t *testing.T) { + // closeErr on both the admin conn and the new-db conn → both deferred-close + // error-log branches run; Provision still succeeds. + fc := &fakePGConn{closeErr: errSeam} + withPGXConnect(t, fc, nil) + p := &DedicatedProvider{adminDSN: "x"} + if _, err := p.Provision(context.Background(), "tok", "team", -1); err != nil { + t.Errorf("Close errors must be non-fatal: %v", err) + } +} + +func TestDedicated_localStorageBytes_CloseError(t *testing.T) { + fc := &fakePGConn{scanInt64: 1, closeErr: errSeam} + withPGXConnect(t, fc, nil) + p := &DedicatedProvider{adminDSN: "x"} + if _, err := p.StorageBytes(context.Background(), "tok", ""); err != nil { + t.Errorf("Close error must be non-fatal: %v", err) + } +} + +func TestDedicated_DeprovisionLocal_CloseError(t *testing.T) { + fc := &fakePGConn{closeErr: errSeam} + withPGXConnect(t, fc, nil) + p := &DedicatedProvider{adminDSN: "x"} + if err := p.Deprovision(context.Background(), "tok", ""); err != nil { + t.Errorf("Close error must be non-fatal: %v", err) + } +} + +func TestDedicated_ProvisionLocal_ConnectError(t *testing.T) { + withPGXConnect(t, nil, errSeam) + p := &DedicatedProvider{adminDSN: "x"} + if _, err := p.Provision(context.Background(), "tok", "team", -1); err == nil { + t.Error("expected connect error") + } +} + +func TestDedicated_ProvisionLocal_ExecErrorBranches(t *testing.T) { + for _, sub := range []string{"CREATE DATABASE", "CREATE USER", "GRANT ALL PRIVILEGES ON DATABASE"} { + fc := &fakePGConn{execErrFor: map[string]error{sub: errSeam}} + withPGXConnect(t, fc, nil) + p := &DedicatedProvider{adminDSN: "x"} + if _, err := p.Provision(context.Background(), "tok", "team", -1); err == nil { + t.Errorf("expected error when %q fails", sub) + } + } +} + +func TestDedicated_ProvisionLocal_SchemaGrantNonFatal(t *testing.T) { + fc := &fakePGConn{execErrFor: map[string]error{"GRANT ALL ON SCHEMA": errSeam}} + withPGXConnect(t, fc, nil) + p := &DedicatedProvider{adminDSN: "x"} + if _, err := p.Provision(context.Background(), "tok", "team", -1); err != nil { + t.Errorf("schema-grant failure must be non-fatal: %v", err) + } +} + +func TestDedicated_ProvisionLocal_NewDBConnectError_NonFatal(t *testing.T) { + var calls int + fc := &fakePGConn{} + withPGXConnectFunc(t, func(ctx context.Context, dsn string) (pgConn, error) { + calls++ + if calls == 2 { + return nil, errSeam + } + return fc, nil + }) + p := &DedicatedProvider{adminDSN: "x"} + if _, err := p.Provision(context.Background(), "tok", "team", -1); err != nil { + t.Errorf("new-db connect failure must be non-fatal: %v", err) + } +} + +func TestDedicated_localStorageBytes(t *testing.T) { + fc := &fakePGConn{scanInt64: 555} + withPGXConnect(t, fc, nil) + p := &DedicatedProvider{adminDSN: "x"} + got, err := p.StorageBytes(context.Background(), "tok", "") + if err != nil || got != 555 { + t.Errorf("StorageBytes = %d, %v", got, err) + } +} + +func TestDedicated_localStorageBytes_ConnectError(t *testing.T) { + withPGXConnect(t, nil, errSeam) + p := &DedicatedProvider{adminDSN: "x"} + if _, err := p.StorageBytes(context.Background(), "tok", ""); err == nil { + t.Error("expected connect error") + } +} + +func TestDedicated_localStorageBytes_ScanError(t *testing.T) { + fc := &fakePGConn{queryRowErr: errSeam} + withPGXConnect(t, fc, nil) + p := &DedicatedProvider{adminDSN: "x"} + if _, err := p.StorageBytes(context.Background(), "tok", ""); err == nil { + t.Error("expected scan error") + } +} + +func TestDedicated_DeprovisionLocal_Success(t *testing.T) { + fc := &fakePGConn{} + withPGXConnect(t, fc, nil) + p := &DedicatedProvider{adminDSN: "x"} + if err := p.Deprovision(context.Background(), "tok", ""); err != nil { + t.Fatalf("Deprovision local: %v", err) + } +} + +func TestDedicated_DeprovisionLocal_ConnectError(t *testing.T) { + withPGXConnect(t, nil, errSeam) + p := &DedicatedProvider{adminDSN: "x"} + if err := p.Deprovision(context.Background(), "tok", ""); err == nil { + t.Error("expected connect error") + } +} + +func TestDedicated_DeprovisionLocal_TerminateAndDropUserNonFatal(t *testing.T) { + fc := &fakePGConn{execErrFor: map[string]error{ + "pg_terminate_backend": errSeam, + "DROP USER": errSeam, + }} + withPGXConnect(t, fc, nil) + p := &DedicatedProvider{adminDSN: "x"} + if err := p.Deprovision(context.Background(), "tok", ""); err != nil { + t.Errorf("terminate/drop-user failures must be non-fatal: %v", err) + } +} + +func TestDedicated_DeprovisionLocal_DropDBError(t *testing.T) { + fc := &fakePGConn{execErrFor: map[string]error{"DROP DATABASE": errSeam}} + withPGXConnect(t, fc, nil) + p := &DedicatedProvider{adminDSN: "x"} + if err := p.Deprovision(context.Background(), "tok", ""); err == nil { + t.Error("expected DROP DATABASE error") + } +} diff --git a/internal/backend/postgres/k8s.go b/internal/backend/postgres/k8s.go index caf50ae..57d791e 100644 --- a/internal/backend/postgres/k8s.go +++ b/internal/backend/postgres/k8s.go @@ -40,14 +40,12 @@ import ( goredis "github.com/redis/go-redis/v9" "context" - "crypto/rand" "encoding/hex" "fmt" "log/slog" "strings" "time" - "github.com/jackc/pgx/v5" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" networkingv1 "k8s.io/api/networking/v1" @@ -64,11 +62,9 @@ import ( ) const ( - k8sNsPrefix = "instant-customer-" - k8sRoleLabel = "instant.dev/role" - k8sRoleValue = "customer-resource" - k8sReadyTimeout = 3 * time.Minute - k8sReadyInterval = 3 * time.Second + k8sNsPrefix = "instant-customer-" + k8sRoleLabel = "instant.dev/role" + k8sRoleValue = "customer-resource" // k8sOwnerTeamLabel is applied to dedicated customer namespaces to record // the owning team UUID. The deploy-side NetworkPolicy in the api repo @@ -77,6 +73,19 @@ const ( k8sOwnerTeamLabel = "instant.dev/owner-team" ) +// k8sReadyTimeout / k8sReadyInterval bound waitPodReady. They are package vars +// (not consts) only so tests can shrink them to milliseconds to reach the +// timeout branch without a 3-minute wait. Production values are unchanged. +var ( + k8sReadyTimeout = 3 * time.Minute + k8sReadyInterval = 3 * time.Second + + // k8sNsTerminateTimeout / k8sNsTerminatePoll bound the Terminating-namespace + // wait loop in applyNamespace. Package vars for the same test-shrink reason. + k8sNsTerminateTimeout = 2 * time.Minute + k8sNsTerminatePoll = 3 * time.Second +) + // tierSizing maps a billing tier to k8s resource sizing for the provisioned pod. // Anonymous (24h trial) gets the smallest viable pod — still a real, dedicated // Postgres, just configured for low cost so the free tier scales. Each step up @@ -108,7 +117,7 @@ func sizingForTier(tier string) tierSizing { pvcGi: 0, qCPURequests: "100m", qMemRequests: "256Mi", qCPULimits: "500m", qMemLimits: "512Mi", - connLimit: 2, + connLimit: 2, } case "hobby": return tierSizing{ @@ -117,7 +126,7 @@ func sizingForTier(tier string) tierSizing { pvcGi: 5, qCPURequests: "200m", qMemRequests: "512Mi", qCPULimits: "1", qMemLimits: "2Gi", - connLimit: 5, + connLimit: 5, } case "pro": return tierSizing{ @@ -126,7 +135,7 @@ func sizingForTier(tier string) tierSizing { pvcGi: 50, qCPURequests: "500m", qMemRequests: "2Gi", qCPULimits: "4", qMemLimits: "8Gi", - connLimit: 20, + connLimit: 20, } case "team", "growth": return tierSizing{ @@ -135,7 +144,7 @@ func sizingForTier(tier string) tierSizing { pvcGi: 200, qCPURequests: "1", qMemRequests: "4Gi", qCPULimits: "8", qMemLimits: "16Gi", - connLimit: -1, // unlimited; capped only by pod max_connections + connLimit: -1, // unlimited; capped only by pod max_connections } default: // Unknown tier → conservative hobby-equivalent sizing rather than fail-open. @@ -147,10 +156,10 @@ func sizingForTier(tier string) tierSizing { // All configuration comes from environment variables — see config.go for the full list. type K8sBackend struct { cs kubernetes.Interface // kubernetes.Interface allows fake.Clientset in tests - storageClass string // K8S_STORAGE_CLASS: "gp3" (EKS) or "local-path" (dev) - image string // K8S_POSTGRES_IMAGE: "pgvector/pgvector:pg16" (default) - externalHost string // K8S_EXTERNAL_HOST: node IP, LB DNS, or proxy hostname - storageSizeGi int // K8S_POSTGRES_STORAGE_GI: default 50 + storageClass string // K8S_STORAGE_CLASS: "gp3" (EKS) or "local-path" (dev) + image string // K8S_POSTGRES_IMAGE: "pgvector/pgvector:pg16" (default) + externalHost string // K8S_EXTERNAL_HOST: node IP, LB DNS, or proxy hostname + storageSizeGi int // K8S_POSTGRES_STORAGE_GI: default 50 // Route registration for the pg-proxy. When rdb is set, Provision writes // `` → `:5432` so the proxy can route // new client connections to this pod. Deprovision deletes the key. @@ -348,7 +357,7 @@ func (b *K8sBackend) StorageBytes(ctx context.Context, token, providerResourceID adminPass := string(secret.Data["POSTGRES_PASSWORD"]) dsn := fmt.Sprintf("postgres://%s:%s@%s:5432/postgres?sslmode=disable", adminUser, adminPass, svc.Spec.ClusterIP) - conn, err := pgx.Connect(ctx, dsn) + conn, err := pgxConnect(ctx, dsn) if err != nil { return 0, fmt.Errorf("k8s postgres.StorageBytes: connect: %w", err) } @@ -447,7 +456,7 @@ func (b *K8sBackend) Regrade(ctx context.Context, token, providerResourceID stri adminPass := string(secret.Data["POSTGRES_PASSWORD"]) dsn := fmt.Sprintf("postgres://%s:%s@%s:5432/postgres?sslmode=disable", adminUser, adminPass, svc.Spec.ClusterIP) - conn, err := pgx.Connect(ctx, dsn) + conn, err := pgxConnect(ctx, dsn) if err != nil { return RegradeResult{Applied: false, SkipReason: fmt.Sprintf("resource not reachable: connect: %v", err)}, nil } @@ -518,12 +527,12 @@ func (b *K8sBackend) applyNamespace(ctx context.Context, ns string) error { if getErr != nil || existing.Status.Phase != corev1.NamespaceTerminating { return err // not terminating — surface the original AlreadyExists error } - deadline := time.Now().Add(2 * time.Minute) + deadline := time.Now().Add(k8sNsTerminateTimeout) for time.Now().Before(deadline) { select { case <-ctx.Done(): return ctx.Err() - case <-time.After(3 * time.Second): + case <-time.After(k8sNsTerminatePoll): } _, getErr = b.cs.CoreV1().Namespaces().Get(ctx, ns, metav1.GetOptions{}) if k8serrors.IsNotFound(getErr) { @@ -531,7 +540,7 @@ func (b *K8sBackend) applyNamespace(ctx context.Context, ns string) error { return err } } - return fmt.Errorf("namespace %s still terminating after 2 minutes", ns) + return fmt.Errorf("namespace %s still terminating after %s", ns, k8sNsTerminateTimeout) } // applyNetworkPolicy creates a deny-all policy with targeted allow rules. @@ -732,7 +741,7 @@ func (b *K8sBackend) waitPodReady(ctx context.Context, ns, labelSelector string) } func (b *K8sBackend) initDatabase(ctx context.Context, adminDSN, dbName, appUser, appPass string, connLimit int) error { - conn, err := pgx.Connect(ctx, adminDSN) + conn, err := pgxConnect(ctx, adminDSN) if err != nil { return fmt.Errorf("connect: %w", err) } @@ -759,7 +768,7 @@ func (b *K8sBackend) initDatabase(ctx context.Context, adminDSN, dbName, appUser // // adminDSN connects to the "postgres" DB; replace it with dbName for this step. dbDSN := strings.Replace(adminDSN, "/postgres?", "/"+dbName+"?", 1) - if dbConn, dbErr := pgx.Connect(ctx, dbDSN); dbErr == nil { + if dbConn, dbErr := pgxConnect(ctx, dbDSN); dbErr == nil { _, _ = dbConn.Exec(ctx, `CREATE EXTENSION IF NOT EXISTS vector`) _, _ = dbConn.Exec(ctx, fmt.Sprintf(`ALTER EXTENSION vector OWNER TO %q`, appUser)) dbConn.Close(ctx) @@ -772,7 +781,7 @@ func (b *K8sBackend) initDatabase(ctx context.Context, adminDSN, dbName, appUser // k8sRandHex returns a cryptographically random hex string of length n*2. func k8sRandHex(n int) (string, error) { b := make([]byte, n) - if _, err := rand.Read(b); err != nil { + if _, err := randRead(b); err != nil { return "", err } return hex.EncodeToString(b), nil diff --git a/internal/backend/postgres/k8s_seam_test.go b/internal/backend/postgres/k8s_seam_test.go new file mode 100644 index 0000000..c919804 --- /dev/null +++ b/internal/backend/postgres/k8s_seam_test.go @@ -0,0 +1,788 @@ +package postgres + +// k8s_seam_test.go — seam-driven coverage for k8s.go. Uses fake.Clientset for +// the apiserver, the pgConn seam for the in-pod SQL, shrunk poll/timeout vars, +// and a preloaded Ready pod so the full Provision happy path runs in ms. + +import ( + "context" + "os" + "testing" + "time" + + goredis "github.com/redis/go-redis/v9" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/kubernetes/fake" + k8stesting "k8s.io/client-go/testing" + + "instant.dev/provisioner/internal/ctxkeys" +) + +var nsGroupResource = schema.GroupResource{Group: "", Resource: "namespaces"} + +func k8sNotFound(name string) error { return apierrors.NewNotFound(nsGroupResource, name) } +func k8sAlreadyExists(name string) error { return apierrors.NewAlreadyExists(nsGroupResource, name) } + +// shrinkK8sTimers shrinks the pod-ready and namespace-terminate timers so +// timeout branches are reached in milliseconds. +func shrinkK8sTimers(t *testing.T) { + t.Helper() + rt, ri := k8sReadyTimeout, k8sReadyInterval + nt, np := k8sNsTerminateTimeout, k8sNsTerminatePoll + k8sReadyTimeout, k8sReadyInterval = 80*time.Millisecond, 5*time.Millisecond + k8sNsTerminateTimeout, k8sNsTerminatePoll = 80*time.Millisecond, 5*time.Millisecond + t.Cleanup(func() { + k8sReadyTimeout, k8sReadyInterval = rt, ri + k8sNsTerminateTimeout, k8sNsTerminatePoll = nt, np + }) +} + +// readyPod returns a pod in the given namespace with the app=postgres label and +// a PodReady=True condition. +func readyPostgresPod(ns string) *corev1.Pod { + return &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "postgres-1", Namespace: ns, Labels: map[string]string{"app": "postgres"}}, + Status: corev1.PodStatus{ + Conditions: []corev1.PodCondition{{Type: corev1.PodReady, Status: corev1.ConditionTrue}}, + }, + } +} + +func TestSizingForTier_AllTiers(t *testing.T) { + for _, tier := range []string{"anonymous", "hobby", "pro", "team", "growth", "unknown-x"} { + sz := sizingForTier(tier) + if sz.qCPURequests == "" { + t.Errorf("tier %q: empty sizing", tier) + } + } +} + +func TestK8sRandHex_Success(t *testing.T) { + s, err := k8sRandHex(16) + if err != nil || len(s) != 32 { + t.Errorf("k8sRandHex = %q, %v", s, err) + } +} + +func TestK8sRandHex_RandError(t *testing.T) { + orig := randRead + randRead = func([]byte) (int, error) { return 0, errSeam } + t.Cleanup(func() { randRead = orig }) + if _, err := k8sRandHex(16); err == nil { + t.Error("expected rand error") + } +} + +func TestBoolPtr_Seam(t *testing.T) { + if *boolPtr(true) != true || *boolPtr(false) != false { + t.Error("boolPtr broken") + } +} + +func TestMin_Seam(t *testing.T) { + if min(2, 5) != 2 || min(5, 2) != 2 { + t.Error("min broken") + } +} + +func TestPgDataVolumeSource(t *testing.T) { + if pgDataVolumeSource(tierSizing{pvcGi: 10}).PersistentVolumeClaim == nil { + t.Error("pvc>0 should use PVC source") + } + if pgDataVolumeSource(tierSizing{pvcGi: 0}).EmptyDir == nil { + t.Error("pvc==0 should use emptyDir") + } +} + +func TestNewK8sDedicatedBackend_BadKubeconfig(t *testing.T) { + if _, err := NewK8sDedicatedBackend("/nonexistent/kubeconfig", "", "", "", 0); err == nil { + t.Error("expected build-config error for missing kubeconfig") + } +} + +// minimalKubeconfig writes a syntactically valid kubeconfig pointing at a dummy +// API host. BuildConfigFromFlags + NewForConfig both succeed without contacting +// the server, so newK8sBackend's success + default-fill branches are exercised. +const minimalKubeconfig = `apiVersion: v1 +kind: Config +clusters: +- cluster: + server: https://127.0.0.1:6443 + name: test +contexts: +- context: + cluster: test + user: test + name: test +current-context: test +users: +- name: test + user: + token: abc +` + +func TestNewK8sBackend_Success_FillsDefaults(t *testing.T) { + dir := t.TempDir() + kc := dir + "/kubeconfig" + if err := os.WriteFile(kc, []byte(minimalKubeconfig), 0o600); err != nil { + t.Fatalf("write kubeconfig: %v", err) + } + // Empty storageClass/image + storageSizeGi<=0 → all default-fill branches run. + b, err := newK8sBackend(kc, "", "", "ext.host", 0) + if err != nil { + t.Fatalf("newK8sBackend: %v", err) + } + if b.storageClass != "gp3" { + t.Errorf("storageClass default = %q; want gp3", b.storageClass) + } + if b.image != "pgvector/pgvector:pg16" { + t.Errorf("image default = %q", b.image) + } + if b.storageSizeGi != 50 { + t.Errorf("storageSizeGi default = %d; want 50", b.storageSizeGi) + } +} + +// A kubeconfig with an unknown auth-provider parses fine (BuildConfigFromFlags +// succeeds) but kubernetes.NewForConfig rejects it — covers the NewForConfig +// error branch in newK8sBackend. +const badAuthKubeconfig = `apiVersion: v1 +kind: Config +clusters: +- cluster: + server: https://127.0.0.1:6443 + name: test +contexts: +- context: + cluster: test + user: test + name: test +current-context: test +users: +- name: test + user: + auth-provider: + name: this-provider-does-not-exist +` + +func TestNewK8sBackend_NewForConfigError(t *testing.T) { + dir := t.TempDir() + kc := dir + "/kubeconfig" + if err := os.WriteFile(kc, []byte(badAuthKubeconfig), 0o600); err != nil { + t.Fatalf("write kubeconfig: %v", err) + } + if _, err := newK8sBackend(kc, "", "", "", 0); err == nil { + t.Error("expected NewForConfig error for unknown auth-provider") + } +} + +func TestNewK8sBackend_InClusterConfigError(t *testing.T) { + // Empty kubeconfig → rest.InClusterConfig is used. Outside a pod it fails. + // Skip if the test happens to run inside a real pod (in-cluster config valid). + if os.Getenv("KUBERNETES_SERVICE_HOST") != "" { + t.Skip("running inside a k8s pod — in-cluster config would succeed") + } + if _, err := newK8sBackend("", "", "", "", 0); err == nil { + t.Error("expected in-cluster config error outside a pod") + } +} + +func TestK8sBackend_EnableRouteRegistry(t *testing.T) { + b := &K8sBackend{} + b.EnableRouteRegistry(&goredis.Client{}, "") + if b.routePrefix != "pg_route:" { + t.Errorf("default prefix = %q", b.routePrefix) + } + b.EnableRouteRegistry(&goredis.Client{}, "custom:") + if b.routePrefix != "custom:" { + t.Errorf("prefix = %q", b.routePrefix) + } +} + +// Provision happy path with PVC (hobby tier) — exercises every apply* helper, +// waitPodReady success, and initDatabase via the pgConn seam. +func TestK8sBackend_Provision_HappyPath_WithPVC(t *testing.T) { + shrinkK8sTimers(t) + const token = "happytoken" + ns := k8sNsPrefix + token + cs := fake.NewClientset(readyPostgresPod(ns)) + // Stamp a ClusterIP on the Service when it's created (fake doesn't auto-assign). + cs.PrependReactor("create", "services", func(a k8stesting.Action) (bool, runtime.Object, error) { + svc := a.(k8stesting.CreateAction).GetObject().(*corev1.Service) + svc.Spec.ClusterIP = "10.0.0.5" + if len(svc.Spec.Ports) > 0 { + svc.Spec.Ports[0].NodePort = 30111 + } + return false, nil, nil // let the default reactor store it + }) + fc := &fakePGConn{} + withPGXConnect(t, fc, nil) + + b := &K8sBackend{cs: cs, storageClass: "local-path", image: "pgvector/pgvector:pg16", externalHost: "127.0.0.1", storageSizeGi: 50} + creds, err := b.Provision(context.Background(), token, "hobby", 5) + if err != nil { + t.Fatalf("Provision: %v", err) + } + if creds.DatabaseName != k8sDBName(token) || creds.ProviderResourceID != ns { + t.Errorf("creds = %+v", creds) + } +} + +// Provision anonymous tier (pvcGi==0 → no PVC, emptyDir) + route registry on. +func TestK8sBackend_Provision_Anonymous_RouteRegistry(t *testing.T) { + shrinkK8sTimers(t) + const token = "anontoken" + ns := k8sNsPrefix + token + cs := fake.NewClientset(readyPostgresPod(ns)) + cs.PrependReactor("create", "services", func(a k8stesting.Action) (bool, runtime.Object, error) { + svc := a.(k8stesting.CreateAction).GetObject().(*corev1.Service) + svc.Spec.ClusterIP = "10.0.0.6" + if len(svc.Spec.Ports) > 0 { + svc.Spec.Ports[0].NodePort = 30222 + } + return false, nil, nil + }) + fc := &fakePGConn{} + withPGXConnect(t, fc, nil) + + // rdb pointed at a closed port: Set will fail → exercises the route-register + // failure (non-fatal) branch without a real Redis. + rdb := goredis.NewClient(&goredis.Options{Addr: "127.0.0.1:1"}) + b := &K8sBackend{cs: cs, storageClass: "local-path", externalHost: "proxy.host", rdb: rdb, routePrefix: "pg_route:"} + ctx := context.WithValue(context.Background(), ctxkeys.TeamIDKey, "team-xyz") + creds, err := b.Provision(ctx, token, "anonymous", 2) + if err != nil { + t.Fatalf("Provision: %v", err) + } + if creds.URL == "" { + t.Error("empty URL") + } +} + +func TestK8sBackend_Provision_RandError(t *testing.T) { + orig := randRead + randRead = func([]byte) (int, error) { return 0, errSeam } + t.Cleanup(func() { randRead = orig }) + b := &K8sBackend{cs: fake.NewClientset()} + if _, err := b.Provision(context.Background(), "t", "hobby", 5); err == nil { + t.Error("expected rand error from admin pass") + } +} + +// Provision app-pass rand error: first k8sRandHex (admin) succeeds, second +// (app) fails — covers the second rand-error branch (k8s.go:236). +func TestK8sBackend_Provision_AppPassRandError(t *testing.T) { + orig := randRead + var n int + randRead = func(b []byte) (int, error) { + n++ + if n >= 2 { + return 0, errSeam + } + return orig(b) + } + t.Cleanup(func() { randRead = orig }) + b := &K8sBackend{cs: fake.NewClientset()} + if _, err := b.Provision(context.Background(), "t", "hobby", 5); err == nil { + t.Error("expected rand error from app pass") + } +} + +// Each apply* step failing must trigger the rollback path (namespace delete + +// wrapped error). Driven via PrependReactor on the matching resource create. +func TestK8sBackend_Provision_RollbackBranches(t *testing.T) { + cases := []struct { + name string + resource string + }{ + {"network_policy", "networkpolicies"}, + {"resource_quota", "resourcequotas"}, + {"admin_secret", "secrets"}, + {"pvc", "persistentvolumeclaims"}, + {"deployment", "deployments"}, + {"service", "services"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + shrinkK8sTimers(t) + cs := fake.NewClientset() + cs.PrependReactor("create", tc.resource, func(k8stesting.Action) (bool, runtime.Object, error) { + return true, nil, errSeam + }) + // hobby tier has pvcGi>0 so the PVC step runs. + b := &K8sBackend{cs: cs, storageClass: "local-path", externalHost: "127.0.0.1", storageSizeGi: 10} + if _, err := b.Provision(context.Background(), "rb"+tc.name, "hobby", 5); err == nil { + t.Errorf("expected rollback error when %q create fails", tc.resource) + } + }) + } +} + +// applyNamespace failing (not via rollback — it's the first step) returns the +// namespace error directly. +func TestK8sBackend_Provision_NamespaceError(t *testing.T) { + cs := fake.NewClientset() + cs.PrependReactor("create", "namespaces", func(k8stesting.Action) (bool, runtime.Object, error) { + return true, nil, errSeam + }) + b := &K8sBackend{cs: cs, storageClass: "local-path", externalHost: "127.0.0.1", storageSizeGi: 10} + if _, err := b.Provision(context.Background(), "nserr", "hobby", 5); err == nil { + t.Error("expected namespace error") + } +} + +func TestK8sBackend_Provision_WaitReadyTimeout_Rollback(t *testing.T) { + shrinkK8sTimers(t) + const token = "neverready" + cs := fake.NewClientset() // no ready pod → waitPodReady times out + cs.PrependReactor("create", "services", func(a k8stesting.Action) (bool, runtime.Object, error) { + svc := a.(k8stesting.CreateAction).GetObject().(*corev1.Service) + svc.Spec.ClusterIP = "10.0.0.7" + if len(svc.Spec.Ports) > 0 { + svc.Spec.Ports[0].NodePort = 30333 + } + return false, nil, nil + }) + b := &K8sBackend{cs: cs, storageClass: "local-path", externalHost: "127.0.0.1", storageSizeGi: 10} + if _, err := b.Provision(context.Background(), token, "anonymous", 2); err == nil { + t.Error("expected wait-ready timeout error") + } +} + +func TestK8sBackend_Provision_InitDBError_Rollback(t *testing.T) { + shrinkK8sTimers(t) + const token = "initfail" + ns := k8sNsPrefix + token + cs := fake.NewClientset(readyPostgresPod(ns)) + cs.PrependReactor("create", "services", func(a k8stesting.Action) (bool, runtime.Object, error) { + svc := a.(k8stesting.CreateAction).GetObject().(*corev1.Service) + svc.Spec.ClusterIP = "10.0.0.8" + if len(svc.Spec.Ports) > 0 { + svc.Spec.Ports[0].NodePort = 30444 + } + return false, nil, nil + }) + withPGXConnect(t, nil, errSeam) // initDatabase connect fails + b := &K8sBackend{cs: cs, storageClass: "local-path", externalHost: "127.0.0.1", storageSizeGi: 10} + if _, err := b.Provision(context.Background(), token, "anonymous", 2); err == nil { + t.Error("expected init-database error") + } +} + +func TestK8sBackend_initDatabase_Success(t *testing.T) { + fc := &fakePGConn{} + withPGXConnect(t, fc, nil) + b := &K8sBackend{} + if err := b.initDatabase(context.Background(), "postgres://a:b@h:5432/postgres?sslmode=disable", "db_x", "usr_x", "p", 5); err != nil { + t.Fatalf("initDatabase: %v", err) + } +} + +func TestK8sBackend_initDatabase_ConnectError(t *testing.T) { + withPGXConnect(t, nil, errSeam) + b := &K8sBackend{} + if err := b.initDatabase(context.Background(), "dsn", "db_x", "usr_x", "p", 5); err == nil { + t.Error("expected connect error") + } +} + +func TestK8sBackend_initDatabase_ExecError(t *testing.T) { + fc := &fakePGConn{execErrFor: map[string]error{"CREATE USER": errSeam}} + withPGXConnect(t, fc, nil) + b := &K8sBackend{} + if err := b.initDatabase(context.Background(), "dsn", "db_x", "usr_x", "p", 5); err == nil { + t.Error("expected exec error") + } +} + +// initDatabase second connection (to the new DB for vector ext) failing is +// non-fatal — covered by returning a conn for the first connect, error for the +// second. +func TestK8sBackend_initDatabase_VectorExtConnFail_NonFatal(t *testing.T) { + var calls int + fc := &fakePGConn{} + withPGXConnectFunc(t, func(ctx context.Context, dsn string) (pgConn, error) { + calls++ + if calls == 2 { + return nil, errSeam + } + return fc, nil + }) + b := &K8sBackend{} + if err := b.initDatabase(context.Background(), "postgres://a@h/postgres?x", "db_x", "usr_x", "p", -1); err != nil { + t.Errorf("vector-ext connect failure must be non-fatal: %v", err) + } +} + +func TestK8sBackend_waitPodReady_ListError(t *testing.T) { + shrinkK8sTimers(t) + cs := fake.NewClientset() + cs.PrependReactor("list", "pods", func(k8stesting.Action) (bool, runtime.Object, error) { + return true, nil, errSeam + }) + b := &K8sBackend{cs: cs} + if err := b.waitPodReady(context.Background(), "ns", "app=postgres"); err == nil { + t.Error("expected list error") + } +} + +func TestK8sBackend_waitPodReady_CtxCancel(t *testing.T) { + rt, ri := k8sReadyTimeout, k8sReadyInterval + k8sReadyTimeout, k8sReadyInterval = time.Minute, 50*time.Millisecond + t.Cleanup(func() { k8sReadyTimeout, k8sReadyInterval = rt, ri }) + cs := fake.NewClientset() // pod never ready + b := &K8sBackend{cs: cs} + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + if err := b.waitPodReady(ctx, "ns", "app=postgres"); err == nil { + t.Error("expected ctx-cancel error") + } +} + +// --- StorageBytes --- + +func TestK8sBackend_StorageBytes_Success(t *testing.T) { + const ns = "instant-customer-sb" + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "postgres-admin", Namespace: ns}, + Data: map[string][]byte{"POSTGRES_USER": []byte("pgadmin"), "POSTGRES_PASSWORD": []byte("pw")}, + } + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "postgres", Namespace: ns}, + Spec: corev1.ServiceSpec{ClusterIP: "10.0.0.9"}, + } + cs := fake.NewClientset(secret, svc) + fc := &fakePGConn{scanInt64: 1234} + withPGXConnect(t, fc, nil) + b := &K8sBackend{cs: cs} + got, err := b.StorageBytes(context.Background(), "sb", ns) + if err != nil || got != 1234 { + t.Errorf("StorageBytes = %d, %v", got, err) + } +} + +func TestK8sBackend_StorageBytes_LegacyMissingSecret(t *testing.T) { + cs := fake.NewClientset() // secret not found + b := &K8sBackend{cs: cs} + got, err := b.StorageBytes(context.Background(), "tok", "instant-customer-tok") + if err != nil || got != 0 { + t.Errorf("missing secret should fail-soft to 0,nil; got %d, %v", got, err) + } +} + +func TestK8sBackend_StorageBytes_MissingService(t *testing.T) { + const ns = "instant-customer-nosvc" + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "postgres-admin", Namespace: ns}, + Data: map[string][]byte{"POSTGRES_USER": []byte("u"), "POSTGRES_PASSWORD": []byte("p")}, + } + cs := fake.NewClientset(secret) // no service + b := &K8sBackend{cs: cs} + got, err := b.StorageBytes(context.Background(), "tok", ns) + if err != nil || got != 0 { + t.Errorf("missing service should fail-soft to 0,nil; got %d, %v", got, err) + } +} + +func TestK8sBackend_StorageBytes_DefaultNamespace(t *testing.T) { + // providerResourceID empty → derive ns from token; missing secret fail-soft. + cs := fake.NewClientset() + b := &K8sBackend{cs: cs} + if _, err := b.StorageBytes(context.Background(), "tok", ""); err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +// Non-NotFound get errors propagate (they are NOT the legacy fail-soft path). +func TestK8sBackend_StorageBytes_SecretGetError(t *testing.T) { + cs := fake.NewClientset() + cs.PrependReactor("get", "secrets", func(k8stesting.Action) (bool, runtime.Object, error) { + return true, nil, errSeam // generic error, not NotFound + }) + b := &K8sBackend{cs: cs} + if _, err := b.StorageBytes(context.Background(), "tok", "instant-customer-tok"); err == nil { + t.Error("expected non-NotFound secret-get error to propagate") + } +} + +func TestK8sBackend_StorageBytes_ServiceGetError(t *testing.T) { + const ns = "instant-customer-svcerr" + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "postgres-admin", Namespace: ns}, + Data: map[string][]byte{"POSTGRES_USER": []byte("u"), "POSTGRES_PASSWORD": []byte("p")}, + } + cs := fake.NewClientset(secret) + cs.PrependReactor("get", "services", func(k8stesting.Action) (bool, runtime.Object, error) { + return true, nil, errSeam + }) + b := &K8sBackend{cs: cs} + if _, err := b.StorageBytes(context.Background(), "tok", ns); err == nil { + t.Error("expected non-NotFound service-get error to propagate") + } +} + +func TestK8sBackend_StorageBytes_ConnectError_Seam(t *testing.T) { + const ns = "instant-customer-ce" + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "postgres-admin", Namespace: ns}, + Data: map[string][]byte{"POSTGRES_USER": []byte("u"), "POSTGRES_PASSWORD": []byte("p")}, + } + svc := &corev1.Service{ObjectMeta: metav1.ObjectMeta{Name: "postgres", Namespace: ns}, Spec: corev1.ServiceSpec{ClusterIP: "1.2.3.4"}} + cs := fake.NewClientset(secret, svc) + withPGXConnect(t, nil, errSeam) + b := &K8sBackend{cs: cs} + if _, err := b.StorageBytes(context.Background(), "tok", ns); err == nil { + t.Error("expected connect error") + } +} + +func TestK8sBackend_StorageBytes_AllCandidatesMiss(t *testing.T) { + const ns = "instant-customer-miss" + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "postgres-admin", Namespace: ns}, + Data: map[string][]byte{"POSTGRES_USER": []byte("u"), "POSTGRES_PASSWORD": []byte("p")}, + } + svc := &corev1.Service{ObjectMeta: metav1.ObjectMeta{Name: "postgres", Namespace: ns}, Spec: corev1.ServiceSpec{ClusterIP: "1.2.3.4"}} + cs := fake.NewClientset(secret, svc) + fc := &fakePGConn{queryRowErr: errSeam} + withPGXConnect(t, fc, nil) + b := &K8sBackend{cs: cs} + if _, err := b.StorageBytes(context.Background(), "tok", ns); err == nil { + t.Error("expected all-candidates-miss error") + } +} + +// --- Deprovision --- + +func TestK8sBackend_Deprovision_Success(t *testing.T) { + const ns = "instant-customer-dp" + nsObj := &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: ns}} + cs := fake.NewClientset(nsObj) + b := &K8sBackend{cs: cs} + if err := b.Deprovision(context.Background(), "dp", ns); err != nil { + t.Fatalf("Deprovision: %v", err) + } +} + +func TestK8sBackend_Deprovision_AlreadyGone(t *testing.T) { + cs := fake.NewClientset() // namespace not found → idempotent success + b := &K8sBackend{cs: cs} + if err := b.Deprovision(context.Background(), "dp", "instant-customer-gone"); err != nil { + t.Errorf("already-gone should be success: %v", err) + } +} + +func TestK8sBackend_Deprovision_DeleteError_Seam(t *testing.T) { + cs := fake.NewClientset() + cs.PrependReactor("delete", "namespaces", func(k8stesting.Action) (bool, runtime.Object, error) { + return true, nil, errSeam + }) + b := &K8sBackend{cs: cs} + if err := b.Deprovision(context.Background(), "dp", "instant-customer-x"); err == nil { + t.Error("expected delete error") + } +} + +func TestK8sBackend_Deprovision_RouteUnregister(t *testing.T) { + const ns = "instant-customer-route" + cs := fake.NewClientset(&corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: ns}}) + rdb := goredis.NewClient(&goredis.Options{Addr: "127.0.0.1:1"}) // Del fails, non-fatal + b := &K8sBackend{cs: cs, rdb: rdb, routePrefix: "pg_route:"} + if err := b.Deprovision(context.Background(), "route", ns); err != nil { + t.Errorf("route-unregister failure must be non-fatal: %v", err) + } +} + +func TestK8sBackend_Deprovision_DefaultNamespace(t *testing.T) { + cs := fake.NewClientset() + b := &K8sBackend{cs: cs} + if err := b.Deprovision(context.Background(), "tok", ""); err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +// --- Regrade --- + +func k8sRegradeFixture(ns string) (*corev1.Secret, *corev1.Service) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "postgres-admin", Namespace: ns}, + Data: map[string][]byte{"POSTGRES_USER": []byte("u"), "POSTGRES_PASSWORD": []byte("p")}, + } + svc := &corev1.Service{ObjectMeta: metav1.ObjectMeta{Name: "postgres", Namespace: ns}, Spec: corev1.ServiceSpec{ClusterIP: "1.2.3.4"}} + return secret, svc +} + +func TestK8sBackend_Regrade_Success(t *testing.T) { + const ns = "instant-customer-rg" + secret, svc := k8sRegradeFixture(ns) + cs := fake.NewClientset(secret, svc) + fc := &fakePGConn{} + withPGXConnect(t, fc, nil) + b := &K8sBackend{cs: cs} + res, err := b.Regrade(context.Background(), "rg", ns, 20) + if err != nil || !res.Applied || res.AppliedConnLimit != 20 { + t.Errorf("Regrade = %+v, %v", res, err) + } +} + +func TestK8sBackend_Regrade_MissingSecret_Skip(t *testing.T) { + cs := fake.NewClientset() + b := &K8sBackend{cs: cs} + res, err := b.Regrade(context.Background(), "tok", "instant-customer-tok", 5) + if err != nil || res.Applied { + t.Errorf("missing secret should skip without error; got %+v, %v", res, err) + } +} + +func TestK8sBackend_Regrade_MissingService_Skip(t *testing.T) { + const ns = "instant-customer-rgnosvc" + secret, _ := k8sRegradeFixture(ns) + cs := fake.NewClientset(secret) + b := &K8sBackend{cs: cs} + res, err := b.Regrade(context.Background(), "tok", ns, 5) + if err != nil || res.Applied { + t.Errorf("missing service should skip; got %+v, %v", res, err) + } +} + +// Non-NotFound get errors → skip with a "not reachable" reason (no error). +func TestK8sBackend_Regrade_SecretGetError_Skip(t *testing.T) { + cs := fake.NewClientset() + cs.PrependReactor("get", "secrets", func(k8stesting.Action) (bool, runtime.Object, error) { + return true, nil, errSeam + }) + b := &K8sBackend{cs: cs} + res, err := b.Regrade(context.Background(), "tok", "instant-customer-tok", 5) + if err != nil || res.Applied { + t.Errorf("secret-get error should skip without error; got %+v, %v", res, err) + } +} + +func TestK8sBackend_Regrade_ServiceGetError_Skip(t *testing.T) { + const ns = "instant-customer-rgsvcerr" + secret, _ := k8sRegradeFixture(ns) + cs := fake.NewClientset(secret) + cs.PrependReactor("get", "services", func(k8stesting.Action) (bool, runtime.Object, error) { + return true, nil, errSeam + }) + b := &K8sBackend{cs: cs} + res, err := b.Regrade(context.Background(), "tok", ns, 5) + if err != nil || res.Applied { + t.Errorf("service-get error should skip without error; got %+v, %v", res, err) + } +} + +func TestK8sBackend_Regrade_ConnectError_Skip(t *testing.T) { + const ns = "instant-customer-rgconn" + secret, svc := k8sRegradeFixture(ns) + cs := fake.NewClientset(secret, svc) + withPGXConnect(t, nil, errSeam) + b := &K8sBackend{cs: cs} + res, err := b.Regrade(context.Background(), "tok", ns, 5) + if err != nil || res.Applied { + t.Errorf("connect error should skip; got %+v, %v", res, err) + } +} + +func TestK8sBackend_Regrade_AlterRoleAllMiss_Skip(t *testing.T) { + const ns = "instant-customer-rgalter" + secret, svc := k8sRegradeFixture(ns) + cs := fake.NewClientset(secret, svc) + fc := &fakePGConn{execErrFor: map[string]error{"ALTER ROLE": errSeam}} + withPGXConnect(t, fc, nil) + b := &K8sBackend{cs: cs} + res, err := b.Regrade(context.Background(), "tok", ns, 5) + if err != nil || res.Applied { + t.Errorf("alter-role miss should skip; got %+v, %v", res, err) + } +} + +func TestK8sBackend_Regrade_DefaultNamespace(t *testing.T) { + cs := fake.NewClientset() // missing secret → skip path, ns derived from token + b := &K8sBackend{cs: cs} + if _, err := b.Regrade(context.Background(), "tok", "", 5); err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +// --- applyNamespace edge paths --- + +func TestApplyNamespace_AlreadyExists_NotTerminating(t *testing.T) { + const ns = "instant-customer-exists" + existing := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{Name: ns}, + Status: corev1.NamespaceStatus{Phase: corev1.NamespaceActive}, + } + cs := fake.NewClientset(existing) + b := &K8sBackend{cs: cs} + if err := b.applyNamespace(context.Background(), ns); err == nil { + t.Error("expected AlreadyExists error surfaced for active namespace") + } +} + +func TestApplyNamespace_Terminating_TimesOut(t *testing.T) { + shrinkK8sTimers(t) + const ns = "instant-customer-term" + existing := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{Name: ns}, + Status: corev1.NamespaceStatus{Phase: corev1.NamespaceTerminating}, + } + cs := fake.NewClientset(existing) + b := &K8sBackend{cs: cs} + if err := b.applyNamespace(context.Background(), ns); err == nil { + t.Error("expected still-terminating timeout error") + } +} + +// Terminating namespace that disappears (Get → NotFound) on the next poll → +// applyNamespace recreates it successfully. Covers the recreate-success branch. +func TestApplyNamespace_Terminating_RecreatesAfterGone(t *testing.T) { + shrinkK8sTimers(t) + const ns = "instant-customer-recreate" + existing := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{Name: ns}, + Status: corev1.NamespaceStatus{Phase: corev1.NamespaceTerminating}, + } + cs := fake.NewClientset() + + var gets int + cs.PrependReactor("get", "namespaces", func(k8stesting.Action) (bool, runtime.Object, error) { + gets++ + if gets == 1 { + // First Get (AlreadyExists branch) sees the terminating namespace. + return true, existing, nil + } + // Subsequent Get (poll loop) → NotFound so it recreates. + return true, nil, k8sNotFound(ns) + }) + var creates int + cs.PrependReactor("create", "namespaces", func(k8stesting.Action) (bool, runtime.Object, error) { + creates++ + if creates == 1 { + // First create → AlreadyExists so we enter the terminating path. + return true, nil, k8sAlreadyExists(ns) + } + // Recreate after gone → success (reactor fully handles it). + return true, &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: ns}}, nil + }) + + b := &K8sBackend{cs: cs} + if err := b.applyNamespace(context.Background(), ns); err != nil { + t.Errorf("recreate-after-gone should succeed: %v", err) + } +} + +func TestApplyNamespace_CreateError(t *testing.T) { + cs := fake.NewClientset() + cs.PrependReactor("create", "namespaces", func(k8stesting.Action) (bool, runtime.Object, error) { + return true, nil, errSeam + }) + b := &K8sBackend{cs: cs} + if err := b.applyNamespace(context.Background(), "instant-customer-cerr"); err == nil { + t.Error("expected create error") + } +} diff --git a/internal/backend/postgres/local.go b/internal/backend/postgres/local.go index 9a63ad4..af08a3e 100644 --- a/internal/backend/postgres/local.go +++ b/internal/backend/postgres/local.go @@ -19,8 +19,6 @@ import ( "strings" "time" - "github.com/jackc/pgx/v5" - "instant.dev/provisioner/internal/poolident" ) @@ -47,8 +45,10 @@ const sqlCreateExtensionVector = "CREATE EXTENSION IF NOT EXISTS vector" // itself, so one or two attempts is normally enough. const deprovisionDropDBAttempts = 3 -// deprovisionDropDBRetryDelay is the pause between DROP DATABASE attempts. -const deprovisionDropDBRetryDelay = 500 * time.Millisecond +// deprovisionDropDBRetryDelay is the pause between DROP DATABASE attempts. It is +// a package var (not a const) only so tests can shrink it to avoid a real +// 500ms*N wait while still exercising the retry loop. Production value unchanged. +var deprovisionDropDBRetryDelay = 500 * time.Millisecond // pgDatabaseInUseMarker is the Postgres error-message fragment for SQLSTATE // 55006 (object_in_use) raised by DROP DATABASE when a backend is still @@ -109,7 +109,7 @@ func generatePassword(n int) (string, error) { buf := make([]byte, n) charsetLen := big.NewInt(int64(len(alphanumChars))) for i := range buf { - idx, err := rand.Int(rand.Reader, charsetLen) + idx, err := randInt(rand.Reader, charsetLen) if err != nil { return "", fmt.Errorf("generatePassword: %w", err) } @@ -142,7 +142,7 @@ func (b *LocalBackend) Provision(ctx context.Context, token, tier string, connLi defer b.router.ReleasePick(clusterIdx) // Connect as admin. - conn, err := pgx.Connect(ctx, adminURL) + conn, err := pgxConnect(ctx, adminURL) if err != nil { return nil, fmt.Errorf("db.local.Provision: connect: %w", err) } @@ -183,7 +183,7 @@ func (b *LocalBackend) Provision(ctx context.Context, token, tier string, connLi // Connect to the new database to grant schema privileges. // Build the new DB URL by substituting the database name in the admin URL. newDBURL := buildDBURL(adminURL, username, pass, dbName) - adminNewDB, err := pgx.Connect(ctx, buildAdminNewDBURL(adminURL, dbName)) + adminNewDB, err := pgxConnect(ctx, buildAdminNewDBURL(adminURL, dbName)) if err != nil { slog.Error("db.local.Provision: connect new db for schema grant (non-fatal)", "error", err) } else { @@ -238,7 +238,7 @@ func (b *LocalBackend) Provision(ctx context.Context, token, tier string, connLi // segment is stripped before the router parses it. func (b *LocalBackend) StorageBytes(ctx context.Context, token, providerResourceID string) (int64, error) { adminURL := b.router.AdminURLForResource(poolident.BasePRID(providerResourceID)) - conn, err := pgx.Connect(ctx, adminURL) + conn, err := pgxConnect(ctx, adminURL) if err != nil { return 0, fmt.Errorf("db.local.StorageBytes: connect: %w", err) } @@ -276,7 +276,7 @@ func (b *LocalBackend) Deprovision(ctx context.Context, token, providerResourceI username := userNamePrefix + namingToken adminURL := b.router.AdminURLForResource(poolident.BasePRID(providerResourceID)) - conn, err := pgx.Connect(ctx, adminURL) + conn, err := pgxConnect(ctx, adminURL) if err != nil { return fmt.Errorf("db.local.Deprovision: connect: %w", err) } @@ -353,7 +353,7 @@ func (b *LocalBackend) Regrade(ctx context.Context, token, providerResourceID st username := userNamePrefix + poolident.NamingToken(token, providerResourceID) adminURL := b.router.AdminURLForResource(poolident.BasePRID(providerResourceID)) - conn, err := pgx.Connect(ctx, adminURL) + conn, err := pgxConnect(ctx, adminURL) if err != nil { return RegradeResult{Applied: false}, fmt.Errorf("db.local.Regrade: connect: %w", err) } diff --git a/internal/backend/postgres/local_seam_test.go b/internal/backend/postgres/local_seam_test.go new file mode 100644 index 0000000..bc6ad53 --- /dev/null +++ b/internal/backend/postgres/local_seam_test.go @@ -0,0 +1,402 @@ +package postgres + +// local_seam_test.go — seam-driven coverage for local.go: generatePassword, +// Provision (success + every Exec-error branch), StorageBytes, Deprovision +// (success, retry loop, terminal error, ctx cancel), Regrade, and the small +// URL-building helpers. + +import ( + "context" + "errors" + "io" + "math/big" + "testing" + "time" +) + +func TestGeneratePassword_Success(t *testing.T) { + got, err := generatePassword(16) + if err != nil { + t.Fatalf("generatePassword: %v", err) + } + if len(got) != 16 { + t.Errorf("len = %d; want 16", len(got)) + } +} + +func TestGeneratePassword_RandError(t *testing.T) { + orig := randInt + randInt = func(_ io.Reader, _ *big.Int) (*big.Int, error) { return nil, errSeam } + t.Cleanup(func() { randInt = orig }) + + if _, err := generatePassword(8); err == nil { + t.Error("expected error when randInt fails") + } +} + +func TestLocalBackend_Provision_Success(t *testing.T) { + fc := &fakePGConn{} + withPGXConnect(t, fc, nil) + + b := newLocalBackend("postgres://u:p@h:5432/postgres?sslmode=disable") + creds, err := b.Provision(context.Background(), "tok-123", "hobby", 8) + if err != nil { + t.Fatalf("Provision: %v", err) + } + if creds.DatabaseName != "db_tok-123" || creds.Username != "usr_tok-123" { + t.Errorf("creds = %+v", creds) + } + if creds.ProviderResourceID != "local:0" { + t.Errorf("PRID = %q; want local:0", creds.ProviderResourceID) + } +} + +func TestLocalBackend_Provision_GenPasswordError(t *testing.T) { + orig := randInt + randInt = func(_ io.Reader, _ *big.Int) (*big.Int, error) { return nil, errSeam } + t.Cleanup(func() { randInt = orig }) + b := newLocalBackend("") + if _, err := b.Provision(context.Background(), "t", "hobby", -1); err == nil { + t.Error("expected generatePassword error") + } +} + +func TestLocalBackend_Provision_PickError(t *testing.T) { + // A router with no clusters → Pick returns an error. + b := &LocalBackend{router: newClusterRouter(nil, 0)} + if _, err := b.Provision(context.Background(), "t", "hobby", -1); err == nil { + t.Error("expected Pick error when no clusters configured") + } +} + +func TestLocalBackend_Provision_ConnectError_Seam(t *testing.T) { + withPGXConnect(t, nil, errSeam) + b := newLocalBackend("") + if _, err := b.Provision(context.Background(), "t", "hobby", -1); err == nil { + t.Error("expected connect error") + } +} + +func TestLocalBackend_Provision_ExecErrorBranches(t *testing.T) { + cases := []struct { + name string + sub string + }{ + {"create_database", "CREATE DATABASE"}, + {"create_user", "CREATE USER"}, + {"grant_database", "GRANT ALL PRIVILEGES ON DATABASE"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + fc := &fakePGConn{execErrFor: map[string]error{tc.sub: errSeam}} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + if _, err := b.Provision(context.Background(), "t", "hobby", 5); err == nil { + t.Errorf("expected error when %q fails", tc.sub) + } + }) + } +} + +// Non-fatal branches: REVOKE CONNECT and (post-new-db) GRANT SCHEMA / CREATE +// EXTENSION failures are logged but Provision still succeeds. +func TestLocalBackend_Provision_NonFatalBranches(t *testing.T) { + fc := &fakePGConn{execErrFor: map[string]error{ + "REVOKE CONNECT": errSeam, + "GRANT ALL ON SCHEMA": errSeam, + "CREATE EXTENSION": errSeam, + }} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + if _, err := b.Provision(context.Background(), "t", "hobby", -1); err != nil { + t.Errorf("non-fatal failures should not fail Provision: %v", err) + } +} + +// When the second connection (to the new DB for schema grant) fails, Provision +// logs and continues — the schema grant is best-effort. +func TestLocalBackend_Provision_NewDBConnectError_NonFatal(t *testing.T) { + var calls int + fc := &fakePGConn{} + withPGXConnectFunc(t, func(ctx context.Context, dsn string) (pgConn, error) { + calls++ + if calls == 2 { // second connect = new-db schema grant + return nil, errSeam + } + return fc, nil + }) + b := newLocalBackend("") + if _, err := b.Provision(context.Background(), "t", "hobby", -1); err != nil { + t.Errorf("new-db connect failure must be non-fatal: %v", err) + } +} + +func TestLocalBackend_Provision_CloseError_NonFatal(t *testing.T) { + fc := &fakePGConn{closeErr: errSeam} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + if _, err := b.Provision(context.Background(), "t", "hobby", -1); err != nil { + t.Errorf("Close error must be non-fatal: %v", err) + } + if fc.closeCalls == 0 { + t.Error("Close was never called") + } +} + +func TestLocalBackend_StorageBytes_Success(t *testing.T) { + fc := &fakePGConn{scanInt64: 4242} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + got, err := b.StorageBytes(context.Background(), "tok", "local:0") + if err != nil { + t.Fatalf("StorageBytes: %v", err) + } + if got != 4242 { + t.Errorf("got %d; want 4242", got) + } +} + +func TestLocalBackend_StorageBytes_ConnectError_Seam(t *testing.T) { + withPGXConnect(t, nil, errSeam) + b := newLocalBackend("") + if _, err := b.StorageBytes(context.Background(), "tok", ""); err == nil { + t.Error("expected connect error") + } +} + +func TestLocalBackend_StorageBytes_ScanError(t *testing.T) { + fc := &fakePGConn{queryRowErr: errSeam} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + if _, err := b.StorageBytes(context.Background(), "tok", ""); err == nil { + t.Error("expected scan error") + } +} + +func TestLocalBackend_Deprovision_Success(t *testing.T) { + fc := &fakePGConn{} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + if err := b.Deprovision(context.Background(), "tok", "local:0"); err != nil { + t.Fatalf("Deprovision: %v", err) + } +} + +func TestLocalBackend_Deprovision_ConnectError_Seam(t *testing.T) { + withPGXConnect(t, nil, errSeam) + b := newLocalBackend("") + if err := b.Deprovision(context.Background(), "tok", ""); err == nil { + t.Error("expected connect error") + } +} + +// Non-fatal: REVOKE CONNECT, terminate, DROP USER all log-and-continue. +func TestLocalBackend_Deprovision_NonFatalBranches(t *testing.T) { + fc := &fakePGConn{execErrFor: map[string]error{ + "REVOKE CONNECT": errSeam, + "pg_terminate_backend": errSeam, + "DROP USER": errSeam, + }} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + if err := b.Deprovision(context.Background(), "tok", ""); err != nil { + t.Errorf("non-fatal failures should not fail Deprovision: %v", err) + } +} + +// Terminal DROP DATABASE error (not the in-use race) breaks the loop on attempt 1. +func TestLocalBackend_Deprovision_DropDBTerminalError(t *testing.T) { + fc := &fakePGConn{execErrFor: map[string]error{"DROP DATABASE": errors.New("permission denied")}} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + if err := b.Deprovision(context.Background(), "tok", ""); err == nil { + t.Error("expected terminal DROP DATABASE error") + } +} + +// The in-use race retries deprovisionDropDBAttempts times, then returns the error. +func TestLocalBackend_Deprovision_DropDBInUseRetries(t *testing.T) { + orig := deprovisionDropDBRetryDelay + deprovisionDropDBRetryDelay = time.Millisecond + t.Cleanup(func() { deprovisionDropDBRetryDelay = orig }) + + fc := &fakePGConn{execErrFor: map[string]error{ + "DROP DATABASE": errors.New("database is being accessed by other users"), + }} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + if err := b.Deprovision(context.Background(), "tok", ""); err == nil { + t.Error("expected in-use error after exhausting retries") + } +} + +// ctx cancellation inside the retry loop breaks out via ctx.Done(). +func TestLocalBackend_Deprovision_CtxCancelDuringRetry(t *testing.T) { + orig := deprovisionDropDBRetryDelay + deprovisionDropDBRetryDelay = time.Second + t.Cleanup(func() { deprovisionDropDBRetryDelay = orig }) + + fc := &fakePGConn{execErrFor: map[string]error{ + "DROP DATABASE": errors.New("database is being accessed by other users"), + }} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + ctx, cancel := context.WithCancel(context.Background()) + cancel() // already cancelled → loop's select hits ctx.Done immediately + if err := b.Deprovision(ctx, "tok", ""); err == nil { + t.Error("expected error after ctx cancel") + } +} + +func TestLocalBackend_Regrade_Success(t *testing.T) { + fc := &fakePGConn{} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + res, err := b.Regrade(context.Background(), "tok", "local:0", 20) + if err != nil { + t.Fatalf("Regrade: %v", err) + } + if !res.Applied || res.AppliedConnLimit != 20 { + t.Errorf("res = %+v; want Applied with 20", res) + } +} + +func TestLocalBackend_Regrade_ZeroNormalizedToUnlimited(t *testing.T) { + fc := &fakePGConn{} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + res, err := b.Regrade(context.Background(), "tok", "", 0) + if err != nil { + t.Fatalf("Regrade: %v", err) + } + if res.AppliedConnLimit != -1 { + t.Errorf("0 should normalize to -1; got %d", res.AppliedConnLimit) + } +} + +func TestLocalBackend_Regrade_ConnectError_Seam(t *testing.T) { + withPGXConnect(t, nil, errSeam) + b := newLocalBackend("") + if _, err := b.Regrade(context.Background(), "tok", "", 5); err == nil { + t.Error("expected connect error") + } +} + +func TestLocalBackend_Regrade_AlterRoleError(t *testing.T) { + fc := &fakePGConn{execErrFor: map[string]error{"ALTER ROLE": errSeam}} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + if _, err := b.Regrade(context.Background(), "tok", "", 5); err == nil { + t.Error("expected ALTER ROLE error") + } +} + +// Close-error branches in the deferred disconnect of StorageBytes and Regrade. +func TestLocalBackend_StorageBytes_CloseErrorLogged(t *testing.T) { + fc := &fakePGConn{scanInt64: 1, closeErr: errSeam} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + if _, err := b.StorageBytes(context.Background(), "tok", ""); err != nil { + t.Errorf("Close error must be non-fatal: %v", err) + } +} + +func TestLocalBackend_Regrade_CloseErrorLogged(t *testing.T) { + fc := &fakePGConn{closeErr: errSeam} + withPGXConnect(t, fc, nil) + b := newLocalBackend("") + if _, err := b.Regrade(context.Background(), "tok", "", 5); err != nil { + t.Errorf("Close error must be non-fatal: %v", err) + } +} + +func TestLocalBackend_StartShutdown_Seam(t *testing.T) { + // Install a fast, deterministic seam so the immediate first refreshCounts + // poll returns at once rather than dialing a real (absent) Postgres for the + // full 5s connect timeout. Shutdown joins the poll goroutine, so this also + // guarantees no router goroutine survives the test to race the global seam. + withPGXConnect(t, &fakePGConn{}, nil) + b := newLocalBackend("") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + b.Start(ctx) + b.Shutdown() +} + +func TestBuildDBURL_PublicHostPort(t *testing.T) { + t.Setenv("POSTGRES_PUBLIC_HOST_PORT", "pg.example.com:6543") + got := buildDBURL("postgres://a:b@internal:5432/postgres", "u", "p", "db_x") + if got != "postgres://u:p@pg.example.com:6543/db_x?sslmode=disable" { + t.Errorf("got %q", got) + } +} + +func TestBuildDBURL_PublicHostDefaultPort(t *testing.T) { + t.Setenv("POSTGRES_PUBLIC_HOST_PORT", "") + t.Setenv("POSTGRES_PUBLIC_HOST", "pg.example.com") + t.Setenv("POSTGRES_PUBLIC_PORT", "") + got := buildDBURL("postgres://a:b@internal:5432/postgres", "u", "p", "db_x") + if got != "postgres://u:p@pg.example.com:5432/db_x?sslmode=disable" { + t.Errorf("got %q", got) + } +} + +func TestBuildDBURL_PublicHostExplicitPort(t *testing.T) { + t.Setenv("POSTGRES_PUBLIC_HOST_PORT", "") + t.Setenv("POSTGRES_PUBLIC_HOST", "pg.example.com") + t.Setenv("POSTGRES_PUBLIC_PORT", "7777") + got := buildDBURL("postgres://a:b@internal:5432/postgres", "u", "p", "db_x") + if got != "postgres://u:p@pg.example.com:7777/db_x?sslmode=disable" { + t.Errorf("got %q", got) + } +} + +func TestBuildDBURL_FallbackToAdminHost(t *testing.T) { + t.Setenv("POSTGRES_PUBLIC_HOST_PORT", "") + t.Setenv("POSTGRES_PUBLIC_HOST", "") + got := buildDBURL("postgres://a:b@internal-host:5432/postgres", "u", "p", "db_x") + if got != "postgres://u:p@internal-host:5432/db_x?sslmode=disable" { + t.Errorf("got %q", got) + } +} + +func TestBuildAdminNewDBURL_Seam(t *testing.T) { + if got := buildAdminNewDBURL("postgres://a:b@h:5432/postgres?x=1", "db_y"); got != "postgres://a:b@h:5432/db_y" { + t.Errorf("got %q", got) + } + // No slash → appends. + if got := buildAdminNewDBURL("noslash", "db_y"); got != "noslash/db_y" { + t.Errorf("got %q", got) + } +} + +func TestExtractHost_Seam(t *testing.T) { + cases := map[string]string{ + "postgres://u:p@h:5432/db": "h:5432", + "postgres://u:p@h/db": "h", + "postgres://h:5432": "h:5432", + "h-no-prefix": "h-no-prefix", + } + for in, want := range cases { + if got := extractHost(in); got != want { + t.Errorf("extractHost(%q) = %q; want %q", in, got, want) + } + } +} + +func TestIndexOf_Seam(t *testing.T) { + if indexOf("abc@def", '@') != 3 { + t.Error("indexOf @ wrong") + } + if indexOf("abc", '@') != -1 { + t.Error("indexOf missing should be -1") + } +} + +func TestPublicHostPort_Empty(t *testing.T) { + t.Setenv("POSTGRES_PUBLIC_HOST_PORT", "") + t.Setenv("POSTGRES_PUBLIC_HOST", "") + if got := publicHostPort(); got != "" { + t.Errorf("got %q; want empty", got) + } +} diff --git a/internal/backend/postgres/neon.go b/internal/backend/postgres/neon.go index 9387006..5c216c1 100644 --- a/internal/backend/postgres/neon.go +++ b/internal/backend/postgres/neon.go @@ -8,7 +8,6 @@ import ( "context" "encoding/json" "fmt" - "io" "log/slog" "net/http" "time" @@ -100,12 +99,12 @@ func (b *NeonBackend) Provision(ctx context.Context, token, tier string, connLim "pg_version": 16, }, } - bodyBytes, err := json.Marshal(body) + bodyBytes, err := jsonMarshal(body) if err != nil { return nil, fmt.Errorf("db.neon.Provision: marshal: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, b.base()+"/projects", bytes.NewReader(bodyBytes)) + req, err := httpNewRequestWithContext(ctx, http.MethodPost, b.base()+"/projects", bytes.NewReader(bodyBytes)) if err != nil { return nil, fmt.Errorf("db.neon.Provision: new request: %w", err) } @@ -118,7 +117,7 @@ func (b *NeonBackend) Provision(ctx context.Context, token, tier string, connLim } defer resp.Body.Close() - respBytes, err := io.ReadAll(resp.Body) + respBytes, err := ioReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("db.neon.Provision: read body: %w", err) } @@ -167,7 +166,7 @@ func (b *NeonBackend) StorageBytes(ctx context.Context, token, providerResourceI return 0, fmt.Errorf("db.neon.StorageBytes: empty providerResourceID") } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, + req, err := httpNewRequestWithContext(ctx, http.MethodGet, b.base()+"/projects/"+providerResourceID, nil) if err != nil { return 0, fmt.Errorf("db.neon.StorageBytes: new request: %w", err) @@ -180,7 +179,7 @@ func (b *NeonBackend) StorageBytes(ctx context.Context, token, providerResourceI } defer resp.Body.Close() - respBytes, err := io.ReadAll(resp.Body) + respBytes, err := ioReadAll(resp.Body) if err != nil { return 0, fmt.Errorf("db.neon.StorageBytes: read body: %w", err) } @@ -210,7 +209,7 @@ func (b *NeonBackend) Deprovision(ctx context.Context, token, providerResourceID return fmt.Errorf("db.neon.Deprovision: empty providerResourceID") } - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, + req, err := httpNewRequestWithContext(ctx, http.MethodDelete, b.base()+"/projects/"+providerResourceID, nil) if err != nil { return fmt.Errorf("db.neon.Deprovision: new request: %w", err) @@ -224,7 +223,7 @@ func (b *NeonBackend) Deprovision(ctx context.Context, token, providerResourceID defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(resp.Body) + body, _ := ioReadAll(resp.Body) return fmt.Errorf("db.neon.Deprovision: unexpected status %d: %s", resp.StatusCode, string(body)) } @@ -245,7 +244,7 @@ func (b *NeonBackend) Regrade(ctx context.Context, token, providerResourceID str // so the caller can decide whether to proceed with a create. // GET https://console.neon.tech/api/v2/projects func (b *NeonBackend) findProjectByName(ctx context.Context, projectName string) (string, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, b.base()+"/projects", nil) + req, err := httpNewRequestWithContext(ctx, http.MethodGet, b.base()+"/projects", nil) if err != nil { return "", fmt.Errorf("db.neon.findProjectByName: new request: %w", err) } @@ -257,7 +256,7 @@ func (b *NeonBackend) findProjectByName(ctx context.Context, projectName string) } defer resp.Body.Close() - respBytes, err := io.ReadAll(resp.Body) + respBytes, err := ioReadAll(resp.Body) if err != nil { return "", fmt.Errorf("db.neon.findProjectByName: read body: %w", err) } diff --git a/internal/backend/postgres/neon_seam_test.go b/internal/backend/postgres/neon_seam_test.go new file mode 100644 index 0000000..6da80ca --- /dev/null +++ b/internal/backend/postgres/neon_seam_test.go @@ -0,0 +1,375 @@ +package postgres + +// neon_seam_test.go — coverage for neon.go paths the existing idempotency tests +// don't reach: StorageBytes, Deprovision, findProjectByName, the base() default, +// and the marshal / new-request / read-body / non-2xx / unmarshal error wraps +// (driven via the json/http/io seams). + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func neonTestServer(t *testing.T, h http.HandlerFunc) *NeonBackend { + t.Helper() + srv := httptest.NewServer(h) + t.Cleanup(srv.Close) + return &NeonBackend{apiKey: "key", regionID: defaultNeonRegion, client: srv.Client(), apiBase: srv.URL} +} + +func TestNeon_base_DefaultsWhenEmpty(t *testing.T) { + b := &NeonBackend{} + if b.base() != neonAPIBase { + t.Errorf("base() = %q; want %q", b.base(), neonAPIBase) + } +} + +func TestNewNeonBackend_DefaultRegion(t *testing.T) { + b := newNeonBackend("k", "") + if b.regionID != defaultNeonRegion { + t.Errorf("regionID = %q; want default", b.regionID) + } + b2 := newNeonBackend("k", "eu-x") + if b2.regionID != "eu-x" { + t.Errorf("regionID = %q; want eu-x", b2.regionID) + } +} + +func TestNeon_Provision_CreateSuccess_FullDecode(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _ = json.NewEncoder(w).Encode(map[string]any{"projects": []any{}}) + return + } + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]any{ + "project": map[string]string{"id": "p1"}, + "connection_uris": []map[string]string{{"connection_uri": "postgres://c"}}, + }) + }) + creds, err := b.Provision(context.Background(), "tok", "team", -1) + if err != nil { + t.Fatalf("Provision: %v", err) + } + if creds.ProviderResourceID != "p1" || creds.URL != "postgres://c" { + t.Errorf("creds = %+v", creds) + } +} + +func TestNeon_Provision_MarshalError(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"projects": []any{}}) + }) + orig := jsonMarshal + jsonMarshal = func(any) ([]byte, error) { return nil, errSeam } + t.Cleanup(func() { jsonMarshal = orig }) + if _, err := b.Provision(context.Background(), "tok", "team", -1); err == nil { + t.Error("expected marshal error") + } +} + +func TestNeon_Provision_NewRequestError(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"projects": []any{}}) + }) + orig := httpNewRequestWithContext + httpNewRequestWithContext = func(context.Context, string, string, io.Reader) (*http.Request, error) { + return nil, errSeam + } + t.Cleanup(func() { httpNewRequestWithContext = orig }) + if _, err := b.Provision(context.Background(), "tok", "team", -1); err == nil { + t.Error("expected new-request error") + } +} + +func TestNeon_Provision_HTTPError(t *testing.T) { + // Point at an unroutable base so client.Do fails. + b := &NeonBackend{apiKey: "k", client: &http.Client{}, apiBase: "http://127.0.0.1:1"} + if _, err := b.Provision(context.Background(), "tok", "team", -1); err == nil { + t.Error("expected http error") + } +} + +func TestNeon_Provision_Non2xx(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _ = json.NewEncoder(w).Encode(map[string]any{"projects": []any{}}) + return + } + http.Error(w, "boom", http.StatusInternalServerError) + }) + if _, err := b.Provision(context.Background(), "tok", "team", -1); err == nil { + t.Error("expected non-2xx error") + } +} + +func TestNeon_Provision_ReadBodyError(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _ = json.NewEncoder(w).Encode(map[string]any{"projects": []any{}}) + return + } + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte("{}")) + }) + orig := ioReadAll + var n int + ioReadAll = func(r io.Reader) ([]byte, error) { + n++ + if n >= 2 { // first ReadAll is the GET list; fail on the POST create body + return nil, errSeam + } + return io.ReadAll(r) + } + t.Cleanup(func() { ioReadAll = orig }) + if _, err := b.Provision(context.Background(), "tok", "team", -1); err == nil { + t.Error("expected read-body error") + } +} + +func TestNeon_Provision_UnmarshalError(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _ = json.NewEncoder(w).Encode(map[string]any{"projects": []any{}}) + return + } + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte("not-json")) + }) + if _, err := b.Provision(context.Background(), "tok", "team", -1); err == nil { + t.Error("expected unmarshal error") + } +} + +func TestNeon_Provision_EmptyProjectID(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _ = json.NewEncoder(w).Encode(map[string]any{"projects": []any{}}) + return + } + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]any{"project": map[string]string{"id": ""}}) + }) + if _, err := b.Provision(context.Background(), "tok", "team", -1); err == nil { + t.Error("expected empty-project-id error") + } +} + +func TestNeon_Provision_NoConnectionURI(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _ = json.NewEncoder(w).Encode(map[string]any{"projects": []any{}}) + return + } + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]any{"project": map[string]string{"id": "p1"}}) + }) + if _, err := b.Provision(context.Background(), "tok", "team", -1); err == nil { + t.Error("expected no-connection-uri error") + } +} + +func TestNeon_StorageBytes_Success(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "project": map[string]any{"usage": map[string]any{"data_storage_bytes_hour": 9999}}, + }) + }) + got, err := b.StorageBytes(context.Background(), "tok", "p1") + if err != nil { + t.Fatalf("StorageBytes: %v", err) + } + if got != 9999 { + t.Errorf("got %d; want 9999", got) + } +} + +func TestNeon_StorageBytes_EmptyPRID(t *testing.T) { + b := &NeonBackend{} + if _, err := b.StorageBytes(context.Background(), "tok", ""); err == nil { + t.Error("expected empty-PRID error") + } +} + +func TestNeon_StorageBytes_NewRequestError(t *testing.T) { + orig := httpNewRequestWithContext + httpNewRequestWithContext = func(context.Context, string, string, io.Reader) (*http.Request, error) { + return nil, errSeam + } + t.Cleanup(func() { httpNewRequestWithContext = orig }) + b := &NeonBackend{apiKey: "k", client: &http.Client{}, apiBase: "http://x"} + if _, err := b.StorageBytes(context.Background(), "tok", "p1"); err == nil { + t.Error("expected new-request error") + } +} + +func TestNeon_StorageBytes_HTTPError(t *testing.T) { + b := &NeonBackend{apiKey: "k", client: &http.Client{}, apiBase: "http://127.0.0.1:1"} + if _, err := b.StorageBytes(context.Background(), "tok", "p1"); err == nil { + t.Error("expected http error") + } +} + +func TestNeon_StorageBytes_Non2xx(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "no", http.StatusNotFound) + }) + if _, err := b.StorageBytes(context.Background(), "tok", "p1"); err == nil { + t.Error("expected non-2xx error") + } +} + +func TestNeon_StorageBytes_ReadBodyError(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("{}")) }) + orig := ioReadAll + ioReadAll = func(io.Reader) ([]byte, error) { return nil, errSeam } + t.Cleanup(func() { ioReadAll = orig }) + if _, err := b.StorageBytes(context.Background(), "tok", "p1"); err == nil { + t.Error("expected read-body error") + } +} + +func TestNeon_StorageBytes_UnmarshalError(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("xx")) }) + if _, err := b.StorageBytes(context.Background(), "tok", "p1"); err == nil { + t.Error("expected unmarshal error") + } +} + +func TestNeon_Deprovision_Success(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + if err := b.Deprovision(context.Background(), "tok", "p1"); err != nil { + t.Fatalf("Deprovision: %v", err) + } +} + +func TestNeon_Deprovision_EmptyPRID(t *testing.T) { + b := &NeonBackend{} + if err := b.Deprovision(context.Background(), "tok", ""); err == nil { + t.Error("expected empty-PRID error") + } +} + +func TestNeon_Deprovision_NewRequestError(t *testing.T) { + orig := httpNewRequestWithContext + httpNewRequestWithContext = func(context.Context, string, string, io.Reader) (*http.Request, error) { + return nil, errSeam + } + t.Cleanup(func() { httpNewRequestWithContext = orig }) + b := &NeonBackend{apiKey: "k", client: &http.Client{}, apiBase: "http://x"} + if err := b.Deprovision(context.Background(), "tok", "p1"); err == nil { + t.Error("expected new-request error") + } +} + +func TestNeon_Deprovision_HTTPError(t *testing.T) { + b := &NeonBackend{apiKey: "k", client: &http.Client{}, apiBase: "http://127.0.0.1:1"} + if err := b.Deprovision(context.Background(), "tok", "p1"); err == nil { + t.Error("expected http error") + } +} + +func TestNeon_Deprovision_Non2xx(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "no", http.StatusBadRequest) + }) + if err := b.Deprovision(context.Background(), "tok", "p1"); err == nil { + t.Error("expected non-2xx error") + } +} + +func TestNeon_Regrade_NoOp(t *testing.T) { + b := &NeonBackend{} + res, err := b.Regrade(context.Background(), "tok", "p1", 5) + if err != nil || res.Applied { + t.Errorf("Regrade = %+v, %v; want no-op skip", res, err) + } +} + +func TestNeon_findProjectByName_NewRequestError(t *testing.T) { + orig := httpNewRequestWithContext + httpNewRequestWithContext = func(context.Context, string, string, io.Reader) (*http.Request, error) { + return nil, errSeam + } + t.Cleanup(func() { httpNewRequestWithContext = orig }) + b := &NeonBackend{apiKey: "k", client: &http.Client{}, apiBase: "http://x"} + if _, err := b.findProjectByName(context.Background(), "n"); err == nil { + t.Error("expected new-request error") + } +} + +func TestNeon_findProjectByName_HTTPError(t *testing.T) { + b := &NeonBackend{apiKey: "k", client: &http.Client{}, apiBase: "http://127.0.0.1:1"} + if _, err := b.findProjectByName(context.Background(), "n"); err == nil { + t.Error("expected http error") + } +} + +func TestNeon_findProjectByName_Non2xx(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "no", http.StatusForbidden) + }) + if _, err := b.findProjectByName(context.Background(), "n"); err == nil { + t.Error("expected non-2xx error") + } +} + +func TestNeon_findProjectByName_ReadBodyError(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("{}")) }) + orig := ioReadAll + ioReadAll = func(io.Reader) ([]byte, error) { return nil, errSeam } + t.Cleanup(func() { ioReadAll = orig }) + if _, err := b.findProjectByName(context.Background(), "n"); err == nil { + t.Error("expected read-body error") + } +} + +func TestNeon_findProjectByName_UnmarshalError(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("xx")) }) + if _, err := b.findProjectByName(context.Background(), "n"); err == nil { + t.Error("expected unmarshal error") + } +} + +func TestNeon_findProjectByName_NotFoundReturnsEmpty(t *testing.T) { + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "projects": []map[string]string{{"id": "other", "name": "instant-zzz"}}, + }) + }) + id, err := b.findProjectByName(context.Background(), "instant-target") + if err != nil || id != "" { + t.Errorf("findProjectByName = %q, %v; want empty, nil", id, err) + } +} + +// Provision lookup-error path: findProjectByName errors → Provision logs and +// proceeds to create. +func TestNeon_Provision_LookupError_ProceedsToCreate(t *testing.T) { + var created bool + b := neonTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + http.Error(w, "list-fail", http.StatusInternalServerError) + return + } + created = true + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]any{ + "project": map[string]string{"id": "p9"}, + "connection_uris": []map[string]string{{"connection_uri": "postgres://u"}}, + }) + }) + if _, err := b.Provision(context.Background(), "tok", "team", -1); err != nil { + t.Fatalf("Provision: %v", err) + } + if !created { + t.Error("expected create to proceed after lookup error") + } +} + +var _ = errors.New diff --git a/internal/backend/postgres/seams.go b/internal/backend/postgres/seams.go new file mode 100644 index 0000000..bb2dafc --- /dev/null +++ b/internal/backend/postgres/seams.go @@ -0,0 +1,70 @@ +package postgres + +// seams.go — test seams for the postgres backends. +// +// These package-level function variables and the narrow pgConn interface let +// the test suite drive the error and success branches of code paths that would +// otherwise require a live Postgres cluster, a working crypto/rand, or a +// successful net/http construction. In production every seam defaults to the +// real implementation, so behaviour is identical to calling pgx.Connect / +// rand.Read / json.Marshal / http.NewRequestWithContext directly. +// +// Why a seam instead of a live cluster: the provisioner's coverage CI job is +// mock-only (no service containers — see .github/workflows/coverage.yml). The +// redis backend reaches ≥95% the same way: fakes + connection-failure branches. +// For postgres the SQL happy path (CREATE DATABASE / CREATE USER success) is +// the bulk of the statements, so a fake pgConn is the only way to execute those +// lines deterministically in CI without a real database. + +import ( + "context" + "crypto/rand" + "encoding/json" + "io" + "math/big" + "net/http" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// pgConn is the narrow subset of *pgx.Conn that the postgres backends use. +// *pgx.Conn satisfies it; tests inject a fake. +type pgConn interface { + Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row + Close(ctx context.Context) error +} + +// pgxConnect is the seam for pgx.Connect. Production wraps the real connection +// so it satisfies the pgConn interface. A test overrides this var to return a +// fake pgConn (success path) or an error (connect-failure path). +var pgxConnect = func(ctx context.Context, connString string) (pgConn, error) { + c, err := pgx.Connect(ctx, connString) + if err != nil { + return nil, err + } + return c, nil +} + +// randRead is the seam for crypto/rand.Read. Overriding it to return an error +// covers the rand-failure branch in k8sRandHex without weakening production +// (which uses the real CSPRNG). +var randRead = rand.Read + +// randInt is the seam for crypto/rand.Int, used by generatePassword. A test +// overrides it to force the rand-failure branch. +var randInt func(rand io.Reader, max *big.Int) (*big.Int, error) = rand.Int + +// jsonMarshal is the seam for encoding/json.Marshal, used by the Neon/dedicated +// request bodies. Overriding it forces the marshal-error wrap branch. +var jsonMarshal = json.Marshal + +// httpNewRequestWithContext is the seam for http.NewRequestWithContext, used by +// every Neon/dedicated API call. Overriding it forces the construction-error +// wrap branch (otherwise unreachable — a valid method + URL never errors). +var httpNewRequestWithContext = http.NewRequestWithContext + +// ioReadAll is the seam for io.ReadAll on an HTTP response body. Overriding it +// forces the read-body-error wrap branch. +var ioReadAll = io.ReadAll diff --git a/internal/backend/postgres/seams_test.go b/internal/backend/postgres/seams_test.go new file mode 100644 index 0000000..a5b0956 --- /dev/null +++ b/internal/backend/postgres/seams_test.go @@ -0,0 +1,101 @@ +package postgres + +// seams_test.go — shared test doubles for the seam-driven coverage tests. +// +// fakePGConn implements the pgConn interface so the SQL happy paths and each +// individual Exec/QueryRow error branch can be exercised in CI without a live +// Postgres cluster (the provisioner coverage job is mock-only). withPGXConnect +// and the rand/json/http overrides install/restore the package seams. + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// fakePGConn is a programmable pgConn double. +type fakePGConn struct { + // execErrFor returns a non-nil error when the executed SQL contains the + // substring key. The first matching key wins. nil → success. + execErrFor map[string]error + // queryRowErr is returned by the Row.Scan of any QueryRow call. + queryRowErr error + // scanInt64 is written into the first *int64 Scan destination on success. + scanInt64 int64 + // closeErr is returned by Close. + closeErr error + + execCalls []string + queryCalls []string + closeCalls int +} + +func (f *fakePGConn) Exec(_ context.Context, sql string, _ ...any) (pgconn.CommandTag, error) { + f.execCalls = append(f.execCalls, sql) + for sub, err := range f.execErrFor { + if err != nil && strings.Contains(sql, sub) { + return pgconn.CommandTag{}, err + } + } + return pgconn.CommandTag{}, nil +} + +func (f *fakePGConn) QueryRow(_ context.Context, sql string, _ ...any) pgx.Row { + f.queryCalls = append(f.queryCalls, sql) + return &fakeRow{err: f.queryRowErr, v: f.scanInt64} +} + +func (f *fakePGConn) Close(_ context.Context) error { + f.closeCalls++ + return f.closeErr +} + +// fakeRow implements pgx.Row. +type fakeRow struct { + err error + v int64 +} + +func (r *fakeRow) Scan(dest ...any) error { + if r.err != nil { + return r.err + } + if len(dest) > 0 { + switch p := dest[0].(type) { + case *int64: + *p = r.v + case *int: + *p = int(r.v) + } + } + return nil +} + +// withPGXConnect installs a pgxConnect seam returning the given conn (or err) +// and restores the original on test cleanup. +func withPGXConnect(t *testing.T, conn pgConn, err error) { + t.Helper() + orig := pgxConnect + pgxConnect = func(context.Context, string) (pgConn, error) { + if err != nil { + return nil, err + } + return conn, nil + } + t.Cleanup(func() { pgxConnect = orig }) +} + +// withPGXConnectFunc installs a fully custom pgxConnect seam (e.g. to vary the +// returned conn by call order) and restores it on cleanup. +func withPGXConnectFunc(t *testing.T, fn func(ctx context.Context, dsn string) (pgConn, error)) { + t.Helper() + orig := pgxConnect + pgxConnect = fn + t.Cleanup(func() { pgxConnect = orig }) +} + +var errSeam = errors.New("seam-induced error")