diff --git a/kv/memberlist/dnsprovider.go b/kv/memberlist/dnsprovider.go index b51a5d055..82e90b31b 100644 --- a/kv/memberlist/dnsprovider.go +++ b/kv/memberlist/dnsprovider.go @@ -2,6 +2,8 @@ package memberlist import ( "context" + "net" + "sync" ) // DNSProvider supports storing or resolving a list of addresses. @@ -13,3 +15,31 @@ type DNSProvider interface { // Addresses returns the latest addresses present in the DNSProvider. Addresses() []string } + +type dnsProvider struct { + sync.Mutex + addr []string +} + +func NewDNSProvider() DNSProvider { return &dnsProvider{} } + +func (d *dnsProvider) Resolve(ctx context.Context, addrs []string) error { + d.Lock() + defer d.Unlock() + for _, a := range addrs { + ips, err := net.LookupIP(a) + if err != nil { + return err + } + for _, ip := range ips { + d.addr = append(d.addr, ip.String()) + } + } + return nil +} + +func (d *dnsProvider) Addresses() []string { + d.Lock() + defer d.Unlock() + return d.addr +} diff --git a/kv/memberlist/dnsprovider_test.go b/kv/memberlist/dnsprovider_test.go new file mode 100644 index 000000000..30d08699d --- /dev/null +++ b/kv/memberlist/dnsprovider_test.go @@ -0,0 +1,28 @@ +package memberlist + +import ( + "context" + "testing" +) + +func TestDNSProvider(t *testing.T) { + dns := &dnsProvider{} + if err := dns.Resolve(context.Background(), []string{"localhost"}); err != nil { + t.Fatal(err) + } + has127_0_0_1 := false + for _, addr := range dns.Addresses() { + if addr == "127.0.0.1" { + has127_0_0_1 = true + } + } + if !has127_0_0_1 { + t.Error("resolving localhost must result in 127.0.0.1 address", dns.Addresses()) + } + if err := dns.Resolve(context.Background(), []string{"invalid dns"}); err == nil { + t.Error("resolving and invalid address must result in an error") + } + if len(dns.Addresses()) == 0 { + t.Error("DNSProvider must keep recent addresses on failure") + } +}