diff --git a/router/websocketserver.go b/router/websocketserver.go index d184c453..633326c9 100644 --- a/router/websocketserver.go +++ b/router/websocketserver.go @@ -86,6 +86,12 @@ type WebsocketServer struct { // Details.transport.auth.request|*http.Request making it available to // authenticator and authorizer logic. EnableRequestCapture bool + // TrackingCookieSameSiteAttribute defines which SameSite attribute will be assigned to + // the next auth cookie if EnableTrackingCookie is true + TrackingCookieSameSiteAttribute http.SameSite + // TrackingCookieSecureAttribute defines whether the TrackingCookie should be marked as + // secure and thus only sent over encrypted connections. + TrackingCookieSecureAttribute bool // KeepAlive configures a websocket "ping/pong" heartbeat when set to a // non-zero value. KeepAlive is the interval between websocket "pings". @@ -306,8 +312,10 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err == nil { // Create next auth cookie with 20 byte random value. nextCookie = &http.Cookie{ - Name: cookieName, - Value: base64.URLEncoding.EncodeToString(b), + Name: cookieName, + Value: base64.URLEncoding.EncodeToString(b), + Secure: s.TrackingCookieSecureAttribute, + SameSite: s.TrackingCookieSameSiteAttribute, } http.SetCookie(w, nextCookie) authDict["nextcookie"] = nextCookie diff --git a/router/websocketserver_test.go b/router/websocketserver_test.go index 9788c72e..61c01b67 100644 --- a/router/websocketserver_test.go +++ b/router/websocketserver_test.go @@ -83,6 +83,86 @@ func TestWSHandshakeMsgpack(t *testing.T) { require.True(t, ok, "expected WELCOME") } +func b2p(b bool) *bool { return &b } +func s2p(s http.SameSite) *http.SameSite { return &s } +func TestWSCookieAttributes(t *testing.T) { + // http library treats the samesite attribute for strings as if no attribute was set + sameSiteUnsetValue := http.SameSiteDefaultMode - 1 + tests := []struct { + name string + setSameSite *http.SameSite + setSecure *bool + wantSameSite http.SameSite + wantIsSecure bool + }{ + { + name: "default settings", + wantSameSite: sameSiteUnsetValue, + wantIsSecure: false, + }, + { + name: "same site strict", + setSameSite: s2p(http.SameSiteStrictMode), + wantSameSite: http.SameSiteStrictMode, + wantIsSecure: false, + }, + { + name: "same site none", + setSameSite: s2p(http.SameSiteNoneMode), + wantSameSite: http.SameSiteNoneMode, + wantIsSecure: false, + }, + { + name: "secure is true", + setSecure: b2p(true), + wantSameSite: sameSiteUnsetValue, + wantIsSecure: true, + }, + { + name: "secure is false", + setSecure: b2p(false), + wantSameSite: sameSiteUnsetValue, + wantIsSecure: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + checkGoLeaks(t) + r, err := NewRouter(routerConfig, nil) + require.NoError(t, err) + defer r.Close() + + wss := NewWebsocketServer(r) + wss.EnableTrackingCookie = true + if tt.setSecure != nil { + wss.TrackingCookieSecureAttribute = *tt.setSecure + } + if tt.setSameSite != nil { + wss.TrackingCookieSameSiteAttribute = *tt.setSameSite + } + closer, err := wss.ListenAndServe(wsAddr) + require.NoError(t, err) + defer closer.Close() + + dialer := websocket.Dialer{ + Subprotocols: []string{jsonWebsocketProtocol, cborWebsocketProtocol, msgpackWebsocketProtocol}, + TLSClientConfig: nil, + } + conn, rsp, err := dialer.DialContext(context.Background(), fmt.Sprintf("ws://%s/", wsAddr), nil) + require.NoError(t, err) + defer conn.Close() + + for _, c := range rsp.Cookies() { + if c.Name == "nexus-wamp-cookie" { + require.Equal(t, c.SameSite, tt.wantSameSite) + require.Equal(t, c.Secure, tt.wantIsSecure) + } + } + }) + } + +} + func TestAllowOrigins(t *testing.T) { s := &WebsocketServer{ Upgrader: &websocket.Upgrader{},