Skip to content

Commit

Permalink
chore: move common code to helper.go
Browse files Browse the repository at this point in the history
Signed-off-by: Shahram Kalantari <[email protected]>
  • Loading branch information
shahramk64 committed Oct 11, 2024
1 parent 8c005b1 commit 13de70a
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 91 deletions.
34 changes: 17 additions & 17 deletions pkg/common/oras/authprovider/azure/azureidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,34 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
)

type azureManagedIdentityProviderFactory struct{}

// ManagedIdentityTokenGetter defines an interface for getting a managed identity token.
type ManagedIdentityTokenGetter interface {
GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error)
}

// DefaultManagedIdentityTokenGetterImpl is the default implementation of AADAccessTokenGetter.
// DefaultManagedIdentityTokenGetterImpl is the default implementation of getManagedIdentityToken.
type DefaultManagedIdentityTokenGetterImpl struct{}

func (g *DefaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
return getManagedIdentityToken(ctx, clientID)

Check warning on line 44 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L43-L44

Added lines #L43 - L44 were not covered by tests
}

func getManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
id := azidentity.ClientID(clientID)
opts := azidentity.ManagedIdentityCredentialOptions{ID: id}
cred, err := azidentity.NewManagedIdentityCredential(&opts)
if err != nil {
return azcore.AccessToken{}, err

Check warning on line 52 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L47-L52

Added lines #L47 - L52 were not covered by tests
}
scopes := []string{AADResource}
if cred != nil {
return cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes})

Check warning on line 56 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L54-L56

Added lines #L54 - L56 were not covered by tests
}
return azcore.AccessToken{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("config is nil pointer for GetServicePrincipalToken")

Check warning on line 58 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L58

Added line #L58 was not covered by tests
}

type azureManagedIdentityProviderFactory struct{}

type MIAuthProvider struct {
identityToken azcore.AccessToken
clientID string
Expand Down Expand Up @@ -185,17 +199,3 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider

return authConfig, nil
}

func getManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
id := azidentity.ClientID(clientID)
opts := azidentity.ManagedIdentityCredentialOptions{ID: id}
cred, err := azidentity.NewManagedIdentityCredential(&opts)
if err != nil {
return azcore.AccessToken{}, err
}
scopes := []string{AADResource}
if cred != nil {
return cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes})
}
return azcore.AccessToken{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("config is nil pointer for GetServicePrincipalToken")
}
4 changes: 2 additions & 2 deletions pkg/common/oras/authprovider/azure/azureidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) {

// Define token values
expiredToken := azcore.AccessToken{Token: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)}
newTokenString := "refreshed_token"
newTokenString := "refreshed"
newAADToken := azcore.AccessToken{Token: "new_token", ExpiresOn: time.Now().Add(10 * time.Minute)}
refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{
ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &newTokenString},
Expand Down Expand Up @@ -174,7 +174,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) {

// Validate success and token refresh
assert.NoError(t, err)
assert.Equal(t, "refreshed_token", authConfig.Password)
assert.Equal(t, "refreshed", authConfig.Password)
}

// Test failed token refresh
Expand Down
46 changes: 15 additions & 31 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,61 +25,45 @@ import (
re "github.com/ratify-project/ratify/errors"
"github.com/ratify-project/ratify/internal/logger"
provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider"
"github.com/ratify-project/ratify/pkg/utils/azureauth"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)

// AuthClientFactory defines an interface for creating an authentication client.
type AuthClientFactory interface {
CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error)
}

// RegistryHostGetter defines an interface for getting the registry host.
type RegistryHostGetter interface {
GetRegistryHost(artifact string) (string, error)
}

// AADAccessTokenGetter defines an interface for getting an AAD access token.
type AADAccessTokenGetter interface {
GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error)
}

// MetricsReporter defines an interface for reporting metrics.
type MetricsReporter interface {
ReportMetrics(ctx context.Context, duration int64, artifactHostName string)
}

// DefaultAuthClientFactoryImpl is the default implementation of AuthClientFactory.
type DefaultAuthClientFactoryImpl struct{}

func (f *DefaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
return DefaultAuthClientFactory(serverURL, options)
}

// DefaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter.
type DefaultRegistryHostGetterImpl struct{}

func (g *DefaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) {
// Implement the logic to get the registry host
return provider.GetRegistryHostName(artifact)
// return artifactHost, nil // Replace with actual logic
}

