Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions crypto/kek.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package crypto

import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"slices"
"sync"
)

type (
// KEK defines an interface for Key Encryption Keys.
// These keys are used to encrypt/decrypt DEKs and are customer-managed (e.g. via AWS/GCP KMS).
KEK interface {
io.Closer

ID() string // A unique ID for this KEK, e.g. KMS ARN
Encrypt(context.Context, []byte) ([]byte, error)
Decrypt(context.Context, []byte) ([]byte, error)
}

// KEKEncryptor wraps a current [KEK] used to encrypt new DEKs and a set of retired
// KEKs retained for decryption only. It looks up KEKs by ID when opening
// [DEKMaterial] so DEKs encrypted with rotated-out keys remain readable.
KEKEncryptor struct {
currentKey KEK
retiredKeys []KEK
keys map[string]KEK // NB: fast lookup by id

closeOnce sync.Once
closeErr error
}

// KEKEncryptorOption configures a [KEKEncryptor] during construction.
KEKEncryptorOption func(*KEKEncryptor)

// nilKEK defines a [KEK] that does nothing.
nilKEK struct{}
)

// NewKEKEncryptor constructs a [KEKEncryptor] using the supplied opts.
func NewKEKEncryptor(opts ...KEKEncryptorOption) *KEKEncryptor {
r := &KEKEncryptor{
currentKey: new(nilKEK),
keys: make(map[string]KEK),
}
for _, opt := range opts {
opt(r)
}

r.keys[r.currentKey.ID()] = r.currentKey
for _, k := range r.retiredKeys {
r.keys[k.ID()] = k
}

return r
}

// WithKey sets the [KEK] used to encrypt new DEKs. Defaults to a no-op KEK when unset.
func WithKey(k KEK) KEKEncryptorOption {
return func(r *KEKEncryptor) {
if k != nil { // nil means fallback to nilKEK.
r.currentKey = k
}
}
}

// WithRetiredKey registers k for decryption only. It is added to the key-ID index so that DEKs
// encrypted with k can still be opened, but k is never selected for new DEK encryption.
func WithRetiredKey(k KEK) KEKEncryptorOption {
return func(r *KEKEncryptor) {
if !slices.Contains(r.retiredKeys, k) {
r.retiredKeys = append(r.retiredKeys, k)
}
}
}

// Encrypt encrypts the given DEK using the current KEK. It returns DEKMaterial
// containing the KEK ID and the base64-encoded ciphertext.
func (p *KEKEncryptor) Encrypt(ctx context.Context, dek *DEK) (*DEKMaterial, error) {
k := p.currentKey
ct, err := k.Encrypt(ctx, dek.key)
if err != nil {
return nil, fmt.Errorf("failed to encrypt message: %w", err)
}

return &DEKMaterial{
KEKID: k.ID(),
EncryptedDEK: base64.StdEncoding.EncodeToString(ct),
}, nil
}

// Decrypt decrypts the DEK described by m using the KEK identified by m.KEKID.
func (p *KEKEncryptor) Decrypt(ctx context.Context, m *DEKMaterial) (*DEK, error) {
k, ok := p.keys[m.KEKID]
if !ok {
return nil, fmt.Errorf("unknown key: %s", m.KEKID)
}

ct, err := base64.StdEncoding.DecodeString(m.EncryptedDEK)
if err != nil {
return nil, fmt.Errorf("failed to decode DEK: %s, %w", m.KEKID, err)
}

key, err := k.Decrypt(ctx, ct)
if err != nil {
return nil, fmt.Errorf("failed to decrypt using KEK: %s, %w", m.KEKID, err)
}

if len(key) != keyBytes {
return nil, fmt.Errorf("invalid DEK for KEK: %s", m.KEKID)
}

return dekFromKey(key)
}

// Close closes all registered KEKs and releases their resources.
// Subsequent calls return the same error as the first call.
func (p *KEKEncryptor) Close() error {
// Blocking concurrent callers here is acceptable: Close is a shutdown
// operation; callers should not race to close, and if they do, waiting
// for a single authoritative result is the right behaviour.
p.closeOnce.Do(func() {
errs := make([]error, 0, len(p.keys)+1)
for id, kek := range p.keys {
if err := kek.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close KEK: %s, %w", id, err))
}
}

p.closeErr = errors.Join(errs...)
})

return p.closeErr
}

func (k *nilKEK) ID() string { return "EMPTY_KEK" }
func (k *nilKEK) Encrypt(_ context.Context, pt []byte) ([]byte, error) { return pt, nil }
func (k *nilKEK) Decrypt(_ context.Context, ct []byte) ([]byte, error) { return ct, nil }
func (k *nilKEK) Close() error { return nil }
249 changes: 249 additions & 0 deletions crypto/kek_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
package crypto_test

import (
"context"
"errors"
"sync"
"testing"

"github.com/stretchr/testify/require"

"github.com/temporalio/s2s-proxy/crypto"
)

type fakeKEK struct {
id string
closeCount int
closeErr error
encErr error
decErr error
decResult []byte // when non-nil, returned from Decrypt instead of ct
}

func TestKEKEncryptor_Encrypt(t *testing.T) {
t.Parallel()

dek, err := crypto.NewDEK()
require.NoError(t, err)

tests := []struct {
name string
opts []crypto.KEKEncryptorOption
wantErr bool
wantKEKID string
}{
{
name: "without default uses nilKEK",
wantKEKID: "EMPTY_KEK",
},
{
name: "",
opts: []crypto.KEKEncryptorOption{crypto.WithKey(&fakeKEK{id: "default"})},
wantKEKID: "default",
},
{
name: "kek encrypt error",
opts: []crypto.KEKEncryptorOption{crypto.WithKey(&fakeKEK{id: "k1", encErr: errors.New("kms unavailable")})},
wantErr: true,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

r := crypto.NewKEKEncryptor(tc.opts...)
m, err := r.Encrypt(t.Context(), dek)
if tc.wantErr {
require.Error(t, err)
return
}

require.NoError(t, err)
require.Equal(t, tc.wantKEKID, m.KEKID)
require.NotEmpty(t, m.EncryptedDEK)
})
}
}

