Skip to content

Commit

Permalink
fix: wrap db calls in transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Sep 13, 2024
1 parent e5ff47b commit 702c45a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 22 deletions.
3 changes: 3 additions & 0 deletions consent/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package consent
import (
"context"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"

"github.com/ory/hydra/v2/client"
Expand Down Expand Up @@ -65,6 +66,8 @@ type (
GetDeviceUserAuthRequest(ctx context.Context, challenge string) (*flow.DeviceUserAuthRequest, error)
HandleDeviceUserAuthRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledDeviceUserAuthRequest) (*flow.DeviceUserAuthRequest, error)
VerifyAndInvalidateDeviceUserAuthRequest(ctx context.Context, verifier string) (*flow.HandledDeviceUserAuthRequest, error)

Transaction(context.Context, func(ctx context.Context, c *pop.Connection) error) error
}

ManagerProvider interface {
Expand Down
52 changes: 30 additions & 22 deletions consent/strategy_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"
"time"

"github.com/gobuffalo/pop/v6"
"github.com/gorilla/sessions"
"github.com/hashicorp/go-retryablehttp"
"github.com/pborman/uuid"
Expand All @@ -39,8 +40,6 @@ import (
"github.com/ory/x/urlx"
)

type ctxKey int

const (
DeviceVerificationPath = "/oauth2/device/verify"
CookieAuthenticationSIDName = "sid"
Expand Down Expand Up @@ -1157,21 +1156,11 @@ func (s *DefaultStrategy) HandleOAuth2AuthorizationRequest(
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest")
defer otelx.End(span, &err)

return s.handleOAuth2AuthorizationRequest(ctx, w, r, req, nil)
}

func (s *DefaultStrategy) handleOAuth2AuthorizationRequest(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
req fosite.AuthorizeRequester,
f *flow.Flow,
) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) {
loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier"))
consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier"))
if loginVerifier == "" && consentVerifier == "" {
// ok, we need to process this request and redirect to the original endpoint
return nil, nil, s.requestAuthentication(ctx, w, r, req, f)
return nil, nil, s.requestAuthentication(ctx, w, r, req, nil)
} else if loginVerifier != "" {
f, err := s.verifyAuthentication(ctx, w, r, req, loginVerifier)
if err != nil {
Expand All @@ -1195,7 +1184,10 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) {
) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) {
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest")
defer otelx.End(span, &err)

deviceVerifier := strings.TrimSpace(r.URL.Query().Get("device_verifier"))
loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier"))
consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier"))
Expand Down Expand Up @@ -1233,16 +1225,32 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest(
ar.RequestedAudience = fosite.Arguments(deviceFlow.RequestedAudience)
}

// TODO(nsklikas): wrap these 2 function calls in a transaction (one persists the flow and the other invalidates the user_code)
consentSession, f, err := s.handleOAuth2AuthorizationRequest(ctx, w, r, ar, deviceFlow)
if err != nil {
return nil, nil, err
}
err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(r.Context(), string(f.DeviceCodeRequestID), f.ID)
if err != nil {
return nil, nil, err
if loginVerifier == "" && consentVerifier == "" {
// ok, we need to process this request and redirect to the authentication endpoint
return nil, nil, s.requestAuthentication(ctx, w, r, ar, deviceFlow)
} else if loginVerifier != "" {
f, err := s.verifyAuthentication(ctx, w, r, ar, loginVerifier)
if err != nil {
return nil, nil, err
}

// ok, we need to process this request and redirect to consent endpoint
return nil, f, s.requestConsent(ctx, w, r, ar, f)
}

var consentSession *flow.AcceptOAuth2ConsentRequest
var f *flow.Flow

err = s.r.ConsentManager().Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
consentSession, f, err = s.verifyConsent(ctx, w, r, consentVerifier)
if err != nil {
return err
}
err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(ctx, string(f.DeviceCodeRequestID), f.ID)

return err
})

return consentSession, f, err
}

Expand Down

0 comments on commit 702c45a

Please sign in to comment.