From 183febc35f4923ee9ea0e206029fdbdf88021970 Mon Sep 17 00:00:00 2001 From: Jocelyn Giroux Date: Wed, 13 Nov 2024 08:26:18 -0500 Subject: [PATCH] =?UTF-8?q?Fix=20bug=20in=20unmanaged=20arguments=20proces?= =?UTF-8?q?sing=E2=80=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix a bug that was causing some arguments to be eaten when handling Unmanaged arguments - Add some tests - Fix deprecated warnings --- app_test.go | 12 +++++ flags.go | 134 +++++++++++++++++++++++------------------------ parser.go | 9 +--- parser_test.go | 136 ++++++++++++++++++++++++------------------------ parsers_test.go | 3 +- 5 files changed, 150 insertions(+), 144 deletions(-) diff --git a/app_test.go b/app_test.go index df31719..be7ee1a 100644 --- a/app_test.go +++ b/app_test.go @@ -592,6 +592,10 @@ func TestUnmanaged(t *testing.T) { b: make([]bool, nbElements), s: make([]string, nbElements), } + + // We generate a set of valid flags: + // bool-1, bool-2, ... to bool-{nbElements} with short names -a, -b, ..., -e + // string-1, string-2, ... to string-{nbElements} with short names -A, -B, ..., -E for i := 0; i < nbElements; i++ { a.Flag(fmt.Sprintf("bool-%d", i+1), "").Short(rune('a' + i)).BoolVar(&a.b[i]) a.Flag(fmt.Sprintf("string-%d", i+1), "").Short(rune('A' + i)).StringVar(&a.s[i]) @@ -639,6 +643,14 @@ func TestUnmanaged(t *testing.T) { []bool{true, true, false, true, true}, []string{"", "", "", "", ""}, []string{"-cX"}, nil}, + {"Bad switch mixed with long", true, "-ab -cX-long -de", + []bool{true, true, false, true, true}, + []string{"", "", "", "", ""}, + []string{"-cX-long"}, nil}, + {"Bad switch mixed with very long", true, "-ab -cX-very-long -de", + []bool{true, true, false, true, true}, + []string{"", "", "", "", ""}, + []string{"-cX-very-long"}, nil}, {"Many bad switches with args", true, "-ab -var x=1 -var y=2 -de -var z=3 test", []bool{true, true, false, true, true}, []string{"", "", "", "", ""}, diff --git a/flags.go b/flags.go index 3e5e5bb..ab04fc9 100644 --- a/flags.go +++ b/flags.go @@ -2,7 +2,6 @@ package kingpin import ( "fmt" - "strings" ) type flagGroup struct { @@ -104,86 +103,87 @@ func (f *flagGroup) checkDuplicates() error { } func (f *flagGroup) parse(context *ParseContext) (*FlagClause, error) { - var token *Token - for { - token = context.Peek() - switch token.Type { - case TokenEOL: - return nil, nil + token := context.Peek() + switch token.Type { + case TokenEOL: + return nil, nil + + case TokenLong, TokenShort: + flagToken := token + defaultValue := "" + var flag *FlagClause + var ok bool + var err error + invert := false + + name := token.Value + if token.Type == TokenLong { + if flag, invert, err = f.getFlagAlias(name); err != nil { + return nil, err + } else if flag == nil { + err = fmt.Errorf("unknown long flag '%s'", flagToken) + } + } else if flag, ok = f.short[name]; !ok { + err = fmt.Errorf("unknown short flag '%s'", flagToken) + } - case TokenLong, TokenShort: - flagToken := token - defaultValue := "" - var flag *FlagClause - var ok bool - var err error - invert := false + if err != nil { + if context.appUnmanagedArgs == nil { + return nil, err + } - name := token.Value + // The current flag is not managed by the application, but we gather it anyway in the unmanaged args + current := context.current() if token.Type == TokenLong { - if flag, invert, err = f.getFlagAlias(name); err != nil { - return nil, err - } else if flag == nil { - err = fmt.Errorf("unknown long flag '%s'", flagToken) + context.Next() + } else { + remainingArgs := "-" + if len(context.args) > 0 && context.rawArgs[len(context.rawArgs)-len(context.args)] == current { + // There are more short flags in the current element + remainingArgs = context.args[0] } - } else if flag, ok = f.short[name]; !ok { - err = fmt.Errorf("unknown short flag '%s'", flagToken) - } - if err != nil { - if context.appUnmanagedArgs == nil { - return nil, err - } - current := context.current() - if token.Type == TokenLong { - context.Next() - } else { - // We have to remove all previous elements from the same short flag element - pos := strings.Index(current, token.Value) - 1 - if pos < 0 { - return nil, err - } - context.argi -= pos - context.Elements = context.Elements[:len(context.Elements)-pos] - for x := len(current) - pos - 1; x > 0; x-- { - // We skip all remaining elements of the group - purgedToken := context.Next() - // This error isn't supposed to be possible, but let's handle it anyway. - if purgedToken.Type == TokenLong { - err = fmt.Errorf("while skipping unmanaged shorts flags, skipped long flag '%s' from '%s'", purgedToken.Value, current) - return nil, err - } + // We remove all previous elements from the same short flags group + previousElementsCount := len(current) - len(remainingArgs) - 1 + context.Elements = context.Elements[:len(context.Elements)-previousElementsCount] + + // We consume the remaining short flags of the current group + for i := 0; i < len(remainingArgs); i++ { + // We consume the remaining short flags of the current group + if consumed := context.Next(); consumed.Type != TokenShort { + break } } - context.appUnmanagedArgs.Unmanaged = append(context.appUnmanagedArgs.Unmanaged, current) - return nil, nil } - context.Next() - flag.isSetByUser() + context.appUnmanagedArgs.Unmanaged = append(context.appUnmanagedArgs.Unmanaged, current) + return nil, nil + } - if fb, ok := flag.value.(boolFlag); ok && fb.IsBoolFlag() { - if invert { - defaultValue = "false" - } else { - defaultValue = "true" - } + context.Next() + flag.isSetByUser() + + if fb, ok := flag.value.(boolFlag); ok && fb.IsBoolFlag() { + if invert { + defaultValue = "false" } else { - token = context.Peek() - if token.Type != TokenArg { - context.Push(token) - return nil, fmt.Errorf("expected argument for flag '%s'", flagToken) - } - context.Next() - defaultValue = token.Value + defaultValue = "true" } + } else { + token = context.Peek() + if token.Type != TokenArg { + context.Push(token) + return nil, fmt.Errorf("expected argument for flag '%s'", flagToken) + } + context.Next() + defaultValue = token.Value + } - context.matchedFlag(flag, defaultValue) - return flag, nil + context.matchedFlag(flag, defaultValue) + return flag, nil - default: - return nil, nil - } + default: + return nil, nil } } diff --git a/parser.go b/parser.go index 08b7cb8..c1f17df 100644 --- a/parser.go +++ b/parser.go @@ -192,13 +192,8 @@ func (p *ParseContext) Next() *Token { return p.Next() } - if strings.HasPrefix(arg, "--") || (strings.HasPrefix(arg, "-") && strings.Count(arg, "-") > 1) { - var parts []string - if strings.HasPrefix(arg, "--") { - parts = strings.SplitN(arg[2:], "=", 2) - } else { - parts = strings.SplitN(arg[1:], "=", 2) - } + if strings.HasPrefix(arg, "--") { + parts := strings.SplitN(arg[2:], "=", 2) token := &Token{p.argi, TokenLong, parts[0]} if len(parts) == 2 { p.Push(&Token{p.argi, TokenArg, parts[1]}) diff --git a/parser_test.go b/parser_test.go index 0167a89..a113068 100644 --- a/parser_test.go +++ b/parser_test.go @@ -1,7 +1,6 @@ package kingpin import ( - "io/ioutil" "os" "testing" @@ -9,7 +8,7 @@ import ( ) func TestParserExpandFromFile(t *testing.T) { - f, err := ioutil.TempFile("", "") + f, err := os.CreateTemp("", "") assert.NoError(t, err) defer os.Remove(f.Name()) f.WriteString("hello\nworld\n") @@ -26,7 +25,7 @@ func TestParserExpandFromFile(t *testing.T) { } func TestParserExpandFromFileLeadingArg(t *testing.T) { - f, err := ioutil.TempFile("", "") + f, err := os.CreateTemp("", "") assert.NoError(t, err) defer os.Remove(f.Name()) f.WriteString("hello\nworld\n") @@ -45,7 +44,7 @@ func TestParserExpandFromFileLeadingArg(t *testing.T) { } func TestParserExpandFromFileTrailingArg(t *testing.T) { - f, err := ioutil.TempFile("", "") + f, err := os.CreateTemp("", "") assert.NoError(t, err) defer os.Remove(f.Name()) f.WriteString("hello\nworld\n") @@ -64,7 +63,7 @@ func TestParserExpandFromFileTrailingArg(t *testing.T) { } func TestParserExpandFromFileMultipleSurroundingArgs(t *testing.T) { - f, err := ioutil.TempFile("", "") + f, err := os.CreateTemp("", "") assert.NoError(t, err) defer os.Remove(f.Name()) f.WriteString("hello\nworld\n") @@ -85,7 +84,7 @@ func TestParserExpandFromFileMultipleSurroundingArgs(t *testing.T) { } func TestParserExpandFromFileMultipleFlags(t *testing.T) { - f, err := ioutil.TempFile("", "") + f, err := os.CreateTemp("", "") assert.NoError(t, err) defer os.Remove(f.Name()) f.WriteString("--flag1=f1\n--flag2=f2\n") @@ -121,66 +120,67 @@ func TestParseContextPush(t *testing.T) { assert.Equal(t, "bar", b.Value) } -func TestAppParseSingleThenDoubleDashFlags(t *testing.T) { - app := New("test", "") - app.allowUnmanaged = true - app.Command("foo", "") - - _, err := app.ParseContext([]string{"foo", "-single-dash", "--double-dash"}) - assert.Nil(t, err) - assert.Equal(t, []string{"-single-dash", "--double-dash"}, app.Unmanaged) -} - -func TestAppParseTwoSingleDashFlags(t *testing.T) { - app := New("test", "") - app.allowUnmanaged = true - app.Command("foo", "") - - _, err := app.ParseContext([]string{"foo", "-short-flag", "-verylongshort-flag"}) - assert.Nil(t, err) - assert.Equal(t, []string{"-short-flag", "-verylongshort-flag"}, app.Unmanaged) -} - -func TestAppParseDoubleThenSingleDashFlags(t *testing.T) { - app := New("test", "") - app.allowUnmanaged = true - app.Command("foo", "") - - _, err := app.ParseContext([]string{"foo", "--double-dash", "-single-dash"}) - assert.Nil(t, err) - assert.Equal(t, []string{"--double-dash", "-single-dash"}, app.Unmanaged) -} - -func TestAppParseVerboseVarFlags(t *testing.T) { - app := New("test", "") - app.allowUnmanaged = true - app.Command("foo", "") - app.Flag("verbose", "").Short('v').Bool() - - _, err := app.ParseContext([]string{"foo", "-v", "-var"}) - assert.Nil(t, err) - assert.Equal(t, []string{"-var"}, app.Unmanaged) -} - -func TestAppParseUnmanagedVarWithTwoManagedFlags(t *testing.T) { - app := New("test", "") - app.allowUnmanaged = true - app.Command("foo", "") - app.Flag("verbose", "").Short('v').Bool() - app.Flag("aflag", "").Short('a').Bool() - - _, err := app.ParseContext([]string{"foo", "-var"}) - assert.Nil(t, err) - assert.Equal(t, []string{"-var"}, app.Unmanaged) -} - -func TestAppParseShortLongFlags(t *testing.T) { - app := New("test", "") - app.allowUnmanaged = true - app.Command("foo", "") - app.Flag("verbose-level", "").Short('v').Bool() - - ctx, err := app.ParseContext([]string{"foo", "-verbose-level"}) - assert.Nil(t, err) - assert.Len(t, ctx.Elements, 2) +func TestAppParseFlags(t *testing.T) { + tests := []struct { + name string + args []string + unmanaged []string + elementsLen int + }{ + { + name: "Single then double dash flags", + args: []string{"foo", "-single-dash", "--double-dash"}, + unmanaged: []string{"-single-dash", "--double-dash"}, + }, + { + name: "Two single dash flags", + args: []string{"foo", "--", "-short-flag", "-verylongshort-flag"}, + unmanaged: []string{"-short-flag", "-verylongshort-flag"}, + }, + { + name: "Double then single dash flags", + args: []string{"foo", "--double-dash", "-single-dash"}, + unmanaged: []string{"--double-dash", "-single-dash"}, + }, + { + name: "Verbose var flags", + args: []string{"foo", "-v", "-var"}, + unmanaged: []string{"-var"}, + }, + { + name: "Unmanaged var", + args: []string{"foo", "-var"}, + unmanaged: []string{"-var"}, + }, + { + name: "Long flag as short flag", + args: []string{"foo", "-test", "-verbose-level", "-another-flag"}, + unmanaged: []string{"-test", "-verbose-level", "-another-flag"}, + }, + { + name: "Short pseudo long flags", + args: []string{"foo", "-this_is_a_very_long-flag", "-this is not really a flag"}, + unmanaged: []string{"-this_is_a_very_long-flag", "-this is not really a flag"}, + elementsLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + app := New("test", "") + app.allowUnmanaged = true + app.Command("foo", "") + app.Flag("verbose-level", "").Short('v').Alias("verbose").Bool() + app.Flag("aflag", "").Short('a').Bool() + + ctx, err := app.ParseContext(tt.args) + assert.Nil(t, err) + if tt.unmanaged != nil { + assert.Equal(t, tt.unmanaged, app.Unmanaged) + } + if tt.elementsLen > 0 { + assert.Len(t, ctx.Elements, tt.elementsLen) + } + }) + } } diff --git a/parsers_test.go b/parsers_test.go index 81708c7..fcb4be0 100644 --- a/parsers_test.go +++ b/parsers_test.go @@ -1,7 +1,6 @@ package kingpin import ( - "io/ioutil" "net" "net/url" "os" @@ -53,7 +52,7 @@ func TestParseURL(t *testing.T) { } func TestParseExistingFile(t *testing.T) { - f, err := ioutil.TempFile("", "") + f, err := os.CreateTemp("", "") if err != nil { t.Fatal(err) }