From 2dbce7f35400ac910d78a40abf30944ee0e1a1fe Mon Sep 17 00:00:00 2001 From: Nikos Date: Tue, 30 Jul 2024 13:37:33 +0300 Subject: [PATCH 1/3] refactor: rename transaction to Transaction --- persistence/sql/persister.go | 2 +- persistence/sql/persister_client.go | 2 +- persistence/sql/persister_consent.go | 6 +++--- persistence/sql/persister_grant_jwk.go | 4 ++-- persistence/sql/persister_jwk.go | 6 +++--- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index ae4f7ce1825..bb7ceb94e01 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -184,6 +184,6 @@ func (p *Persister) mustSetNetwork(nid uuid.UUID, v interface{}) interface{} { return v } -func (p *Persister) transaction(ctx context.Context, f func(ctx context.Context, c *pop.Connection) error) error { +func (p *Persister) Transaction(ctx context.Context, f func(ctx context.Context, c *pop.Connection) error) error { return popx.Transaction(ctx, p.conn, f) } diff --git a/persistence/sql/persister_client.go b/persistence/sql/persister_client.go index 422d651bf4e..c85893c1df8 100644 --- a/persistence/sql/persister_client.go +++ b/persistence/sql/persister_client.go @@ -38,7 +38,7 @@ func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) (err er ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateClient") defer otelx.End(span, &err) - return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { o, err := p.GetConcreteClient(ctx, cl.GetID()) if err != nil { return err diff --git a/persistence/sql/persister_consent.go b/persistence/sql/persister_consent.go index 355618125de..e03097313a7 100644 --- a/persistence/sql/persister_consent.go +++ b/persistence/sql/persister_consent.go @@ -35,14 +35,14 @@ func (p *Persister) RevokeSubjectConsentSession(ctx context.Context, user string ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectConsentSession") defer span.End() - return p.transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ?", user)) + return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ?", user)) } func (p *Persister) RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectClientConsentSession") defer span.End() - return p.transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ? AND client_id = ?", user, client)) + return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ? AND client_id = ?", user, client)) } func (p *Persister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error { @@ -117,7 +117,7 @@ func (p *Persister) CreateForcedObfuscatedLoginSession(ctx context.Context, sess ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateForcedObfuscatedLoginSession") defer span.End() - return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { nid := p.NetworkID(ctx) if err := c.RawQuery( "DELETE FROM hydra_oauth2_obfuscated_authentication_session WHERE nid = ? AND client_id = ? AND subject = ?", diff --git a/persistence/sql/persister_grant_jwk.go b/persistence/sql/persister_grant_jwk.go index 115fa58fa0f..66fc2ecee15 100644 --- a/persistence/sql/persister_grant_jwk.go +++ b/persistence/sql/persister_grant_jwk.go @@ -26,7 +26,7 @@ func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jo ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateGrant") defer otelx.End(span, &err) - return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { // add key, if it doesn't exist if _, err := p.GetKey(ctx, g.PublicKey.Set, g.PublicKey.KeyID); err != nil { if !errors.Is(err, sqlcon.ErrNoRows) { @@ -59,7 +59,7 @@ func (p *Persister) DeleteGrant(ctx context.Context, id string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteGrant") defer otelx.End(span, &err) - return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { grant, err := p.GetConcreteGrant(ctx, id) if err != nil { return sqlcon.HandleError(err) diff --git a/persistence/sql/persister_jwk.go b/persistence/sql/persister_jwk.go index 1efdff13394..3f0dc9cf55d 100644 --- a/persistence/sql/persister_jwk.go +++ b/persistence/sql/persister_jwk.go @@ -64,7 +64,7 @@ func (p *Persister) AddKeySet(ctx context.Context, set string, keys *jose.JSONWe ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AddKey") defer span.End() - return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { for _, key := range keys.Keys { out, err := json.Marshal(key) if err != nil { @@ -94,7 +94,7 @@ func (p *Persister) UpdateKey(ctx context.Context, set string, key *jose.JSONWeb ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateKey") defer span.End() - return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { if err := p.DeleteKey(ctx, set, key.KeyID); err != nil { return err } @@ -110,7 +110,7 @@ func (p *Persister) UpdateKeySet(ctx context.Context, set string, keySet *jose.J ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateKeySet") defer span.End() - return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { if err := p.DeleteKeySet(ctx, set); err != nil { return err } From 1e1909c6f232e8e25bd9b14e39b435acc3ffd8e7 Mon Sep 17 00:00:00 2001 From: Nikos Date: Tue, 30 Jul 2024 14:55:08 +0300 Subject: [PATCH 2/3] chore: update config --- contrib/quickstart/5-min/hydra.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/contrib/quickstart/5-min/hydra.yml b/contrib/quickstart/5-min/hydra.yml index 10f7f39dc1a..551fb11651e 100644 --- a/contrib/quickstart/5-min/hydra.yml +++ b/contrib/quickstart/5-min/hydra.yml @@ -12,7 +12,8 @@ urls: consent: http://localhost:4455/ui/consent login: http://localhost:4455/ui/login logout: http://localhost:4455/ui/logout - device_verification: http://localhost:4455/ui/device + device_verification: http://localhost:4455/ui/device_code + post_device_done: http://localhost:4455/ui/device_complete secrets: system: From 0b22d83cb22ef3d20bd31abe4ab834f0e4774a60 Mon Sep 17 00:00:00 2001 From: Nikos Date: Tue, 30 Jul 2024 14:56:56 +0300 Subject: [PATCH 3/3] fix: wrap db calls in transaction --- consent/manager.go | 3 +++ consent/strategy_default.go | 52 +++++++++++++++++++++---------------- oauth2/registry.go | 2 ++ 3 files changed, 35 insertions(+), 22 deletions(-) diff --git a/consent/manager.go b/consent/manager.go index f09c803c06b..577fffa27f1 100644 --- a/consent/manager.go +++ b/consent/manager.go @@ -6,6 +6,7 @@ package consent import ( "context" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/ory/hydra/v2/client" @@ -65,6 +66,8 @@ type ( GetDeviceUserAuthRequest(ctx context.Context, challenge string) (*flow.DeviceUserAuthRequest, error) HandleDeviceUserAuthRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledDeviceUserAuthRequest) (*flow.DeviceUserAuthRequest, error) VerifyAndInvalidateDeviceUserAuthRequest(ctx context.Context, verifier string) (*flow.HandledDeviceUserAuthRequest, error) + + Transaction(context.Context, func(ctx context.Context, c *pop.Connection) error) error } ManagerProvider interface { diff --git a/consent/strategy_default.go b/consent/strategy_default.go index 1a5f846f36c..892508a31cb 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/gobuffalo/pop/v6" "github.com/gorilla/sessions" "github.com/hashicorp/go-retryablehttp" "github.com/pborman/uuid" @@ -39,8 +40,6 @@ import ( "github.com/ory/x/urlx" ) -type ctxKey int - const ( DeviceVerificationPath = "/oauth2/device/verify" CookieAuthenticationSIDName = "sid" @@ -1157,21 +1156,11 @@ func (s *DefaultStrategy) HandleOAuth2AuthorizationRequest( ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest") defer otelx.End(span, &err) - return s.handleOAuth2AuthorizationRequest(ctx, w, r, req, nil) -} - -func (s *DefaultStrategy) handleOAuth2AuthorizationRequest( - ctx context.Context, - w http.ResponseWriter, - r *http.Request, - req fosite.AuthorizeRequester, - f *flow.Flow, -) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) { loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier")) consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier")) if loginVerifier == "" && consentVerifier == "" { // ok, we need to process this request and redirect to the original endpoint - return nil, nil, s.requestAuthentication(ctx, w, r, req, f) + return nil, nil, s.requestAuthentication(ctx, w, r, req, nil) } else if loginVerifier != "" { f, err := s.verifyAuthentication(ctx, w, r, req, loginVerifier) if err != nil { @@ -1195,7 +1184,10 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest( ctx context.Context, w http.ResponseWriter, r *http.Request, -) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) { +) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) { + ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest") + defer otelx.End(span, &err) + deviceVerifier := strings.TrimSpace(r.URL.Query().Get("device_verifier")) loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier")) consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier")) @@ -1233,16 +1225,32 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest( ar.RequestedAudience = fosite.Arguments(deviceFlow.RequestedAudience) } - // TODO(nsklikas): wrap these 2 function calls in a transaction (one persists the flow and the other invalidates the user_code) - consentSession, f, err := s.handleOAuth2AuthorizationRequest(ctx, w, r, ar, deviceFlow) - if err != nil { - return nil, nil, err - } - err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(r.Context(), string(f.DeviceCodeRequestID), f.ID) - if err != nil { - return nil, nil, err + if loginVerifier == "" && consentVerifier == "" { + // ok, we need to process this request and redirect to the authentication endpoint + return nil, nil, s.requestAuthentication(ctx, w, r, ar, deviceFlow) + } else if loginVerifier != "" { + f, err := s.verifyAuthentication(ctx, w, r, ar, loginVerifier) + if err != nil { + return nil, nil, err + } + + // ok, we need to process this request and redirect to consent endpoint + return nil, f, s.requestConsent(ctx, w, r, ar, f) } + var consentSession *flow.AcceptOAuth2ConsentRequest + var f *flow.Flow + + err = s.r.ConsentManager().Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + consentSession, f, err = s.verifyConsent(ctx, w, r, consentVerifier) + if err != nil { + return err + } + err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(ctx, string(f.DeviceCodeRequestID), f.ID) + + return err + }) + return consentSession, f, err } diff --git a/oauth2/registry.go b/oauth2/registry.go index 4b7a19c402a..ffb7b642541 100644 --- a/oauth2/registry.go +++ b/oauth2/registry.go @@ -12,6 +12,7 @@ import ( "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/oauth2/trust" + "github.com/ory/hydra/v2/persistence" "github.com/ory/hydra/v2/x" ) @@ -22,6 +23,7 @@ type InternalRegistry interface { x.RegistryWriter x.RegistryLogger consent.Registry + persistence.Provider Registry FlowCipher() *aead.XChaCha20Poly1305 }