From 7eaea8fb5c19385eb787171e586d116fd7034bb4 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Thu, 19 Sep 2024 15:53:01 +0200 Subject: [PATCH] NoCost --- context.go | 6 ++++++ limiter.go | 20 +++++++++++++++----- limiter_test.go | 4 ++-- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/context.go b/context.go index 3988dd6..93965e6 100644 --- a/context.go +++ b/context.go @@ -9,6 +9,8 @@ const ( requestLimitKey ) +const _NoLimit = -1 + func WithIncrement(ctx context.Context, value int) context.Context { return context.WithValue(ctx, incrementKey, value) } @@ -18,6 +20,10 @@ func getIncrement(ctx context.Context) (int, bool) { return value, ok } +func WithNoLimit(ctx context.Context) context.Context { + return context.WithValue(ctx, requestLimitKey, _NoLimit) +} + func WithRequestLimit(ctx context.Context, value int) context.Context { return context.WithValue(ctx, requestLimitKey, value) } diff --git a/limiter.go b/limiter.go index fe52545..011f586 100644 --- a/limiter.go +++ b/limiter.go @@ -72,22 +72,32 @@ func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string currentWindow := time.Now().UTC().Truncate(l.windowLength) ctx := r.Context() + setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix())) + limit, ok := getRequestLimit(ctx) if !ok { limit = l.requestLimit } - - if limit <= 0 { + // If the limit is set to 0, we are always over limit + if limit == 0 { + return true + } + // If the limit is set to -1, we are never over limit + if limit == _NoLimit { return false } - setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit)) - setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix())) - increment, ok := getIncrement(r.Context()) if !ok { increment = 1 } + // If the increment is 0, we are always on limit + if increment == 0 { + return false + } + + setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit)) + if increment > 1 { setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment)) } diff --git a/limiter_test.go b/limiter_test.go index 5ac41c1..cc4a249 100644 --- a/limiter_test.go +++ b/limiter_test.go @@ -141,8 +141,8 @@ func TestResponseHeaders(t *testing.T) { requestsLimit: 5, increments: []int{0, 0, 0, 0, 0, 0}, respCodes: []int{200, 200, 200, 200, 200, 200}, - respLimitHeader: []string{"5", "5", "5", "5", "5", "5"}, - respRemainingHeader: []string{"5", "5", "5", "5", "5", "5"}, + respLimitHeader: []string{"", "", "", "", "", ""}, + respRemainingHeader: []string{"", "", "", "", "", ""}, }, { name: "always block",