Skip to content

Commit

Permalink
allow for custom ttl in session store
Browse files Browse the repository at this point in the history
  • Loading branch information
woutslakhorst committed Nov 18, 2024
1 parent 9ed0647 commit ea54414
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 69 deletions.
58 changes: 28 additions & 30 deletions auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ type httpRequestContextKey struct{}
// TODO: Might want to make this configurable at some point
const accessTokenValidity = 15 * time.Minute

// accessTokenCacheOffset is used to reduce the ttl of the access token to ensure it is still valid when the client receives it.
// this to offset clock skew and roundtrip times
const accessTokenCacheOffset = 30 * time.Second

// cacheControlMaxAgeURLs holds API endpoints that should have a max-age cache control header set.
var cacheControlMaxAgeURLs = []string{
"/oauth2/:subjectID/presentation_definition",
Expand Down Expand Up @@ -725,22 +729,16 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
}

tokenCache := r.accessTokenCache()
cacheKey, err := accessTokenRequestCacheKey(request)
cacheToken := true
if err != nil {
cacheToken = false
cacheKey := accessTokenRequestCacheKey(request)

// try to retrieve token from cache
tokenResult := new(TokenResponse)
err = tokenCache.Get(cacheKey, tokenResult)
if err == nil {
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
} else if !errors.Is(err, storage.ErrNotFound) {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to create cache key for access token request: %s", err.Error())
} else {
// try to retrieve token from cache
tokenResult := new(TokenResponse)
err = tokenCache.Get(cacheKey, tokenResult)
if err == nil {
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
} else if !errors.Is(err, storage.ErrNotFound) {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error())
}
log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error())
}

var credentials []VerifiableCredential
Expand All @@ -753,17 +751,21 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
useDPoP = false
}
clientID := r.subjectToBaseURL(request.SubjectID)
tokenResult, err := r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials)
tokenResult, err = r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials)
if err != nil {
// this can be an internal server error, a 400 oauth error or a 412 precondition failed if the wallet does not contain the required credentials
return nil, err
}
if cacheToken {
err = tokenCache.Put(cacheKey, tokenResult)
if err != nil {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error())
}
ttl := accessTokenValidity
if tokenResult.ExpiresIn != nil {
ttl = time.Second * time.Duration(*tokenResult.ExpiresIn)
}
// we reduce the ttl by accessTokenCacheOffset to make sure the token is expired when the cache expires
ttl -= accessTokenCacheOffset
err = tokenCache.Put(cacheKey, tokenResult, storage.WithTTL(ttl))
if err != nil {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error())
}
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
}
Expand Down Expand Up @@ -928,7 +930,7 @@ func (r Wrapper) accessTokenServerStore() storage.SessionStore {
// accessTokenClientStore is used by the client to cache access tokens
func (r Wrapper) accessTokenCache() storage.SessionStore {
// we use a slightly reduced validity to prevent the cache from being used after the token has expired
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-30*time.Second, "accesstokencache")
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-accessTokenCacheOffset, "accesstokencache")
}

// accessTokenServerStore is used by the Auth server to store issued access tokens
Expand Down Expand Up @@ -983,12 +985,8 @@ func (r Wrapper) determineClientDID(ctx context.Context, authServerMetadata oaut

// accessTokenRequestCacheKey creates a cache key for the access token request.
// it writes the JSON to a sha256 hash and returns the hex encoded hash.
func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) (string, error) {
// create a hash of the request
func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) string {
hash := sha256.New()
err := json.NewEncoder(hash).Encode(request)
if err != nil {
return "", err
}
return hex.EncodeToString(hash.Sum(nil)), nil
_ = json.NewEncoder(hash).Encode(request)
return hex.EncodeToString(hash.Sum(nil))
}
77 changes: 76 additions & 1 deletion auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {
})

t.Run("cache expired", func(t *testing.T) {
cacheKey, _ := accessTokenRequestCacheKey(request)
cacheKey := accessTokenRequestCacheKey(request)
_ = ctx.client.accessTokenCache().Delete(cacheKey)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "other"}, nil)

Expand All @@ -905,6 +905,16 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {

require.NoError(t, err)
})
t.Run("ok with expired cache by ttl", func(t *testing.T) {
ctx := newTestClient(t)
request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{ExpiresIn: to.Ptr(5)}, nil)

_, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)
assert.False(t, ctx.client.accessTokenCache().Exists(accessTokenRequestCacheKey(request)))
})
t.Run("error - no matching credentials", func(t *testing.T) {
ctx := newTestClient(t)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(nil, pe.ErrNoCredentials)
Expand All @@ -915,6 +925,23 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {
assert.Equal(t, err, pe.ErrNoCredentials)
assert.Equal(t, http.StatusPreconditionFailed, statusCodeFrom(err))
})
t.Run("broken cache", func(t *testing.T) {
ctx := newTestClient(t)
mockStorage := storage.NewMockEngine(ctx.ctrl)
mockStorage.EXPECT().GetSessionDatabase().Return(errorSessionDatabase{err: assert.AnError}).AnyTimes()
ctx.client.storageEngine = mockStorage

request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "first"}, nil)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "second"}, nil)

token1, err := ctx.client.RequestServiceAccessToken(nil, request)
require.NoError(t, err)
token2, err := ctx.client.RequestServiceAccessToken(nil, request)
require.NoError(t, err)

assert.NotEqual(t, token1, token2)
})
}

func TestWrapper_RequestUserAccessToken(t *testing.T) {
Expand Down Expand Up @@ -1340,6 +1367,15 @@ func TestWrapper_subjectOwns(t *testing.T) {
})
}

