Skip to content

Commit

Permalink
chore: rebase off the dev branch
Browse files Browse the repository at this point in the history
Signed-off-by: Shahram Kalantari <[email protected]>
  • Loading branch information
shahramk64 committed Nov 11, 2024
1 parent 170ef82 commit 8666a69
Show file tree
Hide file tree
Showing 2 changed files with 340 additions and 75 deletions.
111 changes: 63 additions & 48 deletions pkg/keymanagementprovider/azurekeyvault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
)

const (
Expand Down Expand Up @@ -79,41 +77,51 @@ type akvKMProvider struct {
resource string
certificates []types.KeyVaultValue
keys []types.KeyVaultValue
keyKVClient *azkeys.Client
secretKVClient *azsecrets.Client
certificateKVClient *azcertificates.Client
keyKVClient keyKVClient
secretKVClient secretKVClient
certificateKVClient certificateKVClient
}

type akvKMProviderFactory struct{}

// // kvClient is an interface to interact with the keyvault client used for mocking purposes
// type kvClient interface {
// // GetCertificate retrieves a certificate from the keyvault
// GetCertificate(ctx context.Context, vaultBaseURL string, certificateName string, certificateVersion string) (kv.CertificateBundle, error)
// // GetKey retrieves a key from the keyvault
// GetKey(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string) (kv.KeyBundle, error)
// // GetSecret retrieves a secret from the keyvault
// GetSecret(ctx context.Context, vaultBaseURL string, secretName string, secretVersion string) (kv.SecretBundle, error)
// }

// type kvClientImpl struct {
// kv.BaseClient
// }

// // GetCertificate retrieves a certificate from the keyvault
// func (c *kvClientImpl) GetCertificate(ctx context.Context, vaultBaseURL string, certificateName string, certificateVersion string) (kv.CertificateBundle, error) {
// return c.BaseClient.GetCertificate(ctx, vaultBaseURL, certificateName, certificateVersion)
// }

// // GetKey retrieves a key from the keyvault
// func (c *kvClientImpl) GetKey(ctx context.Context, vaultBaseURL string, keyName string, keyVersion string) (kv.KeyBundle, error) {
// return c.BaseClient.GetKey(ctx, vaultBaseURL, keyName, keyVersion)
// }

// // GetSecret retrieves a secret from the keyvault
// func (c *kvClientImpl) GetSecret(ctx context.Context, vaultBaseURL string, secretName string, secretVersion string) (kv.SecretBundle, error) {
// return c.BaseClient.GetSecret(ctx, vaultBaseURL, secretName, secretVersion)
// }
// kvClient is an interface to interact with the keyvault client used for mocking purposes
type keyKVClient interface {
// GetKey retrieves a key from the keyvault
GetKey(ctx context.Context, keyName string, keyVersion string) (azkeys.GetKeyResponse, error)
}
type secretKVClient interface {
// GetSecret retrieves a secret from the keyvault
GetSecret(ctx context.Context, secretName string, secretVersion string) (azsecrets.GetSecretResponse, error)
}
type certificateKVClient interface {
// GetCertificate retrieves a certificate from the keyvault
GetCertificate(ctx context.Context, certificateName string, certificateVersion string) (azcertificates.GetCertificateResponse, error)
}

type keyKVClientImpl struct {
azkeys.Client
}
type secretKVClientImpl struct {
azsecrets.Client
}
type certificateKVClientImpl struct {
azcertificates.Client
}

// GetCertificate retrieves a certificate from the keyvault
func (c *certificateKVClientImpl) GetCertificate(ctx context.Context, certificateName string, certificateVersion string) (azcertificates.GetCertificateResponse, error) {
return c.Client.GetCertificate(ctx, certificateName, certificateVersion, nil)
}

// GetKey retrieves a key from the keyvault
func (c *keyKVClientImpl) GetKey(ctx context.Context, keyName string, keyVersion string) (azkeys.GetKeyResponse, error) {
return c.Client.GetKey(ctx, keyName, keyVersion, nil)
}

// GetSecret retrieves a secret from the keyvault
func (c *secretKVClientImpl) GetSecret(ctx context.Context, secretName string, secretVersion string) (azsecrets.GetSecretResponse, error) {
return c.Client.GetSecret(ctx, secretName, secretVersion, nil)
}

// initKVClient is a function to initialize the keyvault client
// used for mocking purposes
Expand Down Expand Up @@ -163,9 +171,9 @@ func (f *akvKMProviderFactory) Create(_ string, keyManagementProviderConfig conf
return nil, re.ErrorCodePluginInitFailure.NewError(re.KeyManagementProvider, ProviderName, re.AKVLink, err, "failed to create keyvault client", re.HideStackTrace)
}

provider.keyKVClient = keyKVClient
provider.secretKVClient = secretKVClient
provider.certificateKVClient = certificateKVClient
provider.keyKVClient = &keyKVClientImpl{*keyKVClient}
provider.secretKVClient = &secretKVClientImpl{*secretKVClient}
provider.certificateKVClient = &certificateKVClientImpl{*certificateKVClient}

return provider, nil
}
Expand All @@ -179,11 +187,16 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp
logger.GetLogger(ctx, logOpt).Debugf("fetching secret from key vault, certName %v, certVersion %v", keyVaultCert.Name)

