diff --git a/help_doc.go b/help_doc.go index 7b587be..f8cf6de 100644 --- a/help_doc.go +++ b/help_doc.go @@ -36,21 +36,25 @@ type ExampleDoc struct { // FlagDoc describes a single flag in the help document. type FlagDoc struct { - Name string `json:"name"` - Short string `json:"short"` - Type string `json:"type"` - Default string `json:"default,omitempty"` - Usage string `json:"usage"` - Group string `json:"group,omitempty"` - IsBool bool `json:"isBool,omitempty"` - Choices []string `json:"choices,omitempty"` + Name string `json:"name"` + Short string `json:"short"` + Type string `json:"type"` + Default string `json:"default,omitempty"` + Usage string `json:"usage"` + Group string `json:"group,omitempty"` + IsBool bool `json:"isBool,omitempty"` + Choices []string `json:"choices,omitempty"` + EnvVar string `json:"envVar,omitempty"` + Required bool `json:"required,omitempty"` } // PositionalDoc describes a positional argument in the help document. type PositionalDoc struct { - Name string `json:"name"` - Usage string `json:"usage"` - Type string `json:"type"` + Name string `json:"name"` + Usage string `json:"usage"` + Type string `json:"type"` + EnvVar string `json:"envVar,omitempty"` + Required bool `json:"required,omitempty"` } // FlagSetDoc describes a standalone FlagSet's help information. @@ -74,13 +78,15 @@ func (f *FlagSet) HelpDoc() *FlagSetDoc { f.VisitAll(func(flag *Flag) { fd := FlagDoc{ - Name: flag.Name, - Type: flag.Value.Type(), - Default: flag.DefValue, - Usage: flag.Usage, - Group: flag.Group, - IsBool: flag.Value.IsBool(), - Choices: []string{}, + Name: flag.Name, + Type: flag.Value.Type(), + Default: flag.DefValue, + Usage: flag.Usage, + Group: flag.Group, + IsBool: flag.Value.IsBool(), + Choices: []string{}, + EnvVar: flag.EnvVar, + Required: flag.Required, } if flag.Short != 0 { fd.Short = string(flag.Short) @@ -97,9 +103,11 @@ func (f *FlagSet) HelpDoc() *FlagSetDoc { for _, pf := range f.GetPositionalFields() { doc.PositionalArgs = append(doc.PositionalArgs, PositionalDoc{ - Name: pf.Name, - Usage: pf.Usage, - Type: pf.Type.String(), + Name: pf.Name, + Usage: pf.Usage, + Type: pf.Type.String(), + EnvVar: pf.EnvVar, + Required: pf.Required, }) } diff --git a/mflags.go b/mflags.go index d5a9c2f..bb24b58 100644 --- a/mflags.go +++ b/mflags.go @@ -3,6 +3,7 @@ package mflags import ( "errors" "fmt" + "os" "reflect" "strconv" "strings" @@ -16,14 +17,18 @@ var ( ErrInvalidChoice = errors.New("invalid choice") ErrHelp = errors.New("help requested") ErrShowHelp = errors.New("show help") // Return from Command.Run to trigger help display + ErrRequired = errors.New("required flag not provided") ) // PositionalField represents a positional argument field type PositionalField struct { - Name string // Field name (e.g., "Command", "Target") - Usage string // Usage description for help output - Value reflect.Value // The reflect.Value of the field - Type reflect.Type // The type of the field + Name string // Field name (e.g., "Command", "Target") + Usage string // Usage description for help output + Value reflect.Value // The reflect.Value of the field + Type reflect.Type // The type of the field + EnvVar string // environment variable name (from env:"..." struct tag) + Required bool // whether this positional must be provided + HasValue bool // true if value was set by env var or CLI arg } type FlagSet struct { @@ -41,6 +46,8 @@ type FlagSet struct { disableAutoHelp bool // If true, don't automatically handle -h/--help in Parse currentGroup string // ambient group name set by FromStruct options or Group() calls groupOrder []string // ordered list of distinct group names (insertion order) + requiredFlags []string // flag names marked required:"true" + requiredPos []int // positional indices marked required:"true" } type Flag struct { @@ -50,6 +57,9 @@ type Flag struct { Value Value DefValue string Group string // group name for help rendering; empty = default "Options:" + EnvVar string // environment variable name (from env:"..." struct tag) + Required bool // whether this flag must be provided + HasValue bool // true if value was set by env var or CLI arg } type Value interface { @@ -1424,6 +1434,7 @@ func (f *FlagSet) Parse(arguments []string) error { if err := setFieldValue(field.Value, f.args[pos]); err != nil { return fmt.Errorf("invalid value for position %d: %v", pos, err) } + field.HasValue = true } } @@ -1449,6 +1460,38 @@ func (f *FlagSet) Parse(arguments []string) error { *f.unknownField = f.unknownFlags } + if err := f.validateRequired(); err != nil { + return err + } + + return nil +} + +// validateRequired checks that all flags and positionals marked required:"true" +// have been provided a value (via CLI arg or env var). +func (f *FlagSet) validateRequired() error { + var missing []string + + for _, name := range f.requiredFlags { + if flag, ok := f.flags[name]; ok { + if !flag.HasValue { + missing = append(missing, "--"+name) + } + } + } + + for _, pos := range f.requiredPos { + if field, ok := f.posFields[pos]; ok { + if !field.HasValue { + missing = append(missing, field.Name) + } + } + } + + if len(missing) > 0 { + return fmt.Errorf("%w: %s", ErrRequired, strings.Join(missing, ", ")) + } + return nil } @@ -1492,6 +1535,9 @@ func (f *FlagSet) parseLongFlag(name string, args []string, index *int) (bool, e return false, fmt.Errorf("%w: --%s: %v", ErrInvalidValue, name, err) } + flag.HasValue = true + + return true, nil } @@ -1515,6 +1561,8 @@ func (f *FlagSet) parseShortFlags(shortFlags string, args []string, index *int) if err := flag.Value.Set("true"); err != nil { return fmt.Errorf("%w: -%c: %v", ErrInvalidValue, r, err) } + flag.HasValue = true + } else { // Check if there are more characters after this flag if i < len(runes)-1 { @@ -1529,6 +1577,8 @@ func (f *FlagSet) parseShortFlags(shortFlags string, args []string, index *int) if err := flag.Value.Set(value); err != nil { return fmt.Errorf("%w: -%c: %v", ErrInvalidValue, r, err) } + flag.HasValue = true + break } else if *index+1 < len(args) { value := args[*index+1] @@ -1536,6 +1586,8 @@ func (f *FlagSet) parseShortFlags(shortFlags string, args []string, index *int) if err := flag.Value.Set(value); err != nil { return fmt.Errorf("%w: -%c: %v", ErrInvalidValue, r, err) } + flag.HasValue = true + } else { return fmt.Errorf("%w: -%c", ErrMissingValue, r) } @@ -1620,6 +1672,8 @@ var knownTags = map[string]bool{ "long": true, "short": true, "default": true, + "env": true, + "required": true, "usage": true, "description": true, "choice": true, @@ -1763,6 +1817,8 @@ func InGroup(name string) FromStructOption { // - `long:"name"` - long flag name (defaults to lowercase field name) // - `short:"x"` - short flag name (single character) // - `default:"value"` - default value for the flag +// - `env:"VAR_NAME"` - populate default from an environment variable (overrides default, overridden by CLI) +// - `required:"true"` - return a parse error if the flag/positional wasn't provided // - `usage:"description"` - usage description // - `description:"description"` - alternate usage description // - `choice:"value"` - constrain string field to specific values (can be repeated for multiple choices) @@ -1858,11 +1914,32 @@ func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error { if posUsage == "" { posUsage = field.Tag.Get("description") } + posEnvVar := field.Tag.Get("env") + posRequired := field.Tag.Get("required") == "true" + posHasValue := false + + // Environment variable sets the positional default + if posEnvVar != "" { + if envVal, ok := os.LookupEnv(posEnvVar); ok { + if err := setFieldValue(fieldValue, envVal); err != nil { + return fmt.Errorf("invalid value for env var %s: %w", posEnvVar, err) + } + posHasValue = true + } + } + f.posFields[pos] = &PositionalField{ - Name: field.Name, - Usage: posUsage, - Value: fieldValue, - Type: field.Type, + Name: field.Name, + Usage: posUsage, + Value: fieldValue, + Type: field.Type, + EnvVar: posEnvVar, + Required: posRequired, + HasValue: posHasValue, + } + + if posRequired { + f.requiredPos = append(f.requiredPos, pos) } } continue // Don't process position field as a flag @@ -1902,6 +1979,9 @@ func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error { } defaultValue := field.Tag.Get("default") + envVar := field.Tag.Get("env") + required := field.Tag.Get("required") == "true" + usage := field.Tag.Get("usage") if usage == "" { usage = field.Tag.Get("description") @@ -1910,6 +1990,10 @@ func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error { } } + if required { + f.requiredFlags = append(f.requiredFlags, longName) + } + // Register the flag based on field type switch field.Type.Kind() { case reflect.Bool: @@ -2076,6 +2160,22 @@ func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error { f.Var(&uint64PtrValue{p: p}, longName, short, usage) } } + + // Set env/required metadata on the registered flag, and apply + // env var value through the flag's Value.Set path so it gets + // validated and works for all types including pointers and slices. + if flag, ok := f.flags[longName]; ok { + flag.EnvVar = envVar + flag.Required = required + if envVar != "" { + if envVal, ok := os.LookupEnv(envVar); ok { + if err := flag.Value.Set(envVal); err != nil { + return fmt.Errorf("invalid value for env var %s: %w", envVar, err) + } + flag.HasValue = true + } + } + } } return nil @@ -2103,6 +2203,9 @@ func formatFlagLine(flag *Flag) string { if flag.DefValue != "" && flag.DefValue != "false" && flag.DefValue != "0" { line += fmt.Sprintf(" (default: %s)", flag.DefValue) } + if flag.EnvVar != "" { + line += fmt.Sprintf(" (env: %s)", flag.EnvVar) + } return line } return flagStr @@ -2202,7 +2305,11 @@ func (f *FlagSet) ShowHelp() { if field, ok := f.posFields[i]; ok { argStr := fmt.Sprintf(" <%s>", strings.ToLower(field.Name)) if field.Usage != "" { - fmt.Printf("%-30s %s\n", argStr, field.Usage) + line := fmt.Sprintf("%-30s %s", argStr, field.Usage) + if field.EnvVar != "" { + line += fmt.Sprintf(" (env: %s)", field.EnvVar) + } + fmt.Println(line) } else { fmt.Println(argStr) } diff --git a/mflags_test.go b/mflags_test.go index de58a33..992ba39 100644 --- a/mflags_test.go +++ b/mflags_test.go @@ -872,15 +872,25 @@ func TestFromStructRejectsUnknownTags(t *testing.T) { }) t.Run("multiple unknown tags across fields", func(t *testing.T) { + type Opts struct { + Name string `long:"name" bogus:"yes"` + Port int `long:"port" nope:"true"` + } + fs := NewFlagSet("test") + err := fs.FromStruct(&Opts{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), `"bogus"`) + assert.Contains(t, err.Error(), `"nope"`) + }) + + t.Run("env and required are recognized tags", func(t *testing.T) { type Opts struct { Name string `long:"name" env:"MY_NAME"` Port int `long:"port" required:"true"` } fs := NewFlagSet("test") err := fs.FromStruct(&Opts{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), `"env"`) - assert.Contains(t, err.Error(), `"required"`) + assert.NoError(t, err) }) t.Run("non-mflags tag like json", func(t *testing.T) { @@ -3645,3 +3655,295 @@ func TestShowHelpWithGroups(t *testing.T) { assert.Contains(t, output, "-v, --verbose") assert.Contains(t, output, "-o, --output") } + +// --- env tag tests --- + +func TestFromStructEnvFlag(t *testing.T) { + type Config struct { + Name string `long:"name" env:"TEST_MFLAGS_NAME"` + } + + t.Run("picks up env var", func(t *testing.T) { + t.Setenv("TEST_MFLAGS_NAME", "from-env") + config := &Config{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.NoError(t, err) + assert.Equal(t, "from-env", config.Name) + }) + + t.Run("env overrides default", func(t *testing.T) { + type ConfigWithDefault struct { + Name string `long:"name" default:"hardcoded" env:"TEST_MFLAGS_NAME"` + } + t.Setenv("TEST_MFLAGS_NAME", "from-env") + config := &ConfigWithDefault{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.NoError(t, err) + assert.Equal(t, "from-env", config.Name) + }) + + t.Run("CLI arg overrides env", func(t *testing.T) { + t.Setenv("TEST_MFLAGS_NAME", "from-env") + config := &Config{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{"--name", "from-cli"}) + assert.NoError(t, err) + assert.Equal(t, "from-cli", config.Name) + }) + + t.Run("unset env uses default", func(t *testing.T) { + type ConfigWithDefault struct { + Name string `long:"name" default:"hardcoded" env:"TEST_MFLAGS_NAME_UNSET"` + } + config := &ConfigWithDefault{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.NoError(t, err) + assert.Equal(t, "hardcoded", config.Name) + }) + + t.Run("env with int flag", func(t *testing.T) { + type ConfigInt struct { + Count int `long:"count" env:"TEST_MFLAGS_COUNT"` + } + t.Setenv("TEST_MFLAGS_COUNT", "42") + config := &ConfigInt{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.NoError(t, err) + assert.Equal(t, 42, config.Count) + }) + + t.Run("env with invalid int errors", func(t *testing.T) { + type ConfigInt struct { + Count int `long:"count" env:"TEST_MFLAGS_COUNT"` + } + t.Setenv("TEST_MFLAGS_COUNT", "abc") + config := &ConfigInt{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.Error(t, err) + assert.Contains(t, err.Error(), "TEST_MFLAGS_COUNT") + }) + + t.Run("env with pointer type", func(t *testing.T) { + type ConfigPtr struct { + Name *string `long:"name" env:"TEST_MFLAGS_PTR_NAME"` + } + t.Setenv("TEST_MFLAGS_PTR_NAME", "from-env") + config := &ConfigPtr{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.NoError(t, err) + assert.NotNil(t, config.Name) + assert.Equal(t, "from-env", *config.Name) + }) + + t.Run("env with bool flag", func(t *testing.T) { + type ConfigBool struct { + Verbose bool `long:"verbose" env:"TEST_MFLAGS_VERBOSE"` + } + t.Setenv("TEST_MFLAGS_VERBOSE", "true") + config := &ConfigBool{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.NoError(t, err) + assert.True(t, config.Verbose) + }) +} + +func TestFromStructEnvPositional(t *testing.T) { + type Config struct { + Target string `position:"0" env:"TEST_MFLAGS_TARGET"` + } + + t.Run("picks up env var", func(t *testing.T) { + t.Setenv("TEST_MFLAGS_TARGET", "from-env") + config := &Config{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.NoError(t, err) + assert.Equal(t, "from-env", config.Target) + }) + + t.Run("CLI arg overrides env", func(t *testing.T) { + t.Setenv("TEST_MFLAGS_TARGET", "from-env") + config := &Config{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{"from-cli"}) + assert.NoError(t, err) + assert.Equal(t, "from-cli", config.Target) + }) +} + +func TestEnvHelpText(t *testing.T) { + type Config struct { + Name string `long:"name" env:"MY_APP_NAME" usage:"Application name"` + } + + config := &Config{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + + old := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + fs.WriteFlagHelp() + + w.Close() + os.Stdout = old + + var buf bytes.Buffer + io.Copy(&buf, r) + output := buf.String() + + assert.Contains(t, output, "(env: MY_APP_NAME)") +} + +// --- required tag tests --- + +func TestFromStructRequiredFlag(t *testing.T) { + type Config struct { + Name string `long:"name" required:"true"` + } + + t.Run("errors when not provided", func(t *testing.T) { + config := &Config{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrRequired) + assert.Contains(t, err.Error(), "--name") + }) + + t.Run("succeeds when provided via CLI", func(t *testing.T) { + config := &Config{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{"--name", "hello"}) + assert.NoError(t, err) + assert.Equal(t, "hello", config.Name) + }) + + t.Run("succeeds when provided via env", func(t *testing.T) { + type ConfigEnv struct { + Name string `long:"name" required:"true" env:"TEST_MFLAGS_REQ_NAME"` + } + t.Setenv("TEST_MFLAGS_REQ_NAME", "from-env") + config := &ConfigEnv{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.NoError(t, err) + assert.Equal(t, "from-env", config.Name) + }) + + t.Run("multiple missing reports all", func(t *testing.T) { + type ConfigMulti struct { + Name string `long:"name" required:"true"` + Host string `long:"host" required:"true"` + } + config := &ConfigMulti{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrRequired) + assert.Contains(t, err.Error(), "--name") + assert.Contains(t, err.Error(), "--host") + }) +} + +func TestFromStructRequiredPositional(t *testing.T) { + type Config struct { + Command string `position:"0" required:"true"` + } + + t.Run("errors when not provided", func(t *testing.T) { + config := &Config{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrRequired) + assert.Contains(t, err.Error(), "Command") + }) + + t.Run("succeeds when provided", func(t *testing.T) { + config := &Config{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{"deploy"}) + assert.NoError(t, err) + assert.Equal(t, "deploy", config.Command) + }) + + t.Run("succeeds when provided via env", func(t *testing.T) { + type ConfigEnv struct { + Command string `position:"0" required:"true" env:"TEST_MFLAGS_CMD"` + } + t.Setenv("TEST_MFLAGS_CMD", "from-env") + config := &ConfigEnv{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.NoError(t, err) + assert.Equal(t, "from-env", config.Command) + }) +} + +func TestRequiredBoolFlag(t *testing.T) { + type Config struct { + Accept bool `long:"accept" required:"true"` + } + + t.Run("errors when not provided", func(t *testing.T) { + config := &Config{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{}) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrRequired) + }) + + t.Run("succeeds when provided", func(t *testing.T) { + config := &Config{} + fs := NewFlagSet("test") + err := fs.FromStruct(config) + assert.NoError(t, err) + err = fs.Parse([]string{"--accept"}) + assert.NoError(t, err) + assert.True(t, config.Accept) + }) +}