diff --git a/README.md b/README.md index 9c1519f..0bab329 100644 --- a/README.md +++ b/README.md @@ -358,6 +358,7 @@ These global settings are optional unless marked otherwise: | Name | Description | Command line flag | ENV var and .env file value | |-----|-----|-----|-----| +| AWS Region (**optional**) | AWS region (will override ENV VAR `AWS_REGION` and `AWS_DEFAULT_REGION`) e.g. `us-east-2` | `--aws-region [value]` | `OKTA_AWSCLI_AWS_REGION` | | Okta Org Domain (**required**) | Full host and domain name of the Okta org e.g. `test.okta.com` or the custom domain value | `--org-domain [value]` | `OKTA_AWSCLI_ORG_DOMAIN` | | OIDC Client ID (**required**) | For `web` the OIDC native application / [Allowed Web SSO Client ID](#allowed-web-sso-client-id), for `m2m` the API services app ID | `--oidc-client-id [value]` | `OKTA_AWSCLI_OIDC_CLIENT_ID` | | AWS IAM Role ARN (**optional** for `web`, **required** for `m2m`) | For web preselects the role list to this preferred IAM role for the given IAM Identity Provider. For `m2m` | `--aws-iam-role [value]` | `OKTA_AWSCLI_IAM_ROLE` | diff --git a/cmd/root/root.go b/cmd/root/root.go index 473f58a..9cfb5f6 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -85,6 +85,13 @@ func init() { Usage: "Output format. [env-var|aws-credentials|process-credentials]", EnvVar: config.FormatEnvVar, }, + { + Name: config.AWSRegionFlag, + Short: "n", + Value: "", + Usage: "Preset AWS Region", + EnvVar: config.AWSRegionEnvVar, + }, { Name: config.AWSCredentialsFlag, Short: "w", diff --git a/internal/aws/aws.go b/internal/aws/aws.go index 6d3abd0..7bcaac4 100644 --- a/internal/aws/aws.go +++ b/internal/aws/aws.go @@ -25,6 +25,7 @@ import ( // the different credentials formats type CredentialContainer struct { AccessKeyID string + Region string SecretAccessKey string SessionToken string Expiration *time.Time @@ -44,6 +45,7 @@ type EnvVarCredential struct { // credentials file type CredsFileCredential struct { AccessKeyID string `ini:"aws_access_key_id"` + Region string `ini:"region"` SecretAccessKey string `ini:"aws_secret_access_key"` SessionToken string `ini:"aws_session_token"` diff --git a/internal/config/config.go b/internal/config/config.go index 43727db..b3ed3ce 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -62,6 +62,8 @@ const ( AWSIAMIdPFlag = "aws-iam-idp" // AWSIAMRoleFlag cli flag const AWSIAMRoleFlag = "aws-iam-role" + // AWSRegion cli flag const + AWSRegionFlag = "aws-region" // CustomScopeFlag cli flag const CustomScopeFlag = "custom-scope" // DebugFlag cli flag const @@ -113,6 +115,8 @@ const ( AWSIAMRoleEnvVar = "OKTA_AWSCLI_IAM_ROLE" // AWSSessionDurationEnvVar env var const AWSSessionDurationEnvVar = "OKTA_AWSCLI_SESSION_DURATION" + // AWSRegionEnvVar env var const + AWSRegionEnvVar = "OKTA_AWSCLI_AWS_REGION" // CacheAccessTokenEnvVar env var const CacheAccessTokenEnvVar = "OKTA_AWSCLI_CACHE_ACCESS_TOKEN" // CustomScopeEnvVar env var const @@ -194,6 +198,7 @@ type Config struct { awsCredentials string awsIAMIdP string awsIAMRole string + awsRegion string awsSessionDuration int64 cacheAccessToken bool customScope string @@ -225,6 +230,7 @@ type Attributes struct { AWSCredentials string AWSIAMIdP string AWSIAMRole string + AWSRegion string AWSSessionDuration int64 CacheAccessToken bool CustomScope string @@ -268,8 +274,7 @@ func NewConfig(attrs *Attributes) (*Config, error) { awsCredentials: attrs.AWSCredentials, awsIAMIdP: attrs.AWSIAMIdP, awsIAMRole: attrs.AWSIAMRole, - cacheAccessToken: attrs.CacheAccessToken, - customScope: attrs.CustomScope, + awsRegion: attrs.AWSRegion, debug: attrs.Debug, debugAPICalls: attrs.DebugAPICalls, expiryAWSVariables: attrs.ExpiryAWSVariables, @@ -318,6 +323,7 @@ func readConfig() (Attributes, error) { AWSIAMIdP: viper.GetString(AWSIAMIdPFlag), AWSIAMRole: viper.GetString(AWSIAMRoleFlag), AWSSessionDuration: viper.GetInt64(SessionDurationFlag), + AWSRegion: viper.GetString(AWSRegionFlag), CustomScope: viper.GetString(CustomScopeFlag), Debug: viper.GetBool(DebugFlag), DebugAPICalls: viper.GetBool(DebugAPICallsFlag), @@ -551,6 +557,15 @@ func (c *Config) SetAWSIAMRole(role string) error { return nil } +func (c *Config) SetAWSRegion(region string) error { + c.awsRegion = region + return nil +} + +func (c *Config) AWSRegion() string { + return c.awsRegion +} + // AWSSessionDuration -- func (c *Config) AWSSessionDuration() int64 { return c.awsSessionDuration diff --git a/internal/m2mauth/m2mauth.go b/internal/m2mauth/m2mauth.go index 42f872c..f45adca 100644 --- a/internal/m2mauth/m2mauth.go +++ b/internal/m2mauth/m2mauth.go @@ -114,6 +114,10 @@ func (m *M2MAuthentication) EstablishIAMCredentials() error { func (m *M2MAuthentication) awsAssumeRoleWithWebIdentity(at *okta.AccessToken) (cc *oaws.CredentialContainer, err error) { awsCfg := aws.NewConfig().WithHTTPClient(m.config.HTTPClient()) + region := m.config.AWSRegion() + if region != "" { + awsCfg = awsCfg.WithRegion(region) + } sess, err := session.NewSession(awsCfg) if err != nil { return diff --git a/internal/output/aws_credentials_file.go b/internal/output/aws_credentials_file.go index ac49d36..69c08e8 100644 --- a/internal/output/aws_credentials_file.go +++ b/internal/output/aws_credentials_file.go @@ -65,8 +65,8 @@ func ensureConfigExists(filename string, profile string) error { return nil } -func saveProfile(filename, profile string, cfc *oaws.CredsFileCredential, legacyVars, expiryVars bool, expiry string) error { - config, err := updateConfig(filename, profile, cfc, legacyVars, expiryVars, expiry) +func saveProfile(filename, profile string, cfc *oaws.CredsFileCredential, legacyVars, expiryVars bool, expiry string, regionVar string) error { + config, err := updateConfig(filename, profile, cfc, legacyVars, expiryVars, expiry, regionVar) if err != nil { return err } @@ -80,7 +80,7 @@ func saveProfile(filename, profile string, cfc *oaws.CredsFileCredential, legacy return nil } -func updateConfig(filename, profile string, cfc *oaws.CredsFileCredential, legacyVars, expiryVars bool, expiry string) (config *ini.File, err error) { +func updateConfig(filename, profile string, cfc *oaws.CredsFileCredential, legacyVars, expiryVars bool, expiry string, region string) (config *ini.File, err error) { config, err = ini.Load(filename) if err != nil { return @@ -102,11 +102,17 @@ func updateConfig(filename, profile string, cfc *oaws.CredsFileCredential, legac if legacyVars { builder.AddField(SecurityTokenField, "", `ini:"aws_security_token"`) } + if region != "" { + builder.AddField(utils.Region, "", `ini:"region"`) + } instance := builder.Build().New() reflect.ValueOf(instance).Elem().FieldByName(utils.AccessKeyID).SetString(cfc.AccessKeyID) reflect.ValueOf(instance).Elem().FieldByName(utils.SecretAccessKey).SetString(cfc.SecretAccessKey) reflect.ValueOf(instance).Elem().FieldByName(utils.SessionToken).SetString(cfc.SessionToken) + if region != "" { + reflect.ValueOf(instance).Elem().FieldByName(utils.Region).SetString(region) + } if expiryVars { reflect.ValueOf(instance).Elem().FieldByName(ExpirationField).SetString(expiry) @@ -120,12 +126,12 @@ func updateConfig(filename, profile string, cfc *oaws.CredsFileCredential, legac return } - return updateINI(config, profile, legacyVars, expiryVars) + return updateINI(config, profile, legacyVars, expiryVars, region) } // updateIni will comment out any keys that are not "aws_access_key_id", // "aws_secret_access_key", "aws_session_token", "credential_process" -func updateINI(config *ini.File, profile string, legacyVars bool, expiryVars bool) (*ini.File, error) { +func updateINI(config *ini.File, profile string, legacyVars bool, expiryVars bool, region string) (*ini.File, error) { ignore := []string{ "aws_access_key_id", "aws_secret_access_key", @@ -138,6 +144,9 @@ func updateINI(config *ini.File, profile string, legacyVars bool, expiryVars boo if expiryVars { ignore = append(ignore, "x_security_token_expires") } + if region != "" { + ignore = append(ignore, "region") + } section := config.Section(profile) comments := []string{} for _, name := range section.KeyStrings() { @@ -171,6 +180,7 @@ type AWSCredentialsFile struct { LegacyAWSVariables bool ExpiryAWSVariables bool Expiry string + Region string } // NewAWSCredentialsFile Creates a new @@ -230,6 +240,11 @@ aws_session_token = %s credArgs = append(credArgs, a.Expiry) } + if c.AWSRegion() != "" { + creds = fmt.Sprintf("%sregion = %%s\n", creds) + credArgs = append(credArgs, c.AWSRegion()) + } + creds = fmt.Sprintf(creds, credArgs...) _, err = f.WriteString(creds) @@ -255,7 +270,7 @@ func (a *AWSCredentialsFile) writeConfig(c *config.Config, cfc *oaws.CredsFileCr return err } - return saveProfile(filename, profile, cfc, a.LegacyAWSVariables, a.ExpiryAWSVariables, a.Expiry) + return saveProfile(filename, profile, cfc, a.LegacyAWSVariables, a.ExpiryAWSVariables, a.Expiry, c.AWSRegion()) } func contains(ignore []string, name string) bool { diff --git a/internal/output/aws_credentials_file_test.go b/internal/output/aws_credentials_file_test.go index f2498cf..2b169d5 100644 --- a/internal/output/aws_credentials_file_test.go +++ b/internal/output/aws_credentials_file_test.go @@ -58,7 +58,7 @@ func TestINIFormatCredentialsContent(t *testing.T) { SecretAccessKey: "e", SessionToken: "f", } - config, err := updateConfig(filename, "test", cfc, false, false, "") + config, err := updateConfig(filename, "test", cfc, false, false, "", "us-east-2") assert.NoError(t, err) err = config.SaveTo(filename) @@ -141,7 +141,7 @@ aws_security_token = ghi t.Run(test.name, func(t *testing.T) { config, err := ini.Load(test.config) require.NoError(t, err) - ini, err := updateINI(config, "default", test.legacy, false) + ini, err := updateINI(config, "default", test.legacy, false, "us-east-2") require.NoError(t, err) section := ini.Section(test.section) require.Equal(t, len(test.want), len(section.KeyStrings())) diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 742e83b..5891365 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -36,8 +36,10 @@ const ( // AccessKeyID AWS creds access key ID AccessKeyID = "AccessKeyID" + // Region region + Region = "Region" // SecretAccessKey AWS creds secret access key SecretAccessKey = "SecretAccessKey" - // SessionToken AWS creds session tokne + // SessionToken AWS creds session token SessionToken = "SessionToken" ) diff --git a/internal/webssoauth/webssoauth.go b/internal/webssoauth/webssoauth.go index e92ede5..9c55a9e 100644 --- a/internal/webssoauth/webssoauth.go +++ b/internal/webssoauth/webssoauth.go @@ -105,8 +105,9 @@ type WebSSOAuthentication struct { // idpAndRole IdP and role pairs type idpAndRole struct { - idp string - role string + idp string + role string + region string } var stderrIsOutAskOpt = func(options *survey.AskOptions) error { @@ -171,7 +172,7 @@ func (w *WebSSOAuthentication) EstablishIAMCredentials() error { } if w.config.FedAppID() != "" { // Alternate path when operator knows their AWS Fed app ID - err = w.establishTokenWithFedAppID(clientID, w.config.FedAppID(), at) + err = w.establishTokenWithFedAppID(clientID, w.config.FedAppID(), at, w.config.AWSRegion()) if at != nil && err != nil { // possible bad cached access token, retry at = nil @@ -210,7 +211,7 @@ AWS Federation App with --aws-acct-fed-app-id FED_APP_ID // special case, we're going to run the table and get all profiles for all apps errArr := []error{} for _, app := range apps { - if err = w.establishTokenWithFedAppID(clientID, app.ID, at); err != nil { + if err = w.establishTokenWithFedAppID(clientID, app.ID, at, w.config.AWSRegion()); err != nil { errArr = append(errArr, err) } } @@ -227,7 +228,7 @@ AWS Federation App with --aws-acct-fed-app-id FED_APP_ID } } - return w.establishTokenWithFedAppID(clientID, fedAppID, at) + return w.establishTokenWithFedAppID(clientID, fedAppID, at, w.config.AWSRegion()) } // choiceFriendlyLabelIDP returns a friendly choice for pretty printing IDP @@ -310,7 +311,7 @@ func (w *WebSSOAuthentication) selectFedApp(apps []*okta.Application) (string, e return idps[selected].ID, nil } -func (w *WebSSOAuthentication) establishTokenWithFedAppID(clientID, fedAppID string, at *okta.AccessToken) error { +func (w *WebSSOAuthentication) establishTokenWithFedAppID(clientID, fedAppID string, at *okta.AccessToken, region string) error { at, err := w.fetchSSOWebToken(clientID, fedAppID, at) if err != nil { return err @@ -331,8 +332,9 @@ func (w *WebSSOAuthentication) establishTokenWithFedAppID(clientID, fedAppID str if err != nil { return err } + iar.region = region - cc, err := w.awsAssumeRoleWithSAML(iar, assertion) + cc, err := w.awsAssumeRoleWithSAML(iar, assertion, region) if err != nil { return err } @@ -349,7 +351,7 @@ func (w *WebSSOAuthentication) establishTokenWithFedAppID(clientID, fedAppID str } } } else { - ccch := w.fetchAllAWSCredentialsWithSAMLRole(idpRolesMap, assertion) + ccch := w.fetchAllAWSCredentialsWithSAMLRole(idpRolesMap, assertion, region) if err != nil { return err } @@ -368,8 +370,11 @@ func (w *WebSSOAuthentication) establishTokenWithFedAppID(clientID, fedAppID str // awsAssumeRoleWithSAML Get AWS Credentials with an STS Assume Role With SAML AWS // API call. -func (w *WebSSOAuthentication) awsAssumeRoleWithSAML(iar *idpAndRole, assertion string) (cc *oaws.CredentialContainer, err error) { +func (w *WebSSOAuthentication) awsAssumeRoleWithSAML(iar *idpAndRole, assertion, region string) (cc *oaws.CredentialContainer, err error) { awsCfg := aws.NewConfig().WithHTTPClient(w.config.HTTPClient()) + if region != "" { + awsCfg = awsCfg.WithRegion(region) + } sess, err := session.NewSession(awsCfg) if err != nil { return @@ -1104,17 +1109,17 @@ func (w *WebSSOAuthentication) consolePrint(format string, a ...any) { } // fetchAllAWSCredentialsWithSAMLRole Gets all AWS Credentials with an STS Assume Role with SAML AWS API call. -func (w *WebSSOAuthentication) fetchAllAWSCredentialsWithSAMLRole(idpRolesMap map[string][]string, assertion string) <-chan *oaws.CredentialContainer { +func (w *WebSSOAuthentication) fetchAllAWSCredentialsWithSAMLRole(idpRolesMap map[string][]string, assertion, region string) <-chan *oaws.CredentialContainer { ccch := make(chan *oaws.CredentialContainer) var wg sync.WaitGroup for idp, roles := range idpRolesMap { for _, role := range roles { - iar := &idpAndRole{idp, role} + iar := &idpAndRole{idp, role, region} wg.Add(1) go func() { defer wg.Done() - cc, err := w.awsAssumeRoleWithSAML(iar, assertion) + cc, err := w.awsAssumeRoleWithSAML(iar, assertion, region) if err != nil { fmt.Fprintf(os.Stderr, "failed to fetch AWS creds IdP %q, and Role %q, error:\n%+v\n", iar.idp, iar.role, err) return