Skip to content

Commit

Permalink
Multiple fixes to gRPC proxy including panics, data loss, resiliency …
Browse files Browse the repository at this point in the history
…not working (dapr#5751)

* Multiple fixes to gRPC proxy including panics, data loss

Signed-off-by: ItalyPaleAle <[email protected]>

* 💄

Signed-off-by: ItalyPaleAle <[email protected]>

* Better 💄

Signed-off-by: ItalyPaleAle <[email protected]>

* Added more unit tests

Signed-off-by: ItalyPaleAle <[email protected]>

* 💄

Signed-off-by: ItalyPaleAle <[email protected]>

Signed-off-by: ItalyPaleAle <[email protected]>
  • Loading branch information
ItalyPaleAle authored Jan 13, 2023
1 parent ccce9e4 commit 3e32456
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 123 deletions.
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ linters-settings:
- "github.com/Sirupsen/logrus": "must use github.com/dapr/kit/logger"
- "github.com/agrea/ptr": "must use github.com/dapr/kit/ptr"
- "go.uber.org/atomic": "must use sync/atomic"
- "golang.org/x/net/context": "must use context"
- "github.com/pkg/errors": "must use standard library (errors package and/or fmt.Errorf)"
- "github.com/cenkalti/backoff": "must use github.com/cenkalti/backoff/v4"
- "github.com/cenkalti/backoff/v2": "must use github.com/cenkalti/backoff/v4"
Expand Down
28 changes: 15 additions & 13 deletions pkg/grpc/proxy/codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ type Frame struct {
}

// ProtoMessage tags a frame as valid proto message.
func (f *Frame) ProtoMessage() {}
func (f *Frame) ProtoMessage() {
// nop
}

// Marshal implements the encoding.Codec interface method.
func (p *Proxy) Marshal(v interface{}) ([]byte, error) {
func (p *Proxy) Marshal(v any) ([]byte, error) {
out, ok := v.(*Frame)
if !ok {
return p.parentCodec.Marshal(v)
Expand All @@ -69,7 +71,7 @@ func (p *Proxy) Marshal(v interface{}) ([]byte, error) {
}

// Unmarshal implements the encoding.Codec interface method.
func (p *Proxy) Unmarshal(data []byte, v interface{}) error {
func (p *Proxy) Unmarshal(data []byte, v any) error {
dst, ok := v.(*Frame)
if !ok {
return p.parentCodec.Unmarshal(data, v)
Expand All @@ -86,25 +88,25 @@ func (*Proxy) Name() string {
// protoCodec is a Codec implementation with protobuf. It is the default rawCodec for gRPC.
type protoCodec struct{}

func (*protoCodec) Marshal(v interface{}) ([]byte, error) {
switch t := v.(type) {
func (*protoCodec) Marshal(v any) ([]byte, error) {
switch x := v.(type) {
case proto.Message:
return proto.Marshal(v.(proto.Message))
return proto.Marshal(x)
case protoV1.Message:
return protoV1.Marshal(v.(protoV1.Message))
return protoV1.Marshal(x)
default:
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", t)
return nil, fmt.Errorf("failed to marshal: message is %T, want proto.Message", x)
}
}

func (*protoCodec) Unmarshal(data []byte, v interface{}) error {
switch t := v.(type) {
func (*protoCodec) Unmarshal(data []byte, v any) error {
switch x := v.(type) {
case proto.Message:
return proto.Unmarshal(data, v.(proto.Message))
return proto.Unmarshal(data, x)
case protoV1.Message:
return protoV1.Unmarshal(data, v.(protoV1.Message))
return protoV1.Unmarshal(data, x)
default:
return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", t)
return fmt.Errorf("failed to unmarshal: message is %T, want proto.Message", x)
}
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/grpc/proxy/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
package proxy

import (
"golang.org/x/net/context"
"context"

"google.golang.org/grpc"
)

Expand Down
108 changes: 67 additions & 41 deletions pkg/grpc/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
package proxy

import (
"fmt"
"context"
"io"
"sync"
"sync/atomic"

"golang.org/x/net/context"
"github.com/google/uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/google/uuid"

"github.com/dapr/dapr/pkg/diagnostics"
"github.com/dapr/dapr/pkg/grpc/proxy/codec"
"github.com/dapr/dapr/pkg/resiliency"
Expand All @@ -37,7 +36,7 @@ func RegisterService(server *grpc.Server, director StreamDirector, resiliency re
}
fakeDesc := &grpc.ServiceDesc{
ServiceName: serviceName,
HandlerType: (*interface{})(nil),
HandlerType: (*any)(nil),
}
for _, m := range methodNames {
streamDesc := grpc.StreamDesc{
Expand All @@ -58,11 +57,10 @@ func RegisterService(server *grpc.Server, director StreamDirector, resiliency re
// This can *only* be used if the `server` also uses grpcproxy.CodecForServer() ServerOption.
func TransparentHandler(director StreamDirector, resiliency resiliency.Provider, isLocalFn func(string) (bool, error), connFactory DirectorConnectionFactory) grpc.StreamHandler {
streamer := &handler{
director: director,
resiliency: resiliency,
isLocalFn: isLocalFn,
bufferedCalls: sync.Map{},
connFactory: connFactory,
director: director,
resiliency: resiliency,
isLocalFn: isLocalFn,
connFactory: connFactory,
}
return streamer.handler
}
Expand All @@ -72,27 +70,25 @@ type handler struct {
resiliency resiliency.Provider
isLocalFn func(string) (bool, error)
bufferedCalls sync.Map
headersSent sync.Map
connFactory DirectorConnectionFactory
}

// handler is where the real magic of proxying happens.
// It is invoked like any gRPC server stream and uses the gRPC server framing to get and receive bytes from the wire,
// forwarding it to a ClientStream established against the relevant ClientConn.
func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error {
func (s *handler) handler(srv any, serverStream grpc.ServerStream) error {
// Create buffered calls for this request.
requestIDObj, err := uuid.NewRandom()
if err != nil {
return status.Errorf(codes.Internal, "failed to generate UUID: %v", err)
}
requestID := requestIDObj.String()
s.bufferedCalls.Store(requestID, []interface{}{})
s.headersSent.Store(requestID, false)
s.bufferedCalls.Store(requestID, []any{})

// little bit of gRPC internals never hurt anyone
fullMethodName, ok := grpc.MethodFromServerStream(serverStream)
if !ok {
return status.Errorf(codes.Internal, "lowLevelServerStream not exists in context")
return status.Errorf(codes.Internal, "full method name not found in stream")
}

// Fetch the AppId so we can reference it for resiliency.
Expand All @@ -107,8 +103,8 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
policyDef = noOp.EndpointPolicy("", "")
} else {
isLocal, err := s.isLocalFn(v[0])
if err == nil && isLocal {
policyDef = s.resiliency.EndpointPolicy(v[0], fmt.Sprintf("%s:%s", v[0], fullMethodName))
if err == nil && !isLocal {
policyDef = s.resiliency.EndpointPolicy(v[0], v[0]+":"+fullMethodName)
} else {
noOp := resiliency.NoOp{}
policyDef = noOp.EndpointPolicy("", "")
Expand All @@ -119,7 +115,12 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
clientStreamOpts := []grpc.CallOption{
grpc.CallContentSubtype((&codec.Proxy{}).Name()),
}
headersSent := &atomic.Bool{}
counter := atomic.Int32{}
_, cErr := policyRunner(func(ctx context.Context) (struct{}, error) {
// Get the current iteration count
iter := counter.Add(1)

// We require that the director's returned context inherits from the ctx.
outgoingCtx, backendConn, target, teardown, err := s.director(ctx, fullMethodName)
defer teardown(false)
Expand All @@ -128,12 +129,26 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
}

clientCtx, clientCancel := context.WithCancel(outgoingCtx)
defer clientCancel()

// TODO(mwitkow): Add a `forwarded` header to metadata, https://en.wikipedia.org/wiki/X-Forwarded-For.
clientStream, err := grpc.NewClientStream(clientCtx, clientStreamDescForProxying, backendConn, fullMethodName, clientStreamOpts...)
clientStream, err := grpc.NewClientStream(
clientCtx,
clientStreamDescForProxying,
backendConn,
fullMethodName,
clientStreamOpts...,
)
if err != nil {
code := status.Code(err)
if target != nil && (code == codes.Unavailable || code == codes.Unauthenticated) {
// It's possible that we get to this point while another goroutine is executing the same policy function.
// For example, this could happen if this iteration has timed out and "policyRunner" has triggered a new execution already.
// In this case, we should not teardown the connection because it could being used by the next execution. So just return and move on.
if counter.Load() != iter {
return struct{}{}, err
}

// Destroy the connection so it can be recreated
teardown(true)

Expand All @@ -148,29 +163,30 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
if err != nil {
return struct{}{}, err
}
} else {
return struct{}{}, err
}
}

// Explicitly *do not close* s2cErrChan and c2sErrChan, otherwise the select below will not terminate.
// Channels do not have to be closed, it is just a control flow mechanism, see
// https://groups.google.com/forum/#!msg/golang-nuts/pZwdYRGxCIk/qpbHxRRPJdUJ
s2cErrChan := s.forwardServerToClient(serverStream, clientStream, requestID)
c2sErrChan := s.forwardClientToServer(clientStream, serverStream, requestID)
c2sErrChan := s.forwardClientToServer(clientStream, serverStream, headersSent)
// We don't know which side is going to stop sending first, so we need a select between the two.
for i := 0; i < 2; i++ {
select {
case s2cErr := <-s2cErrChan:
if s2cErr == io.EOF {
// this is the happy case where the sender has encountered io.EOF, and won't be sending anymore./
// this is the happy case where the sender has encountered io.EOF, and won't be sending anymore.
// the clientStream>serverStream may continue pumping though.
clientStream.CloseSend()
continue
} else {
// however, we may have gotten a receive error (stream disconnected, a read error etc) in which case we need
// to cancel the clientStream to the backend, let all of its goroutines be freed up by the CancelFunc and
// exit with an error to the stack
clientCancel()
return struct{}{}, status.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr)
return struct{}{}, status.Error(codes.Internal, "failed proxying s2c: "+s2cErr.Error())
}
case c2sErr := <-c2sErrChan:
// This happens when the clientStream has nothing else to offer (io.EOF), returned a gRPC error. In those two
Expand All @@ -184,42 +200,44 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
return struct{}{}, nil
}
}
return struct{}{}, status.Errorf(codes.Internal, "gRPC proxying should never reach this stage.")

return struct{}{}, status.Error(codes.Internal, "gRPC proxying should never reach this stage")
})

// Clear the request's buffered calls.
s.bufferedCalls.Delete(requestID)
s.headersSent.Delete(requestID)
return cErr
}

func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream, requestID string) chan error {
func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream, headersSent *atomic.Bool) chan error {
ret := make(chan error, 1)
go func() {
var err error
f := &codec.Frame{}
syncMapValue, _ := s.headersSent.Load(requestID)
localHeaders := syncMapValue.(bool)
for i := 0; ; i++ {
if err := src.RecvMsg(f); err != nil {

for src.Context().Err() == nil && dst.Context().Err() == nil {
err = src.RecvMsg(f)
if err != nil {
ret <- err // this can be io.EOF which is happy case
break
}
// In the case of retries, don't resend the headers.
if i == 0 && !localHeaders {
if headersSent.CompareAndSwap(false, true) {
// This is a bit of a hack, but client to server headers are only readable after first client msg is
// received but must be written to server stream before the first msg is flushed.
// This is the only place to do it nicely.
md, err := src.Header()
var md metadata.MD
md, err = src.Header()
if err != nil {
break
}
if err := dst.SendHeader(md); err != nil {
err = dst.SendHeader(md)
if err != nil {
break
}
localHeaders = true
s.headersSent.Store(requestID, true)
}
if err := dst.SendMsg(f); err != nil {
err = dst.SendMsg(f)
if err != nil {
break
}
}
Expand All @@ -230,23 +248,31 @@ func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerSt
func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream, requestID string) chan error {
ret := make(chan error, 1)
go func() {
f := &codec.Frame{}
var err error

// Start by sending buffered messages
syncMapValue, _ := s.bufferedCalls.Load(requestID)
bufferedFrames := syncMapValue.([]interface{})
bufferedFrames := syncMapValue.([]any)
for _, msg := range bufferedFrames {
if err := dst.SendMsg(msg); err != nil {
err = dst.SendMsg(msg)
if err != nil {
ret <- err
return
}
}
for i := 0; ; i++ {
if err := src.RecvMsg(f); err != nil {

// Receive messages from the source stream and forward them to the destination stream
for src.Context().Err() == nil && dst.Context().Err() == nil {
f := &codec.Frame{}
err = src.RecvMsg(f)
if err != nil {
s.bufferedCalls.Store(requestID, bufferedFrames)
ret <- err // this can be io.EOF which is happy case
break
}
bufferedFrames = append(bufferedFrames, f)
if err := dst.SendMsg(f); err != nil {
err = dst.SendMsg(f)
if err != nil {
s.bufferedCalls.Store(requestID, bufferedFrames)
break
}
Expand Down
Loading

0 comments on commit 3e32456

Please sign in to comment.