Skip to content

Commit

Permalink
Merge pull request #308 from cfergeau/sshrace
Browse files Browse the repository at this point in the history
ssh: Add retries to setupProxy
  • Loading branch information
openshift-merge-bot[bot] authored Jan 11, 2024
2 parents f01fd1c + 8357aa4 commit 3cb88d9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
8 changes: 4 additions & 4 deletions pkg/sshclient/bastion.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ func HostKey(host string) ssh.PublicKey {
return nil
}

func CreateBastion(_url *url.URL, passPhrase string, identity string, initial net.Conn, connect ConnectCallback) (Bastion, error) {
func CreateBastion(_url *url.URL, passPhrase string, identity string, initial net.Conn, connect ConnectCallback) (*Bastion, error) {
var authMethods []ssh.AuthMethod

if len(identity) > 0 {
s, err := PublicKey(identity, []byte(passPhrase))
if err != nil {
return Bastion{}, errors.Wrapf(err, "failed to parse identity %q", identity)
return nil, errors.Wrapf(err, "failed to parse identity %q", identity)
}
authMethods = append(authMethods, ssh.PublicKeys(s))
}
Expand All @@ -100,7 +100,7 @@ func CreateBastion(_url *url.URL, passPhrase string, identity string, initial ne
}

if len(authMethods) == 0 {
return Bastion{}, errors.New("No available auth methods")
return nil, errors.New("No available auth methods")
}

port := _url.Port()
Expand Down Expand Up @@ -149,7 +149,7 @@ func CreateBastion(_url *url.URL, passPhrase string, identity string, initial ne
}

bastion := Bastion{nil, config, _url.Hostname(), port, _url.Path, connect}
return bastion, bastion.reconnect(context.Background(), initial)
return &bastion, bastion.reconnect(context.Background(), initial)
}

func (bastion *Bastion) Reconnect(ctx context.Context) error {
Expand Down
38 changes: 26 additions & 12 deletions pkg/sshclient/ssh_forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sshclient

import (
"context"
"fmt"
"io"
"net"
"net/url"
Expand Down Expand Up @@ -170,42 +171,55 @@ func setupProxy(ctx context.Context, socketURI *url.URL, dest *url.URL, identity
return &SSHForward{}, err
}

bastion, err := CreateBastion(dest, passphrase, identity, conn, connectFunc)
createBastion := func() (*Bastion, error) {
return CreateBastion(dest, passphrase, identity, conn, connectFunc)
}
bastion, err := retry(ctx, createBastion, "Waiting for sshd")
if err != nil {
return &SSHForward{}, err
return &SSHForward{}, fmt.Errorf("setupProxy failed: %w", err)
}

logrus.Debugf("Socket forward established: %s -> %s\n", socketURI.Path, dest.Path)

return &SSHForward{listener, &bastion, socketURI}, nil
return &SSHForward{listener, bastion, socketURI}, nil
}

func initialConnection(ctx context.Context, connectFunc ConnectCallback) (net.Conn, error) {
const maxRetries = 60
const initialBackoff = 100 * time.Millisecond

func retry[T comparable](ctx context.Context, retryFunc func() (T, error), retryMsg string) (T, error) {
var (
conn net.Conn
err error
returnVal T
err error
)

backoff := 100 * time.Millisecond
backoff := initialBackoff

loop:
for i := 0; i < 60; i++ {
for i := 0; i < maxRetries; i++ {
select {
case <-ctx.Done():
break loop
default:
// proceed
}

conn, err = connectFunc(ctx, nil)
returnVal, err = retryFunc()
if err == nil {
break
return returnVal, nil
}
logrus.Debugf("Waiting for sshd: %s", backoff)
logrus.Debugf("%s (%s)", retryMsg, backoff)
sleep(ctx, backoff)
backoff = backOff(backoff)
}
return conn, err
return returnVal, fmt.Errorf("timeout: %w", err)
}

func initialConnection(ctx context.Context, connectFunc ConnectCallback) (net.Conn, error) {
retryFunc := func() (net.Conn, error) {
return connectFunc(ctx, nil)
}
return retry(ctx, retryFunc, "Waiting for sshd socket")
}

func acceptConnection(ctx context.Context, listener net.Listener, bastion *Bastion, socketURI *url.URL) error {
Expand Down

0 comments on commit 3cb88d9

Please sign in to comment.