diff --git a/example/main.go b/example/main.go index 846609a..b430f95 100644 --- a/example/main.go +++ b/example/main.go @@ -5,7 +5,7 @@ import ( "context" "fmt" "log" - "strings" + "os" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -48,40 +48,32 @@ type Config struct { } func main() { + // Set up environment variables for testing + os.Setenv("SECRET_ARN", "arn:aws:secretsmanager:region:account:secret:name") + os.Setenv("DB_HOST", "localhost") + os.Setenv("DB_PORT", "5432") + os.Setenv("DB_USER", "admin") + // Create an empty config struct - cfg := &Config{} + cfg := &DatabaseConfig{} - // Create options for customized behavior + // Create options with AWS configuration opts := &secretfetch.Options{ AWS: &aws.Config{ - Region: "us-west-2", - }, - Validators: map[string]secretfetch.ValidationFunc{ - "api_key": func(s string) error { - if len(s) != 32 { - return fmt.Errorf("API key must be 32 characters long") - } - return nil - }, - }, - Transformers: map[string]secretfetch.TransformFunc{ - "uppercase": func(s string) (string, error) { - return strings.ToUpper(s), nil - }, + Region: "us-west-2", // This can also be set via AWS_REGION environment variable }, } - // Populate all secrets with advanced features + // Fetch all secrets ctx := context.Background() if err := secretfetch.Fetch(ctx, cfg, opts); err != nil { log.Fatalf("Failed to fetch secrets: %v", err) } - // Use your fully populated config! - fmt.Printf("Environment: %s\n", cfg.Environment) - fmt.Printf("Database Config: %+v\n", cfg.Database) - fmt.Printf("Session Timeout: %v\n", cfg.SessionTimeout) - fmt.Printf("Max Connections: %d\n", cfg.MaxConnections) - fmt.Printf("Certificate Length: %d bytes\n", len(cfg.Certificate)) - fmt.Printf("Raw Data Length: %d bytes\n", len(cfg.RawData)) + // Use the configuration + fmt.Printf("Database Configuration:\n") + fmt.Printf("Host: %s\n", cfg.Host) + fmt.Printf("Port: %d\n", cfg.Port) + fmt.Printf("Username: %s\n", cfg.Username) + fmt.Printf("Password: [REDACTED]\n") } diff --git a/secret.go b/secret.go index bafce41..c3b99f4 100644 --- a/secret.go +++ b/secret.go @@ -6,6 +6,7 @@ package secretfetch import ( "context" "encoding/base64" + "encoding/json" "fmt" "os" "reflect" @@ -73,6 +74,12 @@ type cachedValue struct { expiration time.Time } +var ( + secretsCache map[string]string + secretsMu sync.RWMutex + secretsOnce sync.Once +) + func validatePattern(value, pattern string) error { re, err := regexp.Compile(pattern) if err != nil { @@ -428,23 +435,73 @@ func (s *secret) Get(ctx context.Context, opts *Options) (string, error) { } func (s *secret) getFromAWS(ctx context.Context, awsConfig *aws.Config) (string, error) { + if s.awsKey == "" { + return "", nil + } + + // Try to get from cache first + secretsMu.RLock() + if value, ok := secretsCache[s.awsKey]; ok { + secretsMu.RUnlock() + return value, nil + } + secretsMu.RUnlock() + + // Initialize cache if needed + secretsOnce.Do(func() { + secretsCache = make(map[string]string) + }) + + // Load AWS config cfg, err := config.LoadDefaultConfig(ctx, func(o *config.LoadOptions) error { - o.Region = awsConfig.Region - o.Credentials = awsConfig.Credentials + if awsConfig != nil { + o.Region = awsConfig.Region + o.Credentials = awsConfig.Credentials + } return nil }) if err != nil { - return "", err + return "", fmt.Errorf("unable to load AWS config: %w", err) } + // Create client and fetch secret client := secretsmanager.NewFromConfig(cfg) input := &secretsmanager.GetSecretValueInput{ SecretId: aws.String(s.awsKey), } - if result, err := client.GetSecretValue(ctx, input); err == nil { - return *result.SecretString, nil + + result, err := client.GetSecretValue(ctx, input) + if err != nil { + return "", fmt.Errorf("failed to get AWS secret %s: %w", s.awsKey, err) + } + + if result.SecretString == nil { + return "", fmt.Errorf("no secret string found for %s", s.awsKey) + } + + // Try to parse as JSON first + var secretMap map[string]string + if err := json.Unmarshal([]byte(*result.SecretString), &secretMap); err == nil { + // If successful, cache all values + secretsMu.Lock() + for k, v := range secretMap { + secretsCache[k] = v + } + secretsMu.Unlock() + + // Return the specific key if it exists + if value, ok := secretMap[s.awsKey]; ok { + return value, nil + } + // If key not found in JSON, use the entire string } - return "", err + + // Cache and return the raw string + secretsMu.Lock() + secretsCache[s.awsKey] = *result.SecretString + secretsMu.Unlock() + + return *result.SecretString, nil } // FetchAndValidate is an alias for Fetch to maintain backward compatibility