diff --git a/mflags.go b/mflags.go index ef2e814..d5a9c2f 100644 --- a/mflags.go +++ b/mflags.go @@ -1614,6 +1614,95 @@ func getTagValues(tag reflect.StructTag, key string) []string { return values } +// knownTags is the set of struct tag keys that FromStruct knows how to handle. +// Keep this in sync with the Tag.Get() and getTagValues() calls in FromStruct. +var knownTags = map[string]bool{ + "long": true, + "short": true, + "default": true, + "usage": true, + "description": true, + "choice": true, + "position": true, + "rest": true, + "unknown": true, + "group": true, +} + +// validateStructTags checks that every struct tag on exported fields is one +// that FromStruct actually reads. It returns an error listing all unrecognized +// tags so the caller can fix them all in one pass. +func validateStructTags(rt reflect.Type) error { + var errs []string + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + + // Skip the blank identifier used for group declarations + if field.Name == "_" { + continue + } + + if !field.IsExported() { + continue + } + + // Recurse into embedded structs + ft := field.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if field.Anonymous && ft.Kind() == reflect.Struct { + if err := validateStructTags(ft); err != nil { + errs = append(errs, err.Error()) + } + continue + } + + // Parse the raw tag string into key:"value" pairs and check each key + tagStr := string(field.Tag) + for tagStr != "" { + // Skip leading spaces + tagStr = strings.TrimLeft(tagStr, " ") + if tagStr == "" { + break + } + + // Find the key (everything before the colon) + colon := strings.Index(tagStr, ":") + if colon < 0 { + break + } + key := tagStr[:colon] + tagStr = tagStr[colon+1:] + + // Skip past the quoted value + if len(tagStr) == 0 || tagStr[0] != '"' { + break + } + end := 1 + for end < len(tagStr) && tagStr[end] != '"' { + if tagStr[end] == '\\' && end+1 < len(tagStr) { + end += 2 + } else { + end++ + } + } + if end >= len(tagStr) { + break + } + tagStr = tagStr[end+1:] + + if !knownTags[key] { + errs = append(errs, fmt.Sprintf("unknown struct tag %q on field %s in %s", key, field.Name, rt.Name())) + } + } + } + if len(errs) > 0 { + return fmt.Errorf("invalid struct tags:\n %s", strings.Join(errs, "\n ")) + } + return nil +} + // setFieldValue sets a string value to a reflect.Value based on its type func setFieldValue(fieldValue reflect.Value, value string) error { switch fieldValue.Kind() { @@ -1712,6 +1801,11 @@ func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error { rt := rv.Type() + // Validate that all struct tags are ones we know how to handle + if err := validateStructTags(rt); err != nil { + return err + } + // First pass: check for self-declared group via `_ struct{} \`group:"..."\`` for i := 0; i < rt.NumField(); i++ { field := rt.Field(i) diff --git a/mflags_test.go b/mflags_test.go index dc11738..de58a33 100644 --- a/mflags_test.go +++ b/mflags_test.go @@ -859,6 +859,90 @@ func TestFromStructErrors(t *testing.T) { assert.Contains(t, err.Error(), "pointer to a struct") } +func TestFromStructRejectsUnknownTags(t *testing.T) { + t.Run("single unknown tag", func(t *testing.T) { + type Opts struct { + Args struct{} `positional-args:"yes"` + } + fs := NewFlagSet("test") + err := fs.FromStruct(&Opts{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), `unknown struct tag "positional-args"`) + assert.Contains(t, err.Error(), "Args") + }) + + t.Run("multiple unknown tags across fields", func(t *testing.T) { + type Opts struct { + Name string `long:"name" env:"MY_NAME"` + Port int `long:"port" required:"true"` + } + fs := NewFlagSet("test") + err := fs.FromStruct(&Opts{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), `"env"`) + assert.Contains(t, err.Error(), `"required"`) + }) + + t.Run("non-mflags tag like json", func(t *testing.T) { + type Opts struct { + Name string `json:"name"` + } + fs := NewFlagSet("test") + err := fs.FromStruct(&Opts{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), `"json"`) + }) + + t.Run("embedded struct with unknown tags", func(t *testing.T) { + type Inner struct { + Addr string `long:"addr" yaml:"addr"` + } + type Outer struct { + Inner + Verbose bool `long:"verbose"` + } + fs := NewFlagSet("test") + err := fs.FromStruct(&Outer{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), `"yaml"`) + assert.Contains(t, err.Error(), "Addr") + }) + + t.Run("unexported fields with arbitrary tags are ignored", func(t *testing.T) { + type Opts struct { + Verbose bool `long:"verbose"` + internal string `json:"internal" xml:"internal"` + } + fs := NewFlagSet("test") + err := fs.FromStruct(&Opts{}) + assert.NoError(t, err) + }) + + t.Run("valid tags pass", func(t *testing.T) { + type Opts struct { + Name string `long:"name" short:"n" default:"world" description:"Your name"` + Verbose bool `long:"verbose" usage:"Enable verbose"` + Env string `long:"env" choice:"dev" choice:"prod"` + File string `position:"0"` + Rest []string `rest:"true"` + Unknown []string `unknown:"true"` + } + fs := NewFlagSet("test") + err := fs.FromStruct(&Opts{}) + assert.NoError(t, err) + }) + + t.Run("group tag on blank field is allowed", func(t *testing.T) { + type Opts struct { + _ struct{} `group:"Server Options"` + Verbose bool `long:"verbose"` + } + fs := NewFlagSet("test") + err := fs.FromStruct(&Opts{}) + assert.NoError(t, err) + }) +} + type CombinedUsageConfig struct { Verbose bool `long:"verbose" short:"v"` Files []string `long:"files" short:"f"`