Skip to content

Commit

Permalink
allow for custom ttl in session store
Browse files Browse the repository at this point in the history
  • Loading branch information
woutslakhorst committed Nov 22, 2024
1 parent 271f85e commit 0e660b0
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 68 deletions.
58 changes: 28 additions & 30 deletions auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ type httpRequestContextKey struct{}
// TODO: Might want to make this configurable at some point
const accessTokenValidity = 15 * time.Minute

// accessTokenCacheOffset is used to reduce the ttl of the access token to ensure it is still valid when the client receives it.
// this to offset clock skew and roundtrip times
const accessTokenCacheOffset = 30 * time.Second

// cacheControlMaxAgeURLs holds API endpoints that should have a max-age cache control header set.
var cacheControlMaxAgeURLs = []string{
"/oauth2/:subjectID/presentation_definition",
Expand Down Expand Up @@ -725,22 +729,16 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
}

tokenCache := r.accessTokenCache()
cacheKey, err := accessTokenRequestCacheKey(request)
cacheToken := true
if err != nil {
cacheToken = false
cacheKey := accessTokenRequestCacheKey(request)

// try to retrieve token from cache
tokenResult := new(TokenResponse)
err = tokenCache.Get(cacheKey, tokenResult)
if err == nil {
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
} else if !errors.Is(err, storage.ErrNotFound) {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to create cache key for access token request: %s", err.Error())
} else {
// try to retrieve token from cache
tokenResult := new(TokenResponse)
err = tokenCache.Get(cacheKey, tokenResult)
if err == nil {
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
} else if !errors.Is(err, storage.ErrNotFound) {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error())
}
log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error())
}

var credentials []VerifiableCredential
Expand All @@ -753,17 +751,21 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
useDPoP = false
}
clientID := r.subjectToBaseURL(request.SubjectID)
tokenResult, err := r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials)
tokenResult, err = r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials)
if err != nil {
// this can be an internal server error, a 400 oauth error or a 412 precondition failed if the wallet does not contain the required credentials
return nil, err
}
if cacheToken {
err = tokenCache.Put(cacheKey, tokenResult)
if err != nil {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error())
}
ttl := accessTokenValidity
if tokenResult.ExpiresIn != nil {
ttl = time.Second * time.Duration(*tokenResult.ExpiresIn)
}
// we reduce the ttl by accessTokenCacheOffset to make sure the token is expired when the cache expires
ttl -= accessTokenCacheOffset
err = tokenCache.Put(cacheKey, tokenResult, storage.WithTTL(ttl))
if err != nil {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error())
}
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
}
Expand Down Expand Up @@ -928,7 +930,7 @@ func (r Wrapper) accessTokenServerStore() storage.SessionStore {
// accessTokenClientStore is used by the client to cache access tokens
func (r Wrapper) accessTokenCache() storage.SessionStore {
// we use a slightly reduced validity to prevent the cache from being used after the token has expired
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-30*time.Second, "accesstokencache")
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-accessTokenCacheOffset, "accesstokencache")
}

// accessTokenServerStore is used by the Auth server to store issued access tokens
Expand Down Expand Up @@ -983,12 +985,8 @@ func (r Wrapper) determineClientDID(ctx context.Context, authServerMetadata oaut

// accessTokenRequestCacheKey creates a cache key for the access token request.
// it writes the JSON to a sha256 hash and returns the hex encoded hash.
func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) (string, error) {
// create a hash of the request
func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) string {
hash := sha256.New()
err := json.NewEncoder(hash).Encode(request)
if err != nil {
return "", err
}
return hex.EncodeToString(hash.Sum(nil)), nil
_ = json.NewEncoder(hash).Encode(request)
return hex.EncodeToString(hash.Sum(nil))
}
39 changes: 38 additions & 1 deletion auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {
})

