Skip to content

Commit

Permalink
fix: implement persister logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Mar 21, 2024
1 parent 7d12be3 commit 8c0238c
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions persistence/sql/persister_consent.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (p *Persister) GetConsentRequest(ctx context.Context, challenge string) (*f
}

// CreateDeviceUserAuthRequest creates a new flow from a DeviceUserAuthRequest.
func (p *Persister) CreateDeviceUserAuthRequest(ctx context.Context, req *flow.DeviceUserAuthRequest) (*flow.DeviceFlow, error) {
func (p *Persister) CreateDeviceUserAuthRequest(ctx context.Context, req *flow.DeviceUserAuthRequest) (*flow.Flow, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateDeviceUserAuthRequest")
defer span.End()

Expand All @@ -237,7 +237,7 @@ func (p *Persister) GetDeviceUserAuthRequest(ctx context.Context, challenge stri
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetDeviceUserAuthRequest")
defer span.End()

f, err := flowctx.Decode[flow.DeviceFlow](ctx, p.r.FlowCipher(), challenge, flowctx.AsDeviceChallenge)
f, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), challenge, flowctx.AsDeviceChallenge)
if err != nil {
return nil, errorsx.WithStack(x.ErrNotFound.WithWrap(err))
}
Expand All @@ -253,7 +253,7 @@ func (p *Persister) GetDeviceUserAuthRequest(ctx context.Context, challenge stri
}

// HandleDeviceUserAuthRequest uses a HandledDeviceUserAuthRequest to update the flow and returns a DeviceUserAuthRequest.
func (p *Persister) HandleDeviceUserAuthRequest(ctx context.Context, f *flow.DeviceFlow, challenge string, r *flow.HandledDeviceUserAuthRequest) (*flow.DeviceUserAuthRequest, error) {
func (p *Persister) HandleDeviceUserAuthRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledDeviceUserAuthRequest) (*flow.DeviceUserAuthRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.HandleDeviceUserAuthRequest")
defer span.End()

Expand All @@ -276,7 +276,7 @@ func (p *Persister) VerifyAndInvalidateDeviceUserAuthRequest(ctx context.Context
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateDeviceUserAuthRequest")
defer span.End()

f, err := flowctx.Decode[flow.DeviceFlow](ctx, p.r.FlowCipher(), verifier, flowctx.AsDeviceVerifier)
f, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), verifier, flowctx.AsDeviceVerifier)
if err != nil {
return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The device verifier has already been used, has not been granted, or is invalid."))
}
Expand All @@ -288,18 +288,34 @@ func (p *Persister) VerifyAndInvalidateDeviceUserAuthRequest(ctx context.Context
return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error()))
}

if err = p.Connection(ctx).Create(f); err != nil {
return nil, sqlcon.HandleError(err)
}

return f.GetHandledDeviceUserAuthRequest(), nil
}

func (p *Persister) CreateLoginRequest(ctx context.Context, req *flow.LoginRequest) (*flow.Flow, error) {
func (p *Persister) CreateLoginRequest(ctx context.Context, req *flow.LoginRequest, f *flow.Flow) (*flow.Flow, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLoginRequest")
defer span.End()

f := flow.NewFlow(req)
if f == nil {
f = flow.NewFlow(req)
} else {
f.ID = req.ID
f.RequestedScope = req.RequestedScope
f.RequestedAudience = req.RequestedAudience
f.LoginSkip = req.Skip
f.Subject = req.Subject
f.OpenIDConnectContext = req.OpenIDConnectContext
f.Client = req.Client
f.ClientID = req.ClientID
f.RequestURL = req.RequestURL
f.SessionID = req.SessionID
f.LoginWasUsed = req.WasHandled
f.ForceSubjectIdentifier = req.ForceSubjectIdentifier
f.LoginVerifier = req.Verifier
f.LoginCSRF = req.CSRF
f.LoginAuthenticatedAt = req.AuthenticatedAt
f.RequestedAt = req.RequestedAt
f.State = flow.FlowStateLoginInitialized
}
nid := p.NetworkID(ctx)
if nid == uuid.Nil {
return nil, errorsx.WithStack(x.ErrNotFound)
Expand Down

0 comments on commit 8c0238c

Please sign in to comment.