diff --git a/pkg/services/dns/dns.go b/pkg/services/dns/dns.go index 1cfe0b6cb..bda5711b9 100644 --- a/pkg/services/dns/dns.go +++ b/pkg/services/dns/dns.go @@ -1,13 +1,14 @@ package dns import ( - "context" "encoding/json" "fmt" "net" "net/http" + "os" "strings" "sync" + "time" "github.com/containers/gvisor-tap-vsock/pkg/types" "github.com/miekg/dns" @@ -15,8 +16,51 @@ import ( ) type dnsHandler struct { - zones []types.Zone - zonesLock sync.RWMutex + zones []types.Zone + zonesLock sync.RWMutex + dnsClient *dns.Client + nameserver string +} + +func newDnsHandler(zones []types.Zone) *dnsHandler { + + dnsClient, nameserver := readAndCreateClient() + + return &dnsHandler{ + zones: zones, + dnsClient: dnsClient, + nameserver: nameserver, + } + +} + +func readAndCreateClient() (*dns.Client, string) { + conf, err := dns.ClientConfigFromFile("/etc/resolv.conf") + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(2) + } + nameserver := conf.Servers[0] + + // if the nameserver is from /etc/resolv.conf the [ and ] are already + // added, thereby breaking net.ParseIP. Check for this and don't + // fully qualify such a name + if nameserver[0] == '[' && nameserver[len(nameserver)-1] == ']' { + nameserver = nameserver[1 : len(nameserver)-1] + } + if i := net.ParseIP(nameserver); i != nil { + nameserver = net.JoinHostPort(nameserver, conf.Port) + } else { + nameserver = dns.Fqdn(nameserver) + ":" + conf.Port + } + client := new(dns.Client) + client.Net = "udp" + + client.DialTimeout = 2 * time.Second + client.ReadTimeout = 2 * time.Second + client.WriteTimeout = 2 * time.Second + + return client, nameserver } func (h *dnsHandler) handle(w dns.ResponseWriter, r *dns.Msg, responseMessageSize int) { @@ -85,116 +129,30 @@ func (h *dnsHandler) addAnswers(m *dns.Msg) { } } - resolver := net.Resolver{ - PreferGo: false, + // need to create new message struct, as reusing original message struct leading + // to request errors + message := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Authoritative: m.Authoritative, + AuthenticatedData: m.AuthenticatedData, + CheckingDisabled: m.CheckingDisabled, + RecursionDesired: m.RecursionDesired, + Opcode: m.Opcode, + }, + Question: make([]dns.Question, 1), } - switch q.Qtype { - case dns.TypeA: - ips, err := resolver.LookupIPAddr(context.TODO(), q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return - } - for _, ip := range ips { - if len(ip.IP.To4()) != net.IPv4len { - continue - } - m.Answer = append(m.Answer, &dns.A{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 0, - }, - A: ip.IP.To4(), - }) - } - case dns.TypeCNAME: - cname, err := resolver.LookupCNAME(context.TODO(), q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return - } - m.Answer = append(m.Answer, &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: 0, - }, - Target: cname, - }) - case dns.TypeMX: - records, err := resolver.LookupMX(context.TODO(), q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return - } - for _, mx := range records { - m.Answer = append(m.Answer, &dns.MX{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeMX, - Class: dns.ClassINET, - Ttl: 0, - }, - Mx: mx.Host, - Preference: mx.Pref, - }) - } - case dns.TypeNS: - records, err := resolver.LookupNS(context.TODO(), q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return - } - for _, ns := range records { - m.Answer = append(m.Answer, &dns.NS{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - Ttl: 0, - }, - Ns: ns.Host, - }) - } - case dns.TypeSRV: - _, records, err := resolver.LookupSRV(context.TODO(), "", "", q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return - } - for _, srv := range records { - m.Answer = append(m.Answer, &dns.SRV{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - Ttl: 0, - }, - Port: srv.Port, - Priority: srv.Priority, - Target: srv.Target, - Weight: srv.Weight, - }) - } - case dns.TypeTXT: - records, err := resolver.LookupTXT(context.TODO(), q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return - } - m.Answer = append(m.Answer, &dns.TXT{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - Ttl: 0, - }, - Txt: records, - }) + message.Question[0] = q + message.Id = dns.Id() + + r, _, err := h.dnsClient.Exchange(message, h.nameserver) + + if err != nil { + m.Rcode = dns.RcodeNameError + fmt.Fprintf(os.Stderr, "Error: %v \n", err) + return } + + m.Answer = append(m.Answer, r.Answer...) } } @@ -205,7 +163,7 @@ type Server struct { } func New(udpConn net.PacketConn, tcpLn net.Listener, zones []types.Zone) (*Server, error) { - handler := &dnsHandler{zones: zones} + handler := newDnsHandler(zones) return &Server{udpConn: udpConn, tcpLn: tcpLn, handler: handler}, nil }