diff --git a/cmd/api/src/api/v2/auth/oidc_test.go b/cmd/api/src/api/v2/auth/oidc_test.go index d9f9db103c..72c88982ac 100644 --- a/cmd/api/src/api/v2/auth/oidc_test.go +++ b/cmd/api/src/api/v2/auth/oidc_test.go @@ -23,15 +23,13 @@ import ( "github.com/specterops/bloodhound/src/api/v2/apitest" "github.com/specterops/bloodhound/src/api/v2/auth" + "github.com/specterops/bloodhound/src/database" "github.com/specterops/bloodhound/src/model" "github.com/specterops/bloodhound/src/utils/test" "go.uber.org/mock/gomock" ) func TestManagementResource_CreateOIDCProvider(t *testing.T) { - const ( - url = "/api/v2/sso/providers/oidc" - ) var ( mockCtrl = gomock.NewController(t) resources, mockDB = apitest.NewAuthManagementResource(mockCtrl) @@ -45,8 +43,6 @@ func TestManagementResource_CreateOIDCProvider(t *testing.T) { }, nil) test.Request(t). - WithMethod(http.MethodPost). - WithURL(url). WithBody(auth.UpsertOIDCProviderRequest{ Name: "Bloodhound gang", Issuer: "https://localhost/auth", @@ -59,9 +55,6 @@ func TestManagementResource_CreateOIDCProvider(t *testing.T) { t.Run("error parsing body request", func(t *testing.T) { test.Request(t). - WithMethod(http.MethodPost). - WithURL(url). - WithBody(""). OnHandlerFunc(resources.CreateOIDCProvider). Require(). ResponseStatusCode(http.StatusBadRequest) @@ -69,8 +62,6 @@ func TestManagementResource_CreateOIDCProvider(t *testing.T) { t.Run("error validating request field", func(t *testing.T) { test.Request(t). - WithMethod(http.MethodPost). - WithURL(url). WithBody(auth.UpsertOIDCProviderRequest{ Name: "test", Issuer: "1234:not:a:url", @@ -86,8 +77,6 @@ func TestManagementResource_CreateOIDCProvider(t *testing.T) { Issuer: "12345:bloodhound", } test.Request(t). - WithMethod(http.MethodPost). - WithURL(url). WithBody(request). OnHandlerFunc(resources.CreateOIDCProvider). Require(). @@ -98,8 +87,6 @@ func TestManagementResource_CreateOIDCProvider(t *testing.T) { mockDB.EXPECT().CreateOIDCProvider(gomock.Any(), "test", "https://localhost/auth", "bloodhound").Return(model.OIDCProvider{}, fmt.Errorf("error")) test.Request(t). - WithMethod(http.MethodPost). - WithURL(url). WithBody(auth.UpsertOIDCProviderRequest{ Name: "test", Issuer: "https://localhost/auth", @@ -110,3 +97,87 @@ func TestManagementResource_CreateOIDCProvider(t *testing.T) { ResponseStatusCode(http.StatusInternalServerError) }) } + +func TestManagementResource_UpdateOIDCProvider(t *testing.T) { + var ( + mockCtrl = gomock.NewController(t) + resources, mockDB = apitest.NewAuthManagementResource(mockCtrl) + baseProvider = model.SSOProvider{ + Type: model.SessionAuthProviderOIDC, + Name: "Gotham Net", + OIDCProvider: &model.OIDCProvider{ + ClientID: "gotham-net", + Issuer: "https://gotham.net", + }, + } + urlParams = map[string]string{"sso_provider_id": "1"} + ) + defer mockCtrl.Finish() + + t.Run("successfully update an OIDCProvider", func(t *testing.T) { + mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(baseProvider, nil) + mockDB.EXPECT().UpdateOIDCProvider(gomock.Any(), gomock.Any()) + + test.Request(t). + WithURLPathVars(urlParams). + WithBody(auth.UpsertOIDCProviderRequest{ + Name: "Gotham Net 2", + Issuer: "https://gotham-2.net", + ClientID: "gotham-net-2", + }). + OnHandlerFunc(resources.UpdateSSOProvider). + Require(). + ResponseStatusCode(http.StatusOK) + }) + + t.Run("error not found while updating an unknown OIDCProvider", func(t *testing.T) { + mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(model.SSOProvider{}, database.ErrNotFound) + + test.Request(t). + WithURLPathVars(urlParams). + OnHandlerFunc(resources.UpdateSSOProvider). + Require(). + ResponseStatusCode(http.StatusNotFound) + }) + + t.Run("error parsing body request", func(t *testing.T) { + mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(baseProvider, nil) + + test.Request(t). + WithURLPathVars(urlParams). + OnHandlerFunc(resources.UpdateSSOProvider). + Require(). + ResponseStatusCode(http.StatusBadRequest) + }) + + t.Run("error validating request field", func(t *testing.T) { + mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(baseProvider, nil) + + test.Request(t). + WithURLPathVars(urlParams). + WithBody(auth.UpsertOIDCProviderRequest{ + Name: "test", + Issuer: "1234:not:a:url", + ClientID: "bloodhound", + }). + OnHandlerFunc(resources.UpdateSSOProvider). + Require(). + ResponseStatusCode(http.StatusBadRequest) + }) + + t.Run("error creating oidc provider db entry", func(t *testing.T) { + mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(baseProvider, nil) + mockDB.EXPECT().UpdateOIDCProvider(gomock.Any(), gomock.Any()).Return(model.OIDCProvider{}, fmt.Errorf("error")) + + test.Request(t). + WithURLPathVars(urlParams). + WithBody(auth.UpsertOIDCProviderRequest{ + Name: "test", + Issuer: "https://localhost/auth", + ClientID: "bloodhound", + }). + OnHandlerFunc(resources.UpdateSSOProvider). + Require(). + ResponseStatusCode(http.StatusInternalServerError) + }) +} diff --git a/cmd/api/src/database/auth_test.go b/cmd/api/src/database/auth_test.go index 8ec2bcfeca..ec586c25e0 100644 --- a/cmd/api/src/database/auth_test.go +++ b/cmd/api/src/database/auth_test.go @@ -300,13 +300,12 @@ func TestDatabase_CreateGetDeleteAuthSecret(t *testing.T) { func TestDatabase_CreateUpdateDeleteSAMLProvider(t *testing.T) { var ( - ctx = context.Background() - dbInst, user = initAndCreateUser(t) - samlProvider model.SAMLProvider - newSAMLProvider model.SAMLProvider - updatedUser model.User - updatedSAMLProvider model.SAMLProvider - err error + ctx = context.Background() + dbInst, user = initAndCreateUser(t) + samlProvider model.SAMLProvider + newSAMLProvider model.SAMLProvider + updatedUser model.User + err error ) // Initialize the SAMLProvider without setting SSOProviderID samlProvider = model.SAMLProvider{ @@ -333,18 +332,22 @@ func TestDatabase_CreateUpdateDeleteSAMLProvider(t *testing.T) { } else if updatedUser.SSOProvider.SAMLProvider.IssuerURI != newSAMLProvider.IssuerURI { t.Fatalf("Updated user has SAMLProvider URL %s when %s was expected", updatedUser.SSOProvider.SAMLProvider.IssuerURI, newSAMLProvider.IssuerURI) } else { - updatedSAMLProvider = model.SAMLProvider{ - Serial: model.Serial{ - ID: newSAMLProvider.ID, + updatedSSOProvider := model.SSOProvider{ + Name: "updated provider", + Type: model.SessionAuthProviderSAML, + SAMLProvider: &model.SAMLProvider{ + Serial: model.Serial{ + ID: newSAMLProvider.ID, + }, + Name: "updated provider", + DisplayName: newSAMLProvider.DisplayName, + IssuerURI: newSAMLProvider.IssuerURI, + SingleSignOnURI: newSAMLProvider.SingleSignOnURI, + SSOProviderID: newSAMLProvider.SSOProviderID, }, - Name: "updated provider", - DisplayName: newSAMLProvider.DisplayName, - IssuerURI: newSAMLProvider.IssuerURI, - SingleSignOnURI: newSAMLProvider.SingleSignOnURI, - SSOProviderID: newSAMLProvider.SSOProviderID, } - if err = dbInst.UpdateSAMLIdentityProvider(ctx, updatedSAMLProvider); err != nil { + if _, err = dbInst.UpdateSAMLIdentityProvider(ctx, updatedSSOProvider); err != nil { t.Fatalf("Failed to update SAML provider: %v", err) } else if err = test.VerifyAuditLogs(dbInst, model.AuditLogActionUpdateSAMLIdentityProvider, "saml_name", "updated provider"); err != nil { t.Fatalf("Failed to validate UpdateSAMLIdentityProvider audit logs:\n%v", err) diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 271430e0dc..2482616a01 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -1593,6 +1593,20 @@ func (mr *MockDatabaseMockRecorder) SweepSessions(arg0 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SweepSessions", reflect.TypeOf((*MockDatabase)(nil).SweepSessions), arg0) } +// TerminateUserSessionsBySSOProvider mocks base method. +func (m *MockDatabase) TerminateUserSessionsBySSOProvider(arg0 context.Context, arg1 model.SSOProvider) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TerminateUserSessionsBySSOProvider", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// TerminateUserSessionsBySSOProvider indicates an expected call of TerminateUserSessionsBySSOProvider. +func (mr *MockDatabaseMockRecorder) TerminateUserSessionsBySSOProvider(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TerminateUserSessionsBySSOProvider", reflect.TypeOf((*MockDatabase)(nil).TerminateUserSessionsBySSOProvider), arg0, arg1) +} + // UpdateAssetGroup mocks base method. func (m *MockDatabase) UpdateAssetGroup(arg0 context.Context, arg1 model.AssetGroup) error { m.ctrl.T.Helper() @@ -1664,12 +1678,28 @@ func (mr *MockDatabaseMockRecorder) UpdateFileUploadJob(arg0, arg1 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateFileUploadJob", reflect.TypeOf((*MockDatabase)(nil).UpdateFileUploadJob), arg0, arg1) } +// UpdateOIDCProvider mocks base method. +func (m *MockDatabase) UpdateOIDCProvider(arg0 context.Context, arg1 model.SSOProvider) (model.OIDCProvider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateOIDCProvider", arg0, arg1) + ret0, _ := ret[0].(model.OIDCProvider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateOIDCProvider indicates an expected call of UpdateOIDCProvider. +func (mr *MockDatabaseMockRecorder) UpdateOIDCProvider(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOIDCProvider", reflect.TypeOf((*MockDatabase)(nil).UpdateOIDCProvider), arg0, arg1) +} + // UpdateSAMLIdentityProvider mocks base method. -func (m *MockDatabase) UpdateSAMLIdentityProvider(arg0 context.Context, arg1 model.SAMLProvider) error { +func (m *MockDatabase) UpdateSAMLIdentityProvider(arg0 context.Context, arg1 model.SSOProvider) (model.SAMLProvider, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateSAMLIdentityProvider", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].(model.SAMLProvider) + ret1, _ := ret[1].(error) + return ret0, ret1 } // UpdateSAMLIdentityProvider indicates an expected call of UpdateSAMLIdentityProvider. @@ -1678,6 +1708,21 @@ func (mr *MockDatabaseMockRecorder) UpdateSAMLIdentityProvider(arg0, arg1 interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSAMLIdentityProvider", reflect.TypeOf((*MockDatabase)(nil).UpdateSAMLIdentityProvider), arg0, arg1) } +// UpdateSSOProvider mocks base method. +func (m *MockDatabase) UpdateSSOProvider(arg0 context.Context, arg1 model.SSOProvider) (model.SSOProvider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateSSOProvider", arg0, arg1) + ret0, _ := ret[0].(model.SSOProvider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateSSOProvider indicates an expected call of UpdateSSOProvider. +func (mr *MockDatabaseMockRecorder) UpdateSSOProvider(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSSOProvider", reflect.TypeOf((*MockDatabase)(nil).UpdateSSOProvider), arg0, arg1) +} + // UpdateSavedQuery mocks base method. func (m *MockDatabase) UpdateSavedQuery(arg0 context.Context, arg1 model.SavedQuery) (model.SavedQuery, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/oidc_providers_test.go b/cmd/api/src/database/oidc_providers_test.go index 28b6d716aa..dfad1db7a8 100644 --- a/cmd/api/src/database/oidc_providers_test.go +++ b/cmd/api/src/database/oidc_providers_test.go @@ -26,27 +26,50 @@ import ( "github.com/specterops/bloodhound/src/model" "github.com/specterops/bloodhound/src/test/integration" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestBloodhoundDB_CreateOIDCProvider(t *testing.T) { +func TestBloodhoundDB_CreateUpdateOIDCProvider(t *testing.T) { var ( testCtx = context.Background() dbInst = integration.SetupDB(t) ) defer dbInst.Close(testCtx) - t.Run("successfully create an OIDC provider", func(t *testing.T) { + t.Run("successfully create and update an OIDC provider", func(t *testing.T) { provider, err := dbInst.CreateOIDCProvider(testCtx, "test", "https://test.localhost.com/auth", "bloodhound") require.NoError(t, err) - assert.Equal(t, "https://test.localhost.com/auth", provider.Issuer) - assert.Equal(t, "bloodhound", provider.ClientID) - assert.NotEmpty(t, provider.ID) + require.Equal(t, "https://test.localhost.com/auth", provider.Issuer) + require.Equal(t, "bloodhound", provider.ClientID) + require.EqualValues(t, 1, provider.ID) - _, count, err := dbInst.ListAuditLogs(testCtx, time.Now().Add(-time.Minute), time.Now().Add(time.Minute), 0, 10, "", model.SQLFilter{}) + _, count, err := dbInst.ListAuditLogs(testCtx, time.Now().Add(time.Minute), time.Now().Add(-time.Minute), 0, 10, "", model.SQLFilter{}) require.NoError(t, err) - assert.Equal(t, 4, count) + require.Equal(t, 4, count) + + updatedSSOProvider := model.SSOProvider{ + Name: "updated provider", + Type: model.SessionAuthProviderOIDC, + OIDCProvider: &model.OIDCProvider{ + Serial: model.Serial{ + ID: provider.ID, + }, + ClientID: "gotham-net", + Issuer: "https://gotham.net", + SSOProviderID: provider.SSOProviderID, + }, + } + + provider, err = dbInst.UpdateOIDCProvider(testCtx, updatedSSOProvider) + require.NoError(t, err) + + require.Equal(t, updatedSSOProvider.OIDCProvider.Issuer, provider.Issuer) + require.Equal(t, updatedSSOProvider.OIDCProvider.ClientID, provider.ClientID) + require.EqualValues(t, updatedSSOProvider.OIDCProvider.ID, provider.ID) + + _, count, err = dbInst.ListAuditLogs(testCtx, time.Now().Add(time.Minute), time.Now().Add(-time.Minute), 0, 10, "", model.SQLFilter{}) + require.NoError(t, err) + require.Equal(t, 8, count) }) }