Skip to content

Commit

Permalink
Add udp conn tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Dec 20, 2024
1 parent a778e91 commit 809d63f
Show file tree
Hide file tree
Showing 6 changed files with 650 additions and 8 deletions.
10 changes: 9 additions & 1 deletion client/firewall/uspfilter/allow_netbird.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

package uspfilter

import "github.com/netbirdio/netbird/client/internal/statemanager"
import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
Expand All @@ -12,6 +15,11 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)

if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(udpTimeout)
}

if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager)
}
Expand Down
6 changes: 6 additions & 0 deletions client/firewall/uspfilter/allow_netbird_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

Expand All @@ -26,6 +27,11 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)

if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(udpTimeout)
}

if !isWindowsFirewallReachable() {
return nil
}
Expand Down
152 changes: 152 additions & 0 deletions client/firewall/uspfilter/conntrack/conntrack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package conntrack

import (
"net"
"sync"
"time"
)

const (
// DefaultTimeout is the default timeout for UDP connections
DefaultTimeout = 30 * time.Second
// CleanupInterval is how often we check for stale connections
CleanupInterval = 15 * time.Second
)

// UDPConnTrack represents a UDP connection state
type UDPConnTrack struct {
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
LastSeen time.Time
established bool
}

// UDPTracker manages UDP connection states
type UDPTracker struct {
connections map[uint16]*UDPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
mutex sync.RWMutex
done chan struct{} // Channel to signal shutdown
}

// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration) *UDPTracker {
if timeout == 0 {
timeout = DefaultTimeout
}

tracker := &UDPTracker{
connections: make(map[uint16]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(CleanupInterval),
done: make(chan struct{}),
}

go tracker.cleanupRoutine()
return tracker
}

// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
t.mutex.Lock()
defer t.mutex.Unlock()

t.connections[srcPort] = &UDPConnTrack{
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
LastSeen: time.Now(),
established: true,
}
}

// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
t.mutex.RLock()
defer t.mutex.RUnlock()

conn, exists := t.connections[dstPort]
if !exists {
return false
}

// Check if connection is still valid
if time.Since(conn.LastSeen) > t.timeout {
return false
}

if conn.established &&
conn.DestIP.Equal(srcIP) &&
conn.SourceIP.Equal(dstIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort {

conn.LastSeen = time.Now()

return true
}

return false
}

// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}

func (t *UDPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()

now := time.Now()
for srcPort, conn := range t.connections {
if now.Sub(conn.LastSeen) > t.timeout {
delete(t.connections, srcPort)
}
}
}

// Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
}

// GetConnection safely retrieves a connection state by source port.
func (t *UDPTracker) GetConnection(srcPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()

conn, exists := t.connections[srcPort]
if !exists {
return nil, false
}

// Return a copy to prevent potential race conditions
connCopy := &UDPConnTrack{
SourceIP: append(net.IP{}, conn.SourceIP...),
DestIP: append(net.IP{}, conn.DestIP...),
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
LastSeen: conn.LastSeen,
established: conn.established,
}

return connCopy, true
}

// Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}
Loading

0 comments on commit 809d63f

Please sign in to comment.