Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions crypto/aes_errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package crypto_test

import (
"errors"
"strings"
"testing"

"instant.dev/common/crypto"
)

// Cover the Error/Unwrap methods of the typed AES errors. These are surfaced to
// callers who use errors.Is/errors.As to distinguish failure modes.

func TestErrEncrypt_Wrapping(t *testing.T) {
cause := errors.New("enc boom")
e := &crypto.ErrEncrypt{Cause: cause}
if !strings.Contains(e.Error(), "enc boom") {
t.Errorf("Error() = %q", e.Error())
}
if !errors.Is(e, cause) {
t.Error("Unwrap should return cause")
}
}

func TestErrDecrypt_Wrapping(t *testing.T) {
cause := errors.New("dec boom")
e := &crypto.ErrDecrypt{Cause: cause}
if !strings.Contains(e.Error(), "dec boom") {
t.Errorf("Error() = %q", e.Error())
}
if !errors.Is(e, cause) {
t.Error("Unwrap should return cause")
}
}

func TestEncrypt_BadKeyLen(t *testing.T) {
// 5-byte key — AES rejects.
_, err := crypto.Encrypt([]byte{1, 2, 3, 4, 5}, "x")
if err == nil {
t.Error("expected error for invalid key length")
}
}

func TestDecrypt_BadKeyLen(t *testing.T) {
_, err := crypto.Decrypt([]byte{1, 2, 3}, "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBka")
if err == nil {
t.Error("expected error for invalid key length")
}
}
94 changes: 94 additions & 0 deletions crypto/fingerprint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package crypto_test

import (
"net"
"testing"

"instant.dev/common/crypto"
)

func TestFingerprint_IPv4(t *testing.T) {
ip := net.ParseIP("198.51.100.42")
fp := crypto.Fingerprint(ip, 12345)
if len(fp) != 32 {
t.Errorf("expected 32-char hex, got %d", len(fp))
}

// Same /24 subnet → same fingerprint
other := net.ParseIP("198.51.100.7")
fp2 := crypto.Fingerprint(other, 12345)
if fp != fp2 {
t.Errorf("same /24 should yield identical fingerprint, got %q vs %q", fp, fp2)
}

// Different ASN → different fingerprint
fp3 := crypto.Fingerprint(ip, 99999)
if fp == fp3 {
t.Errorf("different ASN should differ, both %q", fp)
}
}

func TestFingerprint_IPv6(t *testing.T) {
ip := net.ParseIP("2001:db8::1")
fp := crypto.Fingerprint(ip, 100)
if len(fp) != 32 {
t.Errorf("expected 32-char hex, got %d", len(fp))
}
other := net.ParseIP("2001:db8::ff")
fp2 := crypto.Fingerprint(other, 100)
if fp != fp2 {
t.Errorf("same /48 IPv6 should yield identical fingerprint")
}
}

func TestParseIP(t *testing.T) {
if ip := crypto.ParseIP("10.0.0.1"); ip == nil {
t.Error("ParseIP returned nil for valid IP")
}
if ip := crypto.ParseIP("not-an-ip"); ip != nil {
t.Errorf("ParseIP returned non-nil for garbage: %v", ip)
}
}

func TestFingerprintIP(t *testing.T) {
fp, err := crypto.FingerprintIP("198.51.100.42", "AS12345")
if err != nil {
t.Fatalf("FingerprintIP: %v", err)
}
if len(fp) != 32 {
t.Errorf("fp length = %d", len(fp))
}

// lowercase "as" prefix should also be stripped
fp2, err := crypto.FingerprintIP("198.51.100.42", "as12345")
if err != nil {
t.Fatalf("FingerprintIP lowercase: %v", err)
}
if fp != fp2 {
t.Errorf("AS vs as: should be equal, %q vs %q", fp, fp2)
}

// Empty ASN works too
fp3, err := crypto.FingerprintIP("198.51.100.42", "")
if err != nil {
t.Fatalf("FingerprintIP empty asn: %v", err)
}
// fp3 may differ from fp because ASN differs; just check it's well-formed.
if len(fp3) != 32 {
t.Errorf("fp3 length = %d", len(fp3))
}

// Invalid IP -> error
if _, err := crypto.FingerprintIP("garbage", ""); err == nil {
t.Error("expected error for invalid IP")
}

// Plain number ASN (no prefix)
fp4, err := crypto.FingerprintIP("10.0.0.1", "12345")
if err != nil {
t.Fatalf("plain asn: %v", err)
}
if len(fp4) != 32 {
t.Errorf("fp4 length = %d", len(fp4))
}
}
165 changes: 165 additions & 0 deletions crypto/jwt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package crypto_test

