From 17fb0ffde0931b47beba63fdf2c258bf74ed2463 Mon Sep 17 00:00:00 2001 From: Utkarsh Saxena Date: Wed, 14 Dec 2022 02:07:13 +0900 Subject: [PATCH 1/3] add(handler): token exchange rfc8693 in impersonation mode --- handler/rfc8693/handler.go | 268 ++++++++++++++++ handler/rfc8693/handler_test.go | 354 +++++++++++++++++++++ handler/rfc8693/storage.go | 28 ++ handler/rfc8693/strategy.go | 19 ++ internal/oauth2_token_exchange_storage.go | 85 +++++ internal/oauth2_token_exchange_strategy.go | 52 +++ oauth2.go | 3 +- 7 files changed, 808 insertions(+), 1 deletion(-) create mode 100644 handler/rfc8693/handler.go create mode 100644 handler/rfc8693/handler_test.go create mode 100644 handler/rfc8693/storage.go create mode 100644 handler/rfc8693/strategy.go create mode 100644 internal/oauth2_token_exchange_storage.go create mode 100644 internal/oauth2_token_exchange_strategy.go diff --git a/handler/rfc8693/handler.go b/handler/rfc8693/handler.go new file mode 100644 index 000000000..937c9a2d4 --- /dev/null +++ b/handler/rfc8693/handler.go @@ -0,0 +1,268 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package rfc8693 + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/compose" + "github.com/ory/fosite/handler/oauth2" + "github.com/ory/fosite/token/jwt" + "github.com/ory/x/errorsx" +) + +// #nosec G101 +const ( + tokenTypeIDToken = "urn:ietf:params:oauth:token-type:id_token" + tokenTypeAT = "urn:ietf:params:oauth:token-type:access_token" +) + +func TokenExchangeGrantFactory(config *compose.CommonStrategy, storage, strategy interface{}) interface{} { + return nil +} + +type Handler struct { + Storage RFC8693Storage + Strategy ClientAuthenticationStrategy + ScopeStrategy fosite.ScopeStrategy + AudienceMatchingStrategy fosite.AudienceMatchingStrategy + RefreshTokenStrategy oauth2.RefreshTokenStrategy + RefreshTokenStorage oauth2.RefreshTokenStorage + fosite.RefreshTokenScopesProvider + + *oauth2.HandleHelper +} + +type tokenExchangeParams struct { + subjectToken string + subjectTokenType string +} + +func parseRequestParameter(requester fosite.AccessRequester) (*tokenExchangeParams, error) { + form := requester.GetRequestForm() + + // From https://tools.ietf.org/html/rfc8693#section-2.1: + // + // subject_token + // REQUIRED. A security token that represents the identity of the + // party on behalf of whom the request is being made. Typically, the + // subject of this token will be the subject of the security token + // issued in response to the request. + subjectToken := form.Get("subject_token") + if subjectToken == "" { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("subject_token is missing")) + } + + // From https://tools.ietf.org/html/rfc8693#section-2.1: + // + // subject_token_type + // REQUIRED. An identifier, that indicates the type of the + // security token in the "subject_token" parameter. + subjectTokenType := form.Get("subject_token_type") + switch subjectTokenType { + case tokenTypeIDToken, tokenTypeAT: + default: + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("unsupported or missing subject_token_type %s", subjectTokenType)) + } + + // From https://tools.ietf.org/html/rfc8693#section-2.1: + // + // requested_token_type + // OPTIONAL. An identifier, for the type of the requested security token. + // If the requested type is unspecified, + // the issued token type is at the discretion of the authorization server and + // may be dictated by knowledge of the requirements of the service or + // resource indicated by the resource or audience parameter. + requestedTokenType := form.Get("requested_token_type") + switch requestedTokenType { + case tokenTypeAT, "": + default: + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("unsupported requested_token_type %s", requestedTokenType)) + } + + // From https://tools.ietf.org/html/rfc8693#section-2.1: + // + // actor_token + // OPTIONAL . A security token that represents the identity of the acting party. + // Typically, this will be the party that is authorized to use the requested security + // token and act on behalf of the subject. + actorToken := form.Get("actor_token") + if actorToken != "" { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("'actor_token' was provided but delegation is currently not supported.")) + } + + // From https://tools.ietf.org/html/rfc8693#section-2.1: + // + // actor_token_type + // An identifier, as described in Section 3, that indicates the type of the security token + // in the actor_token parameter. This is REQUIRED when the actor_token parameter is present + // in the request but MUST NOT be included otherwise. + actorTokenType := form.Get("actor_token_type") + if actorTokenType != "" { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("'actor_token_type' was provided but delegation is currently not supported.")) + } + + return &tokenExchangeParams{ + subjectToken: subjectToken, + subjectTokenType: subjectTokenType, + }, nil +} + +func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) error { + if !c.CanHandleTokenEndpointRequest(requester) { + return errorsx.WithStack(fosite.ErrUnknownRequest) + } + + client := requester.GetClient() + if client.GetID() == "" { + return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHint("unauthenticated client")) + } + + // Check whether client is allowed to use token exchange. + if !client.GetGrantTypes().Has(string(fosite.GrantTypeTokenExchange)) { + return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHint("the client is not allowed to use token-exchange")) + } + + // Get request parameter related token exchange. + params, err := parseRequestParameter(requester) + if err != nil { + return err + } + + // Check and grant scope. + for _, scope := range requester.GetRequestedScopes() { + if !c.ScopeStrategy(client.GetScopes(), scope) { + return errorsx.WithStack(fosite.ErrInvalidScope.WithHintf("The OAuth 2.0 Client is not allowed to request scope '%s'.", scope)) + } + requester.GrantScope(scope) + } + + // Check and grant audience. + if err := c.AudienceMatchingStrategy(client.GetAudience(), requester.GetRequestedAudience()); err != nil { + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("audience not match: %v", err)) + } + for _, audience := range requester.GetRequestedAudience() { + requester.GrantAudience(audience) + } + + // Verify subject token. + switch params.subjectTokenType { + case tokenTypeIDToken: + claims := jwt.MapClaims{} + if _, err := jwt.ParseWithClaims(params.subjectToken, claims, c.keyFunc(ctx)); err != nil { + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("failed to verify JWT: %v", err)) + } + subject, err := c.Storage.GetImpersonateSubject(ctx, claims, requester) + if err != nil { + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("not allowed to token exchange by jwt: %v", err)) + } + requester.SetSession(&fosite.DefaultSession{ + Subject: subject, + }) + requester.GetSession().SetExpiresAt(fosite.AccessToken, time.Now().UTC().Add(c.Config.GetAccessTokenLifespan(ctx))) + return nil + case tokenTypeAT: + or, err := c.verifyAccessTokenAsSubjectToken(ctx, client.GetID(), params) + if err != nil { + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("not allowed to token exchange by at: %v", err)) + } + requester.SetSession(or.GetSession().Clone()) + // When the subject_type is AT, the expiration time is same with subject_token. + // Therefore, we don't need to set the expiresAt. + return nil + default: + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("unsupported subject_type %s", params.subjectTokenType)) + } +} + +func (c *Handler) PopulateTokenEndpointResponse(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) error { + if !c.CanHandleTokenEndpointRequest(requester) { + return errorsx.WithStack(fosite.ErrUnknownRequest) + } + + if !requester.GetClient().GetGrantTypes().Has(string(fosite.GrantTypeTokenExchange)) { + return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHintf("The OAuth 2.0 Client is not allowed to use authorization grant '%s'.", fosite.GrantTypeTokenExchange)) + } + + atLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypeTokenExchange, fosite.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) + + if err := c.IssueAccessToken(ctx, atLifespan, requester, responder); err != nil { + return err + } + + if canIssueRefreshToken(ctx, c, requester) { + fmt.Println(requester) + refresh, refreshSignature, err := c.RefreshTokenStrategy.GenerateRefreshToken(ctx, requester) + if err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + if err := c.RefreshTokenStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithDebug(err.Error())) + } + + responder.SetExtra("refresh_token", refresh) + } + return nil +} + +func canIssueRefreshToken(ctx context.Context, c *Handler, requester fosite.Requester) bool { + scope := c.GetRefreshTokenScopes(ctx) + // Require one of the refresh token scopes, if set. + if len(scope) > 0 && !requester.GetGrantedScopes().HasOneOf(scope...) { + return false + } + // Do not issue a refresh token to clients that cannot use the refresh token grant type. + if !requester.GetClient().GetGrantTypes().Has("refresh_token") { + return false + } + return true +} + +func (c *Handler) CanSkipClientAuth(requester fosite.AccessRequester) bool { + return c.Strategy.CanSkipClientAuth(requester) +} + +func (c *Handler) keyFunc(ctx context.Context) jwt.Keyfunc { + return jwt.Keyfunc(func(t *jwt.Token) (interface{}, error) { + kid, ok := t.Header["kid"].(string) + if !ok { + return nil, errors.New("invalid kid") + } + iss, ok := t.Claims["iss"].(string) + if !ok { + return nil, errors.New("invalid iss") + } + return c.Storage.GetIDTokenPublicKey(ctx, iss, kid) + }) +} + +func (c *Handler) verifyAccessTokenAsSubjectToken(ctx context.Context, clientID string, params *tokenExchangeParams) (fosite.Requester, error) { + sig := c.HandleHelper.AccessTokenStrategy.AccessTokenSignature(ctx, params.subjectToken) + or, err := c.HandleHelper.AccessTokenStorage.GetAccessTokenSession(ctx, sig, nil) + if err != nil { + return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithWrap(err).WithDebug(err.Error())) + } else if err := c.AccessTokenStrategy.ValidateAccessToken(ctx, or, params.subjectToken); err != nil { + return nil, err + } + + allowClientIDs, err := c.Storage.GetAllowedClientIDs(ctx, clientID) + if err != nil { + return nil, err + } + + for _, cid := range allowClientIDs { + if or.GetClient().GetID() == cid { + return or, nil + } + } + return nil, fmt.Errorf("this access_token is not allowed to use token exchange based on AT: original_client:%s, request_client:%s ", or.GetClient().GetID(), clientID) +} + +func (c *Handler) CanHandleTokenEndpointRequest(requester fosite.AccessRequester) bool { + return requester.GetGrantTypes().ExactOne(string(fosite.GrantTypeTokenExchange)) +} diff --git a/handler/rfc8693/handler_test.go b/handler/rfc8693/handler_test.go new file mode 100644 index 000000000..89a5cf7a7 --- /dev/null +++ b/handler/rfc8693/handler_test.go @@ -0,0 +1,354 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package rfc8693 + +import ( + "context" + "net/http" + "net/url" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/ory/fosite" + fositeOAuth2 "github.com/ory/fosite/handler/oauth2" + "github.com/ory/fosite/internal" + "github.com/ory/fosite/token/jwt" + "github.com/stretchr/testify/require" + "gopkg.in/square/go-jose.v2" +) + +func TestTokenExchange_HandleTokenEndpointRequest(t *testing.T) { + ctrl := gomock.NewController(t) + teStore := internal.NewMockRFC8693Storage(ctrl) + atStore := internal.NewMockAccessTokenStorage(ctrl) + rtStore := internal.NewMockRefreshTokenGrantStorage(ctrl) + chgen := internal.NewMockAccessTokenStrategy(ctrl) + areq := internal.NewMockAccessRequester(ctrl) + defer ctrl.Finish() + + h := Handler{ + Storage: teStore, + HandleHelper: &fositeOAuth2.HandleHelper{ + AccessTokenStorage: atStore, + AccessTokenStrategy: chgen, + Config: &fosite.Config{ + AccessTokenLifespan: time.Hour, + }, + }, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + RefreshTokenStorage: rtStore, + } + + for _, c := range []struct { + name string + mock func() + req *http.Request + expectErr error + }{ + { + name: "should fail because granttype is missing", + expectErr: fosite.ErrUnknownRequest, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{""}) + }, + }, + { + name: "should fail because invalid client_id", + expectErr: fosite.ErrUnauthorizedClient, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{}) + }, + }, + { + name: "should fail because grant_type is not valid", + expectErr: fosite.ErrUnauthorizedClient, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{""}, + }) + }, + }, + { + name: "should fail because no subject_token", + expectErr: fosite.ErrInvalidRequest, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{""}, + }) + }, + }, + { + name: "should fail because unsupported subject_token_type", + expectErr: fosite.ErrInvalidRequest, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{"subject_token"}, + "subject_token_type": []string{"unsupported_subject_token_type"}, + }) + }, + }, + { + name: "should fail because scope not valid", + expectErr: fosite.ErrInvalidScope, + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + Scopes: []string{"none"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{"subject_token"}, + "subject_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + "requested_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + }) + areq.EXPECT().GetRequestedScopes().Return([]string{"foo"}) + }, + }, + { + name: "should pass as AT", + mock: func() { + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + Scopes: []string{"foo"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{"subject_token"}, + "subject_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + "requested_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + }) + + // scope and audience. + areq.EXPECT().GetRequestedScopes().Return([]string{"foo"}) + areq.EXPECT().GrantScope("foo") + areq.EXPECT().GetRequestedAudience().Return([]string{}) + areq.EXPECT().GetRequestedAudience().Return([]string{}) + chgen.EXPECT().AccessTokenSignature(gomock.Any(), gomock.Any()).Return("signature") + + // original request. + ar := internal.NewMockAccessRequester(ctrl) + atStore.EXPECT().GetAccessTokenSession(gomock.Any(), "signature", nil).Return(ar, nil) + chgen.EXPECT().ValidateAccessToken(gomock.Any(), ar, gomock.Any()).Return(nil) + + teStore.EXPECT().GetAllowedClientIDs(gomock.Any(), "client").Return([]string{"client2"}, nil) + ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client2", + }) + ar.EXPECT().GetSession().Return(new(fosite.DefaultSession)) + areq.EXPECT().SetSession(gomock.Any()) + }, + }, + { + name: "should fail because of different key", + expectErr: fosite.ErrInvalidRequest, + mock: func() { + // ID Token JWT. + key := []byte("aabbbbccccddddddd") + token := jwt.Token{ + Header: map[string]interface{}{ + "kid": "12asd4q34daf", + }, + Claims: jwt.MapClaims{ + "sub": "foo", + "exp": time.Now().Add(time.Hour).Unix(), + "iss": "bar", + "jti": "12345", + "aud": "token-url", + }, + Method: jose.HS256, + } + tokenString, err := token.SignedString(key) + require.NoError(t, err) + + // request. + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + Scopes: []string{"foo"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{tokenString}, + "subject_token_type": []string{"urn:ietf:params:oauth:token-type:id_token"}, + "requested_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + }) + + // scope and audience. + areq.EXPECT().GetRequestedScopes().Return([]string{"foo"}) + areq.EXPECT().GrantScope("foo") + areq.EXPECT().GetRequestedAudience().Return([]string{}) + areq.EXPECT().GetRequestedAudience().Return([]string{}) + + // verify IDToken. + teStore.EXPECT().GetIDTokenPublicKey(gomock.Any(), "bar", "12asd4q34daf").Return(&jose.JSONWebKey{ + Key: []byte("differnet_key"), + }, nil) + }, + }, + { + name: "should pass as JWT", + mock: func() { + // ID Token JWT. + key := []byte("aaabbbbcccddd") + token := jwt.Token{ + Header: map[string]interface{}{ + "kid": "12asd4q34daf", + }, + Claims: jwt.MapClaims{ + "sub": "foo", + "exp": time.Now().Add(time.Hour).Unix(), + "iss": "bar", + "jti": "12345", + "aud": "token-url", + }, + Method: jose.HS256, + } + tokenString, err := token.SignedString(key) + require.NoError(t, err) + + // request. + areq.EXPECT().GetGrantTypes().Return(fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}) + areq.EXPECT().GetClient().Return(&fosite.DefaultClient{ + ID: "client", + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}, + Scopes: []string{"foo"}, + }) + areq.EXPECT().GetRequestForm().Return(url.Values{ + "subject_token": []string{tokenString}, + "subject_token_type": []string{"urn:ietf:params:oauth:token-type:id_token"}, + "requested_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"}, + }) + + // scope and audience. + areq.EXPECT().GetRequestedScopes().Return([]string{"foo"}) + areq.EXPECT().GrantScope("foo") + areq.EXPECT().GetRequestedAudience().Return([]string{}) + areq.EXPECT().GetRequestedAudience().Return([]string{}) + + // verify IDToken. + teStore.EXPECT().GetIDTokenPublicKey(gomock.Any(), "bar", "12asd4q34daf").Return(&jose.JSONWebKey{ + Key: key, + }, nil) + teStore.EXPECT().GetImpersonateSubject(gomock.Any(), gomock.Any(), gomock.Any()).Return("client", nil) + + areq.EXPECT().SetSession(gomock.Any()) + areq.EXPECT().GetSession().Return(new(fosite.DefaultSession)) + }, + }, + } { + t.Run(c.name, func(t *testing.T) { + c.mock() + err := h.HandleTokenEndpointRequest(context.TODO(), areq) + if c.expectErr != nil { + require.EqualError(t, err, c.expectErr.Error()) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestTokenExchange_PopulateTokenEndpointResponse(t *testing.T) { + ctrl := gomock.NewController(t) + atStore := internal.NewMockAccessTokenStorage(ctrl) + chgen := internal.NewMockAccessTokenStrategy(ctrl) + + areq := fosite.NewAccessRequest(new(fosite.DefaultSession)) + aresp := fosite.NewAccessResponse() + rtStrategy := internal.NewMockRefreshTokenStrategy(ctrl) + rtStore := internal.NewMockRefreshTokenGrantStorage(ctrl) + + defer ctrl.Finish() + + h := Handler{ + HandleHelper: &fositeOAuth2.HandleHelper{ + AccessTokenStorage: atStore, + AccessTokenStrategy: chgen, + Config: &fosite.Config{ + AccessTokenLifespan: time.Hour, + }, + }, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + RefreshTokenStrategy: rtStrategy, + RefreshTokenStorage: rtStore, + RefreshTokenScopesProvider: &fosite.Config{ + RefreshTokenScopes: []string{"offline", "offline_access"}, + }, + } + for _, c := range []struct { + name string + mock func() + req *http.Request + expectErr error + }{ + { + name: "should fail because not responsible", + expectErr: fosite.ErrUnknownRequest, + mock: func() { + areq.GrantTypes = fosite.Arguments{""} + }, + }, + { + name: "should fail because grant_type not allowed", + expectErr: fosite.ErrUnauthorizedClient, + mock: func() { + areq.GrantTypes = fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"} + areq.Client = &fosite.DefaultClient{GrantTypes: fosite.Arguments{"authorization_code"}} + }, + }, + { + name: "should pass", + mock: func() { + areq.GrantTypes = fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"} + areq.Session = &fosite.DefaultSession{} + areq.Client = &fosite.DefaultClient{GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"}} + chgen.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return("tokenfoo.bar", "bar", nil) + atStore.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) + }, + }, + { + name: "should populate both AT and RT", + mock: func() { + areq.GrantedScope = fosite.Arguments{"offline_access"} + areq.GrantTypes = fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange"} + areq.Session = &fosite.DefaultSession{} + areq.Client = &fosite.DefaultClient{ + GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:token-exchange", "refresh_token"}, + } + chgen.EXPECT().GenerateAccessToken(gomock.Any(), areq).Return("tokenfoo.bar", "bar", nil) + atStore.EXPECT().CreateAccessTokenSession(gomock.Any(), "bar", gomock.Eq(areq.Sanitize([]string{}))).Return(nil) + rtStrategy.EXPECT().GenerateRefreshToken(gomock.Any(), gomock.Any()).Return("refresh_token", "refresh_token_signature", nil) + rtStore.EXPECT().CreateRefreshTokenSession(gomock.Any(), "refresh_token_signature", gomock.Eq(areq)).Return(nil) + }, + }, + } { + t.Run(c.name, func(t *testing.T) { + c.mock() + err := h.PopulateTokenEndpointResponse(context.TODO(), areq, aresp) + if c.expectErr != nil { + require.EqualError(t, err, c.expectErr.Error()) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/handler/rfc8693/storage.go b/handler/rfc8693/storage.go new file mode 100644 index 000000000..528f720ba --- /dev/null +++ b/handler/rfc8693/storage.go @@ -0,0 +1,28 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package rfc8693 + +//go:generate mockgen -source=storage.go -destination=../../internal/oauth2_token_exchange_storage.go -package=internal + +import ( + "context" + + "github.com/ory/fosite" + "github.com/ory/fosite/token/jwt" +) + +// RFC8693Storage hold information needed to perform token exchange. +type RFC8693Storage interface { + // GetAllowedClientIDs returns clientIDs that can be used for subject_token. + // The subject token is a security token that represents the identity of + // the party on behalf of whom the request is being made. + // https://datatracker.ietf.org/doc/html/rfc8693#section-2.1 + GetAllowedClientIDs(ctx context.Context, clientID string) ([]string, error) + + // GetIDTokenPublicKey returns the public key that can be used to verify ID Token. + GetIDTokenPublicKey(ctx context.Context, iss, kid string) (interface{}, error) + + // GetImpersonateSubject returns subject value to use the token based on a JWT. + GetImpersonateSubject(ctx context.Context, claims jwt.MapClaims, req fosite.Requester) (string, error) +} diff --git a/handler/rfc8693/strategy.go b/handler/rfc8693/strategy.go new file mode 100644 index 000000000..df2051628 --- /dev/null +++ b/handler/rfc8693/strategy.go @@ -0,0 +1,19 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package rfc8693 + +//go:generate mockgen -source=strategy.go -destination=../../internal/oauth2_token_exchange_strategy.go -package=internal + +import "github.com/ory/fosite" + +type ClientAuthenticationStrategy interface { + CanSkipClientAuth(requester fosite.AccessRequester) bool +} + +// DefaultClientAuthenticationStrategy enforces client authentication for all the cases. +type DefaultClientAuthenticationStrategy struct{} + +func (s *DefaultClientAuthenticationStrategy) CanSkipClientAuth(requester fosite.Requester) bool { + return false +} diff --git a/internal/oauth2_token_exchange_storage.go b/internal/oauth2_token_exchange_storage.go new file mode 100644 index 000000000..a8db68cc6 --- /dev/null +++ b/internal/oauth2_token_exchange_storage.go @@ -0,0 +1,85 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: storage.go + +// Package internal is a generated GoMock package. +package internal + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + fosite "github.com/ory/fosite" + jwt "github.com/ory/fosite/token/jwt" +) + +// MockRFC8693Storage is a mock of RFC8693Storage interface. +type MockRFC8693Storage struct { + ctrl *gomock.Controller + recorder *MockRFC8693StorageMockRecorder +} + +// MockRFC8693StorageMockRecorder is the mock recorder for MockRFC8693Storage. +type MockRFC8693StorageMockRecorder struct { + mock *MockRFC8693Storage +} + +// NewMockRFC8693Storage creates a new mock instance. +func NewMockRFC8693Storage(ctrl *gomock.Controller) *MockRFC8693Storage { + mock := &MockRFC8693Storage{ctrl: ctrl} + mock.recorder = &MockRFC8693StorageMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRFC8693Storage) EXPECT() *MockRFC8693StorageMockRecorder { + return m.recorder +} + +// GetAllowedClientIDs mocks base method. +func (m *MockRFC8693Storage) GetAllowedClientIDs(ctx context.Context, clientID string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllowedClientIDs", ctx, clientID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllowedClientIDs indicates an expected call of GetAllowedClientIDs. +func (mr *MockRFC8693StorageMockRecorder) GetAllowedClientIDs(ctx, clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllowedClientIDs", reflect.TypeOf((*MockRFC8693Storage)(nil).GetAllowedClientIDs), ctx, clientID) +} + +// GetIDTokenPublicKey mocks base method. +func (m *MockRFC8693Storage) GetIDTokenPublicKey(ctx context.Context, iss, kid string) (interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetIDTokenPublicKey", ctx, iss, kid) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetIDTokenPublicKey indicates an expected call of GetIDTokenPublicKey. +func (mr *MockRFC8693StorageMockRecorder) GetIDTokenPublicKey(ctx, iss, kid interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIDTokenPublicKey", reflect.TypeOf((*MockRFC8693Storage)(nil).GetIDTokenPublicKey), ctx, iss, kid) +} + +// GetImpersonateSubject mocks base method. +func (m *MockRFC8693Storage) GetImpersonateSubject(ctx context.Context, claims jwt.MapClaims, req fosite.Requester) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetImpersonateSubject", ctx, claims, req) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetImpersonateSubject indicates an expected call of GetImpersonateSubject. +func (mr *MockRFC8693StorageMockRecorder) GetImpersonateSubject(ctx, claims, req interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetImpersonateSubject", reflect.TypeOf((*MockRFC8693Storage)(nil).GetImpersonateSubject), ctx, claims, req) +} diff --git a/internal/oauth2_token_exchange_strategy.go b/internal/oauth2_token_exchange_strategy.go new file mode 100644 index 000000000..66875129f --- /dev/null +++ b/internal/oauth2_token_exchange_strategy.go @@ -0,0 +1,52 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: strategy.go + +// Package internal is a generated GoMock package. +package internal + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + fosite "github.com/ory/fosite" +) + +// MockClientAuthenticationStrategy is a mock of ClientAuthenticationStrategy interface. +type MockClientAuthenticationStrategy struct { + ctrl *gomock.Controller + recorder *MockClientAuthenticationStrategyMockRecorder +} + +// MockClientAuthenticationStrategyMockRecorder is the mock recorder for MockClientAuthenticationStrategy. +type MockClientAuthenticationStrategyMockRecorder struct { + mock *MockClientAuthenticationStrategy +} + +// NewMockClientAuthenticationStrategy creates a new mock instance. +func NewMockClientAuthenticationStrategy(ctrl *gomock.Controller) *MockClientAuthenticationStrategy { + mock := &MockClientAuthenticationStrategy{ctrl: ctrl} + mock.recorder = &MockClientAuthenticationStrategyMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClientAuthenticationStrategy) EXPECT() *MockClientAuthenticationStrategyMockRecorder { + return m.recorder +} + +// CanSkipClientAuth mocks base method. +func (m *MockClientAuthenticationStrategy) CanSkipClientAuth(requester fosite.AccessRequester) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CanSkipClientAuth", requester) + ret0, _ := ret[0].(bool) + return ret0 +} + +// CanSkipClientAuth indicates an expected call of CanSkipClientAuth. +func (mr *MockClientAuthenticationStrategyMockRecorder) CanSkipClientAuth(requester interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSkipClientAuth", reflect.TypeOf((*MockClientAuthenticationStrategy)(nil).CanSkipClientAuth), requester) +} diff --git a/oauth2.go b/oauth2.go index c25abf65a..1b313a835 100644 --- a/oauth2.go +++ b/oauth2.go @@ -31,7 +31,8 @@ const ( GrantTypeAuthorizationCode GrantType = "authorization_code" GrantTypePassword GrantType = "password" GrantTypeClientCredentials GrantType = "client_credentials" - GrantTypeJWTBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" //nolint:gosec // this is not a hardcoded credential + GrantTypeJWTBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" //nolint:gosec // this is not a hardcoded credential + GrantTypeTokenExchange GrantType = "urn:ietf:params:oauth:grant-type:token-exchange" //nolint:gosec // this is not a hardcoded credential BearerAccessToken string = "bearer" ) From e0bd8f612fa7034c81656e43bb5d6949ee9772da Mon Sep 17 00:00:00 2001 From: Utkarsh Saxena Date: Sat, 22 Apr 2023 21:46:19 +0800 Subject: [PATCH 2/3] change(handler): token exchange related stategy should not be inside the handler package --- client_authentication.go | 3 + config.go | 6 ++ config_default.go | 88 ++++++++++++---------- handler/rfc8693/handler.go | 41 +++++----- handler/rfc8693/handler_test.go | 15 ++-- handler/rfc8693/strategy.go | 19 ----- internal/oauth2_token_exchange_strategy.go | 52 ------------- 7 files changed, 84 insertions(+), 140 deletions(-) delete mode 100644 handler/rfc8693/strategy.go delete mode 100644 internal/oauth2_token_exchange_strategy.go diff --git a/client_authentication.go b/client_authentication.go index c86483991..7394534c0 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -21,6 +21,9 @@ import ( "github.com/ory/fosite/token/jwt" ) +// CanSkipClientAuthenticationStrategy provides a method signature for checking if client authentication can be skipped. +type CanSkipClientAuthenticationStrategy func(context.Context, AccessRequester) bool + // ClientAuthenticationStrategy provides a method signature for authenticating a client request type ClientAuthenticationStrategy func(context.Context, *http.Request, url.Values) (Client, error) diff --git a/config.go b/config.go index 1b50eb70c..f58f0359b 100644 --- a/config.go +++ b/config.go @@ -150,6 +150,12 @@ type GrantTypeJWTBearerCanSkipClientAuthProvider interface { GetGrantTypeJWTBearerCanSkipClientAuth(ctx context.Context) bool } +// GrantTypeTokenExchangeCanSkipClientAuthProvider returns the provider for configuring the grant type Token Exchange can skip client auth. +type GrantTypeTokenExchangeCanSkipClientAuthProvider interface { + // GetGrantTypeTokenExchangeCanSkipClientAuth returns the grant type Token Exchange can skip client auth. + GetGrantTypeTokenExchangeCanSkipClientAuth(ctx context.Context) CanSkipClientAuthenticationStrategy +} + // GrantTypeJWTBearerIDOptionalProvider returns the provider for configuring the grant type JWT bearer ID optional. type GrantTypeJWTBearerIDOptionalProvider interface { // GetGrantTypeJWTBearerIDOptional returns the grant type JWT bearer ID optional. diff --git a/config_default.go b/config_default.go index 7f2e2487e..899df871a 100644 --- a/config_default.go +++ b/config_default.go @@ -23,45 +23,46 @@ const ( ) var ( - _ AuthorizeCodeLifespanProvider = (*Config)(nil) - _ RefreshTokenLifespanProvider = (*Config)(nil) - _ AccessTokenLifespanProvider = (*Config)(nil) - _ ScopeStrategyProvider = (*Config)(nil) - _ AudienceStrategyProvider = (*Config)(nil) - _ RedirectSecureCheckerProvider = (*Config)(nil) - _ RefreshTokenScopesProvider = (*Config)(nil) - _ DisableRefreshTokenValidationProvider = (*Config)(nil) - _ AccessTokenIssuerProvider = (*Config)(nil) - _ JWTScopeFieldProvider = (*Config)(nil) - _ AllowedPromptsProvider = (*Config)(nil) - _ OmitRedirectScopeParamProvider = (*Config)(nil) - _ MinParameterEntropyProvider = (*Config)(nil) - _ SanitationAllowedProvider = (*Config)(nil) - _ EnforcePKCEForPublicClientsProvider = (*Config)(nil) - _ EnablePKCEPlainChallengeMethodProvider = (*Config)(nil) - _ EnforcePKCEProvider = (*Config)(nil) - _ GrantTypeJWTBearerCanSkipClientAuthProvider = (*Config)(nil) - _ GrantTypeJWTBearerIDOptionalProvider = (*Config)(nil) - _ GrantTypeJWTBearerIssuedDateOptionalProvider = (*Config)(nil) - _ GetJWTMaxDurationProvider = (*Config)(nil) - _ IDTokenLifespanProvider = (*Config)(nil) - _ IDTokenIssuerProvider = (*Config)(nil) - _ JWKSFetcherStrategyProvider = (*Config)(nil) - _ ClientAuthenticationStrategyProvider = (*Config)(nil) - _ SendDebugMessagesToClientsProvider = (*Config)(nil) - _ ResponseModeHandlerExtensionProvider = (*Config)(nil) - _ MessageCatalogProvider = (*Config)(nil) - _ FormPostHTMLTemplateProvider = (*Config)(nil) - _ TokenURLProvider = (*Config)(nil) - _ GetSecretsHashingProvider = (*Config)(nil) - _ HTTPClientProvider = (*Config)(nil) - _ HMACHashingProvider = (*Config)(nil) - _ AuthorizeEndpointHandlersProvider = (*Config)(nil) - _ TokenEndpointHandlersProvider = (*Config)(nil) - _ TokenIntrospectionHandlersProvider = (*Config)(nil) - _ RevocationHandlersProvider = (*Config)(nil) - _ PushedAuthorizeRequestHandlersProvider = (*Config)(nil) - _ PushedAuthorizeRequestConfigProvider = (*Config)(nil) + _ AuthorizeCodeLifespanProvider = (*Config)(nil) + _ RefreshTokenLifespanProvider = (*Config)(nil) + _ AccessTokenLifespanProvider = (*Config)(nil) + _ ScopeStrategyProvider = (*Config)(nil) + _ AudienceStrategyProvider = (*Config)(nil) + _ RedirectSecureCheckerProvider = (*Config)(nil) + _ RefreshTokenScopesProvider = (*Config)(nil) + _ DisableRefreshTokenValidationProvider = (*Config)(nil) + _ AccessTokenIssuerProvider = (*Config)(nil) + _ JWTScopeFieldProvider = (*Config)(nil) + _ AllowedPromptsProvider = (*Config)(nil) + _ OmitRedirectScopeParamProvider = (*Config)(nil) + _ MinParameterEntropyProvider = (*Config)(nil) + _ SanitationAllowedProvider = (*Config)(nil) + _ EnforcePKCEForPublicClientsProvider = (*Config)(nil) + _ EnablePKCEPlainChallengeMethodProvider = (*Config)(nil) + _ EnforcePKCEProvider = (*Config)(nil) + _ GrantTypeTokenExchangeCanSkipClientAuthProvider = (*Config)(nil) + _ GrantTypeJWTBearerCanSkipClientAuthProvider = (*Config)(nil) + _ GrantTypeJWTBearerIDOptionalProvider = (*Config)(nil) + _ GrantTypeJWTBearerIssuedDateOptionalProvider = (*Config)(nil) + _ GetJWTMaxDurationProvider = (*Config)(nil) + _ IDTokenLifespanProvider = (*Config)(nil) + _ IDTokenIssuerProvider = (*Config)(nil) + _ JWKSFetcherStrategyProvider = (*Config)(nil) + _ ClientAuthenticationStrategyProvider = (*Config)(nil) + _ SendDebugMessagesToClientsProvider = (*Config)(nil) + _ ResponseModeHandlerExtensionProvider = (*Config)(nil) + _ MessageCatalogProvider = (*Config)(nil) + _ FormPostHTMLTemplateProvider = (*Config)(nil) + _ TokenURLProvider = (*Config)(nil) + _ GetSecretsHashingProvider = (*Config)(nil) + _ HTTPClientProvider = (*Config)(nil) + _ HMACHashingProvider = (*Config)(nil) + _ AuthorizeEndpointHandlersProvider = (*Config)(nil) + _ TokenEndpointHandlersProvider = (*Config)(nil) + _ TokenIntrospectionHandlersProvider = (*Config)(nil) + _ RevocationHandlersProvider = (*Config)(nil) + _ PushedAuthorizeRequestHandlersProvider = (*Config)(nil) + _ PushedAuthorizeRequestConfigProvider = (*Config)(nil) ) type Config struct { @@ -148,6 +149,9 @@ type Config struct { // GrantTypeJWTBearerMaxDuration sets the maximum time after JWT issued date, during which the JWT is considered valid. GrantTypeJWTBearerMaxDuration time.Duration + // GrantTypeTokenExchangeCanSkipClientAuth indicates the stretegy to check if client authentication can be skipped. + GrantTypeTokenExchangeCanSkipClientAuth CanSkipClientAuthenticationStrategy + // ClientAuthenticationStrategy indicates the Strategy to authenticate client requests ClientAuthenticationStrategy ClientAuthenticationStrategy @@ -299,6 +303,12 @@ func (c *Config) GetGrantTypeJWTBearerCanSkipClientAuth(ctx context.Context) boo return c.GrantTypeJWTBearerCanSkipClientAuth } +// GetGrantTypeTokenExchangeCanSkipClientAuth returns the GrantTypeTokenExchangeCanSkipClientAuth field. +// Defaults to nil, in which case TokenExchange follows the default behavior. +func (c *Config) GetGrantTypeTokenExchangeCanSkipClientAuth(ctx context.Context) CanSkipClientAuthenticationStrategy { + return c.GrantTypeTokenExchangeCanSkipClientAuth +} + // GetEnforcePKCE If set to true, public clients must use PKCE. func (c *Config) GetEnforcePKCE(ctx context.Context) bool { return c.EnforcePKCE diff --git a/handler/rfc8693/handler.go b/handler/rfc8693/handler.go index 937c9a2d4..774505068 100644 --- a/handler/rfc8693/handler.go +++ b/handler/rfc8693/handler.go @@ -10,7 +10,6 @@ import ( "time" "github.com/ory/fosite" - "github.com/ory/fosite/compose" "github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/token/jwt" "github.com/ory/x/errorsx" @@ -22,18 +21,17 @@ const ( tokenTypeAT = "urn:ietf:params:oauth:token-type:access_token" ) -func TokenExchangeGrantFactory(config *compose.CommonStrategy, storage, strategy interface{}) interface{} { - return nil -} - type Handler struct { - Storage RFC8693Storage - Strategy ClientAuthenticationStrategy - ScopeStrategy fosite.ScopeStrategy - AudienceMatchingStrategy fosite.AudienceMatchingStrategy - RefreshTokenStrategy oauth2.RefreshTokenStrategy - RefreshTokenStorage oauth2.RefreshTokenStorage - fosite.RefreshTokenScopesProvider + Storage RFC8693Storage + RefreshTokenStorage oauth2.RefreshTokenStorage + RefreshTokenStrategy oauth2.RefreshTokenStrategy + + Config interface { + fosite.GrantTypeTokenExchangeCanSkipClientAuthProvider + fosite.ScopeStrategyProvider + fosite.AudienceStrategyProvider + fosite.RefreshTokenScopesProvider + } *oauth2.HandleHelper } @@ -136,14 +134,14 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester fosi // Check and grant scope. for _, scope := range requester.GetRequestedScopes() { - if !c.ScopeStrategy(client.GetScopes(), scope) { + if !c.Config.GetScopeStrategy(ctx)(client.GetScopes(), scope) { return errorsx.WithStack(fosite.ErrInvalidScope.WithHintf("The OAuth 2.0 Client is not allowed to request scope '%s'.", scope)) } requester.GrantScope(scope) } // Check and grant audience. - if err := c.AudienceMatchingStrategy(client.GetAudience(), requester.GetRequestedAudience()); err != nil { + if err := c.Config.GetAudienceStrategy(ctx)(client.GetAudience(), requester.GetRequestedAudience()); err != nil { return errorsx.WithStack(fosite.ErrInvalidRequest.WithHintf("audience not match: %v", err)) } for _, audience := range requester.GetRequestedAudience() { @@ -164,7 +162,7 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester fosi requester.SetSession(&fosite.DefaultSession{ Subject: subject, }) - requester.GetSession().SetExpiresAt(fosite.AccessToken, time.Now().UTC().Add(c.Config.GetAccessTokenLifespan(ctx))) + requester.GetSession().SetExpiresAt(fosite.AccessToken, time.Now().UTC().Add(c.HandleHelper.Config.GetAccessTokenLifespan(ctx))) return nil case tokenTypeAT: or, err := c.verifyAccessTokenAsSubjectToken(ctx, client.GetID(), params) @@ -189,14 +187,13 @@ func (c *Handler) PopulateTokenEndpointResponse(ctx context.Context, requester f return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHintf("The OAuth 2.0 Client is not allowed to use authorization grant '%s'.", fosite.GrantTypeTokenExchange)) } - atLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypeTokenExchange, fosite.AccessToken, c.Config.GetAccessTokenLifespan(ctx)) + atLifespan := fosite.GetEffectiveLifespan(requester.GetClient(), fosite.GrantTypeTokenExchange, fosite.AccessToken, c.HandleHelper.Config.GetAccessTokenLifespan(ctx)) if err := c.IssueAccessToken(ctx, atLifespan, requester, responder); err != nil { return err } if canIssueRefreshToken(ctx, c, requester) { - fmt.Println(requester) refresh, refreshSignature, err := c.RefreshTokenStrategy.GenerateRefreshToken(ctx, requester) if err != nil { return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) @@ -211,7 +208,7 @@ func (c *Handler) PopulateTokenEndpointResponse(ctx context.Context, requester f } func canIssueRefreshToken(ctx context.Context, c *Handler, requester fosite.Requester) bool { - scope := c.GetRefreshTokenScopes(ctx) + scope := c.Config.GetRefreshTokenScopes(ctx) // Require one of the refresh token scopes, if set. if len(scope) > 0 && !requester.GetGrantedScopes().HasOneOf(scope...) { return false @@ -223,8 +220,12 @@ func canIssueRefreshToken(ctx context.Context, c *Handler, requester fosite.Requ return true } -func (c *Handler) CanSkipClientAuth(requester fosite.AccessRequester) bool { - return c.Strategy.CanSkipClientAuth(requester) +func (c *Handler) CanSkipClientAuth(ctx context.Context, requester fosite.AccessRequester) bool { + if s := c.Config.GetGrantTypeTokenExchangeCanSkipClientAuth(ctx); s != nil { + return s(ctx, requester) + } + + return false } func (c *Handler) keyFunc(ctx context.Context) jwt.Keyfunc { diff --git a/handler/rfc8693/handler_test.go b/handler/rfc8693/handler_test.go index 89a5cf7a7..0d3ad9502 100644 --- a/handler/rfc8693/handler_test.go +++ b/handler/rfc8693/handler_test.go @@ -30,6 +30,7 @@ func TestTokenExchange_HandleTokenEndpointRequest(t *testing.T) { h := Handler{ Storage: teStore, + Config: &fosite.Config{}, HandleHelper: &fositeOAuth2.HandleHelper{ AccessTokenStorage: atStore, AccessTokenStrategy: chgen, @@ -37,9 +38,7 @@ func TestTokenExchange_HandleTokenEndpointRequest(t *testing.T) { AccessTokenLifespan: time.Hour, }, }, - ScopeStrategy: fosite.HierarchicScopeStrategy, - AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, - RefreshTokenStorage: rtStore, + RefreshTokenStorage: rtStore, } for _, c := range []struct { @@ -286,13 +285,9 @@ func TestTokenExchange_PopulateTokenEndpointResponse(t *testing.T) { AccessTokenLifespan: time.Hour, }, }, - ScopeStrategy: fosite.HierarchicScopeStrategy, - AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, - RefreshTokenStrategy: rtStrategy, - RefreshTokenStorage: rtStore, - RefreshTokenScopesProvider: &fosite.Config{ - RefreshTokenScopes: []string{"offline", "offline_access"}, - }, + Config: &fosite.Config{}, + RefreshTokenStrategy: rtStrategy, + RefreshTokenStorage: rtStore, } for _, c := range []struct { name string diff --git a/handler/rfc8693/strategy.go b/handler/rfc8693/strategy.go deleted file mode 100644 index df2051628..000000000 --- a/handler/rfc8693/strategy.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright © 2022 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package rfc8693 - -//go:generate mockgen -source=strategy.go -destination=../../internal/oauth2_token_exchange_strategy.go -package=internal - -import "github.com/ory/fosite" - -type ClientAuthenticationStrategy interface { - CanSkipClientAuth(requester fosite.AccessRequester) bool -} - -// DefaultClientAuthenticationStrategy enforces client authentication for all the cases. -type DefaultClientAuthenticationStrategy struct{} - -func (s *DefaultClientAuthenticationStrategy) CanSkipClientAuth(requester fosite.Requester) bool { - return false -} diff --git a/internal/oauth2_token_exchange_strategy.go b/internal/oauth2_token_exchange_strategy.go deleted file mode 100644 index 66875129f..000000000 --- a/internal/oauth2_token_exchange_strategy.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright © 2022 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -// Code generated by MockGen. DO NOT EDIT. -// Source: strategy.go - -// Package internal is a generated GoMock package. -package internal - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" -) - -// MockClientAuthenticationStrategy is a mock of ClientAuthenticationStrategy interface. -type MockClientAuthenticationStrategy struct { - ctrl *gomock.Controller - recorder *MockClientAuthenticationStrategyMockRecorder -} - -// MockClientAuthenticationStrategyMockRecorder is the mock recorder for MockClientAuthenticationStrategy. -type MockClientAuthenticationStrategyMockRecorder struct { - mock *MockClientAuthenticationStrategy -} - -// NewMockClientAuthenticationStrategy creates a new mock instance. -func NewMockClientAuthenticationStrategy(ctrl *gomock.Controller) *MockClientAuthenticationStrategy { - mock := &MockClientAuthenticationStrategy{ctrl: ctrl} - mock.recorder = &MockClientAuthenticationStrategyMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockClientAuthenticationStrategy) EXPECT() *MockClientAuthenticationStrategyMockRecorder { - return m.recorder -} - -// CanSkipClientAuth mocks base method. -func (m *MockClientAuthenticationStrategy) CanSkipClientAuth(requester fosite.AccessRequester) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CanSkipClientAuth", requester) - ret0, _ := ret[0].(bool) - return ret0 -} - -// CanSkipClientAuth indicates an expected call of CanSkipClientAuth. -func (mr *MockClientAuthenticationStrategyMockRecorder) CanSkipClientAuth(requester interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSkipClientAuth", reflect.TypeOf((*MockClientAuthenticationStrategy)(nil).CanSkipClientAuth), requester) -} From 7fc351aa0c686e06e7b5d919ffc09f2cd9f46bdc Mon Sep 17 00:00:00 2001 From: Utkarsh Saxena Date: Sat, 22 Apr 2023 22:18:00 +0800 Subject: [PATCH 3/3] change(handler): update license to 2023 and fix mockgen --- generate-mocks.sh | 1 + handler/rfc8693/handler.go | 2 +- handler/rfc8693/handler_test.go | 2 +- handler/rfc8693/storage.go | 4 +--- internal/oauth2_token_exchange_storage.go | 28 +++++++++++------------ 5 files changed, 18 insertions(+), 19 deletions(-) diff --git a/generate-mocks.sh b/generate-mocks.sh index d4dded4ea..c01ad517f 100755 --- a/generate-mocks.sh +++ b/generate-mocks.sh @@ -28,5 +28,6 @@ mockgen -package internal -destination internal/access_request.go github.com/ory mockgen -package internal -destination internal/access_response.go github.com/ory/fosite AccessResponder mockgen -package internal -destination internal/authorize_request.go github.com/ory/fosite AuthorizeRequester mockgen -package internal -destination internal/authorize_response.go github.com/ory/fosite AuthorizeResponder +mockgen -package internal -destination internal/oauth2_token_exchange_storage.go github.com/ory/fosite/handler/rfc8693 RFC8693Storage goimports -w internal/ \ No newline at end of file diff --git a/handler/rfc8693/handler.go b/handler/rfc8693/handler.go index 774505068..f5f2cd2f8 100644 --- a/handler/rfc8693/handler.go +++ b/handler/rfc8693/handler.go @@ -1,4 +1,4 @@ -// Copyright © 2022 Ory Corp +// Copyright © 2023 Ory Corp // SPDX-License-Identifier: Apache-2.0 package rfc8693 diff --git a/handler/rfc8693/handler_test.go b/handler/rfc8693/handler_test.go index 0d3ad9502..0f1fdd3cc 100644 --- a/handler/rfc8693/handler_test.go +++ b/handler/rfc8693/handler_test.go @@ -1,4 +1,4 @@ -// Copyright © 2022 Ory Corp +// Copyright © 2023 Ory Corp // SPDX-License-Identifier: Apache-2.0 package rfc8693 diff --git a/handler/rfc8693/storage.go b/handler/rfc8693/storage.go index 528f720ba..7056b5fcb 100644 --- a/handler/rfc8693/storage.go +++ b/handler/rfc8693/storage.go @@ -1,10 +1,8 @@ -// Copyright © 2022 Ory Corp +// Copyright © 2023 Ory Corp // SPDX-License-Identifier: Apache-2.0 package rfc8693 -//go:generate mockgen -source=storage.go -destination=../../internal/oauth2_token_exchange_storage.go -package=internal - import ( "context" diff --git a/internal/oauth2_token_exchange_storage.go b/internal/oauth2_token_exchange_storage.go index a8db68cc6..e7aaac3a5 100644 --- a/internal/oauth2_token_exchange_storage.go +++ b/internal/oauth2_token_exchange_storage.go @@ -1,8 +1,8 @@ -// Copyright © 2022 Ory Corp +// Copyright © 2023 Ory Corp // SPDX-License-Identifier: Apache-2.0 // Code generated by MockGen. DO NOT EDIT. -// Source: storage.go +// Source: github.com/ory/fosite/handler/rfc8693 (interfaces: RFC8693Storage) // Package internal is a generated GoMock package. package internal @@ -40,46 +40,46 @@ func (m *MockRFC8693Storage) EXPECT() *MockRFC8693StorageMockRecorder { } // GetAllowedClientIDs mocks base method. -func (m *MockRFC8693Storage) GetAllowedClientIDs(ctx context.Context, clientID string) ([]string, error) { +func (m *MockRFC8693Storage) GetAllowedClientIDs(arg0 context.Context, arg1 string) ([]string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAllowedClientIDs", ctx, clientID) + ret := m.ctrl.Call(m, "GetAllowedClientIDs", arg0, arg1) ret0, _ := ret[0].([]string) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAllowedClientIDs indicates an expected call of GetAllowedClientIDs. -func (mr *MockRFC8693StorageMockRecorder) GetAllowedClientIDs(ctx, clientID interface{}) *gomock.Call { +func (mr *MockRFC8693StorageMockRecorder) GetAllowedClientIDs(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllowedClientIDs", reflect.TypeOf((*MockRFC8693Storage)(nil).GetAllowedClientIDs), ctx, clientID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllowedClientIDs", reflect.TypeOf((*MockRFC8693Storage)(nil).GetAllowedClientIDs), arg0, arg1) } // GetIDTokenPublicKey mocks base method. -func (m *MockRFC8693Storage) GetIDTokenPublicKey(ctx context.Context, iss, kid string) (interface{}, error) { +func (m *MockRFC8693Storage) GetIDTokenPublicKey(arg0 context.Context, arg1, arg2 string) (interface{}, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetIDTokenPublicKey", ctx, iss, kid) + ret := m.ctrl.Call(m, "GetIDTokenPublicKey", arg0, arg1, arg2) ret0, _ := ret[0].(interface{}) ret1, _ := ret[1].(error) return ret0, ret1 } // GetIDTokenPublicKey indicates an expected call of GetIDTokenPublicKey. -func (mr *MockRFC8693StorageMockRecorder) GetIDTokenPublicKey(ctx, iss, kid interface{}) *gomock.Call { +func (mr *MockRFC8693StorageMockRecorder) GetIDTokenPublicKey(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIDTokenPublicKey", reflect.TypeOf((*MockRFC8693Storage)(nil).GetIDTokenPublicKey), ctx, iss, kid) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIDTokenPublicKey", reflect.TypeOf((*MockRFC8693Storage)(nil).GetIDTokenPublicKey), arg0, arg1, arg2) } // GetImpersonateSubject mocks base method. -func (m *MockRFC8693Storage) GetImpersonateSubject(ctx context.Context, claims jwt.MapClaims, req fosite.Requester) (string, error) { +func (m *MockRFC8693Storage) GetImpersonateSubject(arg0 context.Context, arg1 jwt.MapClaims, arg2 fosite.Requester) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetImpersonateSubject", ctx, claims, req) + ret := m.ctrl.Call(m, "GetImpersonateSubject", arg0, arg1, arg2) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } // GetImpersonateSubject indicates an expected call of GetImpersonateSubject. -func (mr *MockRFC8693StorageMockRecorder) GetImpersonateSubject(ctx, claims, req interface{}) *gomock.Call { +func (mr *MockRFC8693StorageMockRecorder) GetImpersonateSubject(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetImpersonateSubject", reflect.TypeOf((*MockRFC8693Storage)(nil).GetImpersonateSubject), ctx, claims, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetImpersonateSubject", reflect.TypeOf((*MockRFC8693Storage)(nil).GetImpersonateSubject), arg0, arg1, arg2) }