Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Config resolver #328

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"

"github.com/flyteorg/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/logger"

"google.golang.org/grpc/codes"
Expand All @@ -22,20 +21,19 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}

tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, authMetadataClient)
cfgResolver := NewConfigResolver(cfg, authMetadataClient)
tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, cfgResolver)
if err != nil {
return fmt.Errorf("failed to initialized token source provider. Err: %w", err)
}

authorizationMetadataKey := cfg.AuthorizationHeader
if len(authorizationMetadataKey) == 0 {
clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey
clientMetadata, err := cfgResolver.GetPublicClientConfig(ctx)
if err != nil {
return fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}

authorizationMetadataKey := clientMetadata.AuthorizationMetadataKey

tokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
if err != nil {
return err
Expand Down
14 changes: 6 additions & 8 deletions clients/go/admin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,18 @@ func GetAdditionalAdminClientConfigOptions(cfg *Config) []grpc.DialOption {
// This retrieves a DialOption that contains a source for generating JWTs for authentication with Flyte Admin. If
// the token endpoint is set in the config, that will be used, otherwise it'll attempt to make a metadata call.
func getAuthenticationDialOption(ctx context.Context, cfg *Config, tokenSourceProvider TokenSourceProvider,
authClient service.AuthMetadataServiceClient) (grpc.DialOption, error) {
authClient *ConfigResolver) (grpc.DialOption, error) {
if tokenSourceProvider == nil {
return nil, errors.New("can't create authenticated channel without a TokenSourceProvider")
}

authorizationMetadataKey := cfg.AuthorizationHeader
if len(authorizationMetadataKey) == 0 {
clientMetadata, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey
clientMetadata, err := authClient.GetPublicClientConfig(ctx)
if err != nil {
return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}

authorizationMetadataKey := clientMetadata.AuthorizationMetadataKey

tokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
if err != nil {
return nil, err
Expand Down
6 changes: 5 additions & 1 deletion clients/go/admin/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ type Config struct {
TokenURL string `json:"tokenUrl" pflag:",OPTIONAL: Your IdP's token endpoint. It'll be discovered from flyte admin's OAuth Metadata endpoint if not provided."`

// See the implementation of the 'grpcAuthorizationHeader' option in Flyte Admin for more information. But
// basically we want to be able to use a different string to pass the token from this client to the the Admin service
// basically we want to be able to use a different string to pass the token from this client to the Admin service
// because things might be running in a service mesh (like Envoy) that already uses the default 'authorization' header
AuthorizationHeader string `json:"authorizationHeader" pflag:",Custom metadata header to pass JWT"`

Expand All @@ -72,6 +72,10 @@ type Config struct {

Command []string `json:"command" pflag:",Command for external authentication token generation"`

TokenAudience string `json:"tokenAudience" pflag:",OPTIONAL: Audience for token issuance requests. It'll be discovered from the flyte admin's OAuth Metadata endpoint if not provided."`

ServiceHttpEndpoint string `json:"serviceHttpEndpoint" pflag:",OPTIONAL: The http endpoint for FlyteAdmin if it's being served over a different port. It'll be discovered from the flyte admin's OAuth Metadata endpoint if not provided."`

// Set the gRPC service config formatted as a json string https://github.com/grpc/grpc/blob/master/doc/service_config.md
// eg. {"loadBalancingConfig": [{"round_robin":{}}], "methodConfig": [{"name":[{"service": "foo", "method": "bar"}, {"service": "baz"}], "timeout": "1.000000001s"}]}
// find the full schema here https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto#L625
Expand Down
131 changes: 131 additions & 0 deletions clients/go/admin/config_resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package admin

import (
"context"
"github.com/flyteorg/flytestdlib/atomic"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
)

type ConfigResolver struct {
cfg *Config
authMetadataClient service.AuthMetadataServiceClient
resolvedMetadata atomic.Generic[*service.OAuth2MetadataResponse]
resolvedPublicClientConfig atomic.Generic[*service.PublicClientAuthConfigResponse]
}

func (c *ConfigResolver) GetOAuth2Metadata(ctx context.Context) (*service.OAuth2MetadataResponse, error) {
if !c.resolvedMetadata.Empty() {
return c.resolvedMetadata.Load(), nil
}

r := &service.OAuth2MetadataResponse{}
var remote *service.OAuth2MetadataResponse
retrieveFromRemote := func() (*service.OAuth2MetadataResponse, error) {
if remote != nil {
return remote, nil
}

resp, err := c.authMetadataClient.GetOAuth2Metadata(ctx, &service.OAuth2MetadataRequest{})
if err != nil {
return nil, err
}

remote = resp
return remote, nil
}

}

func (c *ConfigResolver) GetPublicClientConfig(ctx context.Context) (*service.PublicClientAuthConfigResponse, error) {
if !c.resolvedPublicClientConfig.Empty() {
return c.resolvedPublicClientConfig.Load(), nil
}

r := &service.PublicClientAuthConfigResponse{}
var remote *service.PublicClientAuthConfigResponse
retrieveFromRemote := func() (*service.PublicClientAuthConfigResponse, error) {
if remote != nil {
return remote, nil
}

resp, err := c.authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return nil, err
}

remote = resp
return remote, nil
}

if len(c.cfg.ClientID) != 0 {
r.ClientId = c.cfg.ClientID
} else {
temp, err := retrieveFromRemote()
if err != nil {
return nil, err
}

r.ClientId = temp.ClientId
}

if len(c.cfg.Scopes) != 0 {
r.Scopes = c.cfg.Scopes
} else {
temp, err := retrieveFromRemote()
if err != nil {
return nil, err
}

r.Scopes = temp.Scopes
}

if len(c.cfg.DeprecatedAuthorizationHeader) != 0 {
r.AuthorizationMetadataKey = c.cfg.DeprecatedAuthorizationHeader
} else {
temp, err := retrieveFromRemote()
if err != nil {
return nil, err
}

r.AuthorizationMetadataKey = temp.AuthorizationMetadataKey
}

if len(c.cfg.TokenAudience) != 0 {
r.Audience = c.cfg.TokenAudience
} else {
temp, err := retrieveFromRemote()
if err != nil {
return nil, err
}

r.Audience = temp.Audience
}

if len(c.cfg.ServiceHttpEndpoint) != 0 {
r.ServiceHttpEndpoint = c.cfg.ServiceHttpEndpoint
} else {
temp, err := retrieveFromRemote()
if err != nil {
return nil, err
}

r.ServiceHttpEndpoint = temp.ServiceHttpEndpoint
}

swapped := c.resolvedPublicClientConfig.CompareAndSwap(nil, r)
if swapped {
return r, nil
}

return c.resolvedPublicClientConfig.Load(), nil
}

func NewConfigResolver(cfg *Config, authMetadataClient service.AuthMetadataServiceClient) *ConfigResolver {
return &ConfigResolver{
cfg: cfg,
authMetadataClient: authMetadataClient,
resolvedMetadata: atomic.NewGenericEmpty[*service.OAuth2MetadataResponse](),
resolvedPublicClientConfig: atomic.NewGenericEmpty[*service.PublicClientAuthConfigResponse](),
}
}
2 changes: 2 additions & 0 deletions clients/go/admin/deviceflow/payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ type DeviceAuthorizationRequest struct {
ClientID string `json:"client_id"`
// Scope is the scope parameter of the access request
Scope string `json:"scope"`
// Audience defines at which endpoints the token can be used.
Audience string `json:"audience"`
}

// DeviceAuthorizationResponse contains the information that the end user would use to authorize the app requesting the
Expand Down
14 changes: 9 additions & 5 deletions clients/go/admin/deviceflow/token_orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import (
)

const (
cliendID = "client_id"
audience = "audience"
clientID = "client_id"
deviceCode = "device_code"
grantType = "grant_type"
scope = "scope"
Expand All @@ -44,13 +45,15 @@ type TokenOrchestrator struct {

// StartDeviceAuthorization will initiate the OAuth2 device authorization flow.
func (t TokenOrchestrator) StartDeviceAuthorization(ctx context.Context, dareq DeviceAuthorizationRequest) (*DeviceAuthorizationResponse, error) {
v := url.Values{cliendID: {dareq.ClientID}, scope: {dareq.Scope}}
v := url.Values{clientID: {dareq.ClientID}, scope: {dareq.Scope}, audience: {dareq.Audience}}
httpReq, err := http.NewRequest("POST", t.ClientConfig.DeviceEndpoint, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")

logger.Debugf(ctx, "Sending the following request to start device authorization %v with body %v", httpReq.URL, v.Encode())

httpResp, err := ctxhttp.Do(ctx, nil, httpReq)
if err != nil {
return nil, err
Expand Down Expand Up @@ -86,19 +89,20 @@ func (t TokenOrchestrator) StartDeviceAuthorization(ctx context.Context, dareq D
// PollTokenEndpoint polls the token endpoint until the user authorizes/ denies the app or an error occurs other than slow_down or authorization_pending
func (t TokenOrchestrator) PollTokenEndpoint(ctx context.Context, tokReq DeviceAccessTokenRequest, pollInterval time.Duration) (*oauth2.Token, error) {
v := url.Values{
cliendID: {tokReq.ClientID},
clientID: {tokReq.ClientID},
grantType: {grantTypeValue},
deviceCode: {tokReq.DeviceCode},
}

for {

httpReq, err := http.NewRequest("POST", t.ClientConfig.Endpoint.TokenURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")

logger.Debugf(ctx, "Sending the following request to fetch the token %v with body %v", httpReq.URL, v.Encode())

httpResp, err := ctxhttp.Do(ctx, nil, httpReq)
if err != nil {
return nil, err
Expand Down Expand Up @@ -150,7 +154,7 @@ func (t TokenOrchestrator) FetchTokenFromAuthFlow(ctx context.Context) (*oauth2.
if len(t.ClientConfig.Scopes) > 0 {
scopes = strings.Join(t.ClientConfig.Scopes, " ")
}
daReq := DeviceAuthorizationRequest{ClientID: t.ClientConfig.ClientID, Scope: scopes}
daReq := DeviceAuthorizationRequest{ClientID: t.ClientConfig.ClientID, Scope: scopes, Audience: t.ClientConfig.Audience}
daResp, err := t.StartDeviceAuthorization(ctx, daReq)
if err != nil {
return nil, err
Expand Down
26 changes: 24 additions & 2 deletions clients/go/admin/deviceflow/token_orchestrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -46,9 +47,18 @@ func TestFetchFromAuthFlow(t *testing.T) {
body, err := io.ReadAll(r.Body)
assert.Nil(t, err)
isDeviceReq := strings.Contains(string(body), scope)
isTokReq := strings.Contains(string(body), deviceCode) && strings.Contains(string(body), grantType) && strings.Contains(string(body), cliendID)
isTokReq := strings.Contains(string(body), deviceCode) && strings.Contains(string(body), grantType) && strings.Contains(string(body), clientID)

if isDeviceReq {
for _, urlParm := range strings.Split(string(body), "&") {
paramKeyValue := strings.Split(urlParm, "=")
switch paramKeyValue[0] {
case audience:
assert.Equal(t, "abcd", paramKeyValue[1])
case clientID:
assert.Equal(t, clientID, paramKeyValue[1])
}
}
dar := DeviceAuthorizationResponse{
DeviceCode: "e1db31fe-3b23-4fce-b759-82bf8ea323d6",
UserCode: "RPBQZNRX",
Expand All @@ -62,6 +72,17 @@ func TestFetchFromAuthFlow(t *testing.T) {
assert.Nil(t, err)
return
} else if isTokReq {
for _, urlParm := range strings.Split(string(body), "&") {
paramKeyValue := strings.Split(urlParm, "=")
switch paramKeyValue[0] {
case grantType:
assert.Equal(t, url.QueryEscape(grantTypeValue), paramKeyValue[1])
case deviceCode:
assert.Equal(t, "e1db31fe-3b23-4fce-b759-82bf8ea323d6", paramKeyValue[1])
case clientID:
assert.Equal(t, clientID, paramKeyValue[1])
}
}
dar := DeviceAccessTokenResponse{
Token: oauth2.Token{
AccessToken: "access_token",
Expand All @@ -81,14 +102,15 @@ func TestFetchFromAuthFlow(t *testing.T) {
orchestrator, err := NewDeviceFlowTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{
ClientConfig: &oauth.Config{
Config: &oauth2.Config{
ClientID: cliendID,
ClientID: clientID,
RedirectURL: "http://localhost:8089/redirect",
Scopes: []string{"code", "all"},
Endpoint: oauth2.Endpoint{
TokenURL: fakeServer.URL,
},
},
DeviceEndpoint: fakeServer.URL,
Audience: "abcd",
},
TokenCache: tokenCache,
}, Config{
Expand Down
10 changes: 7 additions & 3 deletions clients/go/admin/oauth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package oauth

import (
"context"
"github.com/flyteorg/flyteidl/clients/go/admin"

"golang.org/x/oauth2"

Expand All @@ -12,17 +13,19 @@ import (
type Config struct {
*oauth2.Config
DeviceEndpoint string
// Audience value to be passed when requesting access token using device flow.This needs to be passed in the first request of the device flow currently and is configured in admin public client config.Required when auth server hasn't been configured with default audience"`
Audience string
}

// BuildConfigFromMetadataService builds OAuth2 config from information retrieved through the anonymous auth metadata service.
func BuildConfigFromMetadataService(ctx context.Context, authMetadataClient service.AuthMetadataServiceClient) (clientConf *Config, err error) {
func BuildConfigFromMetadataService(ctx context.Context, authMetadataClient *admin.ConfigResolver) (clientConf *Config, err error) {
var clientResp *service.PublicClientAuthConfigResponse
if clientResp, err = authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}); err != nil {
if clientResp, err = authMetadataClient.GetPublicClientConfig(ctx); err != nil {
return nil, err
}

var oauthMetaResp *service.OAuth2MetadataResponse
if oauthMetaResp, err = authMetadataClient.GetOAuth2Metadata(ctx, &service.OAuth2MetadataRequest{}); err != nil {
if oauthMetaResp, err = authMetadataClient.GetOAuth2Metadata(ctx); err != nil {
return nil, err
}

Expand All @@ -37,6 +40,7 @@ func BuildConfigFromMetadataService(ctx context.Context, authMetadataClient serv
},
},
DeviceEndpoint: oauthMetaResp.DeviceAuthorizationEndpoint,
Audience: clientResp.Audience,
}

return clientConf, nil
Expand Down
2 changes: 2 additions & 0 deletions clients/go/admin/oauth/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func TestGenerateClientConfig(t *testing.T) {
ClientId: "dummyClient",
RedirectUri: "dummyRedirectUri",
Scopes: []string{"dummyScopes"},
Audience: "dummyAudience",
}
oauthMetaDataResp := &service.OAuth2MetadataResponse{
Issuer: "dummyIssuer",
Expand All @@ -36,4 +37,5 @@ func TestGenerateClientConfig(t *testing.T) {
assert.Equal(t, "dummyTokenEndpoint", oauthConfig.Endpoint.TokenURL)
assert.Equal(t, "dummyAuthEndPoint", oauthConfig.Endpoint.AuthURL)
assert.Equal(t, "dummyDeviceEndpoint", oauthConfig.DeviceEndpoint)
assert.Equal(t, "dummyAudience", oauthConfig.Audience)
}
Loading