Skip to content

Commit

Permalink
Dial timeout client side implementation (#132)
Browse files Browse the repository at this point in the history
* Plumb dial options into targets/proxy.

Do this by allowing an optional semi-colon after the name with a time.Duration style.

Add tests in proxy client side to validate. Also address some holes we allowed before (like blank targets).

Plug into CLI and then add integration tests to prove it all works.
  • Loading branch information
sfc-gh-jchacon authored May 27, 2022
1 parent a3222fb commit 5d7694c
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 19 deletions.
30 changes: 23 additions & 7 deletions cmd/sanssh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,23 @@ func hasPort(s string) bool {
return strings.LastIndex(s, "]") < strings.LastIndex(s, ":")
}

func validateAndAddPort(s string, port int) string {
// See if there's a duration appended and pull it off
p := strings.Split(s, ";")
if len(p) == 0 || len(p) > 2 || p[0] == "" {
log.Fatalf("Invalid address %q - should be of the form host[:port][;<duration>]", s)
}
new := s
if !hasPort(p[0]) {
new = fmt.Sprintf("%s:%d", p[0], port)
if len(p) == 2 {
// Add duration back if we pulled it off.
new = fmt.Sprintf("%s;%s", new, p[1])
}
}
return new
}

func main() {
// If this is blank it'll remain blank which is fine
// as that means just talk to --targets[0] instead.
Expand Down Expand Up @@ -142,16 +159,15 @@ func main() {
}
}

// Add the default proxy port (if needed).
if *proxyAddr != "" && !hasPort(*proxyAddr) {
*proxyAddr = fmt.Sprintf("%s:%d", *proxyAddr, defaultProxyPort)
// Validate and add the default proxy port (if needed).
if *proxyAddr != "" {
*proxyAddr = validateAndAddPort(*proxyAddr, defaultProxyPort)
}
// Add default target port (if needed).
// Validate and add the default target port (if needed) for each target.
for i, t := range *targetsFlag.Target {
if !hasPort(t) {
(*targetsFlag.Target)[i] = fmt.Sprintf("%s:%d", t, defaultTargetPort)
}
(*targetsFlag.Target)[i] = validateAndAddPort(t, defaultTargetPort)
}

clientPolicy := cmdUtil.ChoosePolicy(logr.Discard(), "", *clientPolicyFlag, *clientPolicyFile)

logOpts := log.Ldate | log.Ltime | log.Lshortfile
Expand Down
78 changes: 73 additions & 5 deletions proxy/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ import (
"context"
"io"
"log"
"strings"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/durationpb"

proxypb "github.com/Snowflake-Labs/sansshell/proxy"
)
Expand All @@ -40,6 +43,9 @@ type Conn struct {
// The targets we're proxying for currently.
Targets []string

// Possible dial timeouts for each target
timeouts []*time.Duration

// The RPC connection to the proxy.
cc *grpc.ClientConn

Expand Down Expand Up @@ -354,6 +360,9 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_
},
},
}
if p.timeouts[i] != nil {
req.GetStartStream().DialTimeout = durationpb.New(*p.timeouts[i])
}
err = stream.Send(req)

// If Send reports an error and is EOF we have to use Recv to get the actual error according to documentation
Expand Down Expand Up @@ -571,33 +580,92 @@ func (p *Conn) Close() error {
return p.cc.Close()
}

func parseTargets(targets []string) ([]string, []*time.Duration, error) {
var hostport []string
var timeouts []*time.Duration
for _, t := range targets {
if len(t) == 0 {
return nil, nil, status.Error(codes.InvalidArgument, "blank targets are not allowed")
}
// First pull off any possible duration
p := strings.Split(t, ";")
if len(p) > 2 {
return nil, nil, status.Errorf(codes.InvalidArgument, "target must be of the form host[:port][;N<time.Duration>] and %s is invalid", t)
}
if p[0] == "" {
return nil, nil, status.Error(codes.InvalidArgument, "blank targets are not allowed")
}
// That's the most we'll parse target here. The rest can happen in Invoke/NewStream later.
hostport = append(hostport, p[0])
if len(p) == 1 {
// No timeout, so just append a placeholder empty value.
timeouts = append(timeouts, nil)
continue
}
d, err := time.ParseDuration(p[1])
if err != nil {
return nil, nil, status.Errorf(codes.InvalidArgument, "%s invalid duration - %v", p[1], err)
}
timeouts = append(timeouts, &d)
}
return hostport, timeouts, nil
}

// Dial will connect to the given proxy and setup to send RPCs to the listed targets.
// If proxy is blank and there is only one target this will return a normal grpc connection object (*grpc.ClientConn).
// Otherwise this will return a *ProxyConn setup to act with the proxy.
// Otherwise this will return a *ProxyConn setup to act with the proxy. Targets is a list of normal gRPC style
// endpoint addresses with an optional dial timeout appended with a semi-colon in time.Duration format.
// i.e. host[:port][;Ns] for instance to set the dial timeout to N seconds. The proxy value can also specify a dial timeout
// in the same fashion.
func Dial(proxy string, targets []string, opts ...grpc.DialOption) (*Conn, error) {
return DialContext(context.Background(), proxy, targets, opts...)
}

