diff --git a/storage/s3.go b/storage/s3.go index 1b07f3ced..d6bd0e3d0 100644 --- a/storage/s3.go +++ b/storage/s3.go @@ -62,29 +62,43 @@ type S3 struct { partSize int64 } -type s3Opt func(s3 *S3) +type s3Opt func(s3 *S3) error // WithRetryPolicy is an option to constructor NewS3 to add a Retry Policy // impacting GET operations func WithRetryPolicy(policy RetryPolicy) s3Opt { - return s3Opt(func(s3 *S3) { + return s3Opt(func(s3 *S3) error { s3.retryPolicy = policy + return nil }) } func WithPartSize(size int64) s3Opt { - return s3Opt(func(s3 *S3) { + return s3Opt(func(s3 *S3) error { s3.partSize = size + return nil }) } func WithUploadConcurrency(concurrency int) s3Opt { - return s3Opt(func(s3 *S3) { + return s3Opt(func(s3 *S3) error { s3.uploadConcurrency = concurrency + return nil + }) +} + +func WithAllowList(l []string) s3Opt { + return s3Opt(func(s3 *S3) error { + for _, url := range l { + if strings.HasPrefix(s3.cfg.Endpoint, url) { + return nil + } + } + return errors.New("endpoint is not in allow list") }) } -func NewS3(cfg S3Config, opts ...s3Opt) *S3 { +func NewS3(cfg S3Config, opts ...s3Opt) (*S3, error) { s3config := s3Config(cfg) s3client := s3.NewFromConfig(s3config) s3 := &S3{ @@ -99,7 +113,10 @@ func NewS3(cfg S3Config, opts ...s3Opt) *S3 { }, } for _, opt := range opts { - opt(s3) + err := opt(s3) + if err != nil { + return nil, err + } } partSize := DefaultPartSize @@ -114,7 +131,7 @@ func NewS3(cfg S3Config, opts ...s3Opt) *S3 { u.PartSize = partSize u.Concurrency = concurrency }) - return s3 + return s3, nil } func (s *S3) Get(ctx context.Context, path string) (io.ReadCloser, error) {