diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index fd8466762d..77aa6a272e 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -40,14 +40,32 @@ type ManagedIdentityTokenGetter interface { // DefaultManagedIdentityTokenGetterImpl is the default implementation of getManagedIdentityToken. type DefaultManagedIdentityTokenGetterImpl struct{} +// func (g *DefaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { +// return getManagedIdentityToken(ctx, clientID) +// } + func (g *DefaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { - return getManagedIdentityToken(ctx, clientID) + return getManagedIdentityToken(ctx, clientID, azidentity.NewManagedIdentityCredential) } -func getManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { +// func getManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { +// id := azidentity.ClientID(clientID) +// opts := azidentity.ManagedIdentityCredentialOptions{ID: id} +// cred, err := azidentity.NewManagedIdentityCredential(&opts) +// if err != nil { +// return azcore.AccessToken{}, err +// } +// scopes := []string{AADResource} +// if cred != nil { +// return cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) +// } +// return azcore.AccessToken{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("config is nil pointer for GetServicePrincipalToken") +// } + +func getManagedIdentityToken(ctx context.Context, clientID string, newCredentialFunc func(opts *azidentity.ManagedIdentityCredentialOptions) (*azidentity.ManagedIdentityCredential, error)) (azcore.AccessToken, error) { id := azidentity.ClientID(clientID) opts := azidentity.ManagedIdentityCredentialOptions{ID: id} - cred, err := azidentity.NewManagedIdentityCredential(&opts) + cred, err := newCredentialFunc(&opts) if err != nil { return azcore.AccessToken{}, err } @@ -110,7 +128,7 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider return nil, err } // retrieve an AAD Access token - token, err := getManagedIdentityToken(context.Background(), client) + token, err := getManagedIdentityToken(context.Background(), client, azidentity.NewManagedIdentityCredential) if err != nil { return nil, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "", re.HideStackTrace) } @@ -177,7 +195,7 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider response, err := client.ExchangeAADAccessTokenForACRRefreshToken( ctx, - "access_token", + azcontainerregistry.PostContentSchemaGrantType("access_token"), artifactHostName, &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ AccessToken: &d.identityToken.Token, diff --git a/pkg/common/oras/authprovider/azure/azureidentity_test.go b/pkg/common/oras/authprovider/azure/azureidentity_test.go index a47854252a..1ab6ff820f 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureidentity_test.go @@ -23,6 +23,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" ratifyerrors "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/pkg/common/oras/authprovider" @@ -155,7 +156,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) { // Setup mock expectations mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(newAADToken, nil) // Initialize provider with expired token @@ -241,3 +242,31 @@ func TestMIAuthProvider_Provide_InvalidHostName(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "HOST_NAME_INVALID") } + +// Unit tests +func TestGetManagedIdentityToken(t *testing.T) { + ctx := context.Background() + clientID := "test-client-id" + expectedToken := azcore.AccessToken{Token: "test-token", ExpiresOn: time.Now().Add(time.Hour)} + + mockGetter := new(MockManagedIdentityTokenGetter) + mockGetter.On("GetManagedIdentityToken", ctx, clientID).Return(expectedToken, nil) + + token, err := mockGetter.GetManagedIdentityToken(ctx, clientID) + assert.Nil(t, err) + assert.Equal(t, expectedToken, token) +} + +func TestGetManagedIdentityToken_Error(t *testing.T) { + ctx := context.Background() + clientID := "test-client-id" + + // Mock the newCredentialFunc to return an error + mockNewCredentialFunc := func(_ *azidentity.ManagedIdentityCredentialOptions) (*azidentity.ManagedIdentityCredential, error) { + return nil, assert.AnError + } + + token, err := getManagedIdentityToken(ctx, clientID, mockNewCredentialFunc) + assert.NotNil(t, err) + assert.Equal(t, azcore.AccessToken{}, token) +} diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index bad1ca3157..88c67f7d71 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -180,7 +180,7 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider startTime := time.Now() response, err := client.ExchangeAADAccessTokenForACRRefreshToken( ctx, - "access_token", + azcontainerregistry.PostContentSchemaGrantType("access_token"), artifactHostName, &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ AccessToken: &d.aadToken.AccessToken, diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index 7924600bab..4a7ccdf7ef 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -69,7 +69,7 @@ func TestWIAuthProvider_Provide_Success(t *testing.T) { // Set expectations for mocked functions mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(initialToken, nil) mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() @@ -113,7 +113,7 @@ func TestWIAuthProvider_Provide_RefreshToken(t *testing.T) { // Set expectations for mocked functions mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil) mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() @@ -225,7 +225,7 @@ func TestWIAuthProvider_Provide_TokenRefresh_Success(t *testing.T) { // Set expectations mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) - mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil) mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() diff --git a/pkg/common/oras/authprovider/azure/helper.go b/pkg/common/oras/authprovider/azure/helper.go index 2e9da285fb..3ca75a5a63 100644 --- a/pkg/common/oras/authprovider/azure/helper.go +++ b/pkg/common/oras/authprovider/azure/helper.go @@ -42,16 +42,21 @@ func DefaultAuthClientFactory(serverURL string, options *azcontainerregistry.Aut return &AuthenticationClientWrapper{client: client}, nil } +// Define the interface for azcontainerregistry.AuthenticationClient methods used +type AuthenticationClientInterface interface { + ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) +} + type AuthenticationClientWrapper struct { - client *azcontainerregistry.AuthenticationClient + client AuthenticationClientInterface } -func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { - return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, azcontainerregistry.PostContentSchemaGrantType(grantType), service, options) +func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { + return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, grantType, service, options) } type AuthClient interface { - ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) + ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) } // RegistryHostGetter defines an interface for getting the registry host. diff --git a/pkg/common/oras/authprovider/azure/helper_test.go b/pkg/common/oras/authprovider/azure/helper_test.go index d7d9b330f2..365c681c3f 100644 --- a/pkg/common/oras/authprovider/azure/helper_test.go +++ b/pkg/common/oras/authprovider/azure/helper_test.go @@ -17,8 +17,10 @@ package azure import ( "context" + "testing" "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -28,7 +30,7 @@ type MockAuthClient struct { } // Mock method for ExchangeAADAccessTokenForACRRefreshToken -func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { +func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { args := m.Called(ctx, grantType, service, options) return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1) } @@ -55,82 +57,44 @@ func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error return args.String(0), args.Error(1) } -// // TestDefaultAuthClientFactoryImpl tests the default factory implementation. -// func TestDefaultAuthClientFactoryImpl(t *testing.T) { -// mockFactory := new(MockAuthClientFactory) -// mockAuthClient := new(MockAuthClient) +func TestDefaultAuthClientFactoryImpl_CreateAuthClient(t *testing.T) { + factory := &DefaultAuthClientFactoryImpl{} + serverURL := "https://example.com" + options := &azcontainerregistry.AuthenticationClientOptions{} -// serverURL := "https://example.azurecr.io" -// options := &azcontainerregistry.AuthenticationClientOptions{} - -// // Set up expectations -// mockFactory.On("CreateAuthClient", serverURL, options).Return(mockAuthClient, nil) - -// factory := &DefaultAuthClientFactoryImpl{} -// client, err := factory.CreateAuthClient(serverURL, options) - -// // Verify expectations -// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options) -// assert.NoError(t, err) -// assert.NotNil(t, client) -// } - -// // TestDefaultAuthClientFactory_Error tests error handling during client creation. -// func TestDefaultAuthClientFactory_Error(t *testing.T) { -// mockFactory := new(MockAuthClientFactory) - -// serverURL := "https://example.azurecr.io" -// options := &azcontainerregistry.AuthenticationClientOptions{} -// expectedError := errors.New("failed to create client") - -// // Set up expectations -// mockFactory.On("CreateAuthClient", serverURL, options).Return(nil, expectedError) - -// factory := &DefaultAuthClientFactoryImpl{} -// client, err := factory.CreateAuthClient(serverURL, options) - -// // Verify expectations -// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options) -// assert.Error(t, err) -// assert.Nil(t, client) -// assert.Equal(t, expectedError, err) -// } - -// // TestGetRegistryHost tests the GetRegistryHost function. -// func TestGetRegistryHost(t *testing.T) { -// mockGetter := new(MockRegistryHostGetter) - -// artifact := "test/artifact" -// expectedHost := "example.azurecr.io" - -// // Set up expectations -// mockGetter.On("GetRegistryHost", artifact).Return(expectedHost, nil) + client, err := factory.CreateAuthClient(serverURL, options) + assert.Nil(t, err) + assert.NotNil(t, client) +} -// getter := &DefaultRegistryHostGetterImpl{} -// host, err := getter.GetRegistryHost(artifact) +func TestDefaultAuthClientFactory(t *testing.T) { + serverURL := "https://example.com" + options := &azcontainerregistry.AuthenticationClientOptions{} -// // Verify expectations -// mockGetter.AssertCalled(t, "GetRegistryHost", artifact) -// assert.NoError(t, err) -// assert.Equal(t, expectedHost, host) -// } + client, err := DefaultAuthClientFactory(serverURL, options) + assert.Nil(t, err) + assert.NotNil(t, client) +} -// // TestGetRegistryHost_Error tests error handling in GetRegistryHost. -// func TestGetRegistryHost_Error(t *testing.T) { -// mockGetter := new(MockRegistryHostGetter) +func TestDefaultRegistryHostGetterImpl_GetRegistryHost(t *testing.T) { + getter := &DefaultRegistryHostGetterImpl{} + artifact := "example.azurecr.io/myArtifact" -// artifact := "test/artifact" -// expectedError := errors.New("failed to get registry host") + host, err := getter.GetRegistryHost(artifact) + assert.Nil(t, err) + assert.Equal(t, "example.azurecr.io", host) +} -// // Set up expectations -// mockGetter.On("GetRegistryHost", artifact).Return("", expectedError) +func TestAuthenticationClientWrapper_ExchangeAADAccessTokenForACRRefreshToken(t *testing.T) { + mockClient := new(MockAuthClient) + wrapper := &AuthenticationClientWrapper{client: mockClient} + ctx := context.Background() + grantType := azcontainerregistry.PostContentSchemaGrantType("grantType") + service := "service" + options := &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{} -// getter := &DefaultRegistryHostGetterImpl{} -// host, err := getter.GetRegistryHost(artifact) + mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", ctx, grantType, service, options).Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{}, nil) -// // Verify expectations -// mockGetter.AssertCalled(t, "GetRegistryHost", artifact) -// assert.Error(t, err) -// assert.Empty(t, host) -// assert.Equal(t, expectedError, err) -// } + _, err := wrapper.ExchangeAADAccessTokenForACRRefreshToken(ctx, grantType, service, options) + assert.Nil(t, err) +}