Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix the OIDC token and refresh token issue for device flow #10

Merged
merged 2 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fositex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ var defaultFactories = []Factory{
compose.OIDCUserinfoVerifiableCredentialFactory,
compose.RFC8628DeviceFactory,
compose.RFC8628DeviceAuthorizationTokenFactory,
compose.OpenIDConnectDeviceFactory,
}

func NewConfig(deps configDependencies) *Config {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,4 @@ require (
gopkg.in/yaml.v3 v3.0.1 // indirect
)

replace github.com/ory/fosite => github.com/canonical/fosite v0.0.0-20240329132814-3be772246a38
replace github.com/ory/fosite => github.com/canonical/fosite v0.0.0-20240412170332-7fe9b8979dd3
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4Yn
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/bradleyjkemp/cupaloy/v2 v2.8.0 h1:any4BmKE+jGIaMpnU8YgH/I2LPiLBufr6oMMlVBbn9M=
github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0=
github.com/canonical/fosite v0.0.0-20240329132814-3be772246a38 h1:tM/abV0wyvC6eekGfGIu2tyTemN3xGKhFpHFuN7wYH8=
github.com/canonical/fosite v0.0.0-20240329132814-3be772246a38/go.mod h1:G5iZOjyC42o5uZaZK4GQdrqQeLxWZ4NZpD3rDRYM0Mc=
github.com/canonical/fosite v0.0.0-20240412170332-7fe9b8979dd3 h1:ZDkf+uEuw7eOY/JcRUoncTbt+WWG0TwoIjJ8hHU5Uuw=
github.com/canonical/fosite v0.0.0-20240412170332-7fe9b8979dd3/go.mod h1:G5iZOjyC42o5uZaZK4GQdrqQeLxWZ4NZpD3rDRYM0Mc=
github.com/cenkalti/backoff/v3 v3.2.2 h1:cfUAAO3yvKMYKPrvhDuHSwQnhZNk/RMHKdZqKTxfm6M=
github.com/cenkalti/backoff/v3 v3.2.2/go.mod h1:cIeZDE3IrqwwJl6VUwCN6trj1oXrTS4rc0ij+ULvLYs=
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
Expand Down
36 changes: 27 additions & 9 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,40 +721,58 @@ func (h *Handler) getOidcUserInfo(w http.ResponseWriter, r *http.Request) {
func (h *Handler) performOAuth2DeviceVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
ctx := r.Context()

consentSession, flow, err := h.r.ConsentStrategy().HandleOAuth2DeviceAuthorizationRequest(ctx, w, r)
consentSession, f, err := h.r.ConsentStrategy().HandleOAuth2DeviceAuthorizationRequest(ctx, w, r)
if errors.Is(err, consent.ErrAbortOAuth2Request) {
x.LogAudit(r, nil, h.r.AuditLogger())
// do nothing
return
} else if e := &(fosite.RFC6749Error{}); errors.As(err, &e) {
}

if e := &(fosite.RFC6749Error{}); errors.As(err, &e) {
x.LogAudit(r, err, h.r.AuditLogger())
h.r.Writer().WriteError(w, r, err)
return
} else if err != nil {
}

if err != nil {
x.LogError(r, err, h.r.Logger())
h.r.Writer().WriteError(w, r, err)
return
}

req := fosite.NewDeviceRequest()
req.Client = consentSession.ConsentRequest.Client
session, err := h.updateSessionWithRequest(ctx, consentSession, flow, r, req)
session, err := h.updateSessionWithRequest(ctx, consentSession, f, r, req)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
}

req.SetSession(session)
// We update the device_code session with the claims that the user gave consent for, this
// marks it as ready to be used for the token endpoint
err = h.r.OAuth2Storage().UpdateDeviceCodeSessionByRequestID(ctx, flow.DeviceCodeRequestID.String(), req)
// Update the device code session with
// - the claims for which the user gave consent
// - the granted scopes
// - the granted audiences
// This marks it as ready to be used for the token exchange endpoint.
err = h.r.OAuth2Storage().UpdateDeviceCodeSessionByRequestID(ctx, f.DeviceCodeRequestID.String(), req)
if err != nil {
x.LogError(r, err, h.r.Logger())
h.r.Writer().WriteError(w, r, err)
return
}

http.Redirect(w, r, urlx.SetQuery(h.c.DeviceDoneURL(ctx), url.Values{"consent_verifier": {string(flow.ConsentVerifier)}}).String(), http.StatusFound)
// TODO evaluate if an OpenID Connect session is necessary for device flow.
// Update the OpenID Connect session if "openid" scope is granted
if req.GetGrantedScopes().Has("openid") {
wood-push-melon marked this conversation as resolved.
Show resolved Hide resolved
err = h.r.OAuth2Storage().UpdateOpenIDConnectSessionByRequestID(ctx, f.DeviceCodeRequestID.String(), req)
if err != nil {
x.LogError(r, err, h.r.Logger())
h.r.Writer().WriteError(w, r, err)
return
}
}

redirectURL := urlx.SetQuery(h.c.DeviceDoneURL(ctx), url.Values{"consent_verifier": {string(f.ConsentVerifier)}}).String()
http.Redirect(w, r, redirectURL, http.StatusFound)
}

