Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions help_doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,25 @@ type ExampleDoc struct {

// FlagDoc describes a single flag in the help document.
type FlagDoc struct {
Name string `json:"name"`
Short string `json:"short"`
Type string `json:"type"`
Default string `json:"default,omitempty"`
Usage string `json:"usage"`
Group string `json:"group,omitempty"`
IsBool bool `json:"isBool,omitempty"`
Choices []string `json:"choices,omitempty"`
Name string `json:"name"`
Short string `json:"short"`
Type string `json:"type"`
Default string `json:"default,omitempty"`
Usage string `json:"usage"`
Group string `json:"group,omitempty"`
IsBool bool `json:"isBool,omitempty"`
Choices []string `json:"choices,omitempty"`
EnvVar string `json:"envVar,omitempty"`
Required bool `json:"required,omitempty"`
}

// PositionalDoc describes a positional argument in the help document.
type PositionalDoc struct {
Name string `json:"name"`
Usage string `json:"usage"`
Type string `json:"type"`
Name string `json:"name"`
Usage string `json:"usage"`
Type string `json:"type"`
EnvVar string `json:"envVar,omitempty"`
Required bool `json:"required,omitempty"`
}

// FlagSetDoc describes a standalone FlagSet's help information.
Expand All @@ -74,13 +78,15 @@ func (f *FlagSet) HelpDoc() *FlagSetDoc {

f.VisitAll(func(flag *Flag) {
fd := FlagDoc{
Name: flag.Name,
Type: flag.Value.Type(),
Default: flag.DefValue,
Usage: flag.Usage,
Group: flag.Group,
IsBool: flag.Value.IsBool(),
Choices: []string{},
Name: flag.Name,
Type: flag.Value.Type(),
Default: flag.DefValue,
Usage: flag.Usage,
Group: flag.Group,
IsBool: flag.Value.IsBool(),
Choices: []string{},
EnvVar: flag.EnvVar,
Required: flag.Required,
}
if flag.Short != 0 {
fd.Short = string(flag.Short)
Expand All @@ -97,9 +103,11 @@ func (f *FlagSet) HelpDoc() *FlagSetDoc {

for _, pf := range f.GetPositionalFields() {
doc.PositionalArgs = append(doc.PositionalArgs, PositionalDoc{
Name: pf.Name,
Usage: pf.Usage,
Type: pf.Type.String(),
Name: pf.Name,
Usage: pf.Usage,
Type: pf.Type.String(),
EnvVar: pf.EnvVar,
Required: pf.Required,
})
}

Expand Down
125 changes: 116 additions & 9 deletions mflags.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mflags
import (
"errors"
"fmt"
"os"
"reflect"
"strconv"
"strings"
Expand All @@ -16,14 +17,18 @@ var (
ErrInvalidChoice = errors.New("invalid choice")
ErrHelp = errors.New("help requested")
ErrShowHelp = errors.New("show help") // Return from Command.Run to trigger help display
ErrRequired = errors.New("required flag not provided")
)

// PositionalField represents a positional argument field
type PositionalField struct {
Name string // Field name (e.g., "Command", "Target")
Usage string // Usage description for help output
Value reflect.Value // The reflect.Value of the field
Type reflect.Type // The type of the field
Name string // Field name (e.g., "Command", "Target")
Usage string // Usage description for help output
Value reflect.Value // The reflect.Value of the field
Type reflect.Type // The type of the field
EnvVar string // environment variable name (from env:"..." struct tag)
Required bool // whether this positional must be provided
HasValue bool // true if value was set by env var or CLI arg
Comment thread
phinze marked this conversation as resolved.
}

type FlagSet struct {
Expand All @@ -41,6 +46,8 @@ type FlagSet struct {
disableAutoHelp bool // If true, don't automatically handle -h/--help in Parse
currentGroup string // ambient group name set by FromStruct options or Group() calls
groupOrder []string // ordered list of distinct group names (insertion order)
requiredFlags []string // flag names marked required:"true"
requiredPos []int // positional indices marked required:"true"
Comment thread
phinze marked this conversation as resolved.
}

type Flag struct {
Expand All @@ -50,6 +57,9 @@ type Flag struct {
Value Value
DefValue string
Group string // group name for help rendering; empty = default "Options:"
EnvVar string // environment variable name (from env:"..." struct tag)
Required bool // whether this flag must be provided
HasValue bool // true if value was set by env var or CLI arg
}

type Value interface {
Expand Down Expand Up @@ -1424,6 +1434,7 @@ func (f *FlagSet) Parse(arguments []string) error {
if err := setFieldValue(field.Value, f.args[pos]); err != nil {
return fmt.Errorf("invalid value for position %d: %v", pos, err)
}
field.HasValue = true
}
}

Expand All @@ -1449,6 +1460,38 @@ func (f *FlagSet) Parse(arguments []string) error {
*f.unknownField = f.unknownFlags
}

if err := f.validateRequired(); err != nil {
return err
}

return nil
}

// validateRequired checks that all flags and positionals marked required:"true"
// have been provided a value (via CLI arg or env var).
func (f *FlagSet) validateRequired() error {
var missing []string

for _, name := range f.requiredFlags {
if flag, ok := f.flags[name]; ok {
if !flag.HasValue {
missing = append(missing, "--"+name)
}
}
}

for _, pos := range f.requiredPos {
if field, ok := f.posFields[pos]; ok {
if !field.HasValue {
missing = append(missing, field.Name)
}
}
}

if len(missing) > 0 {
return fmt.Errorf("%w: %s", ErrRequired, strings.Join(missing, ", "))
}

return nil
}

Expand Down Expand Up @@ -1492,6 +1535,9 @@ func (f *FlagSet) parseLongFlag(name string, args []string, index *int) (bool, e
return false, fmt.Errorf("%w: --%s: %v", ErrInvalidValue, name, err)
}

flag.HasValue = true


return true, nil
}

Expand All @@ -1515,6 +1561,8 @@ func (f *FlagSet) parseShortFlags(shortFlags string, args []string, index *int)
if err := flag.Value.Set("true"); err != nil {
return fmt.Errorf("%w: -%c: %v", ErrInvalidValue, r, err)
}
flag.HasValue = true

} else {
// Check if there are more characters after this flag
if i < len(runes)-1 {
Expand All @@ -1529,13 +1577,17 @@ func (f *FlagSet) parseShortFlags(shortFlags string, args []string, index *int)
if err := flag.Value.Set(value); err != nil {
return fmt.Errorf("%w: -%c: %v", ErrInvalidValue, r, err)
}
flag.HasValue = true

break
} else if *index+1 < len(args) {
value := args[*index+1]
*index++
if err := flag.Value.Set(value); err != nil {
return fmt.Errorf("%w: -%c: %v", ErrInvalidValue, r, err)
}
flag.HasValue = true

} else {
return fmt.Errorf("%w: -%c", ErrMissingValue, r)
}
Expand Down Expand Up @@ -1620,6 +1672,8 @@ var knownTags = map[string]bool{
"long": true,
"short": true,
"default": true,
"env": true,
"required": true,
"usage": true,
"description": true,
"choice": true,
Expand Down Expand Up @@ -1763,6 +1817,8 @@ func InGroup(name string) FromStructOption {
// - `long:"name"` - long flag name (defaults to lowercase field name)
// - `short:"x"` - short flag name (single character)
// - `default:"value"` - default value for the flag
// - `env:"VAR_NAME"` - populate default from an environment variable (overrides default, overridden by CLI)
// - `required:"true"` - return a parse error if the flag/positional wasn't provided
// - `usage:"description"` - usage description
// - `description:"description"` - alternate usage description
// - `choice:"value"` - constrain string field to specific values (can be repeated for multiple choices)
Expand Down Expand Up @@ -1858,11 +1914,32 @@ func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error {
if posUsage == "" {
posUsage = field.Tag.Get("description")
}
posEnvVar := field.Tag.Get("env")
posRequired := field.Tag.Get("required") == "true"
posHasValue := false

// Environment variable sets the positional default
if posEnvVar != "" {
if envVal, ok := os.LookupEnv(posEnvVar); ok {
if err := setFieldValue(fieldValue, envVal); err != nil {
return fmt.Errorf("invalid value for env var %s: %w", posEnvVar, err)
}
posHasValue = true
}
}

f.posFields[pos] = &PositionalField{
Name: field.Name,
Usage: posUsage,
Value: fieldValue,
Type: field.Type,
Name: field.Name,
Usage: posUsage,
Value: fieldValue,
Type: field.Type,
EnvVar: posEnvVar,
Required: posRequired,
HasValue: posHasValue,
}

if posRequired {
f.requiredPos = append(f.requiredPos, pos)
}
}
continue // Don't process position field as a flag
Expand Down Expand Up @@ -1902,6 +1979,9 @@ func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error {
}

defaultValue := field.Tag.Get("default")
envVar := field.Tag.Get("env")
required := field.Tag.Get("required") == "true"

usage := field.Tag.Get("usage")
if usage == "" {
usage = field.Tag.Get("description")
Expand All @@ -1910,6 +1990,10 @@ func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error {
}
}

if required {
f.requiredFlags = append(f.requiredFlags, longName)
}

// Register the flag based on field type
switch field.Type.Kind() {
case reflect.Bool:
Expand Down Expand Up @@ -2076,6 +2160,22 @@ func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error {
f.Var(&uint64PtrValue{p: p}, longName, short, usage)
}
}

// Set env/required metadata on the registered flag, and apply
// env var value through the flag's Value.Set path so it gets
// validated and works for all types including pointers and slices.
if flag, ok := f.flags[longName]; ok {
flag.EnvVar = envVar
flag.Required = required
if envVar != "" {
if envVal, ok := os.LookupEnv(envVar); ok {
if err := flag.Value.Set(envVal); err != nil {
return fmt.Errorf("invalid value for env var %s: %w", envVar, err)
}
flag.HasValue = true
}
}
}
}

return nil
Expand Down Expand Up @@ -2103,6 +2203,9 @@ func formatFlagLine(flag *Flag) string {
if flag.DefValue != "" && flag.DefValue != "false" && flag.DefValue != "0" {
line += fmt.Sprintf(" (default: %s)", flag.DefValue)
}
if flag.EnvVar != "" {
line += fmt.Sprintf(" (env: %s)", flag.EnvVar)
}
return line
}
return flagStr
Expand Down Expand Up @@ -2202,7 +2305,11 @@ func (f *FlagSet) ShowHelp() {
if field, ok := f.posFields[i]; ok {
argStr := fmt.Sprintf(" <%s>", strings.ToLower(field.Name))
if field.Usage != "" {
fmt.Printf("%-30s %s\n", argStr, field.Usage)
line := fmt.Sprintf("%-30s %s", argStr, field.Usage)
if field.EnvVar != "" {
line += fmt.Sprintf(" (env: %s)", field.EnvVar)
}
fmt.Println(line)
} else {
fmt.Println(argStr)
}
Expand Down
Loading
Loading