Skip to content

Commit

Permalink
Merge pull request #220 from okta/reset-cached-access-token
Browse files Browse the repository at this point in the history
Reset cached access token on `invalid_grant`
  • Loading branch information
monde authored Jul 12, 2024
2 parents 8932921 + 6ed20d5 commit 2fb64ee
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 112 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# Changelog

## 2.3.0 (July 12, 2024)

### ENHANCEMENTS

* New command `okta-aws-cli list-profiles` helper to inspect profiles in okta.yaml [#222](https://github.com/okta/okta-aws-cli/pull/222), thanks [@pmgalea](https://github.com/pmgalea)!
* GH releases publish Windows artifact to Chocolatey [#215](https://github.com/okta/okta-aws-cli/pull/215), thanks [@monde](https://github.com/monde)!
* Better retry for when the cached access token has been invalidated outside of okta-aws-cli's control. [#220](https://github.com/okta/okta-aws-cli/pull/220), thanks [@monde](https://github.com/monde)!
* Print a warning at first run if otka.yaml is malformed. [#220](https://github.com/okta/okta-aws-cli/pull/220), thanks [@monde](https://github.com/monde)!

### BUG FIXES

* Correct "default" profile flaw introduced in 2.2.0 release [#220](https://github.com/okta/okta-aws-cli/pull/220), thanks [@monde](https://github.com/monde)!
* Continue polling instead of exit on a 400 "slow_down" API error [#220](https://github.com/okta/okta-aws-cli/pull/220), thanks [@monde](https://github.com/monde)!

## 2.2.0 (July 3, 2024)

### ENHANCEMENTS
Expand Down
6 changes: 2 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ build: fmtcheck
go build -o $(GOBIN)/okta-aws-cli cmd/okta-aws-cli/main.go

clean:
go clean -cache -testcache ./...

clean-all:
go clean -cache -testcache -modcache ./...
rm -fr dist/
go clean -testcache

fmt: tools # Format the code
@$(GOFMT) -l -w .
Expand Down
2 changes: 1 addition & 1 deletion cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func init() {
{
Name: config.ProfileFlag,
Short: "p",
Value: "default",
Value: "",
Usage: "AWS Profile",
EnvVar: config.ProfileEnvVar,
},
Expand Down
34 changes: 29 additions & 5 deletions cmd/root/web/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/okta/okta-aws-cli/internal/config"
cliFlag "github.com/okta/okta-aws-cli/internal/flag"
"github.com/okta/okta-aws-cli/internal/okta"
"github.com/okta/okta-aws-cli/internal/webssoauth"
)

Expand Down Expand Up @@ -78,20 +79,43 @@ func NewWebCommand() *cobra.Command {
Use: "web",
Short: "Human oriented authentication and device authorization",
RunE: func(cmd *cobra.Command, args []string) error {
config, err := config.EvaluateSettings()
cfg, err := config.EvaluateSettings()
if err != nil {
return err
}
err = cliFlag.CheckRequiredFlags(requiredFlags)

// Warn if there is an issue with okta.yaml
_, err = config.OktaConfig()
if err != nil {
return err
webssoauth.ConsolePrint(cfg, "WARNING: issue with %s file. Run `okta-aws-cli debug` command for additional diagnosis.\nError: %+v\n", config.OktaYaml, err)
}

wsa, err := webssoauth.NewWebSSOAuthentication(config)
err = cliFlag.CheckRequiredFlags(requiredFlags)
if err != nil {
return err
}
return wsa.EstablishIAMCredentials()

for attempt := 1; attempt <= 2; attempt++ {
wsa, err := webssoauth.NewWebSSOAuthentication(cfg)
if err != nil {
break
}

err = wsa.EstablishIAMCredentials()
if err == nil {
break
}

if apiErr, ok := err.(*okta.APIError); ok {
if apiErr.ErrorType == "invalid_grant" && webssoauth.RemoveCachedAccessToken() {
webssoauth.ConsolePrint(cfg, "\nCached access token appears to be stale, removing token and retrying device authorization ...\n\n")
continue
}
break
}
}

return err
},
}

Expand Down
21 changes: 11 additions & 10 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func init() {

const (
// Version app version
Version = "2.2.0"
Version = "2.3.0"

////////////////////////////////////////////////////////////
// FORMATS
Expand Down Expand Up @@ -420,7 +420,17 @@ func readConfig() (Attributes, error) {
}
}

// config loading order
// 1) command line flags 2) environment variables, 3) .env file
awsProfile := viper.GetString(ProfileFlag)
// mimic AWS CLI behavior, if profile value is not set by flag check
// the ENV VAR, else set to "default"
if awsProfile == "" {
awsProfile = viper.GetString(downCase(ProfileEnvVar))
}
if awsProfile == "" {
awsProfile = "default"
}

attrs := Attributes{
AllProfiles: viper.GetBool(getFlagNameFromProfile(awsProfile, AllProfilesFlag)),
Expand Down Expand Up @@ -454,15 +464,6 @@ func readConfig() (Attributes, error) {
attrs.Format = EnvVarFormat
}

// mimic AWS CLI behavior, if profile value is not set by flag check
// the ENV VAR, else set to "default"
if attrs.Profile == "" {
attrs.Profile = viper.GetString(downCase(ProfileEnvVar))
}
if attrs.Profile == "" {
attrs.Profile = "default"
}

// Viper binds ENV VARs to a lower snake version, set the configs with them
// if they haven't already been set by cli flag binding.
if attrs.OrgDomain == "" {
Expand Down
83 changes: 81 additions & 2 deletions internal/okta/apierror.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,87 @@

package okta

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/BurntSushi/toml"
)

const (
// APIErrorMessageBase base API error message
APIErrorMessageBase = "the API returned an unknown error"
// APIErrorMessageWithErrorDescription API error message with description
APIErrorMessageWithErrorDescription = "the API returned an error: %s"
// APIErrorMessageWithErrorSummary API error message with summary
APIErrorMessageWithErrorSummary = "the API returned an error: %s"
// HTTPHeaderWwwAuthenticate Www-Authenticate header
HTTPHeaderWwwAuthenticate = "Www-Authenticate"
)

// APIError Wrapper for Okta API error
type APIError struct {
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ErrorType string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorCode string `json:"errorCode,omitempty"`
ErrorSummary string `json:"errorSummary,omitempty" toml:"error_description"`
ErrorLink string `json:"errorLink,omitempty"`
ErrorID string `json:"errorId,omitempty"`
ErrorCauses []map[string]interface{} `json:"errorCauses,omitempty"`
}

// Error String-ify the Error
func (e *APIError) Error() string {
formattedErr := APIErrorMessageBase
if e.ErrorDescription != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorDescription, e.ErrorDescription)
} else if e.ErrorSummary != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorSummary, e.ErrorSummary)
}
if len(e.ErrorCauses) > 0 {
var causes []string
for _, cause := range e.ErrorCauses {
for key, val := range cause {
causes = append(causes, fmt.Sprintf("%s: %v", key, val))
}
}
formattedErr = fmt.Sprintf("%s. Causes: %s", formattedErr, strings.Join(causes, ", "))
}
return formattedErr
}

// NewAPIError Constructor for Okta API error, will return nil if the response
// is not an error.
func NewAPIError(resp *http.Response) error {
statusCode := resp.StatusCode
if statusCode >= http.StatusOK && statusCode < http.StatusBadRequest {
return nil
}
e := APIError{}
if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) &&
strings.Contains(resp.Header.Get(HTTPHeaderWwwAuthenticate), "Bearer") {
for _, v := range strings.Split(resp.Header.Get(HTTPHeaderWwwAuthenticate), ", ") {
if strings.Contains(v, "error_description") {
_, err := toml.Decode(v, &e)
if err != nil {
e.ErrorSummary = "unauthorized"
}
return &e
}
}
}
bodyBytes, _ := io.ReadAll(resp.Body)
copyBodyBytes := make([]byte, len(bodyBytes))
copy(copyBodyBytes, bodyBytes)
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
_ = json.NewDecoder(bytes.NewReader(copyBodyBytes)).Decode(&e)
if statusCode == http.StatusInternalServerError {
e.ErrorSummary += fmt.Sprintf(", x-okta-request-id=%s", resp.Header.Get("x-okta-request-id"))
}
return &e
}
84 changes: 18 additions & 66 deletions internal/paginator/paginator.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
/*
* Copyright (c) 2024-Present, Okta, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package paginator

import (
"bytes"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"

"github.com/BurntSushi/toml"
"github.com/okta/okta-aws-cli/internal/okta"
)

const (
// HTTPHeaderWwwAuthenticate Www-Authenticate header
HTTPHeaderWwwAuthenticate = "Www-Authenticate"
// APIErrorMessageBase base API error message
APIErrorMessageBase = "the API returned an unknown error"
// APIErrorMessageWithErrorDescription API error message with description
APIErrorMessageWithErrorDescription = "the API returned an error: %s"
// APIErrorMessageWithErrorSummary API error message with summary
Expand Down Expand Up @@ -136,7 +149,7 @@ func newPaginateResponse(r *http.Response, pgntr *Paginator) *PaginateResponse {
func buildPaginateResponse(resp *http.Response, pgntr *Paginator, v interface{}) (*PaginateResponse, error) {
ct := resp.Header.Get("Content-Type")
response := newPaginateResponse(resp, pgntr)
err := checkResponseForError(resp)
err := okta.NewAPIError(resp)
if err != nil {
return response, err
}
Expand Down Expand Up @@ -167,64 +180,3 @@ func buildPaginateResponse(resp *http.Response, pgntr *Paginator, v interface{})
}
return response, nil
}

func checkResponseForError(resp *http.Response) error {
statusCode := resp.StatusCode
if statusCode >= http.StatusOK && statusCode < http.StatusBadRequest {
return nil
}
e := Error{}
if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) &&
strings.Contains(resp.Header.Get(HTTPHeaderWwwAuthenticate), "Bearer") {
for _, v := range strings.Split(resp.Header.Get(HTTPHeaderWwwAuthenticate), ", ") {
if strings.Contains(v, "error_description") {
_, err := toml.Decode(v, &e)
if err != nil {
e.ErrorSummary = "unauthorized"
}
return &e
}
}
}
bodyBytes, _ := io.ReadAll(resp.Body)
copyBodyBytes := make([]byte, len(bodyBytes))
copy(copyBodyBytes, bodyBytes)
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
_ = json.NewDecoder(bytes.NewReader(copyBodyBytes)).Decode(&e)
if statusCode == http.StatusInternalServerError {
e.ErrorSummary += fmt.Sprintf(", x-okta-request-id=%s", resp.Header.Get("x-okta-request-id"))
}
return &e
}

// Error A struct for marshalling Okta's API error response bodies
type Error struct {
ErrorMessage string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorCode string `json:"errorCode,omitempty"`
ErrorSummary string `json:"errorSummary,omitempty" toml:"error_description"`
ErrorLink string `json:"errorLink,omitempty"`
ErrorID string `json:"errorId,omitempty"`
ErrorCauses []map[string]interface{} `json:"errorCauses,omitempty"`
}

// Error String-ify the Error
func (e *Error) Error() string {
formattedErr := APIErrorMessageBase
if e.ErrorDescription != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorDescription, e.ErrorDescription)
} else if e.ErrorSummary != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorSummary, e.ErrorSummary)
}
if len(e.ErrorCauses) > 0 {
var causes []string
for _, cause := range e.ErrorCauses {
for key, val := range cause {
causes = append(causes, fmt.Sprintf("%s: %v", key, val))
}
}
formattedErr = fmt.Sprintf("%s. Causes: %s", formattedErr, strings.Join(causes, ", "))
}
return formattedErr
}
Loading

0 comments on commit 2fb64ee

Please sign in to comment.