Skip to content

Commit

Permalink
fix: renew the session token when the token expires
Browse files Browse the repository at this point in the history
  • Loading branch information
MqllR committed Dec 6, 2023
1 parent c1c7ee3 commit 3e46e7e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
28 changes: 22 additions & 6 deletions storage/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -1235,9 +1235,9 @@ func (sc *SessionCache) newSession(ctx context.Context, opts Options) (*session.
WithLogger(sdkLogger{})
}

awsCfg.Retryer = newCustomRetryer(opts.MaxRetries)
awsCfg.Retryer = newCustomRetryer(sc, opts.MaxRetries)

useSharedConfig := session.SharedConfigEnable
useSharedConfig := session.SharedConfigDisable
{
// Reverse of what the SDK does: if AWS_SDK_LOAD_CONFIG is 0 (or a
// falsy value) disable shared configs
Expand Down Expand Up @@ -1276,7 +1276,7 @@ func (sc *SessionCache) newSession(ctx context.Context, opts Options) (*session.
return sess, nil
}

func (sc *SessionCache) clear() {
func (sc *SessionCache) Clear() {
sc.Lock()
defer sc.Unlock()
sc.sessions = map[Options]*session.Session{}
Expand Down Expand Up @@ -1324,10 +1324,12 @@ func setSessionRegion(ctx context.Context, sess *session.Session, bucket string)
// error codes. Such as, retry for S3 InternalError code.
type customRetryer struct {
client.DefaultRetryer
sc *SessionCache
}

func newCustomRetryer(maxRetries int) *customRetryer {
func newCustomRetryer(sc *SessionCache, maxRetries int) *customRetryer {
return &customRetryer{
sc: sc,
DefaultRetryer: client.DefaultRetryer{
NumMaxRetries: maxRetries,
},
Expand All @@ -1337,13 +1339,27 @@ func newCustomRetryer(maxRetries int) *customRetryer {
// ShouldRetry overrides SDK's built in DefaultRetryer, adding custom retry
// logics that are not included in the SDK.
func (c *customRetryer) ShouldRetry(req *request.Request) bool {
shouldRetry := errHasCode(req.Error, "InternalError") || errHasCode(req.Error, "RequestTimeTooSkewed") || errHasCode(req.Error, "SlowDown") || strings.Contains(req.Error.Error(), "connection reset") || strings.Contains(req.Error.Error(), "connection timed out")
log.Error(log.ErrorMessage{
Command: "retrier",
Err: req.Error.Error(),
})

shouldRetry := errHasCode(req.Error, "InternalError") || errHasCode(req.Error, "RequestTimeTooSkewed") || errHasCode(req.Error, "SlowDown") || strings.Contains(req.Error.Error(), "connection reset") || strings.Contains(req.Error.Error(), "connection timed out") || errHasCode(req.Error, "ExpiredToken") || errHasCode(req.Error, "ExpiredTokenException")

if errHasCode(req.Error, "ExpiredToken") || errHasCode(req.Error, "ExpiredTokenException") {
log.Debug(log.DebugMessage{
Err: "Clearing the token",
})

c.sc.Clear()
}

if !shouldRetry {
shouldRetry = c.DefaultRetryer.ShouldRetry(req)
}

// Errors related to tokens
if errHasCode(req.Error, "ExpiredToken") || errHasCode(req.Error, "ExpiredTokenException") || errHasCode(req.Error, "InvalidToken") {
if errHasCode(req.Error, "InvalidToken") {
return false
}

Expand Down
10 changes: 5 additions & 5 deletions storage/s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestNewSessionPathStyle(t *testing.T) {
}

func TestNewSessionWithRegionSetViaEnv(t *testing.T) {
globalSessionCache.clear()
globalSessionCache.Clear()

const expectedRegion = "us-west-2"

Expand All @@ -116,7 +116,7 @@ func TestNewSessionWithRegionSetViaEnv(t *testing.T) {
}

func TestNewSessionWithNoSignRequest(t *testing.T) {
globalSessionCache.clear()
globalSessionCache.Clear()

sess, err := globalSessionCache.newSession(context.Background(), Options{
NoSignRequest: true,
Expand Down Expand Up @@ -190,7 +190,7 @@ aws_secret_access_key = p2_profile_access_key`
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
globalSessionCache.clear()
globalSessionCache.Clear()
sess, err := globalSessionCache.newSession(context.Background(), Options{
Profile: tc.profileName,
CredentialFile: tc.fileName,
Expand Down Expand Up @@ -1041,7 +1041,7 @@ func TestSessionRegionDetection(t *testing.T) {
opts.bucket = tc.bucket
}

globalSessionCache.clear()
globalSessionCache.Clear()

sess, err := globalSessionCache.newSession(context.Background(), opts)
if err != nil {
Expand Down Expand Up @@ -1241,7 +1241,7 @@ func TestAWSLogLevel(t *testing.T) {

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
globalSessionCache.clear()
globalSessionCache.Clear()
sess, err := globalSessionCache.newSession(context.Background(), Options{
LogLevel: log.LevelFromString(tc.level),
})
Expand Down

0 comments on commit 3e46e7e

Please sign in to comment.