func TestKEKEncryptor_Decrypt(t *testing.T) {
t.Parallel()

tests := []struct {
name string
material *crypto.DEKMaterial
opts []crypto.KEKEncryptorOption
wantErr bool
}{
{
name: "unknown key id",
material: &crypto.DEKMaterial{KEKID: "missing"},
wantErr: true,
},
{
name: "invalid base64",
material: &crypto.DEKMaterial{KEKID: "k1", EncryptedDEK: "not-valid-base64!!!"},
opts: []crypto.KEKEncryptorOption{crypto.WithKey(&fakeKEK{id: "k1"})},
wantErr: true,
},
{
name: "kek decrypt error",
material: &crypto.DEKMaterial{KEKID: "k1", EncryptedDEK: "AAAA"},
opts: []crypto.KEKEncryptorOption{crypto.WithKey(&fakeKEK{id: "k1", decErr: errors.New("kms unavailable")})},
wantErr: true,
},
{
name: "wrong-length dek",
material: &crypto.DEKMaterial{KEKID: "k1", EncryptedDEK: "AAAA"},
opts: []crypto.KEKEncryptorOption{crypto.WithKey(&fakeKEK{id: "k1", decResult: []byte("too short")})},
wantErr: true,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

r := crypto.NewKEKEncryptor(tc.opts...)
_, err := r.Decrypt(t.Context(), tc.material)
require.Error(t, err)
})
}
}

func TestKEKEncryptor_RetiredKeyDecrypt(t *testing.T) {
t.Parallel()

original, err := crypto.NewDEK()
require.NoError(t, err)

active := &fakeKEK{id: "active"}
retired := &fakeKEK{id: "retired"}

// Encrypt with the retired key directly (simulating a DEK from before key rotation).
r := crypto.NewKEKEncryptor(crypto.WithKey(retired))
m, err := r.Encrypt(t.Context(), original)
require.NoError(t, err)
require.Equal(t, "retired", m.KEKID)

// New registry: active key for encryption, retired key for decryption only.
r2 := crypto.NewKEKEncryptor(
crypto.WithKey(active),
crypto.WithRetiredKey(retired),
)

t.Run("new DEKs use active key", func(t *testing.T) {
t.Parallel()

m2, err := r2.Encrypt(t.Context(), original)
require.NoError(t, err)
require.Equal(t, "active", m2.KEKID)
})

t.Run("old DEKs decrypt via retired key", func(t *testing.T) {
t.Parallel()

recovered, err := r2.Decrypt(t.Context(), m)
require.NoError(t, err)
require.Equal(t, original, recovered)
})
}

func TestKEKEncryptor_Close(t *testing.T) {
t.Parallel()

t.Run("closes all keks", func(t *testing.T) {
t.Parallel()

k1 := &fakeKEK{id: "k1"}
k2 := &fakeKEK{id: "k2"}
r := crypto.NewKEKEncryptor(
crypto.WithKey(k1),
crypto.WithRetiredKey(k2),
)

require.NoError(t, r.Close())
require.Equal(t, 1, k1.closeCount)
require.Equal(t, 1, k2.closeCount)
})

t.Run("returns error on close failure", func(t *testing.T) {
t.Parallel()

kek := &fakeKEK{id: "k1", closeErr: errors.New("close failed")}
r := crypto.NewKEKEncryptor(crypto.WithKey(kek))
require.Error(t, r.Close())
})

t.Run("idempotent - same result on repeated close", func(t *testing.T) {
t.Parallel()

kek := &fakeKEK{id: "k1", closeErr: errors.New("close failed")}
r := crypto.NewKEKEncryptor(crypto.WithKey(kek))

err1 := r.Close()
err2 := r.Close()
require.Equal(t, err1, err2)
require.Equal(t, 1, kek.closeCount)
})

t.Run("closes retired keys", func(t *testing.T) {
t.Parallel()

active := &fakeKEK{id: "active"}
retired := &fakeKEK{id: "retired"}
r := crypto.NewKEKEncryptor(
crypto.WithKey(active),
crypto.WithRetiredKey(retired),
)

require.NoError(t, r.Close())
require.Equal(t, 1, active.closeCount)
require.Equal(t, 1, retired.closeCount)
})

t.Run("concurrent close calls kek once", func(t *testing.T) {
t.Parallel()

kek := &fakeKEK{id: "k1"}
r := crypto.NewKEKEncryptor(crypto.WithKey(kek))

errs := make([]error, 20)
var wg sync.WaitGroup
for i := range len(errs) {
wg.Go(func() {
errs[i] = r.Close()
})
}

wg.Wait()
require.Equal(t, 1, kek.closeCount)

for _, err := range errs {
require.NoError(t, err)
}
})
}

func (f *fakeKEK) ID() string { return f.id }

func (f *fakeKEK) Encrypt(_ context.Context, pt []byte) ([]byte, error) {
if f.encErr != nil {
return nil, f.encErr
}

return pt, nil
}

func (f *fakeKEK) Decrypt(_ context.Context, ct []byte) ([]byte, error) {
if f.decResult != nil || f.decErr != nil {
return f.decResult, f.decErr
}

return ct, nil
}

func (f *fakeKEK) Close() error {
f.closeCount++
return f.closeErr
}
Loading