Skip to content

Commit

Permalink
fix: fix the OIDC token and refresh token issue for device flow
Browse files Browse the repository at this point in the history
  • Loading branch information
wood-push-melon committed Apr 11, 2024
1 parent d80cc46 commit 1f6ddcc
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 38 deletions.
1 change: 1 addition & 0 deletions fositex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ var defaultFactories = []Factory{
compose.OIDCUserinfoVerifiableCredentialFactory,
compose.RFC8628DeviceFactory,
compose.RFC8628DeviceAuthorizationTokenFactory,
compose.OpenIDConnectDeviceFactory,

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run end-to-end tests (postgres, --jwt)

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run end-to-end tests (memory)

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run tests and lints

undefined: compose.OpenIDConnectDeviceFactory) (typecheck)

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run tests and lints

undefined: compose.OpenIDConnectDeviceFactory (typecheck)

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run end-to-end tests (cockroach, --jwt)

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run end-to-end tests (cockroach)

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run end-to-end tests (memory, --jwt)

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run end-to-end tests (mysql)

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run HSM tests

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run HSM tests

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run HSM tests

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run HSM tests

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run HSM tests

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run HSM tests

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run HSM tests

undefined: compose.OpenIDConnectDeviceFactory

Check failure on line 68 in fositex/config.go

View workflow job for this annotation

GitHub Actions / Run end-to-end tests (postgres)

undefined: compose.OpenIDConnectDeviceFactory
}

func NewConfig(deps configDependencies) *Config {
Expand Down
25 changes: 16 additions & 9 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,40 +721,47 @@ func (h *Handler) getOidcUserInfo(w http.ResponseWriter, r *http.Request) {
func (h *Handler) performOAuth2DeviceVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
ctx := r.Context()

consentSession, flow, err := h.r.ConsentStrategy().HandleOAuth2DeviceAuthorizationRequest(ctx, w, r)
consentSession, f, err := h.r.ConsentStrategy().HandleOAuth2DeviceAuthorizationRequest(ctx, w, r)
if errors.Is(err, consent.ErrAbortOAuth2Request) {
x.LogAudit(r, nil, h.r.AuditLogger())
// do nothing
return
} else if e := &(fosite.RFC6749Error{}); errors.As(err, &e) {
}

if e := &(fosite.RFC6749Error{}); errors.As(err, &e) {
x.LogAudit(r, err, h.r.AuditLogger())
h.r.Writer().WriteError(w, r, err)
return
} else if err != nil {
}

if err != nil {
x.LogError(r, err, h.r.Logger())
h.r.Writer().WriteError(w, r, err)
return
}

req := fosite.NewDeviceRequest()
req.Client = consentSession.ConsentRequest.Client
session, err := h.updateSessionWithRequest(ctx, consentSession, flow, r, req)
session, err := h.updateSessionWithRequest(ctx, consentSession, f, r, req)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
}

req.SetSession(session)
// We update the device_code session with the claims that the user gave consent for, this
// marks it as ready to be used for the token endpoint
err = h.r.OAuth2Storage().UpdateDeviceCodeSessionByRequestID(ctx, flow.DeviceCodeRequestID.String(), req)
// Update the device code session with
// - the claims for which the user gave consent
// - the granted scopes
// - the granted audiences
// This marks it as ready to be used for the token exchange endpoint.
err = h.r.OAuth2Storage().UpdateDeviceCodeSessionByRequestID(ctx, f.DeviceCodeRequestID.String(), req)
if err != nil {
x.LogError(r, err, h.r.Logger())
h.r.Writer().WriteError(w, r, err)
return
}

http.Redirect(w, r, urlx.SetQuery(h.c.DeviceDoneURL(ctx), url.Values{"consent_verifier": {string(flow.ConsentVerifier)}}).String(), http.StatusFound)
redirectURL := urlx.SetQuery(h.c.DeviceDoneURL(ctx), url.Values{"consent_verifier": {string(f.ConsentVerifier)}}).String()
http.Redirect(w, r, redirectURL, http.StatusFound)
}

// OAuth2 Device Flow
Expand Down
90 changes: 72 additions & 18 deletions oauth2/oauth2_device_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"testing"
"time"

"github.com/ory/fosite/token/jwt"
"github.com/pborman/uuid"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
Expand All @@ -27,22 +30,26 @@ func TestDeviceAuthRequest(t *testing.T) {
reg := internal.NewMockedRegistry(t, &contextx.Default{})
testhelpers.NewOAuth2Server(ctx, t, reg)

secret := uuid.New()
c := &client.Client{
ResponseTypes: []string{"id_token", "code", "token"},
ID: "device-client",
Secret: secret,
GrantTypes: []string{
string(fosite.GrantTypeDeviceCode),
},
Scope: "hydra offline openid",
Audience: []string{"https://api.ory.sh/"},
TokenEndpointAuthMethod: "none",
TokenEndpointAuthMethod: "client_secret_post",
}
require.NoError(t, reg.ClientManager().CreateClient(ctx, c))

oauthClient := &oauth2.Config{
ClientID: c.GetID(),
ClientID: c.GetID(),
ClientSecret: secret,
Endpoint: oauth2.Endpoint{
DeviceAuthURL: reg.Config().OAuth2DeviceAuthorisationURL(ctx).String(),
TokenURL: reg.Config().OAuth2TokenURL(ctx).String(),
AuthStyle: oauth2.AuthStyleInParams,
},
Scopes: strings.Split(c.Scope, " "),
}
Expand Down Expand Up @@ -71,7 +78,7 @@ func TestDeviceAuthRequest(t *testing.T) {
testCase.setUp()
}

resp, err := oauthClient.DeviceAuth(context.Background())
resp, err := oauthClient.DeviceAuth(context.Background(), []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("client_secret", secret)}...)

