Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions internal/handlers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ var allowedReturnOriginsDev = []string{
"http://localhost:3000",
}

// OAuth provider endpoint base URLs. Declared as package vars (not consts)
// solely so the test suite can repoint them at an httptest server — production
// never mutates them. Each is the exact URL the corresponding helper used to
// hardcode inline; behaviour is unchanged in prod.
var (
githubTokenURL = "https://github.com/login/oauth/access_token"
githubUserURL = "https://api.github.com/user"
githubUserEmailURL = "https://api.github.com/user/emails"
githubAuthorizeURL = "https://github.com/login/oauth/authorize"
googleTokenInfoURL = "https://oauth2.googleapis.com/tokeninfo"
googleTokenURL = "https://oauth2.googleapis.com/token"
googleUserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
googleAuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
)

// returnToAllowsLocalhost controls whether validateReturnTo treats
// http://localhost:5173 and http://localhost:3000 as allowed return-to
// origins. Set to true in development at startup, false in production.
Expand Down Expand Up @@ -409,7 +424,7 @@ func (h *AuthHandler) GoogleAuthURL(c *fiber.Ctx) error {
// url.Parse of a compile-time-constant string never errors — the err
// branch was dead code. GoogleStart handles the identical parse the same
// way (u, _ := url.Parse(...)).
u, _ := url.Parse("https://accounts.google.com/o/oauth2/v2/auth")
u, _ := url.Parse(googleAuthorizeURL)
q := u.Query()
q.Set("client_id", h.cfg.GoogleClientID)
q.Set("redirect_uri", redirectURI)
Expand Down Expand Up @@ -479,7 +494,7 @@ func exchangeGitHubCode(ctx context.Context, clientID, clientSecret, code string
"client_secret": {clientSecret},
"code": {code},
}
req, _ := http.NewRequestWithContext(ctx, "POST", "https://github.com/login/oauth/access_token", strings.NewReader(form.Encode()))
req, _ := http.NewRequestWithContext(ctx, "POST", githubTokenURL, strings.NewReader(form.Encode()))
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

Expand All @@ -502,7 +517,7 @@ func exchangeGitHubCode(ctx context.Context, clientID, clientSecret, code string
}

// Step 2: fetch user profile
userReq, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil)
userReq, _ := http.NewRequestWithContext(ctx, "GET", githubUserURL, nil)
userReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
userReq.Header.Set("Accept", "application/vnd.github+json")

