Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache access tokens client side #3565

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import (
"bytes"
"context"
"crypto"
"crypto/sha256"
"embed"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -72,6 +74,10 @@ type httpRequestContextKey struct{}
// TODO: Might want to make this configurable at some point
const accessTokenValidity = 15 * time.Minute

// accessTokenCacheOffset is used to reduce the ttl of the access token to ensure it is still valid when the client receives it.
// this to offset clock skew and roundtrip times
const accessTokenCacheOffset = 30 * time.Second

// cacheControlMaxAgeURLs holds API endpoints that should have a max-age cache control header set.
var cacheControlMaxAgeURLs = []string{
"/oauth2/:subjectID/presentation_definition",
Expand Down Expand Up @@ -722,6 +728,19 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
return nil, err
}

tokenCache := r.accessTokenCache()
cacheKey := accessTokenRequestCacheKey(request)

// try to retrieve token from cache
tokenResult := new(TokenResponse)
err = tokenCache.Get(cacheKey, tokenResult)
if err == nil {
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
} else if !errors.Is(err, storage.ErrNotFound) {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error())
}

var credentials []VerifiableCredential
if request.Body.Credentials != nil {
credentials = *request.Body.Credentials
Expand All @@ -732,11 +751,22 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
useDPoP = false
}
clientID := r.subjectToBaseURL(request.SubjectID)
tokenResult, err := r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials)
tokenResult, err = r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials)
if err != nil {
// this can be an internal server error, a 400 oauth error or a 412 precondition failed if the wallet does not contain the required credentials
return nil, err
}
ttl := accessTokenValidity
if tokenResult.ExpiresIn != nil {
ttl = time.Second * time.Duration(*tokenResult.ExpiresIn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when could it not be set 🤔 ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a pointer, probably because the spec says it's optional. In our implementation it's not but you never know.

}
// we reduce the ttl by accessTokenCacheOffset to make sure the token is expired when the cache expires
ttl -= accessTokenCacheOffset
err = tokenCache.Put(cacheKey, tokenResult, storage.WithTTL(ttl))
if err != nil {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error())
}
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
}

Expand Down Expand Up @@ -897,6 +927,12 @@ func (r Wrapper) accessTokenServerStore() storage.SessionStore {
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity, "serveraccesstoken")
}

// accessTokenClientStore is used by the client to cache access tokens
func (r Wrapper) accessTokenCache() storage.SessionStore {
// we use a slightly reduced validity to prevent the cache from being used after the token has expired
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-accessTokenCacheOffset, "accesstokencache")
}

// accessTokenServerStore is used by the Auth server to store issued access tokens
func (r Wrapper) authzRequestObjectStore() storage.SessionStore {
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity, oauthRequestObjectKey...)
Expand Down Expand Up @@ -946,3 +982,11 @@ func (r Wrapper) determineClientDID(ctx context.Context, authServerMetadata oaut
}
return &candidateDIDs[0], nil
}

// accessTokenRequestCacheKey creates a cache key for the access token request.
// it writes the JSON to a sha256 hash and returns the hex encoded hash.
func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) string {
hash := sha256.New()
_ = json.NewEncoder(hash).Encode(request)
return hex.EncodeToString(hash.Sum(nil))
}
71 changes: 64 additions & 7 deletions auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/nuts-foundation/nuts-node/core/to"
"github.com/nuts-foundation/nuts-node/crypto/storage/spi"
test2 "github.com/nuts-foundation/nuts-node/crypto/test"
"github.com/nuts-foundation/nuts-node/http/user"
"github.com/nuts-foundation/nuts-node/test"
"github.com/nuts-foundation/nuts-node/vdr/didsubject"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -51,10 +45,15 @@ import (
"github.com/nuts-foundation/nuts-node/auth/oauth"
oauthServices "github.com/nuts-foundation/nuts-node/auth/services/oauth"
"github.com/nuts-foundation/nuts-node/core"
"github.com/nuts-foundation/nuts-node/core/to"
cryptoNuts "github.com/nuts-foundation/nuts-node/crypto"
"github.com/nuts-foundation/nuts-node/crypto/storage/spi"
test2 "github.com/nuts-foundation/nuts-node/crypto/test"
"github.com/nuts-foundation/nuts-node/http/user"
"github.com/nuts-foundation/nuts-node/jsonld"
"github.com/nuts-foundation/nuts-node/policy"
"github.com/nuts-foundation/nuts-node/storage"
"github.com/nuts-foundation/nuts-node/test"
"github.com/nuts-foundation/nuts-node/vcr"
"github.com/nuts-foundation/nuts-node/vcr/credential"
"github.com/nuts-foundation/nuts-node/vcr/holder"
Expand All @@ -63,6 +62,7 @@ import (
"github.com/nuts-foundation/nuts-node/vcr/types"
"github.com/nuts-foundation/nuts-node/vcr/verifier"
"github.com/nuts-foundation/nuts-node/vdr"
"github.com/nuts-foundation/nuts-node/vdr/didsubject"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -865,11 +865,31 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {

t.Run("ok", func(t *testing.T) {
ctx := newTestClient(t)
request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{}, nil)

_, err := ctx.client.RequestServiceAccessToken(nil, RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body})
token, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)

t.Run("is cached", func(t *testing.T) {
cachedToken, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)
assert.Equal(t, token, cachedToken)
})

t.Run("cache expired", func(t *testing.T) {
cacheKey := accessTokenRequestCacheKey(request)
_ = ctx.client.accessTokenCache().Delete(cacheKey)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "other"}, nil)

otherToken, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)

