diff --git a/.gitignore b/.gitignore index a1a027c..1a6d520 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ # IDE Settings /.idea /.vscode -/.vs \ No newline at end of file +/.vs + +examples/basic/basic +examples/basic/basic.exe \ No newline at end of file diff --git a/enum_slice_var.go b/enum_slice_var.go new file mode 100644 index 0000000..137c2be --- /dev/null +++ b/enum_slice_var.go @@ -0,0 +1,30 @@ +package goflags + +import ( + "fmt" + "strings" +) + +type EnumSliceVar struct { + allowedTypes AllowdTypes + value *[]string +} + +func (e *EnumSliceVar) String() string { + if e.value != nil { + return strings.Join(*e.value, ",") + } + return "" +} + +func (e *EnumSliceVar) Set(value string) error { + values := strings.Split(value, ",") + for _, v := range values { + _, ok := e.allowedTypes[v] + if !ok { + return fmt.Errorf("allowed values are %v", e.allowedTypes.String()) + } + } + *e.value = values + return nil +} diff --git a/enum_slice_var_test.go b/enum_slice_var_test.go new file mode 100644 index 0000000..fe651f4 --- /dev/null +++ b/enum_slice_var_test.go @@ -0,0 +1,61 @@ +package goflags + +import ( + "os" + "os/exec" + "testing" + + "github.com/stretchr/testify/assert" +) + +var enumSliceData []string + +func TestEnumSliceVar(t *testing.T) { + t.Run("Test with single value", func(t *testing.T) { + flagSet := NewFlagSet() + flagSet.EnumSliceVar(&enumSliceData, "enum", []EnumVariable{Type1}, "enum", AllowdTypes{"type1": Type1, "type2": Type2}) + os.Args = []string{ + os.Args[0], + "--enum", "type1", + } + err := flagSet.Parse() + assert.Nil(t, err) + assert.Equal(t, []string{"type1"}, enumSliceData) + tearDown(t.Name()) + }) + + t.Run("Test with multiple value", func(t *testing.T) { + flagSet := NewFlagSet() + flagSet.EnumSliceVar(&enumSliceData, "enum", []EnumVariable{Type1}, "enum", AllowdTypes{"type1": Type1, "type2": Type2}) + os.Args = []string{ + os.Args[0], + "--enum", "type1,type2", + } + err := flagSet.Parse() + assert.Nil(t, err) + assert.Equal(t, []string{"type1", "type2"}, enumSliceData) + tearDown(t.Name()) + }) + + t.Run("Test with invalid value", func(t *testing.T) { + if os.Getenv("IS_SUB_PROCESS") == "1" { + flagSet := NewFlagSet() + + flagSet.EnumSliceVar(&enumSliceData, "enum", []EnumVariable{Nil}, "enum", AllowdTypes{"type1": Type1, "type2": Type2}) + os.Args = []string{ + os.Args[0], + "--enum", "type3", + } + _ = flagSet.Parse() + return + } + cmd := exec.Command(os.Args[0], "-test.run=TestFailEnumVar") + cmd.Env = append(os.Environ(), "IS_SUB_PROCESS=1") + err := cmd.Run() + if e, ok := err.(*exec.ExitError); ok && !e.Success() { + return + } + t.Fatalf("process ran with err %v, want exit error", err) + tearDown(t.Name()) + }) +} diff --git a/examples/basic/main.go b/examples/basic/main.go index 909504f..e7f3ba1 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -16,6 +16,7 @@ type Options struct { fileSize goflags.Size duration time.Duration rls goflags.RateLimitMap + severity []string } func main() { @@ -37,6 +38,7 @@ func main() { flagSet.CallbackVarP(CheckUpdate, "update", "ut", "update this tool to latest version"), flagSet.SizeVarP(&testOptions.fileSize, "max-size", "ms", "", "max file size"), flagSet.DurationVar(&testOptions.duration, "timeout", time.Hour, "timeout"), + flagSet.EnumSliceVarP(&testOptions.severity, "severity", "s", []goflags.EnumVariable{2}, "severity of the scan", goflags.AllowdTypes{"low": goflags.EnumVariable(0), "medium": goflags.EnumVariable(1), "high": goflags.EnumVariable(2)}), ) flagSet.SetCustomHelpText("EXAMPLE USAGE:\ngo run ./examples/basic [OPTIONS]") @@ -45,5 +47,11 @@ func main() { } // ratelimits value is - fmt.Printf("Got RateLimits: %+v\n", testOptions.rls) + if len(testOptions.rls.AsMap()) > 0 { + fmt.Printf("Got RateLimits: %+v\n", testOptions.rls) + } + + if len(testOptions.severity) > 0 { + fmt.Printf("Got Severity: %+v\n", testOptions.severity) + } } diff --git a/goflags.go b/goflags.go index 37ae4db..5d5acac 100644 --- a/goflags.go +++ b/goflags.go @@ -486,6 +486,41 @@ func (flagSet *FlagSet) EnumVarP(field *string, long, short string, defaultValue return flagData } +// EnumVar adds a enum flag with a longname +func (flagSet *FlagSet) EnumSliceVar(field *[]string, long string, defaultValues []EnumVariable, usage string, allowedTypes AllowdTypes) *FlagData { + return flagSet.EnumSliceVarP(field, long, "", defaultValues, usage, allowedTypes) +} + +// EnumVarP adds a enum flag with a shortname and longname +func (flagSet *FlagSet) EnumSliceVarP(field *[]string, long, short string, defaultValues []EnumVariable, usage string, allowedTypes AllowdTypes) *FlagData { + var defaults []string + for k, v := range allowedTypes { + for _, defaultValue := range defaultValues { + if v == defaultValue { + defaults = append(defaults, k) + } + } + } + if len(defaults) == 0 { + panic("undefined default value") + } + + *field = defaults + flagData := &FlagData{ + usage: usage, + long: long, + defaultValue: strings.Join(*field, ","), + } + if short != "" { + flagData.short = short + flagSet.CommandLine.Var(&EnumSliceVar{allowedTypes, field}, short, usage) + flagSet.flagKeys.Set(short, flagData) + } + flagSet.CommandLine.Var(&EnumSliceVar{allowedTypes, field}, long, usage) + flagSet.flagKeys.Set(long, flagData) + return flagData +} + func (flagSet *FlagSet) usageFunc() { var helpAsked bool diff --git a/goflags_test.go b/goflags_test.go index 4d49764..f818c01 100644 --- a/goflags_test.go +++ b/goflags_test.go @@ -3,6 +3,7 @@ package goflags import ( "bytes" "flag" + "fmt" "os" "reflect" "strconv" @@ -87,6 +88,7 @@ func TestUsageOrder(t *testing.T) { var intData int var boolData bool var enumData string + var enumSliceData []string flagSet.SetGroup("String", "String") flagSet.StringVar(&stringData, "string-value", "", "String example value example").Group("String") @@ -119,6 +121,12 @@ func TestUsageOrder(t *testing.T) { "two": EnumVariable(2), }).Group("Enum") + flagSet.EnumSliceVarP(&enumSliceData, "enum-slice-with-default-value", "esn", []EnumVariable{EnumVariable(0)}, "Enum with default value(zero/one/two)", AllowdTypes{ + "zero": EnumVariable(0), + "one": EnumVariable(1), + "two": EnumVariable(2), + }).Group("Enum") + flagSet.SetGroup("Update", "Update") flagSet.CallbackVar(func() {}, "update", "update tool_1 to the latest released version").Group("Update") flagSet.CallbackVarP(func() {}, "disable-update-check", "duc", "disable automatic update check").Group("Update") @@ -134,6 +142,7 @@ func TestUsageOrder(t *testing.T) { resultOutput := output.String() actual := resultOutput[strings.Index(resultOutput, "Flags:\n"):] + fmt.Println(actual) expected := `Flags: @@ -158,7 +167,8 @@ BOOLEAN: -bool-with-default-value Bool with default value example (default true) -bwdv, -bool-with-default-value2 Bool with default value example #2 (default true) ENUM: - -en, -enum-with-default-value value Enum with default value(zero/one/two) (default zero) + -en, -enum-with-default-value value Enum with default value(zero/one/two) (default zero) + -esn, -enum-slice-with-default-value value Enum with default value(zero/one/two) (default zero) UPDATE: -update update tool_1 to the latest released version -duc, -disable-update-check disable automatic update check