diff --git a/network/dag/consistency_test.go b/network/dag/consistency_test.go
index 09f9614cf9..a13387b9df 100644
--- a/network/dag/consistency_test.go
+++ b/network/dag/consistency_test.go
@@ -35,9 +35,7 @@ import (
)
func TestXorTreeRepair(t *testing.T) {
- t.Cleanup(func() {
- goleak.VerifyNone(t)
- })
+ defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
tx, _, _ := CreateTestTransaction(1)
t.Run("xor tree repaired after 2 signals", func(t *testing.T) {
diff --git a/network/transport/grpc/connection_manager_test.go b/network/transport/grpc/connection_manager_test.go
index acc5080394..2db69baa0b 100644
--- a/network/transport/grpc/connection_manager_test.go
+++ b/network/transport/grpc/connection_manager_test.go
@@ -215,7 +215,7 @@ func Test_grpcConnectionManager_hasActiveConnection(t *testing.T) {
func Test_grpcConnectionManager_dialerLoop(t *testing.T) {
// make sure connectLoop only returns after all of its goroutines are closed
- defer goleak.VerifyNone(t)
+ defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
targetAddress := "bootstrap"
var capturedAddress string
diff --git a/network/transport/v2/gossip/manager_test.go b/network/transport/v2/gossip/manager_test.go
index d92d8a56a9..85c9b6e5d4 100644
--- a/network/transport/v2/gossip/manager_test.go
+++ b/network/transport/v2/gossip/manager_test.go
@@ -97,7 +97,7 @@ func TestManager_PeerDisconnected(t *testing.T) {
t.Run("stops ticker", func(t *testing.T) {
// Use uber/goleak to assert the goroutine started by PeerConnected is stopped when PeerDisconnected is called
- defer goleak.VerifyNone(t)
+ defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
gMan := giveMeAgMan(t)
gMan.interval = time.Millisecond
diff --git a/network/transport/v2/protocol_test.go b/network/transport/v2/protocol_test.go
index 1912e6b273..00a28f43c7 100644
--- a/network/transport/v2/protocol_test.go
+++ b/network/transport/v2/protocol_test.go
@@ -202,7 +202,7 @@ func TestProtocol_Start(t *testing.T) {
func TestProtocol_Stop(t *testing.T) {
t.Run("waits until goroutines have finished", func(t *testing.T) {
- defer goleak.VerifyNone(t)
+ defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
// Use waitgroup to make sure the goroutine that blocks has started
wg := &sync.WaitGroup{}
diff --git a/network/transport/v2/transactionlist_handler_test.go b/network/transport/v2/transactionlist_handler_test.go
index 16142ce84a..0971900cbc 100644
--- a/network/transport/v2/transactionlist_handler_test.go
+++ b/network/transport/v2/transactionlist_handler_test.go
@@ -37,9 +37,7 @@ import (
)
func TestTransactionListHandler(t *testing.T) {
- t.Cleanup(func() {
- goleak.VerifyNone(t)
- })
+ defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
t.Run("fn is called", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
diff --git a/pki/validator_test.go b/pki/validator_test.go
index 8115e426d2..c2581cece7 100644
--- a/pki/validator_test.go
+++ b/pki/validator_test.go
@@ -57,7 +57,7 @@ var crlPathMap = map[string]string{
}
func TestValidator_Start(t *testing.T) {
- defer goleak.VerifyNone(t)
+ defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
store, err := core.LoadTrustStore(truststorePKIo)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
diff --git a/storage/engine.go b/storage/engine.go
index 5c8fdd57bb..3a00578de3 100644
--- a/storage/engine.go
+++ b/storage/engine.go
@@ -36,17 +36,19 @@ const storeShutdownTimeout = 5 * time.Second
// New creates a new instance of the storage engine.
func New() Engine {
return &engine{
- storesMux: &sync.Mutex{},
- stores: map[string]stoabs.Store{},
+ storesMux: &sync.Mutex{},
+ stores: map[string]stoabs.Store{},
+ sessionDatabase: NewInMemorySessionDatabase(),
}
}
type engine struct {
- datadir string
- storesMux *sync.Mutex
- stores map[string]stoabs.Store
- databases []database
- config Config
+ datadir string
+ storesMux *sync.Mutex
+ stores map[string]stoabs.Store
+ databases []database
+ sessionDatabase SessionDatabase
+ config Config
}
func (e *engine) Config() interface{} {
@@ -84,9 +86,13 @@ func (e engine) Shutdown() error {
failures = true
}
}
+
if failures {
return errors.New("one or more stores failed to close")
}
+
+ e.sessionDatabase.close()
+
return nil
}
@@ -108,6 +114,7 @@ func (e *engine) Configure(config core.ServerConfig) error {
return fmt.Errorf("unable to configure BBolt database: %w", err)
}
e.databases = append(e.databases, bboltDB)
+
return nil
}
@@ -118,6 +125,10 @@ func (e *engine) GetProvider(moduleName string) Provider {
}
}
+func (e *engine) GetSessionDatabase() SessionDatabase {
+ return e.sessionDatabase
+}
+
type provider struct {
moduleName string
engine *engine
diff --git a/storage/interface.go b/storage/interface.go
index f9b3f9b662..e23888542f 100644
--- a/storage/interface.go
+++ b/storage/interface.go
@@ -19,6 +19,7 @@
package storage
import (
+ "errors"
"github.com/nuts-foundation/go-stoabs"
"github.com/nuts-foundation/nuts-node/core"
"time"
@@ -34,6 +35,8 @@ type Engine interface {
// GetProvider returns the Provider for the given module.
GetProvider(moduleName string) Provider
+ // GetSessionDatabase returns the SessionDatabase
+ GetSessionDatabase() SessionDatabase
}
// Provider lets callers get access to stores.
@@ -59,3 +62,32 @@ type database interface {
getClass() Class
close()
}
+
+var ErrNotFound = errors.New("not found")
+
+// SessionDatabase is a non-persistent database that holds session data on a KV basis.
+// Keys could be access tokens, nonce's, authorization codes, etc.
+// All entries are stored with a TTL, so they will be removed automatically.
+type SessionDatabase interface {
+ // GetStore returns a SessionStore with the given keys as key prefixes.
+ // The keys are used to logically partition the store, eg: tenants and/or flows that are not allowed to overlap like credential issuance and verification.
+ // The TTL is the time-to-live for the entries in the store.
+ GetStore(ttl time.Duration, keys ...string) SessionStore
+ // close stops any background processes and closes the database.
+ close()
+}
+
+// SessionStore is a key-value store that holds session data.
+// The SessionStore is an abstraction for underlying storage, it automatically adds prefixes for logical partitions.
+type SessionStore interface {
+ // Delete deletes the entry for the given key.
+ // It does not return an error if the key does not exist.
+ Delete(key string) error
+ // Exists returns true if the key exists.
+ Exists(key string) bool
+ // Get returns the value for the given key.
+ // Returns ErrNotFound if the key does not exist.
+ Get(key string, target interface{}) error
+ // Put stores the given value for the given key.
+ Put(key string, value interface{}) error
+}
diff --git a/storage/mock.go b/storage/mock.go
index f67a88e420..de364cd577 100644
--- a/storage/mock.go
+++ b/storage/mock.go
@@ -10,6 +10,7 @@ package storage
import (
reflect "reflect"
+ time "time"
stoabs "github.com/nuts-foundation/go-stoabs"
core "github.com/nuts-foundation/nuts-node/core"
@@ -67,6 +68,20 @@ func (mr *MockEngineMockRecorder) GetProvider(moduleName any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvider", reflect.TypeOf((*MockEngine)(nil).GetProvider), moduleName)
}
+// GetSessionDatabase mocks base method.
+func (m *MockEngine) GetSessionDatabase() SessionDatabase {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetSessionDatabase")
+ ret0, _ := ret[0].(SessionDatabase)
+ return ret0
+}
+
+// GetSessionDatabase indicates an expected call of GetSessionDatabase.
+func (mr *MockEngineMockRecorder) GetSessionDatabase() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionDatabase", reflect.TypeOf((*MockEngine)(nil).GetSessionDatabase))
+}
+
// Shutdown mocks base method.
func (m *MockEngine) Shutdown() error {
m.ctrl.T.Helper()
@@ -196,3 +211,136 @@ func (mr *MockdatabaseMockRecorder) getClass() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getClass", reflect.TypeOf((*Mockdatabase)(nil).getClass))
}
+
+// MockSessionDatabase is a mock of SessionDatabase interface.
+type MockSessionDatabase struct {
+ ctrl *gomock.Controller
+ recorder *MockSessionDatabaseMockRecorder
+}
+
+// MockSessionDatabaseMockRecorder is the mock recorder for MockSessionDatabase.
+type MockSessionDatabaseMockRecorder struct {
+ mock *MockSessionDatabase
+}
+
+// NewMockSessionDatabase creates a new mock instance.
+func NewMockSessionDatabase(ctrl *gomock.Controller) *MockSessionDatabase {
+ mock := &MockSessionDatabase{ctrl: ctrl}
+ mock.recorder = &MockSessionDatabaseMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockSessionDatabase) EXPECT() *MockSessionDatabaseMockRecorder {
+ return m.recorder
+}
+
+// GetStore mocks base method.
+func (m *MockSessionDatabase) GetStore(ttl time.Duration, keys ...string) SessionStore {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{ttl}
+ for _, a := range keys {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetStore", varargs...)
+ ret0, _ := ret[0].(SessionStore)
+ return ret0
+}
+
+// GetStore indicates an expected call of GetStore.
+func (mr *MockSessionDatabaseMockRecorder) GetStore(ttl interface{}, keys ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{ttl}, keys...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStore", reflect.TypeOf((*MockSessionDatabase)(nil).GetStore), varargs...)
+}
+
+// close mocks base method.
+func (m *MockSessionDatabase) close() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "close")
+}
+
+// close indicates an expected call of close.
+func (mr *MockSessionDatabaseMockRecorder) close() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockSessionDatabase)(nil).close))
+}
+
+// MockSessionStore is a mock of SessionStore interface.
+type MockSessionStore struct {
+ ctrl *gomock.Controller
+ recorder *MockSessionStoreMockRecorder
+}
+
+// MockSessionStoreMockRecorder is the mock recorder for MockSessionStore.
+type MockSessionStoreMockRecorder struct {
+ mock *MockSessionStore
+}
+
+// NewMockSessionStore creates a new mock instance.
+func NewMockSessionStore(ctrl *gomock.Controller) *MockSessionStore {
+ mock := &MockSessionStore{ctrl: ctrl}
+ mock.recorder = &MockSessionStoreMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockSessionStore) EXPECT() *MockSessionStoreMockRecorder {
+ return m.recorder
+}
+
+// Delete mocks base method.
+func (m *MockSessionStore) Delete(key string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Delete", key)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Delete indicates an expected call of Delete.
+func (mr *MockSessionStoreMockRecorder) Delete(key interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSessionStore)(nil).Delete), key)
+}
+
+// Exists mocks base method.
+func (m *MockSessionStore) Exists(key string) bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Exists", key)
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// Exists indicates an expected call of Exists.
+func (mr *MockSessionStoreMockRecorder) Exists(key interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockSessionStore)(nil).Exists), key)
+}
+
+// Get mocks base method.
+func (m *MockSessionStore) Get(key string, target interface{}) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Get", key, target)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Get indicates an expected call of Get.
+func (mr *MockSessionStoreMockRecorder) Get(key, target interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSessionStore)(nil).Get), key, target)
+}
+
+// Put mocks base method.
+func (m *MockSessionStore) Put(key string, value interface{}) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Put", key, value)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Put indicates an expected call of Put.
+func (mr *MockSessionStoreMockRecorder) Put(key, value interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockSessionStore)(nil).Put), key, value)
+}
diff --git a/storage/session.go b/storage/session.go
new file mode 100644
index 0000000000..9ec851284b
--- /dev/null
+++ b/storage/session.go
@@ -0,0 +1,169 @@
+/*
+ * Copyright (C) 2023 Nuts community
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package storage
+
+import (
+ "encoding/json"
+ "github.com/nuts-foundation/nuts-node/storage/log"
+ "strings"
+ "sync"
+ "time"
+)
+
+var _ SessionDatabase = (*InMemorySessionDatabase)(nil)
+var _ SessionStore = (*InMemorySessionStore)(nil)
+
+var sessionStorePruneInterval = 10 * time.Minute
+
+type expiringEntry struct {
+ // Value stores the actual value as JSON
+ Value string
+ Expiry time.Time
+}
+
+// InMemorySessionDatabase is an in memory database that holds session data on a KV basis.
+// Keys could be access tokens, nonce's, authorization codes, etc.
+// All entries are stored with a TTL, so they will be removed automatically.
+type InMemorySessionDatabase struct {
+ done chan struct{}
+ mux sync.RWMutex
+ routines sync.WaitGroup
+ entries map[string]expiringEntry
+}
+
+// NewInMemorySessionDatabase creates a new in memory session database.
+func NewInMemorySessionDatabase() *InMemorySessionDatabase {
+ result := &InMemorySessionDatabase{
+ entries: map[string]expiringEntry{},
+ done: make(chan struct{}, 10),
+ }
+ result.startPruning(sessionStorePruneInterval)
+ return result
+}
+
+func (i *InMemorySessionDatabase) GetStore(ttl time.Duration, keys ...string) SessionStore {
+ return InMemorySessionStore{
+ ttl: ttl,
+ prefixes: keys,
+ db: i,
+ }
+}
+
+func (i *InMemorySessionDatabase) close() {
+ // Signal pruner to stop and wait for it to finish
+ i.done <- struct{}{}
+}
+
+func (i *InMemorySessionDatabase) startPruning(interval time.Duration) {
+ ticker := time.NewTicker(interval)
+ i.routines.Add(1)
+ go func() {
+ defer i.routines.Done()
+ for {
+ select {
+ case <-i.done:
+ ticker.Stop()
+ return
+ case <-ticker.C:
+ valsPruned := i.prune()
+ if valsPruned > 0 {
+ log.Logger().Debugf("Pruned %d expired session variables", valsPruned)
+ }
+ }
+ }
+ }()
+}
+
+func (i *InMemorySessionDatabase) prune() int {
+ i.mux.Lock()
+ defer i.mux.Unlock()
+
+ moment := time.Now()
+
+ // Find expired flows and delete them
+ var count int
+ for key, entry := range i.entries {
+ if entry.Expiry.Before(moment) {
+ count++
+ delete(i.entries, key)
+ }
+ }
+
+ return count
+}
+
+type InMemorySessionStore struct {
+ ttl time.Duration
+ prefixes []string
+ db *InMemorySessionDatabase
+}
+
+func (i InMemorySessionStore) Delete(key string) error {
+ i.db.mux.Lock()
+ defer i.db.mux.Unlock()
+
+ delete(i.db.entries, i.getFullKey(key))
+ return nil
+}
+
+func (i InMemorySessionStore) Exists(key string) bool {
+ i.db.mux.Lock()
+ defer i.db.mux.Unlock()
+
+ _, ok := i.db.entries[i.getFullKey(key)]
+ return ok
+}
+
+func (i InMemorySessionStore) Get(key string, target interface{}) error {
+ i.db.mux.Lock()
+ defer i.db.mux.Unlock()
+
+ fullKey := i.getFullKey(key)
+ entry, ok := i.db.entries[fullKey]
+ if !ok {
+ return ErrNotFound
+ }
+ if entry.Expiry.Before(time.Now()) {
+ delete(i.db.entries, fullKey)
+ return ErrNotFound
+ }
+
+ return json.Unmarshal([]byte(entry.Value), target)
+}
+
+func (i InMemorySessionStore) Put(key string, value interface{}) error {
+ i.db.mux.Lock()
+ defer i.db.mux.Unlock()
+
+ bytes, err := json.Marshal(value)
+ if err != nil {
+ return err
+ }
+ entry := expiringEntry{
+ Value: string(bytes),
+ Expiry: time.Now().Add(i.ttl),
+ }
+
+ i.db.entries[i.getFullKey(key)] = entry
+ return nil
+}
+
+func (i InMemorySessionStore) getFullKey(key string) string {
+ return strings.Join(append(i.prefixes, key), "/")
+}
diff --git a/storage/session_test.go b/storage/session_test.go
new file mode 100644
index 0000000000..473b63c775
--- /dev/null
+++ b/storage/session_test.go
@@ -0,0 +1,239 @@
+/*
+ * Copyright (C) 2023 Nuts community
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package storage
+
+import (
+ "github.com/nuts-foundation/nuts-node/test"
+ "go.uber.org/goleak"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewInMemorySessionDatabase(t *testing.T) {
+ db := createDatabase(t)
+
+ assert.NotNil(t, db)
+}
+
+func TestInMemorySessionDatabase_GetStore(t *testing.T) {
+ db := createDatabase(t)
+
+ store := db.GetStore(time.Minute, "key1", "key2").(InMemorySessionStore)
+
+ require.NotNil(t, store)
+ assert.Equal(t, time.Minute, store.ttl)
+ assert.Equal(t, []string{"key1", "key2"}, store.prefixes)
+}
+
+func TestInMemorySessionStore_Put(t *testing.T) {
+ db := createDatabase(t)
+ store := db.GetStore(time.Minute, "prefix").(InMemorySessionStore)
+
+ t.Run("string value is stored", func(t *testing.T) {
+ err := store.Put("key", "value")
+
+ require.NoError(t, err)
+ assert.Equal(t, `"value"`, store.db.entries["prefix/key"].Value)
+ })
+
+ t.Run("float value is stored", func(t *testing.T) {
+ err := store.Put("key", 1.23)
+
+ require.NoError(t, err)
+ assert.Equal(t, "1.23", store.db.entries["prefix/key"].Value)
+ })
+
+ t.Run("struct value is stored", func(t *testing.T) {
+ value := testStruct{
+ Field1: "value",
+ }
+
+ err := store.Put("key", value)
+
+ require.NoError(t, err)
+ assert.Equal(t, "{\"field1\":\"value\"}", store.db.entries["prefix/key"].Value)
+ })
+
+ t.Run("value is not JSON", func(t *testing.T) {
+ err := store.Put("key", make(chan int))
+
+ assert.Error(t, err)
+ })
+}
+
+func TestInMemorySessionStore_Get(t *testing.T) {
+ db := createDatabase(t)
+ store := db.GetStore(time.Minute, "prefix").(InMemorySessionStore)
+
+ t.Run("string value is retrieved correctly", func(t *testing.T) {
+ _ = store.Put(t.Name(), "value")
+ var actual string
+
+ err := store.Get(t.Name(), &actual)
+
+ require.NoError(t, err)
+ assert.Equal(t, "value", actual)
+ })
+
+ t.Run("float value is retrieved correctly", func(t *testing.T) {
+ _ = store.Put(t.Name(), 1.23)
+ var actual float64
+
+ err := store.Get(t.Name(), &actual)
+
+ require.NoError(t, err)
+ assert.Equal(t, 1.23, actual)
+ })
+
+ t.Run("struct value is retrieved correctly", func(t *testing.T) {
+ value := testStruct{
+ Field1: "value",
+ }
+ _ = store.Put(t.Name(), value)
+ var actual testStruct
+
+ err := store.Get(t.Name(), &actual)
+
+ require.NoError(t, err)
+ assert.Equal(t, value, actual)
+ })
+
+ t.Run("value is not found", func(t *testing.T) {
+ var actual string
+
+ err := store.Get(t.Name(), &actual)
+
+ assert.Equal(t, ErrNotFound, err)
+ })
+
+ t.Run("value is expired", func(t *testing.T) {
+ store.db.entries["prefix/key"] = expiringEntry{
+ Value: "",
+ Expiry: time.Now().Add(-time.Minute),
+ }
+ var actual string
+
+ err := store.Get("key", &actual)
+
+ assert.Equal(t, ErrNotFound, err)
+ })
+
+ t.Run("value is not JSON", func(t *testing.T) {
+ store.db.entries["prefix/key"] = expiringEntry{
+ Value: "not JSON",
+ Expiry: time.Now().Add(time.Minute),
+ }
+ var actual string
+
+ err := store.Get("key", &actual)
+
+ assert.Error(t, err)
+ })
+
+ t.Run("value is not a pointer", func(t *testing.T) {
+ _ = store.Put(t.Name(), "value")
+
+ err := store.Get(t.Name(), "not a pointer")
+
+ assert.Error(t, err)
+ })
+}
+
+func TestInMemorySessionStore_Delete(t *testing.T) {
+ db := createDatabase(t)
+ store := db.GetStore(time.Minute, "prefix").(InMemorySessionStore)
+
+ t.Run("value is deleted", func(t *testing.T) {
+ _ = store.Put(t.Name(), "value")
+
+ err := store.Delete(t.Name())
+
+ require.NoError(t, err)
+ _, ok := store.db.entries["prefix/key"]
+ assert.False(t, ok)
+ })
+
+ t.Run("value is not found", func(t *testing.T) {
+ err := store.Delete(t.Name())
+
+ assert.NoError(t, err)
+ })
+}
+
+func TestInMemorySessionDatabase_Close(t *testing.T) {
+ defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
+
+ t.Run("assert Close() waits for pruning to finish to avoid leaking goroutines", func(t *testing.T) {
+ sessionStorePruneInterval = 10 * time.Millisecond
+ defer func() {
+ sessionStorePruneInterval = 10 * time.Minute
+ }()
+ store := NewInMemorySessionDatabase()
+ time.Sleep(50 * time.Millisecond) // make sure pruning is running
+ store.close()
+ })
+}
+
+func Test_memoryStore_prune(t *testing.T) {
+ t.Run("automatic", func(t *testing.T) {
+ store := createDatabase(t)
+ // we call startPruning a second time ourselves to speed things up, make sure not to leak the original goroutine
+ defer func() {
+ store.done <- struct{}{}
+ }()
+ store.startPruning(10 * time.Millisecond)
+
+ err := store.GetStore(time.Millisecond).Put("key", "value")
+ require.NoError(t, err)
+
+ test.WaitFor(t, func() (bool, error) {
+ store.mux.Lock()
+ defer store.mux.Unlock()
+ _, exists := store.entries["key"]
+ return !exists, nil
+ }, time.Second, "time-out waiting for entry to be pruned")
+ })
+ t.Run("prunes expired flows", func(t *testing.T) {
+ store := createDatabase(t)
+ defer store.close()
+
+ _ = store.GetStore(0).Put("key1", "value")
+ _ = store.GetStore(time.Minute).Put("key2", "value")
+
+ count := store.prune()
+
+ assert.Equal(t, 1, count)
+
+ // Second round to assert there's nothing to prune now
+ count = store.prune()
+
+ assert.Equal(t, 0, count)
+ })
+}
+
+type testStruct struct {
+ Field1 string `json:"field1"`
+}
+
+func createDatabase(t *testing.T) *InMemorySessionDatabase {
+ return NewTestInMemorySessionDatabase(t)
+}
diff --git a/storage/test.go b/storage/test.go
index eba2f9a337..d1c6c07116 100644
--- a/storage/test.go
+++ b/storage/test.go
@@ -67,3 +67,11 @@ func (p *StaticKVStoreProvider) GetKVStore(_ string, _ Class) (stoabs.KVStore, e
}
return p.Store, nil
}
+
+func NewTestInMemorySessionDatabase(t *testing.T) *InMemorySessionDatabase {
+ db := NewInMemorySessionDatabase()
+ t.Cleanup(func() {
+ db.close()
+ })
+ return db
+}
diff --git a/vcr/issuer/openid.go b/vcr/issuer/openid.go
index a8ee1ec298..18b3d68752 100644
--- a/vcr/issuer/openid.go
+++ b/vcr/issuer/openid.go
@@ -34,6 +34,7 @@ import (
"github.com/nuts-foundation/nuts-node/audit"
"github.com/nuts-foundation/nuts-node/core"
"github.com/nuts-foundation/nuts-node/crypto"
+ "github.com/nuts-foundation/nuts-node/storage"
"github.com/nuts-foundation/nuts-node/vcr/issuer/assets"
"github.com/nuts-foundation/nuts-node/vcr/log"
"github.com/nuts-foundation/nuts-node/vcr/openid4vci"
@@ -57,14 +58,6 @@ type Flow struct {
// Credentials is the list of Verifiable Credentials that be issued to the wallet through this flow.
// It might be pre-determined (in the issuer-initiated flow) or determined during the flow execution (in the wallet-initiated flow).
Credentials []vc.VerifiableCredential `json:"credentials"`
- Expiry time.Time `json:"exp"`
-}
-
-// Nonce is a nonce that has been issued for an OpenID4VCI flow, to be used by the wallet when requesting credentials.
-// A nonce can only be used once (doh), and is only valid for a certain period of time.
-type Nonce struct {
- Nonce string `json:"nonce"`
- Expiry time.Time `json:"exp"`
}
// Grant is a grant that has been issued for an OAuth2 state.
@@ -75,8 +68,6 @@ type Grant struct {
Params map[string]interface{} `json:"params"`
}
-// ErrUnknownIssuer is returned when the given issuer is unknown.
-var ErrUnknownIssuer = errors.New("unknown OpenID4VCI issuer")
var _ OpenIDHandler = (*openidHandler)(nil)
// TokenTTL is the time-to-live for issuance flows, access tokens and nonces.
@@ -105,7 +96,7 @@ type OpenIDHandler interface {
}
// NewOpenIDHandler creates a new OpenIDHandler instance. The identifier is the Credential Issuer Identifier, e.g. https://example.com/issuer/
-func NewOpenIDHandler(issuerDID did.DID, issuerIdentifierURL string, definitionsDIR string, httpClient core.HTTPRequestDoer, keyResolver resolver.KeyResolver, store OpenIDStore) (OpenIDHandler, error) {
+func NewOpenIDHandler(issuerDID did.DID, issuerIdentifierURL string, definitionsDIR string, httpClient core.HTTPRequestDoer, keyResolver resolver.KeyResolver, sessionDatabase storage.SessionDatabase) (OpenIDHandler, error) {
i := &openidHandler{
issuerIdentifierURL: issuerIdentifierURL,
issuerDID: issuerDID,
@@ -113,7 +104,7 @@ func NewOpenIDHandler(issuerDID did.DID, issuerIdentifierURL string, definitions
httpClient: httpClient,
keyResolver: keyResolver,
walletClientCreator: openid4vci.NewWalletAPIClient,
- store: store,
+ store: NewOpenIDMemoryStore(sessionDatabase),
}
// load the credential definitions. This is done to halt startup procedure if needed.
@@ -174,12 +165,12 @@ func (i *openidHandler) HandleAccessTokenRequest(ctx context.Context, preAuthori
}
}
accessToken := generateCode()
- err = i.store.StoreReference(ctx, flow.ID, accessTokenRefType, accessToken, time.Now().Add(TokenTTL))
+ err = i.store.StoreReference(ctx, flow.ID, accessTokenRefType, accessToken)
if err != nil {
return "", "", err
}
cNonce := generateCode()
- err = i.store.StoreReference(ctx, flow.ID, cNonceRefType, cNonce, time.Now().Add(TokenTTL))
+ err = i.store.StoreReference(ctx, flow.ID, cNonceRefType, cNonce)
if err != nil {
return "", "", err
}
@@ -294,7 +285,7 @@ func (i *openidHandler) validateProof(ctx context.Context, flow *Flow, request o
// augment invalid_proof errors according to ยง7.3.2 of openid4vci spec
generateProofError := func(err openid4vci.Error) error {
cnonce := generateCode()
- if err := i.store.StoreReference(ctx, flow.ID, cNonceRefType, cnonce, time.Now().Add(TokenTTL)); err != nil {
+ if err := i.store.StoreReference(ctx, flow.ID, cNonceRefType, cnonce); err != nil {
return err
}
expiry := int(TokenTTL.Seconds())
@@ -438,7 +429,6 @@ func (i *openidHandler) createOffer(ctx context.Context, credential vc.Verifiabl
ID: uuid.NewString(),
IssuerID: credential.Issuer.String(),
WalletID: subjectDID.String(),
- Expiry: time.Now().Add(TokenTTL),
Credentials: []vc.VerifiableCredential{credential},
Grants: []Grant{
{
@@ -449,7 +439,7 @@ func (i *openidHandler) createOffer(ctx context.Context, credential vc.Verifiabl
}
err := i.store.Store(ctx, flow)
if err == nil {
- err = i.store.StoreReference(ctx, flow.ID, preAuthCodeRefType, preAuthorizedCode, time.Now().Add(TokenTTL))
+ err = i.store.StoreReference(ctx, flow.ID, preAuthCodeRefType, preAuthorizedCode)
}
if err != nil {
return nil, fmt.Errorf("unable to store credential offer: %w", err)
diff --git a/vcr/issuer/openid_store.go b/vcr/issuer/openid_store.go
index 689556cae8..0471301164 100644
--- a/vcr/issuer/openid_store.go
+++ b/vcr/issuer/openid_store.go
@@ -21,9 +21,7 @@ package issuer
import (
"context"
"errors"
- "github.com/nuts-foundation/nuts-node/vcr/log"
- "sync"
- "time"
+ "github.com/nuts-foundation/nuts-node/storage"
)
// OpenIDStore defines the storage API for OpenID Credential Issuance flows.
@@ -35,164 +33,71 @@ type OpenIDStore interface {
// like a database index. The reference must be unique for all flows.
// The expiry is the time-to-live for the reference. After this time, the reference is automatically deleted.
// If the flow does not exist, or the reference does already exist, it returns an error.
- StoreReference(ctx context.Context, flowID string, refType string, reference string, expiry time.Time) error
+ StoreReference(ctx context.Context, flowID string, refType string, reference string) error
// FindByReference finds a Flow by its reference.
// If the flow does not exist, it returns nil.
FindByReference(ctx context.Context, refType string, reference string) (*Flow, error)
// DeleteReference deletes the reference from the store.
// It does not return an error if it doesn't exist anymore.
DeleteReference(ctx context.Context, refType string, reference string) error
- // Close signals the store to close any owned resources.
- Close()
}
var _ OpenIDStore = (*openidMemoryStore)(nil)
-var openidStorePruneInterval = 10 * time.Minute
-
type openidMemoryStore struct {
- mux *sync.RWMutex
- flows map[string]Flow
- refs map[string]map[string]referenceValue
- routines *sync.WaitGroup
- ctx context.Context
- cancel context.CancelFunc
+ sessionDatabase storage.SessionDatabase
}
// NewOpenIDMemoryStore creates a new in-memory OpenIDStore.
-func NewOpenIDMemoryStore() OpenIDStore {
- result := &openidMemoryStore{
- mux: &sync.RWMutex{},
- flows: map[string]Flow{},
- refs: map[string]map[string]referenceValue{},
- routines: &sync.WaitGroup{},
+func NewOpenIDMemoryStore(sessionDatabase storage.SessionDatabase) OpenIDStore {
+ return &openidMemoryStore{
+ sessionDatabase: sessionDatabase,
}
- result.ctx, result.cancel = context.WithCancel(context.Background())
- result.startPruning(openidStorePruneInterval)
- return result
-}
-
-type referenceValue struct {
- FlowID string `json:"flow_id"`
- Expiry time.Time `json:"exp"`
}
func (o *openidMemoryStore) Store(_ context.Context, flow Flow) error {
if len(flow.ID) == 0 {
return errors.New("invalid flow ID")
}
- o.mux.Lock()
- defer o.mux.Unlock()
- if o.flows[flow.ID].ID != "" {
+ store := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", "flow")
+ if store.Exists(flow.ID) {
return errors.New("OAuth2 flow with this ID already exists")
}
- o.flows[flow.ID] = flow
- return nil
+ return store.Put(flow.ID, flow)
}
-func (o *openidMemoryStore) StoreReference(_ context.Context, flowID string, refType string, reference string, expiry time.Time) error {
+func (o *openidMemoryStore) StoreReference(_ context.Context, flowID string, refType string, reference string) error {
if len(reference) == 0 {
return errors.New("invalid reference")
}
- o.mux.Lock()
- defer o.mux.Unlock()
- if o.flows[flowID].ID == "" {
+ refStore := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", refType)
+ flowStore := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", "flow")
+ if !flowStore.Exists(flowID) {
return errors.New("OAuth2 flow with this ID does not exist")
}
- if o.refs[refType] == nil {
- o.refs[refType] = map[string]referenceValue{}
- }
- if _, ok := o.refs[refType][reference]; ok {
+ if refStore.Exists(reference) {
return errors.New("reference already exists")
}
- o.refs[refType][reference] = referenceValue{FlowID: flowID, Expiry: expiry}
- return nil
+ return refStore.Put(reference, flowID)
}
func (o *openidMemoryStore) FindByReference(_ context.Context, refType string, reference string) (*Flow, error) {
- o.mux.RLock()
- defer o.mux.RUnlock()
-
- refMap := o.refs[refType]
- if refMap == nil {
- return nil, nil
- }
- value, ok := refMap[reference]
- if !ok {
- return nil, nil
- }
- if value.Expiry.Before(time.Now()) {
+ refStore := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", refType)
+ flowStore := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", "flow")
+ if !refStore.Exists(reference) {
return nil, nil
}
-
- flow := o.flows[value.FlowID]
- if flow.Expiry.Before(time.Now()) {
- return nil, nil
+ var flowID string
+ err := refStore.Get(reference, &flowID)
+ if err != nil {
+ return nil, err
}
- return &flow, nil
+ var flow Flow
+ err = flowStore.Get(flowID, &flow)
+ return &flow, err
}
func (o *openidMemoryStore) DeleteReference(_ context.Context, refType string, reference string) error {
- o.mux.Lock()
- defer o.mux.Unlock()
-
- if o.refs[refType] == nil {
- return nil
- }
- delete(o.refs[refType], reference)
- return nil
-}
-
-func (o *openidMemoryStore) Close() {
- // Signal pruner to stop and wait for it to finish
- o.cancel()
- o.routines.Wait()
-}
-
-func (o *openidMemoryStore) startPruning(interval time.Duration) {
- ticker := time.NewTicker(interval)
- o.routines.Add(1)
- go func(ctx context.Context) {
- defer o.routines.Done()
- for {
- select {
- case <-ctx.Done():
- ticker.Stop()
- return
- case <-ticker.C:
- flowsPruned, refsPruned := o.prune()
- if flowsPruned > 0 || refsPruned > 0 {
- log.Logger().Debugf("Pruned %d expired OpenID4VCI flows and %d expired refs", flowsPruned, refsPruned)
- }
- }
- }
- }(o.ctx)
-}
-
-func (o *openidMemoryStore) prune() (int, int) {
- o.mux.Lock()
- defer o.mux.Unlock()
-
- moment := time.Now()
-
- // Find expired flows and delete them
- var flowCount int
- for id, flow := range o.flows {
- if flow.Expiry.Before(moment) {
- flowCount++
- delete(o.flows, id)
- }
- }
- // Find expired refs and delete them
- var refCount int
- for _, refMap := range o.refs {
- for reference, value := range refMap {
- if value.Expiry.Before(moment) {
- refCount++
- delete(refMap, reference)
- }
- }
- }
-
- return flowCount, refCount
+ refStore := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", refType)
+ return refStore.Delete(reference)
}
diff --git a/vcr/issuer/openid_store_test.go b/vcr/issuer/openid_store_test.go
index a779e949ce..9fcbf80109 100644
--- a/vcr/issuer/openid_store_test.go
+++ b/vcr/issuer/openid_store_test.go
@@ -20,11 +20,9 @@ package issuer
import (
"context"
- "github.com/nuts-foundation/nuts-node/test"
+ "github.com/nuts-foundation/nuts-node/storage"
"github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
"testing"
- "time"
)
const refType = "ref-type"
@@ -34,12 +32,11 @@ func Test_memoryStore_DeleteReference(t *testing.T) {
t.Run("ok", func(t *testing.T) {
store := createStore(t)
expected := Flow{
- ID: "flow-id",
- Expiry: futureExpiry(),
+ ID: "flow-id",
}
err := store.Store(context.Background(), expected)
assert.NoError(t, err)
- err = store.StoreReference(context.Background(), expected.ID, refType, ref, futureExpiry())
+ err = store.StoreReference(context.Background(), expected.ID, refType, ref)
assert.NoError(t, err)
err = store.DeleteReference(context.Background(), refType, ref)
@@ -63,49 +60,31 @@ func Test_memoryStore_FindByReference(t *testing.T) {
t.Run("reference already exists", func(t *testing.T) {
store := createStore(t)
expected := Flow{
- ID: "flow-id",
- Expiry: futureExpiry(),
+ ID: "flow-id",
}
err := store.Store(context.Background(), expected)
assert.NoError(t, err)
- err = store.StoreReference(context.Background(), expected.ID, refType, ref, futureExpiry())
+ err = store.StoreReference(context.Background(), expected.ID, refType, ref)
assert.NoError(t, err)
- err = store.StoreReference(context.Background(), expected.ID, refType, ref, futureExpiry())
+ err = store.StoreReference(context.Background(), expected.ID, refType, ref)
assert.EqualError(t, err, "reference already exists")
})
t.Run("invalid reference", func(t *testing.T) {
store := createStore(t)
- err := store.StoreReference(context.Background(), "unknown", refType, "", futureExpiry())
+ err := store.StoreReference(context.Background(), "unknown", refType, "")
assert.EqualError(t, err, "invalid reference")
})
t.Run("unknown flow", func(t *testing.T) {
store := createStore(t)
- err := store.StoreReference(context.Background(), "unknown", refType, ref, futureExpiry())
+ err := store.StoreReference(context.Background(), "unknown", refType, ref)
assert.EqualError(t, err, "OAuth2 flow with this ID does not exist")
})
- t.Run("reference has expired", func(t *testing.T) {
- store := createStore(t)
- expected := Flow{
- ID: "flow-id",
- Expiry: futureExpiry(),
- }
-
- err := store.Store(context.Background(), expected)
- assert.NoError(t, err)
- // We need a reference to resolve it
- err = store.StoreReference(context.Background(), expected.ID, refType, ref, pastExpiry())
- assert.NoError(t, err)
-
- actual, err := store.FindByReference(context.Background(), refType, ref)
- assert.NoError(t, err)
- assert.Nil(t, actual)
- })
}
func Test_memoryStore_Store(t *testing.T) {
@@ -113,14 +92,13 @@ func Test_memoryStore_Store(t *testing.T) {
t.Run("write, then read", func(t *testing.T) {
store := createStore(t)
expected := Flow{
- ID: "flow-id",
- Expiry: futureExpiry(),
+ ID: "flow-id",
}
err := store.Store(ctx, expected)
assert.NoError(t, err)
// We need a reference to resolve it
- err = store.StoreReference(ctx, expected.ID, refType, ref, futureExpiry())
+ err = store.StoreReference(ctx, expected.ID, refType, ref)
assert.NoError(t, err)
actual, err := store.FindByReference(ctx, refType, ref)
@@ -130,8 +108,7 @@ func Test_memoryStore_Store(t *testing.T) {
t.Run("already exists", func(t *testing.T) {
store := createStore(t)
expected := Flow{
- ID: "flow-id",
- Expiry: futureExpiry(),
+ ID: "flow-id",
}
err := store.Store(ctx, expected)
@@ -140,124 +117,10 @@ func Test_memoryStore_Store(t *testing.T) {
assert.EqualError(t, err, "OAuth2 flow with this ID already exists")
})
- t.Run("flow has expired", func(t *testing.T) {
- store := createStore(t)
- expected := Flow{
- ID: "flow-id",
- Expiry: pastExpiry(),
- }
-
- err := store.Store(ctx, expected)
- assert.NoError(t, err)
- // We need a reference to resolve it
- err = store.StoreReference(ctx, expected.ID, refType, ref, futureExpiry())
- assert.NoError(t, err)
-
- actual, err := store.FindByReference(ctx, refType, ref)
- assert.NoError(t, err)
- assert.Nil(t, actual)
- })
-}
-
-func Test_memoryStore_Close(t *testing.T) {
- t.Run("assert Close() waits for pruning to finish to avoid leaking goroutines", func(t *testing.T) {
- openidStorePruneInterval = 10 * time.Millisecond
- store := createStore(t)
- time.Sleep(50 * time.Millisecond) // make sure pruning is running
- store.Close()
- })
-}
-
-func Test_memoryStore_prune(t *testing.T) {
- ctx := context.Background()
- t.Run("automatic", func(t *testing.T) {
- store := createStore(t)
- // we call startPruning a second time ourselves, make sure not to leak the original goroutine
- cancelFunc := store.cancel
- defer cancelFunc()
- store.startPruning(10 * time.Millisecond)
-
- // Feed it something to prune
- expiredFlow := Flow{
- ID: "expired",
- }
- err := store.Store(ctx, expiredFlow)
- require.NoError(t, err)
-
- test.WaitFor(t, func() (bool, error) {
- store.mux.Lock()
- defer store.mux.Unlock()
- _, exists := store.flows[expiredFlow.ID]
- return !exists, nil
- }, time.Second, "time-out waiting for flow to be pruned")
- })
- t.Run("prunes expired flows", func(t *testing.T) {
- store := createStore(t)
-
- expiredFlow := Flow{
- ID: "expired",
- }
- unexpiredFlow := Flow{
- ID: "unexpired",
- Expiry: futureExpiry(),
- }
- _ = store.Store(ctx, expiredFlow)
- _ = store.Store(ctx, unexpiredFlow)
-
- flows, refs := store.prune()
-
- assert.Equal(t, 1, flows)
- assert.Equal(t, 0, refs)
-
- // Second round to assert there's nothing to prune now
- flows, refs = store.prune()
-
- assert.Equal(t, 0, flows)
- assert.Equal(t, 0, refs)
- })
- t.Run("prunes expired refs", func(t *testing.T) {
- store := createStore(t)
-
- flow := Flow{
- ID: "f",
- Expiry: futureExpiry(),
- }
- err := store.Store(ctx, flow)
- require.NoError(t, err)
- err = store.StoreReference(ctx, flow.ID, refType, "expired", pastExpiry())
- require.NoError(t, err)
- err = store.StoreReference(ctx, flow.ID, refType, "unexpired", futureExpiry())
- require.NoError(t, err)
-
- flows, refs := store.prune()
-
- assert.Equal(t, 0, flows)
- assert.Equal(t, 1, refs)
-
- // Second round to assert there's nothing to prune now
- flows, refs = store.prune()
-
- assert.NoError(t, err)
- assert.Equal(t, 0, flows)
- assert.Equal(t, 0, refs)
- })
}
func createStore(t *testing.T) *openidMemoryStore {
- store := NewOpenIDMemoryStore().(*openidMemoryStore)
- t.Cleanup(store.Close)
+ storageDatabase := storage.NewTestInMemorySessionDatabase(t)
+ store := NewOpenIDMemoryStore(storageDatabase).(*openidMemoryStore)
return store
}
-
-func moment() time.Time {
- return time.Now().In(time.UTC)
-}
-
-func pastExpiry() time.Time {
- return moment().Add(-time.Hour)
-}
-
-func futureExpiry() time.Time {
- // truncating makes assertion easier
- return moment().Add(time.Hour).Truncate(time.Second)
-}
diff --git a/vcr/issuer/openid_test.go b/vcr/issuer/openid_test.go
index c19199c7e5..281eeaaa0f 100644
--- a/vcr/issuer/openid_test.go
+++ b/vcr/issuer/openid_test.go
@@ -28,6 +28,7 @@ import (
"github.com/nuts-foundation/nuts-node/audit"
"github.com/nuts-foundation/nuts-node/core"
"github.com/nuts-foundation/nuts-node/crypto"
+ "github.com/nuts-foundation/nuts-node/storage"
"github.com/nuts-foundation/nuts-node/vcr/openid4vci"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"github.com/stretchr/testify/assert"
@@ -65,21 +66,21 @@ var issuedVC = vc.VerifiableCredential{
func TestNew(t *testing.T) {
t.Run("custom definitions", func(t *testing.T) {
- iss, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/valid", nil, nil, NewOpenIDMemoryStore())
+ iss, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/valid", nil, nil, storage.NewTestInMemorySessionDatabase(t))
require.NoError(t, err)
assert.Len(t, iss.(*openidHandler).credentialsSupported, 3)
})
t.Run("error - invalid json", func(t *testing.T) {
- _, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/invalid", nil, nil, NewOpenIDMemoryStore())
+ _, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/invalid", nil, nil, storage.NewTestInMemorySessionDatabase(t))
require.Error(t, err)
assert.EqualError(t, err, "failed to parse credential definition from test/invalid/invalid.json: unexpected end of JSON input")
})
t.Run("error - invalid directory", func(t *testing.T) {
- _, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/non_existing", nil, nil, NewOpenIDMemoryStore())
+ _, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/non_existing", nil, nil, storage.NewTestInMemorySessionDatabase(t))
require.Error(t, err)
assert.EqualError(t, err, "failed to load credential definitions: lstat ./test/non_existing: no such file or directory")
@@ -396,7 +397,7 @@ func Test_memoryIssuer_HandleAccessTokenRequest(t *testing.T) {
assert.NotEmpty(t, accessToken)
})
t.Run("pre-authorized code issued by other issuer", func(t *testing.T) {
- store := NewOpenIDMemoryStore()
+ store := storage.NewTestInMemorySessionDatabase(t)
service, err := NewOpenIDHandler(issuerDID, issuerIdentifier, definitionsDIR, &http.Client{}, nil, store)
require.NoError(t, err)
_, err = service.(*openidHandler).createOffer(ctx, issuedVC, "code")
@@ -435,7 +436,7 @@ func assertProtocolError(t *testing.T, err error, statusCode int, message string
}
func requireNewTestHandler(t *testing.T, keyResolver resolver.KeyResolver) *openidHandler {
- service, err := NewOpenIDHandler(issuerDID, issuerIdentifier, definitionsDIR, &http.Client{}, keyResolver, NewOpenIDMemoryStore())
+ service, err := NewOpenIDHandler(issuerDID, issuerIdentifier, definitionsDIR, &http.Client{}, keyResolver, storage.NewTestInMemorySessionDatabase(t))
require.NoError(t, err)
return service.(*openidHandler)
}
diff --git a/vcr/vcr.go b/vcr/vcr.go
index 81fe82d4b5..ea44364c24 100644
--- a/vcr/vcr.go
+++ b/vcr/vcr.go
@@ -99,7 +99,7 @@ type vcr struct {
jsonldManager jsonld.JSONLD
eventManager events.Event
storageClient storage.Engine
- openidIsssuerStore issuer.OpenIDStore
+ openidSessionStore storage.SessionDatabase
localWalletResolver openid4vci.IdentifierResolver
issuerHttpClient core.HTTPRequestDoer
walletHttpClient core.HTTPRequestDoer
@@ -112,7 +112,7 @@ func (c *vcr) GetOpenIDIssuer(ctx context.Context, id did.DID) (issuer.OpenIDHan
if err != nil {
return nil, err
}
- return issuer.NewOpenIDHandler(id, identifier, c.config.OpenID4VCI.DefinitionsDIR, c.issuerHttpClient, c.keyResolver, c.openidIsssuerStore)
+ return issuer.NewOpenIDHandler(id, identifier, c.config.OpenID4VCI.DefinitionsDIR, c.issuerHttpClient, c.keyResolver, c.openidSessionStore)
}
func (c *vcr) GetOpenIDHolder(ctx context.Context, id did.DID) (holder.OpenIDHandler, error) {
@@ -269,7 +269,7 @@ func (c *vcr) Configure(config core.ServerConfig) error {
Timeout: c.config.OpenID4VCI.Timeout,
Transport: walletTransport,
})
- c.openidIsssuerStore = issuer.NewOpenIDMemoryStore()
+ c.openidSessionStore = c.storageClient.GetSessionDatabase()
}
c.issuer = issuer.NewIssuer(c.issuerStore, c, networkPublisher, openidHandlerFn, didResolver, c.keyStore, c.jsonldManager, c.trustConfig)
c.verifier = verifier.NewVerifier(c.verifierStore, didResolver, c.keyResolver, c.jsonldManager, c.trustConfig)
@@ -329,9 +329,6 @@ func (c *vcr) Start() error {
}
func (c *vcr) Shutdown() error {
- if c.openidIsssuerStore != nil {
- c.openidIsssuerStore.Close()
- }
err := c.issuerStore.Close()
if err != nil {
log.Logger().