From f151a96826c2d65737ec780c4c51e3cbaef6815c Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Fri, 2 Feb 2024 15:06:05 -0700 Subject: [PATCH 1/3] feat!: Add fixedWindow, tokenBucket, and slidingWindow primitives feat!: Rework primitives to build rules with config & request details --- arcjet/index.ts | 451 ++++++++++++----- arcjet/test/index.edge.test.ts | 34 +- arcjet/test/index.node.test.ts | 461 +++++++++--------- .../nextjs-14-openai/app/api/chat/route.ts | 32 +- protocol/convert.ts | 62 ++- .../gen/es/decide/v1alpha1/decide_pb.d.ts | 103 +++- protocol/gen/es/decide/v1alpha1/decide_pb.js | 26 +- protocol/index.ts | 61 ++- protocol/test/convert.test.ts | 41 +- 9 files changed, 855 insertions(+), 416 deletions(-) diff --git a/arcjet/index.ts b/arcjet/index.ts index 5ef873eea..34d96b6ec 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -18,6 +18,11 @@ import { ArcjetRule, ArcjetLocalRule, ArcjetRequestDetails, + ArcjetPrimitive, + ArcjetProduct, + ArcjetTokenBucketRateLimitRule, + ArcjetFixedWindowRateLimitRule, + ArcjetSlidingWindowRateLimitRule, } from "@arcjet/protocol"; import { ArcjetBotTypeToProtocol, @@ -371,13 +376,29 @@ function runtime(): Runtime { } } -export type RateLimitOptions = { +type TokenBucketRateLimitOptions = { + mode?: ArcjetMode; + match?: string; + characteristics?: string[]; + refillRate: number; + interval: number; + capacity: number; +}; + +type FixedWindowRateLimitOptions = { mode?: ArcjetMode; match?: string; characteristics?: string[]; window: string; max: number; - timeout: string; +}; + +type SlidingWindowRateLimitOptions = { + mode?: ArcjetMode; + match?: string; + characteristics?: string[]; + interval: number; + max: number; }; /** @@ -482,28 +503,33 @@ const Priority = { type PlainObject = { [key: string]: unknown }; -type PropsForRule = R extends ArcjetRule ? Props : {}; +// Primitives and Products can be specified in a variety of ways and are +// externally grouped as `rules` +// See ExtraRules below for further explanation on why we define them like this. +type PrimitivesOrProduct = + | ArcjetPrimitive + | ArcjetPrimitive[] + | ArcjetProduct; + +type PropsForRule = R extends PrimitivesOrProduct[] + ? Props + : R extends PrimitivesOrProduct + ? Props + : {}; // We theoretically support an arbitrary amount of rule flattening, // but one level seems to be easiest; however, this puts a constraint of // the definition of `Product` such that they need to spread each `Primitive` // they are re-exporting. export type ExtraProps = Rules extends [] ? {} - : Rules extends ArcjetRule[][] - ? UnionToIntersection> - : Rules extends ArcjetRule[] - ? UnionToIntersection> - : never; + : Rules extends PrimitivesOrProduct[] + ? UnionToIntersection> + : never; export type ArcjetRequest = Simplify< - Partial + Partial & Props >; -// Primitives and Products are the external names for Rules even though they are defined the same -// See ArcjetRequest above for the explanation on why we define them like this. -export type Primitive = ArcjetRule[]; -export type Product = ArcjetRule[]; - function isLocalRule( rule: ArcjetRule, ): rule is ArcjetLocalRule { @@ -515,63 +541,227 @@ function isLocalRule( ); } -export function rateLimit( - options?: RateLimitOptions, - ...additionalOptions: RateLimitOptions[] -): Primitive { - // TODO(#195): We should also have a local rate limit using an in-memory data - // structure if the environment supports it +class ArcjetTokenBucketRateLimitPrimitive extends ArcjetPrimitive<{ + requested: number; +}> { + priority = Priority.RateLimit; + + mode: ArcjetMode; + match?: string; + characteristics?: string[]; + refillRate: number; + interval: number; + capacity: number; + + constructor(options: TokenBucketRateLimitOptions) { + super(); + + this.mode = options.mode === "LIVE" ? "LIVE" : "DRY_RUN"; + this.match = options.match; + this.characteristics = options.characteristics; + + this.refillRate = options.refillRate; + this.interval = options.interval; + this.capacity = options.capacity; + } + + rule( + context: ArcjetContext, + details: Partial, + ): ArcjetTokenBucketRateLimitRule<{ requested: number }> { + return { + type: "RATE_LIMIT", + mode: this.mode, + match: this.match, + characteristics: this.characteristics, + algorithm: "TOKEN_BUCKET", + refillRate: this.refillRate, + interval: this.interval, + capacity: this.capacity, + requested: typeof details.requested === "number" ? details.requested : 1, + }; + } +} + +class ArcjetFixedWindowRateLimitPrimitive extends ArcjetPrimitive { + priority = Priority.RateLimit; + + mode: ArcjetMode; + match?: string; + characteristics?: string[]; - const rules: ArcjetRateLimitRule<{}>[] = []; + max: number; + window: string; + + constructor(options: FixedWindowRateLimitOptions) { + super(); + + this.mode = options.mode === "LIVE" ? "LIVE" : "DRY_RUN"; + this.match = options.match; + this.characteristics = options.characteristics; + + this.max = options.max; + this.window = options.window; + } + + rule( + context: ArcjetContext, + details: Partial, + ): ArcjetFixedWindowRateLimitRule<{}> { + return { + type: "RATE_LIMIT", + mode: this.mode, + match: this.match, + characteristics: this.characteristics, + algorithm: "FIXED_WINDOW", + max: this.max, + window: this.window, + }; + } +} + +class ArcjetSlidingWindowRateLimitPrimitive extends ArcjetPrimitive { + priority = Priority.RateLimit; + + mode: ArcjetMode; + match?: string; + characteristics?: string[]; + + max: number; + interval: number; + + constructor(options: SlidingWindowRateLimitOptions) { + super(); + + this.mode = options.mode === "LIVE" ? "LIVE" : "DRY_RUN"; + this.match = options.match; + this.characteristics = options.characteristics; + + this.max = options.max; + this.interval = options.interval; + } + + rule( + context: ArcjetContext, + details: Partial, + ): ArcjetSlidingWindowRateLimitRule<{}> { + return { + type: "RATE_LIMIT", + mode: this.mode, + match: this.match, + characteristics: this.characteristics, + algorithm: "SLIDING_WINDOW", + max: this.max, + interval: this.interval, + }; + } +} + +export function tokenBucket( + options?: TokenBucketRateLimitOptions, + ...additionalOptions: TokenBucketRateLimitOptions[] +) { + const primitives: ArcjetTokenBucketRateLimitPrimitive[] = []; if (typeof options === "undefined") { - return rules; + return primitives; } for (const opt of [options, ...additionalOptions]) { - const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN"; + primitives.push(new ArcjetTokenBucketRateLimitPrimitive(opt)); + } - rules.push({ - type: "RATE_LIMIT", - priority: Priority.RateLimit, - mode, - match: opt.match, - characteristics: opt.characteristics, - window: opt.window, - max: opt.max, - timeout: opt.timeout, - }); + return primitives; +} + +export function fixedWindow( + options?: FixedWindowRateLimitOptions, + ...additionalOptions: FixedWindowRateLimitOptions[] +) { + const primitives: ArcjetFixedWindowRateLimitPrimitive[] = []; + + if (typeof options === "undefined") { + return primitives; } - return rules; + for (const opt of [options, ...additionalOptions]) { + primitives.push(new ArcjetFixedWindowRateLimitPrimitive(opt)); + } + + return primitives; } -export function validateEmail( - options?: EmailOptions, - ...additionalOptions: EmailOptions[] -): Primitive<{ email: string }> { - const rules: ArcjetEmailRule<{ email: string }>[] = []; +// This is currently kept for backwards compatibility but should be removed in +// favor of the fixedWindow primitive. +export function rateLimit( + options?: FixedWindowRateLimitOptions, + ...additionalOptions: FixedWindowRateLimitOptions[] +) { + const primitives: ArcjetFixedWindowRateLimitPrimitive[] = []; - // Always create at least one EMAIL rule - for (const opt of [options ?? {}, ...additionalOptions]) { - const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN"; + if (typeof options === "undefined") { + return primitives; + } + + for (const opt of [options, ...additionalOptions]) { + primitives.push(new ArcjetFixedWindowRateLimitPrimitive(opt)); + } + + return primitives; +} + +export function slidingWindow( + options?: SlidingWindowRateLimitOptions, + ...additionalOptions: SlidingWindowRateLimitOptions[] +) { + const primitives: ArcjetSlidingWindowRateLimitPrimitive[] = []; + + if (typeof options === "undefined") { + return primitives; + } + + for (const opt of [options, ...additionalOptions]) { + primitives.push(new ArcjetSlidingWindowRateLimitPrimitive(opt)); + } + + return primitives; +} + +class ArcjetValidateEmailPrimitive extends ArcjetPrimitive<{ email: string }> { + priority = Priority.EmailValidation; + + mode: ArcjetMode; + block: ArcjetEmailType[]; + + requireTopLevelDomain: boolean; + allowDomainLiteral: boolean; + + constructor(options: EmailOptions) { + super(); + + this.mode = options.mode === "LIVE" ? "LIVE" : "DRY_RUN"; // TODO: Filter invalid email types (or error??) - const block = opt.block ?? []; - const requireTopLevelDomain = opt.requireTopLevelDomain ?? true; - const allowDomainLiteral = opt.allowDomainLiteral ?? false; + this.block = options.block ?? []; + this.requireTopLevelDomain = options.requireTopLevelDomain ?? true; + this.allowDomainLiteral = options.allowDomainLiteral ?? false; + } + + rule( + context: ArcjetContext, + details: Partial, + ): ArcjetEmailRule<{ email: string }> { const analyzeOpts = { - requireTopLevelDomain, - allowDomainLiteral, + requireTopLevelDomain: this.requireTopLevelDomain, + allowDomainLiteral: this.allowDomainLiteral, }; - rules.push({ + return { type: "EMAIL", - priority: Priority.EmailValidation, - mode, - block, - requireTopLevelDomain, - allowDomainLiteral, + mode: this.mode, + block: this.block, + requireTopLevelDomain: this.requireTopLevelDomain, + allowDomainLiteral: this.allowDomainLiteral, validate( context: ArcjetContext, @@ -605,41 +795,60 @@ export function validateEmail( }); } }, - }); + }; } - - return rules; } -export function detectBot( - options?: BotOptions, - ...additionalOptions: BotOptions[] -): Primitive { - const rules: ArcjetBotRule<{}>[] = []; +export function validateEmail( + options?: EmailOptions, + ...additionalOptions: EmailOptions[] +) { + const primitives: ArcjetValidateEmailPrimitive[] = []; - // Always create at least one BOT rule + // Always create at least one EMAIL rule for (const opt of [options ?? {}, ...additionalOptions]) { - const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN"; + primitives.push(new ArcjetValidateEmailPrimitive(opt)); + } + + return primitives; +} + +class ArcjetDetectBotPrimitive extends ArcjetPrimitive { + priority = Priority.BotDetection; + + mode: ArcjetMode; + block: ArcjetBotType[]; + add: [string, ArcjetBotType][]; + remove: string[]; + + constructor(options: BotOptions) { + super(); + + this.mode = options.mode === "LIVE" ? "LIVE" : "DRY_RUN"; // TODO: Filter invalid email types (or error??) - const block = Array.isArray(opt.block) - ? opt.block + this.block = Array.isArray(options.block) + ? options.block : [ArcjetBotType.AUTOMATED]; // TODO: Does this avoid prototype pollution by putting in a Map first? - const addMap = new Map(); - for (const [key, value] of Object.entries(opt.patterns?.add ?? {})) { + const addMap = new Map(); + for (const [key, value] of Object.entries(options.patterns?.add ?? {})) { addMap.set(key, value); } // TODO(#217): Additional validation on these `patterns` options - const add = Array.from(addMap.entries()); - const remove = opt.patterns?.remove ?? []; + this.add = Array.from(addMap.entries()); + this.remove = options.patterns?.remove ?? []; + } - rules.push({ + rule( + context: ArcjetContext, + details: Partial, + ): ArcjetBotRule<{}> { + return { type: "BOT", - priority: Priority.BotDetection, - mode, - block, - add, - remove, + mode: this.mode, + block: this.block, + add: this.add, + remove: this.remove, validate( context: ArcjetContext, @@ -668,19 +877,19 @@ export function detectBot( JSON.stringify(headersKV), JSON.stringify( Object.fromEntries( - add.map(([key, botType]) => [ + this.add.map(([key, botType]) => [ key, ArcjetBotTypeToProtocol(botType), ]), ), ), - JSON.stringify(remove), + JSON.stringify(this.remove), ); // If this is a bot and of a type that we want to block, then block! if ( botResult.bot_score !== 0 && - block.includes(BotType[botResult.bot_type] as ArcjetBotType) + this.block.includes(BotType[botResult.bot_type] as ArcjetBotType) ) { return new ArcjetRuleResult({ ttl: 60000, @@ -705,46 +914,62 @@ export function detectBot( }); } }, - }); + }; } +} - return rules; +export function detectBot( + options?: BotOptions, + ...additionalOptions: BotOptions[] +) { + const primitives: ArcjetDetectBotPrimitive[] = []; + + // Always create at least one BOT rule + for (const opt of [options ?? {}, ...additionalOptions]) { + primitives.push(new ArcjetDetectBotPrimitive(opt)); + } + + return primitives; } export type ProtectSignupOptions = { - rateLimit?: RateLimitOptions | RateLimitOptions[]; + rateLimit?: SlidingWindowRateLimitOptions | SlidingWindowRateLimitOptions[]; bots?: BotOptions | BotOptions[]; email?: EmailOptions | EmailOptions[]; }; export function protectSignup( options?: ProtectSignupOptions, -): Product<{ email: string }> { - let rateLimitRules: Primitive<{}> = []; +): ArcjetProduct<{ email: string }> { + let slidingWindowPrimitives: ArcjetSlidingWindowRateLimitPrimitive[] = []; if (Array.isArray(options?.rateLimit)) { - rateLimitRules = rateLimit(...options.rateLimit); + slidingWindowPrimitives = slidingWindow(...options.rateLimit); } else { - rateLimitRules = rateLimit(options?.rateLimit); + slidingWindowPrimitives = slidingWindow(options?.rateLimit); } - let botRules: Primitive<{}> = []; + let detectBotPrimitives: ArcjetDetectBotPrimitive[] = []; if (Array.isArray(options?.bots)) { - botRules = detectBot(...options.bots); + detectBotPrimitives = detectBot(...options.bots); } else { - botRules = detectBot(options?.bots); + detectBotPrimitives = detectBot(options?.bots); } - let emailRules: Primitive<{}> = []; + let emailPrimitives: ArcjetValidateEmailPrimitive[] = []; if (Array.isArray(options?.email)) { - emailRules = validateEmail(...options.email); + emailPrimitives = validateEmail(...options.email); } else { - emailRules = validateEmail(options?.email); + emailPrimitives = validateEmail(options?.email); } - return [...rateLimitRules, ...botRules, ...emailRules]; + return [ + ...slidingWindowPrimitives, + ...detectBotPrimitives, + ...emailPrimitives, + ]; } -export interface ArcjetOptions { +export interface ArcjetOptions { /** * The API key to identify the site in Arcjet. */ @@ -789,9 +1014,9 @@ export interface Arcjet { * * @param options {ArcjetOptions} Arcjet configuration options. */ -export default function arcjet< - const Rules extends [...(Primitive | Product)[]] = [], ->(options: ArcjetOptions): Arcjet>> { +export default function arcjet( + options: ArcjetOptions

, +): Arcjet>> { const log = new Logger(); // We destructure here to make the function signature neat when viewed by consumers @@ -808,14 +1033,20 @@ export default function arcjet< // TODO(#132): Support configurable caching const blockCache = new Cache(); - const flatSortedRules = rules.flat(1).sort((a, b) => a.priority - b.priority); + // However, we like the user-facing concept of `rules` specified as options to + // avoid the need for a distinction between primitives and products. However, + // we map the `rules` option to primitives internally because we still need to + // call `primitive.rule(context, details)` to produce the actual rule. + const flatSortedPrimitives = rules + .flat(1) + .sort((a, b) => a.priority - b.priority); return Object.freeze({ get runtime() { return runtime(); }, async protect( - request: ArcjetRequest>, + request: ArcjetRequest>, ): Promise { // This goes against the type definition above, but users might call // `protect()` with no value and we don't want to crash @@ -844,7 +1075,7 @@ export default function arcjet< const context: ArcjetContext = { key, fingerprint, log }; - if (flatSortedRules.length > 10) { + if (flatSortedPrimitives.length > 10) { log.error("Failure running rules. Only 10 rules may be specified."); const decision = new ArcjetErrorDecision({ @@ -867,9 +1098,14 @@ export default function arcjet< return decision; } + const rules: ArcjetRule[] = []; + for (const primitive of flatSortedPrimitives) { + rules.push(primitive.rule(context, details)); + } + const results: ArcjetRuleResult[] = []; // Default all rules to NOT_RUN/ALLOW before doing anything - for (let idx = 0; idx < flatSortedRules.length; idx++) { + for (let idx = 0; idx < rules.length; idx++) { results[idx] = new ArcjetRuleResult({ ttl: 0, state: "NOT_RUN", @@ -901,7 +1137,7 @@ export default function arcjet< results, }); - client.report(context, details, decision, flatSortedRules); + client.report(context, details, decision, rules); log.debug("decide: already blocked", { id: decision.id, @@ -914,7 +1150,7 @@ export default function arcjet< return decision; } - for (const [idx, rule] of flatSortedRules.entries()) { + for (const [idx, rule] of rules.entries()) { // This re-assignment is a workaround to a TypeScript error with // assertions where the name was introduced via a destructure let localRule: ArcjetLocalRule; @@ -968,7 +1204,7 @@ export default function arcjet< // Only a DENY decision is reported to avoid creating 2 entries for a // request. Upon ALLOW, the `decide` call will create an entry for the // request. - client.report(context, details, decision, flatSortedRules); + client.report(context, details, decision, rules); // If we're not in DRY_RUN mode, we want to cache non-zero TTL results // and return this DENY decision. @@ -1001,7 +1237,7 @@ export default function arcjet< // fail open. try { log.time("decideApi"); - const decision = await client.decide(context, details, flatSortedRules); + const decision = await client.decide(context, details, rules); log.timeEnd("decideApi"); // If the decision is to block and we have a non-zero TTL, we cache the @@ -1027,12 +1263,7 @@ export default function arcjet< results, }); - client.report( - { key, fingerprint, log }, - details, - decision, - flatSortedRules, - ); + client.report({ key, fingerprint, log }, details, decision, rules); return decision; } finally { diff --git a/arcjet/test/index.edge.test.ts b/arcjet/test/index.edge.test.ts index 9ec353aef..add06cecb 100644 --- a/arcjet/test/index.edge.test.ts +++ b/arcjet/test/index.edge.test.ts @@ -4,11 +4,16 @@ import { describe, expect, test, jest } from "@jest/globals"; import arcjet, { - rateLimit, - protectSignup, - Primitive, + tokenBucket, + // protectSignup, + // Primitive, ArcjetReason, ArcjetAllowDecision, + ArcjetPrimitive, + ArcjetRule, + fixedWindow, + validateEmail, + protectSignup, } from "../index"; class ArcjetTestReason extends ArcjetReason {} @@ -26,8 +31,15 @@ describe("Arcjet: Env = Edge runtime", () => { report: jest.fn(), }; - function foobarbaz(): Primitive<{ abc: number }> { - return []; + function foobarbaz(): ArcjetPrimitive<{ abc: number }> { + const testRule = class extends ArcjetPrimitive { + priority = 1; + rule(): ArcjetRule<{ abc: number }> { + return { type: "abc", mode: "DRY_RUN" }; + } + }; + + return new testRule(); } const aj = arcjet({ @@ -35,7 +47,16 @@ describe("Arcjet: Env = Edge runtime", () => { rules: [ // Test rules foobarbaz(), - rateLimit(), + tokenBucket({ + refillRate: 1, + interval: 1, + capacity: 1, + }), + fixedWindow({ + max: 1, + window: "60s", + }), + // validateEmail(), protectSignup(), ], client, @@ -43,6 +64,7 @@ describe("Arcjet: Env = Edge runtime", () => { const decision = await aj.protect({ abc: 123, + requested: 1, email: "", ip: "", method: "", diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index a7dbd7304..bc0c37e25 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -51,7 +51,10 @@ import arcjet, { ArcjetEmailReason, ArcjetBotReason, ArcjetRateLimitReason, + ArcjetPrimitive, ArcjetLocalRule, + ArcjetContext, + ArcjetRequestDetails, } from "../index"; // Instances of Headers contain symbols that may be different depending @@ -444,7 +447,6 @@ describe("createRemoteClient", () => { const rule: ArcjetRule = { type: "TEST_RULE", mode: "DRY_RUN", - priority: 1, }; const _ = await client.decide(context, details, [rule]); @@ -1117,7 +1119,6 @@ describe("createRemoteClient", () => { const rule: ArcjetRule = { type: "TEST_RULE", mode: "LIVE", - priority: 1, }; client.report(context, details, decision, [rule]); @@ -1398,7 +1399,6 @@ describe("ArcjetDecision", () => { test("`isRateLimit()` returns true when reason is RATE_LIMIT", () => { const reason = new ArcjetRateLimitReason({ max: 0, - count: 0, remaining: 0, }); expect(reason.isRateLimit()).toEqual(true); @@ -1423,22 +1423,22 @@ describe("ArcjetDecision", () => { }); describe("Primitives > detectBot", () => { - test("provides a default rule with no options specified", async () => { - const [rule] = detectBot(); - expect(rule.type).toEqual("BOT"); - expect(rule).toHaveProperty("mode", "DRY_RUN"); - expect(rule).toHaveProperty("block", ["AUTOMATED"]); - expect(rule).toHaveProperty("add", []); - expect(rule).toHaveProperty("remove", []); + test("provides a default primitive with no options specified", async () => { + const [primitive] = detectBot(); + expect(primitive).toHaveProperty("priority", 2); + expect(primitive).toHaveProperty("rule"); + expect(primitive).toHaveProperty("mode", "DRY_RUN"); + expect(primitive).toHaveProperty("block", ["AUTOMATED"]); + expect(primitive).toHaveProperty("add", []); + expect(primitive).toHaveProperty("remove", []); }); test("sets mode as 'DRY_RUN' if not 'LIVE' or 'DRY_RUN'", async () => { - const [rule] = detectBot({ + const [primitive] = detectBot({ // @ts-expect-error mode: "INVALID", }); - expect(rule.type).toEqual("BOT"); - expect(rule).toHaveProperty("mode", "DRY_RUN"); + expect(primitive).toHaveProperty("mode", "DRY_RUN"); }); test("allows specifying BotTypes to block", async () => { @@ -1451,9 +1451,8 @@ describe("Primitives > detectBot", () => { ], }; - const [rule] = detectBot(options); - expect(rule.type).toEqual("BOT"); - expect(rule).toHaveProperty("block", [ + const [primitive] = detectBot(options); + expect(primitive).toHaveProperty("block", [ "LIKELY_AUTOMATED", "LIKELY_NOT_A_BOT", "NOT_ANALYZED", @@ -1470,9 +1469,8 @@ describe("Primitives > detectBot", () => { }, }; - const [rule] = detectBot(options); - expect(rule.type).toEqual("BOT"); - expect(rule).toHaveProperty("add", [["safari", "LIKELY_AUTOMATED"]]); + const [primitive] = detectBot(options); + expect(primitive).toHaveProperty("add", [["safari", "LIKELY_AUTOMATED"]]); }); test("allows specifying `remove` patterns", async () => { @@ -1482,9 +1480,8 @@ describe("Primitives > detectBot", () => { }, }; - const [rule] = detectBot(options); - expect(rule.type).toEqual("BOT"); - expect(rule).toHaveProperty("remove", ["^curl"]); + const [primitive] = detectBot(options); + expect(primitive).toHaveProperty("remove", ["^curl"]); }); test("validates that headers is defined", () => { @@ -1497,8 +1494,8 @@ describe("Primitives > detectBot", () => { headers: new Headers(), }; - const [rule] = detectBot(); - expect(rule.type).toEqual("BOT"); + const [primitive] = detectBot(); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); expect(() => { const _ = rule.validate(context, details); @@ -1515,8 +1512,8 @@ describe("Primitives > detectBot", () => { headers: undefined, }; - const [rule] = detectBot(); - expect(rule.type).toEqual("BOT"); + const [primitive] = detectBot(); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); expect(() => { const _ = rule.validate(context, details); @@ -1539,8 +1536,8 @@ describe("Primitives > detectBot", () => { extra: {}, }; - const [rule] = detectBot(); - expect(rule.type).toEqual("BOT"); + const [primitive] = detectBot(); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -1588,8 +1585,8 @@ describe("Primitives > detectBot", () => { }, }; - const [rule] = detectBot(options); - expect(rule.type).toEqual("BOT"); + const [primitive] = detectBot(options); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -1640,8 +1637,8 @@ describe("Primitives > detectBot", () => { }, }; - const [rule] = detectBot(options); - expect(rule.type).toEqual("BOT"); + const [primitive] = detectBot(options); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -1692,8 +1689,8 @@ describe("Primitives > detectBot", () => { }, }; - const [rule] = detectBot(options); - expect(rule.type).toEqual("BOT"); + const [primitive] = detectBot(options); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -1731,7 +1728,7 @@ describe("Primitives > detectBot", () => { }, }; - const [rule] = detectBot({ + const [primitive] = detectBot({ mode: ArcjetMode.LIVE, block: [ // TODO: Fix this in the analyze code so it returns the BotType specified via `add` @@ -1746,7 +1743,7 @@ describe("Primitives > detectBot", () => { }, }, }); - expect(rule.type).toEqual("BOT"); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -1787,8 +1784,8 @@ describe("Primitives > detectBot", () => { }, }; - const [rule] = detectBot(options); - expect(rule.type).toEqual("BOT"); + const [primitive] = detectBot(options); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -1838,8 +1835,8 @@ describe("Primitives > detectBot", () => { }, }; - const [rule] = detectBot(options); - expect(rule.type).toEqual("BOT"); + const [primitive] = detectBot(options); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -1878,8 +1875,8 @@ describe("Primitives > detectBot", () => { }, }; - const [rule] = detectBot(options); - expect(rule.type).toEqual("BOT"); + const [primitive] = detectBot(options); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -1894,13 +1891,13 @@ describe("Primitives > detectBot", () => { }); describe("Primitive > rateLimit", () => { - test("provides no rules if no `options` specified", () => { - const rules = rateLimit(); - expect(rules).toHaveLength(0); + test("provides no primitives if no `options` specified", () => { + const primitives = rateLimit(); + expect(primitives).toHaveLength(0); }); test("sets mode as `DRY_RUN` if not 'LIVE' or 'DRY_RUN'", async () => { - const [rule] = rateLimit({ + const [primitive] = rateLimit({ // @ts-expect-error mode: "INVALID", match: "/test", @@ -1909,68 +1906,60 @@ describe("Primitive > rateLimit", () => { max: 1, timeout: "10m", }); - expect(rule.type).toEqual("RATE_LIMIT"); - expect(rule).toHaveProperty("mode", "DRY_RUN"); + expect(primitive).toHaveProperty("mode", "DRY_RUN"); }); - test("produces a rules based on single `limit` specified", async () => { + test("produces a primitive based on single `limit` specified", async () => { const options = { match: "/test", characteristics: ["ip.src"], window: "1h", max: 1, - timeout: "10m", }; - const rules = rateLimit(options); - expect(rules).toHaveLength(1); - expect(rules[0].type).toEqual("RATE_LIMIT"); - expect(rules[0]).toHaveProperty("mode", "DRY_RUN"); - expect(rules[0]).toHaveProperty("match", "/test"); - expect(rules[0]).toHaveProperty("characteristics", ["ip.src"]); - expect(rules[0]).toHaveProperty("window", "1h"); - expect(rules[0]).toHaveProperty("max", 1); - expect(rules[0]).toHaveProperty("timeout", "10m"); + const primitives = rateLimit(options); + expect(primitives).toHaveLength(1); + expect(primitives[0]).toHaveProperty("mode", "DRY_RUN"); + expect(primitives[0]).toHaveProperty("match", "/test"); + expect(primitives[0]).toHaveProperty("characteristics", ["ip.src"]); + expect(primitives[0]).toHaveProperty("window", "1h"); + expect(primitives[0]).toHaveProperty("max", 1); }); - test("produces a multiple rules based on multiple `limit` specified", async () => { + test("produces a multiple primitives based on multiple `limit` specified", async () => { const options = [ { match: "/test", characteristics: ["ip.src"], window: "1h", max: 1, - timeout: "10m", }, { match: "/test-double", characteristics: ["ip.src"], window: "2h", max: 2, - timeout: "20m", }, ]; - const rules = rateLimit(...options); - expect(rules).toHaveLength(2); - expect(rules).toEqual([ + const primitives = rateLimit(...options); + expect(primitives).toHaveLength(2); + expect(primitives).toEqual([ expect.objectContaining({ - type: "RATE_LIMIT", + priority: 1, mode: "DRY_RUN", match: "/test", characteristics: ["ip.src"], window: "1h", max: 1, - timeout: "10m", }), expect.objectContaining({ - type: "RATE_LIMIT", + priority: 1, mode: "DRY_RUN", match: "/test-double", characteristics: ["ip.src"], window: "2h", max: 2, - timeout: "20m", }), ]); }); @@ -1982,10 +1971,9 @@ describe("Primitive > rateLimit", () => { timeout: "10m", }; - const [rule] = rateLimit(options); - expect(rule.type).toEqual("RATE_LIMIT"); - expect(rule).toHaveProperty("match", undefined); - expect(rule).toHaveProperty("characteristics", undefined); + const [primitives] = rateLimit(options); + expect(primitives).toHaveProperty("match", undefined); + expect(primitives).toHaveProperty("characteristics", undefined); }); test("does not default `match` or `characteristics` if not specified in array `limit`", async () => { @@ -2002,48 +1990,43 @@ describe("Primitive > rateLimit", () => { }, ]; - const rules = rateLimit(...options); - expect(rules).toEqual([ + const primitives = rateLimit(...options); + expect(primitives).toEqual([ expect.objectContaining({ - type: "RATE_LIMIT", + priority: 1, mode: "DRY_RUN", match: undefined, characteristics: undefined, window: "1h", max: 1, - timeout: "10m", }), expect.objectContaining({ - type: "RATE_LIMIT", + priority: 1, mode: "DRY_RUN", match: undefined, characteristics: undefined, window: "2h", max: 2, - timeout: "20m", }), ]); }); }); describe("Primitives > validateEmail", () => { - test("provides a default rule with no options specified", async () => { - const [rule] = validateEmail(); - expect(rule.type).toEqual("EMAIL"); - expect(rule).toHaveProperty("mode", "DRY_RUN"); - expect(rule).toHaveProperty("block", []); - expect(rule).toHaveProperty("requireTopLevelDomain", true); - expect(rule).toHaveProperty("allowDomainLiteral", false); - assertIsLocalRule(rule); + test("provides a default primitive with no options specified", async () => { + const [primitive] = validateEmail(); + expect(primitive).toHaveProperty("mode", "DRY_RUN"); + expect(primitive).toHaveProperty("block", []); + expect(primitive).toHaveProperty("requireTopLevelDomain", true); + expect(primitive).toHaveProperty("allowDomainLiteral", false); }); test("sets mode as 'DRY_RUN' if not 'LIVE' or 'DRY_RUN'", async () => { - const [rule] = validateEmail({ + const [primitive] = validateEmail({ // @ts-expect-error mode: "INVALID", }); - expect(rule.type).toEqual("EMAIL"); - expect(rule).toHaveProperty("mode", "DRY_RUN"); + expect(primitive).toHaveProperty("mode", "DRY_RUN"); }); test("allows specifying EmailTypes to block", async () => { @@ -2057,9 +2040,8 @@ describe("Primitives > validateEmail", () => { ], }; - const [rule] = validateEmail(options); - expect(rule.type).toEqual("EMAIL"); - expect(rule).toHaveProperty("block", [ + const [primitive] = validateEmail(options); + expect(primitive).toHaveProperty("block", [ "DISPOSABLE", "FREE", "NO_GRAVATAR", @@ -2078,8 +2060,8 @@ describe("Primitives > validateEmail", () => { email: "abc@example.com", }; - const [rule] = validateEmail(); - expect(rule.type).toEqual("EMAIL"); + const [primitive] = validateEmail(); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); expect(() => { const _ = rule.validate(context, details); @@ -2096,8 +2078,8 @@ describe("Primitives > validateEmail", () => { email: undefined, }; - const [rule] = validateEmail(); - expect(rule.type).toEqual("EMAIL"); + const [primitive] = validateEmail(); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); expect(() => { const _ = rule.validate(context, details); @@ -2121,8 +2103,8 @@ describe("Primitives > validateEmail", () => { email: "foobarbaz@example.com", }; - const [rule] = validateEmail(); - expect(rule.type).toEqual("EMAIL"); + const [primitive] = validateEmail(); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -2151,8 +2133,8 @@ describe("Primitives > validateEmail", () => { email: "foobarbaz", }; - const [rule] = validateEmail(); - expect(rule.type).toEqual("EMAIL"); + const [primitive] = validateEmail(); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -2181,8 +2163,8 @@ describe("Primitives > validateEmail", () => { email: "foobarbaz@localhost", }; - const [rule] = validateEmail(); - expect(rule.type).toEqual("EMAIL"); + const [primitive] = validateEmail(); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -2211,10 +2193,10 @@ describe("Primitives > validateEmail", () => { email: "foobarbaz@localhost", }; - const [rule] = validateEmail({ + const [primitive] = validateEmail({ block: [], }); - expect(rule.type).toEqual("EMAIL"); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -2243,8 +2225,8 @@ describe("Primitives > validateEmail", () => { email: "@example.com", }; - const [rule] = validateEmail(); - expect(rule.type).toEqual("EMAIL"); + const [primitive] = validateEmail(); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -2273,8 +2255,8 @@ describe("Primitives > validateEmail", () => { email: "foobarbaz@[127.0.0.1]", }; - const [rule] = validateEmail(); - expect(rule.type).toEqual("EMAIL"); + const [primitive] = validateEmail(); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -2303,10 +2285,10 @@ describe("Primitives > validateEmail", () => { email: "foobarbaz@localhost", }; - const [rule] = validateEmail({ + const [primitive] = validateEmail({ requireTopLevelDomain: false, }); - expect(rule.type).toEqual("EMAIL"); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -2335,10 +2317,10 @@ describe("Primitives > validateEmail", () => { email: "foobarbaz@[127.0.0.1]", }; - const [rule] = validateEmail({ + const [primitive] = validateEmail({ allowDomainLiteral: true, }); - expect(rule.type).toEqual("EMAIL"); + const rule = primitive.rule(context, details); assertIsLocalRule(rule); const result = await rule.protect(context, details); expect(result).toMatchObject({ @@ -2358,9 +2340,8 @@ describe("Products > protectSignup", () => { mode: ArcjetMode.DRY_RUN, match: "/test", characteristics: ["ip.src"], - window: "1h", + interval: 60 /* minutes */ * 60 /* seconds */, max: 1, - timeout: "10m", }, bots: { mode: ArcjetMode.DRY_RUN, @@ -2379,16 +2360,14 @@ describe("Products > protectSignup", () => { mode: ArcjetMode.DRY_RUN, match: "/test", characteristics: ["ip.src"], - window: "1h", + interval: 60 /* minutes */ * 60 /* seconds */, max: 1, - timeout: "10m", }, { match: "/test", characteristics: ["ip.src"], - window: "2h", + interval: 2 /* hours */ * 60 /* minutes */ * 60 /* seconds */, max: 2, - timeout: "20m", }, ], }); @@ -2425,92 +2404,91 @@ describe("Products > protectSignup", () => { }); describe("SDK", () => { - function testRuleLocalAllowed(): ArcjetLocalRule { - return { - mode: ArcjetMode.LIVE, - type: "TEST_RULE_LOCAL_ALLOWED", - priority: 1, - validate: jest.fn(), - protect: jest.fn( - async () => - new ArcjetRuleResult({ - ttl: 0, - state: "RUN", - conclusion: "ALLOW", - reason: new ArcjetTestReason(), - }), - ), - }; - } - function testRuleLocalDenied(): ArcjetLocalRule { - return { - mode: ArcjetMode.LIVE, - type: "TEST_RULE_LOCAL_DENIED", - priority: 1, - validate: jest.fn(), - protect: jest.fn( - async () => - new ArcjetRuleResult({ - ttl: 5000, - state: "RUN", - conclusion: "DENY", - reason: new ArcjetTestReason(), - }), - ), - }; + class TestPrimitiveLocalAllowed extends ArcjetPrimitive { + priority = 1; + type = "TEST_RULE_LOCAL_ALLOWED"; + mode = ArcjetMode.LIVE; + validate = jest.fn(); + protect = jest.fn( + async () => + new ArcjetRuleResult({ + ttl: 0, + state: "RUN", + conclusion: "ALLOW", + reason: new ArcjetTestReason(), + }), + ); + rule(): ArcjetRule<{}> { + return this; + } } - function testRuleRemote(): ArcjetRule { - return { - mode: "LIVE", - type: "TEST_RULE_REMOTE", - priority: 1, - }; + class TestPrimitiveLocalDenied extends ArcjetPrimitive { + priority = 1; + type = "TEST_RULE_LOCAL_DENIED"; + mode = ArcjetMode.LIVE; + validate = jest.fn(); + protect = jest.fn( + async () => + new ArcjetRuleResult({ + ttl: 5000, + state: "RUN", + conclusion: "DENY", + reason: new ArcjetTestReason(), + }), + ); + rule() { + return this; + } } - function testRuleMultiple(): ArcjetRule[] { - return [ - { mode: "LIVE", type: "TEST_RULE_MULTIPLE", priority: 1 }, - { mode: "LIVE", type: "TEST_RULE_MULTIPLE", priority: 1 }, - { mode: "LIVE", type: "TEST_RULE_MULTIPLE", priority: 1 }, - ]; + class TestPrimitiveRemote extends ArcjetPrimitive { + priority = 1; + type = "TEST_RULE_REMOTE"; + mode = ArcjetMode.LIVE; + rule() { + return this; + } } - function testRuleInvalidType(): ArcjetRule { - return { - mode: ArcjetMode.LIVE, - type: "TEST_RULE_INVALID_TYPE", - priority: 1, - }; + class TestPrimitiveInvalidType extends ArcjetPrimitive { + priority = 1; + type = "TEST_RULE_INVALID_TYPE"; + mode = ArcjetMode.LIVE; + rule() { + return this; + } } - function testRuleLocalThrow(): ArcjetLocalRule { - return { - mode: ArcjetMode.LIVE, - type: "TEST_RULE_LOCAL_THROW", - priority: 1, - validate: jest.fn(), - protect: jest.fn(async () => { - throw new Error("Local rule protect failed"); - }), - }; + class TestPrimitiveLocalThrow extends ArcjetPrimitive { + priority = 1; + type = "TEST_RULE_LOCAL_THROW"; + mode = ArcjetMode.LIVE; + validate = jest.fn(); + protect = jest.fn(async () => { + throw new Error("Local rule protect failed"); + }); + rule() { + return this; + } } - function testRuleLocalDryRun(): ArcjetLocalRule { - return { - mode: ArcjetMode.DRY_RUN, - type: "TEST_RULE_LOCAL_DRY_RUN", - priority: 1, - validate: jest.fn(), - protect: jest.fn(async () => { - return new ArcjetRuleResult({ - ttl: 0, - state: "RUN", - conclusion: "DENY", - reason: new ArcjetTestReason(), - }); - }), - }; + class TestPrimitiveLocalDryRun extends ArcjetPrimitive { + priority = 1; + mode = ArcjetMode.DRY_RUN; + type = "TEST_RULE_LOCAL_DRY_RUN"; + validate = jest.fn(); + protect = jest.fn(async () => { + return new ArcjetRuleResult({ + ttl: 0, + state: "RUN", + conclusion: "DENY", + reason: new ArcjetTestReason(), + }); + }); + rule() { + return this; + } } test("creates a new Arcjet SDK with no rules", () => { @@ -2583,7 +2561,7 @@ describe("SDK", () => { const aj = arcjet({ key: "test-key", - rules: [[testRuleLocalAllowed(), testRuleLocalDenied()]], + rules: [new TestPrimitiveLocalAllowed(), new TestPrimitiveLocalDenied()], client, }); expect(aj).toHaveProperty("protect"); @@ -2604,7 +2582,7 @@ describe("SDK", () => { const aj = arcjet({ key: "test-key", - rules: [[testRuleRemote()]], + rules: [new TestPrimitiveRemote()], client, }); expect(aj).toHaveProperty("protect"); @@ -2626,7 +2604,9 @@ describe("SDK", () => { const aj = arcjet({ key: "test-key", rules: [ - [testRuleLocalAllowed(), testRuleLocalDenied(), testRuleRemote()], + new TestPrimitiveLocalAllowed(), + new TestPrimitiveLocalDenied(), + new TestPrimitiveRemote(), ], client, }); @@ -2667,12 +2647,12 @@ describe("SDK", () => { "extra-test": "extra-test-value", }, }; - const allowed = testRuleLocalAllowed(); - const denied = testRuleLocalDenied(); + const allowed = new TestPrimitiveLocalAllowed(); + const denied = new TestPrimitiveLocalDenied(); const aj = arcjet({ key: "test-key", - rules: [[allowed, denied]], + rules: [allowed, denied], client, }); @@ -2746,10 +2726,9 @@ describe("SDK", () => { const details = {}; - const rules: ArcjetRule[][] = []; - // We only iterate 4 times because `testRuleMultiple` generates 3 rules - for (let idx = 0; idx < 4; idx++) { - rules.push(testRuleMultiple()); + const rules: ArcjetPrimitive[] = []; + for (let idx = 0; idx < 11; idx++) { + rules.push(new TestPrimitiveRemote()); } const aj = arcjet({ @@ -2785,12 +2764,12 @@ describe("SDK", () => { "extra-test": "extra-test-value", }, }; - const allowed = testRuleLocalAllowed(); - const denied = testRuleLocalDenied(); + const allowed = new TestPrimitiveLocalAllowed(); + const denied = new TestPrimitiveLocalDenied(); const aj = arcjet({ key: "test-key", - rules: [[denied, allowed]], + rules: [denied, allowed], client, }); @@ -2832,7 +2811,7 @@ describe("SDK", () => { "extra-test": "extra-test-value", }, }; - const allowed = testRuleLocalAllowed(); + const allowed = new TestPrimitiveLocalAllowed(); const aj = arcjet({ key, @@ -2875,7 +2854,7 @@ describe("SDK", () => { "extra-test": "extra-test-value", }, }; - const rule = testRuleLocalAllowed(); + const rule = new TestPrimitiveLocalAllowed(); const aj = arcjet({ key, @@ -2921,7 +2900,7 @@ describe("SDK", () => { "extra-test": "extra-test-value", }, }; - const rule = testRuleLocalDenied(); + const rule = new TestPrimitiveLocalDenied(); const aj = arcjet({ key, @@ -2965,7 +2944,7 @@ describe("SDK", () => { "extra-test": "extra-test-value", }, }; - const denied = testRuleLocalDenied(); + const denied = new TestPrimitiveLocalDenied(); const aj = arcjet({ key, @@ -3088,7 +3067,7 @@ describe("SDK", () => { expect(() => { const aj = arcjet({ key: "test-key", - rules: [[testRuleInvalidType()]], + rules: [new TestPrimitiveInvalidType()], client, }); }).not.toThrow("Unknown Rule type"); @@ -3126,7 +3105,7 @@ describe("SDK", () => { const aj = arcjet({ key, - rules: [[testRuleLocalThrow()]], + rules: [new TestPrimitiveLocalThrow()], client, }); @@ -3164,22 +3143,23 @@ describe("SDK", () => { let errorLogSpy; - function testRuleLocalThrowString(): ArcjetLocalRule { - return { - mode: ArcjetMode.LIVE, - type: "TEST_RULE_LOCAL_THROW_STRING", - priority: 1, - validate: jest.fn(), - async protect(context, req) { - errorLogSpy = jest.spyOn(context.log, "error"); - throw "Local rule protect failed"; - }, - }; + class TestPrimitiveLocalThrowString extends ArcjetPrimitive { + mode = ArcjetMode.LIVE; + type = "TEST_RULE_LOCAL_THROW_STRING"; + priority = 1; + validate = jest.fn(); + async protect(context: ArcjetContext) { + errorLogSpy = jest.spyOn(context.log, "error"); + throw "Local rule protect failed"; + } + rule() { + return this; + } } const aj = arcjet({ key, - rules: [[testRuleLocalThrowString()]], + rules: [new TestPrimitiveLocalThrowString()], client, }); @@ -3220,22 +3200,23 @@ describe("SDK", () => { let errorLogSpy; - function testRuleLocalThrowNull(): ArcjetLocalRule { - return { - mode: ArcjetMode.LIVE, - type: "TEST_RULE_LOCAL_THROW_NULL", - priority: 1, - validate: jest.fn(), - async protect(context, req) { - errorLogSpy = jest.spyOn(context.log, "error"); - throw null; - }, - }; + class TestPrimitiveLocalThrowNull extends ArcjetPrimitive { + mode = ArcjetMode.LIVE; + type = "TEST_RULE_LOCAL_THROW_NULL"; + priority = 1; + validate = jest.fn(); + async protect(context: ArcjetContext) { + errorLogSpy = jest.spyOn(context.log, "error"); + throw null; + } + rule() { + return this; + } } const aj = arcjet({ key, - rules: [[testRuleLocalThrowNull()]], + rules: [new TestPrimitiveLocalThrowNull()], client, }); @@ -3276,7 +3257,7 @@ describe("SDK", () => { const aj = arcjet({ key, - rules: [[testRuleLocalDryRun()]], + rules: [new TestPrimitiveLocalDryRun()], client, }); @@ -3325,7 +3306,7 @@ describe("SDK", () => { }, }; - const rule = testRuleRemote(); + const rule = new TestPrimitiveRemote(); const aj = arcjet({ key, diff --git a/examples/nextjs-14-openai/app/api/chat/route.ts b/examples/nextjs-14-openai/app/api/chat/route.ts index 3682f01b6..c5daa778d 100644 --- a/examples/nextjs-14-openai/app/api/chat/route.ts +++ b/examples/nextjs-14-openai/app/api/chat/route.ts @@ -1,5 +1,5 @@ // This example is adapted from https://sdk.vercel.ai/docs/guides/frameworks/nextjs-app -import arcjet, { rateLimit } from "@arcjet/next"; +import arcjet, { rateLimit, tokenBucket } from "@arcjet/next"; import { OpenAIStream, StreamingTextResponse } from "ai"; import OpenAI from "openai"; import { promptTokensEstimate } from "openai-chat-tokens"; @@ -11,12 +11,12 @@ const aj = arcjet({ // See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables key: process.env.AJ_KEY!, rules: [ - rateLimit({ + tokenBucket({ mode: "LIVE", characteristics: ["ip.src"], - window: "1m", - max: 60, - timeout: "10m", + refillRate: 1, + interval: 60, + capacity: 1, }), ], }); @@ -29,12 +29,20 @@ const openai = new OpenAI({ export const runtime = "edge"; export async function POST(req: Request) { + const { messages } = await req.json(); + + const estimate = promptTokensEstimate({ + messages, + }); + + console.log("Token estimate", estimate); + // Protect the route with Arcjet - const decision = await aj.protect(req); + const decision = await aj.protect(req, { requested: estimate }); console.log("Arcjet decision", decision.conclusion); if (decision.reason.isRateLimit()) { - console.log("Request count", decision.reason.count); + // console.log("Request count", decision.reason.count); console.log("Requests remaining", decision.reason.remaining); } @@ -52,14 +60,6 @@ export async function POST(req: Request) { } // If the request is allowed, continue to use OpenAI - const { messages } = await req.json(); - - const estimate = promptTokensEstimate({ - messages, - }); - - console.log("Token estimate", estimate); - // Ask OpenAI for a streaming chat completion given the prompt const response = await openai.chat.completions.create({ model: "gpt-3.5-turbo", @@ -71,4 +71,4 @@ export async function POST(req: Request) { const stream = OpenAIStream(response); // Respond with the stream return new StreamingTextResponse(stream); -} \ No newline at end of file +} diff --git a/protocol/convert.ts b/protocol/convert.ts index a3d47350e..0f0c5a45f 100644 --- a/protocol/convert.ts +++ b/protocol/convert.ts @@ -23,6 +23,9 @@ import { ArcjetRateLimitRule, ArcjetBotRule, ArcjetEmailRule, + ArcjetTokenBucketRateLimitRule, + ArcjetFixedWindowRateLimitRule, + ArcjetSlidingWindowRateLimitRule, } from "./index"; import { BotReason, @@ -34,6 +37,7 @@ import { EmailType, ErrorReason, Mode, + RateLimitAlgorithm, RateLimitReason, Reason, Rule, @@ -235,7 +239,7 @@ export function ArcjetReasonFromProtocol(proto?: Reason) { const reason = proto.reason.value; return new ArcjetRateLimitReason({ max: reason.max, - count: reason.count, + // count: reason.count, remaining: reason.remaining, resetTime: reason.resetTime?.toDate(), }); @@ -289,7 +293,7 @@ export function ArcjetReasonToProtocol(reason: ArcjetReason): Reason { case: "rateLimit", value: new RateLimitReason({ max: reason.max, - count: reason.count, + // count: reason.count, remaining: reason.remaining, resetTime: reason.resetTime ? Timestamp.fromDate(reason.resetTime) @@ -462,6 +466,22 @@ function isRateLimitRule( return rule.type === "RATE_LIMIT"; } +function isTokenBucketRule( + rule: ArcjetRule, +): rule is ArcjetTokenBucketRateLimitRule { + return isRateLimitRule(rule) && rule.algorithm === "TOKEN_BUCKET"; +} +function isFixedWindowRule( + rule: ArcjetRule, +): rule is ArcjetFixedWindowRateLimitRule { + return isRateLimitRule(rule) && rule.algorithm === "FIXED_WINDOW"; +} +function isSlidingWindowRule( + rule: ArcjetRule, +): rule is ArcjetSlidingWindowRateLimitRule { + return isRateLimitRule(rule) && rule.algorithm === "SLIDING_WINDOW"; +} + function isBotRule( rule: ArcjetRule, ): rule is ArcjetBotRule { @@ -477,7 +497,25 @@ function isEmailRule( export function ArcjetRuleToProtocol( rule: ArcjetRule, ): Rule { - if (isRateLimitRule(rule)) { + if (isTokenBucketRule(rule)) { + return new Rule({ + rule: { + case: "rateLimit", + value: { + mode: ArcjetModeToProtocol(rule.mode), + match: rule.match, + characteristics: rule.characteristics, + algorithm: RateLimitAlgorithm.TOKEN_BUCKET, + refillRate: rule.refillRate, + interval: rule.interval, + capacity: rule.capacity, + requested: rule.requested, + }, + }, + }); + } + + if (isFixedWindowRule(rule)) { return new Rule({ rule: { case: "rateLimit", @@ -485,9 +523,25 @@ export function ArcjetRuleToProtocol( mode: ArcjetModeToProtocol(rule.mode), match: rule.match, characteristics: rule.characteristics, + algorithm: RateLimitAlgorithm.FIXED_WINDOW, + max: rule.max, window: rule.window, + }, + }, + }); + } + + if (isSlidingWindowRule(rule)) { + return new Rule({ + rule: { + case: "rateLimit", + value: { + mode: ArcjetModeToProtocol(rule.mode), + match: rule.match, + characteristics: rule.characteristics, + algorithm: RateLimitAlgorithm.SLIDING_WINDOW, max: rule.max, - timeout: rule.timeout, + interval: rule.interval, }, }, }); diff --git a/protocol/gen/es/decide/v1alpha1/decide_pb.d.ts b/protocol/gen/es/decide/v1alpha1/decide_pb.d.ts index 2e98489c0..d2ea6be26 100644 --- a/protocol/gen/es/decide/v1alpha1/decide_pb.d.ts +++ b/protocol/gen/es/decide/v1alpha1/decide_pb.d.ts @@ -286,6 +286,31 @@ export declare enum SDKStack { SDK_STACK_DJANGO = 4, } +/** + * @generated from enum proto.decide.v1alpha1.RateLimitAlgorithm + */ +export declare enum RateLimitAlgorithm { + /** + * @generated from enum value: RATE_LIMIT_ALGORITHM_UNSPECIFIED = 0; + */ + UNSPECIFIED = 0, + + /** + * @generated from enum value: RATE_LIMIT_ALGORITHM_TOKEN_BUCKET = 1; + */ + TOKEN_BUCKET = 1, + + /** + * @generated from enum value: RATE_LIMIT_ALGORITHM_FIXED_WINDOW = 2; + */ + FIXED_WINDOW = 2, + + /** + * @generated from enum value: RATE_LIMIT_ALGORITHM_SLIDING_WINDOW = 3; + */ + SLIDING_WINDOW = 3, +} + /** * The reason for the decision. This is populated based on the selected rules * for deny or challenge responses. Additional details can be found in the @@ -377,21 +402,23 @@ export declare class RateLimitReason extends Message { /** * The configured maximum number of requests allowed in the current window. * - * @generated from field: int32 max = 1; + * @generated from field: uint32 max = 1; */ max: number; /** - * The number of requests which have been made in the current window. + * Deprecated: Always empty. Previously, the number of requests which have + * been made in the current window. * - * @generated from field: int32 count = 2; + * @generated from field: int32 count = 2 [deprecated = true]; + * @deprecated */ count: number; /** * The number of requests remaining in the current window. * - * @generated from field: int32 remaining = 3; + * @generated from field: uint32 remaining = 3; */ remaining: number; @@ -646,10 +673,12 @@ export declare class RateLimitRule extends Message { window: string; /** - * The maximum number of requests allowed in the time period. This is an - * integer value e.g. 100. + * The maximum number of requests allowed in the time period. This is a + * positive integer value e.g. 100. * - * @generated from field: int32 max = 5; + * Required by "fixed window", "sliding window", and unspecified algorithms. + * + * @generated from field: uint32 max = 5; */ max: number; @@ -666,6 +695,51 @@ export declare class RateLimitRule extends Message { */ timeout: string; + /** + * The algorithm to use for rate limiting a request. If unspecified, we will + * fallback to the "fixed window" algorithm. The chosen algorithm will + * affect which other fields must be specified to be a valid configuration. + * + * @generated from field: proto.decide.v1alpha1.RateLimitAlgorithm algorithm = 7; + */ + algorithm: RateLimitAlgorithm; + + /** + * The amount of tokens that are refilled at the provided interval. + * + * Required by "token bucket" algorithm. + * + * @generated from field: uint32 refill_rate = 8; + */ + refillRate: number; + + /** + * The interval in which a rate limit is applied or tokens refilled. + * + * Required by "token bucket" and "sliding window" algorithms. + * + * @generated from field: uint32 interval = 9; + */ + interval: number; + + /** + * The maximum number of tokens that can exist in a token bucket. + * + * Required by "token bucket" algorithm. + * + * @generated from field: uint32 capacity = 10; + */ + capacity: number; + + /** + * The number of tokens to attempt to consume from a token bucket. + * + * Required by "token bucket" algorithm. + * + * @generated from field: uint32 requested = 11; + */ + requested: number; + constructor(data?: PartialMessage); static readonly runtime: typeof proto3; @@ -947,6 +1021,21 @@ export declare class RequestDetails extends Message { */ email: string; + /** + * The string representing semicolon-separated Cookies for a request. + * + * @generated from field: string cookies = 10; + */ + cookies: string; + + /** + * The `?`-prefixed string representing the Query for a request. Commonly + * referred to as a "querystring". + * + * @generated from field: string query = 11; + */ + query: string; + constructor(data?: PartialMessage); static readonly runtime: typeof proto3; diff --git a/protocol/gen/es/decide/v1alpha1/decide_pb.js b/protocol/gen/es/decide/v1alpha1/decide_pb.js index 634f4a527..7e4d539cc 100644 --- a/protocol/gen/es/decide/v1alpha1/decide_pb.js +++ b/protocol/gen/es/decide/v1alpha1/decide_pb.js @@ -105,6 +105,19 @@ export const SDKStack = proto3.makeEnum( ], ); +/** + * @generated from enum proto.decide.v1alpha1.RateLimitAlgorithm + */ +export const RateLimitAlgorithm = proto3.makeEnum( + "proto.decide.v1alpha1.RateLimitAlgorithm", + [ + {no: 0, name: "RATE_LIMIT_ALGORITHM_UNSPECIFIED", localName: "UNSPECIFIED"}, + {no: 1, name: "RATE_LIMIT_ALGORITHM_TOKEN_BUCKET", localName: "TOKEN_BUCKET"}, + {no: 2, name: "RATE_LIMIT_ALGORITHM_FIXED_WINDOW", localName: "FIXED_WINDOW"}, + {no: 3, name: "RATE_LIMIT_ALGORITHM_SLIDING_WINDOW", localName: "SLIDING_WINDOW"}, + ], +); + /** * The reason for the decision. This is populated based on the selected rules * for deny or challenge responses. Additional details can be found in the @@ -133,9 +146,9 @@ export const Reason = proto3.makeMessageType( export const RateLimitReason = proto3.makeMessageType( "proto.decide.v1alpha1.RateLimitReason", () => [ - { no: 1, name: "max", kind: "scalar", T: 5 /* ScalarType.INT32 */ }, + { no: 1, name: "max", kind: "scalar", T: 13 /* ScalarType.UINT32 */ }, { no: 2, name: "count", kind: "scalar", T: 5 /* ScalarType.INT32 */ }, - { no: 3, name: "remaining", kind: "scalar", T: 5 /* ScalarType.INT32 */ }, + { no: 3, name: "remaining", kind: "scalar", T: 13 /* ScalarType.UINT32 */ }, { no: 4, name: "reset_time", kind: "message", T: Timestamp }, ], ); @@ -217,8 +230,13 @@ export const RateLimitRule = proto3.makeMessageType( { no: 2, name: "match", kind: "scalar", T: 9 /* ScalarType.STRING */ }, { no: 3, name: "characteristics", kind: "scalar", T: 9 /* ScalarType.STRING */, repeated: true }, { no: 4, name: "window", kind: "scalar", T: 9 /* ScalarType.STRING */ }, - { no: 5, name: "max", kind: "scalar", T: 5 /* ScalarType.INT32 */ }, + { no: 5, name: "max", kind: "scalar", T: 13 /* ScalarType.UINT32 */ }, { no: 6, name: "timeout", kind: "scalar", T: 9 /* ScalarType.STRING */ }, + { no: 7, name: "algorithm", kind: "enum", T: proto3.getEnumType(RateLimitAlgorithm) }, + { no: 8, name: "refill_rate", kind: "scalar", T: 13 /* ScalarType.UINT32 */ }, + { no: 9, name: "interval", kind: "scalar", T: 13 /* ScalarType.UINT32 */ }, + { no: 10, name: "capacity", kind: "scalar", T: 13 /* ScalarType.UINT32 */ }, + { no: 11, name: "requested", kind: "scalar", T: 13 /* ScalarType.UINT32 */ }, ], ); @@ -308,6 +326,8 @@ export const RequestDetails = proto3.makeMessageType( { no: 7, name: "body", kind: "scalar", T: 12 /* ScalarType.BYTES */ }, { no: 8, name: "extra", kind: "map", K: 9 /* ScalarType.STRING */, V: {kind: "scalar", T: 9 /* ScalarType.STRING */} }, { no: 9, name: "email", kind: "scalar", T: 9 /* ScalarType.STRING */ }, + { no: 10, name: "cookies", kind: "scalar", T: 9 /* ScalarType.STRING */ }, + { no: 11, name: "query", kind: "scalar", T: 9 /* ScalarType.STRING */ }, ], ); diff --git a/protocol/index.ts b/protocol/index.ts index dcb12ad74..8de03a961 100644 --- a/protocol/index.ts +++ b/protocol/index.ts @@ -10,6 +10,17 @@ export const ArcjetMode: ArcjetEnum = Object.freeze({ DRY_RUN: "DRY_RUN", }); +export type ArcjetRateLimitAlgorithm = + | "TOKEN_BUCKET" + | "FIXED_WINDOW" + | "SLIDING_WINDOW"; +export const ArcjetRateLimitAlgorithm: ArcjetEnum = + Object.freeze({ + TOKEN_BUCKET: "TOKEN_BUCKET", + FIXED_WINDOW: "FIXED_WINDOW", + SLIDING_WINDOW: "SLIDING_WINDOW", + }); + export type ArcjetBotType = | "NOT_ANALYZED" | "AUTOMATED" @@ -98,20 +109,20 @@ export class ArcjetRateLimitReason extends ArcjetReason { type: "RATE_LIMIT" = "RATE_LIMIT"; max: number; - count: number; + // count: number; remaining: number; resetTime?: Date; constructor(init: { max: number; - count: number; + // count: number; remaining: number; resetTime?: Date; }) { super(); this.max = init.max; - this.count = init.count; + // this.count = init.count; this.remaining = init.remaining; this.resetTime = init.resetTime; } @@ -378,12 +389,25 @@ export interface ArcjetRequestDetails { extra: Record; } +export abstract class ArcjetPrimitive { + abstract priority: number; + + abstract rule( + context: ArcjetContext, + details: Partial, + ): ArcjetRule; +} + +export type ArcjetProduct = ArcjetPrimitive[]; + export type ArcjetRule = { type: "RATE_LIMIT" | "BOT" | "EMAIL" | string; mode: ArcjetMode; - priority: number; }; +// An ArcjetLocalRule provides additional `validate` and `protect` functions +// which are used to provide local protections before requesting protections +// from the Arcjet service. export interface ArcjetLocalRule extends ArcjetRule { validate( @@ -399,12 +423,39 @@ export interface ArcjetLocalRule export interface ArcjetRateLimitRule extends ArcjetRule { type: "RATE_LIMIT"; + algorithm: ArcjetRateLimitAlgorithm; +} + +export interface ArcjetTokenBucketRateLimitRule + extends ArcjetRateLimitRule { + algorithm: "TOKEN_BUCKET"; + + match?: string; + characteristics?: string[]; + refillRate: number; + interval: number; + capacity: number; + requested: number; +} + +export interface ArcjetFixedWindowRateLimitRule + extends ArcjetRateLimitRule { + algorithm: "FIXED_WINDOW"; match?: string; characteristics?: string[]; + max: number; window: string; +} + +export interface ArcjetSlidingWindowRateLimitRule + extends ArcjetRateLimitRule { + algorithm: "SLIDING_WINDOW"; + + match?: string; + characteristics?: string[]; max: number; - timeout: string; + interval: number; } export interface ArcjetEmailRule diff --git a/protocol/test/convert.test.ts b/protocol/test/convert.test.ts index 62e53b6fd..d4e4ad66c 100644 --- a/protocol/test/convert.test.ts +++ b/protocol/test/convert.test.ts @@ -354,7 +354,6 @@ describe("convert", () => { ArcjetReasonToProtocol( new ArcjetRateLimitReason({ max: 1, - count: 2, remaining: -1, }), ), @@ -364,7 +363,6 @@ describe("convert", () => { case: "rateLimit", value: { max: 1, - count: 2, remaining: -1, }, }, @@ -375,7 +373,6 @@ describe("convert", () => { ArcjetReasonToProtocol( new ArcjetRateLimitReason({ max: 1, - count: 2, remaining: -1, resetTime, }), @@ -386,7 +383,6 @@ describe("convert", () => { case: "rateLimit", value: { max: 1, - count: 2, remaining: -1, resetTime: Timestamp.fromDate(resetTime), }, @@ -585,30 +581,28 @@ describe("convert", () => { ArcjetRuleToProtocol({ type: "UNKNOWN", mode: "DRY_RUN", - priority: 1, }), ).toEqual(new Rule({})); - expect( - ArcjetRuleToProtocol({ - type: "RATE_LIMIT", - mode: "DRY_RUN", - priority: 1, - }), - ).toEqual( - new Rule({ - rule: { - case: "rateLimit", - value: { - mode: Mode.DRY_RUN, - }, - }, - }), - ); + // TODO: Figure out how to make TypeScript allow specifying the algorithm + // expect( + // ArcjetRuleToProtocol({ + // type: "RATE_LIMIT", + // mode: "DRY_RUN", + // }), + // ).toEqual( + // new Rule({ + // rule: { + // case: "rateLimit", + // value: { + // mode: Mode.DRY_RUN, + // }, + // }, + // }), + // ); expect( ArcjetRuleToProtocol({ type: "EMAIL", mode: "DRY_RUN", - priority: 1, }), ).toEqual( new Rule({ @@ -624,7 +618,6 @@ describe("convert", () => { ArcjetRuleToProtocol(>{ type: "EMAIL", mode: "DRY_RUN", - priority: 1, block: ["INVALID"], }), ).toEqual( @@ -642,7 +635,6 @@ describe("convert", () => { ArcjetRuleToProtocol({ type: "BOT", mode: "DRY_RUN", - priority: 1, }), ).toEqual( new Rule({ @@ -662,7 +654,6 @@ describe("convert", () => { ArcjetRuleToProtocol(>{ type: "BOT", mode: "DRY_RUN", - priority: 1, block: ["AUTOMATED"], add: [["chrome", "LIKELY_NOT_A_BOT"]], }), From d2b55ab0adf206a057c09589b78b2123a72d84a5 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Fri, 2 Feb 2024 15:43:15 -0700 Subject: [PATCH 2/3] fixup tests and next bindings --- arcjet-next/index.ts | 14 +++++++++++--- examples/nextjs-13-pages-wrap/pages/api/arcjet.ts | 1 - .../nextjs-14-app-dir-rl/app/api/arcjet/route.ts | 1 - 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/arcjet-next/index.ts b/arcjet-next/index.ts index d86514755..287640bae 100644 --- a/arcjet-next/index.ts +++ b/arcjet-next/index.ts @@ -11,8 +11,6 @@ import type { NextMiddlewareResult } from "next/dist/server/web/types.js"; import arcjet, { ArcjetDecision, ArcjetOptions, - Primitive, - Product, ArcjetHeaders, Runtime, ArcjetRequest, @@ -21,6 +19,8 @@ import arcjet, { RemoteClientOptions, defaultBaseUrl, createRemoteClient, + ArcjetPrimitive, + ArcjetProduct, } from "arcjet"; import findIP from "@arcjet/ip"; @@ -64,6 +64,14 @@ type PlainObject = { [key: string]: unknown; }; +// Primitives and Products can be specified in a variety of ways and are +// externally grouped as `rules` +// See ExtraRules below for further explanation on why we define them like this. +type PrimitivesOrProduct = + | ArcjetPrimitive + | ArcjetPrimitive[] + | ArcjetProduct; + /** * Ensures redirects are followed to properly support the Next.js/Vercel Edge * Runtime. @@ -157,7 +165,7 @@ export interface ArcjetNext { * These can be overriden on a per-request basis by providing them to the * `protect()` or `protectApi` methods. */ -export default function arcjetNext( +export default function arcjetNext( options: ArcjetOptions, ): ArcjetNext>> { const client = options.client ?? createNextRemoteClient(); diff --git a/examples/nextjs-13-pages-wrap/pages/api/arcjet.ts b/examples/nextjs-13-pages-wrap/pages/api/arcjet.ts index c1e9f5b53..76873e31b 100644 --- a/examples/nextjs-13-pages-wrap/pages/api/arcjet.ts +++ b/examples/nextjs-13-pages-wrap/pages/api/arcjet.ts @@ -14,7 +14,6 @@ const aj = arcjet({ //characteristics: ["ip.src"], window: "1m", max: 1, - timeout: "10m", }), ], }); diff --git a/examples/nextjs-14-app-dir-rl/app/api/arcjet/route.ts b/examples/nextjs-14-app-dir-rl/app/api/arcjet/route.ts index 9e7add54f..b625c301f 100644 --- a/examples/nextjs-14-app-dir-rl/app/api/arcjet/route.ts +++ b/examples/nextjs-14-app-dir-rl/app/api/arcjet/route.ts @@ -12,7 +12,6 @@ const aj = arcjet({ characteristics: ["ip.src"], window: "1h", max: 1, - timeout: "10m", }), ], }); From 0719f06d87dde35dddfa773f4c8d7e12ea2114d3 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Fri, 2 Feb 2024 15:49:58 -0700 Subject: [PATCH 3/3] more timeout references --- examples/nextjs-13-pages-wrap/pages/api/arcjet-edge.ts | 1 - examples/nextjs-14-pages-wrap/pages/api/arcjet-edge.ts | 1 - examples/nextjs-14-pages-wrap/pages/api/arcjet.ts | 1 - 3 files changed, 3 deletions(-) diff --git a/examples/nextjs-13-pages-wrap/pages/api/arcjet-edge.ts b/examples/nextjs-13-pages-wrap/pages/api/arcjet-edge.ts index 332002cc5..bc2019b17 100644 --- a/examples/nextjs-13-pages-wrap/pages/api/arcjet-edge.ts +++ b/examples/nextjs-13-pages-wrap/pages/api/arcjet-edge.ts @@ -18,7 +18,6 @@ const aj = arcjet({ //characteristics: ["ip.src"], window: "1m", max: 1, - timeout: "10m", }), ], }); diff --git a/examples/nextjs-14-pages-wrap/pages/api/arcjet-edge.ts b/examples/nextjs-14-pages-wrap/pages/api/arcjet-edge.ts index 332002cc5..bc2019b17 100644 --- a/examples/nextjs-14-pages-wrap/pages/api/arcjet-edge.ts +++ b/examples/nextjs-14-pages-wrap/pages/api/arcjet-edge.ts @@ -18,7 +18,6 @@ const aj = arcjet({ //characteristics: ["ip.src"], window: "1m", max: 1, - timeout: "10m", }), ], }); diff --git a/examples/nextjs-14-pages-wrap/pages/api/arcjet.ts b/examples/nextjs-14-pages-wrap/pages/api/arcjet.ts index c1e9f5b53..76873e31b 100644 --- a/examples/nextjs-14-pages-wrap/pages/api/arcjet.ts +++ b/examples/nextjs-14-pages-wrap/pages/api/arcjet.ts @@ -14,7 +14,6 @@ const aj = arcjet({ //characteristics: ["ip.src"], window: "1m", max: 1, - timeout: "10m", }), ], });