From 230145bb48cc84553ce8253e0885ef5db728c0c8 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 17 Jan 2024 14:16:03 +0100 Subject: [PATCH] Add stun request timeout and refactor code to make it more generic --- client/internal/engine.go | 29 +--------- client/internal/relay/relay.go | 96 ++++++++++++++++++++++++---------- 2 files changed, 69 insertions(+), 56 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 8353a39e215..bbce1dced9e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1276,34 +1276,9 @@ func (e *Engine) receiveProbeEvents() { } func (e *Engine) probeSTUNs() []relay.ProbeResult { - results := make([]relay.ProbeResult, len(e.STUNs)) - var wg sync.WaitGroup - for i, uri := range e.STUNs { - ctx, cancel := context.WithTimeout(e.ctx, 1*time.Second) - defer cancel() - - results[i].URI = uri - wg.Add(1) - go relay.ProbeSTUN(ctx, &wg, uri, &results[i]) - } - - wg.Wait() - - return results + return relay.ProbeAll(e.ctx, relay.ProbeSTUN, e.STUNs) } func (e *Engine) probeTURNs() []relay.ProbeResult { - results := make([]relay.ProbeResult, len(e.TURNs)) - var wg sync.WaitGroup - for i, uri := range e.TURNs { - ctx, cancel := context.WithTimeout(e.ctx, 1*time.Second) - defer cancel() - - results[i].URI = uri - wg.Add(1) - go relay.ProbeTURN(ctx, &wg, uri, &results[i]) - } - wg.Wait() - - return results + return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs) } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index c5511e280ae..1d8e6846d4e 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -5,66 +5,77 @@ import ( "fmt" "net" "sync" + "time" "github.com/pion/stun/v2" "github.com/pion/turn/v3" log "github.com/sirupsen/logrus" ) +// ProbeResult holds the info about the result of a relay probe request type ProbeResult struct { URI *stun.URI Err error Addr string } -func ProbeSTUN(_ context.Context, wg *sync.WaitGroup, uri *stun.URI, result *ProbeResult) { - defer wg.Done() - +// ProbeSTUN tries binding to the given STUN uri and acquiring an address +func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { - if result.Err != nil { - log.Debugf("stun probe error from %s: %s", uri, result.Err) + if probeErr != nil { + log.Debugf("stun probe error from %s: %s", uri, probeErr) } }() client, err := stun.DialURI(uri, &stun.DialConfig{}) if err != nil { - result.Err = fmt.Errorf("dial: %w", err) + probeErr = fmt.Errorf("dial: %w", err) return } defer func() { - if err := client.Close(); err != nil && result.Err == nil { - result.Err = fmt.Errorf("close: %w", err) + if err := client.Close(); err != nil && probeErr == nil { + probeErr = fmt.Errorf("close: %w", err) } }() - if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) { + done := make(chan struct{}) + if err = client.Start(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) { if res.Error != nil { - result.Err = fmt.Errorf("request: %w", err) + probeErr = fmt.Errorf("request: %w", err) return } var xorAddr stun.XORMappedAddress if getErr := xorAddr.GetFrom(res.Message); getErr != nil { - result.Err = fmt.Errorf("get xor addr: %w", err) + probeErr = fmt.Errorf("get xor addr: %w", err) return } log.Debugf("stun probe received address from %s: %s", uri, xorAddr) - result.Addr = xorAddr.String() + addr = xorAddr.String() + + done <- struct{}{} }); err != nil { - result.Err = fmt.Errorf("client: %w", err) + probeErr = fmt.Errorf("client: %w", err) return } -} + select { + case <-ctx.Done(): + probeErr = fmt.Errorf("stun request: %w", ctx.Err()) + return + case <-done: + } -func ProbeTURN(ctx context.Context, wg *sync.WaitGroup, uri *stun.URI, result *ProbeResult) { - defer wg.Done() + return addr, nil +} +// ProbeTURN tries allocating a session from the given TURN URI +func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { - if result.Err != nil { - log.Debugf("turn probe error from %s: %s", uri, result.Err) + if probeErr != nil { + log.Debugf("turn probe error from %s: %s", uri, probeErr) } }() @@ -76,25 +87,25 @@ func ProbeTURN(ctx context.Context, wg *sync.WaitGroup, uri *stun.URI, result *P var err error conn, err = net.ListenPacket("udp", "") if err != nil { - result.Err = fmt.Errorf("listen: %w", err) + probeErr = fmt.Errorf("listen: %w", err) return } case stun.ProtoTypeTCP: dialer := net.Dialer{} tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr) if err != nil { - result.Err = fmt.Errorf("dial: %w", err) + probeErr = fmt.Errorf("dial: %w", err) return } conn = turn.NewSTUNConn(tcpConn) default: - result.Err = fmt.Errorf("conn: unknown proto: %s", uri.Proto) + probeErr = fmt.Errorf("conn: unknown proto: %s", uri.Proto) return } defer func() { - if err := conn.Close(); err != nil && result.Err == nil { - result.Err = fmt.Errorf("conn close: %w", err) + if err := conn.Close(); err != nil && probeErr == nil { + probeErr = fmt.Errorf("conn close: %w", err) } }() @@ -107,27 +118,54 @@ func ProbeTURN(ctx context.Context, wg *sync.WaitGroup, uri *stun.URI, result *P } client, err := turn.NewClient(cfg) if err != nil { - result.Err = fmt.Errorf("create client: %w", err) + probeErr = fmt.Errorf("create client: %w", err) return } defer client.Close() if err := client.Listen(); err != nil { - result.Err = fmt.Errorf("client listen: %w", err) + probeErr = fmt.Errorf("client listen: %w", err) return } relayConn, err := client.Allocate() if err != nil { - result.Err = fmt.Errorf("allocate: %w", err) + probeErr = fmt.Errorf("allocate: %w", err) return } defer func() { - if err := relayConn.Close(); err != nil && result.Err == nil { - result.Err = fmt.Errorf("close relay conn: %w", err) + if err := relayConn.Close(); err != nil && probeErr == nil { + probeErr = fmt.Errorf("close relay conn: %w", err) } }() log.Debugf("turn probe relay address from %s: %s", uri, relayConn.LocalAddr()) - result.Addr = relayConn.LocalAddr().String() + + return relayConn.LocalAddr().String(), nil +} + +// ProbeAll probes all given servers asynchronously and returns the results +func ProbeAll( + ctx context.Context, + fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error), + relays []*stun.URI, +) []ProbeResult { + results := make([]ProbeResult, len(relays)) + + var wg sync.WaitGroup + for i, uri := range relays { + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + + wg.Add(1) + go func(res *ProbeResult, stunURI *stun.URI) { + defer wg.Done() + res.URI = stunURI + res.Addr, res.Err = fn(ctx, stunURI) + }(&results[i], uri) + } + + wg.Wait() + + return results }