Skip to content

Commit

Permalink
Merge remote-tracking branch 'MatthewJohn/master' into pr_162_Matthew…
Browse files Browse the repository at this point in the history
…John
  • Loading branch information
monde committed Feb 13, 2024
2 parents d6f4461 + 2bdd95a commit 2992d2b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 26 deletions.
62 changes: 36 additions & 26 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,34 +315,44 @@ func NewConfig(attrs *Attributes) (*Config, error) {
return cfg, nil
}

func getFlagNameFromProfile(awsProfile string, flag string) string {
profileKey := fmt.Sprintf("%s.%s", awsProfile, flag)
if awsProfile != "" && viper.IsSet(profileKey) == true {
return profileKey
}
return flag
}

func readConfig() (Attributes, error) {
awsProfile := viper.GetString(ProfileFlag)

attrs := Attributes{
AllProfiles: viper.GetBool(AllProfilesFlag),
AuthzID: viper.GetString(AuthzIDFlag),
AWSCredentials: viper.GetString(AWSCredentialsFlag),
AWSIAMIdP: viper.GetString(AWSIAMIdPFlag),
AWSIAMRole: viper.GetString(AWSIAMRoleFlag),
AWSSessionDuration: viper.GetInt64(SessionDurationFlag),
AWSRegion: viper.GetString(AWSRegionFlag),
CustomScope: viper.GetString(CustomScopeFlag),
Debug: viper.GetBool(DebugFlag),
DebugAPICalls: viper.GetBool(DebugAPICallsFlag),
Exec: viper.GetBool(ExecFlag),
FedAppID: viper.GetString(AWSAcctFedAppIDFlag),
Format: viper.GetString(FormatFlag),
LegacyAWSVariables: viper.GetBool(LegacyAWSVariablesFlag),
ExpiryAWSVariables: viper.GetBool(ExpiryAWSVariablesFlag),
CacheAccessToken: viper.GetBool(CacheAccessTokenFlag),
OIDCAppID: viper.GetString(OIDCClientIDFlag),
OpenBrowser: viper.GetBool(OpenBrowserFlag),
OpenBrowserCommand: viper.GetString(OpenBrowserCommandFlag),
OrgDomain: viper.GetString(OrgDomainFlag),
PrivateKey: viper.GetString(PrivateKeyFlag),
PrivateKeyFile: viper.GetString(PrivateKeyFileFlag),
KeyID: viper.GetString(KeyIDFlag),
Profile: viper.GetString(ProfileFlag),
QRCode: viper.GetBool(QRCodeFlag),
WriteAWSCredentials: viper.GetBool(WriteAWSCredentialsFlag),
AllProfiles: viper.GetBool(getFlagNameFromProfile(awsProfile, AllProfilesFlag)),
AuthzID: viper.GetString(getFlagNameFromProfile(awsProfile, AuthzIDFlag)),
AWSCredentials: viper.GetString(getFlagNameFromProfile(awsProfile, AWSCredentialsFlag)),
AWSIAMIdP: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMIdPFlag)),
AWSIAMRole: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMRoleFlag)),
AWSRegion: viper.GetString(getFlagNameFromProfile(awsProfile, AWSRegionFlag)),
AWSSessionDuration: viper.GetInt64(getFlagNameFromProfile(awsProfile, SessionDurationFlag)),
CustomScope: viper.GetString(getFlagNameFromProfile(awsProfile, CustomScopeFlag)),
Debug: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugFlag)),
DebugAPICalls: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugAPICallsFlag)),
Exec: viper.GetBool(getFlagNameFromProfile(awsProfile, ExecFlag)),
FedAppID: viper.GetString(getFlagNameFromProfile(awsProfile, AWSAcctFedAppIDFlag)),
Format: viper.GetString(getFlagNameFromProfile(awsProfile, FormatFlag)),
LegacyAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, LegacyAWSVariablesFlag)),
ExpiryAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, ExpiryAWSVariablesFlag)),
CacheAccessToken: viper.GetBool(getFlagNameFromProfile(awsProfile, CacheAccessTokenFlag)),
OIDCAppID: viper.GetString(getFlagNameFromProfile(awsProfile, OIDCClientIDFlag)),
OpenBrowser: viper.GetBool(getFlagNameFromProfile(awsProfile, OpenBrowserFlag)),
OpenBrowserCommand: viper.GetString(getFlagNameFromProfile(awsProfile, OpenBrowserCommandFlag)),
OrgDomain: viper.GetString(getFlagNameFromProfile(awsProfile, OrgDomainFlag)),
PrivateKey: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFlag)),
PrivateKeyFile: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFileFlag)),
KeyID: viper.GetString(getFlagNameFromProfile(awsProfile, KeyIDFlag)),
Profile: viper.GetString(getFlagNameFromProfile(awsProfile, ProfileFlag)),
QRCode: viper.GetBool(getFlagNameFromProfile(awsProfile, QRCodeFlag)),
WriteAWSCredentials: viper.GetBool(getFlagNameFromProfile(awsProfile, WriteAWSCredentialsFlag)),
}
if attrs.Format == "" {
attrs.Format = EnvVarFormat
Expand Down
24 changes: 24 additions & 0 deletions internal/flag/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"os"
"os/user"
"path/filepath"
"strings"

Expand Down Expand Up @@ -81,7 +82,30 @@ func MakeFlagBindings(cmd *cobra.Command, flags []Flag, persistent bool) {
if vipAwsRegion != "" && os.Getenv(awsRegionEnvVar) == "" {
_ = os.Setenv(awsRegionEnvVar, vipAwsRegion)
}
} else {
// Check if .okta-aws-cli/conifg.yml exists
usr, err := user.Current()
if err == nil {
oktaConfig := filepath.Join(usr.HomeDir, ".okta-aws-cli", "config.yml")
if _, err := os.Stat(oktaConfig); err == nil || !errors.Is(err, os.ErrNotExist) {
viper.AddConfigPath(filepath.Join(usr.HomeDir, ".okta-aws-cli"))
viper.SetConfigName("config.yml")
viper.SetConfigType("yml")

_ = viper.ReadInConfig()

// After viper reads in the dotenv file check if AWS_REGION is set
// there. The value will be keyed by lower case name. If it is, set
// AWS_REGION as an ENV VAR if it hasn't already been.
awsRegionEnvVar := "AWS_REGION"
vipAwsRegion := viper.GetString(strings.ToLower(awsRegionEnvVar))
if vipAwsRegion != "" && os.Getenv(awsRegionEnvVar) == "" {
_ = os.Setenv(awsRegionEnvVar, vipAwsRegion)
}
}
}
}

viper.AutomaticEnv()

// bind cli flags
Expand Down

0 comments on commit 2992d2b

Please sign in to comment.