From 9ed0647ce1c86bf0a4a6dd1e84527825e5e1afad Mon Sep 17 00:00:00 2001 From: Wout Slakhorst Date: Mon, 18 Nov 2024 11:10:15 +0100 Subject: [PATCH] add caching of s2s tokens --- auth/api/iam/api.go | 46 ++++++++++++++++++++++++++++++++++++++++ auth/api/iam/api_test.go | 34 +++++++++++++++++++++++------ 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/auth/api/iam/api.go b/auth/api/iam/api.go index 1335973fb..7b6b84edb 100644 --- a/auth/api/iam/api.go +++ b/auth/api/iam/api.go @@ -22,8 +22,10 @@ import ( "bytes" "context" "crypto" + "crypto/sha256" "embed" "encoding/base64" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -722,6 +724,25 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS return nil, err } + tokenCache := r.accessTokenCache() + cacheKey, err := accessTokenRequestCacheKey(request) + cacheToken := true + if err != nil { + cacheToken = false + // only log error, don't fail + log.Logger().WithError(err).Warnf("Failed to create cache key for access token request: %s", err.Error()) + } else { + // try to retrieve token from cache + tokenResult := new(TokenResponse) + err = tokenCache.Get(cacheKey, tokenResult) + if err == nil { + return RequestServiceAccessToken200JSONResponse(*tokenResult), nil + } else if !errors.Is(err, storage.ErrNotFound) { + // only log error, don't fail + log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error()) + } + } + var credentials []VerifiableCredential if request.Body.Credentials != nil { credentials = *request.Body.Credentials @@ -737,6 +758,13 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS // this can be an internal server error, a 400 oauth error or a 412 precondition failed if the wallet does not contain the required credentials return nil, err } + if cacheToken { + err = tokenCache.Put(cacheKey, tokenResult) + if err != nil { + // only log error, don't fail + log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error()) + } + } return RequestServiceAccessToken200JSONResponse(*tokenResult), nil } @@ -897,6 +925,12 @@ func (r Wrapper) accessTokenServerStore() storage.SessionStore { return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity, "serveraccesstoken") } +// accessTokenClientStore is used by the client to cache access tokens +func (r Wrapper) accessTokenCache() storage.SessionStore { + // we use a slightly reduced validity to prevent the cache from being used after the token has expired + return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-30*time.Second, "accesstokencache") +} + // accessTokenServerStore is used by the Auth server to store issued access tokens func (r Wrapper) authzRequestObjectStore() storage.SessionStore { return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity, oauthRequestObjectKey...) @@ -946,3 +980,15 @@ func (r Wrapper) determineClientDID(ctx context.Context, authServerMetadata oaut } return &candidateDIDs[0], nil } + +// accessTokenRequestCacheKey creates a cache key for the access token request. +// it writes the JSON to a sha256 hash and returns the hex encoded hash. +func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) (string, error) { + // create a hash of the request + hash := sha256.New() + err := json.NewEncoder(hash).Encode(request) + if err != nil { + return "", err + } + return hex.EncodeToString(hash.Sum(nil)), nil +} diff --git a/auth/api/iam/api_test.go b/auth/api/iam/api_test.go index 8e3ae8dbc..aaac9ad8e 100644 --- a/auth/api/iam/api_test.go +++ b/auth/api/iam/api_test.go @@ -24,12 +24,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/nuts-foundation/nuts-node/core/to" - "github.com/nuts-foundation/nuts-node/crypto/storage/spi" - test2 "github.com/nuts-foundation/nuts-node/crypto/test" - "github.com/nuts-foundation/nuts-node/http/user" - "github.com/nuts-foundation/nuts-node/test" - "github.com/nuts-foundation/nuts-node/vdr/didsubject" "io" "net/http" "net/http/httptest" @@ -51,10 +45,15 @@ import ( "github.com/nuts-foundation/nuts-node/auth/oauth" oauthServices "github.com/nuts-foundation/nuts-node/auth/services/oauth" "github.com/nuts-foundation/nuts-node/core" + "github.com/nuts-foundation/nuts-node/core/to" cryptoNuts "github.com/nuts-foundation/nuts-node/crypto" + "github.com/nuts-foundation/nuts-node/crypto/storage/spi" + test2 "github.com/nuts-foundation/nuts-node/crypto/test" + "github.com/nuts-foundation/nuts-node/http/user" "github.com/nuts-foundation/nuts-node/jsonld" "github.com/nuts-foundation/nuts-node/policy" "github.com/nuts-foundation/nuts-node/storage" + "github.com/nuts-foundation/nuts-node/test" "github.com/nuts-foundation/nuts-node/vcr" "github.com/nuts-foundation/nuts-node/vcr/credential" "github.com/nuts-foundation/nuts-node/vcr/holder" @@ -63,6 +62,7 @@ import ( "github.com/nuts-foundation/nuts-node/vcr/types" "github.com/nuts-foundation/nuts-node/vcr/verifier" "github.com/nuts-foundation/nuts-node/vdr" + "github.com/nuts-foundation/nuts-node/vdr/didsubject" "github.com/nuts-foundation/nuts-node/vdr/resolver" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -865,11 +865,31 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) { t.Run("ok", func(t *testing.T) { ctx := newTestClient(t) + request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body} ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{}, nil) - _, err := ctx.client.RequestServiceAccessToken(nil, RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}) + token, err := ctx.client.RequestServiceAccessToken(nil, request) require.NoError(t, err) + + t.Run("is cached", func(t *testing.T) { + cachedToken, err := ctx.client.RequestServiceAccessToken(nil, request) + + require.NoError(t, err) + assert.Equal(t, token, cachedToken) + }) + + t.Run("cache expired", func(t *testing.T) { + cacheKey, _ := accessTokenRequestCacheKey(request) + _ = ctx.client.accessTokenCache().Delete(cacheKey) + ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "other"}, nil) + + otherToken, err := ctx.client.RequestServiceAccessToken(nil, request) + + require.NoError(t, err) + + assert.NotEqual(t, token, otherToken) + }) }) t.Run("ok - no DPoP", func(t *testing.T) { ctx := newTestClient(t)