From c9df51afc187f63e558853ee92b63c16e75b76be Mon Sep 17 00:00:00 2001 From: Manas Srivastava Date: Fri, 22 May 2026 09:09:01 +0530 Subject: [PATCH] =?UTF-8?q?test(coverage):=20drive=20internal/middleware?= =?UTF-8?q?=20to=2095.3%=20(=E2=89=A595%=20target)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit middleware was ~51.8%; black-box + white-box suites now cover auth bearer/JWT extraction, fingerprint masking + dedup cap, geo fail-open (+ real MaxMind fixture), rate-limit fail-open on Redis error, idempotency canonicalisation, DPoP helpers, role lookup, and the NewRelic emit path. Also fixes a flaky base64-prefix assertion in the already-merged crypto coverage test (a random nonce can legitimately begin with "v"; the real invariant is "not a vN. versioned envelope"). internal/middleware 95.3% · internal/crypto 95.6% Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/crypto/coverage_test.go | 4 +- internal/middleware/coverage_external_test.go | 874 ++++++++++++ internal/middleware/coverage_internal_test.go | 339 +++++ internal/middleware/coverage_more_test.go | 1225 +++++++++++++++++ .../middleware/coverage_push3_ext_test.go | 50 + internal/middleware/coverage_push3_test.go | 142 ++ .../testdata/geo-combined-test.mmdb | Bin 0 -> 2502 bytes 7 files changed, 2633 insertions(+), 1 deletion(-) create mode 100644 internal/middleware/coverage_external_test.go create mode 100644 internal/middleware/coverage_internal_test.go create mode 100644 internal/middleware/coverage_more_test.go create mode 100644 internal/middleware/coverage_push3_ext_test.go create mode 100644 internal/middleware/coverage_push3_test.go create mode 100644 internal/middleware/testdata/geo-combined-test.mmdb diff --git a/internal/crypto/coverage_test.go b/internal/crypto/coverage_test.go index f22f1b8..cbbf9b0 100644 --- a/internal/crypto/coverage_test.go +++ b/internal/crypto/coverage_test.go @@ -243,7 +243,9 @@ func TestKeyring_Decrypt_LegacyUnversioned(t *testing.T) { keyA := mustKey(t, coverageKeyHexA) legacy, err := crypto.Encrypt(keyA, "pre-rotation-secret") require.NoError(t, err) - require.False(t, strings.HasPrefix(legacy, "v")) + // A raw Encrypt envelope is base64 (random nonce), so it may coincidentally + // begin with "v"; the invariant is that it is NOT a "vN." versioned envelope. + require.False(t, len(legacy) >= 3 && legacy[0] == 'v' && legacy[2] == '.' && legacy[1] >= '1' && legacy[1] <= '9') kr, err := crypto.NewKeyring('1', map[byte][]byte{'1': keyA}) require.NoError(t, err) diff --git a/internal/middleware/coverage_external_test.go b/internal/middleware/coverage_external_test.go new file mode 100644 index 0000000..1acdab3 --- /dev/null +++ b/internal/middleware/coverage_external_test.go @@ -0,0 +1,874 @@ +package middleware_test + +// coverage_external_test.go — black-box tests for the exported middleware +// surface: GeoEnrich (fail-open when MaxMind absent), SecurityHeaders, +// Telemetry, RequestID, NewRelic no-op, JTI revocation (miniredis), +// env-policy / role-lookup / api-key DB paths (sqlmock), and presign +// per-token rate limiting (miniredis). Covers the fail-open branches the +// brief calls out for geo + rate-limit. + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/alicebob/miniredis/v2" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "instant.dev/internal/middleware" +) + +func newMiniRedis(t *testing.T) (*redis.Client, func()) { + t.Helper() + mr, err := miniredis.Run() + require.NoError(t, err) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + return rdb, func() { rdb.Close(); mr.Close() } +} + +// --------------------------------------------------------------------------- +// geo.go — fail-open path (MaxMind absent) and getters +// --------------------------------------------------------------------------- + +func TestGeoEnrich_NilDBs_FailsOpenWithDefaults(t *testing.T) { + app := fiber.New() + app.Use(middleware.GeoEnrich(nil)) // MaxMind absent — fail-open + app.Get("/g", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{ + "country": middleware.GetGeoCountry(c), + "asn": middleware.GetGeoASN(c), + "org": middleware.GetGeoOrgName(c), + "vendor": middleware.GetCloudVendor(c), + }) + }) + + req := httptest.NewRequest(http.MethodGet, "/g", nil) + req.Header.Set("X-Forwarded-For", "8.8.8.8") + resp, err := app.Test(req, 3000) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode, + "GeoEnrich must fail open (200) when MaxMind DB is missing") +} + +func TestGeoEnrich_WithEmptyDBs(t *testing.T) { + // Non-nil GeoDBs with nil readers exercises enrichFromIP's nil-reader + // guards while still parsing the IP. + app := fiber.New(fiber.Config{ProxyHeader: "X-Forwarded-For"}) + app.Use(middleware.GeoEnrich(&middleware.GeoDBs{})) + app.Get("/g", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + req := httptest.NewRequest(http.MethodGet, "/g", nil) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + resp, err := app.Test(req, 3000) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestGeoGetters_DefaultsWhenLocalsMissing(t *testing.T) { + app := fiber.New() + app.Get("/x", func(c *fiber.Ctx) error { + assert.Equal(t, "XX", middleware.GetGeoCountry(c)) + assert.EqualValues(t, 0, middleware.GetGeoASN(c)) + assert.Equal(t, "unknown", middleware.GetGeoOrgName(c)) + assert.Equal(t, "unknown", middleware.GetCloudVendor(c)) + return c.SendStatus(fiber.StatusOK) + }) + req := httptest.NewRequest(http.MethodGet, "/x", nil) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + resp.Body.Close() +} + +func TestLoadGeoLite2_MissingFileReturnsNil(t *testing.T) { + dbs := middleware.LoadGeoLite2("/nonexistent/GeoLite2-City.mmdb") + assert.Nil(t, dbs, "LoadGeoLite2 must return nil (not panic) when the file is absent") +} + +// TestGeoEnrich_RealMMDBPopulatesLocals loads a hand-built combined +// City+ASN MaxMind fixture (testdata/geo-combined-test.mmdb: 8.8.8.0/24 → +// US, ASN 16509 = AMAZON-02 = "aws") and asserts enrichFromIP populates +// every geo local — the success path the nil-DB fail-open test can't reach. +func TestGeoEnrich_RealMMDBPopulatesLocals(t *testing.T) { + dbs := middleware.LoadGeoLite2("testdata/geo-combined-test.mmdb") + require.NotNil(t, dbs, "fixture mmdb must load") + + app := fiber.New(fiber.Config{ProxyHeader: "X-Forwarded-For"}) + app.Use(middleware.GeoEnrich(dbs)) + app.Get("/g", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{ + "country": middleware.GetGeoCountry(c), + "asn": middleware.GetGeoASN(c), + "org": middleware.GetGeoOrgName(c), + "vendor": middleware.GetCloudVendor(c), + }) + }) + + req := httptest.NewRequest(http.MethodGet, "/g", nil) + req.Header.Set("X-Forwarded-For", "8.8.8.8") + resp, err := app.Test(req, 3000) + require.NoError(t, err) + defer resp.Body.Close() + + var body struct { + Country string `json:"country"` + ASN uint `json:"asn"` + Org string `json:"org"` + Vendor string `json:"vendor"` + } + require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Equal(t, "US", body.Country) + assert.EqualValues(t, 16509, body.ASN) + assert.Equal(t, "AMAZON-02", body.Org) + assert.Equal(t, "aws", body.Vendor, "ASN 16509 maps to the aws cloud vendor") +} + +// --------------------------------------------------------------------------- +// admin.go — AdminEmailAllowlist / IsAdminEmail +// --------------------------------------------------------------------------- + +func TestAdminEmailAllowlist_AndIsAdminEmail(t *testing.T) { + // Unset → nil allowlist, nobody is admin. + t.Setenv("ADMIN_EMAILS", "") + assert.Nil(t, middleware.AdminEmailAllowlist()) + assert.False(t, middleware.IsAdminEmail("a@b.com")) + assert.False(t, middleware.IsAdminEmail(""), "empty email is never admin") + + // Only-commas / blanks → still nil. + t.Setenv("ADMIN_EMAILS", " , , ") + assert.Nil(t, middleware.AdminEmailAllowlist()) + + // Populated, case-insensitive, whitespace-trimmed. + t.Setenv("ADMIN_EMAILS", "Root@Example.com, ops@x.io ") + allow := middleware.AdminEmailAllowlist() + require.NotNil(t, allow) + assert.True(t, allow["root@example.com"]) + assert.True(t, allow["ops@x.io"]) + assert.True(t, middleware.IsAdminEmail("ROOT@example.com")) + assert.True(t, middleware.IsAdminEmail(" ops@x.io ")) + assert.False(t, middleware.IsAdminEmail("intruder@evil.com")) +} + +// --------------------------------------------------------------------------- +// env_policy.go — loadEnvPolicy branches via RequireEnvAccess (no-rows, error, +// malformed JSON all fail-open to allow) +// --------------------------------------------------------------------------- + +func envPolicyApp(t *testing.T, tid uuid.UUID) (*fiber.App, sqlmock.Sqlmock, func()) { + t.Helper() + db, mock, err := sqlmock.New() + require.NoError(t, err) + middleware.SetEnvPolicyDB(db) + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyTeamID, tid.String()) + c.Locals(middleware.LocalKeyTeamRole, "developer") + return c.Next() + }) + app.Post("/deploy", + middleware.RequireEnvAccess(middleware.EnvPolicyActionDeploy), + func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }, + ) + return app, mock, func() { middleware.SetEnvPolicyDB(nil); db.Close() } +} + +func TestRequireEnvAccess_LoadEnvPolicyBranches(t *testing.T) { + tid := uuid.New() + + cases := []struct { + name string + set func(m sqlmock.Sqlmock) + }{ + {"no_rows", func(m sqlmock.Sqlmock) { + m.ExpectQuery("SELECT env_policy FROM teams").WithArgs(tid).WillReturnError(sql.ErrNoRows) + }}, + {"db_error", func(m sqlmock.Sqlmock) { + m.ExpectQuery("SELECT env_policy FROM teams").WithArgs(tid).WillReturnError(sql.ErrConnDone) + }}, + {"malformed_json", func(m sqlmock.Sqlmock) { + m.ExpectQuery("SELECT env_policy FROM teams").WithArgs(tid). + WillReturnRows(sqlmock.NewRows([]string{"env_policy"}).AddRow([]byte(`{not json`))) + }}, + {"nil_bytes", func(m sqlmock.Sqlmock) { + m.ExpectQuery("SELECT env_policy FROM teams").WithArgs(tid). + WillReturnRows(sqlmock.NewRows([]string{"env_policy"}).AddRow([]byte(nil))) + }}, + {"env_missing_in_policy", func(m sqlmock.Sqlmock) { + m.ExpectQuery("SELECT env_policy FROM teams").WithArgs(tid). + WillReturnRows(sqlmock.NewRows([]string{"env_policy"}).AddRow([]byte(`{"staging":{"deploy":["owner"]}}`))) + }}, + {"action_missing_for_env", func(m sqlmock.Sqlmock) { + m.ExpectQuery("SELECT env_policy FROM teams").WithArgs(tid). + WillReturnRows(sqlmock.NewRows([]string{"env_policy"}).AddRow([]byte(`{"production":{"vault_write":["owner"]}}`))) + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + app, mock, clean := envPolicyApp(t, tid) + defer clean() + tc.set(mock) + // env=production so the lookup resolves; all of these branches + // must FAIL OPEN to 200 (never lock the team out). + resp, err := app.Test(httptest.NewRequest(http.MethodPost, "/deploy?env=production", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode, "branch %s must fail open to allow", tc.name) + resp.Body.Close() + }) + } +} + +// --------------------------------------------------------------------------- +// fingerprint.go — production XFF (rightmost hop) + E2E bypass override +// --------------------------------------------------------------------------- + +func TestFingerprintMiddleware_ProductionUsesRightmostXFF(t *testing.T) { + var fpA, fpB string + app := fiber.New(fiber.Config{ProxyHeader: "X-Forwarded-For"}) + app.Use(middleware.FingerprintMiddleware(middleware.FingerprintConfig{Production: true})) + app.Get("/p", func(c *fiber.Ctx) error { + c.Set("X-FP", middleware.GetFingerprint(c)) + return c.SendStatus(fiber.StatusOK) + }) + + get := func(xff string) string { + req := httptest.NewRequest(http.MethodGet, "/p", nil) + req.Header.Set("X-Forwarded-For", xff) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + defer resp.Body.Close() + return resp.Header.Get("X-FP") + } + // Production uses the LAST (trusted edge) hop. Both XFF lists end in the + // same trusted hop → identical fingerprints despite different client IPs. + fpA = get("1.1.1.1, 9.9.9.9") + fpB = get("2.2.2.2, 9.9.9.9") + assert.NotEmpty(t, fpA) + assert.Equal(t, fpA, fpB, "production fingerprint keys on the rightmost (edge) hop") +} + +func TestFingerprintMiddleware_E2EBypassOverridesIP(t *testing.T) { + t.Setenv("E2E_TEST_TOKEN", "shared-e2e-secret-value") + + app := fiber.New(fiber.Config{ProxyHeader: "X-Forwarded-For"}) + app.Use(middleware.FingerprintMiddleware(middleware.FingerprintConfig{Production: true})) + app.Get("/p", func(c *fiber.Ctx) error { + c.Set("X-FP", middleware.GetFingerprint(c)) + return c.SendStatus(fiber.StatusOK) + }) + + get := func(token, sourceIP string) string { + req := httptest.NewRequest(http.MethodGet, "/p", nil) + req.Header.Set("X-Forwarded-For", "9.9.9.9") // same edge hop for all + if token != "" { + req.Header.Set("X-E2E-Test-Token", token) + } + if sourceIP != "" { + req.Header.Set("X-E2E-Source-IP", sourceIP) + } + resp, err := app.Test(req, 3000) + require.NoError(t, err) + defer resp.Body.Close() + return resp.Header.Get("X-FP") + } + + // With a valid token the source-IP header drives the fingerprint, so two + // different override IPs (same XFF edge) yield DIFFERENT fingerprints. + fp1 := get("shared-e2e-secret-value", "10.0.0.1") + fp2 := get("shared-e2e-secret-value", "10.1.0.1") + assert.NotEqual(t, fp1, fp2, "valid E2E token must let X-E2E-Source-IP drive the fingerprint") + + // A wrong token is ignored → falls back to the (shared) edge hop. + fpBad := get("wrong-token", "10.0.0.1") + fpNoTok := get("", "") + assert.Equal(t, fpBad, fpNoTok, "invalid/absent E2E token → fall back to XFF edge hop") +} + +// --------------------------------------------------------------------------- +// idempotency.go — explicit Idempotency-Key replay + empty-key 400 reject +// --------------------------------------------------------------------------- + +func TestIdempotency_ExplicitKeyReplay(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + + calls := 0 + app := fiber.New() + app.Post("/db/new", + middleware.Idempotency(rdb, "db.new"), + func(c *fiber.Ctx) error { + calls++ + return c.Status(fiber.StatusCreated).JSON(fiber.Map{"n": calls}) + }, + ) + key := "explicit-" + uuid.NewString() + send := func() *http.Response { + req := httptest.NewRequest(http.MethodPost, "/db/new", bytes.NewReader([]byte(`{"name":"a"}`))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", key) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + return resp + } + r1 := send() + r1.Body.Close() + r2 := send() + r2.Body.Close() + assert.Equal(t, 1, calls, "explicit-key replay must run the handler once") +} + +func TestIdempotency_ExplicitKeyConflict(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + app := fiber.New() + app.Post("/db/new", + middleware.Idempotency(rdb, "db.new"), + func(c *fiber.Ctx) error { return c.Status(fiber.StatusCreated).SendString("ok") }, + ) + key := "conflict-" + uuid.NewString() + send := func(body string) *http.Response { + req := httptest.NewRequest(http.MethodPost, "/db/new", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", key) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + return resp + } + r1 := send(`{"name":"a"}`) + r1.Body.Close() + // Same key, DIFFERENT body → 409 idempotency_key_conflict. + r2 := send(`{"name":"b"}`) + defer r2.Body.Close() + assert.Equal(t, http.StatusConflict, r2.StatusCode, + "same Idempotency-Key with a different body must 409") +} + +func TestIdempotency_5xxNotCached(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + calls := 0 + app := fiber.New() + app.Post("/db/new", + middleware.Idempotency(rdb, "db.new"), + func(c *fiber.Ctx) error { + calls++ + return c.SendStatus(fiber.StatusInternalServerError) // 5xx → not cached + }, + ) + key := "fivexx-" + uuid.NewString() + send := func() *http.Response { + req := httptest.NewRequest(http.MethodPost, "/db/new", bytes.NewReader([]byte(`{"n":1}`))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", key) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + return resp + } + send().Body.Close() + send().Body.Close() + assert.Equal(t, 2, calls, "5xx responses must NOT be cached — handler runs each time") +} + +// --------------------------------------------------------------------------- +// security_headers.go +// --------------------------------------------------------------------------- + +func TestSecurityHeaders_ProdEmitsHSTS(t *testing.T) { + app := fiber.New() + app.Use(middleware.SecurityHeaders(true)) + app.Get("/h", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/h", nil), 3000) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, middleware.HSTSValue, resp.Header.Get("Strict-Transport-Security")) + assert.Equal(t, middleware.XContentTypeOptionsValue, resp.Header.Get("X-Content-Type-Options")) + assert.Equal(t, middleware.XFrameOptionsValue, resp.Header.Get("X-Frame-Options")) + assert.Equal(t, middleware.ReferrerPolicyValue, resp.Header.Get("Referrer-Policy")) + assert.Equal(t, middleware.PermissionsPolicyValue, resp.Header.Get("Permissions-Policy")) + assert.Equal(t, middleware.CrossOriginResourcePolicyValue, resp.Header.Get("Cross-Origin-Resource-Policy")) +} + +func TestSecurityHeaders_DevOmitsHSTS(t *testing.T) { + app := fiber.New() + app.Use(middleware.SecurityHeaders(false)) + app.Get("/h", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/h", nil), 3000) + require.NoError(t, err) + defer resp.Body.Close() + assert.Empty(t, resp.Header.Get("Strict-Transport-Security"), + "dev/non-prod must NOT advertise HSTS") + assert.Equal(t, middleware.XContentTypeOptionsValue, resp.Header.Get("X-Content-Type-Options")) +} + +// --------------------------------------------------------------------------- +// telemetry.go +// --------------------------------------------------------------------------- + +func TestTelemetry_RecordsAndPassesThrough(t *testing.T) { + app := fiber.New() + app.Use(middleware.Telemetry()) + app.Get("/ok", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + app.Get("/boom", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusInternalServerError) }) + + for path, want := range map[string]int{"/ok": 200, "/boom": 500} { + resp, err := app.Test(httptest.NewRequest(http.MethodGet, path, nil), 3000) + require.NoError(t, err) + assert.Equal(t, want, resp.StatusCode) + resp.Body.Close() + } +} + +// --------------------------------------------------------------------------- +// request_id.go +// --------------------------------------------------------------------------- + +func TestRequestID_GeneratesAndPropagates(t *testing.T) { + app := fiber.New() + app.Use(middleware.RequestID()) + app.Get("/r", func(c *fiber.Ctx) error { + // In-handler the locals + Go context both carry the id. + assert.NotEmpty(t, middleware.GetRequestID(c)) + assert.Equal(t, middleware.GetRequestID(c), + middleware.RequestIDFromContext(c.UserContext())) + return c.SendStatus(fiber.StatusOK) + }) + + // Generated when absent. + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/r", nil), 3000) + require.NoError(t, err) + assert.NotEmpty(t, resp.Header.Get(middleware.HeaderRequestID)) + resp.Body.Close() + + // Propagated when supplied. + req := httptest.NewRequest(http.MethodGet, "/r", nil) + req.Header.Set(middleware.HeaderRequestID, "supplied-id-123") + resp2, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, "supplied-id-123", resp2.Header.Get(middleware.HeaderRequestID)) + resp2.Body.Close() +} + +func TestRequestIDFromContext_Empty(t *testing.T) { + assert.Empty(t, middleware.RequestIDFromContext(context.Background())) +} + +func TestGetRequestID_EmptyWhenUnset(t *testing.T) { + app := fiber.New() + app.Get("/n", func(c *fiber.Ctx) error { + assert.Empty(t, middleware.GetRequestID(c)) + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/n", nil), 3000) + require.NoError(t, err) + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// newrelic.go — nil agent degrades to no-op +// --------------------------------------------------------------------------- + +func TestNewRelic_NilAppNoOp(t *testing.T) { + app := fiber.New() + app.Use(middleware.NewRelic(nil)) + app.Get("/nr", func(c *fiber.Ctx) error { + assert.Nil(t, middleware.GetNRTxn(c), "no txn when agent disabled") + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/nr", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() +} + +// newrelic_metrics.go emit helpers no-op when the global app is nil (default +// in tests). Calling them must not panic and exercises the nil-guard branch. +func TestNewRelicMetrics_NoOpWhenNilApp(t *testing.T) { + middleware.SetNRApp(nil) + middleware.RecordProvisionSuccess("postgres") + middleware.RecordProvisionFail("redis", "quota") + middleware.RecordProvisionFail("redis", "") // empty reason branch + middleware.RecordResourceExpired("mongodb") +} + +// --------------------------------------------------------------------------- +// revocation.go — fail-open on Redis error, hit/miss on live miniredis +// --------------------------------------------------------------------------- + +func TestIsJTIRevoked_NilClientAndEmptyJTI(t *testing.T) { + middleware.SetRevocationDB(nil) + revoked, err := middleware.IsJTIRevoked(context.Background(), "jti") + require.NoError(t, err) + assert.False(t, revoked, "nil client → not revoked (fail-open)") +} + +func TestIsJTIRevoked_HitAndMiss(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + middleware.SetRevocationDB(rdb) + defer middleware.SetRevocationDB(nil) + + ctx := context.Background() + // Empty jti short-circuits. + revoked, err := middleware.IsJTIRevoked(ctx, "") + require.NoError(t, err) + assert.False(t, revoked) + + // Not present → not revoked. + revoked, err = middleware.IsJTIRevoked(ctx, "abc") + require.NoError(t, err) + assert.False(t, revoked) + + // Present → revoked. + require.NoError(t, rdb.Set(ctx, "session.revoked:abc", "1", 0).Err()) + revoked, err = middleware.IsJTIRevoked(ctx, "abc") + require.NoError(t, err) + assert.True(t, revoked) +} + +func TestIsJTIRevoked_RedisErrorFailsOpen(t *testing.T) { + rdb, clean := newMiniRedis(t) + clean() // close immediately → every op errors + middleware.SetRevocationDB(rdb) + defer middleware.SetRevocationDB(nil) + + revoked, err := middleware.IsJTIRevoked(context.Background(), "abc") + assert.Error(t, err, "a Redis error must surface") + assert.False(t, revoked, "but the request still treats it as not-revoked (fail-open)") +} + +// --------------------------------------------------------------------------- +// presign_token_rate_limit.go — nil rdb no-op, allow under cap, 429 over cap +// --------------------------------------------------------------------------- + +func presignApp(rdb *redis.Client) *fiber.App { + app := fiber.New() + // Register the limiter as a route-level handler (NOT app.Use) so the + // :token route param is bound — matching the production router wiring. + app.Post("/storage/:token/presign", + middleware.PresignTokenRateLimit(rdb), + func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }, + ) + return app +} + +func TestPresignRateLimit_NilRedisNoOp(t *testing.T) { + app := presignApp(nil) + resp, err := app.Test(httptest.NewRequest(http.MethodPost, "/storage/tok123/presign", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() +} + +func TestPresignRateLimit_AllowsUnderCapThen429(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + app := presignApp(rdb) + + tok := uuid.NewString() + url := "/storage/" + tok + "/presign" + // The first request is always allowed; once the rolling-window count + // reaches the cap a subsequent request returns 429 with a Retry-After. + // Send a burst well past the cap and assert the limit eventually trips. + var firstStatus, last429 int + for i := 0; i < middleware.PresignPerTokenPerMinute+5; i++ { + resp, err := app.Test(httptest.NewRequest(http.MethodPost, url, nil), 3000) + require.NoError(t, err) + if i == 0 { + firstStatus = resp.StatusCode + } + if resp.StatusCode == http.StatusTooManyRequests { + last429 = resp.StatusCode + assert.NotEmpty(t, resp.Header.Get(fiber.HeaderRetryAfter), + "429 must carry a Retry-After header") + } + resp.Body.Close() + } + assert.Equal(t, http.StatusOK, firstStatus, "first request must be allowed") + assert.Equal(t, http.StatusTooManyRequests, last429, + "a burst past the cap must eventually trip the per-token limit") +} + +func TestPresignRateLimit_RedisErrorFailsOpen(t *testing.T) { + rdb, clean := newMiniRedis(t) + clean() // closed → pipeline errors + app := presignApp(rdb) + resp, err := app.Test(httptest.NewRequest(http.MethodPost, "/storage/tok/presign", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode, "Redis error must fail open (200, request proceeds)") + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// role_lookup.go — sqlmock-backed PopulateTeamRole +// --------------------------------------------------------------------------- + +func TestPopulateTeamRole_NoIDsPassesThrough(t *testing.T) { + middleware.SetRoleLookupDB(nil) + defer middleware.SetRoleLookupDB(nil) + app := fiber.New() + app.Use(middleware.PopulateTeamRole()) + app.Get("/role", func(c *fiber.Ctx) error { + assert.Empty(t, middleware.GetTeamRole(c)) + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/role", nil), 3000) + require.NoError(t, err) + resp.Body.Close() +} + +func TestPopulateTeamRole_ResolvesRole(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + middleware.SetRoleLookupDB(db) + defer middleware.SetRoleLookupDB(nil) + + uid := uuid.NewString() + tid := uuid.NewString() + mock.ExpectQuery("SELECT role FROM users"). + WithArgs(uid, tid). + WillReturnRows(sqlmock.NewRows([]string{"role"}).AddRow("owner")) + + app := fiber.New() + // Inject locals to simulate RequireAuth having populated them. + app.Use(func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyUserID, uid) + c.Locals(middleware.LocalKeyTeamID, tid) + return c.Next() + }) + app.Use(middleware.PopulateTeamRole()) + app.Get("/role", func(c *fiber.Ctx) error { + assert.Equal(t, "owner", middleware.GetTeamRole(c)) + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/role", nil), 3000) + require.NoError(t, err) + resp.Body.Close() +} + +func TestPopulateTeamRole_NoRowsIsTolerated(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + middleware.SetRoleLookupDB(db) + defer middleware.SetRoleLookupDB(nil) + + uid := uuid.NewString() + tid := uuid.NewString() + mock.ExpectQuery("SELECT role FROM users"). + WithArgs(uid, tid). + WillReturnError(sql.ErrNoRows) + + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyUserID, uid) + c.Locals(middleware.LocalKeyTeamID, tid) + return c.Next() + }) + app.Use(middleware.PopulateTeamRole()) + app.Get("/role", func(c *fiber.Ctx) error { + assert.Empty(t, middleware.GetTeamRole(c)) + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/role", nil), 3000) + require.NoError(t, err) + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// env_policy.go — RequireEnvAccess DB paths via sqlmock +// --------------------------------------------------------------------------- + +func TestRequireEnvAccess_NilDBAllows(t *testing.T) { + middleware.SetEnvPolicyDB(nil) + defer middleware.SetEnvPolicyDB(nil) + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyTeamID, uuid.NewString()) + return c.Next() + }) + app.Use(middleware.RequireEnvAccess(middleware.EnvPolicyActionDeploy)) + app.Post("/deploy", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + resp, err := app.Test(httptest.NewRequest(http.MethodPost, "/deploy", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode, "no policy DB → allow (fail-open)") + resp.Body.Close() +} + +func TestRequireEnvAccess_EmptyPolicyAllows(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + middleware.SetEnvPolicyDB(db) + defer middleware.SetEnvPolicyDB(nil) + + tid := uuid.New() + mock.ExpectQuery("SELECT env_policy FROM teams"). + WithArgs(tid). + WillReturnRows(sqlmock.NewRows([]string{"env_policy"}).AddRow([]byte(`{}`))) + + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyTeamID, tid.String()) + return c.Next() + }) + app.Use(middleware.RequireEnvAccess(middleware.EnvPolicyActionDeploy)) + app.Post("/deploy", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + resp, err := app.Test(httptest.NewRequest(http.MethodPost, "/deploy", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode, "empty policy → allow") + resp.Body.Close() +} + +func TestRequireEnvAccess_RoleAllowedAndDenied(t *testing.T) { + tid := uuid.New() + // Policy: production/deploy allows owner only. + policy := `{"production":{"deploy":["owner"]}}` + + run := func(role string, wantStatus int) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + middleware.SetEnvPolicyDB(db) + defer middleware.SetEnvPolicyDB(nil) + + mock.ExpectQuery("SELECT env_policy FROM teams"). + WithArgs(tid). + WillReturnRows(sqlmock.NewRows([]string{"env_policy"}).AddRow([]byte(policy))) + + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyTeamID, tid.String()) + c.Locals(middleware.LocalKeyTeamRole, role) + return c.Next() + }) + app.Use(middleware.RequireEnvAccess(middleware.EnvPolicyActionDeploy)) + app.Post("/deploy", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + // env=production via query string so defaultEnvLookup resolves it. + resp, err := app.Test(httptest.NewRequest(http.MethodPost, "/deploy?env=production", nil), 3000) + require.NoError(t, err) + assert.Equal(t, wantStatus, resp.StatusCode, "role=%s", role) + resp.Body.Close() + } + + run("owner", http.StatusOK) // allowed + run("developer", http.StatusForbidden) // denied → 403 env_policy_denied +} + +func TestRequireEnvAccess_BadTeamIDPassesThrough(t *testing.T) { + db, _, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + middleware.SetEnvPolicyDB(db) + defer middleware.SetEnvPolicyDB(nil) + + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyTeamID, "not-a-uuid") + return c.Next() + }) + app.Use(middleware.RequireEnvAccess(middleware.EnvPolicyActionDeploy)) + app.Post("/deploy", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + resp, err := app.Test(httptest.NewRequest(http.MethodPost, "/deploy", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode, "unparseable team id → pass through") + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// api_key.go — IsAPIKey / AuthenticateAPIKey nil-DB path / scope getters +// --------------------------------------------------------------------------- + +func TestIsAPIKey(t *testing.T) { + assert.True(t, middleware.IsAPIKey("ink_abc123")) + assert.False(t, middleware.IsAPIKey("not-a-pat")) +} + +func TestAuthenticateAPIKey_Success(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + middleware.SetAPIKeyDB(db) + defer middleware.SetAPIKeyDB(nil) + + keyID := uuid.New() + teamID := uuid.New() + creator := uuid.New() + rows := sqlmock.NewRows([]string{ + "id", "team_id", "created_by", "name", "key_hash", + "scopes", "last_used_at", "revoked_at", "created_at", + }).AddRow( + keyID, teamID, creator, "ci-key", "deadbeef", + "{admin,deploy}", nil, nil, time.Now(), + ) + mock.ExpectQuery("SELECT .* FROM api_keys").WillReturnRows(rows) + // Best-effort async TouchAPIKey may or may not run before assertions; be + // lenient about its UPDATE. + mock.ExpectExec("UPDATE api_keys").WillReturnResult(sqlmock.NewResult(0, 1)) + + app := fiber.New() + app.Get("/k", func(c *fiber.Ctx) error { + ok, aerr := middleware.AuthenticateAPIKey(c, "ink_secret") + require.NoError(t, aerr) + assert.True(t, ok) + assert.True(t, middleware.IsAuthedViaAPIKey(c)) + assert.Equal(t, []string{"admin", "deploy"}, middleware.GetAPIKeyScopes(c)) + assert.Equal(t, teamID.String(), middleware.GetTeamID(c)) + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/k", nil), 3000) + require.NoError(t, err) + resp.Body.Close() +} + +func TestAuthenticateAPIKey_NotFound(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + middleware.SetAPIKeyDB(db) + defer middleware.SetAPIKeyDB(nil) + mock.ExpectQuery("SELECT .* FROM api_keys").WillReturnError(sql.ErrNoRows) + + app := fiber.New() + app.Get("/k", func(c *fiber.Ctx) error { + ok, aerr := middleware.AuthenticateAPIKey(c, "ink_missing") + assert.NoError(t, aerr, "not-found is (false, nil)") + assert.False(t, ok) + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/k", nil), 3000) + require.NoError(t, err) + resp.Body.Close() +} + +func TestAuthenticateAPIKey_NilDB(t *testing.T) { + middleware.SetAPIKeyDB(nil) + app := fiber.New() + app.Get("/k", func(c *fiber.Ctx) error { + ok, err := middleware.AuthenticateAPIKey(c, "ink_whatever") + assert.False(t, ok) + assert.Error(t, err, "nil DB → error") + assert.False(t, middleware.IsAuthedViaAPIKey(c)) + assert.Nil(t, middleware.GetAPIKeyScopes(c)) + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/k", nil), 3000) + require.NoError(t, err) + resp.Body.Close() +} diff --git a/internal/middleware/coverage_internal_test.go b/internal/middleware/coverage_internal_test.go new file mode 100644 index 0000000..ee1eb21 --- /dev/null +++ b/internal/middleware/coverage_internal_test.go @@ -0,0 +1,339 @@ +package middleware + +// coverage_internal_test.go — white-box unit tests for unexported helpers +// across the middleware package. Lives in `package middleware` (not +// middleware_test) so it can reach the unexported pure functions that the +// black-box suites can't see. Targets the 0%-coverage helpers surfaced by +// the coverage audit: env_policy, admin_audit, idempotency canonicalisation, +// presign masking, rate-limit math, and the geo/cloud-vendor lookup table. + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fiberAppForLookup builds a Fiber app whose /q route reports the result of +// the unexported defaultEnvLookup as a response header. +func fiberAppForLookup(t *testing.T) *fiber.App { + t.Helper() + app := fiber.New() + h := func(c *fiber.Ctx) error { + env, err := defaultEnvLookup(c) + require.NoError(t, err) + c.Set("X-Env", env) + return c.SendStatus(fiber.StatusOK) + } + app.Get("/q", h) + app.Post("/q", h) + app.Get("/p/:env", h) + return app +} + +func probeEnvLookup(t *testing.T, app *fiber.App, path, contentType, body string) string { + t.Helper() + method := http.MethodGet + var rdr *bytes.Reader + if body != "" || contentType != "" { + method = http.MethodPost + rdr = bytes.NewReader([]byte(body)) + } + var req *http.Request + if rdr != nil { + req = httptest.NewRequest(method, path, rdr) + } else { + req = httptest.NewRequest(method, path, nil) + } + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + resp, err := app.Test(req, 3000) + require.NoError(t, err) + defer resp.Body.Close() + return resp.Header.Get("X-Env") +} + +// --------------------------------------------------------------------------- +// env_policy.go — formatAllowedRoles / envPolicyDeniedAgentAction +// --------------------------------------------------------------------------- + +func TestFormatAllowedRoles(t *testing.T) { + assert.Equal(t, "", formatAllowedRoles(nil)) + assert.Equal(t, "", formatAllowedRoles([]string{})) + assert.Equal(t, "owner", formatAllowedRoles([]string{"owner"})) + assert.Equal(t, "owner or developer", formatAllowedRoles([]string{"owner", "developer"})) + // 3+ uses Oxford-style comma list ending in "or". + got := formatAllowedRoles([]string{"owner", "admin", "developer"}) + assert.Equal(t, "owner, admin, or developer", got) +} + +func TestEnvPolicyDeniedAgentAction(t *testing.T) { + // Known role. + out := envPolicyDeniedAgentAction("production", "owner", "deploy", "developer") + assert.Contains(t, out, "Tell the user") + assert.Contains(t, out, "production") + assert.Contains(t, out, "owner") + assert.Contains(t, out, "deploy") + assert.Contains(t, out, "developer") + assert.Contains(t, out, "https://instanode.dev/app/team") + + // Empty caller role falls back to "unknown". + out2 := envPolicyDeniedAgentAction("staging", "owner or developer", "vault_write", "") + assert.Contains(t, out2, "unknown") +} + +// --------------------------------------------------------------------------- +// admin_audit.go — adminAuditPathSuffix / parseTeamIDFromAdminPath / +// adminAuditSummary / AdminAuditEnsureMetadataNoPrefix +// --------------------------------------------------------------------------- + +func TestAdminAuditPathSuffix(t *testing.T) { + // Empty prefix → sentinel. + assert.Equal(t, adminAuditSuffixInvalid, adminAuditPathSuffix("/api/v1/secret/customers", "")) + + // Canonical mount → suffix returned. + assert.Equal(t, "customers/abc/tier", + adminAuditPathSuffix("/api/v1/secret/customers/abc/tier", "secret")) + + // Bare prefix, no trailing slash → empty suffix. + assert.Equal(t, "", adminAuditPathSuffix("/api/v1/secret", "secret")) + + // Mismatched prefix → invalid sentinel (don't leak the raw path). + assert.Equal(t, adminAuditSuffixInvalid, + adminAuditPathSuffix("/api/v1/other/customers", "secret")) +} + +func TestParseTeamIDFromAdminPath(t *testing.T) { + id := uuid.New() + // Valid /customers//... segment. + got := parseTeamIDFromAdminPath("/api/v1/secret/customers/" + id.String() + "/tier") + assert.Equal(t, id, got) + + // Trailing segment stripped, bare /customers/. + got2 := parseTeamIDFromAdminPath("/api/v1/secret/customers/" + id.String()) + assert.Equal(t, id, got2) + + // No customers segment → Nil. + assert.Equal(t, uuid.Nil, parseTeamIDFromAdminPath("/api/v1/secret/teams/list")) + + // Non-UUID segment → Nil. + assert.Equal(t, uuid.Nil, parseTeamIDFromAdminPath("/api/v1/secret/customers/not-a-uuid/tier")) +} + +func TestAdminAuditSummary(t *testing.T) { + // Success path with email + suffix. + s := adminAuditSummary(AdminAuditMetadata{Email: "a@b.com", PathSuffix: "customers/x"}) + assert.Equal(t, "a@b.com accessed customers/x", s) + + // Anonymous + root. + s2 := adminAuditSummary(AdminAuditMetadata{}) + assert.Equal(t, "anonymous accessed (root)", s2) + + // Denied path includes reason. + s3 := adminAuditSummary(AdminAuditMetadata{Email: "x@y.com", PathSuffix: "p", DeniedBy: "rate_limit"}) + assert.Equal(t, "x@y.com denied (rate_limit) on p", s3) +} + +func TestAdminAuditEnsureMetadataNoPrefix(t *testing.T) { + // Empty prefix always true. + assert.True(t, AdminAuditEnsureMetadataNoPrefix(AdminAuditMetadata{}, "")) + + // Prefix absent from marshaled metadata → true. + clean := AdminAuditMetadata{Email: "a@b.com", PathSuffix: "customers/x"} + assert.True(t, AdminAuditEnsureMetadataNoPrefix(clean, "topsecret")) + + // Prefix leaked into a field → false. + leaked := AdminAuditMetadata{PathSuffix: "topsecret/customers"} + assert.False(t, AdminAuditEnsureMetadataNoPrefix(leaked, "topsecret")) +} + +// --------------------------------------------------------------------------- +// idempotency.go — looksLikeJSON / writeCanonicalJSON / canonicalJSON +// --------------------------------------------------------------------------- + +func TestLooksLikeJSON(t *testing.T) { + assert.True(t, looksLikeJSON([]byte(`{"a":1}`))) + assert.True(t, looksLikeJSON([]byte(` [1,2]`))) + assert.True(t, looksLikeJSON([]byte("\n\t {}"))) + assert.False(t, looksLikeJSON([]byte("plain text"))) + assert.False(t, looksLikeJSON([]byte(""))) + assert.False(t, looksLikeJSON([]byte(" "))) +} + +func TestWriteCanonicalJSON_SortsKeys(t *testing.T) { + in := map[string]interface{}{ + "b": 2, + "a": 1, + "c": map[string]interface{}{"z": true, "y": false}, + } + var buf bytes.Buffer + require.NoError(t, writeCanonicalJSON(&buf, in)) + // Keys must be sorted recursively. + assert.Equal(t, `{"a":1,"b":2,"c":{"y":false,"z":true}}`, buf.String()) +} + +func TestWriteCanonicalJSON_ArrayOrderPreserved(t *testing.T) { + in := []interface{}{3, 1, 2} + var buf bytes.Buffer + require.NoError(t, writeCanonicalJSON(&buf, in)) + assert.Equal(t, `[3,1,2]`, buf.String()) +} + +func TestCanonicalJSON_NestedMixedTypes(t *testing.T) { + // Exercises the writeCanonicalJSON default case (string/number/bool/null) + // nested inside both maps and arrays. + in := map[string]interface{}{ + "arr": []interface{}{"s", 1.5, true, nil}, + "flag": false, + "obj": map[string]interface{}{"k": "v"}, + } + out, err := canonicalJSON(in) + require.NoError(t, err) + assert.Equal(t, `{"arr":["s",1.5,true,null],"flag":false,"obj":{"k":"v"}}`, out) +} + +func TestCanonicalJSON_OrderIndependent(t *testing.T) { + a, err := canonicalJSON(map[string]interface{}{"x": 1, "y": 2}) + require.NoError(t, err) + b, err := canonicalJSON(map[string]interface{}{"y": 2, "x": 1}) + require.NoError(t, err) + assert.Equal(t, a, b, "key order must not change the canonical form") +} + +// --------------------------------------------------------------------------- +// presign_token_rate_limit.go — maskPresignToken +// --------------------------------------------------------------------------- + +func TestMaskPresignToken(t *testing.T) { + assert.Equal(t, "***", maskPresignToken("short")) + assert.Equal(t, "***", maskPresignToken("12345678")) // exactly 8 → masked + assert.Equal(t, "abcdefgh...", maskPresignToken("abcdefghijklmnop")) +} + +// --------------------------------------------------------------------------- +// rate_limit.go — nextUTCMidnight +// --------------------------------------------------------------------------- + +func TestNextUTCMidnight(t *testing.T) { + in := time.Date(2026, 5, 22, 13, 45, 0, 0, time.UTC) + got := nextUTCMidnight(in) + want := time.Date(2026, 5, 23, 0, 0, 0, 0, time.UTC) + assert.Equal(t, want, got) + assert.True(t, got.After(in)) + + // A non-UTC input is normalised to UTC first. + loc := time.FixedZone("UTC+5", 5*3600) + in2 := time.Date(2026, 12, 31, 23, 0, 0, 0, loc) // 18:00 UTC same day + got2 := nextUTCMidnight(in2) + assert.Equal(t, time.Date(2027, 1, 1, 0, 0, 0, 0, time.UTC), got2) +} + +// --------------------------------------------------------------------------- +// log_scrubber.go — scrub free function + ScrubAdminPath edge cases +// --------------------------------------------------------------------------- + +func TestScrub_Internal(t *testing.T) { + h := &LogScrubber{secret: "sekret"} + assert.Equal(t, "a//b", h.scrub("a/sekret/b")) + // Empty secret on the struct → passthrough. + h2 := &LogScrubber{secret: ""} + assert.Equal(t, "untouched", h2.scrub("untouched")) +} + +func TestScrubAdminPath_Edges(t *testing.T) { + assert.Equal(t, "x", ScrubAdminPath("x", "")) // empty secret + assert.Equal(t, "", ScrubAdminPath("", "secret")) // empty input + assert.Equal(t, strings.Repeat(AdminScrubSentinel, 1)+"/y", + ScrubAdminPath("secret/y", "secret")) +} + +// --------------------------------------------------------------------------- +// geo.go — cloudASNs lookup table sanity +// --------------------------------------------------------------------------- + +func TestCloudASNs_Table(t *testing.T) { + assert.Equal(t, "aws", cloudASNs[16509]) + assert.Equal(t, "gcp", cloudASNs[15169]) + assert.Equal(t, "azure", cloudASNs[8075]) + assert.Equal(t, "cloudflare", cloudASNs[13335]) + _, ok := cloudASNs[1] + assert.False(t, ok, "unknown ASN must not be in the table") +} + +// --------------------------------------------------------------------------- +// auth.go — audienceMatches +// --------------------------------------------------------------------------- + +func TestAudienceMatches(t *testing.T) { + // Empty canonical → never matches (the defensive guard). + assert.False(t, audienceMatches(jwt.ClaimStrings{"https://h/x"}, "")) + // Present + matching entry. + assert.True(t, audienceMatches(jwt.ClaimStrings{"https://a", "https://h/x"}, "https://h/x")) + // Present + no match. + assert.False(t, audienceMatches(jwt.ClaimStrings{"https://a"}, "https://h/x")) +} + +// --------------------------------------------------------------------------- +// dpop.go — urlMatches +// --------------------------------------------------------------------------- + +func TestURLMatches(t *testing.T) { + // scheme + host case-insensitive, trailing slash ignored, path exact. + assert.True(t, urlMatches("https://API.example.com/x", "https://api.example.com/x/")) + assert.True(t, urlMatches("https://h.com", "https://h.com/")) // both normalise to "/" + // scheme mismatch. + assert.False(t, urlMatches("http://h.com/x", "https://h.com/x")) + // host mismatch. + assert.False(t, urlMatches("https://a.com/x", "https://b.com/x")) + // path mismatch. + assert.False(t, urlMatches("https://h.com/x", "https://h.com/y")) + // unparseable → false. + assert.False(t, urlMatches("://bad", "https://h.com")) +} + +// --------------------------------------------------------------------------- +// idempotency.go — validateIdempotencyKey +// --------------------------------------------------------------------------- + +func TestValidateIdempotencyKey(t *testing.T) { + assert.Error(t, validateIdempotencyKey(""), "empty rejected") + assert.NoError(t, validateIdempotencyKey("abc-123_OK")) + // Non-ASCII printable rejected. + assert.Error(t, validateIdempotencyKey("abc\x01def")) + assert.Error(t, validateIdempotencyKey("emoji-\U0001F600")) + // Over the length cap. + assert.Error(t, validateIdempotencyKey(strings.Repeat("a", 1000))) +} + +// --------------------------------------------------------------------------- +// env_policy.go — defaultEnvLookup branch coverage +// --------------------------------------------------------------------------- + +func TestDefaultEnvLookup_Branches(t *testing.T) { + app := fiberAppForLookup(t) + // query string wins. + assert.Equal(t, "staging", probeEnvLookup(t, app, "/q?env=staging", "", "")) + // json body "env". + assert.Equal(t, "production", probeEnvLookup(t, app, "/q", "application/json", `{"env":"production"}`)) + // json body "to" fallback. + assert.Equal(t, "qa", probeEnvLookup(t, app, "/q", "application/json", `{"to":"qa"}`)) + // non-json content-type → empty. + assert.Equal(t, "", probeEnvLookup(t, app, "/q", "text/plain", "env=prod")) + // empty body → empty. + assert.Equal(t, "", probeEnvLookup(t, app, "/q", "", "")) + // :env route param wins (highest precedence). + assert.Equal(t, "production", probeEnvLookup(t, app, "/p/production", "", "")) + // malformed JSON body → empty (best-effort parse, no error surfaced). + assert.Equal(t, "", probeEnvLookup(t, app, "/q", "application/json", `{bad`)) + // JSON body with neither env nor to → empty (the trailing return). + assert.Equal(t, "", probeEnvLookup(t, app, "/q", "application/json", `{"x":1}`)) +} diff --git a/internal/middleware/coverage_more_test.go b/internal/middleware/coverage_more_test.go new file mode 100644 index 0000000..520513d --- /dev/null +++ b/internal/middleware/coverage_more_test.go @@ -0,0 +1,1225 @@ +package middleware_test + +// coverage_more_test.go — second coverage batch: drives the remaining +// low-coverage middleware surface (RequireAuth JWT path, rate-limit +// locals + dedup cap, idempotency fingerprint/multipart canonicalisation, +// log-scrubber child loggers, NewRelic with a live agent, env-policy +// route-param + body lookups + WithEnvLookup, admin-audit locals). + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + "mime/multipart" + "net/http" + "net/http/httptest" + "testing" + "time" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/gofiber/fiber/v2" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jws" + jwxjwt "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/newrelic/go-agent/v3/newrelic" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "instant.dev/internal/config" + "instant.dev/internal/middleware" + "instant.dev/internal/testhelpers" +) + +// signSessionRich mints a session JWT carrying email + impersonation + +// read-only claims so RequireAuth populates every optional local. +func signSessionRich(t *testing.T, secret, uid, tid, email, impersonatedBy string, readOnly bool) string { + t.Helper() + type sessionClaims struct { + UserID string `json:"uid"` + TeamID string `json:"tid"` + Email string `json:"email,omitempty"` + ReadOnly bool `json:"read_only,omitempty"` + ImpersonatedBy string `json:"impersonated_by,omitempty"` + jwt.RegisteredClaims + } + claims := sessionClaims{ + UserID: uid, + TeamID: tid, + Email: email, + ReadOnly: readOnly, + ImpersonatedBy: impersonatedBy, + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + ID: uuid.NewString(), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := tok.SignedString([]byte(secret)) + require.NoError(t, err) + return signed +} + +// --------------------------------------------------------------------------- +// auth.go — RequireAuth JWT happy path + getters (email/impersonation) +// --------------------------------------------------------------------------- + +func requireAuthApp(secret string) *fiber.App { + cfg := &config.Config{JWTSecret: secret} + app := fiber.New() + app.Use(middleware.RequireAuth(cfg)) + app.Get("/me", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{ + "email": middleware.GetEmail(c), + "impersonated_by": middleware.GetImpersonatedBy(c), + "read_only": middleware.IsReadOnly(c), + "user_id": middleware.GetUserID(c), + }) + }) + return app +} + +func TestRequireAuth_ValidJWTPopulatesLocals(t *testing.T) { + middleware.SetRevocationDB(nil) // no revocation check + app := requireAuthApp(testhelpers.TestJWTSecret) + tok := signSessionRich(t, testhelpers.TestJWTSecret, + uuid.NewString(), uuid.NewString(), "u@x.com", "admin@x.com", true) + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Authorization", "Bearer "+tok) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() +} + +func TestRequireAuth_NoBearerRejected(t *testing.T) { + app := requireAuthApp(testhelpers.TestJWTSecret) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/me", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + resp.Body.Close() +} + +func TestRequireAuth_WrongAlgRejected(t *testing.T) { + app := requireAuthApp(testhelpers.TestJWTSecret) + // "none" alg token — must be rejected by WithValidMethods. + tok := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{ + "uid": uuid.NewString(), "tid": uuid.NewString(), + }) + signed, err := tok.SignedString(jwt.UnsafeAllowNoneSignatureType) + require.NoError(t, err) + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Authorization", "Bearer "+signed) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "alg=none must be rejected") + resp.Body.Close() +} + +func TestGetEmailImpersonation_DefaultsEmpty(t *testing.T) { + app := fiber.New() + app.Get("/g", func(c *fiber.Ctx) error { + assert.Empty(t, middleware.GetEmail(c)) + assert.Empty(t, middleware.GetImpersonatedBy(c)) + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/g", nil), 3000) + require.NoError(t, err) + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// dpop.go — verifyDPoPProof error branches (htu mismatch, JKT mismatch) +// reusing the fixture defined in dpop_test.go (same _test package). +// --------------------------------------------------------------------------- + +func TestDPoP_HTUMismatchRejected(t *testing.T) { + middleware.ResetDPoPRedisBreakerForTest() + rdb, clean := newMiniRedis(t) + defer clean() + f := newDPoPFixture(t) + app := newDPoPApp(rdb) + // Proof's htu points at a DIFFERENT path than the request → urlMatches + // returns false → 401. + proof := f.signProof("POST", "https://api.instanode.dev/cache/new", time.Now(), uuid.NewString()) + resp := runRequest(t, app, "POST", "https://api.instanode.dev/db/new", f.bearer, proof) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, + "htu pointing at a different resource must be rejected") + resp.Body.Close() +} + +func TestDPoP_WrongMethodInProofRejected(t *testing.T) { + // htm in the proof says GET but the request is POST → getStringClaim + // returns "GET", the htm comparison fails → 401. Exercises the htm + // branch of verifyDPoPProof + getStringClaim success path. + middleware.ResetDPoPRedisBreakerForTest() + rdb, clean := newMiniRedis(t) + defer clean() + f := newDPoPFixture(t) + app := newDPoPApp(rdb) + proof := f.signProof("GET", "https://api.instanode.dev/db/new", time.Now(), uuid.NewString()) + resp := runRequest(t, app, "POST", "https://api.instanode.dev/db/new", f.bearer, proof) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, + "a proof whose htm doesn't match the request method must be rejected") + resp.Body.Close() +} + +// signPartialProof builds a DPoP proof omitting selected claims so the +// getStringClaim `!ok` branches in verifyDPoPProof are exercised. +func signPartialProof(t *testing.T, f *dpopFixture, setHTM, setHTU bool) string { + t.Helper() + tok := jwxjwt.New() + if setHTM { + require.NoError(t, tok.Set("htm", "POST")) + } + if setHTU { + require.NoError(t, tok.Set("htu", "https://api.instanode.dev/db/new")) + } + require.NoError(t, tok.Set(jwxjwt.IssuedAtKey, time.Now())) + require.NoError(t, tok.Set(jwxjwt.JwtIDKey, uuid.NewString())) + hdrs := jws.NewHeaders() + require.NoError(t, hdrs.Set(jws.TypeKey, "dpop+jwt")) + require.NoError(t, hdrs.Set(jws.JWKKey, f.publicKey)) + signed, err := jwxjwt.Sign(tok, jwxjwt.WithKey(jwa.ES256, f.privateKey, jws.WithProtectedHeaders(hdrs))) + require.NoError(t, err) + return string(signed) +} + +func TestDPoP_MissingHTMClaimRejected(t *testing.T) { + middleware.ResetDPoPRedisBreakerForTest() + rdb, clean := newMiniRedis(t) + defer clean() + f := newDPoPFixture(t) + app := newDPoPApp(rdb) + proof := signPartialProof(t, f, false, true) // no htm + resp := runRequest(t, app, "POST", "https://api.instanode.dev/db/new", f.bearer, proof) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "missing htm claim must be rejected") + resp.Body.Close() +} + +func TestDPoP_MissingHTUClaimRejected(t *testing.T) { + middleware.ResetDPoPRedisBreakerForTest() + rdb, clean := newMiniRedis(t) + defer clean() + f := newDPoPFixture(t) + app := newDPoPApp(rdb) + proof := signPartialProof(t, f, true, false) // htm but no htu + resp := runRequest(t, app, "POST", "https://api.instanode.dev/db/new", f.bearer, proof) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "missing htu claim must be rejected") + resp.Body.Close() +} + +func TestDPoP_WrongTypHeaderRejected(t *testing.T) { + middleware.ResetDPoPRedisBreakerForTest() + rdb, clean := newMiniRedis(t) + defer clean() + f := newDPoPFixture(t) + app := newDPoPApp(rdb) + // Build a proof whose protected-header typ is NOT "dpop+jwt". + tok := jwxjwt.New() + require.NoError(t, tok.Set("htm", "POST")) + require.NoError(t, tok.Set("htu", "https://api.instanode.dev/db/new")) + require.NoError(t, tok.Set(jwxjwt.IssuedAtKey, time.Now())) + require.NoError(t, tok.Set(jwxjwt.JwtIDKey, uuid.NewString())) + hdrs := jws.NewHeaders() + require.NoError(t, hdrs.Set(jws.TypeKey, "jwt")) // wrong typ + require.NoError(t, hdrs.Set(jws.JWKKey, f.publicKey)) + signed, err := jwxjwt.Sign(tok, jwxjwt.WithKey(jwa.ES256, f.privateKey, jws.WithProtectedHeaders(hdrs))) + require.NoError(t, err) + resp := runRequest(t, app, "POST", "https://api.instanode.dev/db/new", f.bearer, string(signed)) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "wrong DPoP typ header must be rejected") + resp.Body.Close() +} + +func TestDPoP_MissingJWKHeaderRejected(t *testing.T) { + middleware.ResetDPoPRedisBreakerForTest() + rdb, clean := newMiniRedis(t) + defer clean() + f := newDPoPFixture(t) + app := newDPoPApp(rdb) + // Proof with correct typ but NO jwk in the protected header. + tok := jwxjwt.New() + require.NoError(t, tok.Set("htm", "POST")) + require.NoError(t, tok.Set("htu", "https://api.instanode.dev/db/new")) + require.NoError(t, tok.Set(jwxjwt.IssuedAtKey, time.Now())) + require.NoError(t, tok.Set(jwxjwt.JwtIDKey, uuid.NewString())) + hdrs := jws.NewHeaders() + require.NoError(t, hdrs.Set(jws.TypeKey, "dpop+jwt")) + signed, err := jwxjwt.Sign(tok, jwxjwt.WithKey(jwa.ES256, f.privateKey, jws.WithProtectedHeaders(hdrs))) + require.NoError(t, err) + resp := runRequest(t, app, "POST", "https://api.instanode.dev/db/new", f.bearer, string(signed)) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "DPoP proof missing jwk must be rejected") + resp.Body.Close() +} + +func TestDPoP_NilRedisReplayStoreDown(t *testing.T) { + // rdb == nil → DPoP replay check can't run → 503 (fail-CLOSED, not open). + middleware.ResetDPoPRedisBreakerForTest() + f := newDPoPFixture(t) + app := newDPoPApp(nil) + proof := f.signProof("POST", "https://api.instanode.dev/db/new", time.Now(), uuid.NewString()) + resp := runRequest(t, app, "POST", "https://api.instanode.dev/db/new", f.bearer, proof) + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, + "a missing replay store must fail closed (503), never silently fail open") + resp.Body.Close() +} + +func TestDPoP_GarbageProofParseError(t *testing.T) { + middleware.ResetDPoPRedisBreakerForTest() + rdb, clean := newMiniRedis(t) + defer clean() + f := newDPoPFixture(t) + app := newDPoPApp(rdb) + // A non-JWS DPoP header → jws.Parse fails inside verifyDPoPProof → 401. + resp := runRequest(t, app, "POST", "https://api.instanode.dev/db/new", f.bearer, "not-a-valid-jws") + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, + "an unparseable DPoP proof must be rejected") + resp.Body.Close() +} + +func TestDPoP_JKTMismatchRejected(t *testing.T) { + middleware.ResetDPoPRedisBreakerForTest() + rdb, clean := newMiniRedis(t) + defer clean() + // Bearer is bound to fixture A's thumbprint, but the proof is signed by + // fixture B's key → the computed jwkThumbprint != cnf.jkt → 401. + fA := newDPoPFixture(t) + fB := newDPoPFixture(t) + app := newDPoPApp(rdb) + proofFromB := fB.signProof("POST", "https://api.instanode.dev/db/new", time.Now(), uuid.NewString()) + resp := runRequest(t, app, "POST", "https://api.instanode.dev/db/new", fA.bearer, proofFromB) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, + "a proof key whose thumbprint != bearer cnf.jkt must be rejected") + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// auth.go — OptionalAuth credential-drop branches (audience, revoked) + +// rich-token locals population + PAT path +// --------------------------------------------------------------------------- + +func optionalAuthRichApp(secret string) *fiber.App { + cfg := &config.Config{JWTSecret: secret} + app := fiber.New() + app.Get("/me", + middleware.OptionalAuth(cfg), + func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{ + "user_id": middleware.GetUserID(c), + "email": middleware.GetEmail(c), + "impersonated_by": middleware.GetImpersonatedBy(c), + "read_only": middleware.IsReadOnly(c), + }) + }, + ) + return app +} + +func TestOptionalAuth_RichTokenPopulatesLocals(t *testing.T) { + middleware.SetRevocationDB(nil) + app := optionalAuthRichApp(testhelpers.TestJWTSecret) + tok := signSessionRich(t, testhelpers.TestJWTSecret, + uuid.NewString(), uuid.NewString(), "rich@x.com", "admin@x.com", true) + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Authorization", "Bearer "+tok) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + defer resp.Body.Close() + var body map[string]any + require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Equal(t, "rich@x.com", body["email"]) + assert.Equal(t, "admin@x.com", body["impersonated_by"]) + assert.Equal(t, true, body["read_only"]) +} + +func TestOptionalAuth_RevokedJTIDropsCredential(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + middleware.SetRevocationDB(rdb) + defer middleware.SetRevocationDB(nil) + + jti := uuid.NewString() + require.NoError(t, rdb.Set(context.Background(), "session.revoked:"+jti, "1", 0).Err()) + + app := optionalAuthRichApp(testhelpers.TestJWTSecret) + tok := signSessionJTI(t, testhelpers.TestJWTSecret, jti) + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Authorization", "Bearer "+tok) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + defer resp.Body.Close() + var body map[string]any + require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + // Credential dropped → anonymous (empty user_id), request still 200. + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Empty(t, body["user_id"], "a revoked JTI drops the credential on OptionalAuth (anonymous, not 401)") +} + +func TestOptionalAuth_PATPath(t *testing.T) { + // An ink_-prefixed bearer routes through AuthenticateAPIKey; with no DB + // wired it fails and the request continues as anonymous (no block). + middleware.SetAPIKeyDB(nil) + app := optionalAuthRichApp(testhelpers.TestJWTSecret) + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Authorization", "Bearer ink_some_token") + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode, "invalid PAT on OptionalAuth → anonymous, not 401") + resp.Body.Close() +} + +func TestOptionalAuth_AudienceMismatchDropsCredential(t *testing.T) { + middleware.SetRevocationDB(nil) + app := optionalAuthRichApp(testhelpers.TestJWTSecret) + // A token with a bogus audience → OptionalAuth drops the credential and + // continues anonymous (does NOT 401, unlike RequireAuth). + type sc struct { + UserID string `json:"uid"` + TeamID string `json:"tid"` + jwt.RegisteredClaims + } + claims := sc{ + UserID: uuid.NewString(), + TeamID: uuid.NewString(), + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + ID: uuid.NewString(), + Audience: jwt.ClaimStrings{"https://wrong.example/x"}, + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := tok.SignedString([]byte(testhelpers.TestJWTSecret)) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Authorization", "Bearer "+signed) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + var body map[string]any + require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Empty(t, body["user_id"], "audience mismatch drops the credential on OptionalAuth") +} + +func TestRequireAuth_RevocationRedisErrorFailsOpen(t *testing.T) { + // A revocation Redis error must fail open: the request still authenticates + // (covers the `if err != nil` branch of RequireAuth's JTI check). + rdb, clean := newMiniRedis(t) + clean() // closed → IsJTIRevoked returns an error + middleware.SetRevocationDB(rdb) + defer middleware.SetRevocationDB(nil) + + app := requireAuthApp(testhelpers.TestJWTSecret) + tok := signSessionJTI(t, testhelpers.TestJWTSecret, uuid.NewString()) + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Authorization", "Bearer "+tok) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode, "revocation Redis error must fail open (request still authenticates)") + resp.Body.Close() +} + +func TestAuth_CnfJKTPopulatesDPoPLocal(t *testing.T) { + middleware.SetRevocationDB(nil) + type cnf struct { + JKT string `json:"jkt"` + } + type sc struct { + UserID string `json:"uid"` + TeamID string `json:"tid"` + Cnf *cnf `json:"cnf,omitempty"` + jwt.RegisteredClaims + } + mint := func() string { + claims := sc{ + UserID: uuid.NewString(), + TeamID: uuid.NewString(), + Cnf: &cnf{JKT: "thumbprint-abc"}, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + ID: uuid.NewString(), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := tok.SignedString([]byte(testhelpers.TestJWTSecret)) + require.NoError(t, err) + return signed + } + // RequireAuth path. + app := requireAuthApp(testhelpers.TestJWTSecret) + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Authorization", "Bearer "+mint()) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() + + // OptionalAuth path. + app2 := optionalAuthRichApp(testhelpers.TestJWTSecret) + req2 := httptest.NewRequest(http.MethodGet, "/me", nil) + req2.Header.Set("Authorization", "Bearer "+mint()) + resp2, err := app2.Test(req2, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp2.StatusCode) + resp2.Body.Close() +} + +func TestRequireAuth_EmptyClaimsRejected(t *testing.T) { + middleware.SetRevocationDB(nil) + app := requireAuthApp(testhelpers.TestJWTSecret) + // Valid signature but empty uid/tid → rejected. + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + ID: uuid.NewString(), + }) + signed, err := tok.SignedString([]byte(testhelpers.TestJWTSecret)) + require.NoError(t, err) + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Authorization", "Bearer "+signed) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "empty uid/tid claims must be rejected") + resp.Body.Close() +} + +func TestRequireAuth_PATPath(t *testing.T) { + // PAT routing on RequireAuth: nil DB → AuthenticateAPIKey errors → 401. + middleware.SetAPIKeyDB(nil) + app := requireAuthApp(testhelpers.TestJWTSecret) + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Authorization", "Bearer ink_bad") + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// rate_limit.go — locals getters, allow-under-limit + exceed +// --------------------------------------------------------------------------- + +func TestRateLimit_LocalsGetters_Defaults(t *testing.T) { + app := fiber.New() + app.Get("/r", func(c *fiber.Ctx) error { + assert.False(t, middleware.IsRateLimitExceeded(c)) + assert.EqualValues(t, 0, middleware.GetRateLimitCount(c)) + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/r", nil), 3000) + require.NoError(t, err) + resp.Body.Close() +} + +func TestRateLimit_ExceedSetsLocals(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + + app := fiber.New(fiber.Config{ProxyHeader: "X-Forwarded-For"}) + app.Use(middleware.Fingerprint()) + app.Use(middleware.RateLimit(rdb, middleware.RateLimitConfig{KeyPrefix: "covtest", Limit: 2})) + app.Get("/r", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{ + "exceeded": middleware.IsRateLimitExceeded(c), + "count": middleware.GetRateLimitCount(c), + }) + }) + + ip := "198.51.100.7" + var exceededSeen bool + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodGet, "/r", nil) + req.Header.Set("X-Forwarded-For", ip) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + if i >= 2 { // limit is 2 → 3rd+ exceed + // header should reflect remaining 0 + assert.Equal(t, "0", resp.Header.Get("X-RateLimit-Remaining")) + exceededSeen = true + } + resp.Body.Close() + } + assert.True(t, exceededSeen) +} + +func TestRateLimit_NilRedisFailsOpen(t *testing.T) { + app := fiber.New(fiber.Config{ProxyHeader: "X-Forwarded-For"}) + app.Use(middleware.Fingerprint()) + app.Use(middleware.RateLimit(nil, middleware.RateLimitConfig{})) + app.Get("/r", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + req := httptest.NewRequest(http.MethodGet, "/r", nil) + req.Header.Set("X-Forwarded-For", "203.0.113.99") + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode, "nil redis → fail open") + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// idempotency.go — fingerprint replay (JSON) + multipart canonicalisation +// --------------------------------------------------------------------------- + +func TestIdempotency_FingerprintReplay_JSON(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + + calls := 0 + app := fiber.New() + app.Post("/db/new", + middleware.Idempotency(rdb, "db.new"), + func(c *fiber.Ctx) error { + calls++ + return c.Status(fiber.StatusCreated).JSON(fiber.Map{"n": calls}) + }, + ) + body := `{"name":"x","env":"development"}` + send := func() *http.Response { + req := httptest.NewRequest(http.MethodPost, "/db/new", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + resp, err := app.Test(req, 3000) + require.NoError(t, err) + return resp + } + r1 := send() + r1.Body.Close() + r2 := send() + // Second identical body within TTL is replayed from cache. + assert.Equal(t, "true", r2.Header.Get("X-Idempotent-Replay")) + r2.Body.Close() + assert.Equal(t, 1, calls, "handler must run once; the replay is served from cache") +} + +func TestIdempotency_MultipartCanonicalisation(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + + calls := 0 + app := fiber.New(fiber.Config{BodyLimit: 10 * 1024 * 1024}) + app.Post("/deploy/new", + middleware.Idempotency(rdb, "deploy.new"), + func(c *fiber.Ctx) error { + calls++ + return c.Status(fiber.StatusCreated).JSON(fiber.Map{"n": calls}) + }, + ) + + buildMultipart := func() (*bytes.Buffer, string) { + var buf bytes.Buffer + w := multipart.NewWriter(&buf) + _ = w.WriteField("name", "myapp") + _ = w.WriteField("env", "development") + fw, _ := w.CreateFormFile("bundle", "app.tar") + _, _ = fw.Write([]byte("deterministic-tarball-bytes")) + w.Close() + return &buf, w.FormDataContentType() + } + + send := func() *http.Response { + buf, ct := buildMultipart() + req := httptest.NewRequest(http.MethodPost, "/deploy/new", buf) + req.Header.Set("Content-Type", ct) + resp, err := app.Test(req, 5000) + require.NoError(t, err) + return resp + } + r1 := send() + r1.Body.Close() + r2 := send() + defer r2.Body.Close() + assert.Equal(t, "true", r2.Header.Get("X-Idempotent-Replay"), + "identical multipart upload (same file + fields) must dedup via canonicalMultipartBody") + assert.Equal(t, 1, calls) +} + +func TestIdempotency_RawBodyAndMalformedJSON(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + + calls := 0 + app := fiber.New() + app.Post("/webhook/new", + middleware.Idempotency(rdb, "webhook.new"), + func(c *fiber.Ctx) error { + calls++ + return c.Status(fiber.StatusCreated).SendString("ok") + }, + ) + // text/plain raw-body path (canonicalRequestBody returns raw bytes). + send := func(ct, body string) *http.Response { + req := httptest.NewRequest(http.MethodPost, "/webhook/new", bytes.NewReader([]byte(body))) + if ct != "" { + req.Header.Set("Content-Type", ct) + } + resp, err := app.Test(req, 3000) + require.NoError(t, err) + return resp + } + send("text/plain", "raw-payload").Body.Close() + r := send("text/plain", "raw-payload") + r.Body.Close() + assert.Equal(t, 1, calls, "identical raw bodies dedup via the raw-bytes fingerprint") + + // Malformed-JSON content-type falls back to raw bytes (no crash). + calls = 0 + send("application/json", "{not valid json").Body.Close() + r2 := send("application/json", "{not valid json") + r2.Body.Close() + assert.Equal(t, 1, calls, "malformed JSON falls back to a stable raw-bytes fingerprint") +} + +func TestIdempotency_Fingerprint5xxNotCached(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + calls := 0 + app := fiber.New() + app.Post("/db/new", + middleware.Idempotency(rdb, "db.new"), + func(c *fiber.Ctx) error { + calls++ + return c.SendStatus(fiber.StatusBadGateway) // 5xx → not cached on fingerprint path + }, + ) + send := func() *http.Response { + req := httptest.NewRequest(http.MethodPost, "/db/new", bytes.NewReader([]byte(`{"k":"v"}`))) + req.Header.Set("Content-Type", "application/json") + resp, err := app.Test(req, 3000) + require.NoError(t, err) + return resp + } + send().Body.Close() + send().Body.Close() + assert.Equal(t, 2, calls, "5xx must not be cached on the fingerprint path either") +} + +func TestIdempotency_ExplicitRedisGetFailOpen(t *testing.T) { + rdb, clean := newMiniRedis(t) + clean() // closed → explicit-path rdb.Get errors → fail open to handler + app := fiber.New() + app.Post("/db/new", + middleware.Idempotency(rdb, "db.new"), + func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusCreated) }, + ) + req := httptest.NewRequest(http.MethodPost, "/db/new", bytes.NewReader([]byte(`{"a":1}`))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "k-"+uuid.NewString()) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusCreated, resp.StatusCode, "explicit-path Redis error must fail open") + resp.Body.Close() +} + +func TestIdempotency_EmptyBodyFingerprintDedup(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + calls := 0 + app := fiber.New() + app.Post("/webhook/new", + middleware.Idempotency(rdb, "webhook.new"), + func(c *fiber.Ctx) error { + calls++ + return c.SendStatus(fiber.StatusCreated) + }, + ) + send := func() *http.Response { + // No body, no content-type → canonicalRequestBody returns "" (empty). + resp, err := app.Test(httptest.NewRequest(http.MethodPost, "/webhook/new", nil), 3000) + require.NoError(t, err) + return resp + } + send().Body.Close() + send().Body.Close() + assert.Equal(t, 1, calls, "two empty-body POSTs with same route+scope dedup") +} + +func TestIdempotency_MultipartMultiFileCanonicalisation(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + calls := 0 + app := fiber.New(fiber.Config{BodyLimit: 10 * 1024 * 1024}) + app.Post("/deploy/new", + middleware.Idempotency(rdb, "deploy.new"), + func(c *fiber.Ctx) error { + calls++ + return c.SendStatus(fiber.StatusCreated) + }, + ) + build := func() (*bytes.Buffer, string) { + var buf bytes.Buffer + w := multipart.NewWriter(&buf) + // Two files + multi-valued field exercise the sorted-iteration loops. + f1, _ := w.CreateFormFile("a", "a.bin") + _, _ = f1.Write([]byte("aaa")) + f2, _ := w.CreateFormFile("b", "b.bin") + _, _ = f2.Write([]byte("bbb")) + _ = w.WriteField("tag", "v2") + _ = w.WriteField("tag", "v1") // multi-value, unsorted + w.Close() + return &buf, w.FormDataContentType() + } + send := func() *http.Response { + buf, ct := build() + req := httptest.NewRequest(http.MethodPost, "/deploy/new", buf) + req.Header.Set("Content-Type", ct) + resp, err := app.Test(req, 5000) + require.NoError(t, err) + return resp + } + send().Body.Close() + send().Body.Close() + assert.Equal(t, 1, calls, "identical 2-file + multi-value multipart bodies dedup") +} + +func TestIdempotency_MalformedMultipartCanonErr(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + calls := 0 + app := fiber.New() + app.Post("/deploy/new", + middleware.Idempotency(rdb, "deploy.new"), + func(c *fiber.Ctx) error { + calls++ + return c.SendStatus(fiber.StatusCreated) + }, + ) + // multipart content-type but a body that isn't valid multipart → + // canonicalRequestBody/canonicalMultipartBody errors → the fingerprint + // path logs canonErr and falls through (the middleware must not crash; + // Fiber's own multipart reader may then surface the parse error). Either + // way the canonErr branch in idempotencyFingerprint is exercised. + req := httptest.NewRequest(http.MethodPost, "/deploy/new", bytes.NewReader([]byte("garbage-not-multipart"))) + req.Header.Set("Content-Type", "multipart/form-data; boundary=xyz") + resp, err := app.Test(req, 3000) + if err == nil { + resp.Body.Close() + } + // The point of this test is coverage of the canonErr fall-through, not a + // specific status — assert only that the middleware itself didn't panic. + _ = calls +} + +func TestIdempotency_FingerprintHandlerErrorPropagates(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + app := fiber.New() + app.Post("/db/new", + middleware.Idempotency(rdb, "db.new"), + func(c *fiber.Ctx) error { + // Return a plain (non-response-written) error so the fingerprint + // path's `nextErr != nil && !IsResponseWrittenErr` branch runs. + return fiber.NewError(fiber.StatusBadRequest, "handler error") + }, + ) + req := httptest.NewRequest(http.MethodPost, "/db/new", bytes.NewReader([]byte(`{"a":1}`))) + req.Header.Set("Content-Type", "application/json") + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + resp.Body.Close() +} + +func TestIdempotency_FingerprintCorruptCacheFallsThrough(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + calls := 0 + app := fiber.New() + app.Post("/db/new", + middleware.Idempotency(rdb, "db.new"), + func(c *fiber.Ctx) error { + calls++ + return c.Status(fiber.StatusCreated).SendString("ok") + }, + ) + send := func() *http.Response { + req := httptest.NewRequest(http.MethodPost, "/db/new", bytes.NewReader([]byte(`{"a":1}`))) + req.Header.Set("Content-Type", "application/json") + resp, err := app.Test(req, 3000) + require.NoError(t, err) + return resp + } + // First request caches a valid entry under some idem-fp:* key. + send().Body.Close() + require.Equal(t, 1, calls) + + // Corrupt the cached entry, then resend: json.Unmarshal fails → the + // fingerprint path logs and falls through to the handler again. + keys, err := rdb.Keys(context.Background(), "idem-fp:*").Result() + require.NoError(t, err) + require.NotEmpty(t, keys, "first request must have cached a fingerprint entry") + require.NoError(t, rdb.Set(context.Background(), keys[0], "{not-json", 0).Err()) + + send().Body.Close() + assert.Equal(t, 2, calls, "a corrupt cache entry must fall through to the handler, not replay") +} + +func TestIdempotency_ExplicitCorruptCacheFallsThrough(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + calls := 0 + app := fiber.New() + app.Post("/db/new", + middleware.Idempotency(rdb, "db.new"), + func(c *fiber.Ctx) error { + calls++ + return c.Status(fiber.StatusCreated).SendString("ok") + }, + ) + key := "corrupt-" + uuid.NewString() + send := func() *http.Response { + req := httptest.NewRequest(http.MethodPost, "/db/new", bytes.NewReader([]byte(`{"a":1}`))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", key) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + return resp + } + send().Body.Close() + require.Equal(t, 1, calls) + + keys, err := rdb.Keys(context.Background(), "idem:*").Result() + require.NoError(t, err) + require.NotEmpty(t, keys) + require.NoError(t, rdb.Set(context.Background(), keys[0], "{not-json", 0).Err()) + + send().Body.Close() + assert.Equal(t, 2, calls, "explicit-path corrupt cache must fall through to the handler") +} + +func TestIdempotency_FingerprintRedisFailOpen(t *testing.T) { + rdb, clean := newMiniRedis(t) + clean() // close → every Redis op errors → fingerprint path fails open + app := fiber.New() + app.Post("/db/new", + middleware.Idempotency(rdb, "db.new"), + func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusCreated) }, + ) + req := httptest.NewRequest(http.MethodPost, "/db/new", bytes.NewReader([]byte(`{"a":1}`))) + req.Header.Set("Content-Type", "application/json") + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusCreated, resp.StatusCode, "Redis error must fail open") + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// auth.go — RequireAuth revoked-JTI rejection (audience mismatch is covered +// by auth_audience_test.go's newAudApp suite) +// --------------------------------------------------------------------------- + +func signSessionJTI(t *testing.T, secret, jti string) string { + t.Helper() + type sessionClaims struct { + UserID string `json:"uid"` + TeamID string `json:"tid"` + jwt.RegisteredClaims + } + claims := sessionClaims{ + UserID: uuid.NewString(), + TeamID: uuid.NewString(), + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + ID: jti, + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := tok.SignedString([]byte(secret)) + require.NoError(t, err) + return signed +} + +func TestRequireAuth_RevokedJTIRejected(t *testing.T) { + rdb, clean := newMiniRedis(t) + defer clean() + middleware.SetRevocationDB(rdb) + defer middleware.SetRevocationDB(nil) + + jti := uuid.NewString() + require.NoError(t, rdb.Set(context.Background(), "session.revoked:"+jti, "1", 0).Err()) + + app := requireAuthApp(testhelpers.TestJWTSecret) + tok := signSessionJTI(t, testhelpers.TestJWTSecret, jti) + req := httptest.NewRequest(http.MethodGet, "/me", nil) + req.Header.Set("Authorization", "Bearer "+tok) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "a revoked JTI must be rejected") + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// log_scrubber.go — WithAttrs / WithGroup child loggers still scrub +// --------------------------------------------------------------------------- + +func TestLogScrubber_WithAttrsAndGroup(t *testing.T) { + var buf bytes.Buffer + base := slog.NewTextHandler(&buf, &slog.HandlerOptions{}) + scrub := middleware.NewLogScrubber(base, "topsecret") + + // WithAttrs eagerly scrubs the bound attrs. + child := scrub.WithAttrs([]slog.Attr{slog.String("url", "/api/v1/topsecret/x")}) + // WithGroup returns a still-scrubbing wrapper. + grouped := child.WithGroup("g") + + rec := slog.NewRecord(time.Now(), slog.LevelInfo, "msg topsecret here", 0) + // A group attr carrying the secret in a nested field exercises the + // group-rebuild branch of scrubAttr/Handle. + rec.AddAttrs(slog.Group("ctx", + slog.String("url", "/api/v1/topsecret/x"), + slog.Int("n", 1), + )) + require.NoError(t, grouped.Handle(context.Background(), rec)) + + out := buf.String() + assert.NotContains(t, out, "topsecret", "scrubber must redact the secret in child loggers + message + nested groups") + assert.Contains(t, out, "") +} + +func TestLogScrubber_EmptySecretPassThrough(t *testing.T) { + var buf bytes.Buffer + base := slog.NewTextHandler(&buf, nil) + h := middleware.NewLogScrubber(base, "") + // Empty secret → base handler returned unchanged. + assert.Equal(t, base, h) +} + +// --------------------------------------------------------------------------- +// newrelic.go — live (no-op) agent opens + ends a transaction +// --------------------------------------------------------------------------- + +func TestNewRelic_LiveAgentOpensTxn(t *testing.T) { + // A disabled-but-non-nil application: the agent records nothing but the + // middleware exercises StartTransaction / SetWebResponse / NoticeError. + app, err := newrelic.NewApplication( + newrelic.ConfigAppName("cov-test"), + newrelic.ConfigEnabled(false), + ) + require.NoError(t, err) + require.NotNil(t, app) + middleware.SetNRApp(app) + defer middleware.SetNRApp(nil) + + fapp := fiber.New() + fapp.Use(middleware.NewRelic(app)) + fapp.Get("/ok", func(c *fiber.Ctx) error { + assert.NotNil(t, middleware.GetNRTxn(c)) + // exercise the emit helpers with a live (disabled) agent + middleware.RecordProvisionSuccess("postgres") + middleware.RecordProvisionFail("redis", "quota") + middleware.RecordResourceExpired("mongodb") + return c.SendStatus(fiber.StatusOK) + }) + fapp.Get("/err", func(c *fiber.Ctx) error { + return fiber.NewError(fiber.StatusInternalServerError, "boom") + }) + + r1, err := fapp.Test(httptest.NewRequest(http.MethodGet, "/ok", nil), 3000) + require.NoError(t, err) + r1.Body.Close() + r2, err := fapp.Test(httptest.NewRequest(http.MethodGet, "/err", nil), 3000) + require.NoError(t, err) + r2.Body.Close() +} + +// --------------------------------------------------------------------------- +// env_policy.go — route-param lookup, body lookup, WithEnvLookup override +// --------------------------------------------------------------------------- + +func TestRequireEnvAccess_EmptyTeamIDPassesThrough(t *testing.T) { + db, _, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + middleware.SetEnvPolicyDB(db) + defer middleware.SetEnvPolicyDB(nil) + // No team_id local at all → RequireEnvAccess passes through (the + // downstream handler returns its own 401 in production). + app := fiber.New() + app.Post("/deploy", + middleware.RequireEnvAccess(middleware.EnvPolicyActionDeploy), + func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }, + ) + resp, err := app.Test(httptest.NewRequest(http.MethodPost, "/deploy", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode, "no team id → pass through (handler 401s itself)") + resp.Body.Close() +} + +func TestRequireEnvAccess_RouteParamAndBodyLookup(t *testing.T) { + // No DB wired → always allows, but the lookup branches execute. + middleware.SetEnvPolicyDB(nil) + defer middleware.SetEnvPolicyDB(nil) + + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyTeamID, uuid.NewString()) + return c.Next() + }) + // :env route param path. + app.Post("/vault/:env/set", + middleware.RequireEnvAccess(middleware.EnvPolicyActionVaultWrite), + func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }, + ) + resp, err := app.Test(httptest.NewRequest(http.MethodPost, "/vault/production/set", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() + + // JSON body "env" / "to" path. + app2 := fiber.New() + app2.Use(func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyTeamID, uuid.NewString()) + return c.Next() + }) + app2.Post("/promote", + middleware.RequireEnvAccess(middleware.EnvPolicyActionDeploy), + func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }, + ) + req := httptest.NewRequest(http.MethodPost, "/promote", bytes.NewReader([]byte(`{"to":"production"}`))) + req.Header.Set("Content-Type", "application/json") + resp2, err := app2.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp2.StatusCode) + resp2.Body.Close() +} + +func TestRequireEnvAccess_WithEnvLookupOverride(t *testing.T) { + // A non-nil DB is required so the middleware reaches the env-lookup + // stage (it short-circuits on a nil DB before invoking the lookup). + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + middleware.SetEnvPolicyDB(db) + defer middleware.SetEnvPolicyDB(nil) + mock.ExpectQuery("SELECT env_policy FROM teams"). + WillReturnRows(sqlmock.NewRows([]string{"env_policy"}).AddRow([]byte(`{}`))) + + called := false + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyTeamID, uuid.NewString()) + return c.Next() + }) + app.Delete("/resources/:id", + middleware.RequireEnvAccess(middleware.EnvPolicyActionDeleteResource, + middleware.WithEnvLookup(func(c *fiber.Ctx) (string, error) { + called = true + return "production", nil + })), + func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }, + ) + resp, err := app.Test(httptest.NewRequest(http.MethodDelete, "/resources/abc", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.True(t, called, "WithEnvLookup override must be invoked") + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// admin_audit.go — AdminAuditMetadataFromLocals round-trip +// --------------------------------------------------------------------------- + +func TestAdminAuditEmit_EmptyPrefixNoOp(t *testing.T) { + app := fiber.New() + app.Get("/x", middleware.AdminAuditEmit(nil, ""), func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/x", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() +} + +func TestAdminAuditEmit_NilDBPassThrough(t *testing.T) { + // nil DB → buildAdminAuditMetadata runs (path-suffix strip, UA scrub, + // denied-by resolution) but no insert. The request still completes. + app := fiber.New() + app.Get("/api/v1/sek/customers/list", + middleware.AdminAuditEmit(nil, "sek"), + func(c *fiber.Ctx) error { + c.Set(fiber.HeaderUserAgent, "probe-ua") + return c.SendStatus(fiber.StatusOK) + }, + ) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/api/v1/sek/customers/list", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() +} + +func TestAdminAuditEmit_InsertsWithTeamIDFromPath(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + tid := uuid.New() + mock.ExpectExec("INSERT INTO audit_log").WillReturnResult(sqlmock.NewResult(1, 1)) + + app := fiber.New() + app.Get("/api/v1/sek/customers/:team_id/tier", + middleware.AdminAuditEmit(db, "sek"), + func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }, + ) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, + "/api/v1/sek/customers/"+tid.String()+"/tier", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() +} + +func TestAdminAuditEmit_LongUATruncatedAndJWTTeamFallback(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + mock.ExpectExec("INSERT INTO audit_log").WillReturnResult(sqlmock.NewResult(1, 1)) + + tid := uuid.New() + app := fiber.New() + // No :team_id param and no /customers/ segment → adminAuditTeamID + // falls back to the JWT team local. A >120-char User-Agent exercises the + // truncation branch in buildAdminAuditMetadata. + app.Get("/api/v1/sek/dashboard", + func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyTeamID, tid.String()) + return c.Next() + }, + middleware.AdminAuditEmit(db, "sek"), + func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }, + ) + req := httptest.NewRequest(http.MethodGet, "/api/v1/sek/dashboard", nil) + req.Header.Set(fiber.HeaderUserAgent, string(bytes.Repeat([]byte("u"), 300))) + resp, err := app.Test(req, 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() +} + +func TestAdminAuditEmit_NoTeamContextLogsButNoInsert(t *testing.T) { + db, _, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + // No team context resolvable → uuid.Nil → no-team-context warn branch, + // no INSERT (sqlmock would fail if an unexpected Exec fired). + app := fiber.New() + app.Get("/api/v1/sek/dashboard", + middleware.AdminAuditEmit(db, "sek"), + func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusForbidden) }, + ) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/api/v1/sek/dashboard", nil), 3000) + require.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + resp.Body.Close() +} + +func TestAdminAuditMetadataFromLocals_AbsentThenPresent(t *testing.T) { + app := fiber.New() + app.Get("/a", func(c *fiber.Ctx) error { + _, ok := middleware.AdminAuditMetadataFromLocals(c) + assert.False(t, ok, "absent before AdminAuditEmit stamps it") + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/a", nil), 3000) + require.NoError(t, err) + resp.Body.Close() +} diff --git a/internal/middleware/coverage_push3_ext_test.go b/internal/middleware/coverage_push3_ext_test.go new file mode 100644 index 0000000..5930c68 --- /dev/null +++ b/internal/middleware/coverage_push3_ext_test.go @@ -0,0 +1,50 @@ +package middleware_test + +// coverage_push3_ext_test.go — black-box top-up: PopulateTeamRole DB-error +// (non-ErrNoRows) branch, which logs a warning and falls through to Next(). + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "instant.dev/internal/middleware" +) + +func TestPopulateTeamRole_QueryErrorIsTolerated(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + middleware.SetRoleLookupDB(db) + defer middleware.SetRoleLookupDB(nil) + + uid := uuid.NewString() + tid := uuid.NewString() + // A real DB error (NOT sql.ErrNoRows) → logs warn + Next(), no role local. + mock.ExpectQuery("SELECT role FROM users"). + WithArgs(uid, tid). + WillReturnError(errors.New("connection reset by peer")) + + app := fiber.New() + app.Use(func(c *fiber.Ctx) error { + c.Locals(middleware.LocalKeyUserID, uid) + c.Locals(middleware.LocalKeyTeamID, tid) + return c.Next() + }) + app.Use(middleware.PopulateTeamRole()) + app.Get("/role", func(c *fiber.Ctx) error { + assert.Empty(t, middleware.GetTeamRole(c), "role unset on DB error (fail-open)") + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/role", nil), 3000) + require.NoError(t, err) + resp.Body.Close() + assert.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/internal/middleware/coverage_push3_test.go b/internal/middleware/coverage_push3_test.go new file mode 100644 index 0000000..a218257 --- /dev/null +++ b/internal/middleware/coverage_push3_test.go @@ -0,0 +1,142 @@ +package middleware + +// coverage_push3_test.go — white-box top-up batch closing the last +// sub-95% unexported helpers surfaced by the coverage audit: +// dpop jwkThumbprintBase64URL (success), requestCanonicalURL defensive +// fallback (unparseable canonical URL), urlMatches second-arg parse error, +// and canonicalJSON / writeCanonicalJSON unmarshalable-value error path. + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "net/http/httptest" + "testing" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// rate_limit.go — incrementWithExpiry nil-client + pipeline-exec error +// --------------------------------------------------------------------------- + +func TestIncrementWithExpiry_NilClient(t *testing.T) { + // CLAUDE.md convention #1: a nil client must NOT SIGSEGV — it returns + // (0, err) so the caller fails open. + n, err := incrementWithExpiry(context.Background(), nil, "k", time.Minute) + assert.Error(t, err) + assert.Zero(t, n) +} + +func TestIncrementWithExpiry_PipelineExecError(t *testing.T) { + // Point the client at an address nothing is listening on so the + // pipeline Exec fails — exercises the redis-error branch (fail-open). + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:1", // reserved/closed port + DialTimeout: 200 * time.Millisecond, + MaxRetries: -1, + }) + defer rdb.Close() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + n, err := incrementWithExpiry(ctx, rdb, "k", time.Minute) + assert.Error(t, err) + assert.Zero(t, n) +} + +// --------------------------------------------------------------------------- +// dpop.go — jwkThumbprintBase64URL success path +// --------------------------------------------------------------------------- + +func TestJWKThumbprintBase64URL_Success(t *testing.T) { + raw, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + key, err := jwk.FromRaw(raw) + require.NoError(t, err) + require.NoError(t, key.Set(jwk.AlgorithmKey, jwa.ES256)) + + got, err := jwkThumbprintBase64URL(key) + require.NoError(t, err) + assert.NotEmpty(t, got) + + // Must equal the RFC 7638 thumbprint encoded base64url-no-pad. + tp, err := key.Thumbprint(crypto.SHA256) + require.NoError(t, err) + assert.Equal(t, base64URLNoPad(tp), got) +} + +// --------------------------------------------------------------------------- +// dpop.go — requestCanonicalURL defensive fallback +// --------------------------------------------------------------------------- + +func TestRequestCanonicalURL_FallbackOnUnparseableCanonical(t *testing.T) { + // Force CanonicalResourceURLFor to return a value url.Parse rejects so + // the host-less defensive branch runs. + orig := CanonicalResourceURLFor + CanonicalResourceURLFor = func(_ *fiber.Ctx) string { return "://%%bad" } + defer func() { CanonicalResourceURLFor = orig }() + + app := fiber.New() + app.Get("/widgets", func(c *fiber.Ctx) error { + got := requestCanonicalURL(c) + // Falls back to https://. + assert.Contains(t, got, "/widgets") + assert.Contains(t, got, "https://") + return c.SendStatus(fiber.StatusOK) + }) + req := httptest.NewRequest(fiber.MethodGet, "/widgets", nil) + req.Host = "fallback.example.com" + resp, err := app.Test(req, 3000) + require.NoError(t, err) + resp.Body.Close() +} + +func TestRequestCanonicalURL_HappyPath(t *testing.T) { + orig := CanonicalResourceURLFor + CanonicalResourceURLFor = func(_ *fiber.Ctx) string { return "https://api.example.com" } + defer func() { CanonicalResourceURLFor = orig }() + + app := fiber.New() + app.Get("/db/new", func(c *fiber.Ctx) error { + assert.Equal(t, "https://api.example.com/db/new", requestCanonicalURL(c)) + return c.SendStatus(fiber.StatusOK) + }) + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/db/new", nil), 3000) + require.NoError(t, err) + resp.Body.Close() +} + +// --------------------------------------------------------------------------- +// dpop.go — urlMatches: second operand unparseable +// --------------------------------------------------------------------------- + +func TestURLMatches_SecondOperandUnparseable(t *testing.T) { + assert.False(t, urlMatches("https://h.com/x", "://bad")) +} + +// --------------------------------------------------------------------------- +// idempotency.go — canonicalJSON / writeCanonicalJSON marshal-error path +// --------------------------------------------------------------------------- + +func TestCanonicalJSON_UnmarshalableValueErrors(t *testing.T) { + // json.Marshal cannot encode a func value → the default branch in + // writeCanonicalJSON returns an error, surfaced by canonicalJSON. + _, err := canonicalJSON(map[string]interface{}{"f": func() {}}) + assert.Error(t, err) + + // Same inside a slice element. + _, err = canonicalJSON([]interface{}{func() {}}) + assert.Error(t, err) + + // Top-level unmarshalable value. + _, err = canonicalJSON(func() {}) + assert.Error(t, err) +} diff --git a/internal/middleware/testdata/geo-combined-test.mmdb b/internal/middleware/testdata/geo-combined-test.mmdb new file mode 100644 index 0000000000000000000000000000000000000000..80e6454b1f7b8da73ace80544f5e66e16dcfab90 GIT binary patch literal 2502 zcmZY72UHYS7{&2Bh}gw0_TB{%dj~8u7gR(M#C6ymV3FM+Ta1vHnBJS|UD6ZNdo>B^ z#q?e^rkQ4XFYf%=#hC1Q=XdVh_sz^ZXI3&~R7(jm>ZLX^=1Ym>wwQ$Nu)T~VsROwq zcEZls1-oK0cEj%21AAgG?2Ub}FQ#BW?2iL*AP&O8I0T2{FdU8}a3qex(KrUj;y4_S z6L2D?Vj51ubexPCI0dKTG@Onza3;>e**FL1;yj#>3osM2FdG-*B3z71Fb8un5A(4A zm!b;`u?Uypa_JXo1-TekVhOIo)mVyaunfy_Emq(2S$I<3&rVk*dhm zjx9?*np){FX#=?q{TM*o(fLD0*CB);hA@l~+=z!dI`f;P20|mvmRKL#-7Fm;9V;Cf z>$OEXiut4Qn1kn!qdXo@h;8X4=~9}LrBmYNsnTiE`O@j}<_u=e#ItZKo{i_=xp1ydd=^C1A@jAR7Z@}$%Bi@8J z<1KhA-iEhJ_eyt=@5H3O`@Qn0qJU=Hr&)Cjb6T9n0=_UG?@fCbkdXw;)^g7`U$JV8{q_^qciIsYn%lD-B z@kvIxR;}{%^<8VAqz=@cOX*dbfaWZD$ z6r76Fa5~PwnK%n);~boe^RVS(Eg)xN7G~o@T!f2p3Fcrf=3%~!Dnfybr34oiVi7Kr zQ7&USc?A~ZN-V)uxEf1w4VF1NpT6b5Yh_f#^E!fCM(eIgQ;8lKUV??(e&9JZls>G* z4OoYM8397{E;|-<@NZA48423~+iyog?odN0Z28@Rh`-VbHfL_GF#f-32dmA1uh9(q z>_A?YE32%iFf~0R$74qV;b6n=Twlm`du*?jT~fS#=RTLY$>j@p(>%6+Z+WFJU~1TljQHxjZmZt*)FdXI;IF#}0bkAz!1_l(HY&iL8qN literal 0 HcmV?d00001