startTime := time.Now()
secretResponse, err := s.secretKVClient.GetSecret(ctx, keyVaultCert.Name, keyVaultCert.Version, nil)
secretResponse, err := s.secretKVClient.GetSecret(ctx, keyVaultCert.Name, keyVaultCert.Version)
if err != nil {
// I am aware that there are so many logs here and inside isSecretDisabledError, but I am trying to understand the structure of the error
// I'll make sure to remove them before merging
logger.GetLogger(ctx, logOpt).Infof("s.secretKVClient.GetSecret errored:, err %v", err)
if isSecretDisabledError(err) {
// if secret is disabled, get the version of the certificate for status
certResponse, err := s.certificateKVClient.GetCertificate(ctx, keyVaultCert.Name, keyVaultCert.Version, nil)
logger.GetLogger(ctx, logOpt).Infof("calling s.certificateKVClient.GetCertificate:, keyVaultCert.Name %v, keyVaultCert.Version %v", keyVaultCert.Name, keyVaultCert.Version)
certResponse, err := s.certificateKVClient.GetCertificate(ctx, keyVaultCert.Name, keyVaultCert.Version)
logger.GetLogger(ctx, logOpt).Infof("s.certificateKVClient.GetCertificate was called,checking for possible errors")
if err != nil {
return nil, nil, fmt.Errorf("failed to get certificate objectName:%s, objectVersion:%s, error: %w", keyVaultCert.Name, keyVaultCert.Version, err)
}
Expand All @@ -192,16 +205,16 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp
isEnabled := *certBundle.Attributes.Enabled
lastRefreshed := startTime.Format(time.RFC3339)
certProperty := getStatusProperty(keyVaultCert.Name, keyVaultCert.Version, lastRefreshed, isEnabled)
logger.GetLogger(ctx, logOpt).Infof("certProperty %v", certProperty)
certsStatus = append(certsStatus, certProperty)
mapKey := keymanagementprovider.KMPMapKey{Name: keyVaultCert.Name, Version: keyVaultCert.Version, Enabled: isEnabled}
keymanagementprovider.DeleteCertificateFromMap(s.resource, mapKey)
continue
}

return nil, nil, fmt.Errorf("failed to get secret objectName:%s, objectVersion:%s, error: %w", keyVaultCert.Name, keyVaultCert.Version, err)
}
secretBundle := secretResponse.SecretBundle

secretBundle := secretResponse.SecretBundle
isEnabled := *secretBundle.Attributes.Enabled

certResult, certProperty, err := getCertsFromSecretBundle(ctx, secretBundle, keyVaultCert.Name, isEnabled)
Expand All @@ -227,7 +240,7 @@ func (s *akvKMProvider) GetKeys(ctx context.Context) (map[keymanagementprovider.

// fetch the key object from Key Vault
startTime := time.Now()
keyResponse, err := s.keyKVClient.GetKey(ctx, keyVaultKey.Name, keyVaultKey.Version, nil)
keyResponse, err := s.keyKVClient.GetKey(ctx, keyVaultKey.Name, keyVaultKey.Version)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key objectName:%s, objectVersion:%s, error: %w", keyVaultKey.Name, keyVaultKey.Version, err)
}
Expand Down Expand Up @@ -435,15 +448,17 @@ func getObjectVersion(id string) string {
}

func isSecretDisabledError(err error) bool {
var de autorest.DetailedError
if errors.As(err, &de) {
var re *azure.RequestError
if errors.As(de.Original, &re) {
if re.ServiceError.Code == "SecretDisabled" {
return true
}
logger.GetLogger(context.Background(), logOpt).Infof("Inside isSecretDisabledError ------------------")
var responseError *azcore.ResponseError
if errors.As(err, &responseError) {
// Is there a better way to check if the error is a secret disabled error?
// if responseError.ErrorCode == "Forbidden" && strings.Contains(responseError.Error(), "SecretDisabled") {
if strings.Contains(responseError.Error(), "Forbidden") && strings.Contains(responseError.Error(), "SecretDisabled") {
logger.GetLogger(context.Background(), logOpt).Infof("Leaving isSecretDisabledError returnning true ------------------")
return true
}
}
logger.GetLogger(context.Background(), logOpt).Infof("Leaving isSecretDisabledError returnning falses ------------------")
return false
}

Expand Down
Loading

0 comments on commit 8666a69

Please sign in to comment.