Skip to content

Commit

Permalink
fix: added ARN support
Browse files Browse the repository at this point in the history
  • Loading branch information
crazywolf132 committed Dec 2, 2024
1 parent 35d8c62 commit d6a14ff
Showing 1 changed file with 155 additions and 47 deletions.
202 changes: 155 additions & 47 deletions secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ type Options struct {
Transformers map[string]TransformFunc
// CacheDuration specifies how long to cache values for
CacheDuration time.Duration
cacheMu sync.RWMutex
cache map[string]*cachedValue
// PreloadARNs indicates whether to preload secrets from ARNs
PreloadARNs bool
cacheMu sync.RWMutex
cache map[string]*cachedValue
}

// ValidationFunc is a function type for custom validation
Expand Down Expand Up @@ -67,6 +69,7 @@ type secret struct {
awsKey string
mu sync.RWMutex
cache *cachedValue
required bool
}

type cachedValue struct {
Expand All @@ -75,7 +78,7 @@ type cachedValue struct {
}

var (
secretsCache map[string]string
secretsCache map[string]string = make(map[string]string)
secretsMu sync.RWMutex
secretsOnce sync.Once
)
Expand Down Expand Up @@ -197,7 +200,7 @@ func parseTag(field reflect.StructField, opts *Options) (*secret, error) {
} else {
switch strings.TrimSpace(part) {
case "required":
// Required is handled during Get
s.required = true
case "base64":
s.isBase64 = true
case "json":
Expand Down Expand Up @@ -273,6 +276,13 @@ func Fetch(ctx context.Context, v interface{}, opts *Options) error {
opts.cache = make(map[string]*cachedValue)
}

// Preload secrets from ARNs if enabled
if opts.PreloadARNs {
if err := preloadSecretsFromARNs(ctx, opts); err != nil {
return fmt.Errorf("failed to preload secrets from ARNs: %w", err)
}
}

value := reflect.ValueOf(v)
if value.Kind() != reflect.Ptr {
return fmt.Errorf("v must be a pointer to a struct")
Expand Down Expand Up @@ -346,57 +356,88 @@ func (s *secret) cacheKey() string {
// The retrieved value is then processed (validated, transformed) and cached if caching is enabled.
func (s *secret) Get(ctx context.Context, opts *Options) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()

// Generate a unique cache key for this secret
cacheKey := s.cacheKey()

// Check if value exists in cache and is not expired
opts.cacheMu.RLock()
cached, ok := opts.cache[cacheKey]
opts.cacheMu.RUnlock()
if ok && time.Now().Before(cached.expiration) {
value := cached.value
opts.cacheMu.RUnlock()
return value, nil
return cached.value, nil
}
opts.cacheMu.RUnlock()

var lastErr error

// Try AWS first if enabled
if opts != nil && opts.AWS != nil && s.awsKey != "" {
awsValue, err := s.getFromAWS(ctx, opts.AWS)
if err != nil {
return "", fmt.Errorf("failed to get value from AWS: %w", err)
}
if awsValue != "" {
if s.required {
return "", fmt.Errorf("failed to get value from AWS for required field %s: %w", s.field.Name, err)
}
lastErr = fmt.Errorf("failed to get value from AWS: %w", err)
} else {
// Process and validate the value
processedValue, err := s.processValue(awsValue)
if err != nil {
return "", err
}

// Cache the processed value if caching is enabled
if opts.CacheDuration > 0 {
opts.cacheMu.Lock()
opts.cache[cacheKey] = &cachedValue{
value: processedValue,
expiration: time.Now().Add(opts.CacheDuration),
if s.required {
return "", err
}
opts.cacheMu.Unlock()
lastErr = err
} else {
// Cache the processed value if caching is enabled
if opts.CacheDuration > 0 {
opts.cacheMu.Lock()
opts.cache[cacheKey] = &cachedValue{
value: processedValue,
expiration: time.Now().Add(opts.CacheDuration),
}
opts.cacheMu.Unlock()
}
return processedValue, nil
}

return processedValue, nil
}
}

// Try environment variable if AWS lookup failed or was disabled
if s.envKey != "" {
if value := os.Getenv(s.envKey); value != "" {
if value, ok := os.LookupEnv(s.envKey); ok {
// Process and validate the value
processedValue, err := s.processValue(value)
if err != nil {
return "", err
if s.required {
return "", err
}
lastErr = err
} else {
// Cache the processed value if caching is enabled
if opts.CacheDuration > 0 {
opts.cacheMu.Lock()
opts.cache[cacheKey] = &cachedValue{
value: processedValue,
expiration: time.Now().Add(opts.CacheDuration),
}
opts.cacheMu.Unlock()
}
return processedValue, nil
}
}
}

// Cache the processed value if caching is enabled
// Use fallback value if no other source provided a value
if s.fallback != "" {
// Process and validate the fallback value
processedValue, err := s.processValue(s.fallback)
if err != nil {
if s.required {
return "", err
}
lastErr = err
} else {
// Cache the processed fallback value if caching is enabled
if opts.CacheDuration > 0 {
opts.cacheMu.Lock()
opts.cache[cacheKey] = &cachedValue{
Expand All @@ -405,33 +446,23 @@ func (s *secret) Get(ctx context.Context, opts *Options) (string, error) {
}
opts.cacheMu.Unlock()
}

return processedValue, nil
}
}

// Use fallback value if no other source provided a value
if s.fallback != "" {
// Process and validate the fallback value
processedValue, err := s.processValue(s.fallback)
if err != nil {
return "", err
// If we reach here, no value was found
if s.required {
if lastErr != nil {
return "", fmt.Errorf("no value found for required secret %s: %w", s.field.Name, lastErr)
}

// Cache the processed fallback value if caching is enabled
if opts.CacheDuration > 0 {
opts.cacheMu.Lock()
opts.cache[cacheKey] = &cachedValue{
value: processedValue,
expiration: time.Now().Add(opts.CacheDuration),
}
opts.cacheMu.Unlock()
}

return processedValue, nil
return "", fmt.Errorf("no value found for required secret %s", s.field.Name)
}

return "", fmt.Errorf("no value found for secret %s", s.field.Name)
// For non-required fields, return empty string or last error
if lastErr != nil {
return "", lastErr
}
return "", nil
}

func (s *secret) getFromAWS(ctx context.Context, awsConfig *aws.Config) (string, error) {
Expand Down Expand Up @@ -488,7 +519,7 @@ func (s *secret) getFromAWS(ctx context.Context, awsConfig *aws.Config) (string,
secretsCache[k] = v
}
secretsMu.Unlock()

// Return the specific key if it exists
if value, ok := secretMap[s.awsKey]; ok {
return value, nil
Expand All @@ -508,3 +539,80 @@ func (s *secret) getFromAWS(ctx context.Context, awsConfig *aws.Config) (string,
func FetchAndValidate(ctx context.Context, v interface{}) error {
return Fetch(ctx, v, nil)
}

// preloadSecretsFromARNs fetches secrets from AWS Secrets Manager and caches them
func preloadSecretsFromARNs(ctx context.Context, opts *Options) error {
secretArns := getSecretARNs()
if len(secretArns) == 0 {
return fmt.Errorf("no secret ARNs found in environment variables SECRET_ARNS or SECRET_ARN")
}

if opts.AWS == nil {
// Load default AWS config if not provided
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return fmt.Errorf("unable to load AWS config: %w", err)
}
opts.AWS = &cfg
}

client := secretsmanager.NewFromConfig(*opts.AWS)

var wg sync.WaitGroup
errorsCh := make(chan error, len(secretArns))

for _, arn := range secretArns {
wg.Add(1)
go func(arn string) {
defer wg.Done()
output, err := client.GetSecretValue(ctx, &secretsmanager.GetSecretValueInput{
SecretId: aws.String(arn),
})
if err != nil {
errorsCh <- fmt.Errorf("error fetching secret %s: %w", arn, err)
return
}

var secretPairs map[string]string
if err := json.Unmarshal([]byte(aws.ToString(output.SecretString)), &secretPairs); err != nil {
// If it's not JSON, store the raw secret string
secretPairs = map[string]string{
arn: aws.ToString(output.SecretString),
}
}

// Cache the secrets
secretsMu.Lock()
for k, v := range secretPairs {
secretsCache[k] = v
}
secretsMu.Unlock()
}(arn)
}

wg.Wait()
close(errorsCh)

// Check for errors
if len(errorsCh) > 0 {
return <-errorsCh // Return the first error encountered
}

return nil
}

// getSecretARNs returns a list of secret ARNs from environment variables
func getSecretARNs() []string {
arns := os.Getenv("SECRET_ARNS")
if arns == "" {
arns = os.Getenv("SECRET_ARN")
}
if arns == "" {
return nil
}
arnsList := strings.Split(arns, ",")
for i := range arnsList {
arnsList[i] = strings.TrimSpace(arnsList[i])
}
return arnsList
}

0 comments on commit d6a14ff

Please sign in to comment.