import (
"errors"
"strings"
"testing"
"time"

"github.com/golang-jwt/jwt/v4"

"instant.dev/common/crypto"
)

var jwtSecret = []byte("supersecret-test-key-32-byte-minimum-required-here-pad-zzzzzz")

func TestSignAndVerifyJWT_Roundtrip(t *testing.T) {
claims := crypto.InstantClaims{
Fingerprint: "fp1",
Country: "US",
CloudVendor: "aws",
Tokens: []string{"tok1"},
ResourceTypes: []string{"postgres"},
SuggestedPlan: "hobby",
}
signed, err := crypto.SignJWT(jwtSecret, claims)
if err != nil {
t.Fatalf("SignJWT: %v", err)
}
if signed == "" {
t.Fatal("expected signed token")
}

parsed, err := crypto.VerifyJWT(jwtSecret, signed)
if err != nil {
t.Fatalf("VerifyJWT: %v", err)
}
if parsed.Fingerprint != "fp1" || parsed.Country != "US" {
t.Errorf("parsed = %+v", parsed)
}
if parsed.ID == "" {
t.Error("expected auto-generated jti")
}
}

func TestVerifyJWT_BadSecret(t *testing.T) {
signed, _ := crypto.SignJWT(jwtSecret, crypto.InstantClaims{Fingerprint: "x"})
_, err := crypto.VerifyJWT([]byte("wrong-secret"), signed)
if err == nil {
t.Fatal("expected error from wrong secret")
}
}

func TestVerifyJWT_MalformedToken(t *testing.T) {
_, err := crypto.VerifyJWT(jwtSecret, "not-a-jwt")
if err == nil {
t.Fatal("expected error for malformed token")
}
}

func TestVerifyJWT_FutureIssuedAt(t *testing.T) {
claims := crypto.InstantClaims{Fingerprint: "fp"}
claims.IssuedAt = jwt.NewNumericDate(time.Now().UTC().Add(2 * time.Hour))
signed, err := crypto.SignJWT(jwtSecret, claims)
if err != nil {
t.Fatalf("SignJWT: %v", err)
}
_, err = crypto.VerifyJWT(jwtSecret, signed)
if err == nil {
t.Fatal("expected error for future iat")
}
}

func TestSignOnboardingJWT_Roundtrip(t *testing.T) {
claims := crypto.OnboardingClaims{
Fingerprint: "fp",
Tokens: []string{"a", "b"},
SuggestedPlan: "pro",
}
signed, jti, err := crypto.SignOnboardingJWT(jwtSecret, claims)
if err != nil {
t.Fatalf("SignOnboardingJWT: %v", err)
}
if jti == "" || signed == "" {
t.Error("expected non-empty jti + signed")
}
parsed, err := crypto.VerifyOnboardingJWT(jwtSecret, signed)
if err != nil {
t.Fatalf("VerifyOnboardingJWT: %v", err)
}
if parsed.ID != jti {
t.Errorf("ID = %q, want %q", parsed.ID, jti)
}
if len(parsed.Tokens) != 2 {
t.Errorf("Tokens = %v", parsed.Tokens)
}
// ExpiresAt should be ~7 days from now
if parsed.ExpiresAt == nil || time.Until(parsed.ExpiresAt.Time) < 6*24*time.Hour {
t.Errorf("expected ~7d expiry, got %v", parsed.ExpiresAt)
}
}

