Skip to content

Commit

Permalink
Merge pull request #86 from vimeo/integral_type_slice_pflags
Browse files Browse the repository at this point in the history
pflag: add integral-slice and uintptr support
  • Loading branch information
dfinkel authored Jan 29, 2024
2 parents 023c057 + 7b6fcff commit 8efef7f
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 6 deletions.
67 changes: 61 additions & 6 deletions sources/pflag/pflag.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,29 @@ var (
int32Type = reflect.TypeOf(int32(0))
int64Type = reflect.TypeOf(int64(0))

uintType = reflect.TypeOf(uint(0))
uint8Type = reflect.TypeOf(uint8(0))
uint16Type = reflect.TypeOf(uint16(0))
uint32Type = reflect.TypeOf(uint32(0))
uint64Type = reflect.TypeOf(uint64(0))
uintType = reflect.TypeOf(uint(0))
uint8Type = reflect.TypeOf(uint8(0))
uint16Type = reflect.TypeOf(uint16(0))
uint32Type = reflect.TypeOf(uint32(0))
uint64Type = reflect.TypeOf(uint64(0))
uintptrType = reflect.TypeOf(uintptr(0))

complex64Type = reflect.TypeOf((*complex64)(nil))
complex128Type = reflect.TypeOf((*complex128)(nil))

intSliceType = reflect.SliceOf(intType)
int8SliceType = reflect.SliceOf(int8Type)
int16SliceType = reflect.SliceOf(int16Type)
int32SliceType = reflect.SliceOf(int32Type)
int64SliceType = reflect.SliceOf(int64Type)

uintSliceType = reflect.SliceOf(uintType)
uint8SliceType = reflect.SliceOf(uint8Type)
uint16SliceType = reflect.SliceOf(uint16Type)
uint32SliceType = reflect.SliceOf(uint32Type)
uint64SliceType = reflect.SliceOf(uint64Type)
uintptrSliceType = reflect.SliceOf(uintptrType)

// Verify that Set implements the dials.Source interface
_ dials.Source = (*Set)(nil)
)
Expand Down Expand Up @@ -313,6 +327,8 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error {
f = s.Flags.Uint32P(name, shorthand, fieldVal.Convert(uint32Type).Interface().(uint32), help)
case reflect.Uint64:
f = s.Flags.Uint64P(name, shorthand, fieldVal.Convert(uint64Type).Interface().(uint64), help)
case reflect.Uintptr:
f = s.Flags.Uint64P(name, shorthand, uint64(fieldVal.Convert(uintptrType).Interface().(uintptr)), help)
case reflect.Slice, reflect.Map:
switch ft {
case stringSlice:
Expand All @@ -326,6 +342,44 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error {
case stringSet:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewStringSetFlag(fieldVal.Addr().Interface().(*map[string]struct{})), name, shorthand, help)

// signed integral slices
case intSliceType:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewSignedIntegralSlice(f.(*[]int)), name, shorthand, help)
case int8SliceType:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewSignedIntegralSlice(f.(*[]int8)), name, shorthand, help)
case int16SliceType:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewSignedIntegralSlice(f.(*[]int16)), name, shorthand, help)
case int32SliceType:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewSignedIntegralSlice(f.(*[]int32)), name, shorthand, help)
case int64SliceType:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewSignedIntegralSlice(f.(*[]int64)), name, shorthand, help)

// unsigned integral slices
case uintSliceType:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uint)), name, shorthand, help)
case uint8SliceType:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uint8)), name, shorthand, help)
case uint16SliceType:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uint16)), name, shorthand, help)
case uint32SliceType:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uint32)), name, shorthand, help)
case uint64SliceType:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uint64)), name, shorthand, help)
case uintptrSliceType:
f = fieldVal.Addr().Interface()
s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uintptr)), name, shorthand, help)

default:
// Unhandled type. Just keep going.
continue
Expand Down Expand Up @@ -431,7 +485,8 @@ func (s *Set) Value(_ context.Context, t *dials.Type) (reflect.Value, error) {
return
}

