From cfc62d8cc52debf797e83ffe3066aec02fa19103 Mon Sep 17 00:00:00 2001 From: Edbert Linardi Date: Mon, 9 Oct 2023 13:04:26 -0700 Subject: [PATCH] unmarshal json rawmessage into proto message --- auth/opa/rpcauth/input.go | 13 ++++++++++--- auth/opa/rpcauth/rpcauth.go | 13 ++++++++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/auth/opa/rpcauth/input.go b/auth/opa/rpcauth/input.go index 64a2025d..edd8c9d7 100644 --- a/auth/opa/rpcauth/input.go +++ b/auth/opa/rpcauth/input.go @@ -22,9 +22,12 @@ import ( "encoding/json" "net" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" ) @@ -34,8 +37,8 @@ type RPCAuthInput struct { // The GRPC method name, as '/Package.Service/Method' Method string `json:"method"` - // The request protocol buffer message - Message proto.Message `json:"message"` + // The request protocol buffer, serialized as JSON. + Message json.RawMessage `json:"message"` // The message type as 'Package.Message' MessageType string `json:"type"` @@ -140,7 +143,11 @@ func NewRPCAuthInput(ctx context.Context, method string, req proto.Message) (*RP if req != nil { out.MessageType = string(proto.MessageName(req)) - out.Message = req + marshaled, err := protojson.MarshalOptions{UseProtoNames: true}.Marshal(req) + if err != nil { + return nil, status.Errorf(codes.Internal, "error marshalling request for auth: %v", err) + } + out.Message = json.RawMessage(marshaled) } out.Peer = PeerInputFromContext(ctx) return out, nil diff --git a/auth/opa/rpcauth/rpcauth.go b/auth/opa/rpcauth/rpcauth.go index 612b8ed4..949f1ab6 100644 --- a/auth/opa/rpcauth/rpcauth.go +++ b/auth/opa/rpcauth/rpcauth.go @@ -20,6 +20,7 @@ package rpcauth import ( "context" + "fmt" "strings" "github.com/go-logr/logr" @@ -27,8 +28,10 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" "github.com/Snowflake-Labs/sansshell/auth/opa" @@ -165,7 +168,15 @@ func (g *Authorizer) Eval(ctx context.Context, input *RPCAuthInput) error { recorder := metrics.RecorderFromContextOrNoop(ctx) var redactedInput protoreflect.ProtoMessage // use this for logging if input != nil { - redactedInput = proto.Clone(input.Message) + // Transform the rpcauth input into the original proto + messageType, err := protoregistry.GlobalTypes.FindMessageByURL(input.MessageType) + if err != nil { + return fmt.Errorf("unable to find proto type: %v", err) + } + redactedInput = messageType.New().Interface() + if err := protojson.Unmarshal([]byte(input.Message), redactedInput); err != nil { + return fmt.Errorf("could not marshal input into %v: %v", input.Message, err) + } redactFields(redactedInput.ProtoReflect()) } if input != nil {