Skip to content

Commit

Permalink
Merge pull request #1864 from vmware-tanzu/cli_use_cached_access_token
Browse files Browse the repository at this point in the history
login oidc cmd checks access token expiry before doing token exchange
  • Loading branch information
cfryanr authored Feb 9, 2024
2 parents 492dfa8 + dce9409 commit 485b227
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 150 deletions.
14 changes: 7 additions & 7 deletions cmd/pinniped/cmd/login_oidc.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// Copyright 2020-2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package cmd
Expand Down Expand Up @@ -241,12 +241,12 @@ func runOIDCLogin(cmd *cobra.Command, deps oidcLoginCommandDeps, flags oidcLogin
}

pLogger.Debug("Performing OIDC login", "issuer", flags.issuer, "client id", flags.clientID)
// Do the basic login to get an OIDC token.
// Do the basic login to get an OIDC token. Although this can return several tokens, we only need the ID token here.
token, err := deps.login(flags.issuer, flags.clientID, opts...)
if err != nil {
return fmt.Errorf("could not complete Pinniped login: %w", err)
}
cred := tokenCredential(token)
cred := tokenCredential(token.IDToken)

// If the concierge was configured, exchange the credential for a separate short-lived, cluster-specific credential.
if concierge != nil {
Expand Down Expand Up @@ -344,18 +344,18 @@ func makeClient(caBundlePaths []string, caBundleData []string) (*http.Client, er
return phttp.Default(pool), nil
}

func tokenCredential(token *oidctypes.Token) *clientauthv1beta1.ExecCredential {
func tokenCredential(idToken *oidctypes.IDToken) *clientauthv1beta1.ExecCredential {
cred := clientauthv1beta1.ExecCredential{
TypeMeta: metav1.TypeMeta{
Kind: "ExecCredential",
APIVersion: "client.authentication.k8s.io/v1beta1",
},
Status: &clientauthv1beta1.ExecCredentialStatus{
Token: token.IDToken.Token,
Token: idToken.Token,
},
}
if !token.IDToken.Expiry.IsZero() {
cred.Status.ExpirationTimestamp = &token.IDToken.Expiry
if !idToken.Expiry.IsZero() {
cred.Status.ExpirationTimestamp = &idToken.Expiry
}
return &cred
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/pinniped/cmd/login_static.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// Copyright 2020-2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package cmd
Expand Down Expand Up @@ -133,7 +133,7 @@ func runStaticLogin(cmd *cobra.Command, deps staticLoginDeps, flags staticLoginP
return fmt.Errorf("--token-env variable %q is empty", flags.staticTokenEnvName)
}
}
cred := tokenCredential(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: token}})
cred := tokenCredential(&oidctypes.IDToken{Token: token})

// Look up cached credentials based on a hash of all the CLI arguments, the current token value, and the cluster info.
cacheKey := struct {
Expand Down
7 changes: 2 additions & 5 deletions internal/federationdomain/endpoints/auth/auth_handler_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// Copyright 2020-2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package auth
Expand Down Expand Up @@ -366,10 +366,7 @@ func TestAuthorizationEndpoint(t *testing.T) { //nolint:gocyclo
sadCSRFGenerator := func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }
sadPKCEGenerator := func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }
sadNonceGenerator := func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }

// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"
expectedUpstreamCodeChallenge := testutil.SHA256("test-pkce")

var stateEncoderHashKey = []byte("fake-hash-secret")
var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES
Expand Down
71 changes: 55 additions & 16 deletions pkg/oidcclient/login.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// Copyright 2020-2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

// Package oidcclient implements a CLI OIDC login flow.
Expand Down Expand Up @@ -44,10 +44,16 @@ import (

const (
// minIDTokenValidity is the minimum amount of time that a cached ID token must be still be valid to be considered.
// This is non-zero to ensure that most of the time, your ID token won't expire in the middle of a multi-step k8s
// This is non-zero to ensure that most of the time, your ID token won't expire in the middle of a multistep k8s
// API operation.
minIDTokenValidity = 10 * time.Minute

// minAccessTokenValidity is the minimum amount of time that a cached access token must be still be valid
// to be considered.
// This is non-zero to ensure that most of the time, your access token won't expire before we submit it for
// RFC8693 token exchange.
minAccessTokenValidity = 10 * time.Second

// httpRequestTimeout is the timeout for operations that involve one (or a few) non-interactive HTTPS requests.
// Since these don't involve any user interaction, they should always be roughly as fast as network latency.
httpRequestTimeout = 60 * time.Second
Expand Down Expand Up @@ -328,22 +334,53 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
}

// Do the basic login to get an access and ID token issued to our main client ID.
baseToken, err := h.baseLogin()
token, err := h.baseLogin()
if err != nil {
return nil, err
}

// If there is no requested audience, or the requested audience matches the one we got, we're done.
if h.requestedAudience == "" || (baseToken.IDToken != nil && h.requestedAudience == baseToken.IDToken.Claims["aud"]) {
return baseToken, err
// Perform the RFC8693 token exchange, if needed. Note that the new ID token returned by this exchange
// does not need to be cached because the new ID token is intended to be a very short-lived token.
if h.needRFC8693TokenExchange(token) {
token, err = h.tokenExchangeRFC8693(token)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %w", err)
}
}

// Perform the RFC8693 token exchange.
exchangedToken, err := h.tokenExchangeRFC8693(baseToken)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %w", err)
return token, nil
}

