Skip to content

Commit

Permalink
Replace multiple "resolver.*" fn cals with single "dns.Exchange()" fn.
Browse files Browse the repository at this point in the history
This highly simplify resolving DNS code.
Also, DNS will work only for IPv4
Signed-off-by: Yevhen Vydolob <[email protected]>
Co-authored-by: Christophe Fergeau <[email protected]>
  • Loading branch information
evidolob committed Aug 5, 2024
1 parent 95677b9 commit f2ae1d9
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 140 deletions.
225 changes: 86 additions & 139 deletions pkg/services/dns/dns.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package dns

import (
"context"
"encoding/json"
"fmt"
"net"
Expand All @@ -15,15 +14,43 @@ import (
)

type dnsHandler struct {
zones []types.Zone
zonesLock sync.RWMutex
zones []types.Zone
zonesLock sync.RWMutex
dnsClient *dns.Client
nameserver string
}

func newDNSHandler(zones []types.Zone) (*dnsHandler, error) {

dnsClient, nameserver, err := readAndCreateClient()
if err != nil {
return nil, err
}

return &dnsHandler{
zones: zones,
dnsClient: dnsClient,
nameserver: nameserver,
}, nil

}

func readAndCreateClient() (*dns.Client, string, error) {

nameserver, port, err := GetDNSHostAndPort()
if err != nil {
return nil, "", err
}

nameserver = net.JoinHostPort(nameserver, port)

client := new(dns.Client)

return client, nameserver, nil
}

func (h *dnsHandler) handle(w dns.ResponseWriter, r *dns.Msg, responseMessageSize int) {
m := new(dns.Msg)
m.SetReply(r)
m.RecursionAvailable = true
h.addAnswers(m)
m := h.addAnswers(r)
edns0 := r.IsEdns0()
if edns0 != nil {
responseMessageSize = int(edns0.UDPSize())
Expand All @@ -35,167 +62,84 @@ func (h *dnsHandler) handle(w dns.ResponseWriter, r *dns.Msg, responseMessageSiz
}

func (h *dnsHandler) handleTCP(w dns.ResponseWriter, r *dns.Msg) {
// needs to be handled in a better way, handleTCP/handleUDP can run concurrently so this change is racy
// h.dnsClient.Net = "tcp"
h.handle(w, r, dns.MaxMsgSize)
}

func (h *dnsHandler) handleUDP(w dns.ResponseWriter, r *dns.Msg) {
// needs to be handled in a better way, handleTCP/handleUDP can run concurrently so this change is racy
// h.dnsClient.Net = "udp"
h.handle(w, r, dns.MinMsgSize)
}

func (h *dnsHandler) addAnswers(m *dns.Msg) {
func (h *dnsHandler) addLocalAnswers(m *dns.Msg, q dns.Question) bool {
h.zonesLock.RLock()
defer h.zonesLock.RUnlock()
for _, q := range m.Question {
for _, zone := range h.zones {
zoneSuffix := fmt.Sprintf(".%s", zone.Name)
if strings.HasSuffix(q.Name, zoneSuffix) {
if q.Qtype != dns.TypeA {
return
}
for _, record := range zone.Records {
withoutZone := strings.TrimSuffix(q.Name, zoneSuffix)
if (record.Name != "" && record.Name == withoutZone) ||
(record.Regexp != nil && record.Regexp.MatchString(withoutZone)) {
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 0,
},
A: record.IP,
})
return
}
}
if !zone.DefaultIP.Equal(net.IP("")) {

for _, zone := range h.zones {
zoneSuffix := fmt.Sprintf(".%s", zone.Name)
if strings.HasSuffix(q.Name, zoneSuffix) {
if q.Qtype != dns.TypeA {
return false
}
for _, record := range zone.Records {
withoutZone := strings.TrimSuffix(q.Name, zoneSuffix)
if (record.Name != "" && record.Name == withoutZone) ||
(record.Regexp != nil && record.Regexp.MatchString(withoutZone)) {
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 0,
},
A: zone.DefaultIP,
A: record.IP,
})
return
return true
}
m.Rcode = dns.RcodeNameError
return
}
}

resolver := net.Resolver{
PreferGo: false,
}
switch q.Qtype {
case dns.TypeA:
ips, err := resolver.LookupIPAddr(context.TODO(), q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
for _, ip := range ips {
if len(ip.IP.To4()) != net.IPv4len {
continue
}
if !zone.DefaultIP.Equal(net.IP("")) {
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 0,
},
A: ip.IP.To4(),
A: zone.DefaultIP,
})
return true
}
case dns.TypeCNAME:
cname, err := resolver.LookupCNAME(context.TODO(), q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
m.Answer = append(m.Answer, &dns.CNAME{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: 0,
},
Target: cname,
})
case dns.TypeMX:
records, err := resolver.LookupMX(context.TODO(), q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
for _, mx := range records {
m.Answer = append(m.Answer, &dns.MX{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeMX,
Class: dns.ClassINET,
Ttl: 0,
},
Mx: mx.Host,
Preference: mx.Pref,
})
}
case dns.TypeNS:
records, err := resolver.LookupNS(context.TODO(), q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
for _, ns := range records {
m.Answer = append(m.Answer, &dns.NS{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 0,
},
Ns: ns.Host,
})
}
case dns.TypeSRV:
_, records, err := resolver.LookupSRV(context.TODO(), "", "", q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
for _, srv := range records {
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Port: srv.Port,
Priority: srv.Priority,
Target: srv.Target,
Weight: srv.Weight,
})
}
case dns.TypeTXT:
records, err := resolver.LookupTXT(context.TODO(), q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
m.Answer = append(m.Answer, &dns.TXT{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 0,
},
Txt: records,
})
m.Rcode = dns.RcodeNameError
return true
}
}
return false
}

