diff --git a/clock.go b/clock.go index 2ff2c2a..c5b5475 100644 --- a/clock.go +++ b/clock.go @@ -54,6 +54,14 @@ func (c defaultClock) AfterFunc(d time.Duration, f func()) StopTimer { return time.AfterFunc(d, f) } +func (c defaultClock) ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc) { + return context.WithDeadline(ctx, t) +} + +func (c defaultClock) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(ctx, d) +} + // DefaultClock returns a clock that minimally wraps the `time` package func DefaultClock() Clock { return defaultClock{} @@ -80,4 +88,19 @@ type Clock interface { // The callback function f will be executed after the interval d has // elapsed, unless the returned timer's Stop() method is called first. AfterFunc(d time.Duration, f func()) StopTimer + + // ContextWithDeadline behaves like context.WithDeadline, but it uses the + // clock to determine the when the deadline has expired. + ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc) + // ContextWithDeadlineCause behaves like context.WithDeadlineCause, but it + // uses the clock to determine the when the deadline has expired. Cause is + // ignored in Go 1.20 and earlier. + ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) + // ContextWithTimeout behaves like context.WithTimeout, but it uses the + // clock to determine the when the timeout has elapsed. + ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) + // ContextWithTimeoutCause behaves like context.WithTimeoutCause, but it + // uses the clock to determine the when the timeout has elapsed. Cause is + // ignored in Go 1.20 and earlier. + ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) } diff --git a/clock_121.go b/clock_121.go new file mode 100644 index 0000000..d1939d3 --- /dev/null +++ b/clock_121.go @@ -0,0 +1,16 @@ +//go:build go1.21 + +package clocks + +import ( + "context" + "time" +) + +func (c defaultClock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) { + return context.WithDeadlineCause(ctx, t, cause) +} + +func (c defaultClock) ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) { + return context.WithTimeoutCause(ctx, d, cause) +} diff --git a/clock_121_test.go b/clock_121_test.go new file mode 100644 index 0000000..78d3f40 --- /dev/null +++ b/clock_121_test.go @@ -0,0 +1,84 @@ +//go:build go1.21 + +package clocks + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestDefaultClockContext121(t *testing.T) { + c := DefaultClock() + + t.Run("ContextWithDeadlineCause", func(t *testing.T) { + base := c.Now() + + ctx, cancel := c.ContextWithDeadlineCause(context.Background(), base.Add(time.Millisecond), errors.New("test")) + t.Cleanup(cancel) + + if v := c.SleepUntil(ctx, base.Add(time.Second)); v { + t.Errorf("unexpected return value: %t; expected false", v) + } else { + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + if context.Cause(ctx) == nil || context.Cause(ctx).Error() != "test" { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), "test") + } + } + }) + + t.Run("ContextWithDeadlineCauseCanceled", func(t *testing.T) { + base := c.Now() + + ctx, cancel := c.ContextWithDeadlineCause(context.Background(), base.Add(500*time.Millisecond), errors.New("test")) + + cancel() + + if v := c.SleepUntil(ctx, base.Add(time.Second)); v { + t.Errorf("unexpected return value: %t; expected false", v) + } else { + if ctx.Err() != context.Canceled { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.Canceled) + } + if context.Cause(ctx) == nil || context.Cause(ctx) != context.Canceled { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), context.Canceled) + } + } + }) + + t.Run("ContextWithTimeoutCause", func(t *testing.T) { + ctx, cancel := c.ContextWithTimeoutCause(context.Background(), time.Millisecond, errors.New("test")) + t.Cleanup(cancel) + + if v := c.SleepFor(ctx, time.Second); v { + t.Errorf("unexpected return value: %t; expected false", v) + } else { + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + if context.Cause(ctx) == nil || context.Cause(ctx).Error() != "test" { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), "test") + } + } + }) + + t.Run("ContextWithTimeoutCauseCanceled", func(t *testing.T) { + ctx, cancel := c.ContextWithTimeoutCause(context.Background(), 500*time.Millisecond, errors.New("test")) + + cancel() + + if v := c.SleepFor(ctx, time.Second); v { + t.Errorf("unexpected return value: %t; expected false", v) + } else { + if ctx.Err() != context.Canceled { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.Canceled) + } + if context.Cause(ctx) == nil || context.Cause(ctx) != context.Canceled { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), context.Canceled) + } + } + }) +} diff --git a/clock_pre121.go b/clock_pre121.go new file mode 100644 index 0000000..cc91707 --- /dev/null +++ b/clock_pre121.go @@ -0,0 +1,16 @@ +//go:build !go1.21 + +package clocks + +import ( + "context" + "time" +) + +func (c defaultClock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) { + return context.WithDeadline(ctx, t) +} + +func (c defaultClock) ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) { + return context.WithTimeout(ctx, d) +} diff --git a/clock_test.go b/clock_test.go index 3fe214f..b9f6d57 100644 --- a/clock_test.go +++ b/clock_test.go @@ -49,3 +49,63 @@ func TestDefaultClock(t *testing.T) { <-afCh } } + +func TestDefaultClockContext(t *testing.T) { + c := DefaultClock() + + t.Run("ContextWithDeadlineExceeded", func(t *testing.T) { + base := c.Now() + + ctx, cancel := c.ContextWithDeadline(context.Background(), base.Add(time.Millisecond)) + t.Cleanup(cancel) + + if v := c.SleepUntil(ctx, base.Add(time.Second)); v { + t.Errorf("unexpected return value: %t; expected false", v) + } else { + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + } + }) + + t.Run("ContextWithDeadlineNotExceeded", func(t *testing.T) { + base := c.Now() + + ctx, cancel := c.ContextWithDeadline(context.Background(), base.Add(3*time.Second)) + t.Cleanup(cancel) + + if v := c.SleepUntil(ctx, base.Add(time.Millisecond)); !v { + t.Errorf("unexpected return value: %t; expected true", v) + } else { + if ctx.Err() != nil { + t.Errorf("unexpected error: %v; expected nil", ctx.Err()) + } + } + }) + + t.Run("ContextWithTimeoutExceeded", func(t *testing.T) { + ctx, cancel := c.ContextWithTimeout(context.Background(), time.Millisecond) + t.Cleanup(cancel) + + if v := c.SleepFor(ctx, time.Second); v { + t.Errorf("unexpected return value: %t; expected false", v) + } else { + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + } + }) + + t.Run("ContextWithTimeoutNotExceeded", func(t *testing.T) { + ctx, cancel := c.ContextWithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + if v := c.SleepFor(ctx, time.Millisecond); !v { + t.Errorf("unexpected return value: %t; expected false", v) + } else { + if ctx.Err() != nil { + t.Errorf("unexpected error: %v; expected nil", ctx.Err()) + } + } + }) +} diff --git a/fake/fake_clock.go b/fake/fake_clock.go index 2b25c6d..093d666 100644 --- a/fake/fake_clock.go +++ b/fake/fake_clock.go @@ -3,6 +3,7 @@ package fake import ( "context" "sync" + "sync/atomic" "time" clocks "github.com/vimeo/go-clocks" @@ -410,3 +411,39 @@ func (f *Clock) AwaitTimerAborts(n int) { func (f *Clock) WaitAfterFuncs() { f.cbsWG.Wait() } + +type deadlineContext struct { + context.Context + timedOut atomic.Bool + deadline time.Time +} + +func (d *deadlineContext) Deadline() (time.Time, bool) { + return d.deadline, true +} + +func (d *deadlineContext) Err() error { + if d.timedOut.Load() { + return context.DeadlineExceeded + } + return d.Context.Err() +} + +// ContextWithDeadline behaves like context.WithDeadline, but it uses the +// clock to determine the when the deadline has expired. +func (c *Clock) ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc) { + return c.ContextWithDeadlineCause(ctx, t, nil) +} + +// ContextWithTimeout behaves like context.WithTimeout, but it uses the +// clock to determine the when the timeout has elapsed. +func (c *Clock) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + return c.ContextWithDeadlineCause(ctx, c.Now().Add(d), nil) +} + +// ContextWithTimeoutCause behaves like context.WithTimeoutCause, but it +// uses the clock to determine the when the timeout has elapsed. Cause is +// ignored in Go 1.20 and earlier. +func (c *Clock) ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) { + return c.ContextWithDeadlineCause(ctx, c.Now().Add(d), cause) +} diff --git a/fake/fake_clock_121.go b/fake/fake_clock_121.go new file mode 100644 index 0000000..e3cb3f1 --- /dev/null +++ b/fake/fake_clock_121.go @@ -0,0 +1,36 @@ +//go:build go1.21 + +package fake + +import ( + "context" + "time" +) + +// ContextWithDeadlineCause behaves like context.WithDeadlineCause, but it +// uses the clock to determine the when the deadline has expired. Cause is +// ignored in Go 1.20 and earlier. +func (f *Clock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) { + cctx, cancelCause := context.WithCancelCause(ctx) + dctx := &deadlineContext{ + Context: cctx, + deadline: t, + } + dur := f.Until(t) + if dur <= 0 { + dctx.timedOut.Store(true) + cancelCause(cause) + return dctx, func() {} + } + stop := f.AfterFunc(dur, func() { + if cctx.Err() == nil { + dctx.timedOut.Store(true) + } + cancelCause(cause) + }) + cancel := func() { + cancelCause(context.Canceled) + stop.Stop() + } + return dctx, cancel +} diff --git a/fake/fake_clock_121_test.go b/fake/fake_clock_121_test.go new file mode 100644 index 0000000..28e5d47 --- /dev/null +++ b/fake/fake_clock_121_test.go @@ -0,0 +1,116 @@ +//go:build go1.21 + +package fake + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestFakeClockContext121(t *testing.T) { + t.Run("ContextWithDeadlineCause", func(t *testing.T) { + base := time.Now() + c := NewClock(base) + + ctx, cancel := c.ContextWithDeadlineCause(context.Background(), base.Add(1), errors.New("test")) + t.Cleanup(cancel) + + c.Advance(1) + + select { + case <-ctx.Done(): + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + if context.Cause(ctx) == nil || context.Cause(ctx).Error() != "test" { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), "test") + } + case <-time.After(time.Second): + t.Errorf("context not done after 1 second") + } + }) + + t.Run("ContextWithDeadlineCanceled", func(t *testing.T) { + base := time.Now() + c := NewClock(base) + + ctx, cancel := c.ContextWithDeadlineCause(context.Background(), base.Add(1), errors.New("test")) + t.Cleanup(cancel) + + cancel() + + select { + case <-ctx.Done(): + if ctx.Err() != context.Canceled { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.Canceled) + } + if context.Cause(ctx) == nil || context.Cause(ctx) != context.Canceled { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), context.Canceled) + } + case <-time.After(time.Second): + t.Errorf("context not done after 1 second") + } + }) + + t.Run("ContextWithDeadlineCausePast", func(t *testing.T) { + base := time.Now() + c := NewClock(base) + + ctx, cancel := c.ContextWithDeadlineCause(context.Background(), base.Add(-1), errors.New("test")) + t.Cleanup(cancel) + + select { + case <-ctx.Done(): + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + if context.Cause(ctx) == nil || context.Cause(ctx).Error() != "test" { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), "test") + } + case <-time.After(time.Second): + t.Errorf("context not done after 1 second") + } + }) + + t.Run("ContextWithTimeoutCause", func(t *testing.T) { + c := NewClock(time.Now()) + ctx, cancel := c.ContextWithTimeoutCause(context.Background(), 1, errors.New("test")) + t.Cleanup(cancel) + + c.Advance(1) + + select { + case <-ctx.Done(): + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + if context.Cause(ctx) == nil || context.Cause(ctx).Error() != "test" { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), "test") + } + case <-time.After(time.Second): + t.Errorf("context not done after 1 second") + } + }) + + t.Run("ContextWithTimeoutCauseCanceled", func(t *testing.T) { + c := NewClock(time.Now()) + ctx, cancel := c.ContextWithTimeoutCause(context.Background(), 1, errors.New("test")) + t.Cleanup(cancel) + + cancel() + + select { + case <-ctx.Done(): + if ctx.Err() != context.Canceled { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.Canceled) + } + if context.Cause(ctx) == nil || context.Cause(ctx) != context.Canceled { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), context.Canceled) + } + case <-time.After(time.Second): + t.Errorf("context not done after 1 second") + } + }) +} diff --git a/fake/fake_clock_pre121.go b/fake/fake_clock_pre121.go new file mode 100644 index 0000000..2d690e0 --- /dev/null +++ b/fake/fake_clock_pre121.go @@ -0,0 +1,36 @@ +//go:build !go1.21 + +package fake + +import ( + "context" + "time" +) + +// ContextWithDeadlineCause behaves like context.WithDeadlineCause, but it +// uses the clock to determine the when the deadline has expired. Cause is +// ignored in Go 1.20 and earlier. +func (f *Clock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) { + cctx, cancel := context.WithCancel(ctx) + dctx := &deadlineContext{ + Context: cctx, + deadline: t, + } + dur := f.Until(t) + if dur <= 0 { + dctx.timedOut.Store(true) + cancel() + return dctx, func() {} + } + stop := f.AfterFunc(dur, func() { + if cctx.Err() == nil { + dctx.timedOut.Store(true) + } + cancel() + }) + cancelStop := func() { + cancel() + stop.Stop() + } + return dctx, cancelStop +} diff --git a/fake/fake_clock_test.go b/fake/fake_clock_test.go index 407ddb1..d0fb3fa 100644 --- a/fake/fake_clock_test.go +++ b/fake/fake_clock_test.go @@ -668,3 +668,90 @@ func TestFakeClockAfterFuncNegDur(t *testing.T) { } } + +func TestFakeClockContext(t *testing.T) { + t.Run("ContextDeadline", func(t *testing.T) { + base := time.Now() + c := NewClock(base) + + deadline := base.Add(1) + ctx, cancel := c.ContextWithDeadline(context.Background(), deadline) + t.Cleanup(cancel) + + ctxDeadline, isSet := ctx.Deadline() + if !isSet { + t.Errorf("context deadline not set") + } + if !ctxDeadline.Equal(deadline) { + t.Errorf("unexpected context deadline: %v; expected %v", ctxDeadline, deadline) + } + }) + + t.Run("ContextWithDeadlineExceeded", func(t *testing.T) { + base := time.Now() + c := NewClock(base) + + ctx, cancel := c.ContextWithDeadline(context.Background(), base.Add(1)) + t.Cleanup(cancel) + + c.Advance(1) + + select { + case <-ctx.Done(): + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + case <-time.After(time.Second): + t.Errorf("context not done after 1 second") + } + }) + + t.Run("ContextWithDeadlineNotExceeded", func(t *testing.T) { + base := time.Now() + c := NewClock(base) + + ctx, cancel := c.ContextWithDeadline(context.Background(), base.Add(1)) + t.Cleanup(cancel) + + select { + case <-ctx.Done(): + t.Errorf("context should not be done") + default: + if ctx.Err() != nil { + t.Errorf("unexpected error: %v; expected nil", ctx.Err()) + } + } + }) + + t.Run("ContextWithTimeoutExceeded", func(t *testing.T) { + c := NewClock(time.Now()) + ctx, cancel := c.ContextWithTimeout(context.Background(), 1) + t.Cleanup(cancel) + + c.Advance(1) + + select { + case <-ctx.Done(): + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + case <-time.After(time.Second): + t.Errorf("context not done after 1 second") + } + }) + + t.Run("ContextWithTimeouteNotExceeded", func(t *testing.T) { + c := NewClock(time.Now()) + ctx, cancel := c.ContextWithTimeout(context.Background(), 1) + t.Cleanup(cancel) + + select { + case <-ctx.Done(): + t.Errorf("context should not be done") + default: + if ctx.Err() != nil { + t.Errorf("unexpected error: %v; expected nil", ctx.Err()) + } + } + }) +} diff --git a/go.mod b/go.mod index 8790e9a..2bdc894 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/vimeo/go-clocks -go 1.14 +go 1.19 diff --git a/offset/offset_clock.go b/offset/offset_clock.go index a83375a..9a4ffea 100644 --- a/offset/offset_clock.go +++ b/offset/offset_clock.go @@ -49,6 +49,19 @@ func (o *Clock) AfterFunc(d time.Duration, f func()) clocks.StopTimer { return o.inner.AfterFunc(d, f) } +// ContextWithDeadline behaves like context.WithDeadline, but it uses the +// clock to determine the when the deadline has expired. +func (o *Clock) ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc) { + return o.inner.ContextWithDeadline(ctx, t.Add(o.offset)) +} + +// ContextWithTimeout behaves like context.WithTimeout, but it uses the +// clock to determine the when the timeout has elapsed. +func (o *Clock) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + // timeout is relative, so it doesn't need any adjustment + return o.inner.ContextWithTimeout(ctx, d) +} + // NewOffsetClock creates an OffsetClock. offset is added to all absolute times. func NewOffsetClock(inner clocks.Clock, offset time.Duration) *Clock { return &Clock{ diff --git a/offset/offset_clock_121.go b/offset/offset_clock_121.go new file mode 100644 index 0000000..e116165 --- /dev/null +++ b/offset/offset_clock_121.go @@ -0,0 +1,22 @@ +//go:build go1.21 + +package offset + +import ( + "context" + "time" +) + +// ContextWithDeadlineCause behaves like context.WithDeadlineCause, but it +// uses the clock to determine the when the deadline has expired. Cause is +// ignored in Go 1.20 and earlier. +func (o *Clock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) { + return o.inner.ContextWithDeadlineCause(ctx, t.Add(o.offset), cause) +} + +// ContextWithTimeoutCause behaves like context.WithTimeoutCause, but it +// uses the clock to determine the when the timeout has elapsed. Cause is +// ignored in Go 1.20 and earlier. +func (o *Clock) ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) { + return o.inner.ContextWithTimeoutCause(ctx, d, cause) +} diff --git a/offset/offset_clock_121_test.go b/offset/offset_clock_121_test.go new file mode 100644 index 0000000..0a70f62 --- /dev/null +++ b/offset/offset_clock_121_test.go @@ -0,0 +1,94 @@ +//go:build go1.21 + +package offset + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/vimeo/go-clocks/fake" +) + +func TestOffsetClockContext121(t *testing.T) { + t.Run("ContextWithDeadlineCause", func(t *testing.T) { + base := time.Now() + inner := fake.NewClock(base) + c := NewOffsetClock(inner, time.Hour) + + ctx, cancel := c.ContextWithDeadlineCause(context.Background(), inner.Now().Add(time.Hour), errors.New("test")) + t.Cleanup(cancel) + + awoken := inner.Advance(2 * time.Hour) + if awoken != 1 { + t.Errorf("unexpected number of awoken sleepers: %d; expected 1", awoken) + } + + <-ctx.Done() + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + if context.Cause(ctx) == nil || context.Cause(ctx).Error() != "test" { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), "test") + } + }) + + t.Run("ContextWithDeadlineCauseCanceled", func(t *testing.T) { + base := time.Now() + inner := fake.NewClock(base) + c := NewOffsetClock(inner, time.Hour) + + ctx, cancel := c.ContextWithDeadlineCause(context.Background(), inner.Now().Add(time.Hour), errors.New("test")) + + cancel() + + <-ctx.Done() + if ctx.Err() != context.Canceled { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.Canceled) + } + if context.Cause(ctx) == nil || context.Cause(ctx) != context.Canceled { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), context.Canceled) + } + }) + + t.Run("ContextWithTimeoutCause", func(t *testing.T) { + base := time.Now() + inner := fake.NewClock(base) + c := NewOffsetClock(inner, time.Hour) + + ctx, cancel := c.ContextWithTimeoutCause(context.Background(), time.Hour, errors.New("test")) + t.Cleanup(cancel) + + awoken := inner.Advance(time.Hour) + if awoken != 1 { + t.Errorf("unexpected number of awoken sleepers: %d; expected 1", awoken) + } + + <-ctx.Done() + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + if context.Cause(ctx) == nil || context.Cause(ctx).Error() != "test" { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), "test") + } + }) + + t.Run("ContextWithTimeoutCauseCanceled", func(t *testing.T) { + base := time.Now() + inner := fake.NewClock(base) + c := NewOffsetClock(inner, time.Hour) + + ctx, cancel := c.ContextWithTimeoutCause(context.Background(), time.Hour, errors.New("test")) + + cancel() + + <-ctx.Done() + if ctx.Err() != context.Canceled { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.Canceled) + } + if context.Cause(ctx) == nil || context.Cause(ctx) != context.Canceled { + t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), context.Canceled) + } + }) +} diff --git a/offset/offset_clock_pre121.go b/offset/offset_clock_pre121.go new file mode 100644 index 0000000..df0f120 --- /dev/null +++ b/offset/offset_clock_pre121.go @@ -0,0 +1,22 @@ +//go:build !go1.21 + +package offset + +import ( + "context" + "time" +) + +// ContextWithDeadlineCause behaves like context.WithDeadlineCause, but it +// uses the clock to determine the when the deadline has expired. Cause is +// ignored in Go 1.20 and earlier. +func (o *Clock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) { + return o.inner.ContextWithDeadline(ctx, t.Add(o.offset)) +} + +// ContextWithTimeoutCause behaves like context.WithTimeoutCause, but it +// uses the clock to determine the when the timeout has elapsed. Cause is +// ignored in Go 1.20 and earlier. +func (o *Clock) ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) { + return o.inner.ContextWithTimeout(ctx, d+o.offset) +} diff --git a/offset/offset_clock_test.go b/offset/offset_clock_test.go index e14fe29..ffa11e5 100644 --- a/offset/offset_clock_test.go +++ b/offset/offset_clock_test.go @@ -104,3 +104,89 @@ func TestOffsetClock(t *testing.T) { <-ch } } + +func TestOffsetClockContext(t *testing.T) { + t.Run("ContextWithDeadlineExceeded", func(t *testing.T) { + base := time.Now() + inner := fake.NewClock(base) + c := NewOffsetClock(inner, time.Hour) + + ctx, cancel := c.ContextWithDeadline(context.Background(), inner.Now().Add(time.Hour)) + t.Cleanup(cancel) + + awoken := inner.Advance(2 * time.Hour) + if awoken != 1 { + t.Errorf("unexpected number of awoken sleepers: %d; expected 1", awoken) + } + + <-ctx.Done() + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + }) + + t.Run("ContextWithDeadlineNotExceeded", func(t *testing.T) { + base := time.Now() + inner := fake.NewClock(base) + c := NewOffsetClock(inner, time.Hour) + + ctx, cancel := c.ContextWithDeadline(context.Background(), inner.Now().Add(time.Hour)) + t.Cleanup(cancel) + + awoken := inner.Advance(2*time.Hour - 1*time.Nanosecond) + if awoken != 0 { + t.Errorf("unexpected number of awoken sleepers: %d; expected 0", awoken) + } + + select { + case <-ctx.Done(): + t.Errorf("context should not be done") + default: + if ctx.Err() != nil { + t.Errorf("unexpected error: %v; expected nil", ctx.Err()) + } + } + }) + + t.Run("ContextWithTimeoutExceeded", func(t *testing.T) { + base := time.Now() + inner := fake.NewClock(base) + c := NewOffsetClock(inner, time.Hour) + + ctx, cancel := c.ContextWithTimeout(context.Background(), time.Hour) + t.Cleanup(cancel) + + awoken := inner.Advance(time.Hour) + if awoken != 1 { + t.Errorf("unexpected number of awoken sleepers: %d; expected 1", awoken) + } + + <-ctx.Done() + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded) + } + }) + + t.Run("ContextWithTimeouteNotExceeded", func(t *testing.T) { + base := time.Now() + inner := fake.NewClock(base) + c := NewOffsetClock(inner, time.Hour) + + ctx, cancel := c.ContextWithTimeout(context.Background(), time.Hour) + t.Cleanup(cancel) + + awoken := inner.Advance(time.Hour - time.Nanosecond) + if awoken != 0 { + t.Errorf("unexpected number of awoken sleepers: %d; expected 0", awoken) + } + + select { + case <-ctx.Done(): + t.Errorf("context should not be done") + default: + if ctx.Err() != nil { + t.Errorf("unexpected error: %v; expected nil", ctx.Err()) + } + } + }) +}