diff --git a/client/client.go b/client/client.go index feb140e002c..ddeb52c0c26 100644 --- a/client/client.go +++ b/client/client.go @@ -200,6 +200,15 @@ type Client struct { Metadata sqlxx.JSONRawMessage `json:"metadata,omitempty" db:"metadata"` } +type AuthenticatedClient struct { + ClientID string `json:"client_id" db:"id"` + FrontChannelLogoutURI string `json:"frontchannel_logout_uri,omitempty" db:"frontchannel_logout_uri"` + FrontChannelLogoutSessionRequired bool `json:"frontchannel_logout_session_required,omitempty" db:"frontchannel_logout_session_required"` + BackChannelLogoutURI string `json:"backchannel_logout_uri,omitempty" db:"backchannel_logout_uri"` + BackChannelLogoutSessionRequired bool `json:"backchannel_logout_session_required,omitempty" db:"backchannel_logout_session_required"` + LoginSessionID string `json:"login_session_id,omitempty" db:"login_session_id"` +} + func (Client) TableName() string { return "hydra_client" } diff --git a/consent/handler.go b/consent/handler.go index 78a5897d5da..2d5717589e4 100644 --- a/consent/handler.go +++ b/consent/handler.go @@ -102,6 +102,9 @@ func (h *Handler) SetRoutes(admin *x.RouterAdmin) { func (h *Handler) DeleteConsentSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { subject := r.URL.Query().Get("subject") client := r.URL.Query().Get("client") + loginSessionId := r.URL.Query().Get("login_session_id") + triggerBackChannelLogout := r.URL.Query().Get("trigger_backchannel_logout") + allClients := r.URL.Query().Get("all") == "true" if subject == "" { h.r.Writer().WriteError(w, r, errorsx.WithStack(fosite.ErrInvalidRequest.WithHint(`Query parameter 'subject' is not defined but should have been.`))) @@ -110,11 +113,33 @@ func (h *Handler) DeleteConsentSession(w http.ResponseWriter, r *http.Request, p switch { case len(client) > 0: - if err := h.r.ConsentManager().RevokeSubjectClientConsentSession(r.Context(), subject, client); err != nil && !errors.Is(err, x.ErrNotFound) { - h.r.Writer().WriteError(w, r, err) - return + if len(loginSessionId) > 0 { + if triggerBackChannelLogout == "true" { + if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutByClientSession(r.Context(), r, subject, client, loginSessionId); err != nil { + h.r.Logger().WithError(err).Warn("Unable to execute back channel logout") + } + } + if err := h.r.ConsentManager().RevokeSubjectClientLoginSessionConsentSession(r.Context(), subject, client, loginSessionId); err != nil && !errors.Is(err, x.ErrNotFound) { + h.r.Writer().WriteError(w, r, err) + return + } + } else { + if triggerBackChannelLogout == "true" { + if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutByClient(r.Context(), r, subject, client); err != nil { + h.r.Logger().WithError(err).Warn("Unable to execute back channel logout") + } + } + if err := h.r.ConsentManager().RevokeSubjectClientConsentSession(r.Context(), subject, client); err != nil && !errors.Is(err, x.ErrNotFound) { + h.r.Writer().WriteError(w, r, err) + return + } } case allClients: + if triggerBackChannelLogout == "true" { + if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutBySubject(r.Context(), r, subject); err != nil { + h.r.Logger().WithError(err).Warn("Unable to execute back channel logout") + } + } if err := h.r.ConsentManager().RevokeSubjectConsentSession(r.Context(), subject); err != nil && !errors.Is(err, x.ErrNotFound) { h.r.Writer().WriteError(w, r, err) return diff --git a/consent/manager.go b/consent/manager.go index f0fa286050b..5d230018d71 100644 --- a/consent/manager.go +++ b/consent/manager.go @@ -43,6 +43,7 @@ type Manager interface { HandleConsentRequest(ctx context.Context, challenge string, r *HandledConsentRequest) (*ConsentRequest, error) RevokeSubjectConsentSession(ctx context.Context, user string) error RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error + RevokeSubjectClientLoginSessionConsentSession(ctx context.Context, user, client, loginSessionId string) error VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*HandledConsentRequest, error) FindGrantedAndRememberedConsentRequests(ctx context.Context, client, user string) ([]HandledConsentRequest, error) @@ -64,8 +65,9 @@ type Manager interface { CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error) - ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) - ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) + ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error) + ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error) + ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject string) ([]client.AuthenticatedClient, error) CreateLogoutRequest(ctx context.Context, request *LogoutRequest) error GetLogoutRequest(ctx context.Context, challenge string) (*LogoutRequest, error) diff --git a/consent/manager_test_helpers.go b/consent/manager_test_helpers.go index 960fd50d865..e852d51f16b 100644 --- a/consent/manager_test_helpers.go +++ b/consent/manager_test_helpers.go @@ -736,7 +736,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit } for _, ls := range sessions { - check := func(t *testing.T, expected map[string][]client.Client, actual []client.Client) { + check := func(t *testing.T, expected map[string][]client.Client, actual []client.AuthenticatedClient) { es, ok := expected[ls.ID] if !ok { require.Len(t, actual, 0) @@ -747,10 +747,10 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit for _, e := range es { var found bool for _, a := range actual { - if e.OutfacingID == a.OutfacingID { + if e.OutfacingID == a.ClientID { found = true } - assert.Equal(t, e.OutfacingID, a.OutfacingID) + assert.Equal(t, e.OutfacingID, a.ClientID) assert.Equal(t, e.FrontChannelLogoutURI, a.FrontChannelLogoutURI) assert.Equal(t, e.BackChannelLogoutURI, a.BackChannelLogoutURI) } @@ -758,14 +758,14 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit } } - t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) { - actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(context.Background(), ls.Subject, ls.ID) + t.Run(fmt.Sprintf("method=ListUserSessionAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) { + actual, err := m.ListUserSessionAuthenticatedClientsWithFrontChannelLogout(context.Background(), ls.Subject, ls.ID) require.NoError(t, err) check(t, frontChannels, actual) }) - t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) { - actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(context.Background(), ls.Subject, ls.ID) + t.Run(fmt.Sprintf("method=ListUserSessionAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) { + actual, err := m.ListUserSessionAuthenticatedClientsWithBackChannelLogout(context.Background(), ls.Subject, ls.ID) require.NoError(t, err) check(t, backChannels, actual) }) diff --git a/consent/strategy.go b/consent/strategy.go index fa2f9eebfed..b8d62925e08 100644 --- a/consent/strategy.go +++ b/consent/strategy.go @@ -21,6 +21,7 @@ package consent import ( + "context" "net/http" "github.com/ory/fosite" @@ -31,4 +32,8 @@ var _ Strategy = new(DefaultStrategy) type Strategy interface { HandleOAuth2AuthorizationRequest(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*HandledConsentRequest, error) HandleOpenIDConnectLogout(w http.ResponseWriter, r *http.Request) (*LogoutResult, error) + ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error + ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error + ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error + ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error } diff --git a/consent/strategy_default.go b/consent/strategy_default.go index 6f67951fc05..bd5900c5be0 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -632,7 +632,7 @@ func (s *DefaultStrategy) verifyConsent(w http.ResponseWriter, r *http.Request, } func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, subject, sid string) ([]string, error) { - clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid) + clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid) if err != nil { return nil, err } @@ -653,12 +653,49 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su return urls, nil } -func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.Request, subject, sid string) error { - clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid) +func (s *DefaultStrategy) ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error { + clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid) if err != nil { return err } + return s.executeBackChannelLogout(ctx, r, clients) +} +func (s *DefaultStrategy) ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error { + clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid) + if err != nil { + return err + } + for i := len(clients) - 1; i >= 0; i-- { + if clients[i].ClientID != client { + clients = append(clients[:i], clients[i+1:]...) + } + } + return s.executeBackChannelLogout(ctx, r, clients) +} + +func (s *DefaultStrategy) ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error { + clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject) + if err != nil { + return err + } + for i := len(clients) - 1; i >= 0; i-- { + if clients[i].ClientID != client { + clients = append(clients[:i], clients[i+1:]...) + } + } + return s.executeBackChannelLogout(ctx, r, clients) +} + +func (s *DefaultStrategy) ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error { + clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject) + if err != nil { + return err + } + return s.executeBackChannelLogout(ctx, r, clients) +} + +func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.Request, clients []client.AuthenticatedClient) error { openIDKeyID, err := s.r.OpenIDJWTStrategy().GetPublicKeyID(ctx) if err != nil { return err @@ -678,14 +715,13 @@ func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http. // // s.r.ConsentManager().GetForcedObfuscatedLoginSession(context.Background(), subject, ) // sub := s.obfuscateSubjectIdentifier(c, subject, ) - t, _, err := s.r.OpenIDJWTStrategy().Generate(ctx, jwtgo.MapClaims{ "iss": s.c.IssuerURL().String(), - "aud": []string{c.OutfacingID}, + "aud": []string{c.ClientID}, "iat": time.Now().UTC().Unix(), "jti": uuid.New(), "events": map[string]struct{}{"http://schemas.openid.net/event/backchannel-logout": {}}, - "sid": sid, + "sid": c.LoginSessionID, }, &jwt.Headers{ Extra: map[string]interface{}{"kid": openIDKeyID}, }) @@ -693,7 +729,7 @@ func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http. return err } - tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.OutfacingID, token: t}) + tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.ClientID, token: t}) } var wg sync.WaitGroup @@ -964,7 +1000,7 @@ func (s *DefaultStrategy) completeLogout(w http.ResponseWriter, r *http.Request) return nil, err } - if err := s.executeBackChannelLogout(r.Context(), r, lr.Subject, lr.SessionID); err != nil { + if err := s.ExecuteBackChannelLogoutBySession(r.Context(), r, lr.Subject, lr.SessionID); err != nil { return nil, err } diff --git a/oauth2/oauth2_helper_test.go b/oauth2/oauth2_helper_test.go index 4fb665eabbf..23188ed3635 100644 --- a/oauth2/oauth2_helper_test.go +++ b/oauth2/oauth2_helper_test.go @@ -21,6 +21,7 @@ package oauth2_test import ( + "context" "net/http" "time" @@ -62,3 +63,19 @@ func (c *consentMock) HandleOAuth2AuthorizationRequest(w http.ResponseWriter, r func (c *consentMock) HandleOpenIDConnectLogout(w http.ResponseWriter, r *http.Request) (*consent.LogoutResult, error) { panic("not implemented") } + +func (c *consentMock) ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error { + panic("not implemented") +} + +func (c *consentMock) ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error { + panic("not implemented") +} + +func (c *consentMock) ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error { + panic("not implemented") +} + +func (c *consentMock) ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error { + panic("not implemented") +} diff --git a/persistence/sql/persister_consent.go b/persistence/sql/persister_consent.go index b20dbaf7847..d72de569f9b 100644 --- a/persistence/sql/persister_consent.go +++ b/persistence/sql/persister_consent.go @@ -6,6 +6,8 @@ import ( "fmt" "time" + "github.com/ory/hydra/client" + "github.com/ory/x/sqlxx" "github.com/ory/x/errorsx" @@ -14,7 +16,6 @@ import ( "github.com/pkg/errors" "github.com/ory/fosite" - "github.com/ory/hydra/client" "github.com/ory/hydra/consent" "github.com/ory/hydra/x" "github.com/ory/x/sqlcon" @@ -30,6 +31,10 @@ func (p *Persister) RevokeSubjectClientConsentSession(ctx context.Context, user, return p.transaction(ctx, p.revokeConsentSession("r.subject = ? AND r.client_id = ?", user, client)) } +func (p *Persister) RevokeSubjectClientLoginSessionConsentSession(ctx context.Context, user, client, loginSessionId string) error { + return p.transaction(ctx, p.revokeConsentSession("r.subject = ? AND r.client_id = ? AND r.login_session_id = ?", user, client, loginSessionId)) +} + func (p *Persister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error { return func(ctx context.Context, c *pop.Connection) error { hrs := make([]*consent.HandledConsentRequest, 0) @@ -363,20 +368,24 @@ func (p *Persister) resolveHandledConsentRequests(ctx context.Context, requests return result, nil } -func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) { - return p.listUserAuthenticatedClients(ctx, subject, sid, "front") +func (p *Persister) ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error) { + return p.listUserSessionAuthenticatedClients(ctx, subject, sid, "front") } -func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) { - return p.listUserAuthenticatedClients(ctx, subject, sid, "back") +func (p *Persister) ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error) { + return p.listUserSessionAuthenticatedClients(ctx, subject, sid, "back") } -func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) ([]client.Client, error) { - var cs []client.Client - return cs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { +func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject string) ([]client.AuthenticatedClient, error) { + return p.listUserAuthenticatedClients(ctx, subject, "back") +} + +func (p *Persister) listUserSessionAuthenticatedClients(ctx context.Context, subject, sid, channel string) ([]client.AuthenticatedClient, error) { + var cs []client.AuthenticatedClient + err := p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { if err := c.RawQuery( /* #nosec G201 - channel can either be "front" or "back" */ - fmt.Sprintf(`SELECT DISTINCT c.* FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL AND r.login_session_id = ?`, + fmt.Sprintf(`SELECT DISTINCT c.id, c.frontchannel_logout_uri, c.frontchannel_logout_session_required, c.backchannel_logout_uri, c.backchannel_logout_session_required FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL AND r.login_session_id = ?`, channel, channel, ), @@ -386,6 +395,32 @@ func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, s return sqlcon.HandleError(err) } + return nil + }) + if err != nil { + return nil, err + } + + for i := range cs { + cs[i].LoginSessionID = sid + } + return cs, err +} + +func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, channel string) ([]client.AuthenticatedClient, error) { + var cs []client.AuthenticatedClient + return cs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + if err := c.RawQuery( + /* #nosec G201 - channel can either be "front" or "back" */ + fmt.Sprintf(`SELECT DISTINCT c.id, c.frontchannel_logout_uri, c.frontchannel_logout_session_required, c.backchannel_logout_uri, c.backchannel_logout_session_required, r.login_session_id FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL`, + channel, + channel, + ), + subject, + ).All(&cs); err != nil { + return sqlcon.HandleError(err) + } + return nil }) }