func TestVerifyOnboardingJWT_Bad(t *testing.T) {
if _, err := crypto.VerifyOnboardingJWT(jwtSecret, "garbage"); err == nil {
t.Fatal("expected error")
}
}

func TestVerifyOnboardingJWT_FutureIssuedAt(t *testing.T) {
// Hand-craft an onboarding JWT with iat in the future.
claims := crypto.OnboardingClaims{Fingerprint: "fp"}
claims.RegisteredClaims = jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now().UTC().Add(2 * time.Hour)),
ExpiresAt: jwt.NewNumericDate(time.Now().UTC().Add(72 * time.Hour)),
ID: "jti-future",
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := tok.SignedString(jwtSecret)
if err != nil {
t.Fatalf("manual sign: %v", err)
}
if _, err := crypto.VerifyOnboardingJWT(jwtSecret, signed); err == nil {
t.Fatal("expected error for future iat")
}
}

// TestErrJWTSign_Wrapping exercises the Error/Unwrap on the typed errors.
func TestErrJWTSign_Wrapping(t *testing.T) {
cause := errors.New("underlying boom")
e := &crypto.ErrJWTSign{Cause: cause}
if !strings.Contains(e.Error(), "boom") {
t.Errorf("Error() = %q", e.Error())
}
if !errors.Is(e, cause) {
t.Errorf("Unwrap should return cause")
}
}

func TestErrJWTVerify_Wrapping(t *testing.T) {
cause := errors.New("verify boom")
e := &crypto.ErrJWTVerify{Cause: cause}
if !strings.Contains(e.Error(), "verify boom") {
t.Errorf("Error() = %q", e.Error())
}
if !errors.Is(e, cause) {
t.Errorf("Unwrap should return cause")
}
}

// TestVerifyJWT_WrongAlg verifies the alg-confusion guard rejects tokens signed
// with an unexpected method.
func TestVerifyJWT_WrongAlg(t *testing.T) {
// Sign with the "none" alg by forcing an unsigned token. The library refuses
// to sign with "none" by default, so build with an unsupported alg path:
// craft a token claiming alg=ES256 but signed with HS256 — the parser will
// reject it because keyfunc only returns the HMAC key.
tok := jwt.New(jwt.SigningMethodHS256)
tok.Method = jwt.SigningMethodRS256 // mismatch — verify must refuse
// Sign with HMAC anyway (force the wrong signature path).
// Easier: pre-craft a fixed header-payload-fake-sig string.
bad := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJmcCI6ImEifQ.sig"
_, err := crypto.VerifyJWT(jwtSecret, bad)
if err == nil {
t.Fatal("expected error for non-HMAC alg")
}
}
42 changes: 42 additions & 0 deletions crypto/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package crypto_test

import (
"errors"
"strings"
"testing"

"instant.dev/common/crypto"
)

func TestGenerateAPIKey_Prefix(t *testing.T) {
k1, err := crypto.GenerateAPIKey()
if err != nil {
t.Fatalf("GenerateAPIKey: %v", err)
}
if !strings.HasPrefix(k1, "inst_live_") {
t.Errorf("missing prefix: %q", k1)
}
// 32-byte body → base64url ~43 chars; full key length is 10 + ~43.
if len(k1) < 30 {
t.Errorf("key suspiciously short: %q", k1)
}
}

func TestGenerateAPIKey_Unique(t *testing.T) {
a, _ := crypto.GenerateAPIKey()
b, _ := crypto.GenerateAPIKey()
if a == b {
t.Errorf("two keys collided: %q == %q", a, b)
}
}

func TestErrTokenGenerate_Wrapping(t *testing.T) {
cause := errors.New("rng broke")
e := &crypto.ErrTokenGenerate{Cause: cause}
if !strings.Contains(e.Error(), "rng broke") {
t.Errorf("Error() = %q", e.Error())
}
if !errors.Is(e, cause) {
t.Errorf("Unwrap should return cause")
}
}
Loading
Loading