diff --git a/crypto/kek.go b/crypto/kek.go new file mode 100644 index 0000000..f3b7e6c --- /dev/null +++ b/crypto/kek.go @@ -0,0 +1,142 @@ +package crypto + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "slices" + "sync" +) + +type ( + // KEK defines an interface for Key Encryption Keys. + // These keys are used to encrypt/decrypt DEKs and are customer-managed (e.g. via AWS/GCP KMS). + KEK interface { + io.Closer + + ID() string // A unique ID for this KEK, e.g. KMS ARN + Encrypt(context.Context, []byte) ([]byte, error) + Decrypt(context.Context, []byte) ([]byte, error) + } + + // KEKEncryptor wraps a current [KEK] used to encrypt new DEKs and a set of retired + // KEKs retained for decryption only. It looks up KEKs by ID when opening + // [DEKMaterial] so DEKs encrypted with rotated-out keys remain readable. + KEKEncryptor struct { + currentKey KEK + retiredKeys []KEK + keys map[string]KEK // NB: fast lookup by id + + closeOnce sync.Once + closeErr error + } + + // KEKEncryptorOption configures a [KEKEncryptor] during construction. + KEKEncryptorOption func(*KEKEncryptor) + + // nilKEK defines a [KEK] that does nothing. + nilKEK struct{} +) + +// NewKEKEncryptor constructs a [KEKEncryptor] using the supplied opts. +func NewKEKEncryptor(opts ...KEKEncryptorOption) *KEKEncryptor { + r := &KEKEncryptor{ + currentKey: new(nilKEK), + keys: make(map[string]KEK), + } + for _, opt := range opts { + opt(r) + } + + r.keys[r.currentKey.ID()] = r.currentKey + for _, k := range r.retiredKeys { + r.keys[k.ID()] = k + } + + return r +} + +// WithKey sets the [KEK] used to encrypt new DEKs. Defaults to a no-op KEK when unset. +func WithKey(k KEK) KEKEncryptorOption { + return func(r *KEKEncryptor) { + if k != nil { // nil means fallback to nilKEK. + r.currentKey = k + } + } +} + +// WithRetiredKey registers k for decryption only. It is added to the key-ID index so that DEKs +// encrypted with k can still be opened, but k is never selected for new DEK encryption. +func WithRetiredKey(k KEK) KEKEncryptorOption { + return func(r *KEKEncryptor) { + if !slices.Contains(r.retiredKeys, k) { + r.retiredKeys = append(r.retiredKeys, k) + } + } +} + +// Encrypt encrypts the given DEK using the current KEK. It returns DEKMaterial +// containing the KEK ID and the base64-encoded ciphertext. +func (p *KEKEncryptor) Encrypt(ctx context.Context, dek *DEK) (*DEKMaterial, error) { + k := p.currentKey + ct, err := k.Encrypt(ctx, dek.key) + if err != nil { + return nil, fmt.Errorf("failed to encrypt message: %w", err) + } + + return &DEKMaterial{ + KEKID: k.ID(), + EncryptedDEK: base64.StdEncoding.EncodeToString(ct), + }, nil +} + +// Decrypt decrypts the DEK described by m using the KEK identified by m.KEKID. +func (p *KEKEncryptor) Decrypt(ctx context.Context, m *DEKMaterial) (*DEK, error) { + k, ok := p.keys[m.KEKID] + if !ok { + return nil, fmt.Errorf("unknown key: %s", m.KEKID) + } + + ct, err := base64.StdEncoding.DecodeString(m.EncryptedDEK) + if err != nil { + return nil, fmt.Errorf("failed to decode DEK: %s, %w", m.KEKID, err) + } + + key, err := k.Decrypt(ctx, ct) + if err != nil { + return nil, fmt.Errorf("failed to decrypt using KEK: %s, %w", m.KEKID, err) + } + + if len(key) != keyBytes { + return nil, fmt.Errorf("invalid DEK for KEK: %s", m.KEKID) + } + + return dekFromKey(key) +} + +// Close closes all registered KEKs and releases their resources. +// Subsequent calls return the same error as the first call. +func (p *KEKEncryptor) Close() error { + // Blocking concurrent callers here is acceptable: Close is a shutdown + // operation; callers should not race to close, and if they do, waiting + // for a single authoritative result is the right behaviour. + p.closeOnce.Do(func() { + errs := make([]error, 0, len(p.keys)+1) + for id, kek := range p.keys { + if err := kek.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close KEK: %s, %w", id, err)) + } + } + + p.closeErr = errors.Join(errs...) + }) + + return p.closeErr +} + +func (k *nilKEK) ID() string { return "EMPTY_KEK" } +func (k *nilKEK) Encrypt(_ context.Context, pt []byte) ([]byte, error) { return pt, nil } +func (k *nilKEK) Decrypt(_ context.Context, ct []byte) ([]byte, error) { return ct, nil } +func (k *nilKEK) Close() error { return nil } diff --git a/crypto/kek_test.go b/crypto/kek_test.go new file mode 100644 index 0000000..633472c --- /dev/null +++ b/crypto/kek_test.go @@ -0,0 +1,249 @@ +package crypto_test + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/temporalio/s2s-proxy/crypto" +) + +type fakeKEK struct { + id string + closeCount int + closeErr error + encErr error + decErr error + decResult []byte // when non-nil, returned from Decrypt instead of ct +} + +func TestKEKEncryptor_Encrypt(t *testing.T) { + t.Parallel() + + dek, err := crypto.NewDEK() + require.NoError(t, err) + + tests := []struct { + name string + opts []crypto.KEKEncryptorOption + wantErr bool + wantKEKID string + }{ + { + name: "without default uses nilKEK", + wantKEKID: "EMPTY_KEK", + }, + { + name: "", + opts: []crypto.KEKEncryptorOption{crypto.WithKey(&fakeKEK{id: "default"})}, + wantKEKID: "default", + }, + { + name: "kek encrypt error", + opts: []crypto.KEKEncryptorOption{crypto.WithKey(&fakeKEK{id: "k1", encErr: errors.New("kms unavailable")})}, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + r := crypto.NewKEKEncryptor(tc.opts...) + m, err := r.Encrypt(t.Context(), dek) + if tc.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tc.wantKEKID, m.KEKID) + require.NotEmpty(t, m.EncryptedDEK) + }) + } +} + +func TestKEKEncryptor_Decrypt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + material *crypto.DEKMaterial + opts []crypto.KEKEncryptorOption + wantErr bool + }{ + { + name: "unknown key id", + material: &crypto.DEKMaterial{KEKID: "missing"}, + wantErr: true, + }, + { + name: "invalid base64", + material: &crypto.DEKMaterial{KEKID: "k1", EncryptedDEK: "not-valid-base64!!!"}, + opts: []crypto.KEKEncryptorOption{crypto.WithKey(&fakeKEK{id: "k1"})}, + wantErr: true, + }, + { + name: "kek decrypt error", + material: &crypto.DEKMaterial{KEKID: "k1", EncryptedDEK: "AAAA"}, + opts: []crypto.KEKEncryptorOption{crypto.WithKey(&fakeKEK{id: "k1", decErr: errors.New("kms unavailable")})}, + wantErr: true, + }, + { + name: "wrong-length dek", + material: &crypto.DEKMaterial{KEKID: "k1", EncryptedDEK: "AAAA"}, + opts: []crypto.KEKEncryptorOption{crypto.WithKey(&fakeKEK{id: "k1", decResult: []byte("too short")})}, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + r := crypto.NewKEKEncryptor(tc.opts...) + _, err := r.Decrypt(t.Context(), tc.material) + require.Error(t, err) + }) + } +} + +func TestKEKEncryptor_RetiredKeyDecrypt(t *testing.T) { + t.Parallel() + + original, err := crypto.NewDEK() + require.NoError(t, err) + + active := &fakeKEK{id: "active"} + retired := &fakeKEK{id: "retired"} + + // Encrypt with the retired key directly (simulating a DEK from before key rotation). + r := crypto.NewKEKEncryptor(crypto.WithKey(retired)) + m, err := r.Encrypt(t.Context(), original) + require.NoError(t, err) + require.Equal(t, "retired", m.KEKID) + + // New registry: active key for encryption, retired key for decryption only. + r2 := crypto.NewKEKEncryptor( + crypto.WithKey(active), + crypto.WithRetiredKey(retired), + ) + + t.Run("new DEKs use active key", func(t *testing.T) { + t.Parallel() + + m2, err := r2.Encrypt(t.Context(), original) + require.NoError(t, err) + require.Equal(t, "active", m2.KEKID) + }) + + t.Run("old DEKs decrypt via retired key", func(t *testing.T) { + t.Parallel() + + recovered, err := r2.Decrypt(t.Context(), m) + require.NoError(t, err) + require.Equal(t, original, recovered) + }) +} + +func TestKEKEncryptor_Close(t *testing.T) { + t.Parallel() + + t.Run("closes all keks", func(t *testing.T) { + t.Parallel() + + k1 := &fakeKEK{id: "k1"} + k2 := &fakeKEK{id: "k2"} + r := crypto.NewKEKEncryptor( + crypto.WithKey(k1), + crypto.WithRetiredKey(k2), + ) + + require.NoError(t, r.Close()) + require.Equal(t, 1, k1.closeCount) + require.Equal(t, 1, k2.closeCount) + }) + + t.Run("returns error on close failure", func(t *testing.T) { + t.Parallel() + + kek := &fakeKEK{id: "k1", closeErr: errors.New("close failed")} + r := crypto.NewKEKEncryptor(crypto.WithKey(kek)) + require.Error(t, r.Close()) + }) + + t.Run("idempotent - same result on repeated close", func(t *testing.T) { + t.Parallel() + + kek := &fakeKEK{id: "k1", closeErr: errors.New("close failed")} + r := crypto.NewKEKEncryptor(crypto.WithKey(kek)) + + err1 := r.Close() + err2 := r.Close() + require.Equal(t, err1, err2) + require.Equal(t, 1, kek.closeCount) + }) + + t.Run("closes retired keys", func(t *testing.T) { + t.Parallel() + + active := &fakeKEK{id: "active"} + retired := &fakeKEK{id: "retired"} + r := crypto.NewKEKEncryptor( + crypto.WithKey(active), + crypto.WithRetiredKey(retired), + ) + + require.NoError(t, r.Close()) + require.Equal(t, 1, active.closeCount) + require.Equal(t, 1, retired.closeCount) + }) + + t.Run("concurrent close calls kek once", func(t *testing.T) { + t.Parallel() + + kek := &fakeKEK{id: "k1"} + r := crypto.NewKEKEncryptor(crypto.WithKey(kek)) + + errs := make([]error, 20) + var wg sync.WaitGroup + for i := range len(errs) { + wg.Go(func() { + errs[i] = r.Close() + }) + } + + wg.Wait() + require.Equal(t, 1, kek.closeCount) + + for _, err := range errs { + require.NoError(t, err) + } + }) +} + +func (f *fakeKEK) ID() string { return f.id } + +func (f *fakeKEK) Encrypt(_ context.Context, pt []byte) ([]byte, error) { + if f.encErr != nil { + return nil, f.encErr + } + + return pt, nil +} + +func (f *fakeKEK) Decrypt(_ context.Context, ct []byte) ([]byte, error) { + if f.decResult != nil || f.decErr != nil { + return f.decResult, f.decErr + } + + return ct, nil +} + +func (f *fakeKEK) Close() error { + f.closeCount++ + return f.closeErr +}