diff --git a/README.md b/README.md index e9075ba..cc99897 100644 --- a/README.md +++ b/README.md @@ -269,6 +269,28 @@ fmt.Println(args.UserIDs) map[john:123 mary:456] ``` +### Counting arguments + +```go +var args struct { + Verbosity int `arg:"repeated"` +} +arg.MustParse(&args) +fmt.Println(args.Verbosity) +``` + +```shell +./example -v -v -v # increments each time +3 +./example -vvvv # sets to the length of the option repeat +4 +./example -v=5 # sets directly to the value +5 +``` + +The field must be `int`-like according to `reflect.CanInt()` (e.g. `int`, `int32`, `int64`). A short option must be +provided. Note that you cannot do `-v 5` to set the value, it must be `-v=5`. + ### Version strings ```go diff --git a/parse.go b/parse.go index bf6784a..4596e60 100644 --- a/parse.go +++ b/parse.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "reflect" + "strconv" "strings" scalar "github.com/alexflint/go-scalar" @@ -57,6 +58,8 @@ type spec struct { defaultValue reflect.Value // default value for this option defaultString string // default value for this option, in string form to be displayed in help text placeholder string // placeholder string in help + noarg bool // whether this option has an argument (basically cheats cardinality check) + repeated bool // whether this is a `-xxxx` or `-x -x -x` counting option } // command represents a named subcommand, or the top-level command @@ -76,6 +79,15 @@ var ErrHelp = errors.New("help requested by user") // ErrVersion indicates that the builtin --version was provided var ErrVersion = errors.New("version requested by user") +// ErrRepeat indicates that a repeated option was not well-formed +var ErrRepeat = errors.New("mismatched repeat") + +// ErrNotInt indicates that a repeated option was not an `int`able field +var ErrNotInt = errors.New("repeats must be int") + +// ErrNoShortOption indicates that a repeated option was missing a short name +var ErrNoShortOption = errors.New("short option missing") + // for monkey patching in example and test code var mustParseExit = os.Exit var mustParseOut io.Writer = os.Stdout @@ -365,6 +377,8 @@ func cmdFromStruct(name string, dest path, t reflect.Type, envPrefix string) (*c case strings.HasPrefix(key, "--"): spec.long = key[2:] case strings.HasPrefix(key, "-"): + // This is tricky to handle - `repeated` must be before the short argument + // Or handle it as a post-hoc check. if len(key) > 2 { errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only", t.Name(), field.Name)) @@ -411,12 +425,31 @@ func cmdFromStruct(name string, dest path, t reflect.Type, envPrefix string) (*c cmd.subcommands = append(cmd.subcommands, subcmd) isSubcommand = true + case key == "noarg": + spec.noarg = true + case key == "repeated": + if !isIntable(field) { + errs = append(errs, ErrNotInt.Error()) // fmt.Sprintf("repeat only works on int-able fields: %s", field.Name)) + return false + } + spec.repeated = true default: errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag)) return false } } + if spec.repeated { + // If you don't specify an explicit short option, it'll be in `spec.long`. + if spec.short == "" && len(spec.long) > 1 { + errs = append(errs, ErrNoShortOption.Error()) + return false + } + + // Copy `long` to `short` and remove `long` + spec.short = spec.long + } + // placeholder is the string used in the help text like this: "--somearg PLACEHOLDER" placeholder, hasPlaceholder := field.Tag.Lookup("placeholder") if hasPlaceholder { @@ -444,6 +477,11 @@ func cmdFromStruct(name string, dest path, t reflect.Type, envPrefix string) (*c return false } + // special case - `noarg` and `repeated` have cardinality of `zero` + if spec.noarg || spec.repeated { + spec.cardinality = zero + } + defaultString, hasDefault := field.Tag.Lookup("default") if hasDefault { // we do not support default values for maps and slices @@ -717,6 +755,39 @@ func (p *Parser) process(args []string) error { continue } + // `-x=2` for a `repeat` flag sets the value directly. + if spec.repeated && value == "" { + // Bit faffy and cargo-culted from `scalar.ParseValue` + t := p.val(spec.dest) + if t.Kind() == reflect.Ptr { + if t.IsNil() { + t.Set(reflect.New(t.Type().Elem())) + } + t = t.Elem() + } + + // Check whether we're an `int` field before we use it as such. + if !t.CanInt() { + return ErrNotInt + } + + i := int(t.Int()) + + // Check for mismatches in `-xx...x` options. + if len(opt) >= 2 && strings.Count(opt, opt[0:1]) != len(opt) { + return ErrRepeat + } + + switch { + // Simple `-x` case means increment by one + case len(opt) == 1: + value = strconv.Itoa(i + 1) + // Must be `-xx...x` which sets the length. + default: + value = strconv.Itoa(len(opt)) + } + } + // if it's a flag and it has no value then set the value to true // use boolean because this takes account of TextUnmarshaler if spec.cardinality == zero && value == "" { @@ -845,6 +916,12 @@ func findOption(specs []*spec, name string) *spec { if spec.long == name || spec.short == name { return spec } + // Let's us find `-v` from `-vvvv`. We don't need to worry about + // finding `-v` from, e.g., `-vavavoom` because that'll be blocked later + // as a mismatched repeat. + if spec.repeated && spec.short == name[0:1] { + return spec + } } return nil } @@ -863,3 +940,12 @@ func findSubcommand(cmds []*command, name string) *command { } return nil } + +func isIntable(f reflect.StructField) bool { + ft := f.Type + z := reflect.Zero(ft) + if ft.Kind() == reflect.Ptr { + z = reflect.Zero(ft.Elem()) + } + return z.CanInt() +} diff --git a/parse_test.go b/parse_test.go index 249cbf3..fb1cc5c 100644 --- a/parse_test.go +++ b/parse_test.go @@ -3,6 +3,7 @@ package arg import ( "bytes" "encoding/json" + "errors" "fmt" "net" "net/mail" @@ -1779,3 +1780,132 @@ func TestExitFunctionAndOutStreamGetFilledIn(t *testing.T) { assert.NotNil(t, p.config.Exit) // go prohibits function pointer comparison assert.Equal(t, p.config.Out, os.Stdout) } + +type RepeatedTest struct { + optstring string + count_a int + count_c int + err error +} + +var reptests = []RepeatedTest{ + {"-a", 1, 0, nil}, + {"-aa", 2, 0, nil}, + {"-aaa", 3, 0, nil}, + {"-a -a -a", 3, 0, nil}, + {"-a=3", 3, 0, nil}, + {"-ac", 2, 0, errors.New("mismatched repeat")}, + {"-a -c", 1, 1, nil}, + {"-a -cc", 1, 2, nil}, + {"-a -aa -c -cc -ccc", 2, 3, nil}, // last option wins for "long" version + {"-bb", 0, 0, errors.New("unknown argument -bb")}, + {"-aab", 0, 0, errors.New("mismatched repeat")}, + {"-abba", 0, 0, errors.New("mismatched repeat")}, + {"-a -a -c -c -a -c", 3, 3, nil}, + {"-a -a -c -c -aa -cccc", 2, 4, nil}, + {"-aa -cc -a -a -c", 4, 3, nil}, + {"-aa -cc -a -a -c -aa -cc", 2, 2, nil}, + {"-aa -cc -a -a -c -a=1 -c=1", 1, 1, nil}, + {"-aa -cc -a -a -c -a=9 -c=7", 9, 7, nil}, + {"-aa -cc -a -a -c -a=0 -c=1", 0, 1, nil}, + {"-a=0 -c=1 -a -c", 1, 2, nil}, + {"-a=0 -c=1 -aa -ccc", 2, 3, nil}, + {"-a=0 -c=1 -aa -ccc -a -c", 3, 4, nil}, +} + +// TestRepeatedShort tests our counter parsing +func TestRepeatedShort(t *testing.T) { + for _, v := range reptests { + t.Run(fmt.Sprintf("repeat opts=%s counts=%d:%d", v.optstring, v.count_a, v.count_c), func(t *testing.T) { + var args struct { + A int `arg:"repeated,env"` + C *int `arg:"repeated,env"` + D int + F float64 + } + + err := parse(v.optstring, &args) + if v.err == nil { + require.NoError(t, err) + assert.Equal(t, v.count_a, args.A) + assert.Equal(t, 0, args.D) + + // If an option with `*int` type isn't encountered, the struct + // field will remain `nil` which is unhelpful here yet idiomatic. + if v.count_c > 0 { + assert.Equal(t, v.count_c, *args.C) + } else { + assert.Nil(t, args.C) + } + } else { + require.Error(t, err) + // Not ideal but you can't match two `errors.New(X)` even if `X` is identical. + require.Equal(t, v.err.Error(), err.Error()) + } + }) + } +} + +// TestRepeatedShortInt64 checks whether our counters work with `int64` too +func TestRepeatedShortInt64(t *testing.T) { + for _, v := range reptests { + t.Run(fmt.Sprintf("repeat opts=%s counts=%d:%d", v.optstring, v.count_a, v.count_c), func(t *testing.T) { + var args struct { + A int64 `arg:"repeated,env"` + C *int64 `arg:"repeated,env"` + D int + F float64 + } + + err := parse(v.optstring, &args) + if v.err == nil { + require.NoError(t, err) + assert.Equal(t, int64(v.count_a), args.A) + assert.Equal(t, 0, args.D) + + // If an option with `*int` type isn't encountered, the struct + // field will remain `nil` which is unhelpful here yet idiomatic. + if v.count_c > 0 { + assert.Equal(t, int64(v.count_c), *args.C) + } else { + assert.Nil(t, args.C) + } + } else { + require.Error(t, err) + // Not ideal but you can't match two `errors.New(X)` even if `X` is identical. + require.Equal(t, v.err.Error(), err.Error()) + } + }) + } +} + +// TestRepeatedNotInt tests our error handling for non-int repeats +func TestRepeatedNotInt(t *testing.T) { + var args struct { + A int `arg:"repeated,env"` + C *int `arg:"repeated,env"` + D int + F float64 `arg:"repeated"` + } + optstring := "-f" + + err := parse(optstring, &args) + require.Error(t, err) + require.Equal(t, ErrNotInt.Error(), err.Error()) +} + +// TestRepeatedLongNames tests what happens with no short option specified +func TestRepeatedLongNames(t *testing.T) { + var args struct { + Apples int `arg:"repeated,env"` + Cheese *int `arg:"repeated,env,-c"` + Durian int + F float64 + } + // Fails because `Apples` maps to `--apples` and we have no short option + optstring := "-a" + + err := parse(optstring, &args) + require.Error(t, err) + require.Equal(t, ErrNoShortOption.Error(), err.Error()) +}