Expand All @@ -523,7 +538,7 @@ func exchangeGitHubCode(ctx context.Context, clientID, clientSecret, code string

if profile.Email == "" {
// Fetch primary email separately
emailReq, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil)
emailReq, _ := http.NewRequestWithContext(ctx, "GET", githubUserEmailURL, nil)
emailReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
emailResp, err := client.Do(emailReq)
if err == nil {
Expand Down Expand Up @@ -646,7 +661,7 @@ type googleUser struct {
}

func verifyGoogleIDToken(ctx context.Context, clientID, idToken string) (*googleUser, error) {
verifyURL := fmt.Sprintf("https://oauth2.googleapis.com/tokeninfo?id_token=%s", url.QueryEscape(idToken))
verifyURL := fmt.Sprintf("%s?id_token=%s", googleTokenInfoURL, url.QueryEscape(idToken))
req, _ := http.NewRequestWithContext(ctx, "GET", verifyURL, nil)

client := &http.Client{Timeout: 10 * time.Second}
Expand Down Expand Up @@ -692,7 +707,7 @@ func exchangeGoogleAuthorizationCode(ctx context.Context, clientID, clientSecret
"redirect_uri": {redirectURI},
"grant_type": {"authorization_code"},
}
req, err := http.NewRequestWithContext(ctx, "POST", "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
req, err := http.NewRequestWithContext(ctx, "POST", googleTokenURL, strings.NewReader(form.Encode()))
if err != nil {
return "", err
}
Expand Down Expand Up @@ -723,7 +738,7 @@ func exchangeGoogleAuthorizationCode(ctx context.Context, clientID, clientSecret
}

func fetchGoogleUserInfoOAuth2V2(ctx context.Context, accessToken string) (*googleUser, error) {
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v2/userinfo", nil)
req, err := http.NewRequestWithContext(ctx, "GET", googleUserInfoURL, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -964,7 +979,7 @@ func (h *AuthHandler) GitHubStart(c *fiber.Ctx) error {
h.registerOAuthState(c.Context(), state)

authURL := fmt.Sprintf(
"https://github.com/login/oauth/authorize?client_id=%s&redirect_uri=%s&state=%s&scope=%s",
githubAuthorizeURL+"?client_id=%s&redirect_uri=%s&state=%s&scope=%s",
url.QueryEscape(h.cfg.GitHubClientID),
url.QueryEscape(canonicalAPIBase+"/auth/github/callback"),
url.QueryEscape(state),
Expand Down Expand Up @@ -1050,7 +1065,7 @@ func (h *AuthHandler) GoogleStart(c *fiber.Ctx) error {
// P1-K: record the state in Redis so the callback can consume it once.
h.registerOAuthState(c.Context(), state)

u, _ := url.Parse("https://accounts.google.com/o/oauth2/v2/auth")
u, _ := url.Parse(googleAuthorizeURL)
q := u.Query()
q.Set("client_id", h.cfg.GoogleClientID)
q.Set("redirect_uri", canonicalAPIBase+"/auth/google/callback")
Expand Down
296 changes: 296 additions & 0 deletions internal/handlers/auth_branches_coverage_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
package handlers_test

// auth_branches_coverage_test.go — fills the remaining auth.go / cli_auth.go /
// magic_link.go branches the grading regex (TestAuth* / TestCLI* / TestMagicLink*)
// otherwise misses:
// * registerOAuthState + consumeOAuthState single-use (Redis-backed callback
// replay rejection) — drives the GitHubCallback / GoogleCallbackBrowser
// "Sign-in already used" branch.
// * GetCurrentUser DB branches: happy, admin, impersonation, team-not-found,
// user-not-found, bad team/user ids.
// * magic-link Start with a failing mailer (persist-failed + email-send-failed
// log branch).
// * markEmailVerified via a pre-verified GitHub account (already-verified no-op).

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/gofiber/fiber/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"instant.dev/internal/config"
"instant.dev/internal/handlers"
"instant.dev/internal/middleware"
"instant.dev/internal/plans"
"instant.dev/internal/testhelpers"
)

// --- OAuth state single-use (Redis) ---

// TestAuth_GitHubCallback_SingleUseReplayRejected wires a real Redis client so
// registerOAuthState (on Start) + consumeOAuthState (on the first Callback) run
// the success path; a second Callback with the same state finds the key gone
// and is rejected with "Sign-in already used".
func TestAuth_GitHubCallback_SingleUseReplayRejected(t *testing.T) {
db, clean := testhelpers.SetupTestDB(t)
defer clean()
settleAuditDB(t, db)
rdb, cleanRedis := testhelpers.SetupTestRedis(t)
defer cleanRedis()

startFakeOAuth(t, &fakeOAuthServer{ghID: uniqueGHID(), ghEmail: testhelpers.UniqueEmail(t)})

h := handlers.NewAuthHandler(db, oauthCfg())
h.SetRedis(rdb)
app := buildAuthApp(h)

startResp := getReq(t, app, "/auth/github/start?return_to=https://instanode.dev/x")
cookie := firstCookie(startResp.Header.Get("Set-Cookie"))
state := extractQueryParam(startResp.Header.Get("Location"), "state")
startResp.Body.Close()
require.NotEmpty(t, state)

// First callback consumes the state → 302 success.
req := httptest.NewRequest(http.MethodGet, "/auth/github/callback?code=c&state="+state, nil)
req.Header.Set("Cookie", cookie)
resp, err := app.Test(req, 5000)
require.NoError(t, err)
assert.Equal(t, http.StatusFound, resp.StatusCode)
resp.Body.Close()

// Replay with the same state + cookie → consumeOAuthState returns false →
// "Sign-in already used" 400.
req2 := httptest.NewRequest(http.MethodGet, "/auth/github/callback?code=c&state="+state, nil)
req2.Header.Set("Cookie", cookie)
resp2, err := app.Test(req2, 5000)
require.NoError(t, err)
defer resp2.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp2.StatusCode)
}

// Same single-use replay path for the Google browser callback.
func TestAuth_GoogleCallbackBrowser_SingleUseReplayRejected(t *testing.T) {
db, clean := testhelpers.SetupTestDB(t)
defer clean()
settleAuditDB(t, db)
rdb, cleanRedis := testhelpers.SetupTestRedis(t)
defer cleanRedis()

startFakeOAuth(t, &fakeOAuthServer{gSub: uniqueGHID(), gEmail: testhelpers.UniqueEmail(t)})

h := handlers.NewAuthHandler(db, oauthCfg())
h.SetRedis(rdb)
app := buildAuthApp(h)

startResp := getReq(t, app, "/auth/google/start?return_to=https://instanode.dev/x")
cookie := firstCookie(startResp.Header.Get("Set-Cookie"))
state := extractQueryParam(startResp.Header.Get("Location"), "state")
startResp.Body.Close()
require.NotEmpty(t, state)

req := httptest.NewRequest(http.MethodGet, "/auth/google/callback/browser?code=c&state="+state, nil)
req.Header.Set("Cookie", cookie)
resp, err := app.Test(req, 5000)
require.NoError(t, err)
assert.Equal(t, http.StatusFound, resp.StatusCode)
resp.Body.Close()

req2 := httptest.NewRequest(http.MethodGet, "/auth/google/callback/browser?code=c&state="+state, nil)
req2.Header.Set("Cookie", cookie)
resp2, err := app.Test(req2, 5000)
require.NoError(t, err)
defer resp2.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp2.StatusCode)
}

// TestAuth_GitHubCallback_StateConsumeFailClosed wires a Redis client pointed
// at a dead address: registerOAuthState (Start) logs+swallows the SET error,
// and consumeOAuthState (Callback) hits the GetDel-error fail-CLOSED branch
// → "Sign-in already used" 400 even though the cookie state matches. Covers
// the T10 P1-3 fail-closed path.
func TestAuth_GitHubCallback_StateConsumeFailClosed(t *testing.T) {
bad := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
MaxRetries: -1,
})
defer bad.Close()

startFakeOAuth(t, &fakeOAuthServer{ghID: uniqueGHID(), ghEmail: testhelpers.UniqueEmail(t)})
h := handlers.NewAuthHandler(nil, oauthCfg())
h.SetRedis(bad)
app := buildAuthApp(h)

startResp := getReq(t, app, "/auth/github/start?return_to=https://instanode.dev/x")
cookie := firstCookie(startResp.Header.Get("Set-Cookie"))
state := extractQueryParam(startResp.Header.Get("Location"), "state")
startResp.Body.Close()
require.NotEmpty(t, state)

req := httptest.NewRequest(http.MethodGet, "/auth/github/callback?code=c&state="+state, nil)
req.Header.Set("Cookie", cookie)
resp, err := app.Test(req, 5000)
require.NoError(t, err)
defer resp.Body.Close()
// Fail-closed: a Redis error on consume rejects the (otherwise valid) state.
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
}

// --- GET /auth/me (GetCurrentUser) DB branches ---

func TestCLI_GetCurrentUser_HappyAdminImpersonation(t *testing.T) {
db, clean := testhelpers.SetupTestDB(t)
defer clean()
settleAuditDB(t, db)
rdb, cleanRedis := testhelpers.SetupTestRedis(t)
defer cleanRedis()

adminEmail := testhelpers.UniqueEmail(t)
t.Setenv("ADMIN_EMAILS", adminEmail)

cfg := &config.Config{
JWTSecret: testhelpers.TestJWTSecret,
AESKey: testhelpers.TestAESKeyHex,
Environment: "test",
AdminPathPrefix: "abcdefghijklmnopqrstuvwxyz012345",
}
h := handlers.NewCLIAuthHandler(db, rdb, cfg, plans.Default())

app := fiber.New(fiber.Config{
ErrorHandler: func(c *fiber.Ctx, err error) error {
if err == handlers.ErrResponseWritten {
return nil
}
code := fiber.StatusInternalServerError
if e, ok := err.(*fiber.Error); ok {
code = e.Code
}
return c.Status(code).JSON(fiber.Map{"ok": false, "error": err.Error()})
},
})
app.Use(middleware.RequestID())
app.Get("/auth/me", middleware.RequireAuth(cfg), h.GetCurrentUser)

teamID := testhelpers.MustCreateTeamDB(t, db, "pro")
var userID string
require.NoError(t, db.QueryRowContext(context.Background(),
`INSERT INTO users (team_id, email) VALUES ($1::uuid, $2) RETURNING id::text`,
teamID, adminEmail).Scan(&userID))

token := testhelpers.MustSignSessionJWT(t, userID, teamID, adminEmail)
req := httptest.NewRequest(http.MethodGet, "/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := app.Test(req, 5000)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

var body map[string]any
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
assert.Equal(t, "pro", body["tier"])
assert.Equal(t, true, body["is_platform_admin"])
assert.Equal(t, "abcdefghijklmnopqrstuvwxyz012345", body["admin_path_prefix"])
}

func TestCLI_GetCurrentUser_TeamAndUserNotFound(t *testing.T) {
db, clean := testhelpers.SetupTestDB(t)
defer clean()
settleAuditDB(t, db)
rdb, cleanRedis := testhelpers.SetupTestRedis(t)
defer cleanRedis()

cfg := &config.Config{JWTSecret: testhelpers.TestJWTSecret, AESKey: testhelpers.TestAESKeyHex, Environment: "test"}
h := handlers.NewCLIAuthHandler(db, rdb, cfg, plans.Default())
app := fiber.New(fiber.Config{
ErrorHandler: func(c *fiber.Ctx, err error) error {
if err == handlers.ErrResponseWritten {
return nil
}
code := fiber.StatusInternalServerError
if e, ok := err.(*fiber.Error); ok {
code = e.Code
}
return c.Status(code).JSON(fiber.Map{"ok": false, "error": err.Error()})
},
})
app.Use(middleware.RequestID())
app.Get("/auth/me", middleware.RequireAuth(cfg), h.GetCurrentUser)

// team-not-found: JWT references a random (nonexistent) team id.
randTeam := "00000000-0000-0000-0000-0000000000aa"
randUser := "00000000-0000-0000-0000-0000000000bb"
tok := testhelpers.MustSignSessionJWT(t, randUser, randTeam, "ghost@example.com")
req := httptest.NewRequest(http.MethodGet, "/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+tok)
resp, err := app.Test(req, 5000)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusNotFound, resp.StatusCode)

// user-not-found: real team, but the user id isn't in users.
teamID := testhelpers.MustCreateTeamDB(t, db, "hobby")
tok2 := testhelpers.MustSignSessionJWT(t, randUser, teamID, "ghost@example.com")
req2 := httptest.NewRequest(http.MethodGet, "/auth/me", nil)
req2.Header.Set("Authorization", "Bearer "+tok2)
resp2, err := app.Test(req2, 5000)
require.NoError(t, err)
defer resp2.Body.Close()
assert.Equal(t, http.StatusNotFound, resp2.StatusCode)
}

// --- magic-link Start failing-mailer branch ---

type failingMailer struct{}

func (failingMailer) SendMagicLink(ctx context.Context, to, link string) error {
return fmt.Errorf("simulated brevo outage")
}

func TestMagicLink_Start_SendFailureStillReturns202(t *testing.T) {
db, clean := testhelpers.SetupTestDB(t)
defer clean()
settleAuditDB(t, db)
rdb, cleanRedis := testhelpers.SetupTestRedis(t)
defer cleanRedis()

cfg := &config.Config{JWTSecret: testhelpers.TestJWTSecret, AESKey: testhelpers.TestAESKeyHex}
authH := handlers.NewAuthHandler(db, cfg)
mlH := handlers.NewMagicLinkHandlerWithMailerAndRedis(db, cfg, failingMailer{}, authH, rdb)
app := fiber.New(fiber.Config{
ErrorHandler: func(c *fiber.Ctx, err error) error {
if err == handlers.ErrResponseWritten {
return nil
}
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"ok": false})
},
})
app.Use(middleware.RequestID())
app.Post("/auth/email/start", mlH.Start)

email := testhelpers.UniqueEmail(t)
req := httptest.NewRequest(http.MethodPost, "/auth/email/start", strings.NewReader(fmt.Sprintf(`{"email":%q}`, email)))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req, 5000)
require.NoError(t, err)
defer resp.Body.Close()
// Enumeration-defence contract: 202 even when the send fails.
assert.Equal(t, http.StatusAccepted, resp.StatusCode)

// The row must record a send-failed status (persistMagicLinkSendStatus
// failure branch). We just assert the row exists.
var n int
require.NoError(t, db.QueryRowContext(context.Background(),
`SELECT count(*) FROM magic_links WHERE email = $1`, email).Scan(&n))
assert.Equal(t, 1, n)
}
Loading
Loading