From 2359bb2d273d6a1ff63e1898ad4054fac0cb7ff6 Mon Sep 17 00:00:00 2001 From: luk3skyw4lker Date: Thu, 11 Jul 2024 16:26:44 -0300 Subject: [PATCH] feat: add max calculator to limiter middleware --- middleware/limiter/config.go | 12 ++++++++ middleware/limiter/limiter_fixed.go | 14 +++++---- middleware/limiter/limiter_test.go | 48 +++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 6 deletions(-) diff --git a/middleware/limiter/config.go b/middleware/limiter/config.go index 3045282eec..e7ccaa7b13 100644 --- a/middleware/limiter/config.go +++ b/middleware/limiter/config.go @@ -18,6 +18,13 @@ type Config struct { // Default: 5 Max int + // A function to dynamically calculate the max requests supported by the rate limiter middleware + // + // Default: func(c fiber.Ctx) int { + // return c.Max + // } + MaxCalculator func(c fiber.Ctx) int + // KeyGenerator allows you to generate custom keys, by default c.IP() is used // // Default: func(c fiber.Ctx) string { @@ -102,5 +109,10 @@ func configDefault(config ...Config) Config { if cfg.LimiterMiddleware == nil { cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware } + if cfg.MaxCalculator == nil { + cfg.MaxCalculator = func(_ fiber.Ctx) int { + return cfg.Max + } + } return cfg } diff --git a/middleware/limiter/limiter_fixed.go b/middleware/limiter/limiter_fixed.go index 1e2a1aa0e5..23dd0ca974 100644 --- a/middleware/limiter/limiter_fixed.go +++ b/middleware/limiter/limiter_fixed.go @@ -15,7 +15,6 @@ func (FixedWindow) New(cfg Config) fiber.Handler { var ( // Limiter variables mux = &sync.RWMutex{} - max = strconv.Itoa(cfg.Max) expiration = uint64(cfg.Expiration.Seconds()) ) @@ -27,8 +26,11 @@ func (FixedWindow) New(cfg Config) fiber.Handler { // Return new handler return func(c fiber.Ctx) error { - // Don't execute middleware if Next returns true - if cfg.Next != nil && cfg.Next(c) { + // Generate max from generator, if no generator was provided the default value returned is 5 + max := cfg.MaxCalculator(c) + + // Don't execute middleware if Next returns true or if the max is 0 + if (cfg.Next != nil && cfg.Next(c)) || max == 0 { return c.Next() } @@ -60,7 +62,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler { resetInSec := e.exp - ts // Set how many hits we have left - remaining := cfg.Max - e.currHits + remaining := max - e.currHits // Update storage manager.set(key, e, cfg.Expiration) @@ -68,7 +70,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler { // Unlock entry mux.Unlock() - // Check if hits exceed the cfg.Max + // Check if hits exceed the max if remaining < 0 { // Return response with Retry-After header // https://tools.ietf.org/html/rfc6584 @@ -96,7 +98,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler { } // We can continue, update RateLimit headers - c.Set(xRateLimitLimit, max) + c.Set(xRateLimitLimit, strconv.Itoa(max)) c.Set(xRateLimitRemaining, strconv.Itoa(remaining)) c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10)) diff --git a/middleware/limiter/limiter_test.go b/middleware/limiter/limiter_test.go index ed4470e9a8..2f2b2836f3 100644 --- a/middleware/limiter/limiter_test.go +++ b/middleware/limiter/limiter_test.go @@ -2,6 +2,7 @@ package limiter import ( "io" + "math/rand" "net/http/httptest" "sync" "testing" @@ -14,6 +15,53 @@ import ( "github.com/valyala/fasthttp" ) +// go test -run Test_Limiter_Concurrency_Store -race -v +func Test_Limiter_With_Max_Calculator(t *testing.T) { + t.Parallel() + app := fiber.New() + max := rand.Intn(10) + + app.Use(New(Config{ + MaxCalculator: func(_ fiber.Ctx) int { + return max + }, + Expiration: 2 * time.Second, + Storage: memory.New(), + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("Hello tester!") + }) + + var wg sync.WaitGroup + + for i := 0; i <= max-1; i++ { + wg.Add(1) + go func(wg *sync.WaitGroup) { + defer wg.Done() + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, "Hello tester!", string(body)) + }(&wg) + } + + wg.Wait() + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + require.Equal(t, 429, resp.StatusCode) + + time.Sleep(3 * time.Second) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) +} + // go test -run Test_Limiter_Concurrency_Store -race -v func Test_Limiter_Concurrency_Store(t *testing.T) { t.Parallel()