func (h *handlerState) needRFC8693TokenExchange(token *oidctypes.Token) bool {
// Need a new ID token if there is a requested audience value and any of the following are true...
return h.requestedAudience != "" &&
// we don't have an ID token
(token.IDToken == nil ||
// or, our current ID token has expired or is close to expiring
idTokenExpiredOrCloseToExpiring(token.IDToken) ||
// or, our current ID token has a different audience
(h.requestedAudience != token.IDToken.Claims["aud"]))
}

func (h *handlerState) tokenValidForNearFuture(token *oidctypes.Token) (bool, string) {
if token == nil {
return false, ""
}
return exchangedToken, nil
// If we plan to do an RFC8693 token exchange, then we need an access token that will still be valid when we do the
// exchange (which will happen momentarily). Otherwise, we need an ID token that will be valid for a little while
// (long enough for multistep k8s API operations).
if h.needRFC8693TokenExchange(token) {
return !accessTokenExpiredOrCloseToExpiring(token.AccessToken), "access_token"
}
return !idTokenExpiredOrCloseToExpiring(token.IDToken), "id_token"
}

func accessTokenExpiredOrCloseToExpiring(accessToken *oidctypes.AccessToken) bool {
return accessToken == nil || time.Until(accessToken.Expiry.Time) <= minAccessTokenValidity
}

func idTokenExpiredOrCloseToExpiring(idToken *oidctypes.IDToken) bool {
return idToken == nil || time.Until(idToken.Expiry.Time) <= minIDTokenValidity
}

func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
Expand All @@ -360,10 +397,11 @@ func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
UpstreamProviderName: h.upstreamIdentityProviderName,
}

// If the ID token is still valid for a bit, return it immediately and skip the rest of the flow.
// If the cached tokens include the token type that we need, and that token is still valid for a bit,
// return the cached tokens immediately and skip the rest of the flow.
cached := h.cache.GetToken(cacheKey)
if cached != nil && cached.IDToken != nil && time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity {
h.logger.V(plog.KlogLevelDebug).Info("Pinniped: Found unexpired cached token.")
if valid, whichTokenWasValid := h.tokenValidForNearFuture(cached); valid {
h.logger.V(plog.KlogLevelDebug).Info("Pinniped: Found unexpired cached token.", "type", whichTokenWasValid)
return cached, nil
}

Expand All @@ -378,13 +416,14 @@ func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
if err != nil {
return nil, err
}
// If we got a fresh token, we can update the cache and return it. Otherwise we fall through to the full refresh flow.
// If we got a fresh token, update the cache and return it. Otherwise, fall through to the full login flow.
if freshToken != nil {
h.cache.PutToken(cacheKey, freshToken)
return freshToken, nil
}
}

// We couldn't refresh, so now we need to perform a fresh login attempt.
// Prepare the common options for the authorization URL. We don't have the redirect URL yet though.
authorizeOptions := []oauth2.AuthCodeOption{
oauth2.AccessTypeOffline,
Expand Down Expand Up @@ -833,7 +872,7 @@ func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidcty
}

func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) {
h.logger.V(plog.KlogLevelDebug).Info("Pinniped: Refreshing cached token.")
h.logger.V(plog.KlogLevelDebug).Info("Pinniped: Refreshing cached tokens.")
upstreamOIDCIdentityProvider := h.getProvider(h.oauth2Config, h.provider, h.httpClient)

refreshed, err := upstreamOIDCIdentityProvider.PerformRefresh(ctx, refreshToken.Token)
Expand Down
Loading

0 comments on commit 485b227

Please sign in to comment.