diff --git a/README.md b/README.md index e6125b0..7f8fafe 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ import ( ) type HTTPServer struct { + Env string `env:"ENV"` Host string `env:"HOST" envDefault:"127.0.0.1"` Port int `env:"PORT" envRequired:"true"` Enabled bool `env:"ENABLED"` @@ -39,6 +40,16 @@ type HTTPServer struct { Headers map[string]string `env:"HEADERS"` // "k1=v1,k2=v2" } +func (cfg HTTPServer) Validate() error { + return envconfig.Assert( + envconfig.OneOf(cfg.Env, "ENVIRONMENT", "production", "prod"), + envconfig.Not( + envconfig.Range(cfg.Port, 0, 1023, "PORT"), + "PORT: must not be a reserved port (0-1023)", + ), + ) +} + func main() { var cfg HTTPServer @@ -49,6 +60,7 @@ func main() { fmt.Printf("%+v\n", cfg) } + ``` Example environment: diff --git a/assert.go b/assert.go new file mode 100644 index 0000000..6aae3e9 --- /dev/null +++ b/assert.go @@ -0,0 +1,422 @@ +package envconfig + +import ( + "fmt" + "net/url" + "os" + "regexp" + "strings" +) + +// AssertOpt represents a single validation check that returns an error if validation fails. +// AssertOpt functions are designed to be composable and can be passed to Assert() to perform +// multiple validations at once. +// +// Example: +// +// func (cfg MyConfig) Validate() error { +// return Assert( +// NotEmpty(cfg.Host, "HOST"), +// Range(cfg.Port, 1, 65535, "PORT"), +// ) +// } +type AssertOpt func() error + +// Assert runs all provided validation checks and collects any errors that occur. +// If any validation fails, it returns an ErrValidation containing all failures. +// If all validations pass, it returns nil. +// +// Assert is designed to be used with validation helper functions like NotEmpty, Range, etc. +// All validations are executed regardless of failures, allowing users to see all +// validation issues at once rather than fixing them one at a time. +// +// Example: +// +// func (cfg Config) Validate() error { +// return Assert( +// NotEmpty(cfg.APIKey, "API_KEY"), +// Range(cfg.Port, 1, 65535, "PORT"), +// OneOf(cfg.Environment, "ENVIRONMENT", "dev", "staging", "production"), +// ) +// } +func Assert(opts ...AssertOpt) error { + var errs []error + for _, opt := range opts { + if err := opt(); err != nil { + errs = append(errs, err) + } + } + if len(errs) == 0 { + return nil + } + return ErrValidation(errs) +} + +// ErrValidation is a collection of validation errors that occurred during Assert(). +// It implements the error interface and formats multiple errors into a single, +// human-readable error message. +type ErrValidation []error + +// Error returns a formatted string containing all validation errors, +// separated by semicolons. +func (e ErrValidation) Error() string { + if len(e) == 0 { + return "" + } + var sb strings.Builder + sb.WriteString("validation failed:") + for i, err := range e { + if i > 0 { + sb.WriteString(";") + } + sb.WriteString(" ") + sb.WriteString(err.Error()) + } + return sb.String() +} + +// NotEmpty validates that a string value is not empty. +// Returns an AssertOpt that fails if the value is an empty string. +// +// Parameters: +// - value: the string to validate +// - field: the name of the field (used in error messages) +// +// Example: +// +// NotEmpty(cfg.APIKey, "API_KEY") +func NotEmpty(value, field string) AssertOpt { + return func() error { + if value == "" { + return fmt.Errorf("%s: must not be empty", field) + } + return nil + } +} + +// Number is a constraint for all numeric types that can be compared +type Number interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | + ~float32 | ~float64 +} + +// Range validates that an integer value falls within a specified range (inclusive). +// Returns an AssertOpt that fails if the value is less than min or greater than max. +// +// Parameters: +// - value: the integer to validate +// - min: the minimum allowed value (inclusive) +// - max: the maximum allowed value (inclusive) +// - field: the name of the field (used in error messages) +// +// Example: +// +// Range(cfg.Port, 1, 65535, "PORT") +func Range[T Number](value, min, max T, field string) AssertOpt { + return func() error { + if value < min || value > max { + return fmt.Errorf("%s: must be between %v and %v, got %v", field, min, max, value) + } + return nil + } +} + +// Positive validates that an integer value is greater than zero. +// Returns an AssertOpt that fails if the value is less than or equal to zero. +// +// Parameters: +// - value: the integer to validate +// - field: the name of the field (used in error messages) +// +// Example: +// +// Positive(cfg.Workers, "WORKERS") +func Positive[T Number](value T, field string) AssertOpt { + return func() error { + if value <= 0 { + return fmt.Errorf("%s: must be positive, got %v", field, value) + } + return nil + } +} + +// NonNegative validates that an integer value is greater than or equal to zero. +// Returns an AssertOpt that fails if the value is negative. +// +// Parameters: +// - value: the integer to validate +// - field: the name of the field (used in error messages) +// +// Example: +// +// NonNegative(cfg.Retries, "RETRIES") +func NonNegative[T Number](value T, field string) AssertOpt { + return func() error { + if value < 0 { + return fmt.Errorf("%s: must be non-negative, got %v", field, value) + } + return nil + } +} + +// OneOf validates that a string value matches one of the allowed values. +// The comparison is case-sensitive. +// Returns an AssertOpt that fails if the value is not in the allowed list. +// +// Parameters: +// - value: the string to validate +// - field: the name of the field (used in error messages) +// - allowed: the list of allowed values +// +// Example: +// +// OneOf(cfg.LogLevel, "LOG_LEVEL", "debug", "info", "warn", "error") +func OneOf(value string, field string, allowed ...string) AssertOpt { + return func() error { + for _, a := range allowed { + if value == a { + return nil + } + } + return fmt.Errorf("%s: must be one of %v, got %q", field, allowed, value) + } +} + +// Custom validates a custom condition and returns a specified error message if it fails. +// This is a generic validator for arbitrary conditions that don't fit other helpers. +// Returns an AssertOpt that fails if the condition is false. +// +// Parameters: +// - condition: the boolean condition to check +// - field: the name of the field (used in error messages) +// - message: the error message to return if the condition is false +// +// Example: +// +// Custom(cfg.MaxRetries < cfg.Timeout, "TIMEOUT", "must be greater than MAX_RETRIES") +func Custom(condition bool, field, message string) AssertOpt { + return func() error { + if !condition { + return fmt.Errorf("%s: %s", field, message) + } + return nil + } +} + +// MinLength validates that a string has at least the specified minimum length. +// Returns an AssertOpt that fails if the string length is less than min. +// +// Parameters: +// - value: the string to validate +// - min: the minimum required length +// - field: the name of the field (used in error messages) +// +// Example: +// +// MinLength(cfg.Password, 8, "PASSWORD") +func MinLength(value string, min int, field string) AssertOpt { + return func() error { + if len(value) < min { + return fmt.Errorf("%s: minimum length is %d, got %d", field, min, len(value)) + } + return nil + } +} + +// MaxLength validates that a string does not exceed the specified maximum length. +// Returns an AssertOpt that fails if the string length is greater than max. +// +// Parameters: +// - value: the string to validate +// - max: the maximum allowed length +// - field: the name of the field (used in error messages) +// +// Example: +// +// MaxLength(cfg.Username, 50, "USERNAME") +func MaxLength(value string, max int, field string) AssertOpt { + return func() error { + if len(value) > max { + return fmt.Errorf("%s: maximum length is %d, got %d", field, max, len(value)) + } + return nil + } +} + +// Pattern validates that a string matches the specified regular expression pattern. +// Returns an AssertOpt that fails if the string does not match the pattern or if +// the pattern itself is invalid. +// +// Parameters: +// - value: the string to validate +// - field: the name of the field (used in error messages) +// - pattern: the regular expression pattern to match against +// +// Example: +// +// Pattern(cfg.Email, "EMAIL", `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) +func Pattern(value, field, pattern string) AssertOpt { + return func() error { + matched, err := regexp.MatchString(pattern, value) + if err != nil { + return fmt.Errorf("%s: invalid pattern: %w", field, err) + } + if !matched { + return fmt.Errorf("%s: must match pattern %q", field, pattern) + } + return nil + } +} + +// URL validates that a string is a valid URL according to Go's url.Parse. +// Returns an AssertOpt that fails if the URL cannot be parsed. +// +// Parameters: +// - value: the URL string to validate +// - field: the name of the field (used in error messages) +// +// Example: +// +// URL(cfg.APIEndpoint, "API_ENDPOINT") +func URL(value, field string) AssertOpt { + return func() error { + if value == "" { + return fmt.Errorf("%s: must not be empty", field) + } + if _, err := url.Parse(value); err != nil { + return fmt.Errorf("%s: invalid URL: %w", field, err) + } + return nil + } +} + +// FileExists validates that a file or directory exists at the specified path. +// Returns an AssertOpt that fails if the path does not exist. +// +// Parameters: +// - path: the file or directory path to check +// - field: the name of the field (used in error messages) +// +// Example: +// +// FileExists(cfg.ConfigFile, "CONFIG_FILE") +func FileExists(path, field string) AssertOpt { + return func() error { + if _, err := os.Stat(path); os.IsNotExist(err) { + return fmt.Errorf("%s: file does not exist: %s", field, path) + } else if err != nil { + return fmt.Errorf("%s: cannot access file: %w", field, err) + } + return nil + } +} + +// MinSliceLen validates that a slice has at least the specified minimum length. +// Returns an AssertOpt that fails if the slice length is less than min. +// +// Parameters: +// - length: the actual length of the slice +// - min: the minimum required length +// - field: the name of the field (used in error messages) +// +// Example: +// +// MinSliceLen(len(cfg.Servers), 1, "SERVERS") +func MinSliceLen(length, min int, field string) AssertOpt { + return func() error { + if length < min { + return fmt.Errorf("%s: minimum length is %d, got %d", field, min, length) + } + return nil + } +} + +// MaxSliceLen validates that a slice does not exceed the specified maximum length. +// Returns an AssertOpt that fails if the slice length is greater than max. +// +// Parameters: +// - length: the actual length of the slice +// - max: the maximum allowed length +// - field: the name of the field (used in error messages) +// +// Example: +// +// MaxSliceLen(len(cfg.Tags), 10, "TAGS") +func MaxSliceLen(length, max int, field string) AssertOpt { + return func() error { + if length > max { + return fmt.Errorf("%s: maximum length is %d, got %d", field, max, length) + } + return nil + } +} + +// NotEquals validates that a value does not equal the forbidden value. +// This is a generic function that works with any comparable type. +// Returns an AssertOpt that fails if the value equals the forbidden value. +// +// Parameters: +// - value: the value to validate +// - forbidden: the value that should not be matched +// - field: the name of the field (used in error messages) +// +// Example: +// +// NotEquals(cfg.Port, 22, "PORT") // disallow SSH port +// NotEquals(cfg.Mode, "insecure", "MODE") +// NotEquals(cfg.AdminPassword, "admin", "ADMIN_PASSWORD") +func NotEquals[T comparable](value, forbidden T, field string) AssertOpt { + return func() error { + if value == forbidden { + return fmt.Errorf("%s: must not equal %v", field, forbidden) + } + return nil + } +} + +// NotBlank validates that a string is not empty and not just whitespace. +// NotEmpty is already defined, but here's NotBlank for completeness +// Returns an AssertOpt that fails if the value is empty or contains only whitespace. +// +// Parameters: +// - value: the string to validate +// - field: the name of the field (used in error messages) +// +// Example: +// +// NotBlank(cfg.APIKey, "API_KEY") +func NotBlank(value, field string) AssertOpt { + return func() error { + if strings.TrimSpace(value) == "" { + return fmt.Errorf("%s: must not be blank", field) + } + return nil + } +} + +// Not inverts any AssertOpt, making it fail when the original would succeed +// and succeed when the original would fail. +// This is useful for creating negative assertions from existing validators. +// Returns an AssertOpt that inverts the result of the provided validator. +// +// Parameters: +// - opt: the AssertOpt to invert +// - customMessage: optional custom error message (if empty, a generic message is used) +// +// Example: +// +// Not(OneOf(cfg.Environment, "ENV", "production", "staging"), "must not be production or staging") +// Not(Pattern(cfg.Username, "USERNAME", `^admin.*`), "username must not start with 'admin'") +func Not(opt AssertOpt, customMessage string) AssertOpt { + return func() error { + err := opt() + if err == nil { + if customMessage != "" { + return fmt.Errorf("%s", customMessage) + } + return fmt.Errorf("validation condition must not be true") + } + return nil + } +} diff --git a/assert_test.go b/assert_test.go new file mode 100644 index 0000000..7c27044 --- /dev/null +++ b/assert_test.go @@ -0,0 +1,801 @@ +package envconfig_test + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/struct0x/envconfig" +) + +func TestAssert(t *testing.T) { + t.Run("no_errors", func(t *testing.T) { + err := envconfig.Assert( + envconfig.NotEmpty("value", "FIELD1"), + envconfig.Range(5, 1, 10, "FIELD2"), + envconfig.Positive(1, "FIELD3"), + ) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + t.Run("errors", func(t *testing.T) { + err := envconfig.Assert( + envconfig.NotEmpty("", "FIELD1"), + envconfig.Range(100, 1, 10, "FIELD2"), + envconfig.Positive(-1, "FIELD3"), + ) + if err == nil { + t.Fatal("Expected error, got nil") + } + + var errVal envconfig.ErrValidation + if !errors.As(err, &errVal) { + t.Fatalf("Expected ErrValidation, got %T", err) + } + if len(errVal) != 3 { + t.Errorf("Expected 3 errors, got %d", len(errVal)) + } + + errStr := err.Error() + if !strings.Contains(errStr, "FIELD1") { + t.Errorf("Expected error to contain 'FIELD1', got %v", errStr) + } + if !strings.Contains(errStr, "FIELD2") { + t.Errorf("Expected error to contain 'FIELD2', got %v", errStr) + } + if !strings.Contains(errStr, "FIELD3") { + t.Errorf("Expected error to contain 'FIELD3', got %v", errStr) + } + }) + + t.Run("empty_opts", func(t *testing.T) { + err := envconfig.Assert() + if err != nil { + t.Errorf("Expected no error for empty opts, got %v", err) + } + }) +} + +func TestNotEmpty(t *testing.T) { + tests := []struct { + name string + value string + field string + wantError bool + }{ + {"valid", "value", "FIELD", false}, + {"empty", "", "FIELD", true}, + {"whitespace", " ", "FIELD", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.NotEmpty(tt.value, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("NotEmpty() error = %v, wantError %v", err, tt.wantError) + } + + if err != nil && !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + }) + } +} + +func TestRange(t *testing.T) { + tests := []struct { + name string + value int + min int + max int + field string + wantError bool + }{ + {"within_range", 5, 1, 10, "PORT", false}, + {"at_min", 1, 1, 10, "PORT", false}, + {"at_max", 10, 1, 10, "PORT", false}, + {"below_min", 0, 1, 10, "PORT", true}, + {"above_max", 11, 1, 10, "PORT", true}, + {"negative_range", -5, -10, -1, "FIELD", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.Range(tt.value, tt.min, tt.max, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("Range() error = %v, wantError %v", err, tt.wantError) + } + if err != nil { + if !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + if !strings.Contains(err.Error(), fmt.Sprintf("%d", tt.value)) { + t.Errorf("Error should contain value %d, got %v", tt.value, err) + } + } + }) + } +} + +func TestPositive(t *testing.T) { + tests := []struct { + name string + value int + field string + wantError bool + }{ + {"positive", 1, "COUNT", false}, + {"large_positive", 1000, "COUNT", false}, + {"zero", 0, "COUNT", true}, + {"negative", -1, "COUNT", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.Positive(tt.value, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("Positive() error = %v, wantError %v", err, tt.wantError) + } + if err != nil && !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + }) + } +} + +func TestNonNegative(t *testing.T) { + tests := []struct { + name string + value int + field string + wantError bool + }{ + {"positive", 1, "RETRIES", false}, + {"zero", 0, "RETRIES", false}, + {"negative", -1, "RETRIES", true}, + {"large_negative", -100, "RETRIES", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.NonNegative(tt.value, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("NonNegative() error = %v, wantError %v", err, tt.wantError) + } + if err != nil && !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + }) + } +} + +func TestOneOf(t *testing.T) { + tests := []struct { + name string + value string + field string + allowed []string + wantError bool + }{ + {"valid_first", "dev", "ENV", []string{"dev", "staging", "prod"}, false}, + {"valid_middle", "staging", "ENV", []string{"dev", "staging", "prod"}, false}, + {"valid_last", "prod", "ENV", []string{"dev", "staging", "prod"}, false}, + {"invalid", "test", "ENV", []string{"dev", "staging", "prod"}, true}, + {"case_sensitive", "Dev", "ENV", []string{"dev", "staging", "prod"}, true}, + {"empty_not_allowed", "", "ENV", []string{"dev", "staging", "prod"}, true}, + {"single_allowed", "only", "ENV", []string{"only"}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.OneOf(tt.value, tt.field, tt.allowed...)() + if (err != nil) != tt.wantError { + t.Errorf("OneOf() error = %v, wantError %v", err, tt.wantError) + } + if err != nil { + if !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + if !strings.Contains(err.Error(), tt.value) { + t.Errorf("Error should contain value %q, got %v", tt.value, err) + } + } + }) + } +} + +func TestCustom(t *testing.T) { + tests := []struct { + name string + condition bool + field string + message string + wantError bool + }{ + {"true_condition", true, "FIELD", "custom message", false}, + {"false_condition", false, "FIELD", "custom message", true}, + {"complex_condition", 5 > 3, "FIELD", "should be greater", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.Custom(tt.condition, tt.field, tt.message)() + if (err != nil) != tt.wantError { + t.Errorf("Custom() error = %v, wantError %v", err, tt.wantError) + } + if err != nil { + if !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + if !strings.Contains(err.Error(), tt.message) { + t.Errorf("Error should contain message %q, got %v", tt.message, err) + } + } + }) + } +} + +func TestMinLength(t *testing.T) { + tests := []struct { + name string + value string + min int + field string + wantError bool + }{ + {"exact_min", "abc", 3, "PASSWORD", false}, + {"above_min", "abcde", 3, "PASSWORD", false}, + {"below_min", "ab", 3, "PASSWORD", true}, + {"empty_string", "", 1, "PASSWORD", true}, + {"zero_min", "", 0, "PASSWORD", false}, + {"unicode", "日本語", 3, "TEXT", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.MinLength(tt.value, tt.min, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("MinLength() error = %v, wantError %v", err, tt.wantError) + } + if err != nil && !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + }) + } +} + +func TestMaxLength(t *testing.T) { + tests := []struct { + name string + value string + max int + field string + wantError bool + }{ + {"exact_max", "abc", 3, "USERNAME", false}, + {"below_max", "ab", 3, "USERNAME", false}, + {"above_max", "abcd", 3, "USERNAME", true}, + {"empty_string", "", 10, "USERNAME", false}, + {"zero_max", "a", 0, "USERNAME", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.MaxLength(tt.value, tt.max, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("MaxLength() error = %v, wantError %v", err, tt.wantError) + } + if err != nil && !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + }) + } +} + +func TestPattern(t *testing.T) { + tests := []struct { + name string + value string + field string + pattern string + wantError bool + }{ + {"valid_email", "test@example.com", "EMAIL", `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`, false}, + {"invalid_email", "invalid.email", "EMAIL", `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`, true}, + {"digits_only", "12345", "CODE", `^\d+$`, false}, + {"letters_in_digits", "123a5", "CODE", `^\d+$`, true}, + {"invalid_pattern", "value", "FIELD", `[`, true}, + {"empty_value", "", "FIELD", `^.+$`, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.Pattern(tt.value, tt.field, tt.pattern)() + if (err != nil) != tt.wantError { + t.Errorf("Pattern() error = %v, wantError %v", err, tt.wantError) + } + if err != nil && !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + }) + } +} + +func TestURL(t *testing.T) { + tests := []struct { + name string + value string + field string + wantError bool + }{ + {"valid_http", "http://example.com", "API_URL", false}, + {"valid_https", "https://example.com", "API_URL", false}, + {"valid_with_path", "https://example.com/api/v1", "API_URL", false}, + {"valid_with_query", "https://example.com?key=value", "API_URL", false}, + {"valid_with_port", "http://localhost:8080", "API_URL", false}, + {"relative_url", "/api/v1", "API_URL", false}, // url.Parse accepts relative URLs + {"empty_string", "", "API_URL", true}, + {"invalid_url", "ht4tp://invalid", "API_URL", false}, // url.Parse is lenient + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.URL(tt.value, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("URL() error = %v, wantError %v", err, tt.wantError) + } + if err != nil && !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + }) + } +} + +func TestFileExists(t *testing.T) { + tempDir := t.TempDir() + existingFile := filepath.Join(tempDir, "exists.txt") + if err := os.WriteFile(existingFile, []byte("content"), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + t.Cleanup(func() { + if err := os.Remove(existingFile); err != nil { + t.Errorf("Failed to remove test file: %v", err) + } + }) + + tests := []struct { + name string + path string + field string + wantError bool + }{ + {"existing_file", existingFile, "CONFIG_FILE", false}, + {"existing_dir", tempDir, "CONFIG_DIR", false}, + {"non_existent", filepath.Join(tempDir, "missing.txt"), "CONFIG_FILE", true}, + {"empty_path", "", "CONFIG_FILE", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.FileExists(tt.path, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("FileExists() error = %v, wantError %v", err, tt.wantError) + } + if err != nil && !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + }) + } +} + +func TestMinSliceLen(t *testing.T) { + tests := []struct { + name string + length int + min int + field string + wantError bool + }{ + {"exact_min", 3, 3, "SERVERS", false}, + {"above_min", 5, 3, "SERVERS", false}, + {"below_min", 2, 3, "SERVERS", true}, + {"zero_length", 0, 1, "SERVERS", true}, + {"zero_min", 0, 0, "SERVERS", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.MinSliceLen(tt.length, tt.min, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("MinSliceLen() error = %v, wantError %v", err, tt.wantError) + } + if err != nil && !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + }) + } +} + +func TestMaxSliceLen(t *testing.T) { + tests := []struct { + name string + length int + max int + field string + wantError bool + }{ + {"exact_max", 3, 3, "TAGS", false}, + {"below_max", 2, 3, "TAGS", false}, + {"above_max", 4, 3, "TAGS", true}, + {"zero_length", 0, 10, "TAGS", false}, + {"zero_max", 1, 0, "TAGS", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.MaxSliceLen(tt.length, tt.max, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("MaxSliceLen() error = %v, wantError %v", err, tt.wantError) + } + if err != nil && !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + }) + } +} + +func TestComposableValidators(t *testing.T) { + requireHTTPS := func(urlStr, field string) envconfig.AssertOpt { + return func() error { + return envconfig.Assert( + envconfig.NotEmpty(urlStr, field), + envconfig.URL(urlStr, field), + envconfig.Custom(strings.HasPrefix(urlStr, "https://"), field, "must use HTTPS"), + ) + } + } + + t.Run("valid_https", func(t *testing.T) { + err := requireHTTPS("https://example.com", "API_URL")() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + t.Run("invalid_http", func(t *testing.T) { + err := requireHTTPS("http://example.com", "API_URL")() + if err == nil { + t.Fatal("Expected error for HTTP URL") + } + if !strings.Contains(err.Error(), "HTTPS") { + t.Errorf("Expected error about HTTPS, got %v", err) + } + }) + + t.Run("empty_url", func(t *testing.T) { + err := requireHTTPS("", "API_URL")() + if err == nil { + t.Fatal("Expected error for empty URL") + } + }) +} + +func TestNotEquals(t *testing.T) { + t.Run("integers", func(t *testing.T) { + tests := []struct { + name string + value int + forbidden int + field string + wantError bool + }{ + {"different_values", 8080, 22, "PORT", false}, + {"equal_values", 22, 22, "PORT", true}, + {"zero_vs_nonzero", 0, 1, "COUNT", false}, + {"both_zero", 0, 0, "COUNT", true}, + {"negative_values", -1, -2, "OFFSET", false}, + {"equal_negatives", -5, -5, "OFFSET", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.NotEquals(tt.value, tt.forbidden, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("NotEquals() error = %v, wantError %v", err, tt.wantError) + } + if err != nil { + if !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + if !strings.Contains(err.Error(), fmt.Sprintf("%v", tt.forbidden)) { + t.Errorf("Error should contain forbidden value %v, got %v", tt.forbidden, err) + } + } + }) + } + }) + + t.Run("strings", func(t *testing.T) { + tests := []struct { + name string + value string + forbidden string + field string + wantError bool + }{ + {"different_strings", "allowed", "forbidden", "PASSWORD", false}, + {"equal_strings", "admin", "admin", "PASSWORD", true}, + {"case_sensitive", "Admin", "admin", "USERNAME", false}, + {"empty_vs_nonempty", "", "forbidden", "FIELD", false}, + {"both_empty", "", "", "FIELD", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.NotEquals(tt.value, tt.forbidden, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("NotEquals() error = %v, wantError %v", err, tt.wantError) + } + if err != nil && !strings.Contains(err.Error(), tt.field) { + t.Errorf("Error should contain field name %q, got %v", tt.field, err) + } + }) + } + }) + + t.Run("floats", func(t *testing.T) { + tests := []struct { + name string + value float64 + forbidden float64 + field string + wantError bool + }{ + {"different_floats", 3.14, 2.71, "PI", false}, + {"equal_floats", 1.5, 1.5, "RATIO", true}, + {"zero_float", 0.0, 0.0, "VALUE", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.NotEquals(tt.value, tt.forbidden, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("NotEquals() error = %v, wantError %v", err, tt.wantError) + } + }) + } + }) + + t.Run("booleans", func(t *testing.T) { + tests := []struct { + name string + value bool + forbidden bool + field string + wantError bool + }{ + {"true_vs_false", true, false, "FLAG", false}, + {"false_vs_true", false, true, "FLAG", false}, + {"both_true", true, true, "FLAG", true}, + {"both_false", false, false, "FLAG", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := envconfig.NotEquals(tt.value, tt.forbidden, tt.field)() + if (err != nil) != tt.wantError { + t.Errorf("NotEquals() error = %v, wantError %v", err, tt.wantError) + } + }) + } + }) +} + +func TestNot(t *testing.T) { + t.Run("invert_passing_validation", func(t *testing.T) { + // NotEmpty passes for non-empty string + opt := envconfig.NotEmpty("value", "FIELD") + err := opt() + if err != nil { + t.Fatalf("Setup failed: NotEmpty should pass, got %v", err) + } + + // Not() should make it fail + invertedOpt := envconfig.Not(opt, "custom error message") + err = invertedOpt() + if err == nil { + t.Fatal("Expected Not() to invert passing validation to fail") + } + if !strings.Contains(err.Error(), "custom error message") { + t.Errorf("Expected custom error message, got %v", err) + } + }) + + t.Run("invert_failing_validation", func(t *testing.T) { + // NotEmpty fails for empty string + opt := envconfig.NotEmpty("", "FIELD") + err := opt() + if err == nil { + t.Fatal("Setup failed: NotEmpty should fail for empty string") + } + + // Not() should make it pass + invertedOpt := envconfig.Not(opt, "should not appear") + err = invertedOpt() + if err != nil { + t.Errorf("Expected Not() to invert failing validation to pass, got %v", err) + } + }) + + t.Run("default_message_when_empty", func(t *testing.T) { + opt := envconfig.NotEmpty("value", "FIELD") + invertedOpt := envconfig.Not(opt, "") + err := invertedOpt() + if err == nil { + t.Fatal("Expected error") + } + if !strings.Contains(err.Error(), "must not be true") { + t.Errorf("Expected default message, got %v", err) + } + }) + + t.Run("invert_range_validation", func(t *testing.T) { + // Value is in range [1, 10] + opt := envconfig.Range(5, 1, 10, "PORT") + invertedOpt := envconfig.Not(opt, "PORT: must not be in range 1-10") + err := invertedOpt() + if err == nil { + t.Fatal("Expected Not(Range) to fail when value is in range") + } + }) + + t.Run("invert_oneof_validation", func(t *testing.T) { + opt := envconfig.OneOf("production", "ENV", "production", "staging") + invertedOpt := envconfig.Not(opt, "ENV: must not be production or staging") + err := invertedOpt() + if err == nil { + t.Fatal("Expected Not(OneOf) to fail when value is in list") + } + if !strings.Contains(err.Error(), "must not be production or staging") { + t.Errorf("Expected custom message, got %v", err) + } + + // Value is NOT in allowed list + opt2 := envconfig.OneOf("development", "ENV", "production", "staging") + invertedOpt2 := envconfig.Not(opt2, "ENV: must not be production or staging") + err2 := invertedOpt2() + if err2 != nil { + t.Errorf("Expected Not(OneOf) to pass when value is not in list, got %v", err2) + } + }) + + t.Run("invert_pattern_validation", func(t *testing.T) { + // Pattern matches + opt := envconfig.Pattern("admin123", "USERNAME", `^admin.*`) + invertedOpt := envconfig.Not(opt, "USERNAME: must not start with 'admin'") + err := invertedOpt() + if err == nil { + t.Fatal("Expected Not(Pattern) to fail when pattern matches") + } + + // Pattern doesn't match + opt2 := envconfig.Pattern("user123", "USERNAME", `^admin.*`) + invertedOpt2 := envconfig.Not(opt2, "USERNAME: must not start with 'admin'") + err2 := invertedOpt2() + if err2 != nil { + t.Errorf("Expected Not(Pattern) to pass when pattern doesn't match, got %v", err2) + } + }) +} + +func TestNotEqualsComposition(t *testing.T) { + type ServerConfig struct { + Port int + AdminPassword string + Environment string + } + + t.Run("valid_config", func(t *testing.T) { + cfg := ServerConfig{ + Port: 8080, + AdminPassword: "secure_p@ssw0rd", + Environment: "production", + } + + err := envconfig.Assert( + envconfig.NotEquals(cfg.Port, 22, "PORT"), + envconfig.NotEquals(cfg.Port, 3389, "PORT"), + envconfig.NotEquals(cfg.AdminPassword, "admin", "ADMIN_PASSWORD"), + envconfig.NotEquals(cfg.AdminPassword, "password", "ADMIN_PASSWORD"), + ) + + if err != nil { + t.Errorf("Expected valid config to pass, got error: %v", err) + } + }) + + t.Run("invalid_config_multiple_violations", func(t *testing.T) { + cfg := ServerConfig{ + Port: 22, + AdminPassword: "admin", + Environment: "production", + } + + err := envconfig.Assert( + envconfig.NotEquals(cfg.Port, 22, "PORT"), + envconfig.NotEquals(cfg.Port, 3389, "PORT"), + envconfig.NotEquals(cfg.AdminPassword, "admin", "ADMIN_PASSWORD"), + envconfig.NotEquals(cfg.AdminPassword, "password", "ADMIN_PASSWORD"), + ) + + if err == nil { + t.Fatal("Expected errors, got nil") + } + + var errVal envconfig.ErrValidation + ok := errors.As(err, &errVal) + if !ok { + t.Fatalf("Expected ErrValidation, got %T", err) + } + + if len(errVal) != 2 { + t.Errorf("Expected 2 errors (PORT and ADMIN_PASSWORD), got %d", len(errVal)) + } + }) +} + +func TestNotComposition(t *testing.T) { + t.Run("composed_validators", func(t *testing.T) { + // Create a validator that ensures port is NOT in reserved range (0-1023) + notReservedPort := func(port int, field string) envconfig.AssertOpt { + return envconfig.Not( + envconfig.Range(port, 0, 1023, field), + field+": must not be a reserved port (0-1023)", + ) + } + + // Test with non-reserved port + err := notReservedPort(8080, "PORT")() + if err != nil { + t.Errorf("Expected non-reserved port to pass, got %v", err) + } + + // Test with reserved port + err = notReservedPort(80, "PORT")() + if err == nil { + t.Fatal("Expected reserved port to fail") + } + if !strings.Contains(err.Error(), "reserved port") { + t.Errorf("Expected 'reserved port' in error, got %v", err) + } + }) + + t.Run("not_in_blocklist", func(t *testing.T) { + notInBlocklist := func(value string, field string, blocklist ...string) envconfig.AssertOpt { + return envconfig.Not( + envconfig.OneOf(value, field, blocklist...), + field+": value is in the blocklist", + ) + } + + // Test with allowed value + err := notInBlocklist("secure123", "PASSWORD", "password", "12345", "admin")() + if err != nil { + t.Errorf("Expected allowed password to pass, got %v", err) + } + + // Test with blocked value + err = notInBlocklist("admin", "PASSWORD", "password", "12345", "admin")() + if err == nil { + t.Fatal("Expected blocked password to fail") + } + if !strings.Contains(err.Error(), "blocklist") { + t.Errorf("Expected 'blocklist' in error, got %v", err) + } + }) +} diff --git a/envconfig.go b/envconfig.go index ae3b8ba..bb5ffdc 100644 --- a/envconfig.go +++ b/envconfig.go @@ -88,6 +88,7 @@ import ( // - Parsing/conversion failures return errors that include the env key. // - Unsupported leaf types (that do not implement a supported unmarshal // interface) cause an error. +// - any type can implement Validator interface, and it will be called as soon as value if populated. // // Note on empties: // @@ -114,6 +115,10 @@ func Read[T any](holder *T, lookupEnv ...func(string) (string, bool)) error { return read(lookupEnvFunc, "", holder) } +type Validator interface { + Validate() error +} + func read(le func(string) (string, bool), prefix string, holder any) error { if len(prefix) > 0 { if err, ok := tryUnmarshalKnownInterface(le, prefix, holder); ok { @@ -121,7 +126,8 @@ func read(le func(string) (string, bool), prefix string, holder any) error { } } - holderValue := reflect.ValueOf(holder).Elem() + holderPtr := reflect.ValueOf(holder) + holderValue := holderPtr.Elem() fields := reflect.VisibleFields(holderValue.Type()) for _, field := range fields { @@ -197,6 +203,18 @@ func read(le func(string) (string, bool), prefix string, holder any) error { if err := setValue(fieldVal, envVal); err != nil { return fmt.Errorf("envconfig: %q failed to populate: %w", field.Name, err) } + + if validator, ok := fieldVal.Interface().(Validator); ok { + if err := validator.Validate(); err != nil { + return fmt.Errorf("envconfig: %q failed to validate: %w", field.Name, err) + } + } + } + + if validator, ok := holderPtr.Interface().(Validator); ok { + if err := validator.Validate(); err != nil { + return fmt.Errorf("envconfig: failed to validate: %w", err) + } } return nil diff --git a/envconfig_test.go b/envconfig_test.go index c0a31db..f921caf 100644 --- a/envconfig_test.go +++ b/envconfig_test.go @@ -83,6 +83,21 @@ func TestReadValues(t *testing.T) { CustomTextUnmarshaler: CustomTextUnmarshaler{ Value: "***custom***", }, + CustomBinaryUnmarshaler: CustomBinaryUnmarshaler{ + Value: "***custom2***", + }, + CustomJSONUnmarshaler: CustomJSONUnmarshaler{ + Value: "***custom3***", + }, + CustomTextUnmarshaler2: CustomTextUnmarshaler{ + Value: "***custom***", + }, + CustomBinaryUnmarshaler2: CustomBinaryUnmarshaler{ + Value: "***custom2***", + }, + CustomJSONUnmarshaler2: CustomJSONUnmarshaler{ + Value: "***custom3***", + }, Duration: time.Hour, SliceDuration: []time.Duration{ time.Hour, @@ -281,10 +296,16 @@ type Config struct { StringDefault string `env:"MISSING" envDefault:"Default Value"` MissingValue string `env:"MISSING"` - CustomTextUnmarshaler CustomTextUnmarshaler `env:"CUSTOM"` - Duration time.Duration `env:"DURATION"` - SliceDuration []time.Duration `env:"SDUR"` - MapDuration map[string]time.Duration `env:"MDUR"` + CustomTextUnmarshaler CustomTextUnmarshaler `envPrefix:"CUSTOM"` + CustomBinaryUnmarshaler CustomBinaryUnmarshaler `envPrefix:"CUSTOM"` + CustomJSONUnmarshaler CustomJSONUnmarshaler `envPrefix:"CUSTOM"` + + CustomTextUnmarshaler2 CustomTextUnmarshaler `env:"CUSTOM"` + CustomBinaryUnmarshaler2 CustomBinaryUnmarshaler `env:"CUSTOM"` + CustomJSONUnmarshaler2 CustomJSONUnmarshaler `env:"CUSTOM"` + Duration time.Duration `env:"DURATION"` + SliceDuration []time.Duration `env:"SDUR"` + MapDuration map[string]time.Duration `env:"MDUR"` } type SubConfig struct { @@ -298,7 +319,7 @@ type SubSubConfig struct { } type CustomTextUnmarshaler struct { - Value string + Value string `env:"VALUE"` } func (c *CustomTextUnmarshaler) UnmarshalText(text []byte) error { @@ -306,6 +327,24 @@ func (c *CustomTextUnmarshaler) UnmarshalText(text []byte) error { return nil } +type CustomBinaryUnmarshaler struct { + Value string `env:"VALUE"` +} + +func (c *CustomBinaryUnmarshaler) UnmarshalBinary(text []byte) error { + c.Value = "***" + string(text) + "2***" + return nil +} + +type CustomJSONUnmarshaler struct { + Value string `env:"VALUE"` +} + +func (c *CustomJSONUnmarshaler) UnmarshalJSON(text []byte) error { + c.Value = "***" + string(text) + "3***" + return nil +} + func ptr[T any](t T) *T { return &t } @@ -416,3 +455,26 @@ func TestEmptySlice(t *testing.T) { t.Errorf("Expected empty string, got %q", cfg.Value) } } + +type ConfigWithValidation struct { + Value string `env:"VALUE"` +} + +func (c *ConfigWithValidation) Validate() error { + return envconfig.Assert( + envconfig.Custom(c.Value != "invalid", "VALUE", "invalid value"), + ) +} + +func TestValidation(t *testing.T) { + le := func(key string) (string, bool) { + return "invalid", true + } + + var cfg ConfigWithValidation + + err := envconfig.Read(&cfg, le) + if err == nil { + t.Errorf("Expected error") + } +}