Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add transaction to flow persistence #26

Merged
merged 3 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions consent/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package consent
import (
"context"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"

"github.com/ory/hydra/v2/client"
Expand Down Expand Up @@ -65,6 +66,8 @@ type (
GetDeviceUserAuthRequest(ctx context.Context, challenge string) (*flow.DeviceUserAuthRequest, error)
HandleDeviceUserAuthRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledDeviceUserAuthRequest) (*flow.DeviceUserAuthRequest, error)
VerifyAndInvalidateDeviceUserAuthRequest(ctx context.Context, verifier string) (*flow.HandledDeviceUserAuthRequest, error)

Transaction(context.Context, func(ctx context.Context, c *pop.Connection) error) error
}

ManagerProvider interface {
Expand Down
52 changes: 30 additions & 22 deletions consent/strategy_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"
"time"

"github.com/gobuffalo/pop/v6"
"github.com/gorilla/sessions"
"github.com/hashicorp/go-retryablehttp"
"github.com/pborman/uuid"
Expand All @@ -39,8 +40,6 @@ import (
"github.com/ory/x/urlx"
)

type ctxKey int

const (
DeviceVerificationPath = "/oauth2/device/verify"
CookieAuthenticationSIDName = "sid"
Expand Down Expand Up @@ -1157,21 +1156,11 @@ func (s *DefaultStrategy) HandleOAuth2AuthorizationRequest(
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest")
defer otelx.End(span, &err)

return s.handleOAuth2AuthorizationRequest(ctx, w, r, req, nil)
}

func (s *DefaultStrategy) handleOAuth2AuthorizationRequest(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
req fosite.AuthorizeRequester,
f *flow.Flow,
) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) {
loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier"))
consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier"))
if loginVerifier == "" && consentVerifier == "" {
// ok, we need to process this request and redirect to the original endpoint
return nil, nil, s.requestAuthentication(ctx, w, r, req, f)
return nil, nil, s.requestAuthentication(ctx, w, r, req, nil)
} else if loginVerifier != "" {
f, err := s.verifyAuthentication(ctx, w, r, req, loginVerifier)
if err != nil {
Expand All @@ -1195,7 +1184,10 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) {
) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) {
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest")
defer otelx.End(span, &err)

deviceVerifier := strings.TrimSpace(r.URL.Query().Get("device_verifier"))
loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier"))
consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier"))
Expand Down Expand Up @@ -1233,16 +1225,32 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest(
ar.RequestedAudience = fosite.Arguments(deviceFlow.RequestedAudience)
}

// TODO(nsklikas): wrap these 2 function calls in a transaction (one persists the flow and the other invalidates the user_code)
consentSession, f, err := s.handleOAuth2AuthorizationRequest(ctx, w, r, ar, deviceFlow)
if err != nil {
return nil, nil, err
}
err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(r.Context(), string(f.DeviceCodeRequestID), f.ID)
if err != nil {
return nil, nil, err
if loginVerifier == "" && consentVerifier == "" {
// ok, we need to process this request and redirect to the authentication endpoint
return nil, nil, s.requestAuthentication(ctx, w, r, ar, deviceFlow)
} else if loginVerifier != "" {
f, err := s.verifyAuthentication(ctx, w, r, ar, loginVerifier)
if err != nil {
return nil, nil, err
}

// ok, we need to process this request and redirect to consent endpoint
return nil, f, s.requestConsent(ctx, w, r, ar, f)
}

var consentSession *flow.AcceptOAuth2ConsentRequest
var f *flow.Flow

err = s.r.ConsentManager().Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
consentSession, f, err = s.verifyConsent(ctx, w, r, consentVerifier)
if err != nil {
return err
}
err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(ctx, string(f.DeviceCodeRequestID), f.ID)

return err
})

return consentSession, f, err
}

Expand Down
3 changes: 2 additions & 1 deletion contrib/quickstart/5-min/hydra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ urls:
consent: http://localhost:4455/ui/consent
login: http://localhost:4455/ui/login
logout: http://localhost:4455/ui/logout
device_verification: http://localhost:4455/ui/device
device_verification: http://localhost:4455/ui/device_code
post_device_done: http://localhost:4455/ui/device_complete

secrets:
system:
Expand Down
2 changes: 2 additions & 0 deletions oauth2/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/oauth2/trust"
"github.com/ory/hydra/v2/persistence"
"github.com/ory/hydra/v2/x"
)

