From d6a14fffca7fc0dc4435c8128dde60f340358262 Mon Sep 17 00:00:00 2001 From: crazywolf132 Date: Mon, 2 Dec 2024 14:52:20 +1100 Subject: [PATCH] fix: added ARN support --- secret.go | 202 +++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 155 insertions(+), 47 deletions(-) diff --git a/secret.go b/secret.go index c3b99f4..00392c3 100644 --- a/secret.go +++ b/secret.go @@ -31,8 +31,10 @@ type Options struct { Transformers map[string]TransformFunc // CacheDuration specifies how long to cache values for CacheDuration time.Duration - cacheMu sync.RWMutex - cache map[string]*cachedValue + // PreloadARNs indicates whether to preload secrets from ARNs + PreloadARNs bool + cacheMu sync.RWMutex + cache map[string]*cachedValue } // ValidationFunc is a function type for custom validation @@ -67,6 +69,7 @@ type secret struct { awsKey string mu sync.RWMutex cache *cachedValue + required bool } type cachedValue struct { @@ -75,7 +78,7 @@ type cachedValue struct { } var ( - secretsCache map[string]string + secretsCache map[string]string = make(map[string]string) secretsMu sync.RWMutex secretsOnce sync.Once ) @@ -197,7 +200,7 @@ func parseTag(field reflect.StructField, opts *Options) (*secret, error) { } else { switch strings.TrimSpace(part) { case "required": - // Required is handled during Get + s.required = true case "base64": s.isBase64 = true case "json": @@ -273,6 +276,13 @@ func Fetch(ctx context.Context, v interface{}, opts *Options) error { opts.cache = make(map[string]*cachedValue) } + // Preload secrets from ARNs if enabled + if opts.PreloadARNs { + if err := preloadSecretsFromARNs(ctx, opts); err != nil { + return fmt.Errorf("failed to preload secrets from ARNs: %w", err) + } + } + value := reflect.ValueOf(v) if value.Kind() != reflect.Ptr { return fmt.Errorf("v must be a pointer to a struct") @@ -346,6 +356,7 @@ func (s *secret) cacheKey() string { // The retrieved value is then processed (validated, transformed) and cached if caching is enabled. func (s *secret) Get(ctx context.Context, opts *Options) (string, error) { s.mu.RLock() + defer s.mu.RUnlock() // Generate a unique cache key for this secret cacheKey := s.cacheKey() @@ -353,50 +364,80 @@ func (s *secret) Get(ctx context.Context, opts *Options) (string, error) { // Check if value exists in cache and is not expired opts.cacheMu.RLock() cached, ok := opts.cache[cacheKey] + opts.cacheMu.RUnlock() if ok && time.Now().Before(cached.expiration) { - value := cached.value - opts.cacheMu.RUnlock() - return value, nil + return cached.value, nil } - opts.cacheMu.RUnlock() + + var lastErr error // Try AWS first if enabled if opts != nil && opts.AWS != nil && s.awsKey != "" { awsValue, err := s.getFromAWS(ctx, opts.AWS) if err != nil { - return "", fmt.Errorf("failed to get value from AWS: %w", err) - } - if awsValue != "" { + if s.required { + return "", fmt.Errorf("failed to get value from AWS for required field %s: %w", s.field.Name, err) + } + lastErr = fmt.Errorf("failed to get value from AWS: %w", err) + } else { // Process and validate the value processedValue, err := s.processValue(awsValue) if err != nil { - return "", err - } - - // Cache the processed value if caching is enabled - if opts.CacheDuration > 0 { - opts.cacheMu.Lock() - opts.cache[cacheKey] = &cachedValue{ - value: processedValue, - expiration: time.Now().Add(opts.CacheDuration), + if s.required { + return "", err } - opts.cacheMu.Unlock() + lastErr = err + } else { + // Cache the processed value if caching is enabled + if opts.CacheDuration > 0 { + opts.cacheMu.Lock() + opts.cache[cacheKey] = &cachedValue{ + value: processedValue, + expiration: time.Now().Add(opts.CacheDuration), + } + opts.cacheMu.Unlock() + } + return processedValue, nil } - - return processedValue, nil } } // Try environment variable if AWS lookup failed or was disabled if s.envKey != "" { - if value := os.Getenv(s.envKey); value != "" { + if value, ok := os.LookupEnv(s.envKey); ok { // Process and validate the value processedValue, err := s.processValue(value) if err != nil { - return "", err + if s.required { + return "", err + } + lastErr = err + } else { + // Cache the processed value if caching is enabled + if opts.CacheDuration > 0 { + opts.cacheMu.Lock() + opts.cache[cacheKey] = &cachedValue{ + value: processedValue, + expiration: time.Now().Add(opts.CacheDuration), + } + opts.cacheMu.Unlock() + } + return processedValue, nil } + } + } - // Cache the processed value if caching is enabled + // Use fallback value if no other source provided a value + if s.fallback != "" { + // Process and validate the fallback value + processedValue, err := s.processValue(s.fallback) + if err != nil { + if s.required { + return "", err + } + lastErr = err + } else { + // Cache the processed fallback value if caching is enabled if opts.CacheDuration > 0 { opts.cacheMu.Lock() opts.cache[cacheKey] = &cachedValue{ @@ -405,33 +446,23 @@ func (s *secret) Get(ctx context.Context, opts *Options) (string, error) { } opts.cacheMu.Unlock() } - return processedValue, nil } } - // Use fallback value if no other source provided a value - if s.fallback != "" { - // Process and validate the fallback value - processedValue, err := s.processValue(s.fallback) - if err != nil { - return "", err + // If we reach here, no value was found + if s.required { + if lastErr != nil { + return "", fmt.Errorf("no value found for required secret %s: %w", s.field.Name, lastErr) } - - // Cache the processed fallback value if caching is enabled - if opts.CacheDuration > 0 { - opts.cacheMu.Lock() - opts.cache[cacheKey] = &cachedValue{ - value: processedValue, - expiration: time.Now().Add(opts.CacheDuration), - } - opts.cacheMu.Unlock() - } - - return processedValue, nil + return "", fmt.Errorf("no value found for required secret %s", s.field.Name) } - return "", fmt.Errorf("no value found for secret %s", s.field.Name) + // For non-required fields, return empty string or last error + if lastErr != nil { + return "", lastErr + } + return "", nil } func (s *secret) getFromAWS(ctx context.Context, awsConfig *aws.Config) (string, error) { @@ -488,7 +519,7 @@ func (s *secret) getFromAWS(ctx context.Context, awsConfig *aws.Config) (string, secretsCache[k] = v } secretsMu.Unlock() - + // Return the specific key if it exists if value, ok := secretMap[s.awsKey]; ok { return value, nil @@ -508,3 +539,80 @@ func (s *secret) getFromAWS(ctx context.Context, awsConfig *aws.Config) (string, func FetchAndValidate(ctx context.Context, v interface{}) error { return Fetch(ctx, v, nil) } + +// preloadSecretsFromARNs fetches secrets from AWS Secrets Manager and caches them +func preloadSecretsFromARNs(ctx context.Context, opts *Options) error { + secretArns := getSecretARNs() + if len(secretArns) == 0 { + return fmt.Errorf("no secret ARNs found in environment variables SECRET_ARNS or SECRET_ARN") + } + + if opts.AWS == nil { + // Load default AWS config if not provided + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return fmt.Errorf("unable to load AWS config: %w", err) + } + opts.AWS = &cfg + } + + client := secretsmanager.NewFromConfig(*opts.AWS) + + var wg sync.WaitGroup + errorsCh := make(chan error, len(secretArns)) + + for _, arn := range secretArns { + wg.Add(1) + go func(arn string) { + defer wg.Done() + output, err := client.GetSecretValue(ctx, &secretsmanager.GetSecretValueInput{ + SecretId: aws.String(arn), + }) + if err != nil { + errorsCh <- fmt.Errorf("error fetching secret %s: %w", arn, err) + return + } + + var secretPairs map[string]string + if err := json.Unmarshal([]byte(aws.ToString(output.SecretString)), &secretPairs); err != nil { + // If it's not JSON, store the raw secret string + secretPairs = map[string]string{ + arn: aws.ToString(output.SecretString), + } + } + + // Cache the secrets + secretsMu.Lock() + for k, v := range secretPairs { + secretsCache[k] = v + } + secretsMu.Unlock() + }(arn) + } + + wg.Wait() + close(errorsCh) + + // Check for errors + if len(errorsCh) > 0 { + return <-errorsCh // Return the first error encountered + } + + return nil +} + +// getSecretARNs returns a list of secret ARNs from environment variables +func getSecretARNs() []string { + arns := os.Getenv("SECRET_ARNS") + if arns == "" { + arns = os.Getenv("SECRET_ARN") + } + if arns == "" { + return nil + } + arnsList := strings.Split(arns, ",") + for i := range arnsList { + arnsList[i] = strings.TrimSpace(arnsList[i]) + } + return arnsList +}