diff --git a/pkg/sshclient/ssh_forwarder.go b/pkg/sshclient/ssh_forwarder.go index 38075c230..2cd7ab235 100644 --- a/pkg/sshclient/ssh_forwarder.go +++ b/pkg/sshclient/ssh_forwarder.go @@ -2,6 +2,7 @@ package sshclient import ( "context" + "fmt" "io" "net" "net/url" @@ -180,16 +181,19 @@ func setupProxy(ctx context.Context, socketURI *url.URL, dest *url.URL, identity return &SSHForward{listener, &bastion, socketURI}, nil } -func initialConnection(ctx context.Context, connectFunc ConnectCallback) (net.Conn, error) { +const maxRetries = 60 +const initialBackoff = 100 * time.Millisecond + +func retry[T comparable](ctx context.Context, retryFunc func() (T, error), retryMsg string) (T, error) { var ( - conn net.Conn - err error + returnVal T + err error ) - backoff := 100 * time.Millisecond + backoff := initialBackoff loop: - for i := 0; i < 60; i++ { + for i := 0; i < maxRetries; i++ { select { case <-ctx.Done(): break loop @@ -197,15 +201,22 @@ loop: // proceed } - conn, err = connectFunc(ctx, nil) + returnVal, err = retryFunc() if err == nil { - break + return returnVal, nil } - logrus.Debugf("Waiting for sshd: %s", backoff) + logrus.Debugf("%s (%s)", retryMsg, backoff) sleep(ctx, backoff) backoff = backOff(backoff) } - return conn, err + return returnVal, fmt.Errorf("timeout: %w", err) +} + +func initialConnection(ctx context.Context, connectFunc ConnectCallback) (net.Conn, error) { + retryFunc := func() (net.Conn, error) { + return connectFunc(ctx, nil) + } + return retry(ctx, retryFunc, "Waiting for sshd socket") } func acceptConnection(ctx context.Context, listener net.Listener, bastion *Bastion, socketURI *url.URL) error {