diff --git a/proxy/proxy/proxy.go b/proxy/proxy/proxy.go index 865db53b..0bf6ecd0 100644 --- a/proxy/proxy/proxy.go +++ b/proxy/proxy/proxy.go @@ -21,9 +21,11 @@ package proxy import ( "context" + "fmt" "io" "log" "strings" + "sync" "time" "google.golang.org/grpc" @@ -35,6 +37,7 @@ import ( "google.golang.org/protobuf/types/known/durationpb" proxypb "github.com/Snowflake-Labs/sansshell/proxy" + "github.com/go-logr/logr" ) // Conn is a grpc.ClientConnInterface which is connected to the proxy @@ -347,100 +350,129 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_ streamIds := make(map[uint64]*Ret) - // For every target we have to send a separate StartStream (with a nonce which in our case is the target index so clients can map too). - // We then validate the nonce matches and record the stream ID so later processing can match responses to the right targets. - // This needs to be 2 loops as we want the server to process N StartStreams in parallel and then we'll loop getting responses. - for i, t := range p.Targets { - req := &proxypb.ProxyRequest{ - Request: &proxypb.ProxyRequest_StartStream{ - StartStream: &proxypb.StartStream{ - Target: t, - MethodName: method, - Nonce: uint32(i), - }, - }, - } - if p.timeouts[i] != nil { - req.GetStartStream().DialTimeout = durationpb.New(*p.timeouts[i]) - } - err = stream.Send(req) - - // If Send reports an error and is EOF we have to use Recv to get the actual error according to documentation - // for SendMsg. However it appears SendMsg will return actual errors "sometimes" when it's the first stream - // a server has ever handled so account for that here. - if err != nil && err != io.EOF { - return nil, nil, errors, status.Errorf(codes.Internal, "can't send request for %s on stream - %v", method, err) - } - if err != nil { - _, err := stream.Recv() - return nil, nil, errors, status.Errorf(codes.Internal, "remote error from Send for %s - %v", method, err) - } - } + wg := &sync.WaitGroup{} - // We sent len(p.Targets) requests so loop for that many replies. If the server doesn't we'll have to wait until - // our context times out then. If the server attempts something invalid we'll catch and just abort (i.e, duplicate - // responses and/or out of range). We may encounter closes() in here mixed in with replies but we'll never get - // more than that (data can't start until we return). - replies := 0 - for replies != len(p.Targets) { - resp, err := stream.Recv() - if err != nil { - return nil, nil, errors, status.Errorf(codes.Internal, "can't get response for %s on stream - %v", method, err) - } - - // Validate we got an answer and it has expected reflected values. - - // These are all sanity checks for the entire session so an overall error is appropriate since we're likely - // dealing with a broken proxy of some sort. - switch t := resp.Reply.(type) { - case *proxypb.ProxyReply_StartStreamReply: - replies++ - r := t.StartStreamReply - // We want the returned Target+nonce to match what we sent and that it's one we know about. - if r.Nonce >= uint32(len(p.Targets)) { - return nil, nil, errors, status.Errorf(codes.Internal, "got back invalid nonce (out of range): %+v", r) + var sendErr, recvErr error + wg.Add(1) + go func() { + defer wg.Done() + // For every target we have to send a separate StartStream (with a nonce which in our case is the target index so clients can map too). + // We then validate the nonce matches and record the stream ID so later processing can match responses to the right targets. + // This needs to be 2 loops as we want the server to process N StartStreams in parallel and then we'll loop getting responses. + for i, t := range p.Targets { + req := &proxypb.ProxyRequest{ + Request: &proxypb.ProxyRequest_StartStream{ + StartStream: &proxypb.StartStream{ + Target: t, + MethodName: method, + Nonce: uint32(i), + }, + }, } - if p.Targets[r.Nonce] != r.Target { - return nil, nil, errors, status.Errorf(codes.Internal, "Target/nonce don't match. target %s(%d) is not %s: %+v", p.Targets[r.Nonce], r.Nonce, r.Target, r) + if p.timeouts[i] != nil { + req.GetStartStream().DialTimeout = durationpb.New(*p.timeouts[i]) } + err = stream.Send(req) - id := r.GetStreamId() - if streamIds[id] != nil { - return nil, nil, errors, status.Errorf(codes.Internal, "Duplicate response for target %s. Already have %+v for response %+v", r.Target, streamIds[id], r) + // If Send reports an error and is EOF we have to use Recv to get the actual error according to documentation + // for SendMsg. However it appears SendMsg will return actual errors "sometimes" when it's the first stream + // a server has ever handled so account for that here. The actual Recv for the error will get caught in the other + // routine below. + if err != nil { + if err != io.EOF { + sendErr = status.Errorf(codes.Internal, "can't send request for %s on stream - %v", method, err) + } + return } + } + }() - ret := &Ret{ - Target: r.GetTarget(), - Index: int(r.GetNonce()), - } - // If the target reported an error stick it in errors. - if s := r.GetErrorStatus(); s != nil { - ret.Error = status.Errorf(codes.Internal, "got reply error from stream. Code: %s Message: %s", codes.Code(s.Code), s.Message) - errors = append(errors, ret) - continue + wg.Add(1) + go func() { + defer wg.Done() + + // We sent len(p.Targets) requests so loop for that many start stream replies. If the server doesn't we'll have to wait until + // our context times out then. If the server attempts something invalid we'll catch and just abort (i.e, duplicate + // responses and/or out of range). We may encounter closes() in here mixed in with replies but we'll never get + // more than that (data can't start until we return). For closes() we discover we note them in errors so later code + // just skips them as they've already errored out. + replies := 0 + for replies != len(p.Targets) { + resp, err := stream.Recv() + if err != nil { + recvErr = status.Errorf(codes.Internal, "can't get response for %s on stream - %v", method, err) + return } - // Save stream ID/nonce for later matching. - streamIds[r.GetStreamId()] = ret - case *proxypb.ProxyReply_ServerClose: - c := t.ServerClose - // We've never sent any data so a close here has to be an error. - st := c.GetStatus() - if st == nil || st.Code == 0 { - return nil, nil, errors, status.Errorf(codes.Internal, "close with no data sent and no error? %+v", resp) - } - for _, id := range c.StreamIds { - if streamIds[id] == nil { - return nil, nil, errors, status.Errorf(codes.Internal, "close on invalid stream id: %+v", resp) + // Validate we got an answer and it has expected reflected values. + + // These are all sanity checks for the entire session so an overall error is appropriate since we're likely + // dealing with a broken proxy of some sort. + switch t := resp.Reply.(type) { + case *proxypb.ProxyReply_StartStreamReply: + replies++ + r := t.StartStreamReply + // We want the returned Target+nonce to match what we sent and that it's one we know about. + if r.Nonce >= uint32(len(p.Targets)) { + recvErr = status.Errorf(codes.Internal, "got back invalid nonce (out of range): %+v", r) + return + } + if p.Targets[r.Nonce] != r.Target { + recvErr = status.Errorf(codes.Internal, "Target/nonce don't match. target %s(%d) is not %s: %+v", p.Targets[r.Nonce], r.Nonce, r.Target, r) + return + } + + id := r.GetStreamId() + if streamIds[id] != nil { + recvErr = status.Errorf(codes.Internal, "Duplicate response for target %s. Already have %+v for response %+v", r.Target, streamIds[id], r) + return + } + + ret := &Ret{ + Target: r.GetTarget(), + Index: int(r.GetNonce()), } - streamIds[id].Error = status.Errorf(codes.Internal, "got close error from stream. Code: %s Message: %s", codes.Code(st.Code), st.Message) - errors = append(errors, streamIds[id]) - // If it's closed make sure we don't process it later on. - delete(streamIds, id) + // If the target reported an error stick it in errors. + if s := r.GetErrorStatus(); s != nil { + ret.Error = status.Errorf(codes.Internal, "got reply error from stream. Code: %s Message: %s", codes.Code(s.Code), s.Message) + errors = append(errors, ret) + continue + } + + // Save stream ID/nonce for later matching. + streamIds[r.GetStreamId()] = ret + case *proxypb.ProxyReply_ServerClose: + c := t.ServerClose + // We've never sent any data so a close here has to be an error. + st := c.GetStatus() + if st == nil || st.Code == 0 { + recvErr = status.Errorf(codes.Internal, "close with no data sent and no error? %+v", resp) + return + } + for _, id := range c.StreamIds { + if streamIds[id] == nil { + recvErr = status.Errorf(codes.Internal, "close on invalid stream id: %+v", resp) + return + } + streamIds[id].Error = status.Errorf(codes.Internal, "got close error from stream. Code: %s Message: %s", codes.Code(st.Code), st.Message) + errors = append(errors, streamIds[id]) + // If it's closed make sure we don't process it later on. + delete(streamIds, id) + } + default: + recvErr = status.Errorf(codes.Internal, "unexpected reply for %s on stream - %+v", method, resp) + return } - default: - return nil, nil, errors, status.Errorf(codes.Internal, "unexpected reply for %s on stream - %+v", method, resp) } + }() + + wg.Wait() + + log := logr.FromContextOrDiscard(ctx) + if sendErr != nil || recvErr != nil { + err := fmt.Errorf("Setting up streams errors: %v - %v", sendErr, recvErr) + log.Error(err, "Setup error") + return nil, nil, errors, err } return stream, streamIds, errors, nil }