From 061de00b26c5387675f434c59ae0ee2fd875e70a Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Wed, 24 Jul 2024 22:04:20 +0200 Subject: [PATCH] Improve localCounter performance and memory footprint --- limit_key.go | 15 ++++++++ limiter.go | 12 +++--- local_counter.go | 67 +++++++++++++++----------------- local_counter_test.go | 88 +++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 136 insertions(+), 46 deletions(-) create mode 100644 limit_key.go diff --git a/limit_key.go b/limit_key.go new file mode 100644 index 0000000..dbc6966 --- /dev/null +++ b/limit_key.go @@ -0,0 +1,15 @@ +package httprate + +import ( + "fmt" + "time" + + "github.com/cespare/xxhash/v2" +) + +func LimitCounterKey(key string, window time.Time) uint64 { + h := xxhash.New() + h.WriteString(key) + h.WriteString(fmt.Sprintf("%d", window.Unix())) + return h.Sum64() +} diff --git a/limiter.go b/limiter.go index 0fccef8..12dad3d 100644 --- a/limiter.go +++ b/limiter.go @@ -44,8 +44,10 @@ func newRateLimiter(requestLimit int, windowLength time.Duration, options ...Opt if rl.limitCounter == nil { rl.limitCounter = &localCounter{ - counters: make(map[uint64]*count), - windowLength: windowLength, + latestWindow: time.Now().UTC().Truncate(windowLength), + latestCounters: make(map[uint64]int), + previousCounters: make(map[uint64]int), + windowLength: windowLength, } } rl.limitCounter.Config(requestLimit, windowLength) @@ -133,8 +135,8 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler { } func (l *rateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) { - t := time.Now().UTC() - currentWindow := t.Truncate(l.windowLength) + now := time.Now().UTC() + currentWindow := now.Truncate(l.windowLength) previousWindow := currentWindow.Add(-l.windowLength) currCount, prevCount, err := l.limitCounter.Get(key, currentWindow, previousWindow) @@ -142,7 +144,7 @@ func (l *rateLimiter) calculateRate(key string, requestLimit int) (bool, float64 return false, 0, err } - diff := t.Sub(currentWindow) + diff := now.Sub(currentWindow) rate := float64(prevCount)*(float64(l.windowLength)-float64(diff))/float64(l.windowLength) + float64(currCount) if rate > float64(requestLimit) { return false, rate, nil diff --git a/local_counter.go b/local_counter.go index 7c49858..c071011 100644 --- a/local_counter.go +++ b/local_counter.go @@ -1,7 +1,6 @@ package httprate import ( - "fmt" "sync" "time" @@ -11,15 +10,11 @@ import ( var _ LimitCounter = &localCounter{} type localCounter struct { - counters map[uint64]*count - windowLength time.Duration - lastEvict time.Time - mu sync.RWMutex -} - -type count struct { - value int - updatedAt time.Time + latestWindow time.Time + previousCounters map[uint64]int + latestCounters map[uint64]int + windowLength time.Duration + mu sync.RWMutex } func (c *localCounter) Config(requestLimit int, windowLength time.Duration) { @@ -37,17 +32,12 @@ func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount i c.mu.Lock() defer c.mu.Unlock() - c.evict() + c.evict(currentWindow) - hkey := LimitCounterKey(key, currentWindow) + hkey := limitCounterKey(key, currentWindow) - v, ok := c.counters[hkey] - if !ok { - v = &count{} - c.counters[hkey] = v - } - v.value += amount - v.updatedAt = time.Now() + count, _ := c.latestCounters[hkey] + c.latestCounters[hkey] = count + amount return nil } @@ -56,36 +46,39 @@ func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time) c.mu.RLock() defer c.mu.RUnlock() - curr, ok := c.counters[LimitCounterKey(key, currentWindow)] - if !ok { - curr = &count{value: 0, updatedAt: time.Now()} + if c.latestWindow == currentWindow { + curr, _ := c.latestCounters[limitCounterKey(key, currentWindow)] + prev, _ := c.previousCounters[limitCounterKey(key, previousWindow)] + return curr, prev, nil } - prev, ok := c.counters[LimitCounterKey(key, previousWindow)] - if !ok { - prev = &count{value: 0, updatedAt: time.Now()} + + if c.latestWindow == previousWindow { + prev, _ := c.latestCounters[limitCounterKey(key, previousWindow)] + return 0, prev, nil } - return curr.value, prev.value, nil + return 0, 0, nil } -func (c *localCounter) evict() { - d := c.windowLength * 3 - - if time.Since(c.lastEvict) < d { +func (c *localCounter) evict(currentWindow time.Time) { + if c.latestWindow == currentWindow { return } - c.lastEvict = time.Now() - for k, v := range c.counters { - if time.Since(v.updatedAt) >= d { - delete(c.counters, k) - } + previousWindow := currentWindow.Add(-c.windowLength) + if c.latestWindow == previousWindow { + c.latestWindow = currentWindow + c.latestCounters, c.previousCounters = make(map[uint64]int), c.latestCounters + return } + + c.latestWindow = currentWindow + // NOTE: Don't use clear() to keep backward-compatibility. + c.previousCounters, c.latestCounters = make(map[uint64]int), make(map[uint64]int) } -func LimitCounterKey(key string, window time.Time) uint64 { +func limitCounterKey(key string, window time.Time) uint64 { h := xxhash.New() h.WriteString(key) - h.WriteString(fmt.Sprintf("%d", window.Unix())) return h.Sum64() } diff --git a/local_counter_test.go b/local_counter_test.go index 4d9a71e..e19ce2f 100644 --- a/local_counter_test.go +++ b/local_counter_test.go @@ -8,14 +8,94 @@ import ( "time" ) +func TestLocalCounter(t *testing.T) { + limitCounter := &localCounter{ + latestWindow: time.Now().UTC().Truncate(time.Second), + latestCounters: make(map[uint64]int), + previousCounters: make(map[uint64]int), + windowLength: time.Second, + } + + // Time = NOW() + currentWindow := time.Now().UTC().Truncate(time.Second) + previousWindow := currentWindow.Add(-time.Second) + + for i := 0; i < 5; i++ { + curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) + if curr != 0 { + t.Errorf("unexpected curr = %v, expected %v", curr, 0) + } + if prev != 0 { + t.Errorf("unexpected prev = %v, expected %v", prev, 0) + } + + _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 1) + _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99) + + curr, prev, _ = limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) + if curr != 100 { + t.Errorf("unexpected curr = %v, expected %v", curr, 100) + } + if prev != 0 { + t.Errorf("unexpected prev = %v, expected %v", prev, 0) + } + } + + // Time++ + currentWindow = currentWindow.Add(time.Second) + previousWindow = previousWindow.Add(time.Second) + + for i := 0; i < 5; i++ { + curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) + if curr != 0 { + t.Errorf("unexpected curr = %v, expected %v", curr, 0) + } + if prev != 100 { + t.Errorf("unexpected prev = %v, expected %v", prev, 100) + } + _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 50) + } + + // Time++ + currentWindow = currentWindow.Add(time.Second) + previousWindow = previousWindow.Add(time.Second) + + for i := 0; i < 5; i++ { + curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) + if curr != 0 { + t.Errorf("unexpected curr = %v, expected %v", curr, 0) + } + if prev != 50 { + t.Errorf("unexpected prev = %v, expected %v", prev, 50) + } + _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99) + } + + // Time += 10 + currentWindow = currentWindow.Add(10 * time.Second) + previousWindow = previousWindow.Add(10 * time.Second) + + for i := 0; i < 5; i++ { + curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) + if curr != 0 { + t.Errorf("unexpected curr = %v, expected %v", curr, 0) + } + if prev != 0 { + t.Errorf("unexpected prev = %v, expected %v", prev, 0) + } + _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99) + } +} + func BenchmarkLocalCounter(b *testing.B) { limitCounter := &localCounter{ - counters: make(map[uint64]*count), - windowLength: time.Second, + latestWindow: time.Now().UTC().Truncate(time.Second), + latestCounters: make(map[uint64]int), + previousCounters: make(map[uint64]int), + windowLength: time.Second, } - t := time.Now().UTC() - currentWindow := t.Truncate(time.Second) + currentWindow := time.Now().UTC().Truncate(time.Second) previousWindow := currentWindow.Add(-time.Second) b.ResetTimer()