Skip to content

Commit

Permalink
Add interceptors to the proxy connection
Browse files Browse the repository at this point in the history
This is heavily based on gRPC interceptors from https://github.com/grpc/grpc-go/blob/master/examples/features/interceptor/README.md and serve the same purpose for proxy-level interactions. These interceptors allow hooking into calls before they get fanned out to many targets. The public interface is new fields on the `proxy.Conn` struct because I'm introducing these in a backwards-compatible way.

I'm planning on using this for MPA requests in sanssh. The interceptor lets us hook into every call and turn it into a series of calls that perform the MPA request flow and then pass on the MPA request as metadata into the call. I'm not completely happy with separate calls for unary and streaming, especially because a message with a single request and a streamed response counts as "streaming", but it matches what gRPC does in its calls.

It's theoretically possible to avoid adding this interceptor and instead do what I'd like via a grpc.StreamClientInterceptor. Doing so would require writing something tightly coupled to the logic in the proxy client code and more prone to bugs. It would involve weird things like lying to the client and claiming that we've connected to targets before we connect to any because the message containing the rpc request doesn't get sent until the proxy client code thinks it has established all connections.

Part of #346
  • Loading branch information
stvnrhodes committed Oct 20, 2023
1 parent d7d6050 commit e05d031
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 25 deletions.
66 changes: 58 additions & 8 deletions proxy/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,26 @@ import (
proxypb "github.com/Snowflake-Labs/sansshell/proxy"
)

// UnaryInterceptor intercepts the execution of an RPC through the proxy. Interceptors
// can be added to a Conn by modifying its Interceptors field. When an interceptor
// is set on a ClientConn, all proxy RPC invocations are delegated to the interceptor,
// and it is the responsibility of the interceptor to call invoker to complete the
// processing of the RPC.
type UnaryInterceptor func(ctx context.Context, conn *Conn, method string, args any, invoker UnaryInvoker, opts ...grpc.CallOption) (<-chan *Ret, error)

// UnaryInvoker is called by UnaryInterceptor to complete RPCs.
type UnaryInvoker func(ctx context.Context, method string, args any, opts ...grpc.CallOption) (<-chan *Ret, error)

// StreamInterceptor intercepts the execution of an RPC through the proxy. Interceptors
// can be added to a Conn by modifying its Interceptors field. When an interceptor
// is set on a ClientConn, all proxy RPC invocations are delegated to the interceptor,
// and it is the responsibility of the interceptor to call invoker to complete the
// processing of the RPC.
type StreamInterceptor func(ctx context.Context, desc *grpc.StreamDesc, cc *Conn, method string, streamer Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error)

// Streamer is called by StreamInterceptor to complete RPCs.
type Streamer func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error)

