Skip to content

Commit

Permalink
feat!: Add fixedWindow, tokenBucket, and slidingWindow primitives
Browse files Browse the repository at this point in the history
  • Loading branch information
blaine-arcjet committed Feb 6, 2024
1 parent 807e8de commit 171cbcd
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 35 deletions.
131 changes: 116 additions & 15 deletions arcjet/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ import {
ArcjetDecision,
ArcjetDenyDecision,
ArcjetErrorDecision,
ArcjetRateLimitRule,
ArcjetBotRule,
ArcjetRule,
ArcjetLocalRule,
ArcjetRequestDetails,
ArcjetTokenBucketRateLimitRule,
ArcjetFixedWindowRateLimitRule,
ArcjetSlidingWindowRateLimitRule,
} from "@arcjet/protocol";
import {
ArcjetBotTypeToProtocol,
Expand Down Expand Up @@ -414,14 +416,31 @@ function runtime(): Runtime {
}
}

export type RateLimitOptions = {
type TokenBucketRateLimitOptions = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
refillRate: number;
interval: number;
capacity: number;
};

type FixedWindowRateLimitOptions = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
window: string;
max: number;
};

type SlidingWindowRateLimitOptions = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
interval: number;
max: number;
};

/**
* Bot detection is disabled by default. The `bots` configuration block allows
* you to enable or disable it and configure additional rules.
Expand Down Expand Up @@ -570,30 +589,112 @@ function isLocalRule<Props extends PlainObject>(
);
}

export function tokenBucket(
options?: TokenBucketRateLimitOptions,
...additionalOptions: TokenBucketRateLimitOptions[]
): Primitive<{ requested: number }> {
const rules: ArcjetTokenBucketRateLimitRule<{ requested: number }>[] = [];

if (typeof options === "undefined") {
return rules;
}

for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = opt.match;
const characteristics = opt.characteristics;

const refillRate = opt.refillRate;
const interval = opt.interval;
const capacity = opt.capacity;

rules.push({
type: "RATE_LIMIT",
priority: Priority.RateLimit,
mode,
match,
characteristics,
algorithm: "TOKEN_BUCKET",
refillRate,
interval,
capacity,
});
}

return rules;
}

export function fixedWindow(
options?: FixedWindowRateLimitOptions,
...additionalOptions: FixedWindowRateLimitOptions[]
): Primitive {
const rules: ArcjetFixedWindowRateLimitRule<{}>[] = [];

if (typeof options === "undefined") {
return rules;
}

for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = opt.match;
const characteristics = opt.characteristics;

const max = opt.max;
const window = opt.window;

rules.push({
type: "RATE_LIMIT",
priority: Priority.RateLimit,
mode,
match,
characteristics,
algorithm: "FIXED_WINDOW",
max,
window,
});
}

return rules;
}

// This is currently kept for backwards compatibility but should be removed in
// favor of the fixedWindow primitive.
export function rateLimit(
options?: RateLimitOptions,
...additionalOptions: RateLimitOptions[]
options?: FixedWindowRateLimitOptions,
...additionalOptions: FixedWindowRateLimitOptions[]
): Primitive {
// TODO(#195): We should also have a local rate limit using an in-memory data
// structure if the environment supports it
return fixedWindow(options, ...additionalOptions);
}

const rules: ArcjetRateLimitRule<{}>[] = [];
export function slidingWindow(
options?: SlidingWindowRateLimitOptions,
...additionalOptions: SlidingWindowRateLimitOptions[]
): Primitive {
const rules: ArcjetSlidingWindowRateLimitRule<{}>[] = [];

if (typeof options === "undefined") {
return rules;
}

for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const mode = options.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = options.match;
const characteristics = options.characteristics;

const max = options.max;
const interval = options.interval;

rules.push({
type: "RATE_LIMIT",
priority: Priority.RateLimit,
mode,
match: opt.match,
characteristics: opt.characteristics,
window: opt.window,
max: opt.max,
match,
characteristics,
algorithm: "SLIDING_WINDOW",
max,
interval,
});
}

Expand Down Expand Up @@ -674,7 +775,7 @@ export function detectBot(
// Always create at least one BOT rule
for (const opt of [options ?? {}, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
// TODO: Filter invalid email types (or error??)
// TODO: Filter invalid bot types (or error??)
const block = Array.isArray(opt.block)
? opt.block
: [ArcjetBotType.AUTOMATED];
Expand Down Expand Up @@ -766,7 +867,7 @@ export function detectBot(
}

export type ProtectSignupOptions = {
rateLimit?: RateLimitOptions | RateLimitOptions[];
rateLimit?: SlidingWindowRateLimitOptions | SlidingWindowRateLimitOptions[];
bots?: BotOptions | BotOptions[];
email?: EmailOptions | EmailOptions[];
};
Expand All @@ -776,9 +877,9 @@ export function protectSignup(
): Product<{ email: string }> {
let rateLimitRules: Primitive<{}> = [];
if (Array.isArray(options?.rateLimit)) {
rateLimitRules = rateLimit(...options.rateLimit);
rateLimitRules = slidingWindow(...options.rateLimit);
} else {
rateLimitRules = rateLimit(options?.rateLimit);
rateLimitRules = slidingWindow(options?.rateLimit);
}

let botRules: Primitive<{}> = [];
Expand All @@ -788,7 +889,7 @@ export function protectSignup(
botRules = detectBot(options?.bots);
}

let emailRules: Primitive<{}> = [];
let emailRules: Primitive<{ email: string }> = [];
if (Array.isArray(options?.email)) {
emailRules = validateEmail(...options.email);
} else {
Expand Down
12 changes: 11 additions & 1 deletion arcjet/test/index.edge.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { describe, expect, test, jest } from "@jest/globals";

import arcjet, {
rateLimit,
tokenBucket,
protectSignup,
Primitive,
ArcjetReason,
Expand Down Expand Up @@ -35,14 +36,23 @@ describe("Arcjet: Env = Edge runtime", () => {
rules: [
// Test rules
foobarbaz(),
rateLimit(),
tokenBucket({
refillRate: 1,
interval: 1,
capacity: 1,
}),
rateLimit({
max: 1,
window: "60s",
}),
protectSignup(),
],
client,
});

const decision = await aj.protect({
abc: 123,
requested: 1,
email: "",
ip: "",
method: "",
Expand Down
6 changes: 3 additions & 3 deletions arcjet/test/index.node.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2365,7 +2365,7 @@ describe("Products > protectSignup", () => {
mode: ArcjetMode.DRY_RUN,
match: "/test",
characteristics: ["ip.src"],
window: "1h",
interval: 60 /* minutes */ * 60 /* seconds */,
max: 1,
},
bots: {
Expand All @@ -2385,13 +2385,13 @@ describe("Products > protectSignup", () => {
mode: ArcjetMode.DRY_RUN,
match: "/test",
characteristics: ["ip.src"],
window: "1h",
interval: 60 /* minutes */ * 60 /* seconds */,
max: 1,
},
{
match: "/test",
characteristics: ["ip.src"],
window: "2h",
interval: 2 /* hours */ * 60 /* minutes */ * 60 /* seconds */,
max: 2,
},
],
Expand Down
29 changes: 15 additions & 14 deletions examples/nextjs-14-openai/app/api/chat/route.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// This example is adapted from https://sdk.vercel.ai/docs/guides/frameworks/nextjs-app
import arcjet, { rateLimit } from "@arcjet/next";
import arcjet, { tokenBucket } from "@arcjet/next";
import { OpenAIStream, StreamingTextResponse } from "ai";
import OpenAI from "openai";
import { promptTokensEstimate } from "openai-chat-tokens";
Expand All @@ -11,11 +11,12 @@ const aj = arcjet({
// See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables
key: process.env.AJ_KEY!,
rules: [
rateLimit({
tokenBucket({
mode: "LIVE",
characteristics: ["ip.src"],
window: "1m",
max: 60,
refillRate: 1,
interval: 60,
capacity: 1,
}),
],
});
Expand All @@ -28,8 +29,16 @@ const openai = new OpenAI({
export const runtime = "edge";

export async function POST(req: Request) {
const { messages } = await req.json();

const estimate = promptTokensEstimate({
messages,
});

console.log("Token estimate", estimate);

// Protect the route with Arcjet
const decision = await aj.protect(req);
const decision = await aj.protect(req, { requested: estimate });
console.log("Arcjet decision", decision.conclusion);

if (decision.reason.isRateLimit()) {
Expand All @@ -50,14 +59,6 @@ export async function POST(req: Request) {
}

// If the request is allowed, continue to use OpenAI
const { messages } = await req.json();

const estimate = promptTokensEstimate({
messages,
});

console.log("Token estimate", estimate);

// Ask OpenAI for a streaming chat completion given the prompt
const response = await openai.chat.completions.create({
model: "gpt-3.5-turbo",
Expand All @@ -69,4 +70,4 @@ export async function POST(req: Request) {
const stream = OpenAIStream(response);
// Respond with the stream
return new StreamingTextResponse(stream);
}
}
Loading

0 comments on commit 171cbcd

Please sign in to comment.