Skip to content

Commit

Permalink
feat: UpdateAssetGroup, UpdateAssetGroupSelector, DeleteAssetGroupSel…
Browse files Browse the repository at this point in the history
…ector, CreateSAMLIdentityProvider, and UpdateSAMLIdentityProvider have audit log support (#374)

chore: remove unused RemoveAssetGroupSelector method
chore: update tests and mocks to account for interface changes
  • Loading branch information
superlinkx authored Jan 31, 2024
1 parent 9ff9831 commit 3902523
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 64 deletions.
4 changes: 2 additions & 2 deletions cmd/api/src/api/v2/agi.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (s Resources) UpdateAssetGroup(response http.ResponseWriter, request *http.
} else {
assetGroup.Name = updateAssetGroupRequest.Name

if err := s.DB.UpdateAssetGroup(assetGroup); err != nil {
if err := s.DB.UpdateAssetGroup(request.Context(), assetGroup); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
api.WriteBasicResponse(request.Context(), assetGroup, http.StatusOK, response)
Expand Down Expand Up @@ -286,7 +286,7 @@ func (s Resources) DeleteAssetGroupSelector(response http.ResponseWriter, reques
api.HandleDatabaseError(request, response, err)
} else if assetGroupSelector.SystemSelector {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusConflict, "Cannot delete a system defined asset group selector.", request), response)
} else if err := s.DB.DeleteAssetGroupSelector(assetGroupSelector); err != nil {
} else if err := s.DB.DeleteAssetGroupSelector(request.Context(), assetGroupSelector); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
response.WriteHeader(http.StatusOK)
Expand Down
8 changes: 4 additions & 4 deletions cmd/api/src/api/v2/agi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ func TestResources_UpdateAssetGroup(t *testing.T) {

// UpdateAssetGroup DB fails
mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil)
mockDB.EXPECT().UpdateAssetGroup(model.AssetGroup{}).Return(fmt.Errorf("exploded"))
mockDB.EXPECT().UpdateAssetGroup(gomock.Any(), model.AssetGroup{}).Return(fmt.Errorf("exploded"))

requestTemplate.
WithURLPathVars(map[string]string{
Expand All @@ -376,7 +376,7 @@ func TestResources_UpdateAssetGroup(t *testing.T) {

// Success
mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil)
mockDB.EXPECT().UpdateAssetGroup(model.AssetGroup{}).Return(nil)
mockDB.EXPECT().UpdateAssetGroup(gomock.Any(), model.AssetGroup{}).Return(nil)

requestTemplate.
WithURLPathVars(map[string]string{
Expand Down Expand Up @@ -811,7 +811,7 @@ func TestResources_DeleteAssetGroupSelector(t *testing.T) {
// DeleteAssetGroupSelector DB fails
mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil)
mockDB.EXPECT().GetAssetGroupSelector(int32(1234)).Return(model.AssetGroupSelector{}, nil)
mockDB.EXPECT().DeleteAssetGroupSelector(model.AssetGroupSelector{}).Return(fmt.Errorf("exploded"))
mockDB.EXPECT().DeleteAssetGroupSelector(gomock.Any(), model.AssetGroupSelector{}).Return(fmt.Errorf("exploded"))

requestTemplate.
WithURLPathVars(map[string]string{
Expand All @@ -825,7 +825,7 @@ func TestResources_DeleteAssetGroupSelector(t *testing.T) {
// Success
mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil)
mockDB.EXPECT().GetAssetGroupSelector(int32(1234)).Return(model.AssetGroupSelector{}, nil)
mockDB.EXPECT().DeleteAssetGroupSelector(model.AssetGroupSelector{}).Return(nil)
mockDB.EXPECT().DeleteAssetGroupSelector(gomock.Any(), model.AssetGroupSelector{}).Return(nil)

requestTemplate.
WithURLPathVars(map[string]string{
Expand Down
2 changes: 1 addition & 1 deletion cmd/api/src/api/v2/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (s ManagementResource) CreateSAMLProviderMultipart(response http.ResponseWr
samlIdentityProvider.IssuerURI = metadata.EntityID
samlIdentityProvider.SingleSignOnURI = ssoURL

if newSAMLProvider, err := s.db.CreateSAMLIdentityProvider(samlIdentityProvider); err != nil {
if newSAMLProvider, err := s.db.CreateSAMLIdentityProvider(request.Context(), samlIdentityProvider); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
api.WriteBasicResponse(request.Context(), newSAMLProvider, http.StatusOK, response)
Expand Down
41 changes: 32 additions & 9 deletions cmd/api/src/database/agi.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,17 @@ func (s *BloodhoundDB) CreateAssetGroup(ctx context.Context, name, tag string, s
})
}

func (s *BloodhoundDB) UpdateAssetGroup(assetGroup model.AssetGroup) error {
return CheckError(s.db.Save(&assetGroup))
func (s *BloodhoundDB) UpdateAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error {
var (
auditEntry = model.AuditEntry{
Action: "UpdateAssetGroup",
Model: &assetGroup, // Pointer is required to ensure success log contains updated fields after transaction
}
)

return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error {
return CheckError(tx.Save(&assetGroup))
})
}

func (s *BloodhoundDB) DeleteAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error {
Expand Down Expand Up @@ -169,16 +178,30 @@ func (s *BloodhoundDB) GetAssetGroupSelector(id int32) (model.AssetGroupSelector
return assetGroupSelector, CheckError(s.db.Find(&assetGroupSelector, id))
}

func (s *BloodhoundDB) UpdateAssetGroupSelector(selector model.AssetGroupSelector) error {
return CheckError(s.db.Save(&selector))
}
func (s *BloodhoundDB) UpdateAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error {
var (
auditEntry = model.AuditEntry{
Action: "UpdateAssetGroupSelector",
Model: &selector, // Pointer is required to ensure success log contains updated fields after transaction
}
)

func (s *BloodhoundDB) DeleteAssetGroupSelector(selector model.AssetGroupSelector) error {
return CheckError(s.db.Delete(&selector))
return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error {
return CheckError(tx.Save(&selector))
})
}

func (s *BloodhoundDB) RemoveAssetGroupSelector(selector model.AssetGroupSelector) error {
return CheckError(s.db.Where("asset_group_id=? AND name=?", selector.AssetGroupID, selector.Name).Delete(&model.AssetGroupSelector{}))
func (s *BloodhoundDB) DeleteAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error {
var (
auditEntry = model.AuditEntry{
Action: "DeleteAssetGroupSelector",
Model: &selector, // Pointer is required to ensure success log contains updated fields after transaction
}
)

return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error {
return CheckError(tx.Delete(&selector))
})
}

func (s *BloodhoundDB) CreateRawAssetGroupSelector(assetGroup model.AssetGroup, name, selector string) (model.AssetGroupSelector, error) {
Expand Down
28 changes: 21 additions & 7 deletions cmd/api/src/database/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,20 +578,34 @@ func (s *BloodhoundDB) DeleteAuthSecret(authSecret model.AuthSecret) error {

// CreateSAMLProvider creates a new saml_providers row using the data in the input struct
// INSERT INTO saml_identity_providers (...) VALUES (...)
func (s *BloodhoundDB) CreateSAMLIdentityProvider(samlProvider model.SAMLProvider) (model.SAMLProvider, error) {
func (s *BloodhoundDB) CreateSAMLIdentityProvider(ctx context.Context, samlProvider model.SAMLProvider) (model.SAMLProvider, error) {
var (
updatedSAMLProvider = samlProvider
result = s.db.Create(&updatedSAMLProvider)
auditEntry = model.AuditEntry{
Action: "CreateSAMLIdentityProvider",
Model: &samlProvider, // Pointer is required to ensure success log contains updated fields after transaction
}
)

return updatedSAMLProvider, CheckError(result)
err := s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error {
return CheckError(tx.Create(&samlProvider))
})

return samlProvider, err
}

// CreateSAMLProvider updates a saml_providers row using the data in the input struct
// UPDATE saml_identity_providers SET (...) VALUES (...) WHERE id = ...
func (s *BloodhoundDB) UpdateSAMLIdentityProvider(provider model.SAMLProvider) error {
result := s.db.Save(&provider)
return CheckError(result)
func (s *BloodhoundDB) UpdateSAMLIdentityProvider(ctx context.Context, provider model.SAMLProvider) error {
var (
auditEntry = model.AuditEntry{
Action: "UpdateSAMLIdentityProvider",
Model: &provider, // Pointer is required to ensure success log contains updated fields after transaction
}
)

return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error {
return CheckError(tx.Save(&provider))
})
}

// LookupSAMLProviderByName returns a SAML provider corresponding to the name provided
Expand Down
2 changes: 1 addition & 1 deletion cmd/api/src/database/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ func TestDatabase_CreateSAMLProvider(t *testing.T) {
SingleSignOnURI: "https://idp.example.com/sso",
}

if newSAMLProvider, err := dbInst.CreateSAMLIdentityProvider(samlProvider); err != nil {
if newSAMLProvider, err := dbInst.CreateSAMLIdentityProvider(context.Background(), samlProvider); err != nil {
t.Fatalf("Failed to create SAML provider: %v", err)
} else {
user.SAMLProviderID = null.Int32From(newSAMLProvider.ID)
Expand Down
11 changes: 5 additions & 6 deletions cmd/api/src/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ type Database interface {
GetIngestTasksForJob(jobID int64) (model.IngestTasks, error)
GetUnfinishedIngestIDs() ([]int64, error)
CreateAssetGroup(ctx context.Context, name, tag string, systemGroup bool) (model.AssetGroup, error)
UpdateAssetGroup(assetGroup model.AssetGroup) error
UpdateAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error
DeleteAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error
GetAssetGroup(id int32) (model.AssetGroup, error)
GetAllAssetGroups(order string, filter model.SQLFilter) (model.AssetGroups, error)
Expand All @@ -72,9 +72,8 @@ type Database interface {
GetTimeRangedAssetGroupCollections(assetGroupID int32, from int64, to int64, order string) (model.AssetGroupCollections, error)
GetAllAssetGroupCollections() (model.AssetGroupCollections, error)
GetAssetGroupSelector(id int32) (model.AssetGroupSelector, error)
UpdateAssetGroupSelector(selector model.AssetGroupSelector) error
DeleteAssetGroupSelector(selector model.AssetGroupSelector) error
RemoveAssetGroupSelector(selector model.AssetGroupSelector) error
UpdateAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error
DeleteAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error
CreateRawAssetGroupSelector(assetGroup model.AssetGroup, name, selector string) (model.AssetGroupSelector, error)
CreateAssetGroupSelector(assetGroup model.AssetGroup, spec model.AssetGroupSelectorSpec, systemSelector bool) (model.AssetGroupSelector, error)
UpdateAssetGroupSelectors(ctx ctx.Context, assetGroup model.AssetGroup, selectorSpecs []model.AssetGroupSelectorSpec, systemSelector bool) (model.UpdatedAssetGroupSelectors, error)
Expand Down Expand Up @@ -117,8 +116,8 @@ type Database interface {
GetAuthSecret(id int32) (model.AuthSecret, error)
UpdateAuthSecret(ctx context.Context, authSecret model.AuthSecret) error
DeleteAuthSecret(authSecret model.AuthSecret) error
CreateSAMLIdentityProvider(samlProvider model.SAMLProvider) (model.SAMLProvider, error)
UpdateSAMLIdentityProvider(samlProvider model.SAMLProvider) error
CreateSAMLIdentityProvider(ctx context.Context, samlProvider model.SAMLProvider) (model.SAMLProvider, error)
UpdateSAMLIdentityProvider(ctx context.Context, samlProvider model.SAMLProvider) error
LookupSAMLProviderByName(name string) (model.SAMLProvider, error)
GetAllSAMLProviders() (model.SAMLProviders, error)
GetSAMLProvider(id int32) (model.SAMLProvider, error)
Expand Down
54 changes: 20 additions & 34 deletions cmd/api/src/database/mocks/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3902523

Please sign in to comment.