t.Run("cache expired", func(t *testing.T) {
cacheKey, _ := accessTokenRequestCacheKey(request)
cacheKey := accessTokenRequestCacheKey(request)
_ = ctx.client.accessTokenCache().Delete(cacheKey)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "other"}, nil)

Expand All @@ -905,6 +905,16 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {

require.NoError(t, err)
})
t.Run("ok with expired cache by ttl", func(t *testing.T) {
ctx := newTestClient(t)
request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{ExpiresIn: to.Ptr(5)}, nil)

_, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)
assert.False(t, ctx.client.accessTokenCache().Exists(accessTokenRequestCacheKey(request)))
})
t.Run("error - no matching credentials", func(t *testing.T) {
ctx := newTestClient(t)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(nil, pe.ErrNoCredentials)
Expand All @@ -915,6 +925,24 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {
assert.Equal(t, err, pe.ErrNoCredentials)
assert.Equal(t, http.StatusPreconditionFailed, statusCodeFrom(err))
})
t.Run("broken cache", func(t *testing.T) {
ctx := newTestClient(t)
mockStorage := storage.NewMockEngine(ctx.ctrl)
errorSessionDatabase := storage.NewErrorSessionDatabase(assert.AnError)
mockStorage.EXPECT().GetSessionDatabase().Return(errorSessionDatabase).AnyTimes()
ctx.client.storageEngine = mockStorage

request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "first"}, nil)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "second"}, nil)

token1, err := ctx.client.RequestServiceAccessToken(nil, request)
require.NoError(t, err)
token2, err := ctx.client.RequestServiceAccessToken(nil, request)
require.NoError(t, err)

assert.NotEqual(t, token1, token2)
})
}

func TestWrapper_RequestUserAccessToken(t *testing.T) {
Expand Down Expand Up @@ -1340,6 +1368,15 @@ func TestWrapper_subjectOwns(t *testing.T) {
})
}

func TestWrapper_accessTokenRequestCacheKey(t *testing.T) {
expected := "0cc6fbbd972c72de7bc86c6147347bdd54bcb41fe23cea3d8f61d6ddd75dbf86"
key := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test"}})
other := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test2"}})

assert.Equal(t, expected, key)
assert.NotEqual(t, key, other)
}

