Skip to content

Commit

Permalink
Fix udp test
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Dec 22, 2024
1 parent 8cfdc87 commit 2c6e780
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 98 deletions.
8 changes: 4 additions & 4 deletions client/firewall/uspfilter/conntrack/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte

// makeIPAddr creates an IPAddr from net.IP
func makeIPAddr(ip net.IP) (addr IPAddr) {
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
Expand All @@ -72,8 +72,8 @@ type ConnKey struct {
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
return ConnKey{
SrcIP: makeIPAddr(srcIP),
DstIP: makeIPAddr(dstIP),
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
Expand Down
6 changes: 3 additions & 3 deletions client/firewall/uspfilter/conntrack/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ import (
)

func BenchmarkIPOperations(b *testing.B) {
b.Run("makeIPAddr", func(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = makeIPAddr(ip)
_ = MakeIPAddr(ip)
}
})

b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := makeIPAddr(ip1)
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
Expand Down
8 changes: 4 additions & 4 deletions client/firewall/uspfilter/conntrack/icmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
}

return conn.IsEstablished() &&
ValidateIPs(makeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(makeIPAddr(dstIP), conn.SourceIP) &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id &&
conn.Sequence == seq
}
Expand Down Expand Up @@ -162,8 +162,8 @@ func (t *ICMPTracker) Close() {
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
return ICMPConnKey{
SrcIP: makeIPAddr(srcIP),
DstIP: makeIPAddr(dstIP),
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
ID: id,
Sequence: seq,
}
Expand Down
35 changes: 3 additions & 32 deletions client/firewall/uspfilter/conntrack/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
}

return conn.IsEstablished() &&
ValidateIPs(makeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(makeIPAddr(dstIP), conn.SourceIP) &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
}
Expand Down Expand Up @@ -149,39 +149,10 @@ func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, d
return nil, false
}

// Create a copy with new IP allocations
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, conn.SourceIP)
copyIP(dstIPCopy, conn.DestIP)

connCopy := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
},
}
connCopy.lastSeen.Store(conn.lastSeen.Load())
connCopy.established.Store(conn.IsEstablished())

return connCopy, true
return conn, true
}

// Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}

func makeKey(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) ConnKey {
var srcAddr, dstAddr [16]byte
copy(srcAddr[:], srcIP.To16()) // Ensure 16-byte representation
copy(dstAddr[:], dstIP.To16())
return ConnKey{
SrcIP: srcAddr,
SrcPort: srcPort,
DstIP: dstAddr,
DstPort: dstPort,
}
}
2 changes: 1 addition & 1 deletion client/firewall/uspfilter/conntrack/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)

// Verify connection was tracked
key := makeKey(srcIP, srcPort, dstIP, dstPort)
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := tracker.connections[key]
require.True(t, exists)
assert.True(t, conn.SourceIP.Equal(srcIP))
Expand Down
124 changes: 70 additions & 54 deletions client/firewall/uspfilter/uspfilter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ func TestAddUDPPacketHook(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &Manager{
incomingRules: map[string]RuleSet{},
outgoingRules: map[string]RuleSet{},
}
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
require.NoError(t, err)

manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)

Expand Down Expand Up @@ -315,7 +315,7 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to set network layer for checksum: %v", err)
return
}
payload := gopacket.Payload([]byte("test"))
payload := gopacket.Payload("test")

buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
Expand Down Expand Up @@ -350,6 +350,9 @@ func TestRemovePacketHook(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
defer func() {
require.NoError(t, manager.Reset(nil))
}()

// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
Expand Down Expand Up @@ -387,26 +390,33 @@ func TestRemovePacketHook(t *testing.T) {
}

func TestProcessOutgoingHooks(t *testing.T) {
manager := &Manager{
outgoingRules: map[string]RuleSet{},
wgNetwork: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
decoders: sync.Pool{
New: func() any {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
return d
},
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
require.NoError(t, err)

manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
defer func() {
require.NoError(t, manager.Reset(nil))
}()

manager.decoders = sync.Pool{
New: func() any {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
return d
},
udpTracker: conntrack.NewUDPTracker(100 * time.Millisecond),
}

hookCalled := false
Expand Down Expand Up @@ -434,9 +444,9 @@ func TestProcessOutgoingHooks(t *testing.T) {
DstPort: 53,
}

err := udp.SetNetworkLayerForChecksum(ipv4)
err = udp.SetNetworkLayerForChecksum(ipv4)
require.NoError(t, err)
payload := gopacket.Payload([]byte("test"))
payload := gopacket.Payload("test")

buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
Expand Down Expand Up @@ -497,29 +507,34 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
}

func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager := &Manager{
outgoingRules: map[string]RuleSet{},
incomingRules: map[string]RuleSet{},
wgNetwork: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
decoders: sync.Pool{
New: func() any {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
return d
},
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
require.NoError(t, err)

manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}

manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
return d
},
udpTracker: conntrack.NewUDPTracker(200 * time.Millisecond),
}
defer manager.udpTracker.Close()
defer func() {
require.NoError(t, manager.Reset(nil))
}()

// Set up packet parameters
srcIP := net.ParseIP("100.10.0.1")
Expand All @@ -540,7 +555,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
DstPort: layers.UDPPort(dstPort),
}

err := outboundUDP.SetNetworkLayerForChecksum(outboundIPv4)
err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4)
require.NoError(t, err)

outboundBuf := gopacket.NewSerializeBuffer()
Expand All @@ -552,19 +567,20 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
err = gopacket.SerializeLayers(outboundBuf, opts,
outboundIPv4,
outboundUDP,
gopacket.Payload([]byte("test")),
gopacket.Payload("test"),
)
require.NoError(t, err)

// Process outbound packet and verify connection tracking
drop := manager.processOutgoingHooks(outboundBuf.Bytes())
drop := manager.DropOutgoing(outboundBuf.Bytes())
require.False(t, drop, "Initial outbound packet should not be dropped")

// Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)

require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, conn.SourceIP.Equal(srcIP), "Source IP should match")
require.True(t, conn.DestIP.Equal(dstIP), "Destination IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")

Expand All @@ -588,7 +604,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
err = gopacket.SerializeLayers(inboundBuf, opts,
inboundIPv4,
inboundUDP,
gopacket.Payload([]byte("response")),
gopacket.Payload("response"),
)
require.NoError(t, err)
// Test roundtrip response handling over time
Expand Down Expand Up @@ -689,7 +705,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
err = gopacket.SerializeLayers(testBuf, opts,
&testIPv4,
&testUDP,
gopacket.Payload([]byte("response")),
gopacket.Payload("response"),
)
require.NoError(t, err)

Expand Down

0 comments on commit 2c6e780

Please sign in to comment.