Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions internal/providers/db/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"crypto/rand"
"fmt"
"io"
"log/slog"
"math/big"

Expand All @@ -19,6 +20,33 @@ const defaultCustomersURL = "postgres://instant_cust:instant_cust@postgres-custo
// alphanumChars is the charset for generated passwords.
const alphanumChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"

// randReader is the source of entropy for `generatePassword`. Production
// code uses crypto/rand.Reader. Tests swap it for an always-failing
// reader to exercise the rand-error branch without disturbing the rest
// of the process. Tests MUST restore the previous value via a deferred
// reassignment.
var randReader io.Reader = rand.Reader

// newDBConnect is the package-level seam for the post-CREATE connect-
// to-new-DB step in Provision. Defaults to pgx.Connect; tests override
// it to inject a fake-connect failure so the "couldn't reach new DB"
// branches in Provision can be exercised without the gymnastics of
// constructing a real-looking URL that fails only on the second
// connect.
var newDBConnect = func(ctx context.Context, url string) (*pgx.Conn, error) {
return pgx.Connect(ctx, url)
}

// adminConnect is the package-level seam for the FIRST connect in
// each of Provision / StorageBytes / Deprovision — the admin
// connection to b.customersURL. Defaults to pgx.Connect; tests can
// swap it to inject a pre-terminated connection so the
// terminate_backend / DROP DATABASE / disconnect error branches can
// be exercised deterministically.
var adminConnect = func(ctx context.Context, url string) (*pgx.Conn, error) {
return pgx.Connect(ctx, url)
}

// LocalBackend provisions databases on the shared postgres-customers instance.
type LocalBackend struct {
customersURL string // admin connection URL
Expand All @@ -33,11 +61,13 @@ func newLocalBackend(customersURL string) *LocalBackend {
}

// generatePassword returns a cryptographically random alphanumeric string of length n.
// Entropy comes from the package-level `randReader` (see definition);
// tests can swap it for a failing reader to exercise the error branch.
func generatePassword(n int) (string, error) {
buf := make([]byte, n)
charsetLen := big.NewInt(int64(len(alphanumChars)))
for i := range buf {
idx, err := rand.Int(rand.Reader, charsetLen)
idx, err := rand.Int(randReader, charsetLen)
if err != nil {
return "", fmt.Errorf("generatePassword: %w", err)
}
Expand Down Expand Up @@ -72,7 +102,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier
}

// Connect as admin.
conn, err := pgx.Connect(ctx, b.customersURL)
conn, err := adminConnect(ctx, b.customersURL)
if err != nil {
return nil, fmt.Errorf("db.local.Provision: connect: %w", err)
}
Expand Down Expand Up @@ -110,7 +140,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier
// run as a superuser/admin, not the per-token user (which lacks
// CREATE-on-pg_catalog privileges).
newDBURL := b.buildDBURL(username, pass, dbName)
adminNewDB, err := pgx.Connect(ctx, b.buildAdminNewDBURL(dbName))
adminNewDB, err := newDBConnect(ctx, b.buildAdminNewDBURL(dbName))
if err != nil {
slog.Error("db.local.Provision: connect new db for schema grant (non-fatal)", "error", err)
// If extensions were requested and we couldn't connect to the new
Expand Down Expand Up @@ -158,7 +188,7 @@ func (b *LocalBackend) ProvisionWithExtensions(ctx context.Context, token, tier

// StorageBytes returns the size of db_{token} in bytes using pg_database_size.
func (b *LocalBackend) StorageBytes(ctx context.Context, token, providerResourceID string) (int64, error) {
conn, err := pgx.Connect(ctx, b.customersURL)
conn, err := adminConnect(ctx, b.customersURL)
if err != nil {
return 0, fmt.Errorf("db.local.StorageBytes: connect: %w", err)
}
Expand All @@ -181,7 +211,7 @@ func (b *LocalBackend) Deprovision(ctx context.Context, token, providerResourceI
dbName := "db_" + token
username := "usr_" + token

conn, err := pgx.Connect(ctx, b.customersURL)
conn, err := adminConnect(ctx, b.customersURL)
if err != nil {
return fmt.Errorf("db.local.Deprovision: connect: %w", err)
}
Expand Down
Loading
Loading