cfval := fval.Convert(stripTypePtr(ffield.Type()))
// fval is always a pointer, so dereference it before converting to the final type
cfval := fval.Elem().Convert(stripTypePtr(ffield.Type()))
switch ffield.Kind() {
case reflect.Ptr:
// common case
Expand Down
235 changes: 235 additions & 0 deletions sources/pflag/pflag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func TestDefaultVals(t *testing.T) {
type otherUint16 uint16
type otherUint32 uint32
type otherUint64 uint64
type otherUintptr uintptr
type otherFloat32 float32
type otherFloat64 float64
type otherComplex64 complex64
Expand All @@ -99,6 +100,7 @@ func TestDefaultVals(t *testing.T) {
OUint16 otherUint16
OUint32 otherUint32
OUint64 otherUint64
OUintptr otherUintptr
OFloat32 otherFloat32
OFloat64 otherFloat64
OComplex64 otherComplex64
Expand All @@ -118,6 +120,7 @@ func TestDefaultVals(t *testing.T) {
OUint16: 3,
OUint32: 4,
OUint64: 5,
OUintptr: 0xffff_f333_7777,
OFloat32: 6.0,
OFloat64: 7.0,
OComplex64: 8 + 2i,
Expand Down Expand Up @@ -200,6 +203,24 @@ func TestPFlags(t *testing.T) {
args: []string{"--a=42"},
expected: &struct{ A int }{A: 42},
},
{
name: "basic_int_slice_set",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []int }{A: []int{4}}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=42,33"},
expected: &struct{ A []int }{A: []int{42, 33}},
},
{
name: "basic_uint_slice_set",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uint }{A: []uint{4}}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=42,33"},
expected: &struct{ A []uint }{A: []uint{42, 33}},
},
{
name: "basic_float32_set",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
Expand Down Expand Up @@ -294,6 +315,202 @@ func TestPFlags(t *testing.T) {
expected: nil,
expErr: "failed to parse pflags: invalid argument \"1000000\" for \"--a\" flag: strconv.ParseInt: parsing \"1000000\": value out of range",
},
{
name: "basic_uint16_slice_default",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uint16 }{A: []uint16{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{},
expected: &struct{ A []uint16 }{A: []uint16{10}},
},
{
name: "basic_uint16_slice_set_nooverflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uint16 }{A: []uint16{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=128,32"},
expected: &struct{ A []uint16 }{A: []uint16{128, 32}},
},
{
name: "basic_uint16_slice_set_overflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uint16 }{A: []uint16{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=1000000"},
expected: nil,
expErr: "failed to parse pflags: invalid argument \"1000000\" for \"--a\" flag: failed to parse integer index 0: strconv.ParseUint: parsing \"1000000\": value out of range",
},
{
name: "basic_uint32_set_nooverflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A uint32 }{A: 10}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=128"},
expected: &struct{ A uint32 }{A: 128},
},
{
name: "basic_uint32_set_overflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A uint32 }{A: 10}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=100_000_000_000"},
expected: nil,
expErr: "failed to parse pflags: invalid argument \"100_000_000_000\" for \"--a\" flag: strconv.ParseUint: parsing \"100_000_000_000\": value out of range",
},
{
name: "basic_uint32_slice_default",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uint32 }{A: []uint32{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{},
expected: &struct{ A []uint32 }{A: []uint32{10}},
},
{
name: "basic_uint32_slice_set_nooverflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uint32 }{A: []uint32{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=128,32"},
expected: &struct{ A []uint32 }{A: []uint32{128, 32}},
},
{
name: "basic_uint8_default",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A uint8 }{A: 10}
return &cfg, testWrapDials(&cfg)
},
args: []string{},
expected: &struct{ A uint8 }{A: 10},
},
{
name: "basic_uint8_slice_default",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uint8 }{A: []uint8{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{},
expected: &struct{ A []uint8 }{A: []uint8{10}},
},
{
name: "basic_uint8_set_nooverflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A uint8 }{A: 10}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=125"},
expected: &struct{ A uint8 }{A: 125},
},
{
name: "basic_uint8_slice_set_nooverflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uint8 }{A: []uint8{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=125"},
expected: &struct{ A []uint8 }{A: []uint8{125}},
},
{
name: "basic_uint8_set_overflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A uint8 }{A: 10}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=1000000"},
expected: nil,
expErr: "failed to parse pflags: invalid argument \"1000000\" for \"--a\" flag: strconv.ParseUint: parsing \"1000000\": value out of range",
},
{
name: "basic_uint8_slice_set_overflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uint8 }{A: []uint8{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=1000000"},
expected: nil,
expErr: "failed to parse pflags: invalid argument \"1000000\" for \"--a\" flag: failed to parse integer index 0: strconv.ParseUint: parsing \"1000000\": value out of range",
},
{
name: "basic_uint64_set_nooverflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A uint64 }{A: 10}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=128"},
expected: &struct{ A uint64 }{A: 128},
},
{
name: "basic_uint64_set_overflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A uint64 }{A: 10}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=100_000_000_000_000_000_000"},
expected: nil,
expErr: "failed to parse pflags: invalid argument \"100_000_000_000_000_000_000\" for \"--a\" flag: strconv.ParseUint: parsing \"100_000_000_000_000_000_000\": value out of range",
},
{
name: "basic_uint64_slice_default",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uint64 }{A: []uint64{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{},
expected: &struct{ A []uint64 }{A: []uint64{10}},
},
{
name: "basic_uint64_slice_set_nooverflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uint64 }{A: []uint64{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=128,32"},
expected: &struct{ A []uint64 }{A: []uint64{128, 32}},
},
{
name: "basic_uintptr_set_nooverflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A uintptr }{A: 10}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=128"},
expected: &struct{ A uintptr }{A: 128},
},
{
name: "basic_uintptr_set_overflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A uintptr }{A: 10}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=100_000_000_000_000_000_000"},
expected: nil,
expErr: "failed to parse pflags: invalid argument \"100_000_000_000_000_000_000\" for \"--a\" flag: strconv.ParseUint: parsing \"100_000_000_000_000_000_000\": value out of range",
},
{
name: "basic_uintptr_slice_default",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uintptr }{A: []uintptr{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{},
expected: &struct{ A []uintptr }{A: []uintptr{10}},
},
{
name: "basic_uintptr_slice_set_nooverflow",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []uintptr }{A: []uintptr{10}}
return &cfg, testWrapDials(&cfg)
},
args: []string{"--a=128,32"},
expected: &struct{ A []uintptr }{A: []uintptr{128, 32}},
},

{
name: "map_string_string_set",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
Expand Down Expand Up @@ -348,6 +565,24 @@ func TestPFlags(t *testing.T) {
args: []string{},
expected: &struct{ A map[string]struct{} }{A: map[string]struct{}{"i": {}}},
},
{
name: "int_slice_default_val",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []int }{A: []int{33, 22}}
return &cfg, testWrapDials(&cfg)
},
args: []string{},
expected: &struct{ A []int }{A: []int{33, 22}},
},
{
name: "int_slice_default_nil",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
cfg := struct{ A []int }{A: []int(nil)}
return &cfg, testWrapDials(&cfg)
},
args: []string{},
expected: &struct{ A []int }{A: nil},
},
{
name: "complex128_default",
tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {
Expand Down

0 comments on commit 8efef7f

Please sign in to comment.