From 6701b02e8425c25953f103add46d7e850aa7d0b4 Mon Sep 17 00:00:00 2001 From: blaine-arcjet <146491715+blaine-arcjet@users.noreply.github.com> Date: Tue, 6 Feb 2024 15:00:37 -0700 Subject: [PATCH] feat!: Add fixedWindow, tokenBucket, and slidingWindow primitives (#184) Replaces #171 This adds the `fixedWindow`, `tokenBucket` and `slidingWindow` primitives. It keeps the `rateLimit()` primitive for backwards compatibility, but it is breaking because `protectSignup` was altered to require sliding window configurations. The changes were streamlined by putting `requested` into the `extra` field of the RequestDetails. This allowed us to keep the same structure for specifying our Rules. Draft because I need to add tests to maintain full test coverage and because it relies on #179, #180, #181, #182, and #183 (after which, I'll rebase all the cherry-picked commits out of this PR). --- arcjet/index.ts | 129 ++++- arcjet/test/index.edge.test.ts | 12 +- arcjet/test/index.node.test.ts | 459 +++++++++++++++++- .../pages/api/arcjet-edge.ts | 4 +- .../nextjs-13-pages-wrap/pages/api/arcjet.ts | 4 +- .../app/api/arcjet/route.ts | 4 +- .../nextjs-14-openai/app/api/chat/route.ts | 29 +- .../pages/api/arcjet-edge.ts | 4 +- .../nextjs-14-pages-wrap/pages/api/arcjet.ts | 4 +- protocol/convert.ts | 56 ++- protocol/index.ts | 41 ++ protocol/test/convert.test.ts | 44 +- 12 files changed, 744 insertions(+), 46 deletions(-) diff --git a/arcjet/index.ts b/arcjet/index.ts index ee14740fd..874f631bd 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -13,11 +13,13 @@ import { ArcjetDecision, ArcjetDenyDecision, ArcjetErrorDecision, - ArcjetRateLimitRule, ArcjetBotRule, ArcjetRule, ArcjetLocalRule, ArcjetRequestDetails, + ArcjetTokenBucketRateLimitRule, + ArcjetFixedWindowRateLimitRule, + ArcjetSlidingWindowRateLimitRule, } from "@arcjet/protocol"; import { ArcjetBotTypeToProtocol, @@ -414,7 +416,16 @@ function runtime(): Runtime { } } -export type RateLimitOptions = { +type TokenBucketRateLimitOptions = { + mode?: ArcjetMode; + match?: string; + characteristics?: string[]; + refillRate: number; + interval: number; + capacity: number; +}; + +type FixedWindowRateLimitOptions = { mode?: ArcjetMode; match?: string; characteristics?: string[]; @@ -422,6 +433,14 @@ export type RateLimitOptions = { max: number; }; +type SlidingWindowRateLimitOptions = { + mode?: ArcjetMode; + match?: string; + characteristics?: string[]; + interval: number; + max: number; +}; + /** * Bot detection is disabled by default. The `bots` configuration block allows * you to enable or disable it and configure additional rules. @@ -570,14 +589,90 @@ function isLocalRule( ); } +export function tokenBucket( + options?: TokenBucketRateLimitOptions, + ...additionalOptions: TokenBucketRateLimitOptions[] +): Primitive<{ requested: number }> { + const rules: ArcjetTokenBucketRateLimitRule<{ requested: number }>[] = []; + + if (typeof options === "undefined") { + return rules; + } + + for (const opt of [options, ...additionalOptions]) { + const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN"; + const match = opt.match; + const characteristics = opt.characteristics; + + const refillRate = opt.refillRate; + const interval = opt.interval; + const capacity = opt.capacity; + + rules.push({ + type: "RATE_LIMIT", + priority: Priority.RateLimit, + mode, + match, + characteristics, + algorithm: "TOKEN_BUCKET", + refillRate, + interval, + capacity, + }); + } + + return rules; +} + +export function fixedWindow( + options?: FixedWindowRateLimitOptions, + ...additionalOptions: FixedWindowRateLimitOptions[] +): Primitive { + const rules: ArcjetFixedWindowRateLimitRule<{}>[] = []; + + if (typeof options === "undefined") { + return rules; + } + + for (const opt of [options, ...additionalOptions]) { + const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN"; + const match = opt.match; + const characteristics = opt.characteristics; + + const max = opt.max; + const window = opt.window; + + rules.push({ + type: "RATE_LIMIT", + priority: Priority.RateLimit, + mode, + match, + characteristics, + algorithm: "FIXED_WINDOW", + max, + window, + }); + } + + return rules; +} + +// This is currently kept for backwards compatibility but should be removed in +// favor of the fixedWindow primitive. export function rateLimit( - options?: RateLimitOptions, - ...additionalOptions: RateLimitOptions[] + options?: FixedWindowRateLimitOptions, + ...additionalOptions: FixedWindowRateLimitOptions[] ): Primitive { // TODO(#195): We should also have a local rate limit using an in-memory data // structure if the environment supports it + return fixedWindow(options, ...additionalOptions); +} - const rules: ArcjetRateLimitRule<{}>[] = []; +export function slidingWindow( + options?: SlidingWindowRateLimitOptions, + ...additionalOptions: SlidingWindowRateLimitOptions[] +): Primitive { + const rules: ArcjetSlidingWindowRateLimitRule<{}>[] = []; if (typeof options === "undefined") { return rules; @@ -585,15 +680,21 @@ export function rateLimit( for (const opt of [options, ...additionalOptions]) { const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN"; + const match = opt.match; + const characteristics = opt.characteristics; + + const max = opt.max; + const interval = opt.interval; rules.push({ type: "RATE_LIMIT", priority: Priority.RateLimit, mode, - match: opt.match, - characteristics: opt.characteristics, - window: opt.window, - max: opt.max, + match, + characteristics, + algorithm: "SLIDING_WINDOW", + max, + interval, }); } @@ -674,7 +775,7 @@ export function detectBot( // Always create at least one BOT rule for (const opt of [options ?? {}, ...additionalOptions]) { const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN"; - // TODO: Filter invalid email types (or error??) + // TODO: Filter invalid bot types (or error??) const block = Array.isArray(opt.block) ? opt.block : [ArcjetBotType.AUTOMATED]; @@ -766,7 +867,7 @@ export function detectBot( } export type ProtectSignupOptions = { - rateLimit?: RateLimitOptions | RateLimitOptions[]; + rateLimit?: SlidingWindowRateLimitOptions | SlidingWindowRateLimitOptions[]; bots?: BotOptions | BotOptions[]; email?: EmailOptions | EmailOptions[]; }; @@ -776,9 +877,9 @@ export function protectSignup( ): Product<{ email: string }> { let rateLimitRules: Primitive<{}> = []; if (Array.isArray(options?.rateLimit)) { - rateLimitRules = rateLimit(...options.rateLimit); + rateLimitRules = slidingWindow(...options.rateLimit); } else { - rateLimitRules = rateLimit(options?.rateLimit); + rateLimitRules = slidingWindow(options?.rateLimit); } let botRules: Primitive<{}> = []; @@ -788,7 +889,7 @@ export function protectSignup( botRules = detectBot(options?.bots); } - let emailRules: Primitive<{}> = []; + let emailRules: Primitive<{ email: string }> = []; if (Array.isArray(options?.email)) { emailRules = validateEmail(...options.email); } else { diff --git a/arcjet/test/index.edge.test.ts b/arcjet/test/index.edge.test.ts index 9ec353aef..406cadff9 100644 --- a/arcjet/test/index.edge.test.ts +++ b/arcjet/test/index.edge.test.ts @@ -5,6 +5,7 @@ import { describe, expect, test, jest } from "@jest/globals"; import arcjet, { rateLimit, + tokenBucket, protectSignup, Primitive, ArcjetReason, @@ -35,7 +36,15 @@ describe("Arcjet: Env = Edge runtime", () => { rules: [ // Test rules foobarbaz(), - rateLimit(), + tokenBucket({ + refillRate: 1, + interval: 1, + capacity: 1, + }), + rateLimit({ + max: 1, + window: "60s", + }), protectSignup(), ], client, @@ -43,6 +52,7 @@ describe("Arcjet: Env = Edge runtime", () => { const decision = await aj.protect({ abc: 123, + requested: 1, email: "", ip: "", method: "", diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index 3f8511924..904d6cb3a 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -52,6 +52,9 @@ import arcjet, { ArcjetBotReason, ArcjetRateLimitReason, ArcjetLocalRule, + fixedWindow, + tokenBucket, + slidingWindow, } from "../index"; // Instances of Headers contain symbols that may be different depending @@ -1455,7 +1458,7 @@ describe("ArcjetDecision", () => { }); }); -describe("Primitives > detectBot", () => { +describe("Primitive > detectBot", () => { test("provides a default rule with no options specified", async () => { const [rule] = detectBot(); expect(rule.type).toEqual("BOT"); @@ -1912,6 +1915,435 @@ describe("Primitives > detectBot", () => { }); }); +describe("Primitive > tokenBucket", () => { + test("provides no rules if no `options` specified", () => { + const rules = tokenBucket(); + expect(rules).toHaveLength(0); + }); + + test("sets mode as `DRY_RUN` if not 'LIVE' or 'DRY_RUN'", async () => { + const [rule] = tokenBucket({ + // @ts-expect-error + mode: "INVALID", + match: "/test", + characteristics: ["ip.src"], + refillRate: 1, + interval: 1, + capacity: 1, + }); + expect(rule.type).toEqual("RATE_LIMIT"); + expect(rule).toHaveProperty("mode", "DRY_RUN"); + }); + + test("sets mode as `LIVE` if specified", async () => { + const [rule] = tokenBucket({ + mode: "LIVE", + match: "/test", + characteristics: ["ip.src"], + refillRate: 1, + interval: 1, + capacity: 1, + }); + expect(rule.type).toEqual("RATE_LIMIT"); + expect(rule).toHaveProperty("mode", "LIVE"); + }); + + test("produces a rules based on single `limit` specified", async () => { + const options = { + match: "/test", + characteristics: ["ip.src"], + refillRate: 1, + interval: 1, + capacity: 1, + }; + + const rules = tokenBucket(options); + expect(rules).toHaveLength(1); + expect(rules[0].type).toEqual("RATE_LIMIT"); + expect(rules[0]).toHaveProperty("mode", "DRY_RUN"); + expect(rules[0]).toHaveProperty("match", "/test"); + expect(rules[0]).toHaveProperty("characteristics", ["ip.src"]); + expect(rules[0]).toHaveProperty("algorithm", "TOKEN_BUCKET"); + expect(rules[0]).toHaveProperty("refillRate", 1); + expect(rules[0]).toHaveProperty("interval", 1); + expect(rules[0]).toHaveProperty("capacity", 1); + }); + + test("produces a multiple rules based on multiple `limit` specified", async () => { + const options = [ + { + match: "/test", + characteristics: ["ip.src"], + refillRate: 1, + interval: 1, + capacity: 1, + }, + { + match: "/test-double", + characteristics: ["ip.src"], + refillRate: 2, + interval: 2, + capacity: 2, + }, + ]; + + const rules = tokenBucket(...options); + expect(rules).toHaveLength(2); + expect(rules).toEqual([ + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: "/test", + characteristics: ["ip.src"], + algorithm: "TOKEN_BUCKET", + refillRate: 1, + interval: 1, + capacity: 1, + }), + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: "/test-double", + characteristics: ["ip.src"], + algorithm: "TOKEN_BUCKET", + refillRate: 2, + interval: 2, + capacity: 2, + }), + ]); + }); + + test("does not default `match` and `characteristics` if not specified in single `limit`", async () => { + const options = { + refillRate: 1, + interval: 1, + capacity: 1, + }; + + const [rule] = tokenBucket(options); + expect(rule.type).toEqual("RATE_LIMIT"); + expect(rule).toHaveProperty("match", undefined); + expect(rule).toHaveProperty("characteristics", undefined); + }); + + test("does not default `match` or `characteristics` if not specified in array `limit`", async () => { + const options = [ + { + refillRate: 1, + interval: 1, + capacity: 1, + }, + { + refillRate: 2, + interval: 2, + capacity: 2, + }, + ]; + + const rules = tokenBucket(...options); + expect(rules).toEqual([ + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: undefined, + characteristics: undefined, + algorithm: "TOKEN_BUCKET", + refillRate: 1, + interval: 1, + capacity: 1, + }), + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: undefined, + characteristics: undefined, + refillRate: 2, + interval: 2, + capacity: 2, + }), + ]); + }); +}); + +describe("Primitive > fixedWindow", () => { + test("provides no rules if no `options` specified", () => { + const rules = fixedWindow(); + expect(rules).toHaveLength(0); + }); + + test("sets mode as `DRY_RUN` if not 'LIVE' or 'DRY_RUN'", async () => { + const [rule] = fixedWindow({ + // @ts-expect-error + mode: "INVALID", + match: "/test", + characteristics: ["ip.src"], + window: "1h", + max: 1, + }); + expect(rule.type).toEqual("RATE_LIMIT"); + expect(rule).toHaveProperty("mode", "DRY_RUN"); + }); + + test("sets mode as `LIVE` if specified", async () => { + const [rule] = fixedWindow({ + mode: "LIVE", + match: "/test", + characteristics: ["ip.src"], + window: "1h", + max: 1, + }); + expect(rule.type).toEqual("RATE_LIMIT"); + expect(rule).toHaveProperty("mode", "LIVE"); + }); + + test("produces a rules based on single `limit` specified", async () => { + const options = { + match: "/test", + characteristics: ["ip.src"], + window: "1h", + max: 1, + }; + + const rules = fixedWindow(options); + expect(rules).toHaveLength(1); + expect(rules[0].type).toEqual("RATE_LIMIT"); + expect(rules[0]).toHaveProperty("mode", "DRY_RUN"); + expect(rules[0]).toHaveProperty("match", "/test"); + expect(rules[0]).toHaveProperty("characteristics", ["ip.src"]); + expect(rules[0]).toHaveProperty("algorithm", "FIXED_WINDOW"); + expect(rules[0]).toHaveProperty("window", "1h"); + expect(rules[0]).toHaveProperty("max", 1); + }); + + test("produces a multiple rules based on multiple `limit` specified", async () => { + const options = [ + { + match: "/test", + characteristics: ["ip.src"], + window: "1h", + max: 1, + }, + { + match: "/test-double", + characteristics: ["ip.src"], + window: "2h", + max: 2, + }, + ]; + + const rules = fixedWindow(...options); + expect(rules).toHaveLength(2); + expect(rules).toEqual([ + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: "/test", + characteristics: ["ip.src"], + algorithm: "FIXED_WINDOW", + window: "1h", + max: 1, + }), + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: "/test-double", + characteristics: ["ip.src"], + algorithm: "FIXED_WINDOW", + window: "2h", + max: 2, + }), + ]); + }); + + test("does not default `match` and `characteristics` if not specified in single `limit`", async () => { + const options = { + window: "1h", + max: 1, + }; + + const [rule] = fixedWindow(options); + expect(rule.type).toEqual("RATE_LIMIT"); + expect(rule).toHaveProperty("match", undefined); + expect(rule).toHaveProperty("characteristics", undefined); + }); + + test("does not default `match` or `characteristics` if not specified in array `limit`", async () => { + const options = [ + { + window: "1h", + max: 1, + }, + { + window: "2h", + max: 2, + }, + ]; + + const rules = fixedWindow(...options); + expect(rules).toEqual([ + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: undefined, + characteristics: undefined, + algorithm: "FIXED_WINDOW", + window: "1h", + max: 1, + }), + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: undefined, + characteristics: undefined, + algorithm: "FIXED_WINDOW", + window: "2h", + max: 2, + }), + ]); + }); +}); + +describe("Primitive > slidingWindow", () => { + test("provides no rules if no `options` specified", () => { + const rules = slidingWindow(); + expect(rules).toHaveLength(0); + }); + + test("sets mode as `DRY_RUN` if not 'LIVE' or 'DRY_RUN'", async () => { + const [rule] = slidingWindow({ + // @ts-expect-error + mode: "INVALID", + match: "/test", + characteristics: ["ip.src"], + interval: 3600, + max: 1, + }); + expect(rule.type).toEqual("RATE_LIMIT"); + expect(rule).toHaveProperty("mode", "DRY_RUN"); + }); + + test("sets mode as `LIVE` if specified", async () => { + const [rule] = slidingWindow({ + mode: "LIVE", + match: "/test", + characteristics: ["ip.src"], + interval: 3600, + max: 1, + }); + expect(rule.type).toEqual("RATE_LIMIT"); + expect(rule).toHaveProperty("mode", "LIVE"); + }); + + test("produces a rules based on single `limit` specified", async () => { + const options = { + match: "/test", + characteristics: ["ip.src"], + interval: 3600, + max: 1, + }; + + const rules = slidingWindow(options); + expect(rules).toHaveLength(1); + expect(rules[0].type).toEqual("RATE_LIMIT"); + expect(rules[0]).toHaveProperty("mode", "DRY_RUN"); + expect(rules[0]).toHaveProperty("match", "/test"); + expect(rules[0]).toHaveProperty("characteristics", ["ip.src"]); + expect(rules[0]).toHaveProperty("algorithm", "SLIDING_WINDOW"); + expect(rules[0]).toHaveProperty("interval", 3600); + expect(rules[0]).toHaveProperty("max", 1); + }); + + test("produces a multiple rules based on multiple `limit` specified", async () => { + const options = [ + { + match: "/test", + characteristics: ["ip.src"], + interval: 3600, + max: 1, + }, + { + match: "/test-double", + characteristics: ["ip.src"], + interval: 7200, + max: 2, + }, + ]; + + const rules = slidingWindow(...options); + expect(rules).toHaveLength(2); + expect(rules).toEqual([ + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: "/test", + characteristics: ["ip.src"], + algorithm: "SLIDING_WINDOW", + interval: 3600, + max: 1, + }), + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: "/test-double", + characteristics: ["ip.src"], + algorithm: "SLIDING_WINDOW", + interval: 7200, + max: 2, + }), + ]); + }); + + test("does not default `match` and `characteristics` if not specified in single `limit`", async () => { + const options = { + interval: 3600, + max: 1, + }; + + const [rule] = slidingWindow(options); + expect(rule.type).toEqual("RATE_LIMIT"); + expect(rule).toHaveProperty("match", undefined); + expect(rule).toHaveProperty("characteristics", undefined); + }); + + test("does not default `match` or `characteristics` if not specified in array `limit`", async () => { + const options = [ + { + interval: 3600, + max: 1, + }, + { + interval: 7200, + max: 2, + }, + ]; + + const rules = slidingWindow(...options); + expect(rules).toEqual([ + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: undefined, + characteristics: undefined, + algorithm: "SLIDING_WINDOW", + interval: 3600, + max: 1, + }), + expect.objectContaining({ + type: "RATE_LIMIT", + mode: "DRY_RUN", + match: undefined, + characteristics: undefined, + algorithm: "SLIDING_WINDOW", + interval: 7200, + max: 2, + }), + ]); + }); +}); + +// The `rateLimit` primitive just proxies to `fixedWindow` and is available for +// backwards compatibility. +// TODO: Remove these tests once `rateLimit` is removed describe("Primitive > rateLimit", () => { test("provides no rules if no `options` specified", () => { const rules = rateLimit(); @@ -1931,6 +2363,18 @@ describe("Primitive > rateLimit", () => { expect(rule).toHaveProperty("mode", "DRY_RUN"); }); + test("sets mode as `LIVE` if specified", async () => { + const [rule] = rateLimit({ + mode: "LIVE", + match: "/test", + characteristics: ["ip.src"], + window: "1h", + max: 1, + }); + expect(rule.type).toEqual("RATE_LIMIT"); + expect(rule).toHaveProperty("mode", "LIVE"); + }); + test("produces a rules based on single `limit` specified", async () => { const options = { match: "/test", @@ -1945,6 +2389,7 @@ describe("Primitive > rateLimit", () => { expect(rules[0]).toHaveProperty("mode", "DRY_RUN"); expect(rules[0]).toHaveProperty("match", "/test"); expect(rules[0]).toHaveProperty("characteristics", ["ip.src"]); + expect(rules[0]).toHaveProperty("algorithm", "FIXED_WINDOW"); expect(rules[0]).toHaveProperty("window", "1h"); expect(rules[0]).toHaveProperty("max", 1); }); @@ -1973,6 +2418,7 @@ describe("Primitive > rateLimit", () => { mode: "DRY_RUN", match: "/test", characteristics: ["ip.src"], + algorithm: "FIXED_WINDOW", window: "1h", max: 1, }), @@ -1981,6 +2427,7 @@ describe("Primitive > rateLimit", () => { mode: "DRY_RUN", match: "/test-double", characteristics: ["ip.src"], + algorithm: "FIXED_WINDOW", window: "2h", max: 2, }), @@ -2018,6 +2465,7 @@ describe("Primitive > rateLimit", () => { mode: "DRY_RUN", match: undefined, characteristics: undefined, + algorithm: "FIXED_WINDOW", window: "1h", max: 1, }), @@ -2026,6 +2474,7 @@ describe("Primitive > rateLimit", () => { mode: "DRY_RUN", match: undefined, characteristics: undefined, + algorithm: "FIXED_WINDOW", window: "2h", max: 2, }), @@ -2033,7 +2482,7 @@ describe("Primitive > rateLimit", () => { }); }); -describe("Primitives > validateEmail", () => { +describe("Primitive > validateEmail", () => { test("provides a default rule with no options specified", async () => { const [rule] = validateEmail(); expect(rule.type).toEqual("EMAIL"); @@ -2365,7 +2814,7 @@ describe("Products > protectSignup", () => { mode: ArcjetMode.DRY_RUN, match: "/test", characteristics: ["ip.src"], - window: "1h", + interval: 60 /* minutes */ * 60 /* seconds */, max: 1, }, bots: { @@ -2385,13 +2834,13 @@ describe("Products > protectSignup", () => { mode: ArcjetMode.DRY_RUN, match: "/test", characteristics: ["ip.src"], - window: "1h", + interval: 60 /* minutes */ * 60 /* seconds */, max: 1, }, { match: "/test", characteristics: ["ip.src"], - window: "2h", + interval: 2 /* hours */ * 60 /* minutes */ * 60 /* seconds */, max: 2, }, ], diff --git a/examples/nextjs-13-pages-wrap/pages/api/arcjet-edge.ts b/examples/nextjs-13-pages-wrap/pages/api/arcjet-edge.ts index bc2019b17..2b23e11e9 100644 --- a/examples/nextjs-13-pages-wrap/pages/api/arcjet-edge.ts +++ b/examples/nextjs-13-pages-wrap/pages/api/arcjet-edge.ts @@ -1,5 +1,5 @@ // Next.js API route support: https://nextjs.org/docs/api-routes/introduction -import arcjet, { rateLimit, withArcjet } from "@arcjet/next"; +import arcjet, { fixedWindow, withArcjet } from "@arcjet/next"; import { NextRequest, NextResponse } from "next/server"; export const config = { @@ -12,7 +12,7 @@ const aj = arcjet({ // See: https://nextjs.org/docs/pages/building-your-application/configuring/environment-variables key: process.env.AJ_KEY!, rules: [ - rateLimit({ + fixedWindow({ mode: "LIVE", // Limiting by ip.src is the default if not specified //characteristics: ["ip.src"], diff --git a/examples/nextjs-13-pages-wrap/pages/api/arcjet.ts b/examples/nextjs-13-pages-wrap/pages/api/arcjet.ts index 76873e31b..f55c97352 100644 --- a/examples/nextjs-13-pages-wrap/pages/api/arcjet.ts +++ b/examples/nextjs-13-pages-wrap/pages/api/arcjet.ts @@ -1,5 +1,5 @@ // Next.js API route support: https://nextjs.org/docs/api-routes/introduction -import arcjet, { rateLimit, withArcjet } from "@arcjet/next"; +import arcjet, { fixedWindow, withArcjet } from "@arcjet/next"; import type { NextApiRequest, NextApiResponse } from "next"; const aj = arcjet({ @@ -8,7 +8,7 @@ const aj = arcjet({ // See: https://nextjs.org/docs/pages/building-your-application/configuring/environment-variables key: process.env.AJ_KEY!, rules: [ - rateLimit({ + fixedWindow({ mode: "LIVE", // Limiting by ip.src is the default if not specified //characteristics: ["ip.src"], diff --git a/examples/nextjs-14-app-dir-rl/app/api/arcjet/route.ts b/examples/nextjs-14-app-dir-rl/app/api/arcjet/route.ts index b625c301f..194ca8c71 100644 --- a/examples/nextjs-14-app-dir-rl/app/api/arcjet/route.ts +++ b/examples/nextjs-14-app-dir-rl/app/api/arcjet/route.ts @@ -1,4 +1,4 @@ -import arcjet, { rateLimit } from "@arcjet/next"; +import arcjet, { fixedWindow } from "@arcjet/next"; import { NextResponse } from "next/server"; const aj = arcjet({ @@ -7,7 +7,7 @@ const aj = arcjet({ // See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables key: process.env.AJ_KEY!, rules: [ - rateLimit({ + fixedWindow({ mode: "LIVE", characteristics: ["ip.src"], window: "1h", diff --git a/examples/nextjs-14-openai/app/api/chat/route.ts b/examples/nextjs-14-openai/app/api/chat/route.ts index b89c93201..8605554a4 100644 --- a/examples/nextjs-14-openai/app/api/chat/route.ts +++ b/examples/nextjs-14-openai/app/api/chat/route.ts @@ -1,5 +1,5 @@ // This example is adapted from https://sdk.vercel.ai/docs/guides/frameworks/nextjs-app -import arcjet, { rateLimit } from "@arcjet/next"; +import arcjet, { tokenBucket } from "@arcjet/next"; import { OpenAIStream, StreamingTextResponse } from "ai"; import OpenAI from "openai"; import { promptTokensEstimate } from "openai-chat-tokens"; @@ -11,11 +11,12 @@ const aj = arcjet({ // See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables key: process.env.AJ_KEY!, rules: [ - rateLimit({ + tokenBucket({ mode: "LIVE", characteristics: ["ip.src"], - window: "1m", - max: 60, + refillRate: 40_000, + interval: 1 /* day */ * 24 /* hours */ * 60 /* minutes */ * 60 /* seconds */, + capacity: 40_000, }), ], }); @@ -28,8 +29,16 @@ const openai = new OpenAI({ export const runtime = "edge"; export async function POST(req: Request) { + const { messages } = await req.json(); + + const estimate = promptTokensEstimate({ + messages, + }); + + console.log("Token estimate", estimate); + // Protect the route with Arcjet - const decision = await aj.protect(req); + const decision = await aj.protect(req, { requested: estimate }); console.log("Arcjet decision", decision.conclusion); if (decision.reason.isRateLimit()) { @@ -50,14 +59,6 @@ export async function POST(req: Request) { } // If the request is allowed, continue to use OpenAI - const { messages } = await req.json(); - - const estimate = promptTokensEstimate({ - messages, - }); - - console.log("Token estimate", estimate); - // Ask OpenAI for a streaming chat completion given the prompt const response = await openai.chat.completions.create({ model: "gpt-3.5-turbo", @@ -69,4 +70,4 @@ export async function POST(req: Request) { const stream = OpenAIStream(response); // Respond with the stream return new StreamingTextResponse(stream); -} \ No newline at end of file +} diff --git a/examples/nextjs-14-pages-wrap/pages/api/arcjet-edge.ts b/examples/nextjs-14-pages-wrap/pages/api/arcjet-edge.ts index bc2019b17..2b23e11e9 100644 --- a/examples/nextjs-14-pages-wrap/pages/api/arcjet-edge.ts +++ b/examples/nextjs-14-pages-wrap/pages/api/arcjet-edge.ts @@ -1,5 +1,5 @@ // Next.js API route support: https://nextjs.org/docs/api-routes/introduction -import arcjet, { rateLimit, withArcjet } from "@arcjet/next"; +import arcjet, { fixedWindow, withArcjet } from "@arcjet/next"; import { NextRequest, NextResponse } from "next/server"; export const config = { @@ -12,7 +12,7 @@ const aj = arcjet({ // See: https://nextjs.org/docs/pages/building-your-application/configuring/environment-variables key: process.env.AJ_KEY!, rules: [ - rateLimit({ + fixedWindow({ mode: "LIVE", // Limiting by ip.src is the default if not specified //characteristics: ["ip.src"], diff --git a/examples/nextjs-14-pages-wrap/pages/api/arcjet.ts b/examples/nextjs-14-pages-wrap/pages/api/arcjet.ts index 76873e31b..f55c97352 100644 --- a/examples/nextjs-14-pages-wrap/pages/api/arcjet.ts +++ b/examples/nextjs-14-pages-wrap/pages/api/arcjet.ts @@ -1,5 +1,5 @@ // Next.js API route support: https://nextjs.org/docs/api-routes/introduction -import arcjet, { rateLimit, withArcjet } from "@arcjet/next"; +import arcjet, { fixedWindow, withArcjet } from "@arcjet/next"; import type { NextApiRequest, NextApiResponse } from "next"; const aj = arcjet({ @@ -8,7 +8,7 @@ const aj = arcjet({ // See: https://nextjs.org/docs/pages/building-your-application/configuring/environment-variables key: process.env.AJ_KEY!, rules: [ - rateLimit({ + fixedWindow({ mode: "LIVE", // Limiting by ip.src is the default if not specified //characteristics: ["ip.src"], diff --git a/protocol/convert.ts b/protocol/convert.ts index 8fbcb0af1..4725adf17 100644 --- a/protocol/convert.ts +++ b/protocol/convert.ts @@ -23,6 +23,9 @@ import { ArcjetRateLimitRule, ArcjetBotRule, ArcjetEmailRule, + ArcjetTokenBucketRateLimitRule, + ArcjetFixedWindowRateLimitRule, + ArcjetSlidingWindowRateLimitRule, } from "./index"; import { BotReason, @@ -34,6 +37,7 @@ import { EmailType, ErrorReason, Mode, + RateLimitAlgorithm, RateLimitReason, Reason, Rule, @@ -460,6 +464,22 @@ function isRateLimitRule( return rule.type === "RATE_LIMIT"; } +function isTokenBucketRule( + rule: ArcjetRule, +): rule is ArcjetTokenBucketRateLimitRule { + return isRateLimitRule(rule) && rule.algorithm === "TOKEN_BUCKET"; +} +function isFixedWindowRule( + rule: ArcjetRule, +): rule is ArcjetFixedWindowRateLimitRule { + return isRateLimitRule(rule) && rule.algorithm === "FIXED_WINDOW"; +} +function isSlidingWindowRule( + rule: ArcjetRule, +): rule is ArcjetSlidingWindowRateLimitRule { + return isRateLimitRule(rule) && rule.algorithm === "SLIDING_WINDOW"; +} + function isBotRule( rule: ArcjetRule, ): rule is ArcjetBotRule { @@ -475,7 +495,24 @@ function isEmailRule( export function ArcjetRuleToProtocol( rule: ArcjetRule, ): Rule { - if (isRateLimitRule(rule)) { + if (isTokenBucketRule(rule)) { + return new Rule({ + rule: { + case: "rateLimit", + value: { + mode: ArcjetModeToProtocol(rule.mode), + match: rule.match, + characteristics: rule.characteristics, + algorithm: RateLimitAlgorithm.TOKEN_BUCKET, + refillRate: rule.refillRate, + interval: rule.interval, + capacity: rule.capacity, + }, + }, + }); + } + + if (isFixedWindowRule(rule)) { return new Rule({ rule: { case: "rateLimit", @@ -483,8 +520,25 @@ export function ArcjetRuleToProtocol( mode: ArcjetModeToProtocol(rule.mode), match: rule.match, characteristics: rule.characteristics, + algorithm: RateLimitAlgorithm.FIXED_WINDOW, + max: rule.max, window: rule.window, + }, + }, + }); + } + + if (isSlidingWindowRule(rule)) { + return new Rule({ + rule: { + case: "rateLimit", + value: { + mode: ArcjetModeToProtocol(rule.mode), + match: rule.match, + characteristics: rule.characteristics, + algorithm: RateLimitAlgorithm.SLIDING_WINDOW, max: rule.max, + interval: rule.interval, }, }, }); diff --git a/protocol/index.ts b/protocol/index.ts index c31f3aa48..b806da4e8 100644 --- a/protocol/index.ts +++ b/protocol/index.ts @@ -10,6 +10,17 @@ export const ArcjetMode: ArcjetEnum = Object.freeze({ DRY_RUN: "DRY_RUN", }); +export type ArcjetRateLimitAlgorithm = + | "TOKEN_BUCKET" + | "FIXED_WINDOW" + | "SLIDING_WINDOW"; +export const ArcjetRateLimitAlgorithm: ArcjetEnum = + Object.freeze({ + TOKEN_BUCKET: "TOKEN_BUCKET", + FIXED_WINDOW: "FIXED_WINDOW", + SLIDING_WINDOW: "SLIDING_WINDOW", + }); + export type ArcjetBotType = | "NOT_ANALYZED" | "AUTOMATED" @@ -376,6 +387,9 @@ export type ArcjetRule = { priority: number; }; +// An ArcjetLocalRule provides additional `validate` and `protect` functions +// which are used to provide local protections before requesting protections +// from the Arcjet service. export interface ArcjetLocalRule extends ArcjetRule { validate( @@ -391,11 +405,38 @@ export interface ArcjetLocalRule export interface ArcjetRateLimitRule extends ArcjetRule { type: "RATE_LIMIT"; + algorithm: ArcjetRateLimitAlgorithm; +} + +export interface ArcjetTokenBucketRateLimitRule + extends ArcjetRateLimitRule { + algorithm: "TOKEN_BUCKET"; + + match?: string; + characteristics?: string[]; + refillRate: number; + interval: number; + capacity: number; +} + +export interface ArcjetFixedWindowRateLimitRule + extends ArcjetRateLimitRule { + algorithm: "FIXED_WINDOW"; match?: string; characteristics?: string[]; + max: number; window: string; +} + +export interface ArcjetSlidingWindowRateLimitRule + extends ArcjetRateLimitRule { + algorithm: "SLIDING_WINDOW"; + + match?: string; + characteristics?: string[]; max: number; + interval: number; } export interface ArcjetEmailRule diff --git a/protocol/test/convert.test.ts b/protocol/test/convert.test.ts index 107bd76ce..e7453e010 100644 --- a/protocol/test/convert.test.ts +++ b/protocol/test/convert.test.ts @@ -27,6 +27,7 @@ import { Decision, EmailType, Mode, + RateLimitAlgorithm, Reason, Rule, RuleResult, @@ -48,6 +49,9 @@ import { ArcjetReason, ArcjetRuleResult, ArcjetShieldReason, + ArcjetTokenBucketRateLimitRule, + ArcjetFixedWindowRateLimitRule, + ArcjetSlidingWindowRateLimitRule, } from "../index.js"; import { Timestamp } from "@bufbuild/protobuf"; @@ -583,10 +587,47 @@ describe("convert", () => { }), ).toEqual(new Rule({})); expect( - ArcjetRuleToProtocol({ + ArcjetRuleToProtocol(>{ + type: "RATE_LIMIT", + mode: "DRY_RUN", + priority: 1, + algorithm: "TOKEN_BUCKET", + }), + ).toEqual( + new Rule({ + rule: { + case: "rateLimit", + value: { + mode: Mode.DRY_RUN, + algorithm: RateLimitAlgorithm.TOKEN_BUCKET, + }, + }, + }), + ); + expect( + ArcjetRuleToProtocol(>{ + type: "RATE_LIMIT", + mode: "DRY_RUN", + priority: 1, + algorithm: "FIXED_WINDOW", + }), + ).toEqual( + new Rule({ + rule: { + case: "rateLimit", + value: { + mode: Mode.DRY_RUN, + algorithm: RateLimitAlgorithm.FIXED_WINDOW, + }, + }, + }), + ); + expect( + ArcjetRuleToProtocol(>{ type: "RATE_LIMIT", mode: "DRY_RUN", priority: 1, + algorithm: "SLIDING_WINDOW", }), ).toEqual( new Rule({ @@ -594,6 +635,7 @@ describe("convert", () => { case: "rateLimit", value: { mode: Mode.DRY_RUN, + algorithm: RateLimitAlgorithm.SLIDING_WINDOW, }, }, }),