Skip to content

Commit

Permalink
Fix RecvMsg in the single target case. (#164)
Browse files Browse the repository at this point in the history
We need to allow this to be called via Recv() as well which means handling proto.Message at the return type too.

So check for both types and base this on # targets for which is valid (or not).

Write a test which catches this case since we missed it before.
  • Loading branch information
sfc-gh-jchacon authored Sep 23, 2022
1 parent 4ce1d37 commit 18d5130
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 7 deletions.
51 changes: 44 additions & 7 deletions proxy/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,18 +269,41 @@ func (p *proxyStream) RecvMsg(m interface{}) error {
}

// Since the API is an interface{} we can change what this normally
// expects from a proto.Message to a *[]*ProxyRet instead.
// expects from a proto.Message to a *[]*Ret instead.
//
// Anything else is an error.
manyRet, ok := m.(*[]*Ret)
if !ok {
return status.Errorf(codes.InvalidArgument, "args for proxy RecvMsg must be a *[]*ProxyRet) - got %T", m)
// Anything else is an error if we have > 1 target. In the one target
// case validate it's a proto.Message and unwrap into that instead.
var replyMsg proto.Message
var manyRet *[]*Ret
switch v := m.(type) {
case *[]*Ret:
manyRet = v
case proto.Message:
if len(p.ids) != 1 {
return status.Errorf(codes.InvalidArgument, "args for proxy RecvMsg must be a *[]*Ret) when called in OneMany context - got %T", m)
}
replyMsg = v
default:
if len(p.ids) != 1 {
return status.Errorf(codes.InvalidArgument, "args for proxy RecvMsg must be a *[]*Ret) when called in OneMany context - got %T", m)
}
return status.Errorf(codes.InvalidArgument, "args for proxy RecvMsg must be proto.Message when called directly - got %T", m)
}

// If we have any pre-canned errors push them on now.
// Only send once or else the user gets spammed with errors for every Recv called.
if !p.sentErrors {
*manyRet = append(*manyRet, p.errors...)
// In non OneMany context just return this directly as an error.
// Any other calls to RecvMsg will fall through below and get whatever
// the stream returns at that point.
if len(p.errors) > 0 {
if replyMsg != nil {
p.sentErrors = true
return p.errors[0].Error
} else {
*manyRet = append(*manyRet, p.errors...)
}
}
p.sentErrors = true
}

Expand All @@ -302,7 +325,17 @@ func (p *proxyStream) RecvMsg(m interface{}) error {
}
p.ids[id].Resp = d.Payload
p.ids[id].Error = nil
*manyRet = append(*manyRet, p.ids[id])
if manyRet != nil {
*manyRet = append(*manyRet, p.ids[id])
}
if replyMsg != nil {
if err := d.Payload.UnmarshalTo(replyMsg); err != nil {
return status.Errorf(codes.Internal, "can't unmarshal reply: %v", err)
}
// We know there's only one due to the precheck when we construct
// replyMsg.
return nil
}
}
case cl != nil:
code := codes.Code(cl.GetStatus().GetCode())
Expand All @@ -325,6 +358,10 @@ func (p *proxyStream) RecvMsg(m interface{}) error {
if streamStatus.Code() != codes.OK {
closedErr = streamStatus.Err()
}
if replyMsg != nil {
// Easy case for Recv()
return closedErr
}
for _, id := range cl.StreamIds {
p.ids[id].Error = closedErr
p.ids[id].Resp = nil
Expand Down
35 changes: 35 additions & 0 deletions proxy/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,41 @@ func TestStreaming(t *testing.T) {
},
} {
tc := tc
t.Run(tc.name+" direct", func(t *testing.T) {
conn, err := proxy.Dial(tc.proxy, tc.targets, testutil.WithBufDialer(bufMap), grpc.WithTransportCredentials(insecure.NewCredentials()))
tu.FatalOnErr("Dial", err, t)

ts := tdpb.NewTestServiceClientProxy(conn)
stream, err := ts.TestBidiStream(context.Background())
tu.FatalOnErr("getting stream", err, t)

// We only care about validating Send/Recv work cleanly in 1:1 or error in 1:N

// Should always be able to Send
err = stream.Send(&tdpb.TestRequest{Input: "input"})
tu.FatalOnErr("Send", err, t)

// Now a normal recv should either work or fail depending on > 1 target (or not)
_, err = stream.Recv()
if len(tc.targets) > 1 {
tu.FatalOnNoErr("recv didn't fail for > 1 target", err, t)
} else {
tu.FatalOnErr("Recv", err, t)
}

// Now test the error case
err = stream.Send(&tdpb.TestRequest{Input: "error"})
tu.FatalOnErr("Send error", err, t)

// Shouldn't fail even we close send twice.
err = stream.CloseSend()
tu.FatalOnErr("CloseSend", err, t)
err = stream.CloseSend()
tu.FatalOnErr("CloseSend", err, t)
_, err = stream.Recv()
tu.FatalOnNoErr("recv should get error from send", err, t)
t.Log(err)
})
t.Run(tc.name, func(t *testing.T) {
conn, err := proxy.Dial(tc.proxy, tc.targets, testutil.WithBufDialer(bufMap), grpc.WithTransportCredentials(insecure.NewCredentials()))
tu.FatalOnErr("Dial", err, t)
Expand Down

0 comments on commit 18d5130

Please sign in to comment.