Skip to content

Commit

Permalink
feat: Allow user-defined characteristics on rate limit options
Browse files Browse the repository at this point in the history
  • Loading branch information
blaine-arcjet committed Feb 7, 2024
1 parent b173d83 commit 97bfdf7
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 40 deletions.
143 changes: 108 additions & 35 deletions arcjet/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ function errorMessage(err: unknown): string {
// https://github.com/sindresorhus/type-fest/blob/964466c9d59c711da57a5297ad954c13132a0001/source/simplify.d.ts
// UnionToIntersection:
// https://github.com/sindresorhus/type-fest/blob/017bf38ebb52df37c297324d97bcc693ec22e920/source/union-to-intersection.d.ts
// IsNever:
// https://github.com/sindresorhus/type-fest/blob/e02f228f6391bb2b26c32a55dfe1e3aa2386d515/source/primitive.d.ts
// LiteralCheck & IsStringLiteral:
// https://github.com/sindresorhus/type-fest/blob/e02f228f6391bb2b26c32a55dfe1e3aa2386d515/source/is-literal.d.ts
//
// Licensed: MIT License Copyright (c) Sindre Sorhus <[email protected]>
// (https://sindresorhus.com)
Expand Down Expand Up @@ -149,6 +153,25 @@ type UnionToIntersection<Union> =
? // The `& Union` is to allow indexing by the resulting type
Intersection & Union
: never;
type IsNever<T> = [T] extends [never] ? true : false;
type LiteralCheck<
T,
LiteralType extends
| null
| undefined
| string
| number
| boolean
| symbol
| bigint,
> = IsNever<T> extends false // Must be wider than `never`
? [T] extends [LiteralType] // Must be narrower than `LiteralType`
? [LiteralType] extends [T] // Cannot be wider than `LiteralType`
? false
: true
: false
: false;
type IsStringLiteral<T> = LiteralCheck<T, string>;

export interface RemoteClient {
decide(
Expand Down Expand Up @@ -417,30 +440,31 @@ function runtime(): Runtime {
}
}

type TokenBucketRateLimitOptions = {
type TokenBucketRateLimitOptions<Characteristics extends readonly string[]> = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
characteristics?: Characteristics;
refillRate: number;
interval: string | number;
capacity: number;
};

type FixedWindowRateLimitOptions = {
type FixedWindowRateLimitOptions<Characteristics extends readonly string[]> = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
characteristics?: Characteristics;
window: string | number;
max: number;
};

type SlidingWindowRateLimitOptions = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
interval: string | number;
max: number;
};
type SlidingWindowRateLimitOptions<Characteristics extends readonly string[]> =
{
mode?: ArcjetMode;
match?: string;
characteristics?: Characteristics;
interval: string | number;
max: number;
};

/**
* Bot detection is disabled by default. The `bots` configuration block allows
Expand Down Expand Up @@ -550,6 +574,25 @@ type PlainObject = { [key: string]: unknown };
export type Primitive<Props extends PlainObject = {}> = ArcjetRule<Props>[];
export type Product<Props extends PlainObject = {}> = ArcjetRule<Props>[];

// User-defined characteristics alter the required props of an ArcjetRequest
// Note: If a user doesn't provide the object literal to our primitives
// directly, we fallback to no required props. They can opt-in by adding the
// `as const` suffix to the characteristics array.
type PropsForCharacteristic<T> = IsStringLiteral<T> extends true
? T extends
| "ip.src"
| "http.host"
| "http.method"
| "http.request.uri.path"
| `http.request.headers["${string}"]`
| `http.request.cookie["${string}"]`
| `http.request.uri.args["${string}"]`
? {}
: T extends string
? Record<T, string | number | boolean>
: never
: {};
// Rules can specify they require specific props on an ArcjetRequest
type PropsForRule<R> = R extends ArcjetRule<infer Props> ? Props : {};
// We theoretically support an arbitrary amount of rule flattening,
// but one level seems to be easiest; however, this puts a constraint of
Expand Down Expand Up @@ -590,10 +633,16 @@ function isLocalRule<Props extends PlainObject>(
);
}

export function tokenBucket(
options?: TokenBucketRateLimitOptions,
...additionalOptions: TokenBucketRateLimitOptions[]
): Primitive<{ requested: number }> {
export function tokenBucket<
const Characteristics extends readonly string[] = [],
>(
options?: TokenBucketRateLimitOptions<Characteristics>,
...additionalOptions: TokenBucketRateLimitOptions<Characteristics>[]
): Primitive<
UnionToIntersection<
{ requested: number } | PropsForCharacteristic<Characteristics[number]>
>
> {
const rules: ArcjetTokenBucketRateLimitRule<{ requested: number }>[] = [];

if (typeof options === "undefined") {
Expand All @@ -603,7 +652,9 @@ export function tokenBucket(
for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = opt.match;
const characteristics = opt.characteristics;
const characteristics = Array.isArray(opt.characteristics)
? opt.characteristics
: undefined;

const refillRate = opt.refillRate;
const interval = duration.parse(opt.interval);
Expand All @@ -625,10 +676,14 @@ export function tokenBucket(
return rules;
}

export function fixedWindow(
options?: FixedWindowRateLimitOptions,
...additionalOptions: FixedWindowRateLimitOptions[]
): Primitive {
export function fixedWindow<
const Characteristics extends readonly string[] = [],
>(
options?: FixedWindowRateLimitOptions<Characteristics>,
...additionalOptions: FixedWindowRateLimitOptions<Characteristics>[]
): Primitive<
UnionToIntersection<PropsForCharacteristic<Characteristics[number]>>
> {
const rules: ArcjetFixedWindowRateLimitRule<{}>[] = [];

if (typeof options === "undefined") {
Expand All @@ -638,7 +693,9 @@ export function fixedWindow(
for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = opt.match;
const characteristics = opt.characteristics;
const characteristics = Array.isArray(opt.characteristics)
? opt.characteristics
: undefined;

const max = opt.max;
const window = duration.parse(opt.window);
Expand All @@ -660,19 +717,25 @@ export function fixedWindow(

// This is currently kept for backwards compatibility but should be removed in
// favor of the fixedWindow primitive.
export function rateLimit(
options?: FixedWindowRateLimitOptions,
...additionalOptions: FixedWindowRateLimitOptions[]
): Primitive {
export function rateLimit<const Characteristics extends readonly string[] = []>(
options?: FixedWindowRateLimitOptions<Characteristics>,
...additionalOptions: FixedWindowRateLimitOptions<Characteristics>[]
): Primitive<
UnionToIntersection<PropsForCharacteristic<Characteristics[number]>>
> {
// TODO(#195): We should also have a local rate limit using an in-memory data
// structure if the environment supports it
return fixedWindow(options, ...additionalOptions);
}

export function slidingWindow(
options?: SlidingWindowRateLimitOptions,
...additionalOptions: SlidingWindowRateLimitOptions[]
): Primitive {
export function slidingWindow<
const Characteristics extends readonly string[] = [],
>(
options?: SlidingWindowRateLimitOptions<Characteristics>,
...additionalOptions: SlidingWindowRateLimitOptions<Characteristics>[]
): Primitive<
UnionToIntersection<PropsForCharacteristic<Characteristics[number]>>
> {
const rules: ArcjetSlidingWindowRateLimitRule<{}>[] = [];

if (typeof options === "undefined") {
Expand All @@ -682,7 +745,9 @@ export function slidingWindow(
for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = opt.match;
const characteristics = opt.characteristics;
const characteristics = Array.isArray(opt.characteristics)
? opt.characteristics
: undefined;

const max = opt.max;
const interval = duration.parse(opt.interval);
Expand Down Expand Up @@ -867,15 +932,23 @@ export function detectBot(
return rules;
}

export type ProtectSignupOptions = {
rateLimit?: SlidingWindowRateLimitOptions | SlidingWindowRateLimitOptions[];
export type ProtectSignupOptions<Characteristics extends string[]> = {
rateLimit?:
| SlidingWindowRateLimitOptions<Characteristics>
| SlidingWindowRateLimitOptions<Characteristics>[];
bots?: BotOptions | BotOptions[];
email?: EmailOptions | EmailOptions[];
};

export function protectSignup(
options?: ProtectSignupOptions,
): Product<{ email: string }> {
export function protectSignup<const Characteristics extends string[] = []>(
options?: ProtectSignupOptions<Characteristics>,
): Product<
Simplify<
UnionToIntersection<
{ email: string } | PropsForCharacteristic<Characteristics[number]>
>
>
> {
let rateLimitRules: Primitive<{}> = [];
if (Array.isArray(options?.rateLimit)) {
rateLimitRules = slidingWindow(...options.rateLimit);
Expand Down
29 changes: 24 additions & 5 deletions arcjet/test/index.edge.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,28 @@ describe("Arcjet: Env = Edge runtime", () => {
rules: [
// Test rules
foobarbaz(),
tokenBucket({
refillRate: 1,
interval: 1,
capacity: 1,
}),
tokenBucket(
{
characteristics: [
"ip.src",
"http.host",
"http.method",
"http.request.uri.path",
`http.request.headers["abc"]`,
`http.request.cookie["xyz"]`,
`http.request.uri.args["foobar"]`,
],
refillRate: 1,
interval: 1,
capacity: 1,
},
{
characteristics: ["userId"],
refillRate: 1,
interval: 1,
capacity: 1,
},
),
rateLimit({
max: 1,
window: "60s",
Expand All @@ -61,6 +78,8 @@ describe("Arcjet: Env = Edge runtime", () => {
path: "",
headers: new Headers(),
extra: {},
userId: "user123",
foobar: 123,
});

expect(decision.isErrored()).toBe(false);
Expand Down

0 comments on commit 97bfdf7

Please sign in to comment.