Skip to content

Commit

Permalink
Add context methods to Clock interface
Browse files Browse the repository at this point in the history
This allows for using timeout/deadline functionality built in to
context.Context with a custom clock implementation.

Module Go version bumped to 1.19 due to use of atomic.Bool
  • Loading branch information
justinruggles committed Oct 27, 2023
1 parent e868797 commit 71e0810
Show file tree
Hide file tree
Showing 16 changed files with 592 additions and 1 deletion.
23 changes: 23 additions & 0 deletions clock.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ func (c defaultClock) AfterFunc(d time.Duration, f func()) StopTimer {
return time.AfterFunc(d, f)
}

func (c defaultClock) ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc) {
return context.WithDeadline(ctx, t)
}

func (c defaultClock) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(ctx, d)
}

// DefaultClock returns a clock that minimally wraps the `time` package
func DefaultClock() Clock {
return defaultClock{}
Expand All @@ -80,4 +88,19 @@ type Clock interface {
// The callback function f will be executed after the interval d has
// elapsed, unless the returned timer's Stop() method is called first.
AfterFunc(d time.Duration, f func()) StopTimer

// ContextWithDeadline behaves like context.WithDeadline, but it uses the
// clock to determine the when the deadline has expired.
ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc)
// ContextWithDeadlineCause behaves like context.WithDeadlineCause, but it
// uses the clock to determine the when the deadline has expired. Cause is
// ignored in Go 1.20 and earlier.
ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc)
// ContextWithTimeout behaves like context.WithTimeout, but it uses the
// clock to determine the when the timeout has elapsed.
ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc)
// ContextWithTimeoutCause behaves like context.WithTimeoutCause, but it
// uses the clock to determine the when the timeout has elapsed. Cause is
// ignored in Go 1.20 and earlier.
ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc)
}
16 changes: 16 additions & 0 deletions clock_121.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//go:build go1.21

package clocks

import (
"context"
"time"
)

func (c defaultClock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) {
return context.WithDeadlineCause(ctx, t, cause)
}

func (c defaultClock) ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) {
return context.WithTimeoutCause(ctx, d, cause)
}
48 changes: 48 additions & 0 deletions clock_121_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//go:build go1.21

package clocks

import (
"context"
"errors"
"testing"
"time"
)

func TestDefaultClockContext121(t *testing.T) {
c := DefaultClock()

t.Run("ContextWithDeadlineCause", func(t *testing.T) {
base := c.Now()

ctx, cancel := c.ContextWithDeadlineCause(context.Background(), base.Add(time.Millisecond), errors.New("test"))
t.Cleanup(cancel)

if v := c.SleepUntil(ctx, base.Add(time.Second)); v {
t.Errorf("unexpected return value: %t; expected false", v)
} else {
if ctx.Err() != context.DeadlineExceeded {
t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded)
}
if context.Cause(ctx) == nil || context.Cause(ctx).Error() != "test" {
t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), "test")
}
}
})

t.Run("ContextWithTimeoutCause", func(t *testing.T) {
ctx, cancel := c.ContextWithTimeoutCause(context.Background(), time.Millisecond, errors.New("test"))
t.Cleanup(cancel)

if v := c.SleepFor(ctx, time.Second); v {
t.Errorf("unexpected return value: %t; expected false", v)
} else {
if ctx.Err() != context.DeadlineExceeded {
t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded)
}
if context.Cause(ctx).Error() != "test" {
t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), "test")
}
}
})
}
16 changes: 16 additions & 0 deletions clock_pre121.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//go:build !go1.21

package clocks

import (
"context"
"time"
)

func (c defaultClock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) {
return context.WithDeadline(ctx, t)
}

func (c defaultClock) ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) {
return context.WithTimeout(ctx, d)
}
60 changes: 60 additions & 0 deletions clock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,63 @@ func TestDefaultClock(t *testing.T) {
<-afCh
}
}

func TestDefaultClockContext(t *testing.T) {
c := DefaultClock()

t.Run("ContextWithDeadlineExceeded", func(t *testing.T) {
base := c.Now()

ctx, cancel := c.ContextWithDeadline(context.Background(), base.Add(time.Millisecond))
t.Cleanup(cancel)

if v := c.SleepUntil(ctx, base.Add(time.Second)); v {
t.Errorf("unexpected return value: %t; expected false", v)
} else {
if ctx.Err() != context.DeadlineExceeded {
t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded)
}
}
})

t.Run("ContextWithDeadlineNotExceeded", func(t *testing.T) {
base := c.Now()

ctx, cancel := c.ContextWithDeadline(context.Background(), base.Add(3*time.Second))
t.Cleanup(cancel)

if v := c.SleepUntil(ctx, base.Add(time.Millisecond)); !v {
t.Errorf("unexpected return value: %t; expected true", v)
} else {
if ctx.Err() != nil {
t.Errorf("unexpected error: %v; expected nil", ctx.Err())
}
}
})

t.Run("ContextWithTimeoutExceeded", func(t *testing.T) {
ctx, cancel := c.ContextWithTimeout(context.Background(), time.Millisecond)
t.Cleanup(cancel)

if v := c.SleepFor(ctx, time.Second); v {
t.Errorf("unexpected return value: %t; expected false", v)
} else {
if ctx.Err() != context.DeadlineExceeded {
t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded)
}
}
})

