diff --git a/CHANGELOG.md b/CHANGELOG.md index 0810b0f..8af8fab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,22 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.22.0] - 2026-02-28 + +### Added +- Added `MaxTransitKeyNameLength` (255 characters) constraint for transit key names to ensure database compatibility. +- Added metrics decoration for `transit` and `tokenization` usecases for improved observability. +- New internal testing helpers and DSN getter functions in the integration test suite. + +### Changed +- Refactored `tokenization` domain models, repositories, and generators (Alphanumeric, Luhn, Numeric) for better maintainability and performance. +- Reorganized `transit` domain models and added comprehensive unit tests for `TransitKey`. +- Updated Go version to 1.26.0 in CI workflows and documentation. + +### Fixed +- Corrected `rotate-master-key` CLI flags and documentation in scaling guides. +- Improved error handling in `transit` cryptographic operations. + ## [0.21.0] - 2026-02-28 ### Added diff --git a/cmd/app/main.go b/cmd/app/main.go index 9a081a2..5755882 100644 --- a/cmd/app/main.go +++ b/cmd/app/main.go @@ -12,7 +12,7 @@ import ( // Build-time version information (injected via ldflags during build). var ( - version = "v0.21.0" // Semantic version with "v" prefix (e.g., "v0.12.0") + version = "v0.22.0" // Semantic version with "v" prefix (e.g., "v0.12.0") buildDate = "unknown" // ISO 8601 build timestamp commitSHA = "unknown" // Git commit SHA ) diff --git a/docs/README.md b/docs/README.md index 7524139..119d4f0 100644 --- a/docs/README.md +++ b/docs/README.md @@ -105,7 +105,7 @@ Welcome to the full documentation for Secrets. Pick a path and dive in 🚀 OpenAPI scope note: -- `openapi.yaml` is a baseline subset for common API flows in the current release (v0.20.0, see `docs/metadata.json`) +- `openapi.yaml` is a baseline subset for common API flows in the current release (v0.22.0, see `docs/metadata.json`) - Full endpoint behavior is documented in the endpoint pages under `docs/api/` - Tokenization endpoints are included in `openapi.yaml` for the current release diff --git a/docs/api/data/transit.md b/docs/api/data/transit.md index b99335c..8fce9cd 100644 --- a/docs/api/data/transit.md +++ b/docs/api/data/transit.md @@ -201,7 +201,7 @@ Example decrypt response (`200 OK`): | Endpoint | 401 | 403 | 404 | 409 | 422 | 429 | | --- | --- | --- | --- | --- | --- | --- | -| `POST /v1/transit/keys` | missing/invalid token | missing `write` capability | - | key name already initialized (`version=1`) | invalid create payload | per-client rate limit exceeded | +| `POST /v1/transit/keys` | missing/invalid token | missing `write` capability | - | key name already initialized (`version=1`) | invalid create payload (e.g., `name` exceeds `MaxTransitKeyNameLength` constraint) | per-client rate limit exceeded | | `POST /v1/transit/keys/:name/rotate` | missing/invalid token | missing `rotate` capability | key name not found | - | invalid rotate payload | per-client rate limit exceeded | | `POST /v1/transit/keys/:name/encrypt` | missing/invalid token | missing `encrypt` capability | key name not found | - | `plaintext` missing/invalid base64 | per-client rate limit exceeded | | `POST /v1/transit/keys/:name/decrypt` | missing/invalid token | missing `decrypt` capability | key/version not found | - | malformed `:` | per-client rate limit exceeded | diff --git a/docs/metadata.json b/docs/metadata.json index 5c64ff5..d049065 100644 --- a/docs/metadata.json +++ b/docs/metadata.json @@ -1,5 +1,5 @@ { - "current_release": "v0.20.0", + "current_release": "v0.22.0", "api_version": "v1", "last_docs_refresh": "2026-02-28" } \ No newline at end of file diff --git a/docs/openapi.yaml b/docs/openapi.yaml index 391495d..c59e9e9 100644 --- a/docs/openapi.yaml +++ b/docs/openapi.yaml @@ -2,11 +2,7 @@ openapi: 3.0.3 info: title: Secrets API version: v1 - description: >- - Baseline OpenAPI specification for Secrets API v1. This is intentionally concise - and focuses on high-traffic endpoints and common payloads. OpenAPI path templates - use `{param}` syntax while runtime router/metrics labels may expose `:param` or - wildcard forms such as `*path`. + description: Lightweight secrets manager for simplicity and security. servers: - url: http://localhost:8080 description: Local development diff --git a/docs/releases/RELEASES.md b/docs/releases/RELEASES.md index 399ba43..4ae5d99 100644 --- a/docs/releases/RELEASES.md +++ b/docs/releases/RELEASES.md @@ -6,10 +6,13 @@ This document contains release notes for all versions of Secrets. ## 📑 Quick Navigation -**Latest Release**: [v0.19.0](#0190---2026-02-27) +**Latest Release**: [v0.22.0](#0220---2026-02-28) **All Releases**: +- [v0.22.0 (2026-02-28)](#0220---2026-02-28) - Metrics, Transit and Tokenization improvements +- [v0.21.0 (2026-02-28)](#0210---2026-02-28) - CLI command structure refactor +- [v0.20.0 (2026-02-28)](#0200---2026-02-28) - Go 1.26.0 and doc fixes - [v0.19.0 (2026-02-27)](#0190---2026-02-27) - ⚠️ **Breaking Change**: KMS mode required - [v0.18.0 (2026-02-27)](#0180---2026-02-27) - Repository layer refactoring @@ -56,6 +59,53 @@ This document contains release notes for all versions of Secrets. --- +## [0.22.0] - 2026-02-28 + +### Added + +- Added `MaxTransitKeyNameLength` (255 characters) constraint for transit key names to ensure database compatibility. +- Added metrics decoration for `transit` and `tokenization` usecases for improved observability. +- New internal testing helpers and DSN getter functions in the integration test suite. + +### Changed + +- Refactored `tokenization` domain models, repositories, and generators (Alphanumeric, Luhn, Numeric) for better maintainability and performance. +- Reorganized `transit` domain models and added comprehensive unit tests for `TransitKey`. +- Updated Go version to 1.26.0 in CI workflows and documentation. + +### Fixed + +- Corrected `rotate-master-key` CLI flags and documentation in scaling guides. +- Improved error handling in `transit` cryptographic operations. + +--- + +## [0.21.0] - 2026-02-28 + +### Added + +- Binary releases are now officially supported as a primary installation method. + +### Changed + +- Refactored CLI command structure: individual command files moved to `cmd/app/commands/` for better maintainability. + +--- + +## [0.20.0] - 2026-02-28 + +### Added + +- Upgraded to Go 1.26.0 + +### Fixed + +- Corrected `verify-audit-logs` CLI documentation for time range validation and output format consistency +- Fixed documentation for master key rotation to clarify environment variable update workflow +- Corrected outdated environment variable names and default values in scaling guides + +--- + ## [0.19.0] - 2026-02-27 ### ⚠️ BREAKING CHANGES diff --git a/go.mod b/go.mod index bf55e4f..51ecc12 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,10 @@ go 1.26.0 require ( github.com/allisson/go-env v0.6.0 - github.com/allisson/go-pwdhash v0.3.1 + github.com/allisson/go-pwdhash v0.4.0 github.com/gin-contrib/cors v1.7.6 github.com/gin-contrib/requestid v1.0.5 - github.com/gin-gonic/gin v1.11.0 + github.com/gin-gonic/gin v1.12.0 github.com/go-sql-driver/mysql v1.9.3 github.com/golang-migrate/migrate/v4 v4.19.1 github.com/google/uuid v1.6.0 @@ -20,6 +20,7 @@ require ( go.opentelemetry.io/otel v1.40.0 go.opentelemetry.io/otel/exporters/prometheus v0.62.0 go.opentelemetry.io/otel/metric v1.40.0 + go.opentelemetry.io/otel/sdk v1.40.0 go.opentelemetry.io/otel/sdk/metric v1.40.0 gocloud.dev v0.44.0 gocloud.dev/secrets/hashivault v0.44.0 @@ -58,24 +59,25 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 // indirect github.com/aws/smithy-go v1.23.2 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/bytedance/sonic v1.14.0 // indirect - github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.15.0 // indirect + github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/ccoveille/go-safecast/v2 v2.0.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/gabriel-vasile/mimetype v1.4.9 // indirect + github.com/gabriel-vasile/mimetype v1.4.12 // indirect github.com/gin-contrib/sse v1.1.0 // indirect github.com/go-jose/go-jose/v4 v4.1.1 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.27.0 // indirect + github.com/go-playground/validator/v10 v10.30.1 // indirect github.com/goccy/go-json v0.10.5 // indirect - github.com/goccy/go-yaml v1.18.0 // indirect + github.com/goccy/go-yaml v1.19.2 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/google/wire v0.7.0 // indirect @@ -108,26 +110,23 @@ require ( github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/otlptranslator v1.0.0 // indirect github.com/prometheus/procfs v0.19.2 // indirect - github.com/quic-go/qpack v0.5.1 // indirect - github.com/quic-go/quic-go v0.54.0 // indirect + github.com/quic-go/qpack v0.6.0 // indirect + github.com/quic-go/quic-go v0.59.0 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.3.0 // indirect + github.com/ugorji/go/codec v1.3.1 // indirect + go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.62.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect - go.opentelemetry.io/otel/sdk v1.40.0 // indirect go.opentelemetry.io/otel/trace v1.40.0 // indirect - go.uber.org/mock v0.5.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect - golang.org/x/arch v0.20.0 // indirect - golang.org/x/mod v0.32.0 // indirect - golang.org/x/net v0.49.0 // indirect + golang.org/x/arch v0.22.0 // indirect + golang.org/x/net v0.51.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect - golang.org/x/tools v0.41.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect google.golang.org/api v0.247.0 // indirect google.golang.org/genproto v0.0.0-20250715232539-7130f93afb79 // indirect diff --git a/go.sum b/go.sum index 81f6f32..f0e9d3b 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/allisson/go-env v0.6.0 h1:YaWmnOjhF+0c7GjgJef4LC0XymV12EIoVxJHpHGnGnU= github.com/allisson/go-env v0.6.0/go.mod h1:9XxzBNupzMpZ329C9ZPKIhyI7uCIyhST+/rOFvJpdjQ= -github.com/allisson/go-pwdhash v0.3.1 h1:UzR/0V77E6l63fV6EuAUj0nj1S2jdGADzgoO7UBgaT0= -github.com/allisson/go-pwdhash v0.3.1/go.mod h1:qMlMlCyJ2zwSV8Df406IKgY4VC/39FpiaLamOmZezYU= +github.com/allisson/go-pwdhash v0.4.0 h1:mmiKeXJbykz7xfEOZO+hlqbsxky6OjU3GEiSAbAUrhk= +github.com/allisson/go-pwdhash v0.4.0/go.mod h1:9gKQarlO+tD+rUst0gRT9PD8fgIKsjfV6KeByX9qogM= github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d h1:Byv0BzEl3/e6D5CLfI0j/7hiIEtvGVFPCZ7Ei2oq8iQ= github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-sdk-go-v2 v1.39.6 h1:2JrPCVgWJm7bm83BDwY5z8ietmeJUbh3O2ACnn+Xsqk= @@ -70,10 +70,12 @@ github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM= github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= -github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= -github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= -github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= +github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= +github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= +github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/ccoveille/go-safecast/v2 v2.0.0 h1:+5eyITXAUj3wMjad6cRVJKGnC7vDS55zk0INzJagub0= github.com/ccoveille/go-safecast/v2 v2.0.0/go.mod h1:JIYA4CAR33blIDuE6fSwCp2sz1oOBahXnvmdBhOAABs= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= @@ -106,16 +108,16 @@ github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= -github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= +github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw= +github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/gin-contrib/cors v1.7.6 h1:3gQ8GMzs1Ylpf70y8bMw4fVpycXIeX1ZemuSQIsnQQY= github.com/gin-contrib/cors v1.7.6/go.mod h1:Ulcl+xN4jel9t1Ry8vqph23a60FwH9xVLd+3ykmTjOk= github.com/gin-contrib/requestid v1.0.5 h1:oye4jWPpTmJHLepQWzb36lFZkKzl+gf8R0K/ButxJUY= github.com/gin-contrib/requestid v1.0.5/go.mod h1:vkfMTJPx8IBXnavnuQSM9j5isaQfNja1f1hTB516ilU= github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= -github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk= -github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls= +github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8= +github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc= github.com/go-jose/go-jose/v4 v4.1.1 h1:JYhSgy4mXXzAdF3nUx3ygx347LRXJRrpgyU3adRmkAI= github.com/go-jose/go-jose/v4 v4.1.1/go.mod h1:BdsZGqgdO3b6tTc6LSE56wcDbMMLuPsw5d4ZD5f94kA= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -129,16 +131,16 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4= -github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= +github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w= +github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-test/deep v1.0.2 h1:onZX1rnHT3Wv6cqNgYyFOOlgVKJrksuCMCRvJStbMYw= github.com/go-test/deep v1.0.2/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= -github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= +github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= @@ -255,10 +257,10 @@ github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEo github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM= github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= -github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= -github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg= -github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY= +github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= +github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= +github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= +github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= @@ -274,15 +276,18 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA= -github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= +github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY= +github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= github.com/urfave/cli/v3 v3.6.2 h1:lQuqiPrZ1cIz8hz+HcrG0TNZFxU70dPZ3Yl+pSrH9A8= github.com/urfave/cli/v3 v3.6.2/go.mod h1:ysVLtOEmg2tOy6PknnYVhDoouyC/6N42TMeoMzskhso= +go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= +go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.62.0 h1:rbRJ8BBoVMsQShESYZ0FkvcITu8X8QNwJogcLUmDNNw= @@ -303,22 +308,20 @@ go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZY go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= -go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= gocloud.dev v0.44.0 h1:iVyMAqFl2r6xUy7M4mfqwlN+21UpJoEtgHEcfiLMUXs= gocloud.dev v0.44.0/go.mod h1:ZmjROXGdC/eKZLF1N+RujDlFRx3D+4Av2thREKDMVxY= gocloud.dev/secrets/hashivault v0.44.0 h1:Zwd+EdSQ30BIaGS1w5aRTQUZDNCL133ollkns2RIzFo= gocloud.dev/secrets/hashivault v0.44.0/go.mod h1:GRdIFK5paMZbXftw36rMJQ95CtZ/rEn/+G0G15XTMXU= -golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= -golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= +golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= -golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= +golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= @@ -331,8 +334,6 @@ golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= -golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= -golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= google.golang.org/api v0.247.0 h1:tSd/e0QrUlLsrwMKmkbQhYVa109qIintOls2Wh6bngc= diff --git a/internal/app/di.go b/internal/app/di.go index e8f3aab..3695e71 100644 --- a/internal/app/di.go +++ b/internal/app/di.go @@ -10,30 +10,20 @@ import ( "sync" authHTTP "github.com/allisson/secrets/internal/auth/http" - authMySQL "github.com/allisson/secrets/internal/auth/repository/mysql" - authPostgreSQL "github.com/allisson/secrets/internal/auth/repository/postgresql" authService "github.com/allisson/secrets/internal/auth/service" authUseCase "github.com/allisson/secrets/internal/auth/usecase" "github.com/allisson/secrets/internal/config" cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoMySQL "github.com/allisson/secrets/internal/crypto/repository/mysql" - cryptoPostgreSQL "github.com/allisson/secrets/internal/crypto/repository/postgresql" cryptoService "github.com/allisson/secrets/internal/crypto/service" cryptoUseCase "github.com/allisson/secrets/internal/crypto/usecase" "github.com/allisson/secrets/internal/database" "github.com/allisson/secrets/internal/http" "github.com/allisson/secrets/internal/metrics" secretsHTTP "github.com/allisson/secrets/internal/secrets/http" - secretsMySQL "github.com/allisson/secrets/internal/secrets/repository/mysql" - secretsPostgreSQL "github.com/allisson/secrets/internal/secrets/repository/postgresql" secretsUseCase "github.com/allisson/secrets/internal/secrets/usecase" tokenizationHTTP "github.com/allisson/secrets/internal/tokenization/http" - tokenizationMySQL "github.com/allisson/secrets/internal/tokenization/repository/mysql" - tokenizationPostgreSQL "github.com/allisson/secrets/internal/tokenization/repository/postgresql" tokenizationUseCase "github.com/allisson/secrets/internal/tokenization/usecase" transitHTTP "github.com/allisson/secrets/internal/transit/http" - transitMySQL "github.com/allisson/secrets/internal/transit/repository/mysql" - transitPostgreSQL "github.com/allisson/secrets/internal/transit/repository/postgresql" transitUseCase "github.com/allisson/secrets/internal/transit/usecase" ) @@ -100,8 +90,7 @@ type Container struct { httpServer *http.Server metricsServer *http.MetricsServer - // Initialization flags and mutex for thread-safety - mu sync.Mutex + // Initialization flags and sync.Once for thread-safety loggerInit sync.Once dbInit sync.Once masterKeyChainInit sync.Once @@ -186,24 +175,6 @@ func (c *Container) DB() (*sql.DB, error) { return c.db, nil } -// MasterKeyChain returns the master key chain loaded from environment variables. -func (c *Container) MasterKeyChain() (*cryptoDomain.MasterKeyChain, error) { - var err error - c.masterKeyChainInit.Do(func() { - c.masterKeyChain, err = c.initMasterKeyChain() - if err != nil { - c.initErrors["masterKeyChain"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["masterKeyChain"]; exists { - return nil, storedErr - } - return c.masterKeyChain, nil -} - // TxManager returns the transaction manager. func (c *Container) TxManager() (database.TxManager, error) { var err error @@ -258,102 +229,6 @@ func (c *Container) BusinessMetrics() (metrics.BusinessMetrics, error) { return c.businessMetrics, nil } -// AEADManager returns the AEAD manager service. -func (c *Container) AEADManager() cryptoService.AEADManager { - c.aeadManagerInit.Do(func() { - c.aeadManager = c.initAEADManager() - }) - return c.aeadManager -} - -// KeyManager returns the key manager service. -func (c *Container) KeyManager() cryptoService.KeyManager { - c.keyManagerInit.Do(func() { - c.keyManager = c.initKeyManager() - }) - return c.keyManager -} - -// KMSService returns the KMS service. -func (c *Container) KMSService() cryptoService.KMSService { - c.kmsServiceInit.Do(func() { - c.kmsService = c.initKMSService() - }) - return c.kmsService -} - -// KekRepository returns the KEK repository. -func (c *Container) KekRepository() (cryptoUseCase.KekRepository, error) { - var err error - c.kekRepositoryInit.Do(func() { - c.kekRepository, err = c.initKekRepository() - if err != nil { - c.initErrors["kekRepository"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["kekRepository"]; exists { - return nil, storedErr - } - return c.kekRepository, nil -} - -// KekUseCase returns the KEK use case. -func (c *Container) KekUseCase() (cryptoUseCase.KekUseCase, error) { - var err error - c.kekUseCaseInit.Do(func() { - c.kekUseCase, err = c.initKekUseCase() - if err != nil { - c.initErrors["kekUseCase"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["kekUseCase"]; exists { - return nil, storedErr - } - return c.kekUseCase, nil -} - -// CryptoDekRepository returns the DEK repository for the crypto use case based on database driver. -func (c *Container) CryptoDekRepository() (cryptoUseCase.DekRepository, error) { - var err error - c.cryptoDekRepositoryInit.Do(func() { - c.cryptoDekRepository, err = c.initCryptoDekRepository() - if err != nil { - c.initErrors["cryptoDekRepository"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["cryptoDekRepository"]; exists { - return nil, storedErr - } - return c.cryptoDekRepository, nil -} - -// CryptoDekUseCase returns the DEK use case for the crypto module. -func (c *Container) CryptoDekUseCase() (cryptoUseCase.DekUseCase, error) { - var err error - c.cryptoDekUseCaseInit.Do(func() { - c.cryptoDekUseCase, err = c.initCryptoDekUseCase() - if err != nil { - c.initErrors["cryptoDekUseCase"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["cryptoDekUseCase"]; exists { - return nil, storedErr - } - return c.cryptoDekUseCase, nil -} - // HTTPServer returns the HTTP server instance. func (c *Container) HTTPServer() (*http.Server, error) { var err error @@ -390,261 +265,8 @@ func (c *Container) MetricsServer() (*http.MetricsServer, error) { return c.metricsServer, nil } -// SecretService returns the secret service for authentication operations. -func (c *Container) SecretService() authService.SecretService { - c.secretServiceInit.Do(func() { - c.secretService = c.initSecretService() - }) - return c.secretService -} - -// ClientRepository returns the client repository based on database driver. -func (c *Container) ClientRepository() (authUseCase.ClientRepository, error) { - var err error - c.clientRepositoryInit.Do(func() { - c.clientRepository, err = c.initClientRepository() - if err != nil { - c.initErrors["clientRepository"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["clientRepository"]; exists { - return nil, storedErr - } - return c.clientRepository, nil -} - -// ClientUseCase returns the client use case. -func (c *Container) ClientUseCase() (authUseCase.ClientUseCase, error) { - var err error - c.clientUseCaseInit.Do(func() { - c.clientUseCase, err = c.initClientUseCase() - if err != nil { - c.initErrors["clientUseCase"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["clientUseCase"]; exists { - return nil, storedErr - } - return c.clientUseCase, nil -} - -// TokenService returns the token service for authentication operations. -func (c *Container) TokenService() authService.TokenService { - c.tokenServiceInit.Do(func() { - c.tokenService = c.initTokenService() - }) - return c.tokenService -} - -// TokenRepository returns the token repository based on database driver. -func (c *Container) TokenRepository() (authUseCase.TokenRepository, error) { - var err error - c.tokenRepositoryInit.Do(func() { - c.tokenRepository, err = c.initTokenRepository() - if err != nil { - c.initErrors["tokenRepository"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["tokenRepository"]; exists { - return nil, storedErr - } - return c.tokenRepository, nil -} - -// AuditLogRepository returns the audit log repository based on database driver. -func (c *Container) AuditLogRepository() (authUseCase.AuditLogRepository, error) { - var err error - c.auditLogRepositoryInit.Do(func() { - c.auditLogRepository, err = c.initAuditLogRepository() - if err != nil { - c.initErrors["auditLogRepository"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["auditLogRepository"]; exists { - return nil, storedErr - } - return c.auditLogRepository, nil -} - -// TokenUseCase returns the token use case. -func (c *Container) TokenUseCase() (authUseCase.TokenUseCase, error) { - var err error - c.tokenUseCaseInit.Do(func() { - c.tokenUseCase, err = c.initTokenUseCase() - if err != nil { - c.initErrors["tokenUseCase"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["tokenUseCase"]; exists { - return nil, storedErr - } - return c.tokenUseCase, nil -} - -// AuditLogUseCase returns the audit log use case. -func (c *Container) AuditLogUseCase() (authUseCase.AuditLogUseCase, error) { - var err error - c.auditLogUseCaseInit.Do(func() { - c.auditLogUseCase, err = c.initAuditLogUseCase() - if err != nil { - c.initErrors["auditLogUseCase"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["auditLogUseCase"]; exists { - return nil, storedErr - } - return c.auditLogUseCase, nil -} - -// ClientHandler returns the HTTP handler for client management operations. -func (c *Container) ClientHandler() (*authHTTP.ClientHandler, error) { - var err error - c.clientHandlerInit.Do(func() { - c.clientHandler, err = c.initClientHandler() - if err != nil { - c.initErrors["clientHandler"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["clientHandler"]; exists { - return nil, storedErr - } - return c.clientHandler, nil -} - -// TokenHandler returns the HTTP handler for token operations. -func (c *Container) TokenHandler() (*authHTTP.TokenHandler, error) { - var err error - c.tokenHandlerInit.Do(func() { - c.tokenHandler, err = c.initTokenHandler() - if err != nil { - c.initErrors["tokenHandler"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["tokenHandler"]; exists { - return nil, storedErr - } - return c.tokenHandler, nil -} - -// AuditLogHandler returns the HTTP handler for audit log operations. -func (c *Container) AuditLogHandler() (*authHTTP.AuditLogHandler, error) { - var err error - c.auditLogHandlerInit.Do(func() { - c.auditLogHandler, err = c.initAuditLogHandler() - if err != nil { - c.initErrors["auditLogHandler"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["auditLogHandler"]; exists { - return nil, storedErr - } - return c.auditLogHandler, nil -} - -// DekRepository returns the DEK repository based on database driver. -func (c *Container) DekRepository() (secretsUseCase.DekRepository, error) { - var err error - c.dekRepositoryInit.Do(func() { - c.dekRepository, err = c.initDekRepository() - if err != nil { - c.initErrors["dekRepository"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["dekRepository"]; exists { - return nil, storedErr - } - return c.dekRepository, nil -} - -// SecretRepository returns the secret repository based on database driver. -func (c *Container) SecretRepository() (secretsUseCase.SecretRepository, error) { - var err error - c.secretRepositoryInit.Do(func() { - c.secretRepository, err = c.initSecretRepository() - if err != nil { - c.initErrors["secretRepository"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["secretRepository"]; exists { - return nil, storedErr - } - return c.secretRepository, nil -} - -// SecretUseCase returns the secret use case. -func (c *Container) SecretUseCase() (secretsUseCase.SecretUseCase, error) { - var err error - c.secretUseCaseInit.Do(func() { - c.secretUseCase, err = c.initSecretUseCase() - if err != nil { - c.initErrors["secretUseCase"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["secretUseCase"]; exists { - return nil, storedErr - } - return c.secretUseCase, nil -} - -// SecretHandler returns the HTTP handler for secret management operations. -func (c *Container) SecretHandler() (*secretsHTTP.SecretHandler, error) { - var err error - c.secretHandlerInit.Do(func() { - c.secretHandler, err = c.initSecretHandler() - if err != nil { - c.initErrors["secretHandler"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["secretHandler"]; exists { - return nil, storedErr - } - return c.secretHandler, nil -} - // Shutdown performs cleanup of all initialized resources. func (c *Container) Shutdown(ctx context.Context) error { - c.mu.Lock() - defer c.mu.Unlock() - var shutdownErrors []error // Shutdown HTTP server if initialized @@ -726,25 +348,6 @@ func (c *Container) initDB() (*sql.DB, error) { return db, nil } -// initMasterKeyChain loads the master key chain from environment variables. -func (c *Container) initMasterKeyChain() (*cryptoDomain.MasterKeyChain, error) { - // Get KMS service and logger - kmsService := c.KMSService() - logger := c.Logger() - - // Load master key chain with KMS support and fail-fast validation - masterKeyChain, err := cryptoDomain.LoadMasterKeyChain( - context.Background(), - c.config, - kmsService, - logger, - ) - if err != nil { - return nil, fmt.Errorf("failed to load master key chain: %w", err) - } - return masterKeyChain, nil -} - // initTxManager creates the transaction manager using the database connection. func (c *Container) initTxManager() (database.TxManager, error) { db, err := c.DB() @@ -883,894 +486,6 @@ func (c *Container) initHTTPServer() (*http.Server, error) { return server, nil } -// initAEADManager creates the AEAD manager service. -func (c *Container) initAEADManager() cryptoService.AEADManager { - return cryptoService.NewAEADManager() -} - -// initKeyManager creates the key manager service using the AEAD manager. -func (c *Container) initKeyManager() cryptoService.KeyManager { - aeadManager := c.AEADManager() - return cryptoService.NewKeyManager(aeadManager) -} - -// initKMSService creates the KMS service for encrypting/decrypting master keys. -func (c *Container) initKMSService() cryptoService.KMSService { - return cryptoService.NewKMSService() -} - -// initKekRepository creates the KEK repository based on the database driver. -func (c *Container) initKekRepository() (cryptoUseCase.KekRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database for kek repository: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return cryptoPostgreSQL.NewPostgreSQLKekRepository(db), nil - case "mysql": - return cryptoMySQL.NewMySQLKekRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initKekUseCase creates the KEK use case with all its dependencies. -func (c *Container) initKekUseCase() (cryptoUseCase.KekUseCase, error) { - txManager, err := c.TxManager() - if err != nil { - return nil, fmt.Errorf("failed to get tx manager for kek use case: %w", err) - } - - kekRepository, err := c.KekRepository() - if err != nil { - return nil, fmt.Errorf("failed to get kek repository for kek use case: %w", err) - } - - keyManager := c.KeyManager() - - return cryptoUseCase.NewKekUseCase(txManager, kekRepository, keyManager), nil -} - -// initSecretService creates the secret service for authentication. -func (c *Container) initSecretService() authService.SecretService { - return authService.NewSecretService() -} - -// initClientRepository creates the client repository based on the database driver. -func (c *Container) initClientRepository() (authUseCase.ClientRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database for client repository: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return authPostgreSQL.NewPostgreSQLClientRepository(db), nil - case "mysql": - return authMySQL.NewMySQLClientRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initClientUseCase creates the client use case with all its dependencies. -func (c *Container) initClientUseCase() (authUseCase.ClientUseCase, error) { - txManager, err := c.TxManager() - if err != nil { - return nil, fmt.Errorf("failed to get tx manager for client use case: %w", err) - } - - clientRepository, err := c.ClientRepository() - if err != nil { - return nil, fmt.Errorf("failed to get client repository for client use case: %w", err) - } - - secretService := c.SecretService() - - baseUseCase := authUseCase.NewClientUseCase(txManager, clientRepository, secretService) - - // Wrap with metrics if enabled - if c.config.MetricsEnabled { - businessMetrics, err := c.BusinessMetrics() - if err != nil { - return nil, fmt.Errorf("failed to get business metrics for client use case: %w", err) - } - return authUseCase.NewClientUseCaseWithMetrics(baseUseCase, businessMetrics), nil - } - - return baseUseCase, nil -} - -// initTokenService creates the token service for authentication. -func (c *Container) initTokenService() authService.TokenService { - return authService.NewTokenService() -} - -// initTokenRepository creates the token repository based on the database driver. -func (c *Container) initTokenRepository() (authUseCase.TokenRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database for token repository: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return authPostgreSQL.NewPostgreSQLTokenRepository(db), nil - case "mysql": - return authMySQL.NewMySQLTokenRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initAuditLogRepository creates the audit log repository based on the database driver. -func (c *Container) initAuditLogRepository() (authUseCase.AuditLogRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database for audit log repository: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return authPostgreSQL.NewPostgreSQLAuditLogRepository(db), nil - case "mysql": - return authMySQL.NewMySQLAuditLogRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initTokenUseCase creates the token use case with all its dependencies. -func (c *Container) initTokenUseCase() (authUseCase.TokenUseCase, error) { - clientRepository, err := c.ClientRepository() - if err != nil { - return nil, fmt.Errorf("failed to get client repository for token use case: %w", err) - } - - tokenRepository, err := c.TokenRepository() - if err != nil { - return nil, fmt.Errorf("failed to get token repository for token use case: %w", err) - } - - secretService := c.SecretService() - tokenService := c.TokenService() - - baseUseCase := authUseCase.NewTokenUseCase( - c.config, - clientRepository, - tokenRepository, - secretService, - tokenService, - ) - - // Wrap with metrics if enabled - if c.config.MetricsEnabled { - businessMetrics, err := c.BusinessMetrics() - if err != nil { - return nil, fmt.Errorf("failed to get business metrics for token use case: %w", err) - } - return authUseCase.NewTokenUseCaseWithMetrics(baseUseCase, businessMetrics), nil - } - - return baseUseCase, nil -} - -// initAuditLogUseCase creates the audit log use case with all its dependencies. -func (c *Container) initAuditLogUseCase() (authUseCase.AuditLogUseCase, error) { - auditLogRepository, err := c.AuditLogRepository() - if err != nil { - return nil, fmt.Errorf("failed to get audit log repository for audit log use case: %w", err) - } - - // Create audit signer service - auditSigner := authService.NewAuditSigner() - - // Load KEK chain for signature verification - kekChain, err := c.loadKekChain() - if err != nil { - return nil, fmt.Errorf("failed to load kek chain for audit log use case: %w", err) - } - - baseUseCase := authUseCase.NewAuditLogUseCase(auditLogRepository, auditSigner, kekChain) - - // Wrap with metrics if enabled - if c.config.MetricsEnabled { - businessMetrics, err := c.BusinessMetrics() - if err != nil { - return nil, fmt.Errorf("failed to get business metrics for audit log use case: %w", err) - } - return authUseCase.NewAuditLogUseCaseWithMetrics(baseUseCase, businessMetrics), nil - } - - return baseUseCase, nil -} - -// initClientHandler creates the client HTTP handler with all its dependencies. -func (c *Container) initClientHandler() (*authHTTP.ClientHandler, error) { - clientUseCase, err := c.ClientUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get client use case for client handler: %w", err) - } - - auditLogUseCase, err := c.AuditLogUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get audit log use case for client handler: %w", err) - } - - logger := c.Logger() - - return authHTTP.NewClientHandler(clientUseCase, auditLogUseCase, logger), nil -} - -// initTokenHandler creates the token HTTP handler with all its dependencies. -func (c *Container) initTokenHandler() (*authHTTP.TokenHandler, error) { - tokenUseCase, err := c.TokenUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get token use case for token handler: %w", err) - } - - logger := c.Logger() - - return authHTTP.NewTokenHandler(tokenUseCase, logger), nil -} - -// initAuditLogHandler creates the audit log HTTP handler with all its dependencies. -func (c *Container) initAuditLogHandler() (*authHTTP.AuditLogHandler, error) { - auditLogUseCase, err := c.AuditLogUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get audit log use case for audit log handler: %w", err) - } - - logger := c.Logger() - - return authHTTP.NewAuditLogHandler(auditLogUseCase, logger), nil -} - -// initCryptoDekRepository creates the DEK repository for crypto use case based on the database driver. -func (c *Container) initCryptoDekRepository() (cryptoUseCase.DekRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return cryptoPostgreSQL.NewPostgreSQLDekRepository(db), nil - case "mysql": - return cryptoMySQL.NewMySQLDekRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initCryptoDekUseCase creates the DEK use case for the crypto module. -func (c *Container) initCryptoDekUseCase() (cryptoUseCase.DekUseCase, error) { - txManager, err := c.TxManager() - if err != nil { - return nil, fmt.Errorf("failed to get tx manager: %w", err) - } - - dekRepo, err := c.CryptoDekRepository() - if err != nil { - return nil, fmt.Errorf("failed to get crypto dek repository: %w", err) - } - - keyManager := c.KeyManager() - - return cryptoUseCase.NewDekUseCase(txManager, dekRepo, keyManager), nil -} - -// initDekRepository creates the DEK repository based on the database driver. -func (c *Container) initDekRepository() (secretsUseCase.DekRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database for dek repository: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return cryptoPostgreSQL.NewPostgreSQLDekRepository(db), nil - case "mysql": - return cryptoMySQL.NewMySQLDekRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initSecretRepository creates the secret repository based on the database driver. -func (c *Container) initSecretRepository() (secretsUseCase.SecretRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database for secret repository: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return secretsPostgreSQL.NewPostgreSQLSecretRepository(db), nil - case "mysql": - return secretsMySQL.NewMySQLSecretRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initSecretUseCase creates the secret use case with all its dependencies. -func (c *Container) initSecretUseCase() (secretsUseCase.SecretUseCase, error) { - txManager, err := c.TxManager() - if err != nil { - return nil, fmt.Errorf("failed to get tx manager for secret use case: %w", err) - } - - dekRepository, err := c.DekRepository() - if err != nil { - return nil, fmt.Errorf("failed to get dek repository for secret use case: %w", err) - } - - secretRepository, err := c.SecretRepository() - if err != nil { - return nil, fmt.Errorf("failed to get secret repository for secret use case: %w", err) - } - - kekChain, err := c.loadKekChain() - if err != nil { - return nil, fmt.Errorf("failed to load kek chain for secret use case: %w", err) - } - - aeadManager := c.AEADManager() - keyManager := c.KeyManager() - - baseUseCase := secretsUseCase.NewSecretUseCase( - txManager, - dekRepository, - secretRepository, - kekChain, - aeadManager, - keyManager, - cryptoDomain.AESGCM, - ) - - // Wrap with metrics if enabled - if c.config.MetricsEnabled { - businessMetrics, err := c.BusinessMetrics() - if err != nil { - return nil, fmt.Errorf("failed to get business metrics for secret use case: %w", err) - } - return secretsUseCase.NewSecretUseCaseWithMetrics(baseUseCase, businessMetrics), nil - } - - return baseUseCase, nil -} - -// initSecretHandler creates the secret HTTP handler with all its dependencies. -func (c *Container) initSecretHandler() (*secretsHTTP.SecretHandler, error) { - secretUseCase, err := c.SecretUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get secret use case for secret handler: %w", err) - } - - auditLogUseCase, err := c.AuditLogUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get audit log use case for secret handler: %w", err) - } - - logger := c.Logger() - - return secretsHTTP.NewSecretHandler(secretUseCase, auditLogUseCase, logger), nil -} - -// loadKekChain loads all KEKs from the database and creates a KEK chain. -func (c *Container) loadKekChain() (*cryptoDomain.KekChain, error) { - kekUseCase, err := c.KekUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get kek use case: %w", err) - } - - masterKeyChain, err := c.MasterKeyChain() - if err != nil { - return nil, fmt.Errorf("failed to get master key chain: %w", err) - } - - // Unwrap all KEKs using the master key chain - kekChain, err := kekUseCase.Unwrap(context.Background(), masterKeyChain) - if err != nil { - return nil, fmt.Errorf("failed to unwrap keks: %w", err) - } - - return kekChain, nil -} - -// TransitKeyRepository returns the transit key repository instance. -func (c *Container) TransitKeyRepository() (transitUseCase.TransitKeyRepository, error) { - var err error - c.transitKeyRepositoryInit.Do(func() { - c.transitKeyRepository, err = c.initTransitKeyRepository() - if err != nil { - c.initErrors["transitKeyRepository"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["transitKeyRepository"]; exists { - return nil, storedErr - } - return c.transitKeyRepository, nil -} - -// TransitDekRepository returns the DEK repository for transit use case. -func (c *Container) TransitDekRepository() (transitUseCase.DekRepository, error) { - var err error - c.transitDekRepositoryInit.Do(func() { - c.transitDekRepository, err = c.initTransitDekRepository() - if err != nil { - c.initErrors["transitDekRepository"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["transitDekRepository"]; exists { - return nil, storedErr - } - return c.transitDekRepository, nil -} - -// TransitKeyUseCase returns the transit key use case instance. -func (c *Container) TransitKeyUseCase() (transitUseCase.TransitKeyUseCase, error) { - var err error - c.transitKeyUseCaseInit.Do(func() { - c.transitKeyUseCase, err = c.initTransitKeyUseCase() - if err != nil { - c.initErrors["transitKeyUseCase"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["transitKeyUseCase"]; exists { - return nil, storedErr - } - return c.transitKeyUseCase, nil -} - -// TransitKeyHandler returns the transit key HTTP handler instance. -func (c *Container) TransitKeyHandler() (*transitHTTP.TransitKeyHandler, error) { - var err error - c.transitKeyHandlerInit.Do(func() { - c.transitKeyHandler, err = c.initTransitKeyHandler() - if err != nil { - c.initErrors["transitKeyHandler"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["transitKeyHandler"]; exists { - return nil, storedErr - } - return c.transitKeyHandler, nil -} - -// CryptoHandler returns the crypto HTTP handler instance. -func (c *Container) CryptoHandler() (*transitHTTP.CryptoHandler, error) { - var err error - c.cryptoHandlerInit.Do(func() { - c.cryptoHandler, err = c.initCryptoHandler() - if err != nil { - c.initErrors["cryptoHandler"] = err - } - }) - if err != nil { - return nil, err - } - if storedErr, exists := c.initErrors["cryptoHandler"]; exists { - return nil, storedErr - } - return c.cryptoHandler, nil -} - -// initTransitKeyRepository creates the transit key repository based on the database driver. -func (c *Container) initTransitKeyRepository() (transitUseCase.TransitKeyRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database for transit key repository: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return transitPostgreSQL.NewPostgreSQLTransitKeyRepository(db), nil - case "mysql": - return transitMySQL.NewMySQLTransitKeyRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initTransitDekRepository creates the DEK repository for transit use case. -func (c *Container) initTransitDekRepository() (transitUseCase.DekRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database for transit dek repository: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return cryptoPostgreSQL.NewPostgreSQLDekRepository(db), nil - case "mysql": - return cryptoMySQL.NewMySQLDekRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initTransitKeyUseCase creates the transit key use case with all its dependencies. -func (c *Container) initTransitKeyUseCase() (transitUseCase.TransitKeyUseCase, error) { - txManager, err := c.TxManager() - if err != nil { - return nil, fmt.Errorf("failed to get tx manager for transit key use case: %w", err) - } - - transitKeyRepository, err := c.TransitKeyRepository() - if err != nil { - return nil, fmt.Errorf("failed to get transit key repository for transit key use case: %w", err) - } - - dekRepository, err := c.TransitDekRepository() - if err != nil { - return nil, fmt.Errorf("failed to get dek repository for transit key use case: %w", err) - } - - kekChain, err := c.loadKekChain() - if err != nil { - return nil, fmt.Errorf("failed to load kek chain for transit key use case: %w", err) - } - - keyManager := c.KeyManager() - aeadManager := c.AEADManager() - - baseUseCase := transitUseCase.NewTransitKeyUseCase( - txManager, - transitKeyRepository, - dekRepository, - keyManager, - aeadManager, - kekChain, - ) - - // Wrap with metrics if enabled - if c.config.MetricsEnabled { - businessMetrics, err := c.BusinessMetrics() - if err != nil { - return nil, fmt.Errorf("failed to get business metrics for transit key use case: %w", err) - } - return transitUseCase.NewTransitKeyUseCaseWithMetrics(baseUseCase, businessMetrics), nil - } - - return baseUseCase, nil -} - -// initTransitKeyHandler creates the transit key HTTP handler with all its dependencies. -func (c *Container) initTransitKeyHandler() (*transitHTTP.TransitKeyHandler, error) { - transitKeyUseCase, err := c.TransitKeyUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get transit key use case for transit key handler: %w", err) - } - - auditLogUseCase, err := c.AuditLogUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get audit log use case for transit key handler: %w", err) - } - - logger := c.Logger() - - return transitHTTP.NewTransitKeyHandler(transitKeyUseCase, auditLogUseCase, logger), nil -} - -// initCryptoHandler creates the crypto HTTP handler with all its dependencies. -func (c *Container) initCryptoHandler() (*transitHTTP.CryptoHandler, error) { - transitKeyUseCase, err := c.TransitKeyUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get transit key use case for crypto handler: %w", err) - } - - auditLogUseCase, err := c.AuditLogUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get audit log use case for crypto handler: %w", err) - } - - logger := c.Logger() - - return transitHTTP.NewCryptoHandler(transitKeyUseCase, auditLogUseCase, logger), nil -} - -// TokenizationKeyRepository returns the tokenization key repository. -func (c *Container) TokenizationKeyRepository() (tokenizationUseCase.TokenizationKeyRepository, error) { - var err error - c.tokenizationKeyRepositoryInit.Do(func() { - c.tokenizationKeyRepository, err = c.initTokenizationKeyRepository() - if err != nil { - c.initErrors["tokenizationKeyRepository"] = err - } - }) - if err != nil { - return nil, err - } - return c.tokenizationKeyRepository, c.initErrors["tokenizationKeyRepository"] -} - -// TokenizationTokenRepository returns the tokenization token repository. -func (c *Container) TokenizationTokenRepository() (tokenizationUseCase.TokenRepository, error) { - var err error - c.tokenizationTokenRepositoryInit.Do(func() { - c.tokenizationTokenRepository, err = c.initTokenizationTokenRepository() - if err != nil { - c.initErrors["tokenizationTokenRepository"] = err - } - }) - if err != nil { - return nil, err - } - return c.tokenizationTokenRepository, c.initErrors["tokenizationTokenRepository"] -} - -// TokenizationDekRepository returns the DEK repository for tokenization use case. -func (c *Container) TokenizationDekRepository() (tokenizationUseCase.DekRepository, error) { - var err error - c.tokenizationDekRepositoryInit.Do(func() { - c.tokenizationDekRepository, err = c.initTokenizationDekRepository() - if err != nil { - c.initErrors["tokenizationDekRepository"] = err - } - }) - if err != nil { - return nil, err - } - return c.tokenizationDekRepository, c.initErrors["tokenizationDekRepository"] -} - -// TokenizationKeyUseCase returns the tokenization key use case. -func (c *Container) TokenizationKeyUseCase() (tokenizationUseCase.TokenizationKeyUseCase, error) { - var err error - c.tokenizationKeyUseCaseInit.Do(func() { - c.tokenizationKeyUseCase, err = c.initTokenizationKeyUseCase() - if err != nil { - c.initErrors["tokenizationKeyUseCase"] = err - } - }) - if err != nil { - return nil, err - } - return c.tokenizationKeyUseCase, c.initErrors["tokenizationKeyUseCase"] -} - -// TokenizationUseCase returns the tokenization use case. -func (c *Container) TokenizationUseCase() (tokenizationUseCase.TokenizationUseCase, error) { - var err error - c.tokenizationUseCaseInit.Do(func() { - c.tokenizationUseCase, err = c.initTokenizationUseCase() - if err != nil { - c.initErrors["tokenizationUseCase"] = err - } - }) - if err != nil { - return nil, err - } - return c.tokenizationUseCase, c.initErrors["tokenizationUseCase"] -} - -// TokenizationKeyHandler returns the tokenization key HTTP handler. -func (c *Container) TokenizationKeyHandler() (*tokenizationHTTP.TokenizationKeyHandler, error) { - var err error - c.tokenizationKeyHandlerInit.Do(func() { - c.tokenizationKeyHandler, err = c.initTokenizationKeyHandler() - if err != nil { - c.initErrors["tokenizationKeyHandler"] = err - } - }) - if err != nil { - return nil, err - } - return c.tokenizationKeyHandler, c.initErrors["tokenizationKeyHandler"] -} - -// TokenizationHandler returns the tokenization HTTP handler. -func (c *Container) TokenizationHandler() (*tokenizationHTTP.TokenizationHandler, error) { - var err error - c.tokenizationHandlerInit.Do(func() { - c.tokenizationHandler, err = c.initTokenizationHandler() - if err != nil { - c.initErrors["tokenizationHandler"] = err - } - }) - if err != nil { - return nil, err - } - return c.tokenizationHandler, c.initErrors["tokenizationHandler"] -} - -// initTokenizationKeyRepository creates the tokenization key repository. -func (c *Container) initTokenizationKeyRepository() (tokenizationUseCase.TokenizationKeyRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database for tokenization key repository: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return tokenizationPostgreSQL.NewPostgreSQLTokenizationKeyRepository(db), nil - case "mysql": - return tokenizationMySQL.NewMySQLTokenizationKeyRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initTokenizationTokenRepository creates the tokenization token repository. -func (c *Container) initTokenizationTokenRepository() (tokenizationUseCase.TokenRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database for tokenization token repository: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return tokenizationPostgreSQL.NewPostgreSQLTokenRepository(db), nil - case "mysql": - return tokenizationMySQL.NewMySQLTokenRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initTokenizationDekRepository creates the DEK repository for tokenization use case. -func (c *Container) initTokenizationDekRepository() (tokenizationUseCase.DekRepository, error) { - db, err := c.DB() - if err != nil { - return nil, fmt.Errorf("failed to get database for tokenization dek repository: %w", err) - } - - switch c.config.DBDriver { - case "postgres": - return cryptoPostgreSQL.NewPostgreSQLDekRepository(db), nil - case "mysql": - return cryptoMySQL.NewMySQLDekRepository(db), nil - default: - return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) - } -} - -// initTokenizationKeyUseCase creates the tokenization key use case. -func (c *Container) initTokenizationKeyUseCase() (tokenizationUseCase.TokenizationKeyUseCase, error) { - txManager, err := c.TxManager() - if err != nil { - return nil, fmt.Errorf("failed to get tx manager for tokenization key use case: %w", err) - } - - tokenizationKeyRepository, err := c.TokenizationKeyRepository() - if err != nil { - return nil, fmt.Errorf( - "failed to get tokenization key repository for tokenization key use case: %w", - err, - ) - } - - dekRepository, err := c.TokenizationDekRepository() - if err != nil { - return nil, fmt.Errorf("failed to get dek repository for tokenization key use case: %w", err) - } - - keyManager := c.KeyManager() - - kekChain, err := c.loadKekChain() - if err != nil { - return nil, fmt.Errorf("failed to load kek chain for tokenization key use case: %w", err) - } - - baseUseCase := tokenizationUseCase.NewTokenizationKeyUseCase( - txManager, - tokenizationKeyRepository, - dekRepository, - keyManager, - kekChain, - ) - - // Wrap with metrics if enabled - if c.config.MetricsEnabled { - businessMetrics, err := c.BusinessMetrics() - if err != nil { - return nil, fmt.Errorf("failed to get business metrics for tokenization key use case: %w", err) - } - return tokenizationUseCase.NewTokenizationKeyUseCaseWithMetrics(baseUseCase, businessMetrics), nil - } - - return baseUseCase, nil -} - -// initTokenizationUseCase creates the tokenization use case. -func (c *Container) initTokenizationUseCase() (tokenizationUseCase.TokenizationUseCase, error) { - txManager, err := c.TxManager() - if err != nil { - return nil, fmt.Errorf("failed to get tx manager for tokenization use case: %w", err) - } - - tokenizationKeyRepository, err := c.TokenizationKeyRepository() - if err != nil { - return nil, fmt.Errorf("failed to get tokenization key repository for tokenization use case: %w", err) - } - - tokenRepository, err := c.TokenizationTokenRepository() - if err != nil { - return nil, fmt.Errorf("failed to get token repository for tokenization use case: %w", err) - } - - dekRepository, err := c.TokenizationDekRepository() - if err != nil { - return nil, fmt.Errorf("failed to get dek repository for tokenization use case: %w", err) - } - - aeadManager := c.AEADManager() - - keyManager := c.KeyManager() - - hashService := tokenizationUseCase.NewSHA256HashService() - - kekChain, err := c.loadKekChain() - if err != nil { - return nil, fmt.Errorf("failed to load kek chain for tokenization use case: %w", err) - } - - baseUseCase := tokenizationUseCase.NewTokenizationUseCase( - txManager, - tokenizationKeyRepository, - tokenRepository, - dekRepository, - aeadManager, - keyManager, - hashService, - kekChain, - ) - - // Wrap with metrics if enabled - if c.config.MetricsEnabled { - businessMetrics, err := c.BusinessMetrics() - if err != nil { - return nil, fmt.Errorf("failed to get business metrics for tokenization use case: %w", err) - } - return tokenizationUseCase.NewTokenizationUseCaseWithMetrics(baseUseCase, businessMetrics), nil - } - - return baseUseCase, nil -} - -// initTokenizationKeyHandler creates the tokenization key HTTP handler. -func (c *Container) initTokenizationKeyHandler() (*tokenizationHTTP.TokenizationKeyHandler, error) { - tokenizationKeyUseCase, err := c.TokenizationKeyUseCase() - if err != nil { - return nil, fmt.Errorf( - "failed to get tokenization key use case for tokenization key handler: %w", - err, - ) - } - - logger := c.Logger() - - return tokenizationHTTP.NewTokenizationKeyHandler(tokenizationKeyUseCase, logger), nil -} - -// initTokenizationHandler creates the tokenization HTTP handler. -func (c *Container) initTokenizationHandler() (*tokenizationHTTP.TokenizationHandler, error) { - tokenizationUseCase, err := c.TokenizationUseCase() - if err != nil { - return nil, fmt.Errorf("failed to get tokenization use case for tokenization handler: %w", err) - } - - logger := c.Logger() - - return tokenizationHTTP.NewTokenizationHandler(tokenizationUseCase, logger), nil -} - // initMetricsServer creates the Metrics server if metrics are enabled. func (c *Container) initMetricsServer() (*http.MetricsServer, error) { if !c.config.MetricsEnabled { diff --git a/internal/app/di_auth.go b/internal/app/di_auth.go new file mode 100644 index 0000000..aff4181 --- /dev/null +++ b/internal/app/di_auth.go @@ -0,0 +1,384 @@ +package app + +import ( + "fmt" + + authHTTP "github.com/allisson/secrets/internal/auth/http" + authMySQL "github.com/allisson/secrets/internal/auth/repository/mysql" + authPostgreSQL "github.com/allisson/secrets/internal/auth/repository/postgresql" + authService "github.com/allisson/secrets/internal/auth/service" + authUseCase "github.com/allisson/secrets/internal/auth/usecase" +) + +// SecretService returns the secret service for authentication operations. +func (c *Container) SecretService() authService.SecretService { + c.secretServiceInit.Do(func() { + c.secretService = c.initSecretService() + }) + return c.secretService +} + +// ClientRepository returns the client repository based on database driver. +func (c *Container) ClientRepository() (authUseCase.ClientRepository, error) { + var err error + c.clientRepositoryInit.Do(func() { + c.clientRepository, err = c.initClientRepository() + if err != nil { + c.initErrors["clientRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["clientRepository"]; exists { + return nil, storedErr + } + return c.clientRepository, nil +} + +// ClientUseCase returns the client use case. +func (c *Container) ClientUseCase() (authUseCase.ClientUseCase, error) { + var err error + c.clientUseCaseInit.Do(func() { + c.clientUseCase, err = c.initClientUseCase() + if err != nil { + c.initErrors["clientUseCase"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["clientUseCase"]; exists { + return nil, storedErr + } + return c.clientUseCase, nil +} + +// TokenService returns the token service for authentication operations. +func (c *Container) TokenService() authService.TokenService { + c.tokenServiceInit.Do(func() { + c.tokenService = c.initTokenService() + }) + return c.tokenService +} + +// TokenRepository returns the token repository based on database driver. +func (c *Container) TokenRepository() (authUseCase.TokenRepository, error) { + var err error + c.tokenRepositoryInit.Do(func() { + c.tokenRepository, err = c.initTokenRepository() + if err != nil { + c.initErrors["tokenRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["tokenRepository"]; exists { + return nil, storedErr + } + return c.tokenRepository, nil +} + +// AuditLogRepository returns the audit log repository based on database driver. +func (c *Container) AuditLogRepository() (authUseCase.AuditLogRepository, error) { + var err error + c.auditLogRepositoryInit.Do(func() { + c.auditLogRepository, err = c.initAuditLogRepository() + if err != nil { + c.initErrors["auditLogRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["auditLogRepository"]; exists { + return nil, storedErr + } + return c.auditLogRepository, nil +} + +// TokenUseCase returns the token use case. +func (c *Container) TokenUseCase() (authUseCase.TokenUseCase, error) { + var err error + c.tokenUseCaseInit.Do(func() { + c.tokenUseCase, err = c.initTokenUseCase() + if err != nil { + c.initErrors["tokenUseCase"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["tokenUseCase"]; exists { + return nil, storedErr + } + return c.tokenUseCase, nil +} + +// AuditLogUseCase returns the audit log use case. +func (c *Container) AuditLogUseCase() (authUseCase.AuditLogUseCase, error) { + var err error + c.auditLogUseCaseInit.Do(func() { + c.auditLogUseCase, err = c.initAuditLogUseCase() + if err != nil { + c.initErrors["auditLogUseCase"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["auditLogUseCase"]; exists { + return nil, storedErr + } + return c.auditLogUseCase, nil +} + +// ClientHandler returns the HTTP handler for client management operations. +func (c *Container) ClientHandler() (*authHTTP.ClientHandler, error) { + var err error + c.clientHandlerInit.Do(func() { + c.clientHandler, err = c.initClientHandler() + if err != nil { + c.initErrors["clientHandler"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["clientHandler"]; exists { + return nil, storedErr + } + return c.clientHandler, nil +} + +// TokenHandler returns the HTTP handler for token operations. +func (c *Container) TokenHandler() (*authHTTP.TokenHandler, error) { + var err error + c.tokenHandlerInit.Do(func() { + c.tokenHandler, err = c.initTokenHandler() + if err != nil { + c.initErrors["tokenHandler"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["tokenHandler"]; exists { + return nil, storedErr + } + return c.tokenHandler, nil +} + +// AuditLogHandler returns the HTTP handler for audit log operations. +func (c *Container) AuditLogHandler() (*authHTTP.AuditLogHandler, error) { + var err error + c.auditLogHandlerInit.Do(func() { + c.auditLogHandler, err = c.initAuditLogHandler() + if err != nil { + c.initErrors["auditLogHandler"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["auditLogHandler"]; exists { + return nil, storedErr + } + return c.auditLogHandler, nil +} + +// initSecretService creates the secret service for authentication. +func (c *Container) initSecretService() authService.SecretService { + return authService.NewSecretService() +} + +// initClientRepository creates the client repository based on the database driver. +func (c *Container) initClientRepository() (authUseCase.ClientRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database for client repository: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return authPostgreSQL.NewPostgreSQLClientRepository(db), nil + case "mysql": + return authMySQL.NewMySQLClientRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initClientUseCase creates the client use case with all its dependencies. +func (c *Container) initClientUseCase() (authUseCase.ClientUseCase, error) { + txManager, err := c.TxManager() + if err != nil { + return nil, fmt.Errorf("failed to get tx manager for client use case: %w", err) + } + + clientRepository, err := c.ClientRepository() + if err != nil { + return nil, fmt.Errorf("failed to get client repository for client use case: %w", err) + } + + secretService := c.SecretService() + + baseUseCase := authUseCase.NewClientUseCase(txManager, clientRepository, secretService) + + // Wrap with metrics if enabled + if c.config.MetricsEnabled { + businessMetrics, err := c.BusinessMetrics() + if err != nil { + return nil, fmt.Errorf("failed to get business metrics for client use case: %w", err) + } + return authUseCase.NewClientUseCaseWithMetrics(baseUseCase, businessMetrics), nil + } + + return baseUseCase, nil +} + +// initTokenService creates the token service for authentication. +func (c *Container) initTokenService() authService.TokenService { + return authService.NewTokenService() +} + +// initTokenRepository creates the token repository based on the database driver. +func (c *Container) initTokenRepository() (authUseCase.TokenRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database for token repository: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return authPostgreSQL.NewPostgreSQLTokenRepository(db), nil + case "mysql": + return authMySQL.NewMySQLTokenRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initAuditLogRepository creates the audit log repository based on the database driver. +func (c *Container) initAuditLogRepository() (authUseCase.AuditLogRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database for audit log repository: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return authPostgreSQL.NewPostgreSQLAuditLogRepository(db), nil + case "mysql": + return authMySQL.NewMySQLAuditLogRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initTokenUseCase creates the token use case with all its dependencies. +func (c *Container) initTokenUseCase() (authUseCase.TokenUseCase, error) { + clientRepository, err := c.ClientRepository() + if err != nil { + return nil, fmt.Errorf("failed to get client repository for token use case: %w", err) + } + + tokenRepository, err := c.TokenRepository() + if err != nil { + return nil, fmt.Errorf("failed to get token repository for token use case: %w", err) + } + + secretService := c.SecretService() + tokenService := c.TokenService() + + baseUseCase := authUseCase.NewTokenUseCase( + c.config, + clientRepository, + tokenRepository, + secretService, + tokenService, + ) + + // Wrap with metrics if enabled + if c.config.MetricsEnabled { + businessMetrics, err := c.BusinessMetrics() + if err != nil { + return nil, fmt.Errorf("failed to get business metrics for token use case: %w", err) + } + return authUseCase.NewTokenUseCaseWithMetrics(baseUseCase, businessMetrics), nil + } + + return baseUseCase, nil +} + +// initAuditLogUseCase creates the audit log use case with all its dependencies. +func (c *Container) initAuditLogUseCase() (authUseCase.AuditLogUseCase, error) { + auditLogRepository, err := c.AuditLogRepository() + if err != nil { + return nil, fmt.Errorf("failed to get audit log repository for audit log use case: %w", err) + } + + // Create audit signer service + auditSigner := authService.NewAuditSigner() + + // Load KEK chain for signature verification + kekChain, err := c.loadKekChain() + if err != nil { + return nil, fmt.Errorf("failed to load kek chain for audit log use case: %w", err) + } + + baseUseCase := authUseCase.NewAuditLogUseCase(auditLogRepository, auditSigner, kekChain) + + // Wrap with metrics if enabled + if c.config.MetricsEnabled { + businessMetrics, err := c.BusinessMetrics() + if err != nil { + return nil, fmt.Errorf("failed to get business metrics for audit log use case: %w", err) + } + return authUseCase.NewAuditLogUseCaseWithMetrics(baseUseCase, businessMetrics), nil + } + + return baseUseCase, nil +} + +// initClientHandler creates the client HTTP handler with all its dependencies. +func (c *Container) initClientHandler() (*authHTTP.ClientHandler, error) { + clientUseCase, err := c.ClientUseCase() + if err != nil { + return nil, fmt.Errorf("failed to get client use case for client handler: %w", err) + } + + auditLogUseCase, err := c.AuditLogUseCase() + if err != nil { + return nil, fmt.Errorf("failed to get audit log use case for client handler: %w", err) + } + + logger := c.Logger() + + return authHTTP.NewClientHandler(clientUseCase, auditLogUseCase, logger), nil +} + +// initTokenHandler creates the token HTTP handler with all its dependencies. +func (c *Container) initTokenHandler() (*authHTTP.TokenHandler, error) { + tokenUseCase, err := c.TokenUseCase() + if err != nil { + return nil, fmt.Errorf("failed to get token use case for token handler: %w", err) + } + + logger := c.Logger() + + return authHTTP.NewTokenHandler(tokenUseCase, logger), nil +} + +// initAuditLogHandler creates the audit log HTTP handler with all its dependencies. +func (c *Container) initAuditLogHandler() (*authHTTP.AuditLogHandler, error) { + auditLogUseCase, err := c.AuditLogUseCase() + if err != nil { + return nil, fmt.Errorf("failed to get audit log use case for audit log handler: %w", err) + } + + logger := c.Logger() + + return authHTTP.NewAuditLogHandler(auditLogUseCase, logger), nil +} diff --git a/internal/app/di_crypto.go b/internal/app/di_crypto.go new file mode 100644 index 0000000..5416bca --- /dev/null +++ b/internal/app/di_crypto.go @@ -0,0 +1,250 @@ +package app + +import ( + "context" + "fmt" + + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" + cryptoMySQL "github.com/allisson/secrets/internal/crypto/repository/mysql" + cryptoPostgreSQL "github.com/allisson/secrets/internal/crypto/repository/postgresql" + cryptoService "github.com/allisson/secrets/internal/crypto/service" + cryptoUseCase "github.com/allisson/secrets/internal/crypto/usecase" +) + +// MasterKeyChain returns the master key chain loaded from environment variables. +func (c *Container) MasterKeyChain() (*cryptoDomain.MasterKeyChain, error) { + var err error + c.masterKeyChainInit.Do(func() { + c.masterKeyChain, err = c.initMasterKeyChain() + if err != nil { + c.initErrors["masterKeyChain"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["masterKeyChain"]; exists { + return nil, storedErr + } + return c.masterKeyChain, nil +} + +// AEADManager returns the AEAD manager service. +func (c *Container) AEADManager() cryptoService.AEADManager { + c.aeadManagerInit.Do(func() { + c.aeadManager = c.initAEADManager() + }) + return c.aeadManager +} + +// KeyManager returns the key manager service. +func (c *Container) KeyManager() cryptoService.KeyManager { + c.keyManagerInit.Do(func() { + c.keyManager = c.initKeyManager() + }) + return c.keyManager +} + +// KMSService returns the KMS service. +func (c *Container) KMSService() cryptoService.KMSService { + c.kmsServiceInit.Do(func() { + c.kmsService = c.initKMSService() + }) + return c.kmsService +} + +// KekRepository returns the KEK repository. +func (c *Container) KekRepository() (cryptoUseCase.KekRepository, error) { + var err error + c.kekRepositoryInit.Do(func() { + c.kekRepository, err = c.initKekRepository() + if err != nil { + c.initErrors["kekRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["kekRepository"]; exists { + return nil, storedErr + } + return c.kekRepository, nil +} + +// KekUseCase returns the KEK use case. +func (c *Container) KekUseCase() (cryptoUseCase.KekUseCase, error) { + var err error + c.kekUseCaseInit.Do(func() { + c.kekUseCase, err = c.initKekUseCase() + if err != nil { + c.initErrors["kekUseCase"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["kekUseCase"]; exists { + return nil, storedErr + } + return c.kekUseCase, nil +} + +// CryptoDekRepository returns the DEK repository for the crypto use case based on database driver. +func (c *Container) CryptoDekRepository() (cryptoUseCase.DekRepository, error) { + var err error + c.cryptoDekRepositoryInit.Do(func() { + c.cryptoDekRepository, err = c.initCryptoDekRepository() + if err != nil { + c.initErrors["cryptoDekRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["cryptoDekRepository"]; exists { + return nil, storedErr + } + return c.cryptoDekRepository, nil +} + +// CryptoDekUseCase returns the DEK use case for the crypto module. +func (c *Container) CryptoDekUseCase() (cryptoUseCase.DekUseCase, error) { + var err error + c.cryptoDekUseCaseInit.Do(func() { + c.cryptoDekUseCase, err = c.initCryptoDekUseCase() + if err != nil { + c.initErrors["cryptoDekUseCase"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["cryptoDekUseCase"]; exists { + return nil, storedErr + } + return c.cryptoDekUseCase, nil +} + +// initMasterKeyChain loads the master key chain from environment variables. +func (c *Container) initMasterKeyChain() (*cryptoDomain.MasterKeyChain, error) { + // Get KMS service and logger + kmsService := c.KMSService() + logger := c.Logger() + + // Load master key chain with KMS support and fail-fast validation + masterKeyChain, err := cryptoDomain.LoadMasterKeyChain( + context.Background(), + c.config, + kmsService, + logger, + ) + if err != nil { + return nil, fmt.Errorf("failed to load master key chain: %w", err) + } + return masterKeyChain, nil +} + +// initAEADManager creates the AEAD manager service. +func (c *Container) initAEADManager() cryptoService.AEADManager { + return cryptoService.NewAEADManager() +} + +// initKeyManager creates the key manager service using the AEAD manager. +func (c *Container) initKeyManager() cryptoService.KeyManager { + aeadManager := c.AEADManager() + return cryptoService.NewKeyManager(aeadManager) +} + +// initKMSService creates the KMS service for encrypting/decrypting master keys. +func (c *Container) initKMSService() cryptoService.KMSService { + return cryptoService.NewKMSService() +} + +// initKekRepository creates the KEK repository based on the database driver. +func (c *Container) initKekRepository() (cryptoUseCase.KekRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database for kek repository: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return cryptoPostgreSQL.NewPostgreSQLKekRepository(db), nil + case "mysql": + return cryptoMySQL.NewMySQLKekRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initKekUseCase creates the KEK use case with all its dependencies. +func (c *Container) initKekUseCase() (cryptoUseCase.KekUseCase, error) { + txManager, err := c.TxManager() + if err != nil { + return nil, fmt.Errorf("failed to get tx manager for kek use case: %w", err) + } + + kekRepository, err := c.KekRepository() + if err != nil { + return nil, fmt.Errorf("failed to get kek repository for kek use case: %w", err) + } + + keyManager := c.KeyManager() + + return cryptoUseCase.NewKekUseCase(txManager, kekRepository, keyManager), nil +} + +// initCryptoDekRepository creates the DEK repository for crypto use case based on the database driver. +func (c *Container) initCryptoDekRepository() (cryptoUseCase.DekRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return cryptoPostgreSQL.NewPostgreSQLDekRepository(db), nil + case "mysql": + return cryptoMySQL.NewMySQLDekRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initCryptoDekUseCase creates the DEK use case for the crypto module. +func (c *Container) initCryptoDekUseCase() (cryptoUseCase.DekUseCase, error) { + txManager, err := c.TxManager() + if err != nil { + return nil, fmt.Errorf("failed to get tx manager: %w", err) + } + + dekRepo, err := c.CryptoDekRepository() + if err != nil { + return nil, fmt.Errorf("failed to get crypto dek repository: %w", err) + } + + keyManager := c.KeyManager() + + return cryptoUseCase.NewDekUseCase(txManager, dekRepo, keyManager), nil +} + +// loadKekChain loads all KEKs from the database and creates a KEK chain. +func (c *Container) loadKekChain() (*cryptoDomain.KekChain, error) { + kekUseCase, err := c.KekUseCase() + if err != nil { + return nil, fmt.Errorf("failed to get kek use case: %w", err) + } + + masterKeyChain, err := c.MasterKeyChain() + if err != nil { + return nil, fmt.Errorf("failed to get master key chain: %w", err) + } + + // Unwrap all KEKs using the master key chain + kekChain, err := kekUseCase.Unwrap(context.Background(), masterKeyChain) + if err != nil { + return nil, fmt.Errorf("failed to unwrap keks: %w", err) + } + + return kekChain, nil +} diff --git a/internal/app/di_secrets.go b/internal/app/di_secrets.go new file mode 100644 index 0000000..9cfacc2 --- /dev/null +++ b/internal/app/di_secrets.go @@ -0,0 +1,183 @@ +package app + +import ( + "fmt" + + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" + cryptoMySQL "github.com/allisson/secrets/internal/crypto/repository/mysql" + cryptoPostgreSQL "github.com/allisson/secrets/internal/crypto/repository/postgresql" + secretsHTTP "github.com/allisson/secrets/internal/secrets/http" + secretsMySQL "github.com/allisson/secrets/internal/secrets/repository/mysql" + secretsPostgreSQL "github.com/allisson/secrets/internal/secrets/repository/postgresql" + secretsUseCase "github.com/allisson/secrets/internal/secrets/usecase" +) + +// DekRepository returns the DEK repository based on database driver. +func (c *Container) DekRepository() (secretsUseCase.DekRepository, error) { + var err error + c.dekRepositoryInit.Do(func() { + c.dekRepository, err = c.initDekRepository() + if err != nil { + c.initErrors["dekRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["dekRepository"]; exists { + return nil, storedErr + } + return c.dekRepository, nil +} + +// SecretRepository returns the secret repository based on database driver. +func (c *Container) SecretRepository() (secretsUseCase.SecretRepository, error) { + var err error + c.secretRepositoryInit.Do(func() { + c.secretRepository, err = c.initSecretRepository() + if err != nil { + c.initErrors["secretRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["secretRepository"]; exists { + return nil, storedErr + } + return c.secretRepository, nil +} + +// SecretUseCase returns the secret use case. +func (c *Container) SecretUseCase() (secretsUseCase.SecretUseCase, error) { + var err error + c.secretUseCaseInit.Do(func() { + c.secretUseCase, err = c.initSecretUseCase() + if err != nil { + c.initErrors["secretUseCase"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["secretUseCase"]; exists { + return nil, storedErr + } + return c.secretUseCase, nil +} + +// SecretHandler returns the HTTP handler for secret management operations. +func (c *Container) SecretHandler() (*secretsHTTP.SecretHandler, error) { + var err error + c.secretHandlerInit.Do(func() { + c.secretHandler, err = c.initSecretHandler() + if err != nil { + c.initErrors["secretHandler"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["secretHandler"]; exists { + return nil, storedErr + } + return c.secretHandler, nil +} + +// initDekRepository creates the DEK repository based on the database driver. +func (c *Container) initDekRepository() (secretsUseCase.DekRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database for dek repository: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return cryptoPostgreSQL.NewPostgreSQLDekRepository(db), nil + case "mysql": + return cryptoMySQL.NewMySQLDekRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initSecretRepository creates the secret repository based on the database driver. +func (c *Container) initSecretRepository() (secretsUseCase.SecretRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database for secret repository: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return secretsPostgreSQL.NewPostgreSQLSecretRepository(db), nil + case "mysql": + return secretsMySQL.NewMySQLSecretRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initSecretUseCase creates the secret use case with all its dependencies. +func (c *Container) initSecretUseCase() (secretsUseCase.SecretUseCase, error) { + txManager, err := c.TxManager() + if err != nil { + return nil, fmt.Errorf("failed to get tx manager for secret use case: %w", err) + } + + dekRepository, err := c.DekRepository() + if err != nil { + return nil, fmt.Errorf("failed to get dek repository for secret use case: %w", err) + } + + secretRepository, err := c.SecretRepository() + if err != nil { + return nil, fmt.Errorf("failed to get secret repository for secret use case: %w", err) + } + + kekChain, err := c.loadKekChain() + if err != nil { + return nil, fmt.Errorf("failed to load kek chain for secret use case: %w", err) + } + + aeadManager := c.AEADManager() + keyManager := c.KeyManager() + + baseUseCase := secretsUseCase.NewSecretUseCase( + txManager, + dekRepository, + secretRepository, + kekChain, + aeadManager, + keyManager, + cryptoDomain.AESGCM, + ) + + // Wrap with metrics if enabled + if c.config.MetricsEnabled { + businessMetrics, err := c.BusinessMetrics() + if err != nil { + return nil, fmt.Errorf("failed to get business metrics for secret use case: %w", err) + } + return secretsUseCase.NewSecretUseCaseWithMetrics(baseUseCase, businessMetrics), nil + } + + return baseUseCase, nil +} + +// initSecretHandler creates the secret HTTP handler with all its dependencies. +func (c *Container) initSecretHandler() (*secretsHTTP.SecretHandler, error) { + secretUseCase, err := c.SecretUseCase() + if err != nil { + return nil, fmt.Errorf("failed to get secret use case for secret handler: %w", err) + } + + auditLogUseCase, err := c.AuditLogUseCase() + if err != nil { + return nil, fmt.Errorf("failed to get audit log use case for secret handler: %w", err) + } + + logger := c.Logger() + + return secretsHTTP.NewSecretHandler(secretUseCase, auditLogUseCase, logger), nil +} diff --git a/internal/app/di_test.go b/internal/app/di_test.go index 5414638..1932828 100644 --- a/internal/app/di_test.go +++ b/internal/app/di_test.go @@ -537,3 +537,136 @@ func TestContainerShutdownWithMasterKeyChain(t *testing.T) { // After shutdown, the key chain should be closed (keys should be zeroed) // We can't directly verify that keys are zeroed, but we verify that Shutdown ran without panic } + +// TestContainerAuthComponents verifies that auth components can be retrieved from the container. +func TestContainerAuthComponents(t *testing.T) { + cfg := &config.Config{ + LogLevel: "info", + } + + container := NewContainer(cfg) + + // SecretService + secretService := container.SecretService() + if secretService == nil { + t.Error("expected non-nil secret service") + } + + // TokenService + tokenService := container.TokenService() + if tokenService == nil { + t.Error("expected non-nil token service") + } +} + +// TestContainerSecretsComponents verifies that secrets components can be retrieved from the container. +func TestContainerSecretsComponents(t *testing.T) { + cfg := &config.Config{ + LogLevel: "info", + } + + container := NewContainer(cfg) + + // Since repositories need a DB, we expect errors if DB is not and cannot be connected + cfg.DBDriver = "invalid" + + _, err := container.DekRepository() + if err == nil { + t.Error("expected error for dek repository with invalid db config") + } + + _, err = container.SecretRepository() + if err == nil { + t.Error("expected error for secret repository with invalid db config") + } + + _, err = container.SecretUseCase() + if err == nil { + t.Error("expected error for secret use case with invalid db config") + } + + _, err = container.SecretHandler() + if err == nil { + t.Error("expected error for secret handler with invalid db config") + } +} + +// TestContainerTransitComponents verifies that transit components can be retrieved from the container. +func TestContainerTransitComponents(t *testing.T) { + cfg := &config.Config{ + LogLevel: "info", + DBDriver: "invalid", + } + + container := NewContainer(cfg) + + _, err := container.TransitKeyRepository() + if err == nil { + t.Error("expected error for transit key repository with invalid db config") + } + + _, err = container.TransitDekRepository() + if err == nil { + t.Error("expected error for transit dek repository with invalid db config") + } + + _, err = container.TransitKeyUseCase() + if err == nil { + t.Error("expected error for transit key use case with invalid db config") + } + + _, err = container.TransitKeyHandler() + if err == nil { + t.Error("expected error for transit key handler with invalid db config") + } + + _, err = container.CryptoHandler() + if err == nil { + t.Error("expected error for crypto handler with invalid db config") + } +} + +// TestContainerTokenizationComponents verifies that tokenization components can be retrieved from the container. +func TestContainerTokenizationComponents(t *testing.T) { + cfg := &config.Config{ + LogLevel: "info", + DBDriver: "invalid", + } + + container := NewContainer(cfg) + + _, err := container.TokenizationKeyRepository() + if err == nil { + t.Error("expected error for tokenization key repository with invalid db config") + } + + _, err = container.TokenizationTokenRepository() + if err == nil { + t.Error("expected error for tokenization token repository with invalid db config") + } + + _, err = container.TokenizationDekRepository() + if err == nil { + t.Error("expected error for tokenization dek repository with invalid db config") + } + + _, err = container.TokenizationKeyUseCase() + if err == nil { + t.Error("expected error for tokenization key use case with invalid db config") + } + + _, err = container.TokenizationUseCase() + if err == nil { + t.Error("expected error for tokenization use case with invalid db config") + } + + _, err = container.TokenizationKeyHandler() + if err == nil { + t.Error("expected error for tokenization key handler with invalid db config") + } + + _, err = container.TokenizationHandler() + if err == nil { + t.Error("expected error for tokenization handler with invalid db config") + } +} diff --git a/internal/app/di_tokenization.go b/internal/app/di_tokenization.go new file mode 100644 index 0000000..19932fc --- /dev/null +++ b/internal/app/di_tokenization.go @@ -0,0 +1,308 @@ +package app + +import ( + "fmt" + + cryptoMySQL "github.com/allisson/secrets/internal/crypto/repository/mysql" + cryptoPostgreSQL "github.com/allisson/secrets/internal/crypto/repository/postgresql" + tokenizationHTTP "github.com/allisson/secrets/internal/tokenization/http" + tokenizationMySQL "github.com/allisson/secrets/internal/tokenization/repository/mysql" + tokenizationPostgreSQL "github.com/allisson/secrets/internal/tokenization/repository/postgresql" + tokenizationUseCase "github.com/allisson/secrets/internal/tokenization/usecase" +) + +// TokenizationKeyRepository returns the tokenization key repository. +func (c *Container) TokenizationKeyRepository() (tokenizationUseCase.TokenizationKeyRepository, error) { + var err error + c.tokenizationKeyRepositoryInit.Do(func() { + c.tokenizationKeyRepository, err = c.initTokenizationKeyRepository() + if err != nil { + c.initErrors["tokenizationKeyRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["tokenizationKeyRepository"]; exists { + return nil, storedErr + } + return c.tokenizationKeyRepository, nil +} + +// TokenizationTokenRepository returns the tokenization token repository. +func (c *Container) TokenizationTokenRepository() (tokenizationUseCase.TokenRepository, error) { + var err error + c.tokenizationTokenRepositoryInit.Do(func() { + c.tokenizationTokenRepository, err = c.initTokenizationTokenRepository() + if err != nil { + c.initErrors["tokenizationTokenRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["tokenizationTokenRepository"]; exists { + return nil, storedErr + } + return c.tokenizationTokenRepository, nil +} + +// TokenizationDekRepository returns the DEK repository for tokenization use case. +func (c *Container) TokenizationDekRepository() (tokenizationUseCase.DekRepository, error) { + var err error + c.tokenizationDekRepositoryInit.Do(func() { + c.tokenizationDekRepository, err = c.initTokenizationDekRepository() + if err != nil { + c.initErrors["tokenizationDekRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["tokenizationDekRepository"]; exists { + return nil, storedErr + } + return c.tokenizationDekRepository, nil +} + +// TokenizationKeyUseCase returns the tokenization key use case. +func (c *Container) TokenizationKeyUseCase() (tokenizationUseCase.TokenizationKeyUseCase, error) { + var err error + c.tokenizationKeyUseCaseInit.Do(func() { + c.tokenizationKeyUseCase, err = c.initTokenizationKeyUseCase() + if err != nil { + c.initErrors["tokenizationKeyUseCase"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["tokenizationKeyUseCase"]; exists { + return nil, storedErr + } + return c.tokenizationKeyUseCase, nil +} + +// TokenizationUseCase returns the tokenization use case. +func (c *Container) TokenizationUseCase() (tokenizationUseCase.TokenizationUseCase, error) { + var err error + c.tokenizationUseCaseInit.Do(func() { + c.tokenizationUseCase, err = c.initTokenizationUseCase() + if err != nil { + c.initErrors["tokenizationUseCase"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["tokenizationUseCase"]; exists { + return nil, storedErr + } + return c.tokenizationUseCase, nil +} + +// TokenizationKeyHandler returns the tokenization key HTTP handler. +func (c *Container) TokenizationKeyHandler() (*tokenizationHTTP.TokenizationKeyHandler, error) { + var err error + c.tokenizationKeyHandlerInit.Do(func() { + c.tokenizationKeyHandler, err = c.initTokenizationKeyHandler() + if err != nil { + c.initErrors["tokenizationKeyHandler"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["tokenizationKeyHandler"]; exists { + return nil, storedErr + } + return c.tokenizationKeyHandler, nil +} + +// TokenizationHandler returns the tokenization HTTP handler. +func (c *Container) TokenizationHandler() (*tokenizationHTTP.TokenizationHandler, error) { + var err error + c.tokenizationHandlerInit.Do(func() { + c.tokenizationHandler, err = c.initTokenizationHandler() + if err != nil { + c.initErrors["tokenizationHandler"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["tokenizationHandler"]; exists { + return nil, storedErr + } + return c.tokenizationHandler, nil +} + +// initTokenizationKeyRepository creates the tokenization key repository. +func (c *Container) initTokenizationKeyRepository() (tokenizationUseCase.TokenizationKeyRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database for tokenization key repository: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return tokenizationPostgreSQL.NewPostgreSQLTokenizationKeyRepository(db), nil + case "mysql": + return tokenizationMySQL.NewMySQLTokenizationKeyRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initTokenizationTokenRepository creates the tokenization token repository. +func (c *Container) initTokenizationTokenRepository() (tokenizationUseCase.TokenRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database for tokenization token repository: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return tokenizationPostgreSQL.NewPostgreSQLTokenRepository(db), nil + case "mysql": + return tokenizationMySQL.NewMySQLTokenRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initTokenizationDekRepository creates the DEK repository for tokenization use case. +func (c *Container) initTokenizationDekRepository() (tokenizationUseCase.DekRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database for tokenization dek repository: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return cryptoPostgreSQL.NewPostgreSQLDekRepository(db), nil + case "mysql": + return cryptoMySQL.NewMySQLDekRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initTokenizationKeyUseCase creates the tokenization key use case. +func (c *Container) initTokenizationKeyUseCase() (tokenizationUseCase.TokenizationKeyUseCase, error) { + txManager, err := c.TxManager() + if err != nil { + return nil, fmt.Errorf("failed to get tx manager for tokenization key use case: %w", err) + } + + tokenizationKeyRepository, err := c.TokenizationKeyRepository() + if err != nil { + return nil, fmt.Errorf( + "failed to get tokenization key repository for tokenization key use case: %w", + err, + ) + } + + dekRepository, err := c.TokenizationDekRepository() + if err != nil { + return nil, fmt.Errorf("failed to get dek repository for tokenization key use case: %w", err) + } + + kekChain, err := c.loadKekChain() + if err != nil { + return nil, fmt.Errorf("failed to load kek chain for tokenization key use case: %w", err) + } + + keyManager := c.KeyManager() + + return tokenizationUseCase.NewTokenizationKeyUseCase( + txManager, + tokenizationKeyRepository, + dekRepository, + keyManager, + kekChain, + ), nil +} + +// initTokenizationUseCase creates the tokenization use case. +func (c *Container) initTokenizationUseCase() (tokenizationUseCase.TokenizationUseCase, error) { + txManager, err := c.TxManager() + if err != nil { + return nil, fmt.Errorf("failed to get tx manager for tokenization use case: %w", err) + } + + tokenizationKeyRepository, err := c.TokenizationKeyRepository() + if err != nil { + return nil, fmt.Errorf("failed to get tokenization key repository for tokenization use case: %w", err) + } + + tokenRepository, err := c.TokenizationTokenRepository() + if err != nil { + return nil, fmt.Errorf("failed to get token repository for tokenization use case: %w", err) + } + + dekRepository, err := c.TokenizationDekRepository() + if err != nil { + return nil, fmt.Errorf("failed to get dek repository for tokenization use case: %w", err) + } + + aeadManager := c.AEADManager() + + keyManager := c.KeyManager() + + hashService := tokenizationUseCase.NewSHA256HashService() + + kekChain, err := c.loadKekChain() + if err != nil { + return nil, fmt.Errorf("failed to load kek chain for tokenization use case: %w", err) + } + + baseUseCase := tokenizationUseCase.NewTokenizationUseCase( + txManager, + tokenizationKeyRepository, + tokenRepository, + dekRepository, + aeadManager, + keyManager, + hashService, + kekChain, + ) + + // Wrap with metrics if enabled + if c.config.MetricsEnabled { + businessMetrics, err := c.BusinessMetrics() + if err != nil { + return nil, fmt.Errorf("failed to get business metrics for tokenization use case: %w", err) + } + return tokenizationUseCase.NewTokenizationUseCaseWithMetrics(baseUseCase, businessMetrics), nil + } + + return baseUseCase, nil +} + +// initTokenizationKeyHandler creates the tokenization key HTTP handler. +func (c *Container) initTokenizationKeyHandler() (*tokenizationHTTP.TokenizationKeyHandler, error) { + tokenizationKeyUseCase, err := c.TokenizationKeyUseCase() + if err != nil { + return nil, fmt.Errorf( + "failed to get tokenization key use case for tokenization key handler: %w", + err, + ) + } + + logger := c.Logger() + + return tokenizationHTTP.NewTokenizationKeyHandler(tokenizationKeyUseCase, logger), nil +} + +// initTokenizationHandler creates the tokenization HTTP handler. +func (c *Container) initTokenizationHandler() (*tokenizationHTTP.TokenizationHandler, error) { + tokenizationUseCase, err := c.TokenizationUseCase() + if err != nil { + return nil, fmt.Errorf("failed to get tokenization use case for tokenization handler: %w", err) + } + + logger := c.Logger() + + return tokenizationHTTP.NewTokenizationHandler(tokenizationUseCase, logger), nil +} diff --git a/internal/app/di_transit.go b/internal/app/di_transit.go new file mode 100644 index 0000000..b20ff60 --- /dev/null +++ b/internal/app/di_transit.go @@ -0,0 +1,206 @@ +package app + +import ( + "fmt" + + cryptoMySQL "github.com/allisson/secrets/internal/crypto/repository/mysql" + cryptoPostgreSQL "github.com/allisson/secrets/internal/crypto/repository/postgresql" + transitHTTP "github.com/allisson/secrets/internal/transit/http" + transitMySQL "github.com/allisson/secrets/internal/transit/repository/mysql" + transitPostgreSQL "github.com/allisson/secrets/internal/transit/repository/postgresql" + transitUseCase "github.com/allisson/secrets/internal/transit/usecase" +) + +// TransitKeyRepository returns the transit key repository instance. +func (c *Container) TransitKeyRepository() (transitUseCase.TransitKeyRepository, error) { + var err error + c.transitKeyRepositoryInit.Do(func() { + c.transitKeyRepository, err = c.initTransitKeyRepository() + if err != nil { + c.initErrors["transitKeyRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["transitKeyRepository"]; exists { + return nil, storedErr + } + return c.transitKeyRepository, nil +} + +// TransitDekRepository returns the DEK repository for transit use case. +func (c *Container) TransitDekRepository() (transitUseCase.DekRepository, error) { + var err error + c.transitDekRepositoryInit.Do(func() { + c.transitDekRepository, err = c.initTransitDekRepository() + if err != nil { + c.initErrors["transitDekRepository"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["transitDekRepository"]; exists { + return nil, storedErr + } + return c.transitDekRepository, nil +} + +// TransitKeyUseCase returns the transit key use case instance. +func (c *Container) TransitKeyUseCase() (transitUseCase.TransitKeyUseCase, error) { + var err error + c.transitKeyUseCaseInit.Do(func() { + c.transitKeyUseCase, err = c.initTransitKeyUseCase() + if err != nil { + c.initErrors["transitKeyUseCase"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["transitKeyUseCase"]; exists { + return nil, storedErr + } + return c.transitKeyUseCase, nil +} + +// TransitKeyHandler returns the transit key HTTP handler instance. +func (c *Container) TransitKeyHandler() (*transitHTTP.TransitKeyHandler, error) { + var err error + c.transitKeyHandlerInit.Do(func() { + c.transitKeyHandler, err = c.initTransitKeyHandler() + if err != nil { + c.initErrors["transitKeyHandler"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["transitKeyHandler"]; exists { + return nil, storedErr + } + return c.transitKeyHandler, nil +} + +// CryptoHandler returns the crypto HTTP handler instance. +func (c *Container) CryptoHandler() (*transitHTTP.CryptoHandler, error) { + var err error + c.cryptoHandlerInit.Do(func() { + c.cryptoHandler, err = c.initCryptoHandler() + if err != nil { + c.initErrors["cryptoHandler"] = err + } + }) + if err != nil { + return nil, err + } + if storedErr, exists := c.initErrors["cryptoHandler"]; exists { + return nil, storedErr + } + return c.cryptoHandler, nil +} + +// initTransitKeyRepository creates the transit key repository based on the database driver. +func (c *Container) initTransitKeyRepository() (transitUseCase.TransitKeyRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database for transit key repository: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return transitPostgreSQL.NewPostgreSQLTransitKeyRepository(db), nil + case "mysql": + return transitMySQL.NewMySQLTransitKeyRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initTransitDekRepository creates the DEK repository for transit use case. +func (c *Container) initTransitDekRepository() (transitUseCase.DekRepository, error) { + db, err := c.DB() + if err != nil { + return nil, fmt.Errorf("failed to get database for transit dek repository: %w", err) + } + + switch c.config.DBDriver { + case "postgres": + return cryptoPostgreSQL.NewPostgreSQLDekRepository(db), nil + case "mysql": + return cryptoMySQL.NewMySQLDekRepository(db), nil + default: + return nil, fmt.Errorf("unsupported database driver: %s", c.config.DBDriver) + } +} + +// initTransitKeyUseCase creates the transit key use case with all its dependencies. +func (c *Container) initTransitKeyUseCase() (transitUseCase.TransitKeyUseCase, error) { + txManager, err := c.TxManager() + if err != nil { + return nil, fmt.Errorf("failed to get tx manager for transit key use case: %w", err) + } + + transitKeyRepository, err := c.TransitKeyRepository() + if err != nil { + return nil, fmt.Errorf("failed to get transit key repository for transit key use case: %w", err) + } + + dekRepository, err := c.TransitDekRepository() + if err != nil { + return nil, fmt.Errorf("failed to get dek repository for transit key use case: %w", err) + } + + kekChain, err := c.loadKekChain() + if err != nil { + return nil, fmt.Errorf("failed to load kek chain for transit key use case: %w", err) + } + + keyManager := c.KeyManager() + aeadManager := c.AEADManager() + + baseUseCase := transitUseCase.NewTransitKeyUseCase( + txManager, + transitKeyRepository, + dekRepository, + keyManager, + aeadManager, + kekChain, + ) + + // Wrap with metrics if enabled + if c.config.MetricsEnabled { + businessMetrics, err := c.BusinessMetrics() + if err != nil { + return nil, fmt.Errorf("failed to get business metrics for transit key use case: %w", err) + } + return transitUseCase.NewTransitKeyUseCaseWithMetrics(baseUseCase, businessMetrics), nil + } + + return baseUseCase, nil +} + +// initTransitKeyHandler creates the transit key HTTP handler with all its dependencies. +func (c *Container) initTransitKeyHandler() (*transitHTTP.TransitKeyHandler, error) { + transitKeyUseCase, err := c.TransitKeyUseCase() + if err != nil { + return nil, fmt.Errorf("failed to get transit key use case for transit key handler: %w", err) + } + + logger := c.Logger() + + return transitHTTP.NewTransitKeyHandler(transitKeyUseCase, logger), nil +} + +// initCryptoHandler creates the crypto HTTP handler with all its dependencies. +func (c *Container) initCryptoHandler() (*transitHTTP.CryptoHandler, error) { + transitKeyUseCase, err := c.TransitKeyUseCase() + if err != nil { + return nil, fmt.Errorf("failed to get transit key use case for crypto handler: %w", err) + } + + logger := c.Logger() + + return transitHTTP.NewCryptoHandler(transitKeyUseCase, logger), nil +} diff --git a/internal/auth/domain/audit_log_test.go b/internal/auth/domain/audit_log_test.go new file mode 100644 index 0000000..2bb9c71 --- /dev/null +++ b/internal/auth/domain/audit_log_test.go @@ -0,0 +1,105 @@ +package domain + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestAuditLog_HasValidSignature(t *testing.T) { + kekID := uuid.New() + + tests := []struct { + name string + log AuditLog + expected bool + }{ + { + name: "Valid signature", + log: AuditLog{ + IsSigned: true, + KekID: &kekID, + Signature: make([]byte, 32), + }, + expected: true, + }, + { + name: "Not signed", + log: AuditLog{ + IsSigned: false, + KekID: &kekID, + Signature: make([]byte, 32), + }, + expected: false, + }, + { + name: "Nil KekID", + log: AuditLog{ + IsSigned: true, + KekID: nil, + Signature: make([]byte, 32), + }, + expected: false, + }, + { + name: "Invalid signature length", + log: AuditLog{ + IsSigned: true, + KekID: &kekID, + Signature: make([]byte, 31), + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.log.HasValidSignature()) + }) + } +} + +func TestAuditLog_IsLegacy(t *testing.T) { + kekID := uuid.New() + + tests := []struct { + name string + log AuditLog + expected bool + }{ + { + name: "Legacy log", + log: AuditLog{ + IsSigned: false, + KekID: nil, + Signature: nil, + }, + expected: true, + }, + { + name: "Signed log", + log: AuditLog{ + IsSigned: true, + KekID: &kekID, + Signature: make([]byte, 32), + }, + expected: false, + }, + { + name: "Mixed state (not legacy)", + log: AuditLog{ + IsSigned: false, + KekID: &kekID, + Signature: nil, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.log.IsLegacy()) + }) + } +} diff --git a/internal/auth/http/dto/request.go b/internal/auth/http/dto/request.go index 6cae892..8621cb2 100644 --- a/internal/auth/http/dto/request.go +++ b/internal/auth/http/dto/request.go @@ -10,9 +10,9 @@ import ( // CreateClientRequest contains the parameters for creating a new authentication client. type CreateClientRequest struct { - Name string `json:"name"` - IsActive bool `json:"is_active"` - Policies []authDomain.PolicyDocument `json:"policies"` + Name string `json:"name"` // Unique name for the client + IsActive bool `json:"is_active"` // Whether the client is active + Policies []authDomain.PolicyDocument `json:"policies"` // Access control policies } // Validate checks if the create client request is valid. @@ -32,9 +32,9 @@ func (r *CreateClientRequest) Validate() error { // UpdateClientRequest contains the parameters for updating an existing client. type UpdateClientRequest struct { - Name string `json:"name"` - IsActive bool `json:"is_active"` - Policies []authDomain.PolicyDocument `json:"policies"` + Name string `json:"name"` // New unique name for the client + IsActive bool `json:"is_active"` // New active status + Policies []authDomain.PolicyDocument `json:"policies"` // New access control policies } // Validate checks if the update client request is valid. @@ -74,7 +74,7 @@ func validatePolicyDocument(value interface{}) error { // IssueTokenRequest contains the parameters for issuing an authentication token. type IssueTokenRequest struct { - ClientID string `json:"client_id"` + ClientID string `json:"client_id"` // Client ID (UUID) ClientSecret string `json:"client_secret"` //nolint:gosec // API authentication field } diff --git a/internal/auth/http/dto/response.go b/internal/auth/http/dto/response.go index b8a6bfd..d546692 100644 --- a/internal/auth/http/dto/response.go +++ b/internal/auth/http/dto/response.go @@ -11,17 +11,17 @@ import ( // CreateClientResponse contains the result of creating a new client. // SECURITY: The secret is only returned once and must be saved securely. type CreateClientResponse struct { - ID string `json:"id"` + ID string `json:"id"` // Unique identifier (UUID) Secret string `json:"secret"` //nolint:gosec // returned once on creation } // ClientResponse represents a client in API responses (excludes secret). type ClientResponse struct { - ID string `json:"id"` - Name string `json:"name"` - IsActive bool `json:"is_active"` - Policies []authDomain.PolicyDocument `json:"policies"` - CreatedAt time.Time `json:"created_at"` + ID string `json:"id"` // Unique identifier (UUID) + Name string `json:"name"` // Unique name + IsActive bool `json:"is_active"` // Whether the client is active + Policies []authDomain.PolicyDocument `json:"policies"` // Access control policies + CreatedAt time.Time `json:"created_at"` // Creation timestamp } // MapClientToResponse converts a domain client to an API response. @@ -54,8 +54,8 @@ func MapClientsToListResponse(clients []*authDomain.Client) ListClientsResponse // IssueTokenResponse contains the result of issuing a token. // SECURITY: The token is only returned once and must be saved securely. type IssueTokenResponse struct { - Token string `json:"token"` - ExpiresAt time.Time `json:"expires_at"` + Token string `json:"token"` // Authentication token string + ExpiresAt time.Time `json:"expires_at"` // Token expiration timestamp } // AuditLogResponse represents an audit log entry in API responses. diff --git a/internal/auth/usecase/metrics_decorator_test.go b/internal/auth/usecase/metrics_decorator_test.go new file mode 100644 index 0000000..8de4239 --- /dev/null +++ b/internal/auth/usecase/metrics_decorator_test.go @@ -0,0 +1,127 @@ +package usecase_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + authDomain "github.com/allisson/secrets/internal/auth/domain" + "github.com/allisson/secrets/internal/auth/usecase" + usecaseMocks "github.com/allisson/secrets/internal/auth/usecase/mocks" +) + +// mockBusinessMetrics is a local mock for metrics.BusinessMetrics to avoid dependency issues. +type mockBusinessMetrics struct { + mock.Mock +} + +func (m *mockBusinessMetrics) RecordOperation(ctx context.Context, domain, operation, status string) { + m.Called(ctx, domain, operation, status) +} + +func (m *mockBusinessMetrics) RecordDuration( + ctx context.Context, + domain, operation string, + duration time.Duration, + status string, +) { + m.Called(ctx, domain, operation, duration, status) +} + +func TestClientUseCaseWithMetrics(t *testing.T) { + mockNext := &usecaseMocks.MockClientUseCase{} + mockMetrics := &mockBusinessMetrics{} + uc := usecase.NewClientUseCaseWithMetrics(mockNext, mockMetrics) + + ctx := context.Background() + clientID := uuid.New() + + t.Run("Create success", func(t *testing.T) { + input := &authDomain.CreateClientInput{Name: "test"} + output := &authDomain.CreateClientOutput{ID: clientID} + + mockNext.On("Create", ctx, input).Return(output, nil).Once() + mockMetrics.On("RecordOperation", ctx, "auth", "client_create", "success").Return().Once() + mockMetrics.On("RecordDuration", ctx, "auth", "client_create", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + res, err := uc.Create(ctx, input) + assert.NoError(t, err) + assert.Equal(t, output, res) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) + + t.Run("Create error", func(t *testing.T) { + input := &authDomain.CreateClientInput{Name: "test"} + expectedErr := errors.New("error") + + mockNext.On("Create", ctx, input).Return(nil, expectedErr).Once() + mockMetrics.On("RecordOperation", ctx, "auth", "client_create", "error").Return().Once() + mockMetrics.On("RecordDuration", ctx, "auth", "client_create", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + res, err := uc.Create(ctx, input) + assert.Error(t, err) + assert.Nil(t, res) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) +} + +func TestTokenUseCaseWithMetrics(t *testing.T) { + mockNext := &usecaseMocks.MockTokenUseCase{} + mockMetrics := &mockBusinessMetrics{} + uc := usecase.NewTokenUseCaseWithMetrics(mockNext, mockMetrics) + + ctx := context.Background() + + t.Run("Issue success", func(t *testing.T) { + input := &authDomain.IssueTokenInput{ClientID: uuid.New()} + output := &authDomain.IssueTokenOutput{PlainToken: "token"} + + mockNext.On("Issue", ctx, input).Return(output, nil).Once() + mockMetrics.On("RecordOperation", ctx, "auth", "token_issue", "success").Return().Once() + mockMetrics.On("RecordDuration", ctx, "auth", "token_issue", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + res, err := uc.Issue(ctx, input) + assert.NoError(t, err) + assert.Equal(t, output, res) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) +} + +func TestAuditLogUseCaseWithMetrics(t *testing.T) { + mockNext := &usecaseMocks.MockAuditLogUseCase{} + mockMetrics := &mockBusinessMetrics{} + uc := usecase.NewAuditLogUseCaseWithMetrics(mockNext, mockMetrics) + + ctx := context.Background() + + t.Run("Create success", func(t *testing.T) { + requestID := uuid.New() + clientID := uuid.New() + mockNext.On("Create", ctx, requestID, clientID, authDomain.ReadCapability, "/test", mock.Anything). + Return(nil). + Once() + mockMetrics.On("RecordOperation", ctx, "auth", "audit_log_create", "success").Return().Once() + mockMetrics.On("RecordDuration", ctx, "auth", "audit_log_create", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + err := uc.Create(ctx, requestID, clientID, authDomain.ReadCapability, "/test", nil) + assert.NoError(t, err) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) +} diff --git a/internal/config/config.go b/internal/config/config.go index 1219770..55605a8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,49 +12,63 @@ import ( // Config holds all application configuration. type Config struct { - // Server configuration + // ServerHost is the host address the server will bind to. ServerHost string + // ServerPort is the port number the server will listen on. ServerPort int - // Database configuration - DBDriver string - DBConnectionString string + // DBDriver is the database driver to use (e.g., "postgres", "mysql"). + DBDriver string + // DBConnectionString is the connection string for the database. + DBConnectionString string + // DBMaxOpenConnections is the maximum number of open connections to the database. DBMaxOpenConnections int + // DBMaxIdleConnections is the maximum number of idle connections in the database pool. DBMaxIdleConnections int - DBConnMaxLifetime time.Duration + // DBConnMaxLifetime is the maximum amount of time a connection may be reused. + DBConnMaxLifetime time.Duration - // Logging + // LogLevel is the logging level (e.g., "debug", "info", "warn", "error"). LogLevel string - // Auth + // AuthTokenExpiration is the duration after which an authentication token expires. AuthTokenExpiration time.Duration - // Rate Limiting (authenticated endpoints) - RateLimitEnabled bool + // RateLimitEnabled indicates whether rate limiting for authenticated endpoints is enabled. + RateLimitEnabled bool + // RateLimitRequestsPerSec is the number of requests allowed per second for authenticated endpoints. RateLimitRequestsPerSec float64 - RateLimitBurst int + // RateLimitBurst is the burst size for authenticated endpoints rate limiting. + RateLimitBurst int - // Rate Limiting for Token Endpoint (IP-based, unauthenticated) - RateLimitTokenEnabled bool + // RateLimitTokenEnabled indicates whether rate limiting for the token endpoint is enabled. + RateLimitTokenEnabled bool + // RateLimitTokenRequestsPerSec is the number of requests allowed per second for the token endpoint. RateLimitTokenRequestsPerSec float64 - RateLimitTokenBurst int + // RateLimitTokenBurst is the burst size for the token endpoint rate limiting. + RateLimitTokenBurst int - // CORS - CORSEnabled bool + // CORSEnabled indicates whether CORS is enabled. + CORSEnabled bool + // CORSAllowOrigins is a comma-separated list of allowed origins for CORS. CORSAllowOrigins string - // Metrics - MetricsEnabled bool + // MetricsEnabled indicates whether metrics collection is enabled. + MetricsEnabled bool + // MetricsNamespace is the namespace for the application metrics. MetricsNamespace string - MetricsPort int + // MetricsPort is the port number for the metrics server. + MetricsPort int - // KMS configuration + // KMSProvider is the KMS provider to use (e.g., "google", "aws", "azure"). KMSProvider string - KMSKeyURI string + // KMSKeyURI is the URI for the master key in the KMS. + KMSKeyURI string - // Account Lockout + // LockoutMaxAttempts is the maximum number of failed login attempts before a lockout. LockoutMaxAttempts int - LockoutDuration time.Duration + // LockoutDuration is the duration for which an account is locked out after maximum attempts. + LockoutDuration time.Duration } // Load loads configuration from environment variables and .env file. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index a7beba0..8402d51 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -2,6 +2,7 @@ package config import ( "os" + "path/filepath" "testing" "time" @@ -116,10 +117,51 @@ func TestLoad(t *testing.T) { envVars: map[string]string{ "METRICS_ENABLED": "false", "METRICS_NAMESPACE": "custom", + "METRICS_PORT": "9091", }, validate: func(t *testing.T, cfg *Config) { assert.Equal(t, false, cfg.MetricsEnabled) assert.Equal(t, "custom", cfg.MetricsNamespace) + assert.Equal(t, 9091, cfg.MetricsPort) + }, + }, + { + name: "load custom rate limit token configuration", + envVars: map[string]string{ + "RATE_LIMIT_TOKEN_ENABLED": "false", + "RATE_LIMIT_TOKEN_REQUESTS_PER_SEC": "2.5", + "RATE_LIMIT_TOKEN_BURST": "5", + }, + validate: func(t *testing.T, cfg *Config) { + assert.Equal(t, false, cfg.RateLimitTokenEnabled) + assert.Equal(t, 2.5, cfg.RateLimitTokenRequestsPerSec) + assert.Equal(t, 5, cfg.RateLimitTokenBurst) + }, + }, + { + name: "load custom KMS configuration", + envVars: map[string]string{ + "KMS_PROVIDER": "google", + "KMS_KEY_URI": "gcpkms://projects/my-project/locations/global/keyRings/my-keyring/cryptoKeys/my-key", + }, + validate: func(t *testing.T, cfg *Config) { + assert.Equal(t, "google", cfg.KMSProvider) + assert.Equal( + t, + "gcpkms://projects/my-project/locations/global/keyRings/my-keyring/cryptoKeys/my-key", + cfg.KMSKeyURI, + ) + }, + }, + { + name: "load custom lockout configuration", + envVars: map[string]string{ + "LOCKOUT_MAX_ATTEMPTS": "5", + "LOCKOUT_DURATION_MINUTES": "15", + }, + validate: func(t *testing.T, cfg *Config) { + assert.Equal(t, 5, cfg.LockoutMaxAttempts) + assert.Equal(t, 15*time.Minute, cfg.LockoutDuration) }, }, } @@ -143,3 +185,62 @@ func TestLoad(t *testing.T) { }) } } + +func TestGetGinMode(t *testing.T) { + tests := []struct { + logLevel string + expected string + }{ + {"debug", "debug"}, + {"info", "release"}, + {"warn", "release"}, + {"error", "release"}, + {"fatal", "release"}, + {"panic", "release"}, + {"unknown", "release"}, + {"", "release"}, + } + + for _, tt := range tests { + t.Run(tt.logLevel, func(t *testing.T) { + cfg := &Config{LogLevel: tt.logLevel} + assert.Equal(t, tt.expected, cfg.GetGinMode()) + }) + } +} + +func TestLoadDotEnv(t *testing.T) { + // Create a temporary directory structure + tmpDir, err := os.MkdirTemp("", "config_test") + require.NoError(t, err) + defer func() { + _ = os.RemoveAll(tmpDir) + }() + + // Create a .env file in the temp root + err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte("TEST_ENV_VAR=found"), 0600) + require.NoError(t, err) + + // Create a child directory + childDir := filepath.Join(tmpDir, "child", "grandchild") + err = os.MkdirAll(childDir, 0700) + require.NoError(t, err) + + // Change working directory to childDir + oldCwd, err := os.Getwd() + require.NoError(t, err) + defer func() { + _ = os.Chdir(oldCwd) + }() + + err = os.Chdir(childDir) + require.NoError(t, err) + + // Load .env + loadDotEnv() + + // Verify the env var was loaded + assert.Equal(t, "found", os.Getenv("TEST_ENV_VAR")) + err = os.Unsetenv("TEST_ENV_VAR") + require.NoError(t, err) +} diff --git a/internal/crypto/domain/dek_test.go b/internal/crypto/domain/dek_test.go new file mode 100644 index 0000000..b62d1e2 --- /dev/null +++ b/internal/crypto/domain/dek_test.go @@ -0,0 +1,35 @@ +package domain + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestDek(t *testing.T) { + t.Run("dek initialization", func(t *testing.T) { + id := uuid.New() + kekID := uuid.New() + now := time.Now() + encryptedKey := []byte("encrypted-key") + nonce := []byte("nonce") + + dek := Dek{ + ID: id, + KekID: kekID, + Algorithm: AESGCM, + EncryptedKey: encryptedKey, + Nonce: nonce, + CreatedAt: now, + } + + assert.Equal(t, id, dek.ID) + assert.Equal(t, kekID, dek.KekID) + assert.Equal(t, AESGCM, dek.Algorithm) + assert.Equal(t, encryptedKey, dek.EncryptedKey) + assert.Equal(t, nonce, dek.Nonce) + assert.Equal(t, now, dek.CreatedAt) + }) +} diff --git a/internal/crypto/domain/kek_test.go b/internal/crypto/domain/kek_test.go new file mode 100644 index 0000000..5d945f3 --- /dev/null +++ b/internal/crypto/domain/kek_test.go @@ -0,0 +1,63 @@ +package domain + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestKekChain(t *testing.T) { + kek1 := &Kek{ID: uuid.New(), Key: []byte("key1-data-1234567890123456789012")} + kek2 := &Kek{ID: uuid.New(), Key: []byte("key2-data-1234567890123456789012")} + + t.Run("NewKekChain and ActiveKekID", func(t *testing.T) { + keks := []*Kek{kek1, kek2} + kc := NewKekChain(keks) + assert.Equal(t, kek1.ID, kc.ActiveKekID()) + }) + + t.Run("Get KEK", func(t *testing.T) { + kc := NewKekChain([]*Kek{kek1, kek2}) + + k, ok := kc.Get(kek1.ID) + assert.True(t, ok) + assert.Equal(t, kek1, k) + + k, ok = kc.Get(kek2.ID) + assert.True(t, ok) + assert.Equal(t, kek2, k) + + k, ok = kc.Get(uuid.New()) + assert.False(t, ok) + assert.Nil(t, k) + }) + + t.Run("Close zeros all keys", func(t *testing.T) { + k1Data := make([]byte, 32) + copy(k1Data, []byte("key1-data-1234567890123456789012")) + k2Data := make([]byte, 32) + copy(k2Data, []byte("key2-data-1234567890123456789012")) + + k1 := &Kek{ID: uuid.New(), Key: k1Data} + k2 := &Kek{ID: uuid.New(), Key: k2Data} + + kc := NewKekChain([]*Kek{k1, k2}) + kc.Close() + + assert.Equal(t, uuid.Nil, kc.ActiveKekID()) + _, ok := kc.Get(k1.ID) + assert.False(t, ok) + + expectedZero := make([]byte, 32) + assert.Equal(t, expectedZero, k1.Key) + assert.Equal(t, expectedZero, k2.Key) + }) + + t.Run("NewKekChain with empty slice", func(t *testing.T) { + kc := NewKekChain([]*Kek{}) + assert.Equal(t, uuid.Nil, kc.ActiveKekID()) + _, ok := kc.Get(uuid.New()) + assert.False(t, ok) + }) +} diff --git a/internal/crypto/domain/zero_test.go b/internal/crypto/domain/zero_test.go new file mode 100644 index 0000000..8def49b --- /dev/null +++ b/internal/crypto/domain/zero_test.go @@ -0,0 +1,39 @@ +package domain + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestZero(t *testing.T) { + t.Run("zero non-empty slice", func(t *testing.T) { + b := []byte{1, 2, 3, 4, 5} + Zero(b) + for _, v := range b { + assert.Equal(t, byte(0), v) + } + }) + + t.Run("zero empty slice", func(t *testing.T) { + b := []byte{} + Zero(b) + assert.Equal(t, 0, len(b)) + }) + + t.Run("zero nil slice", func(t *testing.T) { + var b []byte + assert.NotPanics(t, func() { Zero(b) }) + }) + + t.Run("zero large slice", func(t *testing.T) { + b := make([]byte, 1024) + for i := range b { + b[i] = byte(i % 256) + } + Zero(b) + for _, v := range b { + assert.Equal(t, byte(0), v) + } + }) +} diff --git a/internal/crypto/repository/postgresql/postgresql_kek_repository.go b/internal/crypto/repository/postgresql/postgresql_kek_repository.go index e71e2a7..55b5597 100644 --- a/internal/crypto/repository/postgresql/postgresql_kek_repository.go +++ b/internal/crypto/repository/postgresql/postgresql_kek_repository.go @@ -75,7 +75,9 @@ func (p *PostgreSQLKekRepository) Update(ctx context.Context, kek *cryptoDomain. func (p *PostgreSQLKekRepository) List(ctx context.Context) ([]*cryptoDomain.Kek, error) { querier := database.GetTx(ctx, p.db) - query := `SELECT * FROM keks ORDER BY version DESC` + query := `SELECT id, master_key_id, algorithm, encrypted_key, nonce, version, created_at + FROM keks + ORDER BY version DESC` rows, err := querier.QueryContext(ctx, query) if err != nil { diff --git a/internal/crypto/service/mocks/mocks.go b/internal/crypto/service/mocks/mocks.go index f2e830e..a59f7ff 100644 --- a/internal/crypto/service/mocks/mocks.go +++ b/internal/crypto/service/mocks/mocks.go @@ -21,8 +21,6 @@ func NewMockAEAD(t interface { t.Cleanup(func() { mock.AssertExpectations(t) }) - mock.EXPECT().NonceSize().Return(12).Maybe() - return mock } @@ -67,24 +65,6 @@ func (_mock *MockAEAD) Decrypt(ciphertext []byte, nonce []byte, aad []byte) ([]b return r0, r1 } -// NonceSize provides a mock function for the type MockAEAD -func (_mock *MockAEAD) NonceSize() int { - ret := _mock.Called() - - if len(ret) == 0 { - return 12 - } - - var r0 int - if returnFunc, ok := ret.Get(0).(func() int); ok { - r0 = returnFunc() - } else { - r0 = ret.Get(0).(int) - } - - return r0 -} - // MockAEAD_Decrypt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Decrypt' type MockAEAD_Decrypt_Call struct { *mock.Call @@ -98,33 +78,6 @@ func (_e *MockAEAD_Expecter) Decrypt(ciphertext interface{}, nonce interface{}, return &MockAEAD_Decrypt_Call{Call: _e.mock.On("Decrypt", ciphertext, nonce, aad)} } -// NonceSize is a helper method to define mock.On call -func (_e *MockAEAD_Expecter) NonceSize() *MockAEAD_NonceSize_Call { - return &MockAEAD_NonceSize_Call{Call: _e.mock.On("NonceSize")} -} - -// MockAEAD_NonceSize_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NonceSize' -type MockAEAD_NonceSize_Call struct { - *mock.Call -} - -func (_c *MockAEAD_NonceSize_Call) Run(run func()) *MockAEAD_NonceSize_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *MockAEAD_NonceSize_Call) Return(_a0 int) *MockAEAD_NonceSize_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockAEAD_NonceSize_Call) RunAndReturn(run func() int) *MockAEAD_NonceSize_Call { - _c.Call.Return(run) - return _c -} - func (_c *MockAEAD_Decrypt_Call) Run(run func(ciphertext []byte, nonce []byte, aad []byte)) *MockAEAD_Decrypt_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 []byte @@ -234,6 +187,50 @@ func (_c *MockAEAD_Encrypt_Call) RunAndReturn(run func(plaintext []byte, aad []b return _c } +// NonceSize provides a mock function for the type MockAEAD +func (_mock *MockAEAD) NonceSize() int { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for NonceSize") + } + + var r0 int + if returnFunc, ok := ret.Get(0).(func() int); ok { + r0 = returnFunc() + } else { + r0 = ret.Get(0).(int) + } + return r0 +} + +// MockAEAD_NonceSize_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NonceSize' +type MockAEAD_NonceSize_Call struct { + *mock.Call +} + +// NonceSize is a helper method to define mock.On call +func (_e *MockAEAD_Expecter) NonceSize() *MockAEAD_NonceSize_Call { + return &MockAEAD_NonceSize_Call{Call: _e.mock.On("NonceSize")} +} + +func (_c *MockAEAD_NonceSize_Call) Run(run func()) *MockAEAD_NonceSize_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockAEAD_NonceSize_Call) Return(n int) *MockAEAD_NonceSize_Call { + _c.Call.Return(n) + return _c +} + +func (_c *MockAEAD_NonceSize_Call) RunAndReturn(run func() int) *MockAEAD_NonceSize_Call { + _c.Call.Return(run) + return _c +} + // NewMockAEADManager creates a new instance of MockAEADManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockAEADManager(t interface { diff --git a/internal/crypto/usecase/dek_usecase.go b/internal/crypto/usecase/dek_usecase.go index 02c9a57..3447ce4 100644 --- a/internal/crypto/usecase/dek_usecase.go +++ b/internal/crypto/usecase/dek_usecase.go @@ -10,6 +10,8 @@ import ( "github.com/allisson/secrets/internal/database" ) +// dekUseCase implements business logic for Data Encryption Key management. +// Orchestrates DEK rewrapping during KEK rotation. type dekUseCase struct { txManager database.TxManager dekRepo DekRepository diff --git a/internal/database/database.go b/internal/database/database.go index 9518856..b5de103 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -12,14 +12,15 @@ import ( // Config holds database configuration settings. type Config struct { - Driver string - ConnectionString string - MaxOpenConnections int - MaxIdleConnections int - ConnMaxLifetime time.Duration + Driver string // Database driver name (e.g., "postgres", "mysql"). + ConnectionString string // Connection string for the database. + MaxOpenConnections int // Maximum number of open connections to the database. + MaxIdleConnections int // Maximum number of idle connections in the pool. + ConnMaxLifetime time.Duration // Maximum amount of time a connection may be reused. } // Connect establishes a database connection with the given configuration. +// It sets connection pool settings and verifies the connection with a ping. func Connect(cfg Config) (*sql.DB, error) { db, err := sql.Open(cfg.Driver, cfg.ConnectionString) if err != nil { diff --git a/internal/database/database_test.go b/internal/database/database_test.go new file mode 100644 index 0000000..c9b33e6 --- /dev/null +++ b/internal/database/database_test.go @@ -0,0 +1,23 @@ +package database + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestConnect_Error(t *testing.T) { + cfg := Config{ + Driver: "invalid", + ConnectionString: "invalid", + MaxOpenConnections: 10, + MaxIdleConnections: 5, + ConnMaxLifetime: time.Hour, + } + + db, err := Connect(cfg) + assert.Error(t, err) + assert.Nil(t, db) + assert.Contains(t, err.Error(), "sql: unknown driver") +} diff --git a/internal/database/txmanager.go b/internal/database/txmanager.go index 9ac07f0..c487f9c 100644 --- a/internal/database/txmanager.go +++ b/internal/database/txmanager.go @@ -4,6 +4,7 @@ package database import ( "context" "database/sql" + "fmt" ) // txKey is a context key type for storing database transactions. @@ -38,16 +39,30 @@ func (m *sqlTxManager) WithTx(ctx context.Context, fn func(ctx context.Context) return err } + defer func() { + if r := recover(); r != nil { + _ = tx.Rollback() + panic(r) + } + }() + ctx = context.WithValue(ctx, txKey{}, tx) if err := fn(ctx); err != nil { if rbErr := tx.Rollback(); rbErr != nil { - return rbErr + return fmt.Errorf("error executing function (original error: %v): %w", err, rbErr) + } + return err + } + + if err := tx.Commit(); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + return fmt.Errorf("error committing transaction (commit error: %v): %w", err, rbErr) } return err } - return tx.Commit() + return nil } // GetTx retrieves a transaction from context, or returns the DB connection. diff --git a/internal/database/txmanager_test.go b/internal/database/txmanager_test.go index 96b1a37..b074a5d 100644 --- a/internal/database/txmanager_test.go +++ b/internal/database/txmanager_test.go @@ -93,3 +93,16 @@ func TestGetTx_WithoutTransaction(t *testing.T) { assert.NotNil(t, querier) assert.Equal(t, db, querier) } +func TestWithTx_Panic(t *testing.T) { + db := testutil.SetupPostgresDB(t) + defer testutil.TeardownDB(t, db) + + txManager := NewTxManager(db) + ctx := context.Background() + + assert.Panics(t, func() { + _ = txManager.WithTx(ctx, func(ctx context.Context) error { + panic("something went wrong") + }) + }) +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go index cfa3e0b..2d1310a 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -40,6 +40,15 @@ func Wrap(err error, message string) error { return fmt.Errorf("%s: %w", message, err) } +// Wrapf wraps an error with additional context using a format string. +func Wrapf(err error, format string, args ...any) error { + if err == nil { + return nil + } + message := fmt.Sprintf(format, args...) + return fmt.Errorf("%s: %w", message, err) +} + // Is reports whether any error in err's tree matches target. func Is(err, target error) bool { return errors.Is(err, target) diff --git a/internal/errors/errors_test.go b/internal/errors/errors_test.go new file mode 100644 index 0000000..dfbd003 --- /dev/null +++ b/internal/errors/errors_test.go @@ -0,0 +1,120 @@ +package errors + +import ( + "errors" + "testing" +) + +type customError struct { + Msg string +} + +func (e customError) Error() string { return e.Msg } + +func TestNew(t *testing.T) { + err := New("test error") + if err == nil { + t.Fatal("expected error, got nil") + } + if err.Error() != "test error" { + t.Errorf("expected 'test error', got '%s'", err.Error()) + } +} + +func TestWrap(t *testing.T) { + baseErr := errors.New("base error") + + t.Run("wrap non-nil error", func(t *testing.T) { + wrapped := Wrap(baseErr, "wrapped") + if wrapped == nil { + t.Fatal("expected wrapped error, got nil") + } + expected := "wrapped: base error" + if wrapped.Error() != expected { + t.Errorf("expected '%s', got '%s'", expected, wrapped.Error()) + } + if !errors.Is(wrapped, baseErr) { + t.Error("expected wrapped error to wrap baseErr") + } + }) + + t.Run("wrap nil error", func(t *testing.T) { + wrapped := Wrap(nil, "wrapped") + if wrapped != nil { + t.Errorf("expected nil, got %v", wrapped) + } + }) +} + +func TestWrapf(t *testing.T) { + baseErr := errors.New("base error") + + t.Run("wrapf non-nil error", func(t *testing.T) { + wrapped := Wrapf(baseErr, "wrapped %d", 123) + if wrapped == nil { + t.Fatal("expected wrapped error, got nil") + } + expected := "wrapped 123: base error" + if wrapped.Error() != expected { + t.Errorf("expected '%s', got '%s'", expected, wrapped.Error()) + } + if !errors.Is(wrapped, baseErr) { + t.Error("expected wrapped error to wrap baseErr") + } + }) + + t.Run("wrapf nil error", func(t *testing.T) { + wrapped := Wrapf(nil, "wrapped %d", 123) + if wrapped != nil { + t.Errorf("expected nil, got %v", wrapped) + } + }) +} + +func TestIs(t *testing.T) { + if !Is(ErrNotFound, ErrNotFound) { + t.Error("expected ErrNotFound to be ErrNotFound") + } + + wrapped := Wrap(ErrNotFound, "context") + if !Is(wrapped, ErrNotFound) { + t.Error("expected wrapped ErrNotFound to be ErrNotFound") + } + + if Is(ErrNotFound, ErrConflict) { + t.Error("expected ErrNotFound NOT to be ErrConflict") + } +} + +func TestAs(t *testing.T) { + custom := customError{Msg: "custom"} + wrapped := Wrap(custom, "context") + + var target customError + if !As(wrapped, &target) { + t.Fatal("expected wrapped error to be able to extract target") + } + if target.Msg != "custom" { + t.Errorf("expected 'custom', got '%s'", target.Msg) + } +} + +func TestStandardErrors(t *testing.T) { + tests := []struct { + err error + text string + }{ + {ErrNotFound, "not found"}, + {ErrConflict, "conflict"}, + {ErrInvalidInput, "invalid input"}, + {ErrUnauthorized, "unauthorized"}, + {ErrForbidden, "forbidden"}, + {ErrLocked, "locked"}, + } + + for _, tt := range tests { + if tt.err.Error() != tt.text { + t.Errorf("expected text '%s' for error, got '%s'", tt.text, tt.err.Error()) + } + } +} diff --git a/internal/http/http_test.go b/internal/http/http_test.go index 751d64a..1dc4dcf 100644 --- a/internal/http/http_test.go +++ b/internal/http/http_test.go @@ -271,24 +271,10 @@ func TestMetricsServer_Endpoints(t *testing.T) { metricsServer := NewMetricsServer("localhost", 8081, logger, provider) require.NotNil(t, metricsServer) - // We need to test the handler directly since we can't easily start the http.Server in a test without binding ports - // The internal router is not exposed, but we can verify the behavior by creating a similar router - // or by trusting that NewMetricsServer uses the same logic. - // However, for unit testing the logic inside NewMetricsServer, we can just recreate the router logic here - // or relies on integration tests. - // Better approach: Test NewMetricsServer initialization and ensure it doesn't panic. - // But we want to test that /metrics is registered. - // Let's modify NewMetricsServer to allow accessing the handler for testing or just test the router construction. - - // Since we cannot access the router inside metricsServer (it's private in http.Server), - // we will replicate the router construction logic effectively testing the configuration. - router := gin.New() - router.Use(gin.Recovery()) - router.GET("/metrics", gin.WrapH(provider.Handler())) - + // Test the handler from metricsServer exactly as it's configured w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/metrics", nil) - router.ServeHTTP(w, req) + metricsServer.GetHandler().ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Header().Get("Content-Type"), "text/plain") diff --git a/internal/http/metrics_server.go b/internal/http/metrics_server.go index 6eb230f..916ec59 100644 --- a/internal/http/metrics_server.go +++ b/internal/http/metrics_server.go @@ -27,6 +27,7 @@ func NewMetricsServer( ) *MetricsServer { router := gin.New() router.Use(gin.Recovery()) + router.Use(CustomLoggerMiddleware(logger)) if metricsProvider != nil { router.GET("/metrics", gin.WrapH(metricsProvider.Handler())) @@ -44,6 +45,11 @@ func NewMetricsServer( } } +// GetHandler returns the http.Handler for testing purposes. +func (s *MetricsServer) GetHandler() http.Handler { + return s.server.Handler +} + // Start starts the metrics HTTP server. func (s *MetricsServer) Start(ctx context.Context) error { s.logger.Info("starting metrics server", slog.String("addr", s.server.Addr)) diff --git a/internal/http/server.go b/internal/http/server.go index 3aca5cf..55abe60 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -107,23 +107,72 @@ func (s *Server) SetupRouter( router.GET("/health", s.healthHandler) router.GET("/ready", s.readinessHandler) - // Create authentication middleware - authMiddleware := authHTTP.AuthenticationMiddleware( - tokenUseCase, - tokenService, - s.logger, - ) - - // Create rate limit middleware (applied to authenticated routes only) - var rateLimitMiddleware gin.HandlerFunc - if cfg.RateLimitEnabled { - rateLimitMiddleware = authHTTP.RateLimitMiddleware( - cfg.RateLimitRequestsPerSec, - cfg.RateLimitBurst, + // API v1 routes + v1 := router.Group("/v1") + { + // Create authentication middleware + authMiddleware := authHTTP.AuthenticationMiddleware( + tokenUseCase, + tokenService, s.logger, ) + + // Create rate limit middleware + var rateLimitMiddleware gin.HandlerFunc + if cfg.RateLimitEnabled { + rateLimitMiddleware = authHTTP.RateLimitMiddleware( + cfg.RateLimitRequestsPerSec, + cfg.RateLimitBurst, + s.logger, + ) + } + + s.registerAuthRoutes( + v1, + cfg, + clientHandler, + tokenHandler, + auditLogHandler, + tokenUseCase, + tokenService, + auditLogUseCase, + authMiddleware, + rateLimitMiddleware, + ) + s.registerSecretRoutes(v1, secretHandler, authMiddleware, rateLimitMiddleware, auditLogUseCase) + s.registerTransitRoutes( + v1, + transitKeyHandler, + cryptoHandler, + authMiddleware, + rateLimitMiddleware, + auditLogUseCase, + ) + s.registerTokenizationRoutes( + v1, + tokenizationKeyHandler, + tokenizationHandler, + authMiddleware, + rateLimitMiddleware, + auditLogUseCase, + ) } + s.router = router +} + +func (s *Server) registerAuthRoutes( + v1 *gin.RouterGroup, + cfg *config.Config, + clientHandler *authHTTP.ClientHandler, + tokenHandler *authHTTP.TokenHandler, + auditLogHandler *authHTTP.AuditLogHandler, + tokenUseCase authUseCase.TokenUseCase, + tokenService authService.TokenService, + auditLogUseCase authUseCase.AuditLogUseCase, + authMiddleware gin.HandlerFunc, + rateLimitMiddleware gin.HandlerFunc, +) { // Create token rate limit middleware (IP-based, for unauthenticated token endpoint) var tokenRateLimitMiddleware gin.HandlerFunc if cfg.RateLimitTokenEnabled { @@ -134,195 +183,215 @@ func (s *Server) SetupRouter( ) } - // API v1 routes - v1 := router.Group("/v1") + // Token issuance endpoint (no authentication required, IP-based rate limiting) + if tokenRateLimitMiddleware != nil { + v1.POST("/token", tokenRateLimitMiddleware, tokenHandler.IssueTokenHandler) + } else { + v1.POST("/token", tokenHandler.IssueTokenHandler) + } + + // Client management endpoints + clients := v1.Group("/clients") + clients.Use(authMiddleware) + if rateLimitMiddleware != nil { + clients.Use(rateLimitMiddleware) + } { - // Token issuance endpoint (no authentication required, IP-based rate limiting) - if tokenRateLimitMiddleware != nil { - v1.POST("/token", tokenRateLimitMiddleware, tokenHandler.IssueTokenHandler) - } else { - v1.POST("/token", tokenHandler.IssueTokenHandler) - } + clients.POST("", + authHTTP.AuthorizationMiddleware(authDomain.WriteCapability, auditLogUseCase, s.logger), + clientHandler.CreateHandler, + ) + clients.GET("", + authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), + clientHandler.ListHandler, + ) + clients.GET("/:id", + authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), + clientHandler.GetHandler, + ) + clients.PUT("/:id", + authHTTP.AuthorizationMiddleware(authDomain.WriteCapability, auditLogUseCase, s.logger), + clientHandler.UpdateHandler, + ) + clients.DELETE("/:id", + authHTTP.AuthorizationMiddleware(authDomain.DeleteCapability, auditLogUseCase, s.logger), + clientHandler.DeleteHandler, + ) + clients.POST("/:id/unlock", + authHTTP.AuthorizationMiddleware(authDomain.WriteCapability, auditLogUseCase, s.logger), + clientHandler.UnlockHandler, + ) + } - // Client management endpoints - clients := v1.Group("/clients") - clients.Use(authMiddleware) // All client routes require authentication - if rateLimitMiddleware != nil { - clients.Use(rateLimitMiddleware) // Apply rate limiting to authenticated clients - } + // Audit log endpoints + auditLogs := v1.Group("/audit-logs") + auditLogs.Use(authMiddleware) + if rateLimitMiddleware != nil { + auditLogs.Use(rateLimitMiddleware) + } + { + auditLogs.GET("", + authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), + auditLogHandler.ListHandler, + ) + } +} + +func (s *Server) registerSecretRoutes( + v1 *gin.RouterGroup, + secretHandler *secretsHTTP.SecretHandler, + authMiddleware gin.HandlerFunc, + rateLimitMiddleware gin.HandlerFunc, + auditLogUseCase authUseCase.AuditLogUseCase, +) { + // Secret management endpoints + secrets := v1.Group("/secrets") + secrets.Use(authMiddleware) + if rateLimitMiddleware != nil { + secrets.Use(rateLimitMiddleware) + } + { + secrets.GET("", + authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), + secretHandler.ListHandler, + ) + secrets.POST("/*path", + authHTTP.AuthorizationMiddleware(authDomain.EncryptCapability, auditLogUseCase, s.logger), + secretHandler.CreateOrUpdateHandler, + ) + secrets.GET("/*path", + authHTTP.AuthorizationMiddleware(authDomain.DecryptCapability, auditLogUseCase, s.logger), + secretHandler.GetHandler, + ) + secrets.DELETE("/*path", + authHTTP.AuthorizationMiddleware(authDomain.DeleteCapability, auditLogUseCase, s.logger), + secretHandler.DeleteHandler, + ) + } +} + +func (s *Server) registerTransitRoutes( + v1 *gin.RouterGroup, + transitKeyHandler *transitHTTP.TransitKeyHandler, + cryptoHandler *transitHTTP.CryptoHandler, + authMiddleware gin.HandlerFunc, + rateLimitMiddleware gin.HandlerFunc, + auditLogUseCase authUseCase.AuditLogUseCase, +) { + // Transit encryption endpoints + transit := v1.Group("/transit") + transit.Use(authMiddleware) + if rateLimitMiddleware != nil { + transit.Use(rateLimitMiddleware) + } + { + keys := transit.Group("/keys") { - clients.POST("", - authHTTP.AuthorizationMiddleware(authDomain.WriteCapability, auditLogUseCase, s.logger), - clientHandler.CreateHandler, - ) - clients.GET("", + // List transit keys + keys.GET("", authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), - clientHandler.ListHandler, + transitKeyHandler.ListHandler, ) - clients.GET("/:id", - authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), - clientHandler.GetHandler, - ) - clients.PUT("/:id", - authHTTP.AuthorizationMiddleware(authDomain.WriteCapability, auditLogUseCase, s.logger), - clientHandler.UpdateHandler, - ) - clients.DELETE("/:id", - authHTTP.AuthorizationMiddleware(authDomain.DeleteCapability, auditLogUseCase, s.logger), - clientHandler.DeleteHandler, - ) - clients.POST("/:id/unlock", + + // Create new transit key + keys.POST("", authHTTP.AuthorizationMiddleware(authDomain.WriteCapability, auditLogUseCase, s.logger), - clientHandler.UnlockHandler, + transitKeyHandler.CreateHandler, ) - } - // Audit log endpoints - auditLogs := v1.Group("/audit-logs") - auditLogs.Use(authMiddleware) // All audit log routes require authentication - if rateLimitMiddleware != nil { - auditLogs.Use(rateLimitMiddleware) // Apply rate limiting to authenticated clients - } - { - auditLogs.GET("", - authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), - auditLogHandler.ListHandler, + // Rotate transit key to new version + keys.POST("/:name/rotate", + authHTTP.AuthorizationMiddleware(authDomain.RotateCapability, auditLogUseCase, s.logger), + transitKeyHandler.RotateHandler, ) - } - // Secret management endpoints - secrets := v1.Group("/secrets") - secrets.Use(authMiddleware) // All secret routes require authentication - if rateLimitMiddleware != nil { - secrets.Use(rateLimitMiddleware) // Apply rate limiting to authenticated clients - } - { - secrets.GET("", - authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), - secretHandler.ListHandler, + // Delete transit key + keys.DELETE("/:id", + authHTTP.AuthorizationMiddleware(authDomain.DeleteCapability, auditLogUseCase, s.logger), + transitKeyHandler.DeleteHandler, ) - secrets.POST("/*path", + + // Encrypt plaintext with transit key + keys.POST("/:name/encrypt", authHTTP.AuthorizationMiddleware(authDomain.EncryptCapability, auditLogUseCase, s.logger), - secretHandler.CreateOrUpdateHandler, + cryptoHandler.EncryptHandler, ) - secrets.GET("/*path", + + // Decrypt ciphertext with transit key + keys.POST("/:name/decrypt", authHTTP.AuthorizationMiddleware(authDomain.DecryptCapability, auditLogUseCase, s.logger), - secretHandler.GetHandler, - ) - secrets.DELETE("/*path", - authHTTP.AuthorizationMiddleware(authDomain.DeleteCapability, auditLogUseCase, s.logger), - secretHandler.DeleteHandler, + cryptoHandler.DecryptHandler, ) } + } +} - // Transit encryption endpoints - transit := v1.Group("/transit") - transit.Use(authMiddleware) // All transit routes require authentication - if rateLimitMiddleware != nil { - transit.Use(rateLimitMiddleware) // Apply rate limiting to authenticated clients - } +func (s *Server) registerTokenizationRoutes( + v1 *gin.RouterGroup, + tokenizationKeyHandler *tokenizationHTTP.TokenizationKeyHandler, + tokenizationHandler *tokenizationHTTP.TokenizationHandler, + authMiddleware gin.HandlerFunc, + rateLimitMiddleware gin.HandlerFunc, + auditLogUseCase authUseCase.AuditLogUseCase, +) { + // Tokenization endpoints + tokenization := v1.Group("/tokenization") + tokenization.Use(authMiddleware) + if rateLimitMiddleware != nil { + tokenization.Use(rateLimitMiddleware) + } + { + keys := tokenization.Group("/keys") { - keys := transit.Group("/keys") - { - // List transit keys - keys.GET("", - authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), - transitKeyHandler.ListHandler, - ) - - // Create new transit key - keys.POST("", - authHTTP.AuthorizationMiddleware(authDomain.WriteCapability, auditLogUseCase, s.logger), - transitKeyHandler.CreateHandler, - ) - - // Rotate transit key to new version - keys.POST("/:name/rotate", - authHTTP.AuthorizationMiddleware(authDomain.RotateCapability, auditLogUseCase, s.logger), - transitKeyHandler.RotateHandler, - ) - - // Delete transit key - keys.DELETE("/:id", - authHTTP.AuthorizationMiddleware(authDomain.DeleteCapability, auditLogUseCase, s.logger), - transitKeyHandler.DeleteHandler, - ) - - // Encrypt plaintext with transit key - keys.POST("/:name/encrypt", - authHTTP.AuthorizationMiddleware(authDomain.EncryptCapability, auditLogUseCase, s.logger), - cryptoHandler.EncryptHandler, - ) - - // Decrypt ciphertext with transit key - keys.POST("/:name/decrypt", - authHTTP.AuthorizationMiddleware(authDomain.DecryptCapability, auditLogUseCase, s.logger), - cryptoHandler.DecryptHandler, - ) - } - } + // List tokenization keys + keys.GET("", + authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), + tokenizationKeyHandler.ListHandler, + ) - // Tokenization endpoints - tokenization := v1.Group("/tokenization") - tokenization.Use(authMiddleware) // All tokenization routes require authentication - if rateLimitMiddleware != nil { - tokenization.Use(rateLimitMiddleware) // Apply rate limiting to authenticated clients - } - { - keys := tokenization.Group("/keys") - { - // List tokenization keys - keys.GET("", - authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), - tokenizationKeyHandler.ListHandler, - ) - - // Create new tokenization key - keys.POST("", - authHTTP.AuthorizationMiddleware(authDomain.WriteCapability, auditLogUseCase, s.logger), - tokenizationKeyHandler.CreateHandler, - ) - - // Rotate tokenization key to new version - keys.POST("/:name/rotate", - authHTTP.AuthorizationMiddleware(authDomain.RotateCapability, auditLogUseCase, s.logger), - tokenizationKeyHandler.RotateHandler, - ) - - // Delete tokenization key - keys.DELETE("/:id", - authHTTP.AuthorizationMiddleware(authDomain.DeleteCapability, auditLogUseCase, s.logger), - tokenizationKeyHandler.DeleteHandler, - ) - - // Tokenize plaintext with tokenization key - keys.POST("/:name/tokenize", - authHTTP.AuthorizationMiddleware(authDomain.EncryptCapability, auditLogUseCase, s.logger), - tokenizationHandler.TokenizeHandler, - ) - } - - // Detokenize token to retrieve plaintext - tokenization.POST("/detokenize", - authHTTP.AuthorizationMiddleware(authDomain.DecryptCapability, auditLogUseCase, s.logger), - tokenizationHandler.DetokenizeHandler, + // Create new tokenization key + keys.POST("", + authHTTP.AuthorizationMiddleware(authDomain.WriteCapability, auditLogUseCase, s.logger), + tokenizationKeyHandler.CreateHandler, ) - // Validate token existence and validity - tokenization.POST("/validate", - authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), - tokenizationHandler.ValidateHandler, + // Rotate tokenization key to new version + keys.POST("/:name/rotate", + authHTTP.AuthorizationMiddleware(authDomain.RotateCapability, auditLogUseCase, s.logger), + tokenizationKeyHandler.RotateHandler, ) - // Revoke token to prevent further detokenization - tokenization.POST("/revoke", + // Delete tokenization key + keys.DELETE("/:id", authHTTP.AuthorizationMiddleware(authDomain.DeleteCapability, auditLogUseCase, s.logger), - tokenizationHandler.RevokeHandler, + tokenizationKeyHandler.DeleteHandler, + ) + + // Tokenize plaintext with tokenization key + keys.POST("/:name/tokenize", + authHTTP.AuthorizationMiddleware(authDomain.EncryptCapability, auditLogUseCase, s.logger), + tokenizationHandler.TokenizeHandler, ) } - } - s.router = router + // Detokenize token to retrieve plaintext + tokenization.POST("/detokenize", + authHTTP.AuthorizationMiddleware(authDomain.DecryptCapability, auditLogUseCase, s.logger), + tokenizationHandler.DetokenizeHandler, + ) + + // Validate token existence and validity + tokenization.POST("/validate", + authHTTP.AuthorizationMiddleware(authDomain.ReadCapability, auditLogUseCase, s.logger), + tokenizationHandler.ValidateHandler, + ) + + // Revoke token to prevent further detokenization + tokenization.POST("/revoke", + authHTTP.AuthorizationMiddleware(authDomain.DeleteCapability, auditLogUseCase, s.logger), + tokenizationHandler.RevokeHandler, + ) + } } // GetHandler returns the http.Handler for testing purposes. @@ -371,7 +440,7 @@ type readinessResponse struct { // readinessHandler returns a simple readiness check response. func (s *Server) readinessHandler(c *gin.Context) { v, _, _ := s.reqGroup.Do("readiness", func() (interface{}, error) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second) defer cancel() dbStatus := "ok" diff --git a/internal/httputil/response.go b/internal/httputil/response.go index 405ded3..40b051c 100644 --- a/internal/httputil/response.go +++ b/internal/httputil/response.go @@ -10,15 +10,14 @@ import ( apperrors "github.com/allisson/secrets/internal/errors" ) -// ErrorResponse represents a structured error response. +// ErrorResponse represents a structured error response returned by the API. type ErrorResponse struct { - Error string `json:"error"` - Message string `json:"message,omitempty"` - Code string `json:"code,omitempty"` + Error string `json:"error"` // Machine-readable error code + Message string `json:"message,omitempty"` // Human-readable error message + Code string `json:"code,omitempty"` // Additional error details (optional) } // HandleErrorGin maps domain errors to HTTP status codes and returns a JSON response using Gin. -// This is an adapter for Gin's context that maintains the same error handling logic. func HandleErrorGin(c *gin.Context, err error, logger *slog.Logger) { if err == nil { return @@ -27,7 +26,7 @@ func HandleErrorGin(c *gin.Context, err error, logger *slog.Logger) { var statusCode int var errorResponse ErrorResponse - // Map domain errors to HTTP status codes (same logic as HandleError) + // Map domain errors to HTTP status codes based on apperrors. switch { case apperrors.Is(err, apperrors.ErrNotFound): statusCode = http.StatusNotFound diff --git a/internal/httputil/response_test.go b/internal/httputil/response_test.go new file mode 100644 index 0000000..649a89f --- /dev/null +++ b/internal/httputil/response_test.go @@ -0,0 +1,116 @@ +package httputil_test + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + + apperrors "github.com/allisson/secrets/internal/errors" + "github.com/allisson/secrets/internal/httputil" +) + +func TestHandleErrorGin(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + err error + expectedStatus int + expectedErrCode string + expectedErrMessage string + }{ + { + name: "nil error", + err: nil, + expectedStatus: http.StatusOK, // Context status remains unchanged or is 200 by default in test context + }, + { + name: "not found error", + err: apperrors.ErrNotFound, + expectedStatus: http.StatusNotFound, + expectedErrCode: "not_found", + expectedErrMessage: "The requested resource was not found", + }, + { + name: "conflict error", + err: apperrors.ErrConflict, + expectedStatus: http.StatusConflict, + expectedErrCode: "conflict", + expectedErrMessage: "A conflict occurred with existing data", + }, + { + name: "invalid input error", + err: errors.Join(apperrors.ErrInvalidInput, errors.New("custom detail")), + expectedStatus: http.StatusUnprocessableEntity, + expectedErrCode: "invalid_input", + expectedErrMessage: "invalid input: custom detail", + }, + { + name: "unauthorized error", + err: apperrors.ErrUnauthorized, + expectedStatus: http.StatusUnauthorized, + expectedErrCode: "unauthorized", + expectedErrMessage: "Authentication is required", + }, + { + name: "locked error", + err: apperrors.ErrLocked, + expectedStatus: http.StatusLocked, + expectedErrCode: "client_locked", + expectedErrMessage: "Account is locked due to too many failed authentication attempts", + }, + { + name: "forbidden error", + err: apperrors.ErrForbidden, + expectedStatus: http.StatusForbidden, + expectedErrCode: "forbidden", + expectedErrMessage: "You don't have permission to access this resource", + }, + { + name: "unknown error", + err: errors.New("something went wrong"), + expectedStatus: http.StatusInternalServerError, + expectedErrCode: "internal_error", + expectedErrMessage: "An internal error occurred", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + httputil.HandleErrorGin(c, tt.err, nil) + + if tt.err != nil { + assert.Equal(t, tt.expectedStatus, w.Code) + } + }) + } +} + +func TestHandleBadRequestGin(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + err := errors.New("bad json") + httputil.HandleBadRequestGin(c, err, nil) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestHandleValidationErrorGin(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + err := errors.New("validation failed") + httputil.HandleValidationErrorGin(c, err, nil) + + assert.Equal(t, http.StatusUnprocessableEntity, w.Code) +} diff --git a/internal/metrics/business_test.go b/internal/metrics/business_test.go index d402bbb..3260e83 100644 --- a/internal/metrics/business_test.go +++ b/internal/metrics/business_test.go @@ -2,6 +2,8 @@ package metrics import ( "context" + "net/http" + "net/http/httptest" "testing" "time" @@ -9,6 +11,15 @@ import ( "github.com/stretchr/testify/require" ) +// assertBizMetricLine checks that the Prometheus output contains a business metric +// matching the given name, partial label pattern, and value. Uses regex to handle +// extra OTel scope labels injected by the Prometheus exporter. +func assertBizMetricLine(t *testing.T, output, name, labels, value string) { + t.Helper() + pattern := name + `\{[^}]*` + labels + `[^}]*\} ` + value + assert.Regexp(t, pattern, output) +} + func TestNewBusinessMetrics(t *testing.T) { t.Run("Success_CreateBusinessMetrics", func(t *testing.T) { provider, err := NewProvider("test_app") @@ -124,5 +135,49 @@ func TestBusinessMetrics_Integration(t *testing.T) { bm.RecordDuration(ctx, "transit", "rotate_key", 150*time.Millisecond, "success") // Metrics should be recorded without errors - // Actual metric values are tested through Prometheus scraping + // Verify metrics in Prometheus registry + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) + provider.Handler().ServeHTTP(w, req) + + output := w.Body.String() + + // Check operation counts + assertBizMetricLine( + t, + output, + `integration_test_operations_total`, + `domain="auth".*operation="create_client".*status="success"`, + `2`, + ) + assertBizMetricLine( + t, + output, + `integration_test_operations_total`, + `domain="auth".*operation="create_client".*status="error"`, + `1`, + ) + assertBizMetricLine( + t, + output, + `integration_test_operations_total`, + `domain="secrets".*operation="encrypt".*status="success"`, + `1`, + ) + + // Check durations (existence) + assertBizMetricLine( + t, + output, + `integration_test_operation_duration_seconds_count`, + `domain="auth".*operation="create_client".*status="success"`, + `2`, + ) + assertBizMetricLine( + t, + output, + `integration_test_operation_duration_seconds_sum`, + `domain="auth".*operation="create_client".*status="success"`, + ``, + ) } diff --git a/internal/metrics/http_test.go b/internal/metrics/http_test.go index 8edacbc..0bdef5e 100644 --- a/internal/metrics/http_test.go +++ b/internal/metrics/http_test.go @@ -11,6 +11,16 @@ import ( "github.com/stretchr/testify/require" ) +// assertMetricContains checks that the Prometheus output contains a metric matching +// the given name, partial label values, and value. OTel adds extra scope labels so +// we use a regex rather than an exact string match. +func assertMetricLine(t *testing.T, output, name, labels, value string) { + t.Helper() + // The regex allows any number of extra labels between the user-defined ones. + pattern := name + `\{[^}]*` + labels + `[^}]*\} ` + value + assert.Regexp(t, pattern, output) +} + func TestHTTPMetricsMiddleware(t *testing.T) { gin.SetMode(gin.TestMode) @@ -32,8 +42,28 @@ func TestHTTPMetricsMiddleware(t *testing.T) { w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/test", nil) router.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) + + // Verify metric + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/metrics", nil) + provider.Handler().ServeHTTP(w, req) + + output := w.Body.String() + assertMetricLine( + t, + output, + `test_app_http_requests_total`, + `method="GET".*path="/test".*status_code="200"`, + `1`, + ) + assertMetricLine( + t, + output, + `test_app_http_request_duration_seconds_count`, + `method="GET".*path="/test".*status_code="200"`, + `1`, + ) }) t.Run("Success_RecordMultipleRequests", func(t *testing.T) { @@ -76,6 +106,34 @@ func TestHTTPMetricsMiddleware(t *testing.T) { req = httptest.NewRequest(http.MethodGet, "/error", nil) router.ServeHTTP(w, req) assert.Equal(t, http.StatusInternalServerError, w.Code) + + // Verify metrics + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/metrics", nil) + provider.Handler().ServeHTTP(w, req) + + output := w.Body.String() + assertMetricLine( + t, + output, + `test_app_http_requests_total`, + `method="GET".*path="/test".*status_code="200"`, + `5`, + ) + assertMetricLine( + t, + output, + `test_app_http_requests_total`, + `method="POST".*path="/test".*status_code="201"`, + `1`, + ) + assertMetricLine( + t, + output, + `test_app_http_requests_total`, + `method="GET".*path="/error".*status_code="500"`, + `1`, + ) }) t.Run("Success_RecordWithPathParams", func(t *testing.T) { diff --git a/internal/metrics/provider.go b/internal/metrics/provider.go index f775745..5b5e420 100644 --- a/internal/metrics/provider.go +++ b/internal/metrics/provider.go @@ -11,6 +11,8 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" promexporter "go.opentelemetry.io/otel/exporters/prometheus" "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/resource" + semconv "go.opentelemetry.io/otel/semconv/v1.24.0" ) // Provider manages the OpenTelemetry meter provider and Prometheus exporter. @@ -36,9 +38,19 @@ func NewProvider(namespace string) (*Provider, error) { return nil, fmt.Errorf("failed to create prometheus exporter: %w", err) } - // Create meter provider with Prometheus exporter + // Create meter provider with Prometheus exporter and resource attributes + res, err := resource.New(context.Background(), + resource.WithAttributes( + semconv.ServiceNamespace(namespace), + ), + ) + if err != nil { + return nil, fmt.Errorf("failed to create resource: %w", err) + } + meterProvider := metric.NewMeterProvider( metric.WithReader(exporter), + metric.WithResource(res), ) return &Provider{ diff --git a/internal/secrets/domain/secret.go b/internal/secrets/domain/secret.go index 00198e2..537f59c 100644 --- a/internal/secrets/domain/secret.go +++ b/internal/secrets/domain/secret.go @@ -9,14 +9,24 @@ import ( "github.com/google/uuid" ) +// Secret represents an encrypted secret with versioning and metadata. type Secret struct { - ID uuid.UUID - Path string - Version uint - DekID uuid.UUID + // ID is the unique identifier for this specific secret version. + ID uuid.UUID + // Path is the logical key used to access the secret (e.g., "/app/db-password"). + Path string + // Version is the monotonically increasing version number for this path. + Version uint + // DekID references the Data Encryption Key used to encrypt this secret version. + DekID uuid.UUID + // Ciphertext contains the encrypted secret data. Ciphertext []byte - Plaintext []byte `json:"-"` // In memory only - Nonce []byte - CreatedAt time.Time - DeletedAt *time.Time + // Plaintext holds the decrypted secret value in memory only; must be zeroed after use. + Plaintext []byte `json:"-"` + // Nonce is the random value used during AEAD encryption. + Nonce []byte + // CreatedAt is the UTC timestamp when this version was created. + CreatedAt time.Time + // DeletedAt marks when this secret was soft-deleted (nil if active). + DeletedAt *time.Time } diff --git a/internal/secrets/http/dto/list_secrets_response.go b/internal/secrets/http/dto/list_secrets_response.go index 4915c13..0ce6564 100644 --- a/internal/secrets/http/dto/list_secrets_response.go +++ b/internal/secrets/http/dto/list_secrets_response.go @@ -1,3 +1,4 @@ +// Package dto provides data transfer objects for HTTP request and response handling. package dto import ( diff --git a/internal/secrets/http/secret_handler.go b/internal/secrets/http/secret_handler.go index c3f77c1..582a227 100644 --- a/internal/secrets/http/secret_handler.go +++ b/internal/secrets/http/secret_handler.go @@ -4,6 +4,7 @@ package http import ( "encoding/base64" + "errors" "fmt" "log/slog" "net/http" @@ -42,18 +43,28 @@ func NewSecretHandler( } } -// CreateOrUpdateHandler creates a new secret or updates an existing one. -// POST /v1/secrets/*path - Requires EncryptCapability. -// Returns 201 Created with secret metadata (excludes plaintext value for security). -func (h *SecretHandler) CreateOrUpdateHandler(c *gin.Context) { - // Extract and validate path from URL parameter +// extractPath extracts and validates the path parameter from the request. +// Returns the cleaned path and true if valid, empty string and false if invalid. +func (h *SecretHandler) extractPath(c *gin.Context) (string, bool) { path := strings.TrimPrefix(c.Param("path"), "/") if path == "" { httputil.HandleBadRequestGin( c, - fmt.Errorf("path cannot be empty"), + errors.New("path cannot be empty"), h.logger, ) + return "", false + } + return path, true +} + +// CreateOrUpdateHandler creates a new secret or updates an existing one. +// POST /v1/secrets/*path - Requires EncryptCapability. +// Returns 201 Created with secret metadata (excludes plaintext value for security). +func (h *SecretHandler) CreateOrUpdateHandler(c *gin.Context) { + // Extract and validate path from URL parameter + path, ok := h.extractPath(c) + if !ok { return } @@ -99,13 +110,8 @@ func (h *SecretHandler) CreateOrUpdateHandler(c *gin.Context) { // Returns 200 OK with plaintext value. SECURITY: Plaintext is zeroed after response. func (h *SecretHandler) GetHandler(c *gin.Context) { // Extract and validate path from URL parameter - path := strings.TrimPrefix(c.Param("path"), "/") - if path == "" { - httputil.HandleBadRequestGin( - c, - fmt.Errorf("path cannot be empty"), - h.logger, - ) + path, ok := h.extractPath(c) + if !ok { return } @@ -119,7 +125,15 @@ func (h *SecretHandler) GetHandler(c *gin.Context) { if parseErr != nil { httputil.HandleBadRequestGin( c, - fmt.Errorf("invalid version parameter: must be a positive integer"), + errors.New("invalid version parameter: must be a positive integer"), + h.logger, + ) + return + } + if version == 0 { + httputil.HandleBadRequestGin( + c, + errors.New("version must be greater than 0"), h.logger, ) return @@ -147,13 +161,8 @@ func (h *SecretHandler) GetHandler(c *gin.Context) { // Returns 204 No Content. func (h *SecretHandler) DeleteHandler(c *gin.Context) { // Extract and validate path from URL parameter - path := strings.TrimPrefix(c.Param("path"), "/") - if path == "" { - httputil.HandleBadRequestGin( - c, - fmt.Errorf("path cannot be empty"), - h.logger, - ) + path, ok := h.extractPath(c) + if !ok { return } diff --git a/internal/secrets/http/secret_handler_test.go b/internal/secrets/http/secret_handler_test.go index 3e367b9..4903607 100644 --- a/internal/secrets/http/secret_handler_test.go +++ b/internal/secrets/http/secret_handler_test.go @@ -37,7 +37,9 @@ func setupTestHandler(t *testing.T) (*SecretHandler, *mocks.MockSecretUseCase) { } func TestSecretHandler_CreateOrUpdateHandler(t *testing.T) { + t.Parallel() t.Run("Success_ValidRequest", func(t *testing.T) { + t.Parallel() handler, mockUseCase := setupTestHandler(t) secretID := uuid.Must(uuid.NewV7()) @@ -78,6 +80,7 @@ func TestSecretHandler_CreateOrUpdateHandler(t *testing.T) { }) t.Run("Success_NestedPath", func(t *testing.T) { + t.Parallel() handler, mockUseCase := setupTestHandler(t) secretID := uuid.Must(uuid.NewV7()) @@ -116,6 +119,7 @@ func TestSecretHandler_CreateOrUpdateHandler(t *testing.T) { }) t.Run("Error_InvalidJSON", func(t *testing.T) { + t.Parallel() handler, _ := setupTestHandler(t) c, w := createTestContext(http.MethodPost, "/v1/secrets/database/password", nil) @@ -133,6 +137,7 @@ func TestSecretHandler_CreateOrUpdateHandler(t *testing.T) { }) t.Run("Error_EmptyValue", func(t *testing.T) { + t.Parallel() handler, _ := setupTestHandler(t) request := dto.CreateOrUpdateSecretRequest{ @@ -153,6 +158,7 @@ func TestSecretHandler_CreateOrUpdateHandler(t *testing.T) { }) t.Run("Error_InvalidBase64", func(t *testing.T) { + t.Parallel() handler, _ := setupTestHandler(t) request := dto.CreateOrUpdateSecretRequest{ @@ -173,6 +179,7 @@ func TestSecretHandler_CreateOrUpdateHandler(t *testing.T) { }) t.Run("Error_EmptyPath", func(t *testing.T) { + t.Parallel() handler, _ := setupTestHandler(t) request := dto.CreateOrUpdateSecretRequest{ @@ -194,6 +201,7 @@ func TestSecretHandler_CreateOrUpdateHandler(t *testing.T) { }) t.Run("Error_UseCaseError", func(t *testing.T) { + t.Parallel() handler, mockUseCase := setupTestHandler(t) path := "database/password" @@ -223,7 +231,9 @@ func TestSecretHandler_CreateOrUpdateHandler(t *testing.T) { } func TestSecretHandler_GetHandler(t *testing.T) { + t.Parallel() t.Run("Success_GetLatestVersion", func(t *testing.T) { + t.Parallel() handler, mockUseCase := setupTestHandler(t) secretID := uuid.Must(uuid.NewV7()) @@ -262,6 +272,7 @@ func TestSecretHandler_GetHandler(t *testing.T) { }) t.Run("Success_GetSpecificVersion", func(t *testing.T) { + t.Parallel() handler, mockUseCase := setupTestHandler(t) secretID := uuid.Must(uuid.NewV7()) @@ -302,6 +313,7 @@ func TestSecretHandler_GetHandler(t *testing.T) { }) t.Run("Error_InvalidVersionParameter", func(t *testing.T) { + t.Parallel() handler, _ := setupTestHandler(t) path := "database/password" @@ -322,6 +334,7 @@ func TestSecretHandler_GetHandler(t *testing.T) { }) t.Run("Error_NotFound", func(t *testing.T) { + t.Parallel() handler, mockUseCase := setupTestHandler(t) path := "nonexistent/secret" @@ -345,6 +358,7 @@ func TestSecretHandler_GetHandler(t *testing.T) { }) t.Run("Error_EmptyPath", func(t *testing.T) { + t.Parallel() handler, _ := setupTestHandler(t) c, w := createTestContext(http.MethodGet, "/v1/secrets/", nil) @@ -363,7 +377,9 @@ func TestSecretHandler_GetHandler(t *testing.T) { } func TestSecretHandler_DeleteHandler(t *testing.T) { + t.Parallel() t.Run("Success_DeleteSecret", func(t *testing.T) { + t.Parallel() handler, mockUseCase := setupTestHandler(t) path := "database/password" @@ -383,6 +399,7 @@ func TestSecretHandler_DeleteHandler(t *testing.T) { }) t.Run("Success_NestedPath", func(t *testing.T) { + t.Parallel() handler, mockUseCase := setupTestHandler(t) path := "my/nested/secret/path" @@ -401,6 +418,7 @@ func TestSecretHandler_DeleteHandler(t *testing.T) { }) t.Run("Error_NotFound", func(t *testing.T) { + t.Parallel() handler, mockUseCase := setupTestHandler(t) path := "nonexistent/secret" @@ -424,6 +442,7 @@ func TestSecretHandler_DeleteHandler(t *testing.T) { }) t.Run("Error_EmptyPath", func(t *testing.T) { + t.Parallel() handler, _ := setupTestHandler(t) c, w := createTestContext(http.MethodDelete, "/v1/secrets/", nil) @@ -442,7 +461,9 @@ func TestSecretHandler_DeleteHandler(t *testing.T) { } func TestSecretHandler_ListHandler(t *testing.T) { + t.Parallel() t.Run("Success_ListSecrets", func(t *testing.T) { + t.Parallel() handler, mockUseCase := setupTestHandler(t) now := time.Now().UTC() @@ -481,6 +502,7 @@ func TestSecretHandler_ListHandler(t *testing.T) { }) t.Run("Error_InvalidPaginationParams", func(t *testing.T) { + t.Parallel() handler, _ := setupTestHandler(t) c, w := createTestContext(http.MethodGet, "/v1/secrets?offset=invalid", nil) diff --git a/internal/secrets/repository/mysql/mysql_secret_repository.go b/internal/secrets/repository/mysql/mysql_secret_repository.go index a28b2dc..e8eb469 100644 --- a/internal/secrets/repository/mysql/mysql_secret_repository.go +++ b/internal/secrets/repository/mysql/mysql_secret_repository.go @@ -1,12 +1,13 @@ +// Package mysql implements secret persistence for MySQL databases. +// It uses binary UUID marshalling for MySQL compatibility. package mysql import ( "context" "database/sql" + "errors" "time" - "github.com/google/uuid" - "github.com/allisson/secrets/internal/database" apperrors "github.com/allisson/secrets/internal/errors" secretsDomain "github.com/allisson/secrets/internal/secrets/domain" @@ -80,7 +81,7 @@ func (m *MySQLSecretRepository) GetByPath( &secret.DeletedAt, ) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, secretsDomain.ErrSecretNotFound } return nil, apperrors.Wrap(err, "failed to get secret by path") @@ -124,7 +125,7 @@ func (m *MySQLSecretRepository) GetByPathAndVersion( &secret.DeletedAt, ) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, secretsDomain.ErrSecretNotFound } return nil, apperrors.Wrap(err, "failed to get secret by path and version") @@ -141,29 +142,26 @@ func (m *MySQLSecretRepository) GetByPathAndVersion( return &secret, nil } -// Delete performs a soft delete on a secret by setting the DeletedAt timestamp. -func (m *MySQLSecretRepository) Delete(ctx context.Context, secretID uuid.UUID) error { +// Delete performs a soft delete on all versions of a secret by path. +func (m *MySQLSecretRepository) Delete(ctx context.Context, path string) error { querier := database.GetTx(ctx, m.db) query := `UPDATE secrets SET deleted_at = ? - WHERE id = ?` + WHERE path = ? AND deleted_at IS NULL` - id, err := secretID.MarshalBinary() - if err != nil { - return apperrors.Wrap(err, "failed to marshal secret id") - } - - _, err = querier.ExecContext( + _, err := querier.ExecContext( ctx, query, time.Now().UTC(), - id, + path, ) if err != nil { return apperrors.Wrap(err, "failed to delete secret") } + // Note: We intentionally don't check rowsAffected to make Delete idempotent. + // Deleting a non-existent or already-deleted secret is not an error. return nil } diff --git a/internal/secrets/repository/mysql/mysql_secret_repository_test.go b/internal/secrets/repository/mysql/mysql_secret_repository_test.go index dbcc035..df8e71c 100644 --- a/internal/secrets/repository/mysql/mysql_secret_repository_test.go +++ b/internal/secrets/repository/mysql/mysql_secret_repository_test.go @@ -489,7 +489,7 @@ func TestMySQLSecretRepository_Delete(t *testing.T) { assert.Nil(t, deletedAt, "secret should not be deleted initially") // Delete the secret (soft delete) - err = repo.Delete(ctx, secret.ID) + err = repo.Delete(ctx, secret.Path) require.NoError(t, err) // Verify the secret still exists but has deleted_at set @@ -514,10 +514,8 @@ func TestMySQLSecretRepository_Delete_NonExistent(t *testing.T) { ctx := context.Background() // Try to delete a non-existent secret - nonExistentID := uuid.Must(uuid.NewV7()) - // Delete should not return an error even if no rows are affected - err := repo.Delete(ctx, nonExistentID) + err := repo.Delete(ctx, "/nonexistent/path") assert.NoError(t, err) } @@ -546,7 +544,7 @@ func TestMySQLSecretRepository_Delete_AlreadyDeleted(t *testing.T) { require.NoError(t, err) // Delete the secret first time - err = repo.Delete(ctx, secret.ID) + err = repo.Delete(ctx, secret.Path) require.NoError(t, err) // Get the first deletion timestamp @@ -560,8 +558,8 @@ func TestMySQLSecretRepository_Delete_AlreadyDeleted(t *testing.T) { time.Sleep(100 * time.Millisecond) - // Delete the secret second time (should update deleted_at) - err = repo.Delete(ctx, secret.ID) + // Delete the secret second time (should be idempotent - no timestamp update) + err = repo.Delete(ctx, secret.Path) require.NoError(t, err) // Get the second deletion timestamp @@ -570,8 +568,13 @@ func TestMySQLSecretRepository_Delete_AlreadyDeleted(t *testing.T) { require.NoError(t, err) require.NotNil(t, secondDeletedAt) - // The second deletion should have a newer timestamp - assert.True(t, secondDeletedAt.After(*firstDeletedAt), "second delete should update timestamp") + // The second deletion should NOT update the timestamp (idempotent behavior) + assert.Equal( + t, + firstDeletedAt.Unix(), + secondDeletedAt.Unix(), + "second delete should not update timestamp (idempotent)", + ) } func TestMySQLSecretRepository_Delete_MultipleSecrets(t *testing.T) { @@ -604,10 +607,10 @@ func TestMySQLSecretRepository_Delete_MultipleSecrets(t *testing.T) { } // Delete only the first and third secrets - err := repo.Delete(ctx, secretIDs[0]) + err := repo.Delete(ctx, "/app/secret-0") require.NoError(t, err) - err = repo.Delete(ctx, secretIDs[2]) + err = repo.Delete(ctx, "/app/secret-2") require.NoError(t, err) // Verify deletion status @@ -898,7 +901,7 @@ func TestMySQLSecretRepository_GetByPath_WithDeletedSecret(t *testing.T) { require.NoError(t, err) // Delete the secret - err = repo.Delete(ctx, secret.ID) + err = repo.Delete(ctx, secret.Path) require.NoError(t, err) // GetByPath should return ErrNotFound for deleted secrets @@ -908,72 +911,8 @@ func TestMySQLSecretRepository_GetByPath_WithDeletedSecret(t *testing.T) { assert.ErrorIs(t, err, apperrors.ErrNotFound) } -func TestMySQLSecretRepository_GetByPath_MultipleVersions_LatestDeleted(t *testing.T) { - db := testutil.SetupMySQLDB(t) - defer testutil.TeardownDB(t, db) - defer testutil.CleanupMySQLDB(t, db) - - repo := NewMySQLSecretRepository(db) - ctx := context.Background() - - _, dekID := createMySQLKekAndDek(t, db) - - path := "/app/versioned-secret" - - // Create version 1 - secret1 := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: 1, - DekID: dekID, - Ciphertext: []byte("encrypted-v1"), - Nonce: []byte("nonce-v1"), - CreatedAt: time.Now().UTC(), - } - err := repo.Create(ctx, secret1) - require.NoError(t, err) - - // Create version 2 - time.Sleep(time.Millisecond) - secret2 := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: 2, - DekID: dekID, - Ciphertext: []byte("encrypted-v2"), - Nonce: []byte("nonce-v2"), - CreatedAt: time.Now().UTC(), - } - err = repo.Create(ctx, secret2) - require.NoError(t, err) - - // Create version 3 - time.Sleep(time.Millisecond) - secret3 := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: 3, - DekID: dekID, - Ciphertext: []byte("encrypted-v3"), - Nonce: []byte("nonce-v3"), - CreatedAt: time.Now().UTC(), - } - err = repo.Create(ctx, secret3) - require.NoError(t, err) - - // Delete version 3 (the latest) - err = repo.Delete(ctx, secret3.ID) - require.NoError(t, err) - - // GetByPath should return version 2 (the latest non-deleted version) - retrievedSecret, err := repo.GetByPath(ctx, path) - require.NoError(t, err) - assert.NotNil(t, retrievedSecret) - assert.Equal(t, secret2.ID, retrievedSecret.ID) - assert.Equal(t, uint(2), retrievedSecret.Version) - assert.Equal(t, []byte("encrypted-v2"), retrievedSecret.Ciphertext) - assert.Nil(t, retrievedSecret.DeletedAt, "returned secret should not be deleted") -} +// Note: TestMySQLSecretRepository_GetByPath_MultipleVersions_LatestDeleted was removed +// because Delete now deletes ALL versions by path (Issue #2), not individual versions. func TestMySQLSecretRepository_GetByPath_MultipleVersions_AllDeleted(t *testing.T) { db := testutil.SetupMySQLDB(t) @@ -1015,9 +954,9 @@ func TestMySQLSecretRepository_GetByPath_MultipleVersions_AllDeleted(t *testing. require.NoError(t, err) // Delete both versions - err = repo.Delete(ctx, secret1.ID) + err = repo.Delete(ctx, secret1.Path) require.NoError(t, err) - err = repo.Delete(ctx, secret2.ID) + err = repo.Delete(ctx, secret2.Path) require.NoError(t, err) // GetByPath should return ErrNotFound when all versions are deleted @@ -1244,7 +1183,7 @@ func TestMySQLSecretRepository_GetByPathAndVersion_DeletedSecret(t *testing.T) { assert.NotNil(t, retrievedSecret) // Delete the secret - err = repo.Delete(ctx, secret.ID) + err = repo.Delete(ctx, secret.Path) require.NoError(t, err) // GetByPathAndVersion should return ErrNotFound for deleted secrets @@ -1254,57 +1193,8 @@ func TestMySQLSecretRepository_GetByPathAndVersion_DeletedSecret(t *testing.T) { assert.ErrorIs(t, err, apperrors.ErrNotFound) } -func TestMySQLSecretRepository_GetByPathAndVersion_MultipleVersions_OneDeleted(t *testing.T) { - db := testutil.SetupMySQLDB(t) - defer testutil.TeardownDB(t, db) - defer testutil.CleanupMySQLDB(t, db) - - repo := NewMySQLSecretRepository(db) - ctx := context.Background() - - _, dekID := createMySQLKekAndDek(t, db) - - path := "/app/mixed-versions" - - // Create versions 1, 2, and 3 - secrets := make([]*secretsDomain.Secret, 3) - for i := uint(0); i < 3; i++ { - time.Sleep(time.Millisecond) - secrets[i] = &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: i + 1, - DekID: dekID, - Ciphertext: []byte(fmt.Sprintf("encrypted-v%d", i+1)), - Nonce: []byte(fmt.Sprintf("nonce-v%d", i+1)), - CreatedAt: time.Now().UTC(), - } - err := repo.Create(ctx, secrets[i]) - require.NoError(t, err) - } - - // Delete version 2 - err := repo.Delete(ctx, secrets[1].ID) - require.NoError(t, err) - - // Version 1 should still be accessible - v1, err := repo.GetByPathAndVersion(ctx, path, 1) - require.NoError(t, err) - assert.NotNil(t, v1) - assert.Equal(t, uint(1), v1.Version) - - // Version 2 should not be accessible (deleted) - v2, err := repo.GetByPathAndVersion(ctx, path, 2) - assert.Error(t, err) - assert.Nil(t, v2) - assert.ErrorIs(t, err, apperrors.ErrNotFound) - - // Version 3 should still be accessible - v3, err := repo.GetByPathAndVersion(ctx, path, 3) - require.NoError(t, err) - assert.NotNil(t, v3) - assert.Equal(t, uint(3), v3.Version) -} +// Note: TestMySQLSecretRepository_GetByPathAndVersion_MultipleVersions_OneDeleted was removed +// because Delete now deletes ALL versions by path (Issue #2), not individual versions. func TestMySQLSecretRepository_GetByPathAndVersion_WithTransaction(t *testing.T) { db := testutil.SetupMySQLDB(t) diff --git a/internal/secrets/repository/postgresql/postgresql_secret_repository.go b/internal/secrets/repository/postgresql/postgresql_secret_repository.go index 1e9205f..ae58b2a 100644 --- a/internal/secrets/repository/postgresql/postgresql_secret_repository.go +++ b/internal/secrets/repository/postgresql/postgresql_secret_repository.go @@ -5,10 +5,9 @@ package postgresql import ( "context" "database/sql" + "errors" "time" - "github.com/google/uuid" - "github.com/allisson/secrets/internal/database" apperrors "github.com/allisson/secrets/internal/errors" secretsDomain "github.com/allisson/secrets/internal/secrets/domain" @@ -69,7 +68,7 @@ func (p *PostgreSQLSecretRepository) GetByPath( &secret.DeletedAt, ) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, secretsDomain.ErrSecretNotFound } return nil, apperrors.Wrap(err, "failed to get secret by path") @@ -103,7 +102,7 @@ func (p *PostgreSQLSecretRepository) GetByPathAndVersion( &secret.DeletedAt, ) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, secretsDomain.ErrSecretNotFound } return nil, apperrors.Wrap(err, "failed to get secret by path and version") @@ -112,24 +111,26 @@ func (p *PostgreSQLSecretRepository) GetByPathAndVersion( return &secret, nil } -// Delete performs a soft delete on a secret by setting the DeletedAt timestamp. -func (p *PostgreSQLSecretRepository) Delete(ctx context.Context, secretID uuid.UUID) error { +// Delete performs a soft delete on all versions of a secret by path. +func (p *PostgreSQLSecretRepository) Delete(ctx context.Context, path string) error { querier := database.GetTx(ctx, p.db) query := `UPDATE secrets SET deleted_at = $1 - WHERE id = $2` + WHERE path = $2 AND deleted_at IS NULL` _, err := querier.ExecContext( ctx, query, time.Now().UTC(), - secretID, + path, ) if err != nil { return apperrors.Wrap(err, "failed to delete secret") } + // Note: We intentionally don't check rowsAffected to make Delete idempotent. + // Deleting a non-existent or already-deleted secret is not an error. return nil } diff --git a/internal/secrets/repository/postgresql/postgresql_secret_repository_test.go b/internal/secrets/repository/postgresql/postgresql_secret_repository_test.go index 5609ba4..5606eb4 100644 --- a/internal/secrets/repository/postgresql/postgresql_secret_repository_test.go +++ b/internal/secrets/repository/postgresql/postgresql_secret_repository_test.go @@ -455,7 +455,7 @@ func TestPostgreSQLSecretRepository_Delete(t *testing.T) { assert.Nil(t, deletedAt, "secret should not be deleted initially") // Delete the secret (soft delete) - err = repo.Delete(ctx, secret.ID) + err = repo.Delete(ctx, secret.Path) require.NoError(t, err) // Verify the secret still exists but has deleted_at set @@ -480,10 +480,8 @@ func TestPostgreSQLSecretRepository_Delete_NonExistent(t *testing.T) { ctx := context.Background() // Try to delete a non-existent secret - nonExistentID := uuid.Must(uuid.NewV7()) - // Delete should not return an error even if no rows are affected - err := repo.Delete(ctx, nonExistentID) + err := repo.Delete(ctx, "/nonexistent/path") assert.NoError(t, err) } @@ -512,7 +510,7 @@ func TestPostgreSQLSecretRepository_Delete_AlreadyDeleted(t *testing.T) { require.NoError(t, err) // Delete the secret first time - err = repo.Delete(ctx, secret.ID) + err = repo.Delete(ctx, secret.Path) require.NoError(t, err) // Get the first deletion timestamp @@ -524,8 +522,8 @@ func TestPostgreSQLSecretRepository_Delete_AlreadyDeleted(t *testing.T) { time.Sleep(100 * time.Millisecond) - // Delete the secret second time (should update deleted_at) - err = repo.Delete(ctx, secret.ID) + // Delete the secret second time (should be idempotent - no timestamp update) + err = repo.Delete(ctx, secret.Path) require.NoError(t, err) // Get the second deletion timestamp @@ -534,8 +532,13 @@ func TestPostgreSQLSecretRepository_Delete_AlreadyDeleted(t *testing.T) { require.NoError(t, err) require.NotNil(t, secondDeletedAt) - // The second deletion should have a newer timestamp - assert.True(t, secondDeletedAt.After(*firstDeletedAt), "second delete should update timestamp") + // The second deletion should NOT update the timestamp (idempotent behavior) + assert.Equal( + t, + firstDeletedAt.Unix(), + secondDeletedAt.Unix(), + "second delete should not update timestamp (idempotent)", + ) } func TestPostgreSQLSecretRepository_Delete_MultipleSecrets(t *testing.T) { @@ -568,10 +571,10 @@ func TestPostgreSQLSecretRepository_Delete_MultipleSecrets(t *testing.T) { } // Delete only the first and third secrets - err := repo.Delete(ctx, secretIDs[0]) + err := repo.Delete(ctx, "/app/secret-0") require.NoError(t, err) - err = repo.Delete(ctx, secretIDs[2]) + err = repo.Delete(ctx, "/app/secret-2") require.NoError(t, err) // Verify deletion status @@ -895,7 +898,7 @@ func TestPostgreSQLSecretRepository_GetByPath_WithDeletedSecret(t *testing.T) { require.NoError(t, err) // Delete the secret - err = repo.Delete(ctx, secret.ID) + err = repo.Delete(ctx, secret.Path) require.NoError(t, err) // GetByPath should return ErrNotFound for deleted secrets @@ -905,72 +908,8 @@ func TestPostgreSQLSecretRepository_GetByPath_WithDeletedSecret(t *testing.T) { assert.ErrorIs(t, err, apperrors.ErrNotFound) } -func TestPostgreSQLSecretRepository_GetByPath_MultipleVersions_LatestDeleted(t *testing.T) { - db := testutil.SetupPostgresDB(t) - defer testutil.TeardownDB(t, db) - defer testutil.CleanupPostgresDB(t, db) - - repo := NewPostgreSQLSecretRepository(db) - ctx := context.Background() - - _, dekID := createKekAndDek(t, db) - - path := "/app/versioned-secret" - - // Create version 1 - secret1 := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: 1, - DekID: dekID, - Ciphertext: []byte("encrypted-v1"), - Nonce: []byte("nonce-v1"), - CreatedAt: time.Now().UTC(), - } - err := repo.Create(ctx, secret1) - require.NoError(t, err) - - // Create version 2 - time.Sleep(time.Millisecond) - secret2 := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: 2, - DekID: dekID, - Ciphertext: []byte("encrypted-v2"), - Nonce: []byte("nonce-v2"), - CreatedAt: time.Now().UTC(), - } - err = repo.Create(ctx, secret2) - require.NoError(t, err) - - // Create version 3 - time.Sleep(time.Millisecond) - secret3 := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: 3, - DekID: dekID, - Ciphertext: []byte("encrypted-v3"), - Nonce: []byte("nonce-v3"), - CreatedAt: time.Now().UTC(), - } - err = repo.Create(ctx, secret3) - require.NoError(t, err) - - // Delete version 3 (the latest) - err = repo.Delete(ctx, secret3.ID) - require.NoError(t, err) - - // GetByPath should return version 2 (the latest non-deleted version) - retrievedSecret, err := repo.GetByPath(ctx, path) - require.NoError(t, err) - assert.NotNil(t, retrievedSecret) - assert.Equal(t, secret2.ID, retrievedSecret.ID) - assert.Equal(t, uint(2), retrievedSecret.Version) - assert.Equal(t, []byte("encrypted-v2"), retrievedSecret.Ciphertext) - assert.Nil(t, retrievedSecret.DeletedAt, "returned secret should not be deleted") -} +// Note: TestPostgreSQLSecretRepository_GetByPath_MultipleVersions_LatestDeleted was removed +// because Delete now deletes ALL versions by path (Issue #2), not individual versions. func TestPostgreSQLSecretRepository_GetByPath_MultipleVersions_AllDeleted(t *testing.T) { db := testutil.SetupPostgresDB(t) @@ -1012,9 +951,9 @@ func TestPostgreSQLSecretRepository_GetByPath_MultipleVersions_AllDeleted(t *tes require.NoError(t, err) // Delete both versions - err = repo.Delete(ctx, secret1.ID) + err = repo.Delete(ctx, secret1.Path) require.NoError(t, err) - err = repo.Delete(ctx, secret2.ID) + err = repo.Delete(ctx, secret2.Path) require.NoError(t, err) // GetByPath should return ErrNotFound when all versions are deleted @@ -1241,7 +1180,7 @@ func TestPostgreSQLSecretRepository_GetByPathAndVersion_DeletedSecret(t *testing assert.NotNil(t, retrievedSecret) // Delete the secret - err = repo.Delete(ctx, secret.ID) + err = repo.Delete(ctx, secret.Path) require.NoError(t, err) // GetByPathAndVersion should return ErrNotFound for deleted secrets @@ -1251,57 +1190,8 @@ func TestPostgreSQLSecretRepository_GetByPathAndVersion_DeletedSecret(t *testing assert.ErrorIs(t, err, apperrors.ErrNotFound) } -func TestPostgreSQLSecretRepository_GetByPathAndVersion_MultipleVersions_OneDeleted(t *testing.T) { - db := testutil.SetupPostgresDB(t) - defer testutil.TeardownDB(t, db) - defer testutil.CleanupPostgresDB(t, db) - - repo := NewPostgreSQLSecretRepository(db) - ctx := context.Background() - - _, dekID := createKekAndDek(t, db) - - path := "/app/mixed-versions" - - // Create versions 1, 2, and 3 - secrets := make([]*secretsDomain.Secret, 3) - for i := uint(0); i < 3; i++ { - time.Sleep(time.Millisecond) - secrets[i] = &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: i + 1, - DekID: dekID, - Ciphertext: []byte(fmt.Sprintf("encrypted-v%d", i+1)), - Nonce: []byte(fmt.Sprintf("nonce-v%d", i+1)), - CreatedAt: time.Now().UTC(), - } - err := repo.Create(ctx, secrets[i]) - require.NoError(t, err) - } - - // Delete version 2 - err := repo.Delete(ctx, secrets[1].ID) - require.NoError(t, err) - - // Version 1 should still be accessible - v1, err := repo.GetByPathAndVersion(ctx, path, 1) - require.NoError(t, err) - assert.NotNil(t, v1) - assert.Equal(t, uint(1), v1.Version) - - // Version 2 should not be accessible (deleted) - v2, err := repo.GetByPathAndVersion(ctx, path, 2) - assert.Error(t, err) - assert.Nil(t, v2) - assert.ErrorIs(t, err, apperrors.ErrNotFound) - - // Version 3 should still be accessible - v3, err := repo.GetByPathAndVersion(ctx, path, 3) - require.NoError(t, err) - assert.NotNil(t, v3) - assert.Equal(t, uint(3), v3.Version) -} +// Note: TestPostgreSQLSecretRepository_GetByPathAndVersion_MultipleVersions_OneDeleted was removed +// because Delete now deletes ALL versions by path (Issue #2), not individual versions. func TestPostgreSQLSecretRepository_GetByPathAndVersion_WithTransaction(t *testing.T) { db := testutil.SetupPostgresDB(t) diff --git a/internal/secrets/usecase/interface.go b/internal/secrets/usecase/interface.go index 529a4fb..ba3ce00 100644 --- a/internal/secrets/usecase/interface.go +++ b/internal/secrets/usecase/interface.go @@ -26,8 +26,8 @@ type SecretRepository interface { // Create stores a new secret in the repository using transaction support from context. Create(ctx context.Context, secret *secretsDomain.Secret) error - // Delete soft deletes a secret by marking it with DeletedAt timestamp. - Delete(ctx context.Context, secretID uuid.UUID) error + // Delete soft deletes all versions of a secret by path, marking them with DeletedAt timestamp. + Delete(ctx context.Context, path string) error // GetByPath retrieves the latest version of a secret by its path. Returns ErrSecretNotFound if not found. GetByPath(ctx context.Context, path string) (*secretsDomain.Secret, error) diff --git a/internal/secrets/usecase/metrics_decorator_test.go b/internal/secrets/usecase/metrics_decorator_test.go new file mode 100644 index 0000000..7f2f51c --- /dev/null +++ b/internal/secrets/usecase/metrics_decorator_test.go @@ -0,0 +1,438 @@ +package usecase + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/allisson/secrets/internal/metrics" + secretsDomain "github.com/allisson/secrets/internal/secrets/domain" + secretsUsecaseMocks "github.com/allisson/secrets/internal/secrets/usecase/mocks" +) + +// mockBusinessMetrics is a mock implementation of metrics.BusinessMetrics for testing. +type mockBusinessMetrics struct { + mock.Mock +} + +func (m *mockBusinessMetrics) RecordOperation(ctx context.Context, domain, operation, status string) { + m.Called(ctx, domain, operation, status) +} + +func (m *mockBusinessMetrics) RecordDuration( + ctx context.Context, + domain, operation string, + duration time.Duration, + status string, +) { + m.Called(ctx, domain, operation, duration, status) +} + +var _ metrics.BusinessMetrics = (*mockBusinessMetrics)(nil) + +// TestNewSecretUseCaseWithMetrics tests the metrics decorator constructor. +func TestNewSecretUseCaseWithMetrics(t *testing.T) { + t.Parallel() + + mockUseCase := secretsUsecaseMocks.NewMockSecretUseCase(t) + mockMetrics := &mockBusinessMetrics{} + + decorator := NewSecretUseCaseWithMetrics(mockUseCase, mockMetrics) + + assert.NotNil(t, decorator) + assert.Implements(t, (*SecretUseCase)(nil), decorator) +} + +// TestMetricsDecorator_CreateOrUpdate tests the CreateOrUpdate method with metrics. +func TestMetricsDecorator_CreateOrUpdate(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("Success_RecordsSuccessMetrics", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockUseCase := secretsUsecaseMocks.NewMockSecretUseCase(t) + mockMetrics := &mockBusinessMetrics{} + + path := "/app/api-key" + value := []byte("secret-value") + expectedSecret := &secretsDomain.Secret{ + ID: uuid.Must(uuid.NewV7()), + Path: path, + Version: 1, + CreatedAt: time.Now().UTC(), + } + + // Setup expectations + mockUseCase.EXPECT(). + CreateOrUpdate(ctx, path, value). + Return(expectedSecret, nil). + Once() + + mockMetrics.On("RecordOperation", ctx, "secrets", "secret_create", "success"). + Return(). + Once() + + mockMetrics.On("RecordDuration", ctx, "secrets", "secret_create", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Execute + decorator := NewSecretUseCaseWithMetrics(mockUseCase, mockMetrics) + result, err := decorator.CreateOrUpdate(ctx, path, value) + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedSecret, result) + }) + + t.Run("Error_RecordsErrorMetrics", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockUseCase := secretsUsecaseMocks.NewMockSecretUseCase(t) + mockMetrics := &mockBusinessMetrics{} + + path := "/app/api-key" + value := []byte("secret-value") + expectedError := errors.New("database error") + + // Setup expectations + mockUseCase.EXPECT(). + CreateOrUpdate(ctx, path, value). + Return(nil, expectedError). + Once() + + mockMetrics.On("RecordOperation", ctx, "secrets", "secret_create", "error"). + Return(). + Once() + + mockMetrics.On("RecordDuration", ctx, "secrets", "secret_create", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Execute + decorator := NewSecretUseCaseWithMetrics(mockUseCase, mockMetrics) + result, err := decorator.CreateOrUpdate(ctx, path, value) + + // Assert + assert.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, expectedError, err) + }) +} + +// TestMetricsDecorator_Get tests the Get method with metrics. +func TestMetricsDecorator_Get(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("Success_RecordsSuccessMetrics", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockUseCase := secretsUsecaseMocks.NewMockSecretUseCase(t) + mockMetrics := &mockBusinessMetrics{} + + path := "/app/api-key" + expectedSecret := &secretsDomain.Secret{ + ID: uuid.Must(uuid.NewV7()), + Path: path, + Version: 1, + Plaintext: []byte("decrypted-value"), + CreatedAt: time.Now().UTC(), + } + + // Setup expectations + mockUseCase.EXPECT(). + Get(ctx, path). + Return(expectedSecret, nil). + Once() + + mockMetrics.On("RecordOperation", ctx, "secrets", "secret_get", "success"). + Return(). + Once() + + mockMetrics.On("RecordDuration", ctx, "secrets", "secret_get", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Execute + decorator := NewSecretUseCaseWithMetrics(mockUseCase, mockMetrics) + result, err := decorator.Get(ctx, path) + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedSecret, result) + }) + + t.Run("Error_RecordsErrorMetrics", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockUseCase := secretsUsecaseMocks.NewMockSecretUseCase(t) + mockMetrics := &mockBusinessMetrics{} + + path := "/app/nonexistent" + expectedError := secretsDomain.ErrSecretNotFound + + // Setup expectations + mockUseCase.EXPECT(). + Get(ctx, path). + Return(nil, expectedError). + Once() + + mockMetrics.On("RecordOperation", ctx, "secrets", "secret_get", "error"). + Return(). + Once() + + mockMetrics.On("RecordDuration", ctx, "secrets", "secret_get", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Execute + decorator := NewSecretUseCaseWithMetrics(mockUseCase, mockMetrics) + result, err := decorator.Get(ctx, path) + + // Assert + assert.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, expectedError, err) + }) +} + +// TestMetricsDecorator_GetByVersion tests the GetByVersion method with metrics. +func TestMetricsDecorator_GetByVersion(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("Success_RecordsSuccessMetrics", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockUseCase := secretsUsecaseMocks.NewMockSecretUseCase(t) + mockMetrics := &mockBusinessMetrics{} + + path := "/app/api-key" + version := uint(2) + expectedSecret := &secretsDomain.Secret{ + ID: uuid.Must(uuid.NewV7()), + Path: path, + Version: version, + Plaintext: []byte("decrypted-value"), + CreatedAt: time.Now().UTC(), + } + + // Setup expectations + mockUseCase.EXPECT(). + GetByVersion(ctx, path, version). + Return(expectedSecret, nil). + Once() + + mockMetrics.On("RecordOperation", ctx, "secrets", "secret_get_version", "success"). + Return(). + Once() + + mockMetrics.On("RecordDuration", ctx, "secrets", "secret_get_version", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Execute + decorator := NewSecretUseCaseWithMetrics(mockUseCase, mockMetrics) + result, err := decorator.GetByVersion(ctx, path, version) + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedSecret, result) + }) + + t.Run("Error_RecordsErrorMetrics", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockUseCase := secretsUsecaseMocks.NewMockSecretUseCase(t) + mockMetrics := &mockBusinessMetrics{} + + path := "/app/api-key" + version := uint(999) + expectedError := secretsDomain.ErrSecretNotFound + + // Setup expectations + mockUseCase.EXPECT(). + GetByVersion(ctx, path, version). + Return(nil, expectedError). + Once() + + mockMetrics.On("RecordOperation", ctx, "secrets", "secret_get_version", "error"). + Return(). + Once() + + mockMetrics.On("RecordDuration", ctx, "secrets", "secret_get_version", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Execute + decorator := NewSecretUseCaseWithMetrics(mockUseCase, mockMetrics) + result, err := decorator.GetByVersion(ctx, path, version) + + // Assert + assert.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, expectedError, err) + }) +} + +// TestMetricsDecorator_Delete tests the Delete method with metrics. +func TestMetricsDecorator_Delete(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("Success_RecordsSuccessMetrics", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockUseCase := secretsUsecaseMocks.NewMockSecretUseCase(t) + mockMetrics := &mockBusinessMetrics{} + + path := "/app/api-key" + + // Setup expectations + mockUseCase.EXPECT(). + Delete(ctx, path). + Return(nil). + Once() + + mockMetrics.On("RecordOperation", ctx, "secrets", "secret_delete", "success"). + Return(). + Once() + + mockMetrics.On("RecordDuration", ctx, "secrets", "secret_delete", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Execute + decorator := NewSecretUseCaseWithMetrics(mockUseCase, mockMetrics) + err := decorator.Delete(ctx, path) + + // Assert + assert.NoError(t, err) + }) + + t.Run("Error_RecordsErrorMetrics", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockUseCase := secretsUsecaseMocks.NewMockSecretUseCase(t) + mockMetrics := &mockBusinessMetrics{} + + path := "/app/nonexistent" + expectedError := secretsDomain.ErrSecretNotFound + + // Setup expectations + mockUseCase.EXPECT(). + Delete(ctx, path). + Return(expectedError). + Once() + + mockMetrics.On("RecordOperation", ctx, "secrets", "secret_delete", "error"). + Return(). + Once() + + mockMetrics.On("RecordDuration", ctx, "secrets", "secret_delete", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Execute + decorator := NewSecretUseCaseWithMetrics(mockUseCase, mockMetrics) + err := decorator.Delete(ctx, path) + + // Assert + assert.Error(t, err) + assert.Equal(t, expectedError, err) + }) +} + +// TestMetricsDecorator_List tests the List method with metrics. +func TestMetricsDecorator_List(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("Success_RecordsSuccessMetrics", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockUseCase := secretsUsecaseMocks.NewMockSecretUseCase(t) + mockMetrics := &mockBusinessMetrics{} + + offset := 0 + limit := 10 + expectedSecrets := []*secretsDomain.Secret{ + { + ID: uuid.Must(uuid.NewV7()), + Path: "/app/key1", + Version: 1, + CreatedAt: time.Now().UTC(), + }, + { + ID: uuid.Must(uuid.NewV7()), + Path: "/app/key2", + Version: 1, + CreatedAt: time.Now().UTC(), + }, + } + + // Setup expectations + mockUseCase.EXPECT(). + List(ctx, offset, limit). + Return(expectedSecrets, nil). + Once() + + mockMetrics.On("RecordOperation", ctx, "secrets", "secret_list", "success"). + Return(). + Once() + + mockMetrics.On("RecordDuration", ctx, "secrets", "secret_list", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Execute + decorator := NewSecretUseCaseWithMetrics(mockUseCase, mockMetrics) + result, err := decorator.List(ctx, offset, limit) + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedSecrets, result) + assert.Len(t, result, 2) + }) + + t.Run("Error_RecordsErrorMetrics", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockUseCase := secretsUsecaseMocks.NewMockSecretUseCase(t) + mockMetrics := &mockBusinessMetrics{} + + offset := 0 + limit := 10 + expectedError := errors.New("database error") + + // Setup expectations + mockUseCase.EXPECT(). + List(ctx, offset, limit). + Return(nil, expectedError). + Once() + + mockMetrics.On("RecordOperation", ctx, "secrets", "secret_list", "error"). + Return(). + Once() + + mockMetrics.On("RecordDuration", ctx, "secrets", "secret_list", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Execute + decorator := NewSecretUseCaseWithMetrics(mockUseCase, mockMetrics) + result, err := decorator.List(ctx, offset, limit) + + // Assert + assert.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, expectedError, err) + }) +} diff --git a/internal/secrets/usecase/mocks/mocks.go b/internal/secrets/usecase/mocks/mocks.go index e6fe236..1be2112 100644 --- a/internal/secrets/usecase/mocks/mocks.go +++ b/internal/secrets/usecase/mocks/mocks.go @@ -250,16 +250,16 @@ func (_c *MockSecretRepository_Create_Call) RunAndReturn(run func(ctx context.Co } // Delete provides a mock function for the type MockSecretRepository -func (_mock *MockSecretRepository) Delete(ctx context.Context, secretID uuid.UUID) error { - ret := _mock.Called(ctx, secretID) +func (_mock *MockSecretRepository) Delete(ctx context.Context, path string) error { + ret := _mock.Called(ctx, path) if len(ret) == 0 { panic("no return value specified for Delete") } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID) error); ok { - r0 = returnFunc(ctx, secretID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, path) } else { r0 = ret.Error(0) } @@ -273,20 +273,20 @@ type MockSecretRepository_Delete_Call struct { // Delete is a helper method to define mock.On call // - ctx context.Context -// - secretID uuid.UUID -func (_e *MockSecretRepository_Expecter) Delete(ctx interface{}, secretID interface{}) *MockSecretRepository_Delete_Call { - return &MockSecretRepository_Delete_Call{Call: _e.mock.On("Delete", ctx, secretID)} +// - path string +func (_e *MockSecretRepository_Expecter) Delete(ctx interface{}, path interface{}) *MockSecretRepository_Delete_Call { + return &MockSecretRepository_Delete_Call{Call: _e.mock.On("Delete", ctx, path)} } -func (_c *MockSecretRepository_Delete_Call) Run(run func(ctx context.Context, secretID uuid.UUID)) *MockSecretRepository_Delete_Call { +func (_c *MockSecretRepository_Delete_Call) Run(run func(ctx context.Context, path string)) *MockSecretRepository_Delete_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 uuid.UUID + var arg1 string if args[1] != nil { - arg1 = args[1].(uuid.UUID) + arg1 = args[1].(string) } run( arg0, @@ -301,7 +301,7 @@ func (_c *MockSecretRepository_Delete_Call) Return(err error) *MockSecretReposit return _c } -func (_c *MockSecretRepository_Delete_Call) RunAndReturn(run func(ctx context.Context, secretID uuid.UUID) error) *MockSecretRepository_Delete_Call { +func (_c *MockSecretRepository_Delete_Call) RunAndReturn(run func(ctx context.Context, path string) error) *MockSecretRepository_Delete_Call { _c.Call.Return(run) return _c } diff --git a/internal/secrets/usecase/secret_usecase.go b/internal/secrets/usecase/secret_usecase.go index eb698bc..1bb9c4b 100644 --- a/internal/secrets/usecase/secret_usecase.go +++ b/internal/secrets/usecase/secret_usecase.go @@ -48,20 +48,20 @@ func (s *secretUseCase) createOrUpdateSecret( value []byte, kek *cryptoDomain.Kek, ) (*secretsDomain.Secret, error) { - var version uint = 1 - - // Check if secret already exists to determine the version - existingSecret, err := s.secretRepo.GetByPath(ctx, path) - if err != nil && !errors.Is(err, secretsDomain.ErrSecretNotFound) { - return nil, err - } - if existingSecret != nil { - version = existingSecret.Version + 1 - } - // Execute the creation within a transaction var newSecret *secretsDomain.Secret - err = s.txManager.WithTx(ctx, func(txCtx context.Context) error { + err := s.txManager.WithTx(ctx, func(txCtx context.Context) error { + var version uint = 1 + + // Check if secret already exists to determine the version + // This must happen inside the transaction to prevent race conditions + existingSecret, err := s.secretRepo.GetByPath(txCtx, path) + if err != nil && !errors.Is(err, secretsDomain.ErrSecretNotFound) { + return err + } + if existingSecret != nil { + version = existingSecret.Version + 1 + } // Create a new DEK for this secret dek, err := s.keyManager.CreateDek(kek, s.dekAlgorithm) if err != nil { @@ -186,16 +186,10 @@ func (s *secretUseCase) decryptSecret( return secret, nil } -// Delete performs a soft delete on a secret by its path. +// Delete performs a soft delete on all versions of a secret by its path. func (s *secretUseCase) Delete(ctx context.Context, path string) error { - // Retrieve the secret by path to get its ID - secret, err := s.secretRepo.GetByPath(ctx, path) - if err != nil { - return err - } - - // Perform soft delete - return s.secretRepo.Delete(ctx, secret.ID) + // Perform soft delete on all versions + return s.secretRepo.Delete(ctx, path) } // List retrieves secrets without their values, ordered by path with pagination. diff --git a/internal/secrets/usecase/secret_usecase_test.go b/internal/secrets/usecase/secret_usecase_test.go index e4ab84b..ec0a343 100644 --- a/internal/secrets/usecase/secret_usecase_test.go +++ b/internal/secrets/usecase/secret_usecase_test.go @@ -20,9 +20,11 @@ import ( // TestSecretUseCase_CreateOrUpdate tests the CreateOrUpdate method of secretUseCase. func TestSecretUseCase_CreateOrUpdate(t *testing.T) { + t.Parallel() ctx := context.Background() t.Run("Success_CreateNewSecret", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -132,6 +134,7 @@ func TestSecretUseCase_CreateOrUpdate(t *testing.T) { }) t.Run("Success_UpdateExistingSecret", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -251,6 +254,7 @@ func TestSecretUseCase_CreateOrUpdate(t *testing.T) { }) t.Run("Error_ActiveKekNotFound", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -284,6 +288,7 @@ func TestSecretUseCase_CreateOrUpdate(t *testing.T) { }) t.Run("Error_SecretRepoGetByPathFails", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -310,6 +315,13 @@ func TestSecretUseCase_CreateOrUpdate(t *testing.T) { expectedError := errors.New("database error") // Setup expectations + mockTxManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + RunAndReturn(func(ctx context.Context, fn func(context.Context) error) error { + return fn(ctx) + }). + Once() + mockSecretRepo.EXPECT(). GetByPath(mock.Anything, path). Return(nil, expectedError). @@ -334,6 +346,7 @@ func TestSecretUseCase_CreateOrUpdate(t *testing.T) { }) t.Run("Error_CreateDekFails", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -399,9 +412,11 @@ func TestSecretUseCase_CreateOrUpdate(t *testing.T) { // TestSecretUseCase_Get tests the Get method of secretUseCase. func TestSecretUseCase_Get(t *testing.T) { + t.Parallel() ctx := context.Background() t.Run("Success_GetAndDecryptSecret", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -498,6 +513,7 @@ func TestSecretUseCase_Get(t *testing.T) { }) t.Run("Error_SecretNotFound", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -546,6 +562,7 @@ func TestSecretUseCase_Get(t *testing.T) { }) t.Run("Error_DekNotFound", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -610,6 +627,7 @@ func TestSecretUseCase_Get(t *testing.T) { }) t.Run("Error_KekNotFound", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -684,6 +702,7 @@ func TestSecretUseCase_Get(t *testing.T) { }) t.Run("Error_DecryptionFailed", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -780,9 +799,106 @@ func TestSecretUseCase_Get(t *testing.T) { // TestSecretUseCase_Delete tests the Delete method of secretUseCase. func TestSecretUseCase_Delete(t *testing.T) { + t.Parallel() ctx := context.Background() t.Run("Success_DeleteSecret", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockTxManager := databaseMocks.NewMockTxManager(t) + mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) + mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) + mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) + mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) + + kekID := uuid.Must(uuid.NewV7()) + kek := &cryptoDomain.Kek{ + ID: kekID, + MasterKeyID: "test-master-key", + Algorithm: cryptoDomain.AESGCM, + Key: make([]byte, 32), + EncryptedKey: []byte("encrypted-kek"), + Nonce: []byte("kek-nonce"), + Version: 1, + CreatedAt: time.Now().UTC(), + } + kekChain := createKekChain([]*cryptoDomain.Kek{kek}) + defer kekChain.Close() + + path := "/app/api-key" + + // Setup expectations + mockSecretRepo.EXPECT(). + Delete(ctx, path). + Return(nil). + Once() + + // Execute + uc := NewSecretUseCase( + mockTxManager, + mockDekRepo, + mockSecretRepo, + kekChain, + mockAEADManager, + mockKeyManager, + cryptoDomain.AESGCM, + ) + err := uc.Delete(ctx, path) + + // Assert + assert.NoError(t, err) + }) + + t.Run("Error_SecretNotFound", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockTxManager := databaseMocks.NewMockTxManager(t) + mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) + mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) + mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) + mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) + + kekID := uuid.Must(uuid.NewV7()) + kek := &cryptoDomain.Kek{ + ID: kekID, + MasterKeyID: "test-master-key", + Algorithm: cryptoDomain.AESGCM, + Key: make([]byte, 32), + EncryptedKey: []byte("encrypted-kek"), + Nonce: []byte("kek-nonce"), + Version: 1, + CreatedAt: time.Now().UTC(), + } + kekChain := createKekChain([]*cryptoDomain.Kek{kek}) + defer kekChain.Close() + + path := "/app/nonexistent" + + // Setup expectations + mockSecretRepo.EXPECT(). + Delete(ctx, path). + Return(secretsDomain.ErrSecretNotFound). + Once() + + // Execute + uc := NewSecretUseCase( + mockTxManager, + mockDekRepo, + mockSecretRepo, + kekChain, + mockAEADManager, + mockKeyManager, + cryptoDomain.AESGCM, + ) + err := uc.Delete(ctx, path) + + // Assert + assert.Error(t, err) + assert.True(t, errors.Is(err, apperrors.ErrNotFound)) + }) + + t.Run("Error_DeleteFails", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -805,27 +921,114 @@ func TestSecretUseCase_Delete(t *testing.T) { defer kekChain.Close() path := "/app/api-key" - secretID := uuid.Must(uuid.NewV7()) + expectedError := errors.New("database error") + + // Setup expectations + mockSecretRepo.EXPECT(). + Delete(ctx, path). + Return(expectedError). + Once() + + // Execute + uc := NewSecretUseCase( + mockTxManager, + mockDekRepo, + mockSecretRepo, + kekChain, + mockAEADManager, + mockKeyManager, + cryptoDomain.AESGCM, + ) + err := uc.Delete(ctx, path) + + // Assert + assert.Error(t, err) + assert.Equal(t, expectedError, err) + }) +} + +// TestSecretUseCase_GetByVersion tests the GetByVersion method of secretUseCase. +func TestSecretUseCase_GetByVersion(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("Success_GetSpecificVersion", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockTxManager := databaseMocks.NewMockTxManager(t) + mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) + mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) + mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) + mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) + mockCipher := cryptoServiceMocks.NewMockAEAD(t) + + // Create test data + kekID := uuid.Must(uuid.NewV7()) + kek := &cryptoDomain.Kek{ + ID: kekID, + MasterKeyID: "test-master-key", + Algorithm: cryptoDomain.AESGCM, + Key: make([]byte, 32), + EncryptedKey: []byte("encrypted-kek"), + Nonce: []byte("kek-nonce"), + Version: 1, + CreatedAt: time.Now().UTC(), + } + kekChain := createKekChain([]*cryptoDomain.Kek{kek}) + defer kekChain.Close() + + path := "/app/api-key" + version := uint(2) + dekID := uuid.Must(uuid.NewV7()) + ciphertext := []byte("encrypted-secret") + nonce := []byte("secret-nonce") + plaintext := []byte("secret-value") secret := &secretsDomain.Secret{ - ID: secretID, + ID: uuid.Must(uuid.NewV7()), Path: path, - Version: 1, - DekID: uuid.Must(uuid.NewV7()), - Ciphertext: []byte("encrypted-secret"), - Nonce: []byte("secret-nonce"), + Version: version, + DekID: dekID, + Ciphertext: ciphertext, + Nonce: nonce, CreatedAt: time.Now().UTC(), } + dek := &cryptoDomain.Dek{ + ID: dekID, + KekID: kekID, + Algorithm: cryptoDomain.AESGCM, + EncryptedKey: []byte("encrypted-dek"), + Nonce: []byte("dek-nonce"), + CreatedAt: time.Now().UTC(), + } + + dekKey := make([]byte, 32) + // Setup expectations mockSecretRepo.EXPECT(). - GetByPath(ctx, path). + GetByPathAndVersion(ctx, path, version). Return(secret, nil). Once() - mockSecretRepo.EXPECT(). - Delete(ctx, secretID). - Return(nil). + mockDekRepo.EXPECT(). + Get(ctx, dekID). + Return(dek, nil). + Once() + + mockKeyManager.EXPECT(). + DecryptDek(dek, kek). + Return(dekKey, nil). + Once() + + mockAEADManager.EXPECT(). + CreateCipher(dekKey, cryptoDomain.AESGCM). + Return(mockCipher, nil). + Once() + + mockCipher.EXPECT(). + Decrypt(ciphertext, nonce, mock.Anything). + Return(plaintext, nil). Once() // Execute @@ -838,13 +1041,18 @@ func TestSecretUseCase_Delete(t *testing.T) { mockKeyManager, cryptoDomain.AESGCM, ) - err := uc.Delete(ctx, path) + result, err := uc.GetByVersion(ctx, path, version) // Assert assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, path, result.Path) + assert.Equal(t, version, result.Version) + assert.Equal(t, plaintext, result.Plaintext) }) t.Run("Error_SecretNotFound", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -867,10 +1075,11 @@ func TestSecretUseCase_Delete(t *testing.T) { defer kekChain.Close() path := "/app/nonexistent" + version := uint(1) // Setup expectations mockSecretRepo.EXPECT(). - GetByPath(ctx, path). + GetByPathAndVersion(ctx, path, version). Return(nil, secretsDomain.ErrSecretNotFound). Once() @@ -884,20 +1093,23 @@ func TestSecretUseCase_Delete(t *testing.T) { mockKeyManager, cryptoDomain.AESGCM, ) - err := uc.Delete(ctx, path) + result, err := uc.GetByVersion(ctx, path, version) // Assert assert.Error(t, err) + assert.Nil(t, result) assert.True(t, errors.Is(err, apperrors.ErrNotFound)) }) - t.Run("Error_DeleteFails", func(t *testing.T) { + t.Run("Error_DecryptionFailed", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) + mockCipher := cryptoServiceMocks.NewMockAEAD(t) kekID := uuid.Must(uuid.NewV7()) kek := &cryptoDomain.Kek{ @@ -914,14 +1126,108 @@ func TestSecretUseCase_Delete(t *testing.T) { defer kekChain.Close() path := "/app/api-key" - secretID := uuid.Must(uuid.NewV7()) - expectedError := errors.New("database error") + version := uint(1) + dekID := uuid.Must(uuid.NewV7()) + ciphertext := []byte("encrypted-secret") + nonce := []byte("secret-nonce") secret := &secretsDomain.Secret{ - ID: secretID, + ID: uuid.Must(uuid.NewV7()), Path: path, - Version: 1, - DekID: uuid.Must(uuid.NewV7()), + Version: version, + DekID: dekID, + Ciphertext: ciphertext, + Nonce: nonce, + CreatedAt: time.Now().UTC(), + } + + dek := &cryptoDomain.Dek{ + ID: dekID, + KekID: kekID, + Algorithm: cryptoDomain.AESGCM, + EncryptedKey: []byte("encrypted-dek"), + Nonce: []byte("dek-nonce"), + CreatedAt: time.Now().UTC(), + } + + dekKey := make([]byte, 32) + + // Setup expectations + mockSecretRepo.EXPECT(). + GetByPathAndVersion(ctx, path, version). + Return(secret, nil). + Once() + + mockDekRepo.EXPECT(). + Get(ctx, dekID). + Return(dek, nil). + Once() + + mockKeyManager.EXPECT(). + DecryptDek(dek, kek). + Return(dekKey, nil). + Once() + + mockAEADManager.EXPECT(). + CreateCipher(dekKey, cryptoDomain.AESGCM). + Return(mockCipher, nil). + Once() + + mockCipher.EXPECT(). + Decrypt(ciphertext, nonce, mock.Anything). + Return(nil, errors.New("decryption failed")). + Once() + + // Execute + uc := NewSecretUseCase( + mockTxManager, + mockDekRepo, + mockSecretRepo, + kekChain, + mockAEADManager, + mockKeyManager, + cryptoDomain.AESGCM, + ) + result, err := uc.GetByVersion(ctx, path, version) + + // Assert + assert.Error(t, err) + assert.Nil(t, result) + assert.True(t, errors.Is(err, cryptoDomain.ErrDecryptionFailed)) + }) + + t.Run("Error_DekNotFound", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockTxManager := databaseMocks.NewMockTxManager(t) + mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) + mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) + mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) + mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) + + kekID := uuid.Must(uuid.NewV7()) + kek := &cryptoDomain.Kek{ + ID: kekID, + MasterKeyID: "test-master-key", + Algorithm: cryptoDomain.AESGCM, + Key: make([]byte, 32), + EncryptedKey: []byte("encrypted-kek"), + Nonce: []byte("kek-nonce"), + Version: 1, + CreatedAt: time.Now().UTC(), + } + kekChain := createKekChain([]*cryptoDomain.Kek{kek}) + defer kekChain.Close() + + path := "/app/api-key" + version := uint(1) + dekID := uuid.Must(uuid.NewV7()) + + secret := &secretsDomain.Secret{ + ID: uuid.Must(uuid.NewV7()), + Path: path, + Version: version, + DekID: dekID, Ciphertext: []byte("encrypted-secret"), Nonce: []byte("secret-nonce"), CreatedAt: time.Now().UTC(), @@ -929,13 +1235,89 @@ func TestSecretUseCase_Delete(t *testing.T) { // Setup expectations mockSecretRepo.EXPECT(). - GetByPath(ctx, path). + GetByPathAndVersion(ctx, path, version). Return(secret, nil). Once() + mockDekRepo.EXPECT(). + Get(ctx, dekID). + Return(nil, cryptoDomain.ErrDekNotFound). + Once() + + // Execute + uc := NewSecretUseCase( + mockTxManager, + mockDekRepo, + mockSecretRepo, + kekChain, + mockAEADManager, + mockKeyManager, + cryptoDomain.AESGCM, + ) + result, err := uc.GetByVersion(ctx, path, version) + + // Assert + assert.Error(t, err) + assert.Nil(t, result) + assert.True(t, errors.Is(err, apperrors.ErrNotFound)) + }) + + t.Run("Error_KekNotFound", func(t *testing.T) { + t.Parallel() + // Setup mocks + mockTxManager := databaseMocks.NewMockTxManager(t) + mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) + mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) + mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) + mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) + + kekID := uuid.Must(uuid.NewV7()) + kek := &cryptoDomain.Kek{ + ID: kekID, + MasterKeyID: "test-master-key", + Algorithm: cryptoDomain.AESGCM, + Key: make([]byte, 32), + EncryptedKey: []byte("encrypted-kek"), + Nonce: []byte("kek-nonce"), + Version: 1, + CreatedAt: time.Now().UTC(), + } + kekChain := createKekChain([]*cryptoDomain.Kek{kek}) + defer kekChain.Close() + + path := "/app/api-key" + version := uint(1) + dekID := uuid.Must(uuid.NewV7()) + differentKekID := uuid.Must(uuid.NewV7()) // Different KEK ID + + secret := &secretsDomain.Secret{ + ID: uuid.Must(uuid.NewV7()), + Path: path, + Version: version, + DekID: dekID, + Ciphertext: []byte("encrypted-secret"), + Nonce: []byte("secret-nonce"), + CreatedAt: time.Now().UTC(), + } + + dek := &cryptoDomain.Dek{ + ID: dekID, + KekID: differentKekID, // KEK not in chain + Algorithm: cryptoDomain.AESGCM, + EncryptedKey: []byte("encrypted-dek"), + Nonce: []byte("dek-nonce"), + CreatedAt: time.Now().UTC(), + } + + // Setup expectations mockSecretRepo.EXPECT(). - Delete(ctx, secretID). - Return(expectedError). + GetByPathAndVersion(ctx, path, version). + Return(secret, nil). + Once() + + mockDekRepo.EXPECT(). + Get(ctx, dekID). + Return(dek, nil). Once() // Execute @@ -948,19 +1330,22 @@ func TestSecretUseCase_Delete(t *testing.T) { mockKeyManager, cryptoDomain.AESGCM, ) - err := uc.Delete(ctx, path) + result, err := uc.GetByVersion(ctx, path, version) // Assert assert.Error(t, err) - assert.Equal(t, expectedError, err) + assert.Nil(t, result) + assert.True(t, errors.Is(err, cryptoDomain.ErrKekNotFound)) }) } // TestSecretUseCase_List tests the List method of secretUseCase. func TestSecretUseCase_List(t *testing.T) { + t.Parallel() ctx := context.Background() t.Run("Success_ListSecrets", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -1012,6 +1397,7 @@ func TestSecretUseCase_List(t *testing.T) { }) t.Run("Error_RepositoryFails", func(t *testing.T) { + t.Parallel() // Setup mocks mockTxManager := databaseMocks.NewMockTxManager(t) mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) @@ -1051,9 +1437,5 @@ func TestSecretUseCase_List(t *testing.T) { // createKekChain is a helper function to create a KEK chain for testing. func createKekChain(keks []*cryptoDomain.Kek) *cryptoDomain.KekChain { - if len(keks) == 0 { - // Create a dummy KEK chain with a nil active ID - return &cryptoDomain.KekChain{} - } return cryptoDomain.NewKekChain(keks) } diff --git a/internal/testutil/database.go b/internal/testutil/database.go index 82d17b9..cace1d6 100644 --- a/internal/testutil/database.go +++ b/internal/testutil/database.go @@ -1,5 +1,11 @@ // Package testutil provides testing utilities for database integration tests. // +// Environment Variables: +// +// Database connection strings can be customized via environment variables: +// - TEST_POSTGRES_DSN: PostgreSQL connection string (default: postgres://testuser:testpassword@localhost:5433/testdb?sslmode=disable) +// - TEST_MYSQL_DSN: MySQL connection string (default: testuser:testpassword@tcp(localhost:3307)/testdb?parseTime=true&multiStatements=true) +// // Database Setup: // // db := testutil.SetupPostgresDB(t) @@ -13,6 +19,11 @@ // // // Or both: // clientID, kekID := testutil.CreateTestClientAndKek(t, db, "postgres", "my-test") +// +// Migration Path: +// +// Migrations are automatically discovered by walking up from the current +// working directory until a "migrations/{dbType}" directory is found. package testutil import ( @@ -35,17 +46,34 @@ import ( ) const ( + // Default test database DSNs (can be overridden via environment variables) //nolint:gosec // test database credentials - PostgresTestDSN = "postgres://testuser:testpassword@localhost:5433/testdb?sslmode=disable" + defaultPostgresTestDSN = "postgres://testuser:testpassword@localhost:5433/testdb?sslmode=disable" //nolint:gosec // test database credentials - MySQLTestDSN = "testuser:testpassword@tcp(localhost:3307)/testdb?parseTime=true&multiStatements=true" + defaultMySQLTestDSN = "testuser:testpassword@tcp(localhost:3307)/testdb?parseTime=true&multiStatements=true" ) +// GetPostgresTestDSN returns the PostgreSQL test DSN, checking environment variable first. +func GetPostgresTestDSN() string { + if dsn := os.Getenv("TEST_POSTGRES_DSN"); dsn != "" { + return dsn + } + return defaultPostgresTestDSN +} + +// GetMySQLTestDSN returns the MySQL test DSN, checking environment variable first. +func GetMySQLTestDSN() string { + if dsn := os.Getenv("TEST_MYSQL_DSN"); dsn != "" { + return dsn + } + return defaultMySQLTestDSN +} + // SetupPostgresDB creates a new PostgreSQL database connection and runs migrations. func SetupPostgresDB(t *testing.T) *sql.DB { t.Helper() - db, err := sql.Open("postgres", PostgresTestDSN) + db, err := sql.Open("postgres", GetPostgresTestDSN()) require.NoError(t, err, "failed to connect to postgres") err = db.Ping() @@ -64,7 +92,7 @@ func SetupPostgresDB(t *testing.T) *sql.DB { func SetupMySQLDB(t *testing.T) *sql.DB { t.Helper() - db, err := sql.Open("mysql", MySQLTestDSN) + db, err := sql.Open("mysql", GetMySQLTestDSN()) require.NoError(t, err, "failed to connect to mysql") err = db.Ping() @@ -147,18 +175,25 @@ func runPostgresMigrations(t *testing.T, db *sql.DB) { driver, err := postgres.WithInstance(db, &postgres.Config{}) require.NoError(t, err, "failed to create postgres driver") - migrationsPath := getMigrationsPath("postgresql") + migrationsPath, err := getMigrationsPath("postgresql") + require.NoError(t, err, "failed to find postgresql migrations path") + m, err := migrate.NewWithDatabaseInstance( fmt.Sprintf("file://%s", migrationsPath), "postgres", driver, ) - require.NoError(t, err, "failed to create migrate instance") + require.NoError(t, err, "failed to create migrate instance for postgres") + + // Note: We intentionally do NOT close the migrate instance here because we're using + // WithInstance() with an existing database connection that we don't own. Closing the + // migrate instance would close the underlying database connection, which is managed + // by the caller. The file source driver will be garbage collected automatically. // Run migrations up err = m.Up() if err != nil && err != migrate.ErrNoChange { - require.NoError(t, err, "failed to run postgres migrations") + require.NoError(t, err, fmt.Sprintf("failed to run postgres migrations from %s", migrationsPath)) } } @@ -169,46 +204,64 @@ func runMySQLMigrations(t *testing.T, db *sql.DB) { driver, err := mysql.WithInstance(db, &mysql.Config{}) require.NoError(t, err, "failed to create mysql driver") - migrationsPath := getMigrationsPath("mysql") + migrationsPath, err := getMigrationsPath("mysql") + require.NoError(t, err, "failed to find mysql migrations path") + m, err := migrate.NewWithDatabaseInstance( fmt.Sprintf("file://%s", migrationsPath), "mysql", driver, ) - require.NoError(t, err, "failed to create migrate instance") + require.NoError(t, err, "failed to create migrate instance for mysql") + + // Note: We intentionally do NOT close the migrate instance here because we're using + // WithInstance() with an existing database connection that we don't own. Closing the + // migrate instance would close the underlying database connection, which is managed + // by the caller. The file source driver will be garbage collected automatically. // Run migrations up err = m.Up() if err != nil && err != migrate.ErrNoChange { - require.NoError(t, err, "failed to run mysql migrations") + require.NoError(t, err, fmt.Sprintf("failed to run mysql migrations from %s", migrationsPath)) } } // getMigrationsPath resolves the absolute path to migration files for the specified database type. // Walks up the directory tree from current working directory to find the migrations folder. -func getMigrationsPath(dbType string) string { +// Returns an error if the working directory cannot be determined or migrations are not found. +func getMigrationsPath(dbType string) (string, error) { // Get the project root by walking up from the current directory dir, err := os.Getwd() if err != nil { - panic(fmt.Sprintf("failed to get working directory: %v", err)) + return "", fmt.Errorf("failed to get working directory: %w", err) } // Walk up the directory tree until we find the migrations directory for { migrationsPath := filepath.Join(dir, "migrations", dbType) if _, err := os.Stat(migrationsPath); err == nil { - return migrationsPath + return migrationsPath, nil } parent := filepath.Dir(dir) if parent == dir { // Reached the root directory - panic("migrations directory not found") + return "", fmt.Errorf("migrations directory not found for %s (started from %s)", dbType, dir) } dir = parent } } +// uuidToDriverValue converts a UUID to the appropriate value for the database driver. +// PostgreSQL uses UUID natively, MySQL requires binary encoding. +func uuidToDriverValue(id uuid.UUID, driver string) (interface{}, error) { + if driver == "postgres" { + return id, nil + } + // MySQL needs binary format + return id.MarshalBinary() +} + // CreateTestClient creates a minimal active test client for repository tests. // Returns the client ID for use in foreign key relationships. The client is // created with a wildcard policy allowing all capabilities on all paths. @@ -233,12 +286,12 @@ func CreateTestClient(t *testing.T, db *sql.DB, driver, name string) uuid.UUID { policiesJSON, ) } else { // mysql - idBinary, marshalErr := clientID.MarshalBinary() - require.NoError(t, marshalErr, "failed to marshal client UUID") + idValue, marshalErr := uuidToDriverValue(clientID, driver) + require.NoError(t, marshalErr, "failed to convert client UUID for driver "+driver) _, err = db.ExecContext(ctx, `INSERT INTO clients (id, secret, name, is_active, policies, created_at) VALUES (?, ?, ?, ?, ?, NOW())`, - idBinary, + idValue, "test-secret-hash", name, true, @@ -252,6 +305,7 @@ func CreateTestClient(t *testing.T, db *sql.DB, driver, name string) uuid.UUID { // CreateTestKek creates a minimal test KEK for repository tests that need // to reference a KEK (e.g., signed audit logs). Returns the KEK ID. +// The KEK is created with algorithm 'aes-gcm' and random encrypted key data. func CreateTestKek(t *testing.T, db *sql.DB, driver, name string) uuid.UUID { t.Helper() @@ -263,22 +317,33 @@ func CreateTestKek(t *testing.T, db *sql.DB, driver, name string) uuid.UUID { _, err := rand.Read(encryptedKey) require.NoError(t, err, "failed to generate random KEK data") + // Generate nonce (12 bytes for AES-GCM) + nonce := make([]byte, 12) + _, err = rand.Read(nonce) + require.NoError(t, err, "failed to generate random nonce") + + masterKeyID := "test-master-key" + var execErr error if driver == "postgres" { _, execErr = db.ExecContext(ctx, - `INSERT INTO keks (id, version, algorithm, encrypted_key, created_at) - VALUES ($1, 1, 'aes-gcm', $2, NOW())`, + `INSERT INTO keks (id, master_key_id, version, algorithm, encrypted_key, nonce, created_at) + VALUES ($1, $2, 1, 'aes-gcm', $3, $4, NOW())`, kekID, + masterKeyID, encryptedKey, + nonce, ) } else { // mysql - idBinary, marshalErr := kekID.MarshalBinary() - require.NoError(t, marshalErr, "failed to marshal KEK UUID") + idValue, marshalErr := uuidToDriverValue(kekID, driver) + require.NoError(t, marshalErr, "failed to convert KEK UUID for driver "+driver) _, execErr = db.ExecContext(ctx, - `INSERT INTO keks (id, version, algorithm, encrypted_key, created_at) - VALUES (?, 1, 'aes-gcm', ?, NOW())`, - idBinary, + `INSERT INTO keks (id, master_key_id, version, algorithm, encrypted_key, nonce, created_at) + VALUES (?, ?, 1, 'aes-gcm', ?, ?, NOW())`, + idValue, + masterKeyID, encryptedKey, + nonce, ) } @@ -294,3 +359,134 @@ func CreateTestClientAndKek(t *testing.T, db *sql.DB, driver, baseName string) ( kekID = CreateTestKek(t, db, driver, baseName+"-kek") return clientID, kekID } + +// SkipIfNoPostgres skips the test if PostgreSQL test database is not available. +// Useful for running tests in environments without database access. +func SkipIfNoPostgres(t *testing.T) { + t.Helper() + db, err := sql.Open("postgres", GetPostgresTestDSN()) + if err != nil { + t.Skipf("PostgreSQL not available: %v", err) + } + defer func() { + _ = db.Close() // Ignore close error in skip helper + }() + + if err := db.Ping(); err != nil { + t.Skipf("PostgreSQL not available: %v", err) + } +} + +// SkipIfNoMySQL skips the test if MySQL test database is not available. +// Useful for running tests in environments without database access. +func SkipIfNoMySQL(t *testing.T) { + t.Helper() + db, err := sql.Open("mysql", GetMySQLTestDSN()) + if err != nil { + t.Skipf("MySQL not available: %v", err) + } + defer func() { + _ = db.Close() // Ignore close error in skip helper + }() + + if err := db.Ping(); err != nil { + t.Skipf("MySQL not available: %v", err) + } +} + +// CreateTestDek creates a minimal test DEK (Data Encryption Key) for repository tests. +// Returns the DEK ID. The DEK is associated with the provided KEK ID. +func CreateTestDek(t *testing.T, db *sql.DB, driver, name string, kekID uuid.UUID) uuid.UUID { + t.Helper() + + dekID := uuid.Must(uuid.NewV7()) + ctx := context.Background() + + // Dummy encrypted DEK data (32 bytes for AES-256) + encryptedKey := make([]byte, 32) + _, err := rand.Read(encryptedKey) + require.NoError(t, err, "failed to generate random DEK data") + + // Generate nonce (12 bytes for AES-GCM) + nonce := make([]byte, 12) + _, err = rand.Read(nonce) + require.NoError(t, err, "failed to generate random nonce") + + var execErr error + if driver == "postgres" { + _, execErr = db.ExecContext(ctx, + `INSERT INTO deks (id, kek_id, algorithm, encrypted_key, nonce, created_at) + VALUES ($1, $2, 'aes-gcm', $3, $4, NOW())`, + dekID, + kekID, + encryptedKey, + nonce, + ) + } else { // mysql + dekIDValue, marshalErr := uuidToDriverValue(dekID, driver) + require.NoError(t, marshalErr, "failed to convert DEK UUID for driver "+driver) + + kekIDValue, marshalErr := uuidToDriverValue(kekID, driver) + require.NoError(t, marshalErr, "failed to convert KEK UUID for driver "+driver) + + _, execErr = db.ExecContext(ctx, + `INSERT INTO deks (id, kek_id, algorithm, encrypted_key, nonce, created_at) + VALUES (?, ?, 'aes-gcm', ?, ?, NOW())`, + dekIDValue, + kekIDValue, + encryptedKey, + nonce, + ) + } + + require.NoError(t, execErr, "failed to create test DEK: "+name) + return dekID +} + +// ValidateTestClient verifies that a test client was created with expected values. +// Returns true if the client exists and is active, false otherwise. +func ValidateTestClient(t *testing.T, db *sql.DB, driver string, clientID uuid.UUID) bool { + t.Helper() + + ctx := context.Background() + var isActive bool + var err error + + if driver == "postgres" { + err = db.QueryRowContext(ctx, `SELECT is_active FROM clients WHERE id = $1`, clientID).Scan(&isActive) + } else { // mysql + idValue, marshalErr := uuidToDriverValue(clientID, driver) + require.NoError(t, marshalErr, "failed to convert client UUID for validation") + err = db.QueryRowContext(ctx, `SELECT is_active FROM clients WHERE id = ?`, idValue).Scan(&isActive) + } + + if err != nil { + return false + } + + return isActive +} + +// ValidateTestKek verifies that a test KEK was created with expected values. +// Returns true if the KEK exists, false otherwise. +func ValidateTestKek(t *testing.T, db *sql.DB, driver string, kekID uuid.UUID) bool { + t.Helper() + + ctx := context.Background() + var version int + var err error + + if driver == "postgres" { + err = db.QueryRowContext(ctx, `SELECT version FROM keks WHERE id = $1`, kekID).Scan(&version) + } else { // mysql + idValue, marshalErr := uuidToDriverValue(kekID, driver) + require.NoError(t, marshalErr, "failed to convert KEK UUID for validation") + err = db.QueryRowContext(ctx, `SELECT version FROM keks WHERE id = ?`, idValue).Scan(&version) + } + + if err != nil { + return false + } + + return version > 0 +} diff --git a/internal/testutil/database_test.go b/internal/testutil/database_test.go new file mode 100644 index 0000000..b6d6673 --- /dev/null +++ b/internal/testutil/database_test.go @@ -0,0 +1,560 @@ +package testutil + +import ( + "database/sql" + "os" + "path/filepath" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetPostgresTestDSN(t *testing.T) { + tests := []struct { + name string + envValue string + want string + }{ + { + name: "default DSN when env var not set", + envValue: "", + want: defaultPostgresTestDSN, + }, + //nolint:gosec // test credentials are safe in tests + { + name: "custom DSN from env var", + envValue: "postgres://custom:password@localhost:5432/customdb", + want: "postgres://custom:password@localhost:5432/customdb", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save original env var + original := os.Getenv("TEST_POSTGRES_DSN") + defer func() { + if original != "" { + _ = os.Setenv("TEST_POSTGRES_DSN", original) + } else { + _ = os.Unsetenv("TEST_POSTGRES_DSN") + } + }() + + // Set test env var + if tt.envValue != "" { + _ = os.Setenv("TEST_POSTGRES_DSN", tt.envValue) + } else { + _ = os.Unsetenv("TEST_POSTGRES_DSN") + } + + got := GetPostgresTestDSN() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestGetMySQLTestDSN(t *testing.T) { + tests := []struct { + name string + envValue string + want string + }{ + { + name: "default DSN when env var not set", + envValue: "", + want: defaultMySQLTestDSN, + }, + { + name: "custom DSN from env var", + envValue: "custom:password@tcp(localhost:3306)/customdb", + want: "custom:password@tcp(localhost:3306)/customdb", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save original env var + original := os.Getenv("TEST_MYSQL_DSN") + defer func() { + if original != "" { + _ = os.Setenv("TEST_MYSQL_DSN", original) + } else { + _ = os.Unsetenv("TEST_MYSQL_DSN") + } + }() + + // Set test env var + if tt.envValue != "" { + _ = os.Setenv("TEST_MYSQL_DSN", tt.envValue) + } else { + _ = os.Unsetenv("TEST_MYSQL_DSN") + } + + got := GetMySQLTestDSN() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestGetMigrationsPath(t *testing.T) { + tests := []struct { + name string + dbType string + wantErr bool + }{ + { + name: "find postgresql migrations", + dbType: "postgresql", + wantErr: false, + }, + { + name: "find mysql migrations", + dbType: "mysql", + wantErr: false, + }, + { + name: "non-existent database type", + dbType: "nonexistent", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getMigrationsPath(tt.dbType) + if tt.wantErr { + assert.Error(t, err) + assert.Empty(t, got) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, got) + // Verify the path exists + _, statErr := os.Stat(got) + assert.NoError(t, statErr, "migrations path should exist") + // Verify it contains the expected database type + assert.Contains(t, got, tt.dbType) + } + }) + } +} + +func TestGetMigrationsPathFromDifferentWorkingDir(t *testing.T) { + // Save original working directory + originalWd, err := os.Getwd() + require.NoError(t, err) + defer func() { + _ = os.Chdir(originalWd) // Restore working directory + }() + + // Change to a subdirectory within the project + // This simulates running tests from a deeper directory + subDir := filepath.Join(originalWd, "testdata") + //nolint:gosec // 0755 is appropriate for test directories + err = os.MkdirAll(subDir, 0755) + require.NoError(t, err) + defer func() { + _ = os.RemoveAll(subDir) // Clean up test directory + }() + + err = os.Chdir(subDir) + require.NoError(t, err) + + // Should still find migrations by walking up from the subdirectory + path, err := getMigrationsPath("postgresql") + assert.NoError(t, err) + assert.NotEmpty(t, path) + assert.Contains(t, path, "postgresql") +} + +func TestUuidToDriverValue(t *testing.T) { + testID := uuid.Must(uuid.NewV7()) + + tests := []struct { + name string + id uuid.UUID + driver string + wantErr bool + checkValue func(t *testing.T, value interface{}) + }{ + { + name: "postgres returns UUID directly", + id: testID, + driver: "postgres", + wantErr: false, + checkValue: func(t *testing.T, value interface{}) { + gotUUID, ok := value.(uuid.UUID) + assert.True(t, ok, "value should be uuid.UUID") + assert.Equal(t, testID, gotUUID) + }, + }, + { + name: "mysql returns binary", + id: testID, + driver: "mysql", + wantErr: false, + checkValue: func(t *testing.T, value interface{}) { + gotBytes, ok := value.([]byte) + assert.True(t, ok, "value should be []byte") + assert.Len(t, gotBytes, 16, "UUID binary should be 16 bytes") + }, + }, + { + name: "unknown driver defaults to mysql behavior", + id: testID, + driver: "unknown", + wantErr: false, + checkValue: func(t *testing.T, value interface{}) { + _, ok := value.([]byte) + assert.True(t, ok, "value should be []byte for unknown driver") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + value, err := uuidToDriverValue(tt.id, tt.driver) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.checkValue != nil { + tt.checkValue(t, value) + } + } + }) + } +} + +func TestSetupPostgresDB(t *testing.T) { + // Skip if PostgreSQL is not available + SkipIfNoPostgres(t) + + db := SetupPostgresDB(t) + defer TeardownDB(t, db) + + // Verify database connection is working + err := db.Ping() + assert.NoError(t, err) + + // Verify database is clean (no clients should exist) + var count int + err = db.QueryRow("SELECT COUNT(*) FROM clients").Scan(&count) + assert.NoError(t, err) + assert.Equal(t, 0, count, "database should be clean after setup") +} + +func TestSetupMySQLDB(t *testing.T) { + // Skip if MySQL is not available + SkipIfNoMySQL(t) + + db := SetupMySQLDB(t) + defer TeardownDB(t, db) + + // Verify database connection is working + err := db.Ping() + assert.NoError(t, err) + + // Verify database is clean (no clients should exist) + var count int + err = db.QueryRow("SELECT COUNT(*) FROM clients").Scan(&count) + assert.NoError(t, err) + assert.Equal(t, 0, count, "database should be clean after setup") +} + +func TestTeardownDB(t *testing.T) { + SkipIfNoPostgres(t) + + db := SetupPostgresDB(t) + require.NotNil(t, db) + + // Teardown should close the connection + TeardownDB(t, db) + + // Attempting to ping after teardown should fail + err := db.Ping() + assert.Error(t, err, "database should be closed after teardown") +} + +func TestTeardownDBWithNilDB(t *testing.T) { + // Should not panic with nil database + assert.NotPanics(t, func() { + TeardownDB(t, nil) + }) +} + +func TestCleanupPostgresDB(t *testing.T) { + SkipIfNoPostgres(t) + + db := SetupPostgresDB(t) + defer TeardownDB(t, db) + + // Create test data + clientID := CreateTestClient(t, db, "postgres", "test-cleanup-client") + require.NotEqual(t, uuid.Nil, clientID) + + // Verify data exists + var count int + err := db.QueryRow("SELECT COUNT(*) FROM clients").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count) + + // Cleanup should remove all data + CleanupPostgresDB(t, db) + + // Verify data is removed + err = db.QueryRow("SELECT COUNT(*) FROM clients").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 0, count, "cleanup should remove all data") +} + +func TestCleanupMySQLDB(t *testing.T) { + SkipIfNoMySQL(t) + + db := SetupMySQLDB(t) + defer TeardownDB(t, db) + + // Create test data + clientID := CreateTestClient(t, db, "mysql", "test-cleanup-client") + require.NotEqual(t, uuid.Nil, clientID) + + // Verify data exists + var count int + err := db.QueryRow("SELECT COUNT(*) FROM clients").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count) + + // Cleanup should remove all data + CleanupMySQLDB(t, db) + + // Verify data is removed + err = db.QueryRow("SELECT COUNT(*) FROM clients").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 0, count, "cleanup should remove all data") +} + +func TestCreateTestClient(t *testing.T) { + tests := []struct { + name string + driver string + setup func(t *testing.T) *sql.DB + skip func(t *testing.T) + }{ + { + name: "create client in postgres", + driver: "postgres", + setup: SetupPostgresDB, + skip: SkipIfNoPostgres, + }, + { + name: "create client in mysql", + driver: "mysql", + setup: SetupMySQLDB, + skip: SkipIfNoMySQL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.skip(t) + + db := tt.setup(t) + defer TeardownDB(t, db) + + clientID := CreateTestClient(t, db, tt.driver, "test-client") + assert.NotEqual(t, uuid.Nil, clientID) + + // Verify client was created + valid := ValidateTestClient(t, db, tt.driver, clientID) + assert.True(t, valid, "client should exist and be active") + }) + } +} + +func TestCreateTestKek(t *testing.T) { + tests := []struct { + name string + driver string + setup func(t *testing.T) *sql.DB + skip func(t *testing.T) + }{ + { + name: "create KEK in postgres", + driver: "postgres", + setup: SetupPostgresDB, + skip: SkipIfNoPostgres, + }, + { + name: "create KEK in mysql", + driver: "mysql", + setup: SetupMySQLDB, + skip: SkipIfNoMySQL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.skip(t) + + db := tt.setup(t) + defer TeardownDB(t, db) + + kekID := CreateTestKek(t, db, tt.driver, "test-kek") + assert.NotEqual(t, uuid.Nil, kekID) + + // Verify KEK was created + valid := ValidateTestKek(t, db, tt.driver, kekID) + assert.True(t, valid, "KEK should exist") + }) + } +} + +func TestCreateTestClientAndKek(t *testing.T) { + tests := []struct { + name string + driver string + setup func(t *testing.T) *sql.DB + skip func(t *testing.T) + }{ + { + name: "create client and KEK in postgres", + driver: "postgres", + setup: SetupPostgresDB, + skip: SkipIfNoPostgres, + }, + { + name: "create client and KEK in mysql", + driver: "mysql", + setup: SetupMySQLDB, + skip: SkipIfNoMySQL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.skip(t) + + db := tt.setup(t) + defer TeardownDB(t, db) + + clientID, kekID := CreateTestClientAndKek(t, db, tt.driver, "test-fixtures") + + assert.NotEqual(t, uuid.Nil, clientID) + assert.NotEqual(t, uuid.Nil, kekID) + assert.NotEqual(t, clientID, kekID, "client ID and KEK ID should be different") + + // Verify both were created + clientValid := ValidateTestClient(t, db, tt.driver, clientID) + assert.True(t, clientValid, "client should exist") + + kekValid := ValidateTestKek(t, db, tt.driver, kekID) + assert.True(t, kekValid, "KEK should exist") + }) + } +} + +func TestCreateTestDek(t *testing.T) { + tests := []struct { + name string + driver string + setup func(t *testing.T) *sql.DB + skip func(t *testing.T) + }{ + { + name: "create DEK in postgres", + driver: "postgres", + setup: SetupPostgresDB, + skip: SkipIfNoPostgres, + }, + { + name: "create DEK in mysql", + driver: "mysql", + setup: SetupMySQLDB, + skip: SkipIfNoMySQL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.skip(t) + + db := tt.setup(t) + defer TeardownDB(t, db) + + // Create prerequisites (only need KEK, not client) + kekID := CreateTestKek(t, db, tt.driver, "test-dek-kek") + + // Create DEK + dekID := CreateTestDek(t, db, tt.driver, "test-dek", kekID) + assert.NotEqual(t, uuid.Nil, dekID) + + // Verify DEK was created by checking it exists + var algorithm string + var err error + if tt.driver == "postgres" { + err = db.QueryRow("SELECT algorithm FROM deks WHERE id = $1", dekID).Scan(&algorithm) + } else { + idValue, marshalErr := uuidToDriverValue(dekID, tt.driver) + require.NoError(t, marshalErr) + err = db.QueryRow("SELECT algorithm FROM deks WHERE id = ?", idValue).Scan(&algorithm) + } + assert.NoError(t, err) + assert.Equal(t, "aes-gcm", algorithm, "DEK should have aes-gcm algorithm") + }) + } +} + +func TestValidateTestClient(t *testing.T) { + SkipIfNoPostgres(t) + + db := SetupPostgresDB(t) + defer TeardownDB(t, db) + + // Test with valid client + clientID := CreateTestClient(t, db, "postgres", "valid-client") + valid := ValidateTestClient(t, db, "postgres", clientID) + assert.True(t, valid, "should validate existing client") + + // Test with non-existent client + nonExistentID := uuid.Must(uuid.NewV7()) + valid = ValidateTestClient(t, db, "postgres", nonExistentID) + assert.False(t, valid, "should not validate non-existent client") +} + +func TestValidateTestKek(t *testing.T) { + SkipIfNoPostgres(t) + + db := SetupPostgresDB(t) + defer TeardownDB(t, db) + + // Test with valid KEK + kekID := CreateTestKek(t, db, "postgres", "valid-kek") + valid := ValidateTestKek(t, db, "postgres", kekID) + assert.True(t, valid, "should validate existing KEK") + + // Test with non-existent KEK + nonExistentID := uuid.Must(uuid.NewV7()) + valid = ValidateTestKek(t, db, "postgres", nonExistentID) + assert.False(t, valid, "should not validate non-existent KEK") +} + +func TestSkipIfNoPostgres(t *testing.T) { + // This test verifies that SkipIfNoPostgres doesn't panic + // We can't easily test the actual skipping behavior without mocking + t.Run("does not panic", func(t *testing.T) { + assert.NotPanics(t, func() { + SkipIfNoPostgres(t) + }) + }) +} + +func TestSkipIfNoMySQL(t *testing.T) { + // This test verifies that SkipIfNoMySQL doesn't panic + // We can't easily test the actual skipping behavior without mocking + t.Run("does not panic", func(t *testing.T) { + assert.NotPanics(t, func() { + SkipIfNoMySQL(t) + }) + }) +} diff --git a/internal/tokenization/doc.go b/internal/tokenization/doc.go new file mode 100644 index 0000000..506138c --- /dev/null +++ b/internal/tokenization/doc.go @@ -0,0 +1,142 @@ +/* +Package tokenization provides secure, format-preserving tokenization for sensitive data. + +The tokenization module enables replacing sensitive values (credit cards, SSNs, PII) with +non-sensitive tokens while maintaining format compatibility with existing systems. + +# Architecture + +The module follows Clean Architecture principles: + - domain: Core entities (TokenizationKey, Token) and business rules + - usecase: Business logic orchestration + - service: Token generation algorithms (UUID, Numeric, Luhn, Alphanumeric) + - repository: Data persistence (MySQL, PostgreSQL) + - http: HTTP handlers and DTOs + +# Security Model + +Uses a three-tier key hierarchy: + - Master Key (MK): Root key stored in HSM/KMS + - Key Encryption Key (KEK): Encrypts DEKs, derived from MK + - Data Encryption Key (DEK): Encrypts plaintexts, encrypted with KEK + +Plaintext is encrypted with AES-GCM or ChaCha20-Poly1305 AEAD ciphers. + +# Token Formats + + - UUID: UUIDv7 tokens (standard format) + - Numeric: Numeric-only tokens (configurable length) + - Luhn-Preserving: Maintains Luhn check digit (for credit cards) + - Alphanumeric: Alphanumeric tokens (A-Z, 0-9) + +# Deterministic vs Non-Deterministic + +Deterministic Mode (IsDeterministic: true): + - Same plaintext always produces same token + - Enables consistent token reuse + - Risk: Frequency analysis possible + +Non-Deterministic Mode (IsDeterministic: false): + - Same plaintext produces different tokens each time + - Maximum security (recommended) + - Prevents frequency analysis + +# Basic Usage + +Create a tokenization key: + + key, err := tokenizationKeyUseCase.Create( + ctx, + "credit-card-key", + domain.FormatNumeric, + false, // non-deterministic + cryptoDomain.AESGCM, + ) + +Tokenize sensitive data: + + plaintext := []byte("4532123456789012") + metadata := map[string]any{"last4": "9012"} + expiresAt := time.Now().Add(24 * time.Hour) + + token, err := tokenizationUseCase.Tokenize( + ctx, + "credit-card-key", + plaintext, + metadata, + &expiresAt, + ) + +Detokenize to retrieve original value: + + plaintext, metadata, err := tokenizationUseCase.Detokenize(ctx, token.Token) + defer cryptoDomain.Zero(plaintext) // CRITICAL: Zero plaintext after use + +# Key Rotation + +Create a new version of an existing key: + + newKey, err := tokenizationKeyUseCase.Rotate( + ctx, + "credit-card-key", + domain.FormatNumeric, + false, + cryptoDomain.AESGCM, + ) + // newKey.Version = 2 (old tokens still work with version 1) + +# Token Lifecycle + +Validate token: + + isValid, err := tokenizationUseCase.Validate(ctx, "1234567890123456") + +Revoke token: + + err = tokenizationUseCase.Revoke(ctx, "1234567890123456") + +Cleanup expired tokens: + + count, err := tokenizationUseCase.CleanupExpired(ctx, 30, false) + +# Security Best Practices + +1. Always zero plaintext after use: + + plaintext, _, err := tokenizationUseCase.Detokenize(ctx, token) + defer cryptoDomain.Zero(plaintext) + +2. Never store sensitive data in metadata: + + // ✅ Good: Only display data + metadata := map[string]any{"last4": "9012", "exp": "12/25"} + + // ❌ Bad: Sensitive data in metadata + metadata := map[string]any{"full_number": "4532123456789012"} + +3. Implement rate limiting on Tokenize(): + + // Recommended: 100 requests/minute per user + +4. Use appropriate determinism: + + // Use deterministic for analytics/joins + // Use non-deterministic for maximum security (default) + +# Constraints + + - Maximum plaintext size: 64 KB (enforced automatically) + - Maximum token length: 255 characters (format-preserving only) + - Minimum Luhn token length: 2 characters + +# Compliance + +Supports compliance with: + - PCI DSS 4.0 (Requirement 3) + - GDPR (Article 4, Article 25) + - HIPAA (PHI de-identification) + - CCPA (Consumer data protection) + +For complete documentation, see README.md. +*/ +package tokenization diff --git a/internal/tokenization/domain/const.go b/internal/tokenization/domain/const.go index d88a2d7..89727ff 100644 --- a/internal/tokenization/domain/const.go +++ b/internal/tokenization/domain/const.go @@ -16,6 +16,22 @@ const ( FormatAlphanumeric FormatType = "alphanumeric" ) +// Token length constraints +const ( + // MaxTokenLength is the maximum allowed token length for format-preserving tokens. + // This limit applies to Numeric, Luhn-Preserving, and Alphanumeric formats. + MaxTokenLength = 255 + + // MinLuhnTokenLength is the minimum token length required for Luhn algorithm validation. + // Luhn check requires at least 2 digits (payload + check digit). + MinLuhnTokenLength = 2 + + // MaxPlaintextSize is the maximum allowed plaintext size for tokenization (64 KB). + // This limit prevents DoS attacks from extremely large inputs and ensures reasonable + // encryption performance. + MaxPlaintextSize = 65536 // 64 KB +) + // Validate checks if the format type is valid. func (f FormatType) Validate() error { switch f { diff --git a/internal/tokenization/domain/errors.go b/internal/tokenization/domain/errors.go index 49dd674..c05eea9 100644 --- a/internal/tokenization/domain/errors.go +++ b/internal/tokenization/domain/errors.go @@ -28,4 +28,28 @@ var ( // ErrValueTooLong indicates the value exceeds the maximum allowed length. ErrValueTooLong = errors.Wrap(errors.ErrInvalidInput, "value exceeds maximum length") + + // ErrPlaintextTooLarge indicates the plaintext exceeds maximum allowed size. + ErrPlaintextTooLarge = errors.Wrap(errors.ErrInvalidInput, "plaintext exceeds maximum size of 64KB") + + // ErrPlaintextEmpty indicates the plaintext is empty. + ErrPlaintextEmpty = errors.Wrap(errors.ErrInvalidInput, "plaintext cannot be empty") + + // ErrTokenLengthInvalid indicates the token length is invalid for the format. + ErrTokenLengthInvalid = errors.Wrap(errors.ErrInvalidInput, "token length invalid for format type") + + // ErrTokenizationKeyNameEmpty indicates the tokenization key name is empty. + ErrTokenizationKeyNameEmpty = errors.Wrap(errors.ErrInvalidInput, "tokenization key name cannot be empty") + + // ErrTokenizationKeyVersionInvalid indicates the version is invalid (must be > 0). + ErrTokenizationKeyVersionInvalid = errors.Wrap( + errors.ErrInvalidInput, + "tokenization key version must be greater than 0", + ) + + // ErrTokenizationKeyDekIDInvalid indicates the DEK ID is invalid (nil UUID). + ErrTokenizationKeyDekIDInvalid = errors.Wrap( + errors.ErrInvalidInput, + "tokenization key DEK ID cannot be nil", + ) ) diff --git a/internal/tokenization/domain/errors_test.go b/internal/tokenization/domain/errors_test.go new file mode 100644 index 0000000..e73cf43 --- /dev/null +++ b/internal/tokenization/domain/errors_test.go @@ -0,0 +1,179 @@ +package domain + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + apperrors "github.com/allisson/secrets/internal/errors" +) + +func TestErrors_Wrapping(t *testing.T) { + tests := []struct { + name string + err error + expectedMsg string + }{ + { + name: "ErrTokenizationKeyNotFound", + err: ErrTokenizationKeyNotFound, + expectedMsg: "tokenization key not found", + }, + { + name: "ErrTokenizationKeyAlreadyExists", + err: ErrTokenizationKeyAlreadyExists, + expectedMsg: "tokenization key already exists", + }, + { + name: "ErrTokenNotFound", + err: ErrTokenNotFound, + expectedMsg: "token not found", + }, + { + name: "ErrTokenExpired", + err: ErrTokenExpired, + expectedMsg: "token has expired", + }, + { + name: "ErrTokenRevoked", + err: ErrTokenRevoked, + expectedMsg: "token has been revoked", + }, + { + name: "ErrInvalidFormatType", + err: ErrInvalidFormatType, + expectedMsg: "invalid format type", + }, + { + name: "ErrInvalidTokenLength", + err: ErrInvalidTokenLength, + expectedMsg: "invalid token length for format", + }, + { + name: "ErrValueTooLong", + err: ErrValueTooLong, + expectedMsg: "value exceeds maximum length", + }, + { + name: "ErrPlaintextTooLarge", + err: ErrPlaintextTooLarge, + expectedMsg: "plaintext exceeds maximum size of 64KB", + }, + { + name: "ErrPlaintextEmpty", + err: ErrPlaintextEmpty, + expectedMsg: "plaintext cannot be empty", + }, + { + name: "ErrTokenLengthInvalid", + err: ErrTokenLengthInvalid, + expectedMsg: "token length invalid for format type", + }, + { + name: "ErrTokenizationKeyNameEmpty", + err: ErrTokenizationKeyNameEmpty, + expectedMsg: "tokenization key name cannot be empty", + }, + { + name: "ErrTokenizationKeyVersionInvalid", + err: ErrTokenizationKeyVersionInvalid, + expectedMsg: "tokenization key version must be greater than 0", + }, + { + name: "ErrTokenizationKeyDekIDInvalid", + err: ErrTokenizationKeyDekIDInvalid, + expectedMsg: "tokenization key DEK ID cannot be nil", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Error(t, tt.err) + assert.Contains(t, tt.err.Error(), tt.expectedMsg) + }) + } +} + +func TestErrors_Types(t *testing.T) { + tests := []struct { + name string + err error + expectedType error + }{ + { + name: "ErrTokenizationKeyNotFound_IsNotFound", + err: ErrTokenizationKeyNotFound, + expectedType: apperrors.ErrNotFound, + }, + { + name: "ErrTokenizationKeyAlreadyExists_IsConflict", + err: ErrTokenizationKeyAlreadyExists, + expectedType: apperrors.ErrConflict, + }, + { + name: "ErrTokenNotFound_IsNotFound", + err: ErrTokenNotFound, + expectedType: apperrors.ErrNotFound, + }, + { + name: "ErrTokenExpired_IsInvalidInput", + err: ErrTokenExpired, + expectedType: apperrors.ErrInvalidInput, + }, + { + name: "ErrTokenRevoked_IsInvalidInput", + err: ErrTokenRevoked, + expectedType: apperrors.ErrInvalidInput, + }, + { + name: "ErrInvalidFormatType_IsInvalidInput", + err: ErrInvalidFormatType, + expectedType: apperrors.ErrInvalidInput, + }, + { + name: "ErrPlaintextTooLarge_IsInvalidInput", + err: ErrPlaintextTooLarge, + expectedType: apperrors.ErrInvalidInput, + }, + { + name: "ErrPlaintextEmpty_IsInvalidInput", + err: ErrPlaintextEmpty, + expectedType: apperrors.ErrInvalidInput, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.True(t, apperrors.Is(tt.err, tt.expectedType), + "expected %v to be of type %v", tt.err, tt.expectedType) + }) + } +} + +func TestErrors_Distinct(t *testing.T) { + // Verify that all errors are distinct + errors := []error{ + ErrTokenizationKeyNotFound, + ErrTokenizationKeyAlreadyExists, + ErrTokenNotFound, + ErrTokenExpired, + ErrTokenRevoked, + ErrInvalidFormatType, + ErrInvalidTokenLength, + ErrValueTooLong, + ErrPlaintextTooLarge, + ErrPlaintextEmpty, + ErrTokenLengthInvalid, + ErrTokenizationKeyNameEmpty, + ErrTokenizationKeyVersionInvalid, + ErrTokenizationKeyDekIDInvalid, + } + + // Check each error against all others + for i := 0; i < len(errors); i++ { + for j := i + 1; j < len(errors); j++ { + assert.NotEqual(t, errors[i].Error(), errors[j].Error(), + "errors at index %d and %d have the same message", i, j) + } + } +} diff --git a/internal/tokenization/domain/token.go b/internal/tokenization/domain/token.go index 1fc23c0..bea87e8 100644 --- a/internal/tokenization/domain/token.go +++ b/internal/tokenization/domain/token.go @@ -15,10 +15,15 @@ type Token struct { ValueHash *string Ciphertext []byte Nonce []byte - Metadata map[string]any - CreatedAt time.Time - ExpiresAt *time.Time - RevokedAt *time.Time + // Metadata stores optional unencrypted display data (e.g., last 4 digits, expiry date). + // Stored as JSON in the database with recommended maximum size of 1KB. + // Supported types: string, int, float64, bool, nil, and nested maps/slices of these types. + // Example: map[string]any{"last4": "1234", "exp": "12/25", "brand": "Visa"} + // WARNING: Do not store sensitive data in metadata as it is NOT encrypted. + Metadata map[string]any + CreatedAt time.Time + ExpiresAt *time.Time + RevokedAt *time.Time } // IsExpired checks if the token has expired. All time comparisons use UTC. diff --git a/internal/tokenization/domain/tokenization_key.go b/internal/tokenization/domain/tokenization_key.go index 74d7b28..2a2b0b2 100644 --- a/internal/tokenization/domain/tokenization_key.go +++ b/internal/tokenization/domain/tokenization_key.go @@ -9,12 +9,51 @@ import ( // TokenizationKey represents a versioned tokenization key configuration. // Each key defines the token format and deterministic behavior for tokenization operations. type TokenizationKey struct { - ID uuid.UUID - Name string - Version uint - FormatType FormatType + // ID is the unique identifier for this specific key version. + ID uuid.UUID + + // Name is the logical name for this tokenization key (e.g., "payment-cards", "ssn"). + // Multiple versions can share the same name. + Name string + + // Version is the key version number, starting at 1 and incremented on rotation. + // Higher versions are preferred for tokenization; all versions support detokenization. + Version uint + + // FormatType defines the token format (UUID, Numeric, Luhn-Preserving, Alphanumeric). + FormatType FormatType + + // IsDeterministic indicates whether the same plaintext always produces the same token. + // When true, enables efficient duplicate detection; when false, provides better privacy. IsDeterministic bool - DekID uuid.UUID - CreatedAt time.Time - DeletedAt *time.Time + + // DekID is the reference to the Data Encryption Key used to encrypt values for this version. + DekID uuid.UUID + + // CreatedAt is the timestamp when this key version was created (UTC). + CreatedAt time.Time + + // DeletedAt is the timestamp when this key was soft-deleted (nil if active). + DeletedAt *time.Time +} + +// Validate checks if the TokenizationKey has valid field values. +// Returns an error if any field constraint is violated. +func (tk *TokenizationKey) Validate() error { + if tk.Name == "" { + return ErrTokenizationKeyNameEmpty + } + if tk.Version == 0 { + return ErrTokenizationKeyVersionInvalid + } + if err := tk.FormatType.Validate(); err != nil { + return ErrInvalidFormatType + } + if tk.DekID == uuid.Nil { + return ErrTokenizationKeyDekIDInvalid + } + if tk.CreatedAt.IsZero() { + return ErrInvalidFormatType // Using existing error for now + } + return nil } diff --git a/internal/tokenization/domain/tokenization_key_test.go b/internal/tokenization/domain/tokenization_key_test.go new file mode 100644 index 0000000..253e6c3 --- /dev/null +++ b/internal/tokenization/domain/tokenization_key_test.go @@ -0,0 +1,194 @@ +package domain + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestTokenizationKey_Validate(t *testing.T) { + validID := uuid.Must(uuid.NewV7()) + validDekID := uuid.Must(uuid.NewV7()) + now := time.Now().UTC() + + tests := []struct { + name string + key *TokenizationKey + expectError bool + expectedErr error + }{ + { + name: "Success_ValidKey", + key: &TokenizationKey{ + ID: validID, + Name: "test-key", + Version: 1, + FormatType: FormatUUID, + IsDeterministic: false, + DekID: validDekID, + CreatedAt: now, + DeletedAt: nil, + }, + expectError: false, + }, + { + name: "Success_ValidKeyWithHighVersion", + key: &TokenizationKey{ + ID: validID, + Name: "test-key", + Version: 100, + FormatType: FormatNumeric, + IsDeterministic: true, + DekID: validDekID, + CreatedAt: now, + DeletedAt: nil, + }, + expectError: false, + }, + { + name: "Success_ValidKeyAllFormats", + key: &TokenizationKey{ + ID: validID, + Name: "test-key", + Version: 1, + FormatType: FormatLuhnPreserving, + IsDeterministic: false, + DekID: validDekID, + CreatedAt: now, + DeletedAt: nil, + }, + expectError: false, + }, + { + name: "Error_EmptyName", + key: &TokenizationKey{ + ID: validID, + Name: "", + Version: 1, + FormatType: FormatUUID, + IsDeterministic: false, + DekID: validDekID, + CreatedAt: now, + }, + expectError: true, + expectedErr: ErrTokenizationKeyNameEmpty, + }, + { + name: "Error_ZeroVersion", + key: &TokenizationKey{ + ID: validID, + Name: "test-key", + Version: 0, + FormatType: FormatUUID, + IsDeterministic: false, + DekID: validDekID, + CreatedAt: now, + }, + expectError: true, + expectedErr: ErrTokenizationKeyVersionInvalid, + }, + { + name: "Error_InvalidFormatType", + key: &TokenizationKey{ + ID: validID, + Name: "test-key", + Version: 1, + FormatType: FormatType("invalid"), + IsDeterministic: false, + DekID: validDekID, + CreatedAt: now, + }, + expectError: true, + expectedErr: ErrInvalidFormatType, + }, + { + name: "Error_NilDekID", + key: &TokenizationKey{ + ID: validID, + Name: "test-key", + Version: 1, + FormatType: FormatUUID, + IsDeterministic: false, + DekID: uuid.Nil, + CreatedAt: now, + }, + expectError: true, + expectedErr: ErrTokenizationKeyDekIDInvalid, + }, + { + name: "Error_ZeroCreatedAt", + key: &TokenizationKey{ + ID: validID, + Name: "test-key", + Version: 1, + FormatType: FormatUUID, + IsDeterministic: false, + DekID: validDekID, + CreatedAt: time.Time{}, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.key.Validate() + + if tt.expectError { + assert.Error(t, err) + if tt.expectedErr != nil { + assert.ErrorIs(t, err, tt.expectedErr) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTokenizationKey_Fields(t *testing.T) { + t.Run("Success_AllFieldsSet", func(t *testing.T) { + id := uuid.Must(uuid.NewV7()) + dekID := uuid.Must(uuid.NewV7()) + now := time.Now().UTC() + deletedAt := now.Add(24 * time.Hour) + + key := &TokenizationKey{ + ID: id, + Name: "payment-cards", + Version: 5, + FormatType: FormatLuhnPreserving, + IsDeterministic: true, + DekID: dekID, + CreatedAt: now, + DeletedAt: &deletedAt, + } + + assert.Equal(t, id, key.ID) + assert.Equal(t, "payment-cards", key.Name) + assert.Equal(t, uint(5), key.Version) + assert.Equal(t, FormatLuhnPreserving, key.FormatType) + assert.True(t, key.IsDeterministic) + assert.Equal(t, dekID, key.DekID) + assert.Equal(t, now, key.CreatedAt) + assert.NotNil(t, key.DeletedAt) + assert.Equal(t, deletedAt, *key.DeletedAt) + }) + + t.Run("Success_DeletedAtNil", func(t *testing.T) { + key := &TokenizationKey{ + ID: uuid.Must(uuid.NewV7()), + Name: "test-key", + Version: 1, + FormatType: FormatUUID, + IsDeterministic: false, + DekID: uuid.Must(uuid.NewV7()), + CreatedAt: time.Now().UTC(), + DeletedAt: nil, + } + + assert.Nil(t, key.DeletedAt) + }) +} diff --git a/internal/tokenization/repository/mysql/mysql_repository.go b/internal/tokenization/repository/mysql/mysql_repository.go index 3717878..3ddafb9 100644 --- a/internal/tokenization/repository/mysql/mysql_repository.go +++ b/internal/tokenization/repository/mysql/mysql_repository.go @@ -7,6 +7,7 @@ import ( "errors" "time" + "github.com/go-sql-driver/mysql" "github.com/google/uuid" "github.com/allisson/secrets/internal/database" @@ -344,6 +345,11 @@ func (m *MySQLTokenRepository) Create( token.RevokedAt, ) if err != nil { + // Check for duplicate entry error (MySQL error number 1062) + var mysqlErr *mysql.MySQLError + if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 { + return apperrors.ErrConflict + } return apperrors.Wrap(err, "failed to create token") } return nil diff --git a/internal/tokenization/repository/postgresql/postgresql_repository.go b/internal/tokenization/repository/postgresql/postgresql_repository.go index 58697e8..44e488b 100644 --- a/internal/tokenization/repository/postgresql/postgresql_repository.go +++ b/internal/tokenization/repository/postgresql/postgresql_repository.go @@ -10,6 +10,7 @@ import ( "time" "github.com/google/uuid" + "github.com/lib/pq" "github.com/allisson/secrets/internal/database" apperrors "github.com/allisson/secrets/internal/errors" @@ -280,6 +281,11 @@ func (p *PostgreSQLTokenRepository) Create( token.RevokedAt, ) if err != nil { + // Check for unique constraint violation (PostgreSQL error code 23505) + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr.Code == "23505" { + return apperrors.ErrConflict + } return apperrors.Wrap(err, "failed to create token") } return nil diff --git a/internal/tokenization/service/alphanumeric_generator.go b/internal/tokenization/service/alphanumeric_generator.go index 4a980e7..9a054a1 100644 --- a/internal/tokenization/service/alphanumeric_generator.go +++ b/internal/tokenization/service/alphanumeric_generator.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "math/big" + + tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" ) const alphanumericChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" @@ -23,7 +25,7 @@ func (g *alphanumericGenerator) Generate(length int) (string, error) { if length < 1 { return "", errors.New("length must be at least 1") } - if length > 255 { + if length > tokenizationDomain.MaxTokenLength { return "", errors.New("length must not exceed 255") } diff --git a/internal/tokenization/service/luhn_generator.go b/internal/tokenization/service/luhn_generator.go index 753b458..561e6b5 100644 --- a/internal/tokenization/service/luhn_generator.go +++ b/internal/tokenization/service/luhn_generator.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "math/big" + + tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" ) type luhnGenerator struct{} @@ -18,10 +20,10 @@ func NewLuhnGenerator() TokenGenerator { // Generate creates a Luhn algorithm compliant numeric token of the specified length. // The last digit is calculated as the Luhn check digit. Returns an error if length is less than 2. func (g *luhnGenerator) Generate(length int) (string, error) { - if length < 2 { + if length < tokenizationDomain.MinLuhnTokenLength { return "", errors.New("length must be at least 2 for Luhn tokens") } - if length > 255 { + if length > tokenizationDomain.MaxTokenLength { return "", errors.New("length must not exceed 255") } @@ -73,6 +75,10 @@ func (g *luhnGenerator) Validate(token string) error { // calculateLuhnCheckDigit calculates the Luhn check digit for the given digits. // The digits slice should NOT include the check digit position. func calculateLuhnCheckDigit(digits []int) int { + if len(digits) == 0 { + return 0 // Return 0 for empty input + } + sum := 0 length := len(digits) diff --git a/internal/tokenization/service/numeric_generator.go b/internal/tokenization/service/numeric_generator.go index 6c2b15a..e4af796 100644 --- a/internal/tokenization/service/numeric_generator.go +++ b/internal/tokenization/service/numeric_generator.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "math/big" + + tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" ) type numericGenerator struct{} @@ -21,7 +23,7 @@ func (g *numericGenerator) Generate(length int) (string, error) { if length < 1 { return "", errors.New("length must be at least 1") } - if length > 255 { + if length > tokenizationDomain.MaxTokenLength { return "", errors.New("length must not exceed 255") } diff --git a/internal/tokenization/testing/helpers.go b/internal/tokenization/testing/helpers.go new file mode 100644 index 0000000..4c359b6 --- /dev/null +++ b/internal/tokenization/testing/helpers.go @@ -0,0 +1,51 @@ +// Package testing provides shared test utilities for tokenization module tests. +package testing + +import ( + "github.com/google/uuid" + + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" +) + +// CreateMasterKey creates a test master key with a random 32-byte key. +func CreateMasterKey() *cryptoDomain.MasterKey { + return &cryptoDomain.MasterKey{ + ID: "test-master-key", + Key: make([]byte, 32), + } +} + +// CreateKekChain creates a test KEK chain with a single active KEK. +func CreateKekChain(masterKey *cryptoDomain.MasterKey) *cryptoDomain.KekChain { + kek := &cryptoDomain.Kek{ + ID: uuid.Must(uuid.NewV7()), + MasterKeyID: masterKey.ID, + Algorithm: cryptoDomain.AESGCM, + EncryptedKey: make([]byte, 32), + Key: make([]byte, 32), + Nonce: make([]byte, 12), + Version: 1, + } + return cryptoDomain.NewKekChain([]*cryptoDomain.Kek{kek}) +} + +// GetActiveKek retrieves the active KEK from a chain. +func GetActiveKek(kekChain *cryptoDomain.KekChain) *cryptoDomain.Kek { + activeID := kekChain.ActiveKekID() + kek, ok := kekChain.Get(activeID) + if !ok { + panic("active KEK not found in chain") + } + return kek +} + +// CreateTestDek creates a test DEK for the given KEK. +func CreateTestDek(kek *cryptoDomain.Kek) cryptoDomain.Dek { + return cryptoDomain.Dek{ + ID: uuid.Must(uuid.NewV7()), + KekID: kek.ID, + Algorithm: cryptoDomain.AESGCM, + EncryptedKey: []byte("encrypted-dek"), + Nonce: []byte("nonce"), + } +} diff --git a/internal/tokenization/usecase/hash_service_test.go b/internal/tokenization/usecase/hash_service_test.go new file mode 100644 index 0000000..6105625 --- /dev/null +++ b/internal/tokenization/usecase/hash_service_test.go @@ -0,0 +1,162 @@ +package usecase + +import ( + "crypto/sha256" + "encoding/hex" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestNewSHA256HashService tests the constructor. +func TestNewSHA256HashService(t *testing.T) { + hashService := NewSHA256HashService() + assert.NotNil(t, hashService) + assert.IsType(t, &sha256HashService{}, hashService) +} + +// TestSHA256HashService_Hash tests the Hash method. +func TestSHA256HashService_Hash(t *testing.T) { + hashService := NewSHA256HashService() + + t.Run("Success_HashEmptyInput", func(t *testing.T) { + // Empty input should produce the SHA-256 hash of empty string + input := []byte{} + result := hashService.Hash(input) + + // Verify result is non-empty and valid hex + assert.NotEmpty(t, result) + assert.Equal(t, 64, len(result)) // SHA-256 produces 32 bytes = 64 hex chars + + // Verify it matches expected SHA-256 hash of empty string + expectedHash := sha256.Sum256([]byte{}) + expected := hex.EncodeToString(expectedHash[:]) + assert.Equal(t, expected, result) + }) + + t.Run("Success_HashSmallInput", func(t *testing.T) { + input := []byte("hello") + result := hashService.Hash(input) + + // Verify result is non-empty and valid hex + assert.NotEmpty(t, result) + assert.Equal(t, 64, len(result)) + + // Verify it matches expected SHA-256 hash + expectedHash := sha256.Sum256(input) + expected := hex.EncodeToString(expectedHash[:]) + assert.Equal(t, expected, result) + }) + + t.Run("Success_HashLargeInput", func(t *testing.T) { + // Create a large input (10KB) + input := []byte(strings.Repeat("A", 10240)) + result := hashService.Hash(input) + + // Verify result is non-empty and valid hex + assert.NotEmpty(t, result) + assert.Equal(t, 64, len(result)) + + // Verify it matches expected SHA-256 hash + expectedHash := sha256.Sum256(input) + expected := hex.EncodeToString(expectedHash[:]) + assert.Equal(t, expected, result) + }) + + t.Run("Success_HashBinaryData", func(t *testing.T) { + // Test with binary data (not just ASCII) + input := []byte{0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD} + result := hashService.Hash(input) + + // Verify result is non-empty and valid hex + assert.NotEmpty(t, result) + assert.Equal(t, 64, len(result)) + + // Verify it matches expected SHA-256 hash + expectedHash := sha256.Sum256(input) + expected := hex.EncodeToString(expectedHash[:]) + assert.Equal(t, expected, result) + }) + + t.Run("Success_ConsistencyCheck", func(t *testing.T) { + // Same input should always produce the same hash + input := []byte("test-consistency") + result1 := hashService.Hash(input) + result2 := hashService.Hash(input) + result3 := hashService.Hash(input) + + assert.Equal(t, result1, result2) + assert.Equal(t, result2, result3) + }) + + t.Run("Success_DifferentInputsProduceDifferentHashes", func(t *testing.T) { + // Different inputs should produce different hashes + input1 := []byte("plaintext1") + input2 := []byte("plaintext2") + + result1 := hashService.Hash(input1) + result2 := hashService.Hash(input2) + + assert.NotEqual(t, result1, result2) + }) + + t.Run("Success_SensitivityToSmallChanges", func(t *testing.T) { + // Even a single bit change should produce a completely different hash + input1 := []byte("plaintext") + input2 := []byte("Plaintext") // Only first letter capitalized + + result1 := hashService.Hash(input1) + result2 := hashService.Hash(input2) + + assert.NotEqual(t, result1, result2) + }) + + t.Run("Success_ResultIsValidHexString", func(t *testing.T) { + input := []byte("test") + result := hashService.Hash(input) + + // Verify result contains only valid hex characters + for _, char := range result { + assert.True(t, + (char >= '0' && char <= '9') || (char >= 'a' && char <= 'f'), + "Result should only contain hex characters (0-9, a-f)") + } + }) + + t.Run("Success_KnownTestVector", func(t *testing.T) { + // Test with a known SHA-256 test vector + // SHA-256("abc") = ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad + input := []byte("abc") + result := hashService.Hash(input) + + expected := "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + assert.Equal(t, expected, result) + }) + + t.Run("Success_MultipleInstancesProduceSameHash", func(t *testing.T) { + // Multiple instances of the hash service should produce the same result + hashService1 := NewSHA256HashService() + hashService2 := NewSHA256HashService() + + input := []byte("test-multiple-instances") + result1 := hashService1.Hash(input) + result2 := hashService2.Hash(input) + + assert.Equal(t, result1, result2) + }) + + t.Run("Success_UnicodeInput", func(t *testing.T) { + // Test with Unicode characters + input := []byte("Hello 世界 🌍") + result := hashService.Hash(input) + + // Verify result is non-empty and valid hex + assert.NotEmpty(t, result) + assert.Equal(t, 64, len(result)) + + // Verify consistency + result2 := hashService.Hash(input) + assert.Equal(t, result, result2) + }) +} diff --git a/internal/tokenization/usecase/helpers.go b/internal/tokenization/usecase/helpers.go new file mode 100644 index 0000000..0a7dd07 --- /dev/null +++ b/internal/tokenization/usecase/helpers.go @@ -0,0 +1,17 @@ +package usecase + +import ( + "github.com/google/uuid" + + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" +) + +// getKek retrieves a KEK from the chain by its ID. +// Returns ErrKekNotFound if the KEK is not in the chain. +func getKek(kekChain *cryptoDomain.KekChain, kekID uuid.UUID) (*cryptoDomain.Kek, error) { + kek, ok := kekChain.Get(kekID) + if !ok { + return nil, cryptoDomain.ErrKekNotFound + } + return kek, nil +} diff --git a/internal/tokenization/usecase/tokenization_key_usecase.go b/internal/tokenization/usecase/tokenization_key_usecase.go index 7d252be..787c503 100644 --- a/internal/tokenization/usecase/tokenization_key_usecase.go +++ b/internal/tokenization/usecase/tokenization_key_usecase.go @@ -22,69 +22,94 @@ type tokenizationKeyUseCase struct { kekChain *cryptoDomain.KekChain } -// getKek retrieves a KEK from the chain by its ID. -func (t *tokenizationKeyUseCase) getKek(kekID uuid.UUID) (*cryptoDomain.Kek, error) { - kek, ok := t.kekChain.Get(kekID) - if !ok { - return nil, cryptoDomain.ErrKekNotFound - } - return kek, nil -} - -// Create generates and persists a new tokenization key with version 1. -// Returns ErrTokenizationKeyAlreadyExists if a key with the same name already exists. -func (t *tokenizationKeyUseCase) Create( +// createTokenizationKey is a helper that creates a tokenization key within an existing transaction context. +// It does NOT create its own transaction - the caller must handle transaction management. +func (t *tokenizationKeyUseCase) createTokenizationKey( ctx context.Context, name string, + version uint, formatType tokenizationDomain.FormatType, isDeterministic bool, alg cryptoDomain.Algorithm, ) (*tokenizationDomain.TokenizationKey, error) { - // Validate format type - if err := formatType.Validate(); err != nil { - return nil, tokenizationDomain.ErrInvalidFormatType - } - - // Check if tokenization key with version 1 already exists - existingKey, err := t.tokenizationKeyRepo.GetByNameAndVersion(ctx, name, 1) - if err != nil && !apperrors.Is(err, tokenizationDomain.ErrTokenizationKeyNotFound) { - return nil, err - } - if existingKey != nil { - return nil, tokenizationDomain.ErrTokenizationKeyAlreadyExists - } - // Get active KEK from chain - activeKek, err := t.getKek(t.kekChain.ActiveKekID()) + activeKek, err := getKek(t.kekChain, t.kekChain.ActiveKekID()) if err != nil { - return nil, err + return nil, apperrors.Wrap(err, "failed to get active KEK") } // Create DEK encrypted with active KEK dek, err := t.keyManager.CreateDek(activeKek, alg) if err != nil { - return nil, err + return nil, apperrors.Wrap(err, "failed to create DEK") } // Persist DEK to database if err := t.dekRepo.Create(ctx, &dek); err != nil { - return nil, err + return nil, apperrors.Wrap(err, "failed to persist DEK") } - // Create tokenization key with version 1 + // Create tokenization key + keyID, err := uuid.NewV7() + if err != nil { + return nil, apperrors.Wrap(err, "failed to generate UUID for tokenization key") + } tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: uuid.Must(uuid.NewV7()), + ID: keyID, Name: name, - Version: 1, + Version: version, FormatType: formatType, IsDeterministic: isDeterministic, DekID: dek.ID, CreatedAt: time.Now().UTC(), } + // Validate tokenization key fields + if err := tokenizationKey.Validate(); err != nil { + return nil, apperrors.Wrap(err, "tokenization key validation failed") + } + // Persist tokenization key if err := t.tokenizationKeyRepo.Create(ctx, tokenizationKey); err != nil { - return nil, err + return nil, apperrors.Wrap(err, "failed to persist tokenization key") + } + + return tokenizationKey, nil +} + +// Create generates and persists a new tokenization key with version 1. +// Returns ErrTokenizationKeyAlreadyExists if a key with the same name already exists. +func (t *tokenizationKeyUseCase) Create( + ctx context.Context, + name string, + formatType tokenizationDomain.FormatType, + isDeterministic bool, + alg cryptoDomain.Algorithm, +) (*tokenizationDomain.TokenizationKey, error) { + // Validate format type + if err := formatType.Validate(); err != nil { + return nil, tokenizationDomain.ErrInvalidFormatType + } + + // Check if tokenization key with version 1 already exists + existingKey, err := t.tokenizationKeyRepo.GetByNameAndVersion(ctx, name, 1) + if err != nil && !apperrors.Is(err, tokenizationDomain.ErrTokenizationKeyNotFound) { + return nil, apperrors.Wrap(err, "failed to check for existing tokenization key") + } + if existingKey != nil { + return nil, tokenizationDomain.ErrTokenizationKeyAlreadyExists + } + + var tokenizationKey *tokenizationDomain.TokenizationKey + + // Wrap DEK and tokenization key creation in a transaction + err = t.txManager.WithTx(ctx, func(txCtx context.Context) error { + tokenizationKey, err = t.createTokenizationKey(txCtx, name, 1, formatType, isDeterministic, alg) + return err + }) + + if err != nil { + return nil, apperrors.Wrap(err, "failed to create tokenization key") } return tokenizationKey, nil @@ -111,32 +136,36 @@ func (t *tokenizationKeyUseCase) Rotate( if err != nil { // If key doesn't exist, create first version if apperrors.Is(err, tokenizationDomain.ErrTokenizationKeyNotFound) { - newKey, err = t.Create(txCtx, name, formatType, isDeterministic, alg) + newKey, err = t.createTokenizationKey(txCtx, name, 1, formatType, isDeterministic, alg) return err } - return err + return apperrors.Wrap(err, "failed to get current tokenization key") } // Get active KEK from chain - activeKek, err := t.getKek(t.kekChain.ActiveKekID()) + activeKek, err := getKek(t.kekChain, t.kekChain.ActiveKekID()) if err != nil { - return err + return apperrors.Wrap(err, "failed to get active KEK") } // Create new DEK encrypted with active KEK dek, err := t.keyManager.CreateDek(activeKek, alg) if err != nil { - return err + return apperrors.Wrap(err, "failed to create DEK") } // Persist new DEK if err := t.dekRepo.Create(txCtx, &dek); err != nil { - return err + return apperrors.Wrap(err, "failed to persist DEK") } // Create new tokenization key with incremented version + keyID, err := uuid.NewV7() + if err != nil { + return apperrors.Wrap(err, "failed to generate UUID for tokenization key") + } newKey = &tokenizationDomain.TokenizationKey{ - ID: uuid.Must(uuid.NewV7()), + ID: keyID, Name: name, Version: currentKey.Version + 1, FormatType: formatType, @@ -145,12 +174,20 @@ func (t *tokenizationKeyUseCase) Rotate( CreatedAt: time.Now().UTC(), } + // Validate tokenization key fields + if err := newKey.Validate(); err != nil { + return apperrors.Wrap(err, "tokenization key validation failed") + } + // Persist new tokenization key - return t.tokenizationKeyRepo.Create(txCtx, newKey) + if err := t.tokenizationKeyRepo.Create(txCtx, newKey); err != nil { + return apperrors.Wrap(err, "failed to persist rotated tokenization key") + } + return nil }) if err != nil { - return nil, err + return nil, apperrors.Wrap(err, "failed to rotate tokenization key") } return newKey, nil @@ -158,7 +195,11 @@ func (t *tokenizationKeyUseCase) Rotate( // Delete soft-deletes a tokenization key by setting its deleted_at timestamp. func (t *tokenizationKeyUseCase) Delete(ctx context.Context, keyID uuid.UUID) error { - return t.tokenizationKeyRepo.Delete(ctx, keyID) + err := t.tokenizationKeyRepo.Delete(ctx, keyID) + if err != nil { + return apperrors.Wrap(err, "failed to delete tokenization key") + } + return nil } // List retrieves tokenization keys ordered by name ascending with pagination. @@ -166,7 +207,11 @@ func (t *tokenizationKeyUseCase) List( ctx context.Context, offset, limit int, ) ([]*tokenizationDomain.TokenizationKey, error) { - return t.tokenizationKeyRepo.List(ctx, offset, limit) + keys, err := t.tokenizationKeyRepo.List(ctx, offset, limit) + if err != nil { + return nil, apperrors.Wrap(err, "failed to list tokenization keys") + } + return keys, nil } // NewTokenizationKeyUseCase creates a new tokenization key use case instance. diff --git a/internal/tokenization/usecase/tokenization_key_usecase_test.go b/internal/tokenization/usecase/tokenization_key_usecase_test.go index 3f29170..07bae91 100644 --- a/internal/tokenization/usecase/tokenization_key_usecase_test.go +++ b/internal/tokenization/usecase/tokenization_key_usecase_test.go @@ -13,38 +13,10 @@ import ( cryptoServiceMocks "github.com/allisson/secrets/internal/crypto/service/mocks" databaseMocks "github.com/allisson/secrets/internal/database/mocks" tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" + tokenizationTesting "github.com/allisson/secrets/internal/tokenization/testing" tokenizationMocks "github.com/allisson/secrets/internal/tokenization/usecase/mocks" ) -// createKekChain creates a test KEK chain for tokenization key tests. -func createKekChain(masterKey *cryptoDomain.MasterKey) *cryptoDomain.KekChain { - // Create a test KEK with plaintext key populated - kek := &cryptoDomain.Kek{ - ID: uuid.Must(uuid.NewV7()), - MasterKeyID: masterKey.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: make([]byte, 32), - Key: make([]byte, 32), // Plaintext KEK key for testing - Nonce: make([]byte, 12), - Version: 1, - } - - // Create KEK chain with the test KEK (newest first) - kekChain := cryptoDomain.NewKekChain([]*cryptoDomain.Kek{kek}) - - return kekChain -} - -// getActiveKek is a helper to get the active KEK from a chain. -func getActiveKek(kekChain *cryptoDomain.KekChain) *cryptoDomain.Kek { - activeID := kekChain.ActiveKekID() - kek, ok := kekChain.Get(activeID) - if !ok { - panic("active KEK not found in chain") - } - return kek -} - // TestTokenizationKeyUseCase_Create tests the Create method. func TestTokenizationKeyUseCase_Create(t *testing.T) { ctx := context.Background() @@ -57,14 +29,11 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dek := cryptoDomain.Dek{ ID: uuid.Must(uuid.NewV7()), KekID: activeKek.ID, @@ -79,20 +48,29 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). Once() + mockTxManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + // Execute the transaction function + _ = fn(ctx) + }). + Return(nil). + Once() + mockKeyManager.EXPECT(). CreateDek(activeKek, cryptoDomain.AESGCM). Return(dek, nil). Once() mockDekRepo.EXPECT(). - Create(ctx, mock.MatchedBy(func(d *cryptoDomain.Dek) bool { + Create(mock.Anything, mock.MatchedBy(func(d *cryptoDomain.Dek) bool { return d.ID == dek.ID && d.KekID == dek.KekID })). Return(nil). Once() mockTokenizationKeyRepo.EXPECT(). - Create(ctx, mock.MatchedBy(func(key *tokenizationDomain.TokenizationKey) bool { + Create(mock.Anything, mock.MatchedBy(func(key *tokenizationDomain.TokenizationKey) bool { return key.Name == "test-key" && key.FormatType == tokenizationDomain.FormatUUID && key.Version == 1 && @@ -129,14 +107,11 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dek := cryptoDomain.Dek{ ID: uuid.Must(uuid.NewV7()), KekID: activeKek.ID, @@ -151,18 +126,27 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). Once() + mockTxManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + // Execute the transaction function + _ = fn(ctx) + }). + Return(nil). + Once() + mockKeyManager.EXPECT(). CreateDek(activeKek, cryptoDomain.ChaCha20). Return(dek, nil). Once() mockDekRepo.EXPECT(). - Create(ctx, mock.Anything). + Create(mock.Anything, mock.Anything). Return(nil). Once() mockTokenizationKeyRepo.EXPECT(). - Create(ctx, mock.MatchedBy(func(key *tokenizationDomain.TokenizationKey) bool { + Create(mock.Anything, mock.MatchedBy(func(key *tokenizationDomain.TokenizationKey) bool { return key.Name == "payment-cards" && key.FormatType == tokenizationDomain.FormatLuhnPreserving && key.Version == 1 && @@ -203,11 +187,8 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() expectedError := errors.New("key manager error") @@ -218,6 +199,15 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). Once() + mockTxManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + // Execute the transaction function + _ = fn(ctx) + }). + Return(expectedError). + Once() + mockKeyManager.EXPECT(). CreateDek(mock.Anything, mock.Anything). Return(cryptoDomain.Dek{}, expectedError). @@ -236,7 +226,8 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { // Assert assert.Error(t, err) assert.Nil(t, key) - assert.Equal(t, expectedError, err) + assert.True(t, errors.Is(err, expectedError)) + assert.Contains(t, err.Error(), "failed to create tokenization key") }) t.Run("Error_DekRepositoryCreateFails", func(t *testing.T) { @@ -247,14 +238,11 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dek := cryptoDomain.Dek{ ID: uuid.Must(uuid.NewV7()), KekID: activeKek.ID, @@ -271,13 +259,22 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). Once() + mockTxManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + // Execute the transaction function + _ = fn(ctx) + }). + Return(expectedError). + Once() + mockKeyManager.EXPECT(). CreateDek(mock.Anything, mock.Anything). Return(dek, nil). Once() mockDekRepo.EXPECT(). - Create(ctx, mock.Anything). + Create(mock.Anything, mock.Anything). Return(expectedError). Once() @@ -294,7 +291,8 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { // Assert assert.Error(t, err) assert.Nil(t, key) - assert.Equal(t, expectedError, err) + assert.True(t, errors.Is(err, expectedError)) + assert.Contains(t, err.Error(), "failed to create tokenization key") }) t.Run("Error_TokenizationKeyRepositoryCreateFails", func(t *testing.T) { @@ -305,14 +303,11 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dek := cryptoDomain.Dek{ ID: uuid.Must(uuid.NewV7()), KekID: activeKek.ID, @@ -329,18 +324,27 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). Once() + mockTxManager.EXPECT(). + WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(ctx context.Context, fn func(context.Context) error) { + // Execute the transaction function + _ = fn(ctx) + }). + Return(expectedError). + Once() + mockKeyManager.EXPECT(). CreateDek(mock.Anything, mock.Anything). Return(dek, nil). Once() mockDekRepo.EXPECT(). - Create(ctx, mock.Anything). + Create(mock.Anything, mock.Anything). Return(nil). Once() mockTokenizationKeyRepo.EXPECT(). - Create(ctx, mock.Anything). + Create(mock.Anything, mock.Anything). Return(expectedError). Once() @@ -357,7 +361,8 @@ func TestTokenizationKeyUseCase_Create(t *testing.T) { // Assert assert.Error(t, err) assert.Nil(t, key) - assert.Equal(t, expectedError, err) + assert.True(t, errors.Is(err, expectedError)) + assert.Contains(t, err.Error(), "failed to create tokenization key") }) } @@ -373,11 +378,8 @@ func TestTokenizationKeyUseCase_Rotate(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() existingKey := &tokenizationDomain.TokenizationKey{ @@ -389,7 +391,7 @@ func TestTokenizationKeyUseCase_Rotate(t *testing.T) { DekID: uuid.Must(uuid.NewV7()), } - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dek := cryptoDomain.Dek{ ID: uuid.Must(uuid.NewV7()), KekID: activeKek.ID, @@ -459,14 +461,11 @@ func TestTokenizationKeyUseCase_Rotate(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dek := cryptoDomain.Dek{ ID: uuid.Must(uuid.NewV7()), KekID: activeKek.ID, @@ -490,12 +489,7 @@ func TestTokenizationKeyUseCase_Rotate(t *testing.T) { Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). Once() - // Expectations for Create() call within transaction - mockTokenizationKeyRepo.EXPECT(). - GetByNameAndVersion(mock.Anything, "new-key", uint(1)). - Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). - Once() - + // Expectations for createTokenizationKey() call within transaction mockKeyManager.EXPECT(). CreateDek(activeKek, cryptoDomain.AESGCM). Return(dek, nil). @@ -545,11 +539,8 @@ func TestTokenizationKeyUseCase_Delete(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() keyID := uuid.Must(uuid.NewV7()) @@ -582,11 +573,8 @@ func TestTokenizationKeyUseCase_Delete(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() keyID := uuid.Must(uuid.NewV7()) @@ -610,7 +598,8 @@ func TestTokenizationKeyUseCase_Delete(t *testing.T) { // Assert assert.Error(t, err) - assert.Equal(t, expectedError, err) + assert.True(t, errors.Is(err, expectedError)) + assert.Contains(t, err.Error(), "failed to delete tokenization key") }) } @@ -697,6 +686,7 @@ func TestTokenizationKeyUseCase_List(t *testing.T) { // Assert assert.Error(t, err) assert.Nil(t, keys) - assert.Equal(t, expectedErr, err) + assert.True(t, errors.Is(err, expectedErr)) + assert.Contains(t, err.Error(), "failed to list tokenization keys") }) } diff --git a/internal/tokenization/usecase/tokenization_usecase.go b/internal/tokenization/usecase/tokenization_usecase.go index 057c753..592d660 100644 --- a/internal/tokenization/usecase/tokenization_usecase.go +++ b/internal/tokenization/usecase/tokenization_usecase.go @@ -18,6 +18,32 @@ import ( tokenizationService "github.com/allisson/secrets/internal/tokenization/service" ) +// validateTokenLength checks if the plaintext length is valid for the token format type. +func validateTokenLength(formatType tokenizationDomain.FormatType, length int) error { + // UUID format ignores length parameter + if formatType == tokenizationDomain.FormatUUID { + return nil + } + + // Luhn format requires at least 2 characters + if formatType == tokenizationDomain.FormatLuhnPreserving && + length < tokenizationDomain.MinLuhnTokenLength { + return tokenizationDomain.ErrTokenLengthInvalid + } + + // All format-preserving tokens have max length constraint + if length > tokenizationDomain.MaxTokenLength { + return tokenizationDomain.ErrTokenLengthInvalid + } + + // Minimum length is 1 for numeric/alphanumeric + if length < 1 { + return tokenizationDomain.ErrTokenLengthInvalid + } + + return nil +} + // tokenizationUseCase implements TokenizationUseCase for managing tokenization operations. type tokenizationUseCase struct { txManager database.TxManager @@ -30,18 +56,13 @@ type tokenizationUseCase struct { kekChain *cryptoDomain.KekChain } -// getKek retrieves a KEK from the chain by its ID. -func (t *tokenizationUseCase) getKek(kekID uuid.UUID) (*cryptoDomain.Kek, error) { - kek, ok := t.kekChain.Get(kekID) - if !ok { - return nil, cryptoDomain.ErrKekNotFound - } - return kek, nil -} - // Tokenize generates a token for the given plaintext value using the latest version of the named key. // In deterministic mode, returns the existing token if the value has been tokenized before. // Metadata is optional display data (e.g., last 4 digits) stored unencrypted. +// +// Rate Limiting: Production systems should implement rate limiting on this method to prevent abuse. +// Recommended: 100 requests per minute per user/API key for standard use cases. +// Adjust based on your specific security requirements and usage patterns. func (t *tokenizationUseCase) Tokenize( ctx context.Context, keyName string, @@ -49,10 +70,18 @@ func (t *tokenizationUseCase) Tokenize( metadata map[string]any, expiresAt *time.Time, ) (*tokenizationDomain.Token, error) { + // Validate plaintext size + if len(plaintext) == 0 { + return nil, tokenizationDomain.ErrPlaintextEmpty + } + if len(plaintext) > tokenizationDomain.MaxPlaintextSize { + return nil, tokenizationDomain.ErrPlaintextTooLarge + } + // Get latest tokenization key version tokenizationKey, err := t.tokenizationRepo.GetByName(ctx, keyName) if err != nil { - return nil, err + return nil, apperrors.Wrap(err, "failed to get tokenization key by name") } // In deterministic mode, check if token already exists for this value @@ -60,7 +89,7 @@ func (t *tokenizationUseCase) Tokenize( valueHash := t.hashService.Hash(plaintext) existingToken, err := t.tokenRepo.GetByValueHash(ctx, tokenizationKey.ID, valueHash) if err != nil && !apperrors.Is(err, tokenizationDomain.ErrTokenNotFound) { - return nil, err + return nil, apperrors.Wrap(err, "failed to check existing token in deterministic mode") } if existingToken != nil { // Return existing valid token @@ -74,26 +103,26 @@ func (t *tokenizationUseCase) Tokenize( // Get DEK by tokenization key's DekID dek, err := t.dekRepo.Get(ctx, tokenizationKey.DekID) if err != nil { - return nil, err + return nil, apperrors.Wrap(err, "failed to get DEK") } // Get KEK for decrypting DEK - kek, err := t.getKek(dek.KekID) + kek, err := getKek(t.kekChain, dek.KekID) if err != nil { - return nil, err + return nil, apperrors.Wrap(err, "failed to get KEK") } // Decrypt DEK with KEK dekKey, err := t.keyManager.DecryptDek(dek, kek) if err != nil { - return nil, err + return nil, apperrors.Wrap(err, "failed to decrypt DEK") } defer cryptoDomain.Zero(dekKey) // Create AEAD cipher with decrypted DEK cipher, err := t.aeadManager.CreateCipher(dekKey, dek.Algorithm) if err != nil { - return nil, err + return nil, apperrors.Wrap(err, "failed to create cipher") } // Encrypt plaintext @@ -110,14 +139,24 @@ func (t *tokenizationUseCase) Tokenize( // For format-preserving tokens, use plaintext length as hint tokenLength := len(plaintext) + + // Validate token length matches format requirements + if err := validateTokenLength(tokenizationKey.FormatType, tokenLength); err != nil { + return nil, err + } + tokenValue, err := generator.Generate(tokenLength) if err != nil { return nil, apperrors.Wrap(err, "failed to generate token") } // Create token record + tokenID, err := uuid.NewV7() + if err != nil { + return nil, apperrors.Wrap(err, "failed to generate UUID for token") + } token := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), + ID: tokenID, TokenizationKeyID: tokenizationKey.ID, Token: tokenValue, ValueHash: nil, @@ -137,7 +176,21 @@ func (t *tokenizationUseCase) Tokenize( // Persist token if err := t.tokenRepo.Create(ctx, token); err != nil { - return nil, err + // In deterministic mode, handle race condition where another goroutine + // created the same token between our check and insert + if tokenizationKey.IsDeterministic && apperrors.Is(err, apperrors.ErrConflict) { + // Race detected: another concurrent request inserted this token + // Query again to get the token that was inserted + valueHash := t.hashService.Hash(plaintext) + existingToken, queryErr := t.tokenRepo.GetByValueHash(ctx, tokenizationKey.ID, valueHash) + if queryErr != nil { + // If query fails, return original create error + return nil, apperrors.Wrap(err, "failed to create token") + } + // Return the token created by the concurrent request + return existingToken, nil + } + return nil, apperrors.Wrap(err, "failed to create token") } return token, nil @@ -153,7 +206,7 @@ func (t *tokenizationUseCase) Detokenize( // Get token record tokenRecord, err := t.tokenRepo.GetByToken(ctx, token) if err != nil { - return nil, nil, err + return nil, nil, apperrors.Wrap(err, "failed to get token") } // Validate token is not expired @@ -169,38 +222,41 @@ func (t *tokenizationUseCase) Detokenize( // Get tokenization key to retrieve its DekID tokenizationKey, err := t.tokenizationRepo.Get(ctx, tokenRecord.TokenizationKeyID) if err != nil { - return nil, nil, err + return nil, nil, apperrors.Wrap(err, "failed to get tokenization key") } // Get DEK dek, err := t.dekRepo.Get(ctx, tokenizationKey.DekID) if err != nil { - return nil, nil, err + return nil, nil, apperrors.Wrap(err, "failed to get DEK") } // Get KEK for decrypting DEK - kek, err := t.getKek(dek.KekID) + kek, err := getKek(t.kekChain, dek.KekID) if err != nil { - return nil, nil, err + return nil, nil, apperrors.Wrap(err, "failed to get KEK") } // Decrypt DEK with KEK dekKey, err := t.keyManager.DecryptDek(dek, kek) if err != nil { - return nil, nil, err + return nil, nil, apperrors.Wrap(err, "failed to decrypt DEK") } defer cryptoDomain.Zero(dekKey) // Create AEAD cipher with decrypted DEK cipher, err := t.aeadManager.CreateCipher(dekKey, dek.Algorithm) if err != nil { - return nil, nil, err + return nil, nil, apperrors.Wrap(err, "failed to create cipher") } // Decrypt ciphertext with nonce plaintext, err = cipher.Decrypt(tokenRecord.Ciphertext, tokenRecord.Nonce, nil) if err != nil { - return nil, nil, cryptoDomain.ErrDecryptionFailed + return nil, nil, apperrors.Wrap( + cryptoDomain.ErrDecryptionFailed, + "failed to decrypt token ciphertext", + ) } return plaintext, tokenRecord.Metadata, nil @@ -214,7 +270,7 @@ func (t *tokenizationUseCase) Validate(ctx context.Context, token string) (bool, if apperrors.Is(err, tokenizationDomain.ErrTokenNotFound) { return false, nil } - return false, err + return false, apperrors.Wrap(err, "failed to validate token") } // Check if token is valid @@ -226,11 +282,15 @@ func (t *tokenizationUseCase) Revoke(ctx context.Context, token string) error { // Verify token exists first _, err := t.tokenRepo.GetByToken(ctx, token) if err != nil { - return err + return apperrors.Wrap(err, "failed to get token for revocation") } // Revoke the token - return t.tokenRepo.Revoke(ctx, token) + err = t.tokenRepo.Revoke(ctx, token) + if err != nil { + return apperrors.Wrap(err, "failed to revoke token") + } + return nil } // CleanupExpired deletes tokens that expired more than the specified number of days ago. diff --git a/internal/tokenization/usecase/tokenization_usecase_test.go b/internal/tokenization/usecase/tokenization_usecase_test.go index e4b8a13..db176c2 100644 --- a/internal/tokenization/usecase/tokenization_usecase_test.go +++ b/internal/tokenization/usecase/tokenization_usecase_test.go @@ -14,6 +14,7 @@ import ( cryptoServiceMocks "github.com/allisson/secrets/internal/crypto/service/mocks" databaseMocks "github.com/allisson/secrets/internal/database/mocks" tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" + tokenizationTesting "github.com/allisson/secrets/internal/tokenization/testing" tokenizationMocks "github.com/allisson/secrets/internal/tokenization/usecase/mocks" ) @@ -32,14 +33,11 @@ func TestTokenizationUseCase_Tokenize(t *testing.T) { mockHashService := tokenizationMocks.NewMockHashService(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dekID := uuid.Must(uuid.NewV7()) tokenizationKeyID := uuid.Must(uuid.NewV7()) @@ -146,14 +144,11 @@ func TestTokenizationUseCase_Tokenize(t *testing.T) { mockHashService := tokenizationMocks.NewMockHashService(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dekID := uuid.Must(uuid.NewV7()) tokenizationKeyID := uuid.Must(uuid.NewV7()) @@ -273,11 +268,8 @@ func TestTokenizationUseCase_Tokenize(t *testing.T) { mockHashService := tokenizationMocks.NewMockHashService(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() tokenizationKeyID := uuid.Must(uuid.NewV7()) @@ -356,14 +348,11 @@ func TestTokenizationUseCase_Tokenize(t *testing.T) { mockHashService := tokenizationMocks.NewMockHashService(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dekID := uuid.Must(uuid.NewV7()) tokenizationKeyID := uuid.Must(uuid.NewV7()) plaintext := []byte("test-value") @@ -490,11 +479,8 @@ func TestTokenizationUseCase_Tokenize(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() // Setup expectations @@ -520,7 +506,9 @@ func TestTokenizationUseCase_Tokenize(t *testing.T) { // Assert assert.Nil(t, token) - assert.Equal(t, tokenizationDomain.ErrTokenizationKeyNotFound, err) + assert.Error(t, err) + assert.True(t, errors.Is(err, tokenizationDomain.ErrTokenizationKeyNotFound)) + assert.Contains(t, err.Error(), "failed to get tokenization key by name") }) t.Run("Error_DekNotFound", func(t *testing.T) { @@ -533,11 +521,8 @@ func TestTokenizationUseCase_Tokenize(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() dekID := uuid.Must(uuid.NewV7()) @@ -578,7 +563,9 @@ func TestTokenizationUseCase_Tokenize(t *testing.T) { // Assert assert.Nil(t, token) - assert.Equal(t, cryptoDomain.ErrDekNotFound, err) + assert.Error(t, err) + assert.True(t, errors.Is(err, cryptoDomain.ErrDekNotFound)) + assert.Contains(t, err.Error(), "failed to get DEK") }) t.Run("Error_KekNotFound", func(t *testing.T) { @@ -591,11 +578,8 @@ func TestTokenizationUseCase_Tokenize(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() dekID := uuid.Must(uuid.NewV7()) @@ -646,7 +630,9 @@ func TestTokenizationUseCase_Tokenize(t *testing.T) { // Assert assert.Nil(t, token) - assert.Equal(t, cryptoDomain.ErrKekNotFound, err) + assert.Error(t, err) + assert.True(t, errors.Is(err, cryptoDomain.ErrKekNotFound)) + assert.Contains(t, err.Error(), "failed to get KEK") }) t.Run("Error_EncryptionFails", func(t *testing.T) { @@ -659,14 +645,11 @@ func TestTokenizationUseCase_Tokenize(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dekID := uuid.Must(uuid.NewV7()) tokenizationKey := &tokenizationDomain.TokenizationKey{ @@ -755,14 +738,11 @@ func TestTokenizationUseCase_Detokenize(t *testing.T) { mockHashService := tokenizationMocks.NewMockHashService(t) // Create test data - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dekID := uuid.Must(uuid.NewV7()) tokenizationKeyID := uuid.Must(uuid.NewV7()) tokenValue := "test-token-123" @@ -865,11 +845,8 @@ func TestTokenizationUseCase_Detokenize(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() // Setup expectations @@ -896,7 +873,8 @@ func TestTokenizationUseCase_Detokenize(t *testing.T) { // Assert assert.Nil(t, plaintext) assert.Nil(t, metadata) - assert.Equal(t, tokenizationDomain.ErrTokenNotFound, err) + assert.True(t, errors.Is(err, tokenizationDomain.ErrTokenNotFound)) + assert.Contains(t, err.Error(), "failed to get token") }) t.Run("Error_TokenExpired", func(t *testing.T) { @@ -909,11 +887,8 @@ func TestTokenizationUseCase_Detokenize(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() expiredTime := time.Now().UTC().Add(-1 * time.Hour) @@ -967,11 +942,8 @@ func TestTokenizationUseCase_Detokenize(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() revokedTime := time.Now().UTC().Add(-30 * time.Minute) @@ -1025,14 +997,11 @@ func TestTokenizationUseCase_Detokenize(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() - activeKek := getActiveKek(kekChain) + activeKek := tokenizationTesting.GetActiveKek(kekChain) dekID := uuid.Must(uuid.NewV7()) tokenizationKeyID := uuid.Must(uuid.NewV7()) tokenValue := "test-token" @@ -1118,7 +1087,8 @@ func TestTokenizationUseCase_Detokenize(t *testing.T) { // Assert assert.Nil(t, plaintext) assert.Nil(t, metadata) - assert.Equal(t, cryptoDomain.ErrDecryptionFailed, err) + assert.True(t, errors.Is(err, cryptoDomain.ErrDecryptionFailed)) + assert.Contains(t, err.Error(), "failed to decrypt token ciphertext") }) } @@ -1136,11 +1106,8 @@ func TestTokenizationUseCase_Validate(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() tokenValue := "valid-token" @@ -1191,11 +1158,8 @@ func TestTokenizationUseCase_Validate(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() expiredTime := time.Now().UTC().Add(-1 * time.Hour) @@ -1247,11 +1211,8 @@ func TestTokenizationUseCase_Validate(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() // Setup expectations @@ -1290,11 +1251,8 @@ func TestTokenizationUseCase_Validate(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() dbError := errors.New("database error") @@ -1322,7 +1280,9 @@ func TestTokenizationUseCase_Validate(t *testing.T) { // Assert assert.False(t, isValid) - assert.Equal(t, dbError, err) + assert.Error(t, err) + assert.True(t, errors.Is(err, dbError)) + assert.Contains(t, err.Error(), "failed to validate token") }) } @@ -1340,11 +1300,8 @@ func TestTokenizationUseCase_Revoke(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() tokenValue := "token-to-revoke" @@ -1399,11 +1356,8 @@ func TestTokenizationUseCase_Revoke(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() // Setup expectations @@ -1428,7 +1382,9 @@ func TestTokenizationUseCase_Revoke(t *testing.T) { err := uc.Revoke(ctx, "nonexistent-token") // Assert - assert.Equal(t, tokenizationDomain.ErrTokenNotFound, err) + assert.Error(t, err) + assert.True(t, errors.Is(err, tokenizationDomain.ErrTokenNotFound)) + assert.Contains(t, err.Error(), "failed to get token for revocation") }) t.Run("Error_RevokeFails", func(t *testing.T) { @@ -1441,11 +1397,8 @@ func TestTokenizationUseCase_Revoke(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() tokenValue := "test-token" @@ -1489,7 +1442,9 @@ func TestTokenizationUseCase_Revoke(t *testing.T) { err := uc.Revoke(ctx, tokenValue) // Assert - assert.Equal(t, dbError, err) + assert.Error(t, err) + assert.True(t, errors.Is(err, dbError)) + assert.Contains(t, err.Error(), "failed to revoke token") }) } @@ -1507,11 +1462,8 @@ func TestTokenizationUseCase_CleanupExpired(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() // Setup expectations @@ -1556,11 +1508,8 @@ func TestTokenizationUseCase_CleanupExpired(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() // Setup expectations @@ -1605,11 +1554,8 @@ func TestTokenizationUseCase_CleanupExpired(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() // Create use case @@ -1643,11 +1589,8 @@ func TestTokenizationUseCase_CleanupExpired(t *testing.T) { mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) mockHashService := tokenizationMocks.NewMockHashService(t) - masterKey := &cryptoDomain.MasterKey{ - ID: "test-master-key", - Key: make([]byte, 32), - } - kekChain := createKekChain(masterKey) + masterKey := tokenizationTesting.CreateMasterKey() + kekChain := tokenizationTesting.CreateKekChain(masterKey) defer kekChain.Close() dbError := errors.New("database error") diff --git a/internal/transit/domain/const.go b/internal/transit/domain/const.go new file mode 100644 index 0000000..2d6f380 --- /dev/null +++ b/internal/transit/domain/const.go @@ -0,0 +1,9 @@ +// Package domain defines core transit encryption domain models. +package domain + +const ( + // MaxTransitKeyNameLength is the maximum allowed length for transit key names. + // This limit aligns with database schema constraints (VARCHAR(255)) and prevents + // excessively long identifiers that could impact performance or cause display issues. + MaxTransitKeyNameLength = 255 +) diff --git a/internal/transit/domain/encrypted_blob.go b/internal/transit/domain/encrypted_blob.go index fba5da8..78a528f 100644 --- a/internal/transit/domain/encrypted_blob.go +++ b/internal/transit/domain/encrypted_blob.go @@ -2,6 +2,7 @@ package domain import ( "encoding/base64" + "errors" "fmt" "strconv" "strings" @@ -10,9 +11,9 @@ import ( // EncryptedBlob represents an encrypted data blob with version and ciphertext. // Format: "version:ciphertext-base64" type EncryptedBlob struct { - Version uint // Transit key version used for encryption - Ciphertext []byte // Encrypted data - Plaintext []byte // In memory only + Version uint // Transit key version used for this encryption/decryption operation + Ciphertext []byte // Encrypted data with nonce prepended (empty after decryption) + Plaintext []byte // Decrypted data (only populated after decryption, should be zeroed after use) } // NewEncryptedBlob creates an EncryptedBlob from string format "version:ciphertext-base64". @@ -51,3 +52,18 @@ func (eb EncryptedBlob) String() string { encodedCiphertext := base64.StdEncoding.EncodeToString(eb.Ciphertext) return fmt.Sprintf("%d:%s", eb.Version, encodedCiphertext) } + +// Validate checks if the encrypted blob contains valid data. +// Returns an error if any field violates domain constraints. +func (eb *EncryptedBlob) Validate() error { + if eb.Version == 0 { + return errors.New("encrypted blob version must be greater than 0") + } + + // Must have either ciphertext (for encryption result) or plaintext (for decryption result) + if len(eb.Ciphertext) == 0 && len(eb.Plaintext) == 0 { + return errors.New("encrypted blob must contain either ciphertext or plaintext") + } + + return nil +} diff --git a/internal/transit/domain/encrypted_blob_test.go b/internal/transit/domain/encrypted_blob_test.go index 0371f44..c11a7fb 100644 --- a/internal/transit/domain/encrypted_blob_test.go +++ b/internal/transit/domain/encrypted_blob_test.go @@ -317,3 +317,97 @@ func TestEncryptedBlob_String(t *testing.T) { assert.Equal(t, complexData, parsed.Ciphertext) }) } + +func TestEncryptedBlob_Validate(t *testing.T) { + t.Run("Success_ValidBlobWithCiphertext", func(t *testing.T) { + // Arrange + blob := &domain.EncryptedBlob{ + Version: 1, + Ciphertext: []byte("encrypted data"), + Plaintext: nil, + } + + // Act + err := blob.Validate() + + // Assert + assert.NoError(t, err) + }) + + t.Run("Success_ValidBlobWithPlaintext", func(t *testing.T) { + // Arrange + blob := &domain.EncryptedBlob{ + Version: 2, + Ciphertext: nil, + Plaintext: []byte("decrypted data"), + } + + // Act + err := blob.Validate() + + // Assert + assert.NoError(t, err) + }) + + t.Run("Success_ValidBlobWithBoth", func(t *testing.T) { + // Arrange + blob := &domain.EncryptedBlob{ + Version: 3, + Ciphertext: []byte("encrypted"), + Plaintext: []byte("plaintext"), + } + + // Act + err := blob.Validate() + + // Assert + assert.NoError(t, err) + }) + + t.Run("Error_ZeroVersion", func(t *testing.T) { + // Arrange + blob := &domain.EncryptedBlob{ + Version: 0, + Ciphertext: []byte("data"), + } + + // Act + err := blob.Validate() + + // Assert + assert.Error(t, err) + assert.Contains(t, err.Error(), "version must be greater than 0") + }) + + t.Run("Error_EmptyBlob", func(t *testing.T) { + // Arrange + blob := &domain.EncryptedBlob{ + Version: 1, + Ciphertext: []byte{}, + Plaintext: []byte{}, + } + + // Act + err := blob.Validate() + + // Assert + assert.Error(t, err) + assert.Contains(t, err.Error(), "must contain either ciphertext or plaintext") + }) + + t.Run("Error_BothNil", func(t *testing.T) { + // Arrange + blob := &domain.EncryptedBlob{ + Version: 1, + Ciphertext: nil, + Plaintext: nil, + } + + // Act + err := blob.Validate() + + // Assert + assert.Error(t, err) + assert.Contains(t, err.Error(), "must contain either ciphertext or plaintext") + }) +} diff --git a/internal/transit/domain/transit_key.go b/internal/transit/domain/transit_key.go index c8126a1..15dd884 100644 --- a/internal/transit/domain/transit_key.go +++ b/internal/transit/domain/transit_key.go @@ -1,6 +1,8 @@ package domain import ( + "errors" + "fmt" "time" "github.com/google/uuid" @@ -11,10 +13,36 @@ import ( // version (highest number) is used for encryption while older versions remain available // for decryption. Soft deletion via DeletedAt field preserves keys for historical decryption. type TransitKey struct { - ID uuid.UUID - Name string - Version uint - DekID uuid.UUID - CreatedAt time.Time - DeletedAt *time.Time + ID uuid.UUID // Unique identifier for this specific transit key version + Name string // Human-readable name (shared across all versions of this key) + Version uint // Key version number (increments with rotation, starts at 1) + DekID uuid.UUID // Reference to the Data Encryption Key used to encrypt this transit key + CreatedAt time.Time // Timestamp when this key version was created (UTC) + DeletedAt *time.Time // Soft deletion timestamp (nil if active, set when deleted) +} + +// Validate checks if the transit key contains valid data. +// Returns an error if any field violates domain constraints. +func (tk *TransitKey) Validate() error { + if tk.Name == "" { + return errors.New("transit key name cannot be empty") + } + + if len(tk.Name) > MaxTransitKeyNameLength { + return fmt.Errorf("transit key name exceeds maximum length of %d characters", MaxTransitKeyNameLength) + } + + if tk.Version == 0 { + return errors.New("transit key version must be greater than 0") + } + + if tk.DekID == uuid.Nil { + return errors.New("transit key must have a valid DEK ID") + } + + if tk.CreatedAt.IsZero() { + return errors.New("transit key must have a valid created_at timestamp") + } + + return nil } diff --git a/internal/transit/domain/transit_key_test.go b/internal/transit/domain/transit_key_test.go new file mode 100644 index 0000000..88c297e --- /dev/null +++ b/internal/transit/domain/transit_key_test.go @@ -0,0 +1,88 @@ +package domain + +import ( + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestTransitKey_Validate(t *testing.T) { + validTransitKey := &TransitKey{ + ID: uuid.Must(uuid.NewV7()), + Name: "test-key", + Version: 1, + DekID: uuid.Must(uuid.NewV7()), + CreatedAt: time.Now().UTC(), + DeletedAt: nil, + } + + t.Run("Success_ValidTransitKey", func(t *testing.T) { + err := validTransitKey.Validate() + assert.NoError(t, err) + }) + + t.Run("Error_EmptyName", func(t *testing.T) { + key := *validTransitKey + key.Name = "" + + err := key.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "name cannot be empty") + }) + + t.Run("Error_NameTooLong", func(t *testing.T) { + key := *validTransitKey + key.Name = strings.Repeat("a", MaxTransitKeyNameLength+1) + + err := key.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum length") + }) + + t.Run("Success_NameAtMaxLength", func(t *testing.T) { + key := *validTransitKey + key.Name = strings.Repeat("a", MaxTransitKeyNameLength) + + err := key.Validate() + assert.NoError(t, err) + }) + + t.Run("Error_ZeroVersion", func(t *testing.T) { + key := *validTransitKey + key.Version = 0 + + err := key.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "version must be greater than 0") + }) + + t.Run("Error_NilDekID", func(t *testing.T) { + key := *validTransitKey + key.DekID = uuid.Nil + + err := key.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "valid DEK ID") + }) + + t.Run("Error_ZeroCreatedAt", func(t *testing.T) { + key := *validTransitKey + key.CreatedAt = time.Time{} + + err := key.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "valid created_at timestamp") + }) + + t.Run("Success_WithDeletedAt", func(t *testing.T) { + key := *validTransitKey + now := time.Now().UTC() + key.DeletedAt = &now + + err := key.Validate() + assert.NoError(t, err) + }) +} diff --git a/internal/transit/http/crypto_handler.go b/internal/transit/http/crypto_handler.go index 230155d..d7fa26b 100644 --- a/internal/transit/http/crypto_handler.go +++ b/internal/transit/http/crypto_handler.go @@ -9,7 +9,6 @@ import ( "github.com/gin-gonic/gin" - authUseCase "github.com/allisson/secrets/internal/auth/usecase" cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" "github.com/allisson/secrets/internal/httputil" "github.com/allisson/secrets/internal/transit/http/dto" @@ -20,20 +19,17 @@ import ( // CryptoHandler handles HTTP requests for transit encryption and decryption operations. // It coordinates authentication, authorization, and audit logging with the TransitKeyUseCase. type CryptoHandler struct { - transitKeyUseCase transitUseCase.TransitKeyUseCase - auditLogUseCase authUseCase.AuditLogUseCase - logger *slog.Logger + transitKeyUseCase transitUseCase.TransitKeyUseCase // Business logic for encryption and decryption operations + logger *slog.Logger // Structured logger for request handling and error reporting } // NewCryptoHandler creates a new crypto handler with required dependencies. func NewCryptoHandler( transitKeyUseCase transitUseCase.TransitKeyUseCase, - auditLogUseCase authUseCase.AuditLogUseCase, logger *slog.Logger, ) *CryptoHandler { return &CryptoHandler{ transitKeyUseCase: transitKeyUseCase, - auditLogUseCase: auditLogUseCase, logger: logger, } } diff --git a/internal/transit/http/crypto_handler_test.go b/internal/transit/http/crypto_handler_test.go index 792ea43..2c433da 100644 --- a/internal/transit/http/crypto_handler_test.go +++ b/internal/transit/http/crypto_handler_test.go @@ -28,7 +28,7 @@ func setupTestCryptoHandler(t *testing.T) (*CryptoHandler, *mocks.MockTransitKey mockTransitKeyUseCase := mocks.NewMockTransitKeyUseCase(t) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - handler := NewCryptoHandler(mockTransitKeyUseCase, nil, logger) + handler := NewCryptoHandler(mockTransitKeyUseCase, logger) return handler, mockTransitKeyUseCase } diff --git a/internal/transit/http/dto/request.go b/internal/transit/http/dto/request.go index daa3683..54e35fe 100644 --- a/internal/transit/http/dto/request.go +++ b/internal/transit/http/dto/request.go @@ -7,6 +7,7 @@ import ( validation "github.com/jellydator/validation" cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" + transitDomain "github.com/allisson/secrets/internal/transit/domain" customValidation "github.com/allisson/secrets/internal/validation" ) @@ -22,7 +23,7 @@ func (r *CreateTransitKeyRequest) Validate() error { validation.Field(&r.Name, validation.Required, customValidation.NotBlank, - validation.Length(1, 255), + validation.Length(1, transitDomain.MaxTransitKeyNameLength), ), validation.Field(&r.Algorithm, validation.Required, diff --git a/internal/transit/http/transit_key_handler.go b/internal/transit/http/transit_key_handler.go index 932e62a..00dc3c9 100644 --- a/internal/transit/http/transit_key_handler.go +++ b/internal/transit/http/transit_key_handler.go @@ -9,7 +9,6 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" - authUseCase "github.com/allisson/secrets/internal/auth/usecase" "github.com/allisson/secrets/internal/httputil" "github.com/allisson/secrets/internal/transit/http/dto" transitUseCase "github.com/allisson/secrets/internal/transit/usecase" @@ -19,20 +18,17 @@ import ( // TransitKeyHandler handles HTTP requests for transit key management operations. // It coordinates authentication, authorization, and audit logging with the TransitKeyUseCase. type TransitKeyHandler struct { - transitKeyUseCase transitUseCase.TransitKeyUseCase - auditLogUseCase authUseCase.AuditLogUseCase - logger *slog.Logger + transitKeyUseCase transitUseCase.TransitKeyUseCase // Business logic for transit key lifecycle operations + logger *slog.Logger // Structured logger for request handling and error reporting } // NewTransitKeyHandler creates a new transit key handler with required dependencies. func NewTransitKeyHandler( transitKeyUseCase transitUseCase.TransitKeyUseCase, - auditLogUseCase authUseCase.AuditLogUseCase, logger *slog.Logger, ) *TransitKeyHandler { return &TransitKeyHandler{ transitKeyUseCase: transitKeyUseCase, - auditLogUseCase: auditLogUseCase, logger: logger, } } diff --git a/internal/transit/http/transit_key_handler_test.go b/internal/transit/http/transit_key_handler_test.go index 7333f9e..88830ef 100644 --- a/internal/transit/http/transit_key_handler_test.go +++ b/internal/transit/http/transit_key_handler_test.go @@ -31,7 +31,7 @@ func setupTestTransitKeyHandler(t *testing.T) (*TransitKeyHandler, *mocks.MockTr mockTransitKeyUseCase := mocks.NewMockTransitKeyUseCase(t) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - handler := NewTransitKeyHandler(mockTransitKeyUseCase, nil, logger) + handler := NewTransitKeyHandler(mockTransitKeyUseCase, logger) return handler, mockTransitKeyUseCase } diff --git a/internal/transit/usecase/metrics_decorator_test.go b/internal/transit/usecase/metrics_decorator_test.go new file mode 100644 index 0000000..4858671 --- /dev/null +++ b/internal/transit/usecase/metrics_decorator_test.go @@ -0,0 +1,369 @@ +package usecase_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" + transitDomain "github.com/allisson/secrets/internal/transit/domain" + "github.com/allisson/secrets/internal/transit/usecase" + usecaseMocks "github.com/allisson/secrets/internal/transit/usecase/mocks" +) + +// mockBusinessMetrics is a local mock for metrics.BusinessMetrics to avoid dependency issues. +type mockBusinessMetrics struct { + mock.Mock +} + +func (m *mockBusinessMetrics) RecordOperation(ctx context.Context, domain, operation, status string) { + m.Called(ctx, domain, operation, status) +} + +func (m *mockBusinessMetrics) RecordDuration( + ctx context.Context, + domain, operation string, + duration time.Duration, + status string, +) { + m.Called(ctx, domain, operation, duration, status) +} + +func TestTransitKeyUseCaseWithMetrics_Create(t *testing.T) { + mockNext := usecaseMocks.NewMockTransitKeyUseCase(t) + mockMetrics := &mockBusinessMetrics{} + uc := usecase.NewTransitKeyUseCaseWithMetrics(mockNext, mockMetrics) + + ctx := context.Background() + name := "test-key" + alg := cryptoDomain.AESGCM + + t.Run("Create_Success", func(t *testing.T) { + // Arrange + expectedKey := &transitDomain.TransitKey{ + ID: uuid.Must(uuid.NewV7()), + Name: name, + Version: 1, + DekID: uuid.Must(uuid.NewV7()), + CreatedAt: time.Now().UTC(), + } + + mockNext.EXPECT().Create(ctx, name, alg).Return(expectedKey, nil).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_key_create", "success").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_key_create", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Act + result, err := uc.Create(ctx, name, alg) + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedKey, result) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) + + t.Run("Create_Error", func(t *testing.T) { + // Arrange + expectedErr := errors.New("create failed") + + mockNext.EXPECT().Create(ctx, name, alg).Return(nil, expectedErr).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_key_create", "error").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_key_create", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Act + result, err := uc.Create(ctx, name, alg) + + // Assert + assert.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, expectedErr, err) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) +} + +func TestTransitKeyUseCaseWithMetrics_Rotate(t *testing.T) { + mockNext := usecaseMocks.NewMockTransitKeyUseCase(t) + mockMetrics := &mockBusinessMetrics{} + uc := usecase.NewTransitKeyUseCaseWithMetrics(mockNext, mockMetrics) + + ctx := context.Background() + name := "test-key" + alg := cryptoDomain.ChaCha20 + + t.Run("Rotate_Success", func(t *testing.T) { + // Arrange + expectedKey := &transitDomain.TransitKey{ + ID: uuid.Must(uuid.NewV7()), + Name: name, + Version: 2, + DekID: uuid.Must(uuid.NewV7()), + CreatedAt: time.Now().UTC(), + } + + mockNext.EXPECT().Rotate(ctx, name, alg).Return(expectedKey, nil).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_key_rotate", "success").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_key_rotate", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Act + result, err := uc.Rotate(ctx, name, alg) + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedKey, result) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) + + t.Run("Rotate_Error", func(t *testing.T) { + // Arrange + expectedErr := errors.New("rotation failed") + + mockNext.EXPECT().Rotate(ctx, name, alg).Return(nil, expectedErr).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_key_rotate", "error").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_key_rotate", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Act + result, err := uc.Rotate(ctx, name, alg) + + // Assert + assert.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, expectedErr, err) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) +} + +func TestTransitKeyUseCaseWithMetrics_Delete(t *testing.T) { + mockNext := usecaseMocks.NewMockTransitKeyUseCase(t) + mockMetrics := &mockBusinessMetrics{} + uc := usecase.NewTransitKeyUseCaseWithMetrics(mockNext, mockMetrics) + + ctx := context.Background() + transitKeyID := uuid.Must(uuid.NewV7()) + + t.Run("Delete_Success", func(t *testing.T) { + // Arrange + mockNext.EXPECT().Delete(ctx, transitKeyID).Return(nil).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_key_delete", "success").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_key_delete", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Act + err := uc.Delete(ctx, transitKeyID) + + // Assert + assert.NoError(t, err) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) + + t.Run("Delete_Error", func(t *testing.T) { + // Arrange + expectedErr := errors.New("deletion failed") + + mockNext.EXPECT().Delete(ctx, transitKeyID).Return(expectedErr).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_key_delete", "error").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_key_delete", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Act + err := uc.Delete(ctx, transitKeyID) + + // Assert + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) +} + +func TestTransitKeyUseCaseWithMetrics_Encrypt(t *testing.T) { + mockNext := usecaseMocks.NewMockTransitKeyUseCase(t) + mockMetrics := &mockBusinessMetrics{} + uc := usecase.NewTransitKeyUseCaseWithMetrics(mockNext, mockMetrics) + + ctx := context.Background() + name := "test-key" + plaintext := []byte("secret data") + + t.Run("Encrypt_Success", func(t *testing.T) { + // Arrange + expectedBlob := &transitDomain.EncryptedBlob{ + Version: 1, + Ciphertext: []byte("encrypted data"), + } + + mockNext.EXPECT().Encrypt(ctx, name, plaintext).Return(expectedBlob, nil).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_encrypt", "success").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_encrypt", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Act + result, err := uc.Encrypt(ctx, name, plaintext) + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedBlob, result) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) + + t.Run("Encrypt_Error", func(t *testing.T) { + // Arrange + expectedErr := errors.New("encryption failed") + + mockNext.EXPECT().Encrypt(ctx, name, plaintext).Return(nil, expectedErr).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_encrypt", "error").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_encrypt", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Act + result, err := uc.Encrypt(ctx, name, plaintext) + + // Assert + assert.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, expectedErr, err) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) +} + +func TestTransitKeyUseCaseWithMetrics_Decrypt(t *testing.T) { + mockNext := usecaseMocks.NewMockTransitKeyUseCase(t) + mockMetrics := &mockBusinessMetrics{} + uc := usecase.NewTransitKeyUseCaseWithMetrics(mockNext, mockMetrics) + + ctx := context.Background() + name := "test-key" + ciphertext := "1:ZW5jcnlwdGVkIGRhdGE=" + + t.Run("Decrypt_Success", func(t *testing.T) { + // Arrange + expectedBlob := &transitDomain.EncryptedBlob{ + Version: 1, + Plaintext: []byte("secret data"), + } + + mockNext.EXPECT().Decrypt(ctx, name, ciphertext).Return(expectedBlob, nil).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_decrypt", "success").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_decrypt", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Act + result, err := uc.Decrypt(ctx, name, ciphertext) + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedBlob, result) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) + + t.Run("Decrypt_Error", func(t *testing.T) { + // Arrange + expectedErr := errors.New("decryption failed") + + mockNext.EXPECT().Decrypt(ctx, name, ciphertext).Return(nil, expectedErr).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_decrypt", "error").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_decrypt", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Act + result, err := uc.Decrypt(ctx, name, ciphertext) + + // Assert + assert.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, expectedErr, err) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) +} + +func TestTransitKeyUseCaseWithMetrics_List(t *testing.T) { + mockNext := usecaseMocks.NewMockTransitKeyUseCase(t) + mockMetrics := &mockBusinessMetrics{} + uc := usecase.NewTransitKeyUseCaseWithMetrics(mockNext, mockMetrics) + + ctx := context.Background() + offset := 0 + limit := 50 + + t.Run("List_Success", func(t *testing.T) { + // Arrange + expectedKeys := []*transitDomain.TransitKey{ + { + ID: uuid.Must(uuid.NewV7()), + Name: "key-1", + Version: 1, + DekID: uuid.Must(uuid.NewV7()), + CreatedAt: time.Now().UTC(), + }, + { + ID: uuid.Must(uuid.NewV7()), + Name: "key-2", + Version: 1, + DekID: uuid.Must(uuid.NewV7()), + CreatedAt: time.Now().UTC(), + }, + } + + mockNext.EXPECT().List(ctx, offset, limit).Return(expectedKeys, nil).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_key_list", "success").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_key_list", mock.AnythingOfType("time.Duration"), "success"). + Return(). + Once() + + // Act + result, err := uc.List(ctx, offset, limit) + + // Assert + assert.NoError(t, err) + assert.Equal(t, expectedKeys, result) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) + + t.Run("List_Error", func(t *testing.T) { + // Arrange + expectedErr := errors.New("list failed") + + mockNext.EXPECT().List(ctx, offset, limit).Return(nil, expectedErr).Once() + mockMetrics.On("RecordOperation", ctx, "transit", "transit_key_list", "error").Return().Once() + mockMetrics.On("RecordDuration", ctx, "transit", "transit_key_list", mock.AnythingOfType("time.Duration"), "error"). + Return(). + Once() + + // Act + result, err := uc.List(ctx, offset, limit) + + // Assert + assert.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, expectedErr, err) + mockNext.AssertExpectations(t) + mockMetrics.AssertExpectations(t) + }) +} diff --git a/internal/transit/usecase/transit_key_usecase_test.go b/internal/transit/usecase/transit_key_usecase_test.go index 006ce0d..392cec3 100644 --- a/internal/transit/usecase/transit_key_usecase_test.go +++ b/internal/transit/usecase/transit_key_usecase_test.go @@ -1196,6 +1196,11 @@ func TestTransitKeyUseCase_Decrypt(t *testing.T) { Return(mockCipher, nil). Once() + mockCipher.EXPECT(). + NonceSize(). + Return(12). + Maybe() + // Execute uc := NewTransitKeyUseCase( mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, diff --git a/test/integration/api_test.go b/test/integration/api_test.go index ca85545..c5875a0 100644 --- a/test/integration/api_test.go +++ b/test/integration/api_test.go @@ -241,10 +241,10 @@ func setupIntegrationTestWithKMS(t *testing.T, dbDriver string) *integrationTest var dsn string if dbDriver == "postgres" { db = testutil.SetupPostgresDB(t) - dsn = testutil.PostgresTestDSN + dsn = testutil.GetPostgresTestDSN() } else { db = testutil.SetupMySQLDB(t) - dsn = testutil.MySQLTestDSN + dsn = testutil.GetMySQLTestDSN() } // Generate KMS key URI and ephemeral master key @@ -353,10 +353,10 @@ func setupIntegrationTest(t *testing.T, dbDriver string) *integrationTestContext var dsn string if dbDriver == "postgres" { db = testutil.SetupPostgresDB(t) - dsn = testutil.PostgresTestDSN + dsn = testutil.GetPostgresTestDSN() } else { db = testutil.SetupMySQLDB(t) - dsn = testutil.MySQLTestDSN + dsn = testutil.GetMySQLTestDSN() } // Generate KMS key URI and ephemeral master key for testing @@ -1774,10 +1774,10 @@ func setupIntegrationTestWithLockout( var dsn string if dbDriver == "postgres" { db = testutil.SetupPostgresDB(t) - dsn = testutil.PostgresTestDSN + dsn = testutil.GetPostgresTestDSN() } else { db = testutil.SetupMySQLDB(t) - dsn = testutil.MySQLTestDSN + dsn = testutil.GetMySQLTestDSN() } // Generate KMS key URI and ephemeral master key for testing diff --git a/test/integration/audit_log_signature_test.go b/test/integration/audit_log_signature_test.go index d0474e4..77f4370 100644 --- a/test/integration/audit_log_signature_test.go +++ b/test/integration/audit_log_signature_test.go @@ -34,12 +34,12 @@ func TestAuditLogSignature_EndToEnd(t *testing.T) { { name: "PostgreSQL", driver: "postgres", - dsn: testutil.PostgresTestDSN, + dsn: testutil.GetPostgresTestDSN(), }, { name: "MySQL", driver: "mysql", - dsn: testutil.MySQLTestDSN, + dsn: testutil.GetMySQLTestDSN(), }, }