diff --git a/.golangci.yml b/.golangci.yml index d9bf14b2e13..a543cdb5914 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -127,6 +127,7 @@ linters-settings: - "github.com/Sirupsen/logrus": "must use github.com/dapr/kit/logger" - "github.com/agrea/ptr": "must use github.com/dapr/kit/ptr" - "go.uber.org/atomic": "must use sync/atomic" + - "golang.org/x/net/context": "must use context" - "github.com/pkg/errors": "must use standard library (errors package and/or fmt.Errorf)" - "github.com/cenkalti/backoff": "must use github.com/cenkalti/backoff/v4" - "github.com/cenkalti/backoff/v2": "must use github.com/cenkalti/backoff/v4" diff --git a/pkg/grpc/proxy/codec/codec.go b/pkg/grpc/proxy/codec/codec.go index e1567d7542c..abdc8395191 100644 --- a/pkg/grpc/proxy/codec/codec.go +++ b/pkg/grpc/proxy/codec/codec.go @@ -56,10 +56,12 @@ type Frame struct { } // ProtoMessage tags a frame as valid proto message. -func (f *Frame) ProtoMessage() {} +func (f *Frame) ProtoMessage() { + // nop +} // Marshal implements the encoding.Codec interface method. -func (p *Proxy) Marshal(v interface{}) ([]byte, error) { +func (p *Proxy) Marshal(v any) ([]byte, error) { out, ok := v.(*Frame) if !ok { return p.parentCodec.Marshal(v) @@ -69,7 +71,7 @@ func (p *Proxy) Marshal(v interface{}) ([]byte, error) { } // Unmarshal implements the encoding.Codec interface method. -func (p *Proxy) Unmarshal(data []byte, v interface{}) error { +func (p *Proxy) Unmarshal(data []byte, v any) error { dst, ok := v.(*Frame) if !ok { return p.parentCodec.Unmarshal(data, v) @@ -86,25 +88,25 @@ func (*Proxy) Name() string { // protoCodec is a Codec implementation with protobuf. It is the default rawCodec for gRPC. type protoCodec struct{} -func (*protoCodec) Marshal(v interface{}) ([]byte, error) { - switch t := v.(type) { +func (*protoCodec) Marshal(v any) ([]byte, error) { + switch x := v.(type) { case proto.Message: - return proto.Marshal(v.(proto.Message)) + return proto.Marshal(x) case protoV1.Message: - return protoV1.Marshal(v.(protoV1.Message)) + return protoV1.Marshal(x) default: - return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", t) + return nil, fmt.Errorf("failed to marshal: message is %T, want proto.Message", x) } } -func (*protoCodec) Unmarshal(data []byte, v interface{}) error { - switch t := v.(type) { +func (*protoCodec) Unmarshal(data []byte, v any) error { + switch x := v.(type) { case proto.Message: - return proto.Unmarshal(data, v.(proto.Message)) + return proto.Unmarshal(data, x) case protoV1.Message: - return protoV1.Unmarshal(data, v.(protoV1.Message)) + return protoV1.Unmarshal(data, x) default: - return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", t) + return fmt.Errorf("failed to unmarshal: message is %T, want proto.Message", x) } } diff --git a/pkg/grpc/proxy/director.go b/pkg/grpc/proxy/director.go index 2195c317bbe..bd35ac87b5f 100644 --- a/pkg/grpc/proxy/director.go +++ b/pkg/grpc/proxy/director.go @@ -4,7 +4,8 @@ package proxy import ( - "golang.org/x/net/context" + "context" + "google.golang.org/grpc" ) diff --git a/pkg/grpc/proxy/handler.go b/pkg/grpc/proxy/handler.go index e8c49c3159d..50e19ec10f4 100644 --- a/pkg/grpc/proxy/handler.go +++ b/pkg/grpc/proxy/handler.go @@ -4,18 +4,17 @@ package proxy import ( - "fmt" + "context" "io" "sync" + "sync/atomic" - "golang.org/x/net/context" + "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "github.com/google/uuid" - "github.com/dapr/dapr/pkg/diagnostics" "github.com/dapr/dapr/pkg/grpc/proxy/codec" "github.com/dapr/dapr/pkg/resiliency" @@ -37,7 +36,7 @@ func RegisterService(server *grpc.Server, director StreamDirector, resiliency re } fakeDesc := &grpc.ServiceDesc{ ServiceName: serviceName, - HandlerType: (*interface{})(nil), + HandlerType: (*any)(nil), } for _, m := range methodNames { streamDesc := grpc.StreamDesc{ @@ -58,11 +57,10 @@ func RegisterService(server *grpc.Server, director StreamDirector, resiliency re // This can *only* be used if the `server` also uses grpcproxy.CodecForServer() ServerOption. func TransparentHandler(director StreamDirector, resiliency resiliency.Provider, isLocalFn func(string) (bool, error), connFactory DirectorConnectionFactory) grpc.StreamHandler { streamer := &handler{ - director: director, - resiliency: resiliency, - isLocalFn: isLocalFn, - bufferedCalls: sync.Map{}, - connFactory: connFactory, + director: director, + resiliency: resiliency, + isLocalFn: isLocalFn, + connFactory: connFactory, } return streamer.handler } @@ -72,27 +70,25 @@ type handler struct { resiliency resiliency.Provider isLocalFn func(string) (bool, error) bufferedCalls sync.Map - headersSent sync.Map connFactory DirectorConnectionFactory } // handler is where the real magic of proxying happens. // It is invoked like any gRPC server stream and uses the gRPC server framing to get and receive bytes from the wire, // forwarding it to a ClientStream established against the relevant ClientConn. -func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error { +func (s *handler) handler(srv any, serverStream grpc.ServerStream) error { // Create buffered calls for this request. requestIDObj, err := uuid.NewRandom() if err != nil { return status.Errorf(codes.Internal, "failed to generate UUID: %v", err) } requestID := requestIDObj.String() - s.bufferedCalls.Store(requestID, []interface{}{}) - s.headersSent.Store(requestID, false) + s.bufferedCalls.Store(requestID, []any{}) // little bit of gRPC internals never hurt anyone fullMethodName, ok := grpc.MethodFromServerStream(serverStream) if !ok { - return status.Errorf(codes.Internal, "lowLevelServerStream not exists in context") + return status.Errorf(codes.Internal, "full method name not found in stream") } // Fetch the AppId so we can reference it for resiliency. @@ -107,8 +103,8 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error policyDef = noOp.EndpointPolicy("", "") } else { isLocal, err := s.isLocalFn(v[0]) - if err == nil && isLocal { - policyDef = s.resiliency.EndpointPolicy(v[0], fmt.Sprintf("%s:%s", v[0], fullMethodName)) + if err == nil && !isLocal { + policyDef = s.resiliency.EndpointPolicy(v[0], v[0]+":"+fullMethodName) } else { noOp := resiliency.NoOp{} policyDef = noOp.EndpointPolicy("", "") @@ -119,7 +115,12 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error clientStreamOpts := []grpc.CallOption{ grpc.CallContentSubtype((&codec.Proxy{}).Name()), } + headersSent := &atomic.Bool{} + counter := atomic.Int32{} _, cErr := policyRunner(func(ctx context.Context) (struct{}, error) { + // Get the current iteration count + iter := counter.Add(1) + // We require that the director's returned context inherits from the ctx. outgoingCtx, backendConn, target, teardown, err := s.director(ctx, fullMethodName) defer teardown(false) @@ -128,12 +129,26 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error } clientCtx, clientCancel := context.WithCancel(outgoingCtx) + defer clientCancel() // TODO(mwitkow): Add a `forwarded` header to metadata, https://en.wikipedia.org/wiki/X-Forwarded-For. - clientStream, err := grpc.NewClientStream(clientCtx, clientStreamDescForProxying, backendConn, fullMethodName, clientStreamOpts...) + clientStream, err := grpc.NewClientStream( + clientCtx, + clientStreamDescForProxying, + backendConn, + fullMethodName, + clientStreamOpts..., + ) if err != nil { code := status.Code(err) if target != nil && (code == codes.Unavailable || code == codes.Unauthenticated) { + // It's possible that we get to this point while another goroutine is executing the same policy function. + // For example, this could happen if this iteration has timed out and "policyRunner" has triggered a new execution already. + // In this case, we should not teardown the connection because it could being used by the next execution. So just return and move on. + if counter.Load() != iter { + return struct{}{}, err + } + // Destroy the connection so it can be recreated teardown(true) @@ -148,6 +163,8 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error if err != nil { return struct{}{}, err } + } else { + return struct{}{}, err } } @@ -155,13 +172,13 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error // Channels do not have to be closed, it is just a control flow mechanism, see // https://groups.google.com/forum/#!msg/golang-nuts/pZwdYRGxCIk/qpbHxRRPJdUJ s2cErrChan := s.forwardServerToClient(serverStream, clientStream, requestID) - c2sErrChan := s.forwardClientToServer(clientStream, serverStream, requestID) + c2sErrChan := s.forwardClientToServer(clientStream, serverStream, headersSent) // We don't know which side is going to stop sending first, so we need a select between the two. for i := 0; i < 2; i++ { select { case s2cErr := <-s2cErrChan: if s2cErr == io.EOF { - // this is the happy case where the sender has encountered io.EOF, and won't be sending anymore./ + // this is the happy case where the sender has encountered io.EOF, and won't be sending anymore. // the clientStream>serverStream may continue pumping though. clientStream.CloseSend() continue @@ -169,8 +186,7 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error // however, we may have gotten a receive error (stream disconnected, a read error etc) in which case we need // to cancel the clientStream to the backend, let all of its goroutines be freed up by the CancelFunc and // exit with an error to the stack - clientCancel() - return struct{}{}, status.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr) + return struct{}{}, status.Error(codes.Internal, "failed proxying s2c: "+s2cErr.Error()) } case c2sErr := <-c2sErrChan: // This happens when the clientStream has nothing else to offer (io.EOF), returned a gRPC error. In those two @@ -184,42 +200,44 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error return struct{}{}, nil } } - return struct{}{}, status.Errorf(codes.Internal, "gRPC proxying should never reach this stage.") + + return struct{}{}, status.Error(codes.Internal, "gRPC proxying should never reach this stage") }) // Clear the request's buffered calls. s.bufferedCalls.Delete(requestID) - s.headersSent.Delete(requestID) return cErr } -func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream, requestID string) chan error { +func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream, headersSent *atomic.Bool) chan error { ret := make(chan error, 1) go func() { + var err error f := &codec.Frame{} - syncMapValue, _ := s.headersSent.Load(requestID) - localHeaders := syncMapValue.(bool) - for i := 0; ; i++ { - if err := src.RecvMsg(f); err != nil { + + for src.Context().Err() == nil && dst.Context().Err() == nil { + err = src.RecvMsg(f) + if err != nil { ret <- err // this can be io.EOF which is happy case break } // In the case of retries, don't resend the headers. - if i == 0 && !localHeaders { + if headersSent.CompareAndSwap(false, true) { // This is a bit of a hack, but client to server headers are only readable after first client msg is // received but must be written to server stream before the first msg is flushed. // This is the only place to do it nicely. - md, err := src.Header() + var md metadata.MD + md, err = src.Header() if err != nil { break } - if err := dst.SendHeader(md); err != nil { + err = dst.SendHeader(md) + if err != nil { break } - localHeaders = true - s.headersSent.Store(requestID, true) } - if err := dst.SendMsg(f); err != nil { + err = dst.SendMsg(f) + if err != nil { break } } @@ -230,23 +248,31 @@ func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerSt func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream, requestID string) chan error { ret := make(chan error, 1) go func() { - f := &codec.Frame{} + var err error + + // Start by sending buffered messages syncMapValue, _ := s.bufferedCalls.Load(requestID) - bufferedFrames := syncMapValue.([]interface{}) + bufferedFrames := syncMapValue.([]any) for _, msg := range bufferedFrames { - if err := dst.SendMsg(msg); err != nil { + err = dst.SendMsg(msg) + if err != nil { ret <- err return } } - for i := 0; ; i++ { - if err := src.RecvMsg(f); err != nil { + + // Receive messages from the source stream and forward them to the destination stream + for src.Context().Err() == nil && dst.Context().Err() == nil { + f := &codec.Frame{} + err = src.RecvMsg(f) + if err != nil { s.bufferedCalls.Store(requestID, bufferedFrames) ret <- err // this can be io.EOF which is happy case break } bufferedFrames = append(bufferedFrames, f) - if err := dst.SendMsg(f); err != nil { + err = dst.SendMsg(f) + if err != nil { s.bufferedCalls.Store(requestID, bufferedFrames) break } diff --git a/pkg/grpc/proxy/handler_test.go b/pkg/grpc/proxy/handler_test.go index 57231f285e7..55b2e5113dc 100644 --- a/pkg/grpc/proxy/handler_test.go +++ b/pkg/grpc/proxy/handler_test.go @@ -17,18 +17,20 @@ limitations under the License. package proxy import ( + "context" "fmt" "io" "net" + "strconv" "strings" "sync" + "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" @@ -56,7 +58,8 @@ const ( // asserting service is implemented on the server side and serves as a handler for stuff. type assertingService struct { pb.UnimplementedTestServiceServer - t *testing.T + t *testing.T + expectPingStreamError *atomic.Bool } func (s *assertingService) PingEmpty(ctx context.Context, _ *pb.Empty) (*pb.PingResponse, error) { @@ -97,7 +100,11 @@ func (s *assertingService) PingStream(stream pb.TestService_PingStreamServer) er if err == io.EOF { break } else if err != nil { - require.NoError(s.t, err, "can't fail reading stream") + if s.expectPingStreamError.Load() { + require.Error(s.t, err, "should have failed reading stream") + } else { + require.NoError(s.t, err, "can't fail reading stream") + } return err } pong := &pb.PingResponse{Value: ping.Value, Counter: counter} @@ -110,8 +117,7 @@ func (s *assertingService) PingStream(stream pb.TestService_PingStreamServer) er return nil } -// ProxyHappySuite tests the "happy" path of handling: that everything works in absence of connection issues. -type ProxyHappySuite struct { +type proxyTestSuite struct { suite.Suite serverListener net.Listener @@ -119,38 +125,42 @@ type ProxyHappySuite struct { proxyListener net.Listener proxy *grpc.Server serverClientConn *grpc.ClientConn + service *assertingService + lock sync.Mutex client *grpc.ClientConn testClient pb.TestServiceClient } -func (s *ProxyHappySuite) ctx() context.Context { - // Make all RPC calls last at most 1 sec, meaning all async issues or deadlock will not kill tests. - ctx, _ := context.WithTimeout(context.Background(), 120*time.Second) - return ctx +func (s *proxyTestSuite) ctx() (context.Context, context.CancelFunc) { + // Make all RPC calls last at most 5 sec, meaning all async issues or deadlock will not kill tests. + return context.WithTimeout(context.Background(), 5*time.Second) } -func (s *ProxyHappySuite) TestPingEmptyCarriesClientMetadata() { - // s.T().Skip() - ctx := metadata.NewOutgoingContext(s.ctx(), metadata.Pairs(clientMdKey, "true")) +func (s *proxyTestSuite) TestPingEmptyCarriesClientMetadata() { + ctx, cancel := s.ctx() + defer cancel() + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(clientMdKey, "true")) out, err := s.testClient.PingEmpty(ctx, &pb.Empty{}) require.NoError(s.T(), err, "PingEmpty should succeed without errors") require.Equal(s.T(), pingDefaultValue, out.Value) require.Equal(s.T(), int32(42), out.Counter) } -func (s *ProxyHappySuite) TestPingEmpty_StressTest() { +func (s *proxyTestSuite) TestPingEmpty_StressTest() { for i := 0; i < 50; i++ { s.TestPingEmptyCarriesClientMetadata() } } -func (s *ProxyHappySuite) TestPingCarriesServerHeadersAndTrailers() { +func (s *proxyTestSuite) TestPingCarriesServerHeadersAndTrailers() { // s.T().Skip() headerMd := make(metadata.MD) trailerMd := make(metadata.MD) + ctx, cancel := s.ctx() + defer cancel() // This is an awkward calling convention... but meh. - out, err := s.testClient.Ping(s.ctx(), &pb.PingRequest{Value: "foo"}, grpc.Header(&headerMd), grpc.Trailer(&trailerMd)) + out, err := s.testClient.Ping(ctx, &pb.PingRequest{Value: "foo"}, grpc.Header(&headerMd), grpc.Trailer(&trailerMd)) require.NoError(s.T(), err, "Ping should succeed without errors") require.Equal(s.T(), "foo", out.Value) require.Equal(s.T(), int32(42), out.Counter) @@ -158,8 +168,10 @@ func (s *ProxyHappySuite) TestPingCarriesServerHeadersAndTrailers() { assert.Len(s.T(), trailerMd, 1, "server response trailers must contain server data") } -func (s *ProxyHappySuite) TestPingErrorPropagatesAppError() { - _, err := s.testClient.PingError(s.ctx(), &pb.PingRequest{Value: "foo"}) +func (s *proxyTestSuite) TestPingErrorPropagatesAppError() { + ctx, cancel := s.ctx() + defer cancel() + _, err := s.testClient.PingError(ctx, &pb.PingRequest{Value: "foo"}) require.Error(s.T(), err, "PingError should never succeed") st, ok := status.FromError(err) require.True(s.T(), ok, "must get status from error") @@ -167,9 +179,11 @@ func (s *ProxyHappySuite) TestPingErrorPropagatesAppError() { assert.Equal(s.T(), "Userspace error.", st.Message()) } -func (s *ProxyHappySuite) TestDirectorErrorIsPropagated() { +func (s *proxyTestSuite) TestDirectorErrorIsPropagated() { + ctx, cancel := s.ctx() + defer cancel() // See SetupSuite where the StreamDirector has a special case. - ctx := metadata.NewOutgoingContext(s.ctx(), metadata.Pairs(rejectingMdKey, "true")) + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(rejectingMdKey, "true")) _, err := s.testClient.Ping(ctx, &pb.PingRequest{Value: "foo"}) require.Error(s.T(), err, "Director should reject this RPC") st, ok := status.FromError(err) @@ -178,25 +192,16 @@ func (s *ProxyHappySuite) TestDirectorErrorIsPropagated() { assert.Equal(s.T(), "testing rejection", st.Message()) } -func (s *ProxyHappySuite) TestPingStream_FullDuplexWorks() { - stream, err := s.testClient.PingStream(s.ctx()) - require.NoError(s.T(), err, "PingStream request should be successful.") +func (s *proxyTestSuite) TestPingStream_FullDuplexWorks() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + stream, err := s.testClient.PingStream(ctx) + require.NoError(s.T(), err, "PingStream request should be successful") for i := 0; i < countListResponses; i++ { - ping := &pb.PingRequest{Value: fmt.Sprintf("foo:%d", i)} - require.NoError(s.T(), stream.Send(ping), "sending to PingStream must not fail") - resp, sErr := stream.Recv() - if sErr == io.EOF { + if s.sendPing(stream, i) { break } - if i == 0 { - // Check that the header arrives before all entries. - headerMd, hErr := stream.Header() - require.NoError(s.T(), hErr, "PingStream headers should not error.") - assert.Contains(s.T(), headerMd, serverHeaderMdKey, "PingStream response headers user contain metadata") - } - require.NotNil(s.T(), resp, "resp must not be nil") - assert.EqualValues(s.T(), i, resp.Counter, "ping roundtrip must succeed with the correct id") } require.NoError(s.T(), stream.CloseSend(), "no error on close send") _, err = stream.Recv() @@ -206,13 +211,13 @@ func (s *ProxyHappySuite) TestPingStream_FullDuplexWorks() { assert.Len(s.T(), trailerMd, 1, "PingList trailer headers user contain metadata") } -func (s *ProxyHappySuite) TestPingStream_StressTest() { +func (s *proxyTestSuite) TestPingStream_StressTest() { for i := 0; i < 50; i++ { s.TestPingStream_FullDuplexWorks() } } -func (s *ProxyHappySuite) TestPingStream_MultipleThreads() { +func (s *proxyTestSuite) TestPingStream_MultipleThreads() { wg := sync.WaitGroup{} for i := 0; i < 4; i++ { wg.Add(1) @@ -237,7 +242,138 @@ func (s *ProxyHappySuite) TestPingStream_MultipleThreads() { } } -func (s *ProxyHappySuite) SetupSuite() { +func (s *proxyTestSuite) TestRecoveryFromNetworkFailure() { + // Make sure everything works before we break things + s.TestPingEmptyCarriesClientMetadata() + + s.T().Run("Fails when no server is running", func(t *testing.T) { + // Stop the server again + s.stopServer(s.T()) + + ctx, cancel := s.ctx() + defer cancel() + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(clientMdKey, "true")) + _, err := s.testClient.PingEmpty(ctx, &pb.Empty{}) + require.Error(t, err, "must fail to ping when server is down") + }) + + s.T().Run("Reconnects to new server", func(t *testing.T) { + // Restart the server + s.restartServer(s.T()) + + s.TestPingEmptyCarriesClientMetadata() + }) +} + +func (s *proxyTestSuite) sendPing(stream pb.TestService_PingStreamClient, i int) (eof bool) { + ping := &pb.PingRequest{Value: fmt.Sprintf("foo:%d", i)} + err := stream.Send(ping) + require.NoError(s.T(), err, "sending to PingStream must not fail") + resp, err := stream.Recv() + if err == io.EOF { + return true + } + if i == 0 { + // Check that the header arrives before all entries. + headerMd, hErr := stream.Header() + require.NoError(s.T(), hErr, "PingStream headers should not error.") + assert.Contains(s.T(), headerMd, serverHeaderMdKey, "PingStream response headers user contain metadata") + } + require.NotNil(s.T(), resp, "resp must not be nil") + assert.EqualValues(s.T(), i, resp.Counter, "ping roundtrip must succeed with the correct id") + return false +} + +func (s *proxyTestSuite) TestStreamConnectionInterrupted() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + stream, err := s.testClient.PingStream(ctx) + require.NoError(s.T(), err, "PingStream request should be successful") + + // Send one message then interrupt the connection + eof := s.sendPing(stream, 0) + require.False(s.T(), eof) + + s.service.expectPingStreamError.Store(true) + defer func() { + s.service.expectPingStreamError.Store(false) + }() + s.stopServer(s.T()) + + // Send another message, which should fail without resiliency + ping := &pb.PingRequest{Value: fmt.Sprintf("foo:%d", 1)} + err = stream.Send(ping) + require.Error(s.T(), err, "sending to PingStream must fail with a stopped server") + + // Restart the server + s.restartServer(s.T()) + + // Pings should still fail with EOF because the strea is closed + err = stream.Send(ping) + require.Error(s.T(), err, "sending to PingStream must fail on a closed stream") + assert.ErrorIs(s.T(), err, io.EOF) +} + +func (s *proxyTestSuite) initServer() { + s.server = grpc.NewServer() + pb.RegisterTestServiceServer(s.server, s.service) +} + +func (s *proxyTestSuite) stopServer(t *testing.T) { + t.Helper() + s.server.Stop() + time.Sleep(250 * time.Millisecond) +} + +func (s *proxyTestSuite) restartServer(t *testing.T) { + t.Helper() + var err error + + srvPort := s.serverListener.Addr().(*net.TCPAddr).Port + s.serverListener, err = net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(srvPort)) + require.NoError(s.T(), err, "must not error while starting serverListener") + + s.T().Logf("re-starting grpc.Server at: %v", s.serverListener.Addr().String()) + s.initServer() + go s.server.Serve(s.serverListener) + + time.Sleep(250 * time.Millisecond) +} + +func (s *proxyTestSuite) getServerClientConn() (conn *grpc.ClientConn, teardown func(bool), err error) { + s.lock.Lock() + defer s.lock.Unlock() + + teardown = func(destroy bool) { + s.lock.Lock() + defer s.lock.Unlock() + + if destroy { + s.serverClientConn.Close() + s.serverClientConn = nil + } + } + + if s.serverClientConn == nil { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + conn, err = grpc.DialContext( + ctx, + s.serverListener.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions(grpc.CallContentSubtype((&codec.Proxy{}).Name())), + grpc.WithBlock(), + ) + if err != nil { + return nil, teardown, err + } + s.serverClientConn = conn + } + + return s.serverClientConn, teardown, nil +} + +func (s *proxyTestSuite) SetupSuite() { var err error pc := encoding.GetCodec((&codec.Proxy{}).Name()) @@ -252,30 +388,41 @@ func (s *ProxyHappySuite) SetupSuite() { grpclog.SetLoggerV2(testingLog{s.T()}) - s.server = grpc.NewServer() - pb.RegisterTestServiceServer(s.server, &assertingService{t: s.T()}) + s.service = &assertingService{ + t: s.T(), + expectPingStreamError: &atomic.Bool{}, + } + + s.initServer() // Setup of the proxy's Director. - s.serverClientConn, err = grpc.Dial( - s.serverListener.Addr().String(), - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithDefaultCallOptions(grpc.CallContentSubtype((&codec.Proxy{}).Name())), - ) - require.NoError(s.T(), err, "must not error on deferred client Dial") director := func(ctx context.Context, fullName string) (context.Context, *grpc.ClientConn, *ProxyTarget, func(bool), error) { + teardown := func(bool) {} + target := &ProxyTarget{} md, ok := metadata.FromIncomingContext(ctx) if ok { if _, exists := md[rejectingMdKey]; exists { - return ctx, nil, nil, func(bool) {}, status.Errorf(codes.PermissionDenied, "testing rejection") + return ctx, nil, target, teardown, status.Errorf(codes.PermissionDenied, "testing rejection") } } // Explicitly copy the metadata, otherwise the tests will fail. - outCtx, _ := context.WithCancel(ctx) - outCtx = metadata.NewOutgoingContext(outCtx, md.Copy()) - return outCtx, s.serverClientConn, nil, func(bool) {}, nil + outCtx := metadata.NewOutgoingContext(ctx, md.Copy()) + conn, teardown, sErr := s.getServerClientConn() + if sErr != nil { + return ctx, nil, target, teardown, status.Errorf(codes.PermissionDenied, "testing rejection") + } + return outCtx, conn, target, teardown, nil } + th := TransparentHandler( + director, + resiliency.New(nil), + func(string) (bool, error) { return true, nil }, + func(ctx context.Context, address, id, namespace string, customOpts ...grpc.DialOption) (*grpc.ClientConn, func(destroy bool), error) { + return s.getServerClientConn() + }, + ) s.proxy = grpc.NewServer( - grpc.UnknownServiceHandler(TransparentHandler(director, resiliency.New(nil), func(string) (bool, error) { return true, nil }, nil)), + grpc.UnknownServiceHandler(th), ) // Ping handler is handled as an explicit registration and not as a TransparentHandler. RegisterService(s.proxy, director, resiliency.New(nil), @@ -284,15 +431,11 @@ func (s *ProxyHappySuite) SetupSuite() { // Start the serving loops. s.T().Logf("starting grpc.Server at: %v", s.serverListener.Addr().String()) - go func() { - s.server.Serve(s.serverListener) - }() + go s.server.Serve(s.serverListener) s.T().Logf("starting grpc.Proxy at: %v", s.proxyListener.Addr().String()) - go func() { - s.proxy.Serve(s.proxyListener) - }() + go s.proxy.Serve(s.proxyListener) - time.Sleep(time.Second) + time.Sleep(500 * time.Millisecond) clientConn, err := grpc.DialContext( context.Background(), @@ -304,7 +447,7 @@ func (s *ProxyHappySuite) SetupSuite() { s.testClient = pb.NewTestServiceClient(clientConn) } -func (s *ProxyHappySuite) TearDownSuite() { +func (s *proxyTestSuite) TearDownSuite() { if s.client != nil { s.client.Close() } @@ -323,8 +466,8 @@ func (s *ProxyHappySuite) TearDownSuite() { } } -func TestProxyHappySuite(t *testing.T) { - suite.Run(t, &ProxyHappySuite{}) +func TestProxySuite(t *testing.T) { + suite.Run(t, &proxyTestSuite{}) } // Abstraction that allows us to pass the *testing.T as a grpclogger. diff --git a/pkg/messaging/grpc_proxy.go b/pkg/messaging/grpc_proxy.go index af3fad636d7..7370c4a053e 100644 --- a/pkg/messaging/grpc_proxy.go +++ b/pkg/messaging/grpc_proxy.go @@ -72,7 +72,9 @@ func NewProxy(opts ProxyOpts) Proxy { // Handler returns a Stream Handler for handling requests that arrive for services that are not recognized by the server. func (p *proxy) Handler() grpc.StreamHandler { - return grpcProxy.TransparentHandler(p.intercept, p.resiliency, p.IsLocal, grpcProxy.DirectorConnectionFactory(p.connectionFactory)) + return grpcProxy.TransparentHandler(p.intercept, p.resiliency, p.IsLocal, + grpcProxy.DirectorConnectionFactory(p.connectionFactory), + ) } func nopTeardown(destroy bool) { @@ -117,11 +119,19 @@ func (p *proxy) intercept(ctx context.Context, fullName string) (context.Context } // proxy to a remote daprd - conn, teardown, cErr := p.connectionFactory(outCtx, target.address, target.id, target.namespace, grpc.WithDefaultCallOptions(grpc.CallContentSubtype((&codec.Proxy{}).Name()))) + conn, teardown, cErr := p.connectionFactory(outCtx, target.address, target.id, target.namespace, + grpc.WithDefaultCallOptions(grpc.CallContentSubtype((&codec.Proxy{}).Name())), + ) outCtx = p.telemetryFn(outCtx) outCtx = metadata.AppendToOutgoingContext(outCtx, invokev1.CallerIDHeader, p.appID, invokev1.CalleeIDHeader, target.id) - return outCtx, conn, &grpcProxy.ProxyTarget{ID: target.id, Namespace: target.namespace, Address: target.address}, teardown, cErr + pt := &grpcProxy.ProxyTarget{ + ID: target.id, + Namespace: target.namespace, + Address: target.address, + } + + return outCtx, conn, pt, teardown, cErr } // SetRemoteAppFn sets a function that helps the proxy resolve an app ID to an actual address. @@ -135,9 +145,9 @@ func (p *proxy) SetTelemetryFn(spanFn func(context.Context) context.Context) { } // Expose the functionality to detect if apps are local or not. -func (p *proxy) IsLocal(appID string) (bool, error) { - _, isLocal, err := p.isLocalInternal(appID) - return isLocal, err +func (p *proxy) IsLocal(appID string) (isLocal bool, err error) { + _, isLocal, err = p.isLocalInternal(appID) + return } func (p *proxy) isLocalInternal(appID string) (remoteApp, bool, error) { diff --git a/pkg/resiliency/policy.go b/pkg/resiliency/policy.go index cabee618a55..3a2c9276795 100644 --- a/pkg/resiliency/policy.go +++ b/pkg/resiliency/policy.go @@ -166,8 +166,8 @@ func NewRunnerWithOptions[T any](ctx context.Context, def *PolicyDefinition, opt return rRes, rErr }, b, - func(opErr error, _ time.Duration) { - def.log.Infof("Error processing operation %s. Retrying…", def.name) + func(opErr error, d time.Duration) { + def.log.Infof("Error processing operation %s. Retrying in %v…", def.name, d) def.log.Debugf("Error for operation %s was: %v", def.name, opErr) }, func() {