Skip to content

Commit

Permalink
move discovery VP validation to background process to prevent errors (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
woutslakhorst authored Oct 16, 2024
1 parent feb6285 commit 7f2e6dd
Show file tree
Hide file tree
Showing 7 changed files with 550 additions and 106 deletions.
98 changes: 94 additions & 4 deletions discovery/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/nuts-foundation/nuts-node/vcr/holder"
"github.com/nuts-foundation/nuts-node/vcr/pe"
"github.com/nuts-foundation/nuts-node/vcr/signature/proof"
"github.com/nuts-foundation/nuts-node/vcr/types"
"github.com/nuts-foundation/nuts-node/vdr/didsubject"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"slices"
Expand All @@ -47,6 +48,10 @@ type clientRegistrationManager interface {
deactivate(ctx context.Context, serviceID, subjectID string) error
// refresh checks which Verifiable Presentations that are about to expire, and should be refreshed on the Discovery Service.
refresh(ctx context.Context, now time.Time) error
// validate validates all presentations that are not yet validated
validate() error
// removeRevoked removes all revoked presentations from the store
removeRevoked() error
}

var _ clientRegistrationManager = &defaultClientRegistrationManager{}
Expand All @@ -58,16 +63,18 @@ type defaultClientRegistrationManager struct {
vcr vcr.VCR
subjectManager didsubject.Manager
didResolver resolver.DIDResolver
verifier presentationVerifier
}

func newRegistrationManager(services map[string]ServiceDefinition, store *sqlStore, client client.HTTPClient, vcr vcr.VCR, subjectManager didsubject.Manager, didResolver resolver.DIDResolver) *defaultClientRegistrationManager {
func newRegistrationManager(services map[string]ServiceDefinition, store *sqlStore, client client.HTTPClient, vcr vcr.VCR, subjectManager didsubject.Manager, didResolver resolver.DIDResolver, verifier presentationVerifier) *defaultClientRegistrationManager {
return &defaultClientRegistrationManager{
services: services,
store: store,
client: client,
vcr: vcr,
subjectManager: subjectManager,
didResolver: didResolver,
verifier: verifier,
}
}

Expand Down Expand Up @@ -330,6 +337,68 @@ func (r *defaultClientRegistrationManager) refresh(ctx context.Context, now time
return nil
}

func (r *defaultClientRegistrationManager) validate() error {
errMsg := "background verification of presentation failed (service: %s, id: %s)"
// find all unvalidated entries in store
presentations, err := r.store.allPresentations(false)
if err != nil {
return err
}
j := 0
for i, presentation := range presentations {
verifiablePresentation, err := vc.ParseVerifiablePresentation(presentation.PresentationRaw)
if err != nil {
log.Logger().WithError(err).Warnf(errMsg, presentation.ServiceID, presentation.ID)
continue
}
service, exists := r.services[presentation.ServiceID]
if !exists {
log.Logger().WithError(err).Warnf("service not found for background validation: %s", presentation.ServiceID)
continue
}
if err = r.verifier(service, *verifiablePresentation); err != nil {
log.Logger().WithError(err).Warnf(errMsg, presentation.ServiceID, presentation.ID)
continue
}
presentations[j] = presentations[i]
j++
}
// update flag in DB
if j > 0 {
return r.store.updateValidated(presentations[:j])
}
return nil
}

func (r *defaultClientRegistrationManager) removeRevoked() error {
errMsg := "background revocation check of presentation failed (id: %s)"
// find all validated entries in store
presentations, err := r.store.allPresentations(true)
if err != nil {
return err
}

for _, presentation := range presentations {
verifiablePresentation, err := vc.ParseVerifiablePresentation(presentation.PresentationRaw)
if err != nil {
log.Logger().WithError(err).Warnf(errMsg, presentation.ID)
continue
}
_, err = r.vcr.Verifier().VerifyVP(*verifiablePresentation, true, true, nil)
if err != nil && !errors.Is(err, types.ErrRevoked) {
log.Logger().WithError(err).Warnf(errMsg, presentation.ID)
continue
}
if errors.Is(err, types.ErrRevoked) {
log.Logger().WithError(err).Infof("removing revoked presentation (id: %s)", presentation.ID)
if err = r.store.deletePresentationRecord(presentation.ID); err != nil {
log.Logger().WithError(err).Warnf("failed to remove revoked presentation from discovery service (id: %s)", presentation.ID)
}
}
}
return nil
}

// clientUpdater is responsible for updating the local copy of Discovery Services
// Callers should only call update().
type clientUpdater struct {
Expand Down Expand Up @@ -377,13 +446,34 @@ func (u *clientUpdater) updateService(ctx context.Context, service ServiceDefini
return fmt.Errorf("failed to wipe on testSeed change (service=%s, testSeed=%s): %w", service.ID, seed, err)
}
for _, presentation := range presentations {
if err := u.verifier(service, presentation); err != nil {
log.Logger().WithError(err).Warnf("Presentation verification failed, not adding it (service=%s, id=%s)", service.ID, presentation.ID)
// Check if the presentation already exists
credentialSubjectID, err := credential.PresentationSigner(presentation)
if err != nil {
return err
}
exists, err := u.store.exists(service.ID, credentialSubjectID.String(), presentation.ID.String())
if err != nil {
return err
}
if exists {
continue
}
if err := u.store.add(service.ID, presentation, seed, serverTimestamp); err != nil {

// always add the presentation, even if it's not valid
// it won't be returned in a search if invalid
// the validator will set the validated flag to true when it's valid
// it'll also remove it from the store if it's invalidated later
if record, err := u.store.add(service.ID, presentation, seed, serverTimestamp); err != nil {
return fmt.Errorf("failed to store presentation (service=%s, id=%s): %w", service.ID, presentation.ID, err)
} else if err = u.verifier(service, presentation); err == nil {
// valid, immediately activate
if err = u.store.updateValidated([]presentationRecord{*record}); err != nil {
return fmt.Errorf("failed to update validated flag (service=%s, id=%s): %w", service.ID, presentation.ID, err)
}
} else {
log.Logger().WithError(err).Infof("failed to verify added presentation (service=%s, id=%s)", service.ID, presentation.ID)
}

log.Logger().
WithField("discoveryService", service.ID).
WithField("presentationID", presentation.ID).
Expand Down
147 changes: 132 additions & 15 deletions discovery/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import (
"github.com/nuts-foundation/nuts-node/vcr/credential"
"github.com/nuts-foundation/nuts-node/vcr/holder"
"github.com/nuts-foundation/nuts-node/vcr/pe"
"github.com/nuts-foundation/nuts-node/vcr/types"
"github.com/nuts-foundation/nuts-node/vcr/verifier"
"github.com/nuts-foundation/nuts-node/vdr/didsubject"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -64,7 +66,7 @@ func newTestContext(t *testing.T) testContext {
wallet := holder.NewMockWallet(ctrl)
subjectManager := didsubject.NewMockManager(ctrl)
store := setupStore(t, storageEngine.GetSQLDatabase())
manager := newRegistrationManager(testDefinitions(), store, invoker, vcr, subjectManager, didResolver)
manager := newRegistrationManager(testDefinitions(), store, invoker, vcr, subjectManager, didResolver, alwaysOkVerifier)
vcr.EXPECT().Wallet().Return(wallet).AnyTimes()

return testContext{
Expand Down Expand Up @@ -181,7 +183,7 @@ func Test_defaultClientRegistrationManager_activate(t *testing.T) {
return &vpAlice, nil
})
ctx.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil)
ctx.manager = newRegistrationManager(emptyDefinition, ctx.store, ctx.invoker, ctx.vcr, ctx.subjectManager, ctx.didResolver)
ctx.manager = newRegistrationManager(emptyDefinition, ctx.store, ctx.invoker, ctx.vcr, ctx.subjectManager, ctx.didResolver, alwaysOkVerifier)

err := ctx.manager.activate(audit.TestContext(), testServiceID, aliceSubject, nil)

Expand Down Expand Up @@ -221,9 +223,10 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) {
ctx.invoker.EXPECT().Register(gomock.Any(), gomock.Any(), gomock.Any())
ctx.wallet.EXPECT().BuildPresentation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(&vpAlice, nil)
ctx.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil)
require.NoError(t, ctx.store.add(testServiceID, vpAlice, testSeed, 1))
_, err := ctx.store.add(testServiceID, vpAlice, testSeed, 1)
require.NoError(t, err)

err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject)
err = ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject)

assert.NoError(t, err)
})
Expand All @@ -236,9 +239,10 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) {
claims["retract_jti"] = vpAlice.ID.String()
vp.Type = append(vp.Type, retractionPresentationType)
}, vcAlice)
require.NoError(t, ctx.store.add(testServiceID, vpAliceDeactivated, testSeed, 1))
_, err := ctx.store.add(testServiceID, vpAliceDeactivated, testSeed, 1)
require.NoError(t, err)

