diff --git a/pkg/rsstorage/servers/s3server/s3_encrypted_service.go b/pkg/rsstorage/servers/s3server/s3_encrypted_service.go index 83b0025..faec1d0 100644 --- a/pkg/rsstorage/servers/s3server/s3_encrypted_service.go +++ b/pkg/rsstorage/servers/s3server/s3_encrypted_service.go @@ -115,3 +115,7 @@ func (s *encryptedS3Service) GetObject(input *s3.GetObjectInput) (*s3.GetObjectO } return out, err } + +func (s *encryptedS3Service) KmsEncrypted() bool { + return true +} diff --git a/pkg/rsstorage/servers/s3server/s3_service.go b/pkg/rsstorage/servers/s3server/s3_service.go index c6bc6b7..ce4d324 100644 --- a/pkg/rsstorage/servers/s3server/s3_service.go +++ b/pkg/rsstorage/servers/s3server/s3_service.go @@ -30,6 +30,7 @@ type S3Wrapper interface { CopyObject(bucket, key, newBucket, newKey string) (*s3.CopyObjectOutput, error) MoveObject(bucket, key, newBucket, newKey string) (*s3.CopyObjectOutput, error) ListObjects(bucket, prefix string) ([]string, error) + KmsEncrypted() bool } type defaultS3Wrapper struct { @@ -57,6 +58,10 @@ func NewS3Wrapper(configInput *rsstorage.ConfigS3, keyID string) (S3Wrapper, err }, nil } +func (s *defaultS3Wrapper) KmsEncrypted() bool { + return false +} + func (s *defaultS3Wrapper) CreateBucket(input *s3.CreateBucketInput) (*s3.CreateBucketOutput, error) { svc := s3.New(s.session) out, err := svc.CreateBucket(input) diff --git a/pkg/rsstorage/servers/s3server/server.go b/pkg/rsstorage/servers/s3server/server.go index 409c377..4530f80 100644 --- a/pkg/rsstorage/servers/s3server/server.go +++ b/pkg/rsstorage/servers/s3server/server.go @@ -9,6 +9,7 @@ import ( "io" "net/url" "path/filepath" + "strconv" "strings" "time" @@ -23,6 +24,8 @@ import ( "github.com/rstudio/platform-lib/pkg/rsstorage/types" ) +const AmzUnencryptedContentLengthHeader = "X-Amz-Unencrypted-Content-Length" + type moveOrCopyFn func(bucket, key, newBucket, newKey string) (*s3.CopyObjectOutput, error) type StorageServer struct { @@ -107,6 +110,7 @@ func (s *StorageServer) Validate() error { func (s *StorageServer) Check(dir, address string) (bool, *types.ChunksInfo, int64, time.Time, error) { var chunked bool + var contentLength int64 addr := internal.NotEmptyJoin([]string{s.prefix, dir, address}, "/") infoAddr := filepath.Join(addr, "info.json") resp, err := s.svc.HeadObject(&s3.HeadObjectInput{Bucket: aws.String(s.bucket), Key: aws.String(addr)}) @@ -152,8 +156,17 @@ func (s *StorageServer) Check(dir, address string) (bool, *types.ChunksInfo, int } return true, &info, int64(info.FileSize), info.ModTime, nil } else { + // Check some headers for the unencrypted content length for KMS encrypted objects. + if s.svc.KmsEncrypted() { + if cl, ok := resp.Metadata[AmzUnencryptedContentLengthHeader]; ok { + contentLength, _ = strconv.ParseInt(*cl, 10, 64) + } + } else { + contentLength = *resp.ContentLength + } + // For standard assets, the HeadObject response has the information we need. - return true, nil, *resp.ContentLength, *resp.LastModified, nil + return true, nil, contentLength, *resp.LastModified, nil } } @@ -172,6 +185,7 @@ func (s *StorageServer) CalculateUsage() (types.Usage, error) { func (s *StorageServer) Get(dir, address string) (io.ReadCloser, *types.ChunksInfo, int64, time.Time, bool, error) { var chunked bool + var contentLength int64 addr := internal.NotEmptyJoin([]string{s.prefix, dir, address}, "/") infoAddr := filepath.Join(addr, "info.json") resp, err := s.svc.GetObject(&s3.GetObjectInput{Bucket: aws.String(s.bucket), Key: aws.String(addr)}) @@ -208,8 +222,17 @@ func (s *StorageServer) Get(dir, address string) (io.ReadCloser, *types.ChunksIn } return r, c, sz, mod, true, nil } else { + // Check some headers for the unencrypted content length for KMS encrypted objects. + if s.svc.KmsEncrypted() { + if cl, ok := resp.Metadata[AmzUnencryptedContentLengthHeader]; ok { + contentLength, _ = strconv.ParseInt(*cl, 10, 64) + } + } else { + contentLength = *resp.ContentLength + } + // For standard assets, the GetObject response has the information we need. - return resp.Body, nil, *resp.ContentLength, *resp.LastModified, true, nil + return resp.Body, nil, contentLength, *resp.LastModified, true, nil } } diff --git a/pkg/rsstorage/servers/s3server/server_test.go b/pkg/rsstorage/servers/s3server/server_test.go index 96a73cb..85b8cdd 100644 --- a/pkg/rsstorage/servers/s3server/server_test.go +++ b/pkg/rsstorage/servers/s3server/server_test.go @@ -78,6 +78,10 @@ type fakeS3 struct { delBucketOut *s3.DeleteBucketOutput } +func (s *fakeS3) KmsEncrypted() bool { + return false +} + func (s *fakeS3) CreateBucket(input *s3.CreateBucketInput) (*s3.CreateBucketOutput, error) { if s.bucketErr == nil { s.bucketIn = input