Skip to content

Commit

Permalink
add caching of s2s tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
woutslakhorst committed Nov 18, 2024
1 parent 2374bd5 commit 9ed0647
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 7 deletions.
46 changes: 46 additions & 0 deletions auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import (
"bytes"
"context"
"crypto"
"crypto/sha256"
"embed"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -722,6 +724,25 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
return nil, err
}

tokenCache := r.accessTokenCache()
cacheKey, err := accessTokenRequestCacheKey(request)
cacheToken := true
if err != nil {
cacheToken = false
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to create cache key for access token request: %s", err.Error())
} else {
// try to retrieve token from cache
tokenResult := new(TokenResponse)
err = tokenCache.Get(cacheKey, tokenResult)
if err == nil {
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
} else if !errors.Is(err, storage.ErrNotFound) {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error())
}
}

var credentials []VerifiableCredential
if request.Body.Credentials != nil {
credentials = *request.Body.Credentials
Expand All @@ -737,6 +758,13 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
// this can be an internal server error, a 400 oauth error or a 412 precondition failed if the wallet does not contain the required credentials
return nil, err
}
if cacheToken {
err = tokenCache.Put(cacheKey, tokenResult)
if err != nil {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error())
}
}
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
}

Expand Down Expand Up @@ -897,6 +925,12 @@ func (r Wrapper) accessTokenServerStore() storage.SessionStore {
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity, "serveraccesstoken")
}

// accessTokenClientStore is used by the client to cache access tokens
func (r Wrapper) accessTokenCache() storage.SessionStore {
// we use a slightly reduced validity to prevent the cache from being used after the token has expired
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-30*time.Second, "accesstokencache")
}

// accessTokenServerStore is used by the Auth server to store issued access tokens
func (r Wrapper) authzRequestObjectStore() storage.SessionStore {
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity, oauthRequestObjectKey...)
Expand Down Expand Up @@ -946,3 +980,15 @@ func (r Wrapper) determineClientDID(ctx context.Context, authServerMetadata oaut
}
return &candidateDIDs[0], nil
}

// accessTokenRequestCacheKey creates a cache key for the access token request.
// it writes the JSON to a sha256 hash and returns the hex encoded hash.
func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) (string, error) {
// create a hash of the request
hash := sha256.New()
err := json.NewEncoder(hash).Encode(request)
if err != nil {
return "", err
}
return hex.EncodeToString(hash.Sum(nil)), nil
}
34 changes: 27 additions & 7 deletions auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/nuts-foundation/nuts-node/core/to"
"github.com/nuts-foundation/nuts-node/crypto/storage/spi"
test2 "github.com/nuts-foundation/nuts-node/crypto/test"
"github.com/nuts-foundation/nuts-node/http/user"
"github.com/nuts-foundation/nuts-node/test"
"github.com/nuts-foundation/nuts-node/vdr/didsubject"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -51,10 +45,15 @@ import (
"github.com/nuts-foundation/nuts-node/auth/oauth"
oauthServices "github.com/nuts-foundation/nuts-node/auth/services/oauth"
"github.com/nuts-foundation/nuts-node/core"
"github.com/nuts-foundation/nuts-node/core/to"
cryptoNuts "github.com/nuts-foundation/nuts-node/crypto"
"github.com/nuts-foundation/nuts-node/crypto/storage/spi"
test2 "github.com/nuts-foundation/nuts-node/crypto/test"
"github.com/nuts-foundation/nuts-node/http/user"
"github.com/nuts-foundation/nuts-node/jsonld"
"github.com/nuts-foundation/nuts-node/policy"
"github.com/nuts-foundation/nuts-node/storage"
"github.com/nuts-foundation/nuts-node/test"
"github.com/nuts-foundation/nuts-node/vcr"
"github.com/nuts-foundation/nuts-node/vcr/credential"
"github.com/nuts-foundation/nuts-node/vcr/holder"
Expand All @@ -63,6 +62,7 @@ import (
"github.com/nuts-foundation/nuts-node/vcr/types"
"github.com/nuts-foundation/nuts-node/vcr/verifier"
"github.com/nuts-foundation/nuts-node/vdr"
"github.com/nuts-foundation/nuts-node/vdr/didsubject"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -865,11 +865,31 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {

t.Run("ok", func(t *testing.T) {
ctx := newTestClient(t)
request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{}, nil)

_, err := ctx.client.RequestServiceAccessToken(nil, RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body})
token, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)

t.Run("is cached", func(t *testing.T) {
cachedToken, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)
assert.Equal(t, token, cachedToken)
})

t.Run("cache expired", func(t *testing.T) {
cacheKey, _ := accessTokenRequestCacheKey(request)
_ = ctx.client.accessTokenCache().Delete(cacheKey)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "other"}, nil)

otherToken, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)

assert.NotEqual(t, token, otherToken)
})
})
t.Run("ok - no DPoP", func(t *testing.T) {
ctx := newTestClient(t)
Expand Down

0 comments on commit 9ed0647

Please sign in to comment.