Skip to content

Commit

Permalink
Check X-Amz header for unencrypted content length when using KMS
Browse files Browse the repository at this point in the history
  • Loading branch information
jonyoder committed Dec 12, 2023
1 parent 182af43 commit fcc63fc
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 2 deletions.
4 changes: 4 additions & 0 deletions pkg/rsstorage/servers/s3server/s3_encrypted_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,7 @@ func (s *encryptedS3Service) GetObject(input *s3.GetObjectInput) (*s3.GetObjectO
}
return out, err
}

func (s *encryptedS3Service) KmsEncrypted() bool {
return true
}
5 changes: 5 additions & 0 deletions pkg/rsstorage/servers/s3server/s3_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type S3Wrapper interface {
CopyObject(bucket, key, newBucket, newKey string) (*s3.CopyObjectOutput, error)
MoveObject(bucket, key, newBucket, newKey string) (*s3.CopyObjectOutput, error)
ListObjects(bucket, prefix string) ([]string, error)
KmsEncrypted() bool
}

type defaultS3Wrapper struct {
Expand Down Expand Up @@ -57,6 +58,10 @@ func NewS3Wrapper(configInput *rsstorage.ConfigS3, keyID string) (S3Wrapper, err
}, nil
}

func (s *defaultS3Wrapper) KmsEncrypted() bool {
return false
}

func (s *defaultS3Wrapper) CreateBucket(input *s3.CreateBucketInput) (*s3.CreateBucketOutput, error) {
svc := s3.New(s.session)
out, err := svc.CreateBucket(input)
Expand Down
27 changes: 25 additions & 2 deletions pkg/rsstorage/servers/s3server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"net/url"
"path/filepath"
"strconv"
"strings"
"time"

Expand All @@ -23,6 +24,8 @@ import (
"github.com/rstudio/platform-lib/pkg/rsstorage/types"
)

const AmzUnencryptedContentLengthHeader = "X-Amz-Unencrypted-Content-Length"

type moveOrCopyFn func(bucket, key, newBucket, newKey string) (*s3.CopyObjectOutput, error)

type StorageServer struct {
Expand Down Expand Up @@ -107,6 +110,7 @@ func (s *StorageServer) Validate() error {

func (s *StorageServer) Check(dir, address string) (bool, *types.ChunksInfo, int64, time.Time, error) {
var chunked bool
var contentLength int64
addr := internal.NotEmptyJoin([]string{s.prefix, dir, address}, "/")
infoAddr := filepath.Join(addr, "info.json")
resp, err := s.svc.HeadObject(&s3.HeadObjectInput{Bucket: aws.String(s.bucket), Key: aws.String(addr)})
Expand Down Expand Up @@ -152,8 +156,17 @@ func (s *StorageServer) Check(dir, address string) (bool, *types.ChunksInfo, int
}
return true, &info, int64(info.FileSize), info.ModTime, nil
} else {
// Check some headers for the unencrypted content length for KMS encrypted objects.
if s.svc.KmsEncrypted() {
if cl, ok := resp.Metadata[AmzUnencryptedContentLengthHeader]; ok {
contentLength, _ = strconv.ParseInt(*cl, 10, 64)
}
} else {
contentLength = *resp.ContentLength
}

// For standard assets, the HeadObject response has the information we need.
return true, nil, *resp.ContentLength, *resp.LastModified, nil
return true, nil, contentLength, *resp.LastModified, nil
}
}

Expand All @@ -172,6 +185,7 @@ func (s *StorageServer) CalculateUsage() (types.Usage, error) {

func (s *StorageServer) Get(dir, address string) (io.ReadCloser, *types.ChunksInfo, int64, time.Time, bool, error) {
var chunked bool
var contentLength int64
addr := internal.NotEmptyJoin([]string{s.prefix, dir, address}, "/")
infoAddr := filepath.Join(addr, "info.json")
resp, err := s.svc.GetObject(&s3.GetObjectInput{Bucket: aws.String(s.bucket), Key: aws.String(addr)})
Expand Down Expand Up @@ -208,8 +222,17 @@ func (s *StorageServer) Get(dir, address string) (io.ReadCloser, *types.ChunksIn
}
return r, c, sz, mod, true, nil
} else {
// Check some headers for the unencrypted content length for KMS encrypted objects.
if s.svc.KmsEncrypted() {
if cl, ok := resp.Metadata[AmzUnencryptedContentLengthHeader]; ok {
contentLength, _ = strconv.ParseInt(*cl, 10, 64)
}
} else {
contentLength = *resp.ContentLength
}

// For standard assets, the GetObject response has the information we need.
return resp.Body, nil, *resp.ContentLength, *resp.LastModified, true, nil
return resp.Body, nil, contentLength, *resp.LastModified, true, nil
}
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/rsstorage/servers/s3server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ type fakeS3 struct {
delBucketOut *s3.DeleteBucketOutput
}

func (s *fakeS3) KmsEncrypted() bool {
return false
}

func (s *fakeS3) CreateBucket(input *s3.CreateBucketInput) (*s3.CreateBucketOutput, error) {
if s.bucketErr == nil {
s.bucketIn = input
Expand Down

0 comments on commit fcc63fc

Please sign in to comment.