Skip to content

Commit

Permalink
(fix): pass IOContext from WebSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
fogfish committed Oct 20, 2024
1 parent 6d34424 commit bf07910
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 12 deletions.
2 changes: 1 addition & 1 deletion broker/websocket/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ Note the broker implements only infrastructure required for making serverless ap
Use cli https://github.com/vi/websocat for testing purposes.

```bash
websocat wss://0000000000.execute-api.eu-west-1.amazonaws.com/ws/\?apikey=dGVzdDp0ZXN0
websocat wss://0000000000.execute-api.eu-west-1.amazonaws.com/ws/\?apikey=dGVzdDp0ZXN0\&scope=test
{"action":"User", "id":"xxx", "text":"some text"}
```
38 changes: 38 additions & 0 deletions broker/websocket/authorizer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package websocket

import (
"fmt"

"github.com/aws/aws-lambda-go/events"
"github.com/fogfish/swarm"
)

// Helper function to fetch principal identity from the context of websocket message
func PrincipalOf[T any](msg swarm.Msg[T]) (string, error) {
ctx, ok := msg.IOContext.(*events.APIGatewayWebsocketProxyRequestContext)
if !ok {
return "", fmt.Errorf("invalid message context: %T", msg.IOContext)
}

// Note: Authorizer context is defined by authorizer function (see auth lambda).
// The context value depends on the auth type e.g. basic, jwt, etc).
// There are mandatory fields exists in any context: `auth`, `principalId`.
// In case of `jwt` the following fields exists: `iss`, `sub`, `exp`, `nbf`, `iat`, `scope`.
// In case of `basic` the fields `sub`, `scope` is simulated.
auth, ok := ctx.Authorizer.(map[string]any)
if !ok {
return "", fmt.Errorf("invalid authorizer context: %T", ctx.Authorizer)
}

value, has := auth["principalId"]
if !has {
return "", fmt.Errorf("unknown principal in the authorizer context")
}

principal, ok := value.(string)
if !ok {
return "", fmt.Errorf("invalid principal type: %T", value)
}

return principal, nil
}
67 changes: 67 additions & 0 deletions broker/websocket/authorizer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package websocket

import (
"testing"

"github.com/aws/aws-lambda-go/events"
"github.com/fogfish/it/v2"
"github.com/fogfish/swarm"
)

func TestPrincipalOf(t *testing.T) {
t.Run("IOContext-None", func(t *testing.T) {
msg := swarm.Msg[string]{}
it.Then(t).Should(
it.Error(PrincipalOf(msg)).Contain("invalid message context"),
)
})

t.Run("IOContext-Bad", func(t *testing.T) {
msg := swarm.Msg[string]{IOContext: "context"}
it.Then(t).Should(
it.Error(PrincipalOf(msg)).Contain("invalid message context"),
)
})

t.Run("AuthContext-Bad", func(t *testing.T) {
msg := swarm.Msg[string]{IOContext: &events.APIGatewayWebsocketProxyRequestContext{
Authorizer: "authroizer",
}}
it.Then(t).Should(
it.Error(PrincipalOf(msg)).Contain("invalid authorizer context"),
)
})

t.Run("Principal-None", func(t *testing.T) {
msg := swarm.Msg[string]{IOContext: &events.APIGatewayWebsocketProxyRequestContext{
Authorizer: map[string]any{},
}}
it.Then(t).Should(
it.Error(PrincipalOf(msg)).Contain("unknown principal in the authorizer context"),
)
})

t.Run("Principal-Bad", func(t *testing.T) {
msg := swarm.Msg[string]{IOContext: &events.APIGatewayWebsocketProxyRequestContext{
Authorizer: map[string]any{
"principalId": 1,
},
}}
it.Then(t).Should(
it.Error(PrincipalOf(msg)).Contain("invalid principal type"),
)
})

t.Run("Principal", func(t *testing.T) {
msg := swarm.Msg[string]{IOContext: &events.APIGatewayWebsocketProxyRequestContext{
Authorizer: map[string]any{
"principalId": "test",
},
}}
val, err := PrincipalOf(msg)
it.Then(t).Should(
it.Nil(err),
it.Equal(val, "test"),
)
})
}
2 changes: 2 additions & 0 deletions broker/websocket/awscdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func NewBroker(scope constructs.Construct, id *string, props *BrokerProps) *Brok
type AuthorizerApiKeyProps struct {
Access string
Secret string
Scope []string
}