func createIssuerCredential(issuerDID did.DID, holderDID did.DID) *vc.VerifiableCredential {
privateKey, _ := spi.GenerateKeyPair()
credType := ssi.MustParseURI("ExampleType")
Expand Down
2 changes: 1 addition & 1 deletion storage/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (e *engine) Shutdown() error {
}

// Close session database
e.sessionDatabase.close()
e.sessionDatabase.Close()
// Close SQL db
if e.sqlDB != nil {
underlyingDB, err := e.sqlDB.DB()
Expand Down
23 changes: 20 additions & 3 deletions storage/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ type SessionDatabase interface {
// The keys are used to logically partition the store, eg: tenants and/or flows that are not allowed to overlap like credential issuance and verification.
// The TTL is the time-to-live for the entries in the store.
GetStore(ttl time.Duration, keys ...string) SessionStore
// close stops any background processes and closes the database.
close()
// getFullKey returns the full key for the given key and prefixes.
// the supported chars differ per backend.
getFullKey(prefixes []string, key string) string
// Close stops any background processes and closes the database.
Close()
}

// SessionStore is a key-value store that holds session data.
Expand All @@ -95,10 +97,25 @@ type SessionStore interface {
// Returns ErrNotFound if the key does not exist.
Get(key string, target interface{}) error
// Put stores the given value for the given key.
Put(key string, value interface{}) error
// options can be used to fine-tune the storage of the item.
Put(key string, value interface{}, options ...SessionOption) error
// GetAndDelete combines Get and Delete as a convenience for burning nonce entries.
GetAndDelete(key string, target interface{}) error
}

// TransactionKey is the key used to store the SQL transaction in the context.
type TransactionKey struct{}

// SessionOption is an option that can be given when storing items.
type SessionOption func(target *sessionOptions)

type sessionOptions struct {
ttl time.Duration
}

// WithTTL sets the time-to-live for the stored item.
func WithTTL(ttl time.Duration) SessionOption {
return func(target *sessionOptions) {
target.ttl = ttl
}
}
37 changes: 21 additions & 16 deletions storage/mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 17 additions & 2 deletions storage/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,31 @@ func (s SessionStoreImpl[T]) Get(key string, target interface{}) error {
return json.Unmarshal([]byte(val), target)
}

func (s SessionStoreImpl[T]) Put(key string, value interface{}) error {
func (s SessionStoreImpl[T]) Put(key string, value interface{}, options ...SessionOption) error {
opts := s.defaultOptions()
for _, opt := range options {
opt(&opts)
}
// TTL can't go below 0 because that is translated to "no expiration" by the library
// in that case it should be 1 nanosecond
if opts.ttl < 0 {
opts.ttl = 1
}
bytes, err := json.Marshal(value)
if err != nil {
return err
}
return s.underlying.Set(context.Background(), s.db.getFullKey(s.prefixes, key), T(bytes), store.WithExpiration(s.ttl))
return s.underlying.Set(context.Background(), s.db.getFullKey(s.prefixes, key), T(bytes), store.WithExpiration(opts.ttl))
}
func (s SessionStoreImpl[T]) GetAndDelete(key string, target interface{}) error {
if err := s.Get(key, target); err != nil {
return err
}
return s.underlying.Delete(context.Background(), s.db.getFullKey(s.prefixes, key))
}

func (s SessionStoreImpl[T]) defaultOptions() sessionOptions {
return sessionOptions{
ttl: s.ttl,
}
}
2 changes: 1 addition & 1 deletion storage/session_inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (s *InMemorySessionDatabase) GetStore(ttl time.Duration, keys ...string) Se
}
}

func (s *InMemorySessionDatabase) close() {
func (s *InMemorySessionDatabase) Close() {
// NOP
}

Expand Down
65 changes: 65 additions & 0 deletions storage/session_inmemory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,71 @@ func TestInMemorySessionDatabase_GetStore(t *testing.T) {
assert.Equal(t, []string{"key1", "key2"}, store.prefixes)
}

func TestInMemorySessionStore_Exists(t *testing.T) {
db := createDatabase(t)
store := db.GetStore(time.Minute, "prefix")

t.Run("value exists", func(t *testing.T) {
_ = store.Put(t.Name(), "value")

exists := store.Exists(t.Name())

assert.True(t, exists)
})

t.Run("value does not exist", func(t *testing.T) {
exists := store.Exists(t.Name())

assert.False(t, exists)
})
}

func TestInMemorySessionStore_Put(t *testing.T) {
db := createDatabase(t)
store := db.GetStore(time.Minute, "prefix")

t.Run("string value is stored", func(t *testing.T) {
err := store.Put("key", "value")

require.NoError(t, err)

var val string
err = store.Get("key", &val)
require.NoError(t, err)
assert.Equal(t, "value", val)
})

t.Run("float value is stored", func(t *testing.T) {
err := store.Put("key", 1.23)

require.NoError(t, err)

var val float64
err = store.Get("key", &val)
assert.Equal(t, 1.23, val)
})

t.Run("struct value is stored", func(t *testing.T) {
value := testStruct{
Field1: "value",
}

err := store.Put("key", value)

require.NoError(t, err)

var val testStruct
err = store.Get("key", &val)
assert.Equal(t, value, val)
})

t.Run("value is not JSON", func(t *testing.T) {
err := store.Put("key", make(chan int))

assert.Error(t, err)
})
}

func TestInMemorySessionStore_Get(t *testing.T) {
db := createDatabase(t)
store := db.GetStore(time.Minute, "prefix").(SessionStoreImpl[[]byte])
Expand Down
Loading

0 comments on commit 0e660b0

Please sign in to comment.