// Conn is a grpc.ClientConnInterface which is connected to the proxy
// converting calls into RPC the proxy understands.
type Conn struct {
Expand All @@ -55,6 +75,15 @@ type Conn struct {

// If this is true we're not proxy but instead direct connect.
direct bool

// UnaryInterceptors allow intercepting Invoke and InvokeOneMany calls
// that go through a proxy.
// It is unsafe to modify Intercepters while calls are in progress.
UnaryInterceptors []UnaryInterceptor

// StreamInterceptors allow intercepting NewStream calls that go through a proxy.
// It is unsafe to modify Intercepters while calls are in progress.
StreamInterceptors []StreamInterceptor
}

// Ret defines the internal API for getting responses from the proxy.
Expand Down Expand Up @@ -91,7 +120,7 @@ type proxyStream struct {
}

// Invoke - see grpc.ClientConnInterface
func (p *Conn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
func (p *Conn) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
if p.Direct() {
// TODO(jchacon): Add V1 style logging indicating pass through in use.
return p.cc.Invoke(ctx, method, args, reply, opts...)
Expand Down Expand Up @@ -149,7 +178,18 @@ func (p *Conn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method stri
// TODO(jchacon): Add V1 style logging indicating pass through in use.
return p.cc.NewStream(ctx, desc, method, opts...)
}
stream := p.newStream
for i := len(p.StreamInterceptors) - 1; i >= 0; i-- {
intercept := p.StreamInterceptors[i]
inner := stream
stream = func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return intercept(ctx, desc, p, method, inner, opts...)
}
}
return stream(ctx, desc, method, opts...)
}

func (p *Conn) newStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
stream, streamIds, errors, err := p.createStreams(ctx, method)
if err != nil {
return nil, err
Expand Down Expand Up @@ -247,7 +287,7 @@ func (p *proxyStream) closeClients() error {
}

// see grpc.ClientStream
func (p *proxyStream) SendMsg(args interface{}) error {
func (p *proxyStream) SendMsg(args any) error {
if p.sendClosed {
return status.Error(codes.FailedPrecondition, "sending on a closed connection")
}
Expand All @@ -262,13 +302,13 @@ func (p *proxyStream) SendMsg(args interface{}) error {
}

// see grpc.ClientStream
func (p *proxyStream) RecvMsg(m interface{}) error {
func (p *proxyStream) RecvMsg(m any) error {
// Up front check for nothing left since we closed all streams.
if len(p.ids) == 0 {
return io.EOF
}

// Since the API is an interface{} we can change what this normally
// Since the API is an any we can change what this normally
// expects from a proto.Message to a *[]*Ret instead.
//
// Anything else is an error if we have > 1 target. In the one target
Expand Down Expand Up @@ -515,7 +555,7 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_
return stream, streamIds, errors, nil
}

// InvokeOneMany is used in proto generated code to implemened unary OneMany methods doing 1:N calls to the proxy.
// InvokeOneMany is used in proto generated code to implement unary OneMany methods doing 1:N calls to the proxy.
// This returns ProxyRet objects from the channel which contain anypb.Any so the caller (generally generated code)
// will need to convert those to the proper expected specific types.
//
Expand All @@ -525,9 +565,19 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_
// invoked with a context timeout lower than the remote server Dial timeout.
//
// NOTE: The returned channel must be read until it closes in order to avoid leaking goroutines.
//
// TODO(jchacon): Should add the ability to specify remote dial timeout in the connection to the proxy.
func (p *Conn) InvokeOneMany(ctx context.Context, method string, args interface{}, opts ...grpc.CallOption) (<-chan *Ret, error) {
func (p *Conn) InvokeOneMany(ctx context.Context, method string, args any, opts ...grpc.CallOption) (<-chan *Ret, error) {
invoke := p.invokeOneMany
for i := len(p.UnaryInterceptors) - 1; i >= 0; i-- {
intercept := p.UnaryInterceptors[i]
inner := invoke
invoke = func(ctx context.Context, method string, args any, opts ...grpc.CallOption) (<-chan *Ret, error) {
return intercept(ctx, p, method, args, inner, opts...)
}
}
return invoke(ctx, method, args, opts...)
}

func (p *Conn) invokeOneMany(ctx context.Context, method string, args any, opts ...grpc.CallOption) (<-chan *Ret, error) {
requestMsg, ok := args.(proto.Message)
if !ok {
return nil, status.Error(codes.InvalidArgument, "args must be a proto.Message")
Expand Down
130 changes: 113 additions & 17 deletions proxy/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"fmt"
"io"
"net"
"reflect"
"strings"
"testing"

"google.golang.org/grpc"
Expand Down Expand Up @@ -177,12 +179,15 @@ func TestUnary(t *testing.T) {
bufMap[k] = v
}

interceptorCnt := make(map[string]int)

for _, tc := range []struct {
name string
proxy string
targets []string
wantErrOneMany bool
wantErr bool
name string
proxy string
targets []string
unaryInterceptors []proxy.UnaryInterceptor
wantErrOneMany bool
wantErr bool
}{
{
name: "proxy N targets",
Expand All @@ -206,14 +211,57 @@ func TestUnary(t *testing.T) {
name: "no proxy 1 target",
targets: []string{"foo:123"},
},
{
name: "proxy interceptor error",
proxy: "proxy",
unaryInterceptors: []proxy.UnaryInterceptor{
func(ctx context.Context, conn *proxy.Conn, method string, args any, invoker proxy.UnaryInvoker, opts ...grpc.CallOption) (<-chan *proxy.Ret, error) {
if method == "bad_method" {
return invoker(ctx, method, args, opts...)
}
return nil, fmt.Errorf("interceptor err")
},
},
wantErr: true,
wantErrOneMany: true,
targets: []string{"foo:123"},
},
{
name: "no proxy no error",
unaryInterceptors: []proxy.UnaryInterceptor{
func(ctx context.Context, conn *proxy.Conn, method string, args any, invoker proxy.UnaryInvoker, opts ...grpc.CallOption) (<-chan *proxy.Ret, error) {
return nil, fmt.Errorf("interceptor err")
},
},
targets: []string{"foo:123"},
},
{
name: "chained proxy",
proxy: "proxy",
unaryInterceptors: []proxy.UnaryInterceptor{
func(ctx context.Context, conn *proxy.Conn, method string, args any, invoker proxy.UnaryInvoker, opts ...grpc.CallOption) (<-chan *proxy.Ret, error) {
interceptorCnt[method]++
return invoker(ctx, method+"chain", args, opts...)
},
func(ctx context.Context, conn *proxy.Conn, method string, args any, invoker proxy.UnaryInvoker, opts ...grpc.CallOption) (<-chan *proxy.Ret, error) {
if !strings.HasSuffix(method, "chain") {
return nil, fmt.Errorf("method should have chain: %v", method)
}
interceptorCnt[method]++
return invoker(ctx, strings.TrimSuffix(method, "chain"), args, opts...)
},
},
targets: []string{"foo:123"},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
conn, err := proxy.Dial(tc.proxy, tc.targets, testutil.WithBufDialer(bufMap), grpc.WithTransportCredentials(insecure.NewCredentials()))
tu.FatalOnErr("Dial", err, t)
conn.UnaryInterceptors = tc.unaryInterceptors

ts := tdpb.NewTestServiceClientProxy(conn)
resp, err := ts.TestUnaryOneMany(context.Background(), &tdpb.TestRequest{Input: "input"})
resp, err := ts.TestUnaryOneMany(ctx, &tdpb.TestRequest{Input: "input"})
t.Log(err)

for r := range resp {
Expand All @@ -223,11 +271,11 @@ func TestUnary(t *testing.T) {
}
}

_, err = ts.TestUnary(context.Background(), &tdpb.TestRequest{Input: "input"})
_, err = ts.TestUnary(ctx, &tdpb.TestRequest{Input: "input"})
t.Log(err)
tu.WantErr(tc.name, err, tc.wantErr, t)

resp, err = ts.TestUnaryOneMany(context.Background(), &tdpb.TestRequest{Input: "error"})
resp, err = ts.TestUnaryOneMany(ctx, &tdpb.TestRequest{Input: "error"})
tu.FatalOnErr("TestUnaryOneMany error", err, t)
for r := range resp {
t.Logf("%+v", r)
Expand All @@ -236,12 +284,12 @@ func TestUnary(t *testing.T) {

// Check pass through cases
if tc.proxy != "" {
_, err = ts.TestUnary(context.Background(), &tdpb.TestRequest{Input: "error"})
_, err = ts.TestUnary(ctx, &tdpb.TestRequest{Input: "error"})
tu.FatalOnNoErr("TestUnary error", err, t)
}

// Do some direct calls against the conn to get at error cases
resp2, err := conn.InvokeOneMany(context.Background(), "bad_method", &tdpb.TestRequest{Input: "input"})
resp2, err := conn.InvokeOneMany(ctx, "bad_method", &tdpb.TestRequest{Input: "input"})
if tc.proxy == "" {
tu.FatalOnNoErr("InvokeOneMany bad msg", err, t)
} else {
Expand All @@ -250,13 +298,23 @@ func TestUnary(t *testing.T) {
tu.FatalOnNoErr("InvokeOneMany bad method", r.Error, t)
}
}
_, err = conn.InvokeOneMany(context.Background(), "/Testdata.TestService/TestUnary", nil)
_, err = conn.InvokeOneMany(ctx, "/Testdata.TestService/TestUnary", nil)
tu.FatalOnNoErr("InvokeOneMany bad msg", err, t)

err = conn.Close()
tu.FatalOnErr("conn Close()", err, t)
})
}

wantCalled := map[string]int{
"/Testdata.TestService/TestUnary": 5,
"/Testdata.TestService/TestUnarychain": 5,
"bad_method": 1,
"bad_methodchain": 1,
}
if !reflect.DeepEqual(wantCalled, interceptorCnt) {
t.Errorf("wrong callers from interceptors: got %v, want %v", interceptorCnt, wantCalled)
}
}

func TestStreaming(t *testing.T) {
Expand All @@ -269,12 +327,15 @@ func TestStreaming(t *testing.T) {
bufMap[k] = v
}

interceptorCnt := make(map[string]int)

for _, tc := range []struct {
name string
proxy string
targets []string
wantErrOneMany bool
wantErr bool
name string
proxy string
targets []string
streamInterceptors []proxy.StreamInterceptor
wantErrOneMany bool
wantErr bool
}{
{
name: "proxy N targets",
Expand All @@ -291,14 +352,42 @@ func TestStreaming(t *testing.T) {
name: "no proxy 1 target",
targets: []string{"foo:123"},
},
{
name: "no proxy no err",
streamInterceptors: []proxy.StreamInterceptor{
func(ctx context.Context, desc *grpc.StreamDesc, cc *proxy.Conn, method string, streamer proxy.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return nil, fmt.Errorf("interceptor fail")
},
},
targets: []string{"foo:123"},
},
{
name: "proxy 1 target chained",
proxy: "proxy",
streamInterceptors: []proxy.StreamInterceptor{
func(ctx context.Context, desc *grpc.StreamDesc, cc *proxy.Conn, method string, streamer proxy.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
interceptorCnt[method]++
return streamer(ctx, desc, method+"chain", opts...)
},
func(ctx context.Context, desc *grpc.StreamDesc, cc *proxy.Conn, method string, streamer proxy.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if !strings.HasSuffix(method, "chain") {
return nil, fmt.Errorf("method should have chain: %v", method)
}
interceptorCnt[method]++
return streamer(ctx, desc, strings.TrimSuffix(method, "chain"), opts...)
},
},
targets: []string{"foo:123"},
},
} {
tc := tc
t.Run(tc.name+" direct", func(t *testing.T) {
conn, err := proxy.Dial(tc.proxy, tc.targets, testutil.WithBufDialer(bufMap), grpc.WithTransportCredentials(insecure.NewCredentials()))
tu.FatalOnErr("Dial", err, t)
conn.StreamInterceptors = tc.streamInterceptors

ts := tdpb.NewTestServiceClientProxy(conn)
stream, err := ts.TestBidiStream(context.Background())
stream, err := ts.TestBidiStream(ctx)
tu.FatalOnErr("getting stream", err, t)

// We only care about validating Send/Recv work cleanly in 1:1 or error in 1:N
Expand Down Expand Up @@ -403,6 +492,13 @@ func TestStreaming(t *testing.T) {
})
}

wantCalled := map[string]int{
"/Testdata.TestService/TestBidiStream": 1,
"/Testdata.TestService/TestBidiStreamchain": 1,
}
if !reflect.DeepEqual(wantCalled, interceptorCnt) {
t.Errorf("wrong callers from interceptors: got %v, want %v", interceptorCnt, wantCalled)
}
}

type fakeProxy struct {
Expand Down

0 comments on commit e05d031

Please sign in to comment.