Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Add stateful userspace firewall and remove egress filters #3093

Merged
merged 30 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d1fad7f
Remove userspace egress filter
lixmal Dec 20, 2024
c81ac1a
Remove iptables egress filter
lixmal Dec 20, 2024
afb0343
Remove nftables egress filter
lixmal Dec 20, 2024
fbfb2cd
Remove unused code
lixmal Dec 20, 2024
4d14cf6
Still process outgoing udp hooks
lixmal Dec 20, 2024
a778e91
Add udp hook test
lixmal Dec 20, 2024
8216ab6
Add udp conn tracking
lixmal Dec 20, 2024
5d97cf8
Fix udp test
lixmal Dec 20, 2024
9d1702c
Fix corrupted IPs
lixmal Dec 20, 2024
49d1de2
Use slices.Clone consistently
lixmal Dec 20, 2024
2a5ef98
Add icmp tracker
lixmal Dec 20, 2024
f0c8c90
Reset icmp tracker
lixmal Dec 20, 2024
dadf64e
Clean up icmp tracker
lixmal Dec 20, 2024
fa38d8e
Add TCP tracker
lixmal Dec 20, 2024
0970b75
Use switch
lixmal Dec 21, 2024
c6aeb48
Move locks further down
lixmal Dec 21, 2024
b9767d4
Generally allow time exceeded and destination unreachable, disallow e…
lixmal Dec 21, 2024
8661ead
Add env to disable statefulness
lixmal Dec 21, 2024
1306da2
Add benchmarks
lixmal Dec 21, 2024
9ff0475
Add benchmark to compare routed network return traffic handling
lixmal Dec 21, 2024
1a9a82b
Improve TCP state handling
lixmal Dec 21, 2024
3c158d4
Fix races, improve performance and add benchmarks
lixmal Dec 22, 2024
8cfdc87
Merge branch 'main' into remove-egress-filters
lixmal Dec 22, 2024
07019d2
Fix udp test
lixmal Dec 22, 2024
ed77f48
Fix lint
lixmal Dec 22, 2024
802a9be
Fix remaining lint issues
lixmal Dec 22, 2024
84df403
Properly use sync pools (pointers)
lixmal Dec 22, 2024
9e00ea7
Add TCP RST test
lixmal Dec 22, 2024
6d0bf63
Add more comparison benchmarks (stateful vs stateless)
lixmal Dec 22, 2024
590614e
Merge branch 'main' into HEAD
lixmal Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions client/firewall/iptables/acl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,18 +332,12 @@ func (m *aclManager) createDefaultChains() error {

// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() {

established := getConntrackEstablished()

m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))

m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))

m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
Expand Down
14 changes: 2 additions & 12 deletions client/firewall/iptables/manager_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,9 @@ func (m *Manager) AllowNetbird() error {
"",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
return fmt.Errorf("allow netbird interface traffic: %w", err)
}
_, err = m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionOUT,
firewall.ActionAccept,
"",
"",
)
return err
return nil
}

// Flush doesn't need to be implemented for this manager
Expand Down
53 changes: 0 additions & 53 deletions client/firewall/nftables/acl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"time"
Expand All @@ -28,7 +27,6 @@ const (

// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"

Expand Down Expand Up @@ -441,18 +439,6 @@ func (m *AclManager) createDefaultChains() (err error) {
return err
}

// netbird-acl-output-filter
// type filter hook output priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err)
return err
}

// netbird-acl-forward-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
Expand Down Expand Up @@ -619,45 +605,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil
}

func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
dstOp := expr.CmpOpNeq
expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: dstOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}

func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
Expand Down
20 changes: 19 additions & 1 deletion client/firewall/uspfilter/allow_netbird.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

package uspfilter

import "github.com/netbirdio/netbird/client/internal/statemanager"
import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
Expand All @@ -12,6 +15,21 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)

if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}

if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}

if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}

if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager)
}
Expand Down
16 changes: 16 additions & 0 deletions client/firewall/uspfilter/allow_netbird_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

Expand All @@ -26,6 +27,21 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)

if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}

if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}

if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}

if !isWindowsFirewallReachable() {
return nil
}
Expand Down
159 changes: 159 additions & 0 deletions client/firewall/uspfilter/conntrack/icmp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package conntrack

import (
"net"
"slices"
"sync"
"time"

"github.com/google/gopacket/layers"
)

const (
// DefaultICMPTimeout is the default timeout for ICMP connections
DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second
)

// ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct {
// Supports both IPv4 and IPv6
SrcIP [16]byte
DstIP [16]byte
Sequence uint16 // ICMP sequence number
ID uint16 // ICMP identifier
}

// ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct {
SourceIP net.IP
DestIP net.IP
Sequence uint16
ID uint16
LastSeen time.Time
established bool
}

// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
mutex sync.RWMutex
done chan struct{}
}

// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}

tracker := &ICMPTracker{
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}),
}

go tracker.cleanupRoutine()
return tracker
}

// TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
t.mutex.Lock()
defer t.mutex.Unlock()

key := makeICMPKey(srcIP, dstIP, id, seq)

t.connections[key] = &ICMPConnTrack{
SourceIP: slices.Clone(srcIP),
DestIP: slices.Clone(dstIP),
ID: id,
Sequence: seq,
LastSeen: time.Now(),
established: true,
}
}

// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
t.mutex.RLock()
defer t.mutex.RUnlock()

// Always allow Echo Request (type 8 for IPv4, 128 for IPv6)
if icmpType == uint8(layers.ICMPv4TypeEchoRequest) || icmpType == uint8(layers.ICMPv6TypeEchoRequest) {
return true
}

// For Echo Reply, check if we have a matching request
if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) {
return false
}

key := makeICMPKey(dstIP, srcIP, id, seq)
conn, exists := t.connections[key]
if !exists {
return false
}

// Check if connection is still valid
if time.Since(conn.LastSeen) > t.timeout {
return false
}

if conn.established &&
conn.DestIP.Equal(srcIP) &&
conn.SourceIP.Equal(dstIP) &&
conn.ID == id &&
conn.Sequence == seq {

conn.LastSeen = time.Now()
return true
}

return false
}

func (t *ICMPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}

func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()

now := time.Now()
for key, conn := range t.connections {
if now.Sub(conn.LastSeen) > t.timeout {
delete(t.connections, key)
}
}
}

// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
}

func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
var srcAddr, dstAddr [16]byte
copy(srcAddr[:], srcIP.To16())
copy(dstAddr[:], dstIP.To16())
return ICMPConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
ID: id,
Sequence: seq,
}
}
Loading
Loading