From 1d653d4b399315df4432288ddc1f6181bbb158f2 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Tue, 24 Sep 2024 14:42:42 +1000 Subject: [PATCH 01/20] chore: migrate azure-sdk-for-go/containerregistry to the latest release Signed-off-by: Shahram Kalantari --- go.mod | 7 ++++--- go.sum | 18 +++++++++-------- .../oras/authprovider/azure/azureidentity.go | 19 +++++++++++++++--- .../azure/azureworkloadidentity.go | 20 ++++++++++++++++--- 4 files changed, 47 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index 27d1f11c1..37cc55c60 100644 --- a/go.mod +++ b/go.mod @@ -10,8 +10,9 @@ retract ( require ( github.com/Azure/azure-sdk-for-go v68.0.0+incompatible - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 - github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 + github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry v0.2.2 github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 github.com/aws/aws-sdk-go-v2 v1.32.2 github.com/aws/aws-sdk-go-v2/config v1.27.43 @@ -130,7 +131,7 @@ require ( ) require ( - github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect github.com/Azure/go-autorest v14.2.0+incompatible // indirect github.com/Azure/go-autorest/autorest v0.11.29 github.com/Azure/go-autorest/autorest/adal v0.9.24 // indirect diff --git a/go.sum b/go.sum index d94ab83cf..8a05bb2d1 100644 --- a/go.sum +++ b/go.sum @@ -18,12 +18,14 @@ github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/alibabacloudsdkgo github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/alibabacloudsdkgo/helper v0.2.0/go.mod h1:GgeIE+1be8Ivm7Sh4RgwI42aTtC9qrcj+Y9Y6CjJhJs= github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU= github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0 h1:U2rTu3Ef+7w9FHKIAXM6ZyqF3UOWJZ12zIm8zECAFfg= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 h1:jBQA3cKT4L2rWMpgE7Yt3Hwh2aUj8KXjIGLxjHeYNNo= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= +github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry v0.2.2 h1:wBx10efdJcl8FSewgc41kAW4AvHPgmJZmN7fpNxn8rc= +github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry v0.2.2/go.mod h1:zzmu18cpAinSbhC86oWd47nmgbb91Fl+Yac2PE8NdYk= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0 h1:DRiANoJTiW6obBQe3SqZizkuV1PEgfiiGivmVocDy64= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0/go.mod h1:qLIye2hwb/ZouqhpSD9Zn3SJipvpEnz1Ywl3VUk9Y0s= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 h1:D3occbWoio4EBLkbkevetNMAVX197GkzbUMtqjGWn80= @@ -823,8 +825,8 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index 0a5a00e5c..19fd78ca1 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -29,7 +29,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/Azure/azure-sdk-for-go/services/preview/containerregistry/runtime/2019-08-15-preview/containerregistry" + "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" ) type azureManagedIdentityProviderFactory struct{} @@ -135,8 +135,21 @@ func (d *azureManagedIdentityAuthProvider) Provide(ctx context.Context, artifact serverURL := "https://" + artifactHostName // create registry client and exchange AAD token for registry refresh token - refreshTokenClient := containerregistry.NewRefreshTokensClient(serverURL) - rt, err := refreshTokenClient.GetFromExchange(ctx, "access_token", artifactHostName, d.tenantID, "", d.identityToken.Token) + client, err := azcontainerregistry.NewAuthenticationClient(serverURL, nil) // &AuthenticationClientOptions{ClientOptions: options}) + if err != nil { + return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry by azure managed identity token", re.HideStackTrace) + } + // refreshTokenClient := containerregistry.NewRefreshTokensClient(serverURL) + rt, err := client.ExchangeAADAccessTokenForACRRefreshToken( + context.Background(), + "access_token", + artifactHostName, + &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ + AccessToken: &d.identityToken.Token, + Tenant: &d.tenantID, + }, + ) + // rt, err := refreshTokenClient.GetFromExchange(ctx, "access_token", artifactHostName, d.tenantID, "", d.identityToken.Token) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "failed to get refresh token for container registry by azure managed identity token", re.HideStackTrace) } diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index a40ce4436..c5acbcee0 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -21,13 +21,13 @@ import ( "os" "time" + "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" re "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/internal/logger" provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider" "github.com/ratify-project/ratify/pkg/metrics" "github.com/ratify-project/ratify/pkg/utils/azureauth" - "github.com/Azure/azure-sdk-for-go/services/preview/containerregistry/runtime/2019-08-15-preview/containerregistry" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) @@ -130,9 +130,23 @@ func (d *azureWIAuthProvider) Provide(ctx context.Context, artifact string) (pro serverURL := "https://" + artifactHostName // create registry client and exchange AAD token for registry refresh token - refreshTokenClient := containerregistry.NewRefreshTokensClient(serverURL) + // TODO: Consider adding authentication client options for multicloud scenarios + client, err := azcontainerregistry.NewAuthenticationClient(serverURL, nil) // &AuthenticationClientOptions{ClientOptions: options}) + if err != nil { + return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry", re.HideStackTrace) + } + // refreshTokenClient := azcontainerregistry.NewRefreshTokensClient(serverURL) startTime := time.Now() - rt, err := refreshTokenClient.GetFromExchange(context.Background(), "access_token", artifactHostName, d.tenantID, "", d.aadToken.AccessToken) + rt, err := client.ExchangeAADAccessTokenForACRRefreshToken( + context.Background(), + "access_token", + artifactHostName, + &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ + AccessToken: &d.aadToken.AccessToken, + Tenant: &d.tenantID, + }, + ) + // rt, err := refreshTokenClient.GetFromExchange(context.Background(), "access_token", artifactHostName, d.tenantID, "", d.aadToken.AccessToken) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to get refresh token for container registry", re.HideStackTrace) } From c7137d87d74ac6f11fe19af2b9c9832f191d9d21 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Sun, 29 Sep 2024 19:11:51 +1000 Subject: [PATCH 02/20] chore: refactor to enable mocking and add unit tests to azureworkloadidentity_test.go Signed-off-by: Shahram Kalantari --- go.mod | 1 + go.sum | 1 + .../azure/azureworkloadidentity.go | 65 +++++++++++++------ .../azure/azureworkloadidentity_test.go | 60 +++++++++++++++++ 4 files changed, 107 insertions(+), 20 deletions(-) diff --git a/go.mod b/go.mod index 37cc55c60..c580fa570 100644 --- a/go.mod +++ b/go.mod @@ -119,6 +119,7 @@ require ( github.com/sigstore/timestamp-authority v1.2.2 // indirect github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect github.com/sourcegraph/conc v0.3.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tchap/go-patricia/v2 v2.3.1 // indirect github.com/thales-e-security/pool v0.0.2 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect diff --git a/go.sum b/go.sum index 8a05bb2d1..ecb26633c 100644 --- a/go.sum +++ b/go.sum @@ -659,6 +659,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index c5acbcee0..183b7f87a 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -21,7 +21,7 @@ import ( "os" "time" - "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" + azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" re "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/internal/logger" provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider" @@ -33,9 +33,40 @@ import ( type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name type azureWIAuthProvider struct { - aadToken confidential.AuthResult - tenantID string - clientID string + aadToken confidential.AuthResult + tenantID string + clientID string + authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error) + getRegistryHost func(artifact string) (string, error) + getAADAccessToken func(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) + reportMetrics func(ctx context.Context, duration int64, artifactHostName string) +} + +type authenticationClientWrapper struct { + client *azcontainerregistry.AuthenticationClient +} + +func (w *authenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { + return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, azcontainerregistry.PostContentSchemaGrantType(grantType), service, options) +} + +type authClient interface { + ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) +} + +func NewAzureWIAuthProvider() *azureWIAuthProvider { + return &azureWIAuthProvider{ + authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error) { + client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) + if err != nil { + return nil, err + } + return &authenticationClientWrapper{client: client}, nil + }, + getRegistryHost: provider.GetRegistryHostName, + getAADAccessToken: azureauth.GetAADAccessToken, + reportMetrics: metrics.ReportACRExchangeDuration, + } } type azureWIAuthProviderConf struct { @@ -103,22 +134,18 @@ func (d *azureWIAuthProvider) Enabled(_ context.Context) bool { return true } -// Provide returns the credentials for a specified artifact. -// Uses Azure Workload Identity to retrieve an AAD access token which can be -// exchanged for a valid ACR refresh token for login. func (d *azureWIAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) { if !d.Enabled(ctx) { return provider.AuthConfig{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("azure workload identity auth provider is not properly enabled") } - // parse the artifact reference string to extract the registry host name - artifactHostName, err := provider.GetRegistryHostName(artifact) + + artifactHostName, err := d.getRegistryHost(artifact) if err != nil { return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider) } - // need to refresh AAD token if it's expired if time.Now().Add(time.Minute * 5).After(d.aadToken.ExpiresOn) { - newToken, err := azureauth.GetAADAccessToken(ctx, d.tenantID, d.clientID, AADResource) + newToken, err := d.getAADAccessToken(ctx, d.tenantID, d.clientID, AADResource) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, nil, "could not refresh AAD token", re.HideStackTrace) } @@ -126,19 +153,16 @@ func (d *azureWIAuthProvider) Provide(ctx context.Context, artifact string) (pro logger.GetLogger(ctx, logOpt).Info("successfully refreshed AAD token") } - // add protocol to generate complete URI serverURL := "https://" + artifactHostName - - // create registry client and exchange AAD token for registry refresh token // TODO: Consider adding authentication client options for multicloud scenarios - client, err := azcontainerregistry.NewAuthenticationClient(serverURL, nil) // &AuthenticationClientOptions{ClientOptions: options}) + client, err := d.authClientFactory(serverURL, nil) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry", re.HideStackTrace) } - // refreshTokenClient := azcontainerregistry.NewRefreshTokensClient(serverURL) + startTime := time.Now() - rt, err := client.ExchangeAADAccessTokenForACRRefreshToken( - context.Background(), + response, err := client.ExchangeAADAccessTokenForACRRefreshToken( + ctx, "access_token", artifactHostName, &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ @@ -146,11 +170,12 @@ func (d *azureWIAuthProvider) Provide(ctx context.Context, artifact string) (pro Tenant: &d.tenantID, }, ) - // rt, err := refreshTokenClient.GetFromExchange(context.Background(), "access_token", artifactHostName, d.tenantID, "", d.aadToken.AccessToken) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to get refresh token for container registry", re.HideStackTrace) } - metrics.ReportACRExchangeDuration(ctx, time.Since(startTime).Milliseconds(), artifactHostName) + rt := response.ACRRefreshToken + + d.reportMetrics(ctx, time.Since(startTime).Milliseconds(), artifactHostName) refreshTokenExpiry := getACRExpiryIfEarlier(d.aadToken.ExpiresOn) authConfig := provider.AuthConfig{ diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index 3695ef65a..a10f374d3 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -22,9 +22,12 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ratifyerrors "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/pkg/common/oras/authprovider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) // Verifies that Enabled checks if tenantID is empty or AAD token is empty @@ -131,3 +134,60 @@ func TestAzureWIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) { t.Fatalf("create auth provider should have failed: expected err %s, but got err %s", expectedErr, err) } } + +type mockAuthClient struct { + mock.Mock +} + +func (m *mockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { + args := m.Called(ctx, grantType, service, options) + return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1) +} + +func TestProvide_Success(t *testing.T) { + mockClient := new(mockAuthClient) + expectedRefreshToken := "mocked_refresh_token" + mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "myregistry.azurecr.io", mock.Anything). + Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &expectedRefreshToken}, + }, nil) + + provider := &azureWIAuthProvider{ + aadToken: confidential.AuthResult{ + AccessToken: "mockToken", + ExpiresOn: time.Now().Add(time.Hour), + }, + tenantID: "mockTenantID", + clientID: "mockClientID", + authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error) { + return mockClient, nil + }, + getRegistryHost: func(artifact string) (string, error) { + return "myregistry.azurecr.io", nil + }, + getAADAccessToken: func(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { + return confidential.AuthResult{ + AccessToken: "mockToken", + ExpiresOn: time.Now().Add(time.Hour), + }, nil + }, + reportMetrics: func(ctx context.Context, duration int64, artifactHostName string) {}, + } + + authConfig, err := provider.Provide(context.Background(), "artifact") + + assert.NoError(t, err) + // Assert that the returned refresh token matches the expected one + assert.Equal(t, expectedRefreshToken, authConfig.Password) +} + +func TestProvide_Failure_InvalidHostName(t *testing.T) { + provider := &azureWIAuthProvider{ + getRegistryHost: func(artifact string) (string, error) { + return "", errors.New("invalid hostname") + }, + } + + _, err := provider.Provide(context.Background(), "artifact") + assert.Error(t, err) +} From b1a397a1437eccc66d49cd9785c1e080b8cfed88 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Sun, 29 Sep 2024 19:22:49 +1000 Subject: [PATCH 03/20] chore: lint Signed-off-by: Shahram Kalantari --- .../authprovider/azure/azureworkloadidentity.go | 12 ++++++------ .../azure/azureworkloadidentity_test.go | 16 ++++++++-------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index 183b7f87a..241696bcf 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -32,7 +32,7 @@ import ( ) type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name -type azureWIAuthProvider struct { +type WIAuthProvider struct { aadToken confidential.AuthResult tenantID string clientID string @@ -54,8 +54,8 @@ type authClient interface { ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) } -func NewAzureWIAuthProvider() *azureWIAuthProvider { - return &azureWIAuthProvider{ +func NewAzureWIAuthProvider() *WIAuthProvider { + return &WIAuthProvider{ authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error) { client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) if err != nil { @@ -114,7 +114,7 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider return nil, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "", re.HideStackTrace) } - return &azureWIAuthProvider{ + return &WIAuthProvider{ aadToken: token, tenantID: tenant, clientID: clientID, @@ -122,7 +122,7 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider } // Enabled checks for non empty tenant ID and AAD access token -func (d *azureWIAuthProvider) Enabled(_ context.Context) bool { +func (d *WIAuthProvider) Enabled(_ context.Context) bool { if d.tenantID == "" || d.clientID == "" { return false } @@ -134,7 +134,7 @@ func (d *azureWIAuthProvider) Enabled(_ context.Context) bool { return true } -func (d *azureWIAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) { +func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) { if !d.Enabled(ctx) { return provider.AuthConfig{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("azure workload identity auth provider is not properly enabled") } diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index a10f374d3..b4a9a1f7c 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -32,7 +32,7 @@ import ( // Verifies that Enabled checks if tenantID is empty or AAD token is empty func TestAzureWIEnabled_ExpectedResults(t *testing.T) { - azAuthProvider := azureWIAuthProvider{ + azAuthProvider := WIAuthProvider{ tenantID: "test_tenant", clientID: "test_client", aadToken: confidential.AuthResult{ @@ -152,26 +152,26 @@ func TestProvide_Success(t *testing.T) { ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &expectedRefreshToken}, }, nil) - provider := &azureWIAuthProvider{ + provider := &WIAuthProvider{ aadToken: confidential.AuthResult{ AccessToken: "mockToken", ExpiresOn: time.Now().Add(time.Hour), }, tenantID: "mockTenantID", clientID: "mockClientID", - authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error) { + authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (authClient, error) { return mockClient, nil }, - getRegistryHost: func(artifact string) (string, error) { + getRegistryHost: func(_ string) (string, error) { return "myregistry.azurecr.io", nil }, - getAADAccessToken: func(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { + getAADAccessToken: func(_ context.Context, _, _, _ string) (confidential.AuthResult, error) { return confidential.AuthResult{ AccessToken: "mockToken", ExpiresOn: time.Now().Add(time.Hour), }, nil }, - reportMetrics: func(ctx context.Context, duration int64, artifactHostName string) {}, + reportMetrics: func(_ context.Context, _ int64, _ string) {}, } authConfig, err := provider.Provide(context.Background(), "artifact") @@ -182,8 +182,8 @@ func TestProvide_Success(t *testing.T) { } func TestProvide_Failure_InvalidHostName(t *testing.T) { - provider := &azureWIAuthProvider{ - getRegistryHost: func(artifact string) (string, error) { + provider := &WIAuthProvider{ + getRegistryHost: func(_ string) (string, error) { return "", errors.New("invalid hostname") }, } From 24d306d5415ca3900de9500b5b972c9596d7bcad Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Sun, 29 Sep 2024 19:35:43 +1000 Subject: [PATCH 04/20] chore: address comments Signed-off-by: Shahram Kalantari --- pkg/common/oras/authprovider/azure/azureworkloadidentity.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index 241696bcf..b35a4e6ec 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -155,7 +155,8 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider serverURL := "https://" + artifactHostName // TODO: Consider adding authentication client options for multicloud scenarios - client, err := d.authClientFactory(serverURL, nil) + var options *azcontainerregistry.AuthenticationClientOptions + client, err := d.authClientFactory(serverURL, options) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry", re.HideStackTrace) } From c5d87bf9086a3ad15a043bc9f98848ee34efd8bc Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Mon, 30 Sep 2024 10:46:16 +1000 Subject: [PATCH 05/20] chore: add comments Signed-off-by: Shahram Kalantari --- .../oras/authprovider/azure/azureworkloadidentity.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index b35a4e6ec..df96cc0ba 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -134,16 +134,21 @@ func (d *WIAuthProvider) Enabled(_ context.Context) bool { return true } +// Provide returns the credentials for a specified artifact. +// Uses Azure Workload Identity to retrieve an AAD access token which can be +// exchanged for a valid ACR refresh token for login. func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) { if !d.Enabled(ctx) { return provider.AuthConfig{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("azure workload identity auth provider is not properly enabled") } + // parse the artifact reference string to extract the registry host name artifactHostName, err := d.getRegistryHost(artifact) if err != nil { return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider) } + // need to refresh AAD token if it's expired if time.Now().Add(time.Minute * 5).After(d.aadToken.ExpiresOn) { newToken, err := d.getAADAccessToken(ctx, d.tenantID, d.clientID, AADResource) if err != nil { @@ -153,7 +158,10 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider logger.GetLogger(ctx, logOpt).Info("successfully refreshed AAD token") } + // add protocol to generate complete URI serverURL := "https://" + artifactHostName + + // create registry client and exchange AAD token for registry refresh token // TODO: Consider adding authentication client options for multicloud scenarios var options *azcontainerregistry.AuthenticationClientOptions client, err := d.authClientFactory(serverURL, options) From 0d4e8a60c11e374c3203b956e98150ea84ee2051 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Tue, 1 Oct 2024 17:36:52 +1000 Subject: [PATCH 06/20] chore: more unit tests Signed-off-by: Shahram Kalantari --- .../oras/authprovider/azure/azureidentity.go | 60 ++++--- .../authprovider/azure/azureidentity_test.go | 156 +++++++++++++++++- .../azure/azureworkloadidentity.go | 14 +- .../azure/azureworkloadidentity_test.go | 135 ++++++++++++++- 4 files changed, 329 insertions(+), 36 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index 19fd78ca1..5cfc21b3c 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -33,10 +33,28 @@ import ( ) type azureManagedIdentityProviderFactory struct{} -type azureManagedIdentityAuthProvider struct { - identityToken azcore.AccessToken - clientID string - tenantID string +type MIAuthProvider struct { + identityToken azcore.AccessToken + clientID string + tenantID string + authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) + getRegistryHost func(artifact string) (string, error) + getManagedIdentityToken func(ctx context.Context, clientID string) (azcore.AccessToken, error) +} + +// NewAzureWIAuthProvider is defined to enable mocking of some of the function in unit tests +func NewAzureMIAuthProvider() *MIAuthProvider { + return &MIAuthProvider{ + authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) + if err != nil { + return nil, err + } + return &AuthenticationClientWrapper{client: client}, nil + }, + getRegistryHost: provider.GetRegistryHostName, + getManagedIdentityToken: getManagedIdentityToken, + } } type azureManagedIdentityAuthProviderConf struct { @@ -53,7 +71,7 @@ func init() { provider.Register(azureManagedIdentityAuthProviderName, &azureManagedIdentityProviderFactory{}) } -// Create returns an azureManagedIdentityAuthProvider +// Create returns an MIAuthProvider func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider.AuthProviderConfig) (provider.AuthProvider, error) { conf := azureManagedIdentityAuthProviderConf{} authProviderConfigBytes, err := json.Marshal(authProviderConfig) @@ -85,7 +103,7 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider return nil, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "", re.HideStackTrace) } - return &azureManagedIdentityAuthProvider{ + return &MIAuthProvider{ identityToken: token, clientID: client, tenantID: tenant, @@ -93,7 +111,7 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider } // Enabled checks for non empty tenant ID and AAD access token -func (d *azureManagedIdentityAuthProvider) Enabled(_ context.Context) bool { +func (d *MIAuthProvider) Enabled(_ context.Context) bool { if d.clientID == "" { return false } @@ -112,36 +130,39 @@ func (d *azureManagedIdentityAuthProvider) Enabled(_ context.Context) bool { // Provide returns the credentials for a specified artifact. // Uses Managed Identity to retrieve an AAD access token which can be // exchanged for a valid ACR refresh token for login. -func (d *azureManagedIdentityAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) { +func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) { if !d.Enabled(ctx) { return provider.AuthConfig{}, fmt.Errorf("azure managed identity provider is not properly enabled") } + // parse the artifact reference string to extract the registry host name - artifactHostName, err := provider.GetRegistryHostName(artifact) + artifactHostName, err := d.getRegistryHost(artifact) if err != nil { - return provider.AuthConfig{}, err + return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider) } // need to refresh AAD token if it's expired if time.Now().Add(time.Minute * 5).After(d.identityToken.ExpiresOn) { - newToken, err := getManagedIdentityToken(ctx, d.clientID) + newToken, err := d.getManagedIdentityToken(ctx, d.clientID) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "could not refresh azure managed identity token", re.HideStackTrace) } d.identityToken = newToken logger.GetLogger(ctx, logOpt).Info("successfully refreshed azure managed identity token") } + // add protocol to generate complete URI serverURL := "https://" + artifactHostName - // create registry client and exchange AAD token for registry refresh token - client, err := azcontainerregistry.NewAuthenticationClient(serverURL, nil) // &AuthenticationClientOptions{ClientOptions: options}) + // TODO: Consider adding authentication client options for multicloud scenarios + var options *azcontainerregistry.AuthenticationClientOptions + client, err := d.authClientFactory(serverURL, options) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry by azure managed identity token", re.HideStackTrace) } - // refreshTokenClient := containerregistry.NewRefreshTokensClient(serverURL) - rt, err := client.ExchangeAADAccessTokenForACRRefreshToken( - context.Background(), + + response, err := client.ExchangeAADAccessTokenForACRRefreshToken( + ctx, "access_token", artifactHostName, &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ @@ -149,18 +170,17 @@ func (d *azureManagedIdentityAuthProvider) Provide(ctx context.Context, artifact Tenant: &d.tenantID, }, ) - // rt, err := refreshTokenClient.GetFromExchange(ctx, "access_token", artifactHostName, d.tenantID, "", d.identityToken.Token) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "failed to get refresh token for container registry by azure managed identity token", re.HideStackTrace) } + rt := response.ACRRefreshToken - expiresOn := getACRExpiryIfEarlier(d.identityToken.ExpiresOn) - + refreshTokenExpiry := getACRExpiryIfEarlier(d.identityToken.ExpiresOn) authConfig := provider.AuthConfig{ Username: dockerTokenLoginUsernameGUID, Password: *rt.RefreshToken, Provider: d, - ExpiresOn: expiresOn, + ExpiresOn: refreshTokenExpiry, } return authConfig, nil diff --git a/pkg/common/oras/authprovider/azure/azureidentity_test.go b/pkg/common/oras/authprovider/azure/azureidentity_test.go index 472e704b9..680b6b489 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureidentity_test.go @@ -20,15 +20,28 @@ import ( "errors" "os" "testing" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" ratifyerrors "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/pkg/common/oras/authprovider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) +type MockGetManagedIdentityToken struct { + mock.Mock +} + +func (m *MockGetManagedIdentityToken) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { + args := m.Called(ctx, clientID) + return args.Get(0).(azcore.AccessToken), args.Error(1) +} + // Verifies that Enabled checks if tenantID is empty or AAD token is empty func TestAzureMSIEnabled_ExpectedResults(t *testing.T) { - azAuthProvider := azureManagedIdentityAuthProvider{ + azAuthProvider := MIAuthProvider{ tenantID: "test_tenant", clientID: "test_client", identityToken: azcore.AccessToken{ @@ -89,3 +102,144 @@ func TestAzureMSIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) { t.Fatalf("create auth provider should have failed: expected err %s, but got err %s", expectedErr, err) } } + +func TestNewAzureMIAuthProvider_AuthenticationClientError(t *testing.T) { + // Create a new mock client factory + mockFactory := new(MockAuthClientFactory) + + // Setup mock to return an error + mockFactory.On("NewAuthenticationClient", mock.Anything, mock.Anything). + Return(nil, errors.New("failed to create authentication client")) + + // Create a new WIAuthProvider instance + provider := NewAzureMIAuthProvider() + provider.authClientFactory = mockFactory.NewAuthenticationClient + + // Call authClientFactory to test error handling + _, err := provider.authClientFactory("https://myregistry.azurecr.io", nil) + + // Assert that an error is returned + assert.Error(t, err) + assert.Equal(t, "failed to create authentication client", err.Error()) + + // Verify that the mock was called + mockFactory.AssertCalled(t, "NewAuthenticationClient", "https://myregistry.azurecr.io", mock.Anything) +} + +func TestNewAzureMIAuthProvider_Success(t *testing.T) { + // Create a new mock client factory + mockFactory := new(MockAuthClientFactory) + + // Create a mock auth client to return from the factory + mockAuthClient := new(MockAuthClient) + + // Setup mock to return a successful auth client + mockFactory.On("NewAuthenticationClient", mock.Anything, mock.Anything). + Return(mockAuthClient, nil) + + // Create a new WIAuthProvider instance + provider := NewAzureMIAuthProvider() + + // Replace authClientFactory with the mock factory + provider.authClientFactory = mockFactory.NewAuthenticationClient + + // Call authClientFactory to test successful return + client, err := provider.authClientFactory("https://myregistry.azurecr.io", nil) + + // Assert that the client is returned without an error + assert.NoError(t, err) + assert.NotNil(t, client) + + // Assert that the returned client is of the expected type + _, ok := client.(*MockAuthClient) + assert.True(t, ok, "expected client to be of type *MockAuthClient") + + // Verify that the mock was called + mockFactory.AssertCalled(t, "NewAuthenticationClient", "https://myregistry.azurecr.io", mock.Anything) +} + +func TestMIProvide_Success(t *testing.T) { + const registryHost = "myregistry.azurecr.io" + mockClient := new(MockAuthClient) + expectedRefreshToken := "mocked_refresh_token" + mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", registryHost, mock.Anything). + Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &expectedRefreshToken}, + }, nil) + + provider := &MIAuthProvider{ + identityToken: azcore.AccessToken{ + Token: "mockToken", + ExpiresOn: time.Now().Add(time.Hour), + }, + tenantID: "mockTenantID", + clientID: "mockClientID", + authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + return mockClient, nil + }, + getRegistryHost: func(_ string) (string, error) { + return registryHost, nil + }, + getManagedIdentityToken: func(_ context.Context, _ string) (azcore.AccessToken, error) { + return azcore.AccessToken{ + Token: "mockToken", + ExpiresOn: time.Now().Add(time.Hour), + }, nil + }, + } + + authConfig, err := provider.Provide(context.Background(), "artifact") + + assert.NoError(t, err) + // Assert that getManagedIdentityToken was not called + mockClient.AssertNotCalled(t, "getManagedIdentityToken", mock.Anything, mock.Anything) + // Assert that the returned refresh token matches the expected one + assert.Equal(t, expectedRefreshToken, authConfig.Password) +} + +func TestMIProvide_RefreshAAD(t *testing.T) { + const registryHost = "myregistry.azurecr.io" + // Arrange + mockClient := new(MockAuthClient) + + // Create a mock function for getManagedIdentityToken + mockGetManagedIdentityToken := new(MockGetManagedIdentityToken) + + provider := &MIAuthProvider{ + identityToken: azcore.AccessToken{ + Token: "mockToken", + ExpiresOn: time.Now(), // Expired token to force a refresh + }, + tenantID: "mockTenantID", + clientID: "mockClientID", + authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + return mockClient, nil + }, + getRegistryHost: func(_ string) (string, error) { + return registryHost, nil + }, + getManagedIdentityToken: mockGetManagedIdentityToken.GetManagedIdentityToken, // Use the mock + } + + mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", registryHost, mock.Anything). + Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: new(string)}, + }, nil) + + // Set up the expectation for the mocked method + mockGetManagedIdentityToken.On("GetManagedIdentityToken", mock.Anything, "mockClientID"). + Return(azcore.AccessToken{ + Token: "newMockToken", + ExpiresOn: time.Now().Add(time.Hour), + }, nil) + + ctx := context.TODO() + artifact := "testArtifact" + + // Act + _, err := provider.Provide(ctx, artifact) + + // Assert + assert.NoError(t, err) + mockGetManagedIdentityToken.AssertCalled(t, "GetManagedIdentityToken", mock.Anything, "mockClientID") // Assert that getManagedIdentityToken was called +} diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index df96cc0ba..77744622b 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -36,32 +36,33 @@ type WIAuthProvider struct { aadToken confidential.AuthResult tenantID string clientID string - authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error) + authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) getRegistryHost func(artifact string) (string, error) getAADAccessToken func(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) reportMetrics func(ctx context.Context, duration int64, artifactHostName string) } -type authenticationClientWrapper struct { +type AuthenticationClientWrapper struct { client *azcontainerregistry.AuthenticationClient } -func (w *authenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { +func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, azcontainerregistry.PostContentSchemaGrantType(grantType), service, options) } -type authClient interface { +type AuthClient interface { ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) } +// NewAzureWIAuthProvider is defined to enable mocking of some of the function in unit tests func NewAzureWIAuthProvider() *WIAuthProvider { return &WIAuthProvider{ - authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error) { + authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) if err != nil { return nil, err } - return &authenticationClientWrapper{client: client}, nil + return &AuthenticationClientWrapper{client: client}, nil }, getRegistryHost: provider.GetRegistryHostName, getAADAccessToken: azureauth.GetAADAccessToken, @@ -161,7 +162,6 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider // add protocol to generate complete URI serverURL := "https://" + artifactHostName - // create registry client and exchange AAD token for registry refresh token // TODO: Consider adding authentication client options for multicloud scenarios var options *azcontainerregistry.AuthenticationClientOptions client, err := d.authClientFactory(serverURL, options) diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index b4a9a1f7c..1f6d3743b 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -30,6 +30,36 @@ import ( "github.com/stretchr/testify/mock" ) +type MockAuthClient struct { + mock.Mock +} + +type MockAzureAuth struct { + mock.Mock +} + +type MockAuthClientFactory struct { + mock.Mock +} + +func (m *MockAuthClientFactory) NewAuthenticationClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + args := m.Called(serverURL, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(AuthClient), args.Error(1) +} + +func (m *MockAzureAuth) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { + args := m.Called(ctx, tenantID, clientID, resource) + return args.Get(0).(confidential.AuthResult), args.Error(1) +} + +func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { + args := m.Called(ctx, grantType, service, options) + return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1) +} + // Verifies that Enabled checks if tenantID is empty or AAD token is empty func TestAzureWIEnabled_ExpectedResults(t *testing.T) { azAuthProvider := WIAuthProvider{ @@ -135,17 +165,63 @@ func TestAzureWIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) { } } -type mockAuthClient struct { - mock.Mock +func TestNewAzureWIAuthProvider_AuthenticationClientError(t *testing.T) { + // Create a new mock client factory + mockFactory := new(MockAuthClientFactory) + + // Setup mock to return an error + mockFactory.On("NewAuthenticationClient", mock.Anything, mock.Anything). + Return(nil, errors.New("failed to create authentication client")) + + // Create a new WIAuthProvider instance + provider := NewAzureWIAuthProvider() + provider.authClientFactory = mockFactory.NewAuthenticationClient + + // Call authClientFactory to test error handling + _, err := provider.authClientFactory("https://myregistry.azurecr.io", nil) + + // Assert that an error is returned + assert.Error(t, err) + assert.Equal(t, "failed to create authentication client", err.Error()) + + // Verify that the mock was called + mockFactory.AssertCalled(t, "NewAuthenticationClient", "https://myregistry.azurecr.io", mock.Anything) } -func (m *mockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { - args := m.Called(ctx, grantType, service, options) - return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1) +func TestNewAzureWIAuthProvider_Success(t *testing.T) { + // Create a new mock client factory + mockFactory := new(MockAuthClientFactory) + + // Create a mock auth client to return from the factory + mockAuthClient := new(MockAuthClient) + + // Setup mock to return a successful auth client + mockFactory.On("NewAuthenticationClient", mock.Anything, mock.Anything). + Return(mockAuthClient, nil) + + // Create a new WIAuthProvider instance + provider := NewAzureWIAuthProvider() + + // Replace authClientFactory with the mock factory + provider.authClientFactory = mockFactory.NewAuthenticationClient + + // Call authClientFactory to test successful return + client, err := provider.authClientFactory("https://myregistry.azurecr.io", nil) + + // Assert that the client is returned without an error + assert.NoError(t, err) + assert.NotNil(t, client) + + // Assert that the returned client is of the expected type + _, ok := client.(*MockAuthClient) + assert.True(t, ok, "expected client to be of type *MockAuthClient") + + // Verify that the mock was called + mockFactory.AssertCalled(t, "NewAuthenticationClient", "https://myregistry.azurecr.io", mock.Anything) } -func TestProvide_Success(t *testing.T) { - mockClient := new(mockAuthClient) +func TestWIProvide_Success(t *testing.T) { + mockClient := new(MockAuthClient) expectedRefreshToken := "mocked_refresh_token" mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "myregistry.azurecr.io", mock.Anything). Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ @@ -159,7 +235,7 @@ func TestProvide_Success(t *testing.T) { }, tenantID: "mockTenantID", clientID: "mockClientID", - authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (authClient, error) { + authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { return mockClient, nil }, getRegistryHost: func(_ string) (string, error) { @@ -177,10 +253,53 @@ func TestProvide_Success(t *testing.T) { authConfig, err := provider.Provide(context.Background(), "artifact") assert.NoError(t, err) + // Assert that GetAADAccessToken was not called + mockClient.AssertNotCalled(t, "GetAADAccessToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything) // Assert that the returned refresh token matches the expected one assert.Equal(t, expectedRefreshToken, authConfig.Password) } +func TestWIProvide_RefreshAAD(t *testing.T) { + // Arrange + mockAzureAuth := new(MockAzureAuth) + mockClient := new(MockAuthClient) + + provider := &WIAuthProvider{ + aadToken: confidential.AuthResult{ + AccessToken: "mockToken", + ExpiresOn: time.Now(), + }, + tenantID: "mockTenantID", + clientID: "mockClientID", + authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + return mockClient, nil + }, + getRegistryHost: func(_ string) (string, error) { + return "myregistry.azurecr.io", nil + }, + getAADAccessToken: mockAzureAuth.GetAADAccessToken, + reportMetrics: func(_ context.Context, _ int64, _ string) {}, + } + + mockAzureAuth.On("GetAADAccessToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(confidential.AuthResult{AccessToken: "newAccessToken", ExpiresOn: time.Now().Add(time.Hour)}, nil) + + mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "myregistry.azurecr.io", mock.Anything). + Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: new(string)}, + }, nil) + + ctx := context.TODO() + artifact := "testArtifact" + + // Act + _, err := provider.Provide(ctx, artifact) + + assert.NoError(t, err) + // Assert that GetAADAccessToken was not called + mockAzureAuth.AssertCalled(t, "GetAADAccessToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything) +} + func TestProvide_Failure_InvalidHostName(t *testing.T) { provider := &WIAuthProvider{ getRegistryHost: func(_ string) (string, error) { From cf5e7f0123afc51b0b4a1c9b368f222eda6630ac Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Thu, 3 Oct 2024 12:21:40 +1000 Subject: [PATCH 07/20] chore: remove unnecessary functions Signed-off-by: Shahram Kalantari --- .../oras/authprovider/azure/azureidentity.go | 15 ----- .../authprovider/azure/azureidentity_test.go | 66 ++++--------------- .../azure/azureworkloadidentity.go | 17 ----- .../azure/azureworkloadidentity_test.go | 57 +--------------- 4 files changed, 12 insertions(+), 143 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index 5cfc21b3c..380f692ad 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -42,21 +42,6 @@ type MIAuthProvider struct { getManagedIdentityToken func(ctx context.Context, clientID string) (azcore.AccessToken, error) } -// NewAzureWIAuthProvider is defined to enable mocking of some of the function in unit tests -func NewAzureMIAuthProvider() *MIAuthProvider { - return &MIAuthProvider{ - authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) - if err != nil { - return nil, err - } - return &AuthenticationClientWrapper{client: client}, nil - }, - getRegistryHost: provider.GetRegistryHostName, - getManagedIdentityToken: getManagedIdentityToken, - } -} - type azureManagedIdentityAuthProviderConf struct { Name string `json:"name"` ClientID string `json:"clientID"` diff --git a/pkg/common/oras/authprovider/azure/azureidentity_test.go b/pkg/common/oras/authprovider/azure/azureidentity_test.go index 680b6b489..af33afe44 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureidentity_test.go @@ -103,61 +103,6 @@ func TestAzureMSIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) { } } -func TestNewAzureMIAuthProvider_AuthenticationClientError(t *testing.T) { - // Create a new mock client factory - mockFactory := new(MockAuthClientFactory) - - // Setup mock to return an error - mockFactory.On("NewAuthenticationClient", mock.Anything, mock.Anything). - Return(nil, errors.New("failed to create authentication client")) - - // Create a new WIAuthProvider instance - provider := NewAzureMIAuthProvider() - provider.authClientFactory = mockFactory.NewAuthenticationClient - - // Call authClientFactory to test error handling - _, err := provider.authClientFactory("https://myregistry.azurecr.io", nil) - - // Assert that an error is returned - assert.Error(t, err) - assert.Equal(t, "failed to create authentication client", err.Error()) - - // Verify that the mock was called - mockFactory.AssertCalled(t, "NewAuthenticationClient", "https://myregistry.azurecr.io", mock.Anything) -} - -func TestNewAzureMIAuthProvider_Success(t *testing.T) { - // Create a new mock client factory - mockFactory := new(MockAuthClientFactory) - - // Create a mock auth client to return from the factory - mockAuthClient := new(MockAuthClient) - - // Setup mock to return a successful auth client - mockFactory.On("NewAuthenticationClient", mock.Anything, mock.Anything). - Return(mockAuthClient, nil) - - // Create a new WIAuthProvider instance - provider := NewAzureMIAuthProvider() - - // Replace authClientFactory with the mock factory - provider.authClientFactory = mockFactory.NewAuthenticationClient - - // Call authClientFactory to test successful return - client, err := provider.authClientFactory("https://myregistry.azurecr.io", nil) - - // Assert that the client is returned without an error - assert.NoError(t, err) - assert.NotNil(t, client) - - // Assert that the returned client is of the expected type - _, ok := client.(*MockAuthClient) - assert.True(t, ok, "expected client to be of type *MockAuthClient") - - // Verify that the mock was called - mockFactory.AssertCalled(t, "NewAuthenticationClient", "https://myregistry.azurecr.io", mock.Anything) -} - func TestMIProvide_Success(t *testing.T) { const registryHost = "myregistry.azurecr.io" mockClient := new(MockAuthClient) @@ -243,3 +188,14 @@ func TestMIProvide_RefreshAAD(t *testing.T) { assert.NoError(t, err) mockGetManagedIdentityToken.AssertCalled(t, "GetManagedIdentityToken", mock.Anything, "mockClientID") // Assert that getManagedIdentityToken was called } + +func TestMIProvide_Failure_InvalidHostName(t *testing.T) { + provider := &MIAuthProvider{ + getRegistryHost: func(_ string) (string, error) { + return "", errors.New("invalid hostname") + }, + } + + _, err := provider.Provide(context.Background(), "artifact") + assert.Error(t, err) +} diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index 77744622b..a5ed6d9f2 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -25,7 +25,6 @@ import ( re "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/internal/logger" provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider" - "github.com/ratify-project/ratify/pkg/metrics" "github.com/ratify-project/ratify/pkg/utils/azureauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" @@ -54,22 +53,6 @@ type AuthClient interface { ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) } -// NewAzureWIAuthProvider is defined to enable mocking of some of the function in unit tests -func NewAzureWIAuthProvider() *WIAuthProvider { - return &WIAuthProvider{ - authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) - if err != nil { - return nil, err - } - return &AuthenticationClientWrapper{client: client}, nil - }, - getRegistryHost: provider.GetRegistryHostName, - getAADAccessToken: azureauth.GetAADAccessToken, - reportMetrics: metrics.ReportACRExchangeDuration, - } -} - type azureWIAuthProviderConf struct { Name string `json:"name"` ClientID string `json:"clientID,omitempty"` diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index 1f6d3743b..bbfd29eb9 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -165,61 +165,6 @@ func TestAzureWIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) { } } -func TestNewAzureWIAuthProvider_AuthenticationClientError(t *testing.T) { - // Create a new mock client factory - mockFactory := new(MockAuthClientFactory) - - // Setup mock to return an error - mockFactory.On("NewAuthenticationClient", mock.Anything, mock.Anything). - Return(nil, errors.New("failed to create authentication client")) - - // Create a new WIAuthProvider instance - provider := NewAzureWIAuthProvider() - provider.authClientFactory = mockFactory.NewAuthenticationClient - - // Call authClientFactory to test error handling - _, err := provider.authClientFactory("https://myregistry.azurecr.io", nil) - - // Assert that an error is returned - assert.Error(t, err) - assert.Equal(t, "failed to create authentication client", err.Error()) - - // Verify that the mock was called - mockFactory.AssertCalled(t, "NewAuthenticationClient", "https://myregistry.azurecr.io", mock.Anything) -} - -func TestNewAzureWIAuthProvider_Success(t *testing.T) { - // Create a new mock client factory - mockFactory := new(MockAuthClientFactory) - - // Create a mock auth client to return from the factory - mockAuthClient := new(MockAuthClient) - - // Setup mock to return a successful auth client - mockFactory.On("NewAuthenticationClient", mock.Anything, mock.Anything). - Return(mockAuthClient, nil) - - // Create a new WIAuthProvider instance - provider := NewAzureWIAuthProvider() - - // Replace authClientFactory with the mock factory - provider.authClientFactory = mockFactory.NewAuthenticationClient - - // Call authClientFactory to test successful return - client, err := provider.authClientFactory("https://myregistry.azurecr.io", nil) - - // Assert that the client is returned without an error - assert.NoError(t, err) - assert.NotNil(t, client) - - // Assert that the returned client is of the expected type - _, ok := client.(*MockAuthClient) - assert.True(t, ok, "expected client to be of type *MockAuthClient") - - // Verify that the mock was called - mockFactory.AssertCalled(t, "NewAuthenticationClient", "https://myregistry.azurecr.io", mock.Anything) -} - func TestWIProvide_Success(t *testing.T) { mockClient := new(MockAuthClient) expectedRefreshToken := "mocked_refresh_token" @@ -300,7 +245,7 @@ func TestWIProvide_RefreshAAD(t *testing.T) { mockAzureAuth.AssertCalled(t, "GetAADAccessToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything) } -func TestProvide_Failure_InvalidHostName(t *testing.T) { +func TestWIProvide_Failure_InvalidHostName(t *testing.T) { provider := &WIAuthProvider{ getRegistryHost: func(_ string) (string, error) { return "", errors.New("invalid hostname") From f213c5f6ce9bc510b054c92a2b6dfe48f6554263 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Thu, 3 Oct 2024 14:37:22 +1000 Subject: [PATCH 08/20] fix: fix the bugs in the unit tests Signed-off-by: Shahram Kalantari --- pkg/common/oras/authprovider/azure/azureidentity_test.go | 5 +++++ .../oras/authprovider/azure/azureworkloadidentity_test.go | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/pkg/common/oras/authprovider/azure/azureidentity_test.go b/pkg/common/oras/authprovider/azure/azureidentity_test.go index af33afe44..11fb48f5f 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureidentity_test.go @@ -191,6 +191,11 @@ func TestMIProvide_RefreshAAD(t *testing.T) { func TestMIProvide_Failure_InvalidHostName(t *testing.T) { provider := &MIAuthProvider{ + tenantID: "test_tenant", + clientID: "test_client", + identityToken: azcore.AccessToken{ + Token: "test_token", + }, getRegistryHost: func(_ string) (string, error) { return "", errors.New("invalid hostname") }, diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index bbfd29eb9..1b58e18fe 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -247,6 +247,11 @@ func TestWIProvide_RefreshAAD(t *testing.T) { func TestWIProvide_Failure_InvalidHostName(t *testing.T) { provider := &WIAuthProvider{ + aadToken: confidential.AuthResult{ + AccessToken: "mockToken", + ExpiresOn: time.Now(), + }, + tenantID: "mockTenantID", getRegistryHost: func(_ string) (string, error) { return "", errors.New("invalid hostname") }, From 3ec6a2d5361e0367f8e2f481921bea1299079790 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Fri, 4 Oct 2024 12:07:46 +1000 Subject: [PATCH 09/20] fix: provide default implementation of the functions Signed-off-by: Shahram Kalantari --- .../oras/authprovider/azure/azureidentity.go | 8 +++--- .../azure/azureworkloadidentity.go | 25 ++++++++++++++++--- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index 380f692ad..85c7cf7ad 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -89,9 +89,11 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider } return &MIAuthProvider{ - identityToken: token, - clientID: client, - tenantID: tenant, + identityToken: token, + clientID: client, + tenantID: tenant, + authClientFactory: defaultAuthClientFactory, + getManagedIdentityToken: getManagedIdentityToken, }, nil } diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index a5ed6d9f2..c0cbf8377 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -41,6 +41,22 @@ type WIAuthProvider struct { reportMetrics func(ctx context.Context, duration int64, artifactHostName string) } +func defaultAuthClientFactory(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) + if err != nil { + return nil, err + } + return &AuthenticationClientWrapper{client: client}, nil +} + +func defaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { + return azureauth.GetAADAccessToken(ctx, tenantID, clientID, resource) +} + +func defaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) { + logger.GetLogger(ctx, logOpt).Infof("Metrics Report: Duration=%dms, Host=%s", duration, artifactHostName) +} + type AuthenticationClientWrapper struct { client *azcontainerregistry.AuthenticationClient } @@ -99,9 +115,12 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider } return &WIAuthProvider{ - aadToken: token, - tenantID: tenant, - clientID: clientID, + aadToken: token, + tenantID: tenant, + clientID: clientID, + authClientFactory: defaultAuthClientFactory, + getAADAccessToken: defaultGetAADAccessToken, + reportMetrics: defaultReportMetrics, }, nil } From 23cf3d2a68de48a7b96ce3b6f66ebf8ef3a0971b Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Wed, 9 Oct 2024 10:25:23 +1000 Subject: [PATCH 10/20] chore: refactor Signed-off-by: Shahram Kalantari --- .../oras/authprovider/azure/azureidentity.go | 2 +- .../azure/azureworkloadidentity.go | 34 ++---------- .../azure/azureworkloadidentity_test.go | 12 ----- pkg/common/oras/authprovider/azure/helper.go | 53 +++++++++++++++++++ 4 files changed, 57 insertions(+), 44 deletions(-) create mode 100644 pkg/common/oras/authprovider/azure/helper.go diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index 85c7cf7ad..f9f4adb67 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -92,7 +92,7 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider identityToken: token, clientID: client, tenantID: tenant, - authClientFactory: defaultAuthClientFactory, + authClientFactory: DefaultAuthClientFactory, getManagedIdentityToken: getManagedIdentityToken, }, nil } diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index c0cbf8377..57033100b 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -41,34 +41,6 @@ type WIAuthProvider struct { reportMetrics func(ctx context.Context, duration int64, artifactHostName string) } -func defaultAuthClientFactory(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) - if err != nil { - return nil, err - } - return &AuthenticationClientWrapper{client: client}, nil -} - -func defaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { - return azureauth.GetAADAccessToken(ctx, tenantID, clientID, resource) -} - -func defaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) { - logger.GetLogger(ctx, logOpt).Infof("Metrics Report: Duration=%dms, Host=%s", duration, artifactHostName) -} - -type AuthenticationClientWrapper struct { - client *azcontainerregistry.AuthenticationClient -} - -func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { - return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, azcontainerregistry.PostContentSchemaGrantType(grantType), service, options) -} - -type AuthClient interface { - ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) -} - type azureWIAuthProviderConf struct { Name string `json:"name"` ClientID string `json:"clientID,omitempty"` @@ -118,9 +90,9 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider aadToken: token, tenantID: tenant, clientID: clientID, - authClientFactory: defaultAuthClientFactory, - getAADAccessToken: defaultGetAADAccessToken, - reportMetrics: defaultReportMetrics, + authClientFactory: DefaultAuthClientFactory, + getAADAccessToken: DefaultGetAADAccessToken, + reportMetrics: DefaultReportMetrics, }, nil } diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index 1b58e18fe..17ce740c8 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -38,18 +38,6 @@ type MockAzureAuth struct { mock.Mock } -type MockAuthClientFactory struct { - mock.Mock -} - -func (m *MockAuthClientFactory) NewAuthenticationClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - args := m.Called(serverURL, options) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).(AuthClient), args.Error(1) -} - func (m *MockAzureAuth) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { args := m.Called(ctx, tenantID, clientID, resource) return args.Get(0).(confidential.AuthResult), args.Error(1) diff --git a/pkg/common/oras/authprovider/azure/helper.go b/pkg/common/oras/authprovider/azure/helper.go new file mode 100644 index 000000000..db679feb6 --- /dev/null +++ b/pkg/common/oras/authprovider/azure/helper.go @@ -0,0 +1,53 @@ +/* +Copyright The Ratify Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/ratify-project/ratify/internal/logger" + "github.com/ratify-project/ratify/pkg/utils/azureauth" +) + +func DefaultAuthClientFactory(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) + if err != nil { + return nil, err + } + return &AuthenticationClientWrapper{client: client}, nil +} + +func DefaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { + return azureauth.GetAADAccessToken(ctx, tenantID, clientID, resource) +} + +func DefaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) { + logger.GetLogger(ctx, logOpt).Infof("Metrics Report: Duration=%dms, Host=%s", duration, artifactHostName) +} + +type AuthenticationClientWrapper struct { + client *azcontainerregistry.AuthenticationClient +} + +func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { + return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, azcontainerregistry.PostContentSchemaGrantType(grantType), service, options) +} + +type AuthClient interface { + ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) +} From 598ff9e521369ac4a190e00f1d830af9e344009a Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Thu, 10 Oct 2024 11:51:58 +1000 Subject: [PATCH 11/20] chore: refactor azureworkloadidentity Signed-off-by: Shahram Kalantari --- .../azure/azureworkloadidentity.go | 76 +++- .../azure/azureworkloadidentity_test.go | 420 +++++++++++++----- 2 files changed, 382 insertions(+), 114 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index 57033100b..827740d34 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -25,20 +25,69 @@ import ( re "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/internal/logger" provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider" - "github.com/ratify-project/ratify/pkg/utils/azureauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) +// AuthClientFactory defines an interface for creating an authentication client. +type AuthClientFactory interface { + CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) +} + +// RegistryHostGetter defines an interface for getting the registry host. +type RegistryHostGetter interface { + GetRegistryHost(artifact string) (string, error) +} + +// AADAccessTokenGetter defines an interface for getting an AAD access token. +type AADAccessTokenGetter interface { + GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) +} + +// MetricsReporter defines an interface for reporting metrics. +type MetricsReporter interface { + ReportMetrics(ctx context.Context, duration int64, artifactHostName string) +} + +// DefaultAuthClientFactoryImpl is the default implementation of AuthClientFactory. +type DefaultAuthClientFactoryImpl struct{} + +func (f *DefaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + return DefaultAuthClientFactory(serverURL, options) +} + +// DefaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter. +type DefaultRegistryHostGetterImpl struct{} + +func (g *DefaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) { + // Implement the logic to get the registry host + return provider.GetRegistryHostName(artifact) + // return artifactHost, nil // Replace with actual logic +} + +// DefaultAADAccessTokenGetterImpl is the default implementation of AADAccessTokenGetter. +type DefaultAADAccessTokenGetterImpl struct{} + +func (g *DefaultAADAccessTokenGetterImpl) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { + return DefaultGetAADAccessToken(ctx, tenantID, clientID, resource) +} + +// DefaultMetricsReporterImpl is the default implementation of MetricsReporter. +type DefaultMetricsReporterImpl struct{} + +func (r *DefaultMetricsReporterImpl) ReportMetrics(ctx context.Context, duration int64, artifactHostName string) { + DefaultReportMetrics(ctx, duration, artifactHostName) +} + type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name type WIAuthProvider struct { aadToken confidential.AuthResult tenantID string clientID string - authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) - getRegistryHost func(artifact string) (string, error) - getAADAccessToken func(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) - reportMetrics func(ctx context.Context, duration int64, artifactHostName string) + authClientFactory AuthClientFactory + getRegistryHost RegistryHostGetter + getAADAccessToken AADAccessTokenGetter + reportMetrics MetricsReporter } type azureWIAuthProviderConf struct { @@ -81,7 +130,7 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider } // retrieve an AAD Access token - token, err := azureauth.GetAADAccessToken(context.Background(), tenant, clientID, AADResource) + token, err := DefaultGetAADAccessToken(context.Background(), tenant, clientID, AADResource) if err != nil { return nil, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "", re.HideStackTrace) } @@ -90,9 +139,10 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider aadToken: token, tenantID: tenant, clientID: clientID, - authClientFactory: DefaultAuthClientFactory, - getAADAccessToken: DefaultGetAADAccessToken, - reportMetrics: DefaultReportMetrics, + authClientFactory: &DefaultAuthClientFactoryImpl{}, // Concrete implementation + getRegistryHost: &DefaultRegistryHostGetterImpl{}, // Concrete implementation + getAADAccessToken: &DefaultAADAccessTokenGetterImpl{}, // Concrete implementation + reportMetrics: &DefaultMetricsReporterImpl{}, }, nil } @@ -118,14 +168,14 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider } // parse the artifact reference string to extract the registry host name - artifactHostName, err := d.getRegistryHost(artifact) + artifactHostName, err := d.getRegistryHost.GetRegistryHost(artifact) if err != nil { return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider) } // need to refresh AAD token if it's expired if time.Now().Add(time.Minute * 5).After(d.aadToken.ExpiresOn) { - newToken, err := d.getAADAccessToken(ctx, d.tenantID, d.clientID, AADResource) + newToken, err := d.getAADAccessToken.GetAADAccessToken(ctx, d.tenantID, d.clientID, AADResource) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, nil, "could not refresh AAD token", re.HideStackTrace) } @@ -138,7 +188,7 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider // TODO: Consider adding authentication client options for multicloud scenarios var options *azcontainerregistry.AuthenticationClientOptions - client, err := d.authClientFactory(serverURL, options) + client, err := d.authClientFactory.CreateAuthClient(serverURL, options) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry", re.HideStackTrace) } @@ -158,7 +208,7 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider } rt := response.ACRRefreshToken - d.reportMetrics(ctx, time.Since(startTime).Milliseconds(), artifactHostName) + d.reportMetrics.ReportMetrics(ctx, time.Since(startTime).Milliseconds(), artifactHostName) refreshTokenExpiry := getACRExpiryIfEarlier(d.aadToken.ExpiresOn) authConfig := provider.AuthConfig{ diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index 17ce740c8..c8be41b6f 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -22,32 +22,346 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ratifyerrors "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/pkg/common/oras/authprovider" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + + azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) -type MockAuthClient struct { +// MockAuthClientFactory for creating AuthClient +type MockAuthClientFactory struct { mock.Mock } -type MockAzureAuth struct { +func (m *MockAuthClientFactory) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + args := m.Called(serverURL, options) + return args.Get(0).(AuthClient), args.Error(1) +} + +// MockRegistryHostGetter for retrieving registry host +type MockRegistryHostGetter struct { mock.Mock } -func (m *MockAzureAuth) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { +func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error) { + args := m.Called(artifact) + return args.String(0), args.Error(1) +} + +// MockAADAccessTokenGetter for retrieving AAD access token +type MockAADAccessTokenGetter struct { + mock.Mock +} + +func (m *MockAADAccessTokenGetter) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { args := m.Called(ctx, tenantID, clientID, resource) return args.Get(0).(confidential.AuthResult), args.Error(1) } +// MockMetricsReporter for reporting metrics +type MockMetricsReporter struct { + mock.Mock +} + +func (m *MockMetricsReporter) ReportMetrics(ctx context.Context, duration int64, artifactHostName string) { + m.Called(ctx, duration, artifactHostName) +} + +// MockAuthClient for the Azure auth client +type MockAuthClient struct { + mock.Mock +} + func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { args := m.Called(ctx, grantType, service, options) return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1) } +// Test for successful Provide function +func TestWIAuthProvider_Provide_Success(t *testing.T) { + // Mock all dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + mockAuthClient := new(MockAuthClient) + + // Mock AAD token + initialToken := confidential.AuthResult{AccessToken: "initial_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + refreshTokenString := "new_refresh_token" + refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &refreshTokenString}, + } + + // Set expectations for mocked functions + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(initialToken, nil) + mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() + + // Create WIAuthProvider + provider := WIAuthProvider{ + aadToken: initialToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + authConfig, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.NoError(t, err) + assert.Equal(t, "new_refresh_token", authConfig.Password) +} + +// Test for AAD token refresh logic +func TestWIAuthProvider_Provide_RefreshToken(t *testing.T) { + // Mock all dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + mockAuthClient := new(MockAuthClient) + + // Mock expired AAD token, and refreshed token + expiredToken := confidential.AuthResult{AccessToken: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + newToken := confidential.AuthResult{AccessToken: "new_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + refreshTokenString := "refreshed_token" + refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &refreshTokenString}, + } + + // Set expectations for mocked functions + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil) + mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() + + // Create WIAuthProvider with expired token + provider := WIAuthProvider{ + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + authConfig, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.NoError(t, err) + assert.Equal(t, "refreshed_token", authConfig.Password) +} + +// Test for failure when GetAADAccessToken fails +func TestWIAuthProvider_Provide_AADTokenFailure(t *testing.T) { + // Mock all dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + + // Mock expired AAD token, and failure to refresh + expiredToken := confidential.AuthResult{AccessToken: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + + // Set expectations for mocked functions + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(confidential.AuthResult{}, errors.New("token refresh failed")) + + // Create WIAuthProvider with expired token + provider := WIAuthProvider{ + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.Error(t, err) + assert.Contains(t, err.Error(), "could not refresh AAD token") +} + +// Test when tenant ID is missing from the environment +func TestAzureWIProviderFactory_Create_NoTenantID(t *testing.T) { + // Clear the tenant ID environment variable + t.Setenv("AZURE_TENANT_ID", "") + + // Initialize provider factory + factory := &AzureWIProviderFactory{} + + // Call Create with minimal configuration + _, err := factory.Create(map[string]interface{}{}) + + // Expect error related to missing tenant ID + assert.Error(t, err) + assert.Contains(t, err.Error(), "azure tenant id environment variable is empty") +} + +// Test when client ID is missing from the environment +func TestAzureWIProviderFactory_Create_NoClientID(t *testing.T) { + // Set tenant ID but leave client ID empty + t.Setenv("AZURE_TENANT_ID", "tenantID") + t.Setenv("AZURE_CLIENT_ID", "") + + // Initialize provider factory + factory := &AzureWIProviderFactory{} + + // Call Create with minimal configuration + _, err := factory.Create(map[string]interface{}{}) + + // Expect error related to missing client ID + assert.Error(t, err) + assert.Contains(t, err.Error(), "no client ID provided and AZURE_CLIENT_ID environment variable is empty") +} + +// Test for successful token refresh +func TestWIAuthProvider_Provide_TokenRefresh_Success(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + mockAuthClient := new(MockAuthClient) + + // Mock expired AAD token and refreshed token + expiredToken := confidential.AuthResult{AccessToken: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + refreshTokenString := "refreshed_token" + newToken := confidential.AuthResult{AccessToken: "new_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &refreshTokenString}, + } + + // Set expectations + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil) + mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() + + // Create WIAuthProvider with expired token + provider := WIAuthProvider{ + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + authConfig, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.NoError(t, err) + assert.Equal(t, "refreshed_token", authConfig.Password) +} + +// Test when token refresh fails +func TestWIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + + // Mock expired AAD token and failure to refresh + expiredToken := confidential.AuthResult{AccessToken: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + + // Set expectations + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(confidential.AuthResult{}, errors.New("token refresh failed")) + + // Create WIAuthProvider with expired token + provider := WIAuthProvider{ + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.Error(t, err) + assert.Contains(t, err.Error(), "could not refresh AAD token") +} + +// Test for handling empty AccessToken +func TestWIAuthProvider_Enabled_NoAccessToken(t *testing.T) { + // Create a provider with no AccessToken + provider := WIAuthProvider{ + tenantID: "tenantID", + clientID: "clientID", + aadToken: confidential.AuthResult{AccessToken: ""}, + } + + // Assert that provider is not enabled + enabled := provider.Enabled(context.Background()) + assert.False(t, enabled) +} + +// Test for invalid hostname retrieval +func TestWIAuthProvider_Provide_InvalidHostName(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + + // Mock valid AAD token + validToken := confidential.AuthResult{AccessToken: "valid_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + + // Set expectations for an invalid hostname + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("", errors.New("invalid hostname")) + + // Create WIAuthProvider with valid token + provider := WIAuthProvider{ + aadToken: validToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.Error(t, err) + assert.Contains(t, err.Error(), "HOST_NAME_INVALID") +} + // Verifies that Enabled checks if tenantID is empty or AAD token is empty func TestAzureWIEnabled_ExpectedResults(t *testing.T) { azAuthProvider := WIAuthProvider{ @@ -152,99 +466,3 @@ func TestAzureWIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) { t.Fatalf("create auth provider should have failed: expected err %s, but got err %s", expectedErr, err) } } - -func TestWIProvide_Success(t *testing.T) { - mockClient := new(MockAuthClient) - expectedRefreshToken := "mocked_refresh_token" - mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "myregistry.azurecr.io", mock.Anything). - Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ - ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &expectedRefreshToken}, - }, nil) - - provider := &WIAuthProvider{ - aadToken: confidential.AuthResult{ - AccessToken: "mockToken", - ExpiresOn: time.Now().Add(time.Hour), - }, - tenantID: "mockTenantID", - clientID: "mockClientID", - authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - return mockClient, nil - }, - getRegistryHost: func(_ string) (string, error) { - return "myregistry.azurecr.io", nil - }, - getAADAccessToken: func(_ context.Context, _, _, _ string) (confidential.AuthResult, error) { - return confidential.AuthResult{ - AccessToken: "mockToken", - ExpiresOn: time.Now().Add(time.Hour), - }, nil - }, - reportMetrics: func(_ context.Context, _ int64, _ string) {}, - } - - authConfig, err := provider.Provide(context.Background(), "artifact") - - assert.NoError(t, err) - // Assert that GetAADAccessToken was not called - mockClient.AssertNotCalled(t, "GetAADAccessToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything) - // Assert that the returned refresh token matches the expected one - assert.Equal(t, expectedRefreshToken, authConfig.Password) -} - -func TestWIProvide_RefreshAAD(t *testing.T) { - // Arrange - mockAzureAuth := new(MockAzureAuth) - mockClient := new(MockAuthClient) - - provider := &WIAuthProvider{ - aadToken: confidential.AuthResult{ - AccessToken: "mockToken", - ExpiresOn: time.Now(), - }, - tenantID: "mockTenantID", - clientID: "mockClientID", - authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - return mockClient, nil - }, - getRegistryHost: func(_ string) (string, error) { - return "myregistry.azurecr.io", nil - }, - getAADAccessToken: mockAzureAuth.GetAADAccessToken, - reportMetrics: func(_ context.Context, _ int64, _ string) {}, - } - - mockAzureAuth.On("GetAADAccessToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(confidential.AuthResult{AccessToken: "newAccessToken", ExpiresOn: time.Now().Add(time.Hour)}, nil) - - mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "myregistry.azurecr.io", mock.Anything). - Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ - ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: new(string)}, - }, nil) - - ctx := context.TODO() - artifact := "testArtifact" - - // Act - _, err := provider.Provide(ctx, artifact) - - assert.NoError(t, err) - // Assert that GetAADAccessToken was not called - mockAzureAuth.AssertCalled(t, "GetAADAccessToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything) -} - -func TestWIProvide_Failure_InvalidHostName(t *testing.T) { - provider := &WIAuthProvider{ - aadToken: confidential.AuthResult{ - AccessToken: "mockToken", - ExpiresOn: time.Now(), - }, - tenantID: "mockTenantID", - getRegistryHost: func(_ string) (string, error) { - return "", errors.New("invalid hostname") - }, - } - - _, err := provider.Provide(context.Background(), "artifact") - assert.Error(t, err) -} From 03688b033d674747ddd9509909e7f21bc525d418 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Thu, 10 Oct 2024 14:22:46 +1000 Subject: [PATCH 12/20] chore: refactor azureidentity.go Signed-off-by: Shahram Kalantari --- .../oras/authprovider/azure/azureidentity.go | 29 ++- .../authprovider/azure/azureidentity_test.go | 207 +++++++++++------- 2 files changed, 143 insertions(+), 93 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index f9f4adb67..8f89a0f9d 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -33,13 +33,26 @@ import ( ) type azureManagedIdentityProviderFactory struct{} + +// ManagedIdentityTokenGetter defines an interface for getting a managed identity token. +type ManagedIdentityTokenGetter interface { + GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) +} + +// DefaultManagedIdentityTokenGetterImpl is the default implementation of AADAccessTokenGetter. +type DefaultManagedIdentityTokenGetterImpl struct{} + +func (g *DefaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { + return getManagedIdentityToken(ctx, clientID) +} + type MIAuthProvider struct { identityToken azcore.AccessToken clientID string tenantID string - authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) - getRegistryHost func(artifact string) (string, error) - getManagedIdentityToken func(ctx context.Context, clientID string) (azcore.AccessToken, error) + authClientFactory AuthClientFactory + getRegistryHost RegistryHostGetter + getManagedIdentityToken ManagedIdentityTokenGetter } type azureManagedIdentityAuthProviderConf struct { @@ -92,8 +105,8 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider identityToken: token, clientID: client, tenantID: tenant, - authClientFactory: DefaultAuthClientFactory, - getManagedIdentityToken: getManagedIdentityToken, + authClientFactory: &DefaultAuthClientFactoryImpl{}, // Concrete implementation + getManagedIdentityToken: &DefaultManagedIdentityTokenGetterImpl{}, // Concrete implementation }, nil } @@ -123,14 +136,14 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider } // parse the artifact reference string to extract the registry host name - artifactHostName, err := d.getRegistryHost(artifact) + artifactHostName, err := d.getRegistryHost.GetRegistryHost(artifact) if err != nil { return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider) } // need to refresh AAD token if it's expired if time.Now().Add(time.Minute * 5).After(d.identityToken.ExpiresOn) { - newToken, err := d.getManagedIdentityToken(ctx, d.clientID) + newToken, err := d.getManagedIdentityToken.GetManagedIdentityToken(ctx, d.clientID) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "could not refresh azure managed identity token", re.HideStackTrace) } @@ -143,7 +156,7 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider // TODO: Consider adding authentication client options for multicloud scenarios var options *azcontainerregistry.AuthenticationClientOptions - client, err := d.authClientFactory(serverURL, options) + client, err := d.authClientFactory.CreateAuthClient(serverURL, options) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry by azure managed identity token", re.HideStackTrace) } diff --git a/pkg/common/oras/authprovider/azure/azureidentity_test.go b/pkg/common/oras/authprovider/azure/azureidentity_test.go index 11fb48f5f..8cb08b419 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureidentity_test.go @@ -23,18 +23,20 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" + azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" ratifyerrors "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/pkg/common/oras/authprovider" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) -type MockGetManagedIdentityToken struct { +// Mock types for external dependencies +type MockManagedIdentityTokenGetter struct { mock.Mock } -func (m *MockGetManagedIdentityToken) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { +// Mock ManagedIdentityTokenGetter.GetManagedIdentityToken +func (m *MockManagedIdentityTokenGetter) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { args := m.Called(ctx, clientID) return args.Get(0).(azcore.AccessToken), args.Error(1) } @@ -103,104 +105,139 @@ func TestAzureMSIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) { } } -func TestMIProvide_Success(t *testing.T) { - const registryHost = "myregistry.azurecr.io" - mockClient := new(MockAuthClient) - expectedRefreshToken := "mocked_refresh_token" - mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", registryHost, mock.Anything). - Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ - ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &expectedRefreshToken}, - }, nil) +// Test for invalid configuration when tenant ID is missing +func TestAzureManagedIdentityProviderFactory_Create_NoTenantID(t *testing.T) { + t.Setenv("AZURE_TENANT_ID", "") - provider := &MIAuthProvider{ - identityToken: azcore.AccessToken{ - Token: "mockToken", - ExpiresOn: time.Now().Add(time.Hour), - }, - tenantID: "mockTenantID", - clientID: "mockClientID", - authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - return mockClient, nil - }, - getRegistryHost: func(_ string) (string, error) { - return registryHost, nil - }, - getManagedIdentityToken: func(_ context.Context, _ string) (azcore.AccessToken, error) { - return azcore.AccessToken{ - Token: "mockToken", - ExpiresOn: time.Now().Add(time.Hour), - }, nil - }, - } + // Initialize factory + factory := &azureManagedIdentityProviderFactory{} - authConfig, err := provider.Provide(context.Background(), "artifact") + // Attempt to create MIAuthProvider with empty configuration + _, err := factory.Create(map[string]interface{}{}) - assert.NoError(t, err) - // Assert that getManagedIdentityToken was not called - mockClient.AssertNotCalled(t, "getManagedIdentityToken", mock.Anything, mock.Anything) - // Assert that the returned refresh token matches the expected one - assert.Equal(t, expectedRefreshToken, authConfig.Password) + // Validate the error + assert.Error(t, err) + assert.Contains(t, err.Error(), "AZURE_TENANT_ID environment variable is empty") } -func TestMIProvide_RefreshAAD(t *testing.T) { - const registryHost = "myregistry.azurecr.io" - // Arrange - mockClient := new(MockAuthClient) +// Test for missing client ID +func TestAzureManagedIdentityProviderFactory_Create_NoClientID(t *testing.T) { + t.Setenv("AZURE_TENANT_ID", "tenantID") + t.Setenv("AZURE_CLIENT_ID", "") - // Create a mock function for getManagedIdentityToken - mockGetManagedIdentityToken := new(MockGetManagedIdentityToken) + // Initialize factory + factory := &azureManagedIdentityProviderFactory{} - provider := &MIAuthProvider{ - identityToken: azcore.AccessToken{ - Token: "mockToken", - ExpiresOn: time.Now(), // Expired token to force a refresh - }, - tenantID: "mockTenantID", - clientID: "mockClientID", - authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - return mockClient, nil - }, - getRegistryHost: func(_ string) (string, error) { - return registryHost, nil - }, - getManagedIdentityToken: mockGetManagedIdentityToken.GetManagedIdentityToken, // Use the mock - } + // Attempt to create MIAuthProvider with empty client ID + _, err := factory.Create(map[string]interface{}{}) - mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", registryHost, mock.Anything). - Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ - ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: new(string)}, - }, nil) + // Validate the error + assert.Error(t, err) + assert.Contains(t, err.Error(), "AZURE_CLIENT_ID environment variable is empty") +} - // Set up the expectation for the mocked method - mockGetManagedIdentityToken.On("GetManagedIdentityToken", mock.Anything, "mockClientID"). - Return(azcore.AccessToken{ - Token: "newMockToken", - ExpiresOn: time.Now().Add(time.Hour), - }, nil) +// Test successful token refresh +func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter) + mockAuthClient := new(MockAuthClient) + + // Define token values + expiredToken := azcore.AccessToken{Token: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + newTokenString := "refreshed_token" + newAADToken := azcore.AccessToken{Token: "new_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &newTokenString}, + } - ctx := context.TODO() - artifact := "testArtifact" + // Setup mock expectations + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(newAADToken, nil) + + // Initialize provider with expired token + provider := MIAuthProvider{ + identityToken: expiredToken, + clientID: "clientID", + tenantID: "tenantID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getManagedIdentityToken: mockManagedIdentityTokenGetter, + } - // Act - _, err := provider.Provide(ctx, artifact) + // Call Provide method + ctx := context.Background() + authConfig, err := provider.Provide(ctx, "artifact_name") - // Assert + // Validate success and token refresh assert.NoError(t, err) - mockGetManagedIdentityToken.AssertCalled(t, "GetManagedIdentityToken", mock.Anything, "mockClientID") // Assert that getManagedIdentityToken was called + assert.Equal(t, "refreshed_token", authConfig.Password) } -func TestMIProvide_Failure_InvalidHostName(t *testing.T) { - provider := &MIAuthProvider{ - tenantID: "test_tenant", - clientID: "test_client", - identityToken: azcore.AccessToken{ - Token: "test_token", - }, - getRegistryHost: func(_ string) (string, error) { - return "", errors.New("invalid hostname") - }, +// Test failed token refresh +func TestMIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter) + + // Define token values + expiredToken := azcore.AccessToken{Token: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + + // Setup mock expectations + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(azcore.AccessToken{}, errors.New("token refresh failed")) + + // Initialize provider with expired token + provider := MIAuthProvider{ + identityToken: expiredToken, + clientID: "clientID", + tenantID: "tenantID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getManagedIdentityToken: mockManagedIdentityTokenGetter, + } + + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Validate failure + assert.Error(t, err) + assert.Contains(t, err.Error(), "could not refresh azure managed identity token") +} + +// Test for invalid hostname retrieval +func TestMIAuthProvider_Provide_InvalidHostName(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter) + + // Define valid token + validToken := azcore.AccessToken{Token: "valid_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + + // Setup mock expectations for invalid hostname + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("", errors.New("invalid hostname")) + + // Initialize provider with valid token + provider := MIAuthProvider{ + identityToken: validToken, + clientID: "clientID", + tenantID: "tenantID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getManagedIdentityToken: mockManagedIdentityTokenGetter, } - _, err := provider.Provide(context.Background(), "artifact") + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Validate failure assert.Error(t, err) + assert.Contains(t, err.Error(), "HOST_NAME_INVALID") } From 2a0c8d86715cecb60d64457abfb3032ed9d2b814 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Fri, 11 Oct 2024 14:30:41 +1000 Subject: [PATCH 13/20] chore: move common code to helper.go Signed-off-by: Shahram Kalantari --- .../oras/authprovider/azure/azureidentity.go | 34 ++--- .../authprovider/azure/azureidentity_test.go | 4 +- .../azure/azureworkloadidentity.go | 46 ++---- .../azure/azureworkloadidentity_test.go | 30 ---- pkg/common/oras/authprovider/azure/helper.go | 37 +++-- .../oras/authprovider/azure/helper_test.go | 136 ++++++++++++++++++ 6 files changed, 196 insertions(+), 91 deletions(-) create mode 100644 pkg/common/oras/authprovider/azure/helper_test.go diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index 8f89a0f9d..fd8466762 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -32,20 +32,34 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" ) -type azureManagedIdentityProviderFactory struct{} - // ManagedIdentityTokenGetter defines an interface for getting a managed identity token. type ManagedIdentityTokenGetter interface { GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) } -// DefaultManagedIdentityTokenGetterImpl is the default implementation of AADAccessTokenGetter. +// DefaultManagedIdentityTokenGetterImpl is the default implementation of getManagedIdentityToken. type DefaultManagedIdentityTokenGetterImpl struct{} func (g *DefaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { return getManagedIdentityToken(ctx, clientID) } +func getManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { + id := azidentity.ClientID(clientID) + opts := azidentity.ManagedIdentityCredentialOptions{ID: id} + cred, err := azidentity.NewManagedIdentityCredential(&opts) + if err != nil { + return azcore.AccessToken{}, err + } + scopes := []string{AADResource} + if cred != nil { + return cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) + } + return azcore.AccessToken{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("config is nil pointer for GetServicePrincipalToken") +} + +type azureManagedIdentityProviderFactory struct{} + type MIAuthProvider struct { identityToken azcore.AccessToken clientID string @@ -185,17 +199,3 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider return authConfig, nil } - -func getManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { - id := azidentity.ClientID(clientID) - opts := azidentity.ManagedIdentityCredentialOptions{ID: id} - cred, err := azidentity.NewManagedIdentityCredential(&opts) - if err != nil { - return azcore.AccessToken{}, err - } - scopes := []string{AADResource} - if cred != nil { - return cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) - } - return azcore.AccessToken{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("config is nil pointer for GetServicePrincipalToken") -} diff --git a/pkg/common/oras/authprovider/azure/azureidentity_test.go b/pkg/common/oras/authprovider/azure/azureidentity_test.go index 8cb08b419..a47854252 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureidentity_test.go @@ -146,7 +146,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) { // Define token values expiredToken := azcore.AccessToken{Token: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} - newTokenString := "refreshed_token" + newTokenString := "refreshed" newAADToken := azcore.AccessToken{Token: "new_token", ExpiresOn: time.Now().Add(10 * time.Minute)} refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &newTokenString}, @@ -174,7 +174,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) { // Validate success and token refresh assert.NoError(t, err) - assert.Equal(t, "refreshed_token", authConfig.Password) + assert.Equal(t, "refreshed", authConfig.Password) } // Test failed token refresh diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index 827740d34..bad1ca315 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -25,46 +25,16 @@ import ( re "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/internal/logger" provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider" + "github.com/ratify-project/ratify/pkg/utils/azureauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) -// AuthClientFactory defines an interface for creating an authentication client. -type AuthClientFactory interface { - CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) -} - -// RegistryHostGetter defines an interface for getting the registry host. -type RegistryHostGetter interface { - GetRegistryHost(artifact string) (string, error) -} - // AADAccessTokenGetter defines an interface for getting an AAD access token. type AADAccessTokenGetter interface { GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) } -// MetricsReporter defines an interface for reporting metrics. -type MetricsReporter interface { - ReportMetrics(ctx context.Context, duration int64, artifactHostName string) -} - -// DefaultAuthClientFactoryImpl is the default implementation of AuthClientFactory. -type DefaultAuthClientFactoryImpl struct{} - -func (f *DefaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - return DefaultAuthClientFactory(serverURL, options) -} - -// DefaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter. -type DefaultRegistryHostGetterImpl struct{} - -func (g *DefaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) { - // Implement the logic to get the registry host - return provider.GetRegistryHostName(artifact) - // return artifactHost, nil // Replace with actual logic -} - // DefaultAADAccessTokenGetterImpl is the default implementation of AADAccessTokenGetter. type DefaultAADAccessTokenGetterImpl struct{} @@ -72,6 +42,15 @@ func (g *DefaultAADAccessTokenGetterImpl) GetAADAccessToken(ctx context.Context, return DefaultGetAADAccessToken(ctx, tenantID, clientID, resource) } +func DefaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { + return azureauth.GetAADAccessToken(ctx, tenantID, clientID, resource) +} + +// MetricsReporter defines an interface for reporting metrics. +type MetricsReporter interface { + ReportMetrics(ctx context.Context, duration int64, artifactHostName string) +} + // DefaultMetricsReporterImpl is the default implementation of MetricsReporter. type DefaultMetricsReporterImpl struct{} @@ -79,7 +58,12 @@ func (r *DefaultMetricsReporterImpl) ReportMetrics(ctx context.Context, duration DefaultReportMetrics(ctx, duration, artifactHostName) } +func DefaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) { + logger.GetLogger(ctx, logOpt).Infof("Metrics Report: Duration=%dms, Host=%s", duration, artifactHostName) +} + type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name + type WIAuthProvider struct { aadToken confidential.AuthResult tenantID string diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index c8be41b6f..7924600ba 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -31,26 +31,6 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) -// MockAuthClientFactory for creating AuthClient -type MockAuthClientFactory struct { - mock.Mock -} - -func (m *MockAuthClientFactory) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - args := m.Called(serverURL, options) - return args.Get(0).(AuthClient), args.Error(1) -} - -// MockRegistryHostGetter for retrieving registry host -type MockRegistryHostGetter struct { - mock.Mock -} - -func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error) { - args := m.Called(artifact) - return args.String(0), args.Error(1) -} - // MockAADAccessTokenGetter for retrieving AAD access token type MockAADAccessTokenGetter struct { mock.Mock @@ -70,16 +50,6 @@ func (m *MockMetricsReporter) ReportMetrics(ctx context.Context, duration int64, m.Called(ctx, duration, artifactHostName) } -// MockAuthClient for the Azure auth client -type MockAuthClient struct { - mock.Mock -} - -func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { - args := m.Called(ctx, grantType, service, options) - return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1) -} - // Test for successful Provide function func TestWIAuthProvider_Provide_Success(t *testing.T) { // Mock all dependencies diff --git a/pkg/common/oras/authprovider/azure/helper.go b/pkg/common/oras/authprovider/azure/helper.go index db679feb6..2e9da285f 100644 --- a/pkg/common/oras/authprovider/azure/helper.go +++ b/pkg/common/oras/authprovider/azure/helper.go @@ -19,11 +19,21 @@ import ( "context" "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" - "github.com/ratify-project/ratify/internal/logger" - "github.com/ratify-project/ratify/pkg/utils/azureauth" + provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider" ) +// AuthClientFactory defines an interface for creating an authentication client. +type AuthClientFactory interface { + CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) +} + +// DefaultAuthClientFactoryImpl is the default implementation of AuthClientFactory. +type DefaultAuthClientFactoryImpl struct{} + +func (f *DefaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + return DefaultAuthClientFactory(serverURL, options) +} + func DefaultAuthClientFactory(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) if err != nil { @@ -32,14 +42,6 @@ func DefaultAuthClientFactory(serverURL string, options *azcontainerregistry.Aut return &AuthenticationClientWrapper{client: client}, nil } -func DefaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { - return azureauth.GetAADAccessToken(ctx, tenantID, clientID, resource) -} - -func DefaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) { - logger.GetLogger(ctx, logOpt).Infof("Metrics Report: Duration=%dms, Host=%s", duration, artifactHostName) -} - type AuthenticationClientWrapper struct { client *azcontainerregistry.AuthenticationClient } @@ -51,3 +53,16 @@ func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(c type AuthClient interface { ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) } + +// RegistryHostGetter defines an interface for getting the registry host. +type RegistryHostGetter interface { + GetRegistryHost(artifact string) (string, error) +} + +// DefaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter. +type DefaultRegistryHostGetterImpl struct{} + +func (g *DefaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) { + // Implement the logic to get the registry host + return provider.GetRegistryHostName(artifact) +} diff --git a/pkg/common/oras/authprovider/azure/helper_test.go b/pkg/common/oras/authprovider/azure/helper_test.go new file mode 100644 index 000000000..d7d9b330f --- /dev/null +++ b/pkg/common/oras/authprovider/azure/helper_test.go @@ -0,0 +1,136 @@ +/* +Copyright The Ratify Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" + "github.com/stretchr/testify/mock" +) + +// MockAuthClient is a mock implementation of AuthClient. +type MockAuthClient struct { + mock.Mock +} + +// Mock method for ExchangeAADAccessTokenForACRRefreshToken +func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { + args := m.Called(ctx, grantType, service, options) + return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1) +} + +// MockAuthClientFactory is a mock implementation of AuthClientFactory. +type MockAuthClientFactory struct { + mock.Mock +} + +// Mock method for CreateAuthClient +func (m *MockAuthClientFactory) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + args := m.Called(serverURL, options) + return args.Get(0).(AuthClient), args.Error(1) +} + +// MockRegistryHostGetter is a mock implementation of RegistryHostGetter. +type MockRegistryHostGetter struct { + mock.Mock +} + +// Mock method for GetRegistryHost +func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error) { + args := m.Called(artifact) + return args.String(0), args.Error(1) +} + +// // TestDefaultAuthClientFactoryImpl tests the default factory implementation. +// func TestDefaultAuthClientFactoryImpl(t *testing.T) { +// mockFactory := new(MockAuthClientFactory) +// mockAuthClient := new(MockAuthClient) + +// serverURL := "https://example.azurecr.io" +// options := &azcontainerregistry.AuthenticationClientOptions{} + +// // Set up expectations +// mockFactory.On("CreateAuthClient", serverURL, options).Return(mockAuthClient, nil) + +// factory := &DefaultAuthClientFactoryImpl{} +// client, err := factory.CreateAuthClient(serverURL, options) + +// // Verify expectations +// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options) +// assert.NoError(t, err) +// assert.NotNil(t, client) +// } + +// // TestDefaultAuthClientFactory_Error tests error handling during client creation. +// func TestDefaultAuthClientFactory_Error(t *testing.T) { +// mockFactory := new(MockAuthClientFactory) + +// serverURL := "https://example.azurecr.io" +// options := &azcontainerregistry.AuthenticationClientOptions{} +// expectedError := errors.New("failed to create client") + +// // Set up expectations +// mockFactory.On("CreateAuthClient", serverURL, options).Return(nil, expectedError) + +// factory := &DefaultAuthClientFactoryImpl{} +// client, err := factory.CreateAuthClient(serverURL, options) + +// // Verify expectations +// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options) +// assert.Error(t, err) +// assert.Nil(t, client) +// assert.Equal(t, expectedError, err) +// } + +// // TestGetRegistryHost tests the GetRegistryHost function. +// func TestGetRegistryHost(t *testing.T) { +// mockGetter := new(MockRegistryHostGetter) + +// artifact := "test/artifact" +// expectedHost := "example.azurecr.io" + +// // Set up expectations +// mockGetter.On("GetRegistryHost", artifact).Return(expectedHost, nil) + +// getter := &DefaultRegistryHostGetterImpl{} +// host, err := getter.GetRegistryHost(artifact) + +// // Verify expectations +// mockGetter.AssertCalled(t, "GetRegistryHost", artifact) +// assert.NoError(t, err) +// assert.Equal(t, expectedHost, host) +// } + +// // TestGetRegistryHost_Error tests error handling in GetRegistryHost. +// func TestGetRegistryHost_Error(t *testing.T) { +// mockGetter := new(MockRegistryHostGetter) + +// artifact := "test/artifact" +// expectedError := errors.New("failed to get registry host") + +// // Set up expectations +// mockGetter.On("GetRegistryHost", artifact).Return("", expectedError) + +// getter := &DefaultRegistryHostGetterImpl{} +// host, err := getter.GetRegistryHost(artifact) + +// // Verify expectations +// mockGetter.AssertCalled(t, "GetRegistryHost", artifact) +// assert.Error(t, err) +// assert.Empty(t, host) +// assert.Equal(t, expectedError, err) +// } From 947f28ca499014f38566833cae08cd68228168be Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Mon, 14 Oct 2024 09:24:53 +1000 Subject: [PATCH 14/20] chore: unit tests for the helper.go file Signed-off-by: Shahram Kalantari --- .../oras/authprovider/azure/azureidentity.go | 10 +- .../authprovider/azure/azureidentity_test.go | 31 ++++- .../azure/azureworkloadidentity.go | 2 +- .../azure/azureworkloadidentity_test.go | 6 +- pkg/common/oras/authprovider/azure/helper.go | 13 ++- .../oras/authprovider/azure/helper_test.go | 108 ++++++------------ 6 files changed, 84 insertions(+), 86 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index fd8466762..eea251325 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -41,13 +41,13 @@ type ManagedIdentityTokenGetter interface { type DefaultManagedIdentityTokenGetterImpl struct{} func (g *DefaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { - return getManagedIdentityToken(ctx, clientID) + return getManagedIdentityToken(ctx, clientID, azidentity.NewManagedIdentityCredential) } -func getManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { +func getManagedIdentityToken(ctx context.Context, clientID string, newCredentialFunc func(opts *azidentity.ManagedIdentityCredentialOptions) (*azidentity.ManagedIdentityCredential, error)) (azcore.AccessToken, error) { id := azidentity.ClientID(clientID) opts := azidentity.ManagedIdentityCredentialOptions{ID: id} - cred, err := azidentity.NewManagedIdentityCredential(&opts) + cred, err := newCredentialFunc(&opts) if err != nil { return azcore.AccessToken{}, err } @@ -110,7 +110,7 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider return nil, err } // retrieve an AAD Access token - token, err := getManagedIdentityToken(context.Background(), client) + token, err := getManagedIdentityToken(context.Background(), client, azidentity.NewManagedIdentityCredential) if err != nil { return nil, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "", re.HideStackTrace) } @@ -177,7 +177,7 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider response, err := client.ExchangeAADAccessTokenForACRRefreshToken( ctx, - "access_token", + azcontainerregistry.PostContentSchemaGrantType("access_token"), artifactHostName, &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ AccessToken: &d.identityToken.Token, diff --git a/pkg/common/oras/authprovider/azure/azureidentity_test.go b/pkg/common/oras/authprovider/azure/azureidentity_test.go index a47854252..1ab6ff820 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureidentity_test.go @@ -23,6 +23,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" ratifyerrors "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/pkg/common/oras/authprovider" @@ -155,7 +156,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) { // Setup mock expectations mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(newAADToken, nil) // Initialize provider with expired token @@ -241,3 +242,31 @@ func TestMIAuthProvider_Provide_InvalidHostName(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "HOST_NAME_INVALID") } + +// Unit tests +func TestGetManagedIdentityToken(t *testing.T) { + ctx := context.Background() + clientID := "test-client-id" + expectedToken := azcore.AccessToken{Token: "test-token", ExpiresOn: time.Now().Add(time.Hour)} + + mockGetter := new(MockManagedIdentityTokenGetter) + mockGetter.On("GetManagedIdentityToken", ctx, clientID).Return(expectedToken, nil) + + token, err := mockGetter.GetManagedIdentityToken(ctx, clientID) + assert.Nil(t, err) + assert.Equal(t, expectedToken, token) +} + +func TestGetManagedIdentityToken_Error(t *testing.T) { + ctx := context.Background() + clientID := "test-client-id" + + // Mock the newCredentialFunc to return an error + mockNewCredentialFunc := func(_ *azidentity.ManagedIdentityCredentialOptions) (*azidentity.ManagedIdentityCredential, error) { + return nil, assert.AnError + } + + token, err := getManagedIdentityToken(ctx, clientID, mockNewCredentialFunc) + assert.NotNil(t, err) + assert.Equal(t, azcore.AccessToken{}, token) +} diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index bad1ca315..88c67f7d7 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -180,7 +180,7 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider startTime := time.Now() response, err := client.ExchangeAADAccessTokenForACRRefreshToken( ctx, - "access_token", + azcontainerregistry.PostContentSchemaGrantType("access_token"), artifactHostName, &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ AccessToken: &d.aadToken.AccessToken, diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index 7924600ba..4a7ccdf7e 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -69,7 +69,7 @@ func TestWIAuthProvider_Provide_Success(t *testing.T) { // Set expectations for mocked functions mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(initialToken, nil) mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() @@ -113,7 +113,7 @@ func TestWIAuthProvider_Provide_RefreshToken(t *testing.T) { // Set expectations for mocked functions mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil) mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() @@ -225,7 +225,7 @@ func TestWIAuthProvider_Provide_TokenRefresh_Success(t *testing.T) { // Set expectations mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil) mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() diff --git a/pkg/common/oras/authprovider/azure/helper.go b/pkg/common/oras/authprovider/azure/helper.go index 2e9da285f..3ca75a5a6 100644 --- a/pkg/common/oras/authprovider/azure/helper.go +++ b/pkg/common/oras/authprovider/azure/helper.go @@ -42,16 +42,21 @@ func DefaultAuthClientFactory(serverURL string, options *azcontainerregistry.Aut return &AuthenticationClientWrapper{client: client}, nil } +// Define the interface for azcontainerregistry.AuthenticationClient methods used +type AuthenticationClientInterface interface { + ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) +} + type AuthenticationClientWrapper struct { - client *azcontainerregistry.AuthenticationClient + client AuthenticationClientInterface } -func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { - return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, azcontainerregistry.PostContentSchemaGrantType(grantType), service, options) +func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { + return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, grantType, service, options) } type AuthClient interface { - ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) + ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) } // RegistryHostGetter defines an interface for getting the registry host. diff --git a/pkg/common/oras/authprovider/azure/helper_test.go b/pkg/common/oras/authprovider/azure/helper_test.go index d7d9b330f..365c681c3 100644 --- a/pkg/common/oras/authprovider/azure/helper_test.go +++ b/pkg/common/oras/authprovider/azure/helper_test.go @@ -17,8 +17,10 @@ package azure import ( "context" + "testing" "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -28,7 +30,7 @@ type MockAuthClient struct { } // Mock method for ExchangeAADAccessTokenForACRRefreshToken -func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { +func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { args := m.Called(ctx, grantType, service, options) return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1) } @@ -55,82 +57,44 @@ func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error return args.String(0), args.Error(1) } -// // TestDefaultAuthClientFactoryImpl tests the default factory implementation. -// func TestDefaultAuthClientFactoryImpl(t *testing.T) { -// mockFactory := new(MockAuthClientFactory) -// mockAuthClient := new(MockAuthClient) +func TestDefaultAuthClientFactoryImpl_CreateAuthClient(t *testing.T) { + factory := &DefaultAuthClientFactoryImpl{} + serverURL := "https://example.com" + options := &azcontainerregistry.AuthenticationClientOptions{} -// serverURL := "https://example.azurecr.io" -// options := &azcontainerregistry.AuthenticationClientOptions{} - -// // Set up expectations -// mockFactory.On("CreateAuthClient", serverURL, options).Return(mockAuthClient, nil) - -// factory := &DefaultAuthClientFactoryImpl{} -// client, err := factory.CreateAuthClient(serverURL, options) - -// // Verify expectations -// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options) -// assert.NoError(t, err) -// assert.NotNil(t, client) -// } - -// // TestDefaultAuthClientFactory_Error tests error handling during client creation. -// func TestDefaultAuthClientFactory_Error(t *testing.T) { -// mockFactory := new(MockAuthClientFactory) - -// serverURL := "https://example.azurecr.io" -// options := &azcontainerregistry.AuthenticationClientOptions{} -// expectedError := errors.New("failed to create client") - -// // Set up expectations -// mockFactory.On("CreateAuthClient", serverURL, options).Return(nil, expectedError) - -// factory := &DefaultAuthClientFactoryImpl{} -// client, err := factory.CreateAuthClient(serverURL, options) - -// // Verify expectations -// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options) -// assert.Error(t, err) -// assert.Nil(t, client) -// assert.Equal(t, expectedError, err) -// } - -// // TestGetRegistryHost tests the GetRegistryHost function. -// func TestGetRegistryHost(t *testing.T) { -// mockGetter := new(MockRegistryHostGetter) - -// artifact := "test/artifact" -// expectedHost := "example.azurecr.io" - -// // Set up expectations -// mockGetter.On("GetRegistryHost", artifact).Return(expectedHost, nil) + client, err := factory.CreateAuthClient(serverURL, options) + assert.Nil(t, err) + assert.NotNil(t, client) +} -// getter := &DefaultRegistryHostGetterImpl{} -// host, err := getter.GetRegistryHost(artifact) +func TestDefaultAuthClientFactory(t *testing.T) { + serverURL := "https://example.com" + options := &azcontainerregistry.AuthenticationClientOptions{} -// // Verify expectations -// mockGetter.AssertCalled(t, "GetRegistryHost", artifact) -// assert.NoError(t, err) -// assert.Equal(t, expectedHost, host) -// } + client, err := DefaultAuthClientFactory(serverURL, options) + assert.Nil(t, err) + assert.NotNil(t, client) +} -// // TestGetRegistryHost_Error tests error handling in GetRegistryHost. -// func TestGetRegistryHost_Error(t *testing.T) { -// mockGetter := new(MockRegistryHostGetter) +func TestDefaultRegistryHostGetterImpl_GetRegistryHost(t *testing.T) { + getter := &DefaultRegistryHostGetterImpl{} + artifact := "example.azurecr.io/myArtifact" -// artifact := "test/artifact" -// expectedError := errors.New("failed to get registry host") + host, err := getter.GetRegistryHost(artifact) + assert.Nil(t, err) + assert.Equal(t, "example.azurecr.io", host) +} -// // Set up expectations -// mockGetter.On("GetRegistryHost", artifact).Return("", expectedError) +func TestAuthenticationClientWrapper_ExchangeAADAccessTokenForACRRefreshToken(t *testing.T) { + mockClient := new(MockAuthClient) + wrapper := &AuthenticationClientWrapper{client: mockClient} + ctx := context.Background() + grantType := azcontainerregistry.PostContentSchemaGrantType("grantType") + service := "service" + options := &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{} -// getter := &DefaultRegistryHostGetterImpl{} -// host, err := getter.GetRegistryHost(artifact) + mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", ctx, grantType, service, options).Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{}, nil) -// // Verify expectations -// mockGetter.AssertCalled(t, "GetRegistryHost", artifact) -// assert.Error(t, err) -// assert.Empty(t, host) -// assert.Equal(t, expectedError, err) -// } + _, err := wrapper.ExchangeAADAccessTokenForACRRefreshToken(ctx, grantType, service, options) + assert.Nil(t, err) +} From d50d3068ac792ad1d231021e779ad979824b0b4c Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Tue, 15 Oct 2024 14:29:39 +1000 Subject: [PATCH 15/20] chore: create a const from the repetitive string Signed-off-by: Shahram Kalantari --- pkg/common/oras/authprovider/azure/azureidentity.go | 2 +- pkg/common/oras/authprovider/azure/azureidentity_test.go | 2 +- pkg/common/oras/authprovider/azure/azureworkloadidentity.go | 2 +- .../oras/authprovider/azure/azureworkloadidentity_test.go | 6 +++--- pkg/common/oras/authprovider/azure/helper.go | 2 ++ 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index eea251325..b6304b4b4 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -177,7 +177,7 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider response, err := client.ExchangeAADAccessTokenForACRRefreshToken( ctx, - azcontainerregistry.PostContentSchemaGrantType("access_token"), + azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), artifactHostName, &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ AccessToken: &d.identityToken.Token, diff --git a/pkg/common/oras/authprovider/azure/azureidentity_test.go b/pkg/common/oras/authprovider/azure/azureidentity_test.go index 1ab6ff820..eab8e66e7 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureidentity_test.go @@ -156,7 +156,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) { // Setup mock expectations mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(newAADToken, nil) // Initialize provider with expired token diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index 88c67f7d7..096473971 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -180,7 +180,7 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider startTime := time.Now() response, err := client.ExchangeAADAccessTokenForACRRefreshToken( ctx, - azcontainerregistry.PostContentSchemaGrantType("access_token"), + azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), artifactHostName, &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ AccessToken: &d.aadToken.AccessToken, diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index 4a7ccdf7e..54577cc9b 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -69,7 +69,7 @@ func TestWIAuthProvider_Provide_Success(t *testing.T) { // Set expectations for mocked functions mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(initialToken, nil) mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() @@ -113,7 +113,7 @@ func TestWIAuthProvider_Provide_RefreshToken(t *testing.T) { // Set expectations for mocked functions mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil) mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() @@ -225,7 +225,7 @@ func TestWIAuthProvider_Provide_TokenRefresh_Success(t *testing.T) { // Set expectations mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil) mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() diff --git a/pkg/common/oras/authprovider/azure/helper.go b/pkg/common/oras/authprovider/azure/helper.go index 3ca75a5a6..a98561641 100644 --- a/pkg/common/oras/authprovider/azure/helper.go +++ b/pkg/common/oras/authprovider/azure/helper.go @@ -22,6 +22,8 @@ import ( provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider" ) +const GrantTypeAccessToken = "access_token" + // AuthClientFactory defines an interface for creating an authentication client. type AuthClientFactory interface { CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) From ac8e29e9966d661fcaadefc12f38097182661d70 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Tue, 15 Oct 2024 15:00:02 +1000 Subject: [PATCH 16/20] chore: address comments Signed-off-by: Shahram Kalantari --- .../oras/authprovider/azure/azureidentity.go | 10 +++---- .../azure/azureworkloadidentity.go | 30 +++++++++---------- pkg/common/oras/authprovider/azure/helper.go | 16 +++++----- .../oras/authprovider/azure/helper_test.go | 6 ++-- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index b6304b4b4..0485ec1fa 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -37,10 +37,10 @@ type ManagedIdentityTokenGetter interface { GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) } -// DefaultManagedIdentityTokenGetterImpl is the default implementation of getManagedIdentityToken. -type DefaultManagedIdentityTokenGetterImpl struct{} +// defaultManagedIdentityTokenGetterImpl is the default implementation of getManagedIdentityToken. +type defaultManagedIdentityTokenGetterImpl struct{} -func (g *DefaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { +func (g *defaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { return getManagedIdentityToken(ctx, clientID, azidentity.NewManagedIdentityCredential) } @@ -119,8 +119,8 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider identityToken: token, clientID: client, tenantID: tenant, - authClientFactory: &DefaultAuthClientFactoryImpl{}, // Concrete implementation - getManagedIdentityToken: &DefaultManagedIdentityTokenGetterImpl{}, // Concrete implementation + authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation + getManagedIdentityToken: &defaultManagedIdentityTokenGetterImpl{}, // Concrete implementation }, nil } diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index 096473971..404f91670 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -35,14 +35,14 @@ type AADAccessTokenGetter interface { GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) } -// DefaultAADAccessTokenGetterImpl is the default implementation of AADAccessTokenGetter. -type DefaultAADAccessTokenGetterImpl struct{} +// defaultAADAccessTokenGetterImpl is the default implementation of AADAccessTokenGetter. +type defaultAADAccessTokenGetterImpl struct{} -func (g *DefaultAADAccessTokenGetterImpl) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { - return DefaultGetAADAccessToken(ctx, tenantID, clientID, resource) +func (g *defaultAADAccessTokenGetterImpl) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { + return defaultGetAADAccessToken(ctx, tenantID, clientID, resource) } -func DefaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { +func defaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { return azureauth.GetAADAccessToken(ctx, tenantID, clientID, resource) } @@ -51,14 +51,14 @@ type MetricsReporter interface { ReportMetrics(ctx context.Context, duration int64, artifactHostName string) } -// DefaultMetricsReporterImpl is the default implementation of MetricsReporter. -type DefaultMetricsReporterImpl struct{} +// defaultMetricsReporterImpl is the default implementation of MetricsReporter. +type defaultMetricsReporterImpl struct{} -func (r *DefaultMetricsReporterImpl) ReportMetrics(ctx context.Context, duration int64, artifactHostName string) { - DefaultReportMetrics(ctx, duration, artifactHostName) +func (r *defaultMetricsReporterImpl) ReportMetrics(ctx context.Context, duration int64, artifactHostName string) { + defaultReportMetrics(ctx, duration, artifactHostName) } -func DefaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) { +func defaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) { logger.GetLogger(ctx, logOpt).Infof("Metrics Report: Duration=%dms, Host=%s", duration, artifactHostName) } @@ -114,7 +114,7 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider } // retrieve an AAD Access token - token, err := DefaultGetAADAccessToken(context.Background(), tenant, clientID, AADResource) + token, err := defaultGetAADAccessToken(context.Background(), tenant, clientID, AADResource) if err != nil { return nil, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "", re.HideStackTrace) } @@ -123,10 +123,10 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider aadToken: token, tenantID: tenant, clientID: clientID, - authClientFactory: &DefaultAuthClientFactoryImpl{}, // Concrete implementation - getRegistryHost: &DefaultRegistryHostGetterImpl{}, // Concrete implementation - getAADAccessToken: &DefaultAADAccessTokenGetterImpl{}, // Concrete implementation - reportMetrics: &DefaultMetricsReporterImpl{}, + authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation + getRegistryHost: &defaultRegistryHostGetterImpl{}, // Concrete implementation + getAADAccessToken: &defaultAADAccessTokenGetterImpl{}, // Concrete implementation + reportMetrics: &defaultMetricsReporterImpl{}, }, nil } diff --git a/pkg/common/oras/authprovider/azure/helper.go b/pkg/common/oras/authprovider/azure/helper.go index a98561641..b7fc61eba 100644 --- a/pkg/common/oras/authprovider/azure/helper.go +++ b/pkg/common/oras/authprovider/azure/helper.go @@ -29,14 +29,14 @@ type AuthClientFactory interface { CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) } -// DefaultAuthClientFactoryImpl is the default implementation of AuthClientFactory. -type DefaultAuthClientFactoryImpl struct{} +// defaultAuthClientFactoryImpl is the default implementation of AuthClientFactory. +type defaultAuthClientFactoryImpl struct{} -func (f *DefaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - return DefaultAuthClientFactory(serverURL, options) +func (f *defaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + return defaultAuthClientFactory(serverURL, options) } -func DefaultAuthClientFactory(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { +func defaultAuthClientFactory(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) if err != nil { return nil, err @@ -66,10 +66,10 @@ type RegistryHostGetter interface { GetRegistryHost(artifact string) (string, error) } -// DefaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter. -type DefaultRegistryHostGetterImpl struct{} +// defaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter. +type defaultRegistryHostGetterImpl struct{} -func (g *DefaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) { +func (g *defaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) { // Implement the logic to get the registry host return provider.GetRegistryHostName(artifact) } diff --git a/pkg/common/oras/authprovider/azure/helper_test.go b/pkg/common/oras/authprovider/azure/helper_test.go index 365c681c3..49c811f0f 100644 --- a/pkg/common/oras/authprovider/azure/helper_test.go +++ b/pkg/common/oras/authprovider/azure/helper_test.go @@ -58,7 +58,7 @@ func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error } func TestDefaultAuthClientFactoryImpl_CreateAuthClient(t *testing.T) { - factory := &DefaultAuthClientFactoryImpl{} + factory := &defaultAuthClientFactoryImpl{} serverURL := "https://example.com" options := &azcontainerregistry.AuthenticationClientOptions{} @@ -71,13 +71,13 @@ func TestDefaultAuthClientFactory(t *testing.T) { serverURL := "https://example.com" options := &azcontainerregistry.AuthenticationClientOptions{} - client, err := DefaultAuthClientFactory(serverURL, options) + client, err := defaultAuthClientFactory(serverURL, options) assert.Nil(t, err) assert.NotNil(t, client) } func TestDefaultRegistryHostGetterImpl_GetRegistryHost(t *testing.T) { - getter := &DefaultRegistryHostGetterImpl{} + getter := &defaultRegistryHostGetterImpl{} artifact := "example.azurecr.io/myArtifact" host, err := getter.GetRegistryHost(artifact) From 234492768e142643ea43fd04e8526c710935dea6 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Thu, 17 Oct 2024 18:58:15 +1000 Subject: [PATCH 17/20] chore: address comments Signed-off-by: Shahram Kalantari --- pkg/common/oras/authprovider/azure/azureidentity.go | 3 ++- pkg/common/oras/authprovider/azure/azureworkloadidentity.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index 0485ec1fa..0df0a2f8d 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -172,7 +172,8 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider var options *azcontainerregistry.AuthenticationClientOptions client, err := d.authClientFactory.CreateAuthClient(serverURL, options) if err != nil { - return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry by azure managed identity token", re.HideStackTrace) + // return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry by azure managed identity token", re.HideStackTrace) + return provider.AuthConfig{}, re.ErrorCodeAuthDenied.WithError(err).WithDetail("failed to create authentication client for container registry by azure managed identity token") } response, err := client.ExchangeAADAccessTokenForACRRefreshToken( diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index 404f91670..a79c98cc8 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -174,7 +174,8 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider var options *azcontainerregistry.AuthenticationClientOptions client, err := d.authClientFactory.CreateAuthClient(serverURL, options) if err != nil { - return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry", re.HideStackTrace) + // return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry", re.HideStackTrace) + return provider.AuthConfig{}, re.ErrorCodeAuthDenied.WithError(err).WithDetail("failed to create authentication client for container registry by azure managed identity token") } startTime := time.Now() From 10ea3e29d5bcc07f17db3b3af4b06aad9c662dfb Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Tue, 22 Oct 2024 09:37:19 +1000 Subject: [PATCH 18/20] chore: address comments Signed-off-by: Shahram Kalantari --- .../oras/authprovider/azure/azureidentity.go | 4 +- .../authprovider/azure/azureidentity_test.go | 6 +- .../azure/azureworkloadidentity.go | 30 +++---- .../azure/azureworkloadidentity_test.go | 84 +++++++++---------- pkg/common/oras/authprovider/azure/helper.go | 11 ++- 5 files changed, 72 insertions(+), 63 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index 0df0a2f8d..97a687ee0 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -65,7 +65,7 @@ type MIAuthProvider struct { clientID string tenantID string authClientFactory AuthClientFactory - getRegistryHost RegistryHostGetter + registryHostGetter RegistryHostGetter getManagedIdentityToken ManagedIdentityTokenGetter } @@ -150,7 +150,7 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider } // parse the artifact reference string to extract the registry host name - artifactHostName, err := d.getRegistryHost.GetRegistryHost(artifact) + artifactHostName, err := d.registryHostGetter.GetRegistryHost(artifact) if err != nil { return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider) } diff --git a/pkg/common/oras/authprovider/azure/azureidentity_test.go b/pkg/common/oras/authprovider/azure/azureidentity_test.go index eab8e66e7..8d466d3d1 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureidentity_test.go @@ -165,7 +165,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) { clientID: "clientID", tenantID: "tenantID", authClientFactory: mockAuthClientFactory, - getRegistryHost: mockRegistryHostGetter, + registryHostGetter: mockRegistryHostGetter, getManagedIdentityToken: mockManagedIdentityTokenGetter, } @@ -198,7 +198,7 @@ func TestMIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) { clientID: "clientID", tenantID: "tenantID", authClientFactory: mockAuthClientFactory, - getRegistryHost: mockRegistryHostGetter, + registryHostGetter: mockRegistryHostGetter, getManagedIdentityToken: mockManagedIdentityTokenGetter, } @@ -230,7 +230,7 @@ func TestMIAuthProvider_Provide_InvalidHostName(t *testing.T) { clientID: "clientID", tenantID: "tenantID", authClientFactory: mockAuthClientFactory, - getRegistryHost: mockRegistryHostGetter, + registryHostGetter: mockRegistryHostGetter, getManagedIdentityToken: mockManagedIdentityTokenGetter, } diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index a79c98cc8..1ef2c490d 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -65,13 +65,13 @@ func defaultReportMetrics(ctx context.Context, duration int64, artifactHostName type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name type WIAuthProvider struct { - aadToken confidential.AuthResult - tenantID string - clientID string - authClientFactory AuthClientFactory - getRegistryHost RegistryHostGetter - getAADAccessToken AADAccessTokenGetter - reportMetrics MetricsReporter + aadToken confidential.AuthResult + tenantID string + clientID string + authClientFactory AuthClientFactory + registryHostGetter RegistryHostGetter + getAADAccessToken AADAccessTokenGetter + reportMetrics MetricsReporter } type azureWIAuthProviderConf struct { @@ -120,13 +120,13 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider } return &WIAuthProvider{ - aadToken: token, - tenantID: tenant, - clientID: clientID, - authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation - getRegistryHost: &defaultRegistryHostGetterImpl{}, // Concrete implementation - getAADAccessToken: &defaultAADAccessTokenGetterImpl{}, // Concrete implementation - reportMetrics: &defaultMetricsReporterImpl{}, + aadToken: token, + tenantID: tenant, + clientID: clientID, + authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation + registryHostGetter: &defaultRegistryHostGetterImpl{}, // Concrete implementation + getAADAccessToken: &defaultAADAccessTokenGetterImpl{}, // Concrete implementation + reportMetrics: &defaultMetricsReporterImpl{}, }, nil } @@ -152,7 +152,7 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider } // parse the artifact reference string to extract the registry host name - artifactHostName, err := d.getRegistryHost.GetRegistryHost(artifact) + artifactHostName, err := d.registryHostGetter.GetRegistryHost(artifact) if err != nil { return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider) } diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index 54577cc9b..b2ffaa0cd 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -75,13 +75,13 @@ func TestWIAuthProvider_Provide_Success(t *testing.T) { // Create WIAuthProvider provider := WIAuthProvider{ - aadToken: initialToken, - tenantID: "tenantID", - clientID: "clientID", - authClientFactory: mockAuthClientFactory, - getRegistryHost: mockRegistryHostGetter, - getAADAccessToken: mockAADAccessTokenGetter, - reportMetrics: mockMetricsReporter, + aadToken: initialToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, } // Call Provide method @@ -119,13 +119,13 @@ func TestWIAuthProvider_Provide_RefreshToken(t *testing.T) { // Create WIAuthProvider with expired token provider := WIAuthProvider{ - aadToken: expiredToken, - tenantID: "tenantID", - clientID: "clientID", - authClientFactory: mockAuthClientFactory, - getRegistryHost: mockRegistryHostGetter, - getAADAccessToken: mockAADAccessTokenGetter, - reportMetrics: mockMetricsReporter, + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, } // Call Provide method @@ -154,13 +154,13 @@ func TestWIAuthProvider_Provide_AADTokenFailure(t *testing.T) { // Create WIAuthProvider with expired token provider := WIAuthProvider{ - aadToken: expiredToken, - tenantID: "tenantID", - clientID: "clientID", - authClientFactory: mockAuthClientFactory, - getRegistryHost: mockRegistryHostGetter, - getAADAccessToken: mockAADAccessTokenGetter, - reportMetrics: mockMetricsReporter, + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, } // Call Provide method @@ -231,13 +231,13 @@ func TestWIAuthProvider_Provide_TokenRefresh_Success(t *testing.T) { // Create WIAuthProvider with expired token provider := WIAuthProvider{ - aadToken: expiredToken, - tenantID: "tenantID", - clientID: "clientID", - authClientFactory: mockAuthClientFactory, - getRegistryHost: mockRegistryHostGetter, - getAADAccessToken: mockAADAccessTokenGetter, - reportMetrics: mockMetricsReporter, + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, } // Call Provide method @@ -266,13 +266,13 @@ func TestWIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) { // Create WIAuthProvider with expired token provider := WIAuthProvider{ - aadToken: expiredToken, - tenantID: "tenantID", - clientID: "clientID", - authClientFactory: mockAuthClientFactory, - getRegistryHost: mockRegistryHostGetter, - getAADAccessToken: mockAADAccessTokenGetter, - reportMetrics: mockMetricsReporter, + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, } // Call Provide method @@ -314,13 +314,13 @@ func TestWIAuthProvider_Provide_InvalidHostName(t *testing.T) { // Create WIAuthProvider with valid token provider := WIAuthProvider{ - aadToken: validToken, - tenantID: "tenantID", - clientID: "clientID", - authClientFactory: mockAuthClientFactory, - getRegistryHost: mockRegistryHostGetter, - getAADAccessToken: mockAADAccessTokenGetter, - reportMetrics: mockMetricsReporter, + aadToken: validToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, } // Call Provide method diff --git a/pkg/common/oras/authprovider/azure/helper.go b/pkg/common/oras/authprovider/azure/helper.go index b7fc61eba..beafa4db0 100644 --- a/pkg/common/oras/authprovider/azure/helper.go +++ b/pkg/common/oras/authprovider/azure/helper.go @@ -32,10 +32,13 @@ type AuthClientFactory interface { // defaultAuthClientFactoryImpl is the default implementation of AuthClientFactory. type defaultAuthClientFactoryImpl struct{} +// creates an AuthClient using the default factory implementation. +// Return an AuthClient and an error if the client creation fails. func (f *defaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { return defaultAuthClientFactory(serverURL, options) } +// Define a helper function that creates an instance of AuthenticationClientWrapper. func defaultAuthClientFactory(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) if err != nil { @@ -49,14 +52,19 @@ type AuthenticationClientInterface interface { ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) } +// Define the wrapper for AuthenticationClientInterface type AuthenticationClientWrapper struct { client AuthenticationClientInterface } +// A wrapper method that calls the underlying AuthenticationClientInterface's method. +// Exchanges an AAD access token for an ACR refresh token. func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, grantType, service, options) } +// define the interface for authentication operations. +// It includes the method for exchanging an AAD access token for an ACR refresh token. type AuthClient interface { ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) } @@ -69,7 +77,8 @@ type RegistryHostGetter interface { // defaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter. type defaultRegistryHostGetterImpl struct{} +// Retrieves the registry host name for a given artifact. +// It utilizes the provider's GetRegistryHostName function to perform the lookup. func (g *defaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) { - // Implement the logic to get the registry host return provider.GetRegistryHostName(artifact) } From 94f899ded0d6ddb619f3e27871ff6e89f215c939 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Tue, 22 Oct 2024 09:48:34 +1000 Subject: [PATCH 19/20] fix: ran go mod tidy Signed-off-by: Shahram Kalantari --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index c580fa570..be3607472 100644 --- a/go.mod +++ b/go.mod @@ -239,7 +239,7 @@ require ( golang.org/x/crypto v0.28.0 golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 // indirect golang.org/x/mod v0.20.0 // indirect - golang.org/x/net v0.28.0 // indirect + golang.org/x/net v0.29.0 // indirect golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/term v0.25.0 // indirect From a70ebefbfc5eec0df97234b5efd3e96930dcdc87 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Wed, 23 Oct 2024 13:45:57 +1000 Subject: [PATCH 20/20] chore: remove commented code Signed-off-by: Shahram Kalantari --- pkg/common/oras/authprovider/azure/azureidentity.go | 1 - pkg/common/oras/authprovider/azure/azureworkloadidentity.go | 1 - 2 files changed, 2 deletions(-) diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index 97a687ee0..d0369c4dc 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -172,7 +172,6 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider var options *azcontainerregistry.AuthenticationClientOptions client, err := d.authClientFactory.CreateAuthClient(serverURL, options) if err != nil { - // return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry by azure managed identity token", re.HideStackTrace) return provider.AuthConfig{}, re.ErrorCodeAuthDenied.WithError(err).WithDetail("failed to create authentication client for container registry by azure managed identity token") } diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index 1ef2c490d..31f45127d 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -174,7 +174,6 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider var options *azcontainerregistry.AuthenticationClientOptions client, err := d.authClientFactory.CreateAuthClient(serverURL, options) if err != nil { - // return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry", re.HideStackTrace) return provider.AuthConfig{}, re.ErrorCodeAuthDenied.WithError(err).WithDetail("failed to create authentication client for container registry by azure managed identity token") }