diff --git a/kong.go b/kong.go index b85e145..3cb2e40 100644 --- a/kong.go +++ b/kong.go @@ -91,7 +91,7 @@ func New(grammar any, options ...Option) (*Kong, error) { }, } - options = append(options, Bind(k)) + options = append(options, Bind(k), Resolvers(EnvResolver())) for _, option := range options { if err := option.Apply(k); err != nil { diff --git a/model.go b/model.go index 065fcdd..3913239 100644 --- a/model.go +++ b/model.go @@ -3,7 +3,6 @@ package kong import ( "fmt" "math" - "os" "reflect" "strconv" "strings" @@ -377,19 +376,6 @@ func (v *Value) ApplyDefault() error { // Does not include resolvers. func (v *Value) Reset() error { v.Target.Set(reflect.Zero(v.Target.Type())) - if len(v.Tag.Envs) != 0 { - for _, env := range v.Tag.Envs { - envar, ok := os.LookupEnv(env) - // Parse the first non-empty ENV in the list - if ok { - err := v.Parse(ScanFromTokens(Token{Type: FlagValueToken, Value: envar}), v.Target) - if err != nil { - return fmt.Errorf("%s (from envar %s=%q)", err, env, envar) - } - return nil - } - } - } if v.HasDefault { return v.Parse(ScanFromTokens(Token{Type: FlagValueToken, Value: v.Default}), v.Target) } diff --git a/resolver.go b/resolver.go index 29be1b9..f994159 100644 --- a/resolver.go +++ b/resolver.go @@ -2,7 +2,9 @@ package kong import ( "encoding/json" + "fmt" "io" + "os" "strings" ) @@ -66,3 +68,61 @@ func snakeCase(name string) string { name = strings.Join(strings.Split(strings.Title(name), "-"), "") return strings.ToLower(name[:1]) + name[1:] } + +func EnvResolver() Resolver { + // Resolvers are typically only invoked for flags, as shown here: + // https://github.com/alecthomas/kong/blob/v1.6.0/context.go#L567 + // However, environment variable annotations can also apply to arguments, + // as demonstrated in this test: + // https://github.com/alecthomas/kong/blob/v1.6.0/kong_test.go#L1226-L1244 + // To handle this, we ensure that arguments are resolved as well. + // Since the resolution only needs to happen once, we use this boolean + // to track whether the resolution process has already been performed. + argsResolved := false + return ResolverFunc(func(context *Context, parent *Path, flag *Flag) (interface{}, error) { + if !argsResolved { + resolveArgs(context.Path) + argsResolved = true + } + for _, env := range flag.Tag.Envs { + envar, ok := os.LookupEnv(env) + // Parse the first non-empty ENV in the list + if ok { + return envar, nil + } + } + return nil, nil + }) +} + +func resolveArgs(paths []*Path) error { + for _, path := range paths { + if path.Command == nil { + continue + } + for _, positional := range path.Command.Positional { + if positional.Tag == nil { + continue + } + visitValue(positional) + } + if path.Command.Argument != nil { + visitValue(path.Command.Argument) + } + } + return nil +} + +func visitValue(value *Value) error { + for _, env := range value.Tag.Envs { + envar, ok := os.LookupEnv(env) + if !ok { + continue + } + token := Token{Type: FlagValueToken, Value: envar} + if err := value.Parse(ScanFromTokens(token), value.Target); err != nil { + return fmt.Errorf("%s (from envar %s=%q)", err, env, envar) + } + } + return nil +}