Skip to content

Commit

Permalink
internal/cli: add support to overwrite config.toml values via cli fla…
Browse files Browse the repository at this point in the history
…gs (#1008)

* internal/cli: add support to overwrite config.toml via cli flags

* fix lint and refactor

* add extensive tests for flagset

* fix type conversion for big.Int

* add more tests for coverage

* add t.parallel

* internal/cli/flagset: handle flag conversion using interface

* internal/cli/flagset: fix test
  • Loading branch information
manav2401 authored Sep 21, 2023
1 parent ef38194 commit 8613ff1
Show file tree
Hide file tree
Showing 9 changed files with 550 additions and 117 deletions.
2 changes: 1 addition & 1 deletion internal/cli/dumpconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (c *DumpconfigCommand) Synopsis() string {
func (c *DumpconfigCommand) Run(args []string) int {
// Initialize an empty command instance to get flags
command := server.Command{}
flags := command.Flags()
flags := command.Flags(nil)

if err := flags.Parse(args); err != nil {
c.UI.Error(err.Error())
Expand Down
154 changes: 136 additions & 18 deletions internal/cli/flagset/flagset.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,42 @@ import (
"fmt"
"math/big"
"sort"
"strconv"
"strings"
"time"
)

type Flagset struct {
flags []*FlagVar
flags map[string]*FlagVar
set *flag.FlagSet
}

func NewFlagSet(name string) *Flagset {
f := &Flagset{
flags: []*FlagVar{},
flags: make(map[string]*FlagVar, 0),
set: flag.NewFlagSet(name, flag.ContinueOnError),
}

return f
}

// Updatable is a minimalistic representation of a flag which has
// the method `UpdateValue` implemented which can be called while
// overwriting flags.
type Updatable interface {
UpdateValue(string)
}

type FlagVar struct {
Name string
Usage string
Group string
Default any
Value Updatable
}

func (f *Flagset) addFlag(fl *FlagVar) {
f.flags = append(f.flags, fl)
f.flags[fl.Name] = fl
}

func (f *Flagset) Help() string {
Expand All @@ -51,9 +60,12 @@ func (f *Flagset) Help() string {
}

func (f *Flagset) GetAllFlags() []string {
flags := []string{}
for _, flag := range f.flags {
flags = append(flags, flag.Name)
i := 0
flags := make([]string, 0, len(f.flags))

for name := range f.flags {
flags[i] = name
i++
}

return flags
Expand Down Expand Up @@ -110,6 +122,33 @@ func (f *Flagset) Args() []string {
return f.set.Args()
}

// UpdateValue updates the underlying value of a flag
// given the flag name and value to update using pointer.
func (f *Flagset) UpdateValue(names []string, values []string) {
for i, name := range names {
if flag, ok := f.flags[name]; ok {
value := values[i]

// Call the underlying flag's `UpdateValue` method
flag.Value.UpdateValue(value)
}
}
}

// Visit visits all the set flags and returns the name and value
// in string to set later.
func (f *Flagset) Visit() ([]string, []string) {
names := make([]string, 0, len(f.flags))
values := make([]string, 0, len(f.flags))

f.set.Visit(func(flag *flag.Flag) {
names = append(names, flag.Name)
values = append(values, flag.Value.String())
})

return names, values
}

type BoolFlag struct {
Name string
Usage string
Expand All @@ -118,12 +157,19 @@ type BoolFlag struct {
Group string
}

func (b *BoolFlag) UpdateValue(value string) {
v, _ := strconv.ParseBool(value)

*b.Value = v
}

func (f *Flagset) BoolFlag(b *BoolFlag) {
f.addFlag(&FlagVar{
Name: b.Name,
Usage: b.Usage,
Group: b.Group,
Default: b.Default,
Value: b,
})
f.set.BoolVar(b.Value, b.Name, b.Default, b.Usage)
}
Expand All @@ -137,20 +183,26 @@ type StringFlag struct {
HideDefaultFromDoc bool
}

func (b *StringFlag) UpdateValue(value string) {
*b.Value = value
}

func (f *Flagset) StringFlag(b *StringFlag) {
if b.Default == "" || b.HideDefaultFromDoc {
f.addFlag(&FlagVar{
Name: b.Name,
Usage: b.Usage,
Group: b.Group,
Default: nil,
Value: b,
})
} else {
f.addFlag(&FlagVar{
Name: b.Name,
Usage: b.Usage,
Group: b.Group,
Default: b.Default,
Value: b,
})
}

Expand All @@ -165,12 +217,19 @@ type IntFlag struct {
Group string
}

func (b *IntFlag) UpdateValue(value string) {
v, _ := strconv.ParseInt(value, 10, 64)

*b.Value = int(v)
}

func (f *Flagset) IntFlag(i *IntFlag) {
f.addFlag(&FlagVar{
Name: i.Name,
Usage: i.Usage,
Group: i.Group,
Default: i.Default,
Value: i,
})
f.set.IntVar(i.Value, i.Name, i.Default, i.Usage)
}
Expand All @@ -183,12 +242,19 @@ type Uint64Flag struct {
Group string
}

func (b *Uint64Flag) UpdateValue(value string) {
v, _ := strconv.ParseUint(value, 10, 64)

*b.Value = v
}

func (f *Flagset) Uint64Flag(i *Uint64Flag) {
f.addFlag(&FlagVar{
Name: i.Name,
Usage: i.Usage,
Group: i.Group,
Default: fmt.Sprintf("%d", i.Default),
Value: i,
})
f.set.Uint64Var(i.Value, i.Name, i.Default, i.Usage)
}
Expand All @@ -209,31 +275,47 @@ func (b *BigIntFlag) String() string {
return b.Value.String()
}

func (b *BigIntFlag) Set(value string) error {
func parseBigInt(value string) *big.Int {
num := new(big.Int)

var ok bool
if strings.HasPrefix(value, "0x") {
num, ok = num.SetString(value[2:], 16)
*b.Value = *num
num, _ = num.SetString(value[2:], 16)
} else {
num, ok = num.SetString(value, 10)
*b.Value = *num
num, _ = num.SetString(value, 10)
}

if !ok {
return num
}

func (b *BigIntFlag) Set(value string) error {
num := parseBigInt(value)

if num == nil {
return fmt.Errorf("failed to set big int")
}

*b.Value = *num

return nil
}

func (b *BigIntFlag) UpdateValue(value string) {
num := parseBigInt(value)

if num == nil {
return
}

*b.Value = *num
}

func (f *Flagset) BigIntFlag(b *BigIntFlag) {
f.addFlag(&FlagVar{
Name: b.Name,
Usage: b.Usage,
Group: b.Group,
Default: b.Default,
Value: b,
})
f.set.Var(b, b.Name, b.Usage)
}
Expand Down Expand Up @@ -273,20 +355,26 @@ func (i *SliceStringFlag) Set(value string) error {
return nil
}

func (i *SliceStringFlag) UpdateValue(value string) {
*i.Value = SplitAndTrim(value)
}

func (f *Flagset) SliceStringFlag(s *SliceStringFlag) {
if s.Default == nil || len(s.Default) == 0 {
f.addFlag(&FlagVar{
Name: s.Name,
Usage: s.Usage,
Group: s.Group,
Default: nil,
Value: s,
})
} else {
f.addFlag(&FlagVar{
Name: s.Name,
Usage: s.Usage,
Group: s.Group,
Default: strings.Join(s.Default, ","),
Value: s,
})
}

Expand All @@ -301,12 +389,19 @@ type DurationFlag struct {
Group string
}

func (d *DurationFlag) UpdateValue(value string) {
v, _ := time.ParseDuration(value)

*d.Value = v
}

func (f *Flagset) DurationFlag(d *DurationFlag) {
f.addFlag(&FlagVar{
Name: d.Name,
Usage: d.Usage,
Group: d.Group,
Default: d.Default,
Value: d,
})
f.set.DurationVar(d.Value, d.Name, d.Default, "")
}
Expand Down Expand Up @@ -336,38 +431,54 @@ func (m *MapStringFlag) String() string {
return formatMapString(*m.Value)
}

func (m *MapStringFlag) Set(value string) error {
if m.Value == nil {
m.Value = &map[string]string{}
}
func parseMap(value string) map[string]string {
m := make(map[string]string)

for _, t := range strings.Split(value, ",") {
if t != "" {
kv := strings.Split(t, "=")

if len(kv) == 2 {
(*m.Value)[kv[0]] = kv[1]
m[kv[0]] = kv[1]
}
}
}

return m
}

func (m *MapStringFlag) Set(value string) error {
if m.Value == nil {
m.Value = &map[string]string{}
}

m2 := parseMap(value)
*m.Value = m2

return nil
}

func (m *MapStringFlag) UpdateValue(value string) {
m2 := parseMap(value)
*m.Value = m2
}

func (f *Flagset) MapStringFlag(m *MapStringFlag) {
if m.Default == nil || len(m.Default) == 0 {
f.addFlag(&FlagVar{
Name: m.Name,
Usage: m.Usage,
Group: m.Group,
Default: nil,
Value: m,
})
} else {
f.addFlag(&FlagVar{
Name: m.Name,
Usage: m.Usage,
Group: m.Group,
Default: formatMapString(m.Default),
Value: m,
})
}

Expand All @@ -382,12 +493,19 @@ type Float64Flag struct {
Group string
}

func (f *Float64Flag) UpdateValue(value string) {
v, _ := strconv.ParseFloat(value, 64)

*f.Value = v
}

func (f *Flagset) Float64Flag(i *Float64Flag) {
f.addFlag(&FlagVar{
Name: i.Name,
Usage: i.Usage,
Group: i.Group,
Default: i.Default,
Value: i,
})
f.set.Float64Var(i.Value, i.Name, i.Default, "")
}
Loading

0 comments on commit 8613ff1

Please sign in to comment.