diff --git a/fositex/config.go b/fositex/config.go index bec3770af61..49ce459b702 100644 --- a/fositex/config.go +++ b/fositex/config.go @@ -65,6 +65,7 @@ var defaultFactories = []Factory{ compose.OIDCUserinfoVerifiableCredentialFactory, compose.RFC8628DeviceFactory, compose.RFC8628DeviceAuthorizationTokenFactory, + compose.OpenIDConnectDeviceFactory, } func NewConfig(deps configDependencies) *Config { diff --git a/go.mod b/go.mod index d48747dfece..5534d5d0c6b 100644 --- a/go.mod +++ b/go.mod @@ -255,4 +255,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace github.com/ory/fosite => github.com/canonical/fosite v0.0.0-20240329132814-3be772246a38 +replace github.com/ory/fosite => github.com/canonical/fosite v0.0.0-20240412170332-7fe9b8979dd3 diff --git a/go.sum b/go.sum index b62810f1c3b..97eb49b7d2d 100644 --- a/go.sum +++ b/go.sum @@ -74,8 +74,8 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4Yn github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/bradleyjkemp/cupaloy/v2 v2.8.0 h1:any4BmKE+jGIaMpnU8YgH/I2LPiLBufr6oMMlVBbn9M= github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0= -github.com/canonical/fosite v0.0.0-20240329132814-3be772246a38 h1:tM/abV0wyvC6eekGfGIu2tyTemN3xGKhFpHFuN7wYH8= -github.com/canonical/fosite v0.0.0-20240329132814-3be772246a38/go.mod h1:G5iZOjyC42o5uZaZK4GQdrqQeLxWZ4NZpD3rDRYM0Mc= +github.com/canonical/fosite v0.0.0-20240412170332-7fe9b8979dd3 h1:ZDkf+uEuw7eOY/JcRUoncTbt+WWG0TwoIjJ8hHU5Uuw= +github.com/canonical/fosite v0.0.0-20240412170332-7fe9b8979dd3/go.mod h1:G5iZOjyC42o5uZaZK4GQdrqQeLxWZ4NZpD3rDRYM0Mc= github.com/cenkalti/backoff/v3 v3.2.2 h1:cfUAAO3yvKMYKPrvhDuHSwQnhZNk/RMHKdZqKTxfm6M= github.com/cenkalti/backoff/v3 v3.2.2/go.mod h1:cIeZDE3IrqwwJl6VUwCN6trj1oXrTS4rc0ij+ULvLYs= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= diff --git a/oauth2/handler.go b/oauth2/handler.go index 462adc717f1..a3f27a340a0 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -721,16 +721,19 @@ func (h *Handler) getOidcUserInfo(w http.ResponseWriter, r *http.Request) { func (h *Handler) performOAuth2DeviceVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { ctx := r.Context() - consentSession, flow, err := h.r.ConsentStrategy().HandleOAuth2DeviceAuthorizationRequest(ctx, w, r) + consentSession, f, err := h.r.ConsentStrategy().HandleOAuth2DeviceAuthorizationRequest(ctx, w, r) if errors.Is(err, consent.ErrAbortOAuth2Request) { x.LogAudit(r, nil, h.r.AuditLogger()) - // do nothing return - } else if e := &(fosite.RFC6749Error{}); errors.As(err, &e) { + } + + if e := &(fosite.RFC6749Error{}); errors.As(err, &e) { x.LogAudit(r, err, h.r.AuditLogger()) h.r.Writer().WriteError(w, r, err) return - } else if err != nil { + } + + if err != nil { x.LogError(r, err, h.r.Logger()) h.r.Writer().WriteError(w, r, err) return @@ -738,23 +741,38 @@ func (h *Handler) performOAuth2DeviceVerificationFlow(w http.ResponseWriter, r * req := fosite.NewDeviceRequest() req.Client = consentSession.ConsentRequest.Client - session, err := h.updateSessionWithRequest(ctx, consentSession, flow, r, req) + session, err := h.updateSessionWithRequest(ctx, consentSession, f, r, req) if err != nil { h.r.Writer().WriteError(w, r, err) return } req.SetSession(session) - // We update the device_code session with the claims that the user gave consent for, this - // marks it as ready to be used for the token endpoint - err = h.r.OAuth2Storage().UpdateDeviceCodeSessionByRequestID(ctx, flow.DeviceCodeRequestID.String(), req) + // Update the device code session with + // - the claims for which the user gave consent + // - the granted scopes + // - the granted audiences + // This marks it as ready to be used for the token exchange endpoint. + err = h.r.OAuth2Storage().UpdateDeviceCodeSessionByRequestID(ctx, f.DeviceCodeRequestID.String(), req) if err != nil { x.LogError(r, err, h.r.Logger()) h.r.Writer().WriteError(w, r, err) return } - http.Redirect(w, r, urlx.SetQuery(h.c.DeviceDoneURL(ctx), url.Values{"consent_verifier": {string(flow.ConsentVerifier)}}).String(), http.StatusFound) + // TODO evaluate if an OpenID Connect session is necessary for device flow. + // Update the OpenID Connect session if "openid" scope is granted + if req.GetGrantedScopes().Has("openid") { + err = h.r.OAuth2Storage().UpdateOpenIDConnectSessionByRequestID(ctx, f.DeviceCodeRequestID.String(), req) + if err != nil { + x.LogError(r, err, h.r.Logger()) + h.r.Writer().WriteError(w, r, err) + return + } + } + + redirectURL := urlx.SetQuery(h.c.DeviceDoneURL(ctx), url.Values{"consent_verifier": {string(f.ConsentVerifier)}}).String() + http.Redirect(w, r, redirectURL, http.StatusFound) } // OAuth2 Device Flow diff --git a/oauth2/oauth2_device_code_test.go b/oauth2/oauth2_device_code_test.go index 022b79f45f6..758b916be8a 100644 --- a/oauth2/oauth2_device_code_test.go +++ b/oauth2/oauth2_device_code_test.go @@ -9,6 +9,10 @@ import ( "testing" "time" + "github.com/pborman/uuid" + + "github.com/ory/fosite/token/jwt" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -27,22 +31,26 @@ func TestDeviceAuthRequest(t *testing.T) { reg := internal.NewMockedRegistry(t, &contextx.Default{}) testhelpers.NewOAuth2Server(ctx, t, reg) + secret := uuid.New() c := &client.Client{ - ResponseTypes: []string{"id_token", "code", "token"}, + ID: "device-client", + Secret: secret, GrantTypes: []string{ string(fosite.GrantTypeDeviceCode), }, Scope: "hydra offline openid", Audience: []string{"https://api.ory.sh/"}, - TokenEndpointAuthMethod: "none", + TokenEndpointAuthMethod: "client_secret_post", } require.NoError(t, reg.ClientManager().CreateClient(ctx, c)) oauthClient := &oauth2.Config{ - ClientID: c.GetID(), + ClientID: c.GetID(), + ClientSecret: secret, Endpoint: oauth2.Endpoint{ DeviceAuthURL: reg.Config().OAuth2DeviceAuthorisationURL(ctx).String(), TokenURL: reg.Config().OAuth2TokenURL(ctx).String(), + AuthStyle: oauth2.AuthStyleInParams, }, Scopes: strings.Split(c.Scope, " "), } @@ -71,7 +79,7 @@ func TestDeviceAuthRequest(t *testing.T) { testCase.setUp() } - resp, err := oauthClient.DeviceAuth(context.Background()) + resp, err := oauthClient.DeviceAuth(context.Background(), []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("client_secret", secret)}...) if testCase.check != nil { testCase.check(t, resp, err) @@ -89,44 +97,85 @@ func TestDeviceTokenRequest(t *testing.T) { reg := internal.NewMockedRegistry(t, &contextx.Default{}) testhelpers.NewOAuth2Server(ctx, t, reg) + secret := uuid.New() c := &client.Client{ + ID: "device-client", + Secret: secret, GrantTypes: []string{ string(fosite.GrantTypeDeviceCode), + string(fosite.GrantTypeRefreshToken), }, - Scope: "hydra offline openid", - Audience: []string{"https://api.ory.sh/"}, - TokenEndpointAuthMethod: "none", + Scope: "hydra offline openid", + Audience: []string{"https://api.ory.sh/"}, } require.NoError(t, reg.ClientManager().CreateClient(ctx, c)) oauthClient := &oauth2.Config{ - ClientID: c.GetID(), + ClientID: c.GetID(), + ClientSecret: secret, Endpoint: oauth2.Endpoint{ DeviceAuthURL: reg.Config().OAuth2DeviceAuthorisationURL(ctx).String(), TokenURL: reg.Config().OAuth2TokenURL(ctx).String(), + AuthStyle: oauth2.AuthStyleInHeader, }, Scopes: strings.Split(c.Scope, " "), } - var code, signature string - var err error - code, signature, err = reg.RFC8628HMACStrategy().GenerateDeviceCode(context.TODO()) - require.NoError(t, err) - testCases := []struct { description string - setUp func() + setUp func(signature string) check func(t *testing.T, token *oauth2.Token, err error) cleanUp func() }{ { - description: "should pass", - setUp: func() { + description: "should pass with refresh token", + setUp: func(signature string) { + authreq := &fosite.DeviceRequest{ + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ID: c.GetID(), + GrantTypes: []string{string(fosite.GrantTypeDeviceCode)}, + }, + RequestedScope: []string{"hydra", "offline"}, + GrantedScope: []string{"hydra", "offline"}, + Session: &hydraoauth2.Session{ + DefaultSession: &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Subject: "hydra", + }, + ExpiresAt: map[fosite.TokenType]time.Time{ + fosite.DeviceCode: time.Now().Add(time.Hour).UTC(), + }, + }, + BrowserFlowCompleted: true, + }, + RequestedAt: time.Now(), + }, + } + + require.NoError(t, reg.OAuth2Storage().CreateDeviceCodeSession(context.TODO(), signature, authreq)) + }, + check: func(t *testing.T, token *oauth2.Token, err error) { + assert.NotEmpty(t, token.AccessToken) + assert.NotEmpty(t, token.RefreshToken) + }, + }, + { + description: "should pass with ID token", + setUp: func(signature string) { authreq := &fosite.DeviceRequest{ Request: fosite.Request{ - Client: &fosite.DefaultClient{ID: c.GetID(), GrantTypes: []string{string(fosite.GrantTypeDeviceCode)}}, + Client: &fosite.DefaultClient{ + ID: c.GetID(), + GrantTypes: []string{string(fosite.GrantTypeDeviceCode)}, + }, + RequestedScope: []string{"hydra", "offline", "openid"}, + GrantedScope: []string{"hydra", "offline", "openid"}, Session: &hydraoauth2.Session{ DefaultSession: &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Subject: "hydra", + }, ExpiresAt: map[fosite.TokenType]time.Time{ fosite.DeviceCode: time.Now().Add(time.Hour).UTC(), }, @@ -138,17 +187,23 @@ func TestDeviceTokenRequest(t *testing.T) { } require.NoError(t, reg.OAuth2Storage().CreateDeviceCodeSession(context.TODO(), signature, authreq)) + require.NoError(t, reg.OAuth2Storage().CreateOpenIDConnectSession(context.TODO(), signature, authreq)) }, check: func(t *testing.T, token *oauth2.Token, err error) { assert.NotEmpty(t, token.AccessToken) + assert.NotEmpty(t, token.RefreshToken) + assert.NotEmpty(t, token.Extra("id_token")) }, }, } for _, testCase := range testCases { t.Run("case="+testCase.description, func(t *testing.T) { + code, signature, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(context.TODO()) + require.NoError(t, err) + if testCase.setUp != nil { - testCase.setUp() + testCase.setUp(signature) } var token *oauth2.Token diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index e8422845211..808c48faeb2 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -446,6 +446,30 @@ func (p *Persister) CreateOpenIDConnectSession(ctx context.Context, signature st return p.createSession(ctx, signature, requester, sqlTableOpenID) } +// UpdateOpenIDConnectSessionByRequestID updates an OpenID session by requestID +func (p *Persister) UpdateOpenIDConnectSessionByRequestID(ctx context.Context, requestID string, requester fosite.Requester) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateOpenIDConnectSessionByRequestID") + defer otelx.End(span, &err) + + req, err := p.sqlSchemaFromRequest(ctx, requestID, requester, sqlTableOpenID) + if err != nil { + return err + } + + stmt := fmt.Sprintf( + "UPDATE %s SET granted_scope=?, granted_audience=?, session_data=? WHERE request_id=? AND nid = ?", + OAuth2RequestSQL{Table: sqlTableOpenID}.TableName(), + ) + + /* #nosec G201 table is static */ + err = p.Connection(ctx).RawQuery(stmt, req.GrantedScope, req.GrantedAudience, req.Session, requestID, p.NetworkID(ctx)).Exec() + if err != nil { + return sqlcon.HandleError(err) + } + + return nil +} + func (p *Persister) GetOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (_ fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetOpenIDConnectSession") defer otelx.End(span, &err) @@ -560,22 +584,24 @@ func (p *Persister) CreateDeviceCodeSession(ctx context.Context, signature strin func (p *Persister) UpdateDeviceCodeSessionByRequestID(ctx context.Context, requestID string, requester fosite.Requester) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateDeviceCodeSessionByRequestID") defer otelx.End(span, &err) + req, err := p.sqlSchemaFromRequest(ctx, requestID, requester, sqlTableDeviceCode) if err != nil { - return + return err } - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.Connection(ctx). - RawQuery( - fmt.Sprintf("UPDATE %s SET session_data=? WHERE request_id=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableDeviceCode}.TableName()), - req.Session, - requestID, - p.NetworkID(ctx), - ). - Exec(), + stmt := fmt.Sprintf( + "UPDATE %s SET granted_scope=?, granted_audience=?, session_data=? WHERE request_id=? AND nid = ?", + OAuth2RequestSQL{Table: sqlTableDeviceCode}.TableName(), ) + + /* #nosec G201 table is static */ + err = p.Connection(ctx).RawQuery(stmt, req.GrantedScope, req.GrantedAudience, req.Session, requestID, p.NetworkID(ctx)).Exec() + if err != nil { + return sqlcon.HandleError(err) + } + + return nil } // GetDeviceCodeSession returns a device code session from the database diff --git a/x/fosite_storer.go b/x/fosite_storer.go index 2d4d33f7de4..088a222baca 100644 --- a/x/fosite_storer.go +++ b/x/fosite_storer.go @@ -43,6 +43,8 @@ type FositeStorer interface { FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error + UpdateOpenIDConnectSessionByRequestID(ctx context.Context, requestID string, requester fosite.Requester) error + // DeleteOpenIDConnectSession deletes an OpenID Connect session. // This is duplicated from Ory Fosite to help against deprecation linting errors. DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) error