Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds OIDC dynamic SDK check flag, fix ptr ref to unauthed handler #520

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 62 additions & 26 deletions edge-apis/authwrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,27 +220,40 @@ func (a *ApiSessionOidc) GetExpiresAt() *time.Time {
// functionality to the alias type to implement the AuthEnabledApi interface.
type ZitiEdgeManagement struct {
*rest_management_api_client.ZitiEdgeManagement
useOidc bool
versionOnce sync.Once
versionInfo *rest_model.Version
oidcExplicitSet bool
apiUrl *url.URL
// useOidc tracks if OIDC auth should be used
useOidc bool

// useOidcExplicitlySet signals if useOidc was set from an external caller and should be used as is
useOidcExplicitlySet bool

// oidcDynamicallyEnabled will cause the client to check the controller for OIDC support and use if possible as long as useOidc was not explicitly set
oidcDynamicallyEnabled bool //currently defaults false till HA release

versionOnce sync.Once
versionInfo *rest_model.Version

apiUrl *url.URL

TotpCallback func(chan string)
}

func (self *ZitiEdgeManagement) Authenticate(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) {
self.versionOnce.Do(func() {
if self.oidcExplicitSet {
if self.useOidcExplicitlySet {
return
}
versionParams := manInfo.NewListVersionParams()

versionResp, _ := self.Informational.ListVersion(versionParams)
if self.oidcDynamicallyEnabled {
versionParams := manInfo.NewListVersionParams()

versionResp, _ := self.Informational.ListVersion(versionParams)

if versionResp != nil {
self.versionInfo = versionResp.Payload.Data
self.useOidc = stringz.Contains(self.versionInfo.Capabilities, string(rest_model.CapabilitiesOIDCAUTH))
if versionResp != nil {
self.versionInfo = versionResp.Payload.Data
self.useOidc = stringz.Contains(self.versionInfo.Capabilities, string(rest_model.CapabilitiesOIDCAUTH))
}
} else {
self.useOidc = false
}
})

Expand Down Expand Up @@ -279,10 +292,14 @@ func (self *ZitiEdgeManagement) oidcAuth(credentials Credentials, configTypes []
}

func (self *ZitiEdgeManagement) SetUseOidc(use bool) {
self.oidcExplicitSet = true
self.useOidcExplicitlySet = true
self.useOidc = use
}

func (self *ZitiEdgeManagement) SetAllowOidcDynamicallyEnabled(allow bool) {
self.oidcDynamicallyEnabled = allow
}

func (self *ZitiEdgeManagement) RefreshApiSession(apiSession ApiSession) (ApiSession, error) {
switch s := apiSession.(type) {
case *ApiSessionLegacy:
Expand Down Expand Up @@ -317,28 +334,39 @@ func (self *ZitiEdgeManagement) ExchangeTokens(curTokens *oidc.Tokens[*oidc.IDTo
// functionality to the alias type to implement the AuthEnabledApi interface.
type ZitiEdgeClient struct {
*rest_client_api_client.ZitiEdgeClient
useOidc bool
versionInfo *rest_model.Version
versionOnce sync.Once
oidcExplicitSet bool
apiUrl *url.URL
// useOidc tracks if OIDC auth should be used
useOidc bool

// useOidcExplicitlySet signals if useOidc was set from an external caller and should be used as is
useOidcExplicitlySet bool

// oidcDynamicallyEnabled will cause the client to check the controller for OIDC support and use if possible as long as useOidc was not explicitly set.
oidcDynamicallyEnabled bool //currently defaults false till HA release

versionInfo *rest_model.Version
versionOnce sync.Once
apiUrl *url.URL

TotpCallback func(chan string)
}

func (self *ZitiEdgeClient) Authenticate(credentials Credentials, configTypes []string, httpClient *http.Client) (ApiSession, error) {
self.versionOnce.Do(func() {
if self.oidcExplicitSet {
if self.useOidcExplicitlySet {
return
}

versionParams := clientInfo.NewListVersionParams()
if self.oidcDynamicallyEnabled {
versionParams := clientInfo.NewListVersionParams()

versionResp, _ := self.Informational.ListVersion(versionParams)
versionResp, _ := self.Informational.ListVersion(versionParams)

if versionResp != nil {
self.versionInfo = versionResp.Payload.Data
self.useOidc = stringz.Contains(self.versionInfo.Capabilities, string(rest_model.CapabilitiesOIDCAUTH))
if versionResp != nil {
self.versionInfo = versionResp.Payload.Data
self.useOidc = stringz.Contains(self.versionInfo.Capabilities, string(rest_model.CapabilitiesOIDCAUTH))
}
} else {
self.useOidc = false
}
})

Expand Down Expand Up @@ -377,21 +405,29 @@ func (self *ZitiEdgeClient) oidcAuth(credentials Credentials, configTypes []stri
}

func (self *ZitiEdgeClient) SetUseOidc(use bool) {
self.oidcExplicitSet = true
self.useOidcExplicitlySet = true
self.useOidc = use
}

func (self *ZitiEdgeClient) SetAllowOidcDynamicallyEnabled(allow bool) {
self.oidcDynamicallyEnabled = allow
}

func (self *ZitiEdgeClient) RefreshApiSession(apiSession ApiSession) (ApiSession, error) {
switch s := apiSession.(type) {
case *ApiSessionLegacy:
params := clientApiSession.NewGetCurrentAPISessionParams()
_, err := self.CurrentAPISession.GetCurrentAPISession(params, s)
newApiSessionDetail, err := self.CurrentAPISession.GetCurrentAPISession(params, s)

if err != nil {
return nil, rest_util.WrapErr(err)
}

return s, nil
newApiSession := &ApiSessionLegacy{
Detail: newApiSessionDetail.Payload.Data,
}

return newApiSession, nil
case *ApiSessionOidc:
tokens, err := self.ExchangeTokens(s.OidcTokens)

Expand Down
12 changes: 12 additions & 0 deletions edge-apis/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ type ApiType interface {
}

type OidcEnabledApi interface {
// SetUseOidc forces an API Client to operate in OIDC mode (true) or legacy mode (false). The state of the controller
// is ignored and dynamic enable/disable of OIDC support is suspended.
SetUseOidc(use bool)

// SetAllowOidcDynamicallyEnabled sets whether clients will check the controller for OIDC support or not. If supported
// OIDC is favored over legacy authentication.
SetAllowOidcDynamicallyEnabled(allow bool)
}

// BaseClient implements the Client interface specifically for the types specified in the ApiType constraint. It
Expand Down Expand Up @@ -64,6 +70,12 @@ func (self *BaseClient[A]) SetUseOidc(use bool) {
apiType.SetUseOidc(use)
}

func (self *BaseClient[A]) SetAllowOidcDynamicallyEnabled(allow bool) {
v := any(self.API)
apiType := v.(OidcEnabledApi)
apiType.SetAllowOidcDynamicallyEnabled(allow)
}

// Authenticate will attempt to use the provided credentials to authenticate via the underlying ApiType. On success
// the API Session details will be returned and the current client will make authenticated requests on future
// calls. On an error the API Session in use will be cleared and subsequent requests will become/continue to be
Expand Down
2 changes: 1 addition & 1 deletion ziti/contexts.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func NewContextWithOpts(cfg *Config, options *Options) (Context, error) {
ConfigTypes: cfg.ConfigTypes,
}

newContext.CtrlClt.ClientApiClient.SetUseOidc(cfg.EnableHa)
newContext.CtrlClt.ClientApiClient.SetAllowOidcDynamicallyEnabled(cfg.EnableHa)
newContext.CtrlClt.PostureCache = posture.NewCache(newContext.CtrlClt, newContext.closeNotify)

return newContext, nil
Expand Down
22 changes: 12 additions & 10 deletions ziti/ziti.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,15 @@ func (context *ContextImpl) AddAuthenticationStateFullListener(handler func(Cont

func (context *ContextImpl) AddAuthenticationStateUnauthenticatedListener(handler func(Context, apis.ApiSession)) func() {
listener := func(args ...interface{}) {
apiSession, ok := args[0].(apis.ApiSession)
var apiSession apis.ApiSession

if !ok {
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", apiSession, args[0])
}
if args[0] != nil {
var ok bool
apiSession, ok = args[0].(apis.ApiSession)

if apiSession == nil {
pfxlog.Logger().Fatalf("expected arg[0] was nil, unexpected")
if !ok {
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", apiSession, args[0])
}
}

handler(context, apiSession)
Expand Down Expand Up @@ -823,16 +824,16 @@ func (context *ContextImpl) GetCurrentIdentityWithBackoff() (*rest_model.Identit
}

func (context *ContextImpl) setUnauthenticated() {
prevApiSession := context.CtrlClt.ApiSession.Swap(nil)
willEmit := prevApiSession != nil
prevApiSessionPtr := context.CtrlClt.ApiSession.Swap(nil)
willEmit := prevApiSessionPtr != nil

context.CtrlClt.ApiSessionCertificate = nil

context.CloseAllEdgeRouterConns()
context.sessions.Clear()

if willEmit {
context.Emit(EventAuthenticationStateUnauthenticated, prevApiSession)
context.Emit(EventAuthenticationStateUnauthenticated, *prevApiSessionPtr)
}
}

Expand Down Expand Up @@ -968,8 +969,9 @@ func (context *ContextImpl) authenticateMfa(code string) error {
if _, err := context.CtrlClt.Refresh(); err != nil {
return err
}
apiSession := context.CtrlClt.GetCurrentApiSession()

if apiSession := context.CtrlClt.GetCurrentApiSession(); apiSession != nil && len(apiSession.GetAuthQueries()) == 0 {
if apiSession != nil && len(apiSession.GetAuthQueries()) == 0 {
return context.onFullAuth(apiSession)
}

Expand Down
Loading