if testCase.check != nil {
testCase.check(t, resp, err)
Expand All @@ -89,44 +96,52 @@ func TestDeviceTokenRequest(t *testing.T) {
reg := internal.NewMockedRegistry(t, &contextx.Default{})
testhelpers.NewOAuth2Server(ctx, t, reg)

secret := uuid.New()
c := &client.Client{
ID: "device-client",
Secret: secret,
GrantTypes: []string{
string(fosite.GrantTypeDeviceCode),
string(fosite.GrantTypeRefreshToken),
},
Scope: "hydra offline openid",
Audience: []string{"https://api.ory.sh/"},
TokenEndpointAuthMethod: "none",
Scope: "hydra offline openid",
Audience: []string{"https://api.ory.sh/"},
}
require.NoError(t, reg.ClientManager().CreateClient(ctx, c))

oauthClient := &oauth2.Config{
ClientID: c.GetID(),
ClientID: c.GetID(),
ClientSecret: secret,
Endpoint: oauth2.Endpoint{
DeviceAuthURL: reg.Config().OAuth2DeviceAuthorisationURL(ctx).String(),
TokenURL: reg.Config().OAuth2TokenURL(ctx).String(),
AuthStyle: oauth2.AuthStyleInHeader,
},
Scopes: strings.Split(c.Scope, " "),
}

var code, signature string
var err error
code, signature, err = reg.RFC8628HMACStrategy().GenerateDeviceCode(context.TODO())
require.NoError(t, err)

testCases := []struct {
description string
setUp func()
setUp func(signature string)
check func(t *testing.T, token *oauth2.Token, err error)
cleanUp func()
}{
{
description: "should pass",
setUp: func() {
description: "should pass with refresh token",
setUp: func(signature string) {
authreq := &fosite.DeviceRequest{
Request: fosite.Request{
Client: &fosite.DefaultClient{ID: c.GetID(), GrantTypes: []string{string(fosite.GrantTypeDeviceCode)}},
Client: &fosite.DefaultClient{
ID: c.GetID(),
GrantTypes: []string{string(fosite.GrantTypeDeviceCode)},
},
RequestedScope: []string{"hydra", "offline"},
GrantedScope: []string{"hydra", "offline"},
Session: &hydraoauth2.Session{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: "hydra",
},
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.DeviceCode: time.Now().Add(time.Hour).UTC(),
},
Expand All @@ -141,14 +156,53 @@ func TestDeviceTokenRequest(t *testing.T) {
},
check: func(t *testing.T, token *oauth2.Token, err error) {
assert.NotEmpty(t, token.AccessToken)
assert.NotEmpty(t, token.RefreshToken)
},
},
{
description: "should pass with ID token",
setUp: func(signature string) {
authreq := &fosite.DeviceRequest{
Request: fosite.Request{
Client: &fosite.DefaultClient{
ID: c.GetID(),
GrantTypes: []string{string(fosite.GrantTypeDeviceCode)},
},
RequestedScope: []string{"hydra", "offline", "openid"},
GrantedScope: []string{"hydra", "offline", "openid"},
Session: &hydraoauth2.Session{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: "hydra",
},
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.DeviceCode: time.Now().Add(time.Hour).UTC(),
},
},
BrowserFlowCompleted: true,
},
RequestedAt: time.Now(),
},
}

require.NoError(t, reg.OAuth2Storage().CreateDeviceCodeSession(context.TODO(), signature, authreq))
require.NoError(t, reg.OAuth2Storage().CreateOpenIDConnectSession(context.TODO(), signature, authreq))
},
check: func(t *testing.T, token *oauth2.Token, err error) {
assert.NotEmpty(t, token.AccessToken)
assert.NotEmpty(t, token.RefreshToken)
assert.NotEmpty(t, token.Extra("id_token"))
},
},
}

for _, testCase := range testCases {
t.Run("case="+testCase.description, func(t *testing.T) {
code, signature, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(context.TODO())
require.NoError(t, err)

if testCase.setUp != nil {
testCase.setUp()
testCase.setUp(signature)
}

var token *oauth2.Token
Expand Down
24 changes: 13 additions & 11 deletions persistence/sql/persister_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,22 +560,24 @@ func (p *Persister) CreateDeviceCodeSession(ctx context.Context, signature strin
func (p *Persister) UpdateDeviceCodeSessionByRequestID(ctx context.Context, requestID string, requester fosite.Requester) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateDeviceCodeSessionByRequestID")
defer otelx.End(span, &err)

req, err := p.sqlSchemaFromRequest(ctx, requestID, requester, sqlTableDeviceCode)
if err != nil {
return
return err
}

/* #nosec G201 table is static */
return sqlcon.HandleError(
p.Connection(ctx).
RawQuery(
fmt.Sprintf("UPDATE %s SET session_data=? WHERE request_id=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableDeviceCode}.TableName()),
req.Session,
requestID,
p.NetworkID(ctx),
).
Exec(),
stmt := fmt.Sprintf(
"UPDATE %s SET granted_scope=?, granted_audience=?, session_data=? WHERE request_id=? AND nid = ?",
OAuth2RequestSQL{Table: sqlTableDeviceCode}.TableName(),
)

/* #nosec G201 table is static */
err = p.Connection(ctx).RawQuery(stmt, req.GrantedScope, req.GrantedAudience, req.Session, requestID, p.NetworkID(ctx)).Exec()
if err != nil {
return sqlcon.HandleError(err)
}

return nil
}

// GetDeviceCodeSession returns a device code session from the database
Expand Down

0 comments on commit 1f6ddcc

Please sign in to comment.