func TestWrapper_accessTokenRequestCacheKey(t *testing.T) {
expected := "0cc6fbbd972c72de7bc86c6147347bdd54bcb41fe23cea3d8f61d6ddd75dbf86"
key := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test"}})
other := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test2"}})

assert.Equal(t, expected, key)
assert.NotEqual(t, key, other)
}

func createIssuerCredential(issuerDID did.DID, holderDID did.DID) *vc.VerifiableCredential {
privateKey, _ := spi.GenerateKeyPair()
credType := ssi.MustParseURI("ExampleType")
Expand Down Expand Up @@ -1509,3 +1545,42 @@ func newCustomTestClient(t testing.TB, publicURL *url.URL, authEndpointEnabled b
client: client,
}
}

var _ storage.SessionDatabase = (*errorSessionDatabase)(nil)
var _ storage.SessionStore = (*errorSessionStore)(nil)

type errorSessionDatabase struct {
err error
}

type errorSessionStore struct {
err error
}

func (e errorSessionDatabase) GetStore(ttl time.Duration, keys ...string) storage.SessionStore {
return errorSessionStore{err: e.err}
}

func (e errorSessionDatabase) Close() {
// nop
}

func (e errorSessionStore) Delete(key string) error {
return e.err
}

func (e errorSessionStore) Exists(key string) bool {
return false
}

func (e errorSessionStore) Get(key string, target interface{}) error {
return e.err
}

func (e errorSessionStore) Put(key string, value interface{}, options ...storage.SessionOption) error {
return e.err
}

func (e errorSessionStore) GetAndDelete(key string, target interface{}) error {
return e.err
}
2 changes: 1 addition & 1 deletion storage/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (e *engine) Shutdown() error {
}

// Close session database
e.sessionDatabase.close()
e.sessionDatabase.Close()
// Close SQL db
if e.sqlDB != nil {
underlyingDB, err := e.sqlDB.DB()
Expand Down
21 changes: 18 additions & 3 deletions storage/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ type SessionDatabase interface {
// The keys are used to logically partition the store, eg: tenants and/or flows that are not allowed to overlap like credential issuance and verification.
// The TTL is the time-to-live for the entries in the store.
GetStore(ttl time.Duration, keys ...string) SessionStore
// close stops any background processes and closes the database.
close()
// Close stops any background processes and closes the database.
Close()
}

// SessionStore is a key-value store that holds session data.
Expand All @@ -94,10 +94,25 @@ type SessionStore interface {
// Returns ErrNotFound if the key does not exist.
Get(key string, target interface{}) error
// Put stores the given value for the given key.
Put(key string, value interface{}) error
// options can be used to fine-tune the storage of the item.
Put(key string, value interface{}, options ...SessionOption) error
// GetAndDelete combines Get and Delete as a convenience for burning nonce entries.
GetAndDelete(key string, target interface{}) error
}

// TransactionKey is the key used to store the SQL transaction in the context.
type TransactionKey struct{}

// SessionOption is an option that can be given when storing items.
type SessionOption func(target *sessionOptions)

type sessionOptions struct {
ttl time.Duration
}

// WithTTL sets the time-to-live for the stored item.
func WithTTL(ttl time.Duration) SessionOption {
return func(target *sessionOptions) {
target.ttl = ttl
}
}
37 changes: 21 additions & 16 deletions storage/mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 22 additions & 5 deletions storage/session_inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (i *InMemorySessionDatabase) GetStore(ttl time.Duration, keys ...string) Se
}
}

func (i *InMemorySessionDatabase) close() {
func (i *InMemorySessionDatabase) Close() {
// Signal pruner to stop and wait for it to finish
i.done <- struct{}{}
}
Expand Down Expand Up @@ -127,8 +127,14 @@ func (i InMemorySessionStore) Exists(key string) bool {
i.db.mux.Lock()
defer i.db.mux.Unlock()

_, ok := i.db.entries[i.getFullKey(key)]
return ok
entry, ok := i.db.entries[i.getFullKey(key)]
if !ok {
return false
}
if entry.Expiry.Before(time.Now()) {
return false
}
return true
}

func (i InMemorySessionStore) Get(key string, target interface{}) error {
Expand All @@ -151,7 +157,12 @@ func (i InMemorySessionStore) get(key string, target interface{}) error {
return json.Unmarshal([]byte(entry.Value), target)
}

func (i InMemorySessionStore) Put(key string, value interface{}) error {
func (i InMemorySessionStore) Put(key string, value interface{}, options ...SessionOption) error {
defaultOptions := i.defaultOptions()
for _, option := range options {
option(&defaultOptions)
}

i.db.mux.Lock()
defer i.db.mux.Unlock()

Expand All @@ -161,7 +172,7 @@ func (i InMemorySessionStore) Put(key string, value interface{}) error {
}
entry := expiringEntry{
Value: string(bytes),
Expiry: time.Now().Add(i.ttl),
Expiry: time.Now().Add(defaultOptions.ttl),
}

i.db.entries[i.getFullKey(key)] = entry
Expand All @@ -180,3 +191,9 @@ func (i InMemorySessionStore) GetAndDelete(key string, target interface{}) error
func (i InMemorySessionStore) getFullKey(key string) string {
return strings.Join(append(i.prefixes, key), "/")
}

func (i InMemorySessionStore) defaultOptions() sessionOptions {
return sessionOptions{
ttl: i.ttl,
}
}
Loading

0 comments on commit ea54414

Please sign in to comment.