err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject)
err = ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject)

assert.NoError(t, err)
})
Expand All @@ -255,9 +259,10 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) {
ctx.invoker.EXPECT().Register(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("remote error"))
ctx.wallet.EXPECT().BuildPresentation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(&vpAlice, nil)
ctx.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil)
require.NoError(t, ctx.store.add(testServiceID, vpAlice, testSeed, 1))
_, err := ctx.store.add(testServiceID, vpAlice, testSeed, 1)
require.NoError(t, err)

err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject)
err = ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject)

require.ErrorIs(t, err, ErrPresentationRegistrationFailed)
require.ErrorContains(t, err, "remote error")
Expand All @@ -266,9 +271,10 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) {
ctx := newTestContext(t)
ctx.wallet.EXPECT().BuildPresentation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(nil, assert.AnError)
ctx.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil)
require.NoError(t, ctx.store.add(testServiceID, vpAlice, testSeed, 1))
_, err := ctx.store.add(testServiceID, vpAlice, testSeed, 1)
require.NoError(t, err)

err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject)
err = ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject)

assert.ErrorIs(t, err, assert.AnError)
})
Expand Down Expand Up @@ -380,6 +386,104 @@ func Test_defaultClientRegistrationManager_refresh(t *testing.T) {
})
}

