Skip to content

Commit

Permalink
feat!: Add fixedWindow, tokenBucket, and slidingWindow primitives (#184)
Browse files Browse the repository at this point in the history
Replaces #171 

This adds the `fixedWindow`, `tokenBucket` and `slidingWindow` primitives. It keeps the `rateLimit()` primitive for backwards compatibility, but it is breaking because `protectSignup` was altered to require sliding window configurations.

The changes were streamlined by putting `requested` into the `extra` field of the RequestDetails. This allowed us to keep the same structure for specifying our Rules.

Draft because I need to add tests to maintain full test coverage and because it relies on #179, #180, #181, #182, and #183 (after which, I'll rebase all the cherry-picked commits out of this PR).
  • Loading branch information
blaine-arcjet authored Feb 6, 2024
1 parent 9d42276 commit 6701b02
Show file tree
Hide file tree
Showing 12 changed files with 744 additions and 46 deletions.
129 changes: 115 additions & 14 deletions arcjet/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ import {
ArcjetDecision,
ArcjetDenyDecision,
ArcjetErrorDecision,
ArcjetRateLimitRule,
ArcjetBotRule,
ArcjetRule,
ArcjetLocalRule,
ArcjetRequestDetails,
ArcjetTokenBucketRateLimitRule,
ArcjetFixedWindowRateLimitRule,
ArcjetSlidingWindowRateLimitRule,
} from "@arcjet/protocol";
import {
ArcjetBotTypeToProtocol,
Expand Down Expand Up @@ -414,14 +416,31 @@ function runtime(): Runtime {
}
}

export type RateLimitOptions = {
type TokenBucketRateLimitOptions = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
refillRate: number;
interval: number;
capacity: number;
};

type FixedWindowRateLimitOptions = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
window: string;
max: number;
};

type SlidingWindowRateLimitOptions = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
interval: number;
max: number;
};

/**
* Bot detection is disabled by default. The `bots` configuration block allows
* you to enable or disable it and configure additional rules.
Expand Down Expand Up @@ -570,30 +589,112 @@ function isLocalRule<Props extends PlainObject>(
);
}

export function tokenBucket(
options?: TokenBucketRateLimitOptions,
...additionalOptions: TokenBucketRateLimitOptions[]
): Primitive<{ requested: number }> {
const rules: ArcjetTokenBucketRateLimitRule<{ requested: number }>[] = [];

if (typeof options === "undefined") {
return rules;
}

for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = opt.match;
const characteristics = opt.characteristics;

const refillRate = opt.refillRate;
const interval = opt.interval;
const capacity = opt.capacity;

rules.push({
type: "RATE_LIMIT",
priority: Priority.RateLimit,
mode,
match,
characteristics,
algorithm: "TOKEN_BUCKET",
refillRate,
interval,
capacity,
});
}

return rules;
}

export function fixedWindow(
options?: FixedWindowRateLimitOptions,
...additionalOptions: FixedWindowRateLimitOptions[]
): Primitive {
const rules: ArcjetFixedWindowRateLimitRule<{}>[] = [];

if (typeof options === "undefined") {
return rules;
}

for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = opt.match;
const characteristics = opt.characteristics;

const max = opt.max;
const window = opt.window;

rules.push({
type: "RATE_LIMIT",
priority: Priority.RateLimit,
mode,
match,
characteristics,
algorithm: "FIXED_WINDOW",
max,
window,
});
}

return rules;
}

// This is currently kept for backwards compatibility but should be removed in
// favor of the fixedWindow primitive.
export function rateLimit(
options?: RateLimitOptions,
...additionalOptions: RateLimitOptions[]
options?: FixedWindowRateLimitOptions,
...additionalOptions: FixedWindowRateLimitOptions[]
): Primitive {
// TODO(#195): We should also have a local rate limit using an in-memory data
// structure if the environment supports it
return fixedWindow(options, ...additionalOptions);
}

const rules: ArcjetRateLimitRule<{}>[] = [];
export function slidingWindow(
options?: SlidingWindowRateLimitOptions,
...additionalOptions: SlidingWindowRateLimitOptions[]
): Primitive {
const rules: ArcjetSlidingWindowRateLimitRule<{}>[] = [];

if (typeof options === "undefined") {
return rules;
}

for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = opt.match;
const characteristics = opt.characteristics;

const max = opt.max;
const interval = opt.interval;

rules.push({
type: "RATE_LIMIT",
priority: Priority.RateLimit,
mode,
match: opt.match,
characteristics: opt.characteristics,
window: opt.window,
max: opt.max,
match,
characteristics,
algorithm: "SLIDING_WINDOW",
max,
interval,
});
}

Expand Down Expand Up @@ -674,7 +775,7 @@ export function detectBot(
// Always create at least one BOT rule
for (const opt of [options ?? {}, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
// TODO: Filter invalid email types (or error??)
// TODO: Filter invalid bot types (or error??)
const block = Array.isArray(opt.block)
? opt.block
: [ArcjetBotType.AUTOMATED];
Expand Down Expand Up @@ -766,7 +867,7 @@ export function detectBot(
}

export type ProtectSignupOptions = {
rateLimit?: RateLimitOptions | RateLimitOptions[];
rateLimit?: SlidingWindowRateLimitOptions | SlidingWindowRateLimitOptions[];
bots?: BotOptions | BotOptions[];
email?: EmailOptions | EmailOptions[];
};
Expand All @@ -776,9 +877,9 @@ export function protectSignup(
): Product<{ email: string }> {
let rateLimitRules: Primitive<{}> = [];
if (Array.isArray(options?.rateLimit)) {
rateLimitRules = rateLimit(...options.rateLimit);
rateLimitRules = slidingWindow(...options.rateLimit);
} else {
rateLimitRules = rateLimit(options?.rateLimit);
rateLimitRules = slidingWindow(options?.rateLimit);
}

let botRules: Primitive<{}> = [];
Expand All @@ -788,7 +889,7 @@ export function protectSignup(
botRules = detectBot(options?.bots);
}

let emailRules: Primitive<{}> = [];
let emailRules: Primitive<{ email: string }> = [];
if (Array.isArray(options?.email)) {
emailRules = validateEmail(...options.email);
} else {
Expand Down
12 changes: 11 additions & 1 deletion arcjet/test/index.edge.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { describe, expect, test, jest } from "@jest/globals";

import arcjet, {
rateLimit,
tokenBucket,
protectSignup,
Primitive,
ArcjetReason,
Expand Down Expand Up @@ -35,14 +36,23 @@ describe("Arcjet: Env = Edge runtime", () => {
rules: [
// Test rules
foobarbaz(),
rateLimit(),
tokenBucket({
refillRate: 1,
interval: 1,
capacity: 1,
}),
rateLimit({
max: 1,
window: "60s",
}),
protectSignup(),
],
client,
});

const decision = await aj.protect({
abc: 123,
requested: 1,
email: "",
ip: "",
method: "",
Expand Down
Loading

0 comments on commit 6701b02

Please sign in to comment.