Skip to content

Commit

Permalink
Improve localCounter performance and memory footprint (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
VojtechVitek authored Jul 25, 2024
1 parent c6b43ff commit bff9ca6
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 46 deletions.
15 changes: 15 additions & 0 deletions limit_key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package httprate

import (
"fmt"
"time"

"github.com/cespare/xxhash/v2"
)

func LimitCounterKey(key string, window time.Time) uint64 {
h := xxhash.New()
h.WriteString(key)
h.WriteString(fmt.Sprintf("%d", window.Unix()))
return h.Sum64()
}
12 changes: 7 additions & 5 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ func newRateLimiter(requestLimit int, windowLength time.Duration, options ...Opt

if rl.limitCounter == nil {
rl.limitCounter = &localCounter{
counters: make(map[uint64]*count),
windowLength: windowLength,
latestWindow: time.Now().UTC().Truncate(windowLength),
latestCounters: make(map[uint64]int),
previousCounters: make(map[uint64]int),
windowLength: windowLength,
}
}
rl.limitCounter.Config(requestLimit, windowLength)
Expand Down Expand Up @@ -133,16 +135,16 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
}

func (l *rateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) {
t := time.Now().UTC()
currentWindow := t.Truncate(l.windowLength)
now := time.Now().UTC()
currentWindow := now.Truncate(l.windowLength)
previousWindow := currentWindow.Add(-l.windowLength)

currCount, prevCount, err := l.limitCounter.Get(key, currentWindow, previousWindow)
if err != nil {
return false, 0, err
}

diff := t.Sub(currentWindow)
diff := now.Sub(currentWindow)
rate := float64(prevCount)*(float64(l.windowLength)-float64(diff))/float64(l.windowLength) + float64(currCount)
if rate > float64(requestLimit) {
return false, rate, nil
Expand Down
67 changes: 30 additions & 37 deletions local_counter.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package httprate

import (
"fmt"
"sync"
"time"

Expand All @@ -11,15 +10,11 @@ import (
var _ LimitCounter = &localCounter{}

type localCounter struct {
counters map[uint64]*count
windowLength time.Duration
lastEvict time.Time
mu sync.RWMutex
}

type count struct {
value int
updatedAt time.Time
latestWindow time.Time
previousCounters map[uint64]int
latestCounters map[uint64]int
windowLength time.Duration
mu sync.RWMutex
}

func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {
Expand All @@ -37,17 +32,12 @@ func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount i
c.mu.Lock()
defer c.mu.Unlock()

c.evict()
c.evict(currentWindow)

hkey := LimitCounterKey(key, currentWindow)
hkey := limitCounterKey(key, currentWindow)

v, ok := c.counters[hkey]
if !ok {
v = &count{}
c.counters[hkey] = v
}
v.value += amount
v.updatedAt = time.Now()
count, _ := c.latestCounters[hkey]
c.latestCounters[hkey] = count + amount

return nil
}
Expand All @@ -56,36 +46,39 @@ func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time)
c.mu.RLock()
defer c.mu.RUnlock()

curr, ok := c.counters[LimitCounterKey(key, currentWindow)]
if !ok {
curr = &count{value: 0, updatedAt: time.Now()}
if c.latestWindow == currentWindow {
curr, _ := c.latestCounters[limitCounterKey(key, currentWindow)]
prev, _ := c.previousCounters[limitCounterKey(key, previousWindow)]
return curr, prev, nil
}
prev, ok := c.counters[LimitCounterKey(key, previousWindow)]
if !ok {
prev = &count{value: 0, updatedAt: time.Now()}

if c.latestWindow == previousWindow {
prev, _ := c.latestCounters[limitCounterKey(key, previousWindow)]
return 0, prev, nil
}

return curr.value, prev.value, nil
return 0, 0, nil
}

func (c *localCounter) evict() {
d := c.windowLength * 3

if time.Since(c.lastEvict) < d {
func (c *localCounter) evict(currentWindow time.Time) {
if c.latestWindow == currentWindow {
return
}
c.lastEvict = time.Now()

for k, v := range c.counters {
if time.Since(v.updatedAt) >= d {
delete(c.counters, k)
}
previousWindow := currentWindow.Add(-c.windowLength)
if c.latestWindow == previousWindow {
c.latestWindow = currentWindow
c.latestCounters, c.previousCounters = make(map[uint64]int), c.latestCounters
return
}

c.latestWindow = currentWindow
// NOTE: Don't use clear() to keep backward-compatibility.
c.previousCounters, c.latestCounters = make(map[uint64]int), make(map[uint64]int)
}

func LimitCounterKey(key string, window time.Time) uint64 {
func limitCounterKey(key string, window time.Time) uint64 {
h := xxhash.New()
h.WriteString(key)
h.WriteString(fmt.Sprintf("%d", window.Unix()))
return h.Sum64()
}
88 changes: 84 additions & 4 deletions local_counter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,94 @@ import (
"time"
)

func TestLocalCounter(t *testing.T) {
limitCounter := &localCounter{
latestWindow: time.Now().UTC().Truncate(time.Second),
latestCounters: make(map[uint64]int),
previousCounters: make(map[uint64]int),
windowLength: time.Second,
}

// Time = NOW()
currentWindow := time.Now().UTC().Truncate(time.Second)
previousWindow := currentWindow.Add(-time.Second)

for i := 0; i < 5; i++ {
curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow)
if curr != 0 {
t.Errorf("unexpected curr = %v, expected %v", curr, 0)
}
if prev != 0 {
t.Errorf("unexpected prev = %v, expected %v", prev, 0)
}

_ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 1)
_ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99)

curr, prev, _ = limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow)
if curr != 100 {
t.Errorf("unexpected curr = %v, expected %v", curr, 100)
}
if prev != 0 {
t.Errorf("unexpected prev = %v, expected %v", prev, 0)
}
}

// Time++
currentWindow = currentWindow.Add(time.Second)
previousWindow = previousWindow.Add(time.Second)

for i := 0; i < 5; i++ {
curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow)
if curr != 0 {
t.Errorf("unexpected curr = %v, expected %v", curr, 0)
}
if prev != 100 {
t.Errorf("unexpected prev = %v, expected %v", prev, 100)
}
_ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 50)
}

// Time++
currentWindow = currentWindow.Add(time.Second)
previousWindow = previousWindow.Add(time.Second)

for i := 0; i < 5; i++ {
curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow)
if curr != 0 {
t.Errorf("unexpected curr = %v, expected %v", curr, 0)
}
if prev != 50 {
t.Errorf("unexpected prev = %v, expected %v", prev, 50)
}
_ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99)
}

// Time += 10
currentWindow = currentWindow.Add(10 * time.Second)
previousWindow = previousWindow.Add(10 * time.Second)

for i := 0; i < 5; i++ {
curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow)
if curr != 0 {
t.Errorf("unexpected curr = %v, expected %v", curr, 0)
}
if prev != 0 {
t.Errorf("unexpected prev = %v, expected %v", prev, 0)
}
_ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99)
}
}

func BenchmarkLocalCounter(b *testing.B) {
limitCounter := &localCounter{
counters: make(map[uint64]*count),
windowLength: time.Second,
latestWindow: time.Now().UTC().Truncate(time.Second),
latestCounters: make(map[uint64]int),
previousCounters: make(map[uint64]int),
windowLength: time.Second,
}

t := time.Now().UTC()
currentWindow := t.Truncate(time.Second)
currentWindow := time.Now().UTC().Truncate(time.Second)
previousWindow := currentWindow.Add(-time.Second)

b.ResetTimer()
Expand Down

0 comments on commit bff9ca6

Please sign in to comment.