Skip to content

Commit

Permalink
Fix bug in unmanaged arguments processing†
Browse files Browse the repository at this point in the history
- Fix a bug that was causing some arguments to be eaten when handling Unmanaged arguments
- Add some tests
- Fix deprecated warnings
  • Loading branch information
jocgir committed Nov 13, 2024
1 parent e3bb762 commit 183febc
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 144 deletions.
12 changes: 12 additions & 0 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,10 @@ func TestUnmanaged(t *testing.T) {
b: make([]bool, nbElements),
s: make([]string, nbElements),
}

// We generate a set of valid flags:
// bool-1, bool-2, ... to bool-{nbElements} with short names -a, -b, ..., -e
// string-1, string-2, ... to string-{nbElements} with short names -A, -B, ..., -E
for i := 0; i < nbElements; i++ {
a.Flag(fmt.Sprintf("bool-%d", i+1), "").Short(rune('a' + i)).BoolVar(&a.b[i])
a.Flag(fmt.Sprintf("string-%d", i+1), "").Short(rune('A' + i)).StringVar(&a.s[i])
Expand Down Expand Up @@ -639,6 +643,14 @@ func TestUnmanaged(t *testing.T) {
[]bool{true, true, false, true, true},
[]string{"", "", "", "", ""},
[]string{"-cX"}, nil},
{"Bad switch mixed with long", true, "-ab -cX-long -de",
[]bool{true, true, false, true, true},
[]string{"", "", "", "", ""},
[]string{"-cX-long"}, nil},
{"Bad switch mixed with very long", true, "-ab -cX-very-long -de",
[]bool{true, true, false, true, true},
[]string{"", "", "", "", ""},
[]string{"-cX-very-long"}, nil},
{"Many bad switches with args", true, "-ab -var x=1 -var y=2 -de -var z=3 test",
[]bool{true, true, false, true, true},
[]string{"", "", "", "", ""},
Expand Down
134 changes: 67 additions & 67 deletions flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package kingpin

import (
"fmt"
"strings"
)

type flagGroup struct {
Expand Down Expand Up @@ -104,86 +103,87 @@ func (f *flagGroup) checkDuplicates() error {
}

func (f *flagGroup) parse(context *ParseContext) (*FlagClause, error) {
var token *Token
for {
token = context.Peek()
switch token.Type {
case TokenEOL:
return nil, nil
token := context.Peek()
switch token.Type {
case TokenEOL:
return nil, nil

case TokenLong, TokenShort:
flagToken := token
defaultValue := ""
var flag *FlagClause
var ok bool
var err error
invert := false

name := token.Value
if token.Type == TokenLong {
if flag, invert, err = f.getFlagAlias(name); err != nil {
return nil, err
} else if flag == nil {
err = fmt.Errorf("unknown long flag '%s'", flagToken)
}
} else if flag, ok = f.short[name]; !ok {
err = fmt.Errorf("unknown short flag '%s'", flagToken)
}

case TokenLong, TokenShort:
flagToken := token
defaultValue := ""
var flag *FlagClause
var ok bool
var err error
invert := false
if err != nil {
if context.appUnmanagedArgs == nil {
return nil, err
}

name := token.Value
// The current flag is not managed by the application, but we gather it anyway in the unmanaged args
current := context.current()
if token.Type == TokenLong {
if flag, invert, err = f.getFlagAlias(name); err != nil {
return nil, err
} else if flag == nil {
err = fmt.Errorf("unknown long flag '%s'", flagToken)
context.Next()
} else {
remainingArgs := "-"
if len(context.args) > 0 && context.rawArgs[len(context.rawArgs)-len(context.args)] == current {
// There are more short flags in the current element
remainingArgs = context.args[0]
}
} else if flag, ok = f.short[name]; !ok {
err = fmt.Errorf("unknown short flag '%s'", flagToken)
}

if err != nil {
if context.appUnmanagedArgs == nil {
return nil, err
}
current := context.current()
if token.Type == TokenLong {
context.Next()
} else {
// We have to remove all previous elements from the same short flag element
pos := strings.Index(current, token.Value) - 1
if pos < 0 {
return nil, err
}
context.argi -= pos
context.Elements = context.Elements[:len(context.Elements)-pos]
for x := len(current) - pos - 1; x > 0; x-- {
// We skip all remaining elements of the group
purgedToken := context.Next()
// This error isn't supposed to be possible, but let's handle it anyway.
if purgedToken.Type == TokenLong {
err = fmt.Errorf("while skipping unmanaged shorts flags, skipped long flag '%s' from '%s'", purgedToken.Value, current)
return nil, err
}
// We remove all previous elements from the same short flags group
previousElementsCount := len(current) - len(remainingArgs) - 1
context.Elements = context.Elements[:len(context.Elements)-previousElementsCount]

// We consume the remaining short flags of the current group
for i := 0; i < len(remainingArgs); i++ {
// We consume the remaining short flags of the current group
if consumed := context.Next(); consumed.Type != TokenShort {
break
}
}
context.appUnmanagedArgs.Unmanaged = append(context.appUnmanagedArgs.Unmanaged, current)
return nil, nil
}

context.Next()
flag.isSetByUser()
context.appUnmanagedArgs.Unmanaged = append(context.appUnmanagedArgs.Unmanaged, current)
return nil, nil
}

if fb, ok := flag.value.(boolFlag); ok && fb.IsBoolFlag() {
if invert {
defaultValue = "false"
} else {
defaultValue = "true"
}
context.Next()
flag.isSetByUser()

if fb, ok := flag.value.(boolFlag); ok && fb.IsBoolFlag() {
if invert {
defaultValue = "false"
} else {
token = context.Peek()
if token.Type != TokenArg {
context.Push(token)
return nil, fmt.Errorf("expected argument for flag '%s'", flagToken)
}
context.Next()
defaultValue = token.Value
defaultValue = "true"
}
} else {
token = context.Peek()
if token.Type != TokenArg {
context.Push(token)
return nil, fmt.Errorf("expected argument for flag '%s'", flagToken)
}
context.Next()
defaultValue = token.Value
}

context.matchedFlag(flag, defaultValue)
return flag, nil
context.matchedFlag(flag, defaultValue)
return flag, nil

default:
return nil, nil
}
default:
return nil, nil
}
}

Expand Down
9 changes: 2 additions & 7 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,8 @@ func (p *ParseContext) Next() *Token {
return p.Next()
}

if strings.HasPrefix(arg, "--") || (strings.HasPrefix(arg, "-") && strings.Count(arg, "-") > 1) {
var parts []string
if strings.HasPrefix(arg, "--") {
parts = strings.SplitN(arg[2:], "=", 2)
} else {
parts = strings.SplitN(arg[1:], "=", 2)
}
if strings.HasPrefix(arg, "--") {
parts := strings.SplitN(arg[2:], "=", 2)
token := &Token{p.argi, TokenLong, parts[0]}
if len(parts) == 2 {
p.Push(&Token{p.argi, TokenArg, parts[1]})
Expand Down
136 changes: 68 additions & 68 deletions parser_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package kingpin

import (
"io/ioutil"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func TestParserExpandFromFile(t *testing.T) {
f, err := ioutil.TempFile("", "")
f, err := os.CreateTemp("", "")
assert.NoError(t, err)
defer os.Remove(f.Name())
f.WriteString("hello\nworld\n")
Expand All @@ -26,7 +25,7 @@ func TestParserExpandFromFile(t *testing.T) {
}

func TestParserExpandFromFileLeadingArg(t *testing.T) {
f, err := ioutil.TempFile("", "")
f, err := os.CreateTemp("", "")
assert.NoError(t, err)
defer os.Remove(f.Name())
f.WriteString("hello\nworld\n")
Expand All @@ -45,7 +44,7 @@ func TestParserExpandFromFileLeadingArg(t *testing.T) {
}

func TestParserExpandFromFileTrailingArg(t *testing.T) {
f, err := ioutil.TempFile("", "")
f, err := os.CreateTemp("", "")
assert.NoError(t, err)
defer os.Remove(f.Name())
f.WriteString("hello\nworld\n")
Expand All @@ -64,7 +63,7 @@ func TestParserExpandFromFileTrailingArg(t *testing.T) {
}

func TestParserExpandFromFileMultipleSurroundingArgs(t *testing.T) {
f, err := ioutil.TempFile("", "")
f, err := os.CreateTemp("", "")
assert.NoError(t, err)
defer os.Remove(f.Name())
f.WriteString("hello\nworld\n")
Expand All @@ -85,7 +84,7 @@ func TestParserExpandFromFileMultipleSurroundingArgs(t *testing.T) {
}

func TestParserExpandFromFileMultipleFlags(t *testing.T) {
f, err := ioutil.TempFile("", "")
f, err := os.CreateTemp("", "")
assert.NoError(t, err)
defer os.Remove(f.Name())
f.WriteString("--flag1=f1\n--flag2=f2\n")
Expand Down Expand Up @@ -121,66 +120,67 @@ func TestParseContextPush(t *testing.T) {
assert.Equal(t, "bar", b.Value)
}

func TestAppParseSingleThenDoubleDashFlags(t *testing.T) {
app := New("test", "")
app.allowUnmanaged = true
app.Command("foo", "")

_, err := app.ParseContext([]string{"foo", "-single-dash", "--double-dash"})
assert.Nil(t, err)
assert.Equal(t, []string{"-single-dash", "--double-dash"}, app.Unmanaged)
}

func TestAppParseTwoSingleDashFlags(t *testing.T) {
app := New("test", "")
app.allowUnmanaged = true
app.Command("foo", "")

_, err := app.ParseContext([]string{"foo", "-short-flag", "-verylongshort-flag"})
assert.Nil(t, err)
assert.Equal(t, []string{"-short-flag", "-verylongshort-flag"}, app.Unmanaged)
}

func TestAppParseDoubleThenSingleDashFlags(t *testing.T) {
app := New("test", "")
app.allowUnmanaged = true
app.Command("foo", "")

_, err := app.ParseContext([]string{"foo", "--double-dash", "-single-dash"})
assert.Nil(t, err)
assert.Equal(t, []string{"--double-dash", "-single-dash"}, app.Unmanaged)
}

func TestAppParseVerboseVarFlags(t *testing.T) {
app := New("test", "")
app.allowUnmanaged = true
app.Command("foo", "")
app.Flag("verbose", "").Short('v').Bool()

_, err := app.ParseContext([]string{"foo", "-v", "-var"})
assert.Nil(t, err)
assert.Equal(t, []string{"-var"}, app.Unmanaged)
}

func TestAppParseUnmanagedVarWithTwoManagedFlags(t *testing.T) {
app := New("test", "")
app.allowUnmanaged = true
app.Command("foo", "")
app.Flag("verbose", "").Short('v').Bool()
app.Flag("aflag", "").Short('a').Bool()

_, err := app.ParseContext([]string{"foo", "-var"})
assert.Nil(t, err)
assert.Equal(t, []string{"-var"}, app.Unmanaged)
}

func TestAppParseShortLongFlags(t *testing.T) {
app := New("test", "")
app.allowUnmanaged = true
app.Command("foo", "")
app.Flag("verbose-level", "").Short('v').Bool()

ctx, err := app.ParseContext([]string{"foo", "-verbose-level"})
assert.Nil(t, err)
assert.Len(t, ctx.Elements, 2)
func TestAppParseFlags(t *testing.T) {
tests := []struct {
name string
args []string
unmanaged []string
elementsLen int
}{
{
name: "Single then double dash flags",
args: []string{"foo", "-single-dash", "--double-dash"},
unmanaged: []string{"-single-dash", "--double-dash"},
},
{
name: "Two single dash flags",
args: []string{"foo", "--", "-short-flag", "-verylongshort-flag"},
unmanaged: []string{"-short-flag", "-verylongshort-flag"},
},
{
name: "Double then single dash flags",
args: []string{"foo", "--double-dash", "-single-dash"},
unmanaged: []string{"--double-dash", "-single-dash"},
},
{
name: "Verbose var flags",
args: []string{"foo", "-v", "-var"},
unmanaged: []string{"-var"},
},
{
name: "Unmanaged var",
args: []string{"foo", "-var"},
unmanaged: []string{"-var"},
},
{
name: "Long flag as short flag",
args: []string{"foo", "-test", "-verbose-level", "-another-flag"},
unmanaged: []string{"-test", "-verbose-level", "-another-flag"},
},
{
name: "Short pseudo long flags",
args: []string{"foo", "-this_is_a_very_long-flag", "-this is not really a flag"},
unmanaged: []string{"-this_is_a_very_long-flag", "-this is not really a flag"},
elementsLen: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := New("test", "")
app.allowUnmanaged = true
app.Command("foo", "")
app.Flag("verbose-level", "").Short('v').Alias("verbose").Bool()
app.Flag("aflag", "").Short('a').Bool()

ctx, err := app.ParseContext(tt.args)
assert.Nil(t, err)
if tt.unmanaged != nil {
assert.Equal(t, tt.unmanaged, app.Unmanaged)
}
if tt.elementsLen > 0 {
assert.Len(t, ctx.Elements, tt.elementsLen)
}
})
}
}
Loading

0 comments on commit 183febc

Please sign in to comment.