diff --git a/consent/handler.go b/consent/handler.go index 78a5897d5da..bd5751e66d4 100644 --- a/consent/handler.go +++ b/consent/handler.go @@ -579,6 +579,13 @@ func (h *Handler) AcceptConsentRequest(w http.ResponseWriter, r *http.Request, p h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) return } else if hr.Skip { + if p.Remember && p.RememberFor > 0 { // TODO: Consider removing 'p.RememberFor > 0' to update consent validity in both ways (limited (RememberFor > 0) -> indefinitely (RememberFor = 0) and vice versa) + err = h.r.ConsentManager().ExtendConsentRequest(r.Context(), h.r.ScopeStrategy(), hr, p.RememberFor) + if err != nil { + h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) + return + } + } p.Remember = false } diff --git a/consent/manager.go b/consent/manager.go index f0fa286050b..5a8aa55a9e1 100644 --- a/consent/manager.go +++ b/consent/manager.go @@ -24,6 +24,8 @@ import ( "context" "time" + "github.com/ory/fosite" + "github.com/ory/hydra/client" ) @@ -41,6 +43,7 @@ type Manager interface { CreateConsentRequest(ctx context.Context, req *ConsentRequest) error GetConsentRequest(ctx context.Context, challenge string) (*ConsentRequest, error) HandleConsentRequest(ctx context.Context, challenge string, r *HandledConsentRequest) (*ConsentRequest, error) + ExtendConsentRequest(ctx context.Context, scopeStrategy fosite.ScopeStrategy, cr *ConsentRequest, extendBy int) error RevokeSubjectConsentSession(ctx context.Context, user string) error RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error diff --git a/persistence/sql/persister_consent.go b/persistence/sql/persister_consent.go index 923e951fa4a..f1288a8961e 100644 --- a/persistence/sql/persister_consent.go +++ b/persistence/sql/persister_consent.go @@ -194,6 +194,74 @@ func (p *Persister) HandleConsentRequest(ctx context.Context, challenge string, return p.GetConsentRequest(ctx, challenge) } +func (p *Persister) ExtendConsentRequest(ctx context.Context, scopeStrategy fosite.ScopeStrategy, cr *consent.ConsentRequest, extendBy int) error { + return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + tn := consent.HandledConsentRequest{}.TableName() + + var sessionHcr consent.HandledConsentRequest + if err := c. + Where(fmt.Sprintf("r.subject = ? AND r.client_id = ? AND r.login_session_id = ? AND r.skip=FALSE AND (%s.error='{}' AND %s.remember=TRUE)", tn, tn), cr.Subject, cr.ClientID, cr.LoginSessionID.String()). + Join("hydra_oauth2_consent_request AS r", fmt.Sprintf("%s.challenge = r.challenge", tn)). + Order(fmt.Sprintf("%s.requested_at DESC", tn)). + Limit(1). + First(&sessionHcr); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return errorsx.WithStack(consent.ErrNoPreviousConsentFound) + } + return sqlcon.HandleError(err) + } + + var latestHcr consent.HandledConsentRequest + if err := c. + Where(fmt.Sprintf("r.subject = ? AND r.client_id = ? AND r.skip=FALSE AND (%s.error='{}' AND %s.remember=TRUE)", tn, tn), cr.Subject, cr.ClientID). + Join("hydra_oauth2_consent_request AS r", fmt.Sprintf("%s.challenge = r.challenge", tn)). + Order(fmt.Sprintf("%s.requested_at DESC", tn)). + Limit(1). + First(&latestHcr); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return errorsx.WithStack(consent.ErrNoPreviousConsentFound) + } + return sqlcon.HandleError(err) + } + + if err := p.extendHandledConsentRequest(ctx, cr, scopeStrategy, sessionHcr, extendBy); err != nil { + return err + } + + if latestHcr.ID != sessionHcr.ID { + if err := p.extendHandledConsentRequest(ctx, cr, scopeStrategy, latestHcr, extendBy); err != nil { + return err + } + } + return nil + }) +} + +func (p *Persister) extendHandledConsentRequest(ctx context.Context, cr *consent.ConsentRequest, scopeStrategy fosite.ScopeStrategy, hcr consent.HandledConsentRequest, extendBy int) error { + for _, scope := range cr.RequestedScope { + if !scopeStrategy(hcr.GrantedScope, scope) { + return nil + } + } + + isConsentRequestExpired := hcr.RememberFor > 0 && hcr.RequestedAt.Add(time.Duration(hcr.RememberFor)*time.Second).Before(time.Now().UTC()) + if isConsentRequestExpired { + return nil + } + + remainingTime := hcr.RequestedAt.Unix() + int64(hcr.RememberFor) - time.Now().Unix() + if remainingTime > 0 { + hcr.RememberFor = hcr.RememberFor + extendBy - int(remainingTime) + } else { + hcr.RememberFor = hcr.RememberFor + extendBy + } + + if err := sqlcon.HandleError(p.Connection(ctx).Update(&hcr)); err != nil { + return err + } + return nil +} + func (p *Persister) VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*consent.HandledConsentRequest, error) { var r consent.HandledConsentRequest return &r, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {