From c4fe21c22c99fc8bbc9188692e656128d033a9f5 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Thu, 28 Nov 2024 11:54:13 +0100 Subject: [PATCH 1/3] refactor: refresh token rotation interfaces Previously, the refresh token handler was using a combination of delete/update storage primitives. This made optimizing and implementing the refresh token handling difficult. Going forward, the RefreshTokenStorage must implement `RotateRefreshToken`. Token creation continues to be separated. BREAKING CHANGES: Method `RevokeRefreshTokenMaybeGracePeriod` was removed from `handler/fosite/TokenRevocationStorage`. Interface `handler/fosite/RefreshTokenStorage` has changed: - `CreateRefreshToken` now takes an additional argument `accessSignature` to keep track of refresh/access token pairs: - A new method `RotateRefreshToken` was added, which revokes old refresh tokens and associated access tokens: ```patch // handler/fosite/storage.go type RefreshTokenStorage interface { - CreateRefreshTokenSession(ctx context.Context, signature string, request fosite.Requester) (err error) + CreateRefreshTokenSession(ctx context.Context, signature string, accessSignature string, request fosite.Requester) (err error) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) + RotateRefreshToken(ctx context.Context, requestID string, refreshTokenSignature string) (err error) } ``` --- Makefile | 8 + handler/oauth2/flow_authorize_code_token.go | 2 +- .../oauth2/flow_authorize_code_token_test.go | 4 +- handler/oauth2/flow_client_credentials.go | 3 +- handler/oauth2/flow_refresh.go | 16 +- handler/oauth2/flow_refresh_test.go | 226 +++--------------- handler/oauth2/flow_resource_owner.go | 13 +- handler/oauth2/flow_resource_owner_test.go | 6 +- handler/oauth2/helper.go | 8 +- handler/oauth2/helper_test.go | 4 +- handler/oauth2/revocation_storage.go | 12 - handler/oauth2/storage.go | 4 +- handler/rfc7523/handler.go | 3 +- internal/access_request.go | 1 - internal/access_response.go | 1 - internal/access_token_storage.go | 1 - internal/access_token_strategy.go | 1 - internal/authorize_code_storage.go | 1 - internal/authorize_code_strategy.go | 1 - internal/authorize_handler.go | 1 - internal/authorize_request.go | 1 - internal/client.go | 28 --- internal/id_token_strategy.go | 1 - internal/introspector.go | 1 - internal/oauth2_client_storage.go | 1 - internal/oauth2_owner_storage.go | 22 +- internal/oauth2_revoke_storage.go | 21 +- internal/oauth2_storage.go | 23 +- internal/oauth2_strategy.go | 1 - internal/openid_id_token_storage.go | 1 - internal/pkce_storage_strategy.go | 1 - internal/refresh_token_strategy.go | 1 - internal/request.go | 1 - internal/revoke_handler.go | 1 - internal/storage.go | 1 - internal/token_handler.go | 1 - storage/memory.go | 21 +- 37 files changed, 127 insertions(+), 316 deletions(-) diff --git a/Makefile b/Makefile index baa915e96..dbbc657ab 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +export PATH := .bin:${PATH} + format: .bin/goimports .bin/ory node_modules # formats the source code .bin/ory dev headers copyright --type=open-source .bin/goimports -w . @@ -18,6 +20,9 @@ test: # runs all tests .bin/licenses: Makefile curl https://raw.githubusercontent.com/ory/ci/master/licenses/install | sh +.bin/mockgen: + go build -o .bin/mockgen github.com/golang/mock/mockgen + .bin/ory: Makefile curl https://raw.githubusercontent.com/ory/meta/master/install.sh | bash -s -- -b .bin ory v0.1.48 touch .bin/ory @@ -26,4 +31,7 @@ node_modules: package-lock.json npm ci touch node_modules +gen: .bin/goimports .bin/mockgen # generates mocks + ./generate-mocks.sh + .DEFAULT_GOAL := help diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index dceb83001..3b808aabc 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -169,7 +169,7 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex } else if err = c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil { return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } else if refreshSignature != "" { - if err = c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil { + if err = c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, accessSignature, requester.Sanitize([]string{})); err != nil { return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } } diff --git a/handler/oauth2/flow_authorize_code_token_test.go b/handler/oauth2/flow_authorize_code_token_test.go index b90ca5246..94587111b 100644 --- a/handler/oauth2/flow_authorize_code_token_test.go +++ b/handler/oauth2/flow_authorize_code_token_test.go @@ -498,7 +498,7 @@ func TestAuthorizeCodeTransactional_HandleTokenEndpointRequest(t *testing.T) { Times(1) mockCoreStore. EXPECT(). - CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any(), gomock.Any()). Return(nil). Times(1) mockTransactional. @@ -627,7 +627,7 @@ func TestAuthorizeCodeTransactional_HandleTokenEndpointRequest(t *testing.T) { Times(1) mockCoreStore. EXPECT(). - CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any(), gomock.Any()). Return(nil). Times(1) mockTransactional. diff --git a/handler/oauth2/flow_client_credentials.go b/handler/oauth2/flow_client_credentials.go index 957d9b8e1..bcf3cdc11 100644 --- a/handler/oauth2/flow_client_credentials.go +++ b/handler/oauth2/flow_client_credentials.go @@ -64,7 +64,8 @@ func (c *ClientCredentialsGrantHandler) PopulateTokenEndpointResponse(ctx contex } atLifespan := fosite.GetEffectiveLifespan(request.GetClient(), fosite.GrantTypeClientCredentials, fosite.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) - return c.IssueAccessToken(ctx, atLifespan, request, response) + _, err := c.IssueAccessToken(ctx, atLifespan, request, response) + return err } func (c *ClientCredentialsGrantHandler) CanSkipClientAuth(ctx context.Context, requester fosite.AccessRequester) bool { diff --git a/handler/oauth2/flow_refresh.go b/handler/oauth2/flow_refresh.go index 789f8d63e..297cd279c 100644 --- a/handler/oauth2/flow_refresh.go +++ b/handler/oauth2/flow_refresh.go @@ -69,7 +69,6 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex scopeNames := strings.Join(c.Config.GetRefreshTokenScopes(ctx), " or ") hint := fmt.Sprintf("The OAuth 2.0 Client was not granted scope %s and may thus not perform the 'refresh_token' authorization grant.", scopeNames) return errorsx.WithStack(fosite.ErrScopeNotGranted.WithHint(hint)) - } // The authorization server MUST ... and ensure that the refresh token was issued to the authenticated client @@ -134,25 +133,18 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con err = c.handleRefreshTokenEndpointStorageError(ctx, err) }() - ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil) - if err != nil { - return err - } else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil { - return err - } + storeReq := requester.Sanitize([]string{}) + storeReq.SetID(requester.GetID()) - if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, ts.GetID(), signature); err != nil { + if err = c.TokenRevocationStorage.RotateRefreshToken(ctx, requester.GetID(), signature); err != nil { return err } - storeReq := requester.Sanitize([]string{}) - storeReq.SetID(ts.GetID()) - if err = c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil { return err } - if err = c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, storeReq); err != nil { + if err = c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, accessSignature, storeReq); err != nil { return err } diff --git a/handler/oauth2/flow_refresh_test.go b/handler/oauth2/flow_refresh_test.go index f9b00526b..c9a016553 100644 --- a/handler/oauth2/flow_refresh_test.go +++ b/handler/oauth2/flow_refresh_test.go @@ -85,7 +85,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { require.NoError(t, err) areq.Form.Add("refresh_token", token) - err = store.CreateRefreshTokenSession(context.Background(), sig, &fosite.Request{ + err = store.CreateRefreshTokenSession(context.Background(), sig, "", &fosite.Request{ Client: &fosite.DefaultClient{ID: ""}, GrantedScope: []string{"offline"}, Session: sess, @@ -108,7 +108,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { require.NoError(t, err) areq.Form.Add("refresh_token", token) - err = store.CreateRefreshTokenSession(context.Background(), sig, &fosite.Request{ + err = store.CreateRefreshTokenSession(context.Background(), sig, "", &fosite.Request{ Client: areq.Client, GrantedScope: fosite.Arguments{"foo", "offline"}, RequestedScope: fosite.Arguments{"foo", "bar", "offline"}, @@ -133,7 +133,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { require.NoError(t, err) areq.Form.Add("refresh_token", token) - err = store.CreateRefreshTokenSession(context.Background(), sig, &fosite.Request{ + err = store.CreateRefreshTokenSession(context.Background(), sig, "", &fosite.Request{ Client: areq.Client, GrantedScope: fosite.Arguments{"foo", "offline"}, RequestedScope: fosite.Arguments{"foo", "offline"}, @@ -162,7 +162,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { orReqID := areq.GetID() + "_OR" areq.Form.Add("or_request_id", orReqID) - err = store.CreateRefreshTokenSession(context.Background(), sig, &fosite.Request{ + err = store.CreateRefreshTokenSession(context.Background(), sig, "", &fosite.Request{ ID: orReqID, Client: areq.Client, GrantedScope: fosite.Arguments{"foo", "offline"}, @@ -202,7 +202,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { require.NoError(t, err) areq.Form.Add("refresh_token", token) - err = store.CreateRefreshTokenSession(context.Background(), sig, &fosite.Request{ + err = store.CreateRefreshTokenSession(context.Background(), sig, "", &fosite.Request{ Client: areq.Client, GrantedScope: fosite.Arguments{"foo", "offline"}, RequestedScope: fosite.Arguments{"foo", "bar", "offline"}, @@ -236,7 +236,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { require.NoError(t, err) areq.Form.Add("refresh_token", token) - err = store.CreateRefreshTokenSession(context.Background(), sig, &fosite.Request{ + err = store.CreateRefreshTokenSession(context.Background(), sig, "", &fosite.Request{ Client: areq.Client, GrantedScope: fosite.Arguments{"foo"}, RequestedScope: fosite.Arguments{"foo", "bar"}, @@ -263,7 +263,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { require.NoError(t, err) areq.Form.Add("refresh_token", token) - err = store.CreateRefreshTokenSession(context.Background(), sig, &fosite.Request{ + err = store.CreateRefreshTokenSession(context.Background(), sig, "", &fosite.Request{ Client: areq.Client, GrantedScope: fosite.Arguments{"foo"}, RequestedScope: fosite.Arguments{"foo", "bar"}, @@ -305,7 +305,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { Form: url.Values{"foo": []string{"bar"}}, RequestedAt: time.Now().UTC().Add(-time.Hour).Round(time.Hour), } - err = store.CreateRefreshTokenSession(context.Background(), sig, req) + err = store.CreateRefreshTokenSession(context.Background(), sig, "", req) require.NoError(t, err) err = store.RevokeRefreshToken(context.Background(), req.ID) @@ -468,7 +468,7 @@ func TestRefreshFlow_PopulateTokenEndpointResponse(t *testing.T) { token, signature, err := strategy.GenerateRefreshToken(context.Background(), nil) require.NoError(t, err) - require.NoError(t, store.CreateRefreshTokenSession(context.Background(), signature, areq)) + require.NoError(t, store.CreateRefreshTokenSession(context.Background(), signature, "", areq)) areq.Form.Add("refresh_token", token) }, check: func(t *testing.T) { @@ -551,17 +551,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(request, nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeAccessToken(propagatedContext, gomock.Any()). - Return(nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeRefreshTokenMaybeGracePeriod(propagatedContext, gomock.Any(), gomock.Any()). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). Return(nil). Times(1) mockRevocationStore. @@ -571,7 +561,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any(), gomock.Any()). Return(nil). Times(1) mockTransactional. @@ -581,51 +571,6 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) }, }, - { - description: "transaction should be rolled back if call to `GetRefreshTokenSession` results in an error", - setup: func() { - request.GrantTypes = fosite.Arguments{"refresh_token"} - mockTransactional. - EXPECT(). - BeginTX(propagatedContext). - Return(propagatedContext, nil). - Times(1) - mockRevocationStore. - EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(nil, errors.New("Whoops, a nasty database error occurred!")). - Times(1) - mockTransactional. - EXPECT(). - Rollback(propagatedContext). - Return(nil). - Times(1) - }, - expectError: fosite.ErrServerError, - }, - { - description: "should result in a fosite.ErrInvalidRequest if `GetRefreshTokenSession` results in a " + - "fosite.ErrNotFound error", - setup: func() { - request.GrantTypes = fosite.Arguments{"refresh_token"} - mockTransactional. - EXPECT(). - BeginTX(propagatedContext). - Return(propagatedContext, nil). - Times(1) - mockRevocationStore. - EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(nil, fosite.ErrNotFound). - Times(1) - mockTransactional. - EXPECT(). - Rollback(propagatedContext). - Return(nil). - Times(1) - }, - expectError: fosite.ErrInvalidRequest, - }, { description: "transaction should be rolled back if call to `RevokeAccessToken` results in an error", setup: func() { @@ -637,12 +582,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(request, nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeAccessToken(propagatedContext, gomock.Any()). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). Return(errors.New("Whoops, a nasty database error occurred!")). Times(1) mockTransactional. @@ -665,12 +605,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(request, nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeAccessToken(propagatedContext, gomock.Any()). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). Return(fosite.ErrSerializationFailure). Times(1) mockTransactional. @@ -682,8 +617,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { expectError: fosite.ErrInvalidRequest, }, { - description: "should result in a fosite.ErrInactiveToken if call to `RevokeAccessToken` results in a " + - "fosite.ErrInvalidRequest error", + description: "transaction should be rolled back if call to `RotateRefreshToken` results in an error", setup: func() { request.GrantTypes = fosite.Arguments{"refresh_token"} mockTransactional. @@ -693,39 +627,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(nil, fosite.ErrInactiveToken). - Times(1) - mockTransactional. - EXPECT(). - Rollback(propagatedContext). - Return(nil). - Times(1) - }, - expectError: fosite.ErrInvalidRequest, - }, - { - description: "transaction should be rolled back if call to `RevokeRefreshTokenMaybeGracePeriod` results in an error", - setup: func() { - request.GrantTypes = fosite.Arguments{"refresh_token"} - mockTransactional. - EXPECT(). - BeginTX(propagatedContext). - Return(propagatedContext, nil). - Times(1) - mockRevocationStore. - EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(request, nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeAccessToken(propagatedContext, gomock.Any()). - Return(nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeRefreshTokenMaybeGracePeriod(propagatedContext, gomock.Any(), gomock.Any()). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). Return(errors.New("Whoops, a nasty database error occurred!")). Times(1) mockTransactional. @@ -737,7 +639,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { expectError: fosite.ErrServerError, }, { - description: "should result in a fosite.ErrInvalidRequest if call to `RevokeRefreshTokenMaybeGracePeriod` results in a " + + description: "should result in a fosite.ErrInvalidRequest if call to `RotateRefreshToken` results in a " + "fosite.ErrSerializationFailure error", setup: func() { request.GrantTypes = fosite.Arguments{"refresh_token"} @@ -748,17 +650,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(request, nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeAccessToken(propagatedContext, gomock.Any()). - Return(nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeRefreshTokenMaybeGracePeriod(propagatedContext, gomock.Any(), gomock.Any()). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). Return(fosite.ErrSerializationFailure). Times(1) mockTransactional. @@ -780,17 +672,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(request, nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeAccessToken(propagatedContext, gomock.Any()). - Return(nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeRefreshTokenMaybeGracePeriod(propagatedContext, gomock.Any(), gomock.Any()). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). Return(nil). Times(1) mockRevocationStore. @@ -816,17 +698,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(request, nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeAccessToken(propagatedContext, gomock.Any()). - Return(nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeRefreshTokenMaybeGracePeriod(propagatedContext, gomock.Any(), gomock.Any()). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). Return(nil). Times(1) mockRevocationStore. @@ -853,17 +725,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(request, nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeAccessToken(propagatedContext, gomock.Any()). - Return(nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeRefreshTokenMaybeGracePeriod(propagatedContext, gomock.Any(), gomock.Any()). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). Return(nil). Times(1) mockRevocationStore. @@ -873,7 +735,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any(), gomock.Any()). Return(errors.New("Whoops, a nasty database error occurred!")). Times(1) mockTransactional. @@ -896,17 +758,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(request, nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeAccessToken(propagatedContext, gomock.Any()). - Return(nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeRefreshTokenMaybeGracePeriod(propagatedContext, gomock.Any(), gomock.Any()). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). Return(nil). Times(1) mockRevocationStore. @@ -916,7 +768,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any(), gomock.Any()). Return(fosite.ErrSerializationFailure). Times(1) mockTransactional. @@ -950,8 +802,8 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(nil, fosite.ErrNotFound). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). + Return(fosite.ErrNotFound). Times(1) mockTransactional. EXPECT(). @@ -972,17 +824,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(request, nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeAccessToken(propagatedContext, gomock.Any()). - Return(nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeRefreshTokenMaybeGracePeriod(propagatedContext, gomock.Any(), gomock.Any()). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). Return(nil). Times(1) mockRevocationStore. @@ -992,7 +834,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any(), gomock.Any()). Return(nil). Times(1) mockTransactional. @@ -1020,17 +862,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - GetRefreshTokenSession(propagatedContext, gomock.Any(), nil). - Return(request, nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeAccessToken(propagatedContext, gomock.Any()). - Return(nil). - Times(1) - mockRevocationStore. - EXPECT(). - RevokeRefreshTokenMaybeGracePeriod(propagatedContext, gomock.Any(), gomock.Any()). + RotateRefreshToken(propagatedContext, gomock.Any(), gomock.Any()). Return(nil). Times(1) mockRevocationStore. @@ -1040,7 +872,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) { Times(1) mockRevocationStore. EXPECT(). - CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()). + CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any(), gomock.Any()). Return(nil). Times(1) mockTransactional. diff --git a/handler/oauth2/flow_resource_owner.go b/handler/oauth2/flow_resource_owner.go index 8c1e04370..8474092c9 100644 --- a/handler/oauth2/flow_resource_owner.go +++ b/handler/oauth2/flow_resource_owner.go @@ -93,22 +93,23 @@ func (c *ResourceOwnerPasswordCredentialsGrantHandler) PopulateTokenEndpointResp return errorsx.WithStack(fosite.ErrUnknownRequest) } + atLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypePassword, fosite.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) + accessTokenSignature, err := c.IssueAccessToken(ctx, atLifespan, requester, responder) + if err != nil { + return err + } + var refresh, refreshSignature string if len(c.Config.GetRefreshTokenScopes(ctx)) == 0 || requester.GetGrantedScopes().HasOneOf(c.Config.GetRefreshTokenScopes(ctx)...) { var err error refresh, refreshSignature, err = c.RefreshTokenStrategy.GenerateRefreshToken(ctx, requester) if err != nil { return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } else if err := c.ResourceOwnerPasswordCredentialsGrantStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil { + } else if err := c.ResourceOwnerPasswordCredentialsGrantStorage.CreateRefreshTokenSession(ctx, refreshSignature, accessTokenSignature, requester.Sanitize([]string{})); err != nil { return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } } - atLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypePassword, fosite.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) - if err := c.IssueAccessToken(ctx, atLifespan, requester, responder); err != nil { - return err - } - if refresh != "" { responder.SetExtra("refresh_token", refresh) } diff --git a/handler/oauth2/flow_resource_owner_test.go b/handler/oauth2/flow_resource_owner_test.go index 6e8280acc..0bf201ed6 100644 --- a/handler/oauth2/flow_resource_owner_test.go +++ b/handler/oauth2/flow_resource_owner_test.go @@ -169,7 +169,7 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) { areq.GrantTypes = fosite.Arguments{"password"} areq.GrantScope("offline") rtstr.EXPECT().GenerateRefreshToken(gomock.Any(), areq).Return(mockRT, "bar", nil) - store.EXPECT().CreateRefreshTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) + store.EXPECT().CreateRefreshTokenSession(gomock.Any(), "bar", "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) chgen.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return(mockAT, "bar", nil) store.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) }, @@ -182,10 +182,10 @@ func TestResourceOwnerFlow_PopulateTokenEndpointResponse(t *testing.T) { setup: func(config *fosite.Config) { config.RefreshTokenScopes = []string{} areq.GrantTypes = fosite.Arguments{"password"} - rtstr.EXPECT().GenerateRefreshToken(gomock.Any(), areq).Return(mockRT, "bar", nil) - store.EXPECT().CreateRefreshTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) chgen.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return(mockAT, "bar", nil) store.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) + rtstr.EXPECT().GenerateRefreshToken(gomock.Any(), areq).Return(mockRT, "bar", nil) + store.EXPECT().CreateRefreshTokenSession(gomock.Any(), "bar", "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) }, expect: func() { assert.NotNil(t, aresp.GetExtra("refresh_token"), "expected refresh token") diff --git a/handler/oauth2/helper.go b/handler/oauth2/helper.go index 436501a64..7edf805bc 100644 --- a/handler/oauth2/helper.go +++ b/handler/oauth2/helper.go @@ -21,19 +21,19 @@ type HandleHelper struct { Config HandleHelperConfigProvider } -func (h *HandleHelper) IssueAccessToken(ctx context.Context, defaultLifespan time.Duration, requester fosite.AccessRequester, responder fosite.AccessResponder) error { +func (h *HandleHelper) IssueAccessToken(ctx context.Context, defaultLifespan time.Duration, requester fosite.AccessRequester, responder fosite.AccessResponder) (signature string, err error) { token, signature, err := h.AccessTokenStrategy.GenerateAccessToken(ctx, requester) if err != nil { - return err + return "", err } else if err := h.AccessTokenStorage.CreateAccessTokenSession(ctx, signature, requester.Sanitize([]string{})); err != nil { - return err + return "", err } responder.SetAccessToken(token) responder.SetTokenType("bearer") responder.SetExpiresIn(getExpiresIn(requester, fosite.AccessToken, defaultLifespan, time.Now().UTC())) responder.SetScopes(requester.GetGrantedScopes()) - return nil + return signature, nil } func getExpiresIn(r fosite.Requester, key fosite.TokenType, defaultLifespan time.Duration, now time.Time) time.Duration { diff --git a/handler/oauth2/helper_test.go b/handler/oauth2/helper_test.go index 8f42aba99..fbce3ac20 100644 --- a/handler/oauth2/helper_test.go +++ b/handler/oauth2/helper_test.go @@ -70,10 +70,12 @@ func TestIssueAccessToken(t *testing.T) { }, } { c.mock() - err := helper.IssueAccessToken(context.Background(), helper.Config.GetAccessTokenLifespan(context.TODO()), areq, aresp) + signature, err := helper.IssueAccessToken(context.Background(), helper.Config.GetAccessTokenLifespan(context.TODO()), areq, aresp) require.Equal(t, err == nil, c.err == nil) if c.err != nil { assert.EqualError(t, err, c.err.Error(), "Case %d", k) + } else { + assert.NotEmpty(t, signature, "Case %d", k) } } } diff --git a/handler/oauth2/revocation_storage.go b/handler/oauth2/revocation_storage.go index 7d923bd56..33cf5b935 100644 --- a/handler/oauth2/revocation_storage.go +++ b/handler/oauth2/revocation_storage.go @@ -22,18 +22,6 @@ type TokenRevocationStorage interface { // grant (see Implementation Note). RevokeRefreshToken(ctx context.Context, requestID string) error - // RevokeRefreshTokenMaybeGracePeriod revokes a refresh token as specified in: - // https://tools.ietf.org/html/rfc7009#section-2.1 - // If the particular - // token is a refresh token and the authorization server supports the - // revocation of access tokens, then the authorization server SHOULD - // also invalidate all access tokens based on the same authorization - // grant (see Implementation Note). - // - // If the Refresh Token grace period is greater than zero in configuration the token - // will have its expiration time set as UTCNow + GracePeriod. - RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, requestID string, signature string) error - // RevokeAccessToken revokes an access token as specified in: // https://tools.ietf.org/html/rfc7009#section-2.1 // If the token passed to the request diff --git a/handler/oauth2/storage.go b/handler/oauth2/storage.go index cc0b95831..fd9306624 100644 --- a/handler/oauth2/storage.go +++ b/handler/oauth2/storage.go @@ -42,9 +42,11 @@ type AccessTokenStorage interface { } type RefreshTokenStorage interface { - CreateRefreshTokenSession(ctx context.Context, signature string, request fosite.Requester) (err error) + CreateRefreshTokenSession(ctx context.Context, signature string, accessSignature string, request fosite.Requester) (err error) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) + + RotateRefreshToken(ctx context.Context, requestID string, refreshTokenSignature string) (err error) } diff --git a/handler/rfc7523/handler.go b/handler/rfc7523/handler.go index 385549e94..4c7767e8a 100644 --- a/handler/rfc7523/handler.go +++ b/handler/rfc7523/handler.go @@ -124,7 +124,8 @@ func (c *Handler) PopulateTokenEndpointResponse(ctx context.Context, request fos } atLifespan := fosite.GetEffectiveLifespan(request.GetClient(), fosite.GrantTypeJWTBearer, fosite.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) - return c.IssueAccessToken(ctx, atLifespan, request, response) + _, err := c.IssueAccessToken(ctx, atLifespan, request, response) + return err } func (c *Handler) CanSkipClientAuth(ctx context.Context, requester fosite.AccessRequester) bool { diff --git a/internal/access_request.go b/internal/access_request.go index fd05e5420..d2f42e326 100644 --- a/internal/access_request.go +++ b/internal/access_request.go @@ -13,7 +13,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/access_response.go b/internal/access_response.go index 340c2476f..5b3f42f4c 100644 --- a/internal/access_response.go +++ b/internal/access_response.go @@ -12,7 +12,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/access_token_storage.go b/internal/access_token_storage.go index 385424532..6d0a00512 100644 --- a/internal/access_token_storage.go +++ b/internal/access_token_storage.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/access_token_strategy.go b/internal/access_token_strategy.go index d8c64c9bd..24e95187e 100644 --- a/internal/access_token_strategy.go +++ b/internal/access_token_strategy.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/authorize_code_storage.go b/internal/authorize_code_storage.go index 1543b6825..4c80e3117 100644 --- a/internal/authorize_code_storage.go +++ b/internal/authorize_code_storage.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/authorize_code_strategy.go b/internal/authorize_code_strategy.go index 44edd6440..cc0250a7a 100644 --- a/internal/authorize_code_strategy.go +++ b/internal/authorize_code_strategy.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/authorize_handler.go b/internal/authorize_handler.go index d7420fed5..bbc919d8e 100644 --- a/internal/authorize_handler.go +++ b/internal/authorize_handler.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/authorize_request.go b/internal/authorize_request.go index 5cceafff8..7e85bb057 100644 --- a/internal/authorize_request.go +++ b/internal/authorize_request.go @@ -13,7 +13,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/client.go b/internal/client.go index 2923cf70b..e53cebf57 100644 --- a/internal/client.go +++ b/internal/client.go @@ -9,10 +9,8 @@ package internal import ( reflect "reflect" - time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) @@ -137,20 +135,6 @@ func (mr *MockClientMockRecorder) GetScopes() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetScopes", reflect.TypeOf((*MockClient)(nil).GetScopes)) } -// GetTokenLifespan mocks base method. -func (m *MockClient) GetTokenLifespan(arg0 fosite.GrantType, arg1 fosite.TokenType, arg2 time.Duration) time.Duration { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTokenLifespan", arg0, arg1, arg2) - ret0, _ := ret[0].(time.Duration) - return ret0 -} - -// GetTokenLifespan indicates an expected call of GetTokenLifespan. -func (mr *MockClientMockRecorder) GetTokenLifespan(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTokenLifespan", reflect.TypeOf((*MockClient)(nil).GetTokenLifespan), arg0, arg1, arg2) -} - // IsPublic mocks base method. func (m *MockClient) IsPublic() bool { m.ctrl.T.Helper() @@ -164,15 +148,3 @@ func (mr *MockClientMockRecorder) IsPublic() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPublic", reflect.TypeOf((*MockClient)(nil).IsPublic)) } - -// SetTokenLifespans mocks base method. -func (m *MockClient) SetTokenLifespans(arg0 map[fosite.TokenType]time.Duration) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetTokenLifespans", arg0) -} - -// SetTokenLifespans indicates an expected call of SetTokenLifespans. -func (mr *MockClientMockRecorder) SetTokenLifespans(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTokenLifespans", reflect.TypeOf((*MockClient)(nil).SetTokenLifespans), arg0) -} diff --git a/internal/id_token_strategy.go b/internal/id_token_strategy.go index 330adeaee..d953d339b 100644 --- a/internal/id_token_strategy.go +++ b/internal/id_token_strategy.go @@ -13,7 +13,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/introspector.go b/internal/introspector.go index 7b68fbf46..7122b93e2 100644 --- a/internal/introspector.go +++ b/internal/introspector.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/oauth2_client_storage.go b/internal/oauth2_client_storage.go index d33dfd777..96eb41711 100644 --- a/internal/oauth2_client_storage.go +++ b/internal/oauth2_client_storage.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/oauth2_owner_storage.go b/internal/oauth2_owner_storage.go index 7e79b1783..fca78900b 100644 --- a/internal/oauth2_owner_storage.go +++ b/internal/oauth2_owner_storage.go @@ -68,17 +68,17 @@ func (mr *MockResourceOwnerPasswordCredentialsGrantStorageMockRecorder) CreateAc } // CreateRefreshTokenSession mocks base method. -func (m *MockResourceOwnerPasswordCredentialsGrantStorage) CreateRefreshTokenSession(arg0 context.Context, arg1 string, arg2 fosite.Requester) error { +func (m *MockResourceOwnerPasswordCredentialsGrantStorage) CreateRefreshTokenSession(arg0 context.Context, arg1, arg2 string, arg3 fosite.Requester) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateRefreshTokenSession", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "CreateRefreshTokenSession", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // CreateRefreshTokenSession indicates an expected call of CreateRefreshTokenSession. -func (mr *MockResourceOwnerPasswordCredentialsGrantStorageMockRecorder) CreateRefreshTokenSession(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockResourceOwnerPasswordCredentialsGrantStorageMockRecorder) CreateRefreshTokenSession(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRefreshTokenSession", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorage)(nil).CreateRefreshTokenSession), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRefreshTokenSession", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorage)(nil).CreateRefreshTokenSession), arg0, arg1, arg2, arg3) } // DeleteAccessTokenSession mocks base method. @@ -138,3 +138,17 @@ func (mr *MockResourceOwnerPasswordCredentialsGrantStorageMockRecorder) GetRefre mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRefreshTokenSession", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorage)(nil).GetRefreshTokenSession), arg0, arg1, arg2) } + +// RotateRefreshToken mocks base method. +func (m *MockResourceOwnerPasswordCredentialsGrantStorage) RotateRefreshToken(arg0 context.Context, arg1, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RotateRefreshToken", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// RotateRefreshToken indicates an expected call of RotateRefreshToken. +func (mr *MockResourceOwnerPasswordCredentialsGrantStorageMockRecorder) RotateRefreshToken(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RotateRefreshToken", reflect.TypeOf((*MockResourceOwnerPasswordCredentialsGrantStorage)(nil).RotateRefreshToken), arg0, arg1, arg2) +} diff --git a/internal/oauth2_revoke_storage.go b/internal/oauth2_revoke_storage.go index 12580b4b5..9cdfad2fd 100644 --- a/internal/oauth2_revoke_storage.go +++ b/internal/oauth2_revoke_storage.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) @@ -54,17 +53,17 @@ func (mr *MockTokenRevocationStorageMockRecorder) CreateAccessTokenSession(arg0, } // CreateRefreshTokenSession mocks base method. -func (m *MockTokenRevocationStorage) CreateRefreshTokenSession(arg0 context.Context, arg1 string, arg2 fosite.Requester) error { +func (m *MockTokenRevocationStorage) CreateRefreshTokenSession(arg0 context.Context, arg1, arg2 string, arg3 fosite.Requester) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateRefreshTokenSession", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "CreateRefreshTokenSession", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // CreateRefreshTokenSession indicates an expected call of CreateRefreshTokenSession. -func (mr *MockTokenRevocationStorageMockRecorder) CreateRefreshTokenSession(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockTokenRevocationStorageMockRecorder) CreateRefreshTokenSession(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRefreshTokenSession", reflect.TypeOf((*MockTokenRevocationStorage)(nil).CreateRefreshTokenSession), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRefreshTokenSession", reflect.TypeOf((*MockTokenRevocationStorage)(nil).CreateRefreshTokenSession), arg0, arg1, arg2, arg3) } // DeleteAccessTokenSession mocks base method. @@ -153,16 +152,16 @@ func (mr *MockTokenRevocationStorageMockRecorder) RevokeRefreshToken(arg0, arg1 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeRefreshToken", reflect.TypeOf((*MockTokenRevocationStorage)(nil).RevokeRefreshToken), arg0, arg1) } -// RevokeRefreshTokenMaybeGracePeriod mocks base method. -func (m *MockTokenRevocationStorage) RevokeRefreshTokenMaybeGracePeriod(arg0 context.Context, arg1, arg2 string) error { +// RotateRefreshToken mocks base method. +func (m *MockTokenRevocationStorage) RotateRefreshToken(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RevokeRefreshTokenMaybeGracePeriod", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "RotateRefreshToken", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } -// RevokeRefreshTokenMaybeGracePeriod indicates an expected call of RevokeRefreshTokenMaybeGracePeriod. -func (mr *MockTokenRevocationStorageMockRecorder) RevokeRefreshTokenMaybeGracePeriod(arg0, arg1, arg2 interface{}) *gomock.Call { +// RotateRefreshToken indicates an expected call of RotateRefreshToken. +func (mr *MockTokenRevocationStorageMockRecorder) RotateRefreshToken(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeRefreshTokenMaybeGracePeriod", reflect.TypeOf((*MockTokenRevocationStorage)(nil).RevokeRefreshTokenMaybeGracePeriod), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RotateRefreshToken", reflect.TypeOf((*MockTokenRevocationStorage)(nil).RotateRefreshToken), arg0, arg1, arg2) } diff --git a/internal/oauth2_storage.go b/internal/oauth2_storage.go index a67815f7f..5524c6c7e 100644 --- a/internal/oauth2_storage.go +++ b/internal/oauth2_storage.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) @@ -68,17 +67,17 @@ func (mr *MockCoreStorageMockRecorder) CreateAuthorizeCodeSession(arg0, arg1, ar } // CreateRefreshTokenSession mocks base method. -func (m *MockCoreStorage) CreateRefreshTokenSession(arg0 context.Context, arg1 string, arg2 fosite.Requester) error { +func (m *MockCoreStorage) CreateRefreshTokenSession(arg0 context.Context, arg1, arg2 string, arg3 fosite.Requester) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateRefreshTokenSession", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "CreateRefreshTokenSession", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // CreateRefreshTokenSession indicates an expected call of CreateRefreshTokenSession. -func (mr *MockCoreStorageMockRecorder) CreateRefreshTokenSession(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockCoreStorageMockRecorder) CreateRefreshTokenSession(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRefreshTokenSession", reflect.TypeOf((*MockCoreStorage)(nil).CreateRefreshTokenSession), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRefreshTokenSession", reflect.TypeOf((*MockCoreStorage)(nil).CreateRefreshTokenSession), arg0, arg1, arg2, arg3) } // DeleteAccessTokenSession mocks base method. @@ -167,3 +166,17 @@ func (mr *MockCoreStorageMockRecorder) InvalidateAuthorizeCodeSession(arg0, arg1 mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InvalidateAuthorizeCodeSession", reflect.TypeOf((*MockCoreStorage)(nil).InvalidateAuthorizeCodeSession), arg0, arg1) } + +// RotateRefreshToken mocks base method. +func (m *MockCoreStorage) RotateRefreshToken(arg0 context.Context, arg1, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RotateRefreshToken", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// RotateRefreshToken indicates an expected call of RotateRefreshToken. +func (mr *MockCoreStorageMockRecorder) RotateRefreshToken(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RotateRefreshToken", reflect.TypeOf((*MockCoreStorage)(nil).RotateRefreshToken), arg0, arg1, arg2) +} diff --git a/internal/oauth2_strategy.go b/internal/oauth2_strategy.go index 539202052..aeb473fe5 100644 --- a/internal/oauth2_strategy.go +++ b/internal/oauth2_strategy.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/openid_id_token_storage.go b/internal/openid_id_token_storage.go index bfcd0d628..2aa736b7c 100644 --- a/internal/openid_id_token_storage.go +++ b/internal/openid_id_token_storage.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/pkce_storage_strategy.go b/internal/pkce_storage_strategy.go index 46de92ea0..bdaf4c624 100644 --- a/internal/pkce_storage_strategy.go +++ b/internal/pkce_storage_strategy.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/refresh_token_strategy.go b/internal/refresh_token_strategy.go index 5338bfb71..2bafb3c90 100644 --- a/internal/refresh_token_strategy.go +++ b/internal/refresh_token_strategy.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/request.go b/internal/request.go index c74969c95..762a43b38 100644 --- a/internal/request.go +++ b/internal/request.go @@ -13,7 +13,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/revoke_handler.go b/internal/revoke_handler.go index 948178e19..0599be852 100644 --- a/internal/revoke_handler.go +++ b/internal/revoke_handler.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/storage.go b/internal/storage.go index 14fa7c32c..44b199ac5 100644 --- a/internal/storage.go +++ b/internal/storage.go @@ -13,7 +13,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/internal/token_handler.go b/internal/token_handler.go index 9eb170678..519afa8aa 100644 --- a/internal/token_handler.go +++ b/internal/token_handler.go @@ -12,7 +12,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) diff --git a/storage/memory.go b/storage/memory.go index 5e88772a7..82fc28c85 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -89,7 +89,8 @@ type StoreAuthorizeCode struct { } type StoreRefreshToken struct { - active bool + active bool + accessTokenSignature string fosite.Requester } @@ -321,7 +322,7 @@ func (s *MemoryStore) DeleteAccessTokenSession(_ context.Context, signature stri return nil } -func (s *MemoryStore) CreateRefreshTokenSession(_ context.Context, signature string, req fosite.Requester) error { +func (s *MemoryStore) CreateRefreshTokenSession(_ context.Context, signature, accessTokenSignature string, req fosite.Requester) error { // We first lock refreshTokenRequestIDsMutex and then refreshTokensMutex because this is the same order // locking happens in RevokeRefreshToken and using the same order prevents deadlocks. s.refreshTokenRequestIDsMutex.Lock() @@ -329,7 +330,7 @@ func (s *MemoryStore) CreateRefreshTokenSession(_ context.Context, signature str s.refreshTokensMutex.Lock() defer s.refreshTokensMutex.Unlock() - s.RefreshTokens[signature] = StoreRefreshToken{active: true, Requester: req} + s.RefreshTokens[signature] = StoreRefreshToken{active: true, Requester: req, accessTokenSignature: accessTokenSignature} s.RefreshTokenRequestIDs[req.GetID()] = signature return nil } @@ -385,11 +386,6 @@ func (s *MemoryStore) RevokeRefreshToken(ctx context.Context, requestID string) return nil } -func (s *MemoryStore) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, requestID string, signature string) error { - // no configuration option is available; grace period is not available with memory store - return s.RevokeRefreshToken(ctx, requestID) -} - func (s *MemoryStore) RevokeAccessToken(ctx context.Context, requestID string) error { s.accessTokenRequestIDsMutex.RLock() defer s.accessTokenRequestIDsMutex.RUnlock() @@ -497,3 +493,12 @@ func (s *MemoryStore) DeletePARSession(ctx context.Context, requestURI string) ( delete(s.PARSessions, requestURI) return nil } + +func (s *MemoryStore) RotateRefreshToken(ctx context.Context, requestID string, refreshTokenSignature string) (err error) { + // Graceful token rotation can be implemented here but it's beyond the scope of this example. Check + // the Ory Hydra implementation for reference. + if err := s.RevokeRefreshToken(ctx, requestID); err != nil { + return err + } + return s.RevokeAccessToken(ctx, requestID) +} From b57570a26c3e6759e85d954e3e3fb3394088642e Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 4 Dec 2024 09:34:51 +0100 Subject: [PATCH 2/3] feat: better refresh token debug-ability --- handler/oauth2/flow_refresh.go | 34 ++++++++++++++----------- handler/oauth2/flow_refresh_test.go | 4 +-- integration/refresh_token_grant_test.go | 4 +-- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/handler/oauth2/flow_refresh.go b/handler/oauth2/flow_refresh.go index 297cd279c..987a403b0 100644 --- a/handler/oauth2/flow_refresh.go +++ b/handler/oauth2/flow_refresh.go @@ -51,18 +51,25 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex return errorsx.WithStack(rErr) } - return errorsx.WithStack(fosite.ErrInactiveToken.WithWrap(err).WithDebug(err.Error())) + return fosite.ErrInvalidGrant.WithWrap(err). + WithHint("The refresh token was already used."). + WithDebugf("Refresh token re-use was detected. All related tokens have been revoked.") } else if errors.Is(err, fosite.ErrNotFound) { - return errorsx.WithStack(fosite.ErrInvalidGrant.WithWrap(err).WithDebugf("The refresh token has not been found: %s", err.Error())) + return fosite.ErrInvalidGrant.WithWrap(err). + WithHint("The refresh token is malformed or not valid."). + WithDebug("The refresh token can not be found.") } else if err != nil { - return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } else if err := c.RefreshTokenStrategy.ValidateRefreshToken(ctx, originalRequest, refresh); err != nil { + return fosite.ErrInvalidGrant.WithWrap(err).WithDebug(err.Error()) + } + + if err := c.RefreshTokenStrategy.ValidateRefreshToken(ctx, originalRequest, refresh); err != nil { // The authorization server MUST ... validate the refresh token. // This needs to happen after store retrieval for the session to be hydrated properly if errors.Is(err, fosite.ErrTokenExpired) { - return errorsx.WithStack(fosite.ErrInvalidGrant.WithWrap(err).WithDebug(err.Error())) + return fosite.ErrInvalidGrant.WithWrap(err). + WithHint("The refresh token expired.") } - return errorsx.WithStack(fosite.ErrInvalidRequest.WithWrap(err).WithDebug(err.Error())) + return fosite.ErrInvalidRequest.WithWrap(err).WithDebug(err.Error()) } if !(len(c.Config.GetRefreshTokenScopes(ctx)) == 0 || originalRequest.GetGrantedScopes().HasOneOf(c.Config.GetRefreshTokenScopes(ctx)...)) { @@ -129,23 +136,20 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con if err != nil { return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } - defer func() { - err = c.handleRefreshTokenEndpointStorageError(ctx, err) - }() storeReq := requester.Sanitize([]string{}) storeReq.SetID(requester.GetID()) if err = c.TokenRevocationStorage.RotateRefreshToken(ctx, requester.GetID(), signature); err != nil { - return err + return c.handleRefreshTokenEndpointStorageError(ctx, err) } if err = c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil { - return err + return c.handleRefreshTokenEndpointStorageError(ctx, err) } if err = c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, accessSignature, storeReq); err != nil { - return err + return c.handleRefreshTokenEndpointStorageError(ctx, err) } responder.SetAccessToken(accessToken) @@ -156,7 +160,7 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con responder.SetExtra("refresh_token", refreshToken) if err = storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil { - return err + return c.handleRefreshTokenEndpointStorageError(ctx, err) } return nil @@ -214,14 +218,14 @@ func (c *RefreshTokenGrantHandler) handleRefreshTokenEndpointStorageError(ctx co return errorsx.WithStack(fosite.ErrInvalidRequest. WithDebugf(storageErr.Error()). WithWrap(storageErr). - WithHint("Failed to refresh token because of multiple concurrent requests using the same token which is not allowed.")) + WithHint("Failed to refresh token because of multiple concurrent requests using the same token. Please retry the request.")) } if errors.Is(storageErr, fosite.ErrNotFound) || errors.Is(storageErr, fosite.ErrInactiveToken) { return errorsx.WithStack(fosite.ErrInvalidRequest. WithDebugf(storageErr.Error()). WithWrap(storageErr). - WithHint("Failed to refresh token because of multiple concurrent requests using the same token which is not allowed.")) + WithHint("Failed to refresh token. Please retry the request.")) } return errorsx.WithStack(fosite.ErrServerError.WithWrap(storageErr).WithDebug(storageErr.Error())) diff --git a/handler/oauth2/flow_refresh_test.go b/handler/oauth2/flow_refresh_test.go index c9a016553..6abf80398 100644 --- a/handler/oauth2/flow_refresh_test.go +++ b/handler/oauth2/flow_refresh_test.go @@ -311,7 +311,7 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { err = store.RevokeRefreshToken(context.Background(), req.ID) require.NoError(t, err) }, - expectErr: fosite.ErrInactiveToken, + expectErr: fosite.ErrInvalidGrant, }, } { t.Run("case="+c.description, func(t *testing.T) { @@ -403,7 +403,7 @@ func TestRefreshFlowTransactional_HandleTokenEndpointRequest(t *testing.T) { Return(nil). Times(1) }, - expectError: fosite.ErrInactiveToken, + expectError: fosite.ErrInvalidGrant, }, } { t.Run(fmt.Sprintf("scenario=%s", testCase.description), func(t *testing.T) { diff --git a/integration/refresh_token_grant_test.go b/integration/refresh_token_grant_test.go index e30b2bdfc..1bfdc1d53 100644 --- a/integration/refresh_token_grant_test.go +++ b/integration/refresh_token_grant_test.go @@ -200,13 +200,13 @@ func TestRefreshTokenFlow(t *testing.T) { tokenSource := oauthClient.TokenSource(context.Background(), original) _, err := tokenSource.Token() require.Error(t, err) - require.Equal(t, http.StatusUnauthorized, err.(*oauth2.RetrieveError).Response.StatusCode) + require.Equal(t, http.StatusBadRequest, err.(*oauth2.RetrieveError).Response.StatusCode) refreshed.Expiry = refreshed.Expiry.Add(-time.Hour * 24) tokenSource = oauthClient.TokenSource(context.Background(), refreshed) _, err = tokenSource.Token() require.Error(t, err) - require.Equal(t, http.StatusUnauthorized, err.(*oauth2.RetrieveError).Response.StatusCode) + require.Equal(t, http.StatusBadRequest, err.(*oauth2.RetrieveError).Response.StatusCode) }, }, } { From 57cf545ec1a2969ff1e24471899cba504bf95a14 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 4 Dec 2024 22:40:45 +0100 Subject: [PATCH 3/3] ci: pin hydra version --- .github/workflows/oidc-conformity.yml | 2 +- handler/oauth2/flow_refresh.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/oidc-conformity.yml b/.github/workflows/oidc-conformity.yml index 6994fb5e0..79e5d11ea 100644 --- a/.github/workflows/oidc-conformity.yml +++ b/.github/workflows/oidc-conformity.yml @@ -14,7 +14,7 @@ jobs: with: fetch-depth: 2 repository: ory/hydra - ref: 2866a0499d02341ed0603601cfe4e63b24506fcb + ref: a35e78e364a26c4f87f37d9f545ef10b3ffa468a - uses: actions/setup-go@v2 with: go-version: "1.21" diff --git a/handler/oauth2/flow_refresh.go b/handler/oauth2/flow_refresh.go index 987a403b0..bffddad64 100644 --- a/handler/oauth2/flow_refresh.go +++ b/handler/oauth2/flow_refresh.go @@ -59,7 +59,7 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex WithHint("The refresh token is malformed or not valid."). WithDebug("The refresh token can not be found.") } else if err != nil { - return fosite.ErrInvalidGrant.WithWrap(err).WithDebug(err.Error()) + return fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()) } if err := c.RefreshTokenStrategy.ValidateRefreshToken(ctx, originalRequest, refresh); err != nil {