Skip to content

Commit

Permalink
fix: improved way of handling aws secrets
Browse files Browse the repository at this point in the history
  • Loading branch information
crazywolf132 committed Nov 29, 2024
1 parent 154446f commit 35d8c62
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 31 deletions.
42 changes: 17 additions & 25 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"context"
"fmt"
"log"
"strings"
"os"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -48,40 +48,32 @@ type Config struct {
}

func main() {
// Set up environment variables for testing
os.Setenv("SECRET_ARN", "arn:aws:secretsmanager:region:account:secret:name")
os.Setenv("DB_HOST", "localhost")
os.Setenv("DB_PORT", "5432")
os.Setenv("DB_USER", "admin")

// Create an empty config struct
cfg := &Config{}
cfg := &DatabaseConfig{}

// Create options for customized behavior
// Create options with AWS configuration
opts := &secretfetch.Options{
AWS: &aws.Config{
Region: "us-west-2",
},
Validators: map[string]secretfetch.ValidationFunc{
"api_key": func(s string) error {
if len(s) != 32 {
return fmt.Errorf("API key must be 32 characters long")
}
return nil
},
},
Transformers: map[string]secretfetch.TransformFunc{
"uppercase": func(s string) (string, error) {
return strings.ToUpper(s), nil
},
Region: "us-west-2", // This can also be set via AWS_REGION environment variable
},
}

// Populate all secrets with advanced features
// Fetch all secrets
ctx := context.Background()
if err := secretfetch.Fetch(ctx, cfg, opts); err != nil {
log.Fatalf("Failed to fetch secrets: %v", err)
}

// Use your fully populated config!
fmt.Printf("Environment: %s\n", cfg.Environment)
fmt.Printf("Database Config: %+v\n", cfg.Database)
fmt.Printf("Session Timeout: %v\n", cfg.SessionTimeout)
fmt.Printf("Max Connections: %d\n", cfg.MaxConnections)
fmt.Printf("Certificate Length: %d bytes\n", len(cfg.Certificate))
fmt.Printf("Raw Data Length: %d bytes\n", len(cfg.RawData))
// Use the configuration
fmt.Printf("Database Configuration:\n")
fmt.Printf("Host: %s\n", cfg.Host)
fmt.Printf("Port: %d\n", cfg.Port)
fmt.Printf("Username: %s\n", cfg.Username)
fmt.Printf("Password: [REDACTED]\n")
}
69 changes: 63 additions & 6 deletions secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package secretfetch
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"reflect"
Expand Down Expand Up @@ -73,6 +74,12 @@ type cachedValue struct {
expiration time.Time
}

var (
secretsCache map[string]string
secretsMu sync.RWMutex
secretsOnce sync.Once
)

func validatePattern(value, pattern string) error {
re, err := regexp.Compile(pattern)
if err != nil {
Expand Down Expand Up @@ -428,23 +435,73 @@ func (s *secret) Get(ctx context.Context, opts *Options) (string, error) {
}

func (s *secret) getFromAWS(ctx context.Context, awsConfig *aws.Config) (string, error) {
if s.awsKey == "" {
return "", nil
}

// Try to get from cache first
secretsMu.RLock()
if value, ok := secretsCache[s.awsKey]; ok {
secretsMu.RUnlock()
return value, nil
}
secretsMu.RUnlock()

// Initialize cache if needed
secretsOnce.Do(func() {
secretsCache = make(map[string]string)
})

// Load AWS config
cfg, err := config.LoadDefaultConfig(ctx, func(o *config.LoadOptions) error {
o.Region = awsConfig.Region
o.Credentials = awsConfig.Credentials
if awsConfig != nil {
o.Region = awsConfig.Region
o.Credentials = awsConfig.Credentials
}
return nil
})
if err != nil {
return "", err
return "", fmt.Errorf("unable to load AWS config: %w", err)
}

// Create client and fetch secret
client := secretsmanager.NewFromConfig(cfg)
input := &secretsmanager.GetSecretValueInput{
SecretId: aws.String(s.awsKey),
}
if result, err := client.GetSecretValue(ctx, input); err == nil {
return *result.SecretString, nil

result, err := client.GetSecretValue(ctx, input)
if err != nil {
return "", fmt.Errorf("failed to get AWS secret %s: %w", s.awsKey, err)
}

if result.SecretString == nil {
return "", fmt.Errorf("no secret string found for %s", s.awsKey)
}

// Try to parse as JSON first
var secretMap map[string]string
if err := json.Unmarshal([]byte(*result.SecretString), &secretMap); err == nil {
// If successful, cache all values
secretsMu.Lock()
for k, v := range secretMap {
secretsCache[k] = v
}
secretsMu.Unlock()

// Return the specific key if it exists
if value, ok := secretMap[s.awsKey]; ok {
return value, nil
}
// If key not found in JSON, use the entire string
}
return "", err

// Cache and return the raw string
secretsMu.Lock()
secretsCache[s.awsKey] = *result.SecretString
secretsMu.Unlock()

return *result.SecretString, nil
}

// FetchAndValidate is an alias for Fetch to maintain backward compatibility
Expand Down

0 comments on commit 35d8c62

Please sign in to comment.