diff --git a/clock.go b/clock.go index c5b5475..87953db 100644 --- a/clock.go +++ b/clock.go @@ -62,6 +62,11 @@ func (c defaultClock) ContextWithTimeout(ctx context.Context, d time.Duration) ( return context.WithTimeout(ctx, d) } +func (c defaultClock) NewTimer(d time.Duration) Timer { + t := time.NewTimer(d) + return &defaultTimer{Timer: t} +} + // DefaultClock returns a clock that minimally wraps the `time` package func DefaultClock() Clock { return defaultClock{} @@ -103,4 +108,14 @@ type Clock interface { // uses the clock to determine the when the timeout has elapsed. Cause is // ignored in Go 1.20 and earlier. ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc) + + // NewTimer returns a Timer implementation which will fire after at + // least the specified duration [d]. The Ch() method returns a channel, + // and should be called inline with the receive or select case. + // + // Timers are most useful in select/case blocks. For simple cases, + // SleepFor should be preferred. + // + // Stop() is inherently racy. Be wary of the return value. + NewTimer(d time.Duration) Timer } diff --git a/fake/fake_clock.go b/fake/fake_clock.go index 093d666..89923eb 100644 --- a/fake/fake_clock.go +++ b/fake/fake_clock.go @@ -13,7 +13,6 @@ import ( // testing and skipping through timestamps without having to actually sleep in // the test. type Clock struct { - mu sync.Mutex current time.Time // sleepers contains a map from a channel on which that // sleeper is sleeping to a target-time. When time is advanced past a @@ -28,9 +27,8 @@ type Clock struct { // protection necessary). cbsWG sync.WaitGroup - // cond is broadcasted() upon any sleep or wakeup event (mutations to - // sleepers or cbs). - cond sync.Cond + // timer tracker + timerTrack timerTracker // counter tracking the number of wakeups (protected by mu). wakeups int @@ -51,6 +49,21 @@ type Clock struct { // counter tracking the number of callbacks that have ever been // registered (via AfterFunc) (protected by mu). callbacksAggregate int + + // counter tracking the number of extracted channels (protected by mu). + extractedChans int + + // counter tracking the aggregate number of extracted channels (protected by mu). + extractedChansAggregate int + + // counter tracking the number of number of aggregate signaled timer channels + signaledChans int + + // cond is broadcasted() upon any sleep or wakeup event (mutations to + // sleepers or cbs). + cond sync.Cond + + mu sync.Mutex } var _ clocks.Clock = (*Clock)(nil) @@ -62,7 +75,11 @@ func NewClock(initialTime time.Time) *Clock { sleepers: map[chan<- struct{}]time.Time{}, cbs: map[*stopTimer]time.Time{}, cond: sync.Cond{}, + timerTrack: timerTracker{ + timers: map[*fakeTimer]time.Time{}, + }, } + fc.timerTrack.fc = &fc fc.cond.L = &fc.mu return &fc } @@ -77,6 +94,10 @@ func (f *Clock) setClockLocked(t time.Time, cbRunningWG *sync.WaitGroup) int { awoken++ } } + + timerWakeRes := f.timerTrack.wakeup(t) + f.signaledChans += timerWakeRes.notified + cbsRun := 0 for s, target := range f.cbs { if target.Sub(t) <= 0 { @@ -95,7 +116,7 @@ func (f *Clock) setClockLocked(t time.Time, cbRunningWG *sync.WaitGroup) int { f.callbackExecs += cbsRun f.current = t f.cond.Broadcast() - return awoken + cbsRun + return awoken + cbsRun + timerWakeRes.awoken } // SetClock skips the FakeClock to the specified time (forward or backwards) The @@ -344,6 +365,22 @@ func (f *Clock) AfterFunc(d time.Duration, cb func()) clocks.StopTimer { return s } +// NewTimer creates a new Timer +func (f *Clock) NewTimer(d time.Duration) clocks.Timer { + target := f.Now().Add(d) + // Capacity 1 so sending never blocks + ch := make(chan time.Time, 1) + + ft := fakeTimer{ + ch: ch, + tracker: &f.timerTrack, + } + + f.timerTrack.registerTimer(&ft, target) + + return &ft +} + // NumCallbackExecs returns the number of registered callbacks that have been // executed due to time advancement. func (f *Clock) NumCallbackExecs() int { @@ -396,8 +433,8 @@ func (f *Clock) AwaitRegisteredCallbacks(n int) { } } -// AwaitTimerAborts waits until the aggregate number of registered callbacks -// (via AfterFunc) exceeds its argument. +// AwaitTimerAborts waits until the aggregate number of aborted callbacks +// (via AfterFunc) or timers exceeds its argument. func (f *Clock) AwaitTimerAborts(n int) { f.mu.Lock() defer f.mu.Unlock() @@ -406,6 +443,53 @@ func (f *Clock) AwaitTimerAborts(n int) { } } +// AwaitAggExtractedChans waits the aggregate number of calls to Ch() on +// timers to equal or exceed its argument. +// For this method to be most useful, users of timers should not store the +// value of .Ch(). Instead, call .Ch(), dereference the pointer, and attempt a +// receive immediately, as in case <-*timer.Ch(). +func (f *Clock) AwaitAggExtractedChans(n int) { + f.mu.Lock() + defer f.mu.Unlock() + for f.extractedChansAggregate < n { + f.cond.Wait() + } +} + +// NumAggExtractedChans returns the aggregate number of calls to Ch() on +// timers. +// For this method to be most useful, users of timers should not store the +// value of .Ch(). Instead, call .Ch(), dereference the pointer, and attempt a +// receive immediately, as in case <-*timer.Ch(). +func (f *Clock) NumAggExtractedChans() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.extractedChansAggregate +} + +// numExtractedChans returns the aggregate number of calls to Ch() on +// timers. +func (f *Clock) numExtractedChans() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.extractedChans +} + +// awaitExtractedChans waits the number of calls to Ch() on +// timers to equal or exceed its argument. +func (f *Clock) awaitExtractedChans(n int) { + f.mu.Lock() + defer f.mu.Unlock() + for f.extractedChans < n { + f.cond.Wait() + } +} + +// RegisteredTimers returns the execution-times of registered timers. +func (f *Clock) RegisteredTimers() []time.Time { + return f.timerTrack.registeredTimers() +} + // WaitAfterFuncs blocks until all currently running AfterFunc callbacks // return. func (f *Clock) WaitAfterFuncs() { diff --git a/fake/fake_clock_test.go b/fake/fake_clock_test.go index d0fb3fa..8a98b5c 100644 --- a/fake/fake_clock_test.go +++ b/fake/fake_clock_test.go @@ -224,6 +224,272 @@ func TestFakeClockWithRelativeWaiter(t *testing.T) { } } +func TestFakeClockWithTimer(t *testing.T) { + t.Parallel() + + baseTime := time.Now() + fc := NewClock(baseTime) + + expectedTime := baseTime + + if fn := fc.Now(); !fn.Equal(baseTime) { + t.Errorf("mismatched baseTime(%s) and unincremented Now()(%s)", baseTime, fn) + } + if wakers := fc.Advance(time.Minute); wakers != 0 { + t.Errorf("unexpected wakers from advancing 1 minute(%d); expected 0", wakers) + } + expectedTime = expectedTime.Add(time.Minute) + + if fn := fc.Now(); !fn.Equal(expectedTime) { + t.Errorf("mismatched baseTime(%s) and unincremented Now()(%s)", expectedTime, fn) + } + + sleeperWake := expectedTime.Add(time.Hour * 2) + timer := fc.NewTimer(time.Hour * 2) + ch := make(chan struct{}) + go func() { + <-*timer.Ch() + ch <- struct{}{} + }() + + fc.AwaitAggExtractedChans(1) + + if sl := fc.NumSleepers(); sl != 0 { + t.Errorf("unexpected sleeper-count: %d; expected 0", sl) + } + if sl := fc.Sleepers(); len(sl) != 0 { + t.Errorf("unexpected sleeper-count: %d; expected 0", len(sl)) + } + + if as := fc.NumAggExtractedChans(); as != 1 { + t.Errorf("unexpected number of aggregate aggregate extracted channels: %d; expected 1", as) + } + + if regTimers := fc.RegisteredTimers(); len(regTimers) != 1 { + t.Errorf("unexpected registered timer-count %d; %v", len(regTimers), regTimers) + } + + // make sure we're still sleeping + select { + case <-ch: + t.Errorf("sleeper finished unexpectedly early") + default: + } + + expectedTime = expectedTime.Add(time.Hour) + if wakers := fc.SetClock(expectedTime); wakers != 0 { + t.Errorf("unexpected wakers from advancing 1 hour(%d); expected 0", wakers) + } + + // make sure we're still sleeping after the SetClock call + select { + case <-ch: + t.Errorf("sleeper finished unexpectedly early") + default: + } + + fc.awaitExtractedChans(1) + // verify that our one sleeper is still sleeping + if sl := fc.numExtractedChans(); sl != 1 { + t.Errorf("unexpected extracted channel-count: %d; ", sl) + } + + if regTimers := fc.RegisteredTimers(); len(regTimers) != 1 { + t.Errorf("unexpected registered timer-count %d; %v", len(regTimers), regTimers) + } + + // advance to our wakeup point + expectedTime = sleeperWake + if wakers := fc.SetClock(sleeperWake); wakers != 1 { + t.Errorf("unexpected wakers from advancing 1 hour(%d); expected 1", wakers) + } + + // wait for our sleeper to wake and return (expected true) + <-ch + + if wu := fc.Wakeups(); wu != 0 { + t.Errorf("unexpected wakeup-count: %d; expected 0", wu) + } + if sa := fc.NumSleepAborts(); sa != 0 { + t.Errorf("unexpected sleep abort-count: %d; expected 0", sa) + } + + if sa := fc.NumTimerAborts(); sa != 0 { + t.Errorf("unexpected timer abort-count: %d; expected 0", sa) + } +} + +// Test that cancels a timer and advances the clock in parallel to guarantee that timer operations +// are correctly synchronized. (mostly useful when run with -race) +func TestFakeClockWithTimerStopRace(t *testing.T) { + t.Parallel() + + baseTime := time.Now() + fc := NewClock(baseTime) + // Setup a channel for us to close when we're done and wake the sleeping goroutine. + testWakeCh := make(chan struct{}) + + sleeperWake := baseTime.Add(time.Hour * 2) + timer := fc.NewTimer(time.Hour * 2) + ch := make(chan struct{}) + go func() { + select { + case <-*timer.Ch(): + case <-testWakeCh: + } + ch <- struct{}{} + }() + + fc.AwaitAggExtractedChans(1) + + go timer.Stop() + go fc.SetClock(sleeperWake) + + close(testWakeCh) + + <-ch + +} + +func TestFakeClockWithTimerWithStop(t *testing.T) { + t.Parallel() + + baseTime := time.Now() + fc := NewClock(baseTime) + + expectedTime := baseTime + + if fn := fc.Now(); !fn.Equal(baseTime) { + t.Errorf("mismatched baseTime(%s) and unincremented Now()(%s)", baseTime, fn) + } + if wakers := fc.Advance(time.Minute); wakers != 0 { + t.Errorf("unexpected wakers from advancing 1 minute(%d); expected 0", wakers) + } + expectedTime = expectedTime.Add(time.Minute) + + if fn := fc.Now(); !fn.Equal(expectedTime) { + t.Errorf("mismatched baseTime(%s) and unincremented Now()(%s)", expectedTime, fn) + } + + // Setup a channel for us to close when we're done and wake the sleeping goroutine. + testWakeCh := make(chan struct{}) + + sleeperWake := expectedTime.Add(time.Hour * 2) + timer := fc.NewTimer(time.Hour * 2) + ch := make(chan bool) + go func() { + select { + case <-*timer.Ch(): + ch <- true + case <-testWakeCh: + ch <- false + } + }() + + fc.AwaitAggExtractedChans(1) + + if sl := fc.NumSleepers(); sl != 0 { + t.Errorf("unexpected sleeper-count: %d; expected 1", sl) + } + if sl := fc.Sleepers(); len(sl) != 0 { + t.Errorf("unexpected sleeper-count: %d; expected 1", len(sl)) + } + + if as := fc.NumAggExtractedChans(); as != 1 { + t.Errorf("unexpected number of aggregate aggregate extracted channels: %d; expected 1", as) + } + + if regTimers := fc.RegisteredTimers(); len(regTimers) != 1 { + t.Errorf("unexpected registered timer-count %d; %v", len(regTimers), regTimers) + } + + // make sure we're still sleeping + select { + case <-ch: + t.Errorf("sleeper finished unexpectedly early") + default: + } + + // Set up a goroutine to awaken once we have an aborted timer (our call to .Stop() further down). + stopWaitRunning := make(chan struct{}) + stopWaitCh := make(chan struct{}) + go func() { + close(stopWaitRunning) + fc.AwaitTimerAborts(1) + stopWaitCh <- struct{}{} + }() + // Make sure that the goroutine calling AwaitTimerAborts is running before proceeding + <-stopWaitRunning + + expectedTime = expectedTime.Add(time.Hour) + if wakers := fc.SetClock(expectedTime); wakers != 0 { + t.Errorf("unexpected wakers from advancing 1 hour(%d); expected 0", wakers) + } + + // make sure we're still sleeping after the SetClock call + select { + case <-ch: + t.Errorf("sleeper finished unexpectedly early") + default: + } + select { + case <-stopWaitCh: + t.Errorf("timer abort watching goroutine awoke unexpectedly early (SetClock should not wake AwaitTimerAborts)") + default: + } + + // verify that our one sleeper is still sleeping + if sl := fc.numExtractedChans(); sl != 1 { + t.Errorf("unexpected extracted channel-count: %d; ", sl) + } + + if regTimers := fc.RegisteredTimers(); len(regTimers) != 1 { + t.Errorf("unexpected registered timer-count %d; %v", len(regTimers), regTimers) + } + + if !timer.Stop() { + t.Errorf("Stop indicated it didn't prevent firing (false), the clock hasn't advanced far enough") + } + + select { + case <-ch: + t.Errorf("sleeper finished unexpectedly early (Stop should not wake)") + default: + } + + // wait for the timer aborts to wake up and finish + <-stopWaitCh + + // advance to our wakeup point + expectedTime = sleeperWake + if wakers := fc.SetClock(sleeperWake); wakers != 0 { + t.Errorf("unexpected wakers from advancing 1 hour(%d); expected 0", wakers) + } + + select { + case <-ch: + t.Errorf("stopped timer-based sleeper finished unexpectedly early") + default: + } + + close(testWakeCh) + // wait for our sleeper to wake and return (expected true) + if wokeByTimer := <-ch; wokeByTimer { + t.Errorf("unexpected wake reason: timer fired, expected close of testWakeCh") + } + + if wu := fc.Wakeups(); wu != 0 { + t.Errorf("unexpected wakeup-count: %d; expected 0", wu) + } + if sa := fc.NumSleepAborts(); sa != 0 { + t.Errorf("unexpected sleep abort-count: %d; expected 0", sa) + } + + if sa := fc.NumTimerAborts(); sa != 1 { + t.Errorf("unexpected timer abort-count: %d; expected 1", sa) + } +} + func TestFakeClockWithRelativeWaiterWithCancel(t *testing.T) { t.Parallel() diff --git a/fake/fake_timer.go b/fake/fake_timer.go new file mode 100644 index 0000000..316d792 --- /dev/null +++ b/fake/fake_timer.go @@ -0,0 +1,181 @@ +package fake + +import ( + "runtime" + "sync" + "sync/atomic" + "time" +) + +type timerTracker struct { + // backpointer to the parent clock + fc *Clock + + timers map[*fakeTimer]time.Time + + mu sync.Mutex +} + +type wakeRes struct { + notified int + awoken int +} + +// Returns the number of timers that were notified, followed by how many were +// actually awoken (or at least a best-guess). +func (t *timerTracker) wakeup(now time.Time) wakeRes { + t.mu.Lock() + defer t.mu.Unlock() + out := wakeRes{} + for ft, ttim := range t.timers { + // use After to reflect <= rather than <. + if ttim.After(now) { + continue + } + wres := ft.wake(now) + if wres.wasStopped { + continue + } + if wres.signaled { + out.awoken++ + out.notified++ + } + delete(t.timers, ft) + } + return out +} + +func (t *timerTracker) registerTimer(ft *fakeTimer, wakeTime time.Time) { + t.mu.Lock() + defer t.mu.Unlock() + t.timers[ft] = wakeTime +} + +// returns true if the timer was previously present +func (t *timerTracker) remove(ft *fakeTimer) bool { + t.mu.Lock() + defer t.mu.Unlock() + _, present := t.timers[ft] + delete(t.timers, ft) + return present +} + +func (t *timerTracker) registeredTimers() []time.Time { + t.mu.Lock() + defer t.mu.Unlock() + out := make([]time.Time, 0, len(t.timers)) + for _, t := range t.timers { + out = append(out, t) + } + return out +} + +type refCnt struct { + i atomic.Int32 +} + +func (r *refCnt) inc() { + r.i.Add(1) +} + +func (r *refCnt) dec() { + if v := r.i.Add(-1); v < 0 { + panic("negative refcount") + } +} + +func (r *refCnt) val() int32 { + return r.i.Load() +} + +type fakeTimer struct { + // backreference to tracker + tracker *timerTracker + + ch chan time.Time + + fired atomic.Bool + stopped atomic.Bool + + ptrExtracted refCnt +} + +type ftWakeState struct { + ptrWasExtracted bool + wasStopped bool + signaled bool +} + +func (f *fakeTimer) wake(now time.Time) ftWakeState { + stopped := f.stopped.Load() + f.fired.Store(true) + if stopped { + return ftWakeState{ + ptrWasExtracted: false, + wasStopped: stopped, + signaled: false, + } + } + + select { + case f.ch <- now: + return ftWakeState{ + ptrWasExtracted: f.ptrExtracted.val() > 0, + wasStopped: false, + signaled: true, + } + default: + return ftWakeState{ + ptrWasExtracted: f.ptrExtracted.val() > 0, + wasStopped: false, + signaled: false, + } + } +} + +// Stop attempts to prevent a timer from firing, returning true if it +// succeeds in preventing the timer from firing, and false if it +// already fired. +func (f *fakeTimer) Stop() bool { + f.tracker.remove(f) + f.stopped.Store(true) + fired := f.fired.Load() + + f.tracker.fc.mu.Lock() + defer f.tracker.fc.mu.Unlock() + f.tracker.fc.timerAborts++ + f.tracker.fc.cond.Broadcast() + return !fired +} + +// Ch returns a pointer to the channel for the timer +func (f *fakeTimer) Ch() *<-chan time.Time { + f.ptrExtracted.inc() + + // Define this callback before chWrapper so it doesn't hold a reference + // and prevent the finalizer from running + chFin := func(any) { + f.ptrExtracted.dec() + f.tracker.fc.mu.Lock() + + defer f.tracker.fc.mu.Unlock() + f.tracker.fc.extractedChans-- + f.tracker.fc.cond.Broadcast() + } + + chWrapper := struct { + ch <-chan time.Time + }{ + ch: f.ch, + } + + runtime.SetFinalizer(&chWrapper, chFin) + + f.tracker.fc.mu.Lock() + defer f.tracker.fc.mu.Unlock() + f.tracker.fc.extractedChans++ + f.tracker.fc.extractedChansAggregate++ + + f.tracker.fc.cond.Broadcast() + return &chWrapper.ch +} diff --git a/offset/offset_clock.go b/offset/offset_clock.go index 9a4ffea..fe23517 100644 --- a/offset/offset_clock.go +++ b/offset/offset_clock.go @@ -49,6 +49,12 @@ func (o *Clock) AfterFunc(d time.Duration, f func()) clocks.StopTimer { return o.inner.AfterFunc(d, f) } +// NewTimer returns a timer from the wrapped clock's [clocks.Clock.NewTimer]. +func (o *Clock) NewTimer(d time.Duration) clocks.Timer { + // relative time, so nothing to do here, just delegate on down. + return o.inner.NewTimer(d) +} + // ContextWithDeadline behaves like context.WithDeadline, but it uses the // clock to determine the when the deadline has expired. func (o *Clock) ContextWithDeadline(ctx context.Context, t time.Time) (context.Context, context.CancelFunc) { @@ -69,3 +75,5 @@ func NewOffsetClock(inner clocks.Clock, offset time.Duration) *Clock { offset: offset, } } + +var _ clocks.Clock = (*Clock)(nil) diff --git a/timer.go b/timer.go index e882cfb..b7f4acc 100644 --- a/timer.go +++ b/timer.go @@ -1,5 +1,7 @@ package clocks +import "time" + // StopTimer exposes a `Stop()` method for an equivalent object to time.Timer // (in the defaultClock case, it may be an actual time.Timer). type StopTimer interface { @@ -8,3 +10,22 @@ type StopTimer interface { // already fired. Stop() bool } + +// Timer exposes methods for an equivalent object to time.Timer +// (in the defaultClock case, it may be an actual time.Timer) +type Timer interface { + StopTimer + // Ch returns the channel for the timer + // For the fake clock's tracking to work, one must always dereference + // the channel returned by this method directly, as it relies on a + // finalizer to figure out when a listening goroutine woke up. + Ch() *<-chan time.Time +} + +type defaultTimer struct { + *time.Timer +} + +func (d *defaultTimer) Ch() *<-chan time.Time { + return &d.C +}