Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions mflags.go
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,95 @@ func getTagValues(tag reflect.StructTag, key string) []string {
return values
}

// knownTags is the set of struct tag keys that FromStruct knows how to handle.
// Keep this in sync with the Tag.Get() and getTagValues() calls in FromStruct.
var knownTags = map[string]bool{
"long": true,
"short": true,
"default": true,
"usage": true,
"description": true,
"choice": true,
"position": true,
"rest": true,
"unknown": true,
"group": true,
}

// validateStructTags checks that every struct tag on exported fields is one
// that FromStruct actually reads. It returns an error listing all unrecognized
// tags so the caller can fix them all in one pass.
func validateStructTags(rt reflect.Type) error {
var errs []string
for i := 0; i < rt.NumField(); i++ {
field := rt.Field(i)

// Skip the blank identifier used for group declarations
if field.Name == "_" {
continue
}

if !field.IsExported() {
continue
}

// Recurse into embedded structs
ft := field.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
if field.Anonymous && ft.Kind() == reflect.Struct {
if err := validateStructTags(ft); err != nil {
errs = append(errs, err.Error())
}
continue
}

// Parse the raw tag string into key:"value" pairs and check each key
tagStr := string(field.Tag)
for tagStr != "" {
// Skip leading spaces
tagStr = strings.TrimLeft(tagStr, " ")
if tagStr == "" {
break
}

// Find the key (everything before the colon)
colon := strings.Index(tagStr, ":")
if colon < 0 {
break
}
key := tagStr[:colon]
tagStr = tagStr[colon+1:]

// Skip past the quoted value
if len(tagStr) == 0 || tagStr[0] != '"' {
break
}
end := 1
for end < len(tagStr) && tagStr[end] != '"' {
if tagStr[end] == '\\' && end+1 < len(tagStr) {
end += 2
} else {
end++
}
}
if end >= len(tagStr) {
break
}
tagStr = tagStr[end+1:]

if !knownTags[key] {
errs = append(errs, fmt.Sprintf("unknown struct tag %q on field %s in %s", key, field.Name, rt.Name()))
}
}
}
if len(errs) > 0 {
return fmt.Errorf("invalid struct tags:\n %s", strings.Join(errs, "\n "))
}
return nil
}

// setFieldValue sets a string value to a reflect.Value based on its type
func setFieldValue(fieldValue reflect.Value, value string) error {
switch fieldValue.Kind() {
Expand Down Expand Up @@ -1712,6 +1801,11 @@ func (f *FlagSet) FromStruct(v any, opts ...FromStructOption) error {

rt := rv.Type()

// Validate that all struct tags are ones we know how to handle
if err := validateStructTags(rt); err != nil {
return err
}

// First pass: check for self-declared group via `_ struct{} \`group:"..."\``
for i := 0; i < rt.NumField(); i++ {
field := rt.Field(i)
Expand Down
84 changes: 84 additions & 0 deletions mflags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,90 @@ func TestFromStructErrors(t *testing.T) {
assert.Contains(t, err.Error(), "pointer to a struct")
}

func TestFromStructRejectsUnknownTags(t *testing.T) {
t.Run("single unknown tag", func(t *testing.T) {
type Opts struct {
Args struct{} `positional-args:"yes"`
}
fs := NewFlagSet("test")
err := fs.FromStruct(&Opts{})
assert.Error(t, err)
assert.Contains(t, err.Error(), `unknown struct tag "positional-args"`)
assert.Contains(t, err.Error(), "Args")
})

t.Run("multiple unknown tags across fields", func(t *testing.T) {
type Opts struct {
Name string `long:"name" env:"MY_NAME"`
Port int `long:"port" required:"true"`
}
fs := NewFlagSet("test")
err := fs.FromStruct(&Opts{})
assert.Error(t, err)
assert.Contains(t, err.Error(), `"env"`)
assert.Contains(t, err.Error(), `"required"`)
})

t.Run("non-mflags tag like json", func(t *testing.T) {
type Opts struct {
Name string `json:"name"`
}
fs := NewFlagSet("test")
err := fs.FromStruct(&Opts{})
assert.Error(t, err)
assert.Contains(t, err.Error(), `"json"`)
})

t.Run("embedded struct with unknown tags", func(t *testing.T) {
type Inner struct {
Addr string `long:"addr" yaml:"addr"`
}
type Outer struct {
Inner
Verbose bool `long:"verbose"`
}
fs := NewFlagSet("test")
err := fs.FromStruct(&Outer{})
assert.Error(t, err)
assert.Contains(t, err.Error(), `"yaml"`)
assert.Contains(t, err.Error(), "Addr")
})

t.Run("unexported fields with arbitrary tags are ignored", func(t *testing.T) {
type Opts struct {
Verbose bool `long:"verbose"`
internal string `json:"internal" xml:"internal"`
}
fs := NewFlagSet("test")
err := fs.FromStruct(&Opts{})
assert.NoError(t, err)
})

t.Run("valid tags pass", func(t *testing.T) {
type Opts struct {
Name string `long:"name" short:"n" default:"world" description:"Your name"`
Verbose bool `long:"verbose" usage:"Enable verbose"`
Env string `long:"env" choice:"dev" choice:"prod"`
File string `position:"0"`
Rest []string `rest:"true"`
Unknown []string `unknown:"true"`
}
fs := NewFlagSet("test")
err := fs.FromStruct(&Opts{})
assert.NoError(t, err)
})

t.Run("group tag on blank field is allowed", func(t *testing.T) {
type Opts struct {
_ struct{} `group:"Server Options"`
Verbose bool `long:"verbose"`
}
fs := NewFlagSet("test")
err := fs.FromStruct(&Opts{})
assert.NoError(t, err)
})
}

type CombinedUsageConfig struct {
Verbose bool `long:"verbose" short:"v"`
Files []string `long:"files" short:"f"`
Expand Down
Loading