Skip to content

Commit

Permalink
feat: revoke consent by session id. trigger back channel logout.
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmam committed Mar 15, 2022
1 parent 924be24 commit 1726b54
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 29 deletions.
9 changes: 9 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,15 @@ type Client struct {
RegistrationClientURI string `json:"registration_client_uri,omitempty" db:"-"`
}

type AuthenticatedClient struct {
ClientID string `json:"client_id" db:"id"`
FrontChannelLogoutURI string `json:"frontchannel_logout_uri,omitempty" db:"frontchannel_logout_uri"`
FrontChannelLogoutSessionRequired bool `json:"frontchannel_logout_session_required,omitempty" db:"frontchannel_logout_session_required"`
BackChannelLogoutURI string `json:"backchannel_logout_uri,omitempty" db:"backchannel_logout_uri"`
BackChannelLogoutSessionRequired bool `json:"backchannel_logout_session_required,omitempty" db:"backchannel_logout_session_required"`
LoginSessionID string `json:"login_session_id,omitempty" db:"login_session_id"`
}

func (Client) TableName() string {
return "hydra_client"
}
Expand Down
31 changes: 28 additions & 3 deletions consent/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ func (h *Handler) SetRoutes(admin *x.RouterAdmin) {
func (h *Handler) DeleteConsentSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
subject := r.URL.Query().Get("subject")
client := r.URL.Query().Get("client")
loginSessionId := r.URL.Query().Get("login_session_id")
triggerBackChannelLogout := r.URL.Query().Get("trigger_backchannel_logout")

allClients := r.URL.Query().Get("all") == "true"
if subject == "" {
h.r.Writer().WriteError(w, r, errorsx.WithStack(fosite.ErrInvalidRequest.WithHint(`Query parameter 'subject' is not defined but should have been.`)))
Expand All @@ -110,11 +113,33 @@ func (h *Handler) DeleteConsentSession(w http.ResponseWriter, r *http.Request, p

switch {
case len(client) > 0:
if err := h.r.ConsentManager().RevokeSubjectClientConsentSession(r.Context(), subject, client); err != nil && !errors.Is(err, x.ErrNotFound) {
h.r.Writer().WriteError(w, r, err)
return
if len(loginSessionId) > 0 {
if triggerBackChannelLogout == "true" {
if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutByClientSession(r.Context(), r, subject, client, loginSessionId); err != nil {
h.r.Logger().WithError(err).Warn("Unable to execute back channel logout")
}
}
if err := h.r.ConsentManager().RevokeSubjectClientLoginSessionConsentSession(r.Context(), subject, client, loginSessionId); err != nil && !errors.Is(err, x.ErrNotFound) {
h.r.Writer().WriteError(w, r, err)
return
}
} else {
if triggerBackChannelLogout == "true" {
if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutByClient(r.Context(), r, subject, client); err != nil {
h.r.Logger().WithError(err).Warn("Unable to execute back channel logout")
}
}
if err := h.r.ConsentManager().RevokeSubjectClientConsentSession(r.Context(), subject, client); err != nil && !errors.Is(err, x.ErrNotFound) {
h.r.Writer().WriteError(w, r, err)
return
}
}
case allClients:
if triggerBackChannelLogout == "true" {
if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutBySubject(r.Context(), r, subject); err != nil {
h.r.Logger().WithError(err).Warn("Unable to execute back channel logout")
}
}
if err := h.r.ConsentManager().RevokeSubjectConsentSession(r.Context(), subject); err != nil && !errors.Is(err, x.ErrNotFound) {
h.r.Writer().WriteError(w, r, err)
return
Expand Down
6 changes: 4 additions & 2 deletions consent/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Manager interface {
HandleConsentRequest(ctx context.Context, challenge string, r *HandledConsentRequest) (*ConsentRequest, error)
RevokeSubjectConsentSession(ctx context.Context, user string) error
RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error
RevokeSubjectClientLoginSessionConsentSession(ctx context.Context, user, client, loginSessionId string) error

VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*HandledConsentRequest, error)
FindGrantedAndRememberedConsentRequests(ctx context.Context, client, user string) ([]HandledConsentRequest, error)
Expand All @@ -64,8 +65,9 @@ type Manager interface {
CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error
GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error)

ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error)
ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error)
ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject string) ([]client.AuthenticatedClient, error)

CreateLogoutRequest(ctx context.Context, request *LogoutRequest) error
GetLogoutRequest(ctx context.Context, challenge string) (*LogoutRequest, error)
Expand Down
14 changes: 7 additions & 7 deletions consent/manager_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
}

for _, ls := range sessions {
check := func(t *testing.T, expected map[string][]client.Client, actual []client.Client) {
check := func(t *testing.T, expected map[string][]client.Client, actual []client.AuthenticatedClient) {
es, ok := expected[ls.ID]
if !ok {
require.Len(t, actual, 0)
Expand All @@ -747,25 +747,25 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
for _, e := range es {
var found bool
for _, a := range actual {
if e.OutfacingID == a.OutfacingID {
if e.OutfacingID == a.ClientID {
found = true
}
assert.Equal(t, e.OutfacingID, a.OutfacingID)
assert.Equal(t, e.OutfacingID, a.ClientID)
assert.Equal(t, e.FrontChannelLogoutURI, a.FrontChannelLogoutURI)
assert.Equal(t, e.BackChannelLogoutURI, a.BackChannelLogoutURI)
}
require.True(t, found)
}
}

t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(context.Background(), ls.Subject, ls.ID)
t.Run(fmt.Sprintf("method=ListUserSessionAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
actual, err := m.ListUserSessionAuthenticatedClientsWithFrontChannelLogout(context.Background(), ls.Subject, ls.ID)
require.NoError(t, err)
check(t, frontChannels, actual)
})

t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) {
actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(context.Background(), ls.Subject, ls.ID)
t.Run(fmt.Sprintf("method=ListUserSessionAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) {
actual, err := m.ListUserSessionAuthenticatedClientsWithBackChannelLogout(context.Background(), ls.Subject, ls.ID)
require.NoError(t, err)
check(t, backChannels, actual)
})
Expand Down
5 changes: 5 additions & 0 deletions consent/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package consent

import (
"context"
"net/http"

"github.com/ory/fosite"
Expand All @@ -31,4 +32,8 @@ var _ Strategy = new(DefaultStrategy)
type Strategy interface {
HandleOAuth2AuthorizationRequest(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*HandledConsentRequest, error)
HandleOpenIDConnectLogout(w http.ResponseWriter, r *http.Request) (*LogoutResult, error)
ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error
ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error
ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error
ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error
}
52 changes: 44 additions & 8 deletions consent/strategy_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ func (s *DefaultStrategy) verifyConsent(w http.ResponseWriter, r *http.Request,
}

func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, subject, sid string) ([]string, error) {
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid)
clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid)
if err != nil {
return nil, err
}
Expand All @@ -653,12 +653,49 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su
return urls, nil
}

func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.Request, subject, sid string) error {
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
func (s *DefaultStrategy) ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error {
clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
if err != nil {
return err
}
return s.executeBackChannelLogout(ctx, r, clients)
}

func (s *DefaultStrategy) ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error {
clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
if err != nil {
return err
}
for i := len(clients) - 1; i >= 0; i-- {
if clients[i].ClientID != client {
clients = append(clients[:i], clients[i+1:]...)
}
}
return s.executeBackChannelLogout(ctx, r, clients)
}

func (s *DefaultStrategy) ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error {
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject)
if err != nil {
return err
}
for i := len(clients) - 1; i >= 0; i-- {
if clients[i].ClientID != client {
clients = append(clients[:i], clients[i+1:]...)
}
}
return s.executeBackChannelLogout(ctx, r, clients)
}

func (s *DefaultStrategy) ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error {
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject)
if err != nil {
return err
}
return s.executeBackChannelLogout(ctx, r, clients)
}

func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.Request, clients []client.AuthenticatedClient) error {
openIDKeyID, err := s.r.OpenIDJWTStrategy().GetPublicKeyID(ctx)
if err != nil {
return err
Expand All @@ -678,22 +715,21 @@ func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.
//
// s.r.ConsentManager().GetForcedObfuscatedLoginSession(context.Background(), subject, <missing>)
// sub := s.obfuscateSubjectIdentifier(c, subject, )

t, _, err := s.r.OpenIDJWTStrategy().Generate(ctx, jwtgo.MapClaims{
"iss": s.c.IssuerURL().String(),
"aud": []string{c.OutfacingID},
"aud": []string{c.ClientID},
"iat": time.Now().UTC().Unix(),
"jti": uuid.New(),
"events": map[string]struct{}{"http://schemas.openid.net/event/backchannel-logout": {}},
"sid": sid,
"sid": c.LoginSessionID,
}, &jwt.Headers{
Extra: map[string]interface{}{"kid": openIDKeyID},
})
if err != nil {
return err
}

tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.OutfacingID, token: t})
tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.ClientID, token: t})
}

