diff --git a/crypto/dek.go b/crypto/dek.go new file mode 100644 index 0000000..91accd2 --- /dev/null +++ b/crypto/dek.go @@ -0,0 +1,79 @@ +package crypto + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "errors" + "fmt" +) + +const keyBytes = 32 + +var ErrMalformedCipherText = errors.New("invalid ciphertext, not encrypted with a DEK") + +type ( + // DEK defines a Data Encryption Key used to encrypt/decrypt payloads. + DEK struct { + key []byte + gcm cipher.AEAD + } + + // DEKMaterial defines the material needed in order to decrypt a payload. + DEKMaterial struct { + KEKID string // The ID/URI of the KEK the encrypted the DEK. + EncryptedDEK string // The base64-encoded encrypted DEK. + } +) + +// NewDEK generates a new random 256-bit Data Encryption Key. +func NewDEK() (*DEK, error) { + k := make([]byte, keyBytes) + if _, err := rand.Read(k); err != nil { + return nil, fmt.Errorf("failed to create DEK: %w", err) + } + + return dekFromKey(k) +} + +// Encrypt encrypts the plaintext pt using AES-256-GCM. The returned ciphertext +// is prefixed with the randomly generated nonce. +func (d *DEK) Encrypt(ctx context.Context, pt []byte) ([]byte, error) { + nonce := make([]byte, d.gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %w", err) + } + + return d.gcm.Seal(nonce, nonce, pt, nil), nil +} + +// Decrypt decrypts the ciphertext ct using AES-256-GCM. The ciphertext must be +// prefixed with the nonce, as produced by [DEK.Encrypt]. +func (d *DEK) Decrypt(ctx context.Context, ct []byte) ([]byte, error) { + ns := d.gcm.NonceSize() + if len(ct) < ns { + return nil, ErrMalformedCipherText + } + + pt, err := d.gcm.Open(nil, ct[:ns], ct[ns:], nil) + if err != nil { + return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err) + } + + return pt, nil +} + +func dekFromKey(key []byte) (*DEK, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("failed to create cipher block: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + return &DEK{key: key, gcm: gcm}, nil +} diff --git a/crypto/dek_test.go b/crypto/dek_test.go new file mode 100644 index 0000000..bd8b1e8 --- /dev/null +++ b/crypto/dek_test.go @@ -0,0 +1,130 @@ +package crypto_test + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/temporalio/s2s-proxy/crypto" +) + +func TestNewDEK(t *testing.T) { + t.Parallel() + + t.Run("non-zero", func(t *testing.T) { + t.Parallel() + + d, err := crypto.NewDEK() + require.NoError(t, err) + require.NotEqual(t, crypto.DEK{}, d) + }) + + t.Run("unique each call", func(t *testing.T) { + t.Parallel() + + d1, err := crypto.NewDEK() + require.NoError(t, err) + + d2, err := crypto.NewDEK() + require.NoError(t, err) + + require.NotEqual(t, d1, d2) + }) +} + +func TestDEKEncryptDecrypt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + plaintext []byte + }{ + {"empty", []byte{}}, + {"ascii", []byte("hello, world")}, + {"binary", []byte{0x00, 0xFF, 0x42, 0x13}}, + {"large", bytes.Repeat([]byte("a"), 64*1024)}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := t.Context() + dek, err := crypto.NewDEK() + require.NoError(t, err) + + ct, err := dek.Encrypt(ctx, tc.plaintext) + require.NoError(t, err) + + pt, err := dek.Decrypt(ctx, ct) + require.NoError(t, err) + require.True(t, bytes.Equal(tc.plaintext, pt)) + }) + } +} + +func TestDEKEncrypt(t *testing.T) { + t.Parallel() + + t.Run("unique ciphertext per call", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + dek, err := crypto.NewDEK() + require.NoError(t, err) + + pt := []byte("same plaintext") + ct1, err := dek.Encrypt(ctx, pt) + require.NoError(t, err) + + ct2, err := dek.Encrypt(ctx, pt) + require.NoError(t, err) + + require.NotEqual(t, ct1, ct2) + }) +} + +func TestDEKDecrypt(t *testing.T) { + t.Parallel() + + t.Run("wrong key", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + dek, err := crypto.NewDEK() + require.NoError(t, err) + + ct, err := dek.Encrypt(ctx, []byte("secret")) + require.NoError(t, err) + + altDEK, err := crypto.NewDEK() + require.NoError(t, err) + _, err = altDEK.Decrypt(ctx, ct) + require.Error(t, err) + }) + + t.Run("not encrypted with DEK", func(t *testing.T) { + t.Parallel() + + dek, err := crypto.NewDEK() + require.NoError(t, err) + + ct, err := dek.Decrypt(t.Context(), []byte("nope")) + require.ErrorIs(t, err, crypto.ErrMalformedCipherText) + require.Nil(t, ct) + }) + + t.Run("tampered ciphertext", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + dek, err := crypto.NewDEK() + require.NoError(t, err) + + ct, err := dek.Encrypt(ctx, []byte("secret")) + require.NoError(t, err) + + ct[len(ct)-1] ^= 0xFF + _, err = dek.Decrypt(ctx, ct) + require.Error(t, err) + }) +}