assert.NotEqual(t, token, otherToken)
})
})
t.Run("ok - no DPoP", func(t *testing.T) {
ctx := newTestClient(t)
Expand All @@ -885,6 +905,16 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {

require.NoError(t, err)
})
t.Run("ok with expired cache by ttl", func(t *testing.T) {
ctx := newTestClient(t)
request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{ExpiresIn: to.Ptr(5)}, nil)

_, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)
assert.False(t, ctx.client.accessTokenCache().Exists(accessTokenRequestCacheKey(request)))
})
t.Run("error - no matching credentials", func(t *testing.T) {
ctx := newTestClient(t)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(nil, pe.ErrNoCredentials)
Expand All @@ -895,6 +925,24 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {
assert.Equal(t, err, pe.ErrNoCredentials)
assert.Equal(t, http.StatusPreconditionFailed, statusCodeFrom(err))
})
t.Run("broken cache", func(t *testing.T) {
ctx := newTestClient(t)
mockStorage := storage.NewMockEngine(ctx.ctrl)
errorSessionDatabase := storage.NewErrorSessionDatabase(assert.AnError)
mockStorage.EXPECT().GetSessionDatabase().Return(errorSessionDatabase).AnyTimes()
ctx.client.storageEngine = mockStorage

request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "first"}, nil)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "second"}, nil)

token1, err := ctx.client.RequestServiceAccessToken(nil, request)
require.NoError(t, err)
token2, err := ctx.client.RequestServiceAccessToken(nil, request)
require.NoError(t, err)

assert.NotEqual(t, token1, token2)
})
}

func TestWrapper_RequestUserAccessToken(t *testing.T) {
Expand Down Expand Up @@ -1320,6 +1368,15 @@ func TestWrapper_subjectOwns(t *testing.T) {
})
}

func TestWrapper_accessTokenRequestCacheKey(t *testing.T) {
expected := "0cc6fbbd972c72de7bc86c6147347bdd54bcb41fe23cea3d8f61d6ddd75dbf86"
key := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test"}})
other := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test2"}})

assert.Equal(t, expected, key)
assert.NotEqual(t, key, other)
}

func createIssuerCredential(issuerDID did.DID, holderDID did.DID) *vc.VerifiableCredential {
privateKey, _ := spi.GenerateKeyPair()
credType := ssi.MustParseURI("ExampleType")
Expand Down
2 changes: 1 addition & 1 deletion storage/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (e *engine) Shutdown() error {
}

// Close session database
e.sessionDatabase.close()
e.sessionDatabase.Close()
// Close SQL db
if e.sqlDB != nil {
underlyingDB, err := e.sqlDB.DB()
Expand Down
23 changes: 20 additions & 3 deletions storage/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ type SessionDatabase interface {
// The keys are used to logically partition the store, eg: tenants and/or flows that are not allowed to overlap like credential issuance and verification.
// The TTL is the time-to-live for the entries in the store.
GetStore(ttl time.Duration, keys ...string) SessionStore
// close stops any background processes and closes the database.
close()
// getFullKey returns the full key for the given key and prefixes.
// the supported chars differ per backend.
getFullKey(prefixes []string, key string) string
// Close stops any background processes and closes the database.
Close()
}

// SessionStore is a key-value store that holds session data.
Expand All @@ -95,10 +97,25 @@ type SessionStore interface {
// Returns ErrNotFound if the key does not exist.
Get(key string, target interface{}) error
// Put stores the given value for the given key.
Put(key string, value interface{}) error
// options can be used to fine-tune the storage of the item.
Put(key string, value interface{}, options ...SessionOption) error
// GetAndDelete combines Get and Delete as a convenience for burning nonce entries.
GetAndDelete(key string, target interface{}) error
}

// TransactionKey is the key used to store the SQL transaction in the context.
type TransactionKey struct{}

// SessionOption is an option that can be given when storing items.
type SessionOption func(target *sessionOptions)

type sessionOptions struct {
ttl time.Duration
}

// WithTTL sets the time-to-live for the stored item.
func WithTTL(ttl time.Duration) SessionOption {
return func(target *sessionOptions) {
target.ttl = ttl
}
}
37 changes: 21 additions & 16 deletions storage/mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 17 additions & 2 deletions storage/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,31 @@ func (s SessionStoreImpl[T]) Get(key string, target interface{}) error {
return json.Unmarshal([]byte(val), target)
}

func (s SessionStoreImpl[T]) Put(key string, value interface{}) error {
func (s SessionStoreImpl[T]) Put(key string, value interface{}, options ...SessionOption) error {
opts := s.defaultOptions()
for _, opt := range options {
opt(&opts)
}
// TTL can't go below 0 because that is translated to "no expiration" by the library
// so just don't cache
if opts.ttl <= 0 {
return nil
}
bytes, err := json.Marshal(value)
if err != nil {
return err
}
return s.underlying.Set(context.Background(), s.db.getFullKey(s.prefixes, key), T(bytes), store.WithExpiration(s.ttl))
return s.underlying.Set(context.Background(), s.db.getFullKey(s.prefixes, key), T(bytes), store.WithExpiration(opts.ttl))
}
func (s SessionStoreImpl[T]) GetAndDelete(key string, target interface{}) error {
if err := s.Get(key, target); err != nil {
return err
}
return s.underlying.Delete(context.Background(), s.db.getFullKey(s.prefixes, key))
}

func (s SessionStoreImpl[T]) defaultOptions() sessionOptions {
return sessionOptions{
ttl: s.ttl,
}
}
2 changes: 1 addition & 1 deletion storage/session_inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (s *InMemorySessionDatabase) GetStore(ttl time.Duration, keys ...string) Se
}
}

func (s *InMemorySessionDatabase) close() {
func (s *InMemorySessionDatabase) Close() {
// NOP
}

Expand Down
Loading
Loading