func (broker *Broker) NewAuthorizerApiKey(props *AuthorizerApiKeyProps) awsapigatewayv2.IWebSocketRouteAuthorizer {
Expand All @@ -115,6 +116,7 @@ func (broker *Broker) NewAuthorizerApiKey(props *AuthorizerApiKeyProps) awsapiga
Environment: &map[string]*string{
"CONFIG_SWARM_WS_AUTHORIZER_ACCESS": jsii.String(props.Access),
"CONFIG_SWARM_WS_AUTHORIZER_SECRET": jsii.String(props.Secret),
"CONFIG_SWARM_WS_AUTHORIZER_SCOPE": jsii.String(strings.Join(props.Scope, " ")),
},
},
},
Expand Down
2 changes: 2 additions & 0 deletions broker/websocket/examples/dequeue/typed/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ func (a *actor) handle(rcv <-chan swarm.Msg[User], ack chan<- swarm.Msg[User]) {
for msg := range rcv {
slog.Info("Event user", "data", msg.Object)

websocket.PrincipalOf(msg)

if err := a.emit.Enq(context.Background(), msg.Object, msg.Digest); err != nil {
ack <- msg.Fail(err)
continue
Expand Down
1 change: 1 addition & 0 deletions broker/websocket/examples/serverless/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func main() {
&websocket.AuthorizerApiKeyProps{
Access: "test",
Secret: "test",
Scope: []string{"test", "read"},
},
)

Expand Down
2 changes: 1 addition & 1 deletion broker/websocket/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ require (
github.com/fogfish/it/v2 v2.0.2
github.com/fogfish/logger/v3 v3.1.1
github.com/fogfish/scud v0.10.2
github.com/fogfish/swarm v0.20.1
github.com/fogfish/swarm v0.20.2
)

require (
Expand Down
4 changes: 2 additions & 2 deletions broker/websocket/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ github.com/fogfish/logger/v3 v3.1.1 h1:awmTNpBWRvSj086H3RWIUnc+FSu9qXHJgBa49wYpN
github.com/fogfish/logger/v3 v3.1.1/go.mod h1:hsucoJz/3OX90UdYrXykcKvjjteBnPcYSTr4Rie0ZqU=
github.com/fogfish/scud v0.10.2 h1:cFupgZ4brqeGr/HCURnyDaBUNJIVEJTfKRwxEEUrO3w=
github.com/fogfish/scud v0.10.2/go.mod h1:IVtHIfQMsb9lPKFeCI/OGcT2ssmd6onOZdpXgj/ORgs=
github.com/fogfish/swarm v0.20.1 h1:XzHkTHxgLVbctkTAcT4dVoGdp7mNKihdzz1Io6YGig0=
github.com/fogfish/swarm v0.20.1/go.mod h1:cdIviTojE3DT+FOIIOeOg6tyMqhyanfy2TZHTtKlOmo=
github.com/fogfish/swarm v0.20.2 h1:AiBs8yMw8hihFXiVCS5KaN8y5XjdUU9/aONx4mhF9/8=
github.com/fogfish/swarm v0.20.2/go.mod h1:cdIviTojE3DT+FOIIOeOg6tyMqhyanfy2TZHTtKlOmo=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
Expand Down
30 changes: 26 additions & 4 deletions broker/websocket/lambda/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ func main() {
}

if basic != nil {
principal, context, err := basic.Validate(tkn)
scope := evt.QueryStringParameters["scope"]
principal, context, err := basic.Validate(tkn, scope)
if err != nil {
return None, ErrForbidden
}
Expand Down Expand Up @@ -98,11 +99,15 @@ func AccessPolicy(principal, method string, context map[string]any) events.APIGa

//------------------------------------------------------------------------------

type AuthBasic struct{ access, secret string }
type AuthBasic struct {
access, secret string
scope []string
}

func NewAuthBasic() (*AuthBasic, error) {
access := os.Getenv("CONFIG_SWARM_WS_AUTHORIZER_ACCESS")
secret := os.Getenv("CONFIG_SWARM_WS_AUTHORIZER_SECRET")
scope := os.Getenv("CONFIG_SWARM_WS_AUTHORIZER_SCOPE")

if access == "" || secret == "" {
return nil, errors.New("basic auth is not configured")
Expand All @@ -111,10 +116,11 @@ func NewAuthBasic() (*AuthBasic, error) {
return &AuthBasic{
access: access,
secret: secret,
scope: strings.Split(scope, " "),
}, nil
}

func (auth *AuthBasic) Validate(apikey string) (string, map[string]any, error) {
func (auth *AuthBasic) Validate(apikey, scope string) (string, map[string]any, error) {
c, err := base64.RawStdEncoding.DecodeString(apikey)
if err != nil {
return "", nil, ErrForbidden
Expand All @@ -125,6 +131,22 @@ func (auth *AuthBasic) Validate(apikey string) (string, map[string]any, error) {
return "", nil, ErrForbidden
}

seq, err := url.QueryUnescape(scope)
if err != nil {
return "", nil, ErrForbidden
}
for _, sid := range strings.Split(seq, " ") {
has := false
for _, allowed := range auth.scope {
if allowed == sid {
has = true
}
}
if !has {
return "", nil, ErrForbidden
}
}

gaccess := sha256.Sum256([]byte(access))
gsecret := sha256.Sum256([]byte(secret))
haccess := sha256.Sum256([]byte(auth.access))
Expand All @@ -134,7 +156,7 @@ func (auth *AuthBasic) Validate(apikey string) (string, map[string]any, error) {
secretMatch := (subtle.ConstantTimeCompare(gsecret[:], hsecret[:]) == 1)

if accessMatch && secretMatch {
return access, map[string]any{"auth": "basic"}, nil
return access, map[string]any{"auth": "basic", "sub": access, "scope": scope}, nil
}

return "", nil, ErrForbidden
Expand Down
2 changes: 1 addition & 1 deletion broker/websocket/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@

package websocket

const Version = "broker/websocket/v0.20.0"
const Version = "broker/websocket/v0.20.1"
7 changes: 4 additions & 3 deletions broker/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ func (s bridge) Run() { lambda.Start(s.run) }
func (s bridge) run(evt events.APIGatewayWebsocketProxyRequest) (events.APIGatewayProxyResponse, error) {
bag := make([]swarm.Bag, 1)
bag[0] = swarm.Bag{
Category: evt.RequestContext.RouteKey,
Digest: evt.RequestContext.ConnectionID,
Object: []byte(evt.Body),
Category: evt.RequestContext.RouteKey,
Digest: evt.RequestContext.ConnectionID,
IOContext: &evt.RequestContext,
Object: []byte(evt.Body),
}

if err := s.Bridge.Dispatch(bag); err != nil {
Expand Down
8 changes: 8 additions & 0 deletions broker/websocket/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ func TestDequeuer(t *testing.T) {
it.Equal(bag[0].Category, "test"),
it.Equal(bag[0].Digest, "digest"),
it.Equiv(bag[0].Object, []byte(`{"sut":"test"}`)),
).ShouldNot(
it.Nil(bag[0].IOContext),
)

ctx := bag[0].IOContext.(*events.APIGatewayWebsocketProxyRequestContext)
it.Then(t).Should(
it.Equal(ctx.RouteKey, "test"),
it.Equal(ctx.ConnectionID, "digest"),
)
})
}
Expand Down

0 comments on commit bf07910

Please sign in to comment.