Skip to content

Commit

Permalink
Replace golang net.Resolver with dns.Client implementation.
Browse files Browse the repository at this point in the history
This is simplified implementation if resolving dns requests

Signed-off-by: Yevhen Vydolob <[email protected]>
  • Loading branch information
evidolob committed Mar 19, 2024
1 parent 80594f5 commit 0416600
Showing 1 changed file with 70 additions and 112 deletions.
182 changes: 70 additions & 112 deletions pkg/services/dns/dns.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,66 @@
package dns

import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"strings"
"sync"
"time"

"github.com/containers/gvisor-tap-vsock/pkg/types"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)

type dnsHandler struct {
zones []types.Zone
zonesLock sync.RWMutex
zones []types.Zone
zonesLock sync.RWMutex
dnsClient *dns.Client
nameserver string
}

func newDnsHandler(zones []types.Zone) *dnsHandler {

dnsClient, nameserver := readAndCreateClient()

return &dnsHandler{
zones: zones,
dnsClient: dnsClient,
nameserver: nameserver,
}

}

func readAndCreateClient() (*dns.Client, string) {
conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(2)
}
nameserver := conf.Servers[0]

// if the nameserver is from /etc/resolv.conf the [ and ] are already
// added, thereby breaking net.ParseIP. Check for this and don't
// fully qualify such a name
if nameserver[0] == '[' && nameserver[len(nameserver)-1] == ']' {
nameserver = nameserver[1 : len(nameserver)-1]
}
if i := net.ParseIP(nameserver); i != nil {
nameserver = net.JoinHostPort(nameserver, conf.Port)
} else {
nameserver = dns.Fqdn(nameserver) + ":" + conf.Port
}
client := new(dns.Client)
client.Net = "udp"

client.DialTimeout = 2 * time.Second
client.ReadTimeout = 2 * time.Second
client.WriteTimeout = 2 * time.Second

return client, nameserver
}

func (h *dnsHandler) handle(w dns.ResponseWriter, r *dns.Msg, responseMessageSize int) {
Expand Down Expand Up @@ -85,116 +129,30 @@ func (h *dnsHandler) addAnswers(m *dns.Msg) {
}
}

resolver := net.Resolver{
PreferGo: false,
// need to create new message struct, as reusing original message struct leading
// to request errors
message := &dns.Msg{
MsgHdr: dns.MsgHdr{
Authoritative: m.Authoritative,
AuthenticatedData: m.AuthenticatedData,
CheckingDisabled: m.CheckingDisabled,
RecursionDesired: m.RecursionDesired,
Opcode: m.Opcode,
},
Question: make([]dns.Question, 1),
}
switch q.Qtype {
case dns.TypeA:
ips, err := resolver.LookupIPAddr(context.TODO(), q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
for _, ip := range ips {
if len(ip.IP.To4()) != net.IPv4len {
continue
}
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 0,
},
A: ip.IP.To4(),
})
}
case dns.TypeCNAME:
cname, err := resolver.LookupCNAME(context.TODO(), q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
m.Answer = append(m.Answer, &dns.CNAME{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: 0,
},
Target: cname,
})
case dns.TypeMX:
records, err := resolver.LookupMX(context.TODO(), q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
for _, mx := range records {
m.Answer = append(m.Answer, &dns.MX{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeMX,
Class: dns.ClassINET,
Ttl: 0,
},
Mx: mx.Host,
Preference: mx.Pref,
})
}
case dns.TypeNS:
records, err := resolver.LookupNS(context.TODO(), q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
for _, ns := range records {
m.Answer = append(m.Answer, &dns.NS{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 0,
},
Ns: ns.Host,
})
}
case dns.TypeSRV:
_, records, err := resolver.LookupSRV(context.TODO(), "", "", q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
for _, srv := range records {
m.Answer = append(m.Answer, &dns.SRV{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
},
Port: srv.Port,
Priority: srv.Priority,
Target: srv.Target,
Weight: srv.Weight,
})
}
case dns.TypeTXT:
records, err := resolver.LookupTXT(context.TODO(), q.Name)
if err != nil {
m.Rcode = dns.RcodeNameError
return
}
m.Answer = append(m.Answer, &dns.TXT{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 0,
},
Txt: records,
})
message.Question[0] = q
message.Id = dns.Id()

r, _, err := h.dnsClient.Exchange(message, h.nameserver)

if err != nil {
m.Rcode = dns.RcodeNameError
fmt.Fprintf(os.Stderr, "Error: %v \n", err)
return
}

m.Answer = append(m.Answer, r.Answer...)
}
}

Expand All @@ -205,7 +163,7 @@ type Server struct {
}

func New(udpConn net.PacketConn, tcpLn net.Listener, zones []types.Zone) (*Server, error) {
handler := &dnsHandler{zones: zones}
handler := newDnsHandler(zones)
return &Server{udpConn: udpConn, tcpLn: tcpLn, handler: handler}, nil
}

Expand Down

0 comments on commit 0416600

Please sign in to comment.