From d1fad7f04d285d9b08f838eac01a4eec8c558e81 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 12:03:51 +0100 Subject: [PATCH 01/28] Remove userspace egress filter --- client/iface/device/device_filter.go | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index f87f104293c..113e9b5955f 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -45,27 +45,8 @@ func newDeviceFilter(device tun.Device) *FilteredDevice { // Read wraps read method with filtering feature func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { - if n, err = d.Device.Read(bufs, sizes, offset); err != nil { - return 0, err - } - d.mutex.RLock() - filter := d.filter - d.mutex.RUnlock() - - if filter == nil { - return - } - - for i := 0; i < n; i++ { - if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) { - bufs = append(bufs[:i], bufs[i+1:]...) - sizes = append(sizes[:i], sizes[i+1:]...) - n-- - i-- - } - } - - return n, nil + // outgoing traffic is not filtered + return d.Device.Read(bufs, sizes, offset) } // Write wraps write method with filtering feature From c81ac1a72868c5efa02388ce16dec9837a6e8df1 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 12:12:55 +0100 Subject: [PATCH 02/28] Remove iptables egress filter --- client/firewall/iptables/acl_linux.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 1c0527ebc78..d774f45381b 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -332,18 +332,12 @@ func (m *aclManager) createDefaultChains() error { // The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule. func (m *aclManager) seedInitialEntries() { - established := getConntrackEstablished() m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules}) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...)) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) From afb034346d9aaf553399b0dbb966bd8f19284c0c Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 12:22:17 +0100 Subject: [PATCH 03/28] Remove nftables egress filter --- client/firewall/nftables/acl_linux.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index abe890fb9a1..c3ad349f60c 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -441,18 +441,6 @@ func (m *AclManager) createDefaultChains() (err error) { return err } - // netbird-acl-output-filter - // type filter hook output priority filter; policy accept; - chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) - m.addFwdAllow(chain, expr.MetaKeyOIFNAME) - m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules - m.addDropExpressions(chain, expr.MetaKeyOIFNAME) - err = m.rConn.Flush() - if err != nil { - log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err) - return err - } - // netbird-acl-forward-filter chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd From fbfb2cd98ebaecc477ed675083c1b4891eb00977 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 12:53:45 +0100 Subject: [PATCH 04/28] Remove unused code --- client/firewall/iptables/manager_linux.go | 14 ++------ client/firewall/nftables/acl_linux.go | 41 ----------------------- 2 files changed, 2 insertions(+), 53 deletions(-) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index adb8f20ef5c..0e1e5836f39 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -207,19 +207,9 @@ func (m *Manager) AllowNetbird() error { "", ) if err != nil { - return fmt.Errorf("failed to allow netbird interface traffic: %w", err) + return fmt.Errorf("allow netbird interface traffic: %w", err) } - _, err = m.AddPeerFiltering( - net.ParseIP("0.0.0.0"), - "all", - nil, - nil, - firewall.RuleDirectionOUT, - firewall.ActionAccept, - "", - "", - ) - return err + return nil } // Flush doesn't need to be implemented for this manager diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index c3ad349f60c..852cfec8de6 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "fmt" "net" - "net/netip" "strconv" "strings" "time" @@ -28,7 +27,6 @@ const ( // filter chains contains the rules that jump to the rules chains chainNameInputFilter = "netbird-acl-input-filter" - chainNameOutputFilter = "netbird-acl-output-filter" chainNameForwardFilter = "netbird-acl-forward-filter" chainNamePrerouting = "netbird-rt-prerouting" @@ -607,45 +605,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met return nil } -func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - dstOp := expr.CmpOpNeq - expressions := []expr.Any{ - &expr.Meta{Key: iifname, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: dstOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: expressions, - }) -} - func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, From 4d14cf6b1ab012bbd1557abafbcea1bcde001527 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 13:26:23 +0100 Subject: [PATCH 05/28] Still process outgoing udp hooks --- client/firewall/uspfilter/uspfilter.go | 57 ++++++++++++++++----- client/firewall/uspfilter/uspfilter_test.go | 2 +- client/iface/device/device_filter.go | 23 ++++++++- 3 files changed, 66 insertions(+), 16 deletions(-) diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index fb726395bef..e7c26b11874 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -249,16 +249,55 @@ func (m *Manager) Flush() error { return nil } // DropOutgoing filter outgoing packets func (m *Manager) DropOutgoing(packetData []byte) bool { - return m.dropFilter(packetData, m.outgoingRules, false) + return m.processOutgoingHooks(packetData) } // DropIncoming filter incoming packets func (m *Manager) DropIncoming(packetData []byte) bool { - return m.dropFilter(packetData, m.incomingRules, true) + return m.dropFilter(packetData, m.incomingRules) +} + +// processOutgoingHooks processes only UDP hooks for outgoing packets +func (m *Manager) processOutgoingHooks(packetData []byte) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + return false + } + + if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeUDP { + return false + } + + var ip net.IP + switch d.decoded[0] { + case layers.LayerTypeIPv4: + ip = d.ip4.DstIP + case layers.LayerTypeIPv6: + ip = d.ip6.DstIP + default: + return false + } + + // Check specific IP rules first, then any-IP rules + for _, ipKey := range []string{ip.String(), "0.0.0.0", "::"} { + if rules, exists := m.outgoingRules[ipKey]; exists { + for _, rule := range rules { + if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { + return rule.udpHook(packetData) + } + } + } + } + return false } // dropFilter implements same logic for booth direction of the traffic -func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool { +func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -294,17 +333,9 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco var ip net.IP switch ipLayer { case layers.LayerTypeIPv4: - if isIncomingPacket { - ip = d.ip4.SrcIP - } else { - ip = d.ip4.DstIP - } + ip = d.ip4.SrcIP case layers.LayerTypeIPv6: - if isIncomingPacket { - ip = d.ip6.SrcIP - } else { - ip = d.ip6.DstIP - } + ip = d.ip6.SrcIP } filter, ok := validateRule(ip, packetData, rules[ip.String()], d) diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index d7c93cb7f99..5ad8ab4a270 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -325,7 +325,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if m.dropFilter(buf.Bytes(), m.outgoingRules, false) { + if m.dropFilter(buf.Bytes(), m.outgoingRules) { t.Errorf("expected packet to be accepted") return } diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index 113e9b5955f..f87f104293c 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -45,8 +45,27 @@ func newDeviceFilter(device tun.Device) *FilteredDevice { // Read wraps read method with filtering feature func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { - // outgoing traffic is not filtered - return d.Device.Read(bufs, sizes, offset) + if n, err = d.Device.Read(bufs, sizes, offset); err != nil { + return 0, err + } + d.mutex.RLock() + filter := d.filter + d.mutex.RUnlock() + + if filter == nil { + return + } + + for i := 0; i < n; i++ { + if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) { + bufs = append(bufs[:i], bufs[i+1:]...) + sizes = append(sizes[:i], sizes[i+1:]...) + n-- + i-- + } + } + + return n, nil } // Write wraps write method with filtering feature From a778e91ca195754dccca3c64dee63db317c2ba1b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 13:38:50 +0100 Subject: [PATCH 06/28] Add udp hook test --- client/firewall/uspfilter/uspfilter_test.go | 75 +++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 5ad8ab4a270..4677c07c4d4 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -3,6 +3,7 @@ package uspfilter import ( "fmt" "net" + "sync" "testing" "time" @@ -384,6 +385,80 @@ func TestRemovePacketHook(t *testing.T) { } } +func TestProcessOutgoingHooks(t *testing.T) { + manager := &Manager{ + outgoingRules: map[string]RuleSet{}, + wgNetwork: &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + }, + decoders: sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + }, + } + + hookCalled := false + hookID := manager.AddUDPPacketHook( + false, + net.ParseIP("100.10.0.100"), + 53, + func([]byte) bool { + hookCalled = true + return true + }, + ) + require.NotEmpty(t, hookID) + + // Create test UDP packet + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: net.ParseIP("100.10.0.1"), + DstIP: net.ParseIP("100.10.0.100"), + Protocol: layers.IPProtocolUDP, + } + udp := &layers.UDP{ + SrcPort: 51334, + DstPort: 53, + } + + err := udp.SetNetworkLayerForChecksum(ipv4) + require.NoError(t, err) + payload := gopacket.Payload([]byte("test")) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload) + require.NoError(t, err) + + // Test hook gets called + result := manager.processOutgoingHooks(buf.Bytes()) + require.True(t, result) + require.True(t, hookCalled) + + // Test non-UDP packet is ignored + ipv4.Protocol = layers.IPProtocolTCP + buf = gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(buf, opts, ipv4) + require.NoError(t, err) + + result = manager.processOutgoingHooks(buf.Bytes()) + require.False(t, result) +} + func TestUSPFilterCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { From 8216ab65568b1524363861a680d2e07176983977 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 15:08:33 +0100 Subject: [PATCH 07/28] Add udp conn tracking --- client/firewall/uspfilter/allow_netbird.go | 10 +- .../uspfilter/allow_netbird_windows.go | 6 + client/firewall/uspfilter/conntrack/udp.go | 176 +++++++++++++ .../firewall/uspfilter/conntrack/udp_test.go | 233 ++++++++++++++++++ client/firewall/uspfilter/uspfilter.go | 51 +++- client/firewall/uspfilter/uspfilter_test.go | 206 ++++++++++++++++ 6 files changed, 674 insertions(+), 8 deletions(-) create mode 100644 client/firewall/uspfilter/conntrack/udp.go create mode 100644 client/firewall/uspfilter/conntrack/udp_test.go diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index cefc81a3ce6..f5ca6ba286c 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -2,7 +2,10 @@ package uspfilter -import "github.com/netbirdio/netbird/client/internal/statemanager" +import ( + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/internal/statemanager" +) // Reset firewall to the default state func (m *Manager) Reset(stateManager *statemanager.Manager) error { @@ -12,6 +15,11 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(udpTimeout) + } + if m.nativeFirewall != nil { return m.nativeFirewall.Reset(stateManager) } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index d3732301ed5..ff9513cb137 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -26,6 +27,11 @@ func (m *Manager) Reset(*statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(udpTimeout) + } + if !isWindowsFirewallReachable() { return nil } diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go new file mode 100644 index 00000000000..0a0a92e8d00 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -0,0 +1,176 @@ +package conntrack + +import ( + "net" + "sync" + "time" +) + +const ( + // DefaultTimeout is the default timeout for UDP connections + DefaultTimeout = 30 * time.Second + // CleanupInterval is how often we check for stale connections + CleanupInterval = 15 * time.Second +) + +type ConnKey struct { + // Supports both IPv4 and IPv6 + SrcIP [16]byte + DstIP [16]byte + SrcPort uint16 + DstPort uint16 +} + +// UDPConnTrack represents a UDP connection state +type UDPConnTrack struct { + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + LastSeen time.Time + established bool +} + +// UDPTracker manages UDP connection states +type UDPTracker struct { + connections map[ConnKey]*UDPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} +} + +// NewUDPTracker creates a new UDP connection tracker +func NewUDPTracker(timeout time.Duration) *UDPTracker { + if timeout == 0 { + timeout = DefaultTimeout + } + + tracker := &UDPTracker{ + connections: make(map[ConnKey]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(CleanupInterval), + done: make(chan struct{}), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound UDP connection +func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { + t.mutex.Lock() + defer t.mutex.Unlock() + + key := makeKey(srcIP, srcPort, dstIP, dstPort) + + t.connections[key] = &UDPConnTrack{ + SourceIP: srcIP, + DestIP: dstIP, + SourcePort: srcPort, + DestPort: dstPort, + LastSeen: time.Now(), + established: true, + } +} + +// IsValidInbound checks if an inbound packet matches a tracked connection +func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := makeKey(dstIP, dstPort, srcIP, srcPort) + conn, exists := t.connections[key] + if !exists { + return false + } + + // Check if connection is still valid + if time.Since(conn.LastSeen) > t.timeout { + return false + } + + if conn.established && + conn.DestIP.Equal(srcIP) && + conn.SourceIP.Equal(dstIP) && + conn.DestPort == srcPort && + conn.SourcePort == dstPort { + + conn.LastSeen = time.Now() + + return true + } + + return false +} + +// cleanupRoutine periodically removes stale connections +func (t *UDPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *UDPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + now := time.Now() + for key, conn := range t.connections { + if now.Sub(conn.LastSeen) > t.timeout { + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *UDPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) +} + +// GetConnection safely retrieves a connection state +func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := makeKey(srcIP, srcPort, dstIP, dstPort) + conn, exists := t.connections[key] + if !exists { + return nil, false + } + + // Return a copy to prevent potential race conditions + connCopy := &UDPConnTrack{ + SourceIP: append(net.IP{}, conn.SourceIP...), + DestIP: append(net.IP{}, conn.DestIP...), + SourcePort: conn.SourcePort, + DestPort: conn.DestPort, + LastSeen: conn.LastSeen, + established: conn.established, + } + + return connCopy, true +} + +// Timeout returns the configured timeout duration for the tracker +func (t *UDPTracker) Timeout() time.Duration { + return t.timeout +} + +func makeKey(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) ConnKey { + var srcAddr, dstAddr [16]byte + copy(srcAddr[:], srcIP.To16()) // Ensure 16-byte representation + copy(dstAddr[:], dstIP.To16()) + return ConnKey{ + SrcIP: srcAddr, + SrcPort: srcPort, + DstIP: dstAddr, + DstPort: dstPort, + } +} diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go new file mode 100644 index 00000000000..9e15d310dc8 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -0,0 +1,233 @@ +package conntrack + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewUDPTracker(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + wantTimeout time.Duration + }{ + { + name: "with custom timeout", + timeout: 1 * time.Minute, + wantTimeout: 1 * time.Minute, + }, + { + name: "with zero timeout uses default", + timeout: 0, + wantTimeout: DefaultTimeout, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tracker := NewUDPTracker(tt.timeout) + assert.NotNil(t, tracker) + assert.Equal(t, tt.wantTimeout, tracker.timeout) + assert.NotNil(t, tracker.connections) + assert.NotNil(t, tracker.cleanupTicker) + assert.NotNil(t, tracker.done) + }) + } +} + +func TestUDPTracker_TrackOutbound(t *testing.T) { + tracker := NewUDPTracker(DefaultTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + // Verify connection was tracked + conn, exists := tracker.connections[srcPort] + require.True(t, exists) + assert.True(t, conn.SourceIP.Equal(srcIP)) + assert.True(t, conn.DestIP.Equal(dstIP)) + assert.Equal(t, srcPort, conn.SourcePort) + assert.Equal(t, dstPort, conn.DestPort) + assert.True(t, conn.established) + assert.WithinDuration(t, time.Now(), conn.LastSeen, 1*time.Second) +} + +func TestUDPTracker_IsValidInbound(t *testing.T) { + tracker := NewUDPTracker(1 * time.Second) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + // Track outbound connection + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + tests := []struct { + name string + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + sleep time.Duration + want bool + }{ + { + name: "valid inbound response", + srcIP: dstIP, // Original destination is now source + dstIP: srcIP, // Original source is now destination + srcPort: dstPort, // Original destination port is now source + dstPort: srcPort, // Original source port is now destination + sleep: 0, + want: true, + }, + { + name: "invalid source IP", + srcIP: net.ParseIP("192.168.1.4"), + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination IP", + srcIP: dstIP, + dstIP: net.ParseIP("192.168.1.4"), + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid source port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: 54321, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: 54321, + sleep: 0, + want: false, + }, + { + name: "expired connection", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 2 * time.Second, // Longer than tracker timeout + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.sleep > 0 { + time.Sleep(tt.sleep) + } + got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestUDPTracker_Cleanup(t *testing.T) { + // Use shorter intervals for testing + timeout := 50 * time.Millisecond + cleanupInterval := 25 * time.Millisecond + + // Create tracker with custom cleanup interval + tracker := &UDPTracker{ + connections: make(map[uint16]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(cleanupInterval), + done: make(chan struct{}), + } + + // Start cleanup routine + go tracker.cleanupRoutine() + defer tracker.Close() + + // Add some connections + connections := []struct { + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + }{ + { + srcIP: net.ParseIP("192.168.1.2"), + dstIP: net.ParseIP("192.168.1.3"), + srcPort: 12345, + dstPort: 53, + }, + { + srcIP: net.ParseIP("192.168.1.4"), + dstIP: net.ParseIP("192.168.1.5"), + srcPort: 12346, + dstPort: 53, + }, + } + + for _, conn := range connections { + tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) + } + + // Verify initial connections + tracker.mutex.RLock() + assert.Len(t, tracker.connections, 2) + tracker.mutex.RUnlock() + + // Wait for connection timeout and cleanup interval + time.Sleep(timeout + 2*cleanupInterval) + + // Verify connections were cleaned up + tracker.mutex.RLock() + assert.Empty(t, tracker.connections) + tracker.mutex.RUnlock() + + // Add a new connection and verify it's not immediately cleaned up + tracker.TrackOutbound(connections[0].srcIP, connections[0].dstIP, + connections[0].srcPort, connections[0].dstPort) + + tracker.mutex.RLock() + assert.Len(t, tracker.connections, 1, "New connection should not be cleaned up immediately") + tracker.mutex.RUnlock() +} + +func TestUDPTracker_Close(t *testing.T) { + tracker := NewUDPTracker(DefaultTimeout) + + // Add a connection + tracker.TrackOutbound( + net.ParseIP("192.168.1.2"), + net.ParseIP("192.168.1.3"), + 12345, + 53, + ) + + // Close the tracker + tracker.Close() + + // Verify done channel is closed + _, ok := <-tracker.done + assert.False(t, ok, "done channel should be closed") +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index e7c26b11874..d0fc3c18000 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -5,6 +5,7 @@ import ( "net" "net/netip" "sync" + "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -12,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -19,6 +21,8 @@ import ( const layerTypeAll = 0 +const udpTimeout = 30 * time.Second + var ( errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") ) @@ -41,7 +45,8 @@ type Manager struct { wgIface IFaceMapper nativeFirewall firewall.Manager - mutex sync.RWMutex + mutex sync.RWMutex + udpTracker *conntrack.UDPTracker } // decoder for packages @@ -90,6 +95,7 @@ func create(iface IFaceMapper) (*Manager, error) { outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), wgIface: iface, + udpTracker: conntrack.NewUDPTracker(udpTimeout), } if err := iface.SetFilter(m); err != nil { @@ -273,18 +279,27 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } - var ip net.IP + var srcIP, dstIP net.IP switch d.decoded[0] { case layers.LayerTypeIPv4: - ip = d.ip4.DstIP + srcIP = d.ip4.SrcIP + dstIP = d.ip4.DstIP case layers.LayerTypeIPv6: - ip = d.ip6.DstIP + srcIP = d.ip6.SrcIP + dstIP = d.ip6.DstIP default: return false } - // Check specific IP rules first, then any-IP rules - for _, ipKey := range []string{ip.String(), "0.0.0.0", "::"} { + // Track outbound UDP connection + m.udpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) + + for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { if rules, exists := m.outgoingRules[ipKey]; exists { for _, rule := range rules { if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { @@ -296,7 +311,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } -// dropFilter implements same logic for booth direction of the traffic +// dropFilter implements filtering logic for incoming packets func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -314,6 +329,28 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { return true } + // For UDP inbound packets, check if they match tracked connections + if d.decoded[1] == layers.LayerTypeUDP { + var srcIP, dstIP net.IP + switch d.decoded[0] { + case layers.LayerTypeIPv4: + srcIP = d.ip4.SrcIP + dstIP = d.ip4.DstIP + case layers.LayerTypeIPv6: + srcIP = d.ip6.SrcIP + dstIP = d.ip6.DstIP + } + + if m.udpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) { + return false + } + } + ipLayer := d.decoded[0] switch ipLayer { diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 4677c07c4d4..23f575843a3 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -405,6 +406,7 @@ func TestProcessOutgoingHooks(t *testing.T) { return d }, }, + udpTracker: conntrack.NewUDPTracker(100 * time.Millisecond), } hookCalled := false @@ -493,3 +495,207 @@ func TestUSPFilterCreatePerformance(t *testing.T) { }) } } + +func TestStatefulFirewall_UDPTracking(t *testing.T) { + manager := &Manager{ + outgoingRules: map[string]RuleSet{}, + incomingRules: map[string]RuleSet{}, + wgNetwork: &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + }, + decoders: sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + }, + udpTracker: conntrack.NewUDPTracker(200 * time.Millisecond), + } + defer manager.udpTracker.Close() + + // Set up packet parameters + srcIP := net.ParseIP("100.10.0.1") + dstIP := net.ParseIP("100.10.0.100") + srcPort := uint16(51334) + dstPort := uint16(53) + + // Create outbound packet + outboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: layers.IPProtocolUDP, + } + outboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + + err := outboundUDP.SetNetworkLayerForChecksum(outboundIPv4) + require.NoError(t, err) + + outboundBuf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + err = gopacket.SerializeLayers(outboundBuf, opts, + outboundIPv4, + outboundUDP, + gopacket.Payload([]byte("test")), + ) + require.NoError(t, err) + + // Process outbound packet and verify connection tracking + drop := manager.processOutgoingHooks(outboundBuf.Bytes()) + require.False(t, drop, "Initial outbound packet should not be dropped") + + // Verify connection was tracked + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + require.True(t, exists, "Connection should be tracked after outbound packet") + require.True(t, conn.SourceIP.Equal(srcIP), "Source IP should match") + require.True(t, conn.DestIP.Equal(dstIP), "Destination IP should match") + require.Equal(t, srcPort, conn.SourcePort, "Source port should match") + require.Equal(t, dstPort, conn.DestPort, "Destination port should match") + + // Create valid inbound response packet + inboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: dstIP, // Original destination is now source + DstIP: srcIP, // Original source is now destination + Protocol: layers.IPProtocolUDP, + } + inboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(dstPort), // Original destination port is now source + DstPort: layers.UDPPort(srcPort), // Original source port is now destination + } + + err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4) + require.NoError(t, err) + + inboundBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(inboundBuf, opts, + inboundIPv4, + inboundUDP, + gopacket.Payload([]byte("response")), + ) + require.NoError(t, err) + // Test roundtrip response handling over time + checkPoints := []struct { + sleep time.Duration + shouldAllow bool + description string + }{ + { + sleep: 0, + shouldAllow: true, + description: "Immediate response should be allowed", + }, + { + sleep: 50 * time.Millisecond, + shouldAllow: true, + description: "Response within timeout should be allowed", + }, + { + sleep: 100 * time.Millisecond, + shouldAllow: true, + description: "Response at half timeout should be allowed", + }, + { + // tracker hasn't updated conn for 250ms -> greater than 200ms timeout + sleep: 250 * time.Millisecond, + shouldAllow: false, + description: "Response after timeout should be dropped", + }, + } + + for _, cp := range checkPoints { + time.Sleep(cp.sleep) + + drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules) + require.Equal(t, cp.shouldAllow, !drop, cp.description) + + // If the connection should still be valid, verify it exists + if cp.shouldAllow { + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + require.True(t, exists, "Connection should still exist during valid window") + require.True(t, time.Since(conn.LastSeen) < manager.udpTracker.Timeout(), + "LastSeen should be updated for valid responses") + } + } + + // Test invalid response packets (while connection is expired) + invalidCases := []struct { + name string + modifyFunc func(*layers.IPv4, *layers.UDP) + description string + }{ + { + name: "wrong source IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.SrcIP = net.ParseIP("100.10.0.101") + }, + description: "Response from wrong IP should be dropped", + }, + { + name: "wrong destination IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.DstIP = net.ParseIP("100.10.0.2") + }, + description: "Response to wrong IP should be dropped", + }, + { + name: "wrong source port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.SrcPort = 54 + }, + description: "Response from wrong port should be dropped", + }, + { + name: "wrong destination port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.DstPort = 51335 + }, + description: "Response to wrong port should be dropped", + }, + } + + // Create a new outbound connection for invalid tests + drop = manager.processOutgoingHooks(outboundBuf.Bytes()) + require.False(t, drop, "Second outbound packet should not be dropped") + + for _, tc := range invalidCases { + t.Run(tc.name, func(t *testing.T) { + testIPv4 := *inboundIPv4 + testUDP := *inboundUDP + + tc.modifyFunc(&testIPv4, &testUDP) + + err = testUDP.SetNetworkLayerForChecksum(&testIPv4) + require.NoError(t, err) + + testBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(testBuf, opts, + &testIPv4, + &testUDP, + gopacket.Payload([]byte("response")), + ) + require.NoError(t, err) + + // Verify the invalid packet is dropped + drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules) + require.True(t, drop, tc.description) + }) + } +} From 5d97cf8567318aed7c6623db0441acf496bdce08 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 17:56:57 +0100 Subject: [PATCH 08/28] Fix udp test --- client/firewall/uspfilter/conntrack/udp_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 9e15d310dc8..a19170c4481 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -51,7 +51,8 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) // Verify connection was tracked - conn, exists := tracker.connections[srcPort] + key := makeKey(srcIP, srcPort, dstIP, dstPort) + conn, exists := tracker.connections[key] require.True(t, exists) assert.True(t, conn.SourceIP.Equal(srcIP)) assert.True(t, conn.DestIP.Equal(dstIP)) @@ -156,7 +157,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { // Create tracker with custom cleanup interval tracker := &UDPTracker{ - connections: make(map[uint16]*UDPConnTrack), + connections: make(map[ConnKey]*UDPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(cleanupInterval), done: make(chan struct{}), From 9d1702c25d9b8834bf8d53900579be9a89af77a9 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 19:25:49 +0100 Subject: [PATCH 09/28] Fix corrupted IPs --- client/firewall/uspfilter/conntrack/udp.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 0a0a92e8d00..cdb2f64cc04 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -2,6 +2,7 @@ package conntrack import ( "net" + "slices" "sync" "time" ) @@ -65,8 +66,8 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d key := makeKey(srcIP, srcPort, dstIP, dstPort) t.connections[key] = &UDPConnTrack{ - SourceIP: srcIP, - DestIP: dstIP, + SourceIP: slices.Clone(srcIP), + DestIP: slices.Clone(dstIP), SourcePort: srcPort, DestPort: dstPort, LastSeen: time.Now(), From 49d1de267194655cf7807520488b47590a04b3e3 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 20:27:22 +0100 Subject: [PATCH 10/28] Use slices.Clone consistently --- client/firewall/uspfilter/conntrack/udp.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index cdb2f64cc04..40448ee7852 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -146,10 +146,9 @@ func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, d return nil, false } - // Return a copy to prevent potential race conditions connCopy := &UDPConnTrack{ - SourceIP: append(net.IP{}, conn.SourceIP...), - DestIP: append(net.IP{}, conn.DestIP...), + SourceIP: slices.Clone(conn.SourceIP), + DestIP: slices.Clone(conn.DestIP), SourcePort: conn.SourcePort, DestPort: conn.DestPort, LastSeen: conn.LastSeen, From 2a5ef98d81910d92cc3fad49ac2be0801942ca3d Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 21:19:28 +0100 Subject: [PATCH 11/28] Add icmp tracker --- client/firewall/uspfilter/conntrack/icmp.go | 159 ++++++++++++++++++ client/firewall/uspfilter/conntrack/udp.go | 12 +- .../firewall/uspfilter/conntrack/udp_test.go | 6 +- client/firewall/uspfilter/uspfilter.go | 120 +++++++------ 4 files changed, 237 insertions(+), 60 deletions(-) create mode 100644 client/firewall/uspfilter/conntrack/icmp.go diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go new file mode 100644 index 00000000000..1968ef6b951 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -0,0 +1,159 @@ +package conntrack + +import ( + "net" + "slices" + "sync" + "time" + + "github.com/google/gopacket/layers" +) + +const ( + // DefaultICMPTimeout is the default timeout for ICMP connections + DefaultICMPTimeout = 30 * time.Second + // ICMPCleanupInterval is how often we check for stale ICMP connections + ICMPCleanupInterval = 15 * time.Second +) + +// ICMPConnKey uniquely identifies an ICMP connection +type ICMPConnKey struct { + // Supports both IPv4 and IPv6 + SrcIP [16]byte + DstIP [16]byte + Sequence uint16 // ICMP sequence number + ID uint16 // ICMP identifier +} + +// ICMPConnTrack represents an ICMP connection state +type ICMPConnTrack struct { + SourceIP net.IP + DestIP net.IP + Sequence uint16 + ID uint16 + LastSeen time.Time + established bool +} + +// ICMPTracker manages ICMP connection states +type ICMPTracker struct { + connections map[ICMPConnKey]*ICMPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} +} + +// NewICMPTracker creates a new ICMP connection tracker +func NewICMPTracker(timeout time.Duration) *ICMPTracker { + if timeout == 0 { + timeout = DefaultICMPTimeout + } + + tracker := &ICMPTracker{ + connections: make(map[ICMPConnKey]*ICMPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(ICMPCleanupInterval), + done: make(chan struct{}), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound ICMP Echo Request +func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { + t.mutex.Lock() + defer t.mutex.Unlock() + + key := makeICMPKey(srcIP, dstIP, id, seq) + + t.connections[key] = &ICMPConnTrack{ + SourceIP: slices.Clone(srcIP), + DestIP: slices.Clone(dstIP), + ID: id, + Sequence: seq, + LastSeen: time.Now(), + established: true, + } +} + +// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request +func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { + t.mutex.RLock() + defer t.mutex.RUnlock() + + // Always allow Echo Request (type 8 for IPv4, 128 for IPv6) + if icmpType == uint8(layers.ICMPv4TypeEchoRequest) || icmpType == uint8(layers.ICMPv6TypeEchoRequest) { + return true + } + + // For Echo Reply, check if we have a matching request + if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) { + return false + } + + key := makeICMPKey(dstIP, srcIP, id, seq) + conn, exists := t.connections[key] + if !exists { + return false + } + + // Check if connection is still valid + if time.Since(conn.LastSeen) > t.timeout { + return false + } + + if conn.established && + conn.DestIP.Equal(srcIP) && + conn.SourceIP.Equal(dstIP) && + conn.ID == id && + conn.Sequence == seq { + + conn.LastSeen = time.Now() + return true + } + + return false +} + +func (t *ICMPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *ICMPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + now := time.Now() + for key, conn := range t.connections { + if now.Sub(conn.LastSeen) > t.timeout { + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *ICMPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) +} + +func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { + var srcAddr, dstAddr [16]byte + copy(srcAddr[:], srcIP.To16()) + copy(dstAddr[:], dstIP.To16()) + return ICMPConnKey{ + SrcIP: srcAddr, + DstIP: dstAddr, + ID: id, + Sequence: seq, + } +} diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 40448ee7852..b4f1b898171 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -8,10 +8,10 @@ import ( ) const ( - // DefaultTimeout is the default timeout for UDP connections - DefaultTimeout = 30 * time.Second - // CleanupInterval is how often we check for stale connections - CleanupInterval = 15 * time.Second + // DefaultUDPTimeout is the default timeout for UDP connections + DefaultUDPTimeout = 30 * time.Second + // UDPCleanupInterval is how often we check for stale connections + UDPCleanupInterval = 15 * time.Second ) type ConnKey struct { @@ -44,13 +44,13 @@ type UDPTracker struct { // NewUDPTracker creates a new UDP connection tracker func NewUDPTracker(timeout time.Duration) *UDPTracker { if timeout == 0 { - timeout = DefaultTimeout + timeout = DefaultUDPTimeout } tracker := &UDPTracker{ connections: make(map[ConnKey]*UDPConnTrack), timeout: timeout, - cleanupTicker: time.NewTicker(CleanupInterval), + cleanupTicker: time.NewTicker(UDPCleanupInterval), done: make(chan struct{}), } diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index a19170c4481..938dc18ea59 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -23,7 +23,7 @@ func TestNewUDPTracker(t *testing.T) { { name: "with zero timeout uses default", timeout: 0, - wantTimeout: DefaultTimeout, + wantTimeout: DefaultUDPTimeout, }, } @@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) { } func TestUDPTracker_TrackOutbound(t *testing.T) { - tracker := NewUDPTracker(DefaultTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -215,7 +215,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { } func TestUDPTracker_Close(t *testing.T) { - tracker := NewUDPTracker(DefaultTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout) // Add a connection tracker.TrackOutbound( diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index d0fc3c18000..45fd3b5e0b8 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -45,8 +45,9 @@ type Manager struct { wgIface IFaceMapper nativeFirewall firewall.Manager - mutex sync.RWMutex - udpTracker *conntrack.UDPTracker + mutex sync.RWMutex + udpTracker *conntrack.UDPTracker + icmpTracker *conntrack.ICMPTracker } // decoder for packages @@ -95,7 +96,8 @@ func create(iface IFaceMapper) (*Manager, error) { outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), wgIface: iface, - udpTracker: conntrack.NewUDPTracker(udpTimeout), + udpTracker: conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout), + icmpTracker: conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout), } if err := iface.SetFilter(m); err != nil { @@ -264,6 +266,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool { } // processOutgoingHooks processes only UDP hooks for outgoing packets +// processOutgoingHooks processes UDP and ICMP hooks for outgoing packets func (m *Manager) processOutgoingHooks(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -275,7 +278,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } - if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeUDP { + if len(d.decoded) < 2 { return false } @@ -291,23 +294,38 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } - // Track outbound UDP connection - m.udpTracker.TrackOutbound( - srcIP, - dstIP, - uint16(d.udp.SrcPort), - uint16(d.udp.DstPort), - ) - - for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { - if rules, exists := m.outgoingRules[ipKey]; exists { - for _, rule := range rules { - if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { - return rule.udpHook(packetData) + switch d.decoded[1] { + case layers.LayerTypeUDP: + // Track outbound UDP connection + m.udpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) + + for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { + if rules, exists := m.outgoingRules[ipKey]; exists { + for _, rule := range rules { + if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { + return rule.udpHook(packetData) + } } } } + + case layers.LayerTypeICMPv4: + // Track outbound ICMP Echo Request + if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { + m.icmpTracker.TrackOutbound( + srcIP, + dstIP, + d.icmp4.Id, + d.icmp4.Seq, + ) + } } + return false } @@ -329,18 +347,26 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { return true } - // For UDP inbound packets, check if they match tracked connections - if d.decoded[1] == layers.LayerTypeUDP { - var srcIP, dstIP net.IP - switch d.decoded[0] { - case layers.LayerTypeIPv4: - srcIP = d.ip4.SrcIP - dstIP = d.ip4.DstIP - case layers.LayerTypeIPv6: - srcIP = d.ip6.SrcIP - dstIP = d.ip6.DstIP - } + var srcIP, dstIP net.IP + switch d.decoded[0] { + case layers.LayerTypeIPv4: + srcIP = d.ip4.SrcIP + dstIP = d.ip4.DstIP + case layers.LayerTypeIPv6: + srcIP = d.ip6.SrcIP + dstIP = d.ip6.DstIP + default: + log.Errorf("unknown layer: %v", d.decoded[0]) + return true + } + if !m.wgNetwork.Contains(srcIP) || !m.wgNetwork.Contains(dstIP) { + return false + } + + switch d.decoded[1] { + case layers.LayerTypeUDP: + // Check if inbound UDP packet matches a tracked connection if m.udpTracker.IsValidInbound( srcIP, dstIP, @@ -349,41 +375,33 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { ) { return false } - } - ipLayer := d.decoded[0] - - switch ipLayer { - case layers.LayerTypeIPv4: - if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) { - return false - } - case layers.LayerTypeIPv6: - if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) { + case layers.LayerTypeICMPv4: + // Check if inbound ICMP packet is valid + if m.icmpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.icmp4.Id), + uint16(d.icmp4.Seq), + uint8(d.icmp4.TypeCode.Type()), + ) { return false } - default: - log.Errorf("unknown layer: %v", d.decoded[0]) - return true - } - var ip net.IP - switch ipLayer { - case layers.LayerTypeIPv4: - ip = d.ip4.SrcIP - case layers.LayerTypeIPv6: - ip = d.ip6.SrcIP + // TODO: Handle icmpv6 + // TODO: Handle icmp destination unreachable and others + } - filter, ok := validateRule(ip, packetData, rules[ip.String()], d) + filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d) if ok { return filter } - filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d) + filter, ok = validateRule(srcIP, packetData, rules["0.0.0.0"], d) if ok { return filter } - filter, ok = validateRule(ip, packetData, rules["::"], d) + filter, ok = validateRule(srcIP, packetData, rules["::"], d) if ok { return filter } From f0c8c90a6279f74058b969957ac07a89f12720ff Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 23:56:11 +0100 Subject: [PATCH 12/28] Reset icmp tracker --- client/firewall/uspfilter/allow_netbird.go | 7 ++++++- client/firewall/uspfilter/allow_netbird_windows.go | 7 ++++++- client/firewall/uspfilter/conntrack/tcp.go | 1 + 3 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 client/firewall/uspfilter/conntrack/tcp.go diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index f5ca6ba286c..9b4f77440cc 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -17,7 +17,12 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(udpTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + } + + if m.icmpTracker != nil { + m.icmpTracker.Close() + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) } if m.nativeFirewall != nil { diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index ff9513cb137..7100690e7ff 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -29,7 +29,12 @@ func (m *Manager) Reset(*statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(udpTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + } + + if m.icmpTracker != nil { + m.icmpTracker.Close() + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) } if !isWindowsFirewallReachable() { diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go new file mode 100644 index 00000000000..229e60c6704 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -0,0 +1 @@ +package conntrack From dadf64ed1558d8d5928ac67cc910ef617d427344 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 21:19:28 +0100 Subject: [PATCH 13/28] Clean up icmp tracker --- client/firewall/uspfilter/uspfilter.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 45fd3b5e0b8..f996b69c300 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -5,7 +5,6 @@ import ( "net" "net/netip" "sync" - "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -21,8 +20,6 @@ import ( const layerTypeAll = 0 -const udpTimeout = 30 * time.Second - var ( errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") ) From fa38d8ec7fee6af4059410545a254ac44d71ebd6 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 21 Dec 2024 00:03:51 +0100 Subject: [PATCH 14/28] Add TCP tracker --- client/firewall/uspfilter/allow_netbird.go | 5 + .../uspfilter/allow_netbird_windows.go | 5 + client/firewall/uspfilter/conntrack/tcp.go | 310 ++++++++++++++++++ client/firewall/uspfilter/uspfilter.go | 214 +++++++----- 4 files changed, 459 insertions(+), 75 deletions(-) diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 9b4f77440cc..cc07922559d 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -25,6 +25,11 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) } + if m.tcpTracker != nil { + m.tcpTracker.Close() + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + } + if m.nativeFirewall != nil { return m.nativeFirewall.Reset(stateManager) } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 7100690e7ff..0d55d62689c 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -37,6 +37,11 @@ func (m *Manager) Reset(*statemanager.Manager) error { m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) } + if m.tcpTracker != nil { + m.tcpTracker.Close() + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + } + if !isWindowsFirewallReachable() { return nil } diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 229e60c6704..a4eef164e14 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -1 +1,311 @@ package conntrack + +// TODO: Send RST packets for invalid/timed-out connections + +import ( + "net" + "slices" + "sync" + "time" +) + +const ( + // MSL (Maximum Segment Lifetime) is typically 2 minutes + MSL = 2 * time.Minute + // TimeWaitTimeout (TIME-WAIT) should last 2*MSL + TimeWaitTimeout = 2 * MSL +) + +const ( + TCPSyn uint8 = 0x02 + TCPAck uint8 = 0x10 + TCPFin uint8 = 0x01 + TCPRst uint8 = 0x04 + TCPPush uint8 = 0x08 + TCPUrg uint8 = 0x20 +) + +const ( + // DefaultTCPTimeout is the default timeout for established TCP connections + DefaultTCPTimeout = 3 * time.Hour + // TCPHandshakeTimeout is timeout for TCP handshake completion + TCPHandshakeTimeout = 60 * time.Second + // TCPCleanupInterval is how often we check for stale connections + TCPCleanupInterval = 5 * time.Minute +) + +// TCPState represents the state of a TCP connection +type TCPState int + +const ( + TCPStateNew TCPState = iota + TCPStateSynSent + TCPStateSynReceived + TCPStateEstablished + TCPStateFinWait1 + TCPStateFinWait2 + TCPStateClosing + TCPStateTimeWait + TCPStateCloseWait + TCPStateLastAck + TCPStateClosed +) + +// TCPConnKey uniquely identifies a TCP connection +type TCPConnKey struct { + SrcIP [16]byte + DstIP [16]byte + SrcPort uint16 + DstPort uint16 +} + +// TCPConnTrack represents a TCP connection state +type TCPConnTrack struct { + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + State TCPState + LastSeen time.Time + established bool +} + +// TCPTracker manages TCP connection states +type TCPTracker struct { + connections map[TCPConnKey]*TCPConnTrack + mutex sync.RWMutex + cleanupTicker *time.Ticker + done chan struct{} + timeout time.Duration +} + +// NewTCPTracker creates a new TCP connection tracker +func NewTCPTracker(timeout time.Duration) *TCPTracker { + tracker := &TCPTracker{ + connections: make(map[TCPConnKey]*TCPConnTrack), + cleanupTicker: time.NewTicker(TCPCleanupInterval), + done: make(chan struct{}), + timeout: timeout, + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound processes an outbound TCP packet and updates connection state +func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { + t.mutex.Lock() + defer t.mutex.Unlock() + + key := makeTCPKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now() + + conn, exists := t.connections[key] + if !exists { + conn = &TCPConnTrack{ + SourceIP: slices.Clone(srcIP), + DestIP: slices.Clone(dstIP), + SourcePort: srcPort, + DestPort: dstPort, + State: TCPStateNew, + LastSeen: now, + established: false, + } + t.connections[key] = conn + } + + // Update connection state based on TCP flags + t.updateState(conn, flags, true) + conn.LastSeen = now +} + +// IsValidInbound checks if an inbound TCP packet matches a tracked connection +func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { + t.mutex.Lock() + defer t.mutex.Unlock() + + // For SYN packets (new connection attempts), always allow + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + key := makeTCPKey(dstIP, srcIP, dstPort, srcPort) + t.connections[key] = &TCPConnTrack{ + SourceIP: slices.Clone(dstIP), + DestIP: slices.Clone(srcIP), + SourcePort: dstPort, + DestPort: srcPort, + State: TCPStateSynReceived, + LastSeen: time.Now(), + established: false, + } + return true + } + + key := makeTCPKey(dstIP, srcIP, dstPort, srcPort) + conn, exists := t.connections[key] + if !exists { + return false + } + + // Update state and check validity + if flags&TCPRst != 0 { + conn.State = TCPStateClosed + conn.established = false + return true + } + + // Special handling for FIN state + if conn.State == TCPStateFinWait1 || conn.State == TCPStateFinWait2 { + t.updateState(conn, flags, false) + conn.LastSeen = time.Now() + return true + } + + t.updateState(conn, flags, false) + conn.LastSeen = time.Now() + + // Allow if established or in a valid state for the flags + return conn.established || t.isValidStateForFlags(conn.State, flags) +} + +// updateState updates the TCP connection state based on flags +func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) { + // Handle RST flag specially - it always causes transition to closed + if flags&TCPRst != 0 { + conn.State = TCPStateClosed + conn.established = false + return + } + + switch conn.State { + case TCPStateNew: + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + conn.State = TCPStateSynSent + } + + case TCPStateSynSent: + if flags&TCPSyn != 0 && flags&TCPAck != 0 { + if isOutbound { + conn.State = TCPStateSynReceived + } else { + // Simultaneous open + conn.State = TCPStateEstablished + conn.established = true + } + } + + case TCPStateSynReceived: + if flags&TCPAck != 0 && flags&TCPSyn == 0 { + conn.State = TCPStateEstablished + conn.established = true + } + + case TCPStateEstablished: + if flags&TCPFin != 0 { + if isOutbound { + conn.State = TCPStateFinWait1 + } else { + conn.State = TCPStateCloseWait + } + conn.established = false + } + + case TCPStateFinWait1: + if flags&TCPFin != 0 && flags&TCPAck != 0 { + // Simultaneous close + conn.State = TCPStateClosing + } else if flags&TCPFin != 0 { + conn.State = TCPStateFinWait2 + } else if flags&TCPAck != 0 { + conn.State = TCPStateFinWait2 + } + + case TCPStateFinWait2: + if flags&TCPFin != 0 { + conn.State = TCPStateTimeWait + } + + case TCPStateClosing: + if flags&TCPAck != 0 { + conn.State = TCPStateTimeWait + } + + case TCPStateCloseWait: + if flags&TCPFin != 0 { + conn.State = TCPStateLastAck + } + + case TCPStateLastAck: + if flags&TCPAck != 0 { + conn.State = TCPStateClosed + } + + case TCPStateTimeWait: + // Stay in TIME-WAIT for 2MSL before transitioning to closed + // This is handled by the cleanup routine + } +} + +// isValidStateForFlags checks if the TCP flags are valid for the current connection state +func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { + switch state { + case TCPStateSynSent: + return flags&TCPSyn != 0 && flags&TCPAck != 0 + case TCPStateSynReceived: + return flags&TCPAck != 0 + case TCPStateEstablished: + return true // Allow all flags in established state + case TCPStateFinWait1, TCPStateFinWait2: + return flags&TCPFin != 0 || flags&TCPAck != 0 + } + return false +} + +func (t *TCPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *TCPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + now := time.Now() + for key, conn := range t.connections { + var timeout time.Duration + switch { + case conn.State == TCPStateTimeWait: + timeout = TimeWaitTimeout + case conn.established: + timeout = t.timeout + default: + timeout = TCPHandshakeTimeout + } + + if now.Sub(conn.LastSeen) > timeout { + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *TCPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) +} + +func makeTCPKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) TCPConnKey { + var srcAddr, dstAddr [16]byte + copy(srcAddr[:], srcIP.To16()) + copy(dstAddr[:], dstIP.To16()) + return TCPConnKey{ + SrcIP: srcAddr, + DstIP: dstAddr, + SrcPort: srcPort, + DstPort: dstPort, + } +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index f996b69c300..b07a7ef14e7 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -45,6 +45,7 @@ type Manager struct { mutex sync.RWMutex udpTracker *conntrack.UDPTracker icmpTracker *conntrack.ICMPTracker + tcpTracker *conntrack.TCPTracker } // decoder for packages @@ -95,6 +96,7 @@ func create(iface IFaceMapper) (*Manager, error) { wgIface: iface, udpTracker: conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout), icmpTracker: conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout), + tcpTracker: conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout), } if err := iface.SetFilter(m); err != nil { @@ -262,8 +264,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool { return m.dropFilter(packetData, m.incomingRules) } -// processOutgoingHooks processes only UDP hooks for outgoing packets -// processOutgoingHooks processes UDP and ICMP hooks for outgoing packets +// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP func (m *Manager) processOutgoingHooks(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -279,53 +280,102 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } - var srcIP, dstIP net.IP + srcIP, dstIP := m.extractIPs(d) + if srcIP == nil { + return false + } + + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.trackTCPOutbound(d, srcIP, dstIP) + case layers.LayerTypeUDP: + m.trackUDPOutbound(d, srcIP, dstIP) + return m.checkUDPHooks(d, dstIP, packetData) + case layers.LayerTypeICMPv4: + m.trackICMPOutbound(d, srcIP, dstIP) + } + + return false +} + +func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { switch d.decoded[0] { case layers.LayerTypeIPv4: - srcIP = d.ip4.SrcIP - dstIP = d.ip4.DstIP + return d.ip4.SrcIP, d.ip4.DstIP case layers.LayerTypeIPv6: - srcIP = d.ip6.SrcIP - dstIP = d.ip6.DstIP + return d.ip6.SrcIP, d.ip6.DstIP default: - return false + return nil, nil } +} - switch d.decoded[1] { - case layers.LayerTypeUDP: - // Track outbound UDP connection - m.udpTracker.TrackOutbound( - srcIP, - dstIP, - uint16(d.udp.SrcPort), - uint16(d.udp.DstPort), - ) +func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) { + flags := getTCPFlags(&d.tcp) + m.tcpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.tcp.SrcPort), + uint16(d.tcp.DstPort), + flags, + ) +} - for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { - if rules, exists := m.outgoingRules[ipKey]; exists { - for _, rule := range rules { - if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { - return rule.udpHook(packetData) - } +func getTCPFlags(tcp *layers.TCP) uint8 { + var flags uint8 + if tcp.SYN { + flags |= conntrack.TCPSyn + } + if tcp.ACK { + flags |= conntrack.TCPAck + } + if tcp.FIN { + flags |= conntrack.TCPFin + } + if tcp.RST { + flags |= conntrack.TCPRst + } + if tcp.PSH { + flags |= conntrack.TCPPush + } + if tcp.URG { + flags |= conntrack.TCPUrg + } + return flags +} + +func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) { + m.udpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) +} + +func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { + for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { + if rules, exists := m.outgoingRules[ipKey]; exists { + for _, rule := range rules { + if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { + return rule.udpHook(packetData) } } } - - case layers.LayerTypeICMPv4: - // Track outbound ICMP Echo Request - if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { - m.icmpTracker.TrackOutbound( - srcIP, - dstIP, - d.icmp4.Id, - d.icmp4.Seq, - ) - } } - return false } +func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { + if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { + m.icmpTracker.TrackOutbound( + srcIP, + dstIP, + d.icmp4.Id, + d.icmp4.Seq, + ) + } +} + // dropFilter implements filtering logic for incoming packets func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { m.mutex.RLock() @@ -334,76 +384,90 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - log.Tracef("couldn't decode layer, err: %s", err) + if !m.isValidPacket(d, packetData) { return true } - if len(d.decoded) < 2 { - log.Tracef("not enough levels in network packet") + srcIP, dstIP := m.extractIPs(d) + if srcIP == nil { + log.Errorf("unknown layer: %v", d.decoded[0]) return true } - var srcIP, dstIP net.IP - switch d.decoded[0] { - case layers.LayerTypeIPv4: - srcIP = d.ip4.SrcIP - dstIP = d.ip4.DstIP - case layers.LayerTypeIPv6: - srcIP = d.ip6.SrcIP - dstIP = d.ip6.DstIP - default: - log.Errorf("unknown layer: %v", d.decoded[0]) - return true + if !m.isWireguardTraffic(srcIP, dstIP) { + return false + } + + if m.isValidTrackedConnection(d, srcIP, dstIP) { + return false + } + + return m.applyRules(srcIP, packetData, rules, d) +} + +func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + log.Tracef("couldn't decode layer, err: %s", err) + return false } - if !m.wgNetwork.Contains(srcIP) || !m.wgNetwork.Contains(dstIP) { + if len(d.decoded) < 2 { + log.Tracef("not enough levels in network packet") return false } + return true +} + +func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool { + return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP) +} +func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { switch d.decoded[1] { + case layers.LayerTypeTCP: + return m.tcpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.tcp.SrcPort), + uint16(d.tcp.DstPort), + getTCPFlags(&d.tcp), + ) + case layers.LayerTypeUDP: - // Check if inbound UDP packet matches a tracked connection - if m.udpTracker.IsValidInbound( + return m.udpTracker.IsValidInbound( srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), - ) { - return false - } + ) case layers.LayerTypeICMPv4: - // Check if inbound ICMP packet is valid - if m.icmpTracker.IsValidInbound( + return m.icmpTracker.IsValidInbound( srcIP, dstIP, - uint16(d.icmp4.Id), - uint16(d.icmp4.Seq), - uint8(d.icmp4.TypeCode.Type()), - ) { - return false - } - - // TODO: Handle icmpv6 - // TODO: Handle icmp destination unreachable and others - + d.icmp4.Id, + d.icmp4.Seq, + d.icmp4.TypeCode.Type(), + ) } - filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d) - if ok { + return false +} + +func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { + if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { return filter } - filter, ok = validateRule(srcIP, packetData, rules["0.0.0.0"], d) - if ok { + + if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok { return filter } - filter, ok = validateRule(srcIP, packetData, rules["::"], d) - if ok { + + if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok { return filter } - // default policy is DROP ALL + // Default policy: DROP ALL return true } From 0970b75b0646ff7e8e899604a86a518f0fc6af85 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 21 Dec 2024 01:00:50 +0100 Subject: [PATCH 15/28] Use switch --- client/firewall/uspfilter/conntrack/tcp.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a4eef164e14..7679f102918 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -209,12 +209,13 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo } case TCPStateFinWait1: - if flags&TCPFin != 0 && flags&TCPAck != 0 { + switch { + case flags&TCPFin != 0 && flags&TCPAck != 0: // Simultaneous close conn.State = TCPStateClosing - } else if flags&TCPFin != 0 { + case flags&TCPFin != 0: conn.State = TCPStateFinWait2 - } else if flags&TCPAck != 0 { + case flags&TCPAck != 0: conn.State = TCPStateFinWait2 } From c6aeb48d983d145d1fc093127fdb30cae8eb626f Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 21 Dec 2024 12:59:13 +0100 Subject: [PATCH 16/28] Move locks further down --- client/firewall/uspfilter/conntrack/icmp.go | 4 ++-- client/firewall/uspfilter/conntrack/tcp.go | 5 +++-- client/firewall/uspfilter/conntrack/udp.go | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 1968ef6b951..2d95206a658 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -63,11 +63,11 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker { // TrackOutbound records an outbound ICMP Echo Request func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { + key := makeICMPKey(srcIP, dstIP, id, seq) + t.mutex.Lock() defer t.mutex.Unlock() - key := makeICMPKey(srcIP, dstIP, id, seq) - t.connections[key] = &ICMPConnTrack{ SourceIP: slices.Clone(srcIP), DestIP: slices.Clone(dstIP), diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 7679f102918..5664191fd2b 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -94,12 +94,13 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker { // TrackOutbound processes an outbound TCP packet and updates connection state func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { - t.mutex.Lock() - defer t.mutex.Unlock() key := makeTCPKey(srcIP, dstIP, srcPort, dstPort) now := time.Now() + t.mutex.Lock() + defer t.mutex.Unlock() + conn, exists := t.connections[key] if !exists { conn = &TCPConnTrack{ diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index b4f1b898171..94ab9e15273 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -60,11 +60,11 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker { // TrackOutbound records an outbound UDP connection func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { + key := makeKey(srcIP, srcPort, dstIP, dstPort) + t.mutex.Lock() defer t.mutex.Unlock() - key := makeKey(srcIP, srcPort, dstIP, dstPort) - t.connections[key] = &UDPConnTrack{ SourceIP: slices.Clone(srcIP), DestIP: slices.Clone(dstIP), From b9767d498d613a1993134f2522c42139290d1402 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 21 Dec 2024 20:59:08 +0100 Subject: [PATCH 17/28] Generally allow time exceeded and destination unreachable, disallow echo request --- client/firewall/uspfilter/conntrack/icmp.go | 18 ++++++++++-------- client/firewall/uspfilter/uspfilter.go | 2 ++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 2d95206a658..9b76c370a46 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -80,20 +80,22 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { - t.mutex.RLock() - defer t.mutex.RUnlock() - - // Always allow Echo Request (type 8 for IPv4, 128 for IPv6) - if icmpType == uint8(layers.ICMPv4TypeEchoRequest) || icmpType == uint8(layers.ICMPv6TypeEchoRequest) { + switch icmpType { + // For Destination Unreachable and Time Exceeded, always allow + case uint8(layers.ICMPv4TypeDestinationUnreachable), uint8(layers.ICMPv4TypeTimeExceeded): return true - } - // For Echo Reply, check if we have a matching request - if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) { + case uint8(layers.ICMPv4TypeEchoReply): + // continue further down + default: return false } key := makeICMPKey(dstIP, srcIP, id, seq) + + t.mutex.RLock() + defer t.mutex.RUnlock() + conn, exists := t.connections[key] if !exists { return false diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index b07a7ef14e7..b32a96a57c6 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -449,6 +449,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool d.icmp4.Seq, d.icmp4.TypeCode.Type(), ) + + // TODO: ICMPv6 } return false From 8661eadad0411aa6e82ea0550a2aa41c004874da Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 21 Dec 2024 22:49:57 +0100 Subject: [PATCH 18/28] Add env to disable statefulness --- client/firewall/uspfilter/uspfilter.go | 49 +++++++++++++++++++------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index b32a96a57c6..24cfd6e9691 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -4,6 +4,8 @@ import ( "fmt" "net" "net/netip" + "os" + "strconv" "sync" "github.com/google/gopacket" @@ -20,6 +22,8 @@ import ( const layerTypeAll = 0 +const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" + var ( errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") ) @@ -42,7 +46,9 @@ type Manager struct { wgIface IFaceMapper nativeFirewall firewall.Manager - mutex sync.RWMutex + mutex sync.RWMutex + + stateful bool udpTracker *conntrack.UDPTracker icmpTracker *conntrack.ICMPTracker tcpTracker *conntrack.TCPTracker @@ -77,6 +83,8 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager } func create(iface IFaceMapper) (*Manager, error) { + disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) + m := &Manager{ decoders: sync.Pool{ New: func() any { @@ -94,9 +102,16 @@ func create(iface IFaceMapper) (*Manager, error) { outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), wgIface: iface, - udpTracker: conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout), - icmpTracker: conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout), - tcpTracker: conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout), + stateful: !disableConntrack, + } + + // Only initialize trackers if stateful mode is enabled + if disableConntrack { + log.Info("conntrack is disabled") + } else { + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) } if err := iface.SetFilter(m); err != nil { @@ -285,14 +300,23 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } - switch d.decoded[1] { - case layers.LayerTypeTCP: - m.trackTCPOutbound(d, srcIP, dstIP) - case layers.LayerTypeUDP: - m.trackUDPOutbound(d, srcIP, dstIP) + // Always process UDP hooks + if d.decoded[1] == layers.LayerTypeUDP { + // Track UDP state only if enabled + if m.stateful { + m.trackUDPOutbound(d, srcIP, dstIP) + } return m.checkUDPHooks(d, dstIP, packetData) - case layers.LayerTypeICMPv4: - m.trackICMPOutbound(d, srcIP, dstIP) + } + + // Track other protocols only if stateful mode is enabled + if m.stateful { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.trackTCPOutbound(d, srcIP, dstIP) + case layers.LayerTypeICMPv4: + m.trackICMPOutbound(d, srcIP, dstIP) + } } return false @@ -398,7 +422,8 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { return false } - if m.isValidTrackedConnection(d, srcIP, dstIP) { + // Check connection state only if enabled + if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { return false } From 1306da2f910d82a9392bc0eaec975302cff14f3a Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 21 Dec 2024 23:35:48 +0100 Subject: [PATCH 19/28] Add benchmarks --- .../uspfilter/uspfilter_bench_test.go | 261 ++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 client/firewall/uspfilter/uspfilter_bench_test.go diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go new file mode 100644 index 00000000000..6e748c85f2c --- /dev/null +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -0,0 +1,261 @@ +package uspfilter + +import ( + "fmt" + "math/rand" + "net" + "os" + "testing" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/device" +) + +// generateRandomIPs generates n different random IPs in the 100.64.0.0/10 range +func generateRandomIPs(n int) []net.IP { + ips := make([]net.IP, n) + seen := make(map[string]bool) + + for i := 0; i < n; { + ip := make(net.IP, 4) + ip[0] = 100 + ip[1] = byte(64 + rand.Intn(63)) // 64-126 + ip[2] = byte(rand.Intn(256)) + ip[3] = byte(1 + rand.Intn(254)) // avoid .0 and .255 + + key := ip.String() + if !seen[key] { + ips[i] = ip + seen[key] = true + i++ + } + } + return ips +} + +func generatePacket(srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte { + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: protocol, + } + + var transportLayer gopacket.SerializableLayer + switch protocol { + case layers.IPProtocolTCP: + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + } + tcp.SetNetworkLayerForChecksum(ipv4) + transportLayer = tcp + case layers.IPProtocolUDP: + udp := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + udp.SetNetworkLayerForChecksum(ipv4) + transportLayer = udp + } + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload([]byte("test"))) + return buf.Bytes() +} + +// BenchmarkCoreFiltering focuses on the essential performance comparisons between +// stateful and stateless filtering approaches +func BenchmarkCoreFiltering(b *testing.B) { + scenarios := []struct { + name string + stateful bool + setupFunc func(*Manager) + desc string + }{ + { + name: "stateless_single_allow_all", + stateful: false, + setupFunc: func(m *Manager) { + // Single rule allowing all traffic + m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "allow all") + }, + desc: "Baseline: Single 'allow all' rule without connection tracking", + }, + { + name: "stateful_no_rules", + stateful: true, + setupFunc: func(m *Manager) { + // No explicit rules - rely purely on connection tracking + }, + desc: "Pure connection tracking without any rules", + }, + { + name: "stateless_explicit_return", + stateful: false, + setupFunc: func(m *Manager) { + // Add explicit rules matching return traffic pattern + for i := 0; i < 1000; i++ { // Simulate realistic ruleset size + ip := generateRandomIPs(1)[0] + m.AddPeerFiltering(ip, fw.ProtocolTCP, + &fw.Port{Values: []int{1024 + i}}, + &fw.Port{Values: []int{80}}, + fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return") + } + }, + desc: "Explicit rules matching return traffic patterns without state", + }, + { + name: "stateful_with_established", + stateful: true, + setupFunc: func(m *Manager) { + // Add some basic rules but rely on state for established connections + m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, + fw.RuleDirectionIN, fw.ActionDrop, "", "default drop") + }, + desc: "Connection tracking with established connections", + }, + } + + // Test both TCP and UDP + protocols := []struct { + name string + proto layers.IPProtocol + }{ + {"TCP", layers.IPProtocolTCP}, + {"UDP", layers.IPProtocolUDP}, + } + + for _, sc := range scenarios { + for _, proto := range protocols { + b.Run(fmt.Sprintf("%s_%s", sc.name, proto.name), func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + os.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + os.Unsetenv("NB_DISABLE_CONNTRACK") + } + + // Create manager and basic setup + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer manager.Reset(nil) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + // Apply scenario-specific setup + sc.setupFunc(manager) + + // Generate test packets + srcIP := generateRandomIPs(1)[0] + dstIP := generateRandomIPs(1)[0] + srcPort := uint16(1024 + b.N%60000) + dstPort := uint16(80) + + outbound := generatePacket(srcIP, dstIP, srcPort, dstPort, proto.proto) + inbound := generatePacket(dstIP, srcIP, dstPort, srcPort, proto.proto) + + // For stateful scenarios, establish the connection + if sc.stateful { + manager.processOutgoingHooks(outbound) + } + + // Measure inbound packet processing + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } + } +} + +// BenchmarkStateScaling measures how performance scales with connection table size +func BenchmarkStateScaling(b *testing.B) { + connCounts := []int{100, 1000, 10000, 100000} + + for _, count := range connCounts { + b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer manager.Reset(nil) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + // Pre-populate connection table + srcIPs := generateRandomIPs(count) + dstIPs := generateRandomIPs(count) + for i := 0; i < count; i++ { + outbound := generatePacket(srcIPs[i], dstIPs[i], + uint16(1024+i), 80, layers.IPProtocolTCP) + manager.processOutgoingHooks(outbound) + } + + // Test packet + testOut := generatePacket(srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP) + testIn := generatePacket(dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) + + // First establish our test connection + manager.processOutgoingHooks(testOut) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(testIn, manager.incomingRules) + } + }) + } +} + +// BenchmarkEstablishmentOverhead measures the overhead of connection establishment +func BenchmarkEstablishmentOverhead(b *testing.B) { + scenarios := []struct { + name string + established bool + }{ + {"established", true}, + {"new", false}, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer manager.Reset(nil) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + srcIP := generateRandomIPs(1)[0] + dstIP := generateRandomIPs(1)[0] + outbound := generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) + inbound := generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + + if sc.established { + manager.processOutgoingHooks(outbound) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } +} From 9ff04750a864bc54eb7c27ffb5bc5c93cb840fc1 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 21 Dec 2024 23:52:19 +0100 Subject: [PATCH 20/28] Add benchmark to compare routed network return traffic handling --- .../uspfilter/uspfilter_bench_test.go | 242 ++++++++++++++++++ 1 file changed, 242 insertions(+) diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index 6e748c85f2c..4c44d076a8f 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -5,12 +5,14 @@ import ( "math/rand" "net" "os" + "strings" "testing" "github.com/google/gopacket" "github.com/google/gopacket/layers" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface/device" ) @@ -259,3 +261,243 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { }) } } + +// BenchmarkRoutedNetworkReturn compares approaches for handling routed network return traffic +func BenchmarkRoutedNetworkReturn(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + state string // "new", "established", "post_handshake" (TCP only) + setupFunc func(*Manager) + genPackets func(net.IP, net.IP) ([]byte, []byte) // generates appropriate packets for the scenario + desc string + }{ + { + name: "allow_non_wg_tcp_new", + proto: layers.IPProtocolTCP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + os.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + }, + desc: "Allow non-WG: TCP new connection", + }, + { + name: "allow_non_wg_tcp_established", + proto: layers.IPProtocolTCP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + os.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate packets with ACK flag for established connection + return generateTCPPacketWithFlags(srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + }, + desc: "Allow non-WG: TCP established connection", + }, + { + name: "allow_non_wg_udp_new", + proto: layers.IPProtocolUDP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + os.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Allow non-WG: UDP new connection", + }, + { + name: "allow_non_wg_udp_established", + proto: layers.IPProtocolUDP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + os.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Allow non-WG: UDP established connection", + }, + { + name: "stateful_tcp_new", + proto: layers.IPProtocolTCP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + os.Unsetenv("NB_DISABLE_CONNTRACK") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + }, + desc: "Stateful: TCP new connection", + }, + { + name: "stateful_tcp_established", + proto: layers.IPProtocolTCP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + os.Unsetenv("NB_DISABLE_CONNTRACK") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate established TCP packets (ACK flag) + return generateTCPPacketWithFlags(srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + }, + desc: "Stateful: TCP established connection", + }, + { + name: "stateful_tcp_post_handshake", + proto: layers.IPProtocolTCP, + state: "post_handshake", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + os.Unsetenv("NB_DISABLE_CONNTRACK") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate packets with PSH+ACK flags for data transfer + return generateTCPPacketWithFlags(srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + generateTCPPacketWithFlags(dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck)) + }, + desc: "Stateful: TCP post-handshake data transfer", + }, + { + name: "stateful_udp_new", + proto: layers.IPProtocolUDP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + os.Unsetenv("NB_DISABLE_CONNTRACK") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Stateful: UDP new connection", + }, + { + name: "stateful_udp_established", + proto: layers.IPProtocolUDP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + os.Unsetenv("NB_DISABLE_CONNTRACK") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Stateful: UDP established connection", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer manager.Reset(nil) + + // Setup scenario + sc.setupFunc(manager) + + // Use IPs outside WG range for routed network simulation + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("8.8.8.8") + outbound, inbound := sc.genPackets(srcIP, dstIP) + + // For stateful cases and established connections + if !strings.Contains(sc.name, "allow_non_wg") || + (strings.Contains(sc.state, "established") || sc.state == "post_handshake") { + manager.processOutgoingHooks(outbound) + + // For TCP post-handshake, simulate full handshake + if sc.state == "post_handshake" { + // SYN + syn := generateTCPPacketWithFlags(srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + // SYN-ACK + synack := generateTCPPacketWithFlags(dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + // ACK + ack := generateTCPPacketWithFlags(srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } +} + +// generateTCPPacketWithFlags creates a TCP packet with specific flags +func generateTCPPacketWithFlags(srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte { + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: layers.IPProtocolTCP, + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + } + + // Set TCP flags + tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0 + tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0 + tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0 + tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0 + tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0 + + tcp.SetNetworkLayerForChecksum(ipv4) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload([]byte("test"))) + return buf.Bytes() +} From 1a9a82b56462d931f8dede0a4e2c4af3778967e7 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 22 Dec 2024 00:18:14 +0100 Subject: [PATCH 21/28] Improve TCP state handling --- client/firewall/uspfilter/conntrack/tcp.go | 85 +++++++-- .../firewall/uspfilter/conntrack/tcp_test.go | 162 ++++++++++++++++++ 2 files changed, 228 insertions(+), 19 deletions(-) create mode 100644 client/firewall/uspfilter/conntrack/tcp_test.go diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 5664191fd2b..22c37184c5f 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -125,19 +125,27 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, t.mutex.Lock() defer t.mutex.Unlock() - // For SYN packets (new connection attempts), always allow - if flags&TCPSyn != 0 && flags&TCPAck == 0 { - key := makeTCPKey(dstIP, srcIP, dstPort, srcPort) - t.connections[key] = &TCPConnTrack{ - SourceIP: slices.Clone(dstIP), - DestIP: slices.Clone(srcIP), - SourcePort: dstPort, - DestPort: srcPort, - State: TCPStateSynReceived, - LastSeen: time.Now(), - established: false, + // Always validate flag combinations first + if !isValidFlagCombination(flags) { + return false + } + + // For SYN packets (new connection attempts), allow only pure SYN + if flags&TCPSyn != 0 { + if flags&TCPAck == 0 { + key := makeTCPKey(dstIP, srcIP, dstPort, srcPort) + t.connections[key] = &TCPConnTrack{ + SourceIP: slices.Clone(dstIP), + DestIP: slices.Clone(srcIP), + SourcePort: dstPort, + DestPort: srcPort, + State: TCPStateSynReceived, + LastSeen: time.Now(), + established: false, + } + return true } - return true + // If it's SYN+ACK, let it fall through to normal processing } key := makeTCPKey(dstIP, srcIP, dstPort, srcPort) @@ -146,11 +154,14 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, return false } - // Update state and check validity + // Handle RST packets - only allow for existing connections if flags&TCPRst != 0 { - conn.State = TCPStateClosed - conn.established = false - return true + if conn.established || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { + conn.State = TCPStateClosed + conn.established = false + return true + } + return false } // Special handling for FIN state @@ -212,7 +223,7 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo case TCPStateFinWait1: switch { case flags&TCPFin != 0 && flags&TCPAck != 0: - // Simultaneous close + // Simultaneous close - both sides sent FIN conn.State = TCPStateClosing case flags&TCPFin != 0: conn.State = TCPStateFinWait2 @@ -228,6 +239,7 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo case TCPStateClosing: if flags&TCPAck != 0 { conn.State = TCPStateTimeWait + // Keep established = false from previous state } case TCPStateCloseWait: @@ -248,15 +260,36 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo // isValidStateForFlags checks if the TCP flags are valid for the current connection state func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { + if !isValidFlagCombination(flags) { + return false + } + switch state { + case TCPStateNew: + return flags&TCPSyn != 0 && flags&TCPAck == 0 case TCPStateSynSent: return flags&TCPSyn != 0 && flags&TCPAck != 0 case TCPStateSynReceived: return flags&TCPAck != 0 case TCPStateEstablished: - return true // Allow all flags in established state - case TCPStateFinWait1, TCPStateFinWait2: + if flags&TCPRst != 0 { + return true + } + return flags&TCPAck != 0 + case TCPStateFinWait1: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateFinWait2: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateClosing: + // In CLOSING state, we should accept the final ACK + return flags&TCPAck != 0 + case TCPStateTimeWait: + // In TIME_WAIT, we might see retransmissions + return flags&TCPAck != 0 + case TCPStateCloseWait: return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateLastAck: + return flags&TCPAck != 0 } return false } @@ -311,3 +344,17 @@ func makeTCPKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) TCPC DstPort: dstPort, } } + +func isValidFlagCombination(flags uint8) bool { + // Invalid: SYN+FIN + if flags&TCPSyn != 0 && flags&TCPFin != 0 { + return false + } + + // Invalid: RST with SYN or FIN + if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) { + return false + } + + return true +} diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go new file mode 100644 index 00000000000..7bca8d9960e --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -0,0 +1,162 @@ +package conntrack + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTCPStateMachine(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("100.64.0.1") + dstIP := net.ParseIP("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + t.Run("Security Tests", func(t *testing.T) { + tests := []struct { + name string + flags uint8 + wantDrop bool + desc string + }{ + { + name: "Block unsolicited SYN-ACK", + flags: TCPSyn | TCPAck, + wantDrop: true, + desc: "Should block SYN-ACK without prior SYN", + }, + { + name: "Block invalid SYN-FIN", + flags: TCPSyn | TCPFin, + wantDrop: true, + desc: "Should block invalid SYN-FIN combination", + }, + { + name: "Block unsolicited RST", + flags: TCPRst, + wantDrop: true, + desc: "Should block RST without connection", + }, + { + name: "Block unsolicited ACK", + flags: TCPAck, + wantDrop: true, + desc: "Should block ACK without connection", + }, + { + name: "Block data without connection", + flags: TCPAck | TCPPush, + wantDrop: true, + desc: "Should block data without established connection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) + require.Equal(t, !tt.wantDrop, isValid, tt.desc) + }) + } + }) + + t.Run("Connection Flow Tests", func(t *testing.T) { + tests := []struct { + name string + test func(*testing.T) + desc string + }{ + { + name: "Normal Handshake", + test: func(t *testing.T) { + // Send initial SYN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + + // Receive SYN-ACK + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + require.True(t, valid, "SYN-ACK should be allowed") + + // Send ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + + // Test data transfer + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + require.True(t, valid, "Data should be allowed after handshake") + }, + }, + { + name: "Normal Close", + test: func(t *testing.T) { + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Send FIN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + + // Receive ACK for FIN + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + require.True(t, valid, "ACK for FIN should be allowed") + + // Receive FIN from other side + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + require.True(t, valid, "FIN should be allowed") + + // Send final ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + }, + }, + { + name: "RST During Connection", + test: func(t *testing.T) { + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Receive RST + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + require.True(t, valid, "RST should be allowed for established connection") + + // Verify connection is closed + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + require.False(t, valid, "Data should be blocked after RST") + }, + }, + { + name: "Simultaneous Close", + test: func(t *testing.T) { + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Both sides send FIN+ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + require.True(t, valid, "Simultaneous FIN should be allowed") + + // Both sides send final ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + require.True(t, valid, "Final ACKs should be allowed") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tracker = NewTCPTracker(DefaultTCPTimeout) + tt.test(t) + }) + } + }) +} + +// Helper to establish a TCP connection +func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + require.True(t, valid, "SYN-ACK should be allowed") + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) +} From 3c158d4f29c4d51c4c11b636baa5ee4e53ae8601 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 22 Dec 2024 13:06:19 +0100 Subject: [PATCH 22/28] Fix races, improve performance and add benchmarks --- client/firewall/uspfilter/conntrack/common.go | 137 +++++++++++++++ .../uspfilter/conntrack/common_test.go | 115 +++++++++++++ client/firewall/uspfilter/conntrack/icmp.go | 93 +++++----- .../firewall/uspfilter/conntrack/icmp_test.go | 39 +++++ client/firewall/uspfilter/conntrack/tcp.go | 162 ++++++++++-------- .../firewall/uspfilter/conntrack/tcp_test.go | 103 +++++++++++ client/firewall/uspfilter/conntrack/udp.go | 111 ++++++------ .../firewall/uspfilter/conntrack/udp_test.go | 37 +++- client/firewall/uspfilter/uspfilter_test.go | 2 +- 9 files changed, 631 insertions(+), 168 deletions(-) create mode 100644 client/firewall/uspfilter/conntrack/common.go create mode 100644 client/firewall/uspfilter/conntrack/common_test.go create mode 100644 client/firewall/uspfilter/conntrack/icmp_test.go diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go new file mode 100644 index 00000000000..079a0175f2f --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common.go @@ -0,0 +1,137 @@ +// common.go +package conntrack + +import ( + "net" + "sync" + "sync/atomic" + "time" +) + +// BaseConnTrack provides common fields and locking for all connection types +type BaseConnTrack struct { + sync.RWMutex + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + lastSeen atomic.Int64 // Unix nano for atomic access + established atomic.Bool +} + +// these small methods will be inlined by the compiler + +// UpdateLastSeen safely updates the last seen timestamp +func (b *BaseConnTrack) UpdateLastSeen() { + b.lastSeen.Store(time.Now().UnixNano()) +} + +// IsEstablished safely checks if connection is established +func (b *BaseConnTrack) IsEstablished() bool { + return b.established.Load() +} + +// SetEstablished safely sets the established state +func (b *BaseConnTrack) SetEstablished(state bool) { + b.established.Store(state) +} + +// GetLastSeen safely gets the last seen timestamp +func (b *BaseConnTrack) GetLastSeen() time.Time { + return time.Unix(0, b.lastSeen.Load()) +} + +// timeoutExceeded checks if the connection has exceeded the given timeout +func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool { + lastSeen := time.Unix(0, b.lastSeen.Load()) + return time.Since(lastSeen) > timeout +} + +// IPAddr is a fixed-size IP address to avoid allocations +type IPAddr [16]byte + +// makeIPAddr creates an IPAddr from net.IP +func makeIPAddr(ip net.IP) (addr IPAddr) { + // Optimization: check for v4 first as it's more common + if ip4 := ip.To4(); ip4 != nil { + copy(addr[12:], ip4) + } else { + copy(addr[:], ip.To16()) + } + return addr +} + +// ConnKey uniquely identifies a connection +type ConnKey struct { + SrcIP IPAddr + DstIP IPAddr + SrcPort uint16 + DstPort uint16 +} + +// makeConnKey creates a connection key +func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { + return ConnKey{ + SrcIP: makeIPAddr(srcIP), + DstIP: makeIPAddr(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + } +} + +// ValidateIPs checks if IPs match without allocation +func ValidateIPs(connIP IPAddr, pktIP net.IP) bool { + if ip4 := pktIP.To4(); ip4 != nil { + // Compare IPv4 addresses (last 4 bytes) + for i := 0; i < 4; i++ { + if connIP[12+i] != ip4[i] { + return false + } + } + return true + } + // Compare full IPv6 addresses + ip6 := pktIP.To16() + for i := 0; i < 16; i++ { + if connIP[i] != ip6[i] { + return false + } + } + return true +} + +// PreallocatedIPs is a pool of IP byte slices to reduce allocations +type PreallocatedIPs struct { + sync.Pool +} + +// NewPreallocatedIPs creates a new IP pool +func NewPreallocatedIPs() *PreallocatedIPs { + return &PreallocatedIPs{ + Pool: sync.Pool{ + New: func() interface{} { + return make(net.IP, 16) + }, + }, + } +} + +// Get retrieves an IP from the pool +func (p *PreallocatedIPs) Get() net.IP { + return p.Pool.Get().(net.IP) +} + +// Put returns an IP to the pool +func (p *PreallocatedIPs) Put(ip net.IP) { + p.Pool.Put(ip) +} + +// copyIP copies an IP address efficiently +func copyIP(dst, src net.IP) { + if len(src) == 16 { + copy(dst, src) + } else { + // Handle IPv4 + copy(dst[12:], src.To4()) + } +} diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go new file mode 100644 index 00000000000..a337f649b47 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -0,0 +1,115 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkIPOperations(b *testing.B) { + b.Run("makeIPAddr", func(b *testing.B) { + ip := net.ParseIP("192.168.1.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = makeIPAddr(ip) + } + }) + + b.Run("ValidateIPs", func(b *testing.B) { + ip1 := net.ParseIP("192.168.1.1") + ip2 := net.ParseIP("192.168.1.1") + addr := makeIPAddr(ip1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ValidateIPs(addr, ip2) + } + }) + + b.Run("IPPool", func(b *testing.B) { + pool := NewPreallocatedIPs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ip := pool.Get() + pool.Put(ip) + } + }) + +} +func BenchmarkAtomicOperations(b *testing.B) { + conn := &BaseConnTrack{} + b.Run("UpdateLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.UpdateLastSeen() + } + }) + + b.Run("IsEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.IsEstablished() + } + }) + + b.Run("SetEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.SetEstablished(i%2 == 0) + } + }) + + b.Run("GetLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.GetLastSeen() + } + }) +} + +// Memory pressure tests +func BenchmarkMemoryPressure(b *testing.B) { + b.Run("TCPHighLoad", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) + } + } + }) + + b.Run("UDPHighLoad", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) + } + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 9b76c370a46..4cab4cb0e72 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -2,7 +2,6 @@ package conntrack import ( "net" - "slices" "sync" "time" @@ -27,12 +26,9 @@ type ICMPConnKey struct { // ICMPConnTrack represents an ICMP connection state type ICMPConnTrack struct { - SourceIP net.IP - DestIP net.IP - Sequence uint16 - ID uint16 - LastSeen time.Time - established bool + BaseConnTrack + Sequence uint16 + ID uint16 } // ICMPTracker manages ICMP connection states @@ -42,6 +38,7 @@ type ICMPTracker struct { cleanupTicker *time.Ticker mutex sync.RWMutex done chan struct{} + ipPool *PreallocatedIPs } // NewICMPTracker creates a new ICMP connection tracker @@ -55,6 +52,7 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker { timeout: timeout, cleanupTicker: time.NewTicker(ICMPCleanupInterval), done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), } go tracker.cleanupRoutine() @@ -64,29 +62,41 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker { // TrackOutbound records an outbound ICMP Echo Request func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { key := makeICMPKey(srcIP, dstIP, id, seq) + now := time.Now().UnixNano() t.mutex.Lock() - defer t.mutex.Unlock() - - t.connections[key] = &ICMPConnTrack{ - SourceIP: slices.Clone(srcIP), - DestIP: slices.Clone(dstIP), - ID: id, - Sequence: seq, - LastSeen: time.Now(), - established: true, + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &ICMPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + }, + ID: id, + Sequence: seq, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn } + t.mutex.Unlock() + + conn.lastSeen.Store(now) } // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { switch icmpType { - // For Destination Unreachable and Time Exceeded, always allow - case uint8(layers.ICMPv4TypeDestinationUnreachable), uint8(layers.ICMPv4TypeTimeExceeded): + case uint8(layers.ICMPv4TypeDestinationUnreachable), + uint8(layers.ICMPv4TypeTimeExceeded): return true - // For Echo Reply, check if we have a matching request case uint8(layers.ICMPv4TypeEchoReply): - // continue further down + // continue processing default: return false } @@ -94,29 +104,22 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq key := makeICMPKey(dstIP, srcIP, id, seq) t.mutex.RLock() - defer t.mutex.RUnlock() - conn, exists := t.connections[key] + t.mutex.RUnlock() + if !exists { return false } - // Check if connection is still valid - if time.Since(conn.LastSeen) > t.timeout { + if conn.timeoutExceeded(t.timeout) { return false } - if conn.established && - conn.DestIP.Equal(srcIP) && - conn.SourceIP.Equal(dstIP) && + return conn.IsEstablished() && + ValidateIPs(makeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(makeIPAddr(dstIP), conn.SourceIP) && conn.ID == id && - conn.Sequence == seq { - - conn.LastSeen = time.Now() - return true - } - - return false + conn.Sequence == seq } func (t *ICMPTracker) cleanupRoutine() { @@ -129,14 +132,14 @@ func (t *ICMPTracker) cleanupRoutine() { } } } - func (t *ICMPTracker) cleanup() { t.mutex.Lock() defer t.mutex.Unlock() - now := time.Now() for key, conn := range t.connections { - if now.Sub(conn.LastSeen) > t.timeout { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) delete(t.connections, key) } } @@ -146,15 +149,21 @@ func (t *ICMPTracker) cleanup() { func (t *ICMPTracker) Close() { t.cleanupTicker.Stop() close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() } +// makeICMPKey creates an ICMP connection key func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { - var srcAddr, dstAddr [16]byte - copy(srcAddr[:], srcIP.To16()) - copy(dstAddr[:], dstIP.To16()) return ICMPConnKey{ - SrcIP: srcAddr, - DstIP: dstAddr, + SrcIP: makeIPAddr(srcIP), + DstIP: makeIPAddr(dstIP), ID: id, Sequence: seq, } diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go new file mode 100644 index 00000000000..21176e719d4 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -0,0 +1,39 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkICMPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535)) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 22c37184c5f..e8d20f41c67 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -4,7 +4,6 @@ package conntrack import ( "net" - "slices" "sync" "time" ) @@ -61,31 +60,28 @@ type TCPConnKey struct { // TCPConnTrack represents a TCP connection state type TCPConnTrack struct { - SourceIP net.IP - DestIP net.IP - SourcePort uint16 - DestPort uint16 - State TCPState - LastSeen time.Time - established bool + BaseConnTrack + State TCPState } // TCPTracker manages TCP connection states type TCPTracker struct { - connections map[TCPConnKey]*TCPConnTrack + connections map[ConnKey]*TCPConnTrack mutex sync.RWMutex cleanupTicker *time.Ticker done chan struct{} timeout time.Duration + ipPool *PreallocatedIPs } // NewTCPTracker creates a new TCP connection tracker func NewTCPTracker(timeout time.Duration) *TCPTracker { tracker := &TCPTracker{ - connections: make(map[TCPConnKey]*TCPConnTrack), + connections: make(map[ConnKey]*TCPConnTrack), cleanupTicker: time.NewTicker(TCPCleanupInterval), done: make(chan struct{}), timeout: timeout, + ipPool: NewPreallocatedIPs(), } go tracker.cleanupRoutine() @@ -94,88 +90,108 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker { // TrackOutbound processes an outbound TCP packet and updates connection state func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { - - key := makeTCPKey(srcIP, dstIP, srcPort, dstPort) - now := time.Now() + // Create key before lock + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() t.mutex.Lock() - defer t.mutex.Unlock() - conn, exists := t.connections[key] if !exists { + // Use preallocated IPs + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + conn = &TCPConnTrack{ - SourceIP: slices.Clone(srcIP), - DestIP: slices.Clone(dstIP), - SourcePort: srcPort, - DestPort: dstPort, - State: TCPStateNew, - LastSeen: now, - established: false, + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + State: TCPStateNew, } + conn.lastSeen.Store(now) + conn.established.Store(false) t.connections[key] = conn } + t.mutex.Unlock() - // Update connection state based on TCP flags + // Lock individual connection for state update + conn.Lock() t.updateState(conn, flags, true) - conn.LastSeen = now + conn.Unlock() + conn.lastSeen.Store(now) } // IsValidInbound checks if an inbound TCP packet matches a tracked connection func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { - t.mutex.Lock() - defer t.mutex.Unlock() - - // Always validate flag combinations first if !isValidFlagCombination(flags) { return false } - // For SYN packets (new connection attempts), allow only pure SYN - if flags&TCPSyn != 0 { - if flags&TCPAck == 0 { - key := makeTCPKey(dstIP, srcIP, dstPort, srcPort) - t.connections[key] = &TCPConnTrack{ - SourceIP: slices.Clone(dstIP), - DestIP: slices.Clone(srcIP), - SourcePort: dstPort, - DestPort: srcPort, - State: TCPStateSynReceived, - LastSeen: time.Now(), - established: false, + // Handle new SYN packets + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.Lock() + if _, exists := t.connections[key]; !exists { + // Use preallocated IPs + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, dstIP) + copyIP(dstIPCopy, srcIP) + + conn := &TCPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: dstPort, + DestPort: srcPort, + }, + State: TCPStateSynReceived, } - return true + conn.lastSeen.Store(time.Now().UnixNano()) + conn.established.Store(false) + t.connections[key] = conn } - // If it's SYN+ACK, let it fall through to normal processing + t.mutex.Unlock() + return true } - key := makeTCPKey(dstIP, srcIP, dstPort, srcPort) + // Look up existing connection + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.RLock() conn, exists := t.connections[key] + t.mutex.RUnlock() + if !exists { return false } - // Handle RST packets - only allow for existing connections + // Handle RST packets if flags&TCPRst != 0 { - if conn.established || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { + conn.Lock() + isEstablished := conn.IsEstablished() + if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { conn.State = TCPStateClosed - conn.established = false + conn.SetEstablished(false) + conn.Unlock() return true } + conn.Unlock() return false } - // Special handling for FIN state - if conn.State == TCPStateFinWait1 || conn.State == TCPStateFinWait2 { - t.updateState(conn, flags, false) - conn.LastSeen = time.Now() - return true - } - + // Update state + conn.Lock() t.updateState(conn, flags, false) - conn.LastSeen = time.Now() + conn.UpdateLastSeen() + isEstablished := conn.IsEstablished() + isValidState := t.isValidStateForFlags(conn.State, flags) + conn.Unlock() - // Allow if established or in a valid state for the flags - return conn.established || t.isValidStateForFlags(conn.State, flags) + return isEstablished || isValidState } // updateState updates the TCP connection state based on flags @@ -183,7 +199,7 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo // Handle RST flag specially - it always causes transition to closed if flags&TCPRst != 0 { conn.State = TCPStateClosed - conn.established = false + conn.SetEstablished(false) return } @@ -200,14 +216,14 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo } else { // Simultaneous open conn.State = TCPStateEstablished - conn.established = true + conn.SetEstablished(true) } } case TCPStateSynReceived: if flags&TCPAck != 0 && flags&TCPSyn == 0 { conn.State = TCPStateEstablished - conn.established = true + conn.SetEstablished(true) } case TCPStateEstablished: @@ -217,7 +233,7 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo } else { conn.State = TCPStateCloseWait } - conn.established = false + conn.SetEstablished(false) } case TCPStateFinWait1: @@ -309,19 +325,22 @@ func (t *TCPTracker) cleanup() { t.mutex.Lock() defer t.mutex.Unlock() - now := time.Now() for key, conn := range t.connections { var timeout time.Duration switch { case conn.State == TCPStateTimeWait: timeout = TimeWaitTimeout - case conn.established: + case conn.IsEstablished(): timeout = t.timeout default: timeout = TCPHandshakeTimeout } - if now.Sub(conn.LastSeen) > timeout { + lastSeen := conn.GetLastSeen() + if time.Since(lastSeen) > timeout { + // Return IPs to pool + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) delete(t.connections, key) } } @@ -331,18 +350,15 @@ func (t *TCPTracker) cleanup() { func (t *TCPTracker) Close() { t.cleanupTicker.Stop() close(t.done) -} -func makeTCPKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) TCPConnKey { - var srcAddr, dstAddr [16]byte - copy(srcAddr[:], srcIP.To16()) - copy(dstAddr[:], dstIP.To16()) - return TCPConnKey{ - SrcIP: srcAddr, - DstIP: dstAddr, - SrcPort: srcPort, - DstPort: dstPort, + // Clean up all remaining IPs + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) } + t.connections = nil + t.mutex.Unlock() } func isValidFlagCombination(flags uint8) bool { diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 7bca8d9960e..42a2f708a47 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -3,6 +3,7 @@ package conntrack import ( "net" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -160,3 +161,105 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) } + +// Benchmarks for the optimized implementation +func (t *TCPTracker) benchmarkTrackOutbound(b *testing.B) { + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + t.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } +} + +func (t *TCPTracker) benchmarkIsValidInbound(b *testing.B) { + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + t.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + t.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) + } +} + +func BenchmarkTCPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) + } + }) + + b.Run("ConcurrentAccess", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + if i%2 == 0 { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } else { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) + } + i++ + } + }) + }) +} + +// Benchmark connection cleanup +func BenchmarkCleanup(b *testing.B) { + b.Run("TCPCleanup", func(b *testing.B) { + tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing + defer tracker.Close() + + // Pre-populate with expired connections + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + for i := 0; i < 10000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + // Wait for connections to expire + time.Sleep(200 * time.Millisecond) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.cleanup() + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 94ab9e15273..4d55ec0df83 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -2,7 +2,6 @@ package conntrack import ( "net" - "slices" "sync" "time" ) @@ -14,22 +13,9 @@ const ( UDPCleanupInterval = 15 * time.Second ) -type ConnKey struct { - // Supports both IPv4 and IPv6 - SrcIP [16]byte - DstIP [16]byte - SrcPort uint16 - DstPort uint16 -} - // UDPConnTrack represents a UDP connection state type UDPConnTrack struct { - SourceIP net.IP - DestIP net.IP - SourcePort uint16 - DestPort uint16 - LastSeen time.Time - established bool + BaseConnTrack } // UDPTracker manages UDP connection states @@ -39,6 +25,7 @@ type UDPTracker struct { cleanupTicker *time.Ticker mutex sync.RWMutex done chan struct{} + ipPool *PreallocatedIPs } // NewUDPTracker creates a new UDP connection tracker @@ -52,6 +39,7 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker { timeout: timeout, cleanupTicker: time.NewTicker(UDPCleanupInterval), done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), } go tracker.cleanupRoutine() @@ -60,49 +48,55 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker { // TrackOutbound records an outbound UDP connection func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { - key := makeKey(srcIP, srcPort, dstIP, dstPort) + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() t.mutex.Lock() - defer t.mutex.Unlock() - - t.connections[key] = &UDPConnTrack{ - SourceIP: slices.Clone(srcIP), - DestIP: slices.Clone(dstIP), - SourcePort: srcPort, - DestPort: dstPort, - LastSeen: time.Now(), - established: true, + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &UDPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn } + t.mutex.Unlock() + + conn.lastSeen.Store(now) } // IsValidInbound checks if an inbound packet matches a tracked connection func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { - t.mutex.RLock() - defer t.mutex.RUnlock() + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) - key := makeKey(dstIP, dstPort, srcIP, srcPort) + t.mutex.RLock() conn, exists := t.connections[key] + t.mutex.RUnlock() + if !exists { return false } - // Check if connection is still valid - if time.Since(conn.LastSeen) > t.timeout { + if conn.timeoutExceeded(t.timeout) { return false } - if conn.established && - conn.DestIP.Equal(srcIP) && - conn.SourceIP.Equal(dstIP) && + return conn.IsEstablished() && + ValidateIPs(makeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(makeIPAddr(dstIP), conn.SourceIP) && conn.DestPort == srcPort && - conn.SourcePort == dstPort { - - conn.LastSeen = time.Now() - - return true - } - - return false + conn.SourcePort == dstPort } // cleanupRoutine periodically removes stale connections @@ -121,9 +115,10 @@ func (t *UDPTracker) cleanup() { t.mutex.Lock() defer t.mutex.Unlock() - now := time.Now() for key, conn := range t.connections { - if now.Sub(conn.LastSeen) > t.timeout { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) delete(t.connections, key) } } @@ -133,6 +128,14 @@ func (t *UDPTracker) cleanup() { func (t *UDPTracker) Close() { t.cleanupTicker.Stop() close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() } // GetConnection safely retrieves a connection state @@ -140,20 +143,28 @@ func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, d t.mutex.RLock() defer t.mutex.RUnlock() - key := makeKey(srcIP, srcPort, dstIP, dstPort) + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) conn, exists := t.connections[key] if !exists { return nil, false } + // Create a copy with new IP allocations + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, conn.SourceIP) + copyIP(dstIPCopy, conn.DestIP) + connCopy := &UDPConnTrack{ - SourceIP: slices.Clone(conn.SourceIP), - DestIP: slices.Clone(conn.DestIP), - SourcePort: conn.SourcePort, - DestPort: conn.DestPort, - LastSeen: conn.LastSeen, - established: conn.established, + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: conn.SourcePort, + DestPort: conn.DestPort, + }, } + connCopy.lastSeen.Store(conn.lastSeen.Load()) + connCopy.established.Store(conn.IsEstablished()) return connCopy, true } diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 938dc18ea59..1a8afc21a01 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -58,8 +58,8 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { assert.True(t, conn.DestIP.Equal(dstIP)) assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, dstPort, conn.DestPort) - assert.True(t, conn.established) - assert.WithinDuration(t, time.Now(), conn.LastSeen, 1*time.Second) + assert.True(t, conn.IsEstablished()) + assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) } func TestUDPTracker_IsValidInbound(t *testing.T) { @@ -232,3 +232,36 @@ func TestUDPTracker_Close(t *testing.T) { _, ok := <-tracker.done assert.False(t, ok, "done channel should be closed") } + +func BenchmarkUDPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) + } + }) +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 23f575843a3..ea78a013abc 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -630,7 +630,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { if cp.shouldAllow { conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) require.True(t, exists, "Connection should still exist during valid window") - require.True(t, time.Since(conn.LastSeen) < manager.udpTracker.Timeout(), + require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(), "LastSeen should be updated for valid responses") } } From 07019d2d633a3560d62fc3e6bf9ca327537126e6 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 22 Dec 2024 13:42:07 +0100 Subject: [PATCH 23/28] Fix udp test --- client/firewall/uspfilter/conntrack/common.go | 8 +- .../uspfilter/conntrack/common_test.go | 6 +- client/firewall/uspfilter/conntrack/icmp.go | 8 +- client/firewall/uspfilter/conntrack/udp.go | 35 +---- .../firewall/uspfilter/conntrack/udp_test.go | 36 +---- client/firewall/uspfilter/uspfilter_test.go | 124 ++++++++++-------- 6 files changed, 90 insertions(+), 127 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 079a0175f2f..11d6599315d 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -50,8 +50,8 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool { // IPAddr is a fixed-size IP address to avoid allocations type IPAddr [16]byte -// makeIPAddr creates an IPAddr from net.IP -func makeIPAddr(ip net.IP) (addr IPAddr) { +// MakeIPAddr creates an IPAddr from net.IP +func MakeIPAddr(ip net.IP) (addr IPAddr) { // Optimization: check for v4 first as it's more common if ip4 := ip.To4(); ip4 != nil { copy(addr[12:], ip4) @@ -72,8 +72,8 @@ type ConnKey struct { // makeConnKey creates a connection key func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { return ConnKey{ - SrcIP: makeIPAddr(srcIP), - DstIP: makeIPAddr(dstIP), + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), SrcPort: srcPort, DstPort: dstPort, } diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index a337f649b47..72d006def57 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -6,18 +6,18 @@ import ( ) func BenchmarkIPOperations(b *testing.B) { - b.Run("makeIPAddr", func(b *testing.B) { + b.Run("MakeIPAddr", func(b *testing.B) { ip := net.ParseIP("192.168.1.1") b.ResetTimer() for i := 0; i < b.N; i++ { - _ = makeIPAddr(ip) + _ = MakeIPAddr(ip) } }) b.Run("ValidateIPs", func(b *testing.B) { ip1 := net.ParseIP("192.168.1.1") ip2 := net.ParseIP("192.168.1.1") - addr := makeIPAddr(ip1) + addr := MakeIPAddr(ip1) b.ResetTimer() for i := 0; i < b.N; i++ { _ = ValidateIPs(addr, ip2) diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 4cab4cb0e72..e0a971678f1 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -116,8 +116,8 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq } return conn.IsEstablished() && - ValidateIPs(makeIPAddr(srcIP), conn.DestIP) && - ValidateIPs(makeIPAddr(dstIP), conn.SourceIP) && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && conn.ID == id && conn.Sequence == seq } @@ -162,8 +162,8 @@ func (t *ICMPTracker) Close() { // makeICMPKey creates an ICMP connection key func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { return ICMPConnKey{ - SrcIP: makeIPAddr(srcIP), - DstIP: makeIPAddr(dstIP), + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), ID: id, Sequence: seq, } diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 4d55ec0df83..a969a4e8425 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -93,8 +93,8 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, } return conn.IsEstablished() && - ValidateIPs(makeIPAddr(srcIP), conn.DestIP) && - ValidateIPs(makeIPAddr(dstIP), conn.SourceIP) && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && conn.DestPort == srcPort && conn.SourcePort == dstPort } @@ -149,39 +149,10 @@ func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, d return nil, false } - // Create a copy with new IP allocations - srcIPCopy := t.ipPool.Get() - dstIPCopy := t.ipPool.Get() - copyIP(srcIPCopy, conn.SourceIP) - copyIP(dstIPCopy, conn.DestIP) - - connCopy := &UDPConnTrack{ - BaseConnTrack: BaseConnTrack{ - SourceIP: srcIPCopy, - DestIP: dstIPCopy, - SourcePort: conn.SourcePort, - DestPort: conn.DestPort, - }, - } - connCopy.lastSeen.Store(conn.lastSeen.Load()) - connCopy.established.Store(conn.IsEstablished()) - - return connCopy, true + return conn, true } // Timeout returns the configured timeout duration for the tracker func (t *UDPTracker) Timeout() time.Duration { return t.timeout } - -func makeKey(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) ConnKey { - var srcAddr, dstAddr [16]byte - copy(srcAddr[:], srcIP.To16()) // Ensure 16-byte representation - copy(dstAddr[:], dstIP.To16()) - return ConnKey{ - SrcIP: srcAddr, - SrcPort: srcPort, - DstIP: dstAddr, - DstPort: dstPort, - } -} diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 1a8afc21a01..67172189069 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -51,7 +51,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) // Verify connection was tracked - key := makeKey(srcIP, srcPort, dstIP, dstPort) + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) conn, exists := tracker.connections[key] require.True(t, exists) assert.True(t, conn.SourceIP.Equal(srcIP)) @@ -161,11 +161,11 @@ func TestUDPTracker_Cleanup(t *testing.T) { timeout: timeout, cleanupTicker: time.NewTicker(cleanupInterval), done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), } // Start cleanup routine go tracker.cleanupRoutine() - defer tracker.Close() // Add some connections connections := []struct { @@ -193,44 +193,20 @@ func TestUDPTracker_Cleanup(t *testing.T) { } // Verify initial connections - tracker.mutex.RLock() assert.Len(t, tracker.connections, 2) - tracker.mutex.RUnlock() // Wait for connection timeout and cleanup interval time.Sleep(timeout + 2*cleanupInterval) - // Verify connections were cleaned up tracker.mutex.RLock() - assert.Empty(t, tracker.connections) + connCount := len(tracker.connections) tracker.mutex.RUnlock() - // Add a new connection and verify it's not immediately cleaned up - tracker.TrackOutbound(connections[0].srcIP, connections[0].dstIP, - connections[0].srcPort, connections[0].dstPort) - - tracker.mutex.RLock() - assert.Len(t, tracker.connections, 1, "New connection should not be cleaned up immediately") - tracker.mutex.RUnlock() -} - -func TestUDPTracker_Close(t *testing.T) { - tracker := NewUDPTracker(DefaultUDPTimeout) - - // Add a connection - tracker.TrackOutbound( - net.ParseIP("192.168.1.2"), - net.ParseIP("192.168.1.3"), - 12345, - 53, - ) + // Verify connections were cleaned up + assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up") - // Close the tracker + // Properly close the tracker tracker.Close() - - // Verify done channel is closed - _, ok := <-tracker.done - assert.False(t, ok, "done channel should be closed") } func BenchmarkUDPTracker(b *testing.B) { diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index ea78a013abc..d3563e6f251 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -187,10 +187,10 @@ func TestAddUDPPacketHook(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - manager := &Manager{ - incomingRules: map[string]RuleSet{}, - outgoingRules: map[string]RuleSet{}, - } + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -315,7 +315,7 @@ func TestNotMatchByIP(t *testing.T) { t.Errorf("failed to set network layer for checksum: %v", err) return } - payload := gopacket.Payload([]byte("test")) + payload := gopacket.Payload("test") buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -350,6 +350,9 @@ func TestRemovePacketHook(t *testing.T) { if err != nil { t.Fatalf("Failed to create Manager: %s", err) } + defer func() { + require.NoError(t, manager.Reset(nil)) + }() // Add a UDP packet hook hookFunc := func(data []byte) bool { return true } @@ -387,26 +390,33 @@ func TestRemovePacketHook(t *testing.T) { } func TestProcessOutgoingHooks(t *testing.T) { - manager := &Manager{ - outgoingRules: map[string]RuleSet{}, - wgNetwork: &net.IPNet{ - IP: net.ParseIP("100.10.0.0"), - Mask: net.CIDRMask(16, 32), - }, - decoders: sync.Pool{ - New: func() any { - d := &decoder{ - decoded: []gopacket.LayerType{}, - } - d.parser = gopacket.NewDecodingLayerParser( - layers.LayerTypeIPv4, - &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, - ) - d.parser.IgnoreUnsupported = true - return d - }, + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + manager.udpTracker.Close() + manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) + defer func() { + require.NoError(t, manager.Reset(nil)) + }() + + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d }, - udpTracker: conntrack.NewUDPTracker(100 * time.Millisecond), } hookCalled := false @@ -434,9 +444,9 @@ func TestProcessOutgoingHooks(t *testing.T) { DstPort: 53, } - err := udp.SetNetworkLayerForChecksum(ipv4) + err = udp.SetNetworkLayerForChecksum(ipv4) require.NoError(t, err) - payload := gopacket.Payload([]byte("test")) + payload := gopacket.Payload("test") buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -497,29 +507,34 @@ func TestUSPFilterCreatePerformance(t *testing.T) { } func TestStatefulFirewall_UDPTracking(t *testing.T) { - manager := &Manager{ - outgoingRules: map[string]RuleSet{}, - incomingRules: map[string]RuleSet{}, - wgNetwork: &net.IPNet{ - IP: net.ParseIP("100.10.0.0"), - Mask: net.CIDRMask(16, 32), - }, - decoders: sync.Pool{ - New: func() any { - d := &decoder{ - decoded: []gopacket.LayerType{}, - } - d.parser = gopacket.NewDecodingLayerParser( - layers.LayerTypeIPv4, - &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, - ) - d.parser.IgnoreUnsupported = true - return d - }, + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + + manager.udpTracker.Close() // Close the existing tracker + manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d }, - udpTracker: conntrack.NewUDPTracker(200 * time.Millisecond), } - defer manager.udpTracker.Close() + defer func() { + require.NoError(t, manager.Reset(nil)) + }() // Set up packet parameters srcIP := net.ParseIP("100.10.0.1") @@ -540,7 +555,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { DstPort: layers.UDPPort(dstPort), } - err := outboundUDP.SetNetworkLayerForChecksum(outboundIPv4) + err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4) require.NoError(t, err) outboundBuf := gopacket.NewSerializeBuffer() @@ -552,19 +567,20 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { err = gopacket.SerializeLayers(outboundBuf, opts, outboundIPv4, outboundUDP, - gopacket.Payload([]byte("test")), + gopacket.Payload("test"), ) require.NoError(t, err) // Process outbound packet and verify connection tracking - drop := manager.processOutgoingHooks(outboundBuf.Bytes()) + drop := manager.DropOutgoing(outboundBuf.Bytes()) require.False(t, drop, "Initial outbound packet should not be dropped") // Verify connection was tracked conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + require.True(t, exists, "Connection should be tracked after outbound packet") - require.True(t, conn.SourceIP.Equal(srcIP), "Source IP should match") - require.True(t, conn.DestIP.Equal(dstIP), "Destination IP should match") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match") require.Equal(t, srcPort, conn.SourcePort, "Source port should match") require.Equal(t, dstPort, conn.DestPort, "Destination port should match") @@ -588,7 +604,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { err = gopacket.SerializeLayers(inboundBuf, opts, inboundIPv4, inboundUDP, - gopacket.Payload([]byte("response")), + gopacket.Payload("response"), ) require.NoError(t, err) // Test roundtrip response handling over time @@ -689,7 +705,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { err = gopacket.SerializeLayers(testBuf, opts, &testIPv4, &testUDP, - gopacket.Payload([]byte("response")), + gopacket.Payload("response"), ) require.NoError(t, err) From ed77f4869dfd0dd5034951be7926674002ae2859 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 22 Dec 2024 19:08:40 +0100 Subject: [PATCH 24/28] Fix lint --- .../firewall/uspfilter/conntrack/tcp_test.go | 12 ++ .../uspfilter/uspfilter_bench_test.go | 119 ++++++++++-------- 2 files changed, 78 insertions(+), 53 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 42a2f708a47..cf9876a626a 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -73,6 +73,8 @@ func TestTCPStateMachine(t *testing.T) { { name: "Normal Handshake", test: func(t *testing.T) { + t.Helper() + // Send initial SYN tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) @@ -91,6 +93,8 @@ func TestTCPStateMachine(t *testing.T) { { name: "Normal Close", test: func(t *testing.T) { + t.Helper() + // First establish connection establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) @@ -112,6 +116,8 @@ func TestTCPStateMachine(t *testing.T) { { name: "RST During Connection", test: func(t *testing.T) { + t.Helper() + // First establish connection establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) @@ -121,12 +127,16 @@ func TestTCPStateMachine(t *testing.T) { // Verify connection is closed valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + t.Helper() + require.False(t, valid, "Data should be blocked after RST") }, }, { name: "Simultaneous Close", test: func(t *testing.T) { + t.Helper() + // First establish connection establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) @@ -145,6 +155,8 @@ func TestTCPStateMachine(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Helper() + tracker = NewTCPTracker(DefaultTCPTimeout) tt.test(t) }) diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index 4c44d076a8f..6cf68980443 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -10,6 +10,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" @@ -38,7 +39,7 @@ func generateRandomIPs(n int) []net.IP { return ips } -func generatePacket(srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte { +func generatePacket(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte { ipv4 := &layers.IPv4{ TTL: 64, Version: 4, @@ -55,20 +56,21 @@ func generatePacket(srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layer DstPort: layers.TCPPort(dstPort), SYN: true, } - tcp.SetNetworkLayerForChecksum(ipv4) + require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) transportLayer = tcp case layers.IPProtocolUDP: udp := &layers.UDP{ SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort), } - udp.SetNetworkLayerForChecksum(ipv4) + require.NoError(b, udp.SetNetworkLayerForChecksum(ipv4)) transportLayer = udp } buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} - gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload([]byte("test"))) + err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test")) + require.NoError(b, err) return buf.Bytes() } @@ -86,8 +88,9 @@ func BenchmarkCoreFiltering(b *testing.B) { stateful: false, setupFunc: func(m *Manager) { // Single rule allowing all traffic - m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "allow all") + require.NoError(b, err) }, desc: "Baseline: Single 'allow all' rule without connection tracking", }, @@ -106,10 +109,11 @@ func BenchmarkCoreFiltering(b *testing.B) { // Add explicit rules matching return traffic pattern for i := 0; i < 1000; i++ { // Simulate realistic ruleset size ip := generateRandomIPs(1)[0] - m.AddPeerFiltering(ip, fw.ProtocolTCP, + _, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, &fw.Port{Values: []int{1024 + i}}, &fw.Port{Values: []int{80}}, fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return") + require.NoError(b, err) } }, desc: "Explicit rules matching return traffic patterns without state", @@ -119,8 +123,9 @@ func BenchmarkCoreFiltering(b *testing.B) { stateful: true, setupFunc: func(m *Manager) { // Add some basic rules but rely on state for established connections - m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, fw.RuleDirectionIN, fw.ActionDrop, "", "default drop") + require.NoError(b, err) }, desc: "Connection tracking with established connections", }, @@ -140,16 +145,18 @@ func BenchmarkCoreFiltering(b *testing.B) { b.Run(fmt.Sprintf("%s_%s", sc.name, proto.name), func(b *testing.B) { // Configure stateful/stateless mode if !sc.stateful { - os.Setenv("NB_DISABLE_CONNTRACK", "1") + require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) } else { - os.Unsetenv("NB_DISABLE_CONNTRACK") + require.NoError(b, os.Setenv("NB_CONNTRACK_TIMEOUT", "1m")) } // Create manager and basic setup manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, }) - defer manager.Reset(nil) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) manager.wgNetwork = &net.IPNet{ IP: net.ParseIP("100.64.0.0"), @@ -165,8 +172,8 @@ func BenchmarkCoreFiltering(b *testing.B) { srcPort := uint16(1024 + b.N%60000) dstPort := uint16(80) - outbound := generatePacket(srcIP, dstIP, srcPort, dstPort, proto.proto) - inbound := generatePacket(dstIP, srcIP, dstPort, srcPort, proto.proto) + outbound := generatePacket(b, srcIP, dstIP, srcPort, dstPort, proto.proto) + inbound := generatePacket(b, dstIP, srcIP, dstPort, srcPort, proto.proto) // For stateful scenarios, establish the connection if sc.stateful { @@ -192,7 +199,9 @@ func BenchmarkStateScaling(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, }) - defer manager.Reset(nil) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) manager.wgNetwork = &net.IPNet{ IP: net.ParseIP("100.64.0.0"), @@ -203,14 +212,14 @@ func BenchmarkStateScaling(b *testing.B) { srcIPs := generateRandomIPs(count) dstIPs := generateRandomIPs(count) for i := 0; i < count; i++ { - outbound := generatePacket(srcIPs[i], dstIPs[i], + outbound := generatePacket(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, layers.IPProtocolTCP) manager.processOutgoingHooks(outbound) } // Test packet - testOut := generatePacket(srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP) - testIn := generatePacket(dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) + testOut := generatePacket(b, srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP) + testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) // First establish our test connection manager.processOutgoingHooks(testOut) @@ -238,7 +247,9 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, }) - defer manager.Reset(nil) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) manager.wgNetwork = &net.IPNet{ IP: net.ParseIP("100.64.0.0"), @@ -247,8 +258,8 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { srcIP := generateRandomIPs(1)[0] dstIP := generateRandomIPs(1)[0] - outbound := generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) - inbound := generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) + inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) if sc.established { manager.processOutgoingHooks(outbound) @@ -281,11 +292,11 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32), } - os.Setenv("NB_DISABLE_CONNTRACK", "1") + require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { - return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), - generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) }, desc: "Allow non-WG: TCP new connection", }, @@ -298,12 +309,12 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32), } - os.Setenv("NB_DISABLE_CONNTRACK", "1") + require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { // Generate packets with ACK flag for established connection - return generateTCPPacketWithFlags(srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), - generateTCPPacketWithFlags(dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) }, desc: "Allow non-WG: TCP established connection", }, @@ -316,11 +327,11 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32), } - os.Setenv("NB_DISABLE_CONNTRACK", "1") + require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { - return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), - generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) }, desc: "Allow non-WG: UDP new connection", }, @@ -333,11 +344,11 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32), } - os.Setenv("NB_DISABLE_CONNTRACK", "1") + require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { - return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), - generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) }, desc: "Allow non-WG: UDP established connection", }, @@ -350,11 +361,11 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(0, 32), } - os.Unsetenv("NB_DISABLE_CONNTRACK") + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { - return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), - generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) }, desc: "Stateful: TCP new connection", }, @@ -367,12 +378,12 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(0, 32), } - os.Unsetenv("NB_DISABLE_CONNTRACK") + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { // Generate established TCP packets (ACK flag) - return generateTCPPacketWithFlags(srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), - generateTCPPacketWithFlags(dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) }, desc: "Stateful: TCP established connection", }, @@ -385,12 +396,12 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(0, 32), } - os.Unsetenv("NB_DISABLE_CONNTRACK") + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { // Generate packets with PSH+ACK flags for data transfer - return generateTCPPacketWithFlags(srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), - generateTCPPacketWithFlags(dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck)) + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck)) }, desc: "Stateful: TCP post-handshake data transfer", }, @@ -403,11 +414,11 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(0, 32), } - os.Unsetenv("NB_DISABLE_CONNTRACK") + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { - return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), - generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) }, desc: "Stateful: UDP new connection", }, @@ -420,11 +431,11 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(0, 32), } - os.Unsetenv("NB_DISABLE_CONNTRACK") + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { - return generatePacket(srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), - generatePacket(dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) }, desc: "Stateful: UDP established connection", }, @@ -435,7 +446,9 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, }) - defer manager.Reset(nil) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) // Setup scenario sc.setupFunc(manager) @@ -453,13 +466,13 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { // For TCP post-handshake, simulate full handshake if sc.state == "post_handshake" { // SYN - syn := generateTCPPacketWithFlags(srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) + syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) manager.processOutgoingHooks(syn) // SYN-ACK - synack := generateTCPPacketWithFlags(dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) + synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) manager.dropFilter(synack, manager.incomingRules) // ACK - ack := generateTCPPacketWithFlags(srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) + ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) manager.processOutgoingHooks(ack) } } @@ -473,7 +486,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { } // generateTCPPacketWithFlags creates a TCP packet with specific flags -func generateTCPPacketWithFlags(srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte { +func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { ipv4 := &layers.IPv4{ TTL: 64, Version: 4, @@ -494,10 +507,10 @@ func generateTCPPacketWithFlags(srcIP, dstIP net.IP, srcPort, dstPort uint16, fl tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0 tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0 - tcp.SetNetworkLayerForChecksum(ipv4) + require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} - gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload([]byte("test"))) + require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) return buf.Bytes() } From 802a9be0c3106727b51a83de278b1a669a4dd20d Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 22 Dec 2024 19:18:32 +0100 Subject: [PATCH 25/28] Fix remaining lint issues --- .../firewall/uspfilter/conntrack/tcp_test.go | 28 ++----------------- .../uspfilter/uspfilter_bench_test.go | 12 +++++--- 2 files changed, 10 insertions(+), 30 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index cf9876a626a..c87ab93befa 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -166,6 +166,8 @@ func TestTCPStateMachine(t *testing.T) { // Helper to establish a TCP connection func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { + t.Helper() + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) @@ -174,32 +176,6 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) } -// Benchmarks for the optimized implementation -func (t *TCPTracker) benchmarkTrackOutbound(b *testing.B) { - srcIP := net.ParseIP("192.168.1.1") - dstIP := net.ParseIP("192.168.1.2") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - t.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) - } -} - -func (t *TCPTracker) benchmarkIsValidInbound(b *testing.B) { - srcIP := net.ParseIP("192.168.1.1") - dstIP := net.ParseIP("192.168.1.2") - - // Pre-populate some connections - for i := 0; i < 1000; i++ { - t.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - t.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) - } -} - func BenchmarkTCPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { tracker := NewTCPTracker(DefaultTCPTimeout) diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index 6cf68980443..cb732641e2f 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -40,6 +40,8 @@ func generateRandomIPs(n int) []net.IP { } func generatePacket(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte { + b.Helper() + ipv4 := &layers.IPv4{ TTL: 64, Version: 4, @@ -292,7 +294,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32), } - require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) + b.Setenv("NB_DISABLE_CONNTRACK", "1") }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), @@ -309,7 +311,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32), } - require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) + b.Setenv("NB_DISABLE_CONNTRACK", "1") }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { // Generate packets with ACK flag for established connection @@ -327,7 +329,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32), } - require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) + b.Setenv("NB_DISABLE_CONNTRACK", "1") }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), @@ -344,7 +346,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32), } - require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) + b.Setenv("NB_DISABLE_CONNTRACK", "1") }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), @@ -487,6 +489,8 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { // generateTCPPacketWithFlags creates a TCP packet with specific flags func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { + b.Helper() + ipv4 := &layers.IPv4{ TTL: 64, Version: 4, From 84df403928e7e19116d8d9997e98c8ac66bf539b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 22 Dec 2024 20:56:54 +0100 Subject: [PATCH 26/28] Properly use sync pools (pointers) --- client/firewall/uspfilter/conntrack/common.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 11d6599315d..a4b1971bf6e 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -110,7 +110,8 @@ func NewPreallocatedIPs() *PreallocatedIPs { return &PreallocatedIPs{ Pool: sync.Pool{ New: func() interface{} { - return make(net.IP, 16) + ip := make(net.IP, 16) + return &ip }, }, } @@ -118,12 +119,12 @@ func NewPreallocatedIPs() *PreallocatedIPs { // Get retrieves an IP from the pool func (p *PreallocatedIPs) Get() net.IP { - return p.Pool.Get().(net.IP) + return *p.Pool.Get().(*net.IP) } // Put returns an IP to the pool func (p *PreallocatedIPs) Put(ip net.IP) { - p.Pool.Put(ip) + p.Pool.Put(&ip) } // copyIP copies an IP address efficiently From 9e00ea7481631a5112ba4d233b3368447d0b8ea0 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 22 Dec 2024 20:57:06 +0100 Subject: [PATCH 27/28] Add TCP RST test --- .../firewall/uspfilter/conntrack/tcp_test.go | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index c87ab93befa..3933c888943 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -164,6 +164,64 @@ func TestTCPStateMachine(t *testing.T) { }) } +func TestRSTHandling(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("100.64.0.1") + dstIP := net.ParseIP("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + tests := []struct { + name string + setupState func() + sendRST func() + wantValid bool + desc string + }{ + { + name: "RST in established", + setupState: func() { + // Establish connection first + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + }, + sendRST: func() { + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + }, + wantValid: true, + desc: "Should accept RST for established connection", + }, + { + name: "RST without connection", + setupState: func() {}, + sendRST: func() { + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + }, + wantValid: false, + desc: "Should reject RST without connection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupState() + tt.sendRST() + + // Verify connection state is as expected + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn := tracker.connections[key] + if tt.wantValid { + require.NotNil(t, conn) + require.Equal(t, TCPStateClosed, conn.State) + require.False(t, conn.IsEstablished()) + } + }) + } +} + // Helper to establish a TCP connection func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { t.Helper() From 6d0bf6350b230771106d561f1530ba59b8fb2907 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 22 Dec 2024 22:15:58 +0100 Subject: [PATCH 28/28] Add more comparison benchmarks (stateful vs stateless) --- .../uspfilter/uspfilter_bench_test.go | 478 ++++++++++++++++++ 1 file changed, 478 insertions(+) diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index cb732641e2f..3c661e71c70 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -487,6 +487,484 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { } } +var scenarios = []struct { + name string + stateful bool // Whether conntrack is enabled + rules bool // Whether to add return traffic rules + routed bool // Whether to test routed network traffic + connCount int // Number of concurrent connections + desc string +}{ + { + name: "stateless_with_rules_100conns", + stateful: false, + rules: true, + routed: false, + connCount: 100, + desc: "Pure stateless with return traffic rules, 100 conns", + }, + { + name: "stateless_with_rules_1000conns", + stateful: false, + rules: true, + routed: false, + connCount: 1000, + desc: "Pure stateless with return traffic rules, 1000 conns", + }, + { + name: "stateful_no_rules_100conns", + stateful: true, + rules: false, + routed: false, + connCount: 100, + desc: "Pure stateful tracking without rules, 100 conns", + }, + { + name: "stateful_no_rules_1000conns", + stateful: true, + rules: false, + routed: false, + connCount: 1000, + desc: "Pure stateful tracking without rules, 1000 conns", + }, + { + name: "stateful_with_rules_100conns", + stateful: true, + rules: true, + routed: false, + connCount: 100, + desc: "Combined stateful + rules (current implementation), 100 conns", + }, + { + name: "stateful_with_rules_1000conns", + stateful: true, + rules: true, + routed: false, + connCount: 1000, + desc: "Combined stateful + rules (current implementation), 1000 conns", + }, + { + name: "routed_network_100conns", + stateful: true, + rules: false, + routed: true, + connCount: 100, + desc: "Routed network traffic (non-WG), 100 conns", + }, + { + name: "routed_network_1000conns", + stateful: true, + rules: false, + routed: true, + connCount: 1000, + desc: "Routed network traffic (non-WG), 1000 conns", + }, +} + +// BenchmarkLongLivedConnections tests performance with realistic TCP traffic patterns +func BenchmarkLongLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + // Single rule to allow all return traffic from port 80 + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create established connections + for i := 0; i < sc.connCount; i++ { + // Initial SYN + syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + + // SYN-ACK + synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + + // ACK + ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + + // Prepare test packets simulating bidirectional traffic + inPackets := make([][]byte, sc.connCount) + outPackets := make([][]byte, sc.connCount) + for i := 0; i < sc.connCount; i++ { + // Server -> Client (inbound) + inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)) + // Client -> Server (outbound) + outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + connIdx := i % sc.connCount + + // Simulate bidirectional traffic + // First outbound data + manager.processOutgoingHooks(outPackets[connIdx]) + // Then inbound response - this is what we're actually measuring + manager.dropFilter(inPackets[connIdx], manager.incomingRules) + } + }) + } +} + +// BenchmarkShortLivedConnections tests performance with many short-lived connections +func BenchmarkShortLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + // Single rule to allow all return traffic from port 80 + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create packet patterns for a complete HTTP-like short connection: + // 1. Initial handshake (SYN, SYN-ACK, ACK) + // 2. HTTP Request (PSH+ACK from client) + // 3. HTTP Response (PSH+ACK from server) + // 4. Connection teardown (FIN+ACK, ACK, FIN+ACK, ACK) + type connPackets struct { + syn []byte + synAck []byte + ack []byte + request []byte + response []byte + finClient []byte + ackServer []byte + finServer []byte + ackClient []byte + } + + // Generate all possible connection patterns + patterns := make([]connPackets, sc.connCount) + for i := 0; i < sc.connCount; i++ { + patterns[i] = connPackets{ + // Handshake + syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)), + synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)), + ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + + // Data transfer + request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)), + + // Connection teardown + finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPAck)), + finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Each iteration creates a new short-lived connection + connIdx := i % sc.connCount + p := patterns[connIdx] + + // Connection establishment + manager.processOutgoingHooks(p.syn) + manager.dropFilter(p.synAck, manager.incomingRules) + manager.processOutgoingHooks(p.ack) + + // Data transfer + manager.processOutgoingHooks(p.request) + manager.dropFilter(p.response, manager.incomingRules) + + // Connection teardown + manager.processOutgoingHooks(p.finClient) + manager.dropFilter(p.ackServer, manager.incomingRules) + manager.dropFilter(p.finServer, manager.incomingRules) + manager.processOutgoingHooks(p.ackClient) + } + }) + } +} + +// BenchmarkParallelLongLivedConnections tests performance with realistic TCP traffic patterns in parallel +func BenchmarkParallelLongLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create established connections + for i := 0; i < sc.connCount; i++ { + syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + + synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + + ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + + // Pre-generate test packets + inPackets := make([][]byte, sc.connCount) + outPackets := make([][]byte, sc.connCount) + for i := 0; i < sc.connCount; i++ { + inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)) + outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + // Each goroutine gets its own counter to distribute load + counter := 0 + for pb.Next() { + connIdx := counter % sc.connCount + counter++ + + // Simulate bidirectional traffic + manager.processOutgoingHooks(outPackets[connIdx]) + manager.dropFilter(inPackets[connIdx], manager.incomingRules) + } + }) + }) + } +} + +// BenchmarkParallelShortLivedConnections tests performance with many short-lived connections in parallel +func BenchmarkParallelShortLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + if sc.rules { + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs and pre-generate all packet patterns + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + type connPackets struct { + syn []byte + synAck []byte + ack []byte + request []byte + response []byte + finClient []byte + ackServer []byte + finServer []byte + ackClient []byte + } + + patterns := make([]connPackets, sc.connCount) + for i := 0; i < sc.connCount; i++ { + patterns[i] = connPackets{ + syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)), + synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)), + ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)), + finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPAck)), + finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + } + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + counter := 0 + for pb.Next() { + connIdx := counter % sc.connCount + counter++ + p := patterns[connIdx] + + // Full connection lifecycle + manager.processOutgoingHooks(p.syn) + manager.dropFilter(p.synAck, manager.incomingRules) + manager.processOutgoingHooks(p.ack) + + manager.processOutgoingHooks(p.request) + manager.dropFilter(p.response, manager.incomingRules) + + manager.processOutgoingHooks(p.finClient) + manager.dropFilter(p.ackServer, manager.incomingRules) + manager.dropFilter(p.finServer, manager.incomingRules) + manager.processOutgoingHooks(p.ackClient) + } + }) + }) + } +} + // generateTCPPacketWithFlags creates a TCP packet with specific flags func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { b.Helper()