Skip to content

Commit

Permalink
Merge pull request #170 from bkneis/feature/implement-dial-context
Browse files Browse the repository at this point in the history
Add timeout to Dial to avoid infinite timeout
  • Loading branch information
gammazero authored May 4, 2019
2 parents 1dcc590 + 5090e35 commit 88fb9e6
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 25 deletions.
42 changes: 42 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,45 @@ func (ks *serverKeyStore) AuthRole(authid string) (string, error) {
}
return "user", nil
}

// ---- network testing ----

func TestConnectContext(t *testing.T) {
const (
expect = "dial tcp: operation was canceled"
unixExpect = "dial unix /tmp/wamp.sock: operation was canceled"
)

cfg := Config{
Realm: testRealm,
ResponseTimeout: 500 * time.Millisecond,
Logger: logger,
Debug: false,
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := ConnectNetContext(ctx, "http://localhost:9999/ws", cfg)
if err == nil || err.Error() != expect {
t.Fatalf("expected error %s, got %s", expect, err)
}

_, err = ConnectNetContext(ctx, "https://localhost:9999/ws", cfg)
if err == nil || err.Error() != expect {
t.Fatalf("expected error %s, got %s", expect, err)
}

_, err = ConnectNetContext(ctx, "tcp://localhost:9999", cfg)
if err == nil || err.Error() != expect {
t.Fatalf("expected error %s, got %s", expect, err)
}

_, err = ConnectNetContext(ctx, "tcps://localhost:9999", cfg)
if err == nil || err.Error() != expect {
t.Fatalf("expected error %s, got %s", expect, err)
}

_, err = ConnectNetContext(ctx, "unix:///tmp/wamp.sock", cfg)
if err == nil || err.Error() != unixExpect {
t.Fatalf("expected error %s, got %s", expect, err)
}
}
15 changes: 10 additions & 5 deletions client/network.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"fmt"
"log"
"net/url"
Expand All @@ -11,6 +12,10 @@ import (
"github.com/gammazero/nexus/wamp"
)

func ConnectNet(routerURL string, cfg Config) (*Client, error) {
return ConnectNetContext(context.Background(), routerURL, cfg)
}

// ConnectNet creates a new client connected a WAMP router over a websocket,
// TCP socket, or unix socket. The new client joins the realm specified in the
// Config.
Expand All @@ -31,7 +36,7 @@ import (
// For Unix socket clients, the routerURL has the form "unix://path". The path
// portion specifies a path on the local file system where the Unix socket is
// created. TLS is not used for unix socket.
func ConnectNet(routerURL string, cfg Config) (*Client, error) {
func ConnectNetContext(ctx context.Context, routerURL string, cfg Config) (*Client, error) {
if cfg.Logger == nil {
cfg.Logger = log.New(os.Stderr, "", 0)
}
Expand All @@ -51,17 +56,17 @@ func ConnectNet(routerURL string, cfg Config) (*Client, error) {
routerURL = u.String()
fallthrough
case "ws", "wss":
p, err = transport.ConnectWebsocketPeer(routerURL, cfg.Serialization,
p, err = transport.ConnectWebsocketPeerContext(ctx, routerURL, cfg.Serialization,
cfg.TlsCfg, cfg.Dial, cfg.Logger, &cfg.WsCfg)
case "tcp":
p, err = transport.ConnectRawSocketPeer(u.Scheme, u.Host,
p, err = transport.ConnectRawSocketPeerContext(ctx, u.Scheme, u.Host,
cfg.Serialization, cfg.Logger, cfg.RecvLimit)
case "tcps":
p, err = transport.ConnectTlsRawSocketPeer("tcp", u.Host,
p, err = transport.ConnectTlsRawSocketPeerContext(ctx, "tcp", u.Host,
cfg.Serialization, cfg.TlsCfg, cfg.Logger, cfg.RecvLimit)
case "unix":
path := strings.TrimRight(u.Host+u.Path, "/")
p, err = transport.ConnectRawSocketPeer(u.Scheme, path,
p, err = transport.ConnectRawSocketPeerContext(ctx, u.Scheme, path,
cfg.Serialization, cfg.Logger, cfg.RecvLimit)
default:
err = fmt.Errorf("invalid url: %s", routerURL)
Expand Down
74 changes: 60 additions & 14 deletions transport/rawsocketpeer.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,47 +49,74 @@ const (
magic = 0x7f
)

// ConnectRawSocketPeer creates a new rawSocketPeer with the specified config,
// and connects it to the WAMP router at the specified address. The network
// and address parameters are documented here:
// ConnectRawSocketPeer calls ConnectRawSocketPeerContext without a context.
func ConnectRawSocketPeer(network, address string, serialization serialize.Serialization, logger stdlog.StdLog, recvLimit int) (wamp.Peer, error) {
return ConnectRawSocketPeerContext(context.Background(), network, address, serialization, logger, recvLimit)
}

// ConnectRawSocketPeerContext creates a new rawSocketPeer with the specified
// config, and connects it to the WAMP router at the specified address. The
// network and address parameters are documented here:
// https://golang.org/pkg/net/#Dial
//
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Once successfully connected,
// any expiration of the context will not affect the connection.
//
// If recvLimit is > 0, then the client will not receive messages with size
// larger than the nearest power of 2 greater than or equal to recvLimit. If
// recvLimit is <= 0, then the default of 16M is used.
func ConnectRawSocketPeer(network, address string, serialization serialize.Serialization, logger stdlog.StdLog, recvLimit int) (wamp.Peer, error) {
err := checkNetworkType(network)
func ConnectRawSocketPeerContext(ctx context.Context, network, address string, serialization serialize.Serialization, logger stdlog.StdLog, recvLimit int) (wamp.Peer, error) {
var (
protocol byte
conn net.Conn
err error
peer *rawSocketPeer
)

err = checkNetworkType(network)
if err != nil {
return nil, err
}

protocol, err := getProtoByte(serialization)
protocol, err = getProtoByte(serialization)
if err != nil {
return nil, err
}

conn, err := net.Dial(network, address)
var d net.Dialer
conn, err = d.DialContext(ctx, network, address)
if err != nil {
return nil, err
}

peer, err := clientHandshake(conn, logger, protocol, recvLimit)
peer, err = clientHandshake(conn, logger, protocol, recvLimit)
if err != nil {
conn.Close()
return nil, err
}
return peer, nil
}

// ConnectTlsRawSocketPeer creates a new rawSocketPeer with the specified
// config, and connects it, using TLS, to the WAMP router at the specified
// address. The network, address, and tlscfg parameters are documented here:
// https://golang.org/pkg/crypto/tls/#Dial
// ConnectTlsRawSocketPeer calls ConnectTlsRawSocketPeerContext without a Dial
// context.
func ConnectTlsRawSocketPeer(network, address string, serialization serialize.Serialization, tlsConfig *tls.Config, logger stdlog.StdLog, recvLimit int) (wamp.Peer, error) {
return ConnectTlsRawSocketPeerContext(context.Background(), network, address, serialization, tlsConfig, logger, recvLimit)
}

// ConnectTlsRawSocketPeerContext creates a new rawSocketPeer with the
// specified config, and connects it, using TLS, to the WAMP router at the
// specified address. The network, address, and tlscfg parameters are
// documented here: https://golang.org/pkg/crypto/tls/#Dial
//
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Once successfully connected,
// any expiration of the context will not affect the connection.
//
// If recvLimit is > 0, then the client will not receive messages with size
// larger than the nearest power of 2 greater than or equal to recvLimit. If
// recvLimit is <= 0, then the default of 16M is used.
func ConnectTlsRawSocketPeer(network, address string, serialization serialize.Serialization, tlsConfig *tls.Config, logger stdlog.StdLog, recvLimit int) (wamp.Peer, error) {
func ConnectTlsRawSocketPeerContext(ctx context.Context, network, address string, serialization serialize.Serialization, tlsConfig *tls.Config, logger stdlog.StdLog, recvLimit int) (wamp.Peer, error) {
err := checkNetworkType(network)
if err != nil {
return nil, err
Expand All @@ -100,11 +127,30 @@ func ConnectTlsRawSocketPeer(network, address string, serialization serialize.Se
return nil, err
}

conn, err := tls.Dial(network, address, tlsConfig)
var d net.Dialer
rawConn, err := d.DialContext(ctx, network, address)
if err != nil {
return nil, err
}

conn := tls.Client(rawConn, tlsConfig)
errChannel := make(chan error, 1)

go func() {
errChannel <- conn.Handshake()
}()

// Wait for TLS handshake to complete or context to expire
select {
case err = <-errChannel:
if err != nil {
rawConn.Close()
return nil, err
}
case <-ctx.Done():
return nil, ctx.Err()
}

peer, err := clientHandshake(conn, logger, protocol, recvLimit)
if err != nil {
conn.Close()
Expand Down
31 changes: 25 additions & 6 deletions transport/websocketpeer.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,32 @@ const (

type DialFunc func(network, addr string) (net.Conn, error)

// ConnectWebsocketPeer creates a new websocket client with the specified
// config, connects the client to the websocket server at the specified URL,
// and returns the connected websocket peer.
func ConnectWebsocketPeer(routerURL string, serialization serialize.Serialization, tlsConfig *tls.Config, dial DialFunc, logger stdlog.StdLog, wsCfg *WebsocketConfig) (wamp.Peer, error) {
// ConnectWebsocketPeer calls ConnectWebsocketPeerContext without a Dial
// context.
func ConnectWebsocketPeer(
routerURL string,
serialization serialize.Serialization,
tlsConfig *tls.Config,
dial DialFunc,
logger stdlog.StdLog,
wsCfg *WebsocketConfig) (wamp.Peer, error) {
return ConnectWebsocketPeerContext(context.Background(), routerURL, serialization, tlsConfig, dial, logger, wsCfg)
}

// ConnectWebsocketPeerContext creates a new websocket client with the
// specified config, connects the client to the websocket server at the
// specified URL, and returns the connected websocket peer.
//
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Once successfully connected,
// any expiration of the context will not affect the connection.
func ConnectWebsocketPeerContext(ctx context.Context, routerURL string, serialization serialize.Serialization, tlsConfig *tls.Config, dial DialFunc, logger stdlog.StdLog, wsCfg *WebsocketConfig) (wamp.Peer, error) {
var (
protocol string
payloadType int
serializer serialize.Serializer
conn *websocket.Conn
err error
)

switch serialization {
Expand Down Expand Up @@ -107,7 +125,8 @@ func ConnectWebsocketPeer(routerURL string, serialization serialize.Serializatio

if wsCfg != nil {
if wsCfg.ProxyURL != "" {
proxyURL, err := url.Parse(wsCfg.ProxyURL)
var proxyURL *url.URL
proxyURL, err = url.Parse(wsCfg.ProxyURL)
if err != nil {
return nil, err
}
Expand All @@ -117,7 +136,7 @@ func ConnectWebsocketPeer(routerURL string, serialization serialize.Serializatio
dialer.EnableCompression = true
}

conn, _, err := dialer.Dial(routerURL, nil)
conn, _, err = dialer.DialContext(ctx, routerURL, nil)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 88fb9e6

Please sign in to comment.