Skip to content

Commit

Permalink
using sync/map for udp conn mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Apr 22, 2023
1 parent f787f0c commit 8160628
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 79 deletions.
6 changes: 4 additions & 2 deletions udp/udp.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package udp

import (
"github.com/wwqgtxx/wstunnel/config"
"log"
"net"
"time"

"github.com/wwqgtxx/wstunnel/config"
)

const MaxUdpAge = 5 * 60
const MaxUdpAge = 5 * time.Minute

type Tunnel interface {
Handle()
Expand Down
77 changes: 29 additions & 48 deletions udp/udp_mmsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"log"
"net"
"sync"
"time"

"github.com/wwqgtxx/wstunnel/config"
cache "github.com/wwqgtxx/wstunnel/utils/lrucache"
)

const (
Expand Down Expand Up @@ -40,35 +40,21 @@ var WriteMsgsBufPool = sync.Pool{New: func() any {
// This means we can use this struct to read from a socket that receives both IPv4 and IPv6 messages.
var _ ipv4.Message = ipv6.Message{}

type MmsgNatItem struct {
type MmsgMapItem struct {
net.Conn
*ipv4.PacketConn
sync.Mutex
}

type MmsgTunnel struct {
nat *cache.LruCache[string, *MmsgNatItem]
connMap sync.Map
address string
target string
reserved []byte
}

func NewMmsgTunnel(udpConfig config.UdpConfig) Tunnel {
nat := cache.New[string, *MmsgNatItem](
cache.WithAge[string, *MmsgNatItem](MaxUdpAge),
cache.WithUpdateAgeOnGet[string, *MmsgNatItem](),
cache.WithEvict[string, *MmsgNatItem](func(key string, value *MmsgNatItem) {
if conn := value.Conn; conn != nil {
log.Println("Delete", conn.LocalAddr(), "for", key, "to", conn.RemoteAddr())
_ = conn.Close()
}
}),
cache.WithCreate[string, *MmsgNatItem](func(key string) *MmsgNatItem {
return &MmsgNatItem{}
}),
)
t := &MmsgTunnel{
nat: nat,
address: udpConfig.BindAddress,
target: udpConfig.TargetAddress,
reserved: slices.Clone(udpConfig.Reserved),
Expand Down Expand Up @@ -141,22 +127,23 @@ func (t *MmsgTunnel) Handle() {
}
WriteMsgsBufPool.Put(wMsgs)
}()
natItem, _ := t.nat.Get(addr)
natItem.Mutex.Lock()
remoteConn := natItem.Conn
remotePacketConn := natItem.PacketConn
v, _ := t.connMap.LoadOrStore(addr, &MmsgMapItem{})
mapItem := v.(*MmsgMapItem)
mapItem.Mutex.Lock()
remoteConn := mapItem.Conn
remotePacketConn := mapItem.PacketConn
if remoteConn == nil || remotePacketConn == nil {
log.Println("Dial to", t.target, "for", addr)
remoteConn, err = net.Dial("udp", t.target)
if err != nil {
natItem.Mutex.Unlock()
mapItem.Mutex.Unlock()
log.Println(err)
return
}
log.Println("Associate from", addr, "to", remoteConn.RemoteAddr(), "by", remoteConn.LocalAddr())
remotePacketConn = ipv4.NewPacketConn(remoteConn.(*net.UDPConn))
natItem.Conn = remoteConn
natItem.PacketConn = remotePacketConn
mapItem.Conn = remoteConn
mapItem.PacketConn = remotePacketConn
go func() {
rMsgs := ReadMsgsBufPool.Get().([]ipv4.Message)
wMsgs := WriteMsgsBufPool.Get().([]ipv4.Message)
Expand All @@ -165,10 +152,12 @@ func (t *MmsgTunnel) Handle() {
WriteMsgsBufPool.Put(wMsgs)
}()
for {
_ = remoteConn.SetReadDeadline(time.Now().Add(MaxUdpAge)) // set timeout
n, err := remotePacketConn.ReadBatch(rMsgs, 0)
if err != nil {
t.nat.Delete(addr) // it will call conn.Close() inside
log.Println(err)
t.connMap.Delete(addr)
log.Println("Delete and close", remoteConn.LocalAddr(), "for", addr, "to", remoteConn.RemoteAddr(), "because", err)
_ = remoteConn.Close()
return
}
for i := 0; i < n; i++ {
Expand All @@ -184,28 +173,23 @@ func (t *MmsgTunnel) Handle() {
wMsgsN := n
if wMsgsN == 1 { // maybe faster
_, err = udpConn.WriteTo(wMsgs[0].Buffers[0], nAddr)
if err != nil {
t.nat.Delete(addr) // it will call conn.Close() inside
log.Println(err)
return
}
} else {
var wN int
wN, err = packetConn.WriteBatch(wMsgs[:wMsgsN], 0)
if err != nil {
t.nat.Delete(addr) // it will call conn.Close() inside
log.Println(err)
return
}
if wN != wMsgsN {
if err == nil && wN != wMsgsN {
log.Println("warning wN=", wN, "wMsgsN=", wMsgsN)
}
}
t.nat.Get(addr) // refresh lru
if err != nil {
t.connMap.Delete(addr)
log.Println("Delete and close", remoteConn.LocalAddr(), "for", addr, "to", remoteConn.RemoteAddr(), "because", err)
_ = remoteConn.Close()
return
}
}
}()
}
natItem.Mutex.Unlock()
mapItem.Mutex.Unlock()

for _, wMsg := range wMsgs[:wMsgsN] {
buf := wMsg.Buffers[0]
Expand All @@ -217,21 +201,18 @@ func (t *MmsgTunnel) Handle() {

if wMsgsN == 1 { // maybe faster
_, err = remoteConn.Write(wMsgs[0].Buffers[0])
if err != nil {
log.Println(err)
return
}
} else {
var wN int
wN, err = remotePacketConn.WriteBatch(wMsgs[:wMsgsN], 0)
if err != nil {
log.Println(err)
return
}
if wN != wMsgsN {
if err == nil && wN != wMsgsN {
log.Println("warning wN=", wN, "wMsgsN=", wMsgsN)
}
}
if err != nil {
log.Println(err)
return
}
_ = remoteConn.SetReadDeadline(time.Now().Add(MaxUdpAge)) // refresh timeout

}()
}
Expand Down
47 changes: 18 additions & 29 deletions udp/udp_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ import (
"golang.org/x/net/ipv4"
"log"
"net"
"net/netip"
"sync"
"time"

"github.com/wwqgtxx/wstunnel/config"
cache "github.com/wwqgtxx/wstunnel/utils/lrucache"
)

const BufferSize = 16 * 1024
Expand All @@ -24,35 +23,21 @@ func ListenUdp(network, address string) (*net.UDPConn, error) {
return pc.(*net.UDPConn), nil
}

type StdNatItem struct {
type StdMapItem struct {
net.Conn
*ipv4.PacketConn
sync.Mutex
}

type StdTunnel struct {
nat *cache.LruCache[netip.AddrPort, *StdNatItem]
connMap sync.Map
address string
target string
reserved []byte
}

func NewStdTunnel(udpConfig config.UdpConfig) Tunnel {
nat := cache.New[netip.AddrPort, *StdNatItem](
cache.WithAge[netip.AddrPort, *StdNatItem](MaxUdpAge),
cache.WithUpdateAgeOnGet[netip.AddrPort, *StdNatItem](),
cache.WithEvict[netip.AddrPort, *StdNatItem](func(key netip.AddrPort, value *StdNatItem) {
if conn := value.Conn; conn != nil {
log.Println("Delete", conn.LocalAddr(), "for", key, "to", conn.RemoteAddr())
_ = conn.Close()
}
}),
cache.WithCreate[netip.AddrPort, *StdNatItem](func(key netip.AddrPort) *StdNatItem {
return &StdNatItem{}
}),
)
t := &StdTunnel{
nat: nat,
address: udpConfig.BindAddress,
target: udpConfig.TargetAddress,
reserved: slices.Clone(udpConfig.Reserved),
Expand All @@ -78,27 +63,30 @@ func (t *StdTunnel) Handle() {
go func() {
defer BufPool.Put(buf)
var err error
natItem, _ := t.nat.Get(addr)
natItem.Mutex.Lock()
remoteConn := natItem.Conn
v, _ := t.connMap.LoadOrStore(addr, &StdMapItem{})
mapItem := v.(*StdMapItem)
mapItem.Mutex.Lock()
remoteConn := mapItem.Conn
if remoteConn == nil {
log.Println("Dial to", t.target, "for", addr)
remoteConn, err = net.Dial("udp", t.target)
if err != nil {
natItem.Mutex.Unlock()
mapItem.Mutex.Unlock()
log.Println(err)
return
}
log.Println("Associate from", addr, "to", remoteConn.RemoteAddr(), "by", remoteConn.LocalAddr())
natItem.Conn = remoteConn
mapItem.Conn = remoteConn
go func() {
for {
buf := BufPool.Get().([]byte)
_ = remoteConn.SetReadDeadline(time.Now().Add(MaxUdpAge)) // set timeout
n, err := remoteConn.Read(buf)
if err != nil {
BufPool.Put(buf)
t.nat.Delete(addr) // it will call remoteConn.Close() inside
log.Println(err)
t.connMap.Delete(addr)
log.Println("Delete and close", remoteConn.LocalAddr(), "for", addr, "to", remoteConn.RemoteAddr(), "because", err)
_ = remoteConn.Close()
return
}
if len(t.reserved) > 0 && n > len(t.reserved) { // wireguard reserved
Expand All @@ -109,15 +97,15 @@ func (t *StdTunnel) Handle() {
_, err = udpConn.WriteToUDPAddrPort(buf[:n], addr)
BufPool.Put(buf)
if err != nil {
t.nat.Delete(addr) // it will call remoteConn.Close() inside
log.Println(err)
t.connMap.Delete(addr)
log.Println("Delete and close", remoteConn.LocalAddr(), "for", addr, "to", remoteConn.RemoteAddr(), "because", err)
_ = remoteConn.Close()
return
}
t.nat.Get(addr) // refresh lru
}
}()
}
natItem.Mutex.Unlock()
mapItem.Mutex.Unlock()
if len(t.reserved) > 0 && n > len(t.reserved) { // wireguard reserved
copy(buf[1:], t.reserved)
}
Expand All @@ -126,6 +114,7 @@ func (t *StdTunnel) Handle() {
log.Println(err)
return
}
_ = remoteConn.SetReadDeadline(time.Now().Add(MaxUdpAge)) // refresh timeout
}()

}
Expand Down

0 comments on commit 8160628

Please sign in to comment.