Skip to content

Commit

Permalink
Log and monitor failures to validate access tokens
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Dye <[email protected]>
  • Loading branch information
katrogan authored and andrewwdye committed Sep 24, 2024
1 parent 51acfd0 commit c127997
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 16 deletions.
29 changes: 27 additions & 2 deletions flyteadmin/auth/authzserver/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
fositeOAuth2 "github.com/ory/fosite/handler/oauth2"
"github.com/ory/fosite/storage"
"github.com/ory/fosite/token/jwt"
"github.com/prometheus/client_golang/prometheus"
"k8s.io/apimachinery/pkg/util/sets"

"github.com/flyteorg/flyte/flyteadmin/auth"
Expand All @@ -24,6 +25,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flytestdlib/logger"
"github.com/flyteorg/flyte/flytestdlib/promutils"
)

const (
Expand All @@ -33,12 +35,18 @@ const (
KeyIDClaim = "key_id"
)

type providerMetrics struct {
InvalidTokens prometheus.Counter
ExpiredTokens prometheus.Counter
}

// Provider implements OAuth2 Authorization Server.
type Provider struct {
fosite.OAuth2Provider
cfg config.AuthorizationServer
publicKey []rsa.PublicKey
keySet jwk.Set
metrics providerMetrics
}

func (p Provider) PublicKeys() []rsa.PublicKey {
Expand Down Expand Up @@ -111,23 +119,36 @@ func (p Provider) ValidateAccessToken(ctx context.Context, expectedAudience, tok
})

if err != nil {
logger.Infof(ctx, "failed to parse token for audience '%s'. Error: %v", expectedAudience, err)
return nil, err
}

if !parsedToken.Valid {
if ve, ok := err.(*jwtgo.ValidationError); ok && ve.Is(jwtgo.ErrTokenExpired) {
logger.Infof(ctx, "parsed token for audience '%s' is expired", expectedAudience)
p.metrics.ExpiredTokens.Inc()
} else {
logger.Infof(ctx, "parsed token for audience '%s' is invalid: %+v", expectedAudience, err)
p.metrics.InvalidTokens.Inc()
}
return nil, fmt.Errorf("parsed token is invalid")
}

claimsRaw := parsedToken.Claims.(jwtgo.MapClaims)
return verifyClaims(sets.NewString(expectedAudience), claimsRaw)
identityCtx, err := verifyClaims(sets.NewString(expectedAudience), claimsRaw)
if err != nil {
logger.Infof(ctx, "failed to verify claims for audience: '%s'. Error: %v", expectedAudience, err)
return nil, err
}
return identityCtx, nil
}

// NewProvider creates a new OAuth2 Provider that is able to do OAuth 2-legged and 3-legged flows. It'll lookup
// config.SecretNameClaimSymmetricKey and config.SecretNameTokenSigningRSAKey secrets from the secret manager to use to
// sign and generate hashes for tokens. The RSA Private key is expected to be in PEM format with the public key embedded.
// Use auth.GetInitSecretsCommand() to generate new valid secrets that will be accepted by this provider.
// The config.SecretNameClaimSymmetricKey must be a 32-bytes long key in Base64Encoding.
func NewProvider(ctx context.Context, cfg config.AuthorizationServer, sm core.SecretManager) (Provider, error) {
func NewProvider(ctx context.Context, cfg config.AuthorizationServer, sm core.SecretManager, scope promutils.Scope) (Provider, error) {
// fosite requires four parameters for the server to get up and running:
// 1. config - for any enforcement you may desire, you can do this using `compose.Config`. You like PKCE, enforce it!
// 2. store - no auth service is generally useful unless it can remember clients and users.
Expand Down Expand Up @@ -230,5 +251,9 @@ func NewProvider(ctx context.Context, cfg config.AuthorizationServer, sm core.Se
OAuth2Provider: oauth2Provider,
publicKey: publicKeys,
keySet: keysSet,
metrics: providerMetrics{
ExpiredTokens: scope.MustNewCounter("expired_token", "The number of expired tokens"),
InvalidTokens: scope.MustNewCounter("invalid_tokens", "The number of invalid tokens"),
},
}, nil
}
7 changes: 4 additions & 3 deletions flyteadmin/auth/authzserver/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/flyteorg/flyte/flyteadmin/auth/config"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyte/flytestdlib/promutils"
)

func newMockProvider(t testing.TB) (Provider, auth.SecretsSet) {
Expand All @@ -36,7 +37,7 @@ func newMockProvider(t testing.TB) (Provider, auth.SecretsSet) {
sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return(buf.String(), nil)
sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return(buf.String(), nil)

p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm)
p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm, promutils.NewTestScope())
assert.NoError(t, err)
return p, secrets
}
Expand All @@ -58,7 +59,7 @@ func newInvalidMockProvider(ctx context.Context, t *testing.T, secrets auth.Secr
sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return(buf.String(), nil)

invalidFunc()
p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm)
p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm, promutils.NewTestScope())
assert.Error(t, err)
assert.ErrorContains(t, err, errorContains)
assert.Equal(t, Provider{}, p)
Expand Down Expand Up @@ -294,7 +295,7 @@ func TestProvider_ValidateAccessToken(t *testing.T) {
sm.OnGet(ctx, config.SecretNameTokenSigningRSAKey).Return(buf.String(), nil)
sm.OnGet(ctx, config.SecretNameOldTokenSigningRSAKey).Return(buf.String(), nil)

p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm)
p, err := NewProvider(ctx, config.DefaultConfig.AppAuth.SelfAuthServer, sm, promutils.NewTestScope())
assert.NoError(t, err)

