diff --git a/usecase/doublylinkedlist.go b/usecase/doublylinkedlist.go new file mode 100644 index 0000000000..169043e517 --- /dev/null +++ b/usecase/doublylinkedlist.go @@ -0,0 +1,49 @@ +package usecase + +type doublyLinkedList[T any] struct { + head *item[T] + tail *item[T] +} + +type item[T any] struct { + prev *item[T] + next *item[T] + value T +} + +func (l *doublyLinkedList[T]) empty() bool { + return l.head == nil +} + +func (l *doublyLinkedList[T]) append(v T) *item[T] { + newItem := &item[T]{ + value: v, + prev: l.tail, + } + if l.head == nil { + // empty list, item becomes first and last item + l.head = newItem + } else { + // non-empty list, item becomes last item + l.tail.next = newItem + } + l.tail = newItem + return newItem +} + +func (l *doublyLinkedList[T]) remove(this *item[T]) { + if this.prev == nil { + // no prev, this is head, next becomes head + l.head = this.next + } else { + // there's items after this + this.prev.next = this.next + } + if this.next == nil { + // this is tail, so prev becomes tail + l.tail = this.prev + } else { + // this isn't tail (more items after this), so prev becomes tail + this.next.prev = this.prev + } +} diff --git a/usecase/doublylinkedlist_test.go b/usecase/doublylinkedlist_test.go new file mode 100644 index 0000000000..5004dae287 --- /dev/null +++ b/usecase/doublylinkedlist_test.go @@ -0,0 +1,82 @@ +package usecase + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_doublyLinkedList_append(t *testing.T) { + t.Run("1 entry", func(t *testing.T) { + list := new(doublyLinkedList[int]) + list.append(1) + assert.Equal(t, 1, list.head.value) + assert.Equal(t, 1, list.tail.value) + }) + t.Run("2 entries", func(t *testing.T) { + list := new(doublyLinkedList[int]) + list.append(1) + list.append(2) + assert.Equal(t, 1, list.head.value) + assert.Equal(t, 2, list.tail.value) + assert.Equal(t, list.tail, list.head.next) + assert.Equal(t, list.head, list.tail.prev) + }) +} + +func Test_doublyLinkedList_empty(t *testing.T) { + t.Run("empty", func(t *testing.T) { + list := new(doublyLinkedList[int]) + assert.True(t, list.empty()) + }) + t.Run("non-empty", func(t *testing.T) { + list := new(doublyLinkedList[int]) + list.append(1) + assert.False(t, list.empty()) + }) +} + +func Test_doublyLinkedList_remove(t *testing.T) { + t.Run("remove head", func(t *testing.T) { + list := new(doublyLinkedList[int]) + first := list.append(1) + list.append(2) + list.remove(first) + assert.Same(t, list.tail, list.head) + assert.Equal(t, 2, list.head.value) + assert.Nil(t, list.head.prev) + assert.Nil(t, list.head.next) + }) + t.Run("remove tail", func(t *testing.T) { + list := new(doublyLinkedList[int]) + list.append(1) + last := list.append(2) + list.remove(last) + assert.Same(t, list.tail, list.head) + assert.Equal(t, 1, list.head.value) + assert.Nil(t, list.head.prev) + assert.Nil(t, list.head.next) + }) + t.Run("remove middle", func(t *testing.T) { + list := new(doublyLinkedList[int]) + first := list.append(1) + middle := list.append(2) + last := list.append(3) + list.remove(middle) + assert.Equal(t, 1, list.head.value) + assert.Equal(t, 3, list.tail.value) + assert.Same(t, first, list.head) + assert.Same(t, last, list.tail) + assert.Same(t, last, list.head.next) + assert.Same(t, first, list.tail.prev) + }) + t.Run("empty list after remove of last item", func(t *testing.T) { + list := new(doublyLinkedList[int]) + first := list.append(1) + second := list.append(2) + list.remove(first) + list.remove(second) + assert.Nil(t, list.head) + assert.Nil(t, list.tail) + assert.True(t, list.empty()) + }) +} diff --git a/usecase/maintainer.go b/usecase/maintainer.go index 1aab8ce04c..be9d87fdf3 100644 --- a/usecase/maintainer.go +++ b/usecase/maintainer.go @@ -19,39 +19,41 @@ package usecase import ( + "crypto/md5" "errors" "fmt" "github.com/nuts-foundation/go-did/vc" "github.com/nuts-foundation/nuts-node/usecase/log" "strings" "sync" + "time" ) var _ ListWriter = &maintainer{} var ErrListNotFound = errors.New("list not found") var ErrPresentationAlreadyExists = errors.New("presentation already exists") -// listEntry is a singly-linked list entry, used to store the Verifiable Presentations in order they were added to the list. -type listEntry struct { +// listValue is a doubly-linked list entry value, used to store the Verifiable Presentations in order they were added to the list. +type listValue struct { // presentation is the Verifiable Presentation presentation vc.VerifiablePresentation - // next is the next entry in the list - next *listEntry - timestamp Timestamp + timestamp Timestamp } type list struct { definition Definition name string - // head is the first entry in the list - head *listEntry - // tail is the last entry in the list - tail *listEntry - lock sync.RWMutex + items doublyLinkedList[*listValue] + // index maps a presentation hash to the entry in the list + index map[[16]byte]*item[*listValue] + lock sync.RWMutex } func (l *list) exists(presentation vc.VerifiablePresentation) bool { - return false // TODO + l.lock.RLock() + defer l.lock.RUnlock() + _, exists := l.index[presentationHash(presentation)] + return exists } func (l *list) add(presentation vc.VerifiablePresentation) error { @@ -61,18 +63,17 @@ func (l *list) add(presentation vc.VerifiablePresentation) error { } l.lock.Lock() defer l.lock.Unlock() - newEntry := &listEntry{ + isEmpty := l.items.empty() + newEntry := &listValue{ presentation: presentation, timestamp: 1, } - if l.tail != nil { - newEntry.timestamp = l.tail.timestamp + 1 - l.tail.next = newEntry - } - l.tail = newEntry - if l.head == nil { - l.head = newEntry + addedItem := l.items.append(newEntry) + if !isEmpty { + // list wasn't empty, so we need to increment the timestamp + newEntry.timestamp = addedItem.prev.value.timestamp + 1 } + l.index[presentationHash(presentation)] = addedItem return nil } @@ -82,27 +83,50 @@ func (l *list) get(startAfter Timestamp) ([]vc.VerifiablePresentation, Timestamp result := make([]vc.VerifiablePresentation, 0) timestamp := startAfter - if l.head == nil { - // empty list + if l.items.empty() { return result, timestamp } - current := l.head + current := l.items.head for { if current == nil { // End of list break } - if current.timestamp > startAfter { + if current.value.timestamp > startAfter { // Client wants presentations after the given lamport clock - result = append(result, current.presentation) - timestamp = current.timestamp + result = append(result, current.value.presentation) + timestamp = current.value.timestamp } current = current.next } return result, timestamp } +func (l *list) prune(currentTime time.Time) { + l.lock.Lock() + defer l.lock.Unlock() + current := l.items.head + for { + if current == nil { + // End of list + break + } + token := current.value.presentation.JWT() + // TODO: check revocation status + if !token.Expiration().Before(currentTime) { + // expired, remove + l.items.remove(current) + delete(l.index, presentationHash(current.value.presentation)) + } + current = current.next + } +} + +func presentationHash(presentation vc.VerifiablePresentation) [16]byte { + return md5.Sum([]byte(presentation.Raw())) +} + func createList(definition Definition) (*list, error) { // name is derived from endpoint: it's the last path part of the definition endpoint // It is used to route HTTP GET requests to the correct list. @@ -114,6 +138,7 @@ func createList(definition Definition) (*list, error) { return &list{ definition: definition, name: name, + index: map[[16]byte]*item[*listValue]{}, lock: sync.RWMutex{}, }, nil } @@ -167,3 +192,11 @@ func (m *maintainer) Get(listName string, startAt Timestamp) ([]vc.VerifiablePre result, timestamp := l.(*list).get(startAt) return result, ×tamp, nil } + +func (m *maintainer) pruneLists(currentTime time.Time) { + m.lists.Range(func(_, value any) bool { + currentList := value.(*list) + currentList.prune(currentTime) + return true + }) +} diff --git a/usecase/maintainer_test.go b/usecase/maintainer_test.go index 9f7b4a9dfc..809bc5d9e1 100644 --- a/usecase/maintainer_test.go +++ b/usecase/maintainer_test.go @@ -41,21 +41,19 @@ func init() { } func Test_list_exists(t *testing.T) { - t.Skip("TODO") - t.Run("empty list", func(t *testing.T) { - l, err := createList(Definition{}) + l, err := createList(testDefinition) require.NoError(t, err) assert.False(t, l.exists(jwtVP)) }) t.Run("non-empty list, no match", func(t *testing.T) { - l, err := createList(Definition{}) + l, err := createList(testDefinition) require.NoError(t, err) require.NoError(t, l.add(jwtVP)) assert.False(t, l.exists(vc.VerifiablePresentation{})) }) t.Run("non-empty list, match", func(t *testing.T) { - l, err := createList(Definition{}) + l, err := createList(testDefinition) require.NoError(t, err) require.NoError(t, l.add(jwtVP)) assert.True(t, l.exists(jwtVP)) @@ -112,6 +110,20 @@ func Test_list_get(t *testing.T) { }) } +func Test_list_add(t *testing.T) { + vp1, err := vc.ParseVerifiablePresentation(`{"id": "did:example:issuer#1"}`) + require.NoError(t, err) + + t.Run("already exists", func(t *testing.T) { + l, err := createList(testDefinition) + require.NoError(t, err) + err = l.add(*vp1) + require.NoError(t, err) + err = l.add(*vp1) + assert.Equal(t, ErrPresentationAlreadyExists, err) + }) +} + func Test_maintainer_Add(t *testing.T) { t.Run("ok", func(t *testing.T) { m, err := newMaintainer("", []Definition{testDefinition}) @@ -124,6 +136,15 @@ func Test_maintainer_Add(t *testing.T) { assert.NoError(t, err) assert.Equal(t, Timestamp(1), *timestamp) }) + t.Run("already exists", func(t *testing.T) { + m, err := newMaintainer("", []Definition{testDefinition}) + require.NoError(t, err) + + err = m.Add("usecase", jwtVP) + assert.NoError(t, err) + err = m.Add("usecase", jwtVP) + assert.EqualError(t, err, "presentation already exists") + }) t.Run("list unknown", func(t *testing.T) { m, err := newMaintainer("", []Definition{testDefinition}) require.NoError(t, err) @@ -139,7 +160,7 @@ func Test_maintainer_Get(t *testing.T) { err = m.Add("usecase", jwtVP) assert.NoError(t, err) - vps, timestamp, err := m.Get("foo", 0) + vps, timestamp, err := m.Get("usecase", 0) assert.NoError(t, err) assert.Equal(t, []vc.VerifiablePresentation{jwtVP}, vps) assert.Equal(t, Timestamp(1), *timestamp)