diff --git a/auth/api/iam/api.go b/auth/api/iam/api.go index 7b6b84edb..34e1469ab 100644 --- a/auth/api/iam/api.go +++ b/auth/api/iam/api.go @@ -74,6 +74,10 @@ type httpRequestContextKey struct{} // TODO: Might want to make this configurable at some point const accessTokenValidity = 15 * time.Minute +// accessTokenCacheOffset is used to reduce the ttl of the access token to ensure it is still valid when the client receives it. +// this to offset clock skew and roundtrip times +const accessTokenCacheOffset = 30 * time.Second + // cacheControlMaxAgeURLs holds API endpoints that should have a max-age cache control header set. var cacheControlMaxAgeURLs = []string{ "/oauth2/:subjectID/presentation_definition", @@ -725,22 +729,16 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS } tokenCache := r.accessTokenCache() - cacheKey, err := accessTokenRequestCacheKey(request) - cacheToken := true - if err != nil { - cacheToken = false + cacheKey := accessTokenRequestCacheKey(request) + + // try to retrieve token from cache + tokenResult := new(TokenResponse) + err = tokenCache.Get(cacheKey, tokenResult) + if err == nil { + return RequestServiceAccessToken200JSONResponse(*tokenResult), nil + } else if !errors.Is(err, storage.ErrNotFound) { // only log error, don't fail - log.Logger().WithError(err).Warnf("Failed to create cache key for access token request: %s", err.Error()) - } else { - // try to retrieve token from cache - tokenResult := new(TokenResponse) - err = tokenCache.Get(cacheKey, tokenResult) - if err == nil { - return RequestServiceAccessToken200JSONResponse(*tokenResult), nil - } else if !errors.Is(err, storage.ErrNotFound) { - // only log error, don't fail - log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error()) - } + log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error()) } var credentials []VerifiableCredential @@ -753,17 +751,21 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS useDPoP = false } clientID := r.subjectToBaseURL(request.SubjectID) - tokenResult, err := r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials) + tokenResult, err = r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials) if err != nil { // this can be an internal server error, a 400 oauth error or a 412 precondition failed if the wallet does not contain the required credentials return nil, err } - if cacheToken { - err = tokenCache.Put(cacheKey, tokenResult) - if err != nil { - // only log error, don't fail - log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error()) - } + ttl := accessTokenValidity + if tokenResult.ExpiresIn != nil { + ttl = time.Second * time.Duration(*tokenResult.ExpiresIn) + } + // we reduce the ttl by accessTokenCacheOffset to make sure the token is expired when the cache expires + ttl -= accessTokenCacheOffset + err = tokenCache.Put(cacheKey, tokenResult, storage.WithTTL(ttl)) + if err != nil { + // only log error, don't fail + log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error()) } return RequestServiceAccessToken200JSONResponse(*tokenResult), nil } @@ -928,7 +930,7 @@ func (r Wrapper) accessTokenServerStore() storage.SessionStore { // accessTokenClientStore is used by the client to cache access tokens func (r Wrapper) accessTokenCache() storage.SessionStore { // we use a slightly reduced validity to prevent the cache from being used after the token has expired - return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-30*time.Second, "accesstokencache") + return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-accessTokenCacheOffset, "accesstokencache") } // accessTokenServerStore is used by the Auth server to store issued access tokens @@ -983,12 +985,8 @@ func (r Wrapper) determineClientDID(ctx context.Context, authServerMetadata oaut // accessTokenRequestCacheKey creates a cache key for the access token request. // it writes the JSON to a sha256 hash and returns the hex encoded hash. -func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) (string, error) { - // create a hash of the request +func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) string { hash := sha256.New() - err := json.NewEncoder(hash).Encode(request) - if err != nil { - return "", err - } - return hex.EncodeToString(hash.Sum(nil)), nil + _ = json.NewEncoder(hash).Encode(request) + return hex.EncodeToString(hash.Sum(nil)) } diff --git a/auth/api/iam/api_test.go b/auth/api/iam/api_test.go index aaac9ad8e..f78c9fb35 100644 --- a/auth/api/iam/api_test.go +++ b/auth/api/iam/api_test.go @@ -880,7 +880,7 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) { }) t.Run("cache expired", func(t *testing.T) { - cacheKey, _ := accessTokenRequestCacheKey(request) + cacheKey := accessTokenRequestCacheKey(request) _ = ctx.client.accessTokenCache().Delete(cacheKey) ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "other"}, nil) @@ -905,6 +905,16 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) { require.NoError(t, err) }) + t.Run("ok with expired cache by ttl", func(t *testing.T) { + ctx := newTestClient(t) + request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body} + ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{ExpiresIn: to.Ptr(5)}, nil) + + _, err := ctx.client.RequestServiceAccessToken(nil, request) + + require.NoError(t, err) + assert.False(t, ctx.client.accessTokenCache().Exists(accessTokenRequestCacheKey(request))) + }) t.Run("error - no matching credentials", func(t *testing.T) { ctx := newTestClient(t) ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(nil, pe.ErrNoCredentials) @@ -915,6 +925,24 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) { assert.Equal(t, err, pe.ErrNoCredentials) assert.Equal(t, http.StatusPreconditionFailed, statusCodeFrom(err)) }) + t.Run("broken cache", func(t *testing.T) { + ctx := newTestClient(t) + mockStorage := storage.NewMockEngine(ctx.ctrl) + errorSessionDatabase := storage.NewErrorSessionDatabase(assert.AnError) + mockStorage.EXPECT().GetSessionDatabase().Return(errorSessionDatabase).AnyTimes() + ctx.client.storageEngine = mockStorage + + request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body} + ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "first"}, nil) + ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "second"}, nil) + + token1, err := ctx.client.RequestServiceAccessToken(nil, request) + require.NoError(t, err) + token2, err := ctx.client.RequestServiceAccessToken(nil, request) + require.NoError(t, err) + + assert.NotEqual(t, token1, token2) + }) } func TestWrapper_RequestUserAccessToken(t *testing.T) { @@ -1340,6 +1368,15 @@ func TestWrapper_subjectOwns(t *testing.T) { }) } +func TestWrapper_accessTokenRequestCacheKey(t *testing.T) { + expected := "0cc6fbbd972c72de7bc86c6147347bdd54bcb41fe23cea3d8f61d6ddd75dbf86" + key := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test"}}) + other := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test2"}}) + + assert.Equal(t, expected, key) + assert.NotEqual(t, key, other) +} + func createIssuerCredential(issuerDID did.DID, holderDID did.DID) *vc.VerifiableCredential { privateKey, _ := spi.GenerateKeyPair() credType := ssi.MustParseURI("ExampleType") diff --git a/storage/engine.go b/storage/engine.go index def57482f..8f9790c82 100644 --- a/storage/engine.go +++ b/storage/engine.go @@ -147,7 +147,7 @@ func (e *engine) Shutdown() error { } // Close session database - e.sessionDatabase.close() + e.sessionDatabase.Close() // Close SQL db if e.sqlDB != nil { underlyingDB, err := e.sqlDB.DB() diff --git a/storage/interface.go b/storage/interface.go index 762d74b44..80dbbd816 100644 --- a/storage/interface.go +++ b/storage/interface.go @@ -78,9 +78,11 @@ type SessionDatabase interface { // The keys are used to logically partition the store, eg: tenants and/or flows that are not allowed to overlap like credential issuance and verification. // The TTL is the time-to-live for the entries in the store. GetStore(ttl time.Duration, keys ...string) SessionStore - // close stops any background processes and closes the database. - close() + // getFullKey returns the full key for the given key and prefixes. + // the supported chars differ per backend. getFullKey(prefixes []string, key string) string + // Close stops any background processes and closes the database. + Close() } // SessionStore is a key-value store that holds session data. @@ -95,10 +97,25 @@ type SessionStore interface { // Returns ErrNotFound if the key does not exist. Get(key string, target interface{}) error // Put stores the given value for the given key. - Put(key string, value interface{}) error + // options can be used to fine-tune the storage of the item. + Put(key string, value interface{}, options ...SessionOption) error // GetAndDelete combines Get and Delete as a convenience for burning nonce entries. GetAndDelete(key string, target interface{}) error } // TransactionKey is the key used to store the SQL transaction in the context. type TransactionKey struct{} + +// SessionOption is an option that can be given when storing items. +type SessionOption func(target *sessionOptions) + +type sessionOptions struct { + ttl time.Duration +} + +// WithTTL sets the time-to-live for the stored item. +func WithTTL(ttl time.Duration) SessionOption { + return func(target *sessionOptions) { + target.ttl = ttl + } +} diff --git a/storage/mock.go b/storage/mock.go index 3cf42b60a..272de38f3 100644 --- a/storage/mock.go +++ b/storage/mock.go @@ -255,6 +255,18 @@ func (m *MockSessionDatabase) EXPECT() *MockSessionDatabaseMockRecorder { return m.recorder } +// Close mocks base method. +func (m *MockSessionDatabase) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockSessionDatabaseMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSessionDatabase)(nil).Close)) +} + // GetStore mocks base method. func (m *MockSessionDatabase) GetStore(ttl time.Duration, keys ...string) SessionStore { m.ctrl.T.Helper() @@ -274,18 +286,6 @@ func (mr *MockSessionDatabaseMockRecorder) GetStore(ttl any, keys ...any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStore", reflect.TypeOf((*MockSessionDatabase)(nil).GetStore), varargs...) } -// close mocks base method. -func (m *MockSessionDatabase) close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "close") -} - -// close indicates an expected call of close. -func (mr *MockSessionDatabaseMockRecorder) close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockSessionDatabase)(nil).close)) -} - // MockSessionStore is a mock of SessionStore interface. type MockSessionStore struct { ctrl *gomock.Controller @@ -367,15 +367,20 @@ func (mr *MockSessionStoreMockRecorder) GetAndDelete(key, target any) *gomock.Ca } // Put mocks base method. -func (m *MockSessionStore) Put(key string, value any) error { +func (m *MockSessionStore) Put(key string, value any, options ...SessionOption) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Put", key, value) + varargs := []any{key, value} + for _, a := range options { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Put", varargs...) ret0, _ := ret[0].(error) return ret0 } // Put indicates an expected call of Put. -func (mr *MockSessionStoreMockRecorder) Put(key, value any) *gomock.Call { +func (mr *MockSessionStoreMockRecorder) Put(key, value any, options ...any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockSessionStore)(nil).Put), key, value) + varargs := append([]any{key, value}, options...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockSessionStore)(nil).Put), varargs...) } diff --git a/storage/session.go b/storage/session.go index 77d90d5c5..106a7d72b 100644 --- a/storage/session.go +++ b/storage/session.go @@ -85,12 +85,21 @@ func (s SessionStoreImpl[T]) Get(key string, target interface{}) error { return json.Unmarshal([]byte(val), target) } -func (s SessionStoreImpl[T]) Put(key string, value interface{}) error { +func (s SessionStoreImpl[T]) Put(key string, value interface{}, options ...SessionOption) error { + opts := s.defaultOptions() + for _, opt := range options { + opt(&opts) + } + // TTL can't go below 0 because that is translated to "no expiration" by the library + // in that case it should be 1 nanosecond + if opts.ttl < 0 { + opts.ttl = 1 + } bytes, err := json.Marshal(value) if err != nil { return err } - return s.underlying.Set(context.Background(), s.db.getFullKey(s.prefixes, key), T(bytes), store.WithExpiration(s.ttl)) + return s.underlying.Set(context.Background(), s.db.getFullKey(s.prefixes, key), T(bytes), store.WithExpiration(opts.ttl)) } func (s SessionStoreImpl[T]) GetAndDelete(key string, target interface{}) error { if err := s.Get(key, target); err != nil { @@ -98,3 +107,9 @@ func (s SessionStoreImpl[T]) GetAndDelete(key string, target interface{}) error } return s.underlying.Delete(context.Background(), s.db.getFullKey(s.prefixes, key)) } + +func (s SessionStoreImpl[T]) defaultOptions() sessionOptions { + return sessionOptions{ + ttl: s.ttl, + } +} diff --git a/storage/session_inmemory.go b/storage/session_inmemory.go index c65bbfcdc..29cfd62f0 100644 --- a/storage/session_inmemory.go +++ b/storage/session_inmemory.go @@ -56,7 +56,7 @@ func (s *InMemorySessionDatabase) GetStore(ttl time.Duration, keys ...string) Se } } -func (s *InMemorySessionDatabase) close() { +func (s *InMemorySessionDatabase) Close() { // NOP } diff --git a/storage/session_inmemory_test.go b/storage/session_inmemory_test.go index 3222e3c35..b3511eafc 100644 --- a/storage/session_inmemory_test.go +++ b/storage/session_inmemory_test.go @@ -42,6 +42,71 @@ func TestInMemorySessionDatabase_GetStore(t *testing.T) { assert.Equal(t, []string{"key1", "key2"}, store.prefixes) } +func TestInMemorySessionStore_Exists(t *testing.T) { + db := createDatabase(t) + store := db.GetStore(time.Minute, "prefix") + + t.Run("value exists", func(t *testing.T) { + _ = store.Put(t.Name(), "value") + + exists := store.Exists(t.Name()) + + assert.True(t, exists) + }) + + t.Run("value does not exist", func(t *testing.T) { + exists := store.Exists(t.Name()) + + assert.False(t, exists) + }) +} + +func TestInMemorySessionStore_Put(t *testing.T) { + db := createDatabase(t) + store := db.GetStore(time.Minute, "prefix") + + t.Run("string value is stored", func(t *testing.T) { + err := store.Put("key", "value") + + require.NoError(t, err) + + var val string + err = store.Get("key", &val) + require.NoError(t, err) + assert.Equal(t, "value", val) + }) + + t.Run("float value is stored", func(t *testing.T) { + err := store.Put("key", 1.23) + + require.NoError(t, err) + + var val float64 + err = store.Get("key", &val) + assert.Equal(t, 1.23, val) + }) + + t.Run("struct value is stored", func(t *testing.T) { + value := testStruct{ + Field1: "value", + } + + err := store.Put("key", value) + + require.NoError(t, err) + + var val testStruct + err = store.Get("key", &val) + assert.Equal(t, value, val) + }) + + t.Run("value is not JSON", func(t *testing.T) { + err := store.Put("key", make(chan int)) + + assert.Error(t, err) + }) +} + func TestInMemorySessionStore_Get(t *testing.T) { db := createDatabase(t) store := db.GetStore(time.Minute, "prefix").(SessionStoreImpl[[]byte]) diff --git a/storage/session_memcached.go b/storage/session_memcached.go index eb1f61897..1cd3ee4ac 100644 --- a/storage/session_memcached.go +++ b/storage/session_memcached.go @@ -40,6 +40,7 @@ func NewMemcachedSessionDatabase(client *memcache.Client) *MemcachedSessionDatab memcachedStore := memcachestore.NewMemcache(client, store.WithExpiration(defaultSessionDataTTL)) return &MemcachedSessionDatabase{ underlying: cache.New[[]byte](memcachedStore), + client: client, } } @@ -52,8 +53,7 @@ func (s MemcachedSessionDatabase) GetStore(ttl time.Duration, keys ...string) Se } } -func (s MemcachedSessionDatabase) close() { - // noop +func (s MemcachedSessionDatabase) Close() { if s.client != nil { _ = s.client.Close() } diff --git a/storage/session_redis.go b/storage/session_redis.go index 2567a339b..df163c53f 100644 --- a/storage/session_redis.go +++ b/storage/session_redis.go @@ -30,6 +30,7 @@ import ( type redisSessionDatabase struct { underlying *cache.Cache[string] prefix string + client *redis.Client } func NewRedisSessionDatabase(client *redis.Client, prefix string) SessionDatabase { @@ -37,6 +38,7 @@ func NewRedisSessionDatabase(client *redis.Client, prefix string) SessionDatabas return redisSessionDatabase{ underlying: cache.New[string](redisStore), prefix: prefix, + client: client, } } @@ -54,8 +56,10 @@ func (s redisSessionDatabase) GetStore(ttl time.Duration, keys ...string) Sessio } } -func (s redisSessionDatabase) close() { - // nop +func (s redisSessionDatabase) Close() { + if s.client != nil { + _ = s.client.Close() + } } func (s redisSessionDatabase) getFullKey(prefixes []string, key string) string { diff --git a/storage/session_redis_test.go b/storage/session_redis_test.go index 224c372d4..a78b0bbc8 100644 --- a/storage/session_redis_test.go +++ b/storage/session_redis_test.go @@ -45,7 +45,7 @@ func TestRedisSessionStore(t *testing.T) { store, _ := NewTestStorageEngineRedis(t) require.NoError(t, store.Start()) sessions := store.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() t.Run("lifecycle", func(t *testing.T) { store := sessions.GetStore(time.Second, "unit") @@ -66,7 +66,7 @@ func TestRedisSessionStore_Get(t *testing.T) { storageEngine, miniRedis := NewTestStorageEngineRedis(t) require.NoError(t, storageEngine.Start()) sessions := storageEngine.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() var actual testType t.Run("non-existing key", func(t *testing.T) { @@ -90,7 +90,7 @@ func TestRedisSessionStore_Delete(t *testing.T) { store, _ := NewTestStorageEngineRedis(t) require.NoError(t, store.Start()) sessions := store.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() // We make sure the value exists in another store, // to test partitioning otherStore := sessions.GetStore(time.Second, "unit_other") @@ -111,7 +111,7 @@ func TestRedisSessionStore_GetAndDelete(t *testing.T) { storageEngine, miniRedis := NewTestStorageEngineRedis(t) require.NoError(t, storageEngine.Start()) sessions := storageEngine.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() t.Run("ok", func(t *testing.T) { var actual testType @@ -152,7 +152,7 @@ func TestRedisSessionStore_Exists(t *testing.T) { store, miniRedis := NewTestStorageEngineRedis(t) require.NoError(t, store.Start()) sessions := store.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() // We make sure the value exists in another store, // to test partitioning otherStore := sessions.GetStore(time.Second, "unit_other") @@ -177,7 +177,7 @@ func TestRedisSessionStore_Put(t *testing.T) { store, _ := NewTestStorageEngineRedis(t) require.NoError(t, store.Start()) sessions := store.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() // We make sure the value exists in another store, // to test partitioning otherStore := sessions.GetStore(time.Second, "unit_other") @@ -209,7 +209,7 @@ func TestRedisSessionStore_Pruning(t *testing.T) { store, miniRedis := NewTestStorageEngineRedis(t) require.NoError(t, store.Start()) sessions := store.GetSessionDatabase() - defer sessions.close() + defer sessions.Close() // We make sure the value exists in another store, // to test partitioning otherStore := sessions.GetStore(time.Second*1, "unit_other") diff --git a/storage/test.go b/storage/test.go index 164878511..b22d77411 100644 --- a/storage/test.go +++ b/storage/test.go @@ -122,7 +122,7 @@ func (p *StaticKVStoreProvider) GetKVStore(_ string, _ Class) (stoabs.KVStore, e func NewTestInMemorySessionDatabase(t *testing.T) *InMemorySessionDatabase { db := NewInMemorySessionDatabase() t.Cleanup(func() { - db.close() + db.Close() }) return db } @@ -164,7 +164,7 @@ func (e errorSessionDatabase) getFullKey(prefixes []string, key string) string { return "" } -func (e errorSessionDatabase) close() { +func (e errorSessionDatabase) Close() { // nop } @@ -180,7 +180,7 @@ func (e errorSessionStore) Get(key string, target interface{}) error { return e.err } -func (e errorSessionStore) Put(key string, value interface{}) error { +func (e errorSessionStore) Put(key string, value interface{}, options ...SessionOption) error { return e.err }