// DialContext is the same as Dial except the context provided can be used to cancel or expire the pending connection.
// By default dial operations are non-blocking. See grpc.Dial for a complete explanation.
func DialContext(ctx context.Context, proxy string, targets []string, opts ...grpc.DialOption) (*Conn, error) {
ret := &Conn{}
parsedProxy := []string{proxy}
proxyTimeout := []*time.Duration{nil}

var err error
if proxy != "" {
parsedProxy, proxyTimeout, err = parseTargets([]string{proxy})
if err != nil {
return nil, err
}
}
hostport, timeouts, err := parseTargets(targets)
if err != nil {
return nil, err
}

// If there are no targets things will likely fail but this gives the ability to still send RPCs to the
// proxy itself.
dialTarget := proxy
ret := &Conn{}
dialTarget := parsedProxy[0]
dialTimeout := proxyTimeout[0]
if proxy == "" {
if len(targets) != 1 {
return nil, status.Error(codes.InvalidArgument, "no proxy specified but more than one target set")
}
dialTarget = targets[0]
dialTarget = hostport[0]
dialTimeout = timeouts[0]
ret.direct = true

}
var cancel context.CancelFunc
if dialTimeout != nil {
ctx, cancel = context.WithTimeout(ctx, *dialTimeout)
opts = append(opts, grpc.WithBlock())
defer cancel()
}
conn, err := grpc.DialContext(ctx, dialTarget, opts...)
if err != nil {
return nil, err
}
ret.cc = conn
// Make our own copy of these.
ret.Targets = append(ret.Targets, targets...)
ret.Targets = append(ret.Targets, hostport...)
ret.timeouts = append(ret.timeouts, timeouts...)
return ret, nil
}
46 changes: 44 additions & 2 deletions proxy/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func startTestProxy(ctx context.Context, t *testing.T, targets map[string]*bufco
func TestDial(t *testing.T) {
ctx := context.Background()
testServerMap := testutil.StartTestDataServers(t, "foo:123", "bar:123")
startTestProxy(ctx, t, testServerMap)
bufMap := startTestProxy(ctx, t, testServerMap)

// This should fail since we don't set credentials
_, err := proxy.DialContext(ctx, "b", []string{"foo:123"})
Expand All @@ -82,6 +82,12 @@ func TestDial(t *testing.T) {
targets: []string{"foo:123", "bar:123"},
options: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
},
{
name: "proxy with timeout and N hosts with timeouts",
proxy: "proxy;5s",
targets: []string{"foo:123;5s", "bar:123;5s"},
options: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
},
{
name: "no proxy and a host",
targets: []string{"foo:123"},
Expand Down Expand Up @@ -115,10 +121,46 @@ func TestDial(t *testing.T) {
targets: []string{"foo:123", "bar:123"},
wantErr: true,
},
{
name: "proxy with a dial duration",
proxy: "proxy;5s",
targets: []string{"foo:123", "bar:123"},
options: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
},
{
name: "proxy with bad duration",
proxy: "proxy;5p",
targets: []string{"foo:123", "bar:123"},
options: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
wantErr: true,
},
{
name: "target with bad form",
proxy: "proxy",
targets: []string{"foo:123;5s;5s", "bar:123"},
options: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
wantErr: true,
},
{
name: "blank target",
proxy: "proxy",
targets: []string{"", "bar:123"},
options: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
wantErr: true,
},
{
name: "blank target2",
proxy: "proxy",
targets: []string{";5s", "bar:123"},
options: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
wantErr: true,
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
_, err := proxy.Dial(tc.proxy, tc.targets, tc.options...)
opts := []grpc.DialOption{testutil.WithBufDialer(bufMap)}
opts = append(opts, tc.options...)
_, err := proxy.Dial(tc.proxy, tc.targets, opts...)
tu.WantErr(tc.name, err, tc.wantErr, t)
})
}
Expand Down
15 changes: 10 additions & 5 deletions testing/integrate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -954,10 +954,15 @@ ${SANSSH_NOPROXY} ${SINGLE_TARGET} file mv ${LOGS}/testdir/file ${LOGS}/testdir/
check_status $? /dev/null mv failed
check_mv

echo "parallel work with some bad targets"
echo "parallel work with some bad targets and various timeouts"
mkdir -p "${LOGS}/parallel"
if ${SANSSH_PROXY} --timeout=5s --output-dir="${LOGS}/parallel" --targets=localhost,1.1.1.1,0.0.0.1,localhost healthcheck validate; then
check_status 1 /dev/null healtcheck did not error out
start=$(date +%s)
if ${SANSSH_NOPROXY} --proxy="localhost;2s" --timeout=10s --output-dir="${LOGS}/parallel" --targets="localhost:50042;3s,1.1.1.1;4s,0.0.0.1;5s,localhost;6s" healthcheck validate; then
check_status 1 /dev/null healthcheck did not error out
fi
end=$(date +%s)
if [ "$(expr \( "${end}" - "${start}" \) \> 9)" == "1" ]; then
check_status 1 /dev/null took to main deadline. should be no more than 6s
fi

echo "Logs from parallel work - debugging"
Expand All @@ -972,10 +977,10 @@ done
errors=$(cat "${LOGS}"/parallel/*.error | wc -l)
healthy=$(cat "${LOGS}"/parallel/? | grep -c -h -E "Target.*healthy")
if [ "${errors}" != 2 ]; then
check_status 1 /dev/null 2 targets should be unhealthy for various reasons
check_status 1 /dev/null 2 "targets should be unhealthy for various reasons"
fi
if [ "${healthy}" != 2 ]; then
check_status 1 /dev/null 2 targets should be healthy
check_status 1 /dev/null 2 "targets should be healthy"
fi

# TODO(jchacon): Provide a java binary for test{s
Expand Down

0 comments on commit 5d7694c

Please sign in to comment.