From f7f2cb0ae7b6763426320814398fbeac28cf63bb Mon Sep 17 00:00:00 2001 From: Mike Mondragon Date: Mon, 8 Jul 2024 14:41:08 -0700 Subject: [PATCH] Better retry for when the cached access token has been invalidated outside of okta-aws-cli's control. Closes #207 Closes #198 --- cmd/root/web/web.go | 26 ++++++++-- internal/okta/apierror.go | 83 +++++++++++++++++++++++++++++- internal/paginator/paginator.go | 84 +++++++------------------------ internal/webssoauth/webssoauth.go | 64 ++++++++++++++--------- 4 files changed, 161 insertions(+), 96 deletions(-) diff --git a/cmd/root/web/web.go b/cmd/root/web/web.go index 7df3b96..c0a8f30 100644 --- a/cmd/root/web/web.go +++ b/cmd/root/web/web.go @@ -21,6 +21,7 @@ import ( "github.com/okta/okta-aws-cli/internal/config" cliFlag "github.com/okta/okta-aws-cli/internal/flag" + "github.com/okta/okta-aws-cli/internal/okta" "github.com/okta/okta-aws-cli/internal/webssoauth" ) @@ -82,16 +83,33 @@ func NewWebCommand() *cobra.Command { if err != nil { return err } + err = cliFlag.CheckRequiredFlags(requiredFlags) if err != nil { return err } - wsa, err := webssoauth.NewWebSSOAuthentication(config) - if err != nil { - return err + for attempt := 1; attempt <= 2; attempt++ { + wsa, err := webssoauth.NewWebSSOAuthentication(config) + if err != nil { + break + } + + err = wsa.EstablishIAMCredentials() + if err == nil { + break + } + + if apiErr, ok := err.(*okta.APIError); ok { + if apiErr.ErrorType == "invalid_grant" && webssoauth.RemoveCachedAccessToken() { + webssoauth.ConsolePrint(config, "\nCached access token appears to be stale, removing token and retrying device authorization ...\n\n") + continue + } + break + } } - return wsa.EstablishIAMCredentials() + + return err }, } diff --git a/internal/okta/apierror.go b/internal/okta/apierror.go index 337e9ed..618c28a 100644 --- a/internal/okta/apierror.go +++ b/internal/okta/apierror.go @@ -16,8 +16,87 @@ package okta +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/BurntSushi/toml" +) + +const ( + // APIErrorMessageBase base API error message + APIErrorMessageBase = "the API returned an unknown error" + // APIErrorMessageWithErrorDescription API error message with description + APIErrorMessageWithErrorDescription = "the API returned an error: %s" + // APIErrorMessageWithErrorSummary API error message with summary + APIErrorMessageWithErrorSummary = "the API returned an error: %s" + // HTTPHeaderWwwAuthenticate Www-Authenticate header + HTTPHeaderWwwAuthenticate = "Www-Authenticate" +) + // APIError Wrapper for Okta API error type APIError struct { - Error string `json:"error,omitempty"` - ErrorDescription string `json:"error_description,omitempty"` + ErrorType string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorCode string `json:"errorCode,omitempty"` + ErrorSummary string `json:"errorSummary,omitempty" toml:"error_description"` + ErrorLink string `json:"errorLink,omitempty"` + ErrorID string `json:"errorId,omitempty"` + ErrorCauses []map[string]interface{} `json:"errorCauses,omitempty"` +} + +// Error String-ify the Error +func (e *APIError) Error() string { + formattedErr := APIErrorMessageBase + if e.ErrorDescription != "" { + formattedErr = fmt.Sprintf(APIErrorMessageWithErrorDescription, e.ErrorDescription) + } else if e.ErrorSummary != "" { + formattedErr = fmt.Sprintf(APIErrorMessageWithErrorSummary, e.ErrorSummary) + } + if len(e.ErrorCauses) > 0 { + var causes []string + for _, cause := range e.ErrorCauses { + for key, val := range cause { + causes = append(causes, fmt.Sprintf("%s: %v", key, val)) + } + } + formattedErr = fmt.Sprintf("%s. Causes: %s", formattedErr, strings.Join(causes, ", ")) + } + return formattedErr +} + +// NewAPIError Constructor for Okta API error, will return nil if the response +// is not an error. +func NewAPIError(resp *http.Response) error { + statusCode := resp.StatusCode + if statusCode >= http.StatusOK && statusCode < http.StatusBadRequest { + return nil + } + e := APIError{} + if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) && + strings.Contains(resp.Header.Get(HTTPHeaderWwwAuthenticate), "Bearer") { + for _, v := range strings.Split(resp.Header.Get(HTTPHeaderWwwAuthenticate), ", ") { + if strings.Contains(v, "error_description") { + _, err := toml.Decode(v, &e) + if err != nil { + e.ErrorSummary = "unauthorized" + } + return &e + } + } + } + bodyBytes, _ := io.ReadAll(resp.Body) + copyBodyBytes := make([]byte, len(bodyBytes)) + copy(copyBodyBytes, bodyBytes) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + _ = json.NewDecoder(bytes.NewReader(copyBodyBytes)).Decode(&e) + if statusCode == http.StatusInternalServerError { + e.ErrorSummary += fmt.Sprintf(", x-okta-request-id=%s", resp.Header.Get("x-okta-request-id")) + } + return &e } diff --git a/internal/paginator/paginator.go b/internal/paginator/paginator.go index c80d858..2cf77e3 100644 --- a/internal/paginator/paginator.go +++ b/internal/paginator/paginator.go @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2024-Present, Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package paginator import ( @@ -5,20 +21,17 @@ import ( "encoding/json" "encoding/xml" "errors" - "fmt" "io" "net/http" "net/url" "strings" - "github.com/BurntSushi/toml" + "github.com/okta/okta-aws-cli/internal/okta" ) const ( // HTTPHeaderWwwAuthenticate Www-Authenticate header HTTPHeaderWwwAuthenticate = "Www-Authenticate" - // APIErrorMessageBase base API error message - APIErrorMessageBase = "the API returned an unknown error" // APIErrorMessageWithErrorDescription API error message with description APIErrorMessageWithErrorDescription = "the API returned an error: %s" // APIErrorMessageWithErrorSummary API error message with summary @@ -136,7 +149,7 @@ func newPaginateResponse(r *http.Response, pgntr *Paginator) *PaginateResponse { func buildPaginateResponse(resp *http.Response, pgntr *Paginator, v interface{}) (*PaginateResponse, error) { ct := resp.Header.Get("Content-Type") response := newPaginateResponse(resp, pgntr) - err := checkResponseForError(resp) + err := okta.NewAPIError(resp) if err != nil { return response, err } @@ -167,64 +180,3 @@ func buildPaginateResponse(resp *http.Response, pgntr *Paginator, v interface{}) } return response, nil } - -func checkResponseForError(resp *http.Response) error { - statusCode := resp.StatusCode - if statusCode >= http.StatusOK && statusCode < http.StatusBadRequest { - return nil - } - e := Error{} - if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) && - strings.Contains(resp.Header.Get(HTTPHeaderWwwAuthenticate), "Bearer") { - for _, v := range strings.Split(resp.Header.Get(HTTPHeaderWwwAuthenticate), ", ") { - if strings.Contains(v, "error_description") { - _, err := toml.Decode(v, &e) - if err != nil { - e.ErrorSummary = "unauthorized" - } - return &e - } - } - } - bodyBytes, _ := io.ReadAll(resp.Body) - copyBodyBytes := make([]byte, len(bodyBytes)) - copy(copyBodyBytes, bodyBytes) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - _ = json.NewDecoder(bytes.NewReader(copyBodyBytes)).Decode(&e) - if statusCode == http.StatusInternalServerError { - e.ErrorSummary += fmt.Sprintf(", x-okta-request-id=%s", resp.Header.Get("x-okta-request-id")) - } - return &e -} - -// Error A struct for marshalling Okta's API error response bodies -type Error struct { - ErrorMessage string `json:"error"` - ErrorDescription string `json:"error_description"` - ErrorCode string `json:"errorCode,omitempty"` - ErrorSummary string `json:"errorSummary,omitempty" toml:"error_description"` - ErrorLink string `json:"errorLink,omitempty"` - ErrorID string `json:"errorId,omitempty"` - ErrorCauses []map[string]interface{} `json:"errorCauses,omitempty"` -} - -// Error String-ify the Error -func (e *Error) Error() string { - formattedErr := APIErrorMessageBase - if e.ErrorDescription != "" { - formattedErr = fmt.Sprintf(APIErrorMessageWithErrorDescription, e.ErrorDescription) - } else if e.ErrorSummary != "" { - formattedErr = fmt.Sprintf(APIErrorMessageWithErrorSummary, e.ErrorSummary) - } - if len(e.ErrorCauses) > 0 { - var causes []string - for _, cause := range e.ErrorCauses { - for key, val := range cause { - causes = append(causes, fmt.Sprintf("%s: %v", key, val)) - } - } - formattedErr = fmt.Sprintf("%s. Causes: %s", formattedErr, strings.Join(causes, ", ")) - } - return formattedErr -} diff --git a/internal/webssoauth/webssoauth.go b/internal/webssoauth/webssoauth.go index 0d55723..5698c9a 100644 --- a/internal/webssoauth/webssoauth.go +++ b/internal/webssoauth/webssoauth.go @@ -28,7 +28,6 @@ import ( "net/url" "os" osexec "os/exec" - "os/user" "path/filepath" "regexp" "strings" @@ -723,16 +722,9 @@ func (w *WebSSOAuthentication) fetchSSOWebToken(clientID, awsFedAppID string, at return nil, err } - if resp.StatusCode != http.StatusOK { - baseErrStr := "fetching SSO web token received API response %q" - - var apiErr okta.APIError - err = json.NewDecoder(resp.Body).Decode(&apiErr) - if err != nil { - return nil, fmt.Errorf(baseErrStr, resp.Status) - } - - return nil, fmt.Errorf(baseErrStr+okta.AccessTokenErrorFormat, resp.Status, apiErr.Error, apiErr.ErrorDescription) + err = okta.NewAPIError(resp) + if err != nil { + return nil, err } token = &okta.AccessToken{} @@ -956,8 +948,8 @@ func (w *WebSSOAuthentication) accessToken(deviceAuth *okta.DeviceAuthorization) if err != nil { return backoff.Permanent(fmt.Errorf("fetching access token polling received unexpected API error body %q", string(bodyBytes))) } - if apiErr.Error != "authorization_pending" { - return backoff.Permanent(fmt.Errorf("fetching access token polling received unexpected API polling error %q - %q", apiErr.Error, apiErr.ErrorDescription)) + if apiErr.ErrorType != "authorization_pending" { + return backoff.Permanent(fmt.Errorf("fetching access token polling received unexpected API polling error %q - %q", apiErr.ErrorType, apiErr.ErrorDescription)) } return errors.New("continue polling") @@ -1120,15 +1112,37 @@ func (w *WebSSOAuthentication) isClassicOrg() bool { return false } +// cachedAccessTokenPath Path to the cached access token in $HOME/.okta/awscli-access-token.json +func cachedAccessTokenPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, dotOktaDir, tokenFileName), nil +} + +// RemoveCachedAccessToken Remove cached access token if it exists. Returns true +// if the file exists was reremoved, swallows errors otherwise. +func RemoveCachedAccessToken() bool { + accessTokenPath, err := cachedAccessTokenPath() + if err != nil { + return false + } + if os.Remove(accessTokenPath) != nil { + return false + } + + return true +} + // cachedAccessToken will returned the cached access token if it exists and is // not expired. func (w *WebSSOAuthentication) cachedAccessToken() (at *okta.AccessToken) { - homeDir, err := os.UserHomeDir() + accessTokenPath, err := cachedAccessTokenPath() if err != nil { return } - configPath := filepath.Join(homeDir, dotOktaDir, tokenFileName) - atJSON, err := os.ReadFile(configPath) + atJSON, err := os.ReadFile(accessTokenPath) if err != nil { return } @@ -1158,15 +1172,12 @@ func (w *WebSSOAuthentication) cacheAccessToken(at *okta.AccessToken) { return } - cUser, err := user.Current() + homeDir, err := os.UserHomeDir() if err != nil { return } - if cUser.HomeDir == "" { - return - } - oktaDir := filepath.Join(cUser.HomeDir, dotOktaDir) + oktaDir := filepath.Join(homeDir, dotOktaDir) // noop if dir exists err = os.MkdirAll(oktaDir, 0o700) if err != nil { @@ -1178,18 +1189,23 @@ func (w *WebSSOAuthentication) cacheAccessToken(at *okta.AccessToken) { return } - configPath := filepath.Join(cUser.HomeDir, dotOktaDir, tokenFileName) + configPath := filepath.Join(homeDir, dotOktaDir, tokenFileName) _ = os.WriteFile(configPath, atJSON, 0o600) } -func (w *WebSSOAuthentication) consolePrint(format string, a ...any) { - if w.config.IsProcessCredentialsFormat() { +// ConsolePrint printf formatted warning messages. +func ConsolePrint(config *config.Config, format string, a ...any) { + if config.IsProcessCredentialsFormat() { return } fmt.Fprintf(os.Stderr, format, a...) } +func (w *WebSSOAuthentication) consolePrint(format string, a ...any) { + ConsolePrint(w.config, format, a...) +} + // fetchAllAWSCredentialsWithSAMLRole Gets all AWS Credentials with an STS Assume Role with SAML AWS API call. func (w *WebSSOAuthentication) fetchAllAWSCredentialsWithSAMLRole(idpRolesMap map[string][]string, assertion, region string) <-chan *oaws.CredentialContainer { ccch := make(chan *oaws.CredentialContainer)