From b5e01de6b7271580396a1f39b3a73d9023bfc239 Mon Sep 17 00:00:00 2001 From: Mike Mondragon Date: Fri, 22 Sep 2023 16:55:31 -0700 Subject: [PATCH] M2M auth access token request --- README.md | 2 + cmd/root/m2m/m2m.go | 33 +- cmd/root/web/web.go | 7 + go.mod | 4 + go.sum | 6 + internal/config/config.go | 309 +++++++++++------- internal/m2mauth/m2mauth.go | 177 ++++++++++ internal/m2mauth/m2mauth_test.go | 155 +++++++++ internal/okta/accesstoken.go | 30 ++ internal/okta/apierror.go | 23 ++ internal/okta/client_assertion_claims.go | 31 ++ internal/okta/okta.go | 25 ++ internal/sessiontoken/sessiontoken.go | 79 ++--- internal/testutils/testutils.go | 157 +++++++++ internal/utils/utils.go | 26 ++ test/fixtures/vcr/TestM2MAuthAccessToken.yaml | 42 +++ .../vcr/TestM2MAuthMakeClientAssertion.yaml | 3 + 17 files changed, 925 insertions(+), 184 deletions(-) create mode 100644 internal/m2mauth/m2mauth.go create mode 100644 internal/m2mauth/m2mauth_test.go create mode 100644 internal/okta/accesstoken.go create mode 100644 internal/okta/apierror.go create mode 100644 internal/okta/client_assertion_claims.go create mode 100644 internal/okta/okta.go create mode 100644 internal/testutils/testutils.go create mode 100644 internal/utils/utils.go create mode 100644 test/fixtures/vcr/TestM2MAuthAccessToken.yaml create mode 100644 test/fixtures/vcr/TestM2MAuthMakeClientAssertion.yaml diff --git a/README.md b/README.md index 5965800..74bb66b 100644 --- a/README.md +++ b/README.md @@ -427,6 +427,8 @@ These settings are optional unless marked otherwise: | Name | Description | Command line flag | ENV var and .env file value | |-----|-----|-----|-----| +| Custom Authorization Server ID (**required**) | The ID of the Okta custom authorization server | `--authz-id [value]` | `OKTA_AUTHZ_ID` | +| Key ID (kid) (**required**) | The ID of the key stored in the service app | `--key-id [value]` | `OKTA_AWSCLI_KEY_ID` | | Private Key (**required**) | PEM or JWKS format private key whose public key is stored on the service app | `--private-key [value]` | `OKTA_AWSCLI_PRIVATE_KEY` | | Custom scope name | The custom scope established in the custom authorization server. Default `okta-aws-cli` | `--custom-scope [value]` | `OKTA_AWSCLI_CUSTOM_SCOPE` | diff --git a/cmd/root/m2m/m2m.go b/cmd/root/m2m/m2m.go index f83f98c..23bff67 100644 --- a/cmd/root/m2m/m2m.go +++ b/cmd/root/m2m/m2m.go @@ -17,17 +17,22 @@ package m2m import ( - "fmt" - "os" - "github.com/spf13/cobra" "github.com/okta/okta-aws-cli/internal/config" cliFlag "github.com/okta/okta-aws-cli/internal/flag" + "github.com/okta/okta-aws-cli/internal/m2mauth" ) var ( flags = []cliFlag.Flag{ + { + Name: config.KeyIDFlag, + Short: "i", + Value: "", + Usage: "Key ID", + EnvVar: config.KeyIDEnvVar, + }, { Name: config.PrivateKeyFlag, Short: "k", @@ -42,8 +47,15 @@ var ( Usage: "Custom Scope", EnvVar: config.CustomScopeEnvVar, }, + { + Name: config.AuthzIDFlag, + Short: "u", + Value: "", + Usage: "Custom Authorization Server ID", + EnvVar: config.AuthzIDEnvVar, + }, } - requiredFlags = []string{"org-domain", "oidc-client-id", "aws-iam-role", "private-key"} + requiredFlags = []string{"org-domain", "oidc-client-id", "aws-iam-role", "key-id", "private-key", "authz-id"} ) // NewM2MCommand Sets up the m2m cobra sub command @@ -61,14 +73,11 @@ func NewM2MCommand() *cobra.Command { return err } - fmt.Fprintf(os.Stderr, "WIP - m2m, get to work!\n") - fmt.Fprintf(os.Stderr, "Okta Org Domain: %s\n", config.OrgDomain()) - fmt.Fprintf(os.Stderr, "OIDC Client ID: %s\n", config.OIDCAppID()) - fmt.Fprintf(os.Stderr, "IAM Role ARN: %s\n", config.AWSIAMRole()) - fmt.Fprintf(os.Stderr, "Private Key: %s\n", config.PrivateKey()) - fmt.Fprintf(os.Stderr, "Custom Scope: %s\n", config.CustomScope()) - - return nil + m2mAuth, err := m2mauth.NewM2MAuthentication(config) + if err != nil { + return err + } + return m2mAuth.EstablishIAMCredentials() }, } diff --git a/cmd/root/web/web.go b/cmd/root/web/web.go index 80d3150..5316f08 100644 --- a/cmd/root/web/web.go +++ b/cmd/root/web/web.go @@ -73,6 +73,13 @@ func NewWebCommand() *cobra.Command { return err } + // TODO refactor the naming convention + // webAuth, err := webauth.NewWebSSOAuthentication(config) + // if err != nil { + // return err + // } + // return webAuth.EstablishIAMCredentials() + st, err := sessiontoken.NewSessionToken(config) if err != nil { return err diff --git a/go.mod b/go.mod index ca0bbe3..8371208 100644 --- a/go.mod +++ b/go.mod @@ -17,10 +17,14 @@ require ( github.com/tidwall/pretty v1.2.0 golang.org/x/net v0.7.0 golang.org/x/sys v0.5.0 + gopkg.in/dnaeon/go-vcr.v3 v3.1.2 gopkg.in/ini.v1 v1.67.0 + gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/yaml.v2 v2.4.0 ) +require golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect + require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect diff --git a/go.sum b/go.sum index f9b9dfa..482007f 100644 --- a/go.sum +++ b/go.sum @@ -234,6 +234,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= +golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -523,9 +525,13 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/dnaeon/go-vcr.v3 v3.1.2 h1:F1smfXBqQqwpVifDfUBQG6zzaGjzT+EnVZakrOdr5wA= +gopkg.in/dnaeon/go-vcr.v3 v3.1.2/go.mod h1:2IMOnnlx9I6u9x+YBsM3tAMx6AlOxnJ0pWxQAzZ79Ag= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= +gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/internal/config/config.go b/internal/config/config.go index 0ff910c..9413bae 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -38,6 +38,8 @@ const ( // EnvVarFormat format const EnvVarFormat = "env-var" + // AuthzIDFlag cli flag const + AuthzIDFlag = "authz-id" // AWSAcctFedAppIDFlag cli flag const AWSAcctFedAppIDFlag = "aws-acct-fed-app-id" // AWSCredentialsFlag cli flag const @@ -62,6 +64,8 @@ const ( OrgDomainFlag = "org-domain" // PrivateKeyFlag cli flag const PrivateKeyFlag = "private-key" + // KeyIDFlag cli flag const + KeyIDFlag = "key-id" // ProfileFlag cli flag const ProfileFlag = "profile" // QRCodeFlag cli flag const @@ -77,6 +81,8 @@ const ( // CacheAccessTokenFlag cli flag const CacheAccessTokenFlag = "cache-access-token" + // AuthzIDEnvVar env var const + AuthzIDEnvVar = "OKTA_AUTHZ_ID" // AWSCredentialsEnvVar env var const AWSCredentialsEnvVar = "OKTA_AWSCLI_AWS_CREDENTIALS" // AWSIAMIdPEnvVar env var const @@ -109,6 +115,8 @@ const ( OpenBrowserEnvVar = "OKTA_AWSCLI_OPEN_BROWSER" // PrivateKeyEnvVar env var const PrivateKeyEnvVar = "OKTA_AWSCLI_PRIVATE_KEY" + // KeyIDEnvVar env var const + KeyIDEnvVar = "OKTA_AWSCLI_KEY_ID" // ProfileEnvVar env var const ProfileEnvVar = "OKTA_AWSCLI_PROFILE" // QRCodeEnvVar env var const @@ -135,6 +143,11 @@ type OktaYamlConfig struct { } `yaml:"awscli"` } +// Clock interface to abstract time operations +type Clock interface { + Now() time.Time +} + // Config A config object for the CLI // // External consumers of Config use its setters and getters to interact with the @@ -142,6 +155,7 @@ type OktaYamlConfig struct { // control data access, be concerned with evaluation, validation, and not // allowing direct access to values as is done on structs in the generic case. type Config struct { + authzID string awsCredentials string awsIAMIdP string awsIAMRole string @@ -154,6 +168,7 @@ type Config struct { fedAppID string format string httpClient *http.Client + keyID string legacyAWSVariables bool oidcAppID string openBrowser bool @@ -162,29 +177,32 @@ type Config struct { profile string qrCode bool writeAWSCredentials bool -} - -// attributes config construction -type attributes struct { - awsCredentials string - awsIAMIdP string - awsIAMRole string - awsSessionDuration int64 - cacheAccessToken bool - customScope string - debug bool - debugAPICalls bool - expiryAWSVariables bool - fedAppID string - format string - legacyAWSVariables bool - oidcAppID string - openBrowser bool - orgDomain string - privateKey string - profile string - qrCode bool - writeAWSCredentials bool + clock Clock +} + +// Attributes config construction +type Attributes struct { + AuthzID string + AWSCredentials string + AWSIAMIdP string + AWSIAMRole string + AWSSessionDuration int64 + CacheAccessToken bool + CustomScope string + Debug bool + DebugAPICalls bool + ExpiryAWSVariables bool + FedAppID string + Format string + KeyID string + LegacyAWSVariables bool + OIDCAppID string + OpenBrowser bool + OrgDomain string + PrivateKey string + Profile string + QRCode bool + WriteAWSCredentials bool } // EvaluateSettings Returns a new config gathering values in this order of precedence: @@ -196,39 +214,41 @@ func EvaluateSettings() (*Config, error) { if err != nil { return nil, err } - return NewConfig(cfgAttrs) + return NewConfig(&cfgAttrs) } // NewConfig create config from attributes -func NewConfig(attrs attributes) (*Config, error) { +func NewConfig(attrs *Attributes) (*Config, error) { var err error cfg := &Config{ - awsCredentials: attrs.awsCredentials, - awsIAMIdP: attrs.awsIAMIdP, - awsIAMRole: attrs.awsIAMRole, - cacheAccessToken: attrs.cacheAccessToken, - customScope: attrs.customScope, - debug: attrs.debug, - debugAPICalls: attrs.debugAPICalls, - expiryAWSVariables: attrs.expiryAWSVariables, - fedAppID: attrs.fedAppID, - format: attrs.format, - legacyAWSVariables: attrs.legacyAWSVariables, - openBrowser: attrs.openBrowser, - privateKey: attrs.privateKey, - profile: attrs.profile, - qrCode: attrs.qrCode, - writeAWSCredentials: attrs.writeAWSCredentials, - } - err = cfg.SetOrgDomain(attrs.orgDomain) + authzID: attrs.AuthzID, + awsCredentials: attrs.AWSCredentials, + awsIAMIdP: attrs.AWSIAMIdP, + awsIAMRole: attrs.AWSIAMRole, + cacheAccessToken: attrs.CacheAccessToken, + customScope: attrs.CustomScope, + debug: attrs.Debug, + debugAPICalls: attrs.DebugAPICalls, + expiryAWSVariables: attrs.ExpiryAWSVariables, + fedAppID: attrs.FedAppID, + format: attrs.Format, + legacyAWSVariables: attrs.LegacyAWSVariables, + openBrowser: attrs.OpenBrowser, + privateKey: attrs.PrivateKey, + keyID: attrs.KeyID, + profile: attrs.Profile, + qrCode: attrs.QRCode, + writeAWSCredentials: attrs.WriteAWSCredentials, + } + err = cfg.SetOrgDomain(attrs.OrgDomain) if err != nil { return nil, err } - err = cfg.SetOIDCAppID(attrs.oidcAppID) + err = cfg.SetOIDCAppID(attrs.OIDCAppID) if err != nil { return nil, err } - err = cfg.SetAWSSessionDuration(attrs.awsSessionDuration) + err = cfg.SetAWSSessionDuration(attrs.AWSSessionDuration) if err != nil { return nil, err } @@ -240,133 +260,142 @@ func NewConfig(attrs attributes) (*Config, error) { if err != nil { return nil, err } + cfg.clock = &realClock{} return cfg, nil } -func readConfig() (attributes, error) { - attrs := attributes{ - awsCredentials: viper.GetString(AWSCredentialsFlag), - awsIAMIdP: viper.GetString(AWSIAMIdPFlag), - awsIAMRole: viper.GetString(AWSIAMRoleFlag), - awsSessionDuration: viper.GetInt64(SessionDurationFlag), - customScope: viper.GetString(CustomScopeFlag), - debug: viper.GetBool(DebugFlag), - debugAPICalls: viper.GetBool(DebugAPICallsFlag), - fedAppID: viper.GetString(AWSAcctFedAppIDFlag), - format: viper.GetString(FormatFlag), - legacyAWSVariables: viper.GetBool(LegacyAWSVariablesFlag), - expiryAWSVariables: viper.GetBool(ExpiryAWSVariablesFlag), - cacheAccessToken: viper.GetBool(CacheAccessTokenFlag), - oidcAppID: viper.GetString(OIDCClientIDFlag), - openBrowser: viper.GetBool(OpenBrowserFlag), - orgDomain: viper.GetString(OrgDomainFlag), - privateKey: viper.GetString(PrivateKeyFlag), - profile: viper.GetString(ProfileFlag), - qrCode: viper.GetBool(QRCodeFlag), - writeAWSCredentials: viper.GetBool(WriteAWSCredentialsFlag), - } - if attrs.format == "" { - attrs.format = EnvVarFormat +func readConfig() (Attributes, error) { + attrs := Attributes{ + AuthzID: viper.GetString(AuthzIDFlag), + AWSCredentials: viper.GetString(AWSCredentialsFlag), + AWSIAMIdP: viper.GetString(AWSIAMIdPFlag), + AWSIAMRole: viper.GetString(AWSIAMRoleFlag), + AWSSessionDuration: viper.GetInt64(SessionDurationFlag), + CustomScope: viper.GetString(CustomScopeFlag), + Debug: viper.GetBool(DebugFlag), + DebugAPICalls: viper.GetBool(DebugAPICallsFlag), + FedAppID: viper.GetString(AWSAcctFedAppIDFlag), + Format: viper.GetString(FormatFlag), + LegacyAWSVariables: viper.GetBool(LegacyAWSVariablesFlag), + ExpiryAWSVariables: viper.GetBool(ExpiryAWSVariablesFlag), + CacheAccessToken: viper.GetBool(CacheAccessTokenFlag), + OIDCAppID: viper.GetString(OIDCClientIDFlag), + OpenBrowser: viper.GetBool(OpenBrowserFlag), + OrgDomain: viper.GetString(OrgDomainFlag), + PrivateKey: viper.GetString(PrivateKeyFlag), + KeyID: viper.GetString(KeyIDFlag), + Profile: viper.GetString(ProfileFlag), + QRCode: viper.GetBool(QRCodeFlag), + WriteAWSCredentials: viper.GetBool(WriteAWSCredentialsFlag), + } + if attrs.Format == "" { + attrs.Format = EnvVarFormat } // mimic AWS CLI behavior, if profile value is not set by flag check // the ENV VAR, else set to "default" - if attrs.profile == "" { - attrs.profile = viper.GetString(downCase(ProfileEnvVar)) + if attrs.Profile == "" { + attrs.Profile = viper.GetString(downCase(ProfileEnvVar)) } - if attrs.profile == "" { - attrs.profile = "default" + if attrs.Profile == "" { + attrs.Profile = "default" } // Viper binds ENV VARs to a lower snake version, set the configs with them // if they haven't already been set by cli flag binding. - if attrs.orgDomain == "" { - attrs.orgDomain = viper.GetString(downCase(OktaOrgDomainEnvVar)) + if attrs.OrgDomain == "" { + attrs.OrgDomain = viper.GetString(downCase(OktaOrgDomainEnvVar)) + } + if attrs.OIDCAppID == "" { + attrs.OIDCAppID = viper.GetString(downCase(OktaOIDCClientIDEnvVar)) } - if attrs.oidcAppID == "" { - attrs.oidcAppID = viper.GetString(downCase(OktaOIDCClientIDEnvVar)) + if attrs.FedAppID == "" { + attrs.FedAppID = viper.GetString(downCase(OktaAWSAccountFederationAppIDEnvVar)) } - if attrs.fedAppID == "" { - attrs.fedAppID = viper.GetString(downCase(OktaAWSAccountFederationAppIDEnvVar)) + if attrs.AWSIAMIdP == "" { + attrs.AWSIAMIdP = viper.GetString(downCase(AWSIAMIdPEnvVar)) } - if attrs.awsIAMIdP == "" { - attrs.awsIAMIdP = viper.GetString(downCase(AWSIAMIdPEnvVar)) + if attrs.AWSIAMRole == "" { + attrs.AWSIAMRole = viper.GetString(downCase(AWSIAMRoleEnvVar)) } - if attrs.awsIAMRole == "" { - attrs.awsIAMRole = viper.GetString(downCase(AWSIAMRoleEnvVar)) + if !attrs.QRCode { + attrs.QRCode = viper.GetBool(downCase(QRCodeEnvVar)) } - if !attrs.qrCode { - attrs.qrCode = viper.GetBool(downCase(QRCodeEnvVar)) + if attrs.PrivateKey == "" { + attrs.PrivateKey = viper.GetString(downCase(PrivateKeyEnvVar)) } - if attrs.privateKey == "" { - attrs.privateKey = viper.GetString(downCase(PrivateKeyEnvVar)) + if attrs.KeyID == "" { + attrs.KeyID = viper.GetString(downCase(KeyIDEnvVar)) } - if attrs.customScope == "" { - attrs.customScope = viper.GetString(downCase(CustomScopeEnvVar)) + if attrs.CustomScope == "" { + attrs.CustomScope = viper.GetString(downCase(CustomScopeEnvVar)) + } + if attrs.AuthzID == "" { + attrs.AuthzID = viper.GetString(downCase(AuthzIDEnvVar)) } // if session duration is 0, inspect the ENV VAR for a value, else set // a default of 3600 - if attrs.awsSessionDuration == 0 { - attrs.awsSessionDuration = viper.GetInt64(downCase(AWSSessionDurationEnvVar)) + if attrs.AWSSessionDuration == 0 { + attrs.AWSSessionDuration = viper.GetInt64(downCase(AWSSessionDurationEnvVar)) } - if attrs.awsSessionDuration == 0 { - attrs.awsSessionDuration = 3600 + if attrs.AWSSessionDuration == 0 { + attrs.AWSSessionDuration = 3600 } // correct org domain if it's in admin form - orgDomain := strings.Replace(attrs.orgDomain, "-admin", "", -1) - if orgDomain != attrs.orgDomain { - fmt.Fprintf(os.Stderr, "WARNING: proactively correcting org domain %q to non-admin form %q.\n\n", attrs.orgDomain, orgDomain) - attrs.orgDomain = orgDomain + orgDomain := strings.Replace(attrs.OrgDomain, "-admin", "", -1) + if orgDomain != attrs.OrgDomain { + fmt.Fprintf(os.Stderr, "WARNING: proactively correcting org domain %q to non-admin form %q.\n\n", attrs.OrgDomain, orgDomain) + attrs.OrgDomain = orgDomain } - if strings.HasPrefix(attrs.orgDomain, "http") { - u, err := url.Parse(attrs.orgDomain) + if strings.HasPrefix(attrs.OrgDomain, "http") { + u, err := url.Parse(attrs.OrgDomain) // try to help correct org domain value if parsing occurs correctly, // else let the CLI error out else where if err == nil { orgDomain = u.Hostname() - fmt.Fprintf(os.Stderr, "WARNING: proactively correcting URL format org domain %q value to hostname only form %q.\n\n", attrs.orgDomain, orgDomain) - attrs.orgDomain = orgDomain + fmt.Fprintf(os.Stderr, "WARNING: proactively correcting URL format org domain %q value to hostname only form %q.\n\n", attrs.OrgDomain, orgDomain) + attrs.OrgDomain = orgDomain } } - if strings.HasSuffix(attrs.orgDomain, "/") { - orgDomain = string([]byte(attrs.orgDomain)[0 : len(attrs.orgDomain)-1]) + if strings.HasSuffix(attrs.OrgDomain, "/") { + orgDomain = string([]byte(attrs.OrgDomain)[0 : len(attrs.OrgDomain)-1]) // try to help correct malformed org domain value - fmt.Fprintf(os.Stderr, "WARNING: proactively correcting malformed org domain %q value to hostname only form %q.\n\n", attrs.orgDomain, orgDomain) - attrs.orgDomain = orgDomain + fmt.Fprintf(os.Stderr, "WARNING: proactively correcting malformed org domain %q value to hostname only form %q.\n\n", attrs.OrgDomain, orgDomain) + attrs.OrgDomain = orgDomain } // There is always a default aws credentials path set in root.go's init // function so overwrite the config value if the operator is attempting to // set it by ENV VAR value. if viper.GetString(downCase(AWSCredentialsEnvVar)) != "" { - attrs.awsCredentials = viper.GetString(downCase(AWSCredentialsEnvVar)) + attrs.AWSCredentials = viper.GetString(downCase(AWSCredentialsEnvVar)) } - if !attrs.writeAWSCredentials { - attrs.writeAWSCredentials = viper.GetBool(downCase(WriteAWSCredentialsEnvVar)) + if !attrs.WriteAWSCredentials { + attrs.WriteAWSCredentials = viper.GetBool(downCase(WriteAWSCredentialsEnvVar)) } - if attrs.writeAWSCredentials { + if attrs.WriteAWSCredentials { // writing aws creds option implies "aws-credentials" format - attrs.format = AWSCredentialsFormat + attrs.Format = AWSCredentialsFormat } - if !attrs.openBrowser { - attrs.openBrowser = viper.GetBool(downCase(OpenBrowserEnvVar)) + if !attrs.OpenBrowser { + attrs.OpenBrowser = viper.GetBool(downCase(OpenBrowserEnvVar)) } - if !attrs.debug { - attrs.debug = viper.GetBool(downCase(DebugEnvVar)) + if !attrs.Debug { + attrs.Debug = viper.GetBool(downCase(DebugEnvVar)) } - if !attrs.debugAPICalls { - attrs.debugAPICalls = viper.GetBool(downCase(DebugAPICallsEnvVar)) + if !attrs.DebugAPICalls { + attrs.DebugAPICalls = viper.GetBool(downCase(DebugAPICallsEnvVar)) } - if !attrs.legacyAWSVariables { - attrs.legacyAWSVariables = viper.GetBool(downCase(LegacyAWSVariablesEnvVar)) + if !attrs.LegacyAWSVariables { + attrs.LegacyAWSVariables = viper.GetBool(downCase(LegacyAWSVariablesEnvVar)) } - if !attrs.expiryAWSVariables { - attrs.expiryAWSVariables = viper.GetBool(downCase(ExpiryAWSVariablesEnvVar)) + if !attrs.ExpiryAWSVariables { + attrs.ExpiryAWSVariables = viper.GetBool(downCase(ExpiryAWSVariablesEnvVar)) } - if !attrs.cacheAccessToken { - attrs.cacheAccessToken = viper.GetBool(downCase(CacheAccessTokenEnvVar)) + if !attrs.CacheAccessToken { + attrs.CacheAccessToken = viper.GetBool(downCase(CacheAccessTokenEnvVar)) } return attrs, nil } @@ -376,6 +405,17 @@ func downCase(s string) string { return strings.ToLower(s) } +// AuthzID -- +func (c *Config) AuthzID() string { + return c.authzID +} + +// SetAuthzID -- +func (c *Config) SetAuthzID(authzID string) error { + c.authzID = authzID + return nil +} + // AWSCredentials -- func (c *Config) AWSCredentials() string { return c.awsCredentials @@ -442,6 +482,16 @@ func (c *Config) SetCacheAccessToken(cacheAccessToken bool) error { return nil } +// Clock -- +func (c *Config) Clock() Clock { + return c.clock +} + +// SetClock -- +func (c *Config) SetClock(clock Clock) { + c.clock = clock +} + // CustomScope -- func (c *Config) CustomScope() string { return c.customScope @@ -574,6 +624,17 @@ func (c *Config) SetPrivateKey(privateKey string) error { return nil } +// KeyID -- +func (c *Config) KeyID() string { + return c.keyID +} + +// SetKeyID -- +func (c *Config) SetKeyID(keyID string) error { + c.keyID = keyID + return nil +} + // Profile -- func (c *Config) Profile() string { return c.profile @@ -744,3 +805,7 @@ awscli: fmt.Fprintf(os.Stderr, "okta.yaml is OK\n") return nil } + +type realClock struct{} + +func (realClock) Now() time.Time { return time.Now() } diff --git a/internal/m2mauth/m2mauth.go b/internal/m2mauth/m2mauth.go new file mode 100644 index 0000000..0d991a0 --- /dev/null +++ b/internal/m2mauth/m2mauth.go @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2023-Present, Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package m2mauth + +import ( + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/okta/okta-aws-cli/internal/config" + "github.com/okta/okta-aws-cli/internal/okta" + "github.com/okta/okta-aws-cli/internal/utils" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +// M2MAuthentication Object structure for headless authentication +type M2MAuthentication struct { + config *config.Config +} + +// NewM2MAuthentication New M2M Authenticator constructor +func NewM2MAuthentication(config *config.Config) (*M2MAuthentication, error) { + m := M2MAuthentication{ + config: config, + } + return &m, nil +} + +// EstablishIAMCredentials Full operation to fetch temporary IAM credentials and +// output them to preferred format. +func (m *M2MAuthentication) EstablishIAMCredentials() error { + _, err := m.AccessToken() + if err != nil { + return err + } + // WIP + // out, err := m.AssumeRole(at) (*sts.AssumeRoleWithWebIdentityOutput, error) { + // err = m.OutputCredentials(out) + return nil +} + +func (m *M2MAuthentication) createKeySigner() (jose.Signer, error) { + signerOptions := (&jose.SignerOptions{}).WithHeader("kid", m.config.KeyID()) + priv := []byte(strings.ReplaceAll(m.config.PrivateKey(), `\n`, "\n")) + + privPem, _ := pem.Decode(priv) + if privPem == nil { + return nil, errors.New("invalid private key") + } + + if privPem.Type == "RSA PRIVATE KEY" { + parsedKey, err := x509.ParsePKCS1PrivateKey(privPem.Bytes) + if err != nil { + return nil, err + } + return jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: parsedKey}, signerOptions) + } + if privPem.Type == "PRIVATE KEY" { + parsedKey, err := x509.ParsePKCS8PrivateKey(privPem.Bytes) + if err != nil { + return nil, err + } + var alg jose.SignatureAlgorithm + switch parsedKey.(type) { + case *rsa.PrivateKey: + alg = jose.RS256 + case *ecdsa.PrivateKey: + alg = jose.ES256 // TODO handle ES384 or ES512 ? + default: + // TODO are either of these also valid? + // ed25519.PrivateKey: + // *ecdh.PrivateKey + return nil, fmt.Errorf("private key %q is unknown pkcs#8 format type", privPem.Type) + } + return jose.NewSigner(jose.SigningKey{Algorithm: alg, Key: parsedKey}, signerOptions) + } + + return nil, fmt.Errorf("private key %q is not pkcs#1 or pkcs#8 format", privPem.Type) +} + +func (m *M2MAuthentication) makeClientAssertion() (string, error) { + privateKeySinger, err := m.createKeySigner() + if err != nil { + return "", err + } + + tokenRequestURL := fmt.Sprintf(okta.CustomAuthzV1TokenEndpointFormat, m.config.OrgDomain(), m.config.AuthzID()) + now := m.config.Clock().Now() + claims := okta.ClientAssertionClaims{ + Subject: m.config.OIDCAppID(), + IssuedAt: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(now.Add(time.Hour * time.Duration(1))), + Issuer: m.config.OIDCAppID(), + Audience: tokenRequestURL, + } + + jwtBuilder := jwt.Signed(privateKeySinger).Claims(claims) + return jwtBuilder.CompactSerialize() +} + +// AccessToken Takes okta-aws-cli private key and presents a client_credentials +// flow assertion to /oauth2/{authzServerID}/v1/token to gather an access token. +func (m *M2MAuthentication) AccessToken() (*okta.AccessToken, error) { + clientAssertion, err := m.makeClientAssertion() + if err != nil { + return nil, err + } + + var tokenRequestBuff io.ReadWriter + query := url.Values{} + tokenRequestURL := fmt.Sprintf(okta.CustomAuthzV1TokenEndpointFormat, m.config.OrgDomain(), m.config.AuthzID()) + + query.Add("grant_type", "client_credentials") + query.Add("scope", m.config.CustomScope()) + query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + query.Add("client_assertion", clientAssertion) + tokenRequestURL += "?" + query.Encode() + tokenRequest, err := http.NewRequest("POST", tokenRequestURL, tokenRequestBuff) + if err != nil { + return nil, err + } + + tokenRequest.Header.Add("Accept", utils.ApplicationJSON) + tokenRequest.Header.Add(utils.ContentType, utils.ApplicationXFORM) + resp, err := m.config.HTTPClient().Do(tokenRequest) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + baseErrStr := "fetching access token received API response %q" + if err != nil { + return nil, fmt.Errorf(baseErrStr, resp.Status) + } + + var apiErr okta.APIError + err = json.NewDecoder(resp.Body).Decode(&apiErr) + if err != nil { + return nil, fmt.Errorf(baseErrStr, resp.Status) + } + + return nil, fmt.Errorf(baseErrStr+", error: %q, description: %q", resp.Status, apiErr.Error, apiErr.ErrorDescription) + } + + token := &okta.AccessToken{} + err = json.NewDecoder(resp.Body).Decode(token) + if err != nil { + return nil, err + } + + return token, nil +} diff --git a/internal/m2mauth/m2mauth_test.go b/internal/m2mauth/m2mauth_test.go new file mode 100644 index 0000000..dd945df --- /dev/null +++ b/internal/m2mauth/m2mauth_test.go @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2023-Present, Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package m2mauth + +import ( + "net/http" + "os" + "path" + "regexp" + "testing" + + "github.com/okta/okta-aws-cli/internal/config" + "github.com/okta/okta-aws-cli/internal/testutils" + "github.com/stretchr/testify/require" + "gopkg.in/dnaeon/go-vcr.v3/recorder" +) + +func TestMain(m *testing.M) { + var reset func() + reset = osSetEnvIfBlank("OKTA_ORG_DOMAIN", testutils.TestDomainName) + defer reset() + reset = osSetEnvIfBlank("OKTA_OIDC_CLIENT_ID", "0oaa4htg72TNrkTDr1d7") + defer reset() + reset = osSetEnvIfBlank("OKTA_AWSCLI_IAM_ROLE", "arn:aws:iam::123:role/RickRollNeverGonnaGiveYouUp") + defer reset() + reset = osSetEnvIfBlank("OKTA_AUTHZ_ID", "aus8w23r13NvyUwln1d7") + defer reset() + reset = osSetEnvIfBlank("OKTA_AWSCLI_CUSTOM_SCOPE", "okta-aws-cli") + defer reset() + reset = osSetEnvIfBlank("OKTA_AWSCLI_KEY_ID", "kid-rock") + defer reset() + + // NOTE: Okta Security this is just some random PK to unit test the client + // assertion generator in this app. PK was created with + // `openssl genrsa 512 | pbcopy` + reset = osSetEnvIfBlank("OKTA_AWSCLI_PRIVATE_KEY", ` +-----BEGIN PRIVATE KEY----- +MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAzAZ73GY6TbcC0cQS +LQ+GfIkZxeTJjkW8+pdg0zmcGs4ZByZqp7oP02TbZ0UyLFHe8Eqik5rXR98mts5e +TuG2BwIDAQABAkEAmG2jrjdGCffYCGYnejjmLjaz5bCXkU6y8LmWIlkhMrg/F7uH +/yjmN3Hcj06F4b2DRczIIxWHpZVeFaqxvitZ6QIhAPlxhYIIpx4h+mf7cPXOlCZc +QDRqIa+pp3JH3Pgrz8mzAiEA0WNZP8acq251xTl2i+OrstH0o3YeYUmASv8bmyNs +0F0CIALSAsVunZ0cmz0zvZo55LjuUBeHn6vhyi/jmh8AN9A7AiEAoNtM1iTTeROb +4A7cFm2qGu8WnHkCr8SSjYrb/1vAnXUCIFgT6wGO6AFjQAahQlpVnqpppP9F8eSd +qrebTIkNMM8u +-----END PRIVATE KEY-----`) + defer reset() + + os.Exit(m.Run()) +} + +func osSetEnvIfBlank(key, value string) func() { + if os.Getenv(key) != "" { + return func() {} + } + _ = os.Setenv(key, value) + return func() { + _ = os.Unsetenv(key) + } +} + +func TestM2MAuthEstablishIAMCredentials(t *testing.T) { + t.Skip("TODO") +} + +// TestM2MAuthMakeClientAssertion Tests the private make client assertion method +// on m2mauth +func TestM2MAuthMakeClientAssertion(t *testing.T) { + config, teardownTest := setupTest(t) + config.SetClock(testutils.NewTestClock()) + defer teardownTest(t) + + m, err := NewM2MAuthentication(config) + require.NoError(t, err) + _, err = m.makeClientAssertion() + require.NoError(t, err) +} + +func TestM2MAuthAccessToken(t *testing.T) { + config, teardownTest := setupTest(t) + defer teardownTest(t) + + m, err := NewM2MAuthentication(config) + require.NoError(t, err) + + at, err := m.AccessToken() + require.NoError(t, err) + require.NotNil(t, at) + + require.Equal(t, "Bearer", at.TokenType) + require.Equal(t, int64(3600), at.ExpiresIn) + require.Equal(t, "okta-aws-cli", at.Scope) + require.Regexp(t, regexp.MustCompile("^eyJ"), at.AccessToken) +} + +func setupTest(t *testing.T) (*config.Config, func(t *testing.T)) { + attrs := &config.Attributes{ + OrgDomain: os.Getenv("OKTA_ORG_DOMAIN"), + OIDCAppID: os.Getenv("OKTA_OIDC_CLIENT_ID"), + AWSIAMRole: os.Getenv("OKTA_AWSCLI_IAM_ROLE"), + AuthzID: os.Getenv("OKTA_AUTHZ_ID"), + CustomScope: os.Getenv("OKTA_AWSCLI_CUSTOM_SCOPE"), + KeyID: os.Getenv("OKTA_AWSCLI_KEY_ID"), + PrivateKey: os.Getenv("OKTA_AWSCLI_PRIVATE_KEY"), + } + config, err := config.NewConfig(attrs) + require.NoError(t, err) + + rt := config.HTTPClient().Transport + vcr, err := newVCRRecorder(t, rt) + require.NoError(t, err) + rt = http.RoundTripper(vcr) + config.HTTPClient().Transport = rt + + tearDown := func(t *testing.T) { + err := vcr.Stop() + require.NoError(t, err) + } + + return config, tearDown +} + +func newVCRRecorder(t *testing.T, transport http.RoundTripper) (rec *recorder.Recorder, err error) { + dir, _ := os.Getwd() + vcrFixturesHome := path.Join(dir, "../../test/fixtures/vcr") + cassettesPath := path.Join(vcrFixturesHome, t.Name()) + rec, err = recorder.NewWithOptions(&recorder.Options{ + CassetteName: cassettesPath, + Mode: recorder.ModeRecordOnce, + SkipRequestLatency: true, // skip how vcr will mimic the real request latency that it can record allowing for fast playback + RealTransport: transport, + }) + if err != nil { + return + } + + rec.SetMatcher(testutils.VCROktaAPIRequestMatcher) + rec.AddHook(testutils.VCROktaAPIRequestHook, recorder.AfterCaptureHook) + + return +} diff --git a/internal/okta/accesstoken.go b/internal/okta/accesstoken.go new file mode 100644 index 0000000..af26db4 --- /dev/null +++ b/internal/okta/accesstoken.go @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2023-Present, Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package okta + +// AccessToken Encapsulates an Okta access token +// https://developer.okta.com/docs/reference/api/oidc/#token +type AccessToken struct { + AccessToken string `json:"access_token,omitempty"` + IDToken string `json:"id_token,omitempty"` + TokenType string `json:"token_type,omitempty"` + Scope string `json:"scope,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + DeviceSecret string `json:"device_secret,omitempty"` + Expiry string `json:"expiry"` +} diff --git a/internal/okta/apierror.go b/internal/okta/apierror.go new file mode 100644 index 0000000..337e9ed --- /dev/null +++ b/internal/okta/apierror.go @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2023-Present, Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package okta + +// APIError Wrapper for Okta API error +type APIError struct { + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` +} diff --git a/internal/okta/client_assertion_claims.go b/internal/okta/client_assertion_claims.go new file mode 100644 index 0000000..c006d09 --- /dev/null +++ b/internal/okta/client_assertion_claims.go @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023-Present, Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package okta + +import ( + "gopkg.in/square/go-jose.v2/jwt" +) + +// ClientAssertionClaims Okta Client Assertion Claims model +type ClientAssertionClaims struct { + Issuer string `json:"iss,omitempty"` + Subject string `json:"sub,omitempty"` + Audience string `json:"aud,omitempty"` + Expiry *jwt.NumericDate `json:"exp,omitempty"` + IssuedAt *jwt.NumericDate `json:"iat,omitempty"` + ID string `json:"jti,omitempty"` +} diff --git a/internal/okta/okta.go b/internal/okta/okta.go new file mode 100644 index 0000000..5669e44 --- /dev/null +++ b/internal/okta/okta.go @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2023-Present, Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package okta + +const ( + // OAuthV1TokenEndpointFormat sprintf format string for base oauth server token endpoint + OAuthV1TokenEndpointFormat = "https://%s/oauth2/v1/token" + + // CustomAuthzV1TokenEndpointFormat sprintf format string for custom oauth server token endpoint + CustomAuthzV1TokenEndpointFormat = "https://%s/oauth2/%s/v1/token" +) diff --git a/internal/sessiontoken/sessiontoken.go b/internal/sessiontoken/sessiontoken.go index 83787e8..99cc00e 100644 --- a/internal/sessiontoken/sessiontoken.go +++ b/internal/sessiontoken/sessiontoken.go @@ -47,20 +47,18 @@ import ( oaws "github.com/okta/okta-aws-cli/internal/aws" boff "github.com/okta/okta-aws-cli/internal/backoff" "github.com/okta/okta-aws-cli/internal/config" + "github.com/okta/okta-aws-cli/internal/okta" "github.com/okta/okta-aws-cli/internal/output" + "github.com/okta/okta-aws-cli/internal/utils" ) const ( amazonAWS = "amazon_aws" accept = "Accept" - applicationJSON = "application/json" - applicationXWwwForm = "application/x-www-form-urlencoded" - contentType = "Content-Type" userAgent = "User-Agent" nameKey = "name" saml2Attribute = "saml2:attribute" samlAttributesRole = "https://aws.amazon.com/SAML/Attributes/Role" - oauthV1TokenEndpointFmt = "https://%s/oauth2/v1/token" askIDPError = "error asking for IdP selection: %w" noRoleError = "provider %q has no roles to choose from" noIDPsError = "no IdPs to choose from" @@ -90,19 +88,6 @@ type SessionToken struct { fedAppAlreadySelected bool } -// accessToken Encapsulates an Okta access token -// https://developer.okta.com/docs/reference/api/oidc/#token -type accessToken struct { - AccessToken string `json:"access_token,omitempty"` - IDToken string `json:"id_token,omitempty"` - TokenType string `json:"token_type,omitempty"` - Scope string `json:"scope,omitempty"` - ExpiresIn int64 `json:"expires_in,omitempty"` - RefreshToken string `json:"refresh_token,omitempty"` - DeviceSecret string `json:"device_secret,omitempty"` - Expiry string `json:"expiry"` -} - // deviceAuthorization Encapsulates Okta API result to // /oauth2/v1/device/authorize call type deviceAuthorization struct { @@ -129,12 +114,6 @@ type oktaApplication struct { } `json:"settings"` } -// apiError Wrapper for Okta API error -type apiError struct { - Error string `json:"error,omitempty"` - ErrorDescription string `json:"error_description,omitempty"` -} - // idpAndRole IdP and role pairs type idpAndRole struct { idp string @@ -168,7 +147,7 @@ func NewSessionToken(config *config.Config) (token *SessionToken, err error) { // token. func (s *SessionToken) EstablishToken() error { clientID := s.config.OIDCAppID() - var at *accessToken + var at *okta.AccessToken var apps []*oktaApplication var err error at = s.cachedAccessToken() @@ -309,7 +288,7 @@ func (s *SessionToken) selectFedApp(apps []*oktaApplication) (string, error) { return idps[selected].ID, nil } -func (s *SessionToken) establishTokenWithFedAppID(clientID, fedAppID string, at *accessToken) error { +func (s *SessionToken) establishTokenWithFedAppID(clientID, fedAppID string, at *okta.AccessToken) error { at, err := s.fetchSSOWebToken(clientID, fedAppID, at) if err != nil { return err @@ -575,7 +554,7 @@ func (s *SessionToken) extractIDPAndRolesMapFromAssertion(encoded string) (irmap } // fetchSAMLAssertion Gets the SAML assertion from Okta API /login/token/sso -func (s *SessionToken) fetchSAMLAssertion(at *accessToken) (assertion string, err error) { +func (s *SessionToken) fetchSAMLAssertion(at *okta.AccessToken) (assertion string, err error) { params := url.Values{"token": {at.AccessToken}} apiURL := fmt.Sprintf("https://%s/login/token/sso?%s", s.config.OrgDomain(), params.Encode()) @@ -605,8 +584,8 @@ func (s *SessionToken) fetchSAMLAssertion(at *accessToken) (assertion string, er // fetchSSOWebToken see: // https://developer.okta.com/docs/reference/api/oidc/#token -func (s *SessionToken) fetchSSOWebToken(clientID, awsFedAppID string, at *accessToken) (token *accessToken, err error) { - apiURL := fmt.Sprintf(oauthV1TokenEndpointFmt, s.config.OrgDomain()) +func (s *SessionToken) fetchSSOWebToken(clientID, awsFedAppID string, at *okta.AccessToken) (token *okta.AccessToken, err error) { + apiURL := fmt.Sprintf(okta.OAuthV1TokenEndpointFormat, s.config.OrgDomain()) data := url.Values{ "client_id": {clientID}, @@ -624,8 +603,8 @@ func (s *SessionToken) fetchSSOWebToken(clientID, awsFedAppID string, at *access if err != nil { return nil, err } - req.Header.Add(accept, applicationJSON) - req.Header.Add(contentType, applicationXWwwForm) + req.Header.Add(accept, utils.ApplicationJSON) + req.Header.Add(utils.ContentType, utils.ApplicationXFORM) req.Header.Add(userAgent, agent.NewUserAgent(config.Version).String()) resp, err := s.config.HTTPClient().Do(req) @@ -639,7 +618,7 @@ func (s *SessionToken) fetchSSOWebToken(clientID, awsFedAppID string, at *access return nil, fmt.Errorf(baseErrStr, resp.Status) } - var apiErr apiError + var apiErr okta.APIError err = json.NewDecoder(resp.Body).Decode(&apiErr) if err != nil { return nil, fmt.Errorf(baseErrStr, resp.Status) @@ -648,7 +627,7 @@ func (s *SessionToken) fetchSSOWebToken(clientID, awsFedAppID string, at *access return nil, fmt.Errorf(baseErrStr+", error: %q, description: %q", resp.Status, apiErr.Error, apiErr.ErrorDescription) } - token = &accessToken{} + token = &okta.AccessToken{} err = json.NewDecoder(resp.Body).Decode(token) if err != nil { return nil, err @@ -695,7 +674,7 @@ func (s *SessionToken) promptAuthentication(da *deviceAuthorization) { // after getting anything other than a 403 on /api/v1/apps will be wrapped as as // an error that is related having multiple fed apps available. Requires // assoicated OIDC app has been granted okta.apps.read to its scope. -func (s *SessionToken) listFedApps(clientID string, at *accessToken) (apps []*oktaApplication, err error) { +func (s *SessionToken) listFedApps(clientID string, at *okta.AccessToken) (apps []*oktaApplication, err error) { apiURL, err := url.Parse(fmt.Sprintf("https://%s/api/v1/apps", s.config.OrgDomain())) if err != nil { return nil, err @@ -710,8 +689,8 @@ func (s *SessionToken) listFedApps(clientID string, at *accessToken) (apps []*ok return nil, err } - req.Header.Add(accept, applicationJSON) - req.Header.Add(contentType, applicationJSON) + req.Header.Add(accept, utils.ApplicationJSON) + req.Header.Add(utils.ContentType, utils.ApplicationJSON) req.Header.Add(userAgent, agent.NewUserAgent(config.Version).String()) req.Header.Add("Authorization", fmt.Sprintf("%s %s", at.TokenType, at.AccessToken)) resp, err := s.config.HTTPClient().Do(req) @@ -751,15 +730,15 @@ func (s *SessionToken) listFedApps(clientID string, at *accessToken) (apps []*ok // fetchAccessToken see: // https://developer.okta.com/docs/reference/api/oidc/#token -func (s *SessionToken) fetchAccessToken(clientID string, deviceAuth *deviceAuthorization) (at *accessToken, err error) { - apiURL := fmt.Sprintf(oauthV1TokenEndpointFmt, s.config.OrgDomain()) +func (s *SessionToken) fetchAccessToken(clientID string, deviceAuth *deviceAuthorization) (at *okta.AccessToken, err error) { + apiURL := fmt.Sprintf(okta.OAuthV1TokenEndpointFormat, s.config.OrgDomain()) req, err := http.NewRequest(http.MethodPost, apiURL, nil) if err != nil { return nil, err } - req.Header.Add(accept, applicationJSON) - req.Header.Add(contentType, applicationXWwwForm) + req.Header.Add(accept, utils.ApplicationJSON) + req.Header.Add(utils.ContentType, utils.ApplicationXFORM) req.Header.Add(userAgent, agent.NewUserAgent(config.Version).String()) var bodyBytes []byte @@ -806,7 +785,7 @@ func (s *SessionToken) fetchAccessToken(clientID string, deviceAuth *deviceAutho return nil, err } - at = &accessToken{} + at = &okta.AccessToken{} err = json.NewDecoder(bytes.NewReader(bodyBytes)).Decode(at) if err != nil { return nil, err @@ -828,8 +807,8 @@ func (s *SessionToken) authorize(clientID string) (*deviceAuthorization, error) if err != nil { return nil, err } - req.Header.Add(accept, applicationJSON) - req.Header.Add(contentType, applicationXWwwForm) + req.Header.Add(accept, utils.ApplicationJSON) + req.Header.Add(utils.ContentType, utils.ApplicationXFORM) req.Header.Add(userAgent, agent.NewUserAgent(config.Version).String()) resp, err := s.config.HTTPClient().Do(req) @@ -840,8 +819,8 @@ func (s *SessionToken) authorize(clientID string) (*deviceAuthorization, error) return nil, fmt.Errorf("authorize received API response %q", resp.Status) } - ct := resp.Header.Get(contentType) - if !strings.Contains(ct, applicationJSON) { + ct := resp.Header.Get(utils.ContentType) + if !strings.Contains(ct, utils.ApplicationJSON) { return nil, fmt.Errorf("authorize non-JSON API response content type %q", ct) } @@ -913,8 +892,8 @@ func findSAMLRoleAttibute(n *html.Node) (node *html.Node, found bool) { return nil, false } -func apiErr(bodyBytes []byte) (ae *apiError, err error) { - ae = &apiError{} +func apiErr(bodyBytes []byte) (ae *okta.APIError, err error) { + ae = &okta.APIError{} err = json.NewDecoder(bytes.NewReader(bodyBytes)).Decode(ae) return } @@ -934,7 +913,7 @@ func (s *SessionToken) isClassicOrg() bool { if err != nil { return false } - req.Header.Add(accept, applicationJSON) + req.Header.Add(accept, utils.ApplicationJSON) req.Header.Add(userAgent, agent.NewUserAgent(config.Version).String()) resp, err := s.config.HTTPClient().Do(req) @@ -960,7 +939,7 @@ func (s *SessionToken) isClassicOrg() bool { // cachedAccessToken will returned the cached access token if it exists and is // not expired. -func (s *SessionToken) cachedAccessToken() (at *accessToken) { +func (s *SessionToken) cachedAccessToken() (at *okta.AccessToken) { homeDir, err := os.UserHomeDir() if err != nil { return @@ -971,7 +950,7 @@ func (s *SessionToken) cachedAccessToken() (at *accessToken) { return } - _at := accessToken{} + _at := okta.AccessToken{} err = json.Unmarshal(atJSON, &_at) if err != nil { return @@ -991,7 +970,7 @@ func (s *SessionToken) cachedAccessToken() (at *accessToken) { // cacheAccessToken will cache the access token for later use if enabled. Silent // if fails. -func (s *SessionToken) cacheAccessToken(at *accessToken) { +func (s *SessionToken) cacheAccessToken(at *okta.AccessToken) { if !s.config.CacheAccessToken() { return } diff --git a/internal/testutils/testutils.go b/internal/testutils/testutils.go new file mode 100644 index 0000000..d9e4926 --- /dev/null +++ b/internal/testutils/testutils.go @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2023-Present, Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package testutils + +import ( + "bytes" + "encoding/json" + "io" + "log" + "net/http" + "os" + "reflect" + "regexp" + "strings" + "time" + + "github.com/okta/okta-aws-cli/internal/config" + "github.com/okta/okta-aws-cli/internal/utils" + "gopkg.in/dnaeon/go-vcr.v3/cassette" +) + +const ( + // TestDomainName Fake domain name for tests / recordings + TestDomainName = "test.dne-okta.com" + // ClientAssertionNameValueRE client assertion regular expression format + ClientAssertionNameValueRE = "client_assertion=[^&]+" + // ClientAssertionNameValueValue client asserver name and value url encoded format + ClientAssertionNameValueValue = "client_assertion=abc123" +) + +// TestClock Is a test clock of the Clock interface +type TestClock struct{} + +// Now The test clock's now +func (TestClock) Now() time.Time { return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) } + +// NewTestClock New test clock constructor +func NewTestClock() config.Clock { + return &TestClock{} +} + +// VCROktaAPIRequestHook Modifies VCR recordings. +func VCROktaAPIRequestHook(i *cassette.Interaction) error { + // need to scrub Okta org strings and rewrite as test.dne-okta.com so that + // HTTP requests that escape VCR are bad. + + // test.dne-okta.com + vcrHostname := TestDomainName + // example.okta.com + orgHostname := os.Getenv("OKTA_ORG_DOMAIN") + + // save disk space, clean up what gets written to disk + i.Request.Headers.Del("User-Agent") + deleteResponseHeaders := []string{ + "Cache-Control", + "Content-Security-Policy", + "Content-Security-Policy-Report-Only", + "duration", + "Expect-Ct", + "Expires", + "P3p", + "Pragma", + "Public-Key-Pins-Report-Only", + "Server", + "Set-Cookie", + "Strict-Transport-Security", + "Vary", + } + for _, header := range deleteResponseHeaders { + i.Response.Headers.Del(header) + } + for name := range i.Response.Headers { + // delete all X-headers + if strings.HasPrefix(name, "X-") { + i.Response.Headers.Del(name) + continue + } + } + + // scrub client assertion out of token requests + m := regexp.MustCompile(ClientAssertionNameValueRE) + i.Request.URL = m.ReplaceAllString(i.Request.URL, ClientAssertionNameValueValue) + + // %s/example.okta.com/test.dne-okta.com/ + i.Request.Host = strings.ReplaceAll(i.Request.Host, orgHostname, vcrHostname) + + // %s/example.okta.com/test.dne-okta.com/ + i.Request.URL = strings.ReplaceAll(i.Request.URL, orgHostname, vcrHostname) + + // %s/example.okta.com/test.dne-okta.com/ + i.Request.Body = strings.ReplaceAll(i.Request.Body, orgHostname, vcrHostname) + + return nil +} + +// VCROktaAPIRequestMatcher Defines how VCR will match requests to responses. +func VCROktaAPIRequestMatcher(r *http.Request, i cassette.Request) bool { + // scrub access token for lookup + if r.URL.RawQuery != "" { + m := regexp.MustCompile(ClientAssertionNameValueRE) + r.URL.RawQuery = m.ReplaceAllString(r.URL.RawQuery, ClientAssertionNameValueValue) + } + // scrub host for lookup + r.URL.Host = TestDomainName + + // Default matcher compares method and URL only + if !cassette.DefaultMatcher(r, i) { + return false + } + // TODO: there might be header information we could inspect to make this more precise + if r.Body == nil { + return true + } + + var b bytes.Buffer + if _, err := b.ReadFrom(r.Body); err != nil { + log.Printf("[DEBUG] Failed to read request body from cassette: %v", err) + return false + } + r.Body = io.NopCloser(&b) + reqBody := b.String() + // If body matches identically, we are done + if reqBody == i.Body { + return true + } + + // JSON might be the same, but reordered. Try parsing json and comparing + contentType := r.Header.Get(utils.ContentType) + if strings.Contains(contentType, utils.ApplicationJSON) { + var reqJSON, cassetteJSON interface{} + if err := json.Unmarshal([]byte(reqBody), &reqJSON); err != nil { + log.Printf("[DEBUG] Failed to unmarshall request json: %v", err) + return false + } + if err := json.Unmarshal([]byte(i.Body), &cassetteJSON); err != nil { + log.Printf("[DEBUG] Failed to unmarshall cassette json: %v", err) + return false + } + return reflect.DeepEqual(reqJSON, cassetteJSON) + } + + return true +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go new file mode 100644 index 0000000..65147dd --- /dev/null +++ b/internal/utils/utils.go @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2023-Present, Okta, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +const ( + // ContentType http header content type + ContentType = "Content-Type" + // ApplicationJSON content value for json + ApplicationJSON = "application/json" + // ApplicationXFORM content type value for web form + ApplicationXFORM = "application/x-www-form-urlencoded" +) diff --git a/test/fixtures/vcr/TestM2MAuthAccessToken.yaml b/test/fixtures/vcr/TestM2MAuthAccessToken.yaml new file mode 100644 index 0000000..99864b5 --- /dev/null +++ b/test/fixtures/vcr/TestM2MAuthAccessToken.yaml @@ -0,0 +1,42 @@ +--- +version: 2 +interactions: + - id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + transfer_encoding: [] + trailer: {} + host: test.dne-okta.com + remote_addr: "" + request_uri: "" + body: "" + form: {} + headers: + Accept: + - application/json + Content-Type: + - application/x-www-form-urlencoded + url: https://test.dne-okta.com/oauth2/aus8w23r13NvyUwln1d7/v1/token?client_assertion=abc123&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer&grant_type=client_credentials&scope=okta-aws-cli + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + transfer_encoding: [] + trailer: {} + content_length: -1 + uncompressed: false + body: '{"token_type":"Bearer","expires_in":3600,"access_token":"eyJraWQiOiJjRnZ0bHRBbzF1cVhMT0ZLOGxFTVA2a3czd25pVVRHMVFSckhvQXBNTkF3IiwiYWxnIjoiUlMyNTYifQ.eyJ2ZXIiOjEsImp0aSI6IkFULjYxZTdVRU1oUGZWUkQ0LU51Wm9TbUsxSDJ2VGpjbXp2Njl4ZFd0VXpBVkkiLCJpc3MiOiJodHRwczovL21tb25kcmFnb24tYXdzLWNsaS0wMC5va3RhcHJldmlldy5jb20vb2F1dGgyL2F1czh3MjNyMTNOdnlVd2xuMWQ3IiwiYXVkIjoiaHR0cHM6Ly9va3RhLWF3cy1jbGktYXV0aG9yaXplciIsImlhdCI6MTY5NTc2NzAzNCwiZXhwIjoxNjk1NzcwNjM0LCJjaWQiOiIwb2FhNGh0ZzcyVE5ya1REcjFkNyIsInNjcCI6WyJva3RhLWF3cy1jbGkiXSwic3ViIjoiMG9hYTRodGc3MlROcmtURHIxZDcifQ.jE6sEw1acXo_pccQUaOXrT4uQ0KI9fLYKHsh23aCsXPBrdfaVYe_yEPdZM7GWg3VYpG9VQVo-I26IKb88Nqnxw11ABMIIglHXlUx0AJHHPZP7PXi8p91y0WG7lDoU2seiX9ce8DXX83R831qLSbQImUOKOz9aNemmvSzwPvDPnjnWNQq_Dmn_MDFiaS4cqMcWB_d_SFVAFVoa-ZC-Rli0kZ63-0ZAtmyv8unHAd1eLCyq3eikeFKXRuSKaAlAgdix2OUHnC9IL_gym9xiZDXDqASmKOqRdIcJ6Q0vn8ujvKwcO_LYPAZkkfkVkDeMEvm_ee43jgcPNF-xdmLJ3YnWg","scope":"okta-aws-cli"}' + headers: + Content-Type: + - application/json + Date: + - Tue, 26 Sep 2023 22:23:54 GMT + Report-To: + - '{"group":"csp","max_age":31536000,"endpoints":[{"url":"https://oktacsp.report-uri.com/a/t/g"}],"include_subdomains":true}' + status: 200 OK + code: 200 + duration: 620.35651ms diff --git a/test/fixtures/vcr/TestM2MAuthMakeClientAssertion.yaml b/test/fixtures/vcr/TestM2MAuthMakeClientAssertion.yaml new file mode 100644 index 0000000..2797c38 --- /dev/null +++ b/test/fixtures/vcr/TestM2MAuthMakeClientAssertion.yaml @@ -0,0 +1,3 @@ +--- +version: 2 +interactions: []