From bf07910125d85d0790f8d7739d400aec9045ffb7 Mon Sep 17 00:00:00 2001 From: Dmitry Kolesnikov Date: Sun, 20 Oct 2024 14:32:25 +0300 Subject: [PATCH] (fix): pass IOContext from WebSocket --- broker/websocket/README.md | 2 +- broker/websocket/authorizer.go | 38 +++++++++++ broker/websocket/authorizer_test.go | 67 +++++++++++++++++++ broker/websocket/awscdk.go | 2 + .../examples/dequeue/typed/websocket.go | 2 + broker/websocket/examples/serverless/main.go | 1 + broker/websocket/go.mod | 2 +- broker/websocket/go.sum | 4 +- broker/websocket/lambda/auth/auth.go | 30 +++++++-- broker/websocket/version.go | 2 +- broker/websocket/websocket.go | 7 +- broker/websocket/websocket_test.go | 8 +++ 12 files changed, 153 insertions(+), 12 deletions(-) create mode 100644 broker/websocket/authorizer.go create mode 100644 broker/websocket/authorizer_test.go diff --git a/broker/websocket/README.md b/broker/websocket/README.md index 0ef5975..69d7958 100644 --- a/broker/websocket/README.md +++ b/broker/websocket/README.md @@ -7,6 +7,6 @@ Note the broker implements only infrastructure required for making serverless ap Use cli https://github.com/vi/websocat for testing purposes. ```bash -websocat wss://0000000000.execute-api.eu-west-1.amazonaws.com/ws/\?apikey=dGVzdDp0ZXN0 +websocat wss://0000000000.execute-api.eu-west-1.amazonaws.com/ws/\?apikey=dGVzdDp0ZXN0\&scope=test {"action":"User", "id":"xxx", "text":"some text"} ``` \ No newline at end of file diff --git a/broker/websocket/authorizer.go b/broker/websocket/authorizer.go new file mode 100644 index 0000000..0d2c311 --- /dev/null +++ b/broker/websocket/authorizer.go @@ -0,0 +1,38 @@ +package websocket + +import ( + "fmt" + + "github.com/aws/aws-lambda-go/events" + "github.com/fogfish/swarm" +) + +// Helper function to fetch principal identity from the context of websocket message +func PrincipalOf[T any](msg swarm.Msg[T]) (string, error) { + ctx, ok := msg.IOContext.(*events.APIGatewayWebsocketProxyRequestContext) + if !ok { + return "", fmt.Errorf("invalid message context: %T", msg.IOContext) + } + + // Note: Authorizer context is defined by authorizer function (see auth lambda). + // The context value depends on the auth type e.g. basic, jwt, etc). + // There are mandatory fields exists in any context: `auth`, `principalId`. + // In case of `jwt` the following fields exists: `iss`, `sub`, `exp`, `nbf`, `iat`, `scope`. + // In case of `basic` the fields `sub`, `scope` is simulated. + auth, ok := ctx.Authorizer.(map[string]any) + if !ok { + return "", fmt.Errorf("invalid authorizer context: %T", ctx.Authorizer) + } + + value, has := auth["principalId"] + if !has { + return "", fmt.Errorf("unknown principal in the authorizer context") + } + + principal, ok := value.(string) + if !ok { + return "", fmt.Errorf("invalid principal type: %T", value) + } + + return principal, nil +} diff --git a/broker/websocket/authorizer_test.go b/broker/websocket/authorizer_test.go new file mode 100644 index 0000000..0413596 --- /dev/null +++ b/broker/websocket/authorizer_test.go @@ -0,0 +1,67 @@ +package websocket + +import ( + "testing" + + "github.com/aws/aws-lambda-go/events" + "github.com/fogfish/it/v2" + "github.com/fogfish/swarm" +) + +func TestPrincipalOf(t *testing.T) { + t.Run("IOContext-None", func(t *testing.T) { + msg := swarm.Msg[string]{} + it.Then(t).Should( + it.Error(PrincipalOf(msg)).Contain("invalid message context"), + ) + }) + + t.Run("IOContext-Bad", func(t *testing.T) { + msg := swarm.Msg[string]{IOContext: "context"} + it.Then(t).Should( + it.Error(PrincipalOf(msg)).Contain("invalid message context"), + ) + }) + + t.Run("AuthContext-Bad", func(t *testing.T) { + msg := swarm.Msg[string]{IOContext: &events.APIGatewayWebsocketProxyRequestContext{ + Authorizer: "authroizer", + }} + it.Then(t).Should( + it.Error(PrincipalOf(msg)).Contain("invalid authorizer context"), + ) + }) + + t.Run("Principal-None", func(t *testing.T) { + msg := swarm.Msg[string]{IOContext: &events.APIGatewayWebsocketProxyRequestContext{ + Authorizer: map[string]any{}, + }} + it.Then(t).Should( + it.Error(PrincipalOf(msg)).Contain("unknown principal in the authorizer context"), + ) + }) + + t.Run("Principal-Bad", func(t *testing.T) { + msg := swarm.Msg[string]{IOContext: &events.APIGatewayWebsocketProxyRequestContext{ + Authorizer: map[string]any{ + "principalId": 1, + }, + }} + it.Then(t).Should( + it.Error(PrincipalOf(msg)).Contain("invalid principal type"), + ) + }) + + t.Run("Principal", func(t *testing.T) { + msg := swarm.Msg[string]{IOContext: &events.APIGatewayWebsocketProxyRequestContext{ + Authorizer: map[string]any{ + "principalId": "test", + }, + }} + val, err := PrincipalOf(msg) + it.Then(t).Should( + it.Nil(err), + it.Equal(val, "test"), + ) + }) +} diff --git a/broker/websocket/awscdk.go b/broker/websocket/awscdk.go index 14dbfe2..3be4b75 100644 --- a/broker/websocket/awscdk.go +++ b/broker/websocket/awscdk.go @@ -96,6 +96,7 @@ func NewBroker(scope constructs.Construct, id *string, props *BrokerProps) *Brok type AuthorizerApiKeyProps struct { Access string Secret string + Scope []string } func (broker *Broker) NewAuthorizerApiKey(props *AuthorizerApiKeyProps) awsapigatewayv2.IWebSocketRouteAuthorizer { @@ -115,6 +116,7 @@ func (broker *Broker) NewAuthorizerApiKey(props *AuthorizerApiKeyProps) awsapiga Environment: &map[string]*string{ "CONFIG_SWARM_WS_AUTHORIZER_ACCESS": jsii.String(props.Access), "CONFIG_SWARM_WS_AUTHORIZER_SECRET": jsii.String(props.Secret), + "CONFIG_SWARM_WS_AUTHORIZER_SCOPE": jsii.String(strings.Join(props.Scope, " ")), }, }, }, diff --git a/broker/websocket/examples/dequeue/typed/websocket.go b/broker/websocket/examples/dequeue/typed/websocket.go index 2c82290..3eab205 100644 --- a/broker/websocket/examples/dequeue/typed/websocket.go +++ b/broker/websocket/examples/dequeue/typed/websocket.go @@ -50,6 +50,8 @@ func (a *actor) handle(rcv <-chan swarm.Msg[User], ack chan<- swarm.Msg[User]) { for msg := range rcv { slog.Info("Event user", "data", msg.Object) + websocket.PrincipalOf(msg) + if err := a.emit.Enq(context.Background(), msg.Object, msg.Digest); err != nil { ack <- msg.Fail(err) continue diff --git a/broker/websocket/examples/serverless/main.go b/broker/websocket/examples/serverless/main.go index e3d903d..cc40513 100644 --- a/broker/websocket/examples/serverless/main.go +++ b/broker/websocket/examples/serverless/main.go @@ -34,6 +34,7 @@ func main() { &websocket.AuthorizerApiKeyProps{ Access: "test", Secret: "test", + Scope: []string{"test", "read"}, }, ) diff --git a/broker/websocket/go.mod b/broker/websocket/go.mod index 00aaf8d..77a29ea 100644 --- a/broker/websocket/go.mod +++ b/broker/websocket/go.mod @@ -14,7 +14,7 @@ require ( github.com/fogfish/it/v2 v2.0.2 github.com/fogfish/logger/v3 v3.1.1 github.com/fogfish/scud v0.10.2 - github.com/fogfish/swarm v0.20.1 + github.com/fogfish/swarm v0.20.2 ) require ( diff --git a/broker/websocket/go.sum b/broker/websocket/go.sum index f6f651e..caf8114 100644 --- a/broker/websocket/go.sum +++ b/broker/websocket/go.sum @@ -68,8 +68,8 @@ github.com/fogfish/logger/v3 v3.1.1 h1:awmTNpBWRvSj086H3RWIUnc+FSu9qXHJgBa49wYpN github.com/fogfish/logger/v3 v3.1.1/go.mod h1:hsucoJz/3OX90UdYrXykcKvjjteBnPcYSTr4Rie0ZqU= github.com/fogfish/scud v0.10.2 h1:cFupgZ4brqeGr/HCURnyDaBUNJIVEJTfKRwxEEUrO3w= github.com/fogfish/scud v0.10.2/go.mod h1:IVtHIfQMsb9lPKFeCI/OGcT2ssmd6onOZdpXgj/ORgs= -github.com/fogfish/swarm v0.20.1 h1:XzHkTHxgLVbctkTAcT4dVoGdp7mNKihdzz1Io6YGig0= -github.com/fogfish/swarm v0.20.1/go.mod h1:cdIviTojE3DT+FOIIOeOg6tyMqhyanfy2TZHTtKlOmo= +github.com/fogfish/swarm v0.20.2 h1:AiBs8yMw8hihFXiVCS5KaN8y5XjdUU9/aONx4mhF9/8= +github.com/fogfish/swarm v0.20.2/go.mod h1:cdIviTojE3DT+FOIIOeOg6tyMqhyanfy2TZHTtKlOmo= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= diff --git a/broker/websocket/lambda/auth/auth.go b/broker/websocket/lambda/auth/auth.go index c448fbd..afbd8b8 100644 --- a/broker/websocket/lambda/auth/auth.go +++ b/broker/websocket/lambda/auth/auth.go @@ -57,7 +57,8 @@ func main() { } if basic != nil { - principal, context, err := basic.Validate(tkn) + scope := evt.QueryStringParameters["scope"] + principal, context, err := basic.Validate(tkn, scope) if err != nil { return None, ErrForbidden } @@ -98,11 +99,15 @@ func AccessPolicy(principal, method string, context map[string]any) events.APIGa //------------------------------------------------------------------------------ -type AuthBasic struct{ access, secret string } +type AuthBasic struct { + access, secret string + scope []string +} func NewAuthBasic() (*AuthBasic, error) { access := os.Getenv("CONFIG_SWARM_WS_AUTHORIZER_ACCESS") secret := os.Getenv("CONFIG_SWARM_WS_AUTHORIZER_SECRET") + scope := os.Getenv("CONFIG_SWARM_WS_AUTHORIZER_SCOPE") if access == "" || secret == "" { return nil, errors.New("basic auth is not configured") @@ -111,10 +116,11 @@ func NewAuthBasic() (*AuthBasic, error) { return &AuthBasic{ access: access, secret: secret, + scope: strings.Split(scope, " "), }, nil } -func (auth *AuthBasic) Validate(apikey string) (string, map[string]any, error) { +func (auth *AuthBasic) Validate(apikey, scope string) (string, map[string]any, error) { c, err := base64.RawStdEncoding.DecodeString(apikey) if err != nil { return "", nil, ErrForbidden @@ -125,6 +131,22 @@ func (auth *AuthBasic) Validate(apikey string) (string, map[string]any, error) { return "", nil, ErrForbidden } + seq, err := url.QueryUnescape(scope) + if err != nil { + return "", nil, ErrForbidden + } + for _, sid := range strings.Split(seq, " ") { + has := false + for _, allowed := range auth.scope { + if allowed == sid { + has = true + } + } + if !has { + return "", nil, ErrForbidden + } + } + gaccess := sha256.Sum256([]byte(access)) gsecret := sha256.Sum256([]byte(secret)) haccess := sha256.Sum256([]byte(auth.access)) @@ -134,7 +156,7 @@ func (auth *AuthBasic) Validate(apikey string) (string, map[string]any, error) { secretMatch := (subtle.ConstantTimeCompare(gsecret[:], hsecret[:]) == 1) if accessMatch && secretMatch { - return access, map[string]any{"auth": "basic"}, nil + return access, map[string]any{"auth": "basic", "sub": access, "scope": scope}, nil } return "", nil, ErrForbidden diff --git a/broker/websocket/version.go b/broker/websocket/version.go index 7c12ee3..1232328 100644 --- a/broker/websocket/version.go +++ b/broker/websocket/version.go @@ -8,4 +8,4 @@ package websocket -const Version = "broker/websocket/v0.20.0" +const Version = "broker/websocket/v0.20.1" diff --git a/broker/websocket/websocket.go b/broker/websocket/websocket.go index 3743b4b..6552021 100644 --- a/broker/websocket/websocket.go +++ b/broker/websocket/websocket.go @@ -134,9 +134,10 @@ func (s bridge) Run() { lambda.Start(s.run) } func (s bridge) run(evt events.APIGatewayWebsocketProxyRequest) (events.APIGatewayProxyResponse, error) { bag := make([]swarm.Bag, 1) bag[0] = swarm.Bag{ - Category: evt.RequestContext.RouteKey, - Digest: evt.RequestContext.ConnectionID, - Object: []byte(evt.Body), + Category: evt.RequestContext.RouteKey, + Digest: evt.RequestContext.ConnectionID, + IOContext: &evt.RequestContext, + Object: []byte(evt.Body), } if err := s.Bridge.Dispatch(bag); err != nil { diff --git a/broker/websocket/websocket_test.go b/broker/websocket/websocket_test.go index b804619..027700a 100644 --- a/broker/websocket/websocket_test.go +++ b/broker/websocket/websocket_test.go @@ -58,6 +58,14 @@ func TestDequeuer(t *testing.T) { it.Equal(bag[0].Category, "test"), it.Equal(bag[0].Digest, "digest"), it.Equiv(bag[0].Object, []byte(`{"sut":"test"}`)), + ).ShouldNot( + it.Nil(bag[0].IOContext), + ) + + ctx := bag[0].IOContext.(*events.APIGatewayWebsocketProxyRequestContext) + it.Then(t).Should( + it.Equal(ctx.RouteKey, "test"), + it.Equal(ctx.ConnectionID, "digest"), ) }) }