diff --git a/auth/opa/rpcauth/redact.go b/auth/opa/rpcauth/redact.go index 6173c2ba..5aa96085 100644 --- a/auth/opa/rpcauth/redact.go +++ b/auth/opa/rpcauth/redact.go @@ -26,6 +26,7 @@ import ( "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/known/anypb" ) func isMessage(descriptor protoreflect.FieldDescriptor) bool { @@ -40,28 +41,65 @@ func isDebugRedactEnabled(fd protoreflect.FieldDescriptor) bool { return opts.GetDebugRedact() } -func redactListField(value protoreflect.Value) { +func redactListField(value protoreflect.Value) error { for i := 0; i < value.List().Len(); i++ { - redactFields(value.List().Get(i).Message()) + errRedact := redactFields(value.List().Get(i).Message()) + if errRedact != nil { + return errRedact + } } + return nil } -func redactMapField(value protoreflect.Value) { +func redactMapField(value protoreflect.Value) error { + var err error value.Map().Range(func(mapKey protoreflect.MapKey, mapValue protoreflect.Value) bool { - redactFields(mapValue.Message()) + errRedact := redactFields(mapValue.Message()) + if errRedact != nil { + err = errRedact + return false + } return true }) + return err } -func redactNestedMessage(message protoreflect.Message, descriptor protoreflect.FieldDescriptor, value protoreflect.Value) { +var anypbFullName = (&anypb.Any{}).ProtoReflect().Descriptor().FullName() + +func redactAny(message protoreflect.Message, descriptor protoreflect.FieldDescriptor, value protoreflect.Value) error { + anyMsg, ok := value.Message().Interface().(*anypb.Any) // cast value to redactAny + if !ok { + return fmt.Errorf("failed to cast message into any") + } + originalMsg, errUnmarshal := anyMsg.UnmarshalNew() // unmarshal any into original message + if errUnmarshal != nil { + return fmt.Errorf("failed to unmarshal anypb: %v", errUnmarshal) + } + errRedact := redactFields(originalMsg.ProtoReflect()) // redact original message + if errRedact != nil { + return errRedact + } + redactedAny, errAny := anypb.New(originalMsg) // cast redacted message back to any + if errAny != nil { + return fmt.Errorf("failed to cast into anypb: %v", errAny) + } + message.Set(descriptor, protoreflect.ValueOf(redactedAny.ProtoReflect())) // set the redacted + + return nil +} + +func redactNestedField(message protoreflect.Message, descriptor protoreflect.FieldDescriptor, value protoreflect.Value) error { switch { case descriptor.IsList() && isMessage(descriptor): - redactListField(value) + return redactListField(value) case descriptor.IsMap() && isMessage(descriptor): - redactMapField(value) + return redactMapField(value) + case descriptor.Message() != nil && descriptor.Message().FullName() == anypbFullName: + return redactAny(message, descriptor, value) case !descriptor.IsMap() && isMessage(descriptor): - redactFields(value.Message()) + return redactFields(value.Message()) } + return nil } func redactSingleField(message protoreflect.Message, descriptor protoreflect.FieldDescriptor) { @@ -80,17 +118,23 @@ func redactSingleField(message protoreflect.Message, descriptor protoreflect.Fie } } -func redactFields(message protoreflect.Message) { +func redactFields(message protoreflect.Message) error { + var err error message.Range( func(descriptor protoreflect.FieldDescriptor, value protoreflect.Value) bool { if isDebugRedactEnabled(descriptor) { redactSingleField(message, descriptor) return true } - redactNestedMessage(message, descriptor, value) + errNested := redactNestedField(message, descriptor, value) + if errNested != nil { + err = errNested + return false + } return true }, ) + return err } func getRedactedInput(input *RPCAuthInput) (RPCAuthInput, error) { @@ -120,7 +164,10 @@ func getRedactedInput(input *RPCAuthInput) (RPCAuthInput, error) { if err := protojson.Unmarshal([]byte(input.Message), redactedMessage); err != nil { return RPCAuthInput{}, fmt.Errorf("could not marshal input into %v: %v", input.MessageType, err) } - redactFields(redactedMessage.ProtoReflect()) + errRedact := redactFields(redactedMessage.ProtoReflect()) + if errRedact != nil { + return RPCAuthInput{}, fmt.Errorf("failed to redact message fields: %v", errRedact) + } } marshaled, err := protojson.MarshalOptions{UseProtoNames: true}.Marshal(redactedMessage) if err != nil { diff --git a/auth/opa/rpcauth/redact_test.go b/auth/opa/rpcauth/redact_test.go index 860199e8..9eb0d183 100644 --- a/auth/opa/rpcauth/redact_test.go +++ b/auth/opa/rpcauth/redact_test.go @@ -20,10 +20,12 @@ import ( "context" "testing" + proxypb "github.com/Snowflake-Labs/sansshell/proxy" httppb "github.com/Snowflake-Labs/sansshell/services/httpoverrpc" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/known/anypb" ) func TestGetRedactedInput(t *testing.T) { @@ -39,8 +41,18 @@ func TestGetRedactedInput(t *testing.T) { }, }, } - mockInput, _ := NewRPCAuthInput(context.TODO(), "/HTTPOverRPC.HTTPOverRPC/Host", httpReq.ProtoReflect().Interface()) + httpReqInput, _ := NewRPCAuthInput(context.TODO(), "/HTTPOverRPC.HTTPOverRPC/Host", httpReq.ProtoReflect().Interface()) + payload, _ := anypb.New(httpReq.ProtoReflect().Interface()) + proxyReq := &proxypb.ProxyRequest{ + Request: &proxypb.ProxyRequest_StreamData{ + StreamData: &proxypb.StreamData{ + StreamIds: []uint64{1}, + Payload: payload, + }, + }, + } + proxyReqInput, _ := NewRPCAuthInput(context.TODO(), "/Proxy.Proxy/Proxy", proxyReq.ProtoReflect().Interface()) for _, tc := range []struct { name string createInputFn func() *RPCAuthInput @@ -50,10 +62,10 @@ func TestGetRedactedInput(t *testing.T) { { name: "redacted fields should be redacted", createInputFn: func() *RPCAuthInput { - return mockInput + return httpReqInput }, assertionFn: func(result RPCAuthInput) { - messageType, _ := protoregistry.GlobalTypes.FindMessageByURL(mockInput.MessageType) + messageType, _ := protoregistry.GlobalTypes.FindMessageByURL(httpReqInput.MessageType) resultMessage := messageType.New().Interface() err := protojson.Unmarshal([]byte(result.Message), resultMessage) assert.NoError(t, err) @@ -67,6 +79,29 @@ func TestGetRedactedInput(t *testing.T) { assert.NoError(t, err) }, }, + { + name: "any containing redacted_fields should be redacted", + createInputFn: func() *RPCAuthInput { + return proxyReqInput + }, + assertionFn: func(result RPCAuthInput) { + messageType, _ := protoregistry.GlobalTypes.FindMessageByURL(proxyReqInput.MessageType) + resultMessage := messageType.New().Interface() + err := protojson.Unmarshal([]byte(result.Message), resultMessage) + assert.NoError(t, err) + + proxyReq := resultMessage.(*proxypb.ProxyRequest) + proxyReqPayload := proxyReq.GetStreamData().Payload + payloadMsg, _ := proxyReqPayload.UnmarshalNew() + httpReq := payloadMsg.(*httppb.HostHTTPRequest) + + assert.Equal(t, "--REDACTED--", httpReq.Request.Headers[0].Values[0]) // field with debug_redact should be redacted + assert.Equal(t, "key0", httpReq.Request.Headers[0].Key) // field without debug_redact should not be redacted + }, + errFunc: func(t *testing.T, err error) { + assert.NoError(t, err) + }, + }, { name: "malformed input should return err", createInputFn: func() *RPCAuthInput {