Skip to content

Commit

Permalink
Adds AWS Region as a CLI flag for explicitly setting that value.
Browse files Browse the repository at this point in the history
Closes #160
Closes #161

Merge remote-tracking branch 'euchen-circle/feature_add_region_flag' into pr_161_euchen-circle
  • Loading branch information
monde committed Feb 10, 2024
2 parents 416a0f0 + 0f315a0 commit 11f890a
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 23 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ These global settings are optional unless marked otherwise:

| Name | Description | Command line flag | ENV var and .env file value |
|-----|-----|-----|-----|
| AWS Region (**optional**) | AWS region (will override ENV VAR `AWS_REGION` and `AWS_DEFAULT_REGION`) e.g. `us-east-2` | `--aws-region [value]` | `OKTA_AWSCLI_AWS_REGION` |
| Okta Org Domain (**required**) | Full host and domain name of the Okta org e.g. `test.okta.com` or the custom domain value | `--org-domain [value]` | `OKTA_AWSCLI_ORG_DOMAIN` |
| OIDC Client ID (**required**) | For `web` the OIDC native application / [Allowed Web SSO Client ID](#allowed-web-sso-client-id), for `m2m` the API services app ID | `--oidc-client-id [value]` | `OKTA_AWSCLI_OIDC_CLIENT_ID` |
| AWS IAM Role ARN (**optional** for `web`, **required** for `m2m`) | For web preselects the role list to this preferred IAM role for the given IAM Identity Provider. For `m2m` | `--aws-iam-role [value]` | `OKTA_AWSCLI_IAM_ROLE` |
Expand Down
7 changes: 7 additions & 0 deletions cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ func init() {
Usage: "Output format. [env-var|aws-credentials|process-credentials]",
EnvVar: config.FormatEnvVar,
},
{
Name: config.AWSRegionFlag,
Short: "n",
Value: "",
Usage: "Preset AWS Region",
EnvVar: config.AWSRegionEnvVar,
},
{
Name: config.AWSCredentialsFlag,
Short: "w",
Expand Down
2 changes: 2 additions & 0 deletions internal/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
// the different credentials formats
type CredentialContainer struct {
AccessKeyID string
Region string
SecretAccessKey string
SessionToken string
Expiration *time.Time
Expand All @@ -44,6 +45,7 @@ type EnvVarCredential struct {
// credentials file
type CredsFileCredential struct {
AccessKeyID string `ini:"aws_access_key_id"`
Region string `ini:"region"`
SecretAccessKey string `ini:"aws_secret_access_key"`
SessionToken string `ini:"aws_session_token"`

Expand Down
19 changes: 17 additions & 2 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ const (
AWSIAMIdPFlag = "aws-iam-idp"
// AWSIAMRoleFlag cli flag const
AWSIAMRoleFlag = "aws-iam-role"
// AWSRegion cli flag const
AWSRegionFlag = "aws-region"
// CustomScopeFlag cli flag const
CustomScopeFlag = "custom-scope"
// DebugFlag cli flag const
Expand Down Expand Up @@ -113,6 +115,8 @@ const (
AWSIAMRoleEnvVar = "OKTA_AWSCLI_IAM_ROLE"
// AWSSessionDurationEnvVar env var const
AWSSessionDurationEnvVar = "OKTA_AWSCLI_SESSION_DURATION"
// AWSRegionEnvVar env var const
AWSRegionEnvVar = "OKTA_AWSCLI_AWS_REGION"
// CacheAccessTokenEnvVar env var const
CacheAccessTokenEnvVar = "OKTA_AWSCLI_CACHE_ACCESS_TOKEN"
// CustomScopeEnvVar env var const
Expand Down Expand Up @@ -194,6 +198,7 @@ type Config struct {
awsCredentials string
awsIAMIdP string
awsIAMRole string
awsRegion string
awsSessionDuration int64
cacheAccessToken bool
customScope string
Expand Down Expand Up @@ -225,6 +230,7 @@ type Attributes struct {
AWSCredentials string
AWSIAMIdP string
AWSIAMRole string
AWSRegion string
AWSSessionDuration int64
CacheAccessToken bool
CustomScope string
Expand Down Expand Up @@ -268,8 +274,7 @@ func NewConfig(attrs *Attributes) (*Config, error) {
awsCredentials: attrs.AWSCredentials,
awsIAMIdP: attrs.AWSIAMIdP,
awsIAMRole: attrs.AWSIAMRole,
cacheAccessToken: attrs.CacheAccessToken,
customScope: attrs.CustomScope,
awsRegion: attrs.AWSRegion,
debug: attrs.Debug,
debugAPICalls: attrs.DebugAPICalls,
expiryAWSVariables: attrs.ExpiryAWSVariables,
Expand Down Expand Up @@ -318,6 +323,7 @@ func readConfig() (Attributes, error) {
AWSIAMIdP: viper.GetString(AWSIAMIdPFlag),
AWSIAMRole: viper.GetString(AWSIAMRoleFlag),
AWSSessionDuration: viper.GetInt64(SessionDurationFlag),
AWSRegion: viper.GetString(AWSRegionFlag),
CustomScope: viper.GetString(CustomScopeFlag),
Debug: viper.GetBool(DebugFlag),
DebugAPICalls: viper.GetBool(DebugAPICallsFlag),
Expand Down Expand Up @@ -551,6 +557,15 @@ func (c *Config) SetAWSIAMRole(role string) error {
return nil
}

func (c *Config) SetAWSRegion(region string) error {
c.awsRegion = region
return nil
}

func (c *Config) AWSRegion() string {
return c.awsRegion
}

// AWSSessionDuration --
func (c *Config) AWSSessionDuration() int64 {
return c.awsSessionDuration
Expand Down
4 changes: 4 additions & 0 deletions internal/m2mauth/m2mauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ func (m *M2MAuthentication) EstablishIAMCredentials() error {

func (m *M2MAuthentication) awsAssumeRoleWithWebIdentity(at *okta.AccessToken) (cc *oaws.CredentialContainer, err error) {
awsCfg := aws.NewConfig().WithHTTPClient(m.config.HTTPClient())
region := m.config.AWSRegion()
if region != "" {
awsCfg = awsCfg.WithRegion(region)
}
sess, err := session.NewSession(awsCfg)
if err != nil {
return
Expand Down
27 changes: 21 additions & 6 deletions internal/output/aws_credentials_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func ensureConfigExists(filename string, profile string) error {
return nil
}

func saveProfile(filename, profile string, cfc *oaws.CredsFileCredential, legacyVars, expiryVars bool, expiry string) error {
config, err := updateConfig(filename, profile, cfc, legacyVars, expiryVars, expiry)
func saveProfile(filename, profile string, cfc *oaws.CredsFileCredential, legacyVars, expiryVars bool, expiry string, regionVar string) error {
config, err := updateConfig(filename, profile, cfc, legacyVars, expiryVars, expiry, regionVar)
if err != nil {
return err
}
Expand All @@ -80,7 +80,7 @@ func saveProfile(filename, profile string, cfc *oaws.CredsFileCredential, legacy
return nil
}

func updateConfig(filename, profile string, cfc *oaws.CredsFileCredential, legacyVars, expiryVars bool, expiry string) (config *ini.File, err error) {
func updateConfig(filename, profile string, cfc *oaws.CredsFileCredential, legacyVars, expiryVars bool, expiry string, region string) (config *ini.File, err error) {
config, err = ini.Load(filename)
if err != nil {
return
Expand All @@ -102,11 +102,17 @@ func updateConfig(filename, profile string, cfc *oaws.CredsFileCredential, legac
if legacyVars {
builder.AddField(SecurityTokenField, "", `ini:"aws_security_token"`)
}
if region != "" {
builder.AddField(utils.Region, "", `ini:"region"`)
}

instance := builder.Build().New()
reflect.ValueOf(instance).Elem().FieldByName(utils.AccessKeyID).SetString(cfc.AccessKeyID)
reflect.ValueOf(instance).Elem().FieldByName(utils.SecretAccessKey).SetString(cfc.SecretAccessKey)
reflect.ValueOf(instance).Elem().FieldByName(utils.SessionToken).SetString(cfc.SessionToken)
if region != "" {
reflect.ValueOf(instance).Elem().FieldByName(utils.Region).SetString(region)
}

if expiryVars {
reflect.ValueOf(instance).Elem().FieldByName(ExpirationField).SetString(expiry)
Expand All @@ -120,12 +126,12 @@ func updateConfig(filename, profile string, cfc *oaws.CredsFileCredential, legac
return
}

return updateINI(config, profile, legacyVars, expiryVars)
return updateINI(config, profile, legacyVars, expiryVars, region)
}

// updateIni will comment out any keys that are not "aws_access_key_id",
// "aws_secret_access_key", "aws_session_token", "credential_process"
func updateINI(config *ini.File, profile string, legacyVars bool, expiryVars bool) (*ini.File, error) {
func updateINI(config *ini.File, profile string, legacyVars bool, expiryVars bool, region string) (*ini.File, error) {
ignore := []string{
"aws_access_key_id",
"aws_secret_access_key",
Expand All @@ -138,6 +144,9 @@ func updateINI(config *ini.File, profile string, legacyVars bool, expiryVars boo
if expiryVars {
ignore = append(ignore, "x_security_token_expires")
}
if region != "" {
ignore = append(ignore, "region")
}
section := config.Section(profile)
comments := []string{}
for _, name := range section.KeyStrings() {
Expand Down Expand Up @@ -171,6 +180,7 @@ type AWSCredentialsFile struct {
LegacyAWSVariables bool
ExpiryAWSVariables bool
Expiry string
Region string
}

// NewAWSCredentialsFile Creates a new
Expand Down Expand Up @@ -230,6 +240,11 @@ aws_session_token = %s
credArgs = append(credArgs, a.Expiry)
}

if c.AWSRegion() != "" {
creds = fmt.Sprintf("%sregion = %%s\n", creds)
credArgs = append(credArgs, c.AWSRegion())
}

creds = fmt.Sprintf(creds, credArgs...)

_, err = f.WriteString(creds)
Expand All @@ -255,7 +270,7 @@ func (a *AWSCredentialsFile) writeConfig(c *config.Config, cfc *oaws.CredsFileCr
return err
}

return saveProfile(filename, profile, cfc, a.LegacyAWSVariables, a.ExpiryAWSVariables, a.Expiry)
return saveProfile(filename, profile, cfc, a.LegacyAWSVariables, a.ExpiryAWSVariables, a.Expiry, c.AWSRegion())
}

func contains(ignore []string, name string) bool {
Expand Down
4 changes: 2 additions & 2 deletions internal/output/aws_credentials_file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestINIFormatCredentialsContent(t *testing.T) {
SecretAccessKey: "e",
SessionToken: "f",
}
config, err := updateConfig(filename, "test", cfc, false, false, "")
config, err := updateConfig(filename, "test", cfc, false, false, "", "us-east-2")
assert.NoError(t, err)

err = config.SaveTo(filename)
Expand Down Expand Up @@ -141,7 +141,7 @@ aws_security_token = ghi
t.Run(test.name, func(t *testing.T) {
config, err := ini.Load(test.config)
require.NoError(t, err)
ini, err := updateINI(config, "default", test.legacy, false)
ini, err := updateINI(config, "default", test.legacy, false, "us-east-2")
require.NoError(t, err)
section := ini.Section(test.section)
require.Equal(t, len(test.want), len(section.KeyStrings()))
Expand Down
4 changes: 3 additions & 1 deletion internal/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ const (

// AccessKeyID AWS creds access key ID
AccessKeyID = "AccessKeyID"
// Region region
Region = "Region"
// SecretAccessKey AWS creds secret access key
SecretAccessKey = "SecretAccessKey"
// SessionToken AWS creds session tokne
// SessionToken AWS creds session token
SessionToken = "SessionToken"
)
29 changes: 17 additions & 12 deletions internal/webssoauth/webssoauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ type WebSSOAuthentication struct {

// idpAndRole IdP and role pairs
type idpAndRole struct {
idp string
role string
idp string
role string
region string
}

var stderrIsOutAskOpt = func(options *survey.AskOptions) error {
Expand Down Expand Up @@ -171,7 +172,7 @@ func (w *WebSSOAuthentication) EstablishIAMCredentials() error {
}
if w.config.FedAppID() != "" {
// Alternate path when operator knows their AWS Fed app ID
err = w.establishTokenWithFedAppID(clientID, w.config.FedAppID(), at)
err = w.establishTokenWithFedAppID(clientID, w.config.FedAppID(), at, w.config.AWSRegion())
if at != nil && err != nil {
// possible bad cached access token, retry
at = nil
Expand Down Expand Up @@ -210,7 +211,7 @@ AWS Federation App with --aws-acct-fed-app-id FED_APP_ID
// special case, we're going to run the table and get all profiles for all apps
errArr := []error{}
for _, app := range apps {
if err = w.establishTokenWithFedAppID(clientID, app.ID, at); err != nil {
if err = w.establishTokenWithFedAppID(clientID, app.ID, at, w.config.AWSRegion()); err != nil {
errArr = append(errArr, err)
}
}
Expand All @@ -227,7 +228,7 @@ AWS Federation App with --aws-acct-fed-app-id FED_APP_ID
}
}

return w.establishTokenWithFedAppID(clientID, fedAppID, at)
return w.establishTokenWithFedAppID(clientID, fedAppID, at, w.config.AWSRegion())
}

// choiceFriendlyLabelIDP returns a friendly choice for pretty printing IDP
Expand Down Expand Up @@ -310,7 +311,7 @@ func (w *WebSSOAuthentication) selectFedApp(apps []*okta.Application) (string, e
return idps[selected].ID, nil
}

func (w *WebSSOAuthentication) establishTokenWithFedAppID(clientID, fedAppID string, at *okta.AccessToken) error {
func (w *WebSSOAuthentication) establishTokenWithFedAppID(clientID, fedAppID string, at *okta.AccessToken, region string) error {
at, err := w.fetchSSOWebToken(clientID, fedAppID, at)
if err != nil {
return err
Expand All @@ -331,8 +332,9 @@ func (w *WebSSOAuthentication) establishTokenWithFedAppID(clientID, fedAppID str
if err != nil {
return err
}
iar.region = region

cc, err := w.awsAssumeRoleWithSAML(iar, assertion)
cc, err := w.awsAssumeRoleWithSAML(iar, assertion, region)
if err != nil {
return err
}
Expand All @@ -349,7 +351,7 @@ func (w *WebSSOAuthentication) establishTokenWithFedAppID(clientID, fedAppID str
}
}
} else {
ccch := w.fetchAllAWSCredentialsWithSAMLRole(idpRolesMap, assertion)
ccch := w.fetchAllAWSCredentialsWithSAMLRole(idpRolesMap, assertion, region)
if err != nil {
return err
}
Expand All @@ -368,8 +370,11 @@ func (w *WebSSOAuthentication) establishTokenWithFedAppID(clientID, fedAppID str

// awsAssumeRoleWithSAML Get AWS Credentials with an STS Assume Role With SAML AWS
// API call.
func (w *WebSSOAuthentication) awsAssumeRoleWithSAML(iar *idpAndRole, assertion string) (cc *oaws.CredentialContainer, err error) {
func (w *WebSSOAuthentication) awsAssumeRoleWithSAML(iar *idpAndRole, assertion, region string) (cc *oaws.CredentialContainer, err error) {
awsCfg := aws.NewConfig().WithHTTPClient(w.config.HTTPClient())
if region != "" {
awsCfg = awsCfg.WithRegion(region)
}
sess, err := session.NewSession(awsCfg)
if err != nil {
return
Expand Down Expand Up @@ -1104,17 +1109,17 @@ func (w *WebSSOAuthentication) consolePrint(format string, a ...any) {
}

// fetchAllAWSCredentialsWithSAMLRole Gets all AWS Credentials with an STS Assume Role with SAML AWS API call.
func (w *WebSSOAuthentication) fetchAllAWSCredentialsWithSAMLRole(idpRolesMap map[string][]string, assertion string) <-chan *oaws.CredentialContainer {
func (w *WebSSOAuthentication) fetchAllAWSCredentialsWithSAMLRole(idpRolesMap map[string][]string, assertion, region string) <-chan *oaws.CredentialContainer {
ccch := make(chan *oaws.CredentialContainer)
var wg sync.WaitGroup

for idp, roles := range idpRolesMap {
for _, role := range roles {
iar := &idpAndRole{idp, role}
iar := &idpAndRole{idp, role, region}
wg.Add(1)
go func() {
defer wg.Done()
cc, err := w.awsAssumeRoleWithSAML(iar, assertion)
cc, err := w.awsAssumeRoleWithSAML(iar, assertion, region)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to fetch AWS creds IdP %q, and Role %q, error:\n%+v\n", iar.idp, iar.role, err)
return
Expand Down

0 comments on commit 11f890a

Please sign in to comment.