From a2623948050a4ef66be721a78800cf788b755bda Mon Sep 17 00:00:00 2001 From: Wout Slakhorst Date: Mon, 18 Nov 2024 12:03:29 +0100 Subject: [PATCH] allow for custom ttl in session store --- auth/api/iam/api.go | 58 ++++++++++++------------ auth/api/iam/api_test.go | 77 +++++++++++++++++++++++++++++++- storage/engine.go | 2 +- storage/interface.go | 21 +++++++-- storage/mock.go | 37 ++++++++------- storage/session_inmemory.go | 27 ++++++++--- storage/session_inmemory_test.go | 34 +++++++++++++- storage/session_redis.go | 17 +++++-- storage/session_redis_test.go | 14 +++--- storage/test.go | 2 +- 10 files changed, 220 insertions(+), 69 deletions(-) diff --git a/auth/api/iam/api.go b/auth/api/iam/api.go index 7b6b84edbf..34e1469abe 100644 --- a/auth/api/iam/api.go +++ b/auth/api/iam/api.go @@ -74,6 +74,10 @@ type httpRequestContextKey struct{} // TODO: Might want to make this configurable at some point const accessTokenValidity = 15 * time.Minute +// accessTokenCacheOffset is used to reduce the ttl of the access token to ensure it is still valid when the client receives it. +// this to offset clock skew and roundtrip times +const accessTokenCacheOffset = 30 * time.Second + // cacheControlMaxAgeURLs holds API endpoints that should have a max-age cache control header set. var cacheControlMaxAgeURLs = []string{ "/oauth2/:subjectID/presentation_definition", @@ -725,22 +729,16 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS } tokenCache := r.accessTokenCache() - cacheKey, err := accessTokenRequestCacheKey(request) - cacheToken := true - if err != nil { - cacheToken = false + cacheKey := accessTokenRequestCacheKey(request) + + // try to retrieve token from cache + tokenResult := new(TokenResponse) + err = tokenCache.Get(cacheKey, tokenResult) + if err == nil { + return RequestServiceAccessToken200JSONResponse(*tokenResult), nil + } else if !errors.Is(err, storage.ErrNotFound) { // only log error, don't fail - log.Logger().WithError(err).Warnf("Failed to create cache key for access token request: %s", err.Error()) - } else { - // try to retrieve token from cache - tokenResult := new(TokenResponse) - err = tokenCache.Get(cacheKey, tokenResult) - if err == nil { - return RequestServiceAccessToken200JSONResponse(*tokenResult), nil - } else if !errors.Is(err, storage.ErrNotFound) { - // only log error, don't fail - log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error()) - } + log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error()) } var credentials []VerifiableCredential @@ -753,17 +751,21 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS useDPoP = false } clientID := r.subjectToBaseURL(request.SubjectID) - tokenResult, err := r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials) + tokenResult, err = r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials) if err != nil { // this can be an internal server error, a 400 oauth error or a 412 precondition failed if the wallet does not contain the required credentials return nil, err } - if cacheToken { - err = tokenCache.Put(cacheKey, tokenResult) - if err != nil { - // only log error, don't fail - log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error()) - } + ttl := accessTokenValidity + if tokenResult.ExpiresIn != nil { + ttl = time.Second * time.Duration(*tokenResult.ExpiresIn) + } + // we reduce the ttl by accessTokenCacheOffset to make sure the token is expired when the cache expires + ttl -= accessTokenCacheOffset + err = tokenCache.Put(cacheKey, tokenResult, storage.WithTTL(ttl)) + if err != nil { + // only log error, don't fail + log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error()) } return RequestServiceAccessToken200JSONResponse(*tokenResult), nil } @@ -928,7 +930,7 @@ func (r Wrapper) accessTokenServerStore() storage.SessionStore { // accessTokenClientStore is used by the client to cache access tokens func (r Wrapper) accessTokenCache() storage.SessionStore { // we use a slightly reduced validity to prevent the cache from being used after the token has expired - return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-30*time.Second, "accesstokencache") + return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-accessTokenCacheOffset, "accesstokencache") } // accessTokenServerStore is used by the Auth server to store issued access tokens @@ -983,12 +985,8 @@ func (r Wrapper) determineClientDID(ctx context.Context, authServerMetadata oaut // accessTokenRequestCacheKey creates a cache key for the access token request. // it writes the JSON to a sha256 hash and returns the hex encoded hash. -func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) (string, error) { - // create a hash of the request +func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) string { hash := sha256.New() - err := json.NewEncoder(hash).Encode(request) - if err != nil { - return "", err - } - return hex.EncodeToString(hash.Sum(nil)), nil + _ = json.NewEncoder(hash).Encode(request) + return hex.EncodeToString(hash.Sum(nil)) } diff --git a/auth/api/iam/api_test.go b/auth/api/iam/api_test.go index aaac9ad8eb..0b233cd9d9 100644 --- a/auth/api/iam/api_test.go +++ b/auth/api/iam/api_test.go @@ -880,7 +880,7 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) { }) t.Run("cache expired", func(t *testing.T) { - cacheKey, _ := accessTokenRequestCacheKey(request) + cacheKey := accessTokenRequestCacheKey(request) _ = ctx.client.accessTokenCache().Delete(cacheKey) ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "other"}, nil) @@ -905,6 +905,16 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) { require.NoError(t, err) }) + t.Run("ok with expired cache by ttl", func(t *testing.T) { + ctx := newTestClient(t) + request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body} + ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{ExpiresIn: to.Ptr(0)}, nil) + + _, err := ctx.client.RequestServiceAccessToken(nil, request) + + require.NoError(t, err) + assert.False(t, ctx.client.accessTokenCache().Exists(accessTokenRequestCacheKey(request))) + }) t.Run("error - no matching credentials", func(t *testing.T) { ctx := newTestClient(t) ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(nil, pe.ErrNoCredentials) @@ -915,6 +925,23 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) { assert.Equal(t, err, pe.ErrNoCredentials) assert.Equal(t, http.StatusPreconditionFailed, statusCodeFrom(err)) }) + t.Run("broken cache", func(t *testing.T) { + ctx := newTestClient(t) + mockStorage := storage.NewMockEngine(ctx.ctrl) + mockStorage.EXPECT().GetSessionDatabase().Return(errorSessionDatabase{err: assert.AnError}).AnyTimes() + ctx.client.storageEngine = mockStorage + + request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body} + ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "first"}, nil) + ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "second"}, nil) + + token1, err := ctx.client.RequestServiceAccessToken(nil, request) + require.NoError(t, err) + token2, err := ctx.client.RequestServiceAccessToken(nil, request) + require.NoError(t, err) + + assert.NotEqual(t, token1, token2) + }) } func TestWrapper_RequestUserAccessToken(t *testing.T) { @@ -1340,6 +1367,15 @@ func TestWrapper_subjectOwns(t *testing.T) { }) } +func TestWrapper_accessTokenRequestCacheKey(t *testing.T) { + expected := "0cc6fbbd972c72de7bc86c6147347bdd54bcb41fe23cea3d8f61d6ddd75dbf86" + key := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test"}}) + other := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test2"}}) + + assert.Equal(t, expected, key) + assert.NotEqual(t, key, other) +} + func createIssuerCredential(issuerDID did.DID, holderDID did.DID) *vc.VerifiableCredential { privateKey, _ := spi.GenerateKeyPair() credType := ssi.MustParseURI("ExampleType") @@ -1509,3 +1545,42 @@ func newCustomTestClient(t testing.TB, publicURL *url.URL, authEndpointEnabled b client: client, } } + +var _ storage.SessionDatabase = (*errorSessionDatabase)(nil) +var _ storage.SessionStore = (*errorSessionStore)(nil) + +type errorSessionDatabase struct { + err error +} + +type errorSessionStore struct { + err error +} + +func (e errorSessionDatabase) GetStore(ttl time.Duration, keys ...string) storage.SessionStore { + return errorSessionStore{err: e.err} +} + +func (e errorSessionDatabase) Close() { + // nop +} + +func (e errorSessionStore) Delete(key string) error { + return e.err +} + +func (e errorSessionStore) Exists(key string) bool { + return false +} + +func (e errorSessionStore) Get(key string, target interface{}) error { + return e.err +} + +func (e errorSessionStore) Put(key string, value interface{}, options ...storage.SessionOption) error { + return e.err +} + +func (e errorSessionStore) GetAndDelete(key string, target interface{}) error { + return e.err +} diff --git a/storage/engine.go b/storage/engine.go index 6826d0a3be..1ab48485c8 100644 --- a/storage/engine.go +++ b/storage/engine.go @@ -135,7 +135,7 @@ func (e *engine) Shutdown() error { } // Close session database - e.sessionDatabase.close() + e.sessionDatabase.Close() // Close SQL db if e.sqlDB != nil { underlyingDB, err := e.sqlDB.DB() diff --git a/storage/interface.go b/storage/interface.go index b6aed53097..c978869458 100644 --- a/storage/interface.go +++ b/storage/interface.go @@ -78,8 +78,8 @@ type SessionDatabase interface { // The keys are used to logically partition the store, eg: tenants and/or flows that are not allowed to overlap like credential issuance and verification. // The TTL is the time-to-live for the entries in the store. GetStore(ttl time.Duration, keys ...string) SessionStore - // close stops any background processes and closes the database. - close() + // Close stops any background processes and closes the database. + Close() } // SessionStore is a key-value store that holds session data. @@ -94,10 +94,25 @@ type SessionStore interface { // Returns ErrNotFound if the key does not exist. Get(key string, target interface{}) error // Put stores the given value for the given key. - Put(key string, value interface{}) error + // options can be used to fine-tune the storage of the item. + Put(key string, value interface{}, options ...SessionOption) error // GetAndDelete combines Get and Delete as a convenience for burning nonce entries. GetAndDelete(key string, target interface{}) error } // TransactionKey is the key used to store the SQL transaction in the context. type TransactionKey struct{} + +// SessionOption is an option that can be given when storing items. +type SessionOption func(target *sessionOptions) + +type sessionOptions struct { + ttl time.Duration +} + +// WithTTL sets the time-to-live for the stored item. +func WithTTL(ttl time.Duration) SessionOption { + return func(target *sessionOptions) { + target.ttl = ttl + } +} diff --git a/storage/mock.go b/storage/mock.go index 3cf42b60ad..272de38f3f 100644 --- a/storage/mock.go +++ b/storage/mock.go @@ -255,6 +255,18 @@ func (m *MockSessionDatabase) EXPECT() *MockSessionDatabaseMockRecorder { return m.recorder } +// Close mocks base method. +func (m *MockSessionDatabase) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockSessionDatabaseMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSessionDatabase)(nil).Close)) +} + // GetStore mocks base method. func (m *MockSessionDatabase) GetStore(ttl time.Duration, keys ...string) SessionStore { m.ctrl.T.Helper() @@ -274,18 +286,6 @@ func (mr *MockSessionDatabaseMockRecorder) GetStore(ttl any, keys ...any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStore", reflect.TypeOf((*MockSessionDatabase)(nil).GetStore), varargs...) } -// close mocks base method. -func (m *MockSessionDatabase) close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "close") -} - -// close indicates an expected call of close. -func (mr *MockSessionDatabaseMockRecorder) close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockSessionDatabase)(nil).close)) -} - // MockSessionStore is a mock of SessionStore interface. type MockSessionStore struct { ctrl *gomock.Controller @@ -367,15 +367,20 @@ func (mr *MockSessionStoreMockRecorder) GetAndDelete(key, target any) *gomock.Ca } // Put mocks base method. -func (m *MockSessionStore) Put(key string, value any) error { +func (m *MockSessionStore) Put(key string, value any, options ...SessionOption) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Put", key, value) + varargs := []any{key, value} + for _, a := range options { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Put", varargs...) ret0, _ := ret[0].(error) return ret0 } // Put indicates an expected call of Put. -func (mr *MockSessionStoreMockRecorder) Put(key, value any) *gomock.Call { +func (mr *MockSessionStoreMockRecorder) Put(key, value any, options ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockSessionStore)(nil).Put), key, value) + varargs := append([]any{key, value}, options...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockSessionStore)(nil).Put), varargs...) } diff --git a/storage/session_inmemory.go b/storage/session_inmemory.go index c5dfa96ccb..b600edb426 100644 --- a/storage/session_inmemory.go +++ b/storage/session_inmemory.go @@ -66,7 +66,7 @@ func (i *InMemorySessionDatabase) GetStore(ttl time.Duration, keys ...string) Se } } -func (i *InMemorySessionDatabase) close() { +func (i *InMemorySessionDatabase) Close() { // Signal pruner to stop and wait for it to finish i.done <- struct{}{} } @@ -127,8 +127,14 @@ func (i InMemorySessionStore) Exists(key string) bool { i.db.mux.Lock() defer i.db.mux.Unlock() - _, ok := i.db.entries[i.getFullKey(key)] - return ok + entry, ok := i.db.entries[i.getFullKey(key)] + if !ok { + return false + } + if entry.Expiry.Before(time.Now()) { + return false + } + return true } func (i InMemorySessionStore) Get(key string, target interface{}) error { @@ -151,7 +157,12 @@ func (i InMemorySessionStore) get(key string, target interface{}) error { return json.Unmarshal([]byte(entry.Value), target) } -func (i InMemorySessionStore) Put(key string, value interface{}) error { +func (i InMemorySessionStore) Put(key string, value interface{}, options ...SessionOption) error { + defaultOptions := i.defaultOptions() + for _, option := range options { + option(&defaultOptions) + } + i.db.mux.Lock() defer i.db.mux.Unlock() @@ -161,7 +172,7 @@ func (i InMemorySessionStore) Put(key string, value interface{}) error { } entry := expiringEntry{ Value: string(bytes), - Expiry: time.Now().Add(i.ttl), + Expiry: time.Now().Add(defaultOptions.ttl), } i.db.entries[i.getFullKey(key)] = entry @@ -180,3 +191,9 @@ func (i InMemorySessionStore) GetAndDelete(key string, target interface{}) error func (i InMemorySessionStore) getFullKey(key string) string { return strings.Join(append(i.prefixes, key), "/") } + +func (i InMemorySessionStore) defaultOptions() sessionOptions { + return sessionOptions{ + ttl: i.ttl, + } +} diff --git a/storage/session_inmemory_test.go b/storage/session_inmemory_test.go index bfc29aa77a..2d91c4db22 100644 --- a/storage/session_inmemory_test.go +++ b/storage/session_inmemory_test.go @@ -45,6 +45,36 @@ func TestInMemorySessionDatabase_GetStore(t *testing.T) { assert.Equal(t, []string{"key1", "key2"}, store.prefixes) } +func TestInMemorySessionStore_Exists(t *testing.T) { + db := createDatabase(t) + store := db.GetStore(time.Minute, "prefix").(InMemorySessionStore) + + t.Run("value exists", func(t *testing.T) { + _ = store.Put(t.Name(), "value") + + exists := store.Exists(t.Name()) + + assert.True(t, exists) + }) + + t.Run("value does not exist", func(t *testing.T) { + exists := store.Exists(t.Name()) + + assert.False(t, exists) + }) + + t.Run("value is expired", func(t *testing.T) { + store.db.entries["prefix/key"] = expiringEntry{ + Value: "", + Expiry: time.Now().Add(-time.Minute), + } + + exists := store.Exists("key") + + assert.False(t, exists) + }) +} + func TestInMemorySessionStore_Put(t *testing.T) { db := createDatabase(t) store := db.GetStore(time.Minute, "prefix").(InMemorySessionStore) @@ -210,7 +240,7 @@ func TestInMemorySessionDatabase_Close(t *testing.T) { }() store := NewInMemorySessionDatabase() time.Sleep(50 * time.Millisecond) // make sure pruning is running - store.close() + store.Close() }) } @@ -235,7 +265,7 @@ func Test_memoryStore_prune(t *testing.T) { }) t.Run("prunes expired flows", func(t *testing.T) { store := createDatabase(t) - defer store.close() + defer store.Close() _ = store.GetStore(0).Put("key1", "value") _ = store.GetStore(time.Minute).Put("key2", "value") diff --git a/storage/session_redis.go b/storage/session_redis.go index 83ca50ffa6..f59c6cf6d4 100644 --- a/storage/session_redis.go +++ b/storage/session_redis.go @@ -54,7 +54,7 @@ func (s redisSessionDatabase) GetStore(ttl time.Duration, keys ...string) Sessio } } -func (s redisSessionDatabase) close() { +func (s redisSessionDatabase) Close() { err := s.client.Close() if err != nil { log.Logger().WithError(err).Error("Failed to close redis client") @@ -91,12 +91,17 @@ func (s redisSessionStore) Get(key string, target interface{}) error { return json.Unmarshal([]byte(result), target) } -func (s redisSessionStore) Put(key string, value interface{}) error { +func (s redisSessionStore) Put(key string, value interface{}, options ...SessionOption) error { + defaultOptions := s.defaultOptions() + for _, option := range options { + option(&defaultOptions) + } + marshal, err := json.Marshal(value) if err != nil { return err } - return s.client.Set(context.Background(), s.toRedisKey(key), marshal, s.ttl).Err() + return s.client.Set(context.Background(), s.toRedisKey(key), marshal, defaultOptions.ttl).Err() } func (s redisSessionStore) GetAndDelete(key string, target interface{}) error { @@ -117,3 +122,9 @@ func (s redisSessionStore) toRedisKey(key string) string { } return key } + +func (i redisSessionStore) defaultOptions() sessionOptions { + return sessionOptions{ + ttl: i.ttl, + } +} diff --git a/storage/session_redis_test.go b/storage/session_redis_test.go index b957c15b55..500033c50f 100644 --- a/storage/session_redis_test.go +++ b/storage/session_redis_test.go @@ -45,7 +45,7 @@ func TestRedisSessionStore(t *testing.T) { store, _ := NewTestStorageEngineRedis(t) require.NoError(t, store.Start()) sessions := store.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() t.Run("lifecycle", func(t *testing.T) { store := sessions.GetStore(time.Second, "unit") @@ -66,7 +66,7 @@ func TestRedisSessionStore_Get(t *testing.T) { storageEngine, miniRedis := NewTestStorageEngineRedis(t) require.NoError(t, storageEngine.Start()) sessions := storageEngine.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() var actual testType t.Run("non-existing key", func(t *testing.T) { @@ -90,7 +90,7 @@ func TestRedisSessionStore_Delete(t *testing.T) { store, _ := NewTestStorageEngineRedis(t) require.NoError(t, store.Start()) sessions := store.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() // We make sure the value exists in another store, // to test partitioning otherStore := sessions.GetStore(time.Second, "unit_other") @@ -111,7 +111,7 @@ func TestRedisSessionStore_GetAndDelete(t *testing.T) { storageEngine, miniRedis := NewTestStorageEngineRedis(t) require.NoError(t, storageEngine.Start()) sessions := storageEngine.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() t.Run("ok", func(t *testing.T) { var actual testType @@ -152,7 +152,7 @@ func TestRedisSessionStore_Exists(t *testing.T) { store, miniRedis := NewTestStorageEngineRedis(t) require.NoError(t, store.Start()) sessions := store.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() // We make sure the value exists in another store, // to test partitioning otherStore := sessions.GetStore(time.Second, "unit_other") @@ -177,7 +177,7 @@ func TestRedisSessionStore_Put(t *testing.T) { store, _ := NewTestStorageEngineRedis(t) require.NoError(t, store.Start()) sessions := store.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() // We make sure the value exists in another store, // to test partitioning otherStore := sessions.GetStore(time.Second, "unit_other") @@ -209,7 +209,7 @@ func TestRedisSessionStore_Pruning(t *testing.T) { store, miniRedis := NewTestStorageEngineRedis(t) require.NoError(t, store.Start()) sessions := store.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() // We make sure the value exists in another store, // to test partitioning otherStore := sessions.GetStore(time.Second*1, "unit_other") diff --git a/storage/test.go b/storage/test.go index 549b564cdd..526f11decc 100644 --- a/storage/test.go +++ b/storage/test.go @@ -121,7 +121,7 @@ func (p *StaticKVStoreProvider) GetKVStore(_ string, _ Class) (stoabs.KVStore, e func NewTestInMemorySessionDatabase(t *testing.T) *InMemorySessionDatabase { db := NewInMemorySessionDatabase() t.Cleanup(func() { - db.close() + db.Close() }) return db }