From d48e5dac2d4dd3dd4caa56b50fe9dbe5328835db Mon Sep 17 00:00:00 2001 From: Paul Hinze Date: Wed, 8 Apr 2026 16:43:33 -0500 Subject: [PATCH] Split mflags.go into smaller, focused files mflags.go had grown to ~2300 lines covering four distinct concerns. Pulled each one into its own file so future work (more value types, richer help formatting, new struct tags) lands somewhere obvious: mflags.go (823) core types, registration, parsing values.go (806) all xxxValue/xxxPtrValue implementations fromstruct.go (580) FromStruct and struct tag reflection help.go (146) ShowHelp, WriteFlagHelp, formatFlagLine Pure file reorganization, no behavior changes. Also removed some stray json/xml tags from an unexported test struct field that was tripping go vet. Closes MIR-988 --- fromstruct.go | 580 ++++++++++++++++++ help.go | 146 +++++ mflags.go | 1516 +----------------------------------------------- mflags_test.go | 2 +- values.go | 806 +++++++++++++++++++++++++ 5 files changed, 1536 insertions(+), 1514 deletions(-) create mode 100644 fromstruct.go create mode 100644 help.go create mode 100644 values.go diff --git a/fromstruct.go b/fromstruct.go new file mode 100644 index 0000000..2d51791 --- /dev/null +++ b/fromstruct.go @@ -0,0 +1,580 @@ +package mflags + +import ( + "fmt" + "os" + "reflect" + "strconv" + "strings" + "time" +) + +// getTagValues extracts all values for a given key from a struct tag. +// This is needed because Go's tag.Get() only returns the first value, +// but we need to support multiple values (e.g., multiple choice tags). +func getTagValues(tag reflect.StructTag, key string) []string { + var values []string + tagStr := string(tag) + searchKey := key + `:` + + for { + idx := strings.Index(tagStr, searchKey) + if idx < 0 { + break + } + + // Move past the key and colon + tagStr = tagStr[idx+len(searchKey):] + + // Find the quoted value + if len(tagStr) == 0 || tagStr[0] != '"' { + break + } + + // Find the closing quote + endIdx := 1 + for endIdx < len(tagStr) && tagStr[endIdx] != '"' { + if tagStr[endIdx] == '\\' && endIdx+1 < len(tagStr) { + endIdx += 2 // Skip escaped character + } else { + endIdx++ + } + } + + if endIdx < len(tagStr) { + value := tagStr[1:endIdx] + values = append(values, value) + tagStr = tagStr[endIdx+1:] + } else { + break + } + } + + return values +} + +// knownTags is the set of struct tag keys that FromStruct knows how to handle. +// Keep this in sync with the Tag.Get() and getTagValues() calls in FromStruct. +var knownTags = map[string]bool{ + "long": true, + "short": true, + "default": true, + "env": true, + "required": true, + "usage": true, + "description": true, + "choice": true, + "position": true, + "rest": true, + "unknown": true, + "group": true, +} + +// validateStructTags checks that every struct tag on exported fields is one +// that FromStruct actually reads. It returns an error listing all unrecognized +// tags so the caller can fix them all in one pass. +func validateStructTags(rt reflect.Type) error { + var errs []string + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + + // Skip the blank identifier used for group declarations + if field.Name == "_" { + continue + } + + if !field.IsExported() { + continue + } + + // Recurse into embedded structs + ft := field.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if field.Anonymous && ft.Kind() == reflect.Struct { + if err := validateStructTags(ft); err != nil { + errs = append(errs, err.Error()) + } + continue + } + + // Parse the raw tag string into key:"value" pairs and check each key + tagStr := string(field.Tag) + for tagStr != "" { + // Skip leading spaces + tagStr = strings.TrimLeft(tagStr, " ") + if tagStr == "" { + break + } + + // Find the key (everything before the colon) + colon := strings.Index(tagStr, ":") + if colon < 0 { + break + } + key := tagStr[:colon] + tagStr = tagStr[colon+1:] + + // Skip past the quoted value + if len(tagStr) == 0 || tagStr[0] != '"' { + break + } + end := 1 + for end < len(tagStr) && tagStr[end] != '"' { + if tagStr[end] == '\\' && end+1 < len(tagStr) { + end += 2 + } else { + end++ + } + } + if end >= len(tagStr) { + break + } + tagStr = tagStr[end+1:] + + if !knownTags[key] { + errs = append(errs, fmt.Sprintf("unknown struct tag %q on field %s in %s", key, field.Name, rt.Name())) + } + } + } + if len(errs) > 0 { + return fmt.Errorf("invalid struct tags:\n %s", strings.Join(errs, "\n ")) + } + return nil +} + +// setFieldValue sets a string value to a reflect.Value based on its type +func setFieldValue(fieldValue reflect.Value, value string) error { + switch fieldValue.Kind() { + case reflect.String: + fieldValue.SetString(value) + case reflect.Bool: + b, err := strconv.ParseBool(value) + if err != nil { + return err + } + fieldValue.SetBool(b) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if fieldValue.Type() == reflect.TypeOf(time.Duration(0)) { + d, err := time.ParseDuration(value) + if err != nil { + return err + } + fieldValue.SetInt(int64(d)) + } else { + i, err := strconv.ParseInt(value, 10, fieldValue.Type().Bits()) + if err != nil { + return err + } + fieldValue.SetInt(i) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u, err := strconv.ParseUint(value, 10, fieldValue.Type().Bits()) + if err != nil { + return err + } + fieldValue.SetUint(u) + case reflect.Float32, reflect.Float64: + f, err := strconv.ParseFloat(value, fieldValue.Type().Bits()) + if err != nil { + return err + } + fieldValue.SetFloat(f) + default: + return fmt.Errorf("unsupported type: %v", fieldValue.Type()) + } + return nil +} + +// FromStructOption configures how FromStruct processes a struct. +type FromStructOption func(*fromStructConfig) + +type fromStructConfig struct { + group string +} + +// InGroup sets the group name for all flags created by FromStruct. +func InGroup(name string) FromStructOption { + return func(c *fromStructConfig) { c.group = name } +} + +// FromStruct creates flag definitions from a struct's fields using struct tags. +// The argument must be a pointer to a struct. Struct tags control how fields are parsed: +// - `long:"name"` - long flag name (defaults to lowercase field name) +// - `short:"x"` - short flag name (single character) +// - `default:"value"` - default value for the flag +// - `env:"VAR_NAME"` - populate default from an environment variable (overrides default, overridden by CLI) +// - `required:"true"` - return a parse error if the flag/positional wasn't provided +// - `usage:"description"` - usage description +// - `description:"description"` - alternate usage description +// - `choice:"value"` - constrain string field to specific values (can be repeated for multiple choices) +// - `position:"0"` - positional argument at index 0 +// - `rest:"true"` - capture all remaining arguments in a []string field +// - `unknown:"true"` - capture unknown flags in a []string field (automatically enables AllowUnknownFlags) +// - `group:"name"` - on a `_ struct{}` field, declares the group for all flags in the struct +// - `group:"name"` - on an embedded struct field, overrides the embedded struct's self-declared group +// +// Supports bool, string, int, []string, and time.Duration field types. +// Anonymous embedded structs are recursively processed. +func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("FromStruct requires a non-nil pointer to a struct") + } + + rv = rv.Elem() + if rv.Kind() != reflect.Struct { + return fmt.Errorf("FromStruct requires a pointer to a struct") + } + + // Apply options + var cfg fromStructConfig + for _, opt := range opts { + opt(&cfg) + } + + // Save and restore currentGroup + prevGroup := f.currentGroup + defer func() { f.currentGroup = prevGroup }() + + if cfg.group != "" { + f.currentGroup = cfg.group + } + + rt := rv.Type() + + // Validate that all struct tags are ones we know how to handle + if err := validateStructTags(rt); err != nil { + return err + } + + // First pass: check for self-declared group via `_ struct{} \`group:"..."\`` + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + if field.Name == "_" && field.Type == reflect.TypeOf(struct{}{}) { + if groupTag := field.Tag.Get("group"); groupTag != "" && f.currentGroup == "" { + f.currentGroup = groupTag + } + break + } + } + + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + + // Skip the `_` group declaration field + if field.Name == "_" { + continue + } + + if !field.IsExported() { + continue + } + + fieldValue := rv.Field(i) + if !fieldValue.CanAddr() { + continue + } + + // Check for anonymous/embedded struct fields and descend into them + if field.Anonymous && field.Type.Kind() == reflect.Struct { + // Check for group tag on the embedding site + if groupTag := field.Tag.Get("group"); groupTag != "" { + if err := f.FromStruct(fieldValue.Addr().Interface(), InGroup(groupTag)); err != nil { + return err + } + } else { + if err := f.FromStruct(fieldValue.Addr().Interface()); err != nil { + return err + } + } + continue + } + + // Check for "position" tag - capture positional argument + if posStr := field.Tag.Get("position"); posStr != "" { + pos, err := strconv.Atoi(posStr) + if err == nil && pos >= 0 { + // Get usage from either "usage" or "description" tag + posUsage := field.Tag.Get("usage") + if posUsage == "" { + posUsage = field.Tag.Get("description") + } + posEnvVar := field.Tag.Get("env") + posRequired := field.Tag.Get("required") == "true" + posHasValue := false + + // Environment variable sets the positional default + if posEnvVar != "" { + if envVal, ok := os.LookupEnv(posEnvVar); ok { + if err := setFieldValue(fieldValue, envVal); err != nil { + return fmt.Errorf("invalid value for env var %s: %w", posEnvVar, err) + } + posHasValue = true + } + } + + f.posFields[pos] = &PositionalField{ + Name: field.Name, + Usage: posUsage, + Value: fieldValue, + Type: field.Type, + EnvVar: posEnvVar, + Required: posRequired, + HasValue: posHasValue, + } + + if posRequired { + f.requiredPos = append(f.requiredPos, pos) + } + } + continue // Don't process position field as a flag + } + + // Check for "rest" tag - special handling for remaining arguments + if field.Tag.Get("rest") != "" { + if field.Type.Kind() == reflect.Slice && field.Type.Elem().Kind() == reflect.String { + f.restField = fieldValue.Addr().Interface().(*[]string) + } + continue // Don't process rest field as a flag + } + + // Check for "unknown" tag - special handling for unknown flags + if field.Tag.Get("unknown") != "" { + if field.Type.Kind() == reflect.Slice && field.Type.Elem().Kind() == reflect.String { + f.unknownField = fieldValue.Addr().Interface().(*[]string) + f.allowUnknownFlags = true // Automatically enable unknown flag handling + } + continue // Don't process unknown field as a flag + } + + // Parse struct tags + longName := field.Tag.Get("long") + if longName == "" { + longName = strings.ToLower(field.Name) + } + + shortName := field.Tag.Get("short") + var short rune + if shortName != "" && len(shortName) == 1 { + short = rune(shortName[0]) + } + + if longName == "" && short == 0 { + continue // No flag name provided + } + + defaultValue := field.Tag.Get("default") + envVar := field.Tag.Get("env") + required := field.Tag.Get("required") == "true" + + usage := field.Tag.Get("usage") + if usage == "" { + usage = field.Tag.Get("description") + if usage == "" { + usage = fmt.Sprintf("%s value", field.Name) + } + } + + if required { + f.requiredFlags = append(f.requiredFlags, longName) + } + + // Register the flag based on field type + switch field.Type.Kind() { + case reflect.Bool: + var defVal bool + if defaultValue != "" { + defVal, _ = strconv.ParseBool(defaultValue) + } + f.BoolVar(fieldValue.Addr().Interface().(*bool), longName, short, defVal, usage) + + case reflect.String: + // Check for choice tags - if present, use ChoiceVar + choices := getTagValues(field.Tag, "choice") + if len(choices) > 0 { + f.ChoiceVar(fieldValue.Addr().Interface().(*string), longName, short, defaultValue, choices, usage) + } else { + f.StringVar(fieldValue.Addr().Interface().(*string), longName, short, defaultValue, usage) + } + + case reflect.Int: + var defVal int + if defaultValue != "" { + defVal, _ = strconv.Atoi(defaultValue) + } + f.IntVar(fieldValue.Addr().Interface().(*int), longName, short, defVal, usage) + + case reflect.Slice: + switch field.Type.Elem().Kind() { + case reflect.String: + var defVal []string + if defaultValue != "" { + defVal = strings.Split(defaultValue, ",") + } + f.StringArrayVar(fieldValue.Addr().Interface().(*[]string), longName, short, defVal, usage) + case reflect.Bool: + f.BoolArrayVar(fieldValue.Addr().Interface().(*[]bool), longName, short, usage) + case reflect.Int: + f.IntArrayVar(fieldValue.Addr().Interface().(*[]int), longName, short, usage) + } + + case reflect.Int64: + // Check if it's a time.Duration + if field.Type == reflect.TypeOf(time.Duration(0)) { + var defVal time.Duration + if defaultValue != "" { + defVal, _ = time.ParseDuration(defaultValue) + } + f.DurationVar(fieldValue.Addr().Interface().(*time.Duration), longName, short, defVal, usage) + } else { + var defVal int64 + if defaultValue != "" { + defVal, _ = strconv.ParseInt(defaultValue, 10, 64) + } + f.Int64Var(fieldValue.Addr().Interface().(*int64), longName, short, defVal, usage) + } + + case reflect.Int8: + var defVal int8 + if defaultValue != "" { + v, _ := strconv.ParseInt(defaultValue, 10, 8) + defVal = int8(v) + } + f.Int8Var(fieldValue.Addr().Interface().(*int8), longName, short, defVal, usage) + + case reflect.Int16: + var defVal int16 + if defaultValue != "" { + v, _ := strconv.ParseInt(defaultValue, 10, 16) + defVal = int16(v) + } + f.Int16Var(fieldValue.Addr().Interface().(*int16), longName, short, defVal, usage) + + case reflect.Int32: + var defVal int32 + if defaultValue != "" { + v, _ := strconv.ParseInt(defaultValue, 10, 32) + defVal = int32(v) + } + f.Int32Var(fieldValue.Addr().Interface().(*int32), longName, short, defVal, usage) + + case reflect.Uint: + var defVal uint + if defaultValue != "" { + v, _ := strconv.ParseUint(defaultValue, 10, 64) + defVal = uint(v) + } + f.UintVar(fieldValue.Addr().Interface().(*uint), longName, short, defVal, usage) + + case reflect.Uint8: + var defVal uint8 + if defaultValue != "" { + v, _ := strconv.ParseUint(defaultValue, 10, 8) + defVal = uint8(v) + } + f.Uint8Var(fieldValue.Addr().Interface().(*uint8), longName, short, defVal, usage) + + case reflect.Uint16: + var defVal uint16 + if defaultValue != "" { + v, _ := strconv.ParseUint(defaultValue, 10, 16) + defVal = uint16(v) + } + f.Uint16Var(fieldValue.Addr().Interface().(*uint16), longName, short, defVal, usage) + + case reflect.Uint32: + var defVal uint32 + if defaultValue != "" { + v, _ := strconv.ParseUint(defaultValue, 10, 32) + defVal = uint32(v) + } + f.Uint32Var(fieldValue.Addr().Interface().(*uint32), longName, short, defVal, usage) + + case reflect.Uint64: + var defVal uint64 + if defaultValue != "" { + defVal, _ = strconv.ParseUint(defaultValue, 10, 64) + } + f.Uint64Var(fieldValue.Addr().Interface().(*uint64), longName, short, defVal, usage) + + case reflect.Ptr: + // Handle pointer types - allows distinguishing "not set" from "zero value" + elemKind := field.Type.Elem().Kind() + switch elemKind { + case reflect.Bool: + p := fieldValue.Addr().Interface().(**bool) + f.Var(&boolPtrValue{p: p}, longName, short, usage) + case reflect.String: + p := fieldValue.Addr().Interface().(**string) + f.Var(&stringPtrValue{p: p}, longName, short, usage) + case reflect.Int: + p := fieldValue.Addr().Interface().(**int) + f.Var(&intPtrValue{p: p}, longName, short, usage) + case reflect.Int64: + // Check if it's a *time.Duration + if field.Type.Elem() == reflect.TypeOf(time.Duration(0)) { + p := fieldValue.Addr().Interface().(**time.Duration) + f.Var(&durationPtrValue{p: p}, longName, short, usage) + } else { + p := fieldValue.Addr().Interface().(**int64) + f.Var(&int64PtrValue{p: p}, longName, short, usage) + } + case reflect.Int8: + p := fieldValue.Addr().Interface().(**int8) + f.Var(&int8PtrValue{p: p}, longName, short, usage) + case reflect.Int16: + p := fieldValue.Addr().Interface().(**int16) + f.Var(&int16PtrValue{p: p}, longName, short, usage) + case reflect.Int32: + p := fieldValue.Addr().Interface().(**int32) + f.Var(&int32PtrValue{p: p}, longName, short, usage) + case reflect.Uint: + p := fieldValue.Addr().Interface().(**uint) + f.Var(&uintPtrValue{p: p}, longName, short, usage) + case reflect.Uint8: + p := fieldValue.Addr().Interface().(**uint8) + f.Var(&uint8PtrValue{p: p}, longName, short, usage) + case reflect.Uint16: + p := fieldValue.Addr().Interface().(**uint16) + f.Var(&uint16PtrValue{p: p}, longName, short, usage) + case reflect.Uint32: + p := fieldValue.Addr().Interface().(**uint32) + f.Var(&uint32PtrValue{p: p}, longName, short, usage) + case reflect.Uint64: + p := fieldValue.Addr().Interface().(**uint64) + f.Var(&uint64PtrValue{p: p}, longName, short, usage) + } + } + + // Set env/required metadata on the registered flag, and apply + // env var value through the flag's Value.Set path so it gets + // validated and works for all types including pointers and slices. + if flag, ok := f.flags[longName]; ok { + flag.EnvVar = envVar + flag.Required = required + if envVar != "" { + if envVal, ok := os.LookupEnv(envVar); ok { + if err := flag.Value.Set(envVal); err != nil { + return fmt.Errorf("invalid value for env var %s: %w", envVar, err) + } + flag.HasValue = true + } + } + } + } + + return nil +} + +// ParseStruct parses command line arguments and updates the struct fields. +// This is a convenience function that creates a FlagSet, calls FromStruct, and parses the arguments. +// See FromStruct for documentation on supported struct tags and field types. +func ParseStruct(v any, arguments []string) error { + fs := NewFlagSet("") + if err := fs.FromStruct(v); err != nil { + return err + } + return fs.Parse(arguments) +} diff --git a/help.go b/help.go new file mode 100644 index 0000000..f3a0620 --- /dev/null +++ b/help.go @@ -0,0 +1,146 @@ +package mflags + +import ( + "fmt" + "strings" +) + +// formatFlagLine formats a single flag for help output. +func formatFlagLine(flag *Flag) string { + var flagStr string + if flag.Short != 0 && flag.Name != "" { + flagStr = fmt.Sprintf(" -%c, --%s", flag.Short, flag.Name) + } else if flag.Short != 0 { + flagStr = fmt.Sprintf(" -%c", flag.Short) + } else { + flagStr = fmt.Sprintf(" --%s", flag.Name) + } + + // Add value placeholder for non-boolean flags + if !flag.Value.IsBool() { + flagStr += fmt.Sprintf(" <%s>", flag.Value.Type()) + } + + // Format with usage + if flag.Usage != "" { + line := fmt.Sprintf("%-30s %s", flagStr, flag.Usage) + if flag.DefValue != "" && flag.DefValue != "false" && flag.DefValue != "0" { + line += fmt.Sprintf(" (default: %s)", flag.DefValue) + } + if flag.EnvVar != "" { + line += fmt.Sprintf(" (env: %s)", flag.EnvVar) + } + return line + } + return flagStr +} + +// WriteFlagHelp writes group-aware flag help output to stdout. +// Named groups are printed first in insertion order, followed by the default +// (unnamed) group under "Options:". +func (f *FlagSet) WriteFlagHelp() { + // Collect flags into groups, preserving VisitAll sort order + type groupFlags struct { + name string + flags []*Flag + } + + groupMap := make(map[string]*groupFlags) + var defaultFlags []*Flag + + f.VisitAll(func(flag *Flag) { + if flag.Group == "" { + defaultFlags = append(defaultFlags, flag) + } else { + gf, ok := groupMap[flag.Group] + if !ok { + gf = &groupFlags{name: flag.Group} + groupMap[flag.Group] = gf + } + gf.flags = append(gf.flags, flag) + } + }) + + // Print named groups in insertion order + for _, groupName := range f.groupOrder { + gf, ok := groupMap[groupName] + if !ok || len(gf.flags) == 0 { + continue + } + fmt.Printf("\n%s:\n", gf.name) + for _, flag := range gf.flags { + fmt.Println(formatFlagLine(flag)) + } + } + + // Print default group last + if len(defaultFlags) > 0 { + fmt.Println("\nOptions:") + for _, flag := range defaultFlags { + fmt.Println(formatFlagLine(flag)) + } + } +} + +// ShowHelp displays help information for the flag set, including all defined flags +// and their usage information. +func (f *FlagSet) ShowHelp() { + if f.name != "" { + fmt.Printf("Usage: %s [options]", f.name) + // Show positional arguments by name + if len(f.posFields) > 0 { + // Find max position to iterate in order + maxPos := -1 + for pos := range f.posFields { + if pos > maxPos { + maxPos = pos + } + } + // Print each positional argument name + for i := 0; i <= maxPos; i++ { + if field, ok := f.posFields[i]; ok { + fmt.Printf(" <%s>", strings.ToLower(field.Name)) + } + } + } + if f.restField != nil { + fmt.Print(" [arguments...]") + } + fmt.Println() + } + + // Show positional arguments with descriptions if any have usage text + if len(f.posFields) > 0 { + // Find max position to iterate in order + maxPos := -1 + hasUsage := false + for pos, field := range f.posFields { + if pos > maxPos { + maxPos = pos + } + if field.Usage != "" { + hasUsage = true + } + } + + if hasUsage { + fmt.Println("\nArguments:") + for i := 0; i <= maxPos; i++ { + if field, ok := f.posFields[i]; ok { + argStr := fmt.Sprintf(" <%s>", strings.ToLower(field.Name)) + if field.Usage != "" { + line := fmt.Sprintf("%-30s %s", argStr, field.Usage) + if field.EnvVar != "" { + line += fmt.Sprintf(" (env: %s)", field.EnvVar) + } + fmt.Println(line) + } else { + fmt.Println(argStr) + } + } + } + } + } + + f.WriteFlagHelp() +} diff --git a/mflags.go b/mflags.go index bb24b58..30b270f 100644 --- a/mflags.go +++ b/mflags.go @@ -3,9 +3,7 @@ package mflags import ( "errors" "fmt" - "os" "reflect" - "strconv" "strings" "time" ) @@ -69,804 +67,6 @@ type Value interface { Type() string } -type boolValue bool - -func (b *boolValue) Set(s string) error { - v, err := strconv.ParseBool(s) - if err != nil { - return err - } - *b = boolValue(v) - return nil -} - -func (b *boolValue) String() string { - return strconv.FormatBool(bool(*b)) -} - -func (b *boolValue) IsBool() bool { - return true -} - -func (b *boolValue) Type() string { - return "bool" -} - -type boolArrayValue []bool - -func (b *boolArrayValue) Set(s string) error { - v, err := strconv.ParseBool(s) - if err != nil { - return err - } - *b = append(*b, v) - return nil -} - -func (b *boolArrayValue) String() string { - if len(*b) == 0 { - return "" - } - strs := make([]string, len(*b)) - for i, v := range *b { - strs[i] = strconv.FormatBool(v) - } - return strings.Join(strs, ",") -} - -func (b *boolArrayValue) IsBool() bool { - return true -} - -func (b *boolArrayValue) Type() string { - return "bool" -} - -type stringValue string - -func (s *stringValue) Set(val string) error { - *s = stringValue(val) - return nil -} - -func (s *stringValue) String() string { - return string(*s) -} - -func (s *stringValue) IsBool() bool { - return false -} - -func (s *stringValue) Type() string { - return "string" -} - -type intValue int - -func (i *intValue) Set(s string) error { - v, err := strconv.Atoi(s) - if err != nil { - return err - } - *i = intValue(v) - return nil -} - -func (i *intValue) String() string { - return strconv.Itoa(int(*i)) -} - -func (i *intValue) IsBool() bool { - return false -} - -func (i *intValue) Type() string { - return "int" -} - -type int64Value int64 - -func (i *int64Value) Set(s string) error { - v, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return err - } - *i = int64Value(v) - return nil -} - -func (i *int64Value) String() string { - return strconv.FormatInt(int64(*i), 10) -} - -func (i *int64Value) IsBool() bool { - return false -} - -func (i *int64Value) Type() string { - return "int" -} - -type int8Value int8 - -func (i *int8Value) Set(s string) error { - v, err := strconv.ParseInt(s, 10, 8) - if err != nil { - return err - } - *i = int8Value(v) - return nil -} - -func (i *int8Value) String() string { - return strconv.FormatInt(int64(*i), 10) -} - -func (i *int8Value) IsBool() bool { - return false -} - -func (i *int8Value) Type() string { - return "int" -} - -type int16Value int16 - -func (i *int16Value) Set(s string) error { - v, err := strconv.ParseInt(s, 10, 16) - if err != nil { - return err - } - *i = int16Value(v) - return nil -} - -func (i *int16Value) String() string { - return strconv.FormatInt(int64(*i), 10) -} - -func (i *int16Value) IsBool() bool { - return false -} - -func (i *int16Value) Type() string { - return "int" -} - -type int32Value int32 - -func (i *int32Value) Set(s string) error { - v, err := strconv.ParseInt(s, 10, 32) - if err != nil { - return err - } - *i = int32Value(v) - return nil -} - -func (i *int32Value) String() string { - return strconv.FormatInt(int64(*i), 10) -} - -func (i *int32Value) IsBool() bool { - return false -} - -func (i *int32Value) Type() string { - return "int" -} - -type uintValue uint - -func (i *uintValue) Set(s string) error { - v, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return err - } - *i = uintValue(v) - return nil -} - -func (i *uintValue) String() string { - return strconv.FormatUint(uint64(*i), 10) -} - -func (i *uintValue) IsBool() bool { - return false -} - -func (i *uintValue) Type() string { - return "uint" -} - -type uint8Value uint8 - -func (i *uint8Value) Set(s string) error { - v, err := strconv.ParseUint(s, 10, 8) - if err != nil { - return err - } - *i = uint8Value(v) - return nil -} - -func (i *uint8Value) String() string { - return strconv.FormatUint(uint64(*i), 10) -} - -func (i *uint8Value) IsBool() bool { - return false -} - -func (i *uint8Value) Type() string { - return "uint" -} - -type uint16Value uint16 - -func (i *uint16Value) Set(s string) error { - v, err := strconv.ParseUint(s, 10, 16) - if err != nil { - return err - } - *i = uint16Value(v) - return nil -} - -func (i *uint16Value) String() string { - return strconv.FormatUint(uint64(*i), 10) -} - -func (i *uint16Value) IsBool() bool { - return false -} - -func (i *uint16Value) Type() string { - return "uint" -} - -type uint32Value uint32 - -func (i *uint32Value) Set(s string) error { - v, err := strconv.ParseUint(s, 10, 32) - if err != nil { - return err - } - *i = uint32Value(v) - return nil -} - -func (i *uint32Value) String() string { - return strconv.FormatUint(uint64(*i), 10) -} - -func (i *uint32Value) IsBool() bool { - return false -} - -func (i *uint32Value) Type() string { - return "uint" -} - -type uint64Value uint64 - -func (i *uint64Value) Set(s string) error { - v, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return err - } - *i = uint64Value(v) - return nil -} - -func (i *uint64Value) String() string { - return strconv.FormatUint(uint64(*i), 10) -} - -func (i *uint64Value) IsBool() bool { - return false -} - -func (i *uint64Value) Type() string { - return "uint" -} - -type intArrayValue []int - -func (i *intArrayValue) Set(s string) error { - v, err := strconv.Atoi(s) - if err != nil { - return err - } - *i = append(*i, v) - return nil -} - -func (i *intArrayValue) String() string { - if len(*i) == 0 { - return "" - } - strs := make([]string, len(*i)) - for idx, v := range *i { - strs[idx] = strconv.Itoa(v) - } - return strings.Join(strs, ",") -} - -func (i *intArrayValue) IsBool() bool { - return false -} - -func (i *intArrayValue) Type() string { - return "int" -} - -type stringArrayValue struct { - values *[]string - hasBeenSet bool -} - -func (s *stringArrayValue) Set(val string) error { - // On first Set call, clear any default values - if !s.hasBeenSet { - *s.values = nil - s.hasBeenSet = true - } - *s.values = append(*s.values, strings.Split(val, ",")...) - return nil -} - -func (s *stringArrayValue) String() string { - if s.values == nil { - return "" - } - return strings.Join(*s.values, ",") -} - -func (s *stringArrayValue) IsBool() bool { - return false -} - -func (s *stringArrayValue) Type() string { - return "value,..." -} - -type durationValue time.Duration - -func (d *durationValue) Set(s string) error { - v, err := time.ParseDuration(s) - if err != nil { - return err - } - *d = durationValue(v) - return nil -} - -func (d *durationValue) String() string { - return time.Duration(*d).String() -} - -func (d *durationValue) IsBool() bool { - return false -} - -func (d *durationValue) Type() string { - return "duration" -} - -// Pointer value types - these allocate the pointed-to value on first Set, -// allowing code to distinguish between "not set" (nil) and "set to zero value" - -type boolPtrValue struct { - p **bool -} - -func (b *boolPtrValue) Set(s string) error { - v, err := strconv.ParseBool(s) - if err != nil { - return err - } - *b.p = new(bool) - **b.p = v - return nil -} - -func (b *boolPtrValue) String() string { - if *b.p == nil { - return "" - } - return strconv.FormatBool(**b.p) -} - -func (b *boolPtrValue) IsBool() bool { - return true -} - -func (b *boolPtrValue) Type() string { - return "bool" -} - -type stringPtrValue struct { - p **string -} - -func (s *stringPtrValue) Set(val string) error { - *s.p = new(string) - **s.p = val - return nil -} - -func (s *stringPtrValue) String() string { - if *s.p == nil { - return "" - } - return **s.p -} - -func (s *stringPtrValue) IsBool() bool { - return false -} - -func (s *stringPtrValue) Type() string { - return "string" -} - -type intPtrValue struct { - p **int -} - -func (i *intPtrValue) Set(s string) error { - v, err := strconv.Atoi(s) - if err != nil { - return err - } - *i.p = new(int) - **i.p = v - return nil -} - -func (i *intPtrValue) String() string { - if *i.p == nil { - return "" - } - return strconv.Itoa(**i.p) -} - -func (i *intPtrValue) IsBool() bool { - return false -} - -func (i *intPtrValue) Type() string { - return "int" -} - -type int64PtrValue struct { - p **int64 -} - -func (i *int64PtrValue) Set(s string) error { - v, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return err - } - *i.p = new(int64) - **i.p = v - return nil -} - -func (i *int64PtrValue) String() string { - if *i.p == nil { - return "" - } - return strconv.FormatInt(**i.p, 10) -} - -func (i *int64PtrValue) IsBool() bool { - return false -} - -func (i *int64PtrValue) Type() string { - return "int" -} - -type int8PtrValue struct { - p **int8 -} - -func (i *int8PtrValue) Set(s string) error { - v, err := strconv.ParseInt(s, 10, 8) - if err != nil { - return err - } - *i.p = new(int8) - **i.p = int8(v) - return nil -} - -func (i *int8PtrValue) String() string { - if *i.p == nil { - return "" - } - return strconv.FormatInt(int64(**i.p), 10) -} - -func (i *int8PtrValue) IsBool() bool { - return false -} - -func (i *int8PtrValue) Type() string { - return "int" -} - -type int16PtrValue struct { - p **int16 -} - -func (i *int16PtrValue) Set(s string) error { - v, err := strconv.ParseInt(s, 10, 16) - if err != nil { - return err - } - *i.p = new(int16) - **i.p = int16(v) - return nil -} - -func (i *int16PtrValue) String() string { - if *i.p == nil { - return "" - } - return strconv.FormatInt(int64(**i.p), 10) -} - -func (i *int16PtrValue) IsBool() bool { - return false -} - -func (i *int16PtrValue) Type() string { - return "int" -} - -type int32PtrValue struct { - p **int32 -} - -func (i *int32PtrValue) Set(s string) error { - v, err := strconv.ParseInt(s, 10, 32) - if err != nil { - return err - } - *i.p = new(int32) - **i.p = int32(v) - return nil -} - -func (i *int32PtrValue) String() string { - if *i.p == nil { - return "" - } - return strconv.FormatInt(int64(**i.p), 10) -} - -func (i *int32PtrValue) IsBool() bool { - return false -} - -func (i *int32PtrValue) Type() string { - return "int" -} - -type uintPtrValue struct { - p **uint -} - -func (i *uintPtrValue) Set(s string) error { - v, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return err - } - *i.p = new(uint) - **i.p = uint(v) - return nil -} - -func (i *uintPtrValue) String() string { - if *i.p == nil { - return "" - } - return strconv.FormatUint(uint64(**i.p), 10) -} - -func (i *uintPtrValue) IsBool() bool { - return false -} - -func (i *uintPtrValue) Type() string { - return "uint" -} - -type uint8PtrValue struct { - p **uint8 -} - -func (i *uint8PtrValue) Set(s string) error { - v, err := strconv.ParseUint(s, 10, 8) - if err != nil { - return err - } - *i.p = new(uint8) - **i.p = uint8(v) - return nil -} - -func (i *uint8PtrValue) String() string { - if *i.p == nil { - return "" - } - return strconv.FormatUint(uint64(**i.p), 10) -} - -func (i *uint8PtrValue) IsBool() bool { - return false -} - -func (i *uint8PtrValue) Type() string { - return "uint" -} - -type uint16PtrValue struct { - p **uint16 -} - -func (i *uint16PtrValue) Set(s string) error { - v, err := strconv.ParseUint(s, 10, 16) - if err != nil { - return err - } - *i.p = new(uint16) - **i.p = uint16(v) - return nil -} - -func (i *uint16PtrValue) String() string { - if *i.p == nil { - return "" - } - return strconv.FormatUint(uint64(**i.p), 10) -} - -func (i *uint16PtrValue) IsBool() bool { - return false -} - -func (i *uint16PtrValue) Type() string { - return "uint" -} - -type uint32PtrValue struct { - p **uint32 -} - -func (i *uint32PtrValue) Set(s string) error { - v, err := strconv.ParseUint(s, 10, 32) - if err != nil { - return err - } - *i.p = new(uint32) - **i.p = uint32(v) - return nil -} - -func (i *uint32PtrValue) String() string { - if *i.p == nil { - return "" - } - return strconv.FormatUint(uint64(**i.p), 10) -} - -func (i *uint32PtrValue) IsBool() bool { - return false -} - -func (i *uint32PtrValue) Type() string { - return "uint" -} - -type uint64PtrValue struct { - p **uint64 -} - -func (i *uint64PtrValue) Set(s string) error { - v, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return err - } - *i.p = new(uint64) - **i.p = v - return nil -} - -func (i *uint64PtrValue) String() string { - if *i.p == nil { - return "" - } - return strconv.FormatUint(**i.p, 10) -} - -func (i *uint64PtrValue) IsBool() bool { - return false -} - -func (i *uint64PtrValue) Type() string { - return "uint" -} - -type durationPtrValue struct { - p **time.Duration -} - -func (d *durationPtrValue) Set(s string) error { - v, err := time.ParseDuration(s) - if err != nil { - return err - } - *d.p = new(time.Duration) - **d.p = v - return nil -} - -func (d *durationPtrValue) String() string { - if *d.p == nil { - return "" - } - return (**d.p).String() -} - -func (d *durationPtrValue) IsBool() bool { - return false -} - -func (d *durationPtrValue) Type() string { - return "duration" -} - -// choiceValue represents a string flag that only accepts specific values. -// It validates inputs against a predefined set of choices. -type choiceValue struct { - value *string - choices []string -} - -func (c *choiceValue) Set(s string) error { - for _, choice := range c.choices { - if s == choice { - *c.value = s - return nil - } - } - return fmt.Errorf("%w: %q (valid: %s)", ErrInvalidChoice, s, strings.Join(c.choices, ", ")) -} - -func (c *choiceValue) String() string { - if c.value == nil { - return "" - } - return *c.value -} - -func (c *choiceValue) IsBool() bool { - return false -} - -func (c *choiceValue) Type() string { - return strings.Join(c.choices, "|") -} - -// Choices returns the valid choices for this value -func (c *choiceValue) Choices() []string { - return c.choices -} - // NewFlagSet returns a new, empty flag set with the specified name. // The name is used for error messages and help output. func NewFlagSet(name string) *FlagSet { @@ -1562,7 +762,7 @@ func (f *FlagSet) parseShortFlags(shortFlags string, args []string, index *int) return fmt.Errorf("%w: -%c: %v", ErrInvalidValue, r, err) } flag.HasValue = true - + } else { // Check if there are more characters after this flag if i < len(runes)-1 { @@ -1578,7 +778,7 @@ func (f *FlagSet) parseShortFlags(shortFlags string, args []string, index *int) return fmt.Errorf("%w: -%c: %v", ErrInvalidValue, r, err) } flag.HasValue = true - + break } else if *index+1 < len(args) { value := args[*index+1] @@ -1587,7 +787,7 @@ func (f *FlagSet) parseShortFlags(shortFlags string, args []string, index *int) return fmt.Errorf("%w: -%c: %v", ErrInvalidValue, r, err) } flag.HasValue = true - + } else { return fmt.Errorf("%w: -%c", ErrMissingValue, r) } @@ -1621,713 +821,3 @@ func (f *FlagSet) AllowUnknownFlags(allow bool) { func (f *FlagSet) UnknownFlags() []string { return f.unknownFlags } - -// getTagValues extracts all values for a given key from a struct tag. -// This is needed because Go's tag.Get() only returns the first value, -// but we need to support multiple values (e.g., multiple choice tags). -func getTagValues(tag reflect.StructTag, key string) []string { - var values []string - tagStr := string(tag) - searchKey := key + `:` - - for { - idx := strings.Index(tagStr, searchKey) - if idx < 0 { - break - } - - // Move past the key and colon - tagStr = tagStr[idx+len(searchKey):] - - // Find the quoted value - if len(tagStr) == 0 || tagStr[0] != '"' { - break - } - - // Find the closing quote - endIdx := 1 - for endIdx < len(tagStr) && tagStr[endIdx] != '"' { - if tagStr[endIdx] == '\\' && endIdx+1 < len(tagStr) { - endIdx += 2 // Skip escaped character - } else { - endIdx++ - } - } - - if endIdx < len(tagStr) { - value := tagStr[1:endIdx] - values = append(values, value) - tagStr = tagStr[endIdx+1:] - } else { - break - } - } - - return values -} - -// knownTags is the set of struct tag keys that FromStruct knows how to handle. -// Keep this in sync with the Tag.Get() and getTagValues() calls in FromStruct. -var knownTags = map[string]bool{ - "long": true, - "short": true, - "default": true, - "env": true, - "required": true, - "usage": true, - "description": true, - "choice": true, - "position": true, - "rest": true, - "unknown": true, - "group": true, -} - -// validateStructTags checks that every struct tag on exported fields is one -// that FromStruct actually reads. It returns an error listing all unrecognized -// tags so the caller can fix them all in one pass. -func validateStructTags(rt reflect.Type) error { - var errs []string - for i := 0; i < rt.NumField(); i++ { - field := rt.Field(i) - - // Skip the blank identifier used for group declarations - if field.Name == "_" { - continue - } - - if !field.IsExported() { - continue - } - - // Recurse into embedded structs - ft := field.Type - if ft.Kind() == reflect.Ptr { - ft = ft.Elem() - } - if field.Anonymous && ft.Kind() == reflect.Struct { - if err := validateStructTags(ft); err != nil { - errs = append(errs, err.Error()) - } - continue - } - - // Parse the raw tag string into key:"value" pairs and check each key - tagStr := string(field.Tag) - for tagStr != "" { - // Skip leading spaces - tagStr = strings.TrimLeft(tagStr, " ") - if tagStr == "" { - break - } - - // Find the key (everything before the colon) - colon := strings.Index(tagStr, ":") - if colon < 0 { - break - } - key := tagStr[:colon] - tagStr = tagStr[colon+1:] - - // Skip past the quoted value - if len(tagStr) == 0 || tagStr[0] != '"' { - break - } - end := 1 - for end < len(tagStr) && tagStr[end] != '"' { - if tagStr[end] == '\\' && end+1 < len(tagStr) { - end += 2 - } else { - end++ - } - } - if end >= len(tagStr) { - break - } - tagStr = tagStr[end+1:] - - if !knownTags[key] { - errs = append(errs, fmt.Sprintf("unknown struct tag %q on field %s in %s", key, field.Name, rt.Name())) - } - } - } - if len(errs) > 0 { - return fmt.Errorf("invalid struct tags:\n %s", strings.Join(errs, "\n ")) - } - return nil -} - -// setFieldValue sets a string value to a reflect.Value based on its type -func setFieldValue(fieldValue reflect.Value, value string) error { - switch fieldValue.Kind() { - case reflect.String: - fieldValue.SetString(value) - case reflect.Bool: - b, err := strconv.ParseBool(value) - if err != nil { - return err - } - fieldValue.SetBool(b) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if fieldValue.Type() == reflect.TypeOf(time.Duration(0)) { - d, err := time.ParseDuration(value) - if err != nil { - return err - } - fieldValue.SetInt(int64(d)) - } else { - i, err := strconv.ParseInt(value, 10, fieldValue.Type().Bits()) - if err != nil { - return err - } - fieldValue.SetInt(i) - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - u, err := strconv.ParseUint(value, 10, fieldValue.Type().Bits()) - if err != nil { - return err - } - fieldValue.SetUint(u) - case reflect.Float32, reflect.Float64: - f, err := strconv.ParseFloat(value, fieldValue.Type().Bits()) - if err != nil { - return err - } - fieldValue.SetFloat(f) - default: - return fmt.Errorf("unsupported type: %v", fieldValue.Type()) - } - return nil -} - -// FromStructOption configures how FromStruct processes a struct. -type FromStructOption func(*fromStructConfig) - -type fromStructConfig struct { - group string -} - -// InGroup sets the group name for all flags created by FromStruct. -func InGroup(name string) FromStructOption { - return func(c *fromStructConfig) { c.group = name } -} - -// FromStruct creates flag definitions from a struct's fields using struct tags. -// The argument must be a pointer to a struct. Struct tags control how fields are parsed: -// - `long:"name"` - long flag name (defaults to lowercase field name) -// - `short:"x"` - short flag name (single character) -// - `default:"value"` - default value for the flag -// - `env:"VAR_NAME"` - populate default from an environment variable (overrides default, overridden by CLI) -// - `required:"true"` - return a parse error if the flag/positional wasn't provided -// - `usage:"description"` - usage description -// - `description:"description"` - alternate usage description -// - `choice:"value"` - constrain string field to specific values (can be repeated for multiple choices) -// - `position:"0"` - positional argument at index 0 -// - `rest:"true"` - capture all remaining arguments in a []string field -// - `unknown:"true"` - capture unknown flags in a []string field (automatically enables AllowUnknownFlags) -// - `group:"name"` - on a `_ struct{}` field, declares the group for all flags in the struct -// - `group:"name"` - on an embedded struct field, overrides the embedded struct's self-declared group -// -// Supports bool, string, int, []string, and time.Duration field types. -// Anonymous embedded structs are recursively processed. -func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error { - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return fmt.Errorf("FromStruct requires a non-nil pointer to a struct") - } - - rv = rv.Elem() - if rv.Kind() != reflect.Struct { - return fmt.Errorf("FromStruct requires a pointer to a struct") - } - - // Apply options - var cfg fromStructConfig - for _, opt := range opts { - opt(&cfg) - } - - // Save and restore currentGroup - prevGroup := f.currentGroup - defer func() { f.currentGroup = prevGroup }() - - if cfg.group != "" { - f.currentGroup = cfg.group - } - - rt := rv.Type() - - // Validate that all struct tags are ones we know how to handle - if err := validateStructTags(rt); err != nil { - return err - } - - // First pass: check for self-declared group via `_ struct{} \`group:"..."\`` - for i := 0; i < rt.NumField(); i++ { - field := rt.Field(i) - if field.Name == "_" && field.Type == reflect.TypeOf(struct{}{}) { - if groupTag := field.Tag.Get("group"); groupTag != "" && f.currentGroup == "" { - f.currentGroup = groupTag - } - break - } - } - - for i := 0; i < rt.NumField(); i++ { - field := rt.Field(i) - - // Skip the `_` group declaration field - if field.Name == "_" { - continue - } - - if !field.IsExported() { - continue - } - - fieldValue := rv.Field(i) - if !fieldValue.CanAddr() { - continue - } - - // Check for anonymous/embedded struct fields and descend into them - if field.Anonymous && field.Type.Kind() == reflect.Struct { - // Check for group tag on the embedding site - if groupTag := field.Tag.Get("group"); groupTag != "" { - if err := f.FromStruct(fieldValue.Addr().Interface(), InGroup(groupTag)); err != nil { - return err - } - } else { - if err := f.FromStruct(fieldValue.Addr().Interface()); err != nil { - return err - } - } - continue - } - - // Check for "position" tag - capture positional argument - if posStr := field.Tag.Get("position"); posStr != "" { - pos, err := strconv.Atoi(posStr) - if err == nil && pos >= 0 { - // Get usage from either "usage" or "description" tag - posUsage := field.Tag.Get("usage") - if posUsage == "" { - posUsage = field.Tag.Get("description") - } - posEnvVar := field.Tag.Get("env") - posRequired := field.Tag.Get("required") == "true" - posHasValue := false - - // Environment variable sets the positional default - if posEnvVar != "" { - if envVal, ok := os.LookupEnv(posEnvVar); ok { - if err := setFieldValue(fieldValue, envVal); err != nil { - return fmt.Errorf("invalid value for env var %s: %w", posEnvVar, err) - } - posHasValue = true - } - } - - f.posFields[pos] = &PositionalField{ - Name: field.Name, - Usage: posUsage, - Value: fieldValue, - Type: field.Type, - EnvVar: posEnvVar, - Required: posRequired, - HasValue: posHasValue, - } - - if posRequired { - f.requiredPos = append(f.requiredPos, pos) - } - } - continue // Don't process position field as a flag - } - - // Check for "rest" tag - special handling for remaining arguments - if field.Tag.Get("rest") != "" { - if field.Type.Kind() == reflect.Slice && field.Type.Elem().Kind() == reflect.String { - f.restField = fieldValue.Addr().Interface().(*[]string) - } - continue // Don't process rest field as a flag - } - - // Check for "unknown" tag - special handling for unknown flags - if field.Tag.Get("unknown") != "" { - if field.Type.Kind() == reflect.Slice && field.Type.Elem().Kind() == reflect.String { - f.unknownField = fieldValue.Addr().Interface().(*[]string) - f.allowUnknownFlags = true // Automatically enable unknown flag handling - } - continue // Don't process unknown field as a flag - } - - // Parse struct tags - longName := field.Tag.Get("long") - if longName == "" { - longName = strings.ToLower(field.Name) - } - - shortName := field.Tag.Get("short") - var short rune - if shortName != "" && len(shortName) == 1 { - short = rune(shortName[0]) - } - - if longName == "" && short == 0 { - continue // No flag name provided - } - - defaultValue := field.Tag.Get("default") - envVar := field.Tag.Get("env") - required := field.Tag.Get("required") == "true" - - usage := field.Tag.Get("usage") - if usage == "" { - usage = field.Tag.Get("description") - if usage == "" { - usage = fmt.Sprintf("%s value", field.Name) - } - } - - if required { - f.requiredFlags = append(f.requiredFlags, longName) - } - - // Register the flag based on field type - switch field.Type.Kind() { - case reflect.Bool: - var defVal bool - if defaultValue != "" { - defVal, _ = strconv.ParseBool(defaultValue) - } - f.BoolVar(fieldValue.Addr().Interface().(*bool), longName, short, defVal, usage) - - case reflect.String: - // Check for choice tags - if present, use ChoiceVar - choices := getTagValues(field.Tag, "choice") - if len(choices) > 0 { - f.ChoiceVar(fieldValue.Addr().Interface().(*string), longName, short, defaultValue, choices, usage) - } else { - f.StringVar(fieldValue.Addr().Interface().(*string), longName, short, defaultValue, usage) - } - - case reflect.Int: - var defVal int - if defaultValue != "" { - defVal, _ = strconv.Atoi(defaultValue) - } - f.IntVar(fieldValue.Addr().Interface().(*int), longName, short, defVal, usage) - - case reflect.Slice: - switch field.Type.Elem().Kind() { - case reflect.String: - var defVal []string - if defaultValue != "" { - defVal = strings.Split(defaultValue, ",") - } - f.StringArrayVar(fieldValue.Addr().Interface().(*[]string), longName, short, defVal, usage) - case reflect.Bool: - f.BoolArrayVar(fieldValue.Addr().Interface().(*[]bool), longName, short, usage) - case reflect.Int: - f.IntArrayVar(fieldValue.Addr().Interface().(*[]int), longName, short, usage) - } - - case reflect.Int64: - // Check if it's a time.Duration - if field.Type == reflect.TypeOf(time.Duration(0)) { - var defVal time.Duration - if defaultValue != "" { - defVal, _ = time.ParseDuration(defaultValue) - } - f.DurationVar(fieldValue.Addr().Interface().(*time.Duration), longName, short, defVal, usage) - } else { - var defVal int64 - if defaultValue != "" { - defVal, _ = strconv.ParseInt(defaultValue, 10, 64) - } - f.Int64Var(fieldValue.Addr().Interface().(*int64), longName, short, defVal, usage) - } - - case reflect.Int8: - var defVal int8 - if defaultValue != "" { - v, _ := strconv.ParseInt(defaultValue, 10, 8) - defVal = int8(v) - } - f.Int8Var(fieldValue.Addr().Interface().(*int8), longName, short, defVal, usage) - - case reflect.Int16: - var defVal int16 - if defaultValue != "" { - v, _ := strconv.ParseInt(defaultValue, 10, 16) - defVal = int16(v) - } - f.Int16Var(fieldValue.Addr().Interface().(*int16), longName, short, defVal, usage) - - case reflect.Int32: - var defVal int32 - if defaultValue != "" { - v, _ := strconv.ParseInt(defaultValue, 10, 32) - defVal = int32(v) - } - f.Int32Var(fieldValue.Addr().Interface().(*int32), longName, short, defVal, usage) - - case reflect.Uint: - var defVal uint - if defaultValue != "" { - v, _ := strconv.ParseUint(defaultValue, 10, 64) - defVal = uint(v) - } - f.UintVar(fieldValue.Addr().Interface().(*uint), longName, short, defVal, usage) - - case reflect.Uint8: - var defVal uint8 - if defaultValue != "" { - v, _ := strconv.ParseUint(defaultValue, 10, 8) - defVal = uint8(v) - } - f.Uint8Var(fieldValue.Addr().Interface().(*uint8), longName, short, defVal, usage) - - case reflect.Uint16: - var defVal uint16 - if defaultValue != "" { - v, _ := strconv.ParseUint(defaultValue, 10, 16) - defVal = uint16(v) - } - f.Uint16Var(fieldValue.Addr().Interface().(*uint16), longName, short, defVal, usage) - - case reflect.Uint32: - var defVal uint32 - if defaultValue != "" { - v, _ := strconv.ParseUint(defaultValue, 10, 32) - defVal = uint32(v) - } - f.Uint32Var(fieldValue.Addr().Interface().(*uint32), longName, short, defVal, usage) - - case reflect.Uint64: - var defVal uint64 - if defaultValue != "" { - defVal, _ = strconv.ParseUint(defaultValue, 10, 64) - } - f.Uint64Var(fieldValue.Addr().Interface().(*uint64), longName, short, defVal, usage) - - case reflect.Ptr: - // Handle pointer types - allows distinguishing "not set" from "zero value" - elemKind := field.Type.Elem().Kind() - switch elemKind { - case reflect.Bool: - p := fieldValue.Addr().Interface().(**bool) - f.Var(&boolPtrValue{p: p}, longName, short, usage) - case reflect.String: - p := fieldValue.Addr().Interface().(**string) - f.Var(&stringPtrValue{p: p}, longName, short, usage) - case reflect.Int: - p := fieldValue.Addr().Interface().(**int) - f.Var(&intPtrValue{p: p}, longName, short, usage) - case reflect.Int64: - // Check if it's a *time.Duration - if field.Type.Elem() == reflect.TypeOf(time.Duration(0)) { - p := fieldValue.Addr().Interface().(**time.Duration) - f.Var(&durationPtrValue{p: p}, longName, short, usage) - } else { - p := fieldValue.Addr().Interface().(**int64) - f.Var(&int64PtrValue{p: p}, longName, short, usage) - } - case reflect.Int8: - p := fieldValue.Addr().Interface().(**int8) - f.Var(&int8PtrValue{p: p}, longName, short, usage) - case reflect.Int16: - p := fieldValue.Addr().Interface().(**int16) - f.Var(&int16PtrValue{p: p}, longName, short, usage) - case reflect.Int32: - p := fieldValue.Addr().Interface().(**int32) - f.Var(&int32PtrValue{p: p}, longName, short, usage) - case reflect.Uint: - p := fieldValue.Addr().Interface().(**uint) - f.Var(&uintPtrValue{p: p}, longName, short, usage) - case reflect.Uint8: - p := fieldValue.Addr().Interface().(**uint8) - f.Var(&uint8PtrValue{p: p}, longName, short, usage) - case reflect.Uint16: - p := fieldValue.Addr().Interface().(**uint16) - f.Var(&uint16PtrValue{p: p}, longName, short, usage) - case reflect.Uint32: - p := fieldValue.Addr().Interface().(**uint32) - f.Var(&uint32PtrValue{p: p}, longName, short, usage) - case reflect.Uint64: - p := fieldValue.Addr().Interface().(**uint64) - f.Var(&uint64PtrValue{p: p}, longName, short, usage) - } - } - - // Set env/required metadata on the registered flag, and apply - // env var value through the flag's Value.Set path so it gets - // validated and works for all types including pointers and slices. - if flag, ok := f.flags[longName]; ok { - flag.EnvVar = envVar - flag.Required = required - if envVar != "" { - if envVal, ok := os.LookupEnv(envVar); ok { - if err := flag.Value.Set(envVal); err != nil { - return fmt.Errorf("invalid value for env var %s: %w", envVar, err) - } - flag.HasValue = true - } - } - } - } - - return nil -} - -// formatFlagLine formats a single flag for help output. -func formatFlagLine(flag *Flag) string { - var flagStr string - if flag.Short != 0 && flag.Name != "" { - flagStr = fmt.Sprintf(" -%c, --%s", flag.Short, flag.Name) - } else if flag.Short != 0 { - flagStr = fmt.Sprintf(" -%c", flag.Short) - } else { - flagStr = fmt.Sprintf(" --%s", flag.Name) - } - - // Add value placeholder for non-boolean flags - if !flag.Value.IsBool() { - flagStr += fmt.Sprintf(" <%s>", flag.Value.Type()) - } - - // Format with usage - if flag.Usage != "" { - line := fmt.Sprintf("%-30s %s", flagStr, flag.Usage) - if flag.DefValue != "" && flag.DefValue != "false" && flag.DefValue != "0" { - line += fmt.Sprintf(" (default: %s)", flag.DefValue) - } - if flag.EnvVar != "" { - line += fmt.Sprintf(" (env: %s)", flag.EnvVar) - } - return line - } - return flagStr -} - -// WriteFlagHelp writes group-aware flag help output to stdout. -// Named groups are printed first in insertion order, followed by the default -// (unnamed) group under "Options:". -func (f *FlagSet) WriteFlagHelp() { - // Collect flags into groups, preserving VisitAll sort order - type groupFlags struct { - name string - flags []*Flag - } - - groupMap := make(map[string]*groupFlags) - var defaultFlags []*Flag - - f.VisitAll(func(flag *Flag) { - if flag.Group == "" { - defaultFlags = append(defaultFlags, flag) - } else { - gf, ok := groupMap[flag.Group] - if !ok { - gf = &groupFlags{name: flag.Group} - groupMap[flag.Group] = gf - } - gf.flags = append(gf.flags, flag) - } - }) - - // Print named groups in insertion order - for _, groupName := range f.groupOrder { - gf, ok := groupMap[groupName] - if !ok || len(gf.flags) == 0 { - continue - } - fmt.Printf("\n%s:\n", gf.name) - for _, flag := range gf.flags { - fmt.Println(formatFlagLine(flag)) - } - } - - // Print default group last - if len(defaultFlags) > 0 { - fmt.Println("\nOptions:") - for _, flag := range defaultFlags { - fmt.Println(formatFlagLine(flag)) - } - } -} - -// ShowHelp displays help information for the flag set, including all defined flags -// and their usage information. -func (f *FlagSet) ShowHelp() { - if f.name != "" { - fmt.Printf("Usage: %s [options]", f.name) - // Show positional arguments by name - if len(f.posFields) > 0 { - // Find max position to iterate in order - maxPos := -1 - for pos := range f.posFields { - if pos > maxPos { - maxPos = pos - } - } - // Print each positional argument name - for i := 0; i <= maxPos; i++ { - if field, ok := f.posFields[i]; ok { - fmt.Printf(" <%s>", strings.ToLower(field.Name)) - } - } - } - if f.restField != nil { - fmt.Print(" [arguments...]") - } - fmt.Println() - } - - // Show positional arguments with descriptions if any have usage text - if len(f.posFields) > 0 { - // Find max position to iterate in order - maxPos := -1 - hasUsage := false - for pos, field := range f.posFields { - if pos > maxPos { - maxPos = pos - } - if field.Usage != "" { - hasUsage = true - } - } - - if hasUsage { - fmt.Println("\nArguments:") - for i := 0; i <= maxPos; i++ { - if field, ok := f.posFields[i]; ok { - argStr := fmt.Sprintf(" <%s>", strings.ToLower(field.Name)) - if field.Usage != "" { - line := fmt.Sprintf("%-30s %s", argStr, field.Usage) - if field.EnvVar != "" { - line += fmt.Sprintf(" (env: %s)", field.EnvVar) - } - fmt.Println(line) - } else { - fmt.Println(argStr) - } - } - } - } - } - - f.WriteFlagHelp() -} - -// ParseStruct parses command line arguments and updates the struct fields. -// This is a convenience function that creates a FlagSet, calls FromStruct, and parses the arguments. -// See FromStruct for documentation on supported struct tags and field types. -func ParseStruct(v any, arguments []string) error { - fs := NewFlagSet("") - if err := fs.FromStruct(v); err != nil { - return err - } - return fs.Parse(arguments) -} diff --git a/mflags_test.go b/mflags_test.go index 992ba39..1a1eb13 100644 --- a/mflags_test.go +++ b/mflags_test.go @@ -921,7 +921,7 @@ func TestFromStructRejectsUnknownTags(t *testing.T) { t.Run("unexported fields with arbitrary tags are ignored", func(t *testing.T) { type Opts struct { Verbose bool `long:"verbose"` - internal string `json:"internal" xml:"internal"` + internal string //nolint:unused // present to verify unexported fields are skipped } fs := NewFlagSet("test") err := fs.FromStruct(&Opts{}) diff --git a/values.go b/values.go new file mode 100644 index 0000000..d9699a9 --- /dev/null +++ b/values.go @@ -0,0 +1,806 @@ +package mflags + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +type boolValue bool + +func (b *boolValue) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + *b = boolValue(v) + return nil +} + +func (b *boolValue) String() string { + return strconv.FormatBool(bool(*b)) +} + +func (b *boolValue) IsBool() bool { + return true +} + +func (b *boolValue) Type() string { + return "bool" +} + +type boolArrayValue []bool + +func (b *boolArrayValue) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + *b = append(*b, v) + return nil +} + +func (b *boolArrayValue) String() string { + if len(*b) == 0 { + return "" + } + strs := make([]string, len(*b)) + for i, v := range *b { + strs[i] = strconv.FormatBool(v) + } + return strings.Join(strs, ",") +} + +func (b *boolArrayValue) IsBool() bool { + return true +} + +func (b *boolArrayValue) Type() string { + return "bool" +} + +type stringValue string + +func (s *stringValue) Set(val string) error { + *s = stringValue(val) + return nil +} + +func (s *stringValue) String() string { + return string(*s) +} + +func (s *stringValue) IsBool() bool { + return false +} + +func (s *stringValue) Type() string { + return "string" +} + +type intValue int + +func (i *intValue) Set(s string) error { + v, err := strconv.Atoi(s) + if err != nil { + return err + } + *i = intValue(v) + return nil +} + +func (i *intValue) String() string { + return strconv.Itoa(int(*i)) +} + +func (i *intValue) IsBool() bool { + return false +} + +func (i *intValue) Type() string { + return "int" +} + +type int64Value int64 + +func (i *int64Value) Set(s string) error { + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + *i = int64Value(v) + return nil +} + +func (i *int64Value) String() string { + return strconv.FormatInt(int64(*i), 10) +} + +func (i *int64Value) IsBool() bool { + return false +} + +func (i *int64Value) Type() string { + return "int" +} + +type int8Value int8 + +func (i *int8Value) Set(s string) error { + v, err := strconv.ParseInt(s, 10, 8) + if err != nil { + return err + } + *i = int8Value(v) + return nil +} + +func (i *int8Value) String() string { + return strconv.FormatInt(int64(*i), 10) +} + +func (i *int8Value) IsBool() bool { + return false +} + +func (i *int8Value) Type() string { + return "int" +} + +type int16Value int16 + +func (i *int16Value) Set(s string) error { + v, err := strconv.ParseInt(s, 10, 16) + if err != nil { + return err + } + *i = int16Value(v) + return nil +} + +func (i *int16Value) String() string { + return strconv.FormatInt(int64(*i), 10) +} + +func (i *int16Value) IsBool() bool { + return false +} + +func (i *int16Value) Type() string { + return "int" +} + +type int32Value int32 + +func (i *int32Value) Set(s string) error { + v, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return err + } + *i = int32Value(v) + return nil +} + +func (i *int32Value) String() string { + return strconv.FormatInt(int64(*i), 10) +} + +func (i *int32Value) IsBool() bool { + return false +} + +func (i *int32Value) Type() string { + return "int" +} + +type uintValue uint + +func (i *uintValue) Set(s string) error { + v, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return err + } + *i = uintValue(v) + return nil +} + +func (i *uintValue) String() string { + return strconv.FormatUint(uint64(*i), 10) +} + +func (i *uintValue) IsBool() bool { + return false +} + +func (i *uintValue) Type() string { + return "uint" +} + +type uint8Value uint8 + +func (i *uint8Value) Set(s string) error { + v, err := strconv.ParseUint(s, 10, 8) + if err != nil { + return err + } + *i = uint8Value(v) + return nil +} + +func (i *uint8Value) String() string { + return strconv.FormatUint(uint64(*i), 10) +} + +func (i *uint8Value) IsBool() bool { + return false +} + +func (i *uint8Value) Type() string { + return "uint" +} + +type uint16Value uint16 + +func (i *uint16Value) Set(s string) error { + v, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return err + } + *i = uint16Value(v) + return nil +} + +func (i *uint16Value) String() string { + return strconv.FormatUint(uint64(*i), 10) +} + +func (i *uint16Value) IsBool() bool { + return false +} + +func (i *uint16Value) Type() string { + return "uint" +} + +type uint32Value uint32 + +func (i *uint32Value) Set(s string) error { + v, err := strconv.ParseUint(s, 10, 32) + if err != nil { + return err + } + *i = uint32Value(v) + return nil +} + +func (i *uint32Value) String() string { + return strconv.FormatUint(uint64(*i), 10) +} + +func (i *uint32Value) IsBool() bool { + return false +} + +func (i *uint32Value) Type() string { + return "uint" +} + +type uint64Value uint64 + +func (i *uint64Value) Set(s string) error { + v, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return err + } + *i = uint64Value(v) + return nil +} + +func (i *uint64Value) String() string { + return strconv.FormatUint(uint64(*i), 10) +} + +func (i *uint64Value) IsBool() bool { + return false +} + +func (i *uint64Value) Type() string { + return "uint" +} + +type intArrayValue []int + +func (i *intArrayValue) Set(s string) error { + v, err := strconv.Atoi(s) + if err != nil { + return err + } + *i = append(*i, v) + return nil +} + +func (i *intArrayValue) String() string { + if len(*i) == 0 { + return "" + } + strs := make([]string, len(*i)) + for idx, v := range *i { + strs[idx] = strconv.Itoa(v) + } + return strings.Join(strs, ",") +} + +func (i *intArrayValue) IsBool() bool { + return false +} + +func (i *intArrayValue) Type() string { + return "int" +} + +type stringArrayValue struct { + values *[]string + hasBeenSet bool +} + +func (s *stringArrayValue) Set(val string) error { + // On first Set call, clear any default values + if !s.hasBeenSet { + *s.values = nil + s.hasBeenSet = true + } + *s.values = append(*s.values, strings.Split(val, ",")...) + return nil +} + +func (s *stringArrayValue) String() string { + if s.values == nil { + return "" + } + return strings.Join(*s.values, ",") +} + +func (s *stringArrayValue) IsBool() bool { + return false +} + +func (s *stringArrayValue) Type() string { + return "value,..." +} + +type durationValue time.Duration + +func (d *durationValue) Set(s string) error { + v, err := time.ParseDuration(s) + if err != nil { + return err + } + *d = durationValue(v) + return nil +} + +func (d *durationValue) String() string { + return time.Duration(*d).String() +} + +func (d *durationValue) IsBool() bool { + return false +} + +func (d *durationValue) Type() string { + return "duration" +} + +// Pointer value types - these allocate the pointed-to value on first Set, +// allowing code to distinguish between "not set" (nil) and "set to zero value" + +type boolPtrValue struct { + p **bool +} + +func (b *boolPtrValue) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + *b.p = new(bool) + **b.p = v + return nil +} + +func (b *boolPtrValue) String() string { + if *b.p == nil { + return "" + } + return strconv.FormatBool(**b.p) +} + +func (b *boolPtrValue) IsBool() bool { + return true +} + +func (b *boolPtrValue) Type() string { + return "bool" +} + +type stringPtrValue struct { + p **string +} + +func (s *stringPtrValue) Set(val string) error { + *s.p = new(string) + **s.p = val + return nil +} + +func (s *stringPtrValue) String() string { + if *s.p == nil { + return "" + } + return **s.p +} + +func (s *stringPtrValue) IsBool() bool { + return false +} + +func (s *stringPtrValue) Type() string { + return "string" +} + +type intPtrValue struct { + p **int +} + +func (i *intPtrValue) Set(s string) error { + v, err := strconv.Atoi(s) + if err != nil { + return err + } + *i.p = new(int) + **i.p = v + return nil +} + +func (i *intPtrValue) String() string { + if *i.p == nil { + return "" + } + return strconv.Itoa(**i.p) +} + +func (i *intPtrValue) IsBool() bool { + return false +} + +func (i *intPtrValue) Type() string { + return "int" +} + +type int64PtrValue struct { + p **int64 +} + +func (i *int64PtrValue) Set(s string) error { + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + *i.p = new(int64) + **i.p = v + return nil +} + +func (i *int64PtrValue) String() string { + if *i.p == nil { + return "" + } + return strconv.FormatInt(**i.p, 10) +} + +func (i *int64PtrValue) IsBool() bool { + return false +} + +func (i *int64PtrValue) Type() string { + return "int" +} + +type int8PtrValue struct { + p **int8 +} + +func (i *int8PtrValue) Set(s string) error { + v, err := strconv.ParseInt(s, 10, 8) + if err != nil { + return err + } + *i.p = new(int8) + **i.p = int8(v) + return nil +} + +func (i *int8PtrValue) String() string { + if *i.p == nil { + return "" + } + return strconv.FormatInt(int64(**i.p), 10) +} + +func (i *int8PtrValue) IsBool() bool { + return false +} + +func (i *int8PtrValue) Type() string { + return "int" +} + +type int16PtrValue struct { + p **int16 +} + +func (i *int16PtrValue) Set(s string) error { + v, err := strconv.ParseInt(s, 10, 16) + if err != nil { + return err + } + *i.p = new(int16) + **i.p = int16(v) + return nil +} + +func (i *int16PtrValue) String() string { + if *i.p == nil { + return "" + } + return strconv.FormatInt(int64(**i.p), 10) +} + +func (i *int16PtrValue) IsBool() bool { + return false +} + +func (i *int16PtrValue) Type() string { + return "int" +} + +type int32PtrValue struct { + p **int32 +} + +func (i *int32PtrValue) Set(s string) error { + v, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return err + } + *i.p = new(int32) + **i.p = int32(v) + return nil +} + +func (i *int32PtrValue) String() string { + if *i.p == nil { + return "" + } + return strconv.FormatInt(int64(**i.p), 10) +} + +func (i *int32PtrValue) IsBool() bool { + return false +} + +func (i *int32PtrValue) Type() string { + return "int" +} + +type uintPtrValue struct { + p **uint +} + +func (i *uintPtrValue) Set(s string) error { + v, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return err + } + *i.p = new(uint) + **i.p = uint(v) + return nil +} + +func (i *uintPtrValue) String() string { + if *i.p == nil { + return "" + } + return strconv.FormatUint(uint64(**i.p), 10) +} + +func (i *uintPtrValue) IsBool() bool { + return false +} + +func (i *uintPtrValue) Type() string { + return "uint" +} + +type uint8PtrValue struct { + p **uint8 +} + +func (i *uint8PtrValue) Set(s string) error { + v, err := strconv.ParseUint(s, 10, 8) + if err != nil { + return err + } + *i.p = new(uint8) + **i.p = uint8(v) + return nil +} + +func (i *uint8PtrValue) String() string { + if *i.p == nil { + return "" + } + return strconv.FormatUint(uint64(**i.p), 10) +} + +func (i *uint8PtrValue) IsBool() bool { + return false +} + +func (i *uint8PtrValue) Type() string { + return "uint" +} + +type uint16PtrValue struct { + p **uint16 +} + +func (i *uint16PtrValue) Set(s string) error { + v, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return err + } + *i.p = new(uint16) + **i.p = uint16(v) + return nil +} + +func (i *uint16PtrValue) String() string { + if *i.p == nil { + return "" + } + return strconv.FormatUint(uint64(**i.p), 10) +} + +func (i *uint16PtrValue) IsBool() bool { + return false +} + +func (i *uint16PtrValue) Type() string { + return "uint" +} + +type uint32PtrValue struct { + p **uint32 +} + +func (i *uint32PtrValue) Set(s string) error { + v, err := strconv.ParseUint(s, 10, 32) + if err != nil { + return err + } + *i.p = new(uint32) + **i.p = uint32(v) + return nil +} + +func (i *uint32PtrValue) String() string { + if *i.p == nil { + return "" + } + return strconv.FormatUint(uint64(**i.p), 10) +} + +func (i *uint32PtrValue) IsBool() bool { + return false +} + +func (i *uint32PtrValue) Type() string { + return "uint" +} + +type uint64PtrValue struct { + p **uint64 +} + +func (i *uint64PtrValue) Set(s string) error { + v, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return err + } + *i.p = new(uint64) + **i.p = v + return nil +} + +func (i *uint64PtrValue) String() string { + if *i.p == nil { + return "" + } + return strconv.FormatUint(**i.p, 10) +} + +func (i *uint64PtrValue) IsBool() bool { + return false +} + +func (i *uint64PtrValue) Type() string { + return "uint" +} + +type durationPtrValue struct { + p **time.Duration +} + +func (d *durationPtrValue) Set(s string) error { + v, err := time.ParseDuration(s) + if err != nil { + return err + } + *d.p = new(time.Duration) + **d.p = v + return nil +} + +func (d *durationPtrValue) String() string { + if *d.p == nil { + return "" + } + return (**d.p).String() +} + +func (d *durationPtrValue) IsBool() bool { + return false +} + +func (d *durationPtrValue) Type() string { + return "duration" +} + +// choiceValue represents a string flag that only accepts specific values. +// It validates inputs against a predefined set of choices. +type choiceValue struct { + value *string + choices []string +} + +func (c *choiceValue) Set(s string) error { + for _, choice := range c.choices { + if s == choice { + *c.value = s + return nil + } + } + return fmt.Errorf("%w: %q (valid: %s)", ErrInvalidChoice, s, strings.Join(c.choices, ", ")) +} + +func (c *choiceValue) String() string { + if c.value == nil { + return "" + } + return *c.value +} + +func (c *choiceValue) IsBool() bool { + return false +} + +func (c *choiceValue) Type() string { + return strings.Join(c.choices, "|") +} + +// Choices returns the valid choices for this value +func (c *choiceValue) Choices() []string { + return c.choices +}