Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Add stateful userspace firewall and remove egress filters #3093

Merged
merged 30 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d1fad7f
Remove userspace egress filter
lixmal Dec 20, 2024
c81ac1a
Remove iptables egress filter
lixmal Dec 20, 2024
afb0343
Remove nftables egress filter
lixmal Dec 20, 2024
fbfb2cd
Remove unused code
lixmal Dec 20, 2024
4d14cf6
Still process outgoing udp hooks
lixmal Dec 20, 2024
a778e91
Add udp hook test
lixmal Dec 20, 2024
8216ab6
Add udp conn tracking
lixmal Dec 20, 2024
5d97cf8
Fix udp test
lixmal Dec 20, 2024
9d1702c
Fix corrupted IPs
lixmal Dec 20, 2024
49d1de2
Use slices.Clone consistently
lixmal Dec 20, 2024
2a5ef98
Add icmp tracker
lixmal Dec 20, 2024
f0c8c90
Reset icmp tracker
lixmal Dec 20, 2024
dadf64e
Clean up icmp tracker
lixmal Dec 20, 2024
fa38d8e
Add TCP tracker
lixmal Dec 20, 2024
0970b75
Use switch
lixmal Dec 21, 2024
c6aeb48
Move locks further down
lixmal Dec 21, 2024
b9767d4
Generally allow time exceeded and destination unreachable, disallow e…
lixmal Dec 21, 2024
8661ead
Add env to disable statefulness
lixmal Dec 21, 2024
1306da2
Add benchmarks
lixmal Dec 21, 2024
9ff0475
Add benchmark to compare routed network return traffic handling
lixmal Dec 21, 2024
1a9a82b
Improve TCP state handling
lixmal Dec 21, 2024
3c158d4
Fix races, improve performance and add benchmarks
lixmal Dec 22, 2024
8cfdc87
Merge branch 'main' into remove-egress-filters
lixmal Dec 22, 2024
07019d2
Fix udp test
lixmal Dec 22, 2024
ed77f48
Fix lint
lixmal Dec 22, 2024
802a9be
Fix remaining lint issues
lixmal Dec 22, 2024
84df403
Properly use sync pools (pointers)
lixmal Dec 22, 2024
9e00ea7
Add TCP RST test
lixmal Dec 22, 2024
6d0bf63
Add more comparison benchmarks (stateful vs stateless)
lixmal Dec 22, 2024
590614e
Merge branch 'main' into HEAD
lixmal Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions client/firewall/iptables/acl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,18 +332,12 @@ func (m *aclManager) createDefaultChains() error {

// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() {

established := getConntrackEstablished()

m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))

m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))

m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
Expand Down
14 changes: 2 additions & 12 deletions client/firewall/iptables/manager_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,9 @@ func (m *Manager) AllowNetbird() error {
"",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
return fmt.Errorf("allow netbird interface traffic: %w", err)
}
_, err = m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionOUT,
firewall.ActionAccept,
"",
"",
)
return err
return nil
}

// Flush doesn't need to be implemented for this manager
Expand Down
53 changes: 0 additions & 53 deletions client/firewall/nftables/acl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"time"
Expand All @@ -28,7 +27,6 @@ const (

// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"

Expand Down Expand Up @@ -441,18 +439,6 @@ func (m *AclManager) createDefaultChains() (err error) {
return err
}

// netbird-acl-output-filter
// type filter hook output priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err)
return err
}

// netbird-acl-forward-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
Expand Down Expand Up @@ -619,45 +605,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil
}

func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
dstOp := expr.CmpOpNeq
expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: dstOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}

func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
Expand Down
20 changes: 19 additions & 1 deletion client/firewall/uspfilter/allow_netbird.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

package uspfilter

import "github.com/netbirdio/netbird/client/internal/statemanager"
import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
Expand All @@ -12,6 +15,21 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)

if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}

if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}

if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}

if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager)
}
Expand Down
16 changes: 16 additions & 0 deletions client/firewall/uspfilter/allow_netbird_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

Expand All @@ -26,6 +27,21 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)

if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}

if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}

if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}

if !isWindowsFirewallReachable() {
return nil
}
Expand Down
138 changes: 138 additions & 0 deletions client/firewall/uspfilter/conntrack/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// common.go
package conntrack

import (
"net"
"sync"
"sync/atomic"
"time"
)

// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
sync.RWMutex
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access
established atomic.Bool
}

// these small methods will be inlined by the compiler

// UpdateLastSeen safely updates the last seen timestamp
func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}

// IsEstablished safely checks if connection is established
func (b *BaseConnTrack) IsEstablished() bool {
return b.established.Load()
}

// SetEstablished safely sets the established state
func (b *BaseConnTrack) SetEstablished(state bool) {
b.established.Store(state)
}

// GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load())
}

// timeoutExceeded checks if the connection has exceeded the given timeout
func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
lastSeen := time.Unix(0, b.lastSeen.Load())
return time.Since(lastSeen) > timeout
}

// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte

// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}

// ConnKey uniquely identifies a connection
type ConnKey struct {
SrcIP IPAddr
DstIP IPAddr
SrcPort uint16
DstPort uint16
}

// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}

// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}

// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}

// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}

// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}

// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}

// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
}
Loading
Loading