diff --git a/services/mpa/mpahooks/mpahooks.go b/services/mpa/mpahooks/mpahooks.go index ab7d2468..bc9bf15a 100644 --- a/services/mpa/mpahooks/mpahooks.go +++ b/services/mpa/mpahooks/mpahooks.go @@ -85,14 +85,15 @@ func ActionMatchesInput(ctx context.Context, action *mpa.Action, input *rpcauth. if err := msg.MarshalFrom(m2); err != nil { return fmt.Errorf("unable to marshal into anyproto: %v", err) } - if input.Peer == nil || input.Peer.Principal == nil { - return fmt.Errorf("missing peer information") - } // Prefer using a proxied identity if provided - user := input.Peer.Principal.ID + var user string if p := proxiedidentity.FromContext(ctx); p != nil { user = p.ID + } else if input.Peer != nil && input.Peer.Principal != nil { + user = input.Peer.Principal.ID + } else { + return fmt.Errorf("missing peer information") } sentAct := &mpa.Action{ diff --git a/services/mpa/mpahooks/mpahooks_test.go b/services/mpa/mpahooks/mpahooks_test.go index 859cde68..f0bcae15 100644 --- a/services/mpa/mpahooks/mpahooks_test.go +++ b/services/mpa/mpahooks/mpahooks_test.go @@ -42,6 +42,7 @@ import ( "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/emptypb" @@ -60,14 +61,26 @@ func mustAny(a *anypb.Any, err error) *anypb.Any { func TestActionMatchesInput(t *testing.T) { ctx := context.Background() + var ctxWithIdentity context.Context + if _, err := proxiedidentity.ServerProxiedIdentityUnaryInterceptor()( + metadata.NewIncomingContext(ctx, metadata.Pairs("proxied-sansshell-identity", `{"id":"proxied"}`)), + nil, nil, func(ctx context.Context, req any) (any, error) { + ctxWithIdentity = ctx + return nil, nil + }); err != nil { + t.Fatal(err) + } + for _, tc := range []struct { desc string + ctx context.Context action *mpa.Action input *rpcauth.RPCAuthInput matches bool }{ { desc: "basic action", + ctx: ctx, action: &mpa.Action{ User: "requester", Method: "foobar", @@ -87,6 +100,7 @@ func TestActionMatchesInput(t *testing.T) { }, { desc: "missing auth info", + ctx: ctx, action: &mpa.Action{ User: "requester", Method: "foobar", @@ -101,6 +115,7 @@ func TestActionMatchesInput(t *testing.T) { }, { desc: "wrong message", + ctx: ctx, action: &mpa.Action{ User: "requester", Method: "foobar", @@ -118,9 +133,44 @@ func TestActionMatchesInput(t *testing.T) { }, matches: false, }, + { + desc: "proxied identity with peer", + ctx: ctxWithIdentity, + action: &mpa.Action{ + User: "proxied", + Method: "foobar", + Message: mustAny(anypb.New(&emptypb.Empty{})), + }, + input: &rpcauth.RPCAuthInput{ + Method: "foobar", + MessageType: "google.protobuf.Empty", + Message: []byte("{}"), + Peer: &rpcauth.PeerAuthInput{ + Principal: &rpcauth.PrincipalAuthInput{ + ID: "requester", + }, + }, + }, + matches: true, + }, + { + desc: "proxied identity without peer", + ctx: ctxWithIdentity, + action: &mpa.Action{ + User: "proxied", + Method: "foobar", + Message: mustAny(anypb.New(&emptypb.Empty{})), + }, + input: &rpcauth.RPCAuthInput{ + Method: "foobar", + MessageType: "google.protobuf.Empty", + Message: []byte("{}"), + }, + matches: true, + }, } { t.Run(tc.desc, func(t *testing.T) { - err := mpahooks.ActionMatchesInput(ctx, tc.action, tc.input) + err := mpahooks.ActionMatchesInput(tc.ctx, tc.action, tc.input) if err != nil && tc.matches { t.Errorf("expected match: %v", err) }