Skip to content

Commit

Permalink
chore: ensure the s3 object key don't have a prefix /
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxin0723 committed Nov 22, 2024
1 parent ba74b80 commit 883ebc1
Showing 1 changed file with 73 additions and 32 deletions.
105 changes: 73 additions & 32 deletions s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
smithyendpoints "github.com/aws/smithy-go/endpoints"
"github.com/qor/oss"
)

Expand Down Expand Up @@ -58,6 +59,20 @@ func New(cfg *Config) *Client {

client := &Client{Config: cfg}

s3CfgOptions := []func(o *s3.Options){
func(o *s3.Options) {
o.Region = cfg.Region
o.UsePathStyle = cfg.S3ForcePathStyle
},
}

if cfg.S3Endpoint != "" {
s3CfgOptions = append(s3CfgOptions, s3.WithEndpointResolverV2(&endpointResolverV2{
Url: cfg.S3Endpoint,
}))

}

// use role ARN to fetch credentials
if cfg.RoleARN != "" {
awsCfg, err := config.LoadDefaultConfig(context.TODO())
Expand All @@ -68,59 +83,46 @@ func New(cfg *Config) *Client {
provider := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(awsCfg), cfg.RoleARN)
creds := aws.NewCredentialsCache(provider)

s3Client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
o.Region = cfg.Region
o.BaseEndpoint = aws.String(cfg.S3Endpoint)
o.UsePathStyle = cfg.S3ForcePathStyle

s3CfgOptions = append(s3CfgOptions, func(o *s3.Options) {
o.Credentials = creds
})

s3Client := s3.NewFromConfig(awsCfg, s3CfgOptions...)
client.S3 = s3Client
return client
}

// use alreay configured aws config
if cfg.AwsConfig != nil {
s3Client := s3.NewFromConfig(*cfg.AwsConfig, func(o *s3.Options) {
o.Region = cfg.Region
o.BaseEndpoint = aws.String(cfg.S3Endpoint)
o.UsePathStyle = cfg.S3ForcePathStyle
})

s3Client := s3.NewFromConfig(*cfg.AwsConfig, s3CfgOptions...)
client.S3 = s3Client
return client
}

cfgOptions := []func(*config.LoadOptions) error{
aswCfgOptions := []func(*config.LoadOptions) error{
config.WithRegion(cfg.Region),
}

// use EC2 IAM role
if cfg.EnableEC2IAMRole {
cfgOptions = append(cfgOptions, config.WithCredentialsProvider(
aswCfgOptions = append(aswCfgOptions, config.WithCredentialsProvider(
ec2rolecreds.New(),
))
}

// use static credentials
if cfg.AccessID != "" && cfg.AccessKey != "" {
cfgOptions = append(cfgOptions, config.WithCredentialsProvider(
aswCfgOptions = append(aswCfgOptions, config.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(cfg.AccessID, cfg.AccessKey, cfg.SessionToken),
))
}

awsConfig, err := config.LoadDefaultConfig(context.TODO(), cfgOptions...)
awsConfig, err := config.LoadDefaultConfig(context.TODO(), aswCfgOptions...)
if err != nil {
panic(err)
}

s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) {
o.Region = cfg.Region
o.BaseEndpoint = aws.String(cfg.S3Endpoint)
o.UsePathStyle = cfg.S3ForcePathStyle
})

s3Client := s3.NewFromConfig(awsConfig, s3CfgOptions...)
client.S3 = s3Client
return client
}
Expand All @@ -147,7 +149,7 @@ func (client Client) Get(path string) (file *os.File, err error) {
func (client Client) GetStream(path string) (io.ReadCloser, error) {
getResponse, err := client.S3.GetObject(context.TODO(), &s3.GetObjectInput{
Bucket: aws.String(client.Config.Bucket),
Key: aws.String(client.ToRelativePath(path)),
Key: aws.String(client.ToS3Key(path)),
})

return getResponse.Body, err
Expand All @@ -159,7 +161,7 @@ func (client Client) Put(urlPath string, reader io.Reader) (*oss.Object, error)
seeker.Seek(0, 0)
}

urlPath = client.ToRelativePath(urlPath)
key := client.ToS3Key(urlPath)
buffer, err := io.ReadAll(reader)

fileType := mime.TypeByExtension(path.Ext(urlPath))
Expand All @@ -169,7 +171,7 @@ func (client Client) Put(urlPath string, reader io.Reader) (*oss.Object, error)

params := &s3.PutObjectInput{
Bucket: aws.String(client.Config.Bucket), // required
Key: aws.String(urlPath), // required
Key: aws.String(key), // required
ACL: client.Config.ACL,
Body: bytes.NewReader(buffer),
ContentLength: aws.Int64(int64(len(buffer))),
Expand All @@ -194,7 +196,7 @@ func (client Client) Put(urlPath string, reader io.Reader) (*oss.Object, error)
func (client Client) Delete(path string) error {
_, err := client.S3.DeleteObject(context.Background(), &s3.DeleteObjectInput{
Bucket: aws.String(client.Config.Bucket),
Key: aws.String(client.ToRelativePath(path)),
Key: aws.String(client.ToS3Key(path)),
})
return err
}
Expand All @@ -204,7 +206,7 @@ func (client Client) DeleteObjects(paths []string) (err error) {
var objs []types.ObjectIdentifier
for _, v := range paths {
var obj types.ObjectIdentifier
obj.Key = aws.String(strings.TrimPrefix(client.ToRelativePath(v), "/"))
obj.Key = aws.String(strings.TrimPrefix(client.ToS3Key(v), "/"))
objs = append(objs, obj)
}
input := &s3.DeleteObjectsInput{
Expand Down Expand Up @@ -238,7 +240,7 @@ func (client Client) List(path string) ([]*oss.Object, error) {
if err == nil {
for _, content := range listObjectsResponse.Contents {
objects = append(objects, &oss.Object{
Path: client.ToRelativePath(*content.Key),
Path: client.ToS3Key(*content.Key),
Name: filepath.Base(*content.Key),
LastModified: content.LastModified,
StorageInterface: client,
Expand All @@ -255,16 +257,36 @@ func (client Client) GetEndpoint() string {
return client.Config.Endpoint
}

endpoint := *client.S3.Options().BaseEndpoint
for _, prefix := range []string{"https://", "http://"} {
endpoint = strings.TrimPrefix(endpoint, prefix)
if client.Config.S3Endpoint != "" {
return client.Config.S3Endpoint
}

if client.Config.S3ForcePathStyle {
return fmt.Sprintf("s3.%s.amazonaws.com/%s", client.Config.Region, client.Config.Bucket)
}

return client.Config.Bucket + "." + endpoint
return fmt.Sprintf("%s.s3.%s.amazonaws.com", client.Config.Bucket, client.Config.Region)
}

var urlRegexp = regexp.MustCompile(`(https?:)?//((\w+).)+(\w+)/`)

// ToS3Key convert URL path to S3 key
func (client Client) ToS3Key(urlPath string) string {
if urlRegexp.MatchString(urlPath) {
if u, err := url.Parse(urlPath); err == nil {
if client.Config.S3ForcePathStyle { // First part of path will be bucket name
return strings.TrimPrefix(u.Path, "/"+client.Config.Bucket)
}
return strings.TrimPrefix(u.Path, "/")
}
}

if client.Config.S3ForcePathStyle { // First part of path will be bucket name
return strings.TrimPrefix(urlPath, "/"+client.Config.Bucket+"/")
}
return strings.TrimPrefix(urlPath, "/")
}

// ToRelativePath process path to relative path
func (client Client) ToRelativePath(urlPath string) string {
if urlRegexp.MatchString(urlPath) {
Expand All @@ -291,7 +313,7 @@ func (client Client) GetURL(path string) (url string, err error) {
presignClient := s3.NewPresignClient(client.S3)
presignedGetURL, err := presignClient.PresignGetObject(context.TODO(), &s3.GetObjectInput{
Bucket: aws.String(client.Config.Bucket),
Key: aws.String(client.ToRelativePath(path)),
Key: aws.String(client.ToS3Key(path)),
}, func(opts *s3.PresignOptions) {
opts.Expires = 1 * time.Hour
})
Expand All @@ -314,3 +336,22 @@ func (client Client) Copy(from, to string) (err error) {
})
return
}

type endpointResolverV2 struct {
Url string
}

func (r *endpointResolverV2) ResolveEndpoint(
ctx context.Context, params s3.EndpointParameters,
) (
endpoint smithyendpoints.Endpoint, err error,
) {

u, err := url.Parse(r.Url)
if err != nil {
return smithyendpoints.Endpoint{}, err
}
return smithyendpoints.Endpoint{
URI: *u,
}, nil
}

0 comments on commit 883ebc1

Please sign in to comment.