func (h *dnsHandler) addAnswers(r *dns.Msg) *dns.Msg {
m := new(dns.Msg)
m.SetReply(r)
m.RecursionAvailable = true
for _, q := range m.Question {
if done := h.addLocalAnswers(m, q); done {
return m

// ignore IPv6 request, we support only IPv4 requests for now
} else if q.Qtype == dns.TypeAAAA {
return m
}
}

r, _, err := h.dnsClient.Exchange(r, h.nameserver)
if err != nil {
log.Errorf("Error during DNS Exchange: %s", err)
m.Rcode = dns.RcodeNameError
return m
}

return r
}

type Server struct {
Expand All @@ -205,7 +149,10 @@ type Server struct {
}

func New(udpConn net.PacketConn, tcpLn net.Listener, zones []types.Zone) (*Server, error) {
handler := &dnsHandler{zones: zones}
handler, err := newDNSHandler(zones)
if err != nil {
return nil, err
}
return &Server{udpConn: udpConn, tcpLn: tcpLn, handler: handler}, nil
}

Expand Down
22 changes: 22 additions & 0 deletions pkg/services/dns/dns_config_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//go:build !windows

package dns

import (
"fmt"
"os"

"github.com/miekg/dns"
)

func GetDNSHostAndPort() (string, string, error) {
conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
fmt.Fprintln(os.Stderr, err)
return "", "", err
}
// TODO: use all configured nameservers, instead just first one
nameserver := conf.Servers[0]

return nameserver, conf.Port, nil
}
26 changes: 26 additions & 0 deletions pkg/services/dns/dns_config_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//go:build windows

package dns

import (
"net/netip"
"strconv"

qdmDns "github.com/qdm12/dns/v2/pkg/nameserver"
)

func GetDNSHostAndPort() (string, string, error) {
nameservers := qdmDns.GetDNSServers()

var nameserver netip.AddrPort
for _, n := range nameservers {
// return first non ipv6 nameserver
if n.Addr().Is4() {
nameserver = n
break
}
}

return nameserver.Addr().String(), strconv.Itoa(int(nameserver.Port())), nil

}
Loading

0 comments on commit f2ae1d9

Please sign in to comment.