// OAuth2 Device Flow
Expand Down
91 changes: 73 additions & 18 deletions oauth2/oauth2_device_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ import (
"testing"
"time"

"github.com/pborman/uuid"

"github.com/ory/fosite/token/jwt"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
Expand All @@ -27,22 +31,26 @@ func TestDeviceAuthRequest(t *testing.T) {
reg := internal.NewMockedRegistry(t, &contextx.Default{})
testhelpers.NewOAuth2Server(ctx, t, reg)

secret := uuid.New()
c := &client.Client{
ResponseTypes: []string{"id_token", "code", "token"},
ID: "device-client",
Secret: secret,
GrantTypes: []string{
string(fosite.GrantTypeDeviceCode),
},
Scope: "hydra offline openid",
Audience: []string{"https://api.ory.sh/"},
TokenEndpointAuthMethod: "none",
TokenEndpointAuthMethod: "client_secret_post",
}
require.NoError(t, reg.ClientManager().CreateClient(ctx, c))

oauthClient := &oauth2.Config{
ClientID: c.GetID(),
ClientID: c.GetID(),
ClientSecret: secret,
Endpoint: oauth2.Endpoint{
DeviceAuthURL: reg.Config().OAuth2DeviceAuthorisationURL(ctx).String(),
TokenURL: reg.Config().OAuth2TokenURL(ctx).String(),
AuthStyle: oauth2.AuthStyleInParams,
},
Scopes: strings.Split(c.Scope, " "),
}
Expand Down Expand Up @@ -71,7 +79,7 @@ func TestDeviceAuthRequest(t *testing.T) {
testCase.setUp()
}

resp, err := oauthClient.DeviceAuth(context.Background())
resp, err := oauthClient.DeviceAuth(context.Background(), []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("client_secret", secret)}...)

if testCase.check != nil {
testCase.check(t, resp, err)
Expand All @@ -89,44 +97,85 @@ func TestDeviceTokenRequest(t *testing.T) {
reg := internal.NewMockedRegistry(t, &contextx.Default{})
testhelpers.NewOAuth2Server(ctx, t, reg)

secret := uuid.New()
c := &client.Client{
ID: "device-client",
Secret: secret,
GrantTypes: []string{
string(fosite.GrantTypeDeviceCode),
string(fosite.GrantTypeRefreshToken),
},
Scope: "hydra offline openid",
Audience: []string{"https://api.ory.sh/"},
TokenEndpointAuthMethod: "none",
Scope: "hydra offline openid",
Audience: []string{"https://api.ory.sh/"},
}
require.NoError(t, reg.ClientManager().CreateClient(ctx, c))

oauthClient := &oauth2.Config{
ClientID: c.GetID(),
ClientID: c.GetID(),
ClientSecret: secret,
Endpoint: oauth2.Endpoint{
DeviceAuthURL: reg.Config().OAuth2DeviceAuthorisationURL(ctx).String(),
TokenURL: reg.Config().OAuth2TokenURL(ctx).String(),
AuthStyle: oauth2.AuthStyleInHeader,
},
Scopes: strings.Split(c.Scope, " "),
}

var code, signature string
var err error
code, signature, err = reg.RFC8628HMACStrategy().GenerateDeviceCode(context.TODO())
require.NoError(t, err)

testCases := []struct {
description string
setUp func()
setUp func(signature string)
check func(t *testing.T, token *oauth2.Token, err error)
cleanUp func()
}{
{
description: "should pass",
setUp: func() {
description: "should pass with refresh token",
setUp: func(signature string) {
authreq := &fosite.DeviceRequest{
Request: fosite.Request{
Client: &fosite.DefaultClient{
ID: c.GetID(),
GrantTypes: []string{string(fosite.GrantTypeDeviceCode)},
},
RequestedScope: []string{"hydra", "offline"},
GrantedScope: []string{"hydra", "offline"},
Session: &hydraoauth2.Session{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: "hydra",
},
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.DeviceCode: time.Now().Add(time.Hour).UTC(),
},
},
BrowserFlowCompleted: true,
},
RequestedAt: time.Now(),
},
}

require.NoError(t, reg.OAuth2Storage().CreateDeviceCodeSession(context.TODO(), signature, authreq))
},
check: func(t *testing.T, token *oauth2.Token, err error) {
assert.NotEmpty(t, token.AccessToken)
assert.NotEmpty(t, token.RefreshToken)
},
},
{
description: "should pass with ID token",
setUp: func(signature string) {
authreq := &fosite.DeviceRequest{
Request: fosite.Request{
Client: &fosite.DefaultClient{ID: c.GetID(), GrantTypes: []string{string(fosite.GrantTypeDeviceCode)}},
Client: &fosite.DefaultClient{
ID: c.GetID(),
GrantTypes: []string{string(fosite.GrantTypeDeviceCode)},
},
RequestedScope: []string{"hydra", "offline", "openid"},
GrantedScope: []string{"hydra", "offline", "openid"},
Session: &hydraoauth2.Session{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: "hydra",
},
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.DeviceCode: time.Now().Add(time.Hour).UTC(),
},
Expand All @@ -138,17 +187,23 @@ func TestDeviceTokenRequest(t *testing.T) {
}

require.NoError(t, reg.OAuth2Storage().CreateDeviceCodeSession(context.TODO(), signature, authreq))
require.NoError(t, reg.OAuth2Storage().CreateOpenIDConnectSession(context.TODO(), signature, authreq))
},
check: func(t *testing.T, token *oauth2.Token, err error) {
assert.NotEmpty(t, token.AccessToken)
assert.NotEmpty(t, token.RefreshToken)
assert.NotEmpty(t, token.Extra("id_token"))
},
},
}

for _, testCase := range testCases {
t.Run("case="+testCase.description, func(t *testing.T) {
code, signature, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(context.TODO())
require.NoError(t, err)

if testCase.setUp != nil {
testCase.setUp()
testCase.setUp(signature)
}

var token *oauth2.Token
Expand Down
48 changes: 37 additions & 11 deletions persistence/sql/persister_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,30 @@ func (p *Persister) CreateOpenIDConnectSession(ctx context.Context, signature st
return p.createSession(ctx, signature, requester, sqlTableOpenID)
}

// UpdateOpenIDConnectSessionByRequestID updates an OpenID session by requestID
func (p *Persister) UpdateOpenIDConnectSessionByRequestID(ctx context.Context, requestID string, requester fosite.Requester) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateOpenIDConnectSessionByRequestID")
defer otelx.End(span, &err)

req, err := p.sqlSchemaFromRequest(ctx, requestID, requester, sqlTableOpenID)
if err != nil {
return err
}

stmt := fmt.Sprintf(
"UPDATE %s SET granted_scope=?, granted_audience=?, session_data=? WHERE request_id=? AND nid = ?",
OAuth2RequestSQL{Table: sqlTableOpenID}.TableName(),
)

/* #nosec G201 table is static */
err = p.Connection(ctx).RawQuery(stmt, req.GrantedScope, req.GrantedAudience, req.Session, requestID, p.NetworkID(ctx)).Exec()
if err != nil {
return sqlcon.HandleError(err)
}

return nil
}

func (p *Persister) GetOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (_ fosite.Requester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetOpenIDConnectSession")
defer otelx.End(span, &err)
Expand Down Expand Up @@ -560,22 +584,24 @@ func (p *Persister) CreateDeviceCodeSession(ctx context.Context, signature strin
func (p *Persister) UpdateDeviceCodeSessionByRequestID(ctx context.Context, requestID string, requester fosite.Requester) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateDeviceCodeSessionByRequestID")
defer otelx.End(span, &err)

req, err := p.sqlSchemaFromRequest(ctx, requestID, requester, sqlTableDeviceCode)
if err != nil {
return
return err
}

/* #nosec G201 table is static */
return sqlcon.HandleError(
p.Connection(ctx).
RawQuery(
fmt.Sprintf("UPDATE %s SET session_data=? WHERE request_id=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableDeviceCode}.TableName()),
req.Session,
requestID,
p.NetworkID(ctx),
).
Exec(),
stmt := fmt.Sprintf(
"UPDATE %s SET granted_scope=?, granted_audience=?, session_data=? WHERE request_id=? AND nid = ?",
OAuth2RequestSQL{Table: sqlTableDeviceCode}.TableName(),
)

/* #nosec G201 table is static */
err = p.Connection(ctx).RawQuery(stmt, req.GrantedScope, req.GrantedAudience, req.Session, requestID, p.NetworkID(ctx)).Exec()
if err != nil {
return sqlcon.HandleError(err)
}

return nil
}

// GetDeviceCodeSession returns a device code session from the database
Expand Down
2 changes: 2 additions & 0 deletions x/fosite_storer.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ type FositeStorer interface {

FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error

UpdateOpenIDConnectSessionByRequestID(ctx context.Context, requestID string, requester fosite.Requester) error

// DeleteOpenIDConnectSession deletes an OpenID Connect session.
// This is duplicated from Ory Fosite to help against deprecation linting errors.
DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) error
Expand Down
Loading