Expand All @@ -22,6 +23,7 @@ type InternalRegistry interface {
x.RegistryWriter
x.RegistryLogger
consent.Registry
persistence.Provider
Registry
FlowCipher() *aead.XChaCha20Poly1305
}
Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,6 @@ func (p *Persister) mustSetNetwork(nid uuid.UUID, v interface{}) interface{} {
return v
}

func (p *Persister) transaction(ctx context.Context, f func(ctx context.Context, c *pop.Connection) error) error {
func (p *Persister) Transaction(ctx context.Context, f func(ctx context.Context, c *pop.Connection) error) error {
return popx.Transaction(ctx, p.conn, f)
}
2 changes: 1 addition & 1 deletion persistence/sql/persister_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) (err er
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateClient")
defer otelx.End(span, &err)

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
o, err := p.GetConcreteClient(ctx, cl.GetID())
if err != nil {
return err
Expand Down
6 changes: 3 additions & 3 deletions persistence/sql/persister_consent.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ func (p *Persister) RevokeSubjectConsentSession(ctx context.Context, user string
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectConsentSession")
defer span.End()

return p.transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ?", user))
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ?", user))
}

func (p *Persister) RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectClientConsentSession")
defer span.End()

return p.transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ? AND client_id = ?", user, client))
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ? AND client_id = ?", user, client))
}

func (p *Persister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error {
Expand Down Expand Up @@ -117,7 +117,7 @@ func (p *Persister) CreateForcedObfuscatedLoginSession(ctx context.Context, sess
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateForcedObfuscatedLoginSession")
defer span.End()

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
nid := p.NetworkID(ctx)
if err := c.RawQuery(
"DELETE FROM hydra_oauth2_obfuscated_authentication_session WHERE nid = ? AND client_id = ? AND subject = ?",
Expand Down
4 changes: 2 additions & 2 deletions persistence/sql/persister_grant_jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jo
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateGrant")
defer otelx.End(span, &err)

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
// add key, if it doesn't exist
if _, err := p.GetKey(ctx, g.PublicKey.Set, g.PublicKey.KeyID); err != nil {
if !errors.Is(err, sqlcon.ErrNoRows) {
Expand Down Expand Up @@ -59,7 +59,7 @@ func (p *Persister) DeleteGrant(ctx context.Context, id string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteGrant")
defer otelx.End(span, &err)

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
grant, err := p.GetConcreteGrant(ctx, id)
if err != nil {
return sqlcon.HandleError(err)
Expand Down
6 changes: 3 additions & 3 deletions persistence/sql/persister_jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (p *Persister) AddKeySet(ctx context.Context, set string, keys *jose.JSONWe
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AddKey")
defer span.End()

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
for _, key := range keys.Keys {
out, err := json.Marshal(key)
if err != nil {
Expand Down Expand Up @@ -94,7 +94,7 @@ func (p *Persister) UpdateKey(ctx context.Context, set string, key *jose.JSONWeb
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateKey")
defer span.End()

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := p.DeleteKey(ctx, set, key.KeyID); err != nil {
return err
}
Expand All @@ -110,7 +110,7 @@ func (p *Persister) UpdateKeySet(ctx context.Context, set string, keySet *jose.J
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateKeySet")
defer span.End()

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := p.DeleteKeySet(ctx, set); err != nil {
return err
}
Expand Down
Loading