// DefaultAADAccessTokenGetterImpl is the default implementation of AADAccessTokenGetter.
type DefaultAADAccessTokenGetterImpl struct{}

func (g *DefaultAADAccessTokenGetterImpl) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) {
return DefaultGetAADAccessToken(ctx, tenantID, clientID, resource)

Check warning on line 42 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L41-L42

Added lines #L41 - L42 were not covered by tests
}

func DefaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) {
return azureauth.GetAADAccessToken(ctx, tenantID, clientID, resource)
}

// MetricsReporter defines an interface for reporting metrics.
type MetricsReporter interface {
ReportMetrics(ctx context.Context, duration int64, artifactHostName string)
}

// DefaultMetricsReporterImpl is the default implementation of MetricsReporter.
type DefaultMetricsReporterImpl struct{}

func (r *DefaultMetricsReporterImpl) ReportMetrics(ctx context.Context, duration int64, artifactHostName string) {
DefaultReportMetrics(ctx, duration, artifactHostName)

Check warning on line 58 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L57-L58

Added lines #L57 - L58 were not covered by tests
}

func DefaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) {
logger.GetLogger(ctx, logOpt).Infof("Metrics Report: Duration=%dms, Host=%s", duration, artifactHostName)

Check warning on line 62 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L61-L62

Added lines #L61 - L62 were not covered by tests
}

type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name

type WIAuthProvider struct {
aadToken confidential.AuthResult
tenantID string
Expand Down
30 changes: 0 additions & 30 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,6 @@ import (
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)

// MockAuthClientFactory for creating AuthClient
type MockAuthClientFactory struct {
mock.Mock
}

func (m *MockAuthClientFactory) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
args := m.Called(serverURL, options)
return args.Get(0).(AuthClient), args.Error(1)
}

// MockRegistryHostGetter for retrieving registry host
type MockRegistryHostGetter struct {
mock.Mock
}

func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error) {
args := m.Called(artifact)
return args.String(0), args.Error(1)
}

// MockAADAccessTokenGetter for retrieving AAD access token
type MockAADAccessTokenGetter struct {
mock.Mock
Expand All @@ -70,16 +50,6 @@ func (m *MockMetricsReporter) ReportMetrics(ctx context.Context, duration int64,
m.Called(ctx, duration, artifactHostName)
}

// MockAuthClient for the Azure auth client
type MockAuthClient struct {
mock.Mock
}

func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
args := m.Called(ctx, grantType, service, options)
return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1)
}

// Test for successful Provide function
func TestWIAuthProvider_Provide_Success(t *testing.T) {
// Mock all dependencies
Expand Down
37 changes: 26 additions & 11 deletions pkg/common/oras/authprovider/azure/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@ import (
"context"

"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
"github.com/ratify-project/ratify/internal/logger"
"github.com/ratify-project/ratify/pkg/utils/azureauth"
provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider"
)

// AuthClientFactory defines an interface for creating an authentication client.
type AuthClientFactory interface {
CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error)
}

// DefaultAuthClientFactoryImpl is the default implementation of AuthClientFactory.
type DefaultAuthClientFactoryImpl struct{}

func (f *DefaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
return DefaultAuthClientFactory(serverURL, options)

Check warning on line 34 in pkg/common/oras/authprovider/azure/helper.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/helper.go#L33-L34

Added lines #L33 - L34 were not covered by tests
}

func DefaultAuthClientFactory(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options)
if err != nil {
Expand All @@ -32,14 +42,6 @@ func DefaultAuthClientFactory(serverURL string, options *azcontainerregistry.Aut
return &AuthenticationClientWrapper{client: client}, nil

Check warning on line 42 in pkg/common/oras/authprovider/azure/helper.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/helper.go#L42

Added line #L42 was not covered by tests
}

func DefaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) {
return azureauth.GetAADAccessToken(ctx, tenantID, clientID, resource)
}

func DefaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) {
logger.GetLogger(ctx, logOpt).Infof("Metrics Report: Duration=%dms, Host=%s", duration, artifactHostName)
}

type AuthenticationClientWrapper struct {
client *azcontainerregistry.AuthenticationClient
}
Expand All @@ -51,3 +53,16 @@ func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(c
type AuthClient interface {
ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error)
}

// RegistryHostGetter defines an interface for getting the registry host.
type RegistryHostGetter interface {
GetRegistryHost(artifact string) (string, error)
}

// DefaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter.
type DefaultRegistryHostGetterImpl struct{}

func (g *DefaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) {

Check warning on line 65 in pkg/common/oras/authprovider/azure/helper.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/helper.go#L65

Added line #L65 was not covered by tests
// Implement the logic to get the registry host
return provider.GetRegistryHostName(artifact)

Check warning on line 67 in pkg/common/oras/authprovider/azure/helper.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/helper.go#L67

Added line #L67 was not covered by tests
}
136 changes: 136 additions & 0 deletions pkg/common/oras/authprovider/azure/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
Copyright The Ratify Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package azure

import (
"context"

"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
"github.com/stretchr/testify/mock"
)

// MockAuthClient is a mock implementation of AuthClient.
type MockAuthClient struct {
mock.Mock
}

// Mock method for ExchangeAADAccessTokenForACRRefreshToken
func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
args := m.Called(ctx, grantType, service, options)
return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1)
}

// MockAuthClientFactory is a mock implementation of AuthClientFactory.
type MockAuthClientFactory struct {
mock.Mock
}

// Mock method for CreateAuthClient
func (m *MockAuthClientFactory) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
args := m.Called(serverURL, options)
return args.Get(0).(AuthClient), args.Error(1)
}

// MockRegistryHostGetter is a mock implementation of RegistryHostGetter.
type MockRegistryHostGetter struct {
mock.Mock
}

// Mock method for GetRegistryHost
func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error) {
args := m.Called(artifact)
return args.String(0), args.Error(1)
}

// // TestDefaultAuthClientFactoryImpl tests the default factory implementation.
// func TestDefaultAuthClientFactoryImpl(t *testing.T) {
// mockFactory := new(MockAuthClientFactory)
// mockAuthClient := new(MockAuthClient)

// serverURL := "https://example.azurecr.io"
// options := &azcontainerregistry.AuthenticationClientOptions{}

// // Set up expectations
// mockFactory.On("CreateAuthClient", serverURL, options).Return(mockAuthClient, nil)

// factory := &DefaultAuthClientFactoryImpl{}
// client, err := factory.CreateAuthClient(serverURL, options)

// // Verify expectations
// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options)
// assert.NoError(t, err)
// assert.NotNil(t, client)
// }

// // TestDefaultAuthClientFactory_Error tests error handling during client creation.
// func TestDefaultAuthClientFactory_Error(t *testing.T) {
// mockFactory := new(MockAuthClientFactory)

// serverURL := "https://example.azurecr.io"
// options := &azcontainerregistry.AuthenticationClientOptions{}
// expectedError := errors.New("failed to create client")

// // Set up expectations
// mockFactory.On("CreateAuthClient", serverURL, options).Return(nil, expectedError)

// factory := &DefaultAuthClientFactoryImpl{}
// client, err := factory.CreateAuthClient(serverURL, options)

// // Verify expectations
// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options)
// assert.Error(t, err)
// assert.Nil(t, client)
// assert.Equal(t, expectedError, err)
// }

// // TestGetRegistryHost tests the GetRegistryHost function.
// func TestGetRegistryHost(t *testing.T) {
// mockGetter := new(MockRegistryHostGetter)

// artifact := "test/artifact"
// expectedHost := "example.azurecr.io"

// // Set up expectations
// mockGetter.On("GetRegistryHost", artifact).Return(expectedHost, nil)

// getter := &DefaultRegistryHostGetterImpl{}
// host, err := getter.GetRegistryHost(artifact)

// // Verify expectations
// mockGetter.AssertCalled(t, "GetRegistryHost", artifact)
// assert.NoError(t, err)
// assert.Equal(t, expectedHost, host)
// }

// // TestGetRegistryHost_Error tests error handling in GetRegistryHost.
// func TestGetRegistryHost_Error(t *testing.T) {
// mockGetter := new(MockRegistryHostGetter)

// artifact := "test/artifact"
// expectedError := errors.New("failed to get registry host")

// // Set up expectations
// mockGetter.On("GetRegistryHost", artifact).Return("", expectedError)

// getter := &DefaultRegistryHostGetterImpl{}
// host, err := getter.GetRegistryHost(artifact)

// // Verify expectations
// mockGetter.AssertCalled(t, "GetRegistryHost", artifact)
// assert.Error(t, err)
// assert.Empty(t, host)
// assert.Equal(t, expectedError, err)
// }

0 comments on commit 13de70a

Please sign in to comment.