// create a signer for rsa 256
Expand Down
12 changes: 7 additions & 5 deletions flyteadmin/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,23 +301,25 @@ func GetAuthenticationInterceptor(authCtx interfaces.AuthenticationContext) func
fromHTTP := metautils.ExtractIncoming(ctx).Get(FromHTTPKey)
isFromHTTP := fromHTTP == FromHTTPVal

identityContext, err := GRPCGetIdentityFromAccessToken(ctx, authCtx)
if err == nil {
identityContext, accessTokenErr := GRPCGetIdentityFromAccessToken(ctx, authCtx)
if accessTokenErr == nil {
return SetContextForIdentity(ctx, identityContext), nil
}

logger.Infof(ctx, "Failed to parse Access Token from context. Will attempt to find IDToken. Error: %v", err)
logger.Infof(ctx, "Failed to parse Access Token from context. Will attempt to find IDToken. Error: %v", accessTokenErr)

identityContext, err = GRPCGetIdentityFromIDToken(ctx, authCtx.Options().UserAuth.OpenID.ClientID,
identityContext, idTokenErr := GRPCGetIdentityFromIDToken(ctx, authCtx.Options().UserAuth.OpenID.ClientID,
authCtx.OidcProvider())

if err == nil {
if idTokenErr == nil {
return SetContextForIdentity(ctx, identityContext), nil
}
logger.Debugf(ctx, "Failed to parse ID Token from context. Error: %v", idTokenErr)

// Only enforcement logic is present. The default case is to let things through.
if (isFromHTTP && !authCtx.Options().DisableForHTTP) ||
(!isFromHTTP && !authCtx.Options().DisableForGrpc) {
err := fmt.Errorf("id token err: %w, access token err: %w", fmt.Errorf("access token err: %w", accessTokenErr), idTokenErr)
return ctx, status.Errorf(codes.Unauthenticated, "token parse error %s", err)
}

Expand Down
4 changes: 2 additions & 2 deletions flyteadmin/pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry,
var oauth2Provider interfaces.OAuth2Provider
var oauth2ResourceServer interfaces.OAuth2ResourceServer
if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf {
oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm)
oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm, scope.NewSubScope("auth_provider"))
if err != nil {
logger.Errorf(ctx, "Error creating authorization server %s", err)
return err
Expand Down Expand Up @@ -463,7 +463,7 @@ func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, c
var oauth2Provider interfaces.OAuth2Provider
var oauth2ResourceServer interfaces.OAuth2ResourceServer
if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf {
oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm)
oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm, scope.NewSubScope("auth_provider"))
if err != nil {
logger.Errorf(ctx, "Error creating authorization server %s", err)
return err
Expand Down
30 changes: 26 additions & 4 deletions flyteidl/clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"strings"
"sync"

"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/util/retry"

"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"

Expand Down Expand Up @@ -167,6 +170,7 @@ func GetPKCEAuthTokenSource(ctx context.Context, pkceTokenOrchestrator pkce.Toke
type ClientCredentialsTokenSourceProvider struct {
ccConfig clientcredentials.Config
tokenCache cache.TokenCache
cfg *Config
}

func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string,
Expand Down Expand Up @@ -198,7 +202,9 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s
Scopes: scopes,
EndpointParams: endpointParams,
},
tokenCache: tokenCache}, nil
tokenCache: tokenCache,
cfg: cfg,
}, nil
}

func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
Expand All @@ -207,6 +213,7 @@ func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context
new: p.ccConfig.TokenSource(ctx),
mu: sync.Mutex{},
tokenCache: p.tokenCache,
cfg: p.cfg,
}, nil
}

Expand All @@ -215,6 +222,7 @@ type customTokenSource struct {
mu sync.Mutex // guards everything else
new oauth2.TokenSource
tokenCache cache.TokenCache
cfg *Config
}

func (s *customTokenSource) Token() (*oauth2.Token, error) {
Expand All @@ -225,10 +233,24 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) {
return token, nil
}

token, err := s.new.Token()
totalAttempts := s.cfg.MaxRetries + 1 // Add one for initial request attempt
backoff := wait.Backoff{
Duration: s.cfg.PerRetryTimeout.Duration,
Steps: totalAttempts,
}
var token *oauth2.Token
err := retry.OnError(backoff, func(err error) bool {
return err != nil
}, func() (err error) {
token, err = s.new.Token()
if err != nil {
logger.Infof(s.ctx, "failed to get token: %w", err)
return fmt.Errorf("failed to get token: %w", err)
}
return nil
})
if err != nil {
logger.Warnf(s.ctx, "failed to get token: %v", err)
return nil, fmt.Errorf("failed to get token: %w", err)
return nil, err
}
logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry)

Expand Down

0 comments on commit c127997

Please sign in to comment.