diff --git a/.gitignore b/.gitignore index e549f085..e4a85bea 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,5 @@ github_deploy_key *.json *.jwt -/go.work -/go.work.sum +go.work +go.work.sum diff --git a/edge-apis/authwrapper.go b/edge-apis/authwrapper.go index 956a6bd1..7ce55025 100644 --- a/edge-apis/authwrapper.go +++ b/edge-apis/authwrapper.go @@ -10,10 +10,12 @@ import ( "github.com/go-resty/resty/v2" "github.com/openziti/edge-api/rest_client_api_client" clientAuth "github.com/openziti/edge-api/rest_client_api_client/authentication" + clientControllers "github.com/openziti/edge-api/rest_client_api_client/controllers" clientApiSession "github.com/openziti/edge-api/rest_client_api_client/current_api_session" clientInfo "github.com/openziti/edge-api/rest_client_api_client/informational" "github.com/openziti/edge-api/rest_management_api_client" manAuth "github.com/openziti/edge-api/rest_management_api_client/authentication" + manControllers "github.com/openziti/edge-api/rest_management_api_client/controllers" manCurApiSession "github.com/openziti/edge-api/rest_management_api_client/current_api_session" manInfo "github.com/openziti/edge-api/rest_management_api_client/informational" "github.com/openziti/edge-api/rest_model" @@ -24,7 +26,6 @@ import ( "github.com/zitadel/oidc/v2/pkg/client/tokenexchange" "github.com/zitadel/oidc/v2/pkg/oidc" "golang.org/x/oauth2" - "net" "net/http" "net/url" "strings" @@ -45,6 +46,9 @@ type AuthEnabledApi interface { //http client if not provided. Authenticate(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) SetUseOidc(bool) + ListControllers() (*rest_model.ControllersList, error) + GetClientTransportPool() ClientTransportPool + SetClientTransportPool(ClientTransportPool) } type ApiSession interface { @@ -217,6 +221,8 @@ func (a *ApiSessionOidc) GetExpiresAt() *time.Time { return nil } +var _ AuthEnabledApi = (*ZitiEdgeManagement)(nil) + // ZitiEdgeManagement is an alias of the go-swagger generated client that allows this package to add additional // functionality to the alias type to implement the AuthEnabledApi interface. type ZitiEdgeManagement struct { @@ -233,9 +239,26 @@ type ZitiEdgeManagement struct { versionOnce sync.Once versionInfo *rest_model.Version - apiUrl *url.URL + TotpCallback func(chan string) + ClientTransportPool ClientTransportPool +} + +func (self *ZitiEdgeManagement) SetClientTransportPool(transportPool ClientTransportPool) { + self.ClientTransportPool = transportPool +} - TotpCallback func(chan string) +func (self *ZitiEdgeManagement) GetClientTransportPool() ClientTransportPool { + return self.ClientTransportPool +} + +func (self *ZitiEdgeManagement) ListControllers() (*rest_model.ControllersList, error) { + params := manControllers.NewListControllersParams() + resp, err := self.Controllers.ListControllers(params, nil) + if err != nil { + return nil, err + } + + return &resp.GetPayload().Data, nil } func (self *ZitiEdgeManagement) Authenticate(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) { @@ -288,8 +311,8 @@ func (self *ZitiEdgeManagement) legacyAuth(credentials Credentials, configTypes return &ApiSessionLegacy{Detail: resp.GetPayload().Data}, err } -func (self *ZitiEdgeManagement) oidcAuth(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) { - return oidcAuth(self.apiUrl.Host, credentials, configTypes, httpClient, self.TotpCallback) +func (self *ZitiEdgeManagement) oidcAuth(credentials Credentials, configTypeOverrides []string, httpClient *http.Client) (ApiSession, error) { + return oidcAuth(self.ClientTransportPool, credentials, configTypeOverrides, httpClient, self.TotpCallback) } func (self *ZitiEdgeManagement) SetUseOidc(use bool) { @@ -328,9 +351,11 @@ func (self *ZitiEdgeManagement) RefreshApiSession(apiSession ApiSession, httpCli } func (self *ZitiEdgeManagement) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenClaims], httpClient *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) { - return exchangeTokens(getBaseUrl(self.apiUrl), curTokens, httpClient) + return exchangeTokens(self.ClientTransportPool, curTokens, httpClient) } +var _ AuthEnabledApi = (*ZitiEdgeClient)(nil) + // ZitiEdgeClient is an alias of the go-swagger generated client that allows this package to add additional // functionality to the alias type to implement the AuthEnabledApi interface. type ZitiEdgeClient struct { @@ -346,12 +371,30 @@ type ZitiEdgeClient struct { versionInfo *rest_model.Version versionOnce sync.Once - apiUrl *url.URL - TotpCallback func(chan string) + TotpCallback func(chan string) + ClientTransportPool ClientTransportPool +} + +func (self *ZitiEdgeClient) GetClientTransportPool() ClientTransportPool { + return self.ClientTransportPool +} + +func (self *ZitiEdgeClient) SetClientTransportPool(transportPool ClientTransportPool) { + self.ClientTransportPool = transportPool } -func (self *ZitiEdgeClient) Authenticate(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) { +func (self *ZitiEdgeClient) ListControllers() (*rest_model.ControllersList, error) { + params := clientControllers.NewListControllersParams() + resp, err := self.Controllers.ListControllers(params, nil) + if err != nil { + return nil, err + } + + return &resp.GetPayload().Data, nil +} + +func (self *ZitiEdgeClient) Authenticate(credentials Credentials, configTypesOverrides []string, httpClient *http.Client) (ApiSession, error) { self.versionOnce.Do(func() { if self.useOidcExplicitlySet { return @@ -372,10 +415,10 @@ func (self *ZitiEdgeClient) Authenticate(credentials Credentials, configTypes [] }) if self.useOidc { - return self.oidcAuth(credentials, configTypes, httpClient) + return self.oidcAuth(credentials, configTypesOverrides, httpClient) } - return self.legacyAuth(credentials, configTypes, httpClient) + return self.legacyAuth(credentials, configTypesOverrides, httpClient) } func (self *ZitiEdgeClient) legacyAuth(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) { @@ -401,8 +444,8 @@ func (self *ZitiEdgeClient) legacyAuth(credentials Credentials, configTypes []st return &ApiSessionLegacy{Detail: resp.GetPayload().Data}, err } -func (self *ZitiEdgeClient) oidcAuth(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) { - return oidcAuth(self.apiUrl.Host, credentials, configTypes, httpClient, self.TotpCallback) +func (self *ZitiEdgeClient) oidcAuth(credentials Credentials, configTypeOverrides []string, httpClient *http.Client) (ApiSession, error) { + return oidcAuth(self.ClientTransportPool, credentials, configTypeOverrides, httpClient, self.TotpCallback) } func (self *ZitiEdgeClient) SetUseOidc(use bool) { @@ -445,59 +488,67 @@ func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession, httpClient } func (self *ZitiEdgeClient) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTokenClaims], httpClient *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) { - return exchangeTokens(getBaseUrl(self.apiUrl), curTokens, httpClient) + return exchangeTokens(self.ClientTransportPool, curTokens, httpClient) } -func getBaseUrl(apiUrl *url.URL) string { - urlCopy := *apiUrl - urlCopy.Path = "" - return urlCopy.String() -} +func exchangeTokens(clientTransportPool ClientTransportPool, curTokens *oidc.Tokens[*oidc.IDTokenClaims], client *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) { -func exchangeTokens(issuer string, curTokens *oidc.Tokens[*oidc.IDTokenClaims], client *http.Client) (*oidc.Tokens[*oidc.IDTokenClaims], error) { - te, err := tokenexchange.NewTokenExchanger(issuer, tokenexchange.WithHTTPClient(client)) + var outTokens *oidc.Tokens[*oidc.IDTokenClaims] - if err != nil { - return nil, err - } + _, err := clientTransportPool.TryTransportForF(func(transport *ApiClientTransport) (any, error) { + apiHost := transport.ApiUrl.Host + te, err := tokenexchange.NewTokenExchanger(apiHost, tokenexchange.WithHTTPClient(client)) - accessResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.AccessTokenType) + if err != nil { + return nil, err + } - if err != nil { - return nil, err - } + accessResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.AccessTokenType) + + if err != nil { + return nil, err + } - //TODO: be smarter, only refresh refresh token if the new access token lives beyond refresh - refreshResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.RefreshTokenType) + //TODO: be smarter, only refresh refresh token if the new access token lives beyond refresh + refreshResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.RefreshTokenType) - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - idResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.IDTokenType) + idResp, err := tokenexchange.ExchangeToken(te, curTokens.RefreshToken, oidc.RefreshTokenType, "", "", nil, nil, nil, oidc.IDTokenType) - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - idClaims := &oidc.IDTokenClaims{} + idClaims := &oidc.IDTokenClaims{} + + err = json.Unmarshal([]byte(idResp.AccessToken), idClaims) + + if err != nil { + return nil, err + } + + outTokens = &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: &oauth2.Token{ + AccessToken: accessResp.AccessToken, + TokenType: accessResp.TokenType, + RefreshToken: refreshResp.RefreshToken, + Expiry: time.Time{}, + }, + IDTokenClaims: idClaims, + IDToken: idResp.AccessToken, //access token is used to hold id token per zitadel comments + } - err = json.Unmarshal([]byte(idResp.AccessToken), idClaims) + return outTokens, nil + }) if err != nil { return nil, err } - return &oidc.Tokens[*oidc.IDTokenClaims]{ - Token: &oauth2.Token{ - AccessToken: accessResp.AccessToken, - TokenType: accessResp.TokenType, - RefreshToken: refreshResp.RefreshToken, - Expiry: time.Time{}, - }, - IDTokenClaims: idClaims, - IDToken: idResp.AccessToken, //access token is used to hold id token per zitadel comments - }, nil + return outTokens, nil } type authPayload struct { @@ -532,14 +583,18 @@ func (a *authPayload) toMap() map[string]string { return result } -func oidcAuth(issuer string, credentials Credentials, configTypes []string, httpClient *http.Client, totpCallback func(chan string)) (ApiSession, error) { +func oidcAuth(clientTransportPool ClientTransportPool, credentials Credentials, configTypeOverrides []string, httpClient *http.Client, totpCallback func(chan string)) (ApiSession, error) { payload := &authPayload{ Authenticate: credentials.Payload(), } method := credentials.Method() - payload.ConfigTypes = configTypes + + if configTypeOverrides != nil { + payload.ConfigTypes = configTypeOverrides + } certs := credentials.TlsCerts() + if len(certs) != 0 { if transport, ok := httpClient.Transport.(*http.Transport); ok { transport.TLSClientConfig.Certificates = certs @@ -547,112 +602,118 @@ func oidcAuth(issuer string, credentials Credentials, configTypes []string, http } } - rpServer, err := newLocalRpServer(issuer, method) - - if err != nil { - return nil, err - } - - rpServer.Start() - defer rpServer.Stop() - - client := resty.NewWithClient(httpClient) - apiHost := issuer - if host, _, err := net.SplitHostPort(issuer); err == nil { - apiHost = host - } - client.SetRedirectPolicy(resty.DomainCheckRedirectPolicy("127.0.0.1", "localhost", apiHost)) - resp, err := client.R().Get(rpServer.LoginUri) - - if err != nil { - return nil, err - } + var outTokens *oidc.Tokens[*oidc.IDTokenClaims] - if resp.StatusCode() != http.StatusOK { - return nil, fmt.Errorf("local rp login response is expected to be HTTP status %d got %d with body: %s", http.StatusOK, resp.StatusCode(), resp.Body()) - } - payload.AuthRequestId = resp.Header().Get(AuthRequestIdHeader) + _, err := clientTransportPool.TryTransportForF(func(transport *ApiClientTransport) (any, error) { + rpServer, err := newLocalRpServer(transport.ApiUrl.Host, method) - if payload.AuthRequestId == "" { - return nil, errors.New("could not find auth request id header") - } + if err != nil { + return nil, err + } - opLoginUri := "https://" + resp.RawResponse.Request.URL.Host + "/oidc/login/" + method - totpUri := "https://" + resp.RawResponse.Request.URL.Host + "/oidc/login/totp" + rpServer.Start() + defer rpServer.Stop() - formData := payload.toMap() + client := resty.NewWithClient(httpClient) + apiHost := transport.ApiUrl.Hostname() - req := client.R() - clientRequest := asClientRequest(req, client) + client.SetRedirectPolicy(resty.DomainCheckRedirectPolicy("127.0.0.1", "localhost", apiHost)) + resp, err := client.R().Get(rpServer.LoginUri) - err = credentials.AuthenticateRequest(clientRequest, strfmt.Default) + if err != nil { + return nil, err + } - if err != nil { - return nil, err - } + if resp.StatusCode() != http.StatusOK { + return nil, fmt.Errorf("local rp login response is expected to be HTTP status %d got %d with body: %s", http.StatusOK, resp.StatusCode(), resp.Body()) + } + payload.AuthRequestId = resp.Header().Get(AuthRequestIdHeader) - resp, err = req.SetFormData(formData).Post(opLoginUri) + if payload.AuthRequestId == "" { + return nil, errors.New("could not find auth request id header") + } - if err != nil { - return nil, err - } + opLoginUri := "https://" + resp.RawResponse.Request.URL.Host + "/oidc/login/" + method + totpUri := "https://" + resp.RawResponse.Request.URL.Host + "/oidc/login/totp" - if resp.StatusCode() != http.StatusOK { - return nil, fmt.Errorf("remote op login response is expected to be HTTP status %d got %d with body: %s", http.StatusOK, resp.StatusCode(), resp.Body()) - } + formData := payload.toMap() - authRequestId := resp.Header().Get(AuthRequestIdHeader) - totpRequiredHeader := resp.Header().Get(TotpRequiredHeader) - totpRequired := totpRequiredHeader != "" - totpCode := "" + req := client.R() + clientRequest := asClientRequest(req, client) - if totpRequired { + err = credentials.AuthenticateRequest(clientRequest, strfmt.Default) - if totpCallback == nil { - return nil, errors.New("totp is required but not totp callback was defined") - } - codeChan := make(chan string) - go totpCallback(codeChan) - - select { - case code := <-codeChan: - totpCode = code - case <-time.After(30 * time.Minute): - return nil, fmt.Errorf("timedout waiting for totpT callback") + if err != nil { + return nil, err } - resp, err = client.R().SetBody(&totpCodePayload{ - MfaCode: rest_model.MfaCode{ - Code: &totpCode, - }, - AuthRequestId: authRequestId, - }).Post(totpUri) + resp, err = req.SetFormData(formData).Post(opLoginUri) if err != nil { return nil, err } if resp.StatusCode() != http.StatusOK { - apiErr := &errorz.ApiError{} - err = json.Unmarshal(resp.Body(), apiErr) + return nil, fmt.Errorf("remote op login response is expected to be HTTP status %d got %d with body: %s", http.StatusOK, resp.StatusCode(), resp.Body()) + } + + authRequestId := resp.Header().Get(AuthRequestIdHeader) + totpRequiredHeader := resp.Header().Get(TotpRequiredHeader) + totpRequired := totpRequiredHeader != "" + totpCode := "" + + if totpRequired { + + if totpCallback == nil { + return nil, errors.New("totp is required but not totp callback was defined") + } + codeChan := make(chan string) + go totpCallback(codeChan) + + select { + case code := <-codeChan: + totpCode = code + case <-time.After(30 * time.Minute): + return nil, fmt.Errorf("timedout waiting for totpT callback") + } + + resp, err = client.R().SetBody(&totpCodePayload{ + MfaCode: rest_model.MfaCode{ + Code: &totpCode, + }, + AuthRequestId: authRequestId, + }).Post(totpUri) if err != nil { - return nil, fmt.Errorf("could not verify TOTP MFA code recieved %d - could not parse body: %s", resp.StatusCode(), string(resp.Body())) + return nil, err } - return nil, apiErr + if resp.StatusCode() != http.StatusOK { + apiErr := &errorz.ApiError{} + err = json.Unmarshal(resp.Body(), apiErr) + if err != nil { + return nil, fmt.Errorf("could not verify TOTP MFA code recieved %d - could not parse body: %s", resp.StatusCode(), string(resp.Body())) + } + + return nil, apiErr + + } } - } - var outTokens *oidc.Tokens[*oidc.IDTokenClaims] + tokens := <-rpServer.TokenChan - tokens := <-rpServer.TokenChan + if tokens == nil { + return nil, errors.New("authentication did not complete, received nil tokens") + } + outTokens = tokens - if tokens == nil { - return nil, errors.New("authentication did not complete, received nil tokens") + return nil, nil + }) + + if err != nil { + return nil, err } - outTokens = tokens return &ApiSessionOidc{ OidcTokens: outTokens, diff --git a/edge-apis/clients.go b/edge-apis/clients.go index 8e70645d..b7c01349 100644 --- a/edge-apis/clients.go +++ b/edge-apis/clients.go @@ -19,11 +19,14 @@ package edge_apis import ( "crypto/x509" "github.com/go-openapi/runtime" + openapiclient "github.com/go-openapi/runtime/client" "github.com/go-openapi/strfmt" + "github.com/michaelquigley/pfxlog" "github.com/openziti/edge-api/rest_client_api_client" "github.com/openziti/edge-api/rest_management_api_client" - "github.com/pkg/errors" + "net/http" "net/url" + "strings" "sync/atomic" ) @@ -47,11 +50,25 @@ type OidcEnabledApi interface { // BaseClient implements the Client interface specifically for the types specified in the ApiType constraint. It // provides shared functionality that all ApiType types require. type BaseClient[A ApiType] struct { - API *A + API *A + AuthEnabledApi AuthEnabledApi Components - AuthInfoWriter runtime.ClientAuthInfoWriter - ApiSession atomic.Pointer[ApiSession] - Credentials Credentials + AuthInfoWriter runtime.ClientAuthInfoWriter + ApiSession atomic.Pointer[ApiSession] + Credentials Credentials + ApiUrls []*url.URL + ApiBinding string + ApiVersion string + Schemes []string + onControllerListeners []func([]*url.URL) +} + +func (self *BaseClient[A]) Url() url.URL { + return *self.AuthEnabledApi.GetClientTransportPool().GetActiveTransport().ApiUrl +} + +func (self *BaseClient[A]) AddOnControllerUpdateListeners(listener func([]*url.URL)) { + self.onControllerListeners = append(self.onControllerListeners, listener) } // GetCurrentApiSession returns the ApiSession that is being used to authenticate requests. @@ -80,72 +97,109 @@ func (self *BaseClient[A]) SetAllowOidcDynamicallyEnabled(allow bool) { // the API Session details will be returned and the current client will make authenticated requests on future // calls. On an error the API Session in use will be cleared and subsequent requests will become/continue to be // made in an unauthenticated fashion. -func (self *BaseClient[A]) Authenticate(credentials Credentials, configTypes []string) (ApiSession, error) { - //casting to `any` works around golang error that happens when type asserting a generic typed field - myAny := any(self.API) - if a, ok := myAny.(AuthEnabledApi); ok { - self.Credentials = nil - self.ApiSession.Store(nil) - - if credCaPool := credentials.GetCaPool(); credCaPool != nil { - self.HttpTransport.TLSClientConfig.RootCAs = credCaPool - } else { - self.HttpTransport.TLSClientConfig.RootCAs = self.Components.CaPool - } - - apiSession, err := a.Authenticate(credentials, configTypes, self.HttpClient) +func (self *BaseClient[A]) Authenticate(credentials Credentials, configTypesOverride []string) (ApiSession, error) { - if err != nil { - return nil, err - } + self.Credentials = nil + self.ApiSession.Store(nil) - self.Credentials = credentials - self.ApiSession.Store(&apiSession) + if credCaPool := credentials.GetCaPool(); credCaPool != nil { + self.HttpTransport.TLSClientConfig.RootCAs = credCaPool + } else { + self.HttpTransport.TLSClientConfig.RootCAs = self.Components.CaPool + } - self.Runtime.DefaultAuthentication = runtime.ClientAuthInfoWriterFunc(func(request runtime.ClientRequest, registry strfmt.Registry) error { - currentSessionPtr := self.ApiSession.Load() - if currentSessionPtr != nil { - currentSession := *currentSessionPtr + apiSession, err := self.AuthEnabledApi.Authenticate(credentials, configTypesOverride, self.HttpClient) - if currentSession != nil && currentSession.GetToken() != nil { - if err := currentSession.AuthenticateRequest(request, registry); err != nil { - return err - } - } - } + if err != nil { + return nil, err + } - if self.Credentials != nil { - if err := self.Credentials.AuthenticateRequest(request, registry); err != nil { - return err - } - } + self.Credentials = credentials + self.ApiSession.Store(&apiSession) - return nil - }) + self.ProcessControllers(self.AuthEnabledApi) - return apiSession, nil - } - return nil, errors.New("authentication not supported") + return apiSession, nil } // initializeComponents assembles the lower level components necessary for the go-swagger/openapi facilities. -func (self *BaseClient[A]) initializeComponents(apiUrl *url.URL, schemes []string, authInfoWriter runtime.ClientAuthInfoWriter, caPool *x509.CertPool) { - components := NewComponents(apiUrl, schemes) +func (self *BaseClient[A]) initializeComponents(apiUrls []*url.URL, caPool *x509.CertPool) { + components := NewComponents() components.HttpTransport.TLSClientConfig.RootCAs = caPool - components.Runtime.DefaultAuthentication = authInfoWriter components.CaPool = caPool + self.Components = *components } +func NewRuntime(apiUrl *url.URL, schemes []string, httpClient *http.Client) *openapiclient.Runtime { + return openapiclient.NewWithClient(apiUrl.Host, apiUrl.Path, schemes, httpClient) +} + // AuthenticateRequest implements the openapi runtime.ClientAuthInfoWriter interface from the OpenAPI libraries. It is used // to authenticate outgoing requests. func (self *BaseClient[A]) AuthenticateRequest(request runtime.ClientRequest, registry strfmt.Registry) error { if self.AuthInfoWriter != nil { return self.AuthInfoWriter.AuthenticateRequest(request, registry) } + + // do not add auth to authenticating endpoints + if strings.Contains(request.GetPath(), "/oidc/auth") || strings.Contains(request.GetPath(), "/authenticate") { + return nil + } + + currentSessionPtr := self.ApiSession.Load() + if currentSessionPtr != nil { + currentSession := *currentSessionPtr + + if currentSession != nil && currentSession.GetToken() != nil { + if err := currentSession.AuthenticateRequest(request, registry); err != nil { + return err + } + } + } + + if self.Credentials != nil { + if err := self.Credentials.AuthenticateRequest(request, registry); err != nil { + return err + } + } + return nil } +func (self *BaseClient[A]) ProcessControllers(authEnabledApi AuthEnabledApi) { + list, err := authEnabledApi.ListControllers() + + if err != nil { + pfxlog.Logger().WithError(err).Error("error listing controllers, continuing with 1 default configured controller") + return + } + + if list == nil || len(*list) <= 1 { + pfxlog.Logger().Info("no additional controllers reported, continuing with 1 default configured controller") + return + } + + //look for matching api binding and versions + for _, controller := range *list { + apis := controller.APIAddresses[self.ApiBinding] + + for _, apiAddr := range apis { + if apiAddr.Version == self.ApiVersion { + apiUrl, parseErr := url.Parse(apiAddr.URL) + if parseErr == nil { + self.AuthEnabledApi.GetClientTransportPool().Add(apiUrl, NewRuntime(apiUrl, self.Schemes, self.HttpClient)) + } + } + } + } + + apis := self.AuthEnabledApi.GetClientTransportPool().GetApiUrls() + for _, listener := range self.onControllerListeners { + listener(apis) + } +} + // ManagementApiClient provides the ability to authenticate and interact with the Edge Management API. type ManagementApiClient struct { BaseClient[ZitiEdgeManagement] @@ -162,19 +216,31 @@ type ManagementApiClient struct { // For OpenZiti instances not using publicly signed certificates, `ziti.GetControllerWellKnownCaPool()` can be used // to obtain and verify the target controllers CAs. Tools should allow users to verify and accept new controllers // that have not been verified from an outside secret (such as an enrollment token). -func NewManagementApiClient(apiUrl *url.URL, caPool *x509.CertPool, totpCallback func(chan string)) *ManagementApiClient { +func NewManagementApiClient(apiUrls []*url.URL, caPool *x509.CertPool, totpCallback func(chan string)) *ManagementApiClient { ret := &ManagementApiClient{} + ret.Schemes = rest_management_api_client.DefaultSchemes + ret.ApiBinding = "edge-management" + ret.ApiVersion = "v1" + ret.ApiUrls = apiUrls + ret.initializeComponents(apiUrls, caPool) + + transportPool := NewClientTransportPoolRandom() + + for _, apiUrl := range apiUrls { + newRuntime := NewRuntime(apiUrl, ret.Schemes, ret.Components.HttpClient) + newRuntime.DefaultAuthentication = ret + transportPool.Add(apiUrl, newRuntime) + } - ret.initializeComponents(apiUrl, rest_management_api_client.DefaultSchemes, ret, caPool) - - newApi := rest_management_api_client.New(ret.Components.Runtime, nil) + newApi := rest_management_api_client.New(transportPool, nil) api := ZitiEdgeManagement{ - ZitiEdgeManagement: newApi, - apiUrl: apiUrl, - TotpCallback: totpCallback, + ZitiEdgeManagement: newApi, + TotpCallback: totpCallback, + ClientTransportPool: transportPool, } ret.API = &api + ret.AuthEnabledApi = &api return ret } @@ -194,18 +260,31 @@ type ClientApiClient struct { // For OpenZiti instances not using publicly signed certificates, `ziti.GetControllerWellKnownCaPool()` can be used // to obtain and verify the target controllers CAs. Tools should allow users to verify and accept new controllers // that have not been verified from an outside secret (such as an enrollment token). -func NewClientApiClient(apiUrl *url.URL, caPool *x509.CertPool, totpCallback func(chan string)) *ClientApiClient { +func NewClientApiClient(apiUrls []*url.URL, caPool *x509.CertPool, totpCallback func(chan string)) *ClientApiClient { ret := &ClientApiClient{} + ret.ApiBinding = "edge-client" + ret.ApiVersion = "v1" + ret.Schemes = rest_client_api_client.DefaultSchemes + ret.ApiUrls = apiUrls - ret.initializeComponents(apiUrl, rest_client_api_client.DefaultSchemes, ret, caPool) + ret.initializeComponents(apiUrls, caPool) + + transportPool := NewClientTransportPoolRandom() + + for _, apiUrl := range apiUrls { + newRuntime := NewRuntime(apiUrl, ret.Schemes, ret.Components.HttpClient) + newRuntime.DefaultAuthentication = ret + transportPool.Add(apiUrl, newRuntime) + } - newApi := rest_client_api_client.New(ret.Components.Runtime, nil) + newApi := rest_client_api_client.New(transportPool, nil) api := ZitiEdgeClient{ - ZitiEdgeClient: newApi, - apiUrl: apiUrl, - TotpCallback: totpCallback, + ZitiEdgeClient: newApi, + TotpCallback: totpCallback, + ClientTransportPool: transportPool, } ret.API = &api + ret.AuthEnabledApi = &api return ret } diff --git a/edge-apis/component.go b/edge-apis/component.go index 321ada93..2ce21e6b 100644 --- a/edge-apis/component.go +++ b/edge-apis/component.go @@ -2,11 +2,9 @@ package edge_apis import ( "crypto/x509" - openapiclient "github.com/go-openapi/runtime/client" "github.com/openziti/edge-api/rest_util" "net/http" "net/http/cookiejar" - "net/url" "time" ) @@ -14,14 +12,13 @@ import ( // components are interconnected and have references to each other. This struct is used to set, move, and manage // them as a set. type Components struct { - Runtime *openapiclient.Runtime HttpClient *http.Client HttpTransport *http.Transport CaPool *x509.CertPool } // NewComponents assembles a new set of components with reasonable production defaults. -func NewComponents(api *url.URL, schemes []string) *Components { +func NewComponents() *Components { tlsClientConfig, _ := rest_util.NewTlsConfig() httpTransport := &http.Transport{ @@ -43,10 +40,7 @@ func NewComponents(api *url.URL, schemes []string) *Components { Timeout: 10 * time.Second, } - apiRuntime := openapiclient.NewWithClient(api.Host, api.Path, schemes, httpClient) - return &Components{ - Runtime: apiRuntime, HttpClient: httpClient, HttpTransport: httpTransport, } diff --git a/edge-apis/pool.go b/edge-apis/pool.go new file mode 100644 index 00000000..859f3c25 --- /dev/null +++ b/edge-apis/pool.go @@ -0,0 +1,264 @@ +/* + Copyright 2019 NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package edge_apis + +import ( + "github.com/go-openapi/runtime" + "github.com/michaelquigley/pfxlog" + cmap "github.com/orcaman/concurrent-map/v2" + errors "github.com/pkg/errors" + "golang.org/x/exp/rand" + "net" + "net/url" + "sync/atomic" + "time" +) + +type ApiClientTransport struct { + runtime.ClientTransport + ApiUrl *url.URL +} + +// ClientTransportPool abstracts the concept of multiple `runtime.ClientTransport` (openapi interface) representing one +// target OpenZiti network. In situations where controllers are running in HA mode (multiple controllers) this +// interface can attempt to try different controller during outages or partitioning. +type ClientTransportPool interface { + runtime.ClientTransport + + Add(apiUrl *url.URL, transport runtime.ClientTransport) + Remove(apiUrl *url.URL) + + GetActiveTransport() *ApiClientTransport + SetActiveTransport(*ApiClientTransport) + GetApiUrls() []*url.URL + IterateTransportsRandomly() chan<- *ApiClientTransport + + TryTransportsForOp(operation *runtime.ClientOperation) (any, error) + TryTransportForF(cb func(*ApiClientTransport) (any, error)) (any, error) +} + +var _ runtime.ClientTransport = (ClientTransportPool)(nil) +var _ ClientTransportPool = (*ClientTransportPoolRandom)(nil) + +// ClientTransportPoolRandom selects a client transport (controller) at random until it is unreachable. Controllers +// are tried at random until a controller is reached. The newly connected controller is set for use on future requests +// until is too becomes unreachable. +type ClientTransportPoolRandom struct { + pool cmap.ConcurrentMap[string, *ApiClientTransport] + current atomic.Pointer[ApiClientTransport] +} + +func (c *ClientTransportPoolRandom) IterateTransportsRandomly() chan<- *ApiClientTransport { + channel := make(chan *ApiClientTransport, 1) + + go func() { + var transports []*ApiClientTransport + + for tpl := range c.pool.IterBuffered() { + transports = append(transports, tpl.Val) + } + + for len(transports) > 0 { + var selected *ApiClientTransport + selected, transports = selectAndRemoveRandom(transports, nil) + + if selected != nil { + channel <- selected + } + } + }() + + return channel +} + +func (c *ClientTransportPoolRandom) GetApiUrls() []*url.URL { + var result []*url.URL + + for tpl := range c.pool.IterBuffered() { + result = append(result, tpl.Val.ApiUrl) + } + + return result +} + +func (c *ClientTransportPoolRandom) GetActiveTransport() *ApiClientTransport { + active := c.current.Load() + if active == nil { + active = c.AnyTransport() + c.SetActiveTransport(active) + } + + return active +} + +func (c *ClientTransportPoolRandom) GetApiClientTransports() []*ApiClientTransport { + var result []*ApiClientTransport + + for tpl := range c.pool.IterBuffered() { + result = append(result, tpl.Val) + } + + return result +} + +func NewClientTransportPoolRandom() *ClientTransportPoolRandom { + return &ClientTransportPoolRandom{ + pool: cmap.New[*ApiClientTransport](), + current: atomic.Pointer[ApiClientTransport]{}, + } +} + +func (c *ClientTransportPoolRandom) SetActiveTransport(transport *ApiClientTransport) { + pfxlog.Logger().WithField("key", transport.ApiUrl.String()).Info("setting active controller") + c.current.Store(transport) +} + +func (c *ClientTransportPoolRandom) Add(apiUrl *url.URL, transport runtime.ClientTransport) { + c.pool.Set(apiUrl.String(), &ApiClientTransport{ + ClientTransport: transport, + ApiUrl: apiUrl, + }) +} + +func (c *ClientTransportPoolRandom) Remove(apiUrl *url.URL) { + c.pool.Remove(apiUrl.String()) +} + +func (c *ClientTransportPoolRandom) Submit(operation *runtime.ClientOperation) (any, error) { + return c.TryTransportsForOp(operation) +} + +func (c *ClientTransportPoolRandom) TryTransportsForOp(operation *runtime.ClientOperation) (any, error) { + result, err := c.TryTransportForF(func(transport *ApiClientTransport) (any, error) { + return transport.Submit(operation) + }) + + return result, err +} + +func (c *ClientTransportPoolRandom) IterateRandomTransport() <-chan *ApiClientTransport { + var transportsToTry []*cmap.Tuple[string, *ApiClientTransport] + for tpl := range c.pool.IterBuffered() { + transportsToTry = append(transportsToTry, &tpl) + } + + ch := make(chan *ApiClientTransport, len(transportsToTry)) + + go func() { + for len(transportsToTry) > 0 { + var transportTpl *cmap.Tuple[string, *ApiClientTransport] + transportTpl, transportsToTry = selectAndRemoveRandom(transportsToTry, nil) + ch <- transportTpl.Val + } + }() + + return ch +} + +func (c *ClientTransportPoolRandom) TryTransportForF(cb func(*ApiClientTransport) (any, error)) (any, error) { + //try active first if we have it + active := c.GetActiveTransport() + activeKey := "" + + if active != nil { + activeKey = active.ApiUrl.String() + result, err := cb(active) + + if err == nil || !errorIndicatesControllerSwap(err) { + pfxlog.Logger().WithError(err).Debugf("determine a network error did not occur on (%T) returning", err) + return result, err + } + + pfxlog.Logger().WithError(err).Debugf("encountered network error (%T) while submitting request", err) + + if c.pool.Count() == 1 { + pfxlog.Logger().Debug("active transport failed, only 1 transport in pool") + + return result, err + } + } + + // either no active or active failed, lets start trying them at random + pfxlog.Logger().Debug("trying random transports from pool") + + ch := c.IterateRandomTransport() + + var lastResult any + lastErr := errors.New("no transports to try, active transport already failed or was nil") //default err should never be returned + attempts := 0 + for transport := range ch { + // skip the already attempted active key + if activeKey != "" && transport.ApiUrl.String() == activeKey { + continue + } + + attempts = attempts + 1 + lastResult, lastErr = cb(transport) + + if lastErr == nil { + c.SetActiveTransport(transport) + return lastResult, nil + } + } + + return lastResult, lastErr +} + +func (c *ClientTransportPoolRandom) AnyTransport() *ApiClientTransport { + rand.Seed(uint64(time.Now().UnixNano())) + transportBuffer := c.pool.Items() + var keys []string + + for key := range transportBuffer { + keys = append(keys, key) + } + + if len(keys) == 0 { + return nil + } + index := rand.Intn(len(keys)) + return transportBuffer[keys[index]] +} + +var _ runtime.ClientTransport = (*ClientTransportPoolRandom)(nil) +var _ ClientTransportPool = (*ClientTransportPoolRandom)(nil) + +var opError = &net.OpError{} + +func errorIndicatesControllerSwap(err error) bool { + pfxlog.Logger().WithError(err).Debugf("checking for network errror on type (%T) and its wrapped errors", err) + + if errors.As(err, &opError) { + pfxlog.Logger().Debug("detected net.OpError") + return true + } + + //others? rate limiting? http timeout? + + return false +} + +func selectAndRemoveRandom[T any](slice []T, zero T) (selected T, modifiedSlice []T) { + rand.Seed(uint64(time.Now().UnixNano())) + if len(slice) == 0 { + return zero, slice + } + index := rand.Intn(len(slice)) + selected = slice[index] + modifiedSlice = append(slice[:index], slice[index+1:]...) + return selected, modifiedSlice +} diff --git a/ziti/client.go b/ziti/client.go index 36bece10..82034397 100644 --- a/ziti/client.go +++ b/ziti/client.go @@ -107,7 +107,7 @@ func (self *CtrlClient) Authenticate() (apis.ApiSession, error) { self.ApiSessionCertificate = nil - apiSession, err := self.ClientApiClient.Authenticate(self.Credentials, self.ConfigTypes) + apiSession, err := self.ClientApiClient.Authenticate(self.Credentials, nil) if err != nil { return nil, rest_util.WrapErr(err) diff --git a/ziti/contexts.go b/ziti/contexts.go index 69cef095..65dccab8 100644 --- a/ziti/contexts.go +++ b/ziti/contexts.go @@ -90,7 +90,9 @@ func NewContextWithOpts(cfg *Config, options *Options) (Context, error) { } if cfg.ID.Cert != "" && cfg.ID.Key != "" { - cfg.Credentials = edge_apis.NewIdentityCredentialsFromConfig(cfg.ID) + idCredentials := edge_apis.NewIdentityCredentialsFromConfig(cfg.ID) + idCredentials.ConfigTypes = cfg.ConfigTypes + cfg.Credentials = idCredentials } else if cfg.Credentials == nil { return nil, errors.New("either cfg.ID or cfg.Credentials must be provided") } @@ -114,7 +116,7 @@ func NewContextWithOpts(cfg *Config, options *Options) (Context, error) { } newContext.CtrlClt = &CtrlClient{ - ClientApiClient: edge_apis.NewClientApiClient(apiUrls[0], cfg.Credentials.GetCaPool(), func(codeCh chan string) { + ClientApiClient: edge_apis.NewClientApiClient(apiUrls, cfg.Credentials.GetCaPool(), func(codeCh chan string) { provider := rest_model.MfaProvidersZiti authQuery := &rest_model.AuthQueryDetail{ @@ -146,5 +148,9 @@ func NewContextWithOpts(cfg *Config, options *Options) (Context, error) { newContext.CtrlClt.ClientApiClient.SetAllowOidcDynamicallyEnabled(cfg.EnableHa) newContext.CtrlClt.PostureCache = posture.NewCache(newContext.CtrlClt, newContext.closeNotify) + newContext.CtrlClt.AddOnControllerUpdateListeners(func(urls []*url.URL) { + newContext.Emit(EventControllerUrlsUpdated, urls) + }) + return newContext, nil } diff --git a/ziti/events.go b/ziti/events.go index 65cd3887..0a4dfc0c 100644 --- a/ziti/events.go +++ b/ziti/events.go @@ -101,6 +101,13 @@ const ( // 1) Context - the context that triggered the listener // 2) apiSession *rest_model.CurrentApiSessionDetail - details of the invalid API Session EventAuthenticationStateUnauthenticated = events.EventName("auth-state-unauthenticated") + + // EventControllerUrlsUpdated is emitted when a new set of controllers is detected + // + // Arguments: + // 1) Context - the context that triggered the listener + // 2) apiUrls []*urls.URL - the URLs of the API for the available controllers + EventControllerUrlsUpdated = events.EventName("controller-urls-updated") ) // Eventer provides types methods for adding event listeners to a context and exposes some weakly typed functions diff --git a/ziti/ziti.go b/ziti/ziti.go index cf78a03b..3e32d41f 100644 --- a/ziti/ziti.go +++ b/ziti/ziti.go @@ -33,6 +33,7 @@ import ( "math" "math/rand" "net" + "net/url" "reflect" "strconv" "strings" @@ -431,6 +432,29 @@ func (context *ContextImpl) AddAuthenticationStateUnauthenticatedListener(handle } } +func (context *ContextImpl) AddControllerUrlsUpdateListener(handler func(Context, []*url.URL)) func() { + listener := func(args ...interface{}) { + var apiUrls []*url.URL + + if args[0] != nil { + var ok bool + apiUrls, ok = args[0].([]*url.URL) + + if !ok { + pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", apiUrls, args[0]) + } + } + + handler(context, apiUrls) + } + + context.AddListener(EventControllerUrlsUpdated, listener) + + return func() { + context.RemoveListener(EventAuthenticationStateUnauthenticated, listener) + } +} + func (context *ContextImpl) Events() Eventer { return context }