var wg sync.WaitGroup
Expand Down Expand Up @@ -964,7 +1000,7 @@ func (s *DefaultStrategy) completeLogout(w http.ResponseWriter, r *http.Request)
return nil, err
}

if err := s.executeBackChannelLogout(r.Context(), r, lr.Subject, lr.SessionID); err != nil {
if err := s.ExecuteBackChannelLogoutBySession(r.Context(), r, lr.Subject, lr.SessionID); err != nil {
return nil, err
}

Expand Down
17 changes: 17 additions & 0 deletions oauth2/oauth2_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package oauth2_test

import (
"context"
"net/http"
"time"

Expand Down Expand Up @@ -62,3 +63,19 @@ func (c *consentMock) HandleOAuth2AuthorizationRequest(w http.ResponseWriter, r
func (c *consentMock) HandleOpenIDConnectLogout(w http.ResponseWriter, r *http.Request) (*consent.LogoutResult, error) {
panic("not implemented")
}

func (c *consentMock) ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error {
panic("not implemented")
}

func (c *consentMock) ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error {
panic("not implemented")
}

func (c *consentMock) ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error {
panic("not implemented")
}

func (c *consentMock) ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error {
panic("not implemented")
}
53 changes: 44 additions & 9 deletions persistence/sql/persister_consent.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"time"

"github.com/ory/hydra/client"

"github.com/gobuffalo/pop/v6"

"github.com/ory/x/sqlxx"
Expand All @@ -15,7 +17,6 @@ import (
"github.com/pkg/errors"

"github.com/ory/fosite"
"github.com/ory/hydra/client"
"github.com/ory/hydra/consent"
"github.com/ory/hydra/x"
"github.com/ory/x/sqlcon"
Expand All @@ -31,6 +32,10 @@ func (p *Persister) RevokeSubjectClientConsentSession(ctx context.Context, user,
return p.transaction(ctx, p.revokeConsentSession("r.subject = ? AND r.client_id = ?", user, client))
}

func (p *Persister) RevokeSubjectClientLoginSessionConsentSession(ctx context.Context, user, client, loginSessionId string) error {
return p.transaction(ctx, p.revokeConsentSession("r.subject = ? AND r.client_id = ? AND r.login_session_id = ?", user, client, loginSessionId))
}

func (p *Persister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error {
return func(ctx context.Context, c *pop.Connection) error {
hrs := make([]*consent.HandledConsentRequest, 0)
Expand Down Expand Up @@ -364,20 +369,24 @@ func (p *Persister) resolveHandledConsentRequests(ctx context.Context, requests
return result, nil
}

func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) {
return p.listUserAuthenticatedClients(ctx, subject, sid, "front")
func (p *Persister) ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error) {
return p.listUserSessionAuthenticatedClients(ctx, subject, sid, "front")
}

func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) {
return p.listUserAuthenticatedClients(ctx, subject, sid, "back")
func (p *Persister) ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error) {
return p.listUserSessionAuthenticatedClients(ctx, subject, sid, "back")
}

func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) ([]client.Client, error) {
var cs []client.Client
return cs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject string) ([]client.AuthenticatedClient, error) {
return p.listUserAuthenticatedClients(ctx, subject, "back")
}

func (p *Persister) listUserSessionAuthenticatedClients(ctx context.Context, subject, sid, channel string) ([]client.AuthenticatedClient, error) {
var cs []client.AuthenticatedClient
err := p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := c.RawQuery(
/* #nosec G201 - channel can either be "front" or "back" */
fmt.Sprintf(`SELECT DISTINCT c.* FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL AND r.login_session_id = ?`,
fmt.Sprintf(`SELECT DISTINCT c.id, c.frontchannel_logout_uri, c.frontchannel_logout_session_required, c.backchannel_logout_uri, c.backchannel_logout_session_required FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL AND r.login_session_id = ?`,
channel,
channel,
),
Expand All @@ -387,6 +396,32 @@ func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, s
return sqlcon.HandleError(err)
}

return nil
})
if err != nil {
return nil, err
}

for i := range cs {
cs[i].LoginSessionID = sid
}
return cs, err
}

func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, channel string) ([]client.AuthenticatedClient, error) {
var cs []client.AuthenticatedClient
return cs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := c.RawQuery(
/* #nosec G201 - channel can either be "front" or "back" */
fmt.Sprintf(`SELECT DISTINCT c.id, c.frontchannel_logout_uri, c.frontchannel_logout_session_required, c.backchannel_logout_uri, c.backchannel_logout_session_required, r.login_session_id FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL`,
channel,
channel,
),
subject,
).All(&cs); err != nil {
return sqlcon.HandleError(err)
}

return nil
})
}
Expand Down

0 comments on commit 1726b54

Please sign in to comment.