Skip to content

Commit

Permalink
feat: add max calculator to limiter middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
luk3skyw4lker committed Jul 11, 2024
1 parent c579a1a commit 2359bb2
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 6 deletions.
12 changes: 12 additions & 0 deletions middleware/limiter/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ type Config struct {
// Default: 5
Max int

// A function to dynamically calculate the max requests supported by the rate limiter middleware
//
// Default: func(c fiber.Ctx) int {
// return c.Max
// }
MaxCalculator func(c fiber.Ctx) int

// KeyGenerator allows you to generate custom keys, by default c.IP() is used
//
// Default: func(c fiber.Ctx) string {
Expand Down Expand Up @@ -102,5 +109,10 @@ func configDefault(config ...Config) Config {
if cfg.LimiterMiddleware == nil {
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
}
if cfg.MaxCalculator == nil {
cfg.MaxCalculator = func(_ fiber.Ctx) int {
return cfg.Max
}
}
return cfg
}
14 changes: 8 additions & 6 deletions middleware/limiter/limiter_fixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
expiration = uint64(cfg.Expiration.Seconds())
)

Expand All @@ -27,8 +26,11 @@ func (FixedWindow) New(cfg Config) fiber.Handler {

// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
// Generate max from generator, if no generator was provided the default value returned is 5
max := cfg.MaxCalculator(c)

// Don't execute middleware if Next returns true or if the max is 0
if (cfg.Next != nil && cfg.Next(c)) || max == 0 {
return c.Next()
}

Expand Down Expand Up @@ -60,15 +62,15 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
resetInSec := e.exp - ts

// Set how many hits we have left
remaining := cfg.Max - e.currHits
remaining := max - e.currHits

// Update storage
manager.set(key, e, cfg.Expiration)

// Unlock entry
mux.Unlock()

// Check if hits exceed the cfg.Max
// Check if hits exceed the max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
Expand Down Expand Up @@ -96,7 +98,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
}

// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitLimit, strconv.Itoa(max))
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))

Expand Down
48 changes: 48 additions & 0 deletions middleware/limiter/limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package limiter

import (
"io"
"math/rand"
"net/http/httptest"
"sync"
"testing"
Expand All @@ -14,6 +15,53 @@ import (
"github.com/valyala/fasthttp"
)

// go test -run Test_Limiter_Concurrency_Store -race -v
func Test_Limiter_With_Max_Calculator(t *testing.T) {
t.Parallel()
app := fiber.New()
max := rand.Intn(10)

app.Use(New(Config{
MaxCalculator: func(_ fiber.Ctx) int {
return max
},
Expiration: 2 * time.Second,
Storage: memory.New(),
}))

app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})

var wg sync.WaitGroup

for i := 0; i <= max-1; i++ {
wg.Add(1)
go func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
assert.NoError(t, err)
assert.Equal(t, fiber.StatusOK, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Hello tester!", string(body))
}(&wg)
}

wg.Wait()

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)

time.Sleep(3 * time.Second)

resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}

// go test -run Test_Limiter_Concurrency_Store -race -v
func Test_Limiter_Concurrency_Store(t *testing.T) {
t.Parallel()
Expand Down

0 comments on commit 2359bb2

Please sign in to comment.