diff --git a/discovery/client.go b/discovery/client.go index e22d4b351..08fc8a4d6 100644 --- a/discovery/client.go +++ b/discovery/client.go @@ -33,6 +33,7 @@ import ( "github.com/nuts-foundation/nuts-node/vcr/holder" "github.com/nuts-foundation/nuts-node/vcr/pe" "github.com/nuts-foundation/nuts-node/vcr/signature/proof" + "github.com/nuts-foundation/nuts-node/vcr/types" "github.com/nuts-foundation/nuts-node/vdr/didsubject" "github.com/nuts-foundation/nuts-node/vdr/resolver" "slices" @@ -47,6 +48,10 @@ type clientRegistrationManager interface { deactivate(ctx context.Context, serviceID, subjectID string) error // refresh checks which Verifiable Presentations that are about to expire, and should be refreshed on the Discovery Service. refresh(ctx context.Context, now time.Time) error + // validate validates all presentations that are not yet validated + validate() error + // removeRevoked removes all revoked presentations from the store + removeRevoked() error } var _ clientRegistrationManager = &defaultClientRegistrationManager{} @@ -58,9 +63,10 @@ type defaultClientRegistrationManager struct { vcr vcr.VCR subjectManager didsubject.Manager didResolver resolver.DIDResolver + verifier presentationVerifier } -func newRegistrationManager(services map[string]ServiceDefinition, store *sqlStore, client client.HTTPClient, vcr vcr.VCR, subjectManager didsubject.Manager, didResolver resolver.DIDResolver) *defaultClientRegistrationManager { +func newRegistrationManager(services map[string]ServiceDefinition, store *sqlStore, client client.HTTPClient, vcr vcr.VCR, subjectManager didsubject.Manager, didResolver resolver.DIDResolver, verifier presentationVerifier) *defaultClientRegistrationManager { return &defaultClientRegistrationManager{ services: services, store: store, @@ -68,6 +74,7 @@ func newRegistrationManager(services map[string]ServiceDefinition, store *sqlSto vcr: vcr, subjectManager: subjectManager, didResolver: didResolver, + verifier: verifier, } } @@ -330,6 +337,68 @@ func (r *defaultClientRegistrationManager) refresh(ctx context.Context, now time return nil } +func (r *defaultClientRegistrationManager) validate() error { + errMsg := "background verification of presentation failed (service: %s, id: %s)" + // find all unvalidated entries in store + presentations, err := r.store.allPresentations(false) + if err != nil { + return err + } + j := 0 + for i, presentation := range presentations { + verifiablePresentation, err := vc.ParseVerifiablePresentation(presentation.PresentationRaw) + if err != nil { + log.Logger().WithError(err).Warnf(errMsg, presentation.ServiceID, presentation.ID) + continue + } + service, exists := r.services[presentation.ServiceID] + if !exists { + log.Logger().WithError(err).Warnf("service not found for background validation: %s", presentation.ServiceID) + continue + } + if err = r.verifier(service, *verifiablePresentation); err != nil { + log.Logger().WithError(err).Warnf(errMsg, presentation.ServiceID, presentation.ID) + continue + } + presentations[j] = presentations[i] + j++ + } + // update flag in DB + if j > 0 { + return r.store.updateValidated(presentations[:j]) + } + return nil +} + +func (r *defaultClientRegistrationManager) removeRevoked() error { + errMsg := "background revocation check of presentation failed (id: %s)" + // find all validated entries in store + presentations, err := r.store.allPresentations(true) + if err != nil { + return err + } + + for _, presentation := range presentations { + verifiablePresentation, err := vc.ParseVerifiablePresentation(presentation.PresentationRaw) + if err != nil { + log.Logger().WithError(err).Warnf(errMsg, presentation.ID) + continue + } + _, err = r.vcr.Verifier().VerifyVP(*verifiablePresentation, true, true, nil) + if err != nil && !errors.Is(err, types.ErrRevoked) { + log.Logger().WithError(err).Warnf(errMsg, presentation.ID) + continue + } + if errors.Is(err, types.ErrRevoked) { + log.Logger().WithError(err).Infof("removing revoked presentation (id: %s)", presentation.ID) + if err = r.store.deletePresentationRecord(presentation.ID); err != nil { + log.Logger().WithError(err).Warnf("failed to remove revoked presentation from discovery service (id: %s)", presentation.ID) + } + } + } + return nil +} + // clientUpdater is responsible for updating the local copy of Discovery Services // Callers should only call update(). type clientUpdater struct { @@ -377,13 +446,34 @@ func (u *clientUpdater) updateService(ctx context.Context, service ServiceDefini return fmt.Errorf("failed to wipe on testSeed change (service=%s, testSeed=%s): %w", service.ID, seed, err) } for _, presentation := range presentations { - if err := u.verifier(service, presentation); err != nil { - log.Logger().WithError(err).Warnf("Presentation verification failed, not adding it (service=%s, id=%s)", service.ID, presentation.ID) + // Check if the presentation already exists + credentialSubjectID, err := credential.PresentationSigner(presentation) + if err != nil { + return err + } + exists, err := u.store.exists(service.ID, credentialSubjectID.String(), presentation.ID.String()) + if err != nil { + return err + } + if exists { continue } - if err := u.store.add(service.ID, presentation, seed, serverTimestamp); err != nil { + + // always add the presentation, even if it's not valid + // it won't be returned in a search if invalid + // the validator will set the validated flag to true when it's valid + // it'll also remove it from the store if it's invalidated later + if record, err := u.store.add(service.ID, presentation, seed, serverTimestamp); err != nil { return fmt.Errorf("failed to store presentation (service=%s, id=%s): %w", service.ID, presentation.ID, err) + } else if err = u.verifier(service, presentation); err == nil { + // valid, immediately activate + if err = u.store.updateValidated([]presentationRecord{*record}); err != nil { + return fmt.Errorf("failed to update validated flag (service=%s, id=%s): %w", service.ID, presentation.ID, err) + } + } else { + log.Logger().WithError(err).Infof("failed to verify added presentation (service=%s, id=%s)", service.ID, presentation.ID) } + log.Logger(). WithField("discoveryService", service.ID). WithField("presentationID", presentation.ID). diff --git a/discovery/client_test.go b/discovery/client_test.go index 0ab14a8c4..b3209a43a 100644 --- a/discovery/client_test.go +++ b/discovery/client_test.go @@ -31,6 +31,8 @@ import ( "github.com/nuts-foundation/nuts-node/vcr/credential" "github.com/nuts-foundation/nuts-node/vcr/holder" "github.com/nuts-foundation/nuts-node/vcr/pe" + "github.com/nuts-foundation/nuts-node/vcr/types" + "github.com/nuts-foundation/nuts-node/vcr/verifier" "github.com/nuts-foundation/nuts-node/vdr/didsubject" "github.com/nuts-foundation/nuts-node/vdr/resolver" "github.com/stretchr/testify/assert" @@ -64,7 +66,7 @@ func newTestContext(t *testing.T) testContext { wallet := holder.NewMockWallet(ctrl) subjectManager := didsubject.NewMockManager(ctrl) store := setupStore(t, storageEngine.GetSQLDatabase()) - manager := newRegistrationManager(testDefinitions(), store, invoker, vcr, subjectManager, didResolver) + manager := newRegistrationManager(testDefinitions(), store, invoker, vcr, subjectManager, didResolver, alwaysOkVerifier) vcr.EXPECT().Wallet().Return(wallet).AnyTimes() return testContext{ @@ -181,7 +183,7 @@ func Test_defaultClientRegistrationManager_activate(t *testing.T) { return &vpAlice, nil }) ctx.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil) - ctx.manager = newRegistrationManager(emptyDefinition, ctx.store, ctx.invoker, ctx.vcr, ctx.subjectManager, ctx.didResolver) + ctx.manager = newRegistrationManager(emptyDefinition, ctx.store, ctx.invoker, ctx.vcr, ctx.subjectManager, ctx.didResolver, alwaysOkVerifier) err := ctx.manager.activate(audit.TestContext(), testServiceID, aliceSubject, nil) @@ -221,9 +223,10 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) { ctx.invoker.EXPECT().Register(gomock.Any(), gomock.Any(), gomock.Any()) ctx.wallet.EXPECT().BuildPresentation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(&vpAlice, nil) ctx.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil) - require.NoError(t, ctx.store.add(testServiceID, vpAlice, testSeed, 1)) + _, err := ctx.store.add(testServiceID, vpAlice, testSeed, 1) + require.NoError(t, err) - err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) + err = ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) assert.NoError(t, err) }) @@ -236,9 +239,10 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) { claims["retract_jti"] = vpAlice.ID.String() vp.Type = append(vp.Type, retractionPresentationType) }, vcAlice) - require.NoError(t, ctx.store.add(testServiceID, vpAliceDeactivated, testSeed, 1)) + _, err := ctx.store.add(testServiceID, vpAliceDeactivated, testSeed, 1) + require.NoError(t, err) - err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) + err = ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) assert.NoError(t, err) }) @@ -255,9 +259,10 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) { ctx.invoker.EXPECT().Register(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("remote error")) ctx.wallet.EXPECT().BuildPresentation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(&vpAlice, nil) ctx.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil) - require.NoError(t, ctx.store.add(testServiceID, vpAlice, testSeed, 1)) + _, err := ctx.store.add(testServiceID, vpAlice, testSeed, 1) + require.NoError(t, err) - err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) + err = ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) require.ErrorIs(t, err, ErrPresentationRegistrationFailed) require.ErrorContains(t, err, "remote error") @@ -266,9 +271,10 @@ func Test_defaultClientRegistrationManager_deactivate(t *testing.T) { ctx := newTestContext(t) ctx.wallet.EXPECT().BuildPresentation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(nil, assert.AnError) ctx.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil) - require.NoError(t, ctx.store.add(testServiceID, vpAlice, testSeed, 1)) + _, err := ctx.store.add(testServiceID, vpAlice, testSeed, 1) + require.NoError(t, err) - err := ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) + err = ctx.manager.deactivate(audit.TestContext(), testServiceID, aliceSubject) assert.ErrorIs(t, err, assert.AnError) }) @@ -380,6 +386,104 @@ func Test_defaultClientRegistrationManager_refresh(t *testing.T) { }) } +func Test_defaultClientRegistrationManager_validate(t *testing.T) { + storageEngine := storage.NewTestStorageEngine(t) + require.NoError(t, storageEngine.Start()) + + tests := []struct { + name string + setupManager func(ctx testContext) *defaultClientRegistrationManager + expectedLen int + }{ + { + name: "ok", + setupManager: func(ctx testContext) *defaultClientRegistrationManager { + return ctx.manager + }, + expectedLen: 1, + }, + { + name: "verification failed", + setupManager: func(ctx testContext) *defaultClientRegistrationManager { + return newRegistrationManager(testDefinitions(), ctx.store, ctx.invoker, ctx.vcr, ctx.subjectManager, ctx.didResolver, func(service ServiceDefinition, vp vc.VerifiablePresentation) error { + return errors.New("verification failed") + }) + }, + expectedLen: 0, + }, + { + name: "registration for unknown service", + setupManager: func(ctx testContext) *defaultClientRegistrationManager { + return newRegistrationManager(map[string]ServiceDefinition{}, ctx.store, ctx.invoker, ctx.vcr, ctx.subjectManager, ctx.didResolver, alwaysOkVerifier) + }, + expectedLen: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := newTestContext(t) + _, err := ctx.store.add(testServiceID, vpAlice, testSeed, 1) + require.NoError(t, err) + manager := tt.setupManager(ctx) + + err = manager.validate() + require.NoError(t, err) + + presentations, err := ctx.store.allPresentations(true) + require.NoError(t, err) + assert.Len(t, presentations, tt.expectedLen) + }) + } +} + +func Test_defaultClientRegistrationManager_removeRevoked(t *testing.T) { + storageEngine := storage.NewTestStorageEngine(t) + require.NoError(t, storageEngine.Start()) + + tests := []struct { + name string + verifyVPError error + expectedLen int + }{ + { + name: "ok - not revoked", + verifyVPError: nil, + expectedLen: 1, + }, + { + name: "ok - revoked", + verifyVPError: types.ErrRevoked, + expectedLen: 0, + }, + { + name: "error", + verifyVPError: assert.AnError, + expectedLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := newTestContext(t) + _, err := ctx.store.add(testServiceID, vpAlice, testSeed, 1) + require.NoError(t, err) + require.NoError(t, ctx.manager.validate()) + + mockVerifier := verifier.NewMockVerifier(ctx.ctrl) + ctx.vcr.EXPECT().Verifier().Return(mockVerifier).AnyTimes() + mockVerifier.EXPECT().VerifyVP(gomock.Any(), true, true, nil).Return(nil, tt.verifyVPError) + + err = ctx.manager.removeRevoked() + require.NoError(t, err) + + presentations, err := ctx.store.allPresentations(true) + require.NoError(t, err) + assert.Len(t, presentations, tt.expectedLen) + }) + } +} + func Test_clientUpdater_updateService(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) @@ -408,11 +512,21 @@ func Test_clientUpdater_updateService(t *testing.T) { httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 0).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, testSeed, 1, nil) - err := updater.updateService(ctx, testDefinitions()[testServiceID]) + require.NoError(t, updater.updateService(ctx, testDefinitions()[testServiceID])) - require.NoError(t, err) + t.Run("ignores duplicates", func(t *testing.T) { + httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 1).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, testSeed, 1, nil) + + require.NoError(t, updater.updateService(ctx, testDefinitions()[testServiceID])) + + // check count + presentation, err := updater.store.allPresentations(true) + + require.NoError(t, err) + assert.Len(t, presentation, 1) + }) }) - t.Run("ignores invalid presentations", func(t *testing.T) { + t.Run("allows invalid presentations", func(t *testing.T) { resetStore(t, storageEngine.GetSQLDatabase()) ctrl := gomock.NewController(t) httpClient := client.NewMockHTTPClient(ctrl) @@ -428,13 +542,16 @@ func Test_clientUpdater_updateService(t *testing.T) { err := updater.updateService(ctx, testDefinitions()[testServiceID]) require.NoError(t, err) - // Bob's VP should exist, Alice's not + // Both should exist, 1 should be validated immediately exists, err := store.exists(testServiceID, bobDID.String(), vpBob.ID.String()) require.NoError(t, err) require.True(t, exists) exists, err = store.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) require.NoError(t, err) - require.False(t, exists) + require.True(t, exists) + validated, err := store.allPresentations(true) + require.NoError(t, err) + require.Len(t, validated, 1) }) t.Run("pass timestamp", func(t *testing.T) { resetStore(t, storageEngine.GetSQLDatabase()) diff --git a/discovery/module.go b/discovery/module.go index a34f45fed..a231d47ab 100644 --- a/discovery/module.go +++ b/discovery/module.go @@ -152,7 +152,7 @@ func (m *Module) Start() error { return err } m.clientUpdater = newClientUpdater(m.allDefinitions, m.store, m.verifyRegistration, m.httpClient) - m.registrationManager = newRegistrationManager(m.allDefinitions, m.store, m.httpClient, m.vcrInstance, m.subjectManager, m.didResolver) + m.registrationManager = newRegistrationManager(m.allDefinitions, m.store, m.httpClient, m.vcrInstance, m.subjectManager, m.didResolver, m.verifyRegistration) if m.config.Client.RefreshInterval > 0 { m.routines.Add(1) go func() { @@ -203,7 +203,28 @@ func (m *Module) Register(context context.Context, serviceID string, presentatio return err } - return m.store.add(serviceID, presentation, "", 0) + // Check if the presentation already exists + credentialSubjectID, err := credential.PresentationSigner(presentation) + if err != nil { + return err + } + exists, err := m.store.exists(definition.ID, credentialSubjectID.String(), presentation.ID.String()) + if err != nil { + return err + } + if exists { + return errors.Join(ErrInvalidPresentation, ErrPresentationAlreadyExists) + } + record, err := m.store.add(serviceID, presentation, "", 0) + if err != nil { + return err + } + // also update validated flag since validation is already done + if err = m.store.updateValidated([]presentationRecord{*record}); err != nil { + log.Logger().WithError(err).Errorf("failed to update validated flag for presentation (id: %s)", record.ID) + } + + return nil } func (m *Module) verifyRegistration(definition ServiceDefinition, presentation vc.VerifiablePresentation) error { @@ -235,15 +256,7 @@ func (m *Module) verifyRegistration(definition ServiceDefinition, presentation v return errors.Join(ErrInvalidPresentation, ErrDIDMethodsNotSupported) } - // Check if the presentation already exists - exists, err := m.store.exists(definition.ID, credentialSubjectID.String(), presentation.ID.String()) - if err != nil { - return err - } - if exists { - return errors.Join(ErrInvalidPresentation, ErrPresentationAlreadyExists) - } - // Depending on the presentation type, we need to validate different properties before storing it. + // Depending on the presentation type, we need to updateValidated different properties before storing it. if presentation.IsType(retractionPresentationType) { err = m.validateRetraction(definition.ID, presentation) } else { @@ -484,7 +497,7 @@ func (m *Module) Search(serviceID string, query map[string]string) ([]SearchResu if !exists { return nil, ErrServiceNotFound } - matchingVPs, err := m.store.search(serviceID, query) + matchingVPs, err := m.store.search(serviceID, query, false) if err != nil { return nil, err } @@ -557,6 +570,16 @@ func (m *Module) update() { if err != nil { log.Logger().WithError(err).Errorf("Failed to load latest Verifiable Presentations from Discovery Service") } + // updateValidated all presentations not yet validated + err = m.registrationManager.validate() + if err != nil { + log.Logger().WithError(err).Errorf("Failed to validate presentations") + } + // purge list + err = m.registrationManager.removeRevoked() + if err != nil { + log.Logger().WithError(err).Errorf("Failed to remove revoked presentations") + } } do() for { diff --git a/discovery/module_test.go b/discovery/module_test.go index 23f5f436e..f0569a9fc 100644 --- a/discovery/module_test.go +++ b/discovery/module_test.go @@ -41,6 +41,8 @@ import ( "go.uber.org/mock/gomock" "gorm.io/gorm" "os" + "sync" + "sync/atomic" "testing" "time" ) @@ -61,8 +63,10 @@ func Test_Module_Register(t *testing.T) { t.Run("registration", func(t *testing.T) { t.Run("ok", func(t *testing.T) { - m, testContext := setupModule(t, storageEngine) - testContext.verifier.EXPECT().VerifyVP(gomock.Any(), true, true, nil) + m, testContext := setupModule(t, storageEngine, func(module *Module) { + module.config.Client.RefreshInterval = 0 + }) + testContext.verifier.EXPECT().VerifyVP(gomock.Any(), true, true, nil).Times(2) err := m.Register(ctx, testServiceID, vpAlice) require.NoError(t, err) @@ -262,8 +266,11 @@ func Test_Module_Get(t *testing.T) { require.NoError(t, storageEngine.Start()) ctx := context.Background() t.Run("ok", func(t *testing.T) { - m, _ := setupModule(t, storageEngine) - require.NoError(t, m.store.add(testServiceID, vpAlice, testSeed, 0)) + m, _ := setupModule(t, storageEngine, func(module *Module) { + module.config.Client.RefreshInterval = 0 + }) + _, err := m.store.add(testServiceID, vpAlice, testSeed, 1) + require.NoError(t, err) presentations, seed, timestamp, err := m.Get(ctx, testServiceID, 0) assert.NoError(t, err) assert.Equal(t, map[string]vc.VerifiablePresentation{"1": vpAlice}, presentations) @@ -441,9 +448,13 @@ func TestModule_Search(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) t.Run("ok", func(t *testing.T) { - m, _ := setupModule(t, storageEngine) - - require.NoError(t, m.store.add(testServiceID, vpAlice, testSeed, 0)) + m, ctx := setupModule(t, storageEngine, func(module *Module) { + module.config.Client.RefreshInterval = 0 + }) + ctx.verifier.EXPECT().VerifyVP(gomock.Any(), true, true, nil) + _, err := m.store.add(testServiceID, vpAlice, testSeed, 1) + require.NoError(t, err) + require.NoError(t, m.registrationManager.validate()) results, err := m.Search(testServiceID, map[string]string{ "credentialSubject.person.givenName": "Alice", @@ -471,29 +482,78 @@ func TestModule_Search(t *testing.T) { func TestModule_update(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) - t.Run("Start() initiates update", func(t *testing.T) { - _, _ = setupModule(t, storageEngine, func(module *Module) { - // we want to assert the job runs, so make it run very often to make the test faster - module.config.Client.RefreshInterval = 1 * time.Millisecond - // overwrite httpClient mock for custom behavior assertions (we want to know how often HttpClient.Get() was called) - httpClient := client.NewMockHTTPClient(gomock.NewController(t)) - // Get() should be called at least twice (times the number of Service Definitions), once for the initial run on startup, then again after the refresh interval - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, "", 0, nil).MinTimes(2 * len(module.allDefinitions)) - module.httpClient = httpClient - }) - time.Sleep(10 * time.Millisecond) - }) - t.Run("update() runs on node startup", func(t *testing.T) { - _, _ = setupModule(t, storageEngine, func(module *Module) { - // we want to assert the job immediately executes on node startup, even if the refresh interval hasn't passed - module.config.Client.RefreshInterval = time.Hour - // overwrite httpClient mock for custom behavior assertions (we want to know how often HttpClient.Get() was called) - httpClient := client.NewMockHTTPClient(gomock.NewController(t)) - // update causes call to HttpClient.Get(), once for each Service Definition - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, "", 0, nil).Times(len(module.allDefinitions)) - module.httpClient = httpClient + + tests := []struct { + name string + refreshInterval time.Duration + expectedHTTPCalls int + expectedVerifyVPCalls int + }{ + { + name: "Start() initiates update", + refreshInterval: time.Millisecond, + expectedHTTPCalls: 2, + expectedVerifyVPCalls: 4, + }, + { + name: "update() runs on node startup", + refreshInterval: time.Hour, + expectedHTTPCalls: 1, + expectedVerifyVPCalls: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetStore(t, storageEngine.GetSQLDatabase()) + ctrl := gomock.NewController(t) + mockVerifier := verifier.NewMockVerifier(ctrl) + mockVCR := vcr.NewMockVCR(ctrl) + mockVCR.EXPECT().Verifier().Return(mockVerifier).AnyTimes() + m := New(storageEngine, mockVCR, nil, nil) + m.config = DefaultConfig() + m.publicURL = test.MustParseURL("https://example.com") + m.config.Client.RefreshInterval = tt.refreshInterval + require.NoError(t, m.Configure(core.TestServerConfig())) + m.allDefinitions = testDefinitions() + httpClient := client.NewMockHTTPClient(ctrl) + httpWg := sync.WaitGroup{} + httpWg.Add(tt.expectedHTTPCalls * len(m.allDefinitions)) + httpCounter := atomic.Int64{} + httpCounter.Add(int64(tt.expectedHTTPCalls * len(m.allDefinitions))) + httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _, _ interface{}) (map[string]vc.VerifiablePresentation, string, int, error) { + if httpCounter.Load() != int64(0) { + httpWg.Done() + httpCounter.Add(int64(-1)) + } + return nil, testSeed, 0, nil + }).MinTimes(tt.expectedHTTPCalls * len(m.allDefinitions)) + m.httpClient = httpClient + m.store, _ = newSQLStore(m.storageInstance.GetSQLDatabase(), m.allDefinitions) + vpWg := sync.WaitGroup{} + vpWg.Add(tt.expectedVerifyVPCalls) + vpCounter := atomic.Int64{} + vpCounter.Add(int64(tt.expectedVerifyVPCalls)) + mockVerifier.EXPECT().VerifyVP(gomock.Any(), true, true, nil).DoAndReturn(func(_, _, _, _ interface{}) ([]vc.VerifiableCredential, error) { + if vpCounter.Load() != int64(0) { + vpWg.Done() + vpCounter.Add(int64(-1)) + } + return nil, nil + }).MinTimes(tt.expectedVerifyVPCalls) + _, err := m.store.add(testServiceID, vpAlice, testSeed, 1) + require.NoError(t, err) + + require.NoError(t, m.Start()) + + vpWg.Wait() + httpWg.Wait() + + t.Cleanup(func() { + _ = m.Shutdown() + }) }) - }) + } } func TestModule_ActivateServiceForSubject(t *testing.T) { @@ -627,11 +687,14 @@ func TestModule_GetServiceActivation(t *testing.T) { assert.Nil(t, presentation) }) t.Run("activated, with VP", func(t *testing.T) { - m, testContext := setupModule(t, storageEngine) + m, testContext := setupModule(t, storageEngine, func(module *Module) { + module.config.Client.RefreshInterval = 0 + }) testContext.subjectManager.EXPECT().ListDIDs(gomock.Any(), aliceSubject).Return([]did.DID{aliceDID}, nil).AnyTimes() next := time.Now() _ = m.store.updatePresentationRefreshTime(testServiceID, aliceSubject, nil, &next) - _ = m.store.add(testServiceID, vpAlice, testSeed, 0) + _, err := m.store.add(testServiceID, vpAlice, testSeed, 1) + require.NoError(t, err) activated, presentation, err := m.GetServiceActivation(context.Background(), testServiceID, aliceSubject) diff --git a/discovery/store.go b/discovery/store.go index 9859d0b67..9f7ae7e79 100644 --- a/discovery/store.go +++ b/discovery/store.go @@ -19,6 +19,7 @@ package discovery import ( + "database/sql/driver" "encoding/json" "errors" "fmt" @@ -48,6 +49,8 @@ func (s serviceRecord) TableName() string { var _ schema.Tabler = (*presentationRecord)(nil) +type SQLBool bool + type presentationRecord struct { ID string `gorm:"primaryKey"` ServiceID string @@ -56,6 +59,7 @@ type presentationRecord struct { PresentationID string PresentationRaw string PresentationExpiration int64 + Validated SQLBool Credentials []credentialRecord `gorm:"foreignKey:PresentationID;references:ID"` } @@ -63,6 +67,30 @@ func (s presentationRecord) TableName() string { return "discovery_presentation" } +func (b *SQLBool) Scan(value interface{}) error { + *b = false + if value != nil { + switch v := value.(type) { + case int64: + if v != 0 { + *b = true + } + } + } + return nil +} + +func (b SQLBool) Value() (driver.Value, error) { + if b { + return int64(1), nil + } + return int64(0), nil +} + +func (b SQLBool) Bool() bool { + return bool(b) +} + // credentialRecord is a Verifiable Credential, part of a presentation (entry) on a use case list. type credentialRecord struct { // ID is the unique identifier of the entry. @@ -136,15 +164,16 @@ func newSQLStore(db *gorm.DB, clientDefinitions map[string]ServiceDefinition) (* // add adds a presentation to the list of presentations. // If the given timestamp is 0, the server will assign a timestamp. -func (s *sqlStore) add(serviceID string, presentation vc.VerifiablePresentation, seed string, timestamp int) error { +func (s *sqlStore) add(serviceID string, presentation vc.VerifiablePresentation, seed string, timestamp int) (*presentationRecord, error) { credentialSubjectID, err := credential.PresentationSigner(presentation) if err != nil { - return err + return nil, err } if err := s.prune(); err != nil { - return err + return nil, err } - return s.db.Transaction(func(tx *gorm.DB) error { + var newPresentation *presentationRecord + return newPresentation, s.db.Transaction(func(tx *gorm.DB) error { if timestamp == 0 { var newTs *int if len(seed) == 0 { // default for server @@ -167,15 +196,16 @@ func (s *sqlStore) add(serviceID string, presentation vc.VerifiablePresentation, return err } - return storePresentation(tx, serviceID, timestamp, presentation) + newPresentation, err = storePresentation(tx, serviceID, timestamp, presentation) + return err }) } // storePresentation creates a presentationRecord from a VerifiablePresentation and stores it, with its credentials, in the database. -func storePresentation(tx *gorm.DB, serviceID string, timestamp int, presentation vc.VerifiablePresentation) error { +func storePresentation(tx *gorm.DB, serviceID string, timestamp int, presentation vc.VerifiablePresentation) (*presentationRecord, error) { credentialSubjectID, err := credential.PresentationSigner(presentation) if err != nil { - return err + return nil, err } newPresentation := presentationRecord{ @@ -192,7 +222,7 @@ func storePresentation(tx *gorm.DB, serviceID string, timestamp int, presentatio for _, verifiableCredential := range presentation.VerifiableCredential { cred, err := credentialStore.Store(tx, verifiableCredential) if err != nil { - return err + return nil, err } newPresentation.Credentials = append(newPresentation.Credentials, credentialRecord{ ID: uuid.NewString(), @@ -201,7 +231,8 @@ func storePresentation(tx *gorm.DB, serviceID string, timestamp int, presentatio }) } - return tx.Create(&newPresentation).Error + err = tx.Create(&newPresentation).Error + return &newPresentation, err } // get returns all presentations, registered on the given service, starting after the given timestamp. @@ -232,11 +263,14 @@ func (s *sqlStore) get(serviceID string, startAfter int) (map[string]vc.Verifiab // The query is a map of JSON paths and expected string values, matched against the presentation's credentials. // Wildcard matching is supported by prefixing or suffixing the value with an asterisk (*). // It returns the presentations which contain credentials that match the given query. -func (s *sqlStore) search(serviceID string, query map[string]string) ([]vc.VerifiablePresentation, error) { +func (s *sqlStore) search(serviceID string, query map[string]string, allowUnvalidated bool) ([]vc.VerifiablePresentation, error) { // first only select columns also used in group by clause // if the query is empty, there's no need to do a join stmt := s.db.Model(&presentationRecord{}). Where("service_id = ?", serviceID) + if !allowUnvalidated { + stmt = stmt.Where("validated != 0") + } if len(query) > 0 { stmt = stmt.Joins("inner join discovery_credential ON discovery_credential.presentation_id = discovery_presentation.id") stmt = store.CredentialStore{}.BuildSearchStatement(stmt, "discovery_credential.credential_id", query) @@ -344,6 +378,41 @@ func (s *sqlStore) removeExpired() (int, error) { return int(result.RowsAffected), nil } +// allPresentations returns all presentations, the validated param can be used to select validated or unvalidated presentations +func (s *sqlStore) allPresentations(validated bool) ([]presentationRecord, error) { + result := make([]presentationRecord, 0) + stmt := s.db + if validated { + stmt = stmt.Where("validated != 0") + } else { + stmt = stmt.Where("validated = 0") + } + err := stmt.Find(&result).Error + if err != nil { + return nil, err + } + return result, nil +} + +// updateValidated sets the validated flag for the given presentations +func (s *sqlStore) updateValidated(records []presentationRecord) error { + return s.db.Transaction(func(tx *gorm.DB) error { + for _, record := range records { + if err := tx.Model(&presentationRecord{}).Where("id = ?", record.ID).Update("validated", true).Error; err != nil { + return err + } + } + return nil + }) +} + +// deletePresentationRecord removes a presentationRecord from the store based on its ID +func (s *sqlStore) deletePresentationRecord(id string) error { + return s.db.Transaction(func(tx *gorm.DB) error { + return tx.Delete(&presentationRecord{}, "id = ?", id).Error + }) +} + // updatePresentationRefreshTime creates/updates the next refresh time for a Verifiable Presentation on a Discovery Service. // If nextRegistration is nil, the entry will be removed from the database. func (s *sqlStore) updatePresentationRefreshTime(serviceID string, subjectID string, parameters map[string]interface{}, nextRefresh *time.Time) error { @@ -466,7 +535,7 @@ func (s *sqlStore) getSubjectVPsOnService(serviceID string, subjectDIDs []did.DI for _, subjectDID := range subjectDIDs { loopVPs, err := s.search(serviceID, map[string]string{ "credentialSubject.id": subjectDID.String(), - }) + }, true) if err != nil { return nil, err } diff --git a/discovery/store_test.go b/discovery/store_test.go index 812bba212..9daa0c9f6 100644 --- a/discovery/store_test.go +++ b/discovery/store_test.go @@ -45,21 +45,24 @@ func Test_sqlStore_exists(t *testing.T) { }) t.Run("non-empty list, no match (other subject and ID)", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpBob, testSeed, 0)) + _, err := m.add(testServiceID, vpBob, testSeed, 0) + require.NoError(t, err) exists, err := m.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) assert.NoError(t, err) assert.False(t, exists) }) t.Run("non-empty list, no match (other list)", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) + _, err := m.add(testServiceID, vpAlice, testSeed, 0) + require.NoError(t, err) exists, err := m.exists("other", aliceDID.String(), vpAlice.ID.String()) assert.NoError(t, err) assert.False(t, exists) }) t.Run("non-empty list, match", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) + _, err := m.add(testServiceID, vpAlice, testSeed, 0) + require.NoError(t, err) exists, err := m.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) assert.NoError(t, err) assert.True(t, exists) @@ -72,14 +75,15 @@ func Test_sqlStore_add(t *testing.T) { t.Run("no credentials in presentation", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - err := m.add(testServiceID, createPresentation(aliceDID), testSeed, 0) + _, err := m.add(testServiceID, createPresentation(aliceDID), testSeed, 0) assert.NoError(t, err) }) t.Run("seed", func(t *testing.T) { t.Run("passing seed updates last_seed", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, createPresentation(aliceDID), testSeed, 0)) + _, err := m.add(testServiceID, createPresentation(aliceDID), testSeed, 0) + require.NoError(t, err) _, seed, _, err := m.get(testServiceID, 0) @@ -88,7 +92,8 @@ func Test_sqlStore_add(t *testing.T) { }) t.Run("generated seed", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, createPresentation(aliceDID), "", 0)) + _, err := m.add(testServiceID, createPresentation(aliceDID), "", 0) + require.NoError(t, err) _, seed, _, err := m.get(testServiceID, 0) @@ -99,7 +104,7 @@ func Test_sqlStore_add(t *testing.T) { t.Run("passing timestamp updates last_timestamp", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - err := m.add(testServiceID, createPresentation(aliceDID), testSeed, 1) + _, err := m.add(testServiceID, createPresentation(aliceDID), testSeed, 1) require.NoError(t, err) timestamp, err := m.getTimestamp(testServiceID) @@ -112,8 +117,10 @@ func Test_sqlStore_add(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) secondVP := createPresentation(aliceDID, vcAlice) - require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) - require.NoError(t, m.add(testServiceID, secondVP, testSeed, 0)) + _, err := m.add(testServiceID, vpAlice, testSeed, 0) + require.NoError(t, err) + _, err = m.add(testServiceID, secondVP, testSeed, 0) + require.NoError(t, err) // First VP should not exist exists, err := m.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) @@ -141,7 +148,8 @@ func Test_sqlStore_get(t *testing.T) { }) t.Run("1 entry, 0 timestamp", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) + _, err := m.add(testServiceID, vpAlice, testSeed, 0) + require.NoError(t, err) presentations, seed, timestamp, err := m.get(testServiceID, 0) assert.NoError(t, err) assert.Equal(t, map[string]vc.VerifiablePresentation{"1": vpAlice}, presentations) @@ -150,8 +158,10 @@ func Test_sqlStore_get(t *testing.T) { }) t.Run("2 entries, 0 timestamp", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) - require.NoError(t, m.add(testServiceID, vpBob, testSeed, 0)) + _, err := m.add(testServiceID, vpAlice, testSeed, 0) + require.NoError(t, err) + _, err = m.add(testServiceID, vpBob, testSeed, 0) + require.NoError(t, err) presentations, _, timestamp, err := m.get(testServiceID, 0) assert.NoError(t, err) assert.Equal(t, map[string]vc.VerifiablePresentation{"1": vpAlice, "2": vpBob}, presentations) @@ -159,8 +169,10 @@ func Test_sqlStore_get(t *testing.T) { }) t.Run("2 entries, start after first", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) - require.NoError(t, m.add(testServiceID, vpBob, testSeed, 0)) + _, err := m.add(testServiceID, vpAlice, testSeed, 0) + require.NoError(t, err) + _, err = m.add(testServiceID, vpBob, testSeed, 0) + require.NoError(t, err) presentations, _, timestamp, err := m.get(testServiceID, 1) assert.NoError(t, err) assert.Equal(t, map[string]vc.VerifiablePresentation{"2": vpBob}, presentations) @@ -168,8 +180,9 @@ func Test_sqlStore_get(t *testing.T) { }) t.Run("2 entries, start at end", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, testSeed, 0)) - require.NoError(t, m.add(testServiceID, vpBob, testSeed, 0)) + _, err := m.add(testServiceID, vpAlice, testSeed, 0) + require.NoError(t, err) + _, err = m.add(testServiceID, vpBob, testSeed, 0) presentations, _, timestamp, err := m.get(testServiceID, 2) assert.NoError(t, err) assert.Equal(t, map[string]vc.VerifiablePresentation{}, presentations) @@ -182,7 +195,7 @@ func Test_sqlStore_get(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - err := c.add(testServiceID, createPresentation(aliceDID, vcAlice), testSeed, 0) + _, err := c.add(testServiceID, createPresentation(aliceDID, vcAlice), testSeed, 0) require.NoError(t, err) }() } @@ -200,7 +213,7 @@ func Test_sqlStore_search(t *testing.T) { t.Run("empty database", func(t *testing.T) { c := setupStore(t, storageEngine.GetSQLDatabase()) - actualVPs, err := c.search(testServiceID, map[string]string{}) + actualVPs, err := c.search(testServiceID, map[string]string{}, true) require.NoError(t, err) require.Len(t, actualVPs, 0) }) @@ -208,13 +221,13 @@ func Test_sqlStore_search(t *testing.T) { vps := []vc.VerifiablePresentation{vpAlice} c := setupStore(t, storageEngine.GetSQLDatabase()) for _, vp := range vps { - err := c.add(testServiceID, vp, testSeed, 0) + _, err := c.add(testServiceID, vp, testSeed, 0) require.NoError(t, err) } actualVPs, err := c.search(testServiceID, map[string]string{ "credentialSubject.person.givenName": "Alice", - }) + }, true) require.NoError(t, err) require.Len(t, actualVPs, 1) assert.Equal(t, vpAlice.ID.String(), actualVPs[0].ID.String()) @@ -223,24 +236,30 @@ func Test_sqlStore_search(t *testing.T) { vps := []vc.VerifiablePresentation{vpAlice, vpBob} c := setupStore(t, storageEngine.GetSQLDatabase()) for _, vp := range vps { - err := c.add(testServiceID, vp, testSeed, 0) + _, err := c.add(testServiceID, vp, testSeed, 0) require.NoError(t, err) } - actualVPs, err := c.search(testServiceID, map[string]string{}) + actualVPs, err := c.search(testServiceID, map[string]string{}, true) require.NoError(t, err) require.Len(t, actualVPs, 2) + + t.Run("validated", func(t *testing.T) { + actualVPs, err = c.search(testServiceID, map[string]string{}, false) + require.NoError(t, err) + require.Len(t, actualVPs, 0) + }) }) t.Run("not found", func(t *testing.T) { vps := []vc.VerifiablePresentation{vpAlice, vpBob} c := setupStore(t, storageEngine.GetSQLDatabase()) for _, vp := range vps { - err := c.add(testServiceID, vp, testSeed, 0) + _, err := c.add(testServiceID, vp, testSeed, 0) require.NoError(t, err) } actualVPs, err := c.search(testServiceID, map[string]string{ "credentialSubject.person.givenName": "Charlie", - }) + }, true) require.NoError(t, err) require.Len(t, actualVPs, 0) }) @@ -345,7 +364,7 @@ func Test_sqlStore_setPresentationRefreshError(t *testing.T) { assert.Equal(t, refreshError.Error, assert.AnError.Error()) assert.True(t, refreshError.LastOccurrence > int(time.Now().Add(-1*time.Second).Unix())) }) - t.Run("delete", func(t *testing.T) { + t.Run("deletePresentationRecord", func(t *testing.T) { c := setupStore(t, storageEngine.GetSQLDatabase()) require.NoError(t, c.updatePresentationRefreshTime(testServiceID, aliceSubject, nil, to.Ptr(time.Now().Add(time.Second)))) require.NoError(t, c.setPresentationRefreshError(testServiceID, aliceSubject, assert.AnError)) @@ -373,8 +392,10 @@ func Test_sqlStore_getSubjectVPsOnService(t *testing.T) { _ = storageEngine.Shutdown() }) c := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, c.add(testServiceID, vpAlice2, testSeed, 0)) - require.NoError(t, c.add(testServiceID, vpBob2, testSeed, 0)) + _, err := c.add(testServiceID, vpAlice2, testSeed, 0) + require.NoError(t, err) + _, err = c.add(testServiceID, vpBob2, testSeed, 0) + require.NoError(t, err) t.Run("ok - single", func(t *testing.T) { vps, err := c.getSubjectVPsOnService(testServiceID, []did.DID{aliceDID}) @@ -403,21 +424,76 @@ func Test_sqlStore_wipeOnSeedChange(t *testing.T) { }) t.Run("1 entry wiped, 1 remains", func(t *testing.T) { c := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, c.add(testServiceID, vpAlice, testSeed, 0)) - require.NoError(t, c.add("other", vpAlice, testSeed, 0)) + _, err := c.add(testServiceID, vpAlice, testSeed, 0) + require.NoError(t, err) + _, err = c.add("other", vpAlice, testSeed, 0) + require.NoError(t, err) - err := c.wipeOnSeedChange(testServiceID, "other") + err = c.wipeOnSeedChange(testServiceID, "other") require.NoError(t, err) - vps, err := c.search(testServiceID, map[string]string{}) + vps, err := c.search(testServiceID, map[string]string{}, true) require.NoError(t, err) require.Len(t, vps, 0) - vps, err = c.search("other", map[string]string{}) + vps, err = c.search("other", map[string]string{}, true) require.NoError(t, err) require.Len(t, vps, 1) }) } +func Test_sqlStore_updateValidated(t *testing.T) { + storageEngine := storage.NewTestStorageEngine(t) + require.NoError(t, storageEngine.Start()) + t.Cleanup(func() { + _ = storageEngine.Shutdown() + }) + + c := setupStore(t, storageEngine.GetSQLDatabase()) + _, err := c.add(testServiceID, vpAlice, testSeed, 0) + require.NoError(t, err) + + result, err := c.allPresentations(true) + require.NoError(t, err) + assert.Len(t, result, 0) + result, err = c.allPresentations(false) + require.NoError(t, err) + assert.Len(t, result, 1) + + t.Run("validated", func(t *testing.T) { + err = c.updateValidated(result) + require.NoError(t, err) + + result, err = c.allPresentations(false) + require.NoError(t, err) + assert.Len(t, result, 0) + result, err = c.allPresentations(true) + require.NoError(t, err) + assert.Len(t, result, 1) + }) +} + +func Test_sqlStore_delete(t *testing.T) { + storageEngine := storage.NewTestStorageEngine(t) + require.NoError(t, storageEngine.Start()) + t.Cleanup(func() { + _ = storageEngine.Shutdown() + }) + + c := setupStore(t, storageEngine.GetSQLDatabase()) + _, err := c.add(testServiceID, vpAlice, testSeed, 0) + require.NoError(t, err) + presentations, _ := c.allPresentations(false) + require.Len(t, presentations, 1) + + err = c.deletePresentationRecord(presentations[0].ID) + + require.NoError(t, err) + + result, err := c.allPresentations(false) + require.NoError(t, err) + assert.Len(t, result, 0) +} + func setupStore(t *testing.T, db *gorm.DB) *sqlStore { resetStore(t, db) defs := testDefinitions() @@ -427,7 +503,7 @@ func setupStore(t *testing.T, db *gorm.DB) *sqlStore { } func resetStore(t *testing.T, db *gorm.DB) { - // related tables are emptied due to on-delete-cascade clause + // related tables are emptied due to on-deletePresentationRecord-cascade clause tableNames := []string{"discovery_service", "discovery_presentation", "discovery_credential", "credential", "credential_prop"} for _, tableName := range tableNames { require.NoError(t, db.Exec("DELETE FROM "+tableName).Error) diff --git a/storage/sql_migrations/010_discoverypresentation_validation.sql b/storage/sql_migrations/010_discoverypresentation_validation.sql new file mode 100644 index 000000000..381a6b20d --- /dev/null +++ b/storage/sql_migrations/010_discoverypresentation_validation.sql @@ -0,0 +1,6 @@ +-- +goose Up +-- discovery_presentation: add validated column +alter table discovery_presentation add validated SMALLINT NOT NULL DEFAULT 0; + +-- +goose Down +alter table discovery_presentation drop column validated;