From 8666a69436069479bd5c0afbb62b1b31973dd2e8 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Mon, 11 Nov 2024 12:49:42 +1000 Subject: [PATCH] chore: rebase off the dev branch Signed-off-by: Shahram Kalantari --- .../azurekeyvault/provider.go | 111 ++++--- .../azurekeyvault/provider_test.go | 304 ++++++++++++++++-- 2 files changed, 340 insertions(+), 75 deletions(-) diff --git a/pkg/keymanagementprovider/azurekeyvault/provider.go b/pkg/keymanagementprovider/azurekeyvault/provider.go index e9b668479..271734e98 100644 --- a/pkg/keymanagementprovider/azurekeyvault/provider.go +++ b/pkg/keymanagementprovider/azurekeyvault/provider.go @@ -45,8 +45,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/azure" ) const ( @@ -79,41 +77,51 @@ type akvKMProvider struct { resource string certificates []types.KeyVaultValue keys []types.KeyVaultValue - keyKVClient *azkeys.Client - secretKVClient *azsecrets.Client - certificateKVClient *azcertificates.Client + keyKVClient keyKVClient + secretKVClient secretKVClient + certificateKVClient certificateKVClient } type akvKMProviderFactory struct{} -// // kvClient is an interface to interact with the keyvault client used for mocking purposes -// type kvClient interface { -// // GetCertificate retrieves a certificate from the keyvault -// GetCertificate(ctx context.Context, vaultBaseURL string, certificateName string, certificateVersion string) (kv.CertificateBundle, error) -// // GetKey retrieves a key from the keyvault -// GetKey(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string) (kv.KeyBundle, error) -// // GetSecret retrieves a secret from the keyvault -// GetSecret(ctx context.Context, vaultBaseURL string, secretName string, secretVersion string) (kv.SecretBundle, error) -// } - -// type kvClientImpl struct { -// kv.BaseClient -// } - -// // GetCertificate retrieves a certificate from the keyvault -// func (c *kvClientImpl) GetCertificate(ctx context.Context, vaultBaseURL string, certificateName string, certificateVersion string) (kv.CertificateBundle, error) { -// return c.BaseClient.GetCertificate(ctx, vaultBaseURL, certificateName, certificateVersion) -// } - -// // GetKey retrieves a key from the keyvault -// func (c *kvClientImpl) GetKey(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string) (kv.KeyBundle, error) { -// return c.BaseClient.GetKey(ctx, vaultBaseURL, keyName, keyVersion) -// } - -// // GetSecret retrieves a secret from the keyvault -// func (c *kvClientImpl) GetSecret(ctx context.Context, vaultBaseURL string, secretName string, secretVersion string) (kv.SecretBundle, error) { -// return c.BaseClient.GetSecret(ctx, vaultBaseURL, secretName, secretVersion) -// } +// kvClient is an interface to interact with the keyvault client used for mocking purposes +type keyKVClient interface { + // GetKey retrieves a key from the keyvault + GetKey(ctx context.Context, keyName string, keyVersion string) (azkeys.GetKeyResponse, error) +} +type secretKVClient interface { + // GetSecret retrieves a secret from the keyvault + GetSecret(ctx context.Context, secretName string, secretVersion string) (azsecrets.GetSecretResponse, error) +} +type certificateKVClient interface { + // GetCertificate retrieves a certificate from the keyvault + GetCertificate(ctx context.Context, certificateName string, certificateVersion string) (azcertificates.GetCertificateResponse, error) +} + +type keyKVClientImpl struct { + azkeys.Client +} +type secretKVClientImpl struct { + azsecrets.Client +} +type certificateKVClientImpl struct { + azcertificates.Client +} + +// GetCertificate retrieves a certificate from the keyvault +func (c *certificateKVClientImpl) GetCertificate(ctx context.Context, certificateName string, certificateVersion string) (azcertificates.GetCertificateResponse, error) { + return c.Client.GetCertificate(ctx, certificateName, certificateVersion, nil) +} + +// GetKey retrieves a key from the keyvault +func (c *keyKVClientImpl) GetKey(ctx context.Context, keyName string, keyVersion string) (azkeys.GetKeyResponse, error) { + return c.Client.GetKey(ctx, keyName, keyVersion, nil) +} + +// GetSecret retrieves a secret from the keyvault +func (c *secretKVClientImpl) GetSecret(ctx context.Context, secretName string, secretVersion string) (azsecrets.GetSecretResponse, error) { + return c.Client.GetSecret(ctx, secretName, secretVersion, nil) +} // initKVClient is a function to initialize the keyvault client // used for mocking purposes @@ -163,9 +171,9 @@ func (f *akvKMProviderFactory) Create(_ string, keyManagementProviderConfig conf return nil, re.ErrorCodePluginInitFailure.NewError(re.KeyManagementProvider, ProviderName, re.AKVLink, err, "failed to create keyvault client", re.HideStackTrace) } - provider.keyKVClient = keyKVClient - provider.secretKVClient = secretKVClient - provider.certificateKVClient = certificateKVClient + provider.keyKVClient = &keyKVClientImpl{*keyKVClient} + provider.secretKVClient = &secretKVClientImpl{*secretKVClient} + provider.certificateKVClient = &certificateKVClientImpl{*certificateKVClient} return provider, nil } @@ -179,11 +187,16 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp logger.GetLogger(ctx, logOpt).Debugf("fetching secret from key vault, certName %v, certVersion %v", keyVaultCert.Name) startTime := time.Now() - secretResponse, err := s.secretKVClient.GetSecret(ctx, keyVaultCert.Name, keyVaultCert.Version, nil) + secretResponse, err := s.secretKVClient.GetSecret(ctx, keyVaultCert.Name, keyVaultCert.Version) if err != nil { + // I am aware that there are so many logs here and inside isSecretDisabledError, but I am trying to understand the structure of the error + // I'll make sure to remove them before merging + logger.GetLogger(ctx, logOpt).Infof("s.secretKVClient.GetSecret errored:, err %v", err) if isSecretDisabledError(err) { // if secret is disabled, get the version of the certificate for status - certResponse, err := s.certificateKVClient.GetCertificate(ctx, keyVaultCert.Name, keyVaultCert.Version, nil) + logger.GetLogger(ctx, logOpt).Infof("calling s.certificateKVClient.GetCertificate:, keyVaultCert.Name %v, keyVaultCert.Version %v", keyVaultCert.Name, keyVaultCert.Version) + certResponse, err := s.certificateKVClient.GetCertificate(ctx, keyVaultCert.Name, keyVaultCert.Version) + logger.GetLogger(ctx, logOpt).Infof("s.certificateKVClient.GetCertificate was called,checking for possible errors") if err != nil { return nil, nil, fmt.Errorf("failed to get certificate objectName:%s, objectVersion:%s, error: %w", keyVaultCert.Name, keyVaultCert.Version, err) } @@ -192,16 +205,16 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp isEnabled := *certBundle.Attributes.Enabled lastRefreshed := startTime.Format(time.RFC3339) certProperty := getStatusProperty(keyVaultCert.Name, keyVaultCert.Version, lastRefreshed, isEnabled) + logger.GetLogger(ctx, logOpt).Infof("certProperty %v", certProperty) certsStatus = append(certsStatus, certProperty) mapKey := keymanagementprovider.KMPMapKey{Name: keyVaultCert.Name, Version: keyVaultCert.Version, Enabled: isEnabled} keymanagementprovider.DeleteCertificateFromMap(s.resource, mapKey) continue } - return nil, nil, fmt.Errorf("failed to get secret objectName:%s, objectVersion:%s, error: %w", keyVaultCert.Name, keyVaultCert.Version, err) } - secretBundle := secretResponse.SecretBundle + secretBundle := secretResponse.SecretBundle isEnabled := *secretBundle.Attributes.Enabled certResult, certProperty, err := getCertsFromSecretBundle(ctx, secretBundle, keyVaultCert.Name, isEnabled) @@ -227,7 +240,7 @@ func (s *akvKMProvider) GetKeys(ctx context.Context) (map[keymanagementprovider. // fetch the key object from Key Vault startTime := time.Now() - keyResponse, err := s.keyKVClient.GetKey(ctx, keyVaultKey.Name, keyVaultKey.Version, nil) + keyResponse, err := s.keyKVClient.GetKey(ctx, keyVaultKey.Name, keyVaultKey.Version) if err != nil { return nil, nil, fmt.Errorf("failed to get key objectName:%s, objectVersion:%s, error: %w", keyVaultKey.Name, keyVaultKey.Version, err) } @@ -435,15 +448,17 @@ func getObjectVersion(id string) string { } func isSecretDisabledError(err error) bool { - var de autorest.DetailedError - if errors.As(err, &de) { - var re *azure.RequestError - if errors.As(de.Original, &re) { - if re.ServiceError.Code == "SecretDisabled" { - return true - } + logger.GetLogger(context.Background(), logOpt).Infof("Inside isSecretDisabledError ------------------") + var responseError *azcore.ResponseError + if errors.As(err, &responseError) { + // Is there a better way to check if the error is a secret disabled error? + // if responseError.ErrorCode == "Forbidden" && strings.Contains(responseError.Error(), "SecretDisabled") { + if strings.Contains(responseError.Error(), "Forbidden") && strings.Contains(responseError.Error(), "SecretDisabled") { + logger.GetLogger(context.Background(), logOpt).Infof("Leaving isSecretDisabledError returnning true ------------------") + return true } } + logger.GetLogger(context.Background(), logOpt).Infof("Leaving isSecretDisabledError returnning falses ------------------") return false } diff --git a/pkg/keymanagementprovider/azurekeyvault/provider_test.go b/pkg/keymanagementprovider/azurekeyvault/provider_test.go index e637982c6..d484934aa 100644 --- a/pkg/keymanagementprovider/azurekeyvault/provider_test.go +++ b/pkg/keymanagementprovider/azurekeyvault/provider_test.go @@ -29,30 +29,12 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets" - "github.com/Azure/go-autorest/autorest/azure" "github.com/ratify-project/ratify/pkg/keymanagementprovider/azurekeyvault/types" "github.com/ratify-project/ratify/pkg/keymanagementprovider/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) -func SkipTestInitializeKVClient(t *testing.T) { - testEnvs := []azure.Environment{ - azure.PublicCloud, - azure.GermanCloud, - azure.ChinaCloud, - azure.USGovernmentCloud, - } - - for i := range testEnvs { - keyKVClient, secretKVClient, certificateKVClient, err := initializeKvClient(testEnvs[i].KeyVaultEndpoint, "", "", nil) - assert.NoError(t, err) - assert.NotNil(t, keyKVClient) - assert.NotNil(t, secretKVClient) - assert.NotNil(t, certificateKVClient) - } -} - // TestCreate tests the Create function func TestCreate(t *testing.T) { factory := &akvKMProviderFactory{} @@ -168,7 +150,7 @@ func TestCreate(t *testing.T) { } // TestGetCertificates tests the GetCertificates function -func TestGetCertificates(t *testing.T) { +func TestGetCertificates_original(t *testing.T) { factory := &akvKMProviderFactory{} config := config.KeyManagementProviderConfig{ "vaultUri": "https://testkv.vault.azure.net/", @@ -193,8 +175,281 @@ func TestGetCertificates(t *testing.T) { assert.Nil(t, certStatus) } +type MockKeyKVClient struct { + GetKeyFunc func(ctx context.Context, keyName string, keyVersion string) (azkeys.GetKeyResponse, error) +} +type MockSecretKVClient struct { + GetSecretFunc func(ctx context.Context, secretName string, secretVersion string) (azsecrets.GetSecretResponse, error) +} +type MockCertificateKVClient struct { + GetCertificateFunc func(ctx context.Context, certificateName string, certificateVersion string) (azcertificates.GetCertificateResponse, error) +} + +func (m *MockKeyKVClient) GetKey(ctx context.Context, keyName string, keyVersion string) (azkeys.GetKeyResponse, error) { + if m.GetKeyFunc != nil { + return m.GetKeyFunc(ctx, keyName, keyVersion) + } + return azkeys.GetKeyResponse{}, nil +} +func (m *MockSecretKVClient) GetSecret(ctx context.Context, secretName string, secretVersion string) (azsecrets.GetSecretResponse, error) { + if m.GetSecretFunc != nil { + return m.GetSecretFunc(ctx, secretName, secretVersion) + } + return azsecrets.GetSecretResponse{}, nil +} +func (m *MockCertificateKVClient) GetCertificate(ctx context.Context, certificateName string, certificateVersion string) (azcertificates.GetCertificateResponse, error) { + if m.GetCertificateFunc != nil { + return m.GetCertificateFunc(ctx, certificateName, certificateVersion) + } + return azcertificates.GetCertificateResponse{}, nil +} + +// stringPtr returns a pointer to the given string. +func stringPtr(s string) *string { + return &s +} + +// boolPtr returns a pointer to the given bool. +func boolPtr(b bool) *bool { + return &b +} + +// TestGetCertificates tests the GetCertificates function +func TestGetCertificates(t *testing.T) { + certID := azcertificates.ID("https://testkv.vault.azure.net/certificates/cert1") + secretID := azsecrets.ID("https://testkv.vault.azure.net/secrets/secret1") + testCases := []struct { + name string + mockKeyKVClient *MockKeyKVClient + mockSecretKVClient *MockSecretKVClient + mockCertificateKVClient *MockCertificateKVClient + expectedErr bool + }{ + { + name: "GetSecret error", + mockSecretKVClient: &MockSecretKVClient{ + GetSecretFunc: func(_ context.Context, _ string, _ string) (azsecrets.GetSecretResponse, error) { + return azsecrets.GetSecretResponse{}, errors.New("error") + }, + }, + expectedErr: true, + }, + { + name: "Certificate disabled", + mockCertificateKVClient: &MockCertificateKVClient{ + GetCertificateFunc: func(_ context.Context, _ string, _ string) (azcertificates.GetCertificateResponse, error) { + return azcertificates.GetCertificateResponse{ + CertificateBundle: azcertificates.CertificateBundle{ + ID: &certID, + KID: stringPtr("https://testkv.vault.azure.net/keys/key1"), + Attributes: &azcertificates.CertificateAttributes{ + Enabled: boolPtr(false), + }, + }, + }, nil + }, + }, + mockSecretKVClient: &MockSecretKVClient{ + GetSecretFunc: func(_ context.Context, _ string, _ string) (azsecrets.GetSecretResponse, error) { + err := azcore.ResponseError{ + ErrorCode: "Forbidden, SecretDisabled", + } + return azsecrets.GetSecretResponse{}, &err + }, + }, + expectedErr: false, + }, + { + name: "Certificate disabled error", + mockCertificateKVClient: &MockCertificateKVClient{ + GetCertificateFunc: func(_ context.Context, _ string, _ string) (azcertificates.GetCertificateResponse, error) { + return azcertificates.GetCertificateResponse{}, errors.New("error") + }, + }, + mockSecretKVClient: &MockSecretKVClient{ + GetSecretFunc: func(_ context.Context, _ string, _ string) (azsecrets.GetSecretResponse, error) { + err := azcore.ResponseError{ + ErrorCode: "SecretDisabled", + } + return azsecrets.GetSecretResponse{}, &err + }, + }, + expectedErr: true, + }, + { + name: "Certificate enabled", + mockCertificateKVClient: &MockCertificateKVClient{ + GetCertificateFunc: func(_ context.Context, _ string, _ string) (azcertificates.GetCertificateResponse, error) { + return azcertificates.GetCertificateResponse{ + CertificateBundle: azcertificates.CertificateBundle{ + ID: &certID, + KID: stringPtr("https://testkv.vault.azure.net/keys/key1"), + Attributes: &azcertificates.CertificateAttributes{ + Enabled: boolPtr(true), + }, + }, + }, nil + }, + }, + mockSecretKVClient: &MockSecretKVClient{ + GetSecretFunc: func(_ context.Context, _ string, _ string) (azsecrets.GetSecretResponse, error) { + return azsecrets.GetSecretResponse{ + SecretBundle: azsecrets.SecretBundle{ + ID: &secretID, + Kid: stringPtr("https://testkv.vault.azure.net/keys/key1"), + ContentType: stringPtr("application/x-pem-file"), + Attributes: &azsecrets.SecretAttributes{ + Enabled: boolPtr(true), + }, + Value: stringPtr("-----BEGIN CERTIFICATE-----\nMIIC8TCCAdmgAwIBAgIUaNrwbhs/I1ecqUYdzD2xuAVNdmowDQYJKoZIhvcNAQEL\nBQAwKjEPMA0GA1UECgwGUmF0aWZ5MRcwFQYDVQQDDA5SYXRpZnkgUm9vdCBDQTAe\nFw0yMzA2MjEwMTIyMzdaFw0yNDA2MjAwMTIyMzdaMBkxFzAVBgNVBAMMDnJhdGlm\neS5kZWZhdWx0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtskG1BUt\n4Fw2lbm53KbwZb1hnLmWdwRotZyznhhk/yrUDcq3uF6klwpk/E2IKfUKIo6doHSk\nXaEZXR68UtXygvA4wdg7xZ6kKpXy0gu+RxGE6CGtDHTyDDzITu+NBjo21ZSsyGpQ\nJeIKftUCHdwdygKf0CdJx8A29GBRpHGCmJadmt7tTzOnYjmbuPVLeqJo/Ex9qXcG\nZbxoxnxr5NCocFeKx+EbLo+k/KjdFB2PKnhgzxAaMMMP6eXPr8l5AlzkC83EmPvN\ntveuaBbamdlFkD+53TZeZlxt3GIdq93Iw/UpbQ/pvhbrztMT+UVEkm15sShfX8Xn\nL2st5A4n0V+66QIDAQABoyAwHjAMBgNVHRMBAf8EAjAAMA4GA1UdDwEB/wQEAwIH\ngDANBgkqhkiG9w0BAQsFAAOCAQEAGpOqozyfDSBjoTepsRroxxcZ4sq65gw45Bme\nm36BS6FG0WHIg3cMy6KIIBefTDSKrPkKNTtuF25AeGn9jM+26cnfDM78ZH0+Lnn7\n7hs0MA64WMPQaWs9/+89aM9NADV9vp2zdG4xMi6B7DruvKWyhJaNoRqK/qP6LdSQ\nw8M+21sAHvXgrRkQtJlVOzVhgwt36NOb1hzRlQiZB+nhv2Wbw7fbtAaADk3JAumf\nvM+YdPS1KfAFaYefm4yFd+9/C0KOkHico3LTbELO5hG0Mo/EYvtjM+Fljb42EweF\n3nAx1GSPe5Tn8p3h6RyJW5HIKozEKyfDuLS0ccB/nqT3oNjcTw==\n-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----\nMIIDRTCCAi2gAwIBAgIUcC33VfaMhOnsl7avNTRVQozoVtUwDQYJKoZIhvcNAQEL\nBQAwKjEPMA0GA1UECgwGUmF0aWZ5MRcwFQYDVQQDDA5SYXRpZnkgUm9vdCBDQTAe\nFw0yMzA2MjEwMTIyMzZaFw0yMzA2MjIwMTIyMzZaMCoxDzANBgNVBAoMBlJhdGlm\neTEXMBUGA1UEAwwOUmF0aWZ5IFJvb3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB\nDwAwggEKAoIBAQDDFhDnyPrVDZaeRu6Tbg1a/iTwus+IuX+h8aKhKS1yHz4EF/Lz\nxCy7lNSQ9srGMMVumWuNom/ydIphff6PejZM1jFKPU6OQR/0JX5epcVIjbKa562T\nDguUxJ+h5V3EIyM4RqOWQ2g/xZo86x5TzyNJXiVdHHRvmDvUNwPpMeDjr/EHVAni\n5YQObxkJRiiZ7XOa5zz3YztVm8sSZAwPWroY1HIfvtP+KHpiNDIKSymmuJkH4SEr\nJn++iqN8na18a9DFBPTTrLPe3CxATGrMfosCMZ6LP3iFLLc/FaSpwcnugWdewsUK\nYs+sUY7jFWR7x7/1nyFWyRrQviM4f4TY+K7NAgMBAAGjYzBhMB0GA1UdDgQWBBQH\nYePW7QPP2p1utr3r6gqzEkKs+DAfBgNVHSMEGDAWgBQHYePW7QPP2p1utr3r6gqz\nEkKs+DAPBgNVHRMBAf8EBTADAQH/MA4GA1UdDwEB/wQEAwICBDANBgkqhkiG9w0B\nAQsFAAOCAQEAjKp4vx3bFaKVhAbQeTsDjWJgmXLK2vLgt74MiUwSF6t0wehlfszE\nIcJagGJsvs5wKFf91bnwiqwPjmpse/thPNBAxh1uEoh81tOklv0BN790vsVpq3t+\ncnUvWPiCZdRlAiGGFtRmKk3Keq4sM6UdiUki9s+wnxypHVb4wIpVxu5R271Lnp5I\n+rb2EQ48iblt4XZPczf/5QJdTgbItjBNbuO8WVPOqUIhCiFuAQziLtNUq3p81dHO\nQ2BPgmaitCpIUYHVYighLauBGCH8xOFzj4a4KbOxKdxyJTd0La/vRCKaUtJX67Lc\nfQYVR9HXQZ0YlmwPcmIG5v7wBfcW34NUvA==\n-----END CERTIFICATE-----\n"), + }, + }, nil + }, + }, + expectedErr: false, + }, + { + name: "getCertsFromSecretBundle error", + mockSecretKVClient: &MockSecretKVClient{ + GetSecretFunc: func(_ context.Context, _ string, _ string) (azsecrets.GetSecretResponse, error) { + return azsecrets.GetSecretResponse{ + SecretBundle: azsecrets.SecretBundle{ + ContentType: stringPtr("test"), + ID: &secretID, + Kid: stringPtr("https://testkv.vault.azure.net/keys/key1"), + Attributes: &azsecrets.SecretAttributes{ + Enabled: boolPtr(true), + }, + Value: stringPtr("-----BEGIN CERTIFICATE-----\nMIIC8TCCAdmgAwIBAgIUaNrwbhs/I1ecqUYdzD2xuAVNdmowDQYJKoZIhvcNAQEL\nBQAwKjEPMA0GA1UECgwGUmF0aWZ5MRcwFQYDVQQDDA5SYXRpZnkgUm9vdCBDQTAe\nFw0yMzA2MjEwMTIyMzdaFw0yNDA2MjAwMTIyMzdaMBkxFzAVBgNVBAMMDnJhdGlm\neS5kZWZhdWx0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtskG1BUt\n4Fw2lbm53KbwZb1hnLmWdwRotZyznhhk/yrUDcq3uF6klwpk/E2IKfUKIo6doHSk\nXaEZXR68UtXygvA4wdg7xZ6kKpXy0gu+RxGE6CGtDHTyDDzITu+NBjo21ZSsyGpQ\nJeIKftUCHdwdygKf0CdJx8A29GBRpHGCmJadmt7tTzOnYjmbuPVLeqJo/Ex9qXcG\nZbxoxnxr5NCocFeKx+EbLo+k/KjdFB2PKnhgzxAaMMMP6eXPr8l5AlzkC83EmPvN\ntveuaBbamdlFkD+53TZeZlxt3GIdq93Iw/UpbQ/pvhbrztMT+UVEkm15sShfX8Xn\nL2st5A4n0V+66QIDAQABoyAwHjAMBgNVHRMBAf8EAjAAMA4GA1UdDwEB/wQEAwIH\ngDANBgkqhkiG9w0BAQsFAAOCAQEAGpOqozyfDSBjoTepsRroxxcZ4sq65gw45Bme\nm36BS6FG0WHIg3cMy6KIIBefTDSKrPkKNTtuF25AeGn9jM+26cnfDM78ZH0+Lnn7\n7hs0MA64WMPQaWs9/+89aM9NADV9vp2zdG4xMi6B7DruvKWyhJaNoRqK/qP6LdSQ\nw8M+21sAHvXgrRkQtJlVOzVhgwt36NOb1hzRlQiZB+nhv2Wbw7fbtAaADk3JAumf\nvM+YdPS1KfAFaYefm4yFd+9/C0KOkHico3LTbELO5hG0Mo/EYvtjM+Fljb42EweF\n3nAx1GSPe5Tn8p3h6RyJW5HIKozEKyfDuLS0ccB/nqT3oNjcTw==\n-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----\nMIIDRTCCAi2gAwIBAgIUcC33VfaMhOnsl7avNTRVQozoVtUwDQYJKoZIhvcNAQEL\nBQAwKjEPMA0GA1UECgwGUmF0aWZ5MRcwFQYDVQQDDA5SYXRpZnkgUm9vdCBDQTAe\nFw0yMzA2MjEwMTIyMzZaFw0yMzA2MjIwMTIyMzZaMCoxDzANBgNVBAoMBlJhdGlm\neTEXMBUGA1UEAwwOUmF0aWZ5IFJvb3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB\nDwAwggEKAoIBAQDDFhDnyPrVDZaeRu6Tbg1a/iTwus+IuX+h8aKhKS1yHz4EF/Lz\nxCy7lNSQ9srGMMVumWuNom/ydIphff6PejZM1jFKPU6OQR/0JX5epcVIjbKa562T\nDguUxJ+h5V3EIyM4RqOWQ2g/xZo86x5TzyNJXiVdHHRvmDvUNwPpMeDjr/EHVAni\n5YQObxkJRiiZ7XOa5zz3YztVm8sSZAwPWroY1HIfvtP+KHpiNDIKSymmuJkH4SEr\nJn++iqN8na18a9DFBPTTrLPe3CxATGrMfosCMZ6LP3iFLLc/FaSpwcnugWdewsUK\nYs+sUY7jFWR7x7/1nyFWyRrQviM4f4TY+K7NAgMBAAGjYzBhMB0GA1UdDgQWBBQH\nYePW7QPP2p1utr3r6gqzEkKs+DAfBgNVHSMEGDAWgBQHYePW7QPP2p1utr3r6gqz\nEkKs+DAPBgNVHRMBAf8EBTADAQH/MA4GA1UdDwEB/wQEAwICBDANBgkqhkiG9w0B\nAQsFAAOCAQEAjKp4vx3bFaKVhAbQeTsDjWJgmXLK2vLgt74MiUwSF6t0wehlfszE\nIcJagGJsvs5wKFf91bnwiqwPjmpse/thPNBAxh1uEoh81tOklv0BN790vsVpq3t+\ncnUvWPiCZdRlAiGGFtRmKk3Keq4sM6UdiUki9s+wnxypHVb4wIpVxu5R271Lnp5I\n+rb2EQ48iblt4XZPczf/5QJdTgbItjBNbuO8WVPOqUIhCiFuAQziLtNUq3p81dHO\nQ2BPgmaitCpIUYHVYighLauBGCH8xOFzj4a4KbOxKdxyJTd0La/vRCKaUtJX67Lc\nfQYVR9HXQZ0YlmwPcmIG5v7wBfcW34NUvA==\n-----END CERTIFICATE-----\n"), + }, + }, nil + }, + }, + expectedErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provider := &akvKMProvider{ + certificates: []types.KeyVaultValue{ + { + Name: "cert1", + Version: "c1f03df1113d460491d970737dfdc35d", + }, + }, + keyKVClient: tc.mockKeyKVClient, + secretKVClient: tc.mockSecretKVClient, + certificateKVClient: tc.mockCertificateKVClient, + } + + _, _, err := provider.GetCertificates(context.Background()) + if tc.expectedErr != (err != nil) { + t.Fatalf("error = %v, expectedErr = %v", err, tc.expectedErr) + } + }) + } +} + // TestGetKeys tests the GetKeys function func TestGetKeys(t *testing.T) { + keyID := azkeys.ID("https://testkv.vault.azure.net/keys/key1") + keyTY := azkeys.JSONWebKeyTypeRSA + testCases := []struct { + name string + mockKeyKVClient *MockKeyKVClient + expectedErr bool + }{ + { + name: "GetKey error", + mockKeyKVClient: &MockKeyKVClient{ + GetKeyFunc: func(_ context.Context, _ string, _ string) (azkeys.GetKeyResponse, error) { + return azkeys.GetKeyResponse{}, errors.New("error") + }, + }, + expectedErr: true, + }, + { + name: "Key disabled", + mockKeyKVClient: &MockKeyKVClient{ + GetKeyFunc: func(_ context.Context, _ string, _ string) (azkeys.GetKeyResponse, error) { + return azkeys.GetKeyResponse{ + KeyBundle: azkeys.KeyBundle{ + Key: &azkeys.JSONWebKey{ + KID: &keyID, + }, + Attributes: &azkeys.KeyAttributes{ + Enabled: boolPtr(false), + }, + }, + }, nil + }, + }, + expectedErr: false, + }, + { + name: "getKeyFromKeyBundle error", + mockKeyKVClient: &MockKeyKVClient{ + GetKeyFunc: func(_ context.Context, _ string, _ string) (azkeys.GetKeyResponse, error) { + return azkeys.GetKeyResponse{ + KeyBundle: azkeys.KeyBundle{ + Key: &azkeys.JSONWebKey{ + KID: &keyID, + }, + Attributes: &azkeys.KeyAttributes{ + Enabled: boolPtr(true), + }, + }, + }, nil + }, + }, + expectedErr: true, + }, + { + name: "Key enabled", + mockKeyKVClient: &MockKeyKVClient{ + GetKeyFunc: func(_ context.Context, _ string, _ string) (azkeys.GetKeyResponse, error) { + return azkeys.GetKeyResponse{ + KeyBundle: azkeys.KeyBundle{ + Key: &azkeys.JSONWebKey{ + KID: &keyID, + Kty: &keyTY, + N: []byte("n"), + E: []byte("e"), + }, + Attributes: &azkeys.KeyAttributes{ + Enabled: boolPtr(true), + }, + }, + }, nil + }, + }, + expectedErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provider := &akvKMProvider{ + keys: []types.KeyVaultValue{ + { + Name: "key1", + Version: "c1f03df1113d460491d970737dfdc35d", + }, + }, + keyKVClient: tc.mockKeyKVClient, + } + + _, _, err := provider.GetKeys(context.Background()) + if tc.expectedErr != (err != nil) { + t.Fatalf("error = %v, expectedErr = %v", err, tc.expectedErr) + } + }) + } +} + +// TestGetKeys tests the GetKeys function +func TestGetKeys_original(t *testing.T) { factory := &akvKMProviderFactory{} config := config.KeyManagementProviderConfig{ "vaultUri": "https://testkv.vault.azure.net/", @@ -668,11 +923,12 @@ func TestGetKeyFromKeyBundlex(t *testing.T) { } } +const tenantID = "tenant-id" +const clientID = "client-id" + func TestInitializeKvClient_Success(t *testing.T) { // Mock the context and input parameters keyVaultEndpoint := "https://myvault.vault.azure.net/" - tenantID := "tenant-id" - clientID := "client-id" // Create a mock credential provider mockCredential, err := azidentity.NewClientSecretCredential(tenantID, clientID, "fake-secret", nil) @@ -693,8 +949,6 @@ func TestInitializeKvClient_Success(t *testing.T) { func TestInitializeKvClient_FailureInAzKeysClient(t *testing.T) { // Mock the context and input parameters keyVaultEndpoint := "https://invalid-vault.vault.azure.net/" - tenantID := "mock_tenant-id" - clientID := "mock_client-id" // Run the function keysKVClient, secretsKVClient, certificatesKVClient, err := initializeKvClient(keyVaultEndpoint, tenantID, clientID, nil) @@ -710,8 +964,6 @@ func TestInitializeKvClient_FailureInAzKeysClient(t *testing.T) { func TestInitializeKvClient_FailureInAzSecretsClient(t *testing.T) { // Mock the context and input parameters keyVaultEndpoint := "https://valid-vault.vault.azure.net/" - tenantID := "tenant-id" - clientID := "client-id" // Modify the azsecrets.NewClient function to simulate failure // Run the function @@ -728,8 +980,6 @@ func TestInitializeKvClient_FailureInAzSecretsClient(t *testing.T) { func TestInitializeKvClient_FailureInAzCertificatesClient(t *testing.T) { // Mock the context and input parameters keyVaultEndpoint := "https://valid-vault.vault.azure.net/" - tenantID := "tenant-id" - clientID := "client-id" // Modify the azsecrets.NewClient function to simulate failure // Run the function