From 43ef64cf673fc785a2005f0c7b3f5616211f4065 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 31 Dec 2024 14:07:21 +0100 Subject: [PATCH] [client] Ignore case when matching domains in handler chain (#3133) --- client/internal/dns/handler_chain.go | 21 ++- client/internal/dns/handler_chain_test.go | 168 ++++++++++++++++++++++ 2 files changed, 178 insertions(+), 11 deletions(-) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 9302d50b171..5f63d1ab3f8 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority c.mu.Lock() defer c.mu.Unlock() + pattern = strings.ToLower(dns.Fqdn(pattern)) origPattern := pattern isWildcard := strings.HasPrefix(pattern, "*.") if isWildcard { pattern = pattern[2:] } - pattern = dns.Fqdn(pattern) - origPattern = dns.Fqdn(origPattern) - // First remove any existing handler with same original pattern and priority + // First remove any existing handler with same pattern (case-insensitive) and priority for i := len(c.handlers) - 1; i >= 0; i-- { - if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority { + if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority { if c.handlers[i].StopHandler != nil { c.handlers[i].StopHandler.stop() } @@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) { pattern = dns.Fqdn(pattern) - // Find and remove handlers matching both original pattern and priority + // Find and remove handlers matching both original pattern (case-insensitive) and priority for i := len(c.handlers) - 1; i >= 0; i-- { entry := c.handlers[i] - if entry.OrigPattern == pattern && entry.Priority == priority { + if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if entry.StopHandler != nil { entry.StopHandler.stop() } @@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool { c.mu.RLock() defer c.mu.RUnlock() - pattern = dns.Fqdn(pattern) + pattern = strings.ToLower(dns.Fqdn(pattern)) for _, entry := range c.handlers { - if entry.Pattern == pattern { + if strings.EqualFold(entry.Pattern, pattern) { return true } } @@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - qname := r.Question[0].Name + qname := strings.ToLower(r.Question[0].Name) log.Tracef("handling DNS request for domain=%s", qname) c.mu.RLock() @@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // If handler wants subdomain matching, allow suffix match // Otherwise require exact match if entry.MatchSubdomains { - matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) + matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern) } else { - matched = qname == entry.Pattern + matched = strings.EqualFold(qname, entry.Pattern) } } diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 727b6e9087d..eb40c907fb9 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 4: Remove last handler chain.RemoveHandler(testDomain, nbdns.PriorityDefault) + assert.False(t, chain.HasHandlers(testDomain)) } + +func TestHandlerChain_CaseSensitivity(t *testing.T) { + tests := []struct { + name string + scenario string + addHandlers []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + } + query string + expectedCalls int + }{ + { + name: "case insensitive exact match", + scenario: "handler registered lowercase, query uppercase", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"example.com.", nbdns.PriorityDefault, false, true}, + }, + query: "EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "case insensitive wildcard match", + scenario: "handler registered mixed case wildcard, query different case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"*.Example.Com.", nbdns.PriorityDefault, false, true}, + }, + query: "sub.EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "multiple handlers different case same domain", + scenario: "second handler should replace first despite case difference", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, + {"example.com.", nbdns.PriorityDefault, false, true}, + }, + query: "ExAmPlE.cOm.", + expectedCalls: 1, + }, + { + name: "subdomain matching case insensitive", + scenario: "handler with MatchSubdomains true should match regardless of case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"example.com.", nbdns.PriorityDefault, true, true}, + }, + query: "SUB.EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "root zone case insensitive", + scenario: "root zone handler should match regardless of case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {".", nbdns.PriorityDefault, false, true}, + }, + query: "EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "multiple handlers different priority", + scenario: "should call higher priority handler despite case differences", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, + {"example.com.", nbdns.PriorityMatchDomain, false, false}, + {"Example.Com.", nbdns.PriorityDNSRoute, false, true}, + }, + query: "example.com.", + expectedCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + handlerCalls := make(map[string]bool) // track which patterns were called + + // Add handlers according to test case + for _, h := range tt.addHandlers { + var handler dns.Handler + pattern := h.pattern // capture pattern for closure + + if h.subdomains { + subHandler := &nbdns.MockSubdomainHandler{ + Subdomains: true, + } + if h.shouldMatch { + subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + handlerCalls[pattern] = true + w := args.Get(0).(dns.ResponseWriter) + r := args.Get(1).(*dns.Msg) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + } + handler = subHandler + } else { + mockHandler := &nbdns.MockHandler{} + if h.shouldMatch { + mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + handlerCalls[pattern] = true + w := args.Get(0).(dns.ResponseWriter) + r := args.Get(1).(*dns.Msg) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + } + handler = mockHandler + } + + chain.AddHandler(pattern, handler, h.priority, nil) + } + + // Execute request + r := new(dns.Msg) + r.SetQuestion(tt.query, dns.TypeA) + chain.ServeDNS(&mockResponseWriter{}, r) + + // Verify each handler was called exactly as expected + for _, h := range tt.addHandlers { + wasCalled := handlerCalls[h.pattern] + assert.Equal(t, h.shouldMatch, wasCalled, + "Handler for pattern %q was %s when it should%s have been", + h.pattern, + map[bool]string{true: "called", false: "not called"}[wasCalled], + map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch]) + } + + // Verify total number of calls + assert.Equal(t, tt.expectedCalls, len(handlerCalls), + "Wrong number of total handler calls") + }) + } +}