From 2c6e780b60b58a2ab74d7d3b9fe0d8f9ae71a30a Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 22 Dec 2024 13:42:07 +0100 Subject: [PATCH] Fix udp test --- client/firewall/uspfilter/conntrack/common.go | 8 +- .../uspfilter/conntrack/common_test.go | 6 +- client/firewall/uspfilter/conntrack/icmp.go | 8 +- client/firewall/uspfilter/conntrack/udp.go | 35 +---- .../firewall/uspfilter/conntrack/udp_test.go | 2 +- client/firewall/uspfilter/uspfilter_test.go | 124 ++++++++++-------- 6 files changed, 85 insertions(+), 98 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 079a0175f2f..11d6599315d 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -50,8 +50,8 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool { // IPAddr is a fixed-size IP address to avoid allocations type IPAddr [16]byte -// makeIPAddr creates an IPAddr from net.IP -func makeIPAddr(ip net.IP) (addr IPAddr) { +// MakeIPAddr creates an IPAddr from net.IP +func MakeIPAddr(ip net.IP) (addr IPAddr) { // Optimization: check for v4 first as it's more common if ip4 := ip.To4(); ip4 != nil { copy(addr[12:], ip4) @@ -72,8 +72,8 @@ type ConnKey struct { // makeConnKey creates a connection key func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { return ConnKey{ - SrcIP: makeIPAddr(srcIP), - DstIP: makeIPAddr(dstIP), + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), SrcPort: srcPort, DstPort: dstPort, } diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index a337f649b47..72d006def57 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -6,18 +6,18 @@ import ( ) func BenchmarkIPOperations(b *testing.B) { - b.Run("makeIPAddr", func(b *testing.B) { + b.Run("MakeIPAddr", func(b *testing.B) { ip := net.ParseIP("192.168.1.1") b.ResetTimer() for i := 0; i < b.N; i++ { - _ = makeIPAddr(ip) + _ = MakeIPAddr(ip) } }) b.Run("ValidateIPs", func(b *testing.B) { ip1 := net.ParseIP("192.168.1.1") ip2 := net.ParseIP("192.168.1.1") - addr := makeIPAddr(ip1) + addr := MakeIPAddr(ip1) b.ResetTimer() for i := 0; i < b.N; i++ { _ = ValidateIPs(addr, ip2) diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 4cab4cb0e72..e0a971678f1 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -116,8 +116,8 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq } return conn.IsEstablished() && - ValidateIPs(makeIPAddr(srcIP), conn.DestIP) && - ValidateIPs(makeIPAddr(dstIP), conn.SourceIP) && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && conn.ID == id && conn.Sequence == seq } @@ -162,8 +162,8 @@ func (t *ICMPTracker) Close() { // makeICMPKey creates an ICMP connection key func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { return ICMPConnKey{ - SrcIP: makeIPAddr(srcIP), - DstIP: makeIPAddr(dstIP), + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), ID: id, Sequence: seq, } diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 4d55ec0df83..a969a4e8425 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -93,8 +93,8 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, } return conn.IsEstablished() && - ValidateIPs(makeIPAddr(srcIP), conn.DestIP) && - ValidateIPs(makeIPAddr(dstIP), conn.SourceIP) && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && conn.DestPort == srcPort && conn.SourcePort == dstPort } @@ -149,39 +149,10 @@ func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, d return nil, false } - // Create a copy with new IP allocations - srcIPCopy := t.ipPool.Get() - dstIPCopy := t.ipPool.Get() - copyIP(srcIPCopy, conn.SourceIP) - copyIP(dstIPCopy, conn.DestIP) - - connCopy := &UDPConnTrack{ - BaseConnTrack: BaseConnTrack{ - SourceIP: srcIPCopy, - DestIP: dstIPCopy, - SourcePort: conn.SourcePort, - DestPort: conn.DestPort, - }, - } - connCopy.lastSeen.Store(conn.lastSeen.Load()) - connCopy.established.Store(conn.IsEstablished()) - - return connCopy, true + return conn, true } // Timeout returns the configured timeout duration for the tracker func (t *UDPTracker) Timeout() time.Duration { return t.timeout } - -func makeKey(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) ConnKey { - var srcAddr, dstAddr [16]byte - copy(srcAddr[:], srcIP.To16()) // Ensure 16-byte representation - copy(dstAddr[:], dstIP.To16()) - return ConnKey{ - SrcIP: srcAddr, - SrcPort: srcPort, - DstIP: dstAddr, - DstPort: dstPort, - } -} diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 1a8afc21a01..19e607545f5 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -51,7 +51,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) // Verify connection was tracked - key := makeKey(srcIP, srcPort, dstIP, dstPort) + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) conn, exists := tracker.connections[key] require.True(t, exists) assert.True(t, conn.SourceIP.Equal(srcIP)) diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index ea78a013abc..d3563e6f251 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -187,10 +187,10 @@ func TestAddUDPPacketHook(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - manager := &Manager{ - incomingRules: map[string]RuleSet{}, - outgoingRules: map[string]RuleSet{}, - } + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -315,7 +315,7 @@ func TestNotMatchByIP(t *testing.T) { t.Errorf("failed to set network layer for checksum: %v", err) return } - payload := gopacket.Payload([]byte("test")) + payload := gopacket.Payload("test") buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -350,6 +350,9 @@ func TestRemovePacketHook(t *testing.T) { if err != nil { t.Fatalf("Failed to create Manager: %s", err) } + defer func() { + require.NoError(t, manager.Reset(nil)) + }() // Add a UDP packet hook hookFunc := func(data []byte) bool { return true } @@ -387,26 +390,33 @@ func TestRemovePacketHook(t *testing.T) { } func TestProcessOutgoingHooks(t *testing.T) { - manager := &Manager{ - outgoingRules: map[string]RuleSet{}, - wgNetwork: &net.IPNet{ - IP: net.ParseIP("100.10.0.0"), - Mask: net.CIDRMask(16, 32), - }, - decoders: sync.Pool{ - New: func() any { - d := &decoder{ - decoded: []gopacket.LayerType{}, - } - d.parser = gopacket.NewDecodingLayerParser( - layers.LayerTypeIPv4, - &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, - ) - d.parser.IgnoreUnsupported = true - return d - }, + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + manager.udpTracker.Close() + manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) + defer func() { + require.NoError(t, manager.Reset(nil)) + }() + + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d }, - udpTracker: conntrack.NewUDPTracker(100 * time.Millisecond), } hookCalled := false @@ -434,9 +444,9 @@ func TestProcessOutgoingHooks(t *testing.T) { DstPort: 53, } - err := udp.SetNetworkLayerForChecksum(ipv4) + err = udp.SetNetworkLayerForChecksum(ipv4) require.NoError(t, err) - payload := gopacket.Payload([]byte("test")) + payload := gopacket.Payload("test") buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -497,29 +507,34 @@ func TestUSPFilterCreatePerformance(t *testing.T) { } func TestStatefulFirewall_UDPTracking(t *testing.T) { - manager := &Manager{ - outgoingRules: map[string]RuleSet{}, - incomingRules: map[string]RuleSet{}, - wgNetwork: &net.IPNet{ - IP: net.ParseIP("100.10.0.0"), - Mask: net.CIDRMask(16, 32), - }, - decoders: sync.Pool{ - New: func() any { - d := &decoder{ - decoded: []gopacket.LayerType{}, - } - d.parser = gopacket.NewDecodingLayerParser( - layers.LayerTypeIPv4, - &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, - ) - d.parser.IgnoreUnsupported = true - return d - }, + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + + manager.udpTracker.Close() // Close the existing tracker + manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d }, - udpTracker: conntrack.NewUDPTracker(200 * time.Millisecond), } - defer manager.udpTracker.Close() + defer func() { + require.NoError(t, manager.Reset(nil)) + }() // Set up packet parameters srcIP := net.ParseIP("100.10.0.1") @@ -540,7 +555,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { DstPort: layers.UDPPort(dstPort), } - err := outboundUDP.SetNetworkLayerForChecksum(outboundIPv4) + err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4) require.NoError(t, err) outboundBuf := gopacket.NewSerializeBuffer() @@ -552,19 +567,20 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { err = gopacket.SerializeLayers(outboundBuf, opts, outboundIPv4, outboundUDP, - gopacket.Payload([]byte("test")), + gopacket.Payload("test"), ) require.NoError(t, err) // Process outbound packet and verify connection tracking - drop := manager.processOutgoingHooks(outboundBuf.Bytes()) + drop := manager.DropOutgoing(outboundBuf.Bytes()) require.False(t, drop, "Initial outbound packet should not be dropped") // Verify connection was tracked conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + require.True(t, exists, "Connection should be tracked after outbound packet") - require.True(t, conn.SourceIP.Equal(srcIP), "Source IP should match") - require.True(t, conn.DestIP.Equal(dstIP), "Destination IP should match") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match") require.Equal(t, srcPort, conn.SourcePort, "Source port should match") require.Equal(t, dstPort, conn.DestPort, "Destination port should match") @@ -588,7 +604,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { err = gopacket.SerializeLayers(inboundBuf, opts, inboundIPv4, inboundUDP, - gopacket.Payload([]byte("response")), + gopacket.Payload("response"), ) require.NoError(t, err) // Test roundtrip response handling over time @@ -689,7 +705,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { err = gopacket.SerializeLayers(testBuf, opts, &testIPv4, &testUDP, - gopacket.Payload([]byte("response")), + gopacket.Payload("response"), ) require.NoError(t, err)