Skip to content

Commit

Permalink
fix: tests + mocks
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 committed Nov 22, 2024
1 parent 66420fe commit f058ff7
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 41 deletions.
99 changes: 85 additions & 14 deletions cmd/api/src/api/v2/auth/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,13 @@ import (

"github.com/specterops/bloodhound/src/api/v2/apitest"
"github.com/specterops/bloodhound/src/api/v2/auth"
"github.com/specterops/bloodhound/src/database"
"github.com/specterops/bloodhound/src/model"
"github.com/specterops/bloodhound/src/utils/test"
"go.uber.org/mock/gomock"
)

func TestManagementResource_CreateOIDCProvider(t *testing.T) {
const (
url = "/api/v2/sso/providers/oidc"
)
var (
mockCtrl = gomock.NewController(t)
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)
Expand All @@ -45,8 +43,6 @@ func TestManagementResource_CreateOIDCProvider(t *testing.T) {
}, nil)

test.Request(t).
WithMethod(http.MethodPost).
WithURL(url).
WithBody(auth.UpsertOIDCProviderRequest{
Name: "Bloodhound gang",
Issuer: "https://localhost/auth",
Expand All @@ -59,18 +55,13 @@ func TestManagementResource_CreateOIDCProvider(t *testing.T) {

t.Run("error parsing body request", func(t *testing.T) {
test.Request(t).
WithMethod(http.MethodPost).
WithURL(url).
WithBody("").
OnHandlerFunc(resources.CreateOIDCProvider).
Require().
ResponseStatusCode(http.StatusBadRequest)
})

t.Run("error validating request field", func(t *testing.T) {
test.Request(t).
WithMethod(http.MethodPost).
WithURL(url).
WithBody(auth.UpsertOIDCProviderRequest{
Name: "test",
Issuer: "1234:not:a:url",
Expand All @@ -86,8 +77,6 @@ func TestManagementResource_CreateOIDCProvider(t *testing.T) {
Issuer: "12345:bloodhound",
}
test.Request(t).
WithMethod(http.MethodPost).
WithURL(url).
WithBody(request).
OnHandlerFunc(resources.CreateOIDCProvider).
Require().
Expand All @@ -98,8 +87,6 @@ func TestManagementResource_CreateOIDCProvider(t *testing.T) {
mockDB.EXPECT().CreateOIDCProvider(gomock.Any(), "test", "https://localhost/auth", "bloodhound").Return(model.OIDCProvider{}, fmt.Errorf("error"))

test.Request(t).
WithMethod(http.MethodPost).
WithURL(url).
WithBody(auth.UpsertOIDCProviderRequest{
Name: "test",
Issuer: "https://localhost/auth",
Expand All @@ -110,3 +97,87 @@ func TestManagementResource_CreateOIDCProvider(t *testing.T) {
ResponseStatusCode(http.StatusInternalServerError)
})
}

func TestManagementResource_UpdateOIDCProvider(t *testing.T) {
var (
mockCtrl = gomock.NewController(t)
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)
baseProvider = model.SSOProvider{
Type: model.SessionAuthProviderOIDC,
Name: "Gotham Net",
OIDCProvider: &model.OIDCProvider{
ClientID: "gotham-net",
Issuer: "https://gotham.net",
},
}
urlParams = map[string]string{"sso_provider_id": "1"}
)
defer mockCtrl.Finish()

t.Run("successfully update an OIDCProvider", func(t *testing.T) {
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(baseProvider, nil)
mockDB.EXPECT().UpdateOIDCProvider(gomock.Any(), gomock.Any())

test.Request(t).
WithURLPathVars(urlParams).
WithBody(auth.UpsertOIDCProviderRequest{
Name: "Gotham Net 2",
Issuer: "https://gotham-2.net",
ClientID: "gotham-net-2",
}).
OnHandlerFunc(resources.UpdateSSOProvider).
Require().
ResponseStatusCode(http.StatusOK)
})

t.Run("error not found while updating an unknown OIDCProvider", func(t *testing.T) {
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(model.SSOProvider{}, database.ErrNotFound)

test.Request(t).
WithURLPathVars(urlParams).
OnHandlerFunc(resources.UpdateSSOProvider).
Require().
ResponseStatusCode(http.StatusNotFound)
})

t.Run("error parsing body request", func(t *testing.T) {
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(baseProvider, nil)

test.Request(t).
WithURLPathVars(urlParams).
OnHandlerFunc(resources.UpdateSSOProvider).
Require().
ResponseStatusCode(http.StatusBadRequest)
})

t.Run("error validating request field", func(t *testing.T) {
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(baseProvider, nil)

test.Request(t).
WithURLPathVars(urlParams).
WithBody(auth.UpsertOIDCProviderRequest{
Name: "test",
Issuer: "1234:not:a:url",
ClientID: "bloodhound",
}).
OnHandlerFunc(resources.UpdateSSOProvider).
Require().
ResponseStatusCode(http.StatusBadRequest)
})

t.Run("error creating oidc provider db entry", func(t *testing.T) {
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(baseProvider, nil)
mockDB.EXPECT().UpdateOIDCProvider(gomock.Any(), gomock.Any()).Return(model.OIDCProvider{}, fmt.Errorf("error"))

test.Request(t).
WithURLPathVars(urlParams).
WithBody(auth.UpsertOIDCProviderRequest{
Name: "test",
Issuer: "https://localhost/auth",
ClientID: "bloodhound",
}).
OnHandlerFunc(resources.UpdateSSOProvider).
Require().
ResponseStatusCode(http.StatusInternalServerError)
})
}
35 changes: 19 additions & 16 deletions cmd/api/src/database/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,12 @@ func TestDatabase_CreateGetDeleteAuthSecret(t *testing.T) {

func TestDatabase_CreateUpdateDeleteSAMLProvider(t *testing.T) {
var (
ctx = context.Background()
dbInst, user = initAndCreateUser(t)
samlProvider model.SAMLProvider
newSAMLProvider model.SAMLProvider
updatedUser model.User
updatedSAMLProvider model.SAMLProvider
err error
ctx = context.Background()
dbInst, user = initAndCreateUser(t)
samlProvider model.SAMLProvider
newSAMLProvider model.SAMLProvider
updatedUser model.User
err error
)
// Initialize the SAMLProvider without setting SSOProviderID
samlProvider = model.SAMLProvider{
Expand All @@ -333,18 +332,22 @@ func TestDatabase_CreateUpdateDeleteSAMLProvider(t *testing.T) {
} else if updatedUser.SSOProvider.SAMLProvider.IssuerURI != newSAMLProvider.IssuerURI {
t.Fatalf("Updated user has SAMLProvider URL %s when %s was expected", updatedUser.SSOProvider.SAMLProvider.IssuerURI, newSAMLProvider.IssuerURI)
} else {
updatedSAMLProvider = model.SAMLProvider{
Serial: model.Serial{
ID: newSAMLProvider.ID,
updatedSSOProvider := model.SSOProvider{
Name: "updated provider",
Type: model.SessionAuthProviderSAML,
SAMLProvider: &model.SAMLProvider{
Serial: model.Serial{
ID: newSAMLProvider.ID,
},
Name: "updated provider",
DisplayName: newSAMLProvider.DisplayName,
IssuerURI: newSAMLProvider.IssuerURI,
SingleSignOnURI: newSAMLProvider.SingleSignOnURI,
SSOProviderID: newSAMLProvider.SSOProviderID,
},
Name: "updated provider",
DisplayName: newSAMLProvider.DisplayName,
IssuerURI: newSAMLProvider.IssuerURI,
SingleSignOnURI: newSAMLProvider.SingleSignOnURI,
SSOProviderID: newSAMLProvider.SSOProviderID,
}

if err = dbInst.UpdateSAMLIdentityProvider(ctx, updatedSAMLProvider); err != nil {
if _, err = dbInst.UpdateSAMLIdentityProvider(ctx, updatedSSOProvider); err != nil {
t.Fatalf("Failed to update SAML provider: %v", err)
} else if err = test.VerifyAuditLogs(dbInst, model.AuditLogActionUpdateSAMLIdentityProvider, "saml_name", "updated provider"); err != nil {
t.Fatalf("Failed to validate UpdateSAMLIdentityProvider audit logs:\n%v", err)
Expand Down
51 changes: 48 additions & 3 deletions cmd/api/src/database/mocks/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 31 additions & 8 deletions cmd/api/src/database/oidc_providers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,50 @@ import (

"github.com/specterops/bloodhound/src/model"
"github.com/specterops/bloodhound/src/test/integration"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestBloodhoundDB_CreateOIDCProvider(t *testing.T) {
func TestBloodhoundDB_CreateUpdateOIDCProvider(t *testing.T) {
var (
testCtx = context.Background()
dbInst = integration.SetupDB(t)
)
defer dbInst.Close(testCtx)

t.Run("successfully create an OIDC provider", func(t *testing.T) {
t.Run("successfully create and update an OIDC provider", func(t *testing.T) {
provider, err := dbInst.CreateOIDCProvider(testCtx, "test", "https://test.localhost.com/auth", "bloodhound")
require.NoError(t, err)

assert.Equal(t, "https://test.localhost.com/auth", provider.Issuer)
assert.Equal(t, "bloodhound", provider.ClientID)
assert.NotEmpty(t, provider.ID)
require.Equal(t, "https://test.localhost.com/auth", provider.Issuer)
require.Equal(t, "bloodhound", provider.ClientID)
require.EqualValues(t, 1, provider.ID)

_, count, err := dbInst.ListAuditLogs(testCtx, time.Now().Add(-time.Minute), time.Now().Add(time.Minute), 0, 10, "", model.SQLFilter{})
_, count, err := dbInst.ListAuditLogs(testCtx, time.Now().Add(time.Minute), time.Now().Add(-time.Minute), 0, 10, "", model.SQLFilter{})
require.NoError(t, err)
assert.Equal(t, 4, count)
require.Equal(t, 4, count)

updatedSSOProvider := model.SSOProvider{
Name: "updated provider",
Type: model.SessionAuthProviderOIDC,
OIDCProvider: &model.OIDCProvider{
Serial: model.Serial{
ID: provider.ID,
},
ClientID: "gotham-net",
Issuer: "https://gotham.net",
SSOProviderID: provider.SSOProviderID,
},
}

provider, err = dbInst.UpdateOIDCProvider(testCtx, updatedSSOProvider)
require.NoError(t, err)

require.Equal(t, updatedSSOProvider.OIDCProvider.Issuer, provider.Issuer)
require.Equal(t, updatedSSOProvider.OIDCProvider.ClientID, provider.ClientID)
require.EqualValues(t, updatedSSOProvider.OIDCProvider.ID, provider.ID)

_, count, err = dbInst.ListAuditLogs(testCtx, time.Now().Add(time.Minute), time.Now().Add(-time.Minute), 0, 10, "", model.SQLFilter{})
require.NoError(t, err)
require.Equal(t, 8, count)
})
}

0 comments on commit f058ff7

Please sign in to comment.