func Test_defaultClientRegistrationManager_validate(t *testing.T) {
storageEngine := storage.NewTestStorageEngine(t)
require.NoError(t, storageEngine.Start())

tests := []struct {
name string
setupManager func(ctx testContext) *defaultClientRegistrationManager
expectedLen int
}{
{
name: "ok",
setupManager: func(ctx testContext) *defaultClientRegistrationManager {
return ctx.manager
},
expectedLen: 1,
},
{
name: "verification failed",
setupManager: func(ctx testContext) *defaultClientRegistrationManager {
return newRegistrationManager(testDefinitions(), ctx.store, ctx.invoker, ctx.vcr, ctx.subjectManager, ctx.didResolver, func(service ServiceDefinition, vp vc.VerifiablePresentation) error {
return errors.New("verification failed")
})
},
expectedLen: 0,
},
{
name: "registration for unknown service",
setupManager: func(ctx testContext) *defaultClientRegistrationManager {
return newRegistrationManager(map[string]ServiceDefinition{}, ctx.store, ctx.invoker, ctx.vcr, ctx.subjectManager, ctx.didResolver, alwaysOkVerifier)
},
expectedLen: 0,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := newTestContext(t)
_, err := ctx.store.add(testServiceID, vpAlice, testSeed, 1)
require.NoError(t, err)
manager := tt.setupManager(ctx)

err = manager.validate()
require.NoError(t, err)

presentations, err := ctx.store.allPresentations(true)
require.NoError(t, err)
assert.Len(t, presentations, tt.expectedLen)
})
}
}

func Test_defaultClientRegistrationManager_removeRevoked(t *testing.T) {
storageEngine := storage.NewTestStorageEngine(t)
require.NoError(t, storageEngine.Start())

tests := []struct {
name string
verifyVPError error
expectedLen int
}{
{
name: "ok - not revoked",
verifyVPError: nil,
expectedLen: 1,
},
{
name: "ok - revoked",
verifyVPError: types.ErrRevoked,
expectedLen: 0,
},
{
name: "error",
verifyVPError: assert.AnError,
expectedLen: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := newTestContext(t)
_, err := ctx.store.add(testServiceID, vpAlice, testSeed, 1)
require.NoError(t, err)
require.NoError(t, ctx.manager.validate())

mockVerifier := verifier.NewMockVerifier(ctx.ctrl)
ctx.vcr.EXPECT().Verifier().Return(mockVerifier).AnyTimes()
mockVerifier.EXPECT().VerifyVP(gomock.Any(), true, true, nil).Return(nil, tt.verifyVPError)

err = ctx.manager.removeRevoked()
require.NoError(t, err)

presentations, err := ctx.store.allPresentations(true)
require.NoError(t, err)
assert.Len(t, presentations, tt.expectedLen)
})
}
}

func Test_clientUpdater_updateService(t *testing.T) {
storageEngine := storage.NewTestStorageEngine(t)
require.NoError(t, storageEngine.Start())
Expand Down Expand Up @@ -408,11 +512,21 @@ func Test_clientUpdater_updateService(t *testing.T) {

httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 0).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, testSeed, 1, nil)

err := updater.updateService(ctx, testDefinitions()[testServiceID])
require.NoError(t, updater.updateService(ctx, testDefinitions()[testServiceID]))

require.NoError(t, err)
t.Run("ignores duplicates", func(t *testing.T) {
httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 1).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, testSeed, 1, nil)

require.NoError(t, updater.updateService(ctx, testDefinitions()[testServiceID]))

// check count
presentation, err := updater.store.allPresentations(true)

require.NoError(t, err)
assert.Len(t, presentation, 1)
})
})
t.Run("ignores invalid presentations", func(t *testing.T) {
t.Run("allows invalid presentations", func(t *testing.T) {
resetStore(t, storageEngine.GetSQLDatabase())
ctrl := gomock.NewController(t)
httpClient := client.NewMockHTTPClient(ctrl)
Expand All @@ -428,13 +542,16 @@ func Test_clientUpdater_updateService(t *testing.T) {
err := updater.updateService(ctx, testDefinitions()[testServiceID])

require.NoError(t, err)
// Bob's VP should exist, Alice's not
// Both should exist, 1 should be validated immediately
exists, err := store.exists(testServiceID, bobDID.String(), vpBob.ID.String())
require.NoError(t, err)
require.True(t, exists)
exists, err = store.exists(testServiceID, aliceDID.String(), vpAlice.ID.String())
require.NoError(t, err)
require.False(t, exists)
require.True(t, exists)
validated, err := store.allPresentations(true)
require.NoError(t, err)
require.Len(t, validated, 1)
})
t.Run("pass timestamp", func(t *testing.T) {
resetStore(t, storageEngine.GetSQLDatabase())
Expand Down
Loading

0 comments on commit 7f2e6dd

Please sign in to comment.