t.Run("ContextWithTimeoutNotExceeded", func(t *testing.T) {
ctx, cancel := c.ContextWithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)

if v := c.SleepFor(ctx, time.Millisecond); !v {
t.Errorf("unexpected return value: %t; expected false", v)
} else {
if ctx.Err() != nil {
t.Errorf("unexpected error: %v; expected nil", ctx.Err())
}
}
})
}
37 changes: 37 additions & 0 deletions fake/fake_clock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package fake
import (
"context"
"sync"
"sync/atomic"
"time"

clocks "github.com/vimeo/go-clocks"
Expand Down Expand Up @@ -410,3 +411,39 @@ func (f *Clock) AwaitTimerAborts(n int) {
func (f *Clock) WaitAfterFuncs() {
f.cbsWG.Wait()
}

type deadlineContext struct {
context.Context
timedOut atomic.Bool
deadline time.Time
}

func (d *deadlineContext) Deadline() (time.Time, bool) {
return d.deadline, true
}

func (d *deadlineContext) Err() error {
if d.timedOut.Load() {
return context.DeadlineExceeded
}
return d.Context.Err()
}

// ContextWithDeadline behaves like context.WithDeadline, but it uses the
// clock to determine the when the deadline has expired.
func (c *Clock) ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc) {
return c.ContextWithDeadlineCause(ctx, t, nil)
}

// ContextWithTimeout behaves like context.WithTimeout, but it uses the
// clock to determine the when the timeout has elapsed.
func (c *Clock) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
return c.ContextWithDeadlineCause(ctx, c.Now().Add(d), nil)
}

// ContextWithTimeoutCause behaves like context.WithTimeoutCause, but it
// uses the clock to determine the when the timeout has elapsed. Cause is
// ignored in Go 1.20 and earlier.
func (c *Clock) ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) {
return c.ContextWithDeadlineCause(ctx, c.Now().Add(d), cause)
}
36 changes: 36 additions & 0 deletions fake/fake_clock_121.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//go:build go1.21

package fake

import (
"context"
"time"
)

// ContextWithDeadlineCause behaves like context.WithDeadlineCause, but it
// uses the clock to determine the when the deadline has expired. Cause is
// ignored in Go 1.20 and earlier.
func (f *Clock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) {
cctx, cancelCause := context.WithCancelCause(ctx)
dctx := &deadlineContext{
Context: cctx,
deadline: t,
}
dur := f.Until(t)
if dur <= 0 {
dctx.timedOut.Store(true)
cancelCause(cause)
return dctx, func() {}
}
stop := f.AfterFunc(dur, func() {
if cctx.Err() == nil {
dctx.timedOut.Store(true)
}
cancelCause(cause)
})
cancel := func() {
cancelCause(context.Canceled)
stop.Stop()
}
return dctx, cancel
}
52 changes: 52 additions & 0 deletions fake/fake_clock_121_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package fake

import (
"context"
"errors"
"testing"
"time"
)

func TestFakeClockContext121(t *testing.T) {
t.Run("ContextWithDeadlineCause", func(t *testing.T) {
base := time.Now()
c := NewClock(base)

ctx, cancel := c.ContextWithDeadlineCause(context.Background(), base.Add(1), errors.New("test"))
t.Cleanup(cancel)

c.Advance(1)

select {
case <-ctx.Done():
if ctx.Err() != context.DeadlineExceeded {
t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded)
}
if context.Cause(ctx) == nil || context.Cause(ctx).Error() != "test" {
t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), "test")
}
case <-time.After(time.Second):
t.Errorf("context not done after 1 second")
}
})

t.Run("ContextWithTimeoutCause", func(t *testing.T) {
c := NewClock(time.Now())
ctx, cancel := c.ContextWithTimeoutCause(context.Background(), 1, errors.New("test"))
t.Cleanup(cancel)

c.Advance(1)

select {
case <-ctx.Done():
if ctx.Err() != context.DeadlineExceeded {
t.Errorf("unexpected error: %v; expected %v", ctx.Err(), context.DeadlineExceeded)
}
if context.Cause(ctx) == nil || context.Cause(ctx).Error() != "test" {
t.Errorf("unexpected cause: %v; expected %v", context.Cause(ctx), "test")
}
case <-time.After(time.Second):
t.Errorf("context not done after 1 second")
}
})
}
36 changes: 36 additions & 0 deletions fake/fake_clock_pre121.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//go:build !go1.21

package fake

import (
"context"
"time"
)

// ContextWithDeadlineCause behaves like context.WithDeadlineCause, but it
// uses the clock to determine the when the deadline has expired. Cause is
// ignored in Go 1.20 and earlier.
func (f *Clock) ContextWithDeadlineCause(ctx context.Context, t time.Time, cause error) (context.Context, context.CancelFunc) {
cctx, cancel := context.WithCancel(ctx)
dctx := &deadlineContext{
Context: cctx,
deadline: t,
}
dur := f.Until(t)
if dur <= 0 {
dctx.timedOut.Store(true)
cancel()
return dctx, func() {}
}
stop := f.AfterFunc(dur, func() {
if cctx.Err() == nil {
dctx.timedOut.Store(true)
}
cancel()
})
cancelStop := func() {
cancel()
stop.Stop()
}
return dctx, cancelStop
}
Loading

0 comments on commit 71e0810

Please sign in to comment.