diff --git a/internal/providers/db/local.go b/internal/providers/db/local.go index b89ea3e..998d036 100644 --- a/internal/providers/db/local.go +++ b/internal/providers/db/local.go @@ -8,6 +8,7 @@ import ( "context" "crypto/rand" "fmt" + "io" "log/slog" "math/big" @@ -19,6 +20,33 @@ const defaultCustomersURL = "postgres://instant_cust:instant_cust@postgres-custo // alphanumChars is the charset for generated passwords. const alphanumChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +// randReader is the source of entropy for `generatePassword`. Production +// code uses crypto/rand.Reader. Tests swap it for an always-failing +// reader to exercise the rand-error branch without disturbing the rest +// of the process. Tests MUST restore the previous value via a deferred +// reassignment. +var randReader io.Reader = rand.Reader + +// newDBConnect is the package-level seam for the post-CREATE connect- +// to-new-DB step in Provision. Defaults to pgx.Connect; tests override +// it to inject a fake-connect failure so the "couldn't reach new DB" +// branches in Provision can be exercised without the gymnastics of +// constructing a real-looking URL that fails only on the second +// connect. +var newDBConnect = func(ctx context.Context, url string) (*pgx.Conn, error) { + return pgx.Connect(ctx, url) +} + +// adminConnect is the package-level seam for the FIRST connect in +// each of Provision / StorageBytes / Deprovision — the admin +// connection to b.customersURL. Defaults to pgx.Connect; tests can +// swap it to inject a pre-terminated connection so the +// terminate_backend / DROP DATABASE / disconnect error branches can +// be exercised deterministically. +var adminConnect = func(ctx context.Context, url string) (*pgx.Conn, error) { + return pgx.Connect(ctx, url) +} + // LocalBackend provisions databases on the shared postgres-customers instance. type LocalBackend struct { customersURL string // admin connection URL @@ -33,11 +61,13 @@ func newLocalBackend(customersURL string) *LocalBackend { } // generatePassword returns a cryptographically random alphanumeric string of length n. +// Entropy comes from the package-level `randReader` (see definition); +// tests can swap it for a failing reader to exercise the error branch. func generatePassword(n int) (string, error) { buf := make([]byte, n) charsetLen := big.NewInt(int64(len(alphanumChars))) for i := range buf { - idx, err := rand.Int(rand.Reader, charsetLen) + idx, err := rand.Int(randReader, charsetLen) if err != nil { return "", fmt.Errorf("generatePassword: %w", err) } @@ -72,7 +102,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier } // Connect as admin. - conn, err := pgx.Connect(ctx, b.customersURL) + conn, err := adminConnect(ctx, b.customersURL) if err != nil { return nil, fmt.Errorf("db.local.Provision: connect: %w", err) } @@ -110,7 +140,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier // run as a superuser/admin, not the per-token user (which lacks // CREATE-on-pg_catalog privileges). newDBURL := b.buildDBURL(username, pass, dbName) - adminNewDB, err := pgx.Connect(ctx, b.buildAdminNewDBURL(dbName)) + adminNewDB, err := newDBConnect(ctx, b.buildAdminNewDBURL(dbName)) if err != nil { slog.Error("db.local.Provision: connect new db for schema grant (non-fatal)", "error", err) // If extensions were requested and we couldn't connect to the new @@ -158,7 +188,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier // StorageBytes returns the size of db_{token} in bytes using pg_database_size. func (b *LocalBackend) StorageBytes(ctx context.Context, token, providerResourceID string) (int64, error) { - conn, err := pgx.Connect(ctx, b.customersURL) + conn, err := adminConnect(ctx, b.customersURL) if err != nil { return 0, fmt.Errorf("db.local.StorageBytes: connect: %w", err) } @@ -181,7 +211,7 @@ func (b *LocalBackend) Deprovision(ctx context.Context, token, providerResourceI dbName := "db_" + token username := "usr_" + token - conn, err := pgx.Connect(ctx, b.customersURL) + conn, err := adminConnect(ctx, b.customersURL) if err != nil { return fmt.Errorf("db.local.Deprovision: connect: %w", err) } diff --git a/internal/providers/db/local_integration_test.go b/internal/providers/db/local_integration_test.go new file mode 100644 index 0000000..03b628d --- /dev/null +++ b/internal/providers/db/local_integration_test.go @@ -0,0 +1,563 @@ +package db + +// Integration tests for LocalBackend — these talk to a real Postgres +// instance via the TEST_POSTGRES_CUSTOMERS_URL env var. CI provides this +// via the postgres service container in .github/workflows; local devs +// can point it at any postgres where the connecting role is SUPERUSER +// (CREATE DATABASE + CREATE ROLE both require superuser privileges). +// +// The tests skip gracefully when the env var is unset OR when the +// connection itself fails, so they don't break developers running the +// gate without Docker. + +import ( + "context" + "fmt" + "math/rand" + "os" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5" +) + +// requireLocalBackend pulls TEST_POSTGRES_CUSTOMERS_URL from the env and +// builds a LocalBackend. If the env var is unset, the test skips. +// If the connection fails (e.g. the test-pg container died), the test +// also skips with a clear message — we want red builds to mean a +// regression, not a flaky local env. +func requireLocalBackend(t *testing.T) (*LocalBackend, string) { + t.Helper() + url := os.Getenv("TEST_POSTGRES_CUSTOMERS_URL") + if url == "" { + t.Skip("TEST_POSTGRES_CUSTOMERS_URL not set; skipping LocalBackend integration test") + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + conn, err := pgx.Connect(ctx, url) + if err != nil { + t.Skipf("TEST_POSTGRES_CUSTOMERS_URL unreachable (%v); skipping", err) + } + _ = conn.Close(ctx) + return newLocalBackend(url), url +} + +// genToken returns a short, valid-identifier token unique to this run. +// We avoid using a full uuid because CREATE DATABASE quotes the name — +// `db_uuid` works but the short prefix path is the common production +// shape. Mix a low-entropy random suffix so concurrent test runs against +// the same Postgres don't collide. +func genToken(t *testing.T) string { + t.Helper() + r := rand.New(rand.NewSource(time.Now().UnixNano())) + return fmt.Sprintf("itest%d%d", time.Now().UnixNano()%1_000_000, r.Intn(10_000)) +} + +// cleanupResources tears down a token's database + user. Run via t.Cleanup +// so a panicking test still drops its leftovers — important when a shared +// Postgres is reused across many runs. +func cleanupResources(t *testing.T, adminURL, token string) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := pgx.Connect(ctx, adminURL) + if err != nil { + t.Logf("cleanup: connect: %v", err) + return + } + defer conn.Close(ctx) + // kill leftover sessions, then drop. + _, _ = conn.Exec(ctx, + "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname=$1 AND pid<>pg_backend_pid()", + "db_"+token, + ) + _, _ = conn.Exec(ctx, fmt.Sprintf(`DROP DATABASE IF EXISTS %q`, "db_"+token)) + _, _ = conn.Exec(ctx, fmt.Sprintf(`DROP USER IF EXISTS %q`, "usr_"+token)) +} + +// TestLocal_Provision_HappyPath — creates a database, the user can +// connect (via the URL returned), the database is visible via the admin +// URL. +func TestLocal_Provision_HappyPath(t *testing.T) { + b, adminURL := requireLocalBackend(t) + token := genToken(t) + t.Cleanup(func() { cleanupResources(t, adminURL, token) }) + + creds, err := b.Provision(context.Background(), token, "free") + if err != nil { + t.Fatalf("Provision: %v", err) + } + if creds.DatabaseName != "db_"+token { + t.Errorf("DatabaseName: got %q", creds.DatabaseName) + } + if creds.Username != "usr_"+token { + t.Errorf("Username: got %q", creds.Username) + } + if !strings.HasPrefix(creds.URL, "postgres://usr_"+token+":") { + t.Errorf("URL prefix: got %q", creds.URL) + } + if creds.ProviderResourceID != "" { + t.Errorf("ProviderResourceID: want empty, got %q", creds.ProviderResourceID) + } + + // The new user can connect to its database via the returned URL. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + // Some test environments don't route the admin host externally — + // rewrite the host portion of creds.URL to match the admin URL host + // before dialing. + dialURL := strings.Replace(creds.URL, extractHost(creds.URL), extractHost(adminURL), 1) + conn, err := pgx.Connect(ctx, dialURL+"?sslmode=disable") + if err != nil { + t.Fatalf("user-facing dial: %v (url=%s)", err, dialURL) + } + defer conn.Close(ctx) + var n int + if err := conn.QueryRow(ctx, "SELECT 1").Scan(&n); err != nil || n != 1 { + t.Fatalf("SELECT 1 as new user: n=%d err=%v", n, err) + } +} + +// TestLocal_ProvisionWithExtensions_Vector — passing the "vector" +// extension installs pgvector in the new database (or surfaces the +// CREATE EXTENSION error if the cluster lacks the package). +func TestLocal_ProvisionWithExtensions_Vector(t *testing.T) { + b, adminURL := requireLocalBackend(t) + token := genToken(t) + t.Cleanup(func() { cleanupResources(t, adminURL, token) }) + + _, err := b.ProvisionWithExtensions(context.Background(), token, "pro", []string{"vector"}) + if err != nil { + // Most plain `postgres:16-alpine` images do NOT ship pgvector; + // in that case CREATE EXTENSION fails with the contrib-not- + // found error, which exercises the extension-error branch. + // Either outcome (success or this specific failure) is + // acceptable for the integration; what we don't tolerate is a + // silent passthrough. + if !strings.Contains(err.Error(), "CREATE EXTENSION") && + !strings.Contains(err.Error(), "extension") { + t.Fatalf("unexpected error: %v", err) + } + return + } + + // If extension install succeeded, confirm it via pg_extension. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + newAdmin, err := pgx.Connect(ctx, b.buildAdminNewDBURL("db_"+token)) + if err != nil { + t.Fatalf("new-db admin connect: %v", err) + } + defer newAdmin.Close(ctx) + var has bool + if err := newAdmin.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname='vector')").Scan(&has); err != nil { + t.Fatalf("pg_extension probe: %v", err) + } + if !has { + t.Fatal("vector extension not installed despite no error") + } +} + +// TestLocal_ProvisionWithExtensions_RejectedExtension — disallowed +// extensions fail validation before any DDL runs. +func TestLocal_ProvisionWithExtensions_RejectedExtension(t *testing.T) { + b, _ := requireLocalBackend(t) + _, err := b.ProvisionWithExtensions(context.Background(), "tok-x", "pro", []string{"postgis"}) + if err == nil || !strings.Contains(err.Error(), "allowlist") { + t.Fatalf("want allowlist error, got %v", err) + } +} + +// TestLocal_Provision_DuplicateFails — second Provision call with the +// same token should fail at CREATE DATABASE (the database name already +// exists). +func TestLocal_Provision_DuplicateFails(t *testing.T) { + b, adminURL := requireLocalBackend(t) + token := genToken(t) + t.Cleanup(func() { cleanupResources(t, adminURL, token) }) + + if _, err := b.Provision(context.Background(), token, "free"); err != nil { + t.Fatalf("first Provision: %v", err) + } + if _, err := b.Provision(context.Background(), token, "free"); err == nil || + !strings.Contains(err.Error(), "CREATE DATABASE") { + t.Fatalf("duplicate provision: want CREATE DATABASE error, got %v", err) + } +} + +// TestLocal_Provision_ConnectFails — bad admin URL surfaces as a +// connect-error from Provision (covers the early-return branch). +func TestLocal_Provision_ConnectFails(t *testing.T) { + if os.Getenv("TEST_POSTGRES_CUSTOMERS_URL") == "" { + t.Skip("TEST_POSTGRES_CUSTOMERS_URL not set; skipping") + } + b := newLocalBackend("postgres://nobody:nobody@127.0.0.1:1/none?sslmode=disable&connect_timeout=1") + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if _, err := b.Provision(ctx, "tok-deadhost", "free"); err == nil || + !strings.Contains(err.Error(), "connect") { + t.Fatalf("want connect error, got %v", err) + } +} + +// TestLocal_StorageBytes_HappyPath — pg_database_size returns a small +// positive integer for the just-created database. +func TestLocal_StorageBytes_HappyPath(t *testing.T) { + b, adminURL := requireLocalBackend(t) + token := genToken(t) + t.Cleanup(func() { cleanupResources(t, adminURL, token) }) + + if _, err := b.Provision(context.Background(), token, "free"); err != nil { + t.Fatalf("Provision: %v", err) + } + n, err := b.StorageBytes(context.Background(), token, "") + if err != nil { + t.Fatalf("StorageBytes: %v", err) + } + if n <= 0 { + t.Fatalf("StorageBytes: got %d, want > 0", n) + } +} + +// TestLocal_StorageBytes_MissingDB — pg_database_size on a non-existent +// db returns a query-execution error (the missing-DB branch). +func TestLocal_StorageBytes_MissingDB(t *testing.T) { + b, _ := requireLocalBackend(t) + if _, err := b.StorageBytes(context.Background(), "no-such-token-"+genToken(t), ""); err == nil || + !strings.Contains(err.Error(), "pg_database_size") { + t.Fatalf("want pg_database_size error, got %v", err) + } +} + +// TestLocal_StorageBytes_ConnectFails — bad admin URL surfaces as a +// connect-error from StorageBytes. +func TestLocal_StorageBytes_ConnectFails(t *testing.T) { + if os.Getenv("TEST_POSTGRES_CUSTOMERS_URL") == "" { + t.Skip("TEST_POSTGRES_CUSTOMERS_URL not set; skipping") + } + b := newLocalBackend("postgres://nobody:nobody@127.0.0.1:1/none?sslmode=disable&connect_timeout=1") + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if _, err := b.StorageBytes(ctx, "tok", ""); err == nil || + !strings.Contains(err.Error(), "connect") { + t.Fatalf("want connect error, got %v", err) + } +} + +// TestLocal_Deprovision_HappyPath — Provision then Deprovision; database +// must be gone and user must be gone afterwards. +func TestLocal_Deprovision_HappyPath(t *testing.T) { + b, adminURL := requireLocalBackend(t) + token := genToken(t) + t.Cleanup(func() { cleanupResources(t, adminURL, token) }) + + if _, err := b.Provision(context.Background(), token, "free"); err != nil { + t.Fatalf("Provision: %v", err) + } + if err := b.Deprovision(context.Background(), token, ""); err != nil { + t.Fatalf("Deprovision: %v", err) + } + + // Probe: database and user must be gone. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + conn, err := pgx.Connect(ctx, adminURL) + if err != nil { + t.Fatalf("admin reconnect: %v", err) + } + defer conn.Close(ctx) + + var dbCount int + if err := conn.QueryRow(ctx, "SELECT count(*) FROM pg_database WHERE datname=$1", "db_"+token).Scan(&dbCount); err != nil { + t.Fatalf("pg_database probe: %v", err) + } + if dbCount != 0 { + t.Fatalf("db still present after Deprovision: count=%d", dbCount) + } + var userCount int + if err := conn.QueryRow(ctx, "SELECT count(*) FROM pg_user WHERE usename=$1", "usr_"+token).Scan(&userCount); err != nil { + t.Fatalf("pg_user probe: %v", err) + } + if userCount != 0 { + t.Fatalf("user still present after Deprovision: count=%d", userCount) + } +} + +// TestLocal_Deprovision_NoSuchToken — DROP IF EXISTS makes Deprovision +// idempotent: calling it on a token that was never provisioned succeeds. +func TestLocal_Deprovision_NoSuchToken(t *testing.T) { + b, _ := requireLocalBackend(t) + if err := b.Deprovision(context.Background(), "never-existed-"+genToken(t), ""); err != nil { + t.Fatalf("idempotent Deprovision: %v", err) + } +} + +// TestLocal_Deprovision_ConnectFails — bad admin URL surfaces as a +// connect-error from Deprovision. +func TestLocal_Deprovision_ConnectFails(t *testing.T) { + if os.Getenv("TEST_POSTGRES_CUSTOMERS_URL") == "" { + t.Skip("TEST_POSTGRES_CUSTOMERS_URL not set; skipping") + } + b := newLocalBackend("postgres://nobody:nobody@127.0.0.1:1/none?sslmode=disable&connect_timeout=1") + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if err := b.Deprovision(ctx, "tok", ""); err == nil || + !strings.Contains(err.Error(), "connect") { + t.Fatalf("want connect error, got %v", err) + } +} + +// TestLocal_Provision_GrantSchemaFails_AndCloseError — the +// `newDBConnect` seam returns a connection that we close BEFORE +// Provision's GRANT SCHEMA runs, so: +// (1) the GRANT SCHEMA exec fails -> exercises the "GRANT SCHEMA +// (non-fatal)" branch +// (2) the deferred conn.Close runs on the already-closed conn -> +// exercises the "disconnect new db" defer-error branch +// +// Provision must still return success because the GRANT SCHEMA error +// is logged-and-continued. +func TestLocal_Provision_GrantSchemaFails_AndCloseError(t *testing.T) { + b, adminURL := requireLocalBackend(t) + token := genToken(t) + t.Cleanup(func() { cleanupResources(t, adminURL, token) }) + + orig := newDBConnect + t.Cleanup(func() { newDBConnect = orig }) + newDBConnect = func(ctx context.Context, url string) (*pgx.Conn, error) { + c, err := pgx.Connect(ctx, url) + if err != nil { + return nil, err + } + // Close the connection immediately. The caller will try to + // Exec on it (-> error) and Close again in defer (-> error). + _ = c.Close(context.Background()) + return c, nil + } + + creds, err := b.Provision(context.Background(), token, "free") + if err != nil { + t.Fatalf("Provision should still succeed: %v", err) + } + if creds.DatabaseName != "db_"+token { + t.Fatalf("DatabaseName: got %q", creds.DatabaseName) + } +} + +// TestLocal_Provision_NewDBConnectFails_NoExtensions — when the +// post-CREATE connect to the new database fails AND no extensions were +// requested, Provision logs and continues (returns success without the +// schema-grant). Uses the `newDBConnect` package seam so the failure +// is deterministic. +func TestLocal_Provision_NewDBConnectFails_NoExtensions(t *testing.T) { + b, adminURL := requireLocalBackend(t) + token := genToken(t) + t.Cleanup(func() { cleanupResources(t, adminURL, token) }) + + orig := newDBConnect + t.Cleanup(func() { newDBConnect = orig }) + newDBConnect = func(_ context.Context, _ string) (*pgx.Conn, error) { + return nil, fmt.Errorf("simulated new-db connect failure") + } + + creds, err := b.Provision(context.Background(), token, "free") + if err != nil { + t.Fatalf("Provision should still succeed: %v", err) + } + if creds.DatabaseName != "db_"+token { + t.Fatalf("DatabaseName: got %q", creds.DatabaseName) + } +} + +// TestLocal_Provision_NewDBConnectFails_WithExtensions — same seam, +// but with the "vector" extension requested. The new-db connect +// failure is fatal because we can't install the requested extension. +func TestLocal_Provision_NewDBConnectFails_WithExtensions(t *testing.T) { + b, adminURL := requireLocalBackend(t) + token := genToken(t) + t.Cleanup(func() { cleanupResources(t, adminURL, token) }) + + orig := newDBConnect + t.Cleanup(func() { newDBConnect = orig }) + newDBConnect = func(_ context.Context, _ string) (*pgx.Conn, error) { + return nil, fmt.Errorf("simulated new-db connect failure") + } + + _, err := b.ProvisionWithExtensions(context.Background(), token, "pro", []string{"vector"}) + if err == nil || !strings.Contains(err.Error(), "install extensions") { + t.Fatalf("want install-extensions error wrap, got %v", err) + } +} + +// TestLocal_Provision_GeneratePasswordFails — swap the package-level +// randReader to one that always errors so Provision's password-gen +// step fails and returns before any DB work happens. Exercises the +// `generatePassword: ...` wrap at line 70-72. +func TestLocal_Provision_GeneratePasswordFails(t *testing.T) { + orig := randReader + t.Cleanup(func() { randReader = orig }) + randReader = failingReader{err: errFakeRand} + + b := newLocalBackend("postgres://i:i@localhost:5432/x?sslmode=disable") + if _, err := b.Provision(context.Background(), "tok", "free"); err == nil || + !strings.Contains(err.Error(), "generatePassword") { + t.Fatalf("want generatePassword error wrap, got %v", err) + } +} + +// TestLocal_Deprovision_AdminConnPreKilled — the `adminConnect` seam +// returns a real connection that we've already self-terminated. The +// connection is "logged in" as far as the Go side knows, but every +// Exec returns "conn closed". This exercises (a) the +// terminate_backend `logged-and-continued` branch and (b) the +// DROP DATABASE error-return branch in one shot. +func TestLocal_Deprovision_AdminConnPreKilled(t *testing.T) { + b, adminURL := requireLocalBackend(t) + token := genToken(t) + t.Cleanup(func() { cleanupResources(t, adminURL, token) }) + + // Seed a real db_TOKEN so the DROP target exists. + if _, err := b.Provision(context.Background(), token, "free"); err != nil { + t.Fatalf("seed Provision: %v", err) + } + + orig := adminConnect + t.Cleanup(func() { adminConnect = orig }) + adminConnect = func(ctx context.Context, url string) (*pgx.Conn, error) { + c, err := pgx.Connect(ctx, url) + if err != nil { + return nil, err + } + // Self-terminate: backend pid runs SIGTERM on itself, leaving + // the pgx.Conn in a "logged-in but socket-dead" state. + _, _ = c.Exec(ctx, "SELECT pg_terminate_backend(pg_backend_pid())") + return c, nil + } + + err := b.Deprovision(context.Background(), token, "") + // terminate_backend Exec fails (logged-and-continued). + // DROP DATABASE Exec fails — returns error. + if err == nil || !strings.Contains(err.Error(), "DROP DATABASE") { + t.Fatalf("want DROP DATABASE error, got %v", err) + } +} + +// TestLocal_Provision_CreateUserFails — if a stale user with the same +// name already exists from a botched earlier Provision (database was +// dropped but user wasn't), CREATE USER hits a duplicate-role error. +// This exercises the CREATE USER error branch. +func TestLocal_Provision_CreateUserFails(t *testing.T) { + b, adminURL := requireLocalBackend(t) + token := genToken(t) + username := "usr_" + token + t.Cleanup(func() { cleanupResources(t, adminURL, token) }) + + // Pre-create the user so the database name is still free but the + // role isn't. This is exactly the post-crash state Provision must + // surface as an error. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := pgx.Connect(ctx, adminURL) + if err != nil { + t.Fatalf("admin connect: %v", err) + } + defer conn.Close(ctx) + if _, err := conn.Exec(ctx, fmt.Sprintf(`CREATE USER %q WITH PASSWORD 'x'`, username)); err != nil { + t.Fatalf("seed user: %v", err) + } + + if _, err := b.Provision(context.Background(), token, "free"); err == nil || + !strings.Contains(err.Error(), "CREATE USER") { + t.Fatalf("want CREATE USER error, got %v", err) + } +} + +// TestLocal_Deprovision_DropUserContinues — DROP USER fails for a role +// that owns objects (it can't be dropped while it owns anything). We +// re-assign and confirm Deprovision continues anyway (the function logs +// and returns nil — DROP USER errors are non-fatal). This proves the +// DROP-USER-continues branch. +func TestLocal_Deprovision_DropUserContinues(t *testing.T) { + b, adminURL := requireLocalBackend(t) + token := genToken(t) + username := "usr_" + token + t.Cleanup(func() { + // best-effort owner reassign so the user CAN be dropped. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + conn, err := pgx.Connect(ctx, adminURL) + if err == nil { + _, _ = conn.Exec(ctx, fmt.Sprintf(`REASSIGN OWNED BY %q TO CURRENT_USER`, username)) + _, _ = conn.Exec(ctx, fmt.Sprintf(`DROP OWNED BY %q`, username)) + conn.Close(ctx) + } + cleanupResources(t, adminURL, token) + }) + + if _, err := b.Provision(context.Background(), token, "free"); err != nil { + t.Fatalf("Provision: %v", err) + } + + // Make the user own a stray table in the admin DB; DROP DATABASE + // removes the per-token db (and its objects), but the admin-DB + // table outlives it, so DROP USER fails with a "owns objects" error. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := pgx.Connect(ctx, adminURL) + if err != nil { + t.Fatalf("admin connect: %v", err) + } + defer conn.Close(ctx) + tbl := fmt.Sprintf("stray_%s", strings.ReplaceAll(token, "-", "_")) + if _, err := conn.Exec(ctx, fmt.Sprintf(`CREATE TABLE %q (id int)`, tbl)); err != nil { + t.Fatalf("create stray: %v", err) + } + if _, err := conn.Exec(ctx, fmt.Sprintf(`ALTER TABLE %q OWNER TO %q`, tbl, username)); err != nil { + t.Fatalf("chown stray: %v", err) + } + defer conn.Exec(ctx, fmt.Sprintf(`DROP TABLE IF EXISTS %q`, tbl)) + + // Deprovision returns nil even though DROP USER fails — the DROP + // USER error is logged-and-continued. + if err := b.Deprovision(context.Background(), token, ""); err != nil { + t.Fatalf("Deprovision (DROP USER fail should be non-fatal): %v", err) + } + // Confirm: user still exists. + var n int + if err := conn.QueryRow(ctx, `SELECT count(*) FROM pg_user WHERE usename=$1`, username).Scan(&n); err != nil { + t.Fatalf("pg_user probe: %v", err) + } + if n != 1 { + t.Fatalf("user count=%d (expected 1 — DROP USER should have failed but Deprovision returned ok)", n) + } +} + +// TestLocal_Provision_FullTokenAndShortPrefix — both the "full UUID" and +// "short prefix" naming conventions round-trip through Provision + +// Deprovision. The DB/user names are simple string concatenations on +// token, so any token that's a valid identifier should work. +func TestLocal_Provision_FullTokenAndShortPrefix(t *testing.T) { + b, adminURL := requireLocalBackend(t) + // Full-token shape: hex with dashes — Postgres needs the identifier + // quoted, which Provision does. + full := "deadbeef-cafe-0000-0000-" + strings.ReplaceAll(genToken(t), "itest", "") + // Short prefix shape: the common compact identifier. + short := genToken(t) + for _, token := range []string{full, short} { + token := token + t.Run("token="+token, func(t *testing.T) { + t.Cleanup(func() { cleanupResources(t, adminURL, token) }) + creds, err := b.Provision(context.Background(), token, "free") + if err != nil { + t.Fatalf("Provision(%q): %v", token, err) + } + if !strings.HasSuffix(creds.DatabaseName, token) { + t.Errorf("DatabaseName: got %q want suffix %q", creds.DatabaseName, token) + } + if err := b.Deprovision(context.Background(), token, ""); err != nil { + t.Fatalf("Deprovision: %v", err) + } + }) + } +} diff --git a/internal/providers/db/local_unit_test.go b/internal/providers/db/local_unit_test.go new file mode 100644 index 0000000..aac7fef --- /dev/null +++ b/internal/providers/db/local_unit_test.go @@ -0,0 +1,199 @@ +package db + +// Pure-unit coverage for the no-network helpers in local.go: the URL +// surgery (`extractHost` / `buildDBURL` / `buildAdminNewDBURL`), the +// `indexOf` byte-scan, `generatePassword`'s charset + length invariants, +// and `newLocalBackend`'s default-URL substitution. None of these touch +// Postgres — they exercise the helper layer that lives between +// `Provision` and the wire. + +import ( + "strings" + "testing" +) + +// Test_newLocalBackend_DefaultURL — passing "" must substitute the +// package-level default so callers running outside k8s don't accidentally +// connect to nothing. +func Test_newLocalBackend_DefaultURL(t *testing.T) { + b := newLocalBackend("") + if b.customersURL != defaultCustomersURL { + t.Fatalf("empty input: got %q want %q", b.customersURL, defaultCustomersURL) + } + + custom := "postgres://x:y@host:1234/db?sslmode=disable" + b2 := newLocalBackend(custom) + if b2.customersURL != custom { + t.Fatalf("custom input: got %q want %q", b2.customersURL, custom) + } +} + +// Test_generatePassword_LengthAndCharset — output is exactly n bytes long +// and every byte is in `alphanumChars`. n=0 is permitted (empty string). +func Test_generatePassword_LengthAndCharset(t *testing.T) { + for _, n := range []int{0, 1, 8, 16, 64} { + got, err := generatePassword(n) + if err != nil { + t.Fatalf("generatePassword(%d): %v", n, err) + } + if len(got) != n { + t.Fatalf("generatePassword(%d): len=%d want %d", n, len(got), n) + } + for i := 0; i < len(got); i++ { + if !strings.ContainsRune(alphanumChars, rune(got[i])) { + t.Fatalf("generatePassword(%d): byte %q at %d not in charset", n, got[i], i) + } + } + } +} + +// Test_generatePassword_RandomEnough — two consecutive 32-byte passwords +// must differ. Tiny smoke check, not a statistical test — the goal is +// only to prove crypto/rand was actually invoked. +func Test_generatePassword_RandomEnough(t *testing.T) { + a, err := generatePassword(32) + if err != nil { + t.Fatalf("a: %v", err) + } + b, err := generatePassword(32) + if err != nil { + t.Fatalf("b: %v", err) + } + if a == b { + t.Fatalf("two 32-byte passwords collided — crypto/rand not invoked? %q", a) + } +} + +// failingReader returns an error on every Read. Used to drive the +// `rand.Int` error branch in generatePassword without disturbing +// crypto/rand for other tests. +type failingReader struct{ err error } + +func (f failingReader) Read(_ []byte) (int, error) { return 0, f.err } + +// Test_generatePassword_RandErrorBranch — when the package-level +// randReader fails, generatePassword surfaces the error wrapped with +// the function name. +func Test_generatePassword_RandErrorBranch(t *testing.T) { + orig := randReader + t.Cleanup(func() { randReader = orig }) + randReader = failingReader{err: errFakeRand} + + got, err := generatePassword(8) + if err == nil { + t.Fatalf("want error, got %q", got) + } + if !strings.Contains(err.Error(), "generatePassword") { + t.Fatalf("err=%v want generatePassword-wrapped", err) + } +} + +// errFakeRand is a sentinel passed into failingReader so tests can match +// it back via errors.Is if needed. Kept package-level so test files +// outside this one can reuse it. +var errFakeRand = &fakeRandError{msg: "fake rand failed"} + +type fakeRandError struct{ msg string } + +func (e *fakeRandError) Error() string { return e.msg } + +// Test_indexOf_ByteScan — the package's own minimal `bytes.IndexByte` +// replacement. Cover both the hit (returns index) and miss (returns -1) +// branches. +func Test_indexOf_ByteScan(t *testing.T) { + cases := []struct { + s string + c byte + want int + }{ + {"abcd", 'c', 2}, + {"abcd", 'a', 0}, + {"abcd", 'd', 3}, + {"abcd", 'z', -1}, + {"", 'x', -1}, + {"@", '@', 0}, + } + for _, tc := range cases { + if got := indexOf(tc.s, tc.c); got != tc.want { + t.Errorf("indexOf(%q,%q)=%d want %d", tc.s, tc.c, got, tc.want) + } + } +} + +// Test_extractHost_Cases — covers every branch in extractHost: +// - prefix-trimmed input +// - URL with user:pass@host:port/db +// - URL with host only (no auth, no path) +// - URL with no '/' after host (returns rest of string) +// - empty string (returns empty) +// - input shorter than the postgres:// prefix (defensive branch) +func Test_extractHost_Cases(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"postgres://user:pass@host:5432/db", "host:5432"}, + {"postgres://user:pass@host/db", "host"}, + {"postgres://host:5432/db", "host:5432"}, + {"postgres://host", "host"}, + {"postgres://user:pass@host:5432/db?sslmode=disable", "host:5432"}, + {"", ""}, + // Shorter than `postgres://` — defensive branch where the + // prefix-trim is skipped. + {"pg", "pg"}, + } + for _, tc := range cases { + if got := extractHost(tc.in); got != tc.want { + t.Errorf("extractHost(%q)=%q want %q", tc.in, got, tc.want) + } + } +} + +// Test_buildDBURL_RoundtripsHost — buildDBURL composes a per-tenant +// user-facing URL by stripping the admin URL down to host:port and +// re-attaching the new credentials + database name. +func Test_buildDBURL_RoundtripsHost(t *testing.T) { + b := &LocalBackend{customersURL: "postgres://admin:adminpw@db.internal:5432/instant_customers?sslmode=disable"} + got := b.buildDBURL("usr_abc", "secret", "db_abc") + want := "postgres://usr_abc:secret@db.internal:5432/db_abc" + if got != want { + t.Fatalf("buildDBURL: got %q want %q", got, want) + } +} + +// Test_buildAdminNewDBURL_SwapsTrailingDatabase — strips the trailing +// `/...` from the admin URL and re-attaches the new database name. +func Test_buildAdminNewDBURL_SwapsTrailingDatabase(t *testing.T) { + cases := []struct { + admin string + db string + want string + }{ + { + "postgres://admin:pw@host:5432/instant_customers", + "db_xyz", + "postgres://admin:pw@host:5432/db_xyz", + }, + { + "postgres://admin:pw@host:5432/instant_customers?sslmode=disable", + // query-string lives AFTER the database name and is intentionally + // truncated by the simple "find last '/'" rewrite. The agent- + // visible URL uses sslmode=disable by default at the caller level; + // this test pins the documented behaviour. + "db_xyz", + "postgres://admin:pw@host:5432/db_xyz", + }, + { + // Defensive branch: no '/' anywhere — fallback appends "/db_xyz". + "postgres-no-slashes", + "db_xyz", + "postgres-no-slashes/db_xyz", + }, + } + for _, tc := range cases { + b := &LocalBackend{customersURL: tc.admin} + if got := b.buildAdminNewDBURL(tc.db); got != tc.want { + t.Errorf("buildAdminNewDBURL(%q,%q)=%q want %q", tc.admin, tc.db, got, tc.want) + } + } +} diff --git a/internal/providers/db/neon_test.go b/internal/providers/db/neon_test.go new file mode 100644 index 0000000..9a542a3 --- /dev/null +++ b/internal/providers/db/neon_test.go @@ -0,0 +1,428 @@ +package db + +// Neon backend tests. We never hit the real https://console.neon.tech +// surface; instead the test stands up an httptest.Server and redirects +// the backend at it by overriding the embedded *http.Client's Transport. +// This proves the per-method request shape (method, path, Bearer header, +// JSON body) and the response-parsing branches: happy-path, non-2xx, +// malformed JSON, missing project_id, missing connection_uri, empty +// providerResourceID guard, and the extensions-not-supported error. + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// rewriteTransport rewrites every outbound request's URL host so that +// requests targeted at neon.tech land on our local httptest server. +type rewriteTransport struct { + target string // "http://127.0.0.1:NNNN" + t *testing.T +} + +func (rt *rewriteTransport) RoundTrip(r *http.Request) (*http.Response, error) { + // Replace scheme+host of neonAPIBase with the httptest target. + rewritten := strings.Replace(r.URL.String(), neonAPIBase, rt.target+"/api/v2", 1) + req2, err := http.NewRequestWithContext(r.Context(), r.Method, rewritten, r.Body) + if err != nil { + return nil, err + } + req2.Header = r.Header.Clone() + return http.DefaultTransport.RoundTrip(req2) +} + +// newNeonBackendForTest wires a NeonBackend at a local httptest server. +func newNeonBackendForTest(t *testing.T, srv *httptest.Server, apiKey, regionID string) *NeonBackend { + t.Helper() + b := newNeonBackend(apiKey, regionID) + b.client.Transport = &rewriteTransport{target: srv.URL, t: t} + return b +} + +// TestNewNeonBackend_DefaultRegion — empty regionID substitutes the +// package-level default. +func TestNewNeonBackend_DefaultRegion(t *testing.T) { + b := newNeonBackend("k", "") + if b.regionID != defaultNeonRegion { + t.Fatalf("default region: got %q want %q", b.regionID, defaultNeonRegion) + } + b2 := newNeonBackend("k", "aws-eu-west-1") + if b2.regionID != "aws-eu-west-1" { + t.Fatalf("explicit region: got %q want aws-eu-west-1", b2.regionID) + } +} + +// TestNeon_Provision_HappyPath — Server returns a valid project payload; +// backend should fill Credentials.URL + ProviderResourceID. +func TestNeon_Provision_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Per-request invariants. + if r.Method != http.MethodPost { + t.Errorf("method: got %q want POST", r.Method) + } + if r.URL.Path != "/api/v2/projects" { + t.Errorf("path: got %q want /api/v2/projects", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("auth: got %q", got) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("content-type: got %q", r.Header.Get("Content-Type")) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("decode body: %v", err) + } + proj, ok := body["project"].(map[string]any) + if !ok { + t.Errorf("project envelope missing") + } + if name := proj["name"].(string); name != "instant-tok-9" { + t.Errorf("name: got %q", name) + } + w.WriteHeader(http.StatusCreated) + _, _ = io.WriteString(w, `{ + "project":{"id":"proj_abc123"}, + "connection_uris":[{"connection_uri":"postgres://user:pass@neon/db"}] + }`) + })) + defer srv.Close() + + b := newNeonBackendForTest(t, srv, "test-key", "aws-us-east-1") + creds, err := b.Provision(context.Background(), "tok-9", "pro") + if err != nil { + t.Fatalf("Provision: %v", err) + } + if creds.URL != "postgres://user:pass@neon/db" { + t.Errorf("url: got %q", creds.URL) + } + if creds.ProviderResourceID != "proj_abc123" { + t.Errorf("prid: got %q", creds.ProviderResourceID) + } + if creds.DatabaseName != "neondb" || creds.Username != "" { + t.Errorf("db/username defaults: %+v", creds) + } +} + +// TestNeon_Provision_ErrorBranches — every error branch of Provision. +func TestNeon_Provision_ErrorBranches(t *testing.T) { + type tc struct { + name string + handler http.HandlerFunc + wantSub string + } + cases := []tc{ + { + "non_2xx_status", + func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = io.WriteString(w, `{"error":"bad key"}`) + }, + "unexpected status 401", + }, + { + "non_json_body", + func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `not-json`) + }, + "unmarshal", + }, + { + "empty_project_id", + func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"project":{"id":""},"connection_uris":[{"connection_uri":"u"}]}`) + }, + "empty project ID", + }, + { + "no_connection_uris", + func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"project":{"id":"p1"},"connection_uris":[]}`) + }, + "no connection URI", + }, + { + "empty_connection_uri", + func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"project":{"id":"p1"},"connection_uris":[{"connection_uri":""}]}`) + }, + "no connection URI", + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + srv := httptest.NewServer(c.handler) + defer srv.Close() + b := newNeonBackendForTest(t, srv, "k", "") + _, err := b.Provision(context.Background(), "tok", "free") + if err == nil { + t.Fatalf("want error, got nil") + } + if !strings.Contains(err.Error(), c.wantSub) { + t.Fatalf("err=%q want substr %q", err.Error(), c.wantSub) + } + }) + } +} + +// TestNeon_Provision_HTTPDoFails — exercise the request-do error branch +// (network unreachable). We close the server before the call. +func TestNeon_Provision_HTTPDoFails(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + srv.Close() // immediately + b := newNeonBackendForTest(t, srv, "k", "") + _, err := b.Provision(context.Background(), "tok", "free") + if err == nil || !strings.Contains(err.Error(), "http") { + t.Fatalf("want http error, got %v", err) + } +} + +// TestNeon_ProvisionWithExtensions — the vector path returns the docs +// "not yet supported" error AND the underlying creds (so callers see +// what got created in spite of the missing extension). +func TestNeon_ProvisionWithExtensions(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + _, _ = io.WriteString(w, `{"project":{"id":"p"},"connection_uris":[{"connection_uri":"u"}]}`) + })) + defer srv.Close() + b := newNeonBackendForTest(t, srv, "k", "") + + // Disallowed extension is rejected at the validator BEFORE any HTTP call. + if _, err := b.ProvisionWithExtensions(context.Background(), "tok", "team", []string{"postgis"}); err == nil { + t.Fatal("want error for disallowed extension") + } + + // Allowed extension still errors (Neon path not wired) but returns creds. + creds, err := b.ProvisionWithExtensions(context.Background(), "tok", "team", []string{"vector"}) + if err == nil { + t.Fatal("want extensions-unsupported error on Neon") + } + if !strings.Contains(err.Error(), "not yet supported") { + t.Fatalf("err=%v", err) + } + if creds == nil { + t.Fatal("expected creds returned alongside error") + } + + // No-extensions path is a transparent passthrough to Provision. + creds, err = b.ProvisionWithExtensions(context.Background(), "tok", "team", nil) + if err != nil { + t.Fatalf("nil exts: %v", err) + } + if creds.ProviderResourceID != "p" { + t.Fatalf("creds: %+v", creds) + } +} + +// TestNeon_ProvisionWithExtensions_InnerProvisionFails — when the inner +// Provision call errors (e.g. Neon returns 5xx), ProvisionWithExtensions +// must surface that error verbatim (line 52 branch). +func TestNeon_ProvisionWithExtensions_InnerProvisionFails(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + _, _ = io.WriteString(w, `gateway down`) + })) + defer srv.Close() + b := newNeonBackendForTest(t, srv, "k", "") + // nil extensions — validator passes, Provision is called, it errors. + _, err := b.ProvisionWithExtensions(context.Background(), "tok", "team", nil) + if err == nil || !strings.Contains(err.Error(), "502") { + t.Fatalf("want 502 propagated, got %v", err) + } +} + +// TestNeon_BodyReadFails — when the upstream closes the connection +// mid-stream during ReadAll, Provision should surface the read error. +// We use a Hijack-and-close handler to slam the socket after writing +// only the headers. +func TestNeon_BodyReadFails(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // Hijack the underlying conn and write a malformed Content-Length + // header, then close before sending any body bytes. ReadAll on + // the resulting body will return an unexpected-EOF error. + hj, ok := w.(http.Hijacker) + if !ok { + t.Errorf("ResponseWriter does not support Hijack") + return + } + conn, _, err := hj.Hijack() + if err != nil { + t.Errorf("hijack: %v", err) + return + } + // Tell the client we'll send 100 bytes then bail. + _, _ = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n")) + _ = conn.Close() + })) + defer srv.Close() + b := newNeonBackendForTest(t, srv, "k", "") + _, err := b.Provision(context.Background(), "tok", "free") + if err == nil { + t.Fatal("want read-body error") + } + // We don't pin the exact substring — different Go versions report + // the truncation differently — but it must be one of the read-body + // or http-do error wrappers. + if !strings.Contains(err.Error(), "read body") && !strings.Contains(err.Error(), "http") { + t.Fatalf("err=%v", err) + } +} + +// TestNeon_StorageBytes_BodyReadFails — same connection-hang scenario, +// but for the GET /projects/:id path. +func TestNeon_StorageBytes_BodyReadFails(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + t.Errorf("ResponseWriter does not support Hijack") + return + } + conn, _, err := hj.Hijack() + if err != nil { + t.Errorf("hijack: %v", err) + return + } + _, _ = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n")) + _ = conn.Close() + })) + defer srv.Close() + b := newNeonBackendForTest(t, srv, "k", "") + if _, err := b.StorageBytes(context.Background(), "tok", "p1"); err == nil { + t.Fatal("want read-body error") + } +} + +// TestNeon_StorageBytes_HappyPath — GET project, parse usage. +func TestNeon_StorageBytes_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || !strings.HasSuffix(r.URL.Path, "/projects/p1") { + t.Errorf("unexpected req: %s %s", r.Method, r.URL.Path) + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"project":{"usage":{"data_storage_bytes_hour":98765}}}`) + })) + defer srv.Close() + b := newNeonBackendForTest(t, srv, "k", "") + got, err := b.StorageBytes(context.Background(), "tok", "p1") + if err != nil { + t.Fatalf("StorageBytes: %v", err) + } + if got != 98765 { + t.Fatalf("got %d want 98765", got) + } +} + +// TestNeon_StorageBytes_ErrorBranches — exhaust every error path. +func TestNeon_StorageBytes_ErrorBranches(t *testing.T) { + // 1) empty providerResourceID — short-circuits before any HTTP call. + b := newNeonBackend("k", "") + if _, err := b.StorageBytes(context.Background(), "tok", ""); err == nil { + t.Fatal("empty prid: want error") + } + + // 2) non-2xx status. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, `boom`) + })) + defer srv.Close() + b2 := newNeonBackendForTest(t, srv, "k", "") + if _, err := b2.StorageBytes(context.Background(), "tok", "p1"); err == nil || + !strings.Contains(err.Error(), "500") { + t.Fatalf("non-2xx: err=%v", err) + } + + // 3) malformed JSON. + srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{bad}`) + })) + defer srv2.Close() + b3 := newNeonBackendForTest(t, srv2, "k", "") + if _, err := b3.StorageBytes(context.Background(), "tok", "p1"); err == nil || + !strings.Contains(err.Error(), "unmarshal") { + t.Fatalf("bad json: err=%v", err) + } + + // 4) network unreachable. + srvDead := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + srvDead.Close() + bDead := newNeonBackendForTest(t, srvDead, "k", "") + if _, err := bDead.StorageBytes(context.Background(), "tok", "p1"); err == nil { + t.Fatal("dead server: want error") + } +} + +// TestNeon_StorageBytes_NewRequestFails — providerResourceID containing +// a control character (newline) makes http.NewRequestWithContext bail +// at URL parse before any network call. Exercises the new-request error +// branch. +func TestNeon_StorageBytes_NewRequestFails(t *testing.T) { + b := newNeonBackend("k", "") + if _, err := b.StorageBytes(context.Background(), "tok", "bad\nid"); err == nil || + !strings.Contains(err.Error(), "new request") { + t.Fatalf("want new-request error, got %v", err) + } +} + +// TestNeon_Deprovision_NewRequestFails — same control-char trick on the +// DELETE path. +func TestNeon_Deprovision_NewRequestFails(t *testing.T) { + b := newNeonBackend("k", "") + if err := b.Deprovision(context.Background(), "tok", "bad\nid"); err == nil || + !strings.Contains(err.Error(), "new request") { + t.Fatalf("want new-request error, got %v", err) + } +} + +// TestNeon_Deprovision — DELETE happy-path + every error branch. +func TestNeon_Deprovision(t *testing.T) { + // 1) empty providerResourceID — short-circuits. + b := newNeonBackend("k", "") + if err := b.Deprovision(context.Background(), "tok", ""); err == nil { + t.Fatal("empty prid: want error") + } + + // 2) happy path. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete || !strings.HasSuffix(r.URL.Path, "/projects/p2") { + t.Errorf("unexpected: %s %s", r.Method, r.URL.Path) + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + bOk := newNeonBackendForTest(t, srv, "k", "") + if err := bOk.Deprovision(context.Background(), "tok", "p2"); err != nil { + t.Fatalf("happy: %v", err) + } + + // 3) non-2xx status. + srvErr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = io.WriteString(w, `gone`) + })) + defer srvErr.Close() + bErr := newNeonBackendForTest(t, srvErr, "k", "") + if err := bErr.Deprovision(context.Background(), "tok", "p2"); err == nil || + !strings.Contains(err.Error(), "404") { + t.Fatalf("404: err=%v", err) + } + + // 4) network unreachable. + srvDead := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + srvDead.Close() + bDead := newNeonBackendForTest(t, srvDead, "k", "") + if err := bDead.Deprovision(context.Background(), "tok", "p3"); err == nil { + t.Fatal("dead: want error") + } +} diff --git a/internal/providers/db/provider_test.go b/internal/providers/db/provider_test.go new file mode 100644 index 0000000..73145f8 --- /dev/null +++ b/internal/providers/db/provider_test.go @@ -0,0 +1,136 @@ +package db + +// Tests for the public Provider facade in provider.go. The Provider type +// is a tiny dispatcher in front of a Backend; the goal here is to prove +// the right backend is selected by `New` and that every method on +// Provider forwards faithfully to the underlying Backend. + +import ( + "context" + "errors" + "testing" + + "instant.dev/internal/config" +) + +// stubBackend is a minimal Backend test-double. It records the arguments +// each method was called with and returns canned responses. Useful for +// asserting Provider.X delegates to backend.X without standing up a real +// Postgres / HTTP server. +type stubBackend struct { + gotToken string + gotTier string + gotExts []string + gotPRID string + provCreds *Credentials + provErr error + storageSize int64 + storageErr error + deprovErr error +} + +func (s *stubBackend) Provision(_ context.Context, token, tier string) (*Credentials, error) { + s.gotToken, s.gotTier, s.gotExts = token, tier, nil + return s.provCreds, s.provErr +} +func (s *stubBackend) ProvisionWithExtensions(_ context.Context, token, tier string, exts []string) (*Credentials, error) { + s.gotToken, s.gotTier, s.gotExts = token, tier, exts + return s.provCreds, s.provErr +} +func (s *stubBackend) StorageBytes(_ context.Context, token, prid string) (int64, error) { + s.gotToken, s.gotPRID = token, prid + return s.storageSize, s.storageErr +} +func (s *stubBackend) Deprovision(_ context.Context, token, prid string) error { + s.gotToken, s.gotPRID = token, prid + return s.deprovErr +} + +// TestNew_PicksLocalByDefault — empty / unknown backend strings default +// to the LocalBackend (the production agent-API behaviour); only the +// literal "neon" selects the Neon HTTP backend. +func TestNew_PicksLocalByDefault(t *testing.T) { + cases := []struct { + name string + cfg *config.Config + wantLocal bool + }{ + {"empty_default_local", &config.Config{PostgresProvisionBackend: ""}, true}, + {"explicit_local", &config.Config{PostgresProvisionBackend: "local"}, true}, + {"unknown_falls_back_to_local", &config.Config{PostgresProvisionBackend: "bogus"}, true}, + {"neon_selects_neon", &config.Config{PostgresProvisionBackend: "neon", NeonAPIKey: "k", NeonRegionID: "aws-us-east-1"}, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + p := New(tc.cfg, "") + if p == nil || p.backend == nil { + t.Fatalf("New: returned nil backend") + } + _, isLocal := p.backend.(*LocalBackend) + _, isNeon := p.backend.(*NeonBackend) + if tc.wantLocal && !isLocal { + t.Fatalf("want LocalBackend, got %T", p.backend) + } + if !tc.wantLocal && !isNeon { + t.Fatalf("want NeonBackend, got %T", p.backend) + } + }) + } +} + +// TestProvider_ForwardsToBackend — the four Provider methods must +// faithfully delegate to the configured Backend. We swap in a stub and +// assert (a) arguments arrive as-passed, (b) returns propagate as-is, +// (c) errors propagate as-is. +func TestProvider_ForwardsToBackend(t *testing.T) { + creds := &Credentials{URL: "postgres://x@h/y", DatabaseName: "y", Username: "x"} + wantErr := errors.New("boom") + s := &stubBackend{ + provCreds: creds, + storageSize: 12345, + storageErr: nil, + provErr: nil, + deprovErr: wantErr, + } + p := &Provider{backend: s} + + t.Run("Provision", func(t *testing.T) { + got, err := p.Provision(context.Background(), "tok-1", "pro") + if err != nil || got != creds { + t.Fatalf("Provision: got=%v err=%v", got, err) + } + if s.gotToken != "tok-1" || s.gotTier != "pro" || s.gotExts != nil { + t.Fatalf("Provision: backend got token=%q tier=%q exts=%v", s.gotToken, s.gotTier, s.gotExts) + } + }) + + t.Run("ProvisionWithExtensions", func(t *testing.T) { + got, err := p.ProvisionWithExtensions(context.Background(), "tok-2", "team", []string{"vector"}) + if err != nil || got != creds { + t.Fatalf("ProvisionWithExtensions: got=%v err=%v", got, err) + } + if s.gotToken != "tok-2" || s.gotTier != "team" || len(s.gotExts) != 1 || s.gotExts[0] != "vector" { + t.Fatalf("ProvisionWithExtensions: backend got token=%q tier=%q exts=%v", s.gotToken, s.gotTier, s.gotExts) + } + }) + + t.Run("StorageBytes", func(t *testing.T) { + n, err := p.StorageBytes(context.Background(), "tok-3", "prid-3") + if err != nil || n != 12345 { + t.Fatalf("StorageBytes: n=%d err=%v", n, err) + } + if s.gotToken != "tok-3" || s.gotPRID != "prid-3" { + t.Fatalf("StorageBytes: backend got token=%q prid=%q", s.gotToken, s.gotPRID) + } + }) + + t.Run("Deprovision_PropagatesError", func(t *testing.T) { + err := p.Deprovision(context.Background(), "tok-4", "prid-4") + if !errors.Is(err, wantErr) { + t.Fatalf("Deprovision: want %v got %v", wantErr, err) + } + if s.gotToken != "tok-4" || s.gotPRID != "prid-4" { + t.Fatalf("Deprovision: backend got token=%q prid=%q", s.gotToken, s.gotPRID) + } + }) +}