From e6b54f758847b5b1a02c75ff431e42a0304ad588 Mon Sep 17 00:00:00 2001 From: Yevhen Vydolob Date: Tue, 26 Mar 2024 15:49:23 +0200 Subject: [PATCH] Split getting DNS configuration. For UNIX systems read and parse `/etc/resolv.conf` file. On Windows use 'GetAdaptersAddresses' syscall to get DNS configuration Signed-off-by: Yevhen Vydolob --- pkg/services/dns/dns.go | 19 +- pkg/services/dns/dns_config_unix.go | 28 +++ pkg/services/dns/dns_config_windows.go | 296 +++++++++++++++++++++++++ pkg/services/dns/dns_test.go | 25 +++ 4 files changed, 355 insertions(+), 13 deletions(-) create mode 100644 pkg/services/dns/dns_config_unix.go create mode 100644 pkg/services/dns/dns_config_windows.go diff --git a/pkg/services/dns/dns.go b/pkg/services/dns/dns.go index bda5711b9..b18a835aa 100644 --- a/pkg/services/dns/dns.go +++ b/pkg/services/dns/dns.go @@ -22,7 +22,7 @@ type dnsHandler struct { nameserver string } -func newDnsHandler(zones []types.Zone) *dnsHandler { +func newDNSHandler(zones []types.Zone) *dnsHandler { dnsClient, nameserver := readAndCreateClient() @@ -35,23 +35,16 @@ func newDnsHandler(zones []types.Zone) *dnsHandler { } func readAndCreateClient() (*dns.Client, string) { - conf, err := dns.ClientConfigFromFile("/etc/resolv.conf") + + nameserver, port, err := GetDNSHostAndPort() if err != nil { - fmt.Fprintln(os.Stderr, err) os.Exit(2) } - nameserver := conf.Servers[0] - // if the nameserver is from /etc/resolv.conf the [ and ] are already - // added, thereby breaking net.ParseIP. Check for this and don't - // fully qualify such a name - if nameserver[0] == '[' && nameserver[len(nameserver)-1] == ']' { - nameserver = nameserver[1 : len(nameserver)-1] - } if i := net.ParseIP(nameserver); i != nil { - nameserver = net.JoinHostPort(nameserver, conf.Port) + nameserver = net.JoinHostPort(nameserver, port) } else { - nameserver = dns.Fqdn(nameserver) + ":" + conf.Port + nameserver = dns.Fqdn(nameserver) + ":" + port } client := new(dns.Client) client.Net = "udp" @@ -163,7 +156,7 @@ type Server struct { } func New(udpConn net.PacketConn, tcpLn net.Listener, zones []types.Zone) (*Server, error) { - handler := newDnsHandler(zones) + handler := newDNSHandler(zones) return &Server{udpConn: udpConn, tcpLn: tcpLn, handler: handler}, nil } diff --git a/pkg/services/dns/dns_config_unix.go b/pkg/services/dns/dns_config_unix.go new file mode 100644 index 000000000..bd91a776d --- /dev/null +++ b/pkg/services/dns/dns_config_unix.go @@ -0,0 +1,28 @@ +//go:build !windows + +package dns + +import ( + "fmt" + "os" + + "github.com/miekg/dns" +) + +func GetDNSHostAndPort() (string, string, error) { + conf, err := dns.ClientConfigFromFile("/etc/resolv.conf") + if err != nil { + fmt.Fprintln(os.Stderr, err) + return "", "", err + } + nameserver := conf.Servers[0] + + // if the nameserver is from /etc/resolv.conf the [ and ] are already + // added, thereby breaking net.ParseIP. Check for this and don't + // fully qualify such a name + if nameserver[0] == '[' && nameserver[len(nameserver)-1] == ']' { + nameserver = nameserver[1 : len(nameserver)-1] + } + + return nameserver, conf.Port, nil +} diff --git a/pkg/services/dns/dns_config_windows.go b/pkg/services/dns/dns_config_windows.go new file mode 100644 index 000000000..3890f789e --- /dev/null +++ b/pkg/services/dns/dns_config_windows.go @@ -0,0 +1,296 @@ +//go:build windows + +package dns + +import ( + "errors" + "fmt" + "net/netip" + "strconv" + "syscall" + "unsafe" +) + +func GetDNSHostAndPort() (string, string, error) { + nameservers := getDNSServers() + + var nameserver netip.AddrPort + for _, n := range nameservers { + // return first non ipv6 nameserver + if n.Addr().Is4() { + nameserver = n + break + } + } + + return nameserver.Addr().String(), strconv.Itoa(int(nameserver.Port())), nil + +} + +// copied from https://github.com/qdm12/dns/blob/v2.0.0-beta/pkg/nameserver/getlocal_windows.go +// this function will use windows syscall to get DNS configuration +func getDNSServers() (nameservers []netip.AddrPort) { + const defaultDNSPort = 53 + defaultLocalNameservers := []netip.AddrPort{ + netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), defaultDNSPort), + netip.AddrPortFrom(netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 1}), defaultDNSPort), + } + + adapterAddresses, err := getAdapterAddresses() + if err != nil { + return defaultLocalNameservers + } + + for _, adapterAddress := range adapterAddresses { + const statusUp = 0x01 + if adapterAddress.operStatus != statusUp { + continue + } + + if adapterAddress.firstGatewayAddress == nil { + // Only search DNS servers for adapters having a gateway + continue + } + + dnsServerAddress := adapterAddress.firstDnsServerAddress + for dnsServerAddress != nil { + ip, ok := sockAddressToIP(dnsServerAddress.address.rawSockAddrAny) + if !ok || ipIsSiteLocalAnycast(ip) { + // fec0/10 IPv6 addresses are site local anycast DNS + // addresses Microsoft sets by default if no other + // IPv6 DNS address is set. Site local anycast is + // deprecated since 2004, see + // https://datatracker.ietf.org/doc/html/rfc3879 + dnsServerAddress = dnsServerAddress.next + continue + } + + nameserver := netip.AddrPortFrom(ip, defaultDNSPort) + nameservers = append(nameservers, nameserver) + dnsServerAddress = dnsServerAddress.next + } + } + + if len(nameservers) == 0 { + return defaultLocalNameservers + } + return nameservers +} + +var ( + errBufferOverflowUnexpected = errors.New("unexpected buffer overflowed because buffer was large enough") +) + +func getAdapterAddresses() ( + adapterAddresses []*ipAdapterAddresses, err error) { + var buffer []byte + const initialBufferLength uint32 = 15000 + sizeVar := initialBufferLength + + for { + buffer = make([]byte, sizeVar) + err := runProcGetAdaptersAddresses( + (*ipAdapterAddresses)(unsafe.Pointer(&buffer[0])), + &sizeVar) + if err != nil { + if err.(syscall.Errno) == syscall.ERROR_BUFFER_OVERFLOW { + if sizeVar <= uint32(len(buffer)) { + return nil, fmt.Errorf("%w: buffer size variable %d is "+ + "equal or lower to the buffer current length %d", + errBufferOverflowUnexpected, sizeVar, len(buffer)) + } + continue + } + return nil, fmt.Errorf("getting adapters addresses: %w", err) + } + + noDataFound := sizeVar == 0 + if noDataFound { + return nil, nil + } + break + } + + adapterAddress := (*ipAdapterAddresses)(unsafe.Pointer(&buffer[0])) + for adapterAddress != nil { + adapterAddresses = append(adapterAddresses, adapterAddress) + adapterAddress = adapterAddress.next + } + + return adapterAddresses, nil +} + +var ( + procGetAdaptersAddresses = syscall.NewLazyDLL("iphlpapi.dll"). + NewProc("GetAdaptersAddresses") +) + +func runProcGetAdaptersAddresses(adapterAddresses *ipAdapterAddresses, + sizePointer *uint32) (errcode error) { + const family = syscall.AF_UNSPEC + const GAA_FLAG_SKIP_UNICAST = 0x0001 + const GAA_FLAG_SKIP_ANYCAST = 0x0002 + const GAA_FLAG_SKIP_MULTICAST = 0x0004 + const GAA_FLAG_SKIP_FRIENDLY_NAME = 0x0020 + const GAA_FLAG_INCLUDE_GATEWAYS = 0x0080 + const flags = GAA_FLAG_SKIP_UNICAST | GAA_FLAG_SKIP_ANYCAST | + GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_FRIENDLY_NAME | + GAA_FLAG_INCLUDE_GATEWAYS + const reserved = 0 + // See https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getadaptersaddresses + r1, _, err := syscall.SyscallN(procGetAdaptersAddresses.Addr(), + uintptr(family), uintptr(flags), uintptr(reserved), + uintptr(unsafe.Pointer(adapterAddresses)), + uintptr(unsafe.Pointer(sizePointer))) + switch { + case err != 0: + return err + case r1 != 0: + return syscall.Errno(r1) + default: + return nil + } +} + +func sockAddressToIP(rawSockAddress *syscall.RawSockaddrAny) (ip netip.Addr, ok bool) { + if rawSockAddress == nil { + return netip.Addr{}, false + } + + sockAddress, err := rawSockAddress.Sockaddr() + if err != nil { + return netip.Addr{}, false + } + + switch sockAddress := sockAddress.(type) { + case *syscall.SockaddrInet4: + return netip.AddrFrom4([4]byte{ + sockAddress.Addr[0], sockAddress.Addr[1], sockAddress.Addr[2], sockAddress.Addr[3]}), + true + case *syscall.SockaddrInet6: + return netip.AddrFrom16([16]byte{ + sockAddress.Addr[0], sockAddress.Addr[1], sockAddress.Addr[2], sockAddress.Addr[3], + sockAddress.Addr[4], sockAddress.Addr[5], sockAddress.Addr[6], sockAddress.Addr[7], + sockAddress.Addr[8], sockAddress.Addr[9], sockAddress.Addr[10], sockAddress.Addr[11], + sockAddress.Addr[12], sockAddress.Addr[13], sockAddress.Addr[14], sockAddress.Addr[15]}), + true + default: + return netip.Addr{}, false + } +} + +func ipIsSiteLocalAnycast(ip netip.Addr) bool { + if !ip.Is6() { + return false + } + + array := ip.As16() + return array[0] == 0xfe && array[1] == 0xc0 +} + +// See https://learn.microsoft.com/en-us/windows/win32/api/iptypes/ns-iptypes-ip_adapter_addresses_lh +type ipAdapterAddresses struct { + // The order of fields DOES matter since they are read + // raw from a bytes buffer. However, we are only interested + // in a few select fields, so unneeded fields are either + // named as "_" or removed if they are after the fields + // we are interested in. + _ uint32 + _ uint32 + next *ipAdapterAddresses + _ *byte + _ *ipAdapterUnicastAddress + _ *ipAdapterAnycastAddress + _ *ipAdapterMulticastAddress + firstDnsServerAddress *ipAdapterDnsServerAdapter + _ *uint16 + _ *uint16 + _ *uint16 + _ [syscall.MAX_ADAPTER_ADDRESS_LENGTH]byte + _ uint32 + _ uint32 + _ uint32 + _ uint32 + operStatus uint32 + _ uint32 + _ [16]uint32 + _ *ipAdapterPrefix + _ uint64 + _ uint64 + _ *ipAdapterWinsServerAddress + firstGatewayAddress *ipAdapterGatewayAddress + // Additional fields not needed here +} + +type ipAdapterUnicastAddress struct { + // The order of fields DOES matter since they are read raw + // from a bytes buffer. However, we are not interested in + // the value of any field, so they are all named as "_". + _ uint32 + _ uint32 + _ *ipAdapterUnicastAddress + _ ipAdapterSocketAddress + _ int32 + _ int32 + _ int32 + _ uint32 + _ uint32 + _ uint32 + _ uint8 +} + +type ipAdapterAnycastAddress struct { + // The order of fields DOES matter since they are read raw + // from a bytes buffer. However, we are not interested in + // the value of any field, so they are all named as "_". + _ uint32 + _ uint32 + _ *ipAdapterAnycastAddress + _ ipAdapterSocketAddress +} + +type ipAdapterMulticastAddress struct { + // The order of fields DOES matter since they are read raw + // from a bytes buffer. However, we are only interested in + // a few select fields, so unneeded fields are named as "_". + _ uint32 + _ uint32 + _ *ipAdapterMulticastAddress + _ ipAdapterSocketAddress +} + +type ipAdapterDnsServerAdapter struct { + // The order of fields DOES matter since they are read raw + // from a bytes buffer. However, we are only interested in + // a few select fields, so unneeded fields are named as "_". + _ uint32 + _ uint32 + next *ipAdapterDnsServerAdapter + address ipAdapterSocketAddress +} + +type ipAdapterPrefix struct { + _ uint32 + _ uint32 + _ *ipAdapterPrefix + _ ipAdapterSocketAddress + _ uint32 +} + +type ipAdapterWinsServerAddress struct { + _ uint32 + _ uint32 + _ *ipAdapterWinsServerAddress + _ ipAdapterSocketAddress +} + +type ipAdapterGatewayAddress struct { + _ uint32 + _ uint32 + _ *ipAdapterGatewayAddress + _ ipAdapterSocketAddress +} + +type ipAdapterSocketAddress struct { + rawSockAddrAny *syscall.RawSockaddrAny +} diff --git a/pkg/services/dns/dns_test.go b/pkg/services/dns/dns_test.go index f01488d23..bfae64709 100644 --- a/pkg/services/dns/dns_test.go +++ b/pkg/services/dns/dns_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/containers/gvisor-tap-vsock/pkg/types" + "github.com/miekg/dns" "github.com/onsi/ginkgo" "github.com/onsi/gomega" ) @@ -191,4 +192,28 @@ var _ = ginkgo.Describe("dns add test", func() { }, })) }) + + ginkgo.It("Should pass DNS requests to default system DNS server", func() { + m := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Authoritative: false, + AuthenticatedData: false, + CheckingDisabled: false, + RecursionDesired: true, + Opcode: 0, + }, + Question: make([]dns.Question, 1), + } + + m.Question[0] = dns.Question{ + Name: "redhat.com.", + Qtype: 1, + Qclass: 1, + } + + server.handler.addAnswers(m) + + gomega.Expect(m.Answer[0].Header().Name).To(gomega.Equal("redhat.com.")) + gomega.Expect(m.Answer[0].String()).To(gomega.SatisfyAny(gomega.ContainSubstring("34.235.198.240"), gomega.ContainSubstring("52.200.142.250"))) + }) })