diff --git a/cmd/app/commands/master_key.go b/cmd/app/commands/master_key.go index 2949e5e..468ad9f 100644 --- a/cmd/app/commands/master_key.go +++ b/cmd/app/commands/master_key.go @@ -9,7 +9,7 @@ import ( "log/slog" "time" - cryptoService "github.com/allisson/secrets/internal/crypto/service" + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" ) // RunCreateMasterKey generates a cryptographically secure 32-byte master key for envelope encryption. @@ -27,7 +27,7 @@ import ( // Security: Never use localsecrets provider in production. Use cloud KMS providers (gcpkms, awskms, azurekeyvault). func RunCreateMasterKey( ctx context.Context, - kmsService cryptoService.KMSService, + kmsService cryptoDomain.KMSService, logger *slog.Logger, writer io.Writer, keyID string, diff --git a/cmd/app/commands/rotate_master_key.go b/cmd/app/commands/rotate_master_key.go index d419c60..d9ff6dc 100644 --- a/cmd/app/commands/rotate_master_key.go +++ b/cmd/app/commands/rotate_master_key.go @@ -10,12 +10,12 @@ import ( "strings" "time" - cryptoService "github.com/allisson/secrets/internal/crypto/service" + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" ) func RunRotateMasterKey( ctx context.Context, - kmsService cryptoService.KMSService, + kmsService cryptoDomain.KMSService, logger *slog.Logger, writer io.Writer, keyID, kmsProvider, kmsKeyURI, existingMasterKeys, existingActiveKeyID string, diff --git a/internal/app/di.go b/internal/app/di.go index 3dd33f3..7f211bc 100644 --- a/internal/app/di.go +++ b/internal/app/di.go @@ -47,7 +47,7 @@ type Container struct { // Services aeadManager cryptoService.AEADManager keyManager cryptoService.KeyManager - kmsService cryptoService.KMSService + kmsService cryptoDomain.KMSService secretService authService.SecretService tokenService authService.TokenService diff --git a/internal/app/di_crypto.go b/internal/app/di_crypto.go index 661620a..b0852c2 100644 --- a/internal/app/di_crypto.go +++ b/internal/app/di_crypto.go @@ -46,7 +46,7 @@ func (c *Container) KeyManager() cryptoService.KeyManager { } // KMSService returns the KMS service. -func (c *Container) KMSService() cryptoService.KMSService { +func (c *Container) KMSService() cryptoDomain.KMSService { c.kmsServiceInit.Do(func() { c.kmsService = c.initKMSService() }) @@ -156,7 +156,7 @@ func (c *Container) initKeyManager() cryptoService.KeyManager { } // initKMSService creates the KMS service for encrypting/decrypting master keys. -func (c *Container) initKMSService() cryptoService.KMSService { +func (c *Container) initKMSService() cryptoDomain.KMSService { return cryptoService.NewKMSService() } diff --git a/internal/crypto/domain/master_key.go b/internal/crypto/domain/master_key.go index 2e64065..5b9c865 100644 --- a/internal/crypto/domain/master_key.go +++ b/internal/crypto/domain/master_key.go @@ -215,8 +215,12 @@ func loadMasterKeyChainFromKMS( ) // Make a copy of the key data before storing to prevent issues if the underlying - // slice is reused. The original 'key' slice ownership is transferred to the keychain. - mkc.keys.Store(id, &MasterKey{ID: id, Key: key}) + // slice is reused. The original 'key' slice from KMS is zeroed after copying. + keyCopy := make([]byte, len(key)) + copy(keyCopy, key) + Zero(key) + + mkc.keys.Store(id, &MasterKey{ID: id, Key: keyCopy}) } if _, ok := mkc.Get(active); !ok { diff --git a/internal/crypto/service/key_manager.go b/internal/crypto/service/key_manager.go index 5be4560..8583f40 100644 --- a/internal/crypto/service/key_manager.go +++ b/internal/crypto/service/key_manager.go @@ -33,6 +33,7 @@ func (km *KeyManagerService) CreateKek( if _, err := rand.Read(kekKey); err != nil { return cryptoDomain.Kek{}, fmt.Errorf("failed to generate KEK: %w", err) } + defer cryptoDomain.Zero(kekKey) // Create cipher using AEADManager aead, err := km.aeadManager.CreateCipher(masterKey.Key, alg) @@ -46,12 +47,17 @@ func (km *KeyManagerService) CreateKek( return cryptoDomain.Kek{}, fmt.Errorf("failed to encrypt KEK: %w", err) } + // The plaintext KEK is included in the returned struct for in-memory use (e.g. initial setup) + // but the local variable kekKey is zeroed upon return for security. + keyCopy := make([]byte, len(kekKey)) + copy(keyCopy, kekKey) + kek := cryptoDomain.Kek{ ID: uuid.Must(uuid.NewV7()), MasterKeyID: masterKey.ID, Algorithm: alg, EncryptedKey: encryptedKey, - Key: kekKey, + Key: keyCopy, Nonce: nonce, Version: 1, CreatedAt: time.Now().UTC(), @@ -92,6 +98,7 @@ func (km *KeyManagerService) CreateDek( if _, err := rand.Read(dekKey); err != nil { return cryptoDomain.Dek{}, fmt.Errorf("failed to generate DEK: %w", err) } + defer cryptoDomain.Zero(dekKey) // Create cipher using AEADManager with KEK's algorithm aead, err := km.aeadManager.CreateCipher(kek.Key, kek.Algorithm) diff --git a/internal/crypto/service/kms_service.go b/internal/crypto/service/kms_service.go index 1b6bac6..a212ca9 100644 --- a/internal/crypto/service/kms_service.go +++ b/internal/crypto/service/kms_service.go @@ -16,21 +16,14 @@ import ( _ "gocloud.dev/secrets/localsecrets" ) -// KMSService implements domain.KMSService for KMS operations using gocloud.dev/secrets. -type KMSService interface { - // OpenKeeper opens a secrets.Keeper for the configured KMS provider. - // Returns an error if the KMS provider URI is invalid or connection fails. - OpenKeeper(ctx context.Context, keyURI string) (cryptoDomain.KMSKeeper, error) -} - -// kmsService implements KMSService using gocloud.dev/secrets. -type kmsService struct{} - // NewKMSService creates a new KMS service instance. -func NewKMSService() KMSService { +func NewKMSService() cryptoDomain.KMSService { return &kmsService{} } +// kmsService implements domain.KMSService using gocloud.dev/secrets. +type kmsService struct{} + // OpenKeeper opens a secrets.Keeper for the configured KMS provider using the keyURI. // Supports: gcpkms://, awskms://, azurekeyvault://, hashivault://, base64key:// // Returns a KMSKeeper which *secrets.Keeper implements. diff --git a/internal/crypto/usecase/dek_usecase.go b/internal/crypto/usecase/dek_usecase.go index 3447ce4..925a29c 100644 --- a/internal/crypto/usecase/dek_usecase.go +++ b/internal/crypto/usecase/dek_usecase.go @@ -20,66 +20,72 @@ type dekUseCase struct { // Rewrap finds DEKs that are not encrypted with the specified KEK ID, // decrypts them using their old KEKs, and re-encrypts them with the new KEK. -// Returns the number of DEKs rewrapped in this batch. +// Returns the number of DEKs rewrapped in this batch. Executes in a transaction. func (d *dekUseCase) Rewrap( ctx context.Context, kekChain *cryptoDomain.KekChain, newKekID uuid.UUID, batchSize int, ) (int, error) { - // 1. Fetch batch of DEKs not using the new KEK ID - deks, err := d.dekRepo.GetBatchNotKekID(ctx, newKekID, batchSize) - if err != nil { - return 0, err - } + var rewrappedCount int - if len(deks) == 0 { - return 0, nil - } + err := d.txManager.WithTx(ctx, func(ctx context.Context) error { + // 1. Fetch batch of DEKs not using the new KEK ID + deks, err := d.dekRepo.GetBatchNotKekID(ctx, newKekID, batchSize) + if err != nil { + return err + } - // 2. Get the new KEK from the chain - newKek, ok := kekChain.Get(newKekID) - if !ok { - return 0, cryptoDomain.ErrKekNotFound - } - if newKek.Key == nil { - return 0, cryptoDomain.ErrDecryptionFailed // or another appropriate error indicating unwrapped KEK is needed - } + if len(deks) == 0 { + return nil + } - // 3. Process each DEK in the batch - for _, dek := range deks { - // Get the old KEK - oldKek, ok := kekChain.Get(dek.KekID) + // 2. Get the new KEK from the chain + newKek, ok := kekChain.Get(newKekID) if !ok { - return 0, cryptoDomain.ErrKekNotFound + return cryptoDomain.ErrKekNotFound } - - // Decrypt the DEK plaintext key using the old KEK - dekKey, err := d.keyManager.DecryptDek(dek, oldKek) - if err != nil { - return 0, err + if newKek.Key == nil { + return cryptoDomain.ErrDecryptionFailed } - // Encrypt the DEK plaintext key using the new KEK - encryptedKey, nonce, err := d.keyManager.EncryptDek(dekKey, newKek) - if err != nil { + // 3. Process each DEK in the batch + for _, dek := range deks { + // Get the old KEK + oldKek, ok := kekChain.Get(dek.KekID) + if !ok { + return cryptoDomain.ErrKekNotFound + } + + // Decrypt the DEK plaintext key using the old KEK + dekKey, err := d.keyManager.DecryptDek(dek, oldKek) + if err != nil { + return err + } + + // Encrypt the DEK plaintext key using the new KEK + encryptedKey, nonce, err := d.keyManager.EncryptDek(dekKey, newKek) cryptoDomain.Zero(dekKey) - return 0, err - } - cryptoDomain.Zero(dekKey) + if err != nil { + return err + } - // Update DEK entity - dek.KekID = newKekID - dek.EncryptedKey = encryptedKey - dek.Nonce = nonce + // Update DEK entity + dek.KekID = newKekID + dek.EncryptedKey = encryptedKey + dek.Nonce = nonce - // Save updated DEK - if err := d.dekRepo.Update(ctx, dek); err != nil { - return 0, err + // Save updated DEK + if err := d.dekRepo.Update(ctx, dek); err != nil { + return err + } } - } - return len(deks), nil + rewrappedCount = len(deks) + return nil + }) + + return rewrappedCount, err } // NewDekUseCase creates a new DekUseCase instance. diff --git a/internal/crypto/usecase/dek_usecase_test.go b/internal/crypto/usecase/dek_usecase_test.go index 6d393e9..38cd44f 100644 --- a/internal/crypto/usecase/dek_usecase_test.go +++ b/internal/crypto/usecase/dek_usecase_test.go @@ -62,8 +62,16 @@ func TestDekUseCase_Rewrap(t *testing.T) { batch := []*cryptoDomain.Dek{dek1} - // Setup mock expectations - dekRepo.EXPECT().GetBatchNotKekID(ctx, newKekID, batchSize).Return(batch, nil) + // Setup mock expectations for transaction + txManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + _ = fn(ctx) + }). + Return(nil). + Once() + + dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) plainDek1Key := []byte("dek1-plaintext-key") keyManager.EXPECT().DecryptDek(dek1, oldKek).Return(plainDek1Key, nil) @@ -72,7 +80,7 @@ func TestDekUseCase_Rewrap(t *testing.T) { newNonceDek1 := []byte("dek1-nonce-new") keyManager.EXPECT().EncryptDek(plainDek1Key, newKek).Return(newEncDek1, newNonceDek1, nil) - dekRepo.EXPECT().Update(ctx, mock.MatchedBy(func(dek *cryptoDomain.Dek) bool { + dekRepo.EXPECT().Update(mock.Anything, mock.MatchedBy(func(dek *cryptoDomain.Dek) bool { return dek.ID == dek1.ID && dek.KekID == newKekID && string(dek.EncryptedKey) == "dek1-encrypted-new" && @@ -97,7 +105,18 @@ func TestDekUseCase_Rewrap(t *testing.T) { kekChain := cryptoDomain.NewKekChain([]*cryptoDomain.Kek{newKek}) batchSize := 10 - dekRepo.EXPECT().GetBatchNotKekID(ctx, newKekID, batchSize).Return([]*cryptoDomain.Dek{}, nil) + // Setup mock expectations for transaction + txManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + _ = fn(ctx) + }). + Return(nil). + Once() + + dekRepo.EXPECT(). + GetBatchNotKekID(mock.Anything, newKekID, batchSize). + Return([]*cryptoDomain.Dek{}, nil) rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) @@ -119,7 +138,16 @@ func TestDekUseCase_Rewrap(t *testing.T) { dek1 := &cryptoDomain.Dek{ID: uuid.New(), KekID: uuid.New()} batch := []*cryptoDomain.Dek{dek1} - dekRepo.EXPECT().GetBatchNotKekID(ctx, newKekID, batchSize).Return(batch, nil) + // Setup mock expectations for transaction + txManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + _ = fn(ctx) + }). + Return(cryptoDomain.ErrKekNotFound). + Once() + + dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) @@ -144,7 +172,16 @@ func TestDekUseCase_Rewrap(t *testing.T) { dek1 := &cryptoDomain.Dek{ID: uuid.New(), KekID: oldKekID} batch := []*cryptoDomain.Dek{dek1} - dekRepo.EXPECT().GetBatchNotKekID(ctx, newKekID, batchSize).Return(batch, nil) + // Setup mock expectations for transaction + txManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + _ = fn(ctx) + }). + Return(cryptoDomain.ErrKekNotFound). + Once() + + dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) @@ -170,8 +207,18 @@ func TestDekUseCase_Rewrap(t *testing.T) { dek1 := &cryptoDomain.Dek{ID: uuid.New(), KekID: oldKekID} batch := []*cryptoDomain.Dek{dek1} - dekRepo.EXPECT().GetBatchNotKekID(ctx, newKekID, batchSize).Return(batch, nil) expectedErr := errors.New("decryption failed") + + // Setup mock expectations for transaction + txManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + _ = fn(ctx) + }). + Return(expectedErr). + Once() + + dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) keyManager.EXPECT().DecryptDek(dek1, oldKek).Return(nil, expectedErr) rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) @@ -198,12 +245,22 @@ func TestDekUseCase_Rewrap(t *testing.T) { dek1 := &cryptoDomain.Dek{ID: uuid.New(), KekID: oldKekID} batch := []*cryptoDomain.Dek{dek1} - dekRepo.EXPECT().GetBatchNotKekID(ctx, newKekID, batchSize).Return(batch, nil) + expectedErr := errors.New("encryption failed") + + // Setup mock expectations for transaction + txManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + _ = fn(ctx) + }). + Return(expectedErr). + Once() + + dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) plainDek1Key := []byte("plain-key") keyManager.EXPECT().DecryptDek(dek1, oldKek).Return(plainDek1Key, nil) - expectedErr := errors.New("encryption failed") keyManager.EXPECT().EncryptDek(plainDek1Key, newKek).Return(nil, nil, expectedErr) rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) @@ -230,14 +287,24 @@ func TestDekUseCase_Rewrap(t *testing.T) { dek1 := &cryptoDomain.Dek{ID: uuid.New(), KekID: oldKekID} batch := []*cryptoDomain.Dek{dek1} - dekRepo.EXPECT().GetBatchNotKekID(ctx, newKekID, batchSize).Return(batch, nil) + expectedErr := errors.New("update failed") + + // Setup mock expectations for transaction + txManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + _ = fn(ctx) + }). + Return(expectedErr). + Once() + + dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) plainDek1Key := []byte("plain-key") keyManager.EXPECT().DecryptDek(dek1, oldKek).Return(plainDek1Key, nil) keyManager.EXPECT().EncryptDek(plainDek1Key, newKek).Return([]byte("enc"), []byte("nonce"), nil) - expectedErr := errors.New("update failed") - dekRepo.EXPECT().Update(ctx, dek1).Return(expectedErr) + dekRepo.EXPECT().Update(mock.Anything, dek1).Return(expectedErr) rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize)