Skip to content

Commit

Permalink
[client] Ignore case when matching domains in handler chain (#3133)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal authored Dec 31, 2024
1 parent 18316be commit 43ef64c
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 11 deletions.
21 changes: 10 additions & 11 deletions client/internal/dns/handler_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
c.mu.Lock()
defer c.mu.Unlock()

pattern = strings.ToLower(dns.Fqdn(pattern))
origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.")
if isWildcard {
pattern = pattern[2:]
}
pattern = dns.Fqdn(pattern)
origPattern = dns.Fqdn(origPattern)

// First remove any existing handler with same original pattern and priority
// First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop()
}
Expand Down Expand Up @@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {

pattern = dns.Fqdn(pattern)

// Find and remove handlers matching both original pattern and priority
// Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if entry.OrigPattern == pattern && entry.Priority == priority {
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
Expand All @@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock()
defer c.mu.RUnlock()

pattern = dns.Fqdn(pattern)
pattern = strings.ToLower(dns.Fqdn(pattern))
for _, entry := range c.handlers {
if entry.Pattern == pattern {
if strings.EqualFold(entry.Pattern, pattern) {
return true
}
}
Expand All @@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}

qname := r.Question[0].Name
qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname)

c.mu.RLock()
Expand Down Expand Up @@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
matched = qname == entry.Pattern
matched = strings.EqualFold(qname, entry.Pattern)
}
}

Expand Down
168 changes: 168 additions & 0 deletions client/internal/dns/handler_chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {

// Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)

assert.False(t, chain.HasHandlers(testDomain))
}

func TestHandlerChain_CaseSensitivity(t *testing.T) {
tests := []struct {
name string
scenario string
addHandlers []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}
query string
expectedCalls int
}{
{
name: "case insensitive exact match",
scenario: "handler registered lowercase, query uppercase",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "case insensitive wildcard match",
scenario: "handler registered mixed case wildcard, query different case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
},
query: "sub.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different case same domain",
scenario: "second handler should replace first despite case difference",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "ExAmPlE.cOm.",
expectedCalls: 1,
},
{
name: "subdomain matching case insensitive",
scenario: "handler with MatchSubdomains true should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, true, true},
},
query: "SUB.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "root zone case insensitive",
scenario: "root zone handler should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{".", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different priority",
scenario: "should call higher priority handler despite case differences",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
},
query: "example.com.",
expectedCalls: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlerCalls := make(map[string]bool) // track which patterns were called

// Add handlers according to test case
for _, h := range tt.addHandlers {
var handler dns.Handler
pattern := h.pattern // capture pattern for closure

if h.subdomains {
subHandler := &nbdns.MockSubdomainHandler{
Subdomains: true,
}
if h.shouldMatch {
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = subHandler
} else {
mockHandler := &nbdns.MockHandler{}
if h.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = mockHandler
}

chain.AddHandler(pattern, handler, h.priority, nil)
}

// Execute request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&mockResponseWriter{}, r)

// Verify each handler was called exactly as expected
for _, h := range tt.addHandlers {
wasCalled := handlerCalls[h.pattern]
assert.Equal(t, h.shouldMatch, wasCalled,
"Handler for pattern %q was %s when it should%s have been",
h.pattern,
map[bool]string{true: "called", false: "not called"}[wasCalled],
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
}

// Verify total number of calls
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
"Wrong number of total handler calls")
})
}
}

0 comments on commit 43ef64c

Please sign in to comment.