diff --git a/throttle.go b/throttle.go index ebfc8a2..b98352a 100644 --- a/throttle.go +++ b/throttle.go @@ -55,6 +55,19 @@ type Options struct { // If the throttle is disabled or not // defaults to false Disabled bool + + // If this function returns true, the request will not be counted towards the access count. + // You can set it to provide your own conditions for a request to be counted based on the request or the response, + // for example to exclude success responses from the count. + SkipRegisterFunction func(resp http.ResponseWriter, req *http.Request) bool + + // If this function returns true, the request will not be checked for access, the policy will be ignored. + // You can set it to provide your own conditions for a request or a response to be allowed, for example to skip + // throttling on an IP allowlist. + // Note: You can't delay processing here with something like c.Next() until after the request, because that will + // make the access check to happen after executing the controller handler. Because of this, be aware that resp might + // not contain what you want yet. + SkipAccessCheckFunction func(resp http.ResponseWriter, req *http.Request) bool } // KeyValueStorer is the required interface for the Store Option @@ -224,6 +237,14 @@ func (o *Options) Identify(req *http.Request) string { return o.IdentificationFunction(req) } +func (o *Options) SkipRegister(resp http.ResponseWriter, req *http.Request) bool { + return o.SkipRegisterFunction(resp, req) +} + +func (o *Options) SkipAccessCheck(resp http.ResponseWriter, req *http.Request) bool { + return o.SkipAccessCheckFunction(resp, req) +} + // A throttling Policy // Takes two arguments, one required: // First is a Quota (A Limit with an associated time). When the given Limit @@ -242,17 +263,28 @@ func Policy(quota *Quota, options ...*Options) func(resp http.ResponseWriter, re return func(resp http.ResponseWriter, req *http.Request) { id := makeKey(o.KeyPrefix, quota.KeyId(), o.Identify(req)) + // Already set rate limit headers in case the SkipRegister method calls some delay method like c.Next() and we + // might not be able to set the headers again in that case, because the response has already been written. + setRateLimitHeaders(resp, controller, id) + + if o.SkipAccessCheck(resp, req) { + return + } + if controller.DeniesAccess(id) { msg := newAccessMessage(o.StatusCode, o.Message) - setRateLimitHeaders(resp, controller, id) resp.WriteHeader(msg.StatusCode) resp.Write([]byte(msg.Message)) return - } else { + } + + if !o.SkipRegister(resp, req) { controller.RegisterAccess(id) + + // Set the headers again because the rate limit values have been changed at this point due to calling + // RegisterAccess. setRateLimitHeaders(resp, controller, id) } - } } @@ -279,6 +311,14 @@ func defaultIdentify(req *http.Request) string { return ip } +func defaultSkipRegister(http.ResponseWriter, *http.Request) bool { + return false +} + +func defaultSkipAccess(http.ResponseWriter, *http.Request) bool { + return false +} + // Make a key from various parts for use in the key value store func makeKey(parts ...string) string { return strings.Join(parts, "_") @@ -287,12 +327,14 @@ func makeKey(parts ...string) string { // Creates new default options and assigns any given options func newOptions(options []*Options) *Options { o := Options{ - StatusCode: defaultStatusCode, - Message: defaultMessage, - IdentificationFunction: defaultIdentify, - KeyPrefix: defaultKeyPrefix, - Store: nil, - Disabled: defaultDisabled, + StatusCode: defaultStatusCode, + Message: defaultMessage, + IdentificationFunction: defaultIdentify, + KeyPrefix: defaultKeyPrefix, + Store: nil, + Disabled: defaultDisabled, + SkipRegisterFunction: defaultSkipRegister, + SkipAccessCheckFunction: defaultSkipAccess, } // when all defaults, return it