diff --git a/mapper.go b/mapper.go index 584bb00..22013ad 100644 --- a/mapper.go +++ b/mapper.go @@ -7,9 +7,12 @@ import ( "fmt" "io" "math/bits" + "net" + "net/netip" "net/url" "os" "reflect" + "regexp" "strconv" "strings" "time" @@ -285,6 +288,11 @@ func (r *Registry) RegisterDefaults() *Registry { RegisterType(reflect.TypeOf(time.Duration(0)), durationDecoder()). RegisterType(reflect.TypeOf(&url.URL{}), urlMapper()). RegisterType(reflect.TypeOf(&os.File{}), fileMapper(r)). + RegisterType(reflect.TypeOf(®exp.Regexp{}), regexMapper()). + RegisterType(reflect.TypeOf(&net.IP{}), netIPMapper()). + RegisterType(reflect.TypeOf(&net.IPNet{}), netIPNetMapper()). + RegisterType(reflect.TypeOf(netip.Addr{}), netipAddrMapper()). + RegisterType(reflect.TypeOf(netip.Prefix{}), netipPrefixMapper()). RegisterName("path", pathMapper(r)). RegisterName("existingfile", existingFileMapper(r)). RegisterName("existingdir", existingDirMapper(r)). @@ -733,6 +741,102 @@ func fileContentMapper(r *Registry) MapperFunc { } } +func regexMapper() MapperFunc { + return func(ctx *DecodeContext, target reflect.Value) error { + t, err := ctx.Scan.PopValue("regex") + if err != nil { + return err + } + + var f *regexp.Regexp + switch v := t.Value.(type) { + case string: + f, err = regexp.Compile(v) + if err != nil { + return fmt.Errorf("expected regular expression but got %q: %w", v, err) + } + default: + return fmt.Errorf("expected string but got %q", v) + } + + target.Set(reflect.ValueOf(f)) + + return nil + } +} + +func netIPMapper() MapperFunc { + return func(ctx *DecodeContext, target reflect.Value) error { + var value string + if err := ctx.Scan.PopValueInto("ip", &value); err != nil { + return err + } + + ip := net.ParseIP(value) + if ip == nil { + return fmt.Errorf("expected ip addresss but got %q", value) + } + + target.Set(reflect.ValueOf(ip)) + + return nil + } +} + +func netIPNetMapper() MapperFunc { + return func(ctx *DecodeContext, target reflect.Value) error { + var value string + if err := ctx.Scan.PopValueInto("cidr", &value); err != nil { + return err + } + + _, ipnet, err := net.ParseCIDR(value) + if err != nil { + return fmt.Errorf("expected cidr but got %q: %w", value, err) + } + + target.Set(reflect.ValueOf(ipnet)) + + return nil + } +} + +func netipAddrMapper() MapperFunc { + return func(ctx *DecodeContext, target reflect.Value) error { + var value string + if err := ctx.Scan.PopValueInto("ip", &value); err != nil { + return err + } + + ip, err := netip.ParseAddr(value) + if err != nil { + return fmt.Errorf("expected ip addresss but got %q: %w", value, err) + } + + target.Set(reflect.ValueOf(ip)) + + return nil + } +} + +func netipPrefixMapper() MapperFunc { + return func(ctx *DecodeContext, target reflect.Value) error { + var value string + if err := ctx.Scan.PopValueInto("cidr", &value); err != nil { + return err + } + + prefix, err := netip.ParsePrefix(value) + if err != nil { + return fmt.Errorf("expected ipnet but got %q: %w", value, err) + } + + target.Set(reflect.ValueOf(prefix)) + + return nil + } +} + type ptrMapper struct { r *Registry } diff --git a/mapper_test.go b/mapper_test.go index 113e9f5..d919455 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -5,10 +5,13 @@ import ( "encoding/json" "fmt" "math" + "net" + "net/netip" "net/url" "os" "path/filepath" "reflect" + "regexp" "strings" "testing" "time" @@ -138,6 +141,167 @@ func TestDurationMapperJSONResolver(t *testing.T) { assert.Equal(t, time.Second*5, cli.Flag) } +func TestNetIP(t *testing.T) { + var cli struct { + Flag net.IP + } + k := mustNew(t, &cli) + _, err := k.Parse([]string{"--flag", "127.0.0.1"}) + assert.NoError(t, err) + assert.Equal(t, "127.0.0.1", cli.Flag.String()) + + _, err = k.Parse([]string{"--flag", "2001:db8:abcd:0012::1"}) + assert.NoError(t, err) + assert.Equal(t, "2001:db8:abcd:12::1", cli.Flag.String()) +} + +func TestNetIPSplice(t *testing.T) { + var cli struct { + Flag []net.IP + } + k := mustNew(t, &cli) + _, err := k.Parse([]string{ + "--flag", "127.0.0.1", + "--flag", "192.168.0.1", + "--flag", "2001:db8:abcd:0012::1", + }) + assert.NoError(t, err) + + assert.Equal(t, 3, len(cli.Flag)) + assert.Equal(t, "127.0.0.1", cli.Flag[0].String()) + assert.Equal(t, "192.168.0.1", cli.Flag[1].String()) + assert.Equal(t, "2001:db8:abcd:12::1", cli.Flag[2].String()) +} + +func TestIPNet(t *testing.T) { + var cli struct { + Flag *net.IPNet + } + k := mustNew(t, &cli) + _, err := k.Parse([]string{"--flag", "127.0.0.0/24"}) + assert.NoError(t, err) + assert.Equal(t, "127.0.0.0/24", cli.Flag.String()) + + _, err = k.Parse([]string{"--flag", "2001:db8:abcd:0012::0/64"}) + assert.NoError(t, err) + assert.Equal(t, "2001:db8:abcd:12::/64", cli.Flag.String()) +} + +func TestIPNetSlice(t *testing.T) { + var cli struct { + Test []*net.IPNet + } + + k := mustNew(t, &cli) + _, err := k.Parse([]string{ + "--test", "127.0.0.0/24", + "--test", "123.0.0.0/23", + "--test", "2001:db8:abcd:0012::0/64", + }) + assert.NoError(t, err) + + assert.Equal(t, 3, len(cli.Test)) + assert.Equal(t, "127.0.0.0/24", cli.Test[0].String()) + assert.Equal(t, "123.0.0.0/23", cli.Test[1].String()) + assert.Equal(t, "2001:db8:abcd:12::/64", cli.Test[2].String()) +} + +func TestNetipAddr(t *testing.T) { + var cli struct { + Flag netip.Addr + } + k := mustNew(t, &cli) + _, err := k.Parse([]string{"--flag", "127.0.0.1"}) + assert.NoError(t, err) + assert.Equal(t, "127.0.0.1", cli.Flag.String()) + + _, err = k.Parse([]string{"--flag", "2001:db8:abcd:0012::1"}) + assert.NoError(t, err) + assert.Equal(t, "2001:db8:abcd:12::1", cli.Flag.String()) +} + +func TestNetipAddrSplice(t *testing.T) { + var cli struct { + Flag []net.IP + } + k := mustNew(t, &cli) + _, err := k.Parse([]string{ + "--flag", "127.0.0.1", + "--flag", "192.168.0.1", + "--flag", "2001:db8:abcd:0012::1", + }) + assert.NoError(t, err) + + assert.Equal(t, 3, len(cli.Flag)) + assert.Equal(t, "127.0.0.1", cli.Flag[0].String()) + assert.Equal(t, "192.168.0.1", cli.Flag[1].String()) + assert.Equal(t, "2001:db8:abcd:12::1", cli.Flag[2].String()) +} + +func TestNetipPrefix(t *testing.T) { + var cli struct { + Flag netip.Prefix + } + k := mustNew(t, &cli) + + _, err := k.Parse([]string{"--flag", "127.0.0.0/24"}) + assert.NoError(t, err) + assert.Equal(t, "127.0.0.0/24", cli.Flag.String()) + assert.Equal(t, 24, cli.Flag.Bits()) + + _, err = k.Parse([]string{"--flag", "2001:db8:abcd:0012::0/64"}) + assert.NoError(t, err) + assert.Equal(t, "2001:db8:abcd:12::/64", cli.Flag.String()) +} + +func TestNetipPrefixSlice(t *testing.T) { + var cli struct { + Test []*net.IPNet `kong:"sep=','"` + } + + k := mustNew(t, &cli) + _, err := k.Parse([]string{ + "--test", "127.0.0.0/24", + "--test", "123.0.0.0/23", + "--test", "2001:db8:abcd:0012::0/64", + }) + assert.NoError(t, err) + + assert.Equal(t, 3, len(cli.Test)) + assert.Equal(t, "127.0.0.0/24", cli.Test[0].String()) + assert.Equal(t, "123.0.0.0/23", cli.Test[1].String()) + assert.Equal(t, "2001:db8:abcd:12::/64", cli.Test[2].String()) +} + +func TestRegex(t *testing.T) { + var cli struct { + Test *regexp.Regexp + } + + k := mustNew(t, &cli) + _, err := k.Parse([]string{"--test", "a.+[a-b]{2,4}"}) + assert.NoError(t, err) + assert.Equal(t, "a.+[a-b]{2,4}", cli.Test.String()) +} + +func TestRegexSlice(t *testing.T) { + var cli struct { + Test []*regexp.Regexp `kong:"sep='none'"` + } + + k := mustNew(t, &cli) + _, err := k.Parse([]string{ + "--test", "foo.+[b-r]{2,4}", + "--test", "foo=bar", + }) + assert.NoError(t, err) + + assert.Equal(t, 2, len(cli.Test)) + assert.Equal(t, "foo.+[b-r]{2,4}", cli.Test[0].String()) + assert.Equal(t, "foo=bar", cli.Test[1].String()) + +} + func TestSplitEscaped(t *testing.T) { assert.Equal(t, []string{"a", "b"}, kong.SplitEscaped("a,b", ',')) assert.Equal(t, []string{"a,b", "c"}, kong.SplitEscaped(`a\,b,c`, ','))