Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions crypto/dek.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package crypto

import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"fmt"
)

const keyBytes = 32

var ErrMalformedCipherText = errors.New("invalid ciphertext, not encrypted with a DEK")

type (
// DEK defines a Data Encryption Key used to encrypt/decrypt payloads.
DEK struct {
key []byte
gcm cipher.AEAD
}

// DEKMaterial defines the material needed in order to decrypt a payload.
DEKMaterial struct {
KEKID string // The ID/URI of the KEK the encrypted the DEK.
EncryptedDEK string // The base64-encoded encrypted DEK.
}
)

// NewDEK generates a new random 256-bit Data Encryption Key.
func NewDEK() (*DEK, error) {
k := make([]byte, keyBytes)
if _, err := rand.Read(k); err != nil {
return nil, fmt.Errorf("failed to create DEK: %w", err)
}

return dekFromKey(k)
}

// Encrypt encrypts the plaintext pt using AES-256-GCM. The returned ciphertext
// is prefixed with the randomly generated nonce.
func (d *DEK) Encrypt(ctx context.Context, pt []byte) ([]byte, error) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx unused?

nonce := make([]byte, d.gcm.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}

return d.gcm.Seal(nonce, nonce, pt, nil), nil
}

// Decrypt decrypts the ciphertext ct using AES-256-GCM. The ciphertext must be
// prefixed with the nonce, as produced by [DEK.Encrypt].
func (d *DEK) Decrypt(ctx context.Context, ct []byte) ([]byte, error) {
ns := d.gcm.NonceSize()
if len(ct) < ns {
return nil, ErrMalformedCipherText
}

pt, err := d.gcm.Open(nil, ct[:ns], ct[ns:], nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err)
}

return pt, nil
}

func dekFromKey(key []byte) (*DEK, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create cipher block: %w", err)
}

gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}

return &DEK{key: key, gcm: gcm}, nil
}
130 changes: 130 additions & 0 deletions crypto/dek_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package crypto_test

import (
"bytes"
"testing"

"github.com/stretchr/testify/require"

"github.com/temporalio/s2s-proxy/crypto"
)

func TestNewDEK(t *testing.T) {
t.Parallel()

t.Run("non-zero", func(t *testing.T) {
t.Parallel()

d, err := crypto.NewDEK()
require.NoError(t, err)
require.NotEqual(t, crypto.DEK{}, d)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare value against pointer?

})

t.Run("unique each call", func(t *testing.T) {
t.Parallel()

d1, err := crypto.NewDEK()
require.NoError(t, err)

d2, err := crypto.NewDEK()
require.NoError(t, err)

require.NotEqual(t, d1, d2)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this only compares pointers differ, should also compare contents differ?

})
}

func TestDEKEncryptDecrypt(t *testing.T) {
t.Parallel()

tests := []struct {
name string
plaintext []byte
}{
{"empty", []byte{}},
{"ascii", []byte("hello, world")},
{"binary", []byte{0x00, 0xFF, 0x42, 0x13}},
{"large", bytes.Repeat([]byte("a"), 64*1024)},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := t.Context()
dek, err := crypto.NewDEK()
require.NoError(t, err)

ct, err := dek.Encrypt(ctx, tc.plaintext)
require.NoError(t, err)

pt, err := dek.Decrypt(ctx, ct)
require.NoError(t, err)
require.True(t, bytes.Equal(tc.plaintext, pt))
})
}
}

func TestDEKEncrypt(t *testing.T) {
t.Parallel()

t.Run("unique ciphertext per call", func(t *testing.T) {
t.Parallel()
ctx := t.Context()
dek, err := crypto.NewDEK()
require.NoError(t, err)

pt := []byte("same plaintext")
ct1, err := dek.Encrypt(ctx, pt)
require.NoError(t, err)

ct2, err := dek.Encrypt(ctx, pt)
require.NoError(t, err)

require.NotEqual(t, ct1, ct2)
})
}

func TestDEKDecrypt(t *testing.T) {
t.Parallel()

t.Run("wrong key", func(t *testing.T) {
t.Parallel()

ctx := t.Context()
dek, err := crypto.NewDEK()
require.NoError(t, err)

ct, err := dek.Encrypt(ctx, []byte("secret"))
require.NoError(t, err)

altDEK, err := crypto.NewDEK()
require.NoError(t, err)
_, err = altDEK.Decrypt(ctx, ct)
require.Error(t, err)
})

t.Run("not encrypted with DEK", func(t *testing.T) {
t.Parallel()

dek, err := crypto.NewDEK()
require.NoError(t, err)

ct, err := dek.Decrypt(t.Context(), []byte("nope"))
require.ErrorIs(t, err, crypto.ErrMalformedCipherText)
require.Nil(t, ct)
})

t.Run("tampered ciphertext", func(t *testing.T) {
t.Parallel()

ctx := t.Context()
dek, err := crypto.NewDEK()
require.NoError(t, err)

ct, err := dek.Encrypt(ctx, []byte("secret"))
require.NoError(t, err)

ct[len(ct)-1] ^= 0xFF
_, err = dek.Decrypt(ctx, ct)
require.Error(t, err)
})
}
Loading