Skip to content

Commit

Permalink
feat: terminate existing user sessions on sso provider update
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 committed Nov 22, 2024
1 parent d71a5ee commit 66420fe
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
6 changes: 5 additions & 1 deletion cmd/api/src/database/oidc_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ func (s *BloodhoundDB) UpdateOIDCProvider(ctx context.Context, ssoProvider model

if _, err := bhdb.UpdateSSOProvider(ctx, ssoProvider); err != nil {
return err
} else if err := CheckError(tx.WithContext(ctx).Exec(fmt.Sprintf("UPDATE %s SET client_id = ?, issuer = ?, updated_at = ? WHERE id = ?;", oidcProvidersTableName),
ssoProvider.OIDCProvider.ClientID, ssoProvider.OIDCProvider.Issuer, time.Now().UTC(), ssoProvider.OIDCProvider.ID)); err != nil {
return err
} else {
return CheckError(tx.WithContext(ctx).Exec(fmt.Sprintf("UPDATE %s SET client_id = ?, issuer = ?, updated_at = ? WHERE id = ?;", oidcProvidersTableName), ssoProvider.OIDCProvider.ClientID, ssoProvider.OIDCProvider.Issuer, time.Now().UTC(), ssoProvider.OIDCProvider.ID))
// Ensure all existing sessions are invalidated within the tx
return bhdb.TerminateUserSessionsBySSOProvider(ctx, ssoProvider)
}
})

Expand Down
10 changes: 7 additions & 3 deletions cmd/api/src/database/samlproviders.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,14 @@ func (s *BloodhoundDB) UpdateSAMLIdentityProvider(ctx context.Context, ssoProvid

if _, err := bhdb.UpdateSSOProvider(ctx, ssoProvider); err != nil {
return err
} else if err := CheckError(tx.WithContext(ctx).Exec(
fmt.Sprintf("UPDATE %s SET name = ?, display_name = ?, issuer_uri = ?, single_sign_on_uri = ?, metadata_xml = ?, updated_at = ? WHERE id = ?;", samlProvidersTableName),
ssoProvider.SAMLProvider.Name, ssoProvider.SAMLProvider.DisplayName, ssoProvider.SAMLProvider.IssuerURI, ssoProvider.SAMLProvider.SingleSignOnURI, ssoProvider.SAMLProvider.MetadataXML, time.Now().UTC(), ssoProvider.SAMLProvider.ID),
); err != nil {
return err
} else {
return CheckError(tx.WithContext(ctx).Exec(
fmt.Sprintf("UPDATE %s SET name = ?, display_name = ?, issuer_uri = ?, single_sign_on_uri = ?, metadata_xml = ?, updated_at = ? WHERE id = ?;", samlProvidersTableName),
ssoProvider.SAMLProvider.Name, ssoProvider.SAMLProvider.DisplayName, ssoProvider.SAMLProvider.IssuerURI, ssoProvider.SAMLProvider.SingleSignOnURI, ssoProvider.SAMLProvider.MetadataXML, time.Now().UTC(), ssoProvider.SAMLProvider.ID))
// Ensure all existing sessions are invalidated within the tx
return bhdb.TerminateUserSessionsBySSOProvider(ctx, ssoProvider)
}
})

Expand Down
23 changes: 23 additions & 0 deletions cmd/api/src/database/sso_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type SSOProviderData interface {
GetSSOProviderById(ctx context.Context, id int32) (model.SSOProvider, error)
GetSSOProviderBySlug(ctx context.Context, slug string) (model.SSOProvider, error)
GetSSOProviderUsers(ctx context.Context, id int) (model.Users, error)
TerminateUserSessionsBySSOProvider(ctx context.Context, ssoProvider model.SSOProvider) error
UpdateSSOProvider(ctx context.Context, ssoProvider model.SSOProvider) (model.SSOProvider, error)
}

Expand Down Expand Up @@ -149,6 +150,28 @@ func (s *BloodhoundDB) GetSSOProviderById(ctx context.Context, id int32) (model.
return provider, CheckError(result)
}

// TerminateUserSessionsBySSOProvider terminates all sessions associated with a specific sso provider
func (s *BloodhoundDB) TerminateUserSessionsBySSOProvider(ctx context.Context, ssoProvider model.SSOProvider) error {
// TODO should be migrated to the SSO provider id instead of the child
var childId int32
switch ssoProvider.Type {
case model.SessionAuthProviderSAML:
if ssoProvider.SAMLProvider != nil {
childId = ssoProvider.SAMLProvider.ID
}
case model.SessionAuthProviderOIDC:
if ssoProvider.OIDCProvider != nil {
childId = ssoProvider.OIDCProvider.ID
}
}

if childId == 0 {
return ErrNotFound
}

return CheckError(s.db.WithContext(ctx).Table("user_sessions").Where("auth_provider_type = ? AND auth_provider_id = ?", ssoProvider.Type, childId).Update("expires_at", gorm.Expr("NOW()")))
}

// UpdateSSOProvider updates an entry in the sso_providers table
func (s *BloodhoundDB) UpdateSSOProvider(ctx context.Context, ssoProvider model.SSOProvider) (model.SSOProvider, error) {
// Update the slug
Expand Down

0 comments on commit 66420fe

Please sign in to comment.