Skip to content

Commit

Permalink
internal/cli/flagset: handle flag conversion using interface
Browse files Browse the repository at this point in the history
  • Loading branch information
manav2401 committed Sep 20, 2023
1 parent 6ce00f3 commit 95aca1e
Showing 1 changed file with 72 additions and 100 deletions.
172 changes: 72 additions & 100 deletions internal/cli/flagset/flagset.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@ import (
"flag"
"fmt"
"math/big"
"reflect"
"sort"
"strconv"
"strings"
"time"

"github.com/ethereum/go-ethereum/log"
)

type Flagset struct {
Expand All @@ -27,12 +24,19 @@ func NewFlagSet(name string) *Flagset {
return f
}

// Updatable is a minimalistic representation of a flag which has
// the method `UpdateValue` implemented which can be called while
// overwriting flags.
type Updatable interface {
UpdateValue(string)
}

type FlagVar struct {
Name string
Usage string
Group string
Default any
Value any
Value Updatable
}

func (f *Flagset) addFlag(fl *FlagVar) {
Expand Down Expand Up @@ -125,50 +129,8 @@ func (f *Flagset) UpdateValue(names []string, values []string) {
if flag, ok := f.flags[name]; ok {
value := values[i]

// Get the underlying value set in flag
old := reflect.ValueOf(flag.Value).Elem()
oldType := old.Type()

// Create the new value to set based on the kind of old value. Each
// type of flag supported needs to be parsed individually because
// we receive the value to set (in `values`) as string and it's
// not possible to convert them to the underlying type directly
// at runtime.
var newValue any

// nolint:exhaustive
switch oldType.Kind() {
// Handle default data types first
case reflect.Bool:
newValue = GetBool(value)
case reflect.String:
newValue = value
case reflect.Int:
newValue = GetInt(value)
case reflect.Uint64:
newValue = GetUint64(value)
case reflect.Float64:
newValue = GetFloat64(value)
default:
// Handle custom data types
switch oldType {
case reflect.TypeOf(big.Int{}):
newValue = GetBigInt(value)
case reflect.TypeOf([]string{}):
newValue = GetSliceString(value)
case reflect.TypeOf(time.Second):
newValue = GetDuration(value)
case reflect.TypeOf(map[string]string{}):
newValue = GetMapString(value)
default:
log.Info("Unable to parse the type while overriding flag, skipping", "flag", name, "got type", oldType)
continue
}
}

// Now that both old and new values are of same type, set the
// new value.
old.Set(reflect.ValueOf(newValue))
// Call the underlying flag's `UpdateValue` method
flag.Value.UpdateValue(value)
}
}
}
Expand All @@ -195,10 +157,10 @@ type BoolFlag struct {
Group string
}

func GetBool(value string) bool {
func (b *BoolFlag) UpdateValue(value string) {
v, _ := strconv.ParseBool(value)

return v
*b.Value = v
}

func (f *Flagset) BoolFlag(b *BoolFlag) {
Expand All @@ -207,7 +169,7 @@ func (f *Flagset) BoolFlag(b *BoolFlag) {
Usage: b.Usage,
Group: b.Group,
Default: b.Default,
Value: b.Value,
Value: b,
})
f.set.BoolVar(b.Value, b.Name, b.Default, b.Usage)
}
Expand All @@ -221,22 +183,26 @@ type StringFlag struct {
HideDefaultFromDoc bool
}

func (b *StringFlag) UpdateValue(value string) {
*b.Value = value
}

func (f *Flagset) StringFlag(b *StringFlag) {
if b.Default == "" || b.HideDefaultFromDoc {
f.addFlag(&FlagVar{
Name: b.Name,
Usage: b.Usage,
Group: b.Group,
Default: nil,
Value: b.Value,
Value: b,
})
} else {
f.addFlag(&FlagVar{
Name: b.Name,
Usage: b.Usage,
Group: b.Group,
Default: b.Default,
Value: b.Value,
Value: b,
})
}

Expand All @@ -251,10 +217,10 @@ type IntFlag struct {
Group string
}

func GetInt(value string) int {
func (b *IntFlag) UpdateValue(value string) {
v, _ := strconv.ParseInt(value, 10, 64)

return int(v)
*b.Value = int(v)
}

func (f *Flagset) IntFlag(i *IntFlag) {
Expand All @@ -263,7 +229,7 @@ func (f *Flagset) IntFlag(i *IntFlag) {
Usage: i.Usage,
Group: i.Group,
Default: i.Default,
Value: i.Value,
Value: i,
})
f.set.IntVar(i.Value, i.Name, i.Default, i.Usage)
}
Expand All @@ -276,10 +242,10 @@ type Uint64Flag struct {
Group string
}

func GetUint64(value string) uint64 {
func (b *Uint64Flag) UpdateValue(value string) {
v, _ := strconv.ParseUint(value, 10, 64)

return v
*b.Value = v
}

func (f *Flagset) Uint64Flag(i *Uint64Flag) {
Expand All @@ -288,7 +254,7 @@ func (f *Flagset) Uint64Flag(i *Uint64Flag) {
Usage: i.Usage,
Group: i.Group,
Default: fmt.Sprintf("%d", i.Default),
Value: i.Value,
Value: i,
})
f.set.Uint64Var(i.Value, i.Name, i.Default, i.Usage)
}
Expand All @@ -309,30 +275,38 @@ func (b *BigIntFlag) String() string {
return b.Value.String()
}

func (b *BigIntFlag) Set(value string) error {
func parseBigInt(value string) *big.Int {
num := new(big.Int)

var ok bool
if strings.HasPrefix(value, "0x") {
num, ok = num.SetString(value[2:], 16)
*b.Value = *num
num, _ = num.SetString(value[2:], 16)
} else {
num, ok = num.SetString(value, 10)
*b.Value = *num
num, _ = num.SetString(value, 10)
}

if !ok {
return num
}

func (b *BigIntFlag) Set(value string) error {
num := parseBigInt(value)

if num == nil {
return fmt.Errorf("failed to set big int")
}

*b.Value = *num

return nil
}

func GetBigInt(value string) big.Int {
v := new(big.Int)
v.SetString(value, 10)
func (b *BigIntFlag) UpdateValue(value string) {
num := parseBigInt(value)

if num == nil {
return
}

return *v
*b.Value = *num
}

func (f *Flagset) BigIntFlag(b *BigIntFlag) {
Expand All @@ -341,7 +315,7 @@ func (f *Flagset) BigIntFlag(b *BigIntFlag) {
Usage: b.Usage,
Group: b.Group,
Default: b.Default,
Value: b.Value,
Value: b,
})
f.set.Var(b, b.Name, b.Usage)
}
Expand Down Expand Up @@ -381,8 +355,8 @@ func (i *SliceStringFlag) Set(value string) error {
return nil
}

func GetSliceString(value string) []string {
return SplitAndTrim(value)
func (i *SliceStringFlag) UpdateValue(value string) {
*i.Value = SplitAndTrim(value)
}

func (f *Flagset) SliceStringFlag(s *SliceStringFlag) {
Expand All @@ -392,15 +366,15 @@ func (f *Flagset) SliceStringFlag(s *SliceStringFlag) {
Usage: s.Usage,
Group: s.Group,
Default: nil,
Value: s.Value,
Value: s,
})
} else {
f.addFlag(&FlagVar{
Name: s.Name,
Usage: s.Usage,
Group: s.Group,
Default: strings.Join(s.Default, ","),
Value: s.Value,
Value: s,
})
}

Expand All @@ -415,10 +389,10 @@ type DurationFlag struct {
Group string
}

func GetDuration(value string) time.Duration {
func (d *DurationFlag) UpdateValue(value string) {
v, _ := time.ParseDuration(value)

return v
*d.Value = v
}

func (f *Flagset) DurationFlag(d *DurationFlag) {
Expand All @@ -427,7 +401,7 @@ func (f *Flagset) DurationFlag(d *DurationFlag) {
Usage: d.Usage,
Group: d.Group,
Default: d.Default,
Value: d.Value,
Value: d,
})
f.set.DurationVar(d.Value, d.Name, d.Default, "")
}
Expand Down Expand Up @@ -457,38 +431,36 @@ func (m *MapStringFlag) String() string {
return formatMapString(*m.Value)
}

func (m *MapStringFlag) Set(value string) error {
if m.Value == nil {
m.Value = &map[string]string{}
}
func parseMap(value string) map[string]string {
m := make(map[string]string)

for _, t := range strings.Split(value, ",") {
if t != "" {
kv := strings.Split(t, "=")

if len(kv) == 2 {
(*m.Value)[kv[0]] = kv[1]
m[kv[0]] = kv[1]
}
}
}

return nil
return m
}

func GetMapString(value string) map[string]string {
m := make(map[string]string)
func (m *MapStringFlag) Set(value string) error {
if m.Value == nil {
m.Value = &map[string]string{}
}

for _, t := range strings.Split(value, ",") {
if t != "" {
kv := strings.Split(t, "=")
m2 := parseMap(value)
*m.Value = m2

if len(kv) == 2 {
m[kv[0]] = kv[1]
}
}
}
return nil
}

return m
func (m *MapStringFlag) UpdateValue(value string) {
m2 := parseMap(value)
*m.Value = m2
}

func (f *Flagset) MapStringFlag(m *MapStringFlag) {
Expand All @@ -498,15 +470,15 @@ func (f *Flagset) MapStringFlag(m *MapStringFlag) {
Usage: m.Usage,
Group: m.Group,
Default: nil,
Value: m.Value,
Value: m,
})
} else {
f.addFlag(&FlagVar{
Name: m.Name,
Usage: m.Usage,
Group: m.Group,
Default: formatMapString(m.Default),
Value: m.Value,
Value: m,
})
}

Expand All @@ -521,10 +493,10 @@ type Float64Flag struct {
Group string
}

func GetFloat64(value string) float64 {
func (f *Float64Flag) UpdateValue(value string) {
v, _ := strconv.ParseFloat(value, 64)

return v
*f.Value = v
}

func (f *Flagset) Float64Flag(i *Float64Flag) {
Expand All @@ -533,7 +505,7 @@ func (f *Flagset) Float64Flag(i *Float64Flag) {
Usage: i.Usage,
Group: i.Group,
Default: i.Default,
Value: i.Value,
Value: i,
})
f.set.Float64Var(i.Value, i.Name, i.Default, "")
}

0 comments on commit 95aca1e

Please sign in to comment.