diff --git a/cmd/api/src/api/v2/auth/saml.go b/cmd/api/src/api/v2/auth/saml.go index 3744b624b9..d1f9ac1a3f 100644 --- a/cmd/api/src/api/v2/auth/saml.go +++ b/cmd/api/src/api/v2/auth/saml.go @@ -23,6 +23,7 @@ import ( "io" "net/http" "strconv" + "strings" "github.com/crewjam/saml" "github.com/crewjam/saml/samlsp" @@ -144,9 +145,7 @@ func (s ManagementResource) CreateSAMLProviderMultipart(response http.ResponseWr api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) } else if metadata, err := samlsp.ParseMetadata(metadataXML); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) - } else if ssoDescriptor, err := auth.GetIDPSingleSignOnDescriptor(metadata, saml.HTTPPostBinding); err != nil { - api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) - } else if ssoURL, err := auth.GetIDPSingleSignOnServiceURL(ssoDescriptor, saml.HTTPPostBinding); err != nil { + } else if ssoURL, err := auth.GetIDPSingleSignOnServiceURL(metadata, saml.HTTPPostBinding); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "metadata does not have a SSO service that supports HTTP POST binding", request), response) } else { samlIdentityProvider.Name = providerNames[0] @@ -211,9 +210,7 @@ func (s ManagementResource) UpdateSAMLProviderRequest(response http.ResponseWrit api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) } else if metadata, err := samlsp.ParseMetadata(metadataXML); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) - } else if ssoDescriptor, err := auth.GetIDPSingleSignOnDescriptor(metadata, saml.HTTPPostBinding); err != nil { - api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) - } else if ssoURL, err := auth.GetIDPSingleSignOnServiceURL(ssoDescriptor, saml.HTTPPostBinding); err != nil { + } else if ssoURL, err := auth.GetIDPSingleSignOnServiceURL(metadata, saml.HTTPPostBinding); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "metadata does not have a SSO service that supports HTTP POST binding", request), response) } else { ssoProvider.Name = providerNames[0] @@ -224,6 +221,24 @@ func (s ManagementResource) UpdateSAMLProviderRequest(response http.ResponseWrit ssoProvider.SAMLProvider.IssuerURI = metadata.EntityID ssoProvider.SAMLProvider.SingleSignOnURI = ssoURL + // It's possible to update the ACS url which will be reflected in the metadataXML, we need to guarantee it is set to only what we expect if it is present + if acsUrl, err := auth.GetAssertionConsumerServiceURL(metadata, saml.HTTPPostBinding); err == nil { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "metadata does not have an ACS location that supports HTTP POST binding", request), response) + if !strings.Contains(acsUrl, model.SAMLRootURIVersionMap[ssoProvider.SAMLProvider.RootURIVersion]) { + var validUri bool + for rootUriVersion, path := range model.SAMLRootURIVersionMap { + if strings.Contains(acsUrl, path) { + ssoProvider.SAMLProvider.RootURIVersion = rootUriVersion + validUri = true + break + } + } + if !validUri { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "metadata does not have a valid ACS location", request), response) + } + } + } + if newSAMLProvider, err := s.db.UpdateSAMLIdentityProvider(request.Context(), ssoProvider); err != nil { api.HandleDatabaseError(request, response, err) } else { diff --git a/cmd/api/src/auth/saml.go b/cmd/api/src/auth/saml.go index 9d7a529a10..37bbd8fed7 100644 --- a/cmd/api/src/auth/saml.go +++ b/cmd/api/src/auth/saml.go @@ -28,26 +28,42 @@ import ( "github.com/specterops/bloodhound/src/model" ) -func GetIDPSingleSignOnServiceURL(idp saml.IDPSSODescriptor, bindingType string) (string, error) { - for _, singleSignOnService := range idp.SingleSignOnServices { - if singleSignOnService.Binding == bindingType { - return singleSignOnService.Location, nil +func getIDPSingleSignOnDescriptor(metadata *saml.EntityDescriptor, bindingType string) (saml.IDPSSODescriptor, error) { + for _, idpSSODescriptor := range metadata.IDPSSODescriptors { + for _, singleSignOnService := range idpSSODescriptor.SingleSignOnServices { + if singleSignOnService.Binding == bindingType { + return idpSSODescriptor, nil + } } } - return "", fmt.Errorf("no SSO service defined that supports the %s binding type", bindingType) + return saml.IDPSSODescriptor{}, fmt.Errorf("no SSO service defined that supports the %s binding type", bindingType) } -func GetIDPSingleSignOnDescriptor(metadata *saml.EntityDescriptor, bindingType string) (saml.IDPSSODescriptor, error) { - for _, idpSSODescriptor := range metadata.IDPSSODescriptors { - for _, singleSignOnService := range idpSSODescriptor.SingleSignOnServices { +func GetIDPSingleSignOnServiceURL(metadata *saml.EntityDescriptor, bindingType string) (string, error) { + if ssoDescriptor, err := getIDPSingleSignOnDescriptor(metadata, saml.HTTPPostBinding); err != nil { + return "", err + } else { + for _, singleSignOnService := range ssoDescriptor.SingleSignOnServices { if singleSignOnService.Binding == bindingType { - return idpSSODescriptor, nil + return singleSignOnService.Location, nil } } } + return "", fmt.Errorf("no SSO service defined that supports the %s binding type", bindingType) +} - return saml.IDPSSODescriptor{}, fmt.Errorf("no SSO service defined that supports the %s binding type", bindingType) +// GetAssertionConsumerServiceURL This may not be present, we return the first we find +func GetAssertionConsumerServiceURL(metadata *saml.EntityDescriptor, bindingType string) (string, error) { + for _, spSSODescriptor := range metadata.SPSSODescriptors { + for _, acs := range spSSODescriptor.AssertionConsumerServices { + if acs.Binding == bindingType { + return acs.Location, nil + } + } + } + + return "", fmt.Errorf("no SAML ascertion consumer service url defined in metadata xml") } func NewServiceProvider(hostUrl url.URL, cfg config.Configuration, samlProvider model.SAMLProvider) (saml.ServiceProvider, error) { diff --git a/cmd/api/src/model/samlprovider.go b/cmd/api/src/model/samlprovider.go index 124039c2a6..480e9b002f 100644 --- a/cmd/api/src/model/samlprovider.go +++ b/cmd/api/src/model/samlprovider.go @@ -43,8 +43,13 @@ var ( type SAMLRootURIVersion int var ( - SAMLRootURIVersion1 SAMLRootURIVersion = 1 // "/v2/login/saml/{slug}/" - SAMLRootURIVersion2 SAMLRootURIVersion = 2 // "/v2/sso/{slug}/" + SAMLRootURIVersion1 SAMLRootURIVersion = 1 + SAMLRootURIVersion2 SAMLRootURIVersion = 2 + + SAMLRootURIVersionMap = map[SAMLRootURIVersion]string { + SAMLRootURIVersion1: "/api/v1/login/saml", + SAMLRootURIVersion2: "/api/v2/sso", + } ) type SAMLProvider struct { @@ -142,14 +147,13 @@ func (s SAMLProvider) GetSAMLUserPrincipalNameFromAssertion(assertion *saml.Asse func (s *SAMLProvider) FormatSAMLProviderURLs(hostUrl url.URL) { root := hostUrl + root.Path = path.Join(SAMLRootURIVersionMap[s.RootURIVersion], s.Name) // To preserve existing IDP configurations, existing saml providers still use the old acs endpoint which redirects to the new callback handler switch s.RootURIVersion { case SAMLRootURIVersion1: - root.Path = path.Join("/api/v1/login/saml", s.Name) s.ServiceProviderACSURI = serde.FromURL(*root.JoinPath("acs")) case SAMLRootURIVersion2: - root.Path = path.Join("/api/v2/sso", s.Name) s.ServiceProviderACSURI = serde.FromURL(*root.JoinPath("callback")) }