From 3902523fa9eb310b1c27cc93143a406f73c3c93e Mon Sep 17 00:00:00 2001 From: Alyx Holms Date: Wed, 31 Jan 2024 11:54:24 -0700 Subject: [PATCH] feat: UpdateAssetGroup, UpdateAssetGroupSelector, DeleteAssetGroupSelector, CreateSAMLIdentityProvider, and UpdateSAMLIdentityProvider have audit log support (#374) chore: remove unused RemoveAssetGroupSelector method chore: update tests and mocks to account for interface changes --- cmd/api/src/api/v2/agi.go | 4 +-- cmd/api/src/api/v2/agi_test.go | 8 ++--- cmd/api/src/api/v2/auth/auth.go | 2 +- cmd/api/src/database/agi.go | 41 +++++++++++++++++------ cmd/api/src/database/auth.go | 28 ++++++++++++---- cmd/api/src/database/auth_test.go | 2 +- cmd/api/src/database/db.go | 11 +++---- cmd/api/src/database/mocks/db.go | 54 ++++++++++++------------------- 8 files changed, 86 insertions(+), 64 deletions(-) diff --git a/cmd/api/src/api/v2/agi.go b/cmd/api/src/api/v2/agi.go index 00c6b2d9a4..11b6737ab2 100644 --- a/cmd/api/src/api/v2/agi.go +++ b/cmd/api/src/api/v2/agi.go @@ -187,7 +187,7 @@ func (s Resources) UpdateAssetGroup(response http.ResponseWriter, request *http. } else { assetGroup.Name = updateAssetGroupRequest.Name - if err := s.DB.UpdateAssetGroup(assetGroup); err != nil { + if err := s.DB.UpdateAssetGroup(request.Context(), assetGroup); err != nil { api.HandleDatabaseError(request, response, err) } else { api.WriteBasicResponse(request.Context(), assetGroup, http.StatusOK, response) @@ -286,7 +286,7 @@ func (s Resources) DeleteAssetGroupSelector(response http.ResponseWriter, reques api.HandleDatabaseError(request, response, err) } else if assetGroupSelector.SystemSelector { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusConflict, "Cannot delete a system defined asset group selector.", request), response) - } else if err := s.DB.DeleteAssetGroupSelector(assetGroupSelector); err != nil { + } else if err := s.DB.DeleteAssetGroupSelector(request.Context(), assetGroupSelector); err != nil { api.HandleDatabaseError(request, response, err) } else { response.WriteHeader(http.StatusOK) diff --git a/cmd/api/src/api/v2/agi_test.go b/cmd/api/src/api/v2/agi_test.go index 32d5f68c17..d3841cc6b5 100644 --- a/cmd/api/src/api/v2/agi_test.go +++ b/cmd/api/src/api/v2/agi_test.go @@ -363,7 +363,7 @@ func TestResources_UpdateAssetGroup(t *testing.T) { // UpdateAssetGroup DB fails mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) - mockDB.EXPECT().UpdateAssetGroup(model.AssetGroup{}).Return(fmt.Errorf("exploded")) + mockDB.EXPECT().UpdateAssetGroup(gomock.Any(), model.AssetGroup{}).Return(fmt.Errorf("exploded")) requestTemplate. WithURLPathVars(map[string]string{ @@ -376,7 +376,7 @@ func TestResources_UpdateAssetGroup(t *testing.T) { // Success mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) - mockDB.EXPECT().UpdateAssetGroup(model.AssetGroup{}).Return(nil) + mockDB.EXPECT().UpdateAssetGroup(gomock.Any(), model.AssetGroup{}).Return(nil) requestTemplate. WithURLPathVars(map[string]string{ @@ -811,7 +811,7 @@ func TestResources_DeleteAssetGroupSelector(t *testing.T) { // DeleteAssetGroupSelector DB fails mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) mockDB.EXPECT().GetAssetGroupSelector(int32(1234)).Return(model.AssetGroupSelector{}, nil) - mockDB.EXPECT().DeleteAssetGroupSelector(model.AssetGroupSelector{}).Return(fmt.Errorf("exploded")) + mockDB.EXPECT().DeleteAssetGroupSelector(gomock.Any(), model.AssetGroupSelector{}).Return(fmt.Errorf("exploded")) requestTemplate. WithURLPathVars(map[string]string{ @@ -825,7 +825,7 @@ func TestResources_DeleteAssetGroupSelector(t *testing.T) { // Success mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) mockDB.EXPECT().GetAssetGroupSelector(int32(1234)).Return(model.AssetGroupSelector{}, nil) - mockDB.EXPECT().DeleteAssetGroupSelector(model.AssetGroupSelector{}).Return(nil) + mockDB.EXPECT().DeleteAssetGroupSelector(gomock.Any(), model.AssetGroupSelector{}).Return(nil) requestTemplate. WithURLPathVars(map[string]string{ diff --git a/cmd/api/src/api/v2/auth/auth.go b/cmd/api/src/api/v2/auth/auth.go index 571068559b..cd99a47b6a 100644 --- a/cmd/api/src/api/v2/auth/auth.go +++ b/cmd/api/src/api/v2/auth/auth.go @@ -152,7 +152,7 @@ func (s ManagementResource) CreateSAMLProviderMultipart(response http.ResponseWr samlIdentityProvider.IssuerURI = metadata.EntityID samlIdentityProvider.SingleSignOnURI = ssoURL - if newSAMLProvider, err := s.db.CreateSAMLIdentityProvider(samlIdentityProvider); err != nil { + if newSAMLProvider, err := s.db.CreateSAMLIdentityProvider(request.Context(), samlIdentityProvider); err != nil { api.HandleDatabaseError(request, response, err) } else { api.WriteBasicResponse(request.Context(), newSAMLProvider, http.StatusOK, response) diff --git a/cmd/api/src/database/agi.go b/cmd/api/src/database/agi.go index 1f2824c2fd..8d25fe769c 100644 --- a/cmd/api/src/database/agi.go +++ b/cmd/api/src/database/agi.go @@ -46,8 +46,17 @@ func (s *BloodhoundDB) CreateAssetGroup(ctx context.Context, name, tag string, s }) } -func (s *BloodhoundDB) UpdateAssetGroup(assetGroup model.AssetGroup) error { - return CheckError(s.db.Save(&assetGroup)) +func (s *BloodhoundDB) UpdateAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error { + var ( + auditEntry = model.AuditEntry{ + Action: "UpdateAssetGroup", + Model: &assetGroup, // Pointer is required to ensure success log contains updated fields after transaction + } + ) + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Save(&assetGroup)) + }) } func (s *BloodhoundDB) DeleteAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error { @@ -169,16 +178,30 @@ func (s *BloodhoundDB) GetAssetGroupSelector(id int32) (model.AssetGroupSelector return assetGroupSelector, CheckError(s.db.Find(&assetGroupSelector, id)) } -func (s *BloodhoundDB) UpdateAssetGroupSelector(selector model.AssetGroupSelector) error { - return CheckError(s.db.Save(&selector)) -} +func (s *BloodhoundDB) UpdateAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error { + var ( + auditEntry = model.AuditEntry{ + Action: "UpdateAssetGroupSelector", + Model: &selector, // Pointer is required to ensure success log contains updated fields after transaction + } + ) -func (s *BloodhoundDB) DeleteAssetGroupSelector(selector model.AssetGroupSelector) error { - return CheckError(s.db.Delete(&selector)) + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Save(&selector)) + }) } -func (s *BloodhoundDB) RemoveAssetGroupSelector(selector model.AssetGroupSelector) error { - return CheckError(s.db.Where("asset_group_id=? AND name=?", selector.AssetGroupID, selector.Name).Delete(&model.AssetGroupSelector{})) +func (s *BloodhoundDB) DeleteAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error { + var ( + auditEntry = model.AuditEntry{ + Action: "DeleteAssetGroupSelector", + Model: &selector, // Pointer is required to ensure success log contains updated fields after transaction + } + ) + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Delete(&selector)) + }) } func (s *BloodhoundDB) CreateRawAssetGroupSelector(assetGroup model.AssetGroup, name, selector string) (model.AssetGroupSelector, error) { diff --git a/cmd/api/src/database/auth.go b/cmd/api/src/database/auth.go index 8f655a3b0f..f16c1e2738 100644 --- a/cmd/api/src/database/auth.go +++ b/cmd/api/src/database/auth.go @@ -578,20 +578,34 @@ func (s *BloodhoundDB) DeleteAuthSecret(authSecret model.AuthSecret) error { // CreateSAMLProvider creates a new saml_providers row using the data in the input struct // INSERT INTO saml_identity_providers (...) VALUES (...) -func (s *BloodhoundDB) CreateSAMLIdentityProvider(samlProvider model.SAMLProvider) (model.SAMLProvider, error) { +func (s *BloodhoundDB) CreateSAMLIdentityProvider(ctx context.Context, samlProvider model.SAMLProvider) (model.SAMLProvider, error) { var ( - updatedSAMLProvider = samlProvider - result = s.db.Create(&updatedSAMLProvider) + auditEntry = model.AuditEntry{ + Action: "CreateSAMLIdentityProvider", + Model: &samlProvider, // Pointer is required to ensure success log contains updated fields after transaction + } ) - return updatedSAMLProvider, CheckError(result) + err := s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Create(&samlProvider)) + }) + + return samlProvider, err } // CreateSAMLProvider updates a saml_providers row using the data in the input struct // UPDATE saml_identity_providers SET (...) VALUES (...) WHERE id = ... -func (s *BloodhoundDB) UpdateSAMLIdentityProvider(provider model.SAMLProvider) error { - result := s.db.Save(&provider) - return CheckError(result) +func (s *BloodhoundDB) UpdateSAMLIdentityProvider(ctx context.Context, provider model.SAMLProvider) error { + var ( + auditEntry = model.AuditEntry{ + Action: "UpdateSAMLIdentityProvider", + Model: &provider, // Pointer is required to ensure success log contains updated fields after transaction + } + ) + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Save(&provider)) + }) } // LookupSAMLProviderByName returns a SAML provider corresponding to the name provided diff --git a/cmd/api/src/database/auth_test.go b/cmd/api/src/database/auth_test.go index 5c0903f82c..aae8af58b9 100644 --- a/cmd/api/src/database/auth_test.go +++ b/cmd/api/src/database/auth_test.go @@ -326,7 +326,7 @@ func TestDatabase_CreateSAMLProvider(t *testing.T) { SingleSignOnURI: "https://idp.example.com/sso", } - if newSAMLProvider, err := dbInst.CreateSAMLIdentityProvider(samlProvider); err != nil { + if newSAMLProvider, err := dbInst.CreateSAMLIdentityProvider(context.Background(), samlProvider); err != nil { t.Fatalf("Failed to create SAML provider: %v", err) } else { user.SAMLProviderID = null.Int32From(newSAMLProvider.ID) diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index 90384fc126..bdcf668530 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -62,7 +62,7 @@ type Database interface { GetIngestTasksForJob(jobID int64) (model.IngestTasks, error) GetUnfinishedIngestIDs() ([]int64, error) CreateAssetGroup(ctx context.Context, name, tag string, systemGroup bool) (model.AssetGroup, error) - UpdateAssetGroup(assetGroup model.AssetGroup) error + UpdateAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error DeleteAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error GetAssetGroup(id int32) (model.AssetGroup, error) GetAllAssetGroups(order string, filter model.SQLFilter) (model.AssetGroups, error) @@ -72,9 +72,8 @@ type Database interface { GetTimeRangedAssetGroupCollections(assetGroupID int32, from int64, to int64, order string) (model.AssetGroupCollections, error) GetAllAssetGroupCollections() (model.AssetGroupCollections, error) GetAssetGroupSelector(id int32) (model.AssetGroupSelector, error) - UpdateAssetGroupSelector(selector model.AssetGroupSelector) error - DeleteAssetGroupSelector(selector model.AssetGroupSelector) error - RemoveAssetGroupSelector(selector model.AssetGroupSelector) error + UpdateAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error + DeleteAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error CreateRawAssetGroupSelector(assetGroup model.AssetGroup, name, selector string) (model.AssetGroupSelector, error) CreateAssetGroupSelector(assetGroup model.AssetGroup, spec model.AssetGroupSelectorSpec, systemSelector bool) (model.AssetGroupSelector, error) UpdateAssetGroupSelectors(ctx ctx.Context, assetGroup model.AssetGroup, selectorSpecs []model.AssetGroupSelectorSpec, systemSelector bool) (model.UpdatedAssetGroupSelectors, error) @@ -117,8 +116,8 @@ type Database interface { GetAuthSecret(id int32) (model.AuthSecret, error) UpdateAuthSecret(ctx context.Context, authSecret model.AuthSecret) error DeleteAuthSecret(authSecret model.AuthSecret) error - CreateSAMLIdentityProvider(samlProvider model.SAMLProvider) (model.SAMLProvider, error) - UpdateSAMLIdentityProvider(samlProvider model.SAMLProvider) error + CreateSAMLIdentityProvider(ctx context.Context, samlProvider model.SAMLProvider) (model.SAMLProvider, error) + UpdateSAMLIdentityProvider(ctx context.Context, samlProvider model.SAMLProvider) error LookupSAMLProviderByName(name string) (model.SAMLProvider, error) GetAllSAMLProviders() (model.SAMLProviders, error) GetSAMLProvider(id int32) (model.SAMLProvider, error) diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 25be487439..71311e6552 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -306,18 +306,18 @@ func (mr *MockDatabaseMockRecorder) CreateRole(arg0 interface{}) *gomock.Call { } // CreateSAMLIdentityProvider mocks base method. -func (m *MockDatabase) CreateSAMLIdentityProvider(arg0 model.SAMLProvider) (model.SAMLProvider, error) { +func (m *MockDatabase) CreateSAMLIdentityProvider(arg0 context.Context, arg1 model.SAMLProvider) (model.SAMLProvider, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSAMLIdentityProvider", arg0) + ret := m.ctrl.Call(m, "CreateSAMLIdentityProvider", arg0, arg1) ret0, _ := ret[0].(model.SAMLProvider) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateSAMLIdentityProvider indicates an expected call of CreateSAMLIdentityProvider. -func (mr *MockDatabaseMockRecorder) CreateSAMLIdentityProvider(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) CreateSAMLIdentityProvider(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSAMLIdentityProvider", reflect.TypeOf((*MockDatabase)(nil).CreateSAMLIdentityProvider), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSAMLIdentityProvider", reflect.TypeOf((*MockDatabase)(nil).CreateSAMLIdentityProvider), arg0, arg1) } // CreateSavedQuery mocks base method. @@ -380,17 +380,17 @@ func (mr *MockDatabaseMockRecorder) DeleteAssetGroup(arg0, arg1 interface{}) *go } // DeleteAssetGroupSelector mocks base method. -func (m *MockDatabase) DeleteAssetGroupSelector(arg0 model.AssetGroupSelector) error { +func (m *MockDatabase) DeleteAssetGroupSelector(arg0 context.Context, arg1 model.AssetGroupSelector) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAssetGroupSelector", arg0) + ret := m.ctrl.Call(m, "DeleteAssetGroupSelector", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // DeleteAssetGroupSelector indicates an expected call of DeleteAssetGroupSelector. -func (mr *MockDatabaseMockRecorder) DeleteAssetGroupSelector(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) DeleteAssetGroupSelector(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAssetGroupSelector", reflect.TypeOf((*MockDatabase)(nil).DeleteAssetGroupSelector), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAssetGroupSelector", reflect.TypeOf((*MockDatabase)(nil).DeleteAssetGroupSelector), arg0, arg1) } // DeleteAuthSecret mocks base method. @@ -1290,20 +1290,6 @@ func (mr *MockDatabaseMockRecorder) RawFirst(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RawFirst", reflect.TypeOf((*MockDatabase)(nil).RawFirst), arg0) } -// RemoveAssetGroupSelector mocks base method. -func (m *MockDatabase) RemoveAssetGroupSelector(arg0 model.AssetGroupSelector) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveAssetGroupSelector", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemoveAssetGroupSelector indicates an expected call of RemoveAssetGroupSelector. -func (mr *MockDatabaseMockRecorder) RemoveAssetGroupSelector(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveAssetGroupSelector", reflect.TypeOf((*MockDatabase)(nil).RemoveAssetGroupSelector), arg0) -} - // SavedQueryBelongsToUser mocks base method. func (m *MockDatabase) SavedQueryBelongsToUser(arg0 uuid.UUID, arg1 int) (bool, error) { m.ctrl.T.Helper() @@ -1372,31 +1358,31 @@ func (mr *MockDatabaseMockRecorder) SweepSessions() *gomock.Call { } // UpdateAssetGroup mocks base method. -func (m *MockDatabase) UpdateAssetGroup(arg0 model.AssetGroup) error { +func (m *MockDatabase) UpdateAssetGroup(arg0 context.Context, arg1 model.AssetGroup) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAssetGroup", arg0) + ret := m.ctrl.Call(m, "UpdateAssetGroup", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // UpdateAssetGroup indicates an expected call of UpdateAssetGroup. -func (mr *MockDatabaseMockRecorder) UpdateAssetGroup(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) UpdateAssetGroup(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAssetGroup", reflect.TypeOf((*MockDatabase)(nil).UpdateAssetGroup), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAssetGroup", reflect.TypeOf((*MockDatabase)(nil).UpdateAssetGroup), arg0, arg1) } // UpdateAssetGroupSelector mocks base method. -func (m *MockDatabase) UpdateAssetGroupSelector(arg0 model.AssetGroupSelector) error { +func (m *MockDatabase) UpdateAssetGroupSelector(arg0 context.Context, arg1 model.AssetGroupSelector) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAssetGroupSelector", arg0) + ret := m.ctrl.Call(m, "UpdateAssetGroupSelector", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // UpdateAssetGroupSelector indicates an expected call of UpdateAssetGroupSelector. -func (mr *MockDatabaseMockRecorder) UpdateAssetGroupSelector(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) UpdateAssetGroupSelector(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAssetGroupSelector", reflect.TypeOf((*MockDatabase)(nil).UpdateAssetGroupSelector), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAssetGroupSelector", reflect.TypeOf((*MockDatabase)(nil).UpdateAssetGroupSelector), arg0, arg1) } // UpdateAssetGroupSelectors mocks base method. @@ -1471,17 +1457,17 @@ func (mr *MockDatabaseMockRecorder) UpdateRole(arg0 interface{}) *gomock.Call { } // UpdateSAMLIdentityProvider mocks base method. -func (m *MockDatabase) UpdateSAMLIdentityProvider(arg0 model.SAMLProvider) error { +func (m *MockDatabase) UpdateSAMLIdentityProvider(arg0 context.Context, arg1 model.SAMLProvider) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateSAMLIdentityProvider", arg0) + ret := m.ctrl.Call(m, "UpdateSAMLIdentityProvider", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // UpdateSAMLIdentityProvider indicates an expected call of UpdateSAMLIdentityProvider. -func (mr *MockDatabaseMockRecorder) UpdateSAMLIdentityProvider(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) UpdateSAMLIdentityProvider(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSAMLIdentityProvider", reflect.TypeOf((*MockDatabase)(nil).UpdateSAMLIdentityProvider), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSAMLIdentityProvider", reflect.TypeOf((*MockDatabase)(nil).UpdateSAMLIdentityProvider), arg0, arg1) } // UpdateUser mocks base method.