diff --git a/_example/go.mod b/_example/go.mod new file mode 100644 index 0000000..e24b12b --- /dev/null +++ b/_example/go.mod @@ -0,0 +1,12 @@ +module github.com/go-chi/httprate/_example + +go 1.22.5 + +replace github.com/go-chi/httprate => ../ + +require ( + github.com/go-chi/chi/v5 v5.1.0 + github.com/go-chi/httprate v0.0.0-00010101000000-000000000000 +) + +require github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/_example/go.sum b/_example/go.sum new file mode 100644 index 0000000..29685bd --- /dev/null +++ b/_example/go.sum @@ -0,0 +1,6 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= diff --git a/_example/main.go b/_example/main.go index e21afb4..70ebb8c 100644 --- a/_example/main.go +++ b/_example/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "log" "net/http" "time" @@ -14,11 +15,6 @@ func main() { r := chi.NewRouter() r.Use(middleware.Logger) - // Overall rate-limiter, keyed by IP and URL path (aka endpoint). - // - // This means each user (by IP) will receive a unique limit counter per endpoint. - // r.Use(httprate.Limit(10, 10*time.Second, httprate.WithKeyFuncs(httprate.KeyByIP, httprate.KeyByEndpoint))) - r.Route("/admin", func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -30,7 +26,7 @@ func main() { // Here we set a specific rate limit by ip address and userID r.Use(httprate.Limit( 10, - 10*time.Second, + time.Minute, httprate.WithKeyFuncs(httprate.KeyByIP, func(r *http.Request) (string, error) { token := r.Context().Value("userID").(string) return token, nil @@ -44,21 +40,27 @@ func main() { )) r.Get("/", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("admin.")) + w.Write([]byte("10 req/min\n")) }) }) r.Group(func(r chi.Router) { - // Here we set another rate limit for a group of handlers. + // Here we set another rate limit (3 req/min) for a group of handlers. // // Note: in practice you don't need to have so many layered rate-limiters, // but the example here is to illustrate how to control the machinery. - r.Use(httprate.LimitByIP(3, 5*time.Second)) + r.Use(httprate.LimitByIP(3, time.Minute)) r.Get("/", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(".")) + w.Write([]byte("3 req/min\n")) }) }) + log.Printf("Serving at localhost:3333") + log.Println() + log.Printf("Try running:") + log.Printf("curl -v http://localhost:3333") + log.Printf("curl -v http://localhost:3333/admin") + http.ListenAndServe(":3333", r) } diff --git a/local_counter_test.go b/local_counter_test.go index c6fa7a1..7f6fd32 100644 --- a/local_counter_test.go +++ b/local_counter_test.go @@ -12,14 +12,16 @@ import ( func TestLocalCounter(t *testing.T) { limitCounter := &localCounter{ - latestWindow: time.Now().UTC().Truncate(time.Second), + latestWindow: time.Now().UTC().Truncate(time.Minute), latestCounters: make(map[uint64]int), previousCounters: make(map[uint64]int), - windowLength: time.Second, + windowLength: time.Minute, } - currentWindow := time.Now().UTC().Truncate(time.Second) - previousWindow := currentWindow.Add(-time.Second) + limitCounter.Config(1000, time.Minute) + + currentWindow := time.Now().UTC().Truncate(time.Minute) + previousWindow := currentWindow.Add(-time.Minute) type test struct { name string // In each test do the following: @@ -31,67 +33,67 @@ func TestLocalCounter(t *testing.T) { tests := []test{ { - name: "t=0s: init", + name: "t=0m: init", prev: 0, curr: 0, }, { - name: "t=0s: increment 1", + name: "t=0m: increment 1", incrBy: 1, prev: 0, curr: 1, }, { - name: "t=0s: increment by 99", + name: "t=0m: increment by 99", incrBy: 99, prev: 0, curr: 100, }, { - name: "t=1s: move clock by 1s", - advanceTime: time.Second, + name: "t=1m: move clock by 1m", + advanceTime: time.Minute, prev: 100, curr: 0, }, { - name: "t=1s: increment by 20", + name: "t=1m: increment by 20", incrBy: 20, prev: 100, curr: 20, }, { - name: "t=1s: increment by 20", + name: "t=1m: increment by 20", incrBy: 20, prev: 100, curr: 40, }, { - name: "t=2s: move clock by 1s", - advanceTime: time.Second, + name: "t=2m: move clock by 1m", + advanceTime: time.Minute, prev: 40, curr: 0, }, { - name: "t=2s: incr++", + name: "t=2m: incr++", incrBy: 1, prev: 40, curr: 1, }, { - name: "t=2s: incr+=9", + name: "t=2m: incr+=9", incrBy: 9, prev: 40, curr: 10, }, { - name: "t=2s: incr+=20", + name: "t=2m: incr+=20", incrBy: 20, prev: 40, curr: 30, }, { - name: "t=4s: move clock by 2s", - advanceTime: 2 * time.Second, + name: "t=4m: move clock by 2m", + advanceTime: 2 * time.Minute, prev: 0, curr: 0, }, @@ -145,22 +147,22 @@ func TestLocalCounter(t *testing.T) { func BenchmarkLocalCounter(b *testing.B) { limitCounter := &localCounter{ - latestWindow: time.Now().UTC().Truncate(time.Second), + latestWindow: time.Now().UTC().Truncate(time.Minute), latestCounters: make(map[uint64]int), previousCounters: make(map[uint64]int), - windowLength: time.Second, + windowLength: time.Minute, } - currentWindow := time.Now().UTC().Truncate(time.Second) - previousWindow := currentWindow.Add(-time.Second) + currentWindow := time.Now().UTC().Truncate(time.Minute) + previousWindow := currentWindow.Add(-time.Minute) b.ResetTimer() for i := 0; i < b.N; i++ { for i := range []int{0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0, 0, 1, 0} { // Simulate time. - currentWindow.Add(time.Duration(i) * time.Second) - previousWindow.Add(time.Duration(i) * time.Second) + currentWindow.Add(time.Duration(i) * time.Minute) + previousWindow.Add(time.Duration(i) * time.Minute) wg := sync.WaitGroup{} wg.Add(1000)