diff --git a/proxy/server/target.go b/proxy/server/target.go index 9a2d7770..bb986461 100644 --- a/proxy/server/target.go +++ b/proxy/server/target.go @@ -198,13 +198,14 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { } var err error defer cancel() - s.grpcConn, err = s.dialer.DialContext(dialCtx, s.target, opts...) + grpcConn, err := s.dialer.DialContext(dialCtx, s.target, opts...) if err != nil { // We cannot create a new stream to the target. So we need to cancel this stream. s.logger.Info("unable to create stream", "status", err) s.cancelFunc() return err } + s.grpcConn = grpcConn grpcStream, err := s.grpcConn.NewStream(s.ctx, s.serviceMethod.StreamDesc(), s.serviceMethod.FullName()) if err != nil { // We cannot create a new stream to the target. So we need to cancel this stream. @@ -325,8 +326,10 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { // Once all calls are complete, we need to close our network connection // to the server. - if closeErr := s.grpcConn.Close(); err == nil && closeErr != nil { - err = closeErr + if s.grpcConn != nil { + if closeErr := s.grpcConn.Close(); err == nil && closeErr != nil { + err = closeErr + } } // The error status may by set/overidden if CloseWith was used to diff --git a/proxy/server/target_test.go b/proxy/server/target_test.go index 20e93a15..a5f0eb1e 100644 --- a/proxy/server/target_test.go +++ b/proxy/server/target_test.go @@ -35,7 +35,7 @@ import ( // A TargetDialer than returns an error for all Dials type dialErrTargetDialer codes.Code -func (e dialErrTargetDialer) DialContext(ctx context.Context, target string, dialOpts ...grpc.DialOption) (grpc.ClientConnInterface, error) { +func (e dialErrTargetDialer) DialContext(ctx context.Context, target string, dialOpts ...grpc.DialOption) (ClientConnCloser, error) { return nil, status.Error(codes.Code(e), "") } @@ -150,10 +150,14 @@ func (b blockingClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc return nil, ctx.Err() } +func (b blockingClientConn) Close() error { + return nil +} + // a context dialer that returns blockingClientConn type blockingClientDialer struct{} -func (b blockingClientDialer) DialContext(ctx context.Context, target string, dialOpts ...grpc.DialOption) (grpc.ClientConnInterface, error) { +func (b blockingClientDialer) DialContext(ctx context.Context, target string, dialOpts ...grpc.DialOption) (ClientConnCloser, error) { return blockingClientConn{}, nil }