diff --git a/udp/udp.go b/udp/udp.go index 02485bd..f7dea64 100644 --- a/udp/udp.go +++ b/udp/udp.go @@ -1,12 +1,14 @@ package udp import ( - "github.com/wwqgtxx/wstunnel/config" "log" "net" + "time" + + "github.com/wwqgtxx/wstunnel/config" ) -const MaxUdpAge = 5 * 60 +const MaxUdpAge = 5 * time.Minute type Tunnel interface { Handle() diff --git a/udp/udp_mmsg.go b/udp/udp_mmsg.go index 5321728..bd24e87 100644 --- a/udp/udp_mmsg.go +++ b/udp/udp_mmsg.go @@ -7,9 +7,9 @@ import ( "log" "net" "sync" + "time" "github.com/wwqgtxx/wstunnel/config" - cache "github.com/wwqgtxx/wstunnel/utils/lrucache" ) const ( @@ -40,35 +40,21 @@ var WriteMsgsBufPool = sync.Pool{New: func() any { // This means we can use this struct to read from a socket that receives both IPv4 and IPv6 messages. var _ ipv4.Message = ipv6.Message{} -type MmsgNatItem struct { +type MmsgMapItem struct { net.Conn *ipv4.PacketConn sync.Mutex } type MmsgTunnel struct { - nat *cache.LruCache[string, *MmsgNatItem] + connMap sync.Map address string target string reserved []byte } func NewMmsgTunnel(udpConfig config.UdpConfig) Tunnel { - nat := cache.New[string, *MmsgNatItem]( - cache.WithAge[string, *MmsgNatItem](MaxUdpAge), - cache.WithUpdateAgeOnGet[string, *MmsgNatItem](), - cache.WithEvict[string, *MmsgNatItem](func(key string, value *MmsgNatItem) { - if conn := value.Conn; conn != nil { - log.Println("Delete", conn.LocalAddr(), "for", key, "to", conn.RemoteAddr()) - _ = conn.Close() - } - }), - cache.WithCreate[string, *MmsgNatItem](func(key string) *MmsgNatItem { - return &MmsgNatItem{} - }), - ) t := &MmsgTunnel{ - nat: nat, address: udpConfig.BindAddress, target: udpConfig.TargetAddress, reserved: slices.Clone(udpConfig.Reserved), @@ -141,22 +127,23 @@ func (t *MmsgTunnel) Handle() { } WriteMsgsBufPool.Put(wMsgs) }() - natItem, _ := t.nat.Get(addr) - natItem.Mutex.Lock() - remoteConn := natItem.Conn - remotePacketConn := natItem.PacketConn + v, _ := t.connMap.LoadOrStore(addr, &MmsgMapItem{}) + mapItem := v.(*MmsgMapItem) + mapItem.Mutex.Lock() + remoteConn := mapItem.Conn + remotePacketConn := mapItem.PacketConn if remoteConn == nil || remotePacketConn == nil { log.Println("Dial to", t.target, "for", addr) remoteConn, err = net.Dial("udp", t.target) if err != nil { - natItem.Mutex.Unlock() + mapItem.Mutex.Unlock() log.Println(err) return } log.Println("Associate from", addr, "to", remoteConn.RemoteAddr(), "by", remoteConn.LocalAddr()) remotePacketConn = ipv4.NewPacketConn(remoteConn.(*net.UDPConn)) - natItem.Conn = remoteConn - natItem.PacketConn = remotePacketConn + mapItem.Conn = remoteConn + mapItem.PacketConn = remotePacketConn go func() { rMsgs := ReadMsgsBufPool.Get().([]ipv4.Message) wMsgs := WriteMsgsBufPool.Get().([]ipv4.Message) @@ -165,10 +152,12 @@ func (t *MmsgTunnel) Handle() { WriteMsgsBufPool.Put(wMsgs) }() for { + _ = remoteConn.SetReadDeadline(time.Now().Add(MaxUdpAge)) // set timeout n, err := remotePacketConn.ReadBatch(rMsgs, 0) if err != nil { - t.nat.Delete(addr) // it will call conn.Close() inside - log.Println(err) + t.connMap.Delete(addr) + log.Println("Delete and close", remoteConn.LocalAddr(), "for", addr, "to", remoteConn.RemoteAddr(), "because", err) + _ = remoteConn.Close() return } for i := 0; i < n; i++ { @@ -184,28 +173,23 @@ func (t *MmsgTunnel) Handle() { wMsgsN := n if wMsgsN == 1 { // maybe faster _, err = udpConn.WriteTo(wMsgs[0].Buffers[0], nAddr) - if err != nil { - t.nat.Delete(addr) // it will call conn.Close() inside - log.Println(err) - return - } } else { var wN int wN, err = packetConn.WriteBatch(wMsgs[:wMsgsN], 0) - if err != nil { - t.nat.Delete(addr) // it will call conn.Close() inside - log.Println(err) - return - } - if wN != wMsgsN { + if err == nil && wN != wMsgsN { log.Println("warning wN=", wN, "wMsgsN=", wMsgsN) } } - t.nat.Get(addr) // refresh lru + if err != nil { + t.connMap.Delete(addr) + log.Println("Delete and close", remoteConn.LocalAddr(), "for", addr, "to", remoteConn.RemoteAddr(), "because", err) + _ = remoteConn.Close() + return + } } }() } - natItem.Mutex.Unlock() + mapItem.Mutex.Unlock() for _, wMsg := range wMsgs[:wMsgsN] { buf := wMsg.Buffers[0] @@ -217,21 +201,18 @@ func (t *MmsgTunnel) Handle() { if wMsgsN == 1 { // maybe faster _, err = remoteConn.Write(wMsgs[0].Buffers[0]) - if err != nil { - log.Println(err) - return - } } else { var wN int wN, err = remotePacketConn.WriteBatch(wMsgs[:wMsgsN], 0) - if err != nil { - log.Println(err) - return - } - if wN != wMsgsN { + if err == nil && wN != wMsgsN { log.Println("warning wN=", wN, "wMsgsN=", wMsgsN) } } + if err != nil { + log.Println(err) + return + } + _ = remoteConn.SetReadDeadline(time.Now().Add(MaxUdpAge)) // refresh timeout }() } diff --git a/udp/udp_std.go b/udp/udp_std.go index 31cbf19..e294920 100644 --- a/udp/udp_std.go +++ b/udp/udp_std.go @@ -5,11 +5,10 @@ import ( "golang.org/x/net/ipv4" "log" "net" - "net/netip" "sync" + "time" "github.com/wwqgtxx/wstunnel/config" - cache "github.com/wwqgtxx/wstunnel/utils/lrucache" ) const BufferSize = 16 * 1024 @@ -24,35 +23,21 @@ func ListenUdp(network, address string) (*net.UDPConn, error) { return pc.(*net.UDPConn), nil } -type StdNatItem struct { +type StdMapItem struct { net.Conn *ipv4.PacketConn sync.Mutex } type StdTunnel struct { - nat *cache.LruCache[netip.AddrPort, *StdNatItem] + connMap sync.Map address string target string reserved []byte } func NewStdTunnel(udpConfig config.UdpConfig) Tunnel { - nat := cache.New[netip.AddrPort, *StdNatItem]( - cache.WithAge[netip.AddrPort, *StdNatItem](MaxUdpAge), - cache.WithUpdateAgeOnGet[netip.AddrPort, *StdNatItem](), - cache.WithEvict[netip.AddrPort, *StdNatItem](func(key netip.AddrPort, value *StdNatItem) { - if conn := value.Conn; conn != nil { - log.Println("Delete", conn.LocalAddr(), "for", key, "to", conn.RemoteAddr()) - _ = conn.Close() - } - }), - cache.WithCreate[netip.AddrPort, *StdNatItem](func(key netip.AddrPort) *StdNatItem { - return &StdNatItem{} - }), - ) t := &StdTunnel{ - nat: nat, address: udpConfig.BindAddress, target: udpConfig.TargetAddress, reserved: slices.Clone(udpConfig.Reserved), @@ -78,27 +63,30 @@ func (t *StdTunnel) Handle() { go func() { defer BufPool.Put(buf) var err error - natItem, _ := t.nat.Get(addr) - natItem.Mutex.Lock() - remoteConn := natItem.Conn + v, _ := t.connMap.LoadOrStore(addr, &StdMapItem{}) + mapItem := v.(*StdMapItem) + mapItem.Mutex.Lock() + remoteConn := mapItem.Conn if remoteConn == nil { log.Println("Dial to", t.target, "for", addr) remoteConn, err = net.Dial("udp", t.target) if err != nil { - natItem.Mutex.Unlock() + mapItem.Mutex.Unlock() log.Println(err) return } log.Println("Associate from", addr, "to", remoteConn.RemoteAddr(), "by", remoteConn.LocalAddr()) - natItem.Conn = remoteConn + mapItem.Conn = remoteConn go func() { for { buf := BufPool.Get().([]byte) + _ = remoteConn.SetReadDeadline(time.Now().Add(MaxUdpAge)) // set timeout n, err := remoteConn.Read(buf) if err != nil { BufPool.Put(buf) - t.nat.Delete(addr) // it will call remoteConn.Close() inside - log.Println(err) + t.connMap.Delete(addr) + log.Println("Delete and close", remoteConn.LocalAddr(), "for", addr, "to", remoteConn.RemoteAddr(), "because", err) + _ = remoteConn.Close() return } if len(t.reserved) > 0 && n > len(t.reserved) { // wireguard reserved @@ -109,15 +97,15 @@ func (t *StdTunnel) Handle() { _, err = udpConn.WriteToUDPAddrPort(buf[:n], addr) BufPool.Put(buf) if err != nil { - t.nat.Delete(addr) // it will call remoteConn.Close() inside - log.Println(err) + t.connMap.Delete(addr) + log.Println("Delete and close", remoteConn.LocalAddr(), "for", addr, "to", remoteConn.RemoteAddr(), "because", err) + _ = remoteConn.Close() return } - t.nat.Get(addr) // refresh lru } }() } - natItem.Mutex.Unlock() + mapItem.Mutex.Unlock() if len(t.reserved) > 0 && n > len(t.reserved) { // wireguard reserved copy(buf[1:], t.reserved) } @@ -126,6 +114,7 @@ func (t *StdTunnel) Handle() { log.Println(err) return } + _ = remoteConn.SetReadDeadline(time.Now().Add(MaxUdpAge)) // refresh timeout }() }