diff --git a/go.mod b/go.mod index e296bd3..998cbf5 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,5 @@ module github.com/go-chi/httprate go 1.17 require github.com/cespare/xxhash/v2 v2.3.0 + +require golang.org/x/sync v0.7.0 // indirect diff --git a/go.sum b/go.sum index eb7f94a..09aebbf 100644 --- a/go.sum +++ b/go.sum @@ -2,3 +2,5 @@ github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cb github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= diff --git a/local_counter_test.go b/local_counter_test.go index e19ce2f..c6fa7a1 100644 --- a/local_counter_test.go +++ b/local_counter_test.go @@ -6,6 +6,8 @@ import ( "sync" "testing" "time" + + "golang.org/x/sync/errgroup" ) func TestLocalCounter(t *testing.T) { @@ -16,74 +18,128 @@ func TestLocalCounter(t *testing.T) { windowLength: time.Second, } - // Time = NOW() currentWindow := time.Now().UTC().Truncate(time.Second) previousWindow := currentWindow.Add(-time.Second) - for i := 0; i < 5; i++ { - curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) - if curr != 0 { - t.Errorf("unexpected curr = %v, expected %v", curr, 0) - } - if prev != 0 { - t.Errorf("unexpected prev = %v, expected %v", prev, 0) - } - - _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 1) - _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99) - - curr, prev, _ = limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) - if curr != 100 { - t.Errorf("unexpected curr = %v, expected %v", curr, 100) - } - if prev != 0 { - t.Errorf("unexpected prev = %v, expected %v", prev, 0) - } + type test struct { + name string // In each test do the following: + advanceTime time.Duration // 1. advance time + incrBy int // 2. increase counter + prev int // 3. check previous window counter + curr int // and current window counter } - // Time++ - currentWindow = currentWindow.Add(time.Second) - previousWindow = previousWindow.Add(time.Second) - - for i := 0; i < 5; i++ { - curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) - if curr != 0 { - t.Errorf("unexpected curr = %v, expected %v", curr, 0) - } - if prev != 100 { - t.Errorf("unexpected prev = %v, expected %v", prev, 100) - } - _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 50) + tests := []test{ + { + name: "t=0s: init", + prev: 0, + curr: 0, + }, + { + name: "t=0s: increment 1", + incrBy: 1, + prev: 0, + curr: 1, + }, + { + name: "t=0s: increment by 99", + incrBy: 99, + prev: 0, + curr: 100, + }, + { + name: "t=1s: move clock by 1s", + advanceTime: time.Second, + prev: 100, + curr: 0, + }, + { + name: "t=1s: increment by 20", + incrBy: 20, + prev: 100, + curr: 20, + }, + { + name: "t=1s: increment by 20", + incrBy: 20, + prev: 100, + curr: 40, + }, + { + name: "t=2s: move clock by 1s", + advanceTime: time.Second, + prev: 40, + curr: 0, + }, + { + name: "t=2s: incr++", + incrBy: 1, + prev: 40, + curr: 1, + }, + { + name: "t=2s: incr+=9", + incrBy: 9, + prev: 40, + curr: 10, + }, + { + name: "t=2s: incr+=20", + incrBy: 20, + prev: 40, + curr: 30, + }, + { + name: "t=4s: move clock by 2s", + advanceTime: 2 * time.Second, + prev: 0, + curr: 0, + }, } - // Time++ - currentWindow = currentWindow.Add(time.Second) - previousWindow = previousWindow.Add(time.Second) + concurrentRequests := 1000 - for i := 0; i < 5; i++ { - curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) - if curr != 0 { - t.Errorf("unexpected curr = %v, expected %v", curr, 0) + for _, tt := range tests { + if tt.advanceTime > 0 { + currentWindow = currentWindow.Add(tt.advanceTime) + previousWindow = previousWindow.Add(tt.advanceTime) } - if prev != 50 { - t.Errorf("unexpected prev = %v, expected %v", prev, 50) - } - _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99) - } - // Time += 10 - currentWindow = currentWindow.Add(10 * time.Second) - previousWindow = previousWindow.Add(10 * time.Second) + if tt.incrBy > 0 { + var g errgroup.Group + for i := 0; i < concurrentRequests; i++ { + i := i + g.Go(func() error { + key := fmt.Sprintf("key:%v", i) + return limitCounter.IncrementBy(key, currentWindow, tt.incrBy) + }) + } + if err := g.Wait(); err != nil { + t.Errorf("%s: %v", tt.name, err) + } + } - for i := 0; i < 5; i++ { - curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) - if curr != 0 { - t.Errorf("unexpected curr = %v, expected %v", curr, 0) + var g errgroup.Group + for i := 0; i < concurrentRequests; i++ { + i := i + g.Go(func() error { + key := fmt.Sprintf("key:%v", i) + curr, prev, err := limitCounter.Get(key, currentWindow, previousWindow) + if err != nil { + return fmt.Errorf("%q: %w", key, err) + } + if curr != tt.curr { + return fmt.Errorf("%q: unexpected curr = %v, expected %v", key, curr, tt.curr) + } + if prev != tt.prev { + return fmt.Errorf("%q: unexpected prev = %v, expected %v", key, prev, tt.prev) + } + return nil + }) } - if prev != 0 { - t.Errorf("unexpected prev = %v, expected %v", prev, 0) + if err := g.Wait(); err != nil { + t.Errorf("%s: %v", tt.name, err) } - _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99) } }