diff --git a/cmd/app/auth_commands.go b/cmd/app/auth_commands.go index c279a8a..bcbcae0 100644 --- a/cmd/app/auth_commands.go +++ b/cmd/app/auth_commands.go @@ -1,3 +1,4 @@ +// Package main provides the CLI command definitions for the application. package main import ( @@ -7,9 +8,9 @@ import ( "github.com/allisson/secrets/cmd/app/commands" "github.com/allisson/secrets/internal/app" - "github.com/allisson/secrets/internal/config" ) +// getAuthCommands returns the authentication-related CLI commands. func getAuthCommands() []*cli.Command { return []*cli.Command{ { @@ -36,23 +37,24 @@ func getAuthCommands() []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - tokenizationUseCase, err := container.TokenizationUseCase() - if err != nil { - return err - } - - return commands.RunCleanExpiredTokens( + return commands.ExecuteWithContainer( ctx, - tokenizationUseCase, - container.Logger(), - commands.DefaultIO().Writer, - int(cmd.Int("days")), - cmd.Bool("dry-run"), - cmd.String("format"), + func(ctx context.Context, container *app.Container) error { + tokenizationUseCase, err := container.TokenizationUseCase() + if err != nil { + return err + } + + return commands.RunCleanExpiredTokens( + ctx, + tokenizationUseCase, + container.Logger(), + commands.DefaultIO().Writer, + int(cmd.Int("days")), + cmd.Bool("dry-run"), + cmd.String("format"), + ) + }, ) }, }, @@ -85,24 +87,25 @@ func getAuthCommands() []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - clientUseCase, err := container.ClientUseCase() - if err != nil { - return err - } - - return commands.RunCreateClient( + return commands.ExecuteWithContainer( ctx, - clientUseCase, - container.Logger(), - cmd.String("name"), - cmd.Bool("active"), - cmd.String("policies"), - cmd.String("format"), - commands.DefaultIO(), + func(ctx context.Context, container *app.Container) error { + clientUseCase, err := container.ClientUseCase() + if err != nil { + return err + } + + return commands.RunCreateClient( + ctx, + clientUseCase, + container.Logger(), + cmd.String("name"), + cmd.Bool("active"), + cmd.String("policies"), + cmd.String("format"), + commands.DefaultIO(), + ) + }, ) }, }, @@ -141,25 +144,26 @@ func getAuthCommands() []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - clientUseCase, err := container.ClientUseCase() - if err != nil { - return err - } - - return commands.RunUpdateClient( + return commands.ExecuteWithContainer( ctx, - clientUseCase, - container.Logger(), - commands.DefaultIO(), - cmd.String("id"), - cmd.String("name"), - cmd.Bool("active"), - cmd.String("policies"), - cmd.String("format"), + func(ctx context.Context, container *app.Container) error { + clientUseCase, err := container.ClientUseCase() + if err != nil { + return err + } + + return commands.RunUpdateClient( + ctx, + clientUseCase, + container.Logger(), + commands.DefaultIO(), + cmd.String("id"), + cmd.String("name"), + cmd.Bool("active"), + cmd.String("policies"), + cmd.String("format"), + ) + }, ) }, }, diff --git a/cmd/app/commands.go b/cmd/app/commands.go index 8dcafdc..d1d3a6e 100644 --- a/cmd/app/commands.go +++ b/cmd/app/commands.go @@ -1,9 +1,11 @@ +// Package main provides the CLI command definitions for the application. package main import ( "github.com/urfave/cli/v3" ) +// getCommands aggregates and returns all CLI commands for the application. func getCommands(version string) []*cli.Command { cmds := []*cli.Command{} cmds = append(cmds, getSystemCommands(version)...) diff --git a/cmd/app/commands/clean_audit_logs.go b/cmd/app/commands/clean_audit_logs.go index 60d5726..86976a8 100644 --- a/cmd/app/commands/clean_audit_logs.go +++ b/cmd/app/commands/clean_audit_logs.go @@ -10,6 +10,33 @@ import ( authUseCase "github.com/allisson/secrets/internal/auth/usecase" ) +// CleanAuditLogsResult holds the result of the audit log cleanup operation. +type CleanAuditLogsResult struct { + Count int64 `json:"count"` + Days int `json:"days"` + DryRun bool `json:"dry_run"` +} + +// ToText returns a human-readable representation of the cleanup result. +func (r *CleanAuditLogsResult) ToText() string { + if r.DryRun { + return fmt.Sprintf( + "Dry-run mode: Would delete %d audit log(s) older than %d day(s)", + r.Count, + r.Days, + ) + } + return fmt.Sprintf("Successfully deleted %d audit log(s) older than %d day(s)", r.Count, r.Days) +} + +// ToJSON returns a JSON representation of the cleanup result. +func (r *CleanAuditLogsResult) ToJSON() string { + jsonBytes, _ := json.MarshalIndent(r, "", " ") + return string(jsonBytes) +} + +// RunCleanAuditLogs deletes audit logs older than the specified number of days. +// Supports dry-run mode and multiple output formats. func RunCleanAuditLogs( ctx context.Context, auditLogUseCase authUseCase.AuditLogUseCase, @@ -35,12 +62,13 @@ func RunCleanAuditLogs( return fmt.Errorf("failed to delete audit logs: %w", err) } - // Output result based on format - if format == "json" { - outputCleanAuditLogsJSON(writer, count, days, dryRun) - } else { - outputCleanAuditLogsText(writer, count, days, dryRun) + // Output result + result := &CleanAuditLogsResult{ + Count: count, + Days: days, + DryRun: dryRun, } + WriteOutput(writer, format, result) logger.Info("cleanup completed", slog.Int64("count", count), @@ -50,33 +78,3 @@ func RunCleanAuditLogs( return nil } - -// outputCleanAuditLogsText outputs the result in human-readable text format. -func outputCleanAuditLogsText(writer io.Writer, count int64, days int, dryRun bool) { - if dryRun { - _, _ = fmt.Fprintf( - writer, - "Dry-run mode: Would delete %d audit log(s) older than %d day(s)\n", - count, - days, - ) - } else { - _, _ = fmt.Fprintf(writer, "Successfully deleted %d audit log(s) older than %d day(s)\n", count, days) - } -} - -// outputCleanAuditLogsJSON outputs the result in JSON format for machine consumption. -func outputCleanAuditLogsJSON(writer io.Writer, count int64, days int, dryRun bool) { - result := map[string]interface{}{ - "count": count, - "days": days, - "dry_run": dryRun, - } - - jsonBytes, err := json.MarshalIndent(result, "", " ") - if err != nil { - return - } - - _, _ = fmt.Fprintln(writer, string(jsonBytes)) -} diff --git a/cmd/app/commands/clean_expired_tokens.go b/cmd/app/commands/clean_expired_tokens.go index 210d154..6eecd23 100644 --- a/cmd/app/commands/clean_expired_tokens.go +++ b/cmd/app/commands/clean_expired_tokens.go @@ -10,6 +10,37 @@ import ( tokenizationUseCase "github.com/allisson/secrets/internal/tokenization/usecase" ) +// CleanExpiredTokensResult holds the result of the expired token cleanup operation. +type CleanExpiredTokensResult struct { + Count int64 `json:"count"` + Days int `json:"days"` + DryRun bool `json:"dry_run"` +} + +// ToText returns a human-readable representation of the cleanup result. +func (r *CleanExpiredTokensResult) ToText() string { + if r.DryRun { + return fmt.Sprintf( + "Dry-run mode: Would delete %d expired token(s) older than %d day(s)", + r.Count, + r.Days, + ) + } + return fmt.Sprintf( + "Successfully deleted %d expired token(s) older than %d day(s)", + r.Count, + r.Days, + ) +} + +// ToJSON returns a JSON representation of the cleanup result. +func (r *CleanExpiredTokensResult) ToJSON() string { + jsonBytes, _ := json.MarshalIndent(r, "", " ") + return string(jsonBytes) +} + +// RunCleanExpiredTokens deletes expired tokens older than the specified number of days. +// Supports dry-run mode and multiple output formats. func RunCleanExpiredTokens( ctx context.Context, tokenizationUseCase tokenizationUseCase.TokenizationUseCase, @@ -35,12 +66,13 @@ func RunCleanExpiredTokens( return fmt.Errorf("failed to cleanup expired tokens: %w", err) } - // Output result based on format - if format == "json" { - outputCleanExpiredJSON(writer, count, days, dryRun) - } else { - outputCleanExpiredText(writer, count, days, dryRun) + // Output result + result := &CleanExpiredTokensResult{ + Count: count, + Days: days, + DryRun: dryRun, } + WriteOutput(writer, format, result) logger.Info("cleanup completed", slog.Int64("count", count), @@ -50,38 +82,3 @@ func RunCleanExpiredTokens( return nil } - -// outputCleanExpiredText outputs the result in human-readable text format. -func outputCleanExpiredText(writer io.Writer, count int64, days int, dryRun bool) { - if dryRun { - _, _ = fmt.Fprintf( - writer, - "Dry-run mode: Would delete %d expired token(s) older than %d day(s)\n", - count, - days, - ) - } else { - _, _ = fmt.Fprintf( - writer, - "Successfully deleted %d expired token(s) older than %d day(s)\n", - count, - days, - ) - } -} - -// outputCleanExpiredJSON outputs the result in JSON format for machine consumption. -func outputCleanExpiredJSON(writer io.Writer, count int64, days int, dryRun bool) { - result := map[string]interface{}{ - "count": count, - "days": days, - "dry_run": dryRun, - } - - jsonBytes, err := json.MarshalIndent(result, "", " ") - if err != nil { - return - } - - _, _ = fmt.Fprintln(writer, string(jsonBytes)) -} diff --git a/cmd/app/commands/clean_expired_tokens_test.go b/cmd/app/commands/clean_expired_tokens_test.go index ea04058..7e95cd6 100644 --- a/cmd/app/commands/clean_expired_tokens_test.go +++ b/cmd/app/commands/clean_expired_tokens_test.go @@ -18,25 +18,25 @@ func TestRunCleanExpiredTokens(t *testing.T) { t.Run("text-output", func(t *testing.T) { mockUseCase := &tokenizationMocks.MockTokenizationUseCase{} - mockUseCase.On("CleanupExpired", ctx, days, false).Return(int64(10), nil) + mockUseCase.On("CleanupExpired", ctx, days, false).Return(int64(100), nil) var out bytes.Buffer err := RunCleanExpiredTokens(ctx, mockUseCase, logger, &out, days, false, "text") require.NoError(t, err) - require.Contains(t, out.String(), "Successfully deleted 10 expired token(s)") + require.Contains(t, out.String(), "Successfully deleted 100 expired token(s)") mockUseCase.AssertExpectations(t) }) t.Run("json-output", func(t *testing.T) { mockUseCase := &tokenizationMocks.MockTokenizationUseCase{} - mockUseCase.On("CleanupExpired", ctx, days, true).Return(int64(5), nil) + mockUseCase.On("CleanupExpired", ctx, days, true).Return(int64(50), nil) var out bytes.Buffer err := RunCleanExpiredTokens(ctx, mockUseCase, logger, &out, days, true, "json") require.NoError(t, err) - require.Contains(t, out.String(), `"count": 5`) + require.Contains(t, out.String(), `"count": 50`) require.Contains(t, out.String(), `"dry_run": true`) mockUseCase.AssertExpectations(t) }) diff --git a/cmd/app/commands/create_client.go b/cmd/app/commands/create_client.go index 5fb20d2..0e48b0a 100644 --- a/cmd/app/commands/create_client.go +++ b/cmd/app/commands/create_client.go @@ -1,25 +1,43 @@ package commands import ( - "bufio" "context" "encoding/json" "fmt" - "io" "log/slog" - "os" - "strings" authDomain "github.com/allisson/secrets/internal/auth/domain" authUseCase "github.com/allisson/secrets/internal/auth/usecase" + "github.com/allisson/secrets/internal/ui" ) +// CreateClientResult holds the result of the client creation operation. +type CreateClientResult struct { + ID string `json:"client_id"` + // #nosec G117 + PlainSecret string `json:"secret"` +} + +// ToText returns a human-readable representation of the creation result. +func (r *CreateClientResult) ToText() string { + var sb fmt.Stringer + output := "\nClient created successfully!\n" + output += fmt.Sprintf("Client ID: %s\n", r.ID) + output += fmt.Sprintf("Secret: %s\n", r.PlainSecret) + output += "\nIMPORTANT: The secret is shown only once. Store it securely." + _ = sb + return output +} + +// ToJSON returns a JSON representation of the creation result. +func (r *CreateClientResult) ToJSON() string { + jsonBytes, _ := json.MarshalIndent(r, "", " ") + return string(jsonBytes) +} + // RunCreateClient creates a new authentication client with policies. // Supports both interactive mode (when policiesJSON is empty) and non-interactive -// mode (when policiesJSON is provided). Outputs client ID and plain secret in -// either text or JSON format. -// -// Requirements: Database must be migrated and accessible. +// mode (when policiesJSON is provided). func RunCreateClient( ctx context.Context, clientUseCase authUseCase.ClientUseCase, @@ -38,7 +56,7 @@ func RunCreateClient( if policiesJSON == "" { // Interactive mode - policies, err = promptForPolicies(io) + policies, err = ui.PromptForPolicies(io.Reader, io.Writer) if err != nil { return fmt.Errorf("failed to get policies: %w", err) } @@ -67,12 +85,12 @@ func RunCreateClient( return fmt.Errorf("failed to create client: %w", err) } - // Output result based on format - if format == "json" { - outputJSON(output, io.Writer) - } else { - outputText(output, io.Writer) + // Output result + result := &CreateClientResult{ + ID: output.ID.String(), + PlainSecret: output.PlainSecret, } + WriteOutput(io.Writer, format, result) logger.Info("client created successfully", slog.String("client_id", output.ID.String()), @@ -82,115 +100,3 @@ func RunCreateClient( return nil } - -// promptForPolicies interactively prompts the user to enter policy documents. -// Shows available capabilities and accepts multiple policies until user declines. -func promptForPolicies(io IOTuple) ([]authDomain.PolicyDocument, error) { - reader := bufio.NewReader(io.Reader) - writer := io.Writer - var policies []authDomain.PolicyDocument - - _, _ = fmt.Fprintln(writer, "\nEnter policies for the client") - _, _ = fmt.Fprintln(writer, "Available capabilities: read, write, delete, encrypt, decrypt, rotate") - _, _ = fmt.Fprintln(writer) - - policyNum := 1 - for { - _, _ = fmt.Fprintf(writer, "Policy #%d\n", policyNum) - - // Get path - _, _ = fmt.Fprint(writer, "Enter path pattern (e.g., 'secret/*' or '*'): ") - path, err := reader.ReadString('\n') - if err != nil { - return nil, fmt.Errorf("failed to read path: %w", err) - } - path = strings.TrimSpace(path) - - if path == "" { - return nil, fmt.Errorf("path cannot be empty") - } - - // Get capabilities - _, _ = fmt.Fprint(writer, "Enter capabilities (comma-separated, e.g., 'read,write'): ") - capsInput, err := reader.ReadString('\n') - if err != nil { - return nil, fmt.Errorf("failed to read capabilities: %w", err) - } - capsInput = strings.TrimSpace(capsInput) - - if capsInput == "" { - return nil, fmt.Errorf("capabilities cannot be empty") - } - - capabilities, err := parseCapabilities(capsInput) - if err != nil { - return nil, err - } - - // Add policy - policies = append(policies, authDomain.PolicyDocument{ - Path: path, - Capabilities: capabilities, - }) - - // Ask if user wants to add another - _, _ = fmt.Fprint(writer, "Add another policy? (y/n): ") - addAnother, err := reader.ReadString('\n') - if err != nil { - return nil, fmt.Errorf("failed to read input: %w", err) - } - addAnother = strings.ToLower(strings.TrimSpace(addAnother)) - - if addAnother != "y" && addAnother != "yes" { - break - } - - _, _ = fmt.Fprintln(writer) - policyNum++ - } - - return policies, nil -} - -// parseCapabilities converts a comma-separated string into a slice of Capability. -func parseCapabilities(input string) ([]authDomain.Capability, error) { - parts := strings.Split(input, ",") - capabilities := make([]authDomain.Capability, 0, len(parts)) - - for _, part := range parts { - cap := authDomain.Capability(strings.TrimSpace(part)) - if cap != "" { - capabilities = append(capabilities, cap) - } - } - - if len(capabilities) == 0 { - return nil, fmt.Errorf("at least one capability is required") - } - - return capabilities, nil -} - -// outputText outputs the result in human-readable text format. -func outputText(output *authDomain.CreateClientOutput, writer io.Writer) { - _, _ = fmt.Fprintln(writer, "\nClient created successfully!") - _, _ = fmt.Fprintf(writer, "Client ID: %s\n", output.ID.String()) - _, _ = fmt.Fprintf(writer, "Secret: %s\n", output.PlainSecret) - _, _ = fmt.Fprintln(writer, "\nIMPORTANT: The secret is shown only once. Store it securely.") -} - -// outputJSON outputs the result in JSON format for machine consumption. -func outputJSON(output *authDomain.CreateClientOutput, writer io.Writer) { - result := map[string]string{ - "client_id": output.ID.String(), - "secret": output.PlainSecret, - } - - jsonBytes, err := json.MarshalIndent(result, "", " ") - if err != nil { - _, _ = fmt.Fprintf(os.Stderr, "failed to marshal JSON: %v\n", err) - return - } - - _, _ = fmt.Fprintln(writer, string(jsonBytes)) -} diff --git a/cmd/app/commands/create_client_test.go b/cmd/app/commands/create_client_test.go index a9bb98a..97cc422 100644 --- a/cmd/app/commands/create_client_test.go +++ b/cmd/app/commands/create_client_test.go @@ -3,10 +3,12 @@ package commands import ( "bytes" "context" + "encoding/json" "log/slog" "testing" "github.com/google/uuid" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" authDomain "github.com/allisson/secrets/internal/auth/domain" @@ -16,94 +18,65 @@ import ( func TestRunCreateClient(t *testing.T) { ctx := context.Background() logger := slog.Default() + name := "test-client" + policies := []authDomain.PolicyDocument{ + { + Path: "*", + Capabilities: []authDomain.Capability{"read"}, + }, + } + policiesJSON, _ := json.Marshal(policies) clientID := uuid.New() - plainSecret := "test-secret" + plainSecret := "plain-secret" t.Run("non-interactive-text", func(t *testing.T) { mockUseCase := &authMocks.MockClientUseCase{} - input := &authDomain.CreateClientInput{ - Name: "test-client", - IsActive: true, - Policies: []authDomain.PolicyDocument{ - {Path: "*", Capabilities: []authDomain.Capability{"read"}}, - }, - } - output := &authDomain.CreateClientOutput{ + mockUseCase.On("Create", ctx, mock.Anything).Return(&authDomain.CreateClientOutput{ ID: clientID, PlainSecret: plainSecret, - } - - mockUseCase.On("Create", ctx, input).Return(output, nil) + }, nil) var out bytes.Buffer - io := IOTuple{ - Reader: nil, - Writer: &out, - } - - err := RunCreateClient( - ctx, - mockUseCase, - logger, - "test-client", - true, - `[{"path":"*","capabilities":["read"]}]`, - "text", - io, - ) + io := IOTuple{Writer: &out} + err := RunCreateClient(ctx, mockUseCase, logger, name, true, string(policiesJSON), "text", io) require.NoError(t, err) + require.Contains(t, out.String(), "Client created successfully!") require.Contains(t, out.String(), clientID.String()) require.Contains(t, out.String(), plainSecret) mockUseCase.AssertExpectations(t) }) - t.Run("interactive-json", func(t *testing.T) { + t.Run("non-interactive-json", func(t *testing.T) { mockUseCase := &authMocks.MockClientUseCase{} - input := &authDomain.CreateClientInput{ - Name: "test-client", - IsActive: true, - Policies: []authDomain.PolicyDocument{ - {Path: "secret/*", Capabilities: []authDomain.Capability{"read", "write"}}, - }, - } - output := &authDomain.CreateClientOutput{ + mockUseCase.On("Create", ctx, mock.Anything).Return(&authDomain.CreateClientOutput{ ID: clientID, PlainSecret: plainSecret, - } - - mockUseCase.On("Create", ctx, input).Return(output, nil) + }, nil) - // Simulate interactive input: - // 1. Path: secret/* - // 2. Caps: read,write - // 3. Add another: n - userInput := "secret/*\nread,write\nn\n" var out bytes.Buffer - io := IOTuple{ - Reader: bytes.NewBufferString(userInput), - Writer: &out, - } - - err := RunCreateClient(ctx, mockUseCase, logger, "test-client", true, "", "json", io) + io := IOTuple{Writer: &out} + err := RunCreateClient(ctx, mockUseCase, logger, name, true, string(policiesJSON), "json", io) require.NoError(t, err) require.Contains(t, out.String(), clientID.String()) require.Contains(t, out.String(), plainSecret) - require.Contains(t, out.String(), "{") // Should be JSON mockUseCase.AssertExpectations(t) }) t.Run("invalid-policies-json", func(t *testing.T) { mockUseCase := &authMocks.MockClientUseCase{} - io := IOTuple{ - Reader: nil, - Writer: &bytes.Buffer{}, - } - - err := RunCreateClient(ctx, mockUseCase, logger, "test-client", true, `invalid-json`, "text", io) + err := RunCreateClient(ctx, mockUseCase, logger, name, true, "invalid-json", "text", IOTuple{}) require.Error(t, err) require.Contains(t, err.Error(), "failed to parse policies JSON") }) + + t.Run("empty-policies", func(t *testing.T) { + mockUseCase := &authMocks.MockClientUseCase{} + err := RunCreateClient(ctx, mockUseCase, logger, name, true, "[]", "text", IOTuple{}) + + require.Error(t, err) + require.Contains(t, err.Error(), "at least one policy is required") + }) } diff --git a/cmd/app/commands/create_kek.go b/cmd/app/commands/create_kek.go index 26e057a..ea3da48 100644 --- a/cmd/app/commands/create_kek.go +++ b/cmd/app/commands/create_kek.go @@ -9,11 +9,8 @@ import ( cryptoUseCase "github.com/allisson/secrets/internal/crypto/usecase" ) -// RunCreateKek creates a new Key Encryption Key using the specified algorithm. -// Should only be run once during initial system setup. The KEK is encrypted using -// the active master key from MASTER_KEYS environment variable. -// -// Requirements: Database must be migrated, MASTER_KEYS and ACTIVE_MASTER_KEY_ID must be set. +// RunCreateKek creates a new Key Encryption Key (KEK) and encrypts it with the master key. +// The new KEK will be stored in the database and marked as active for its algorithm. func RunCreateKek( ctx context.Context, kekUseCase cryptoUseCase.KekUseCase, @@ -24,15 +21,11 @@ func RunCreateKek( logger.Info("creating new KEK", slog.String("algorithm", algorithmStr)) // Parse algorithm - algorithm, err := parseAlgorithm(algorithmStr) + algorithm, err := ParseAlgorithm(algorithmStr) if err != nil { return err } - logger.Info("master key chain loaded", - slog.String("active_master_key_id", masterKeyChain.ActiveMasterKeyID()), - ) - // Create the KEK if err := kekUseCase.Create(ctx, masterKeyChain, algorithm); err != nil { return fmt.Errorf("failed to create KEK: %w", err) @@ -45,19 +38,3 @@ func RunCreateKek( return nil } - -// parseAlgorithm converts algorithm string to cryptoDomain.Algorithm type. -// Returns an error if the algorithm string is invalid. -func parseAlgorithm(algorithmStr string) (cryptoDomain.Algorithm, error) { - switch algorithmStr { - case "aes-gcm": - return cryptoDomain.AESGCM, nil - case "chacha20-poly1305": - return cryptoDomain.ChaCha20, nil - default: - return "", fmt.Errorf( - "invalid algorithm: %s (valid options: aes-gcm, chacha20-poly1305)", - algorithmStr, - ) - } -} diff --git a/cmd/app/commands/create_kek_test.go b/cmd/app/commands/create_kek_test.go index f1db1ce..1ddba39 100644 --- a/cmd/app/commands/create_kek_test.go +++ b/cmd/app/commands/create_kek_test.go @@ -3,7 +3,6 @@ package commands import ( "context" "log/slog" - "os" "testing" "github.com/stretchr/testify/require" @@ -14,14 +13,15 @@ import ( func TestRunCreateKek(t *testing.T) { ctx := context.Background() - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - masterKeyChain := cryptoDomain.NewMasterKeyChain("test-master-key") + logger := slog.Default() + masterKeyChain := cryptoDomain.NewMasterKeyChain("key1") t.Run("success", func(t *testing.T) { mockUseCase := &cryptoMocks.MockKekUseCase{} mockUseCase.On("Create", ctx, masterKeyChain, cryptoDomain.AESGCM).Return(nil) err := RunCreateKek(ctx, mockUseCase, masterKeyChain, logger, "aes-gcm") + require.NoError(t, err) mockUseCase.AssertExpectations(t) }) @@ -29,6 +29,7 @@ func TestRunCreateKek(t *testing.T) { t.Run("invalid-algorithm", func(t *testing.T) { mockUseCase := &cryptoMocks.MockKekUseCase{} err := RunCreateKek(ctx, mockUseCase, masterKeyChain, logger, "invalid") + require.Error(t, err) require.Contains(t, err.Error(), "invalid algorithm") }) diff --git a/cmd/app/commands/create_tokenization_key.go b/cmd/app/commands/create_tokenization_key.go index 65e542c..503f60b 100644 --- a/cmd/app/commands/create_tokenization_key.go +++ b/cmd/app/commands/create_tokenization_key.go @@ -10,32 +10,30 @@ import ( // RunCreateTokenizationKey creates a new tokenization key with the specified parameters. // Should be run during initial setup or when adding new tokenization formats. -// -// Requirements: Database must be migrated, MASTER_KEYS and ACTIVE_MASTER_KEY_ID must be set. func RunCreateTokenizationKey( ctx context.Context, tokenizationKeyUseCase tokenizationUseCase.TokenizationKeyUseCase, logger *slog.Logger, name string, - formatType string, + formatTypeStr string, isDeterministic bool, algorithmStr string, ) error { logger.Info("creating new tokenization key", slog.String("name", name), - slog.String("format_type", formatType), + slog.String("format_type", formatTypeStr), slog.Bool("is_deterministic", isDeterministic), slog.String("algorithm", algorithmStr), ) // Parse format type - format, err := parseFormatType(formatType) + format, err := ParseFormatType(formatTypeStr) if err != nil { return err } // Parse algorithm - algorithm, err := parseAlgorithm(algorithmStr) + algorithm, err := ParseAlgorithm(algorithmStr) if err != nil { return err } diff --git a/cmd/app/commands/create_tokenization_key_test.go b/cmd/app/commands/create_tokenization_key_test.go index 48c6c77..eb39d24 100644 --- a/cmd/app/commands/create_tokenization_key_test.go +++ b/cmd/app/commands/create_tokenization_key_test.go @@ -16,27 +16,25 @@ import ( func TestRunCreateTokenizationKey(t *testing.T) { ctx := context.Background() logger := slog.Default() + name := "test-token-key" t.Run("success", func(t *testing.T) { mockUseCase := &tokenizationMocks.MockTokenizationKeyUseCase{} - expectedKey := &tokenizationDomain.TokenizationKey{ - ID: uuid.New(), - Name: "test-token", - FormatType: tokenizationDomain.FormatUUID, - IsDeterministic: true, - Version: 1, - } - mockUseCase.On("Create", ctx, "test-token", tokenizationDomain.FormatUUID, true, cryptoDomain.AESGCM). - Return(expectedKey, nil) - - err := RunCreateTokenizationKey(ctx, mockUseCase, logger, "test-token", "uuid", true, "aes-gcm") + mockUseCase.On("Create", ctx, name, tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM). + Return(&tokenizationDomain.TokenizationKey{ + ID: uuid.New(), + }, nil) + + err := RunCreateTokenizationKey(ctx, mockUseCase, logger, name, "uuid", false, "aes-gcm") + require.NoError(t, err) mockUseCase.AssertExpectations(t) }) t.Run("invalid-format", func(t *testing.T) { mockUseCase := &tokenizationMocks.MockTokenizationKeyUseCase{} - err := RunCreateTokenizationKey(ctx, mockUseCase, logger, "test", "invalid", true, "aes-gcm") + err := RunCreateTokenizationKey(ctx, mockUseCase, logger, name, "invalid", false, "aes-gcm") + require.Error(t, err) require.Contains(t, err.Error(), "invalid format type") }) diff --git a/cmd/app/commands/helpers.go b/cmd/app/commands/helpers.go index 15d44c5..816ecbe 100644 --- a/cmd/app/commands/helpers.go +++ b/cmd/app/commands/helpers.go @@ -4,14 +4,18 @@ package commands import ( "context" "fmt" + "io" "log/slog" + "os" "github.com/golang-migrate/migrate/v4" - - "io" - "os" + _ "github.com/golang-migrate/migrate/v4/database/mysql" + _ "github.com/golang-migrate/migrate/v4/database/postgres" + _ "github.com/golang-migrate/migrate/v4/source/file" "github.com/allisson/secrets/internal/app" + "github.com/allisson/secrets/internal/config" + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" ) @@ -29,15 +33,47 @@ func DefaultIO() IOTuple { } } -// closeContainer closes all resources in the container and logs any errors. -func closeContainer(container *app.Container, logger *slog.Logger) { +// Formatter defines the interface for data that can be output in multiple formats. +type Formatter interface { + ToText() string + ToJSON() string +} + +// WriteOutput writes the formatted data to the provided writer based on the specified format. +func WriteOutput(writer io.Writer, format string, data Formatter) { + if format == "json" { + _, _ = fmt.Fprintln(writer, data.ToJSON()) + } else { + _, _ = fmt.Fprintln(writer, data.ToText()) + } +} + +// ExecuteWithContainer encapsulates the standard CLI command execution pattern: +// loading configuration, initializing the DI container, and ensuring graceful shutdown. +func ExecuteWithContainer( + ctx context.Context, + fn func(ctx context.Context, container *app.Container) error, +) error { + cfg := config.Load() + container := app.NewContainer(cfg) + defer func() { + if err := container.Shutdown(ctx); err != nil { + container.Logger().Error("failed to shutdown container", slog.Any("error", err)) + } + }() + + return fn(ctx, container) +} + +// CloseContainer closes all resources in the container and logs any errors. +func CloseContainer(container *app.Container, logger *slog.Logger) { if err := container.Shutdown(context.Background()); err != nil { logger.Error("failed to shutdown container", slog.Any("error", err)) } } -// closeMigrate closes the migration instance and logs any errors. -func closeMigrate(migrate *migrate.Migrate, logger *slog.Logger) { +// CloseMigrate closes the migration instance and logs any errors. +func CloseMigrate(migrate *migrate.Migrate, logger *slog.Logger) { sourceError, databaseError := migrate.Close() if sourceError != nil || databaseError != nil { logger.Error( @@ -48,9 +84,23 @@ func closeMigrate(migrate *migrate.Migrate, logger *slog.Logger) { } } -// parseFormatType converts format type string to tokenizationDomain.FormatType. -// Returns an error if the format type string is invalid. -func parseFormatType(formatType string) (tokenizationDomain.FormatType, error) { +// ParseAlgorithm converts algorithm string to cryptoDomain.Algorithm type. +func ParseAlgorithm(algorithmStr string) (cryptoDomain.Algorithm, error) { + switch algorithmStr { + case "aes-gcm": + return cryptoDomain.AESGCM, nil + case "chacha20-poly1305": + return cryptoDomain.ChaCha20, nil + default: + return "", fmt.Errorf( + "invalid algorithm: %s (valid options: aes-gcm, chacha20-poly1305)", + algorithmStr, + ) + } +} + +// ParseFormatType converts format type string to tokenizationDomain.FormatType. +func ParseFormatType(formatType string) (tokenizationDomain.FormatType, error) { switch formatType { case "uuid": return tokenizationDomain.FormatUUID, nil diff --git a/cmd/app/commands/helpers_test.go b/cmd/app/commands/helpers_test.go new file mode 100644 index 0000000..f4a3f3b --- /dev/null +++ b/cmd/app/commands/helpers_test.go @@ -0,0 +1,88 @@ +package commands + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" + tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" +) + +type mockFormatter struct { + text string + json string +} + +func (m *mockFormatter) ToText() string { return m.text } +func (m *mockFormatter) ToJSON() string { return m.json } + +func TestWriteOutput(t *testing.T) { + data := &mockFormatter{ + text: "text output", + json: `{"output": "json"}`, + } + + t.Run("text", func(t *testing.T) { + var out bytes.Buffer + WriteOutput(&out, "text", data) + require.Equal(t, "text output\n", out.String()) + }) + + t.Run("json", func(t *testing.T) { + var out bytes.Buffer + WriteOutput(&out, "json", data) + require.Equal(t, "{\"output\": \"json\"}\n", out.String()) + }) +} + +func TestParseAlgorithm(t *testing.T) { + tests := []struct { + input string + expected cryptoDomain.Algorithm + wantErr bool + }{ + {"aes-gcm", cryptoDomain.AESGCM, false}, + {"chacha20-poly1305", cryptoDomain.ChaCha20, false}, + {"invalid", "", true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got, err := ParseAlgorithm(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, got) + } + }) + } +} + +func TestParseFormatType(t *testing.T) { + tests := []struct { + input string + expected tokenizationDomain.FormatType + wantErr bool + }{ + {"uuid", tokenizationDomain.FormatUUID, false}, + {"numeric", tokenizationDomain.FormatNumeric, false}, + {"luhn-preserving", tokenizationDomain.FormatLuhnPreserving, false}, + {"alphanumeric", tokenizationDomain.FormatAlphanumeric, false}, + {"invalid", "", true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got, err := ParseFormatType(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, got) + } + }) + } +} diff --git a/cmd/app/commands/migrations.go b/cmd/app/commands/migrations.go index 7f9ca19..42ff694 100644 --- a/cmd/app/commands/migrations.go +++ b/cmd/app/commands/migrations.go @@ -29,7 +29,7 @@ func RunMigrations(logger *slog.Logger, dbDriver, dbConnectionString string) err if err != nil { return fmt.Errorf("failed to create migrate instance: %w", err) } - defer closeMigrate(m, logger) + defer CloseMigrate(m, logger) if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { return fmt.Errorf("failed to run migrations: %w", err) diff --git a/cmd/app/commands/rotate_kek.go b/cmd/app/commands/rotate_kek.go index 4e75a9e..8175d17 100644 --- a/cmd/app/commands/rotate_kek.go +++ b/cmd/app/commands/rotate_kek.go @@ -9,6 +9,9 @@ import ( cryptoUseCase "github.com/allisson/secrets/internal/crypto/usecase" ) +// RunRotateKek rotates the Key Encryption Key (KEK) for a specific algorithm. +// Generates a new KEK version and marks it as active. Existing secrets encrypted +// with old KEKs remain valid until rewrapped. func RunRotateKek( ctx context.Context, kekUseCase cryptoUseCase.KekUseCase, @@ -19,15 +22,11 @@ func RunRotateKek( logger.Info("rotating KEK", slog.String("algorithm", algorithmStr)) // Parse algorithm - algorithm, err := parseAlgorithm(algorithmStr) + algorithm, err := ParseAlgorithm(algorithmStr) if err != nil { return err } - logger.Info("master key chain loaded", - slog.String("active_master_key_id", masterKeyChain.ActiveMasterKeyID()), - ) - // Rotate the KEK if err := kekUseCase.Rotate(ctx, masterKeyChain, algorithm); err != nil { return fmt.Errorf("failed to rotate KEK: %w", err) diff --git a/cmd/app/commands/rotate_kek_test.go b/cmd/app/commands/rotate_kek_test.go index d45b123..aebb527 100644 --- a/cmd/app/commands/rotate_kek_test.go +++ b/cmd/app/commands/rotate_kek_test.go @@ -16,7 +16,7 @@ func TestRunRotateKek(t *testing.T) { logger := slog.Default() masterKeyChain := cryptoDomain.NewMasterKeyChain("key1") - t.Run("success-aes-gcm", func(t *testing.T) { + t.Run("success", func(t *testing.T) { mockUseCase := &cryptoMocks.MockKekUseCase{} mockUseCase.On("Rotate", ctx, masterKeyChain, cryptoDomain.AESGCM).Return(nil) @@ -26,16 +26,6 @@ func TestRunRotateKek(t *testing.T) { mockUseCase.AssertExpectations(t) }) - t.Run("success-chacha20", func(t *testing.T) { - mockUseCase := &cryptoMocks.MockKekUseCase{} - mockUseCase.On("Rotate", ctx, masterKeyChain, cryptoDomain.ChaCha20).Return(nil) - - err := RunRotateKek(ctx, mockUseCase, masterKeyChain, logger, "chacha20-poly1305") - - require.NoError(t, err) - mockUseCase.AssertExpectations(t) - }) - t.Run("invalid-algorithm", func(t *testing.T) { mockUseCase := &cryptoMocks.MockKekUseCase{} err := RunRotateKek(ctx, mockUseCase, masterKeyChain, logger, "invalid") diff --git a/cmd/app/commands/rotate_master_key.go b/cmd/app/commands/rotate_master_key.go index b33875b..d419c60 100644 --- a/cmd/app/commands/rotate_master_key.go +++ b/cmd/app/commands/rotate_master_key.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "log/slog" + "strings" "time" cryptoService "github.com/allisson/secrets/internal/crypto/service" @@ -34,6 +35,11 @@ func RunRotateMasterKey( return fmt.Errorf("ACTIVE_MASTER_KEY_ID is not set") } + // Ensure active key ID exists in the master keys string + if !strings.Contains(existingMasterKeys, existingActiveKeyID+":") { + return fmt.Errorf("ACTIVE_MASTER_KEY_ID '%s' not found in MASTER_KEYS", existingActiveKeyID) + } + // Generate default key ID if not provided if keyID == "" { keyID = fmt.Sprintf("master-key-%s", time.Now().Format("2006-01-02")) diff --git a/cmd/app/commands/rotate_master_key_test.go b/cmd/app/commands/rotate_master_key_test.go index 1919225..c2add3c 100644 --- a/cmd/app/commands/rotate_master_key_test.go +++ b/cmd/app/commands/rotate_master_key_test.go @@ -3,67 +3,27 @@ package commands import ( "bytes" "context" - "errors" "log/slog" "testing" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - - cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" ) -type mockKMSKeeper struct { - mock.Mock -} - -func (m *mockKMSKeeper) Decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) { - args := m.Called(ctx, ciphertext) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]byte), args.Error(1) -} - -func (m *mockKMSKeeper) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) { - args := m.Called(ctx, plaintext) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]byte), args.Error(1) -} - -func (m *mockKMSKeeper) Close() error { - args := m.Called() - return args.Error(0) -} - -type mockKMSService struct { - mock.Mock -} - -func (m *mockKMSService) OpenKeeper(ctx context.Context, keyURI string) (cryptoDomain.KMSKeeper, error) { - args := m.Called(ctx, keyURI) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).(cryptoDomain.KMSKeeper), args.Error(1) -} - func TestRunRotateMasterKey(t *testing.T) { ctx := context.Background() logger := slog.Default() kmsProvider := "localsecrets" kmsKeyURI := "base64key://YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY=" - existingMasterKeys := "old-key:YWJjZGVmZ2hpamtsbW5vcA==" + existingMasterKeys := "old-key:ciphertext" existingActiveKeyID := "old-key" t.Run("success", func(t *testing.T) { - mockKMSService := &mockKMSService{} - mockKeeper := &mockKMSKeeper{} + mockKMSService := &MockKMSService{} + mockKeeper := &MockKMSKeeper{} mockKMSService.On("OpenKeeper", ctx, kmsKeyURI).Return(mockKeeper, nil) - mockKeeper.On("Encrypt", ctx, mock.AnythingOfType("[]uint8")).Return([]byte("encrypted-key"), nil) + mockKeeper.On("Encrypt", ctx, mock.Anything).Return([]byte("new-ciphertext"), nil) mockKeeper.On("Close").Return(nil) var out bytes.Buffer @@ -80,43 +40,53 @@ func TestRunRotateMasterKey(t *testing.T) { ) require.NoError(t, err) - require.Contains(t, out.String(), "KMS_PROVIDER=\"localsecrets\"") - require.Contains( - t, - out.String(), - "MASTER_KEYS=\"old-key:YWJjZGVmZ2hpamtsbW5vcA==,new-key:ZW5jcnlwdGVkLWtleQ==\"", - ) + require.Contains(t, out.String(), "MASTER_KEYS=\"old-key:ciphertext,new-key:bmV3LWNpcGhlcnRleHQ=\"") require.Contains(t, out.String(), "ACTIVE_MASTER_KEY_ID=\"new-key\"") - mockKMSService.AssertExpectations(t) mockKeeper.AssertExpectations(t) }) - t.Run("kms-open-error", func(t *testing.T) { - mockKMSService := &mockKMSService{} - mockKMSService.On("OpenKeeper", ctx, kmsKeyURI).Return(nil, errors.New("kms error")) + t.Run("missing-kms-params", func(t *testing.T) { + mockKMSService := &MockKMSService{} + err := RunRotateMasterKey(ctx, mockKMSService, logger, &bytes.Buffer{}, "new-key", "", "", "", "") - var out bytes.Buffer + require.Error(t, err) + require.Contains(t, err.Error(), "KMS_PROVIDER and KMS_KEY_URI are required") + }) + + t.Run("missing-existing-keys", func(t *testing.T) { + mockKMSService := &MockKMSService{} err := RunRotateMasterKey( ctx, mockKMSService, logger, - &out, + &bytes.Buffer{}, "new-key", kmsProvider, kmsKeyURI, - existingMasterKeys, - existingActiveKeyID, + "", + "", ) require.Error(t, err) - require.Contains(t, err.Error(), "kms error") + require.Contains(t, err.Error(), "MASTER_KEYS is not set") }) - t.Run("missing-kms-params", func(t *testing.T) { - mockKMSService := &mockKMSService{} - err := RunRotateMasterKey(ctx, mockKMSService, logger, &bytes.Buffer{}, "new-key", "", "", "", "") + t.Run("invalid-active-key-id", func(t *testing.T) { + mockKMSService := &MockKMSService{} + err := RunRotateMasterKey( + ctx, + mockKMSService, + logger, + &bytes.Buffer{}, + "new-key", + kmsProvider, + kmsKeyURI, + existingMasterKeys, + "invalid-key", + ) + require.Error(t, err) - require.Contains(t, err.Error(), "KMS_PROVIDER and KMS_KEY_URI are required") + require.Contains(t, err.Error(), "not found in MASTER_KEYS") }) } diff --git a/cmd/app/commands/rotate_tokenization_key.go b/cmd/app/commands/rotate_tokenization_key.go index b6ff71b..c37e0ff 100644 --- a/cmd/app/commands/rotate_tokenization_key.go +++ b/cmd/app/commands/rotate_tokenization_key.go @@ -8,35 +8,32 @@ import ( tokenizationUseCase "github.com/allisson/secrets/internal/tokenization/usecase" ) -// RunRotateTokenizationKey creates a new version of an existing tokenization key. -// Increments the version number and generates a new DEK while preserving old versions -// for detokenization of previously issued tokens. -// -// Requirements: Database must be migrated, named tokenization key must exist. +// RunRotateTokenizationKey rotates an existing tokenization key to a new version. +// Updates format and deterministic settings. Existing tokens remain valid until rotated. func RunRotateTokenizationKey( ctx context.Context, tokenizationKeyUseCase tokenizationUseCase.TokenizationKeyUseCase, logger *slog.Logger, name string, - formatType string, + formatTypeStr string, isDeterministic bool, algorithmStr string, ) error { logger.Info("rotating tokenization key", slog.String("name", name), - slog.String("format_type", formatType), + slog.String("format_type", formatTypeStr), slog.Bool("is_deterministic", isDeterministic), slog.String("algorithm", algorithmStr), ) // Parse format type - format, err := parseFormatType(formatType) + format, err := ParseFormatType(formatTypeStr) if err != nil { return err } // Parse algorithm - algorithm, err := parseAlgorithm(algorithmStr) + algorithm, err := ParseAlgorithm(algorithmStr) if err != nil { return err } diff --git a/cmd/app/commands/rotate_tokenization_key_test.go b/cmd/app/commands/rotate_tokenization_key_test.go index 6c8fdcc..cb37cd3 100644 --- a/cmd/app/commands/rotate_tokenization_key_test.go +++ b/cmd/app/commands/rotate_tokenization_key_test.go @@ -16,21 +16,26 @@ import ( func TestRunRotateTokenizationKey(t *testing.T) { ctx := context.Background() logger := slog.Default() + name := "test-token-key" t.Run("success", func(t *testing.T) { mockUseCase := &tokenizationMocks.MockTokenizationKeyUseCase{} - expectedKey := &tokenizationDomain.TokenizationKey{ - ID: uuid.New(), - Name: "test-token", - FormatType: tokenizationDomain.FormatUUID, - IsDeterministic: true, - Version: 2, - } - mockUseCase.On("Rotate", ctx, "test-token", tokenizationDomain.FormatUUID, true, cryptoDomain.AESGCM). - Return(expectedKey, nil) - - err := RunRotateTokenizationKey(ctx, mockUseCase, logger, "test-token", "uuid", true, "aes-gcm") + mockUseCase.On("Rotate", ctx, name, tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM). + Return(&tokenizationDomain.TokenizationKey{ + ID: uuid.New(), + }, nil) + + err := RunRotateTokenizationKey(ctx, mockUseCase, logger, name, "uuid", false, "aes-gcm") + require.NoError(t, err) mockUseCase.AssertExpectations(t) }) + + t.Run("invalid-format", func(t *testing.T) { + mockUseCase := &tokenizationMocks.MockTokenizationKeyUseCase{} + err := RunRotateTokenizationKey(ctx, mockUseCase, logger, name, "invalid", false, "aes-gcm") + + require.Error(t, err) + require.Contains(t, err.Error(), "invalid format type") + }) } diff --git a/cmd/app/commands/server.go b/cmd/app/commands/server.go index 59ea16c..bea3b4f 100644 --- a/cmd/app/commands/server.go +++ b/cmd/app/commands/server.go @@ -34,7 +34,7 @@ func RunServer(ctx context.Context, version string) error { logger.Info("starting server", slog.String("version", version)) // Ensure cleanup on exit - defer closeContainer(container, logger) + defer CloseContainer(container, logger) // Get HTTP server from container (this initializes all dependencies) server, err := container.HTTPServer() @@ -72,7 +72,7 @@ func RunServer(ctx context.Context, version string) error { select { case <-ctx.Done(): logger.Info("shutdown signal received") - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), cfg.DBConnMaxLifetime) + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), cfg.ServerShutdownTimeout) defer shutdownCancel() var shutdownErrors []error @@ -93,7 +93,7 @@ func RunServer(ctx context.Context, version string) error { case err := <-serverErr: // Attempt graceful shutdown if one server fails logger.Error("server error, initiating shutdown", slog.Any("error", err)) - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), cfg.DBConnMaxLifetime) + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), cfg.ServerShutdownTimeout) defer shutdownCancel() var shutdownErrors []error diff --git a/cmd/app/commands/update_client.go b/cmd/app/commands/update_client.go index 79d2f43..7f0caa7 100644 --- a/cmd/app/commands/update_client.go +++ b/cmd/app/commands/update_client.go @@ -1,26 +1,43 @@ package commands import ( - "bufio" "context" "encoding/json" "fmt" - "io" "log/slog" - "strings" "github.com/google/uuid" authDomain "github.com/allisson/secrets/internal/auth/domain" authUseCase "github.com/allisson/secrets/internal/auth/usecase" + "github.com/allisson/secrets/internal/ui" ) +// UpdateClientResult holds the result of the client update operation. +type UpdateClientResult struct { + ID string `json:"client_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` +} + +// ToText returns a human-readable representation of the update result. +func (r *UpdateClientResult) ToText() string { + output := "\nClient updated successfully!\n" + output += fmt.Sprintf("Client ID: %s\n", r.ID) + output += fmt.Sprintf("Name: %s\n", r.Name) + output += fmt.Sprintf("Active: %t", r.IsActive) + return output +} + +// ToJSON returns a JSON representation of the update result. +func (r *UpdateClientResult) ToJSON() string { + jsonBytes, _ := json.MarshalIndent(r, "", " ") + return string(jsonBytes) +} + // RunUpdateClient updates an existing authentication client's configuration. // Supports both interactive mode (when policiesJSON is empty) and non-interactive -// mode (when policiesJSON is provided). Only Name, IsActive, and Policies can be -// updated. The client ID and secret remain unchanged. -// -// Requirements: Database must be migrated and the client must exist. +// mode (when policiesJSON is provided). func RunUpdateClient( ctx context.Context, clientUseCase authUseCase.ClientUseCase, @@ -51,7 +68,7 @@ func RunUpdateClient( if policiesJSON == "" { // Interactive mode - show current policies and prompt for new ones - policies, err = promptForPoliciesUpdate(io, existingClient.Policies) + policies, err = ui.PromptForPoliciesUpdate(io.Reader, io.Writer, existingClient.Policies) if err != nil { return fmt.Errorf("failed to get policies: %w", err) } @@ -79,12 +96,13 @@ func RunUpdateClient( return fmt.Errorf("failed to update client: %w", err) } - // Output result based on format - if format == "json" { - outputUpdateJSON(io.Writer, clientID, name, isActive) - } else { - outputUpdateText(io.Writer, clientID, name, isActive) + // Output result + result := &UpdateClientResult{ + ID: clientID.String(), + Name: name, + IsActive: isActive, } + WriteOutput(io.Writer, format, result) logger.Info("client updated successfully", slog.String("client_id", clientID.String()), @@ -94,113 +112,3 @@ func RunUpdateClient( return nil } - -// promptForPoliciesUpdate interactively prompts the user to enter policy documents. -// Shows current policies and available capabilities. Accepts multiple policies until user declines. -func promptForPoliciesUpdate( - io IOTuple, - currentPolicies []authDomain.PolicyDocument, -) ([]authDomain.PolicyDocument, error) { - reader := bufio.NewReader(io.Reader) - var policies []authDomain.PolicyDocument - - _, _ = fmt.Fprintln(io.Writer, "\nCurrent policies:") - for i, policy := range currentPolicies { - capsStr := make([]string, len(policy.Capabilities)) - for j, cap := range policy.Capabilities { - capsStr[j] = string(cap) - } - _, _ = fmt.Fprintf( - io.Writer, - " %d. Path: %s, Capabilities: [%s]\n", - i+1, - policy.Path, - strings.Join(capsStr, ", "), - ) - } - - _, _ = fmt.Fprintln(io.Writer, "\nEnter new policies for the client") - _, _ = fmt.Fprintln(io.Writer, "Available capabilities: read, write, delete, encrypt, decrypt, rotate") - _, _ = fmt.Fprintln(io.Writer) - - policyNum := 1 - for { - _, _ = fmt.Fprintf(io.Writer, "Policy #%d\n", policyNum) - - // Get path - _, _ = fmt.Fprint(io.Writer, "Enter path pattern (e.g., 'secret/*' or '*'): ") - path, err := reader.ReadString('\n') - if err != nil { - return nil, fmt.Errorf("failed to read path: %w", err) - } - path = strings.TrimSpace(path) - - if path == "" { - return nil, fmt.Errorf("path cannot be empty") - } - - // Get capabilities - _, _ = fmt.Fprint(io.Writer, "Enter capabilities (comma-separated, e.g., 'read,write'): ") - capsInput, err := reader.ReadString('\n') - if err != nil { - return nil, fmt.Errorf("failed to read capabilities: %w", err) - } - capsInput = strings.TrimSpace(capsInput) - - if capsInput == "" { - return nil, fmt.Errorf("capabilities cannot be empty") - } - - capabilities, err := parseCapabilities(capsInput) - if err != nil { - return nil, err - } - - // Add policy - policies = append(policies, authDomain.PolicyDocument{ - Path: path, - Capabilities: capabilities, - }) - - // Ask if user wants to add another - _, _ = fmt.Fprint(io.Writer, "Add another policy? (y/n): ") - addAnother, err := reader.ReadString('\n') - if err != nil { - return nil, fmt.Errorf("failed to read input: %w", err) - } - addAnother = strings.ToLower(strings.TrimSpace(addAnother)) - - if addAnother != "y" && addAnother != "yes" { - break - } - - _, _ = fmt.Fprintln(io.Writer) - policyNum++ - } - - return policies, nil -} - -// outputUpdateText outputs the result in human-readable text format. -func outputUpdateText(writer io.Writer, clientID uuid.UUID, name string, isActive bool) { - _, _ = fmt.Fprintln(writer, "\nClient updated successfully!") - _, _ = fmt.Fprintf(writer, "Client ID: %s\n", clientID.String()) - _, _ = fmt.Fprintf(writer, "Name: %s\n", name) - _, _ = fmt.Fprintf(writer, "Active: %t\n", isActive) -} - -// outputUpdateJSON outputs the result in JSON format for machine consumption. -func outputUpdateJSON(writer io.Writer, clientID uuid.UUID, name string, isActive bool) { - result := map[string]interface{}{ - "client_id": clientID.String(), - "name": name, - "is_active": isActive, - } - - jsonBytes, err := json.MarshalIndent(result, "", " ") - if err != nil { - return - } - - _, _ = fmt.Fprintln(writer, string(jsonBytes)) -} diff --git a/cmd/app/commands/update_client_test.go b/cmd/app/commands/update_client_test.go index e6efe4e..b437bcf 100644 --- a/cmd/app/commands/update_client_test.go +++ b/cmd/app/commands/update_client_test.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "encoding/json" - "errors" + "fmt" "log/slog" "testing" @@ -20,119 +20,101 @@ func TestRunUpdateClient(t *testing.T) { ctx := context.Background() logger := slog.Default() clientID := uuid.New() - clientIDStr := clientID.String() - - existingClient := &authDomain.Client{ - ID: clientID, - Name: "old-name", - IsActive: true, - Policies: []authDomain.PolicyDocument{ - {Path: "secret/*", Capabilities: []authDomain.Capability{authDomain.ReadCapability}}, + name := "updated-client" + policies := []authDomain.PolicyDocument{ + { + Path: "*", + Capabilities: []authDomain.Capability{"read", "write"}, }, } + policiesJSON, _ := json.Marshal(policies) - t.Run("success-non-interactive-text", func(t *testing.T) { + t.Run("non-interactive-text", func(t *testing.T) { mockUseCase := &authMocks.MockClientUseCase{} - mockUseCase.On("Get", ctx, clientID).Return(existingClient, nil) - mockUseCase.On("Update", ctx, clientID, mock.AnythingOfType("*domain.UpdateClientInput")).Return(nil) + mockUseCase.On("Get", ctx, clientID).Return(&authDomain.Client{ID: clientID}, nil) + mockUseCase.On("Update", ctx, clientID, mock.Anything).Return(nil) var out bytes.Buffer - io := IOTuple{Reader: &bytes.Buffer{}, Writer: &out} - - policiesJSON := `[{"path": "secret/*", "capabilities": ["read", "write"]}]` + io := IOTuple{Writer: &out} err := RunUpdateClient( ctx, mockUseCase, logger, io, - clientIDStr, - "new-name", + clientID.String(), + name, true, - policiesJSON, + string(policiesJSON), "text", ) require.NoError(t, err) require.Contains(t, out.String(), "Client updated successfully!") - require.Contains(t, out.String(), "Name: new-name") - + require.Contains(t, out.String(), clientID.String()) + require.Contains(t, out.String(), name) mockUseCase.AssertExpectations(t) }) - t.Run("success-non-interactive-json", func(t *testing.T) { + t.Run("non-interactive-json", func(t *testing.T) { mockUseCase := &authMocks.MockClientUseCase{} - mockUseCase.On("Get", ctx, clientID).Return(existingClient, nil) - mockUseCase.On("Update", ctx, clientID, mock.AnythingOfType("*domain.UpdateClientInput")).Return(nil) + mockUseCase.On("Get", ctx, clientID).Return(&authDomain.Client{ID: clientID}, nil) + mockUseCase.On("Update", ctx, clientID, mock.Anything).Return(nil) var out bytes.Buffer - io := IOTuple{Reader: &bytes.Buffer{}, Writer: &out} - - policiesJSON := `[{"path": "secret/*", "capabilities": ["read", "write"]}]` + io := IOTuple{Writer: &out} err := RunUpdateClient( ctx, mockUseCase, logger, io, - clientIDStr, - "new-name", + clientID.String(), + name, true, - policiesJSON, + string(policiesJSON), "json", ) require.NoError(t, err) - - var result map[string]interface{} - err = json.Unmarshal(out.Bytes(), &result) - require.NoError(t, err) - require.Equal(t, clientIDStr, result["client_id"]) - require.Equal(t, "new-name", result["name"]) - require.Equal(t, true, result["is_active"]) - + require.Contains(t, out.String(), clientID.String()) + require.Contains(t, out.String(), name) mockUseCase.AssertExpectations(t) }) - t.Run("invalid-client-id", func(t *testing.T) { + t.Run("invalid-id", func(t *testing.T) { mockUseCase := &authMocks.MockClientUseCase{} err := RunUpdateClient( ctx, mockUseCase, logger, - DefaultIO(), - "invalid-uuid", - "name", + IOTuple{}, + "invalid-id", + name, true, - "", + string(policiesJSON), "text", ) + require.Error(t, err) require.Contains(t, err.Error(), "invalid client ID format") }) - t.Run("client-not-found", func(t *testing.T) { + t.Run("not-found", func(t *testing.T) { mockUseCase := &authMocks.MockClientUseCase{} - mockUseCase.On("Get", ctx, clientID).Return(nil, errors.New("client not found")) + mockUseCase.On("Get", ctx, clientID).Return(nil, fmt.Errorf("not found")) + + err := RunUpdateClient( + ctx, + mockUseCase, + logger, + IOTuple{}, + clientID.String(), + name, + true, + string(policiesJSON), + "text", + ) - err := RunUpdateClient(ctx, mockUseCase, logger, DefaultIO(), clientIDStr, "name", true, "", "text") require.Error(t, err) require.Contains(t, err.Error(), "failed to get existing client") }) - - t.Run("interactive-success", func(t *testing.T) { - mockUseCase := &authMocks.MockClientUseCase{} - mockUseCase.On("Get", ctx, clientID).Return(existingClient, nil) - mockUseCase.On("Update", ctx, clientID, mock.AnythingOfType("*domain.UpdateClientInput")).Return(nil) - - // Mock user input: path, capabilities, another policy? (n) - input := bytes.NewBufferString("secret/test\nread,write\nn\n") - var out bytes.Buffer - io := IOTuple{Reader: input, Writer: &out} - - err := RunUpdateClient(ctx, mockUseCase, logger, io, clientIDStr, "new-name", true, "", "text") - - require.NoError(t, err) - require.Contains(t, out.String(), "Client updated successfully!") - - mockUseCase.AssertExpectations(t) - }) } diff --git a/cmd/app/commands/verify_audit_logs.go b/cmd/app/commands/verify_audit_logs.go index 122621b..748357f 100644 --- a/cmd/app/commands/verify_audit_logs.go +++ b/cmd/app/commands/verify_audit_logs.go @@ -11,25 +11,73 @@ import ( authUseCase "github.com/allisson/secrets/internal/auth/usecase" ) +// VerifyAuditLogsResult holds the result of the audit log verification operation. +type VerifyAuditLogsResult struct { + TotalChecked int64 `json:"total_checked"` + SignedCount int64 `json:"signed_count"` + UnsignedCount int64 `json:"unsigned_count"` + ValidCount int64 `json:"valid_count"` + InvalidCount int64 `json:"invalid_count"` + InvalidLogs []string `json:"invalid_logs"` + Passed bool `json:"passed"` + StartDate time.Time `json:"start_date"` + EndDate time.Time `json:"end_date"` +} + +// ToText returns a human-readable representation of the verification result. +func (r *VerifyAuditLogsResult) ToText() string { + output := "Audit Log Integrity Verification\n" + output += "=================================\n\n" + output += fmt.Sprintf( + "Time Range: %s to %s\n\n", + r.StartDate.Format("2006-01-02 15:04:05"), + r.EndDate.Format("2006-01-02 15:04:05"), + ) + + output += fmt.Sprintf("Total Checked: %d\n", r.TotalChecked) + output += fmt.Sprintf("Signed: %d\n", r.SignedCount) + output += fmt.Sprintf("Unsigned: %d (legacy)\n", r.UnsignedCount) + output += fmt.Sprintf("Valid: %d\n", r.ValidCount) + output += fmt.Sprintf("Invalid: %d\n\n", r.InvalidCount) + + switch { + case r.InvalidCount > 0: + output += fmt.Sprintf("WARNING: %d log(s) failed integrity check!\n\n", r.InvalidCount) + output += "Invalid Log IDs:\n" + for _, id := range r.InvalidLogs { + output += fmt.Sprintf(" - %s\n", id) + } + output += "\nStatus: FAILED ❌" + case r.TotalChecked == 0: + output += "Status: No logs found in specified time range" + default: + output += "Status: PASSED ✓" + } + return output +} + +// ToJSON returns a JSON representation of the verification result. +func (r *VerifyAuditLogsResult) ToJSON() string { + jsonBytes, _ := json.MarshalIndent(r, "", " ") + return string(jsonBytes) +} + // RunVerifyAuditLogs verifies cryptographic integrity of audit logs within a time range. -// Validates HMAC-SHA256 signatures against KEK-derived signing keys for tamper detection. -// -// Requirements: Database must be migrated with signature columns and KEK chain loaded. func RunVerifyAuditLogs( ctx context.Context, auditLogUseCase authUseCase.AuditLogUseCase, logger *slog.Logger, writer io.Writer, - startDate, endDate string, + startDateStr, endDateStr string, format string, ) error { // Parse date strings to time.Time - start, err := parseDate(startDate) + start, err := parseDate(startDateStr) if err != nil { return fmt.Errorf("invalid start date: %w", err) } - end, err := parseDate(endDate) + end, err := parseDate(endDateStr) if err != nil { return fmt.Errorf("invalid end date: %w", err) } @@ -50,14 +98,25 @@ func RunVerifyAuditLogs( return fmt.Errorf("failed to verify audit logs: %w", err) } - // Output result based on format - if format == "json" { - if err := outputVerifyJSON(writer, report); err != nil { - return fmt.Errorf("failed to output JSON: %w", err) - } - } else { - outputVerifyText(writer, report, start, end) + // Convert UUIDs to strings + invalidLogs := make([]string, len(report.InvalidLogs)) + for i, id := range report.InvalidLogs { + invalidLogs[i] = id.String() + } + + // Output result + result := &VerifyAuditLogsResult{ + TotalChecked: report.TotalChecked, + SignedCount: report.SignedCount, + UnsignedCount: report.UnsignedCount, + ValidCount: report.ValidCount, + InvalidCount: report.InvalidCount, + InvalidLogs: invalidLogs, + Passed: report.InvalidCount == 0, + StartDate: start, + EndDate: end, } + WriteOutput(writer, format, result) // Log summary logger.Info("verification completed", @@ -94,55 +153,3 @@ func parseDate(dateStr string) (time.Time, error) { return t, nil } - -// outputVerifyText outputs the verification result in human-readable text format. -func outputVerifyText(writer io.Writer, report *authUseCase.VerificationReport, start, end time.Time) { - _, _ = fmt.Fprintf(writer, "Audit Log Integrity Verification\n") - _, _ = fmt.Fprintf(writer, "=================================\n\n") - _, _ = fmt.Fprintf(writer, - "Time Range: %s to %s\n\n", - start.Format("2006-01-02 15:04:05"), - end.Format("2006-01-02 15:04:05"), - ) - - _, _ = fmt.Fprintf(writer, "Total Checked: %d\n", report.TotalChecked) - _, _ = fmt.Fprintf(writer, "Signed: %d\n", report.SignedCount) - _, _ = fmt.Fprintf(writer, "Unsigned: %d (legacy)\n", report.UnsignedCount) - _, _ = fmt.Fprintf(writer, "Valid: %d\n", report.ValidCount) - _, _ = fmt.Fprintf(writer, "Invalid: %d\n\n", report.InvalidCount) - - switch { - case report.InvalidCount > 0: - _, _ = fmt.Fprintf(writer, "WARNING: %d log(s) failed integrity check!\n\n", report.InvalidCount) - _, _ = fmt.Fprintf(writer, "Invalid Log IDs:\n") - for _, id := range report.InvalidLogs { - _, _ = fmt.Fprintf(writer, " - %s\n", id) - } - _, _ = fmt.Fprintf(writer, "\nStatus: FAILED ❌\n") - case report.TotalChecked == 0: - _, _ = fmt.Fprintf(writer, "Status: No logs found in specified time range\n") - default: - _, _ = fmt.Fprintf(writer, "Status: PASSED ✓\n") - } -} - -// outputVerifyJSON outputs the verification result in JSON format for machine consumption. -func outputVerifyJSON(writer io.Writer, report *authUseCase.VerificationReport) error { - result := map[string]interface{}{ - "total_checked": report.TotalChecked, - "signed_count": report.SignedCount, - "unsigned_count": report.UnsignedCount, - "valid_count": report.ValidCount, - "invalid_count": report.InvalidCount, - "invalid_logs": report.InvalidLogs, - "passed": report.InvalidCount == 0, - } - - jsonBytes, err := json.MarshalIndent(result, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal JSON: %w", err) - } - - _, _ = fmt.Fprintln(writer, string(jsonBytes)) - return nil -} diff --git a/cmd/app/commands/verify_audit_logs_test.go b/cmd/app/commands/verify_audit_logs_test.go index a96ea1f..c3af6c0 100644 --- a/cmd/app/commands/verify_audit_logs_test.go +++ b/cmd/app/commands/verify_audit_logs_test.go @@ -3,7 +3,6 @@ package commands import ( "bytes" "context" - "encoding/json" "log/slog" "testing" @@ -18,63 +17,74 @@ import ( func TestRunVerifyAuditLogs(t *testing.T) { ctx := context.Background() logger := slog.Default() - startDate := "2025-01-01" - endDate := "2025-01-02" + startDate := "2023-01-01" + endDate := "2023-01-02" - report := &authUseCase.VerificationReport{ - TotalChecked: 10, - SignedCount: 10, - ValidCount: 10, - } - - t.Run("success-text", func(t *testing.T) { + t.Run("text-output-pass", func(t *testing.T) { mockUseCase := &authMocks.MockAuditLogUseCase{} - mockUseCase.On("VerifyBatch", ctx, mock.AnythingOfType("time.Time"), mock.AnythingOfType("time.Time")). - Return(report, nil) + mockUseCase.On("VerifyBatch", ctx, mock.Anything, mock.Anything). + Return(&authUseCase.VerificationReport{ + TotalChecked: 10, + ValidCount: 10, + }, nil) var out bytes.Buffer err := RunVerifyAuditLogs(ctx, mockUseCase, logger, &out, startDate, endDate, "text") + require.NoError(t, err) - require.Contains(t, out.String(), "Audit Log Integrity Verification") + require.Contains(t, out.String(), "Status: PASSED") + mockUseCase.AssertExpectations(t) + }) + + t.Run("text-output-fail", func(t *testing.T) { + mockUseCase := &authMocks.MockAuditLogUseCase{} + invalidLogs := []uuid.UUID{uuid.New(), uuid.New()} + mockUseCase.On("VerifyBatch", ctx, mock.Anything, mock.Anything). + Return(&authUseCase.VerificationReport{ + TotalChecked: 10, + InvalidCount: 2, + InvalidLogs: invalidLogs, + }, nil) + + var out bytes.Buffer + err := RunVerifyAuditLogs(ctx, mockUseCase, logger, &out, startDate, endDate, "text") + + require.Error(t, err) + require.Contains(t, out.String(), "Status: FAILED") + require.Contains(t, out.String(), invalidLogs[0].String()) mockUseCase.AssertExpectations(t) }) - t.Run("success-json", func(t *testing.T) { + t.Run("json-output", func(t *testing.T) { mockUseCase := &authMocks.MockAuditLogUseCase{} - mockUseCase.On("VerifyBatch", ctx, mock.AnythingOfType("time.Time"), mock.AnythingOfType("time.Time")). - Return(report, nil) + mockUseCase.On("VerifyBatch", ctx, mock.Anything, mock.Anything). + Return(&authUseCase.VerificationReport{ + TotalChecked: 5, + ValidCount: 5, + }, nil) var out bytes.Buffer err := RunVerifyAuditLogs(ctx, mockUseCase, logger, &out, startDate, endDate, "json") - require.NoError(t, err) - var result map[string]interface{} - err = json.Unmarshal(out.Bytes(), &result) require.NoError(t, err) - require.Equal(t, float64(10), result["total_checked"]) + require.Contains(t, out.String(), `"total_checked": 5`) + require.Contains(t, out.String(), `"passed": true`) mockUseCase.AssertExpectations(t) }) - t.Run("invalid-dates", func(t *testing.T) { - err := RunVerifyAuditLogs(ctx, nil, logger, nil, "invalid", endDate, "text") + t.Run("invalid-date", func(t *testing.T) { + mockUseCase := &authMocks.MockAuditLogUseCase{} + err := RunVerifyAuditLogs(ctx, mockUseCase, logger, &bytes.Buffer{}, "invalid", endDate, "text") + require.Error(t, err) require.Contains(t, err.Error(), "invalid start date") }) - t.Run("integrity-failure", func(t *testing.T) { + t.Run("invalid-range", func(t *testing.T) { mockUseCase := &authMocks.MockAuditLogUseCase{} - failureReport := &authUseCase.VerificationReport{ - TotalChecked: 10, - InvalidCount: 2, - InvalidLogs: []uuid.UUID{uuid.New(), uuid.New()}, - } - mockUseCase.On("VerifyBatch", ctx, mock.AnythingOfType("time.Time"), mock.AnythingOfType("time.Time")). - Return(failureReport, nil) + err := RunVerifyAuditLogs(ctx, mockUseCase, logger, &bytes.Buffer{}, endDate, startDate, "text") - var out bytes.Buffer - err := RunVerifyAuditLogs(ctx, mockUseCase, logger, &out, startDate, endDate, "text") require.Error(t, err) - require.Contains(t, err.Error(), "integrity check failed") - require.Contains(t, out.String(), "WARNING: 2 log(s) failed integrity check!") + require.Contains(t, err.Error(), "end date must be after start date") }) } diff --git a/cmd/app/commands_test.go b/cmd/app/commands_test.go new file mode 100644 index 0000000..d01231b --- /dev/null +++ b/cmd/app/commands_test.go @@ -0,0 +1,84 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetCommands(t *testing.T) { + version := "v0.1.0" + cmds := getCommands(version) + + require.NotEmpty(t, cmds) + + // Check if key commands exist + foundCreateMasterKey := false + for _, cmd := range cmds { + if cmd.Name == "create-master-key" { + foundCreateMasterKey = true + break + } + } + require.True(t, foundCreateMasterKey) +} + +func TestGetAuthCommands(t *testing.T) { + cmds := getAuthCommands() + require.NotEmpty(t, cmds) + + expectedCmds := []string{"clean-expired-tokens", "create-client", "update-client"} + for _, name := range expectedCmds { + found := false + for _, cmd := range cmds { + if cmd.Name == name { + found = true + break + } + } + require.Truef(t, found, "command %s not found", name) + } +} + +func TestGetKeyCommands(t *testing.T) { + cmds := getKeyCommands() + require.NotEmpty(t, cmds) + + expectedCmds := []string{ + "create-master-key", + "rotate-master-key", + "create-kek", + "rotate-kek", + "rewrap-deks", + "create-tokenization-key", + "rotate-tokenization-key", + } + for _, name := range expectedCmds { + found := false + for _, cmd := range cmds { + if cmd.Name == name { + found = true + break + } + } + require.Truef(t, found, "command %s not found", name) + } +} + +func TestGetSystemCommands(t *testing.T) { + version := "v0.1.0" + cmds := getSystemCommands(version) + require.NotEmpty(t, cmds) + + expectedCmds := []string{"server", "migrate", "clean-audit-logs", "verify-audit-logs"} + for _, name := range expectedCmds { + found := false + for _, cmd := range cmds { + if cmd.Name == name { + found = true + break + } + } + require.Truef(t, found, "command %s not found", name) + } +} diff --git a/cmd/app/key_commands.go b/cmd/app/key_commands.go index 38d50f0..d38b5d3 100644 --- a/cmd/app/key_commands.go +++ b/cmd/app/key_commands.go @@ -1,3 +1,4 @@ +// Package main provides the CLI command definitions for the application. package main import ( @@ -8,10 +9,10 @@ import ( "github.com/allisson/secrets/cmd/app/commands" "github.com/allisson/secrets/internal/app" - "github.com/allisson/secrets/internal/config" cryptoService "github.com/allisson/secrets/internal/crypto/service" ) +// getKeyCommands returns the key-related CLI commands. func getKeyCommands() []*cli.Command { return []*cli.Command{ { @@ -38,18 +39,19 @@ func getKeyCommands() []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - return commands.RunCreateMasterKey( + return commands.ExecuteWithContainer( ctx, - cryptoService.NewKMSService(), - container.Logger(), - commands.DefaultIO().Writer, - cmd.String("id"), - cmd.String("kms-provider"), - cmd.String("kms-key-uri"), + func(ctx context.Context, container *app.Container) error { + return commands.RunCreateMasterKey( + ctx, + cryptoService.NewKMSService(), + container.Logger(), + commands.DefaultIO().Writer, + cmd.String("id"), + cmd.String("kms-provider"), + cmd.String("kms-key-uri"), + ) + }, ) }, }, @@ -77,20 +79,21 @@ func getKeyCommands() []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - return commands.RunRotateMasterKey( + return commands.ExecuteWithContainer( ctx, - cryptoService.NewKMSService(), - container.Logger(), - commands.DefaultIO().Writer, - cmd.String("id"), - cmd.String("kms-provider"), - cmd.String("kms-key-uri"), - os.Getenv("MASTER_KEYS"), - os.Getenv("ACTIVE_MASTER_KEY_ID"), + func(ctx context.Context, container *app.Container) error { + return commands.RunRotateMasterKey( + ctx, + cryptoService.NewKMSService(), + container.Logger(), + commands.DefaultIO().Writer, + cmd.String("id"), + cmd.String("kms-provider"), + cmd.String("kms-key-uri"), + os.Getenv("MASTER_KEYS"), + os.Getenv("ACTIVE_MASTER_KEY_ID"), + ) + }, ) }, }, @@ -106,26 +109,27 @@ func getKeyCommands() []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - kekUseCase, err := container.KekUseCase() - if err != nil { - return err - } + return commands.ExecuteWithContainer( + ctx, + func(ctx context.Context, container *app.Container) error { + kekUseCase, err := container.KekUseCase() + if err != nil { + return err + } - masterKeyChain, err := container.MasterKeyChain() - if err != nil { - return err - } + masterKeyChain, err := container.MasterKeyChain() + if err != nil { + return err + } - return commands.RunCreateKek( - ctx, - kekUseCase, - masterKeyChain, - container.Logger(), - cmd.String("algorithm"), + return commands.RunCreateKek( + ctx, + kekUseCase, + masterKeyChain, + container.Logger(), + cmd.String("algorithm"), + ) + }, ) }, }, @@ -141,26 +145,27 @@ func getKeyCommands() []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - kekUseCase, err := container.KekUseCase() - if err != nil { - return err - } + return commands.ExecuteWithContainer( + ctx, + func(ctx context.Context, container *app.Container) error { + kekUseCase, err := container.KekUseCase() + if err != nil { + return err + } - masterKeyChain, err := container.MasterKeyChain() - if err != nil { - return err - } + masterKeyChain, err := container.MasterKeyChain() + if err != nil { + return err + } - return commands.RunRotateKek( - ctx, - kekUseCase, - masterKeyChain, - container.Logger(), - cmd.String("algorithm"), + return commands.RunRotateKek( + ctx, + kekUseCase, + masterKeyChain, + container.Logger(), + cmd.String("algorithm"), + ) + }, ) }, }, @@ -182,33 +187,34 @@ func getKeyCommands() []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - masterKeyChain, err := container.MasterKeyChain() - if err != nil { - return err - } + return commands.ExecuteWithContainer( + ctx, + func(ctx context.Context, container *app.Container) error { + masterKeyChain, err := container.MasterKeyChain() + if err != nil { + return err + } - kekUseCase, err := container.KekUseCase() - if err != nil { - return err - } + kekUseCase, err := container.KekUseCase() + if err != nil { + return err + } - dekUseCase, err := container.CryptoDekUseCase() - if err != nil { - return err - } + dekUseCase, err := container.CryptoDekUseCase() + if err != nil { + return err + } - return commands.RunRewrapDeks( - ctx, - masterKeyChain, - kekUseCase, - dekUseCase, - container.Logger(), - cmd.String("kek-id"), - int(cmd.Int("batch-size")), + return commands.RunRewrapDeks( + ctx, + masterKeyChain, + kekUseCase, + dekUseCase, + container.Logger(), + cmd.String("kek-id"), + int(cmd.Int("batch-size")), + ) + }, ) }, }, @@ -242,23 +248,24 @@ func getKeyCommands() []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - tokenizationKeyUseCase, err := container.TokenizationKeyUseCase() - if err != nil { - return err - } - - return commands.RunCreateTokenizationKey( + return commands.ExecuteWithContainer( ctx, - tokenizationKeyUseCase, - container.Logger(), - cmd.String("name"), - cmd.String("format"), - cmd.Bool("deterministic"), - cmd.String("algorithm"), + func(ctx context.Context, container *app.Container) error { + tokenizationKeyUseCase, err := container.TokenizationKeyUseCase() + if err != nil { + return err + } + + return commands.RunCreateTokenizationKey( + ctx, + tokenizationKeyUseCase, + container.Logger(), + cmd.String("name"), + cmd.String("format"), + cmd.Bool("deterministic"), + cmd.String("algorithm"), + ) + }, ) }, }, @@ -292,23 +299,24 @@ func getKeyCommands() []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - tokenizationKeyUseCase, err := container.TokenizationKeyUseCase() - if err != nil { - return err - } - - return commands.RunRotateTokenizationKey( + return commands.ExecuteWithContainer( ctx, - tokenizationKeyUseCase, - container.Logger(), - cmd.String("name"), - cmd.String("format"), - cmd.Bool("deterministic"), - cmd.String("algorithm"), + func(ctx context.Context, container *app.Container) error { + tokenizationKeyUseCase, err := container.TokenizationKeyUseCase() + if err != nil { + return err + } + + return commands.RunRotateTokenizationKey( + ctx, + tokenizationKeyUseCase, + container.Logger(), + cmd.String("name"), + cmd.String("format"), + cmd.Bool("deterministic"), + cmd.String("algorithm"), + ) + }, ) }, }, diff --git a/cmd/app/main.go b/cmd/app/main.go index 8f0aff6..cb5e336 100644 --- a/cmd/app/main.go +++ b/cmd/app/main.go @@ -33,7 +33,9 @@ func main() { } if err := cmd.Run(context.Background(), os.Args); err != nil { - slog.Error("application error", slog.Any("error", err)) + // Use a basic JSON logger for early errors if they occur before container initialization + logger := slog.New(slog.NewJSONHandler(os.Stderr, nil)) + logger.Error("application error", slog.Any("error", err)) os.Exit(1) } } diff --git a/cmd/app/system_commands.go b/cmd/app/system_commands.go index 1c88eee..b7187de 100644 --- a/cmd/app/system_commands.go +++ b/cmd/app/system_commands.go @@ -1,3 +1,4 @@ +// Package main provides the CLI command definitions for the application. package main import ( @@ -10,6 +11,7 @@ import ( "github.com/allisson/secrets/internal/config" ) +// getSystemCommands returns the system-related CLI commands. func getSystemCommands(version string) []*cli.Command { return []*cli.Command{ { @@ -23,11 +25,17 @@ func getSystemCommands(version string) []*cli.Command { Name: "migrate", Usage: "Run database migrations", Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - return commands.RunMigrations(container.Logger(), cfg.DBDriver, cfg.DBConnectionString) + return commands.ExecuteWithContainer( + ctx, + func(ctx context.Context, container *app.Container) error { + cfg := config.Load() + return commands.RunMigrations( + container.Logger(), + cfg.DBDriver, + cfg.DBConnectionString, + ) + }, + ) }, }, { @@ -54,23 +62,24 @@ func getSystemCommands(version string) []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - auditLogUseCase, err := container.AuditLogUseCase() - if err != nil { - return err - } - - return commands.RunCleanAuditLogs( + return commands.ExecuteWithContainer( ctx, - auditLogUseCase, - container.Logger(), - commands.DefaultIO().Writer, - int(cmd.Int("days")), - cmd.Bool("dry-run"), - cmd.String("format"), + func(ctx context.Context, container *app.Container) error { + auditLogUseCase, err := container.AuditLogUseCase() + if err != nil { + return err + } + + return commands.RunCleanAuditLogs( + ctx, + auditLogUseCase, + container.Logger(), + commands.DefaultIO().Writer, + int(cmd.Int("days")), + cmd.Bool("dry-run"), + cmd.String("format"), + ) + }, ) }, }, @@ -98,23 +107,24 @@ func getSystemCommands(version string) []*cli.Command { }, }, Action: func(ctx context.Context, cmd *cli.Command) error { - cfg := config.Load() - container := app.NewContainer(cfg) - defer func() { _ = container.Shutdown(ctx) }() - - auditLogUseCase, err := container.AuditLogUseCase() - if err != nil { - return err - } - - return commands.RunVerifyAuditLogs( + return commands.ExecuteWithContainer( ctx, - auditLogUseCase, - container.Logger(), - commands.DefaultIO().Writer, - cmd.String("start-date"), - cmd.String("end-date"), - cmd.String("format"), + func(ctx context.Context, container *app.Container) error { + auditLogUseCase, err := container.AuditLogUseCase() + if err != nil { + return err + } + + return commands.RunVerifyAuditLogs( + ctx, + auditLogUseCase, + container.Logger(), + commands.DefaultIO().Writer, + cmd.String("start-date"), + cmd.String("end-date"), + cmd.String("format"), + ) + }, ) }, }, diff --git a/internal/config/config.go b/internal/config/config.go index 55605a8..fd22bac 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -16,6 +16,8 @@ type Config struct { ServerHost string // ServerPort is the port number the server will listen on. ServerPort int + // ServerShutdownTimeout is the maximum time to wait for the server to gracefully shutdown. + ServerShutdownTimeout time.Duration // DBDriver is the database driver to use (e.g., "postgres", "mysql"). DBDriver string @@ -78,8 +80,9 @@ func Load() *Config { return &Config{ // Server configuration - ServerHost: env.GetString("SERVER_HOST", "0.0.0.0"), - ServerPort: env.GetInt("SERVER_PORT", 8080), + ServerHost: env.GetString("SERVER_HOST", "0.0.0.0"), + ServerPort: env.GetInt("SERVER_PORT", 8080), + ServerShutdownTimeout: env.GetDuration("SERVER_SHUTDOWN_TIMEOUT", 10, time.Second), // Database configuration DBDriver: env.GetString("DB_DRIVER", "postgres"), diff --git a/internal/ui/policies.go b/internal/ui/policies.go new file mode 100644 index 0000000..dd6f975 --- /dev/null +++ b/internal/ui/policies.go @@ -0,0 +1,123 @@ +// Package ui provides interactive CLI components and input validation for the application. +package ui + +import ( + "bufio" + "fmt" + "io" + "strings" + + authDomain "github.com/allisson/secrets/internal/auth/domain" +) + +// PromptForPolicies interactively prompts the user to enter policy documents. +// Shows available capabilities and accepts multiple policies until user declines. +func PromptForPolicies(input io.Reader, output io.Writer) ([]authDomain.PolicyDocument, error) { + reader := bufio.NewReader(input) + var policies []authDomain.PolicyDocument + + _, _ = fmt.Fprintln(output, "\nEnter policies for the client") + _, _ = fmt.Fprintln(output, "Available capabilities: read, write, delete, encrypt, decrypt, rotate") + _, _ = fmt.Fprintln(output) + + policyNum := 1 + for { + _, _ = fmt.Fprintf(output, "Policy #%d\n", policyNum) + + // Get path + _, _ = fmt.Fprint(output, "Enter path pattern (e.g., 'secret/*' or '*'): ") + path, err := reader.ReadString('\n') + if err != nil { + return nil, fmt.Errorf("failed to read path: %w", err) + } + path = strings.TrimSpace(path) + + if path == "" { + return nil, fmt.Errorf("path cannot be empty") + } + + // Get capabilities + _, _ = fmt.Fprint(output, "Enter capabilities (comma-separated, e.g., 'read,write'): ") + capsInput, err := reader.ReadString('\n') + if err != nil { + return nil, fmt.Errorf("failed to read capabilities: %w", err) + } + capsInput = strings.TrimSpace(capsInput) + + if capsInput == "" { + return nil, fmt.Errorf("capabilities cannot be empty") + } + + capabilities, err := ParseCapabilities(capsInput) + if err != nil { + return nil, err + } + + // Add policy + policies = append(policies, authDomain.PolicyDocument{ + Path: path, + Capabilities: capabilities, + }) + + // Ask if user wants to add another + _, _ = fmt.Fprint(output, "Add another policy? (y/n): ") + addAnother, err := reader.ReadString('\n') + if err != nil { + return nil, fmt.Errorf("failed to read input: %w", err) + } + addAnother = strings.ToLower(strings.TrimSpace(addAnother)) + + if addAnother != "y" && addAnother != "yes" { + break + } + + _, _ = fmt.Fprintln(output) + policyNum++ + } + + return policies, nil +} + +// PromptForPoliciesUpdate interactively prompts the user to enter policy documents during an update. +// Shows current policies and available capabilities. +func PromptForPoliciesUpdate( + input io.Reader, + output io.Writer, + currentPolicies []authDomain.PolicyDocument, +) ([]authDomain.PolicyDocument, error) { + _, _ = fmt.Fprintln(output, "\nCurrent policies:") + for i, policy := range currentPolicies { + capsStr := make([]string, len(policy.Capabilities)) + for j, cap := range policy.Capabilities { + capsStr[j] = string(cap) + } + _, _ = fmt.Fprintf( + output, + " %d. Path: %s, Capabilities: [%s]\n", + i+1, + policy.Path, + strings.Join(capsStr, ", "), + ) + } + + return PromptForPolicies(input, output) +} + +// ParseCapabilities converts a comma-separated string into a slice of Capability. +func ParseCapabilities(input string) ([]authDomain.Capability, error) { + parts := strings.Split(input, ",") + capabilities := make([]authDomain.Capability, 0, len(parts)) + + for _, part := range parts { + cap := authDomain.Capability(strings.TrimSpace(part)) + if cap != "" { + capabilities = append(capabilities, cap) + } + } + + if len(capabilities) == 0 { + return nil, fmt.Errorf("at least one capability is required") + } + + return capabilities, nil +} diff --git a/internal/ui/policies_test.go b/internal/ui/policies_test.go new file mode 100644 index 0000000..9773e32 --- /dev/null +++ b/internal/ui/policies_test.go @@ -0,0 +1,69 @@ +package ui + +import ( + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + authDomain "github.com/allisson/secrets/internal/auth/domain" +) + +func TestParseCapabilities(t *testing.T) { + tests := []struct { + input string + expected []authDomain.Capability + wantErr bool + }{ + {"read,write", []authDomain.Capability{"read", "write"}, false}, + {"read , write ", []authDomain.Capability{"read", "write"}, false}, + {"", nil, true}, + {" , ", nil, true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got, err := ParseCapabilities(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, got) + } + }) + } +} + +func TestPromptForPolicies(t *testing.T) { + t.Run("success-single-policy", func(t *testing.T) { + input := "secret/*\nread,write\nn\n" + var output bytes.Buffer + policies, err := PromptForPolicies(strings.NewReader(input), &output) + + require.NoError(t, err) + require.Len(t, policies, 1) + require.Equal(t, "secret/*", policies[0].Path) + require.Equal(t, []authDomain.Capability{"read", "write"}, policies[0].Capabilities) + }) + + t.Run("success-multiple-policies", func(t *testing.T) { + input := "secret/*\nread\ny\nother/*\nwrite\nn\n" + var output bytes.Buffer + policies, err := PromptForPolicies(strings.NewReader(input), &output) + + require.NoError(t, err) + require.Len(t, policies, 2) + require.Equal(t, "secret/*", policies[0].Path) + require.Equal(t, "other/*", policies[1].Path) + }) + + t.Run("empty-path", func(t *testing.T) { + input := "\n" + var output bytes.Buffer + _, err := PromptForPolicies(strings.NewReader(input), &output) + + require.Error(t, err) + require.Contains(t, err.Error(), "path cannot be empty") + }) +}