From 0b22d83cb22ef3d20bd31abe4ab834f0e4774a60 Mon Sep 17 00:00:00 2001 From: Nikos Date: Tue, 30 Jul 2024 14:56:56 +0300 Subject: [PATCH] fix: wrap db calls in transaction --- consent/manager.go | 3 +++ consent/strategy_default.go | 52 +++++++++++++++++++++---------------- oauth2/registry.go | 2 ++ 3 files changed, 35 insertions(+), 22 deletions(-) diff --git a/consent/manager.go b/consent/manager.go index f09c803c06b..577fffa27f1 100644 --- a/consent/manager.go +++ b/consent/manager.go @@ -6,6 +6,7 @@ package consent import ( "context" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/ory/hydra/v2/client" @@ -65,6 +66,8 @@ type ( GetDeviceUserAuthRequest(ctx context.Context, challenge string) (*flow.DeviceUserAuthRequest, error) HandleDeviceUserAuthRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledDeviceUserAuthRequest) (*flow.DeviceUserAuthRequest, error) VerifyAndInvalidateDeviceUserAuthRequest(ctx context.Context, verifier string) (*flow.HandledDeviceUserAuthRequest, error) + + Transaction(context.Context, func(ctx context.Context, c *pop.Connection) error) error } ManagerProvider interface { diff --git a/consent/strategy_default.go b/consent/strategy_default.go index 1a5f846f36c..892508a31cb 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/gobuffalo/pop/v6" "github.com/gorilla/sessions" "github.com/hashicorp/go-retryablehttp" "github.com/pborman/uuid" @@ -39,8 +40,6 @@ import ( "github.com/ory/x/urlx" ) -type ctxKey int - const ( DeviceVerificationPath = "/oauth2/device/verify" CookieAuthenticationSIDName = "sid" @@ -1157,21 +1156,11 @@ func (s *DefaultStrategy) HandleOAuth2AuthorizationRequest( ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest") defer otelx.End(span, &err) - return s.handleOAuth2AuthorizationRequest(ctx, w, r, req, nil) -} - -func (s *DefaultStrategy) handleOAuth2AuthorizationRequest( - ctx context.Context, - w http.ResponseWriter, - r *http.Request, - req fosite.AuthorizeRequester, - f *flow.Flow, -) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) { loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier")) consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier")) if loginVerifier == "" && consentVerifier == "" { // ok, we need to process this request and redirect to the original endpoint - return nil, nil, s.requestAuthentication(ctx, w, r, req, f) + return nil, nil, s.requestAuthentication(ctx, w, r, req, nil) } else if loginVerifier != "" { f, err := s.verifyAuthentication(ctx, w, r, req, loginVerifier) if err != nil { @@ -1195,7 +1184,10 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest( ctx context.Context, w http.ResponseWriter, r *http.Request, -) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) { +) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) { + ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest") + defer otelx.End(span, &err) + deviceVerifier := strings.TrimSpace(r.URL.Query().Get("device_verifier")) loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier")) consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier")) @@ -1233,16 +1225,32 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest( ar.RequestedAudience = fosite.Arguments(deviceFlow.RequestedAudience) } - // TODO(nsklikas): wrap these 2 function calls in a transaction (one persists the flow and the other invalidates the user_code) - consentSession, f, err := s.handleOAuth2AuthorizationRequest(ctx, w, r, ar, deviceFlow) - if err != nil { - return nil, nil, err - } - err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(r.Context(), string(f.DeviceCodeRequestID), f.ID) - if err != nil { - return nil, nil, err + if loginVerifier == "" && consentVerifier == "" { + // ok, we need to process this request and redirect to the authentication endpoint + return nil, nil, s.requestAuthentication(ctx, w, r, ar, deviceFlow) + } else if loginVerifier != "" { + f, err := s.verifyAuthentication(ctx, w, r, ar, loginVerifier) + if err != nil { + return nil, nil, err + } + + // ok, we need to process this request and redirect to consent endpoint + return nil, f, s.requestConsent(ctx, w, r, ar, f) } + var consentSession *flow.AcceptOAuth2ConsentRequest + var f *flow.Flow + + err = s.r.ConsentManager().Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + consentSession, f, err = s.verifyConsent(ctx, w, r, consentVerifier) + if err != nil { + return err + } + err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(ctx, string(f.DeviceCodeRequestID), f.ID) + + return err + }) + return consentSession, f, err } diff --git a/oauth2/registry.go b/oauth2/registry.go index 4b7a19c402a..ffb7b642541 100644 --- a/oauth2/registry.go +++ b/oauth2/registry.go @@ -12,6 +12,7 @@ import ( "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/oauth2/trust" + "github.com/ory/hydra/v2/persistence" "github.com/ory/hydra/v2/x" ) @@ -22,6 +23,7 @@ type InternalRegistry interface { x.RegistryWriter x.RegistryLogger consent.Registry + persistence.Provider Registry FlowCipher() *aead.XChaCha20Poly1305 }