Skip to content

Commit

Permalink
Redact sensitive values in stream_data rpcauth logs (#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-elinardi authored Oct 17, 2023
1 parent 7b660ab commit 3e76142
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 14 deletions.
69 changes: 58 additions & 11 deletions auth/opa/rpcauth/redact.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/known/anypb"
)

func isMessage(descriptor protoreflect.FieldDescriptor) bool {
Expand All @@ -40,28 +41,65 @@ func isDebugRedactEnabled(fd protoreflect.FieldDescriptor) bool {
return opts.GetDebugRedact()
}

func redactListField(value protoreflect.Value) {
func redactListField(value protoreflect.Value) error {
for i := 0; i < value.List().Len(); i++ {
redactFields(value.List().Get(i).Message())
errRedact := redactFields(value.List().Get(i).Message())
if errRedact != nil {
return errRedact
}
}
return nil
}

func redactMapField(value protoreflect.Value) {
func redactMapField(value protoreflect.Value) error {
var err error
value.Map().Range(func(mapKey protoreflect.MapKey, mapValue protoreflect.Value) bool {
redactFields(mapValue.Message())
errRedact := redactFields(mapValue.Message())
if errRedact != nil {
err = errRedact
return false
}
return true
})
return err
}

func redactNestedMessage(message protoreflect.Message, descriptor protoreflect.FieldDescriptor, value protoreflect.Value) {
var anypbFullName = (&anypb.Any{}).ProtoReflect().Descriptor().FullName()

func redactAny(message protoreflect.Message, descriptor protoreflect.FieldDescriptor, value protoreflect.Value) error {
anyMsg, ok := value.Message().Interface().(*anypb.Any) // cast value to redactAny
if !ok {
return fmt.Errorf("failed to cast message into any")
}
originalMsg, errUnmarshal := anyMsg.UnmarshalNew() // unmarshal any into original message
if errUnmarshal != nil {
return fmt.Errorf("failed to unmarshal anypb: %v", errUnmarshal)
}
errRedact := redactFields(originalMsg.ProtoReflect()) // redact original message
if errRedact != nil {
return errRedact
}
redactedAny, errAny := anypb.New(originalMsg) // cast redacted message back to any
if errAny != nil {
return fmt.Errorf("failed to cast into anypb: %v", errAny)
}
message.Set(descriptor, protoreflect.ValueOf(redactedAny.ProtoReflect())) // set the redacted

return nil
}

func redactNestedField(message protoreflect.Message, descriptor protoreflect.FieldDescriptor, value protoreflect.Value) error {
switch {
case descriptor.IsList() && isMessage(descriptor):
redactListField(value)
return redactListField(value)
case descriptor.IsMap() && isMessage(descriptor):
redactMapField(value)
return redactMapField(value)
case descriptor.Message() != nil && descriptor.Message().FullName() == anypbFullName:
return redactAny(message, descriptor, value)
case !descriptor.IsMap() && isMessage(descriptor):
redactFields(value.Message())
return redactFields(value.Message())
}
return nil
}

func redactSingleField(message protoreflect.Message, descriptor protoreflect.FieldDescriptor) {
Expand All @@ -80,17 +118,23 @@ func redactSingleField(message protoreflect.Message, descriptor protoreflect.Fie
}
}

func redactFields(message protoreflect.Message) {
func redactFields(message protoreflect.Message) error {
var err error
message.Range(
func(descriptor protoreflect.FieldDescriptor, value protoreflect.Value) bool {
if isDebugRedactEnabled(descriptor) {
redactSingleField(message, descriptor)
return true
}
redactNestedMessage(message, descriptor, value)
errNested := redactNestedField(message, descriptor, value)
if errNested != nil {
err = errNested
return false
}
return true
},
)
return err
}

func getRedactedInput(input *RPCAuthInput) (RPCAuthInput, error) {
Expand Down Expand Up @@ -120,7 +164,10 @@ func getRedactedInput(input *RPCAuthInput) (RPCAuthInput, error) {
if err := protojson.Unmarshal([]byte(input.Message), redactedMessage); err != nil {
return RPCAuthInput{}, fmt.Errorf("could not marshal input into %v: %v", input.MessageType, err)
}
redactFields(redactedMessage.ProtoReflect())
errRedact := redactFields(redactedMessage.ProtoReflect())
if errRedact != nil {
return RPCAuthInput{}, fmt.Errorf("failed to redact message fields: %v", errRedact)
}
}
marshaled, err := protojson.MarshalOptions{UseProtoNames: true}.Marshal(redactedMessage)
if err != nil {
Expand Down
41 changes: 38 additions & 3 deletions auth/opa/rpcauth/redact_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import (
"context"
"testing"

proxypb "github.com/Snowflake-Labs/sansshell/proxy"
httppb "github.com/Snowflake-Labs/sansshell/services/httpoverrpc"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/known/anypb"
)

func TestGetRedactedInput(t *testing.T) {
Expand All @@ -39,8 +41,18 @@ func TestGetRedactedInput(t *testing.T) {
},
},
}
mockInput, _ := NewRPCAuthInput(context.TODO(), "/HTTPOverRPC.HTTPOverRPC/Host", httpReq.ProtoReflect().Interface())
httpReqInput, _ := NewRPCAuthInput(context.TODO(), "/HTTPOverRPC.HTTPOverRPC/Host", httpReq.ProtoReflect().Interface())

payload, _ := anypb.New(httpReq.ProtoReflect().Interface())
proxyReq := &proxypb.ProxyRequest{
Request: &proxypb.ProxyRequest_StreamData{
StreamData: &proxypb.StreamData{
StreamIds: []uint64{1},
Payload: payload,
},
},
}
proxyReqInput, _ := NewRPCAuthInput(context.TODO(), "/Proxy.Proxy/Proxy", proxyReq.ProtoReflect().Interface())
for _, tc := range []struct {
name string
createInputFn func() *RPCAuthInput
Expand All @@ -50,10 +62,10 @@ func TestGetRedactedInput(t *testing.T) {
{
name: "redacted fields should be redacted",
createInputFn: func() *RPCAuthInput {
return mockInput
return httpReqInput
},
assertionFn: func(result RPCAuthInput) {
messageType, _ := protoregistry.GlobalTypes.FindMessageByURL(mockInput.MessageType)
messageType, _ := protoregistry.GlobalTypes.FindMessageByURL(httpReqInput.MessageType)
resultMessage := messageType.New().Interface()
err := protojson.Unmarshal([]byte(result.Message), resultMessage)
assert.NoError(t, err)
Expand All @@ -67,6 +79,29 @@ func TestGetRedactedInput(t *testing.T) {
assert.NoError(t, err)
},
},
{
name: "any containing redacted_fields should be redacted",
createInputFn: func() *RPCAuthInput {
return proxyReqInput
},
assertionFn: func(result RPCAuthInput) {
messageType, _ := protoregistry.GlobalTypes.FindMessageByURL(proxyReqInput.MessageType)
resultMessage := messageType.New().Interface()
err := protojson.Unmarshal([]byte(result.Message), resultMessage)
assert.NoError(t, err)

proxyReq := resultMessage.(*proxypb.ProxyRequest)
proxyReqPayload := proxyReq.GetStreamData().Payload
payloadMsg, _ := proxyReqPayload.UnmarshalNew()
httpReq := payloadMsg.(*httppb.HostHTTPRequest)

assert.Equal(t, "--REDACTED--", httpReq.Request.Headers[0].Values[0]) // field with debug_redact should be redacted
assert.Equal(t, "key0", httpReq.Request.Headers[0].Key) // field without debug_redact should not be redacted
},
errFunc: func(t *testing.T, err error) {
assert.NoError(t, err)
},
},
{
name: "malformed input should return err",
createInputFn: func() *RPCAuthInput {
Expand Down

0 comments on commit 3e76142

Please sign in to comment.