Skip to content

Commit

Permalink
feat: added regex and net types: ip, ipnet, cidr
Browse files Browse the repository at this point in the history
  • Loading branch information
marsom committed May 20, 2024
1 parent 2ab5733 commit f449f3d
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 0 deletions.
104 changes: 104 additions & 0 deletions mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import (
"fmt"
"io"
"math/bits"
"net"
"net/netip"
"net/url"
"os"
"reflect"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -285,6 +288,11 @@ func (r *Registry) RegisterDefaults() *Registry {
RegisterType(reflect.TypeOf(time.Duration(0)), durationDecoder()).
RegisterType(reflect.TypeOf(&url.URL{}), urlMapper()).
RegisterType(reflect.TypeOf(&os.File{}), fileMapper(r)).
RegisterType(reflect.TypeOf(&regexp.Regexp{}), regexMapper()).
RegisterType(reflect.TypeOf(&net.IP{}), netIPMapper()).
RegisterType(reflect.TypeOf(&net.IPNet{}), netIPNetMapper()).
RegisterType(reflect.TypeOf(netip.Addr{}), netipAddrMapper()).
RegisterType(reflect.TypeOf(netip.Prefix{}), netipPrefixMapper()).
RegisterName("path", pathMapper(r)).
RegisterName("existingfile", existingFileMapper(r)).
RegisterName("existingdir", existingDirMapper(r)).
Expand Down Expand Up @@ -733,6 +741,102 @@ func fileContentMapper(r *Registry) MapperFunc {
}
}

func regexMapper() MapperFunc {
return func(ctx *DecodeContext, target reflect.Value) error {
t, err := ctx.Scan.PopValue("regex")
if err != nil {
return err
}

var f *regexp.Regexp
switch v := t.Value.(type) {
case string:
f, err = regexp.Compile(v)
if err != nil {
return fmt.Errorf("expected regular expression but got %q: %w", v, err)
}
default:
return fmt.Errorf("expected string but got %q", v)
}

target.Set(reflect.ValueOf(f))

return nil
}
}

func netIPMapper() MapperFunc {
return func(ctx *DecodeContext, target reflect.Value) error {
var value string
if err := ctx.Scan.PopValueInto("ip", &value); err != nil {
return err
}

ip := net.ParseIP(value)
if ip == nil {
return fmt.Errorf("expected ip addresss but got %q", value)
}

target.Set(reflect.ValueOf(ip))

return nil
}
}

func netIPNetMapper() MapperFunc {
return func(ctx *DecodeContext, target reflect.Value) error {
var value string
if err := ctx.Scan.PopValueInto("cidr", &value); err != nil {
return err
}

_, ipnet, err := net.ParseCIDR(value)
if err != nil {
return fmt.Errorf("expected cidr but got %q: %w", value, err)
}

target.Set(reflect.ValueOf(ipnet))

return nil
}
}

func netipAddrMapper() MapperFunc {
return func(ctx *DecodeContext, target reflect.Value) error {
var value string
if err := ctx.Scan.PopValueInto("ip", &value); err != nil {
return err
}

ip, err := netip.ParseAddr(value)
if err != nil {
return fmt.Errorf("expected ip addresss but got %q: %w", value, err)
}

target.Set(reflect.ValueOf(ip))

return nil
}
}

func netipPrefixMapper() MapperFunc {
return func(ctx *DecodeContext, target reflect.Value) error {
var value string
if err := ctx.Scan.PopValueInto("cidr", &value); err != nil {
return err
}

prefix, err := netip.ParsePrefix(value)
if err != nil {
return fmt.Errorf("expected ipnet but got %q: %w", value, err)
}

target.Set(reflect.ValueOf(prefix))

return nil
}
}

type ptrMapper struct {
r *Registry
}
Expand Down
164 changes: 164 additions & 0 deletions mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ import (
"encoding/json"
"fmt"
"math"
"net"
"net/netip"
"net/url"
"os"
"path/filepath"
"reflect"
"regexp"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -138,6 +141,167 @@ func TestDurationMapperJSONResolver(t *testing.T) {
assert.Equal(t, time.Second*5, cli.Flag)
}

func TestNetIP(t *testing.T) {
var cli struct {
Flag net.IP
}
k := mustNew(t, &cli)
_, err := k.Parse([]string{"--flag", "127.0.0.1"})
assert.NoError(t, err)
assert.Equal(t, "127.0.0.1", cli.Flag.String())

_, err = k.Parse([]string{"--flag", "2001:db8:abcd:0012::1"})
assert.NoError(t, err)
assert.Equal(t, "2001:db8:abcd:12::1", cli.Flag.String())
}

func TestNetIPSplice(t *testing.T) {
var cli struct {
Flag []net.IP
}
k := mustNew(t, &cli)
_, err := k.Parse([]string{
"--flag", "127.0.0.1",
"--flag", "192.168.0.1",
"--flag", "2001:db8:abcd:0012::1",
})
assert.NoError(t, err)

assert.Equal(t, 3, len(cli.Flag))
assert.Equal(t, "127.0.0.1", cli.Flag[0].String())
assert.Equal(t, "192.168.0.1", cli.Flag[1].String())
assert.Equal(t, "2001:db8:abcd:12::1", cli.Flag[2].String())
}

func TestIPNet(t *testing.T) {
var cli struct {
Flag *net.IPNet
}
k := mustNew(t, &cli)
_, err := k.Parse([]string{"--flag", "127.0.0.0/24"})
assert.NoError(t, err)
assert.Equal(t, "127.0.0.0/24", cli.Flag.String())

_, err = k.Parse([]string{"--flag", "2001:db8:abcd:0012::0/64"})
assert.NoError(t, err)
assert.Equal(t, "2001:db8:abcd:12::/64", cli.Flag.String())
}

func TestIPNetSlice(t *testing.T) {
var cli struct {
Test []*net.IPNet
}

k := mustNew(t, &cli)
_, err := k.Parse([]string{
"--test", "127.0.0.0/24",
"--test", "123.0.0.0/23",
"--test", "2001:db8:abcd:0012::0/64",
})
assert.NoError(t, err)

assert.Equal(t, 3, len(cli.Test))
assert.Equal(t, "127.0.0.0/24", cli.Test[0].String())
assert.Equal(t, "123.0.0.0/23", cli.Test[1].String())
assert.Equal(t, "2001:db8:abcd:12::/64", cli.Test[2].String())
}

func TestNetipAddr(t *testing.T) {
var cli struct {
Flag netip.Addr
}
k := mustNew(t, &cli)
_, err := k.Parse([]string{"--flag", "127.0.0.1"})
assert.NoError(t, err)
assert.Equal(t, "127.0.0.1", cli.Flag.String())

_, err = k.Parse([]string{"--flag", "2001:db8:abcd:0012::1"})
assert.NoError(t, err)
assert.Equal(t, "2001:db8:abcd:12::1", cli.Flag.String())
}

func TestNetipAddrSplice(t *testing.T) {
var cli struct {
Flag []net.IP
}
k := mustNew(t, &cli)
_, err := k.Parse([]string{
"--flag", "127.0.0.1",
"--flag", "192.168.0.1",
"--flag", "2001:db8:abcd:0012::1",
})
assert.NoError(t, err)

assert.Equal(t, 3, len(cli.Flag))
assert.Equal(t, "127.0.0.1", cli.Flag[0].String())
assert.Equal(t, "192.168.0.1", cli.Flag[1].String())
assert.Equal(t, "2001:db8:abcd:12::1", cli.Flag[2].String())
}

func TestNetipPrefix(t *testing.T) {
var cli struct {
Flag netip.Prefix
}
k := mustNew(t, &cli)

_, err := k.Parse([]string{"--flag", "127.0.0.0/24"})
assert.NoError(t, err)
assert.Equal(t, "127.0.0.0/24", cli.Flag.String())
assert.Equal(t, 24, cli.Flag.Bits())

_, err = k.Parse([]string{"--flag", "2001:db8:abcd:0012::0/64"})
assert.NoError(t, err)
assert.Equal(t, "2001:db8:abcd:12::/64", cli.Flag.String())
}

func TestNetipPrefixSlice(t *testing.T) {
var cli struct {
Test []*net.IPNet `kong:"sep=','"`
}

k := mustNew(t, &cli)
_, err := k.Parse([]string{
"--test", "127.0.0.0/24",
"--test", "123.0.0.0/23",
"--test", "2001:db8:abcd:0012::0/64",
})
assert.NoError(t, err)

assert.Equal(t, 3, len(cli.Test))
assert.Equal(t, "127.0.0.0/24", cli.Test[0].String())
assert.Equal(t, "123.0.0.0/23", cli.Test[1].String())
assert.Equal(t, "2001:db8:abcd:12::/64", cli.Test[2].String())
}

func TestRegex(t *testing.T) {
var cli struct {
Test *regexp.Regexp
}

k := mustNew(t, &cli)
_, err := k.Parse([]string{"--test", "a.+[a-b]{2,4}"})
assert.NoError(t, err)
assert.Equal(t, "a.+[a-b]{2,4}", cli.Test.String())
}

func TestRegexSlice(t *testing.T) {
var cli struct {
Test []*regexp.Regexp `kong:"sep='none'"`
}

k := mustNew(t, &cli)
_, err := k.Parse([]string{
"--test", "foo.+[b-r]{2,4}",
"--test", "foo=bar",
})
assert.NoError(t, err)

assert.Equal(t, 2, len(cli.Test))
assert.Equal(t, "foo.+[b-r]{2,4}", cli.Test[0].String())
assert.Equal(t, "foo=bar", cli.Test[1].String())

}

func TestSplitEscaped(t *testing.T) {
assert.Equal(t, []string{"a", "b"}, kong.SplitEscaped("a,b", ','))
assert.Equal(t, []string{"a,b", "c"}, kong.SplitEscaped(`a\,b,c`, ','))
Expand Down

0 comments on commit f449f3d

Please sign in to comment.