Skip to content

Commit

Permalink
Add stun request timeout and refactor code to make it more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Jan 17, 2024
1 parent 679e162 commit 230145b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 56 deletions.
29 changes: 2 additions & 27 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -1276,34 +1276,9 @@ func (e *Engine) receiveProbeEvents() {
}

func (e *Engine) probeSTUNs() []relay.ProbeResult {
results := make([]relay.ProbeResult, len(e.STUNs))
var wg sync.WaitGroup
for i, uri := range e.STUNs {
ctx, cancel := context.WithTimeout(e.ctx, 1*time.Second)
defer cancel()

results[i].URI = uri
wg.Add(1)
go relay.ProbeSTUN(ctx, &wg, uri, &results[i])
}

wg.Wait()

return results
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, e.STUNs)
}

func (e *Engine) probeTURNs() []relay.ProbeResult {
results := make([]relay.ProbeResult, len(e.TURNs))
var wg sync.WaitGroup
for i, uri := range e.TURNs {
ctx, cancel := context.WithTimeout(e.ctx, 1*time.Second)
defer cancel()

results[i].URI = uri
wg.Add(1)
go relay.ProbeTURN(ctx, &wg, uri, &results[i])
}
wg.Wait()

return results
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
}
96 changes: 67 additions & 29 deletions client/internal/relay/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,66 +5,77 @@ import (
"fmt"
"net"
"sync"
"time"

"github.com/pion/stun/v2"
"github.com/pion/turn/v3"
log "github.com/sirupsen/logrus"
)

// ProbeResult holds the info about the result of a relay probe request
type ProbeResult struct {
URI *stun.URI
Err error
Addr string
}

func ProbeSTUN(_ context.Context, wg *sync.WaitGroup, uri *stun.URI, result *ProbeResult) {
defer wg.Done()

// ProbeSTUN tries binding to the given STUN uri and acquiring an address
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() {
if result.Err != nil {
log.Debugf("stun probe error from %s: %s", uri, result.Err)
if probeErr != nil {
log.Debugf("stun probe error from %s: %s", uri, probeErr)
}
}()

client, err := stun.DialURI(uri, &stun.DialConfig{})
if err != nil {
result.Err = fmt.Errorf("dial: %w", err)
probeErr = fmt.Errorf("dial: %w", err)
return
}

defer func() {
if err := client.Close(); err != nil && result.Err == nil {
result.Err = fmt.Errorf("close: %w", err)
if err := client.Close(); err != nil && probeErr == nil {
probeErr = fmt.Errorf("close: %w", err)
}
}()

if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
done := make(chan struct{})
if err = client.Start(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
if res.Error != nil {
result.Err = fmt.Errorf("request: %w", err)
probeErr = fmt.Errorf("request: %w", err)
return
}

var xorAddr stun.XORMappedAddress
if getErr := xorAddr.GetFrom(res.Message); getErr != nil {
result.Err = fmt.Errorf("get xor addr: %w", err)
probeErr = fmt.Errorf("get xor addr: %w", err)
return
}

log.Debugf("stun probe received address from %s: %s", uri, xorAddr)
result.Addr = xorAddr.String()
addr = xorAddr.String()

done <- struct{}{}
}); err != nil {
result.Err = fmt.Errorf("client: %w", err)
probeErr = fmt.Errorf("client: %w", err)
return
}

}
select {
case <-ctx.Done():
probeErr = fmt.Errorf("stun request: %w", ctx.Err())
return
case <-done:
}

func ProbeTURN(ctx context.Context, wg *sync.WaitGroup, uri *stun.URI, result *ProbeResult) {
defer wg.Done()
return addr, nil
}

// ProbeTURN tries allocating a session from the given TURN URI
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() {
if result.Err != nil {
log.Debugf("turn probe error from %s: %s", uri, result.Err)
if probeErr != nil {
log.Debugf("turn probe error from %s: %s", uri, probeErr)
}
}()

Expand All @@ -76,25 +87,25 @@ func ProbeTURN(ctx context.Context, wg *sync.WaitGroup, uri *stun.URI, result *P
var err error
conn, err = net.ListenPacket("udp", "")
if err != nil {
result.Err = fmt.Errorf("listen: %w", err)
probeErr = fmt.Errorf("listen: %w", err)
return
}
case stun.ProtoTypeTCP:
dialer := net.Dialer{}
tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr)
if err != nil {
result.Err = fmt.Errorf("dial: %w", err)
probeErr = fmt.Errorf("dial: %w", err)
return
}
conn = turn.NewSTUNConn(tcpConn)
default:
result.Err = fmt.Errorf("conn: unknown proto: %s", uri.Proto)
probeErr = fmt.Errorf("conn: unknown proto: %s", uri.Proto)
return
}

defer func() {
if err := conn.Close(); err != nil && result.Err == nil {
result.Err = fmt.Errorf("conn close: %w", err)
if err := conn.Close(); err != nil && probeErr == nil {
probeErr = fmt.Errorf("conn close: %w", err)
}
}()

Expand All @@ -107,27 +118,54 @@ func ProbeTURN(ctx context.Context, wg *sync.WaitGroup, uri *stun.URI, result *P
}
client, err := turn.NewClient(cfg)
if err != nil {
result.Err = fmt.Errorf("create client: %w", err)
probeErr = fmt.Errorf("create client: %w", err)
return
}
defer client.Close()

if err := client.Listen(); err != nil {
result.Err = fmt.Errorf("client listen: %w", err)
probeErr = fmt.Errorf("client listen: %w", err)
return
}

relayConn, err := client.Allocate()
if err != nil {
result.Err = fmt.Errorf("allocate: %w", err)
probeErr = fmt.Errorf("allocate: %w", err)
return
}
defer func() {
if err := relayConn.Close(); err != nil && result.Err == nil {
result.Err = fmt.Errorf("close relay conn: %w", err)
if err := relayConn.Close(); err != nil && probeErr == nil {
probeErr = fmt.Errorf("close relay conn: %w", err)
}
}()

log.Debugf("turn probe relay address from %s: %s", uri, relayConn.LocalAddr())
result.Addr = relayConn.LocalAddr().String()

return relayConn.LocalAddr().String(), nil
}

// ProbeAll probes all given servers asynchronously and returns the results
func ProbeAll(
ctx context.Context,
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
relays []*stun.URI,
) []ProbeResult {
results := make([]ProbeResult, len(relays))

var wg sync.WaitGroup
for i, uri := range relays {
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()

wg.Add(1)
go func(res *ProbeResult, stunURI *stun.URI) {
defer wg.Done()
res.URI = stunURI
res.Addr, res.Err = fn(ctx, stunURI)
}(&results[i], uri)
}

wg.Wait()

return results
}

0 comments on commit 230145b

Please sign in to comment.