diff --git a/CHANGELOG.md b/CHANGELOG.md index 79485a0b2..8bdf898ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -76,6 +76,7 @@ * [CHANGE] Removed unused `time.Duration` parameter from `ShouldLog()` function in `middleware.OptionalLogging` interface. #513 * [CHANGE] Changed `ShouldLog()` function signature in `middleware.OptionalLogging` interface to `ShouldLog(context.Context) (bool, string)`: the returned `string` contains an optional reason. When reason is valued, `GRPCServerLog` adds `()` suffix to the error. #514 * [CHANGE] Cache: Remove superfluous `cache.RemoteCacheClient` interface and unify all caches using the `cache.Cache` interface. #520 +* [CHANGE] Backoff: `Backoff.Err()` now returns the context cancellation cause when provided to the context passed to `Backoff`. #538 * [FEATURE] Cache: Add support for configuring a Redis cache backend. #268 #271 #276 * [FEATURE] Add support for waiting on the rate limiter using the new `WaitN` method. #279 * [FEATURE] Add `log.BufferedLogger` type. #338 diff --git a/backoff/backoff.go b/backoff/backoff.go index 7ce556472..5468929bc 100644 --- a/backoff/backoff.go +++ b/backoff/backoff.go @@ -54,10 +54,12 @@ func (b *Backoff) Ongoing() bool { return b.ctx.Err() == nil && (b.cfg.MaxRetries == 0 || b.numRetries < b.cfg.MaxRetries) } -// Err returns the reason for terminating the backoff, or nil if it didn't terminate +// Err returns the reason for terminating the backoff, or nil if it didn't terminate. +// If backoff is terminated because the context has been canceled, then this function +// returns the context cancellation cause. func (b *Backoff) Err() error { if b.ctx.Err() != nil { - return b.ctx.Err() + return context.Cause(b.ctx) } if b.cfg.MaxRetries != 0 && b.numRetries >= b.cfg.MaxRetries { return fmt.Errorf("terminated after %d retries", b.numRetries) diff --git a/backoff/backoff_test.go b/backoff/backoff_test.go index dff6432c0..09379908d 100644 --- a/backoff/backoff_test.go +++ b/backoff/backoff_test.go @@ -4,6 +4,9 @@ import ( "context" "testing" "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" ) func TestBackoff_NextDelay(t *testing.T) { @@ -101,3 +104,62 @@ func TestBackoff_NextDelay(t *testing.T) { }) } } + +func TestBackoff_Err(t *testing.T) { + cause := errors.New("my cause") + + tests := map[string]struct { + ctx func(*testing.T) context.Context + expectedErr error + }{ + "should return context.DeadlineExceeded when context deadline exceeded without cause": { + ctx: func(t *testing.T) context.Context { + ctx, cancel := context.WithDeadline(context.Background(), time.Now()) + t.Cleanup(cancel) + + return ctx + }, + expectedErr: context.DeadlineExceeded, + }, + "should return cause when context deadline exceeded with cause": { + ctx: func(t *testing.T) context.Context { + ctx, cancel := context.WithDeadlineCause(context.Background(), time.Now(), cause) + t.Cleanup(cancel) + + return ctx + }, + expectedErr: cause, + }, + "should return context.Canceled when context is canceled without cause": { + ctx: func(_ *testing.T) context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + return ctx + }, + expectedErr: context.Canceled, + }, + "should return cause when context is canceled with cause": { + ctx: func(_ *testing.T) context.Context { + ctx, cancel := context.WithCancelCause(context.Background()) + cancel(cause) + + return ctx + }, + expectedErr: cause, + }, + } + + for testName, testData := range tests { + t.Run(testName, func(t *testing.T) { + b := New(testData.ctx(t), Config{}) + + // Wait until the backoff returns error. + require.Eventually(t, func() bool { + return b.Err() != nil + }, time.Second, 10*time.Millisecond) + + require.Equal(t, testData.expectedErr, b.Err()) + }) + } +}