From ab47fc77fd817d454886f8e8b2258f9085c1554e Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Wed, 21 Feb 2024 12:26:02 -0700 Subject: [PATCH 1/8] feat: Add `withRule` API to adding adhoc rules --- arcjet/index.ts | 472 ++++++++++++++++++--------------- arcjet/test/index.edge.test.ts | 6 +- 2 files changed, 262 insertions(+), 216 deletions(-) diff --git a/arcjet/index.ts b/arcjet/index.ts index 36e0b8f55..3f1ebe18f 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -1025,6 +1025,10 @@ export interface Arcjet { * @returns An {@link ArcjetDecision} indicating Arcjet's decision about the request. */ protect(request: ArcjetRequest): Promise; + + withRule( + rule: Rule, + ): Arcjet>>; } /** @@ -1045,261 +1049,301 @@ export default function arcjet< if (typeof client === "undefined") { throw new Error("Client is required"); } + // This is reassigned to help TypeScript's type inference, as it loses the + // type narrowing of the above `if` statement when using from inside `protect` + const remoteClient = client; // A local cache of block decisions. Might be emphemeral per request, // depending on the way the runtime works, but it's worth a try. // TODO(#132): Support configurable caching const blockCache = new Cache(); - const flatSortedRules = rules.flat(1).sort((a, b) => a.priority - b.priority); + const rootRules: ArcjetRule[] = rules + .flat(1) + .sort((a, b) => a.priority - b.priority); - return Object.freeze({ - get runtime() { - return runtime(); - }, - async protect( - request: ArcjetRequest>, - ): Promise { - // This goes against the type definition above, but users might call - // `protect()` with no value and we don't want to crash - if (typeof request === "undefined") { - request = {} as typeof request; - } + async function protect(rules: ArcjetRule[], request: ArcjetRequest) { + // This goes against the type definition above, but users might call + // `protect()` with no value and we don't want to crash + if (typeof request === "undefined") { + request = {} as typeof request; + } + + const details: Partial = Object.freeze({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: new ArcjetHeaders(request.headers), + cookies: request.cookies, + query: request.query, + // TODO(#208): Re-add body + // body: request.body, + extra: extraProps(request), + email: typeof request.email === "string" ? request.email : undefined, + }); + + log.time("local"); + + log.time("fingerprint"); + let ip = ""; + if (typeof details.ip === "string") { + ip = details.ip; + } + if (details.ip === "") { + log.warn("generateFingerprint: ip is empty"); + } + const fingerprint = await analyze.generateFingerprint(ip); + log.debug("fingerprint (%s): %s", runtime(), fingerprint); + log.timeEnd("fingerprint"); - const details: Partial = Object.freeze({ - ip: request.ip, - method: request.method, - protocol: request.protocol, - host: request.host, - path: request.path, - headers: new ArcjetHeaders(request.headers), - cookies: request.cookies, - query: request.query, - // TODO(#208): Re-add body - // body: request.body, - extra: extraProps(request), - email: typeof request.email === "string" ? request.email : undefined, + const context: ArcjetContext = { key, fingerprint, log }; + + if (rules.length > 10) { + log.error("Failure running rules. Only 10 rules may be specified."); + + const decision = new ArcjetErrorDecision({ + ttl: 0, + reason: new ArcjetErrorReason("Only 10 rules may be specified"), + // No results because the sorted rules were too long and we don't want + // to instantiate a ton of NOT_RUN results + results: [], }); - log.time("local"); + remoteClient.report( + context, + details, + decision, + // No rules because we've determined they were too long and we don't + // want to try to send them to the server + [], + ); - log.time("fingerprint"); - let ip = ""; - if (typeof details.ip === "string") { - ip = details.ip; - } - if (details.ip === "") { - log.warn("generateFingerprint: ip is empty"); - } - const fingerprint = await analyze.generateFingerprint(ip); - log.debug("fingerprint (%s): %s", runtime(), fingerprint); - log.timeEnd("fingerprint"); + return decision; + } - const context: ArcjetContext = { key, fingerprint, log }; + const results: ArcjetRuleResult[] = []; + // Default all rules to NOT_RUN/ALLOW before doing anything + for (let idx = 0; idx < rules.length; idx++) { + results[idx] = new ArcjetRuleResult({ + ttl: 0, + state: "NOT_RUN", + conclusion: "ALLOW", + reason: new ArcjetReason(), + }); + } - if (flatSortedRules.length > 10) { - log.error("Failure running rules. Only 10 rules may be specified."); + // We have our own local cache which we check first. This doesn't work in + // serverless environments where every request is isolated, but there may be + // some instances where the instance is not recycled immediately. If so, we + // can take advantage of that. + log.time("cache"); + const existingBlockReason = blockCache.get(fingerprint); + log.timeEnd("cache"); + + // If already blocked then we can async log to the API and return the + // decision immediately. + if (existingBlockReason) { + log.timeEnd("local"); + log.debug("decide: alreadyBlocked", { + fingerprint, + existingBlockReason, + }); + const decision = new ArcjetDenyDecision({ + ttl: blockCache.ttl(fingerprint), + reason: existingBlockReason, + // All results will be NOT_RUN because we used a cached decision + results, + }); - const decision = new ArcjetErrorDecision({ - ttl: 0, - reason: new ArcjetErrorReason("Only 10 rules may be specified"), - // No results because the sorted rules were too long and we don't want - // to instantiate a ton of NOT_RUN results - results: [], - }); + remoteClient.report(context, details, decision, rules); - client.report( - context, - details, - decision, - // No rules because we've determined they were too long and we don't - // want to try to send them to the server - [], - ); + log.debug("decide: already blocked", { + id: decision.id, + conclusion: decision.conclusion, + fingerprint, + reason: existingBlockReason, + runtime: runtime(), + }); - return decision; - } + return decision; + } - const results: ArcjetRuleResult[] = []; - // Default all rules to NOT_RUN/ALLOW before doing anything - for (let idx = 0; idx < flatSortedRules.length; idx++) { - results[idx] = new ArcjetRuleResult({ - ttl: 0, - state: "NOT_RUN", - conclusion: "ALLOW", - reason: new ArcjetReason(), - }); + for (const [idx, rule] of rules.entries()) { + // This re-assignment is a workaround to a TypeScript error with + // assertions where the name was introduced via a destructure + let localRule: ArcjetLocalRule; + if (isLocalRule(rule)) { + localRule = rule; + } else { + continue; } - // We have our own local cache which we check first. This doesn't work in - // serverless environments where every request is isolated, but there may be - // some instances where the instance is not recycled immediately. If so, we - // can take advantage of that. - log.time("cache"); - const existingBlockReason = blockCache.get(fingerprint); - log.timeEnd("cache"); - - // If already blocked then we can async log to the API and return the - // decision immediately. - if (existingBlockReason) { - log.timeEnd("local"); - log.debug("decide: alreadyBlocked", { - fingerprint, - existingBlockReason, - }); - const decision = new ArcjetDenyDecision({ - ttl: blockCache.ttl(fingerprint), - reason: existingBlockReason, - // All results will be NOT_RUN because we used a cached decision - results, - }); + log.time(rule.type); - client.report(context, details, decision, flatSortedRules); + try { + localRule.validate(context, details); + results[idx] = await localRule.protect(context, details); - log.debug("decide: already blocked", { - id: decision.id, - conclusion: decision.conclusion, + log.debug("Local rule result:", { + id: results[idx].ruleId, + rule: rule.type, fingerprint, - reason: existingBlockReason, + path: details.path, runtime: runtime(), + ttl: results[idx].ttl, + conclusion: results[idx].conclusion, + reason: results[idx].reason, }); + } catch (err) { + log.error( + "Failure running rule: %s due to %s", + rule.type, + errorMessage(err), + ); - return decision; + results[idx] = new ArcjetRuleResult({ + ttl: 0, + state: "RUN", + conclusion: "ERROR", + reason: new ArcjetErrorReason(err), + }); } - for (const [idx, rule] of flatSortedRules.entries()) { - // This re-assignment is a workaround to a TypeScript error with - // assertions where the name was introduced via a destructure - let localRule: ArcjetLocalRule; - if (isLocalRule(rule)) { - localRule = rule; - } else { - continue; - } + log.timeEnd(rule.type); - log.time(rule.type); + if (results[idx].isDenied()) { + log.timeEnd("local"); - try { - localRule.validate(context, details); - results[idx] = await localRule.protect(context, details); + const decision = new ArcjetDenyDecision({ + ttl: results[idx].ttl, + reason: results[idx].reason, + results, + }); - log.debug("Local rule result:", { - id: results[idx].ruleId, - rule: rule.type, - fingerprint, - path: details.path, - runtime: runtime(), - ttl: results[idx].ttl, - conclusion: results[idx].conclusion, - reason: results[idx].reason, - }); - } catch (err) { - log.error( - "Failure running rule: %s due to %s", - rule.type, - errorMessage(err), - ); + // Only a DENY decision is reported to avoid creating 2 entries for a + // request. Upon ALLOW, the `decide` call will create an entry for the + // request. + remoteClient.report(context, details, decision, rules); + + // If we're not in DRY_RUN mode, we want to cache non-zero TTL results + // and return this DENY decision. + if (rule.mode !== "DRY_RUN") { + if (results[idx].ttl > 0) { + log.debug("Caching decision for %d seconds", decision.ttl, { + fingerprint, + conclusion: decision.conclusion, + reason: decision.reason, + }); + + blockCache.set( + fingerprint, + decision.reason, + nowInSeconds() + decision.ttl, + ); + } - results[idx] = new ArcjetRuleResult({ - ttl: 0, - state: "RUN", - conclusion: "ERROR", - reason: new ArcjetErrorReason(err), - }); + return decision; } - log.timeEnd(rule.type); - - if (results[idx].isDenied()) { - log.timeEnd("local"); - - const decision = new ArcjetDenyDecision({ - ttl: results[idx].ttl, - reason: results[idx].reason, - results, - }); - - // Only a DENY decision is reported to avoid creating 2 entries for a - // request. Upon ALLOW, the `decide` call will create an entry for the - // request. - client.report(context, details, decision, flatSortedRules); - - // If we're not in DRY_RUN mode, we want to cache non-zero TTL results - // and return this DENY decision. - if (rule.mode !== "DRY_RUN") { - if (results[idx].ttl > 0) { - log.debug("Caching decision for %d seconds", decision.ttl, { - fingerprint, - conclusion: decision.conclusion, - reason: decision.reason, - }); - - blockCache.set( - fingerprint, - decision.reason, - nowInSeconds() + decision.ttl, - ); - } - - return decision; - } + log.warn( + `Dry run mode is enabled for "%s" rule. Overriding decision. Decision was: %s`, + rule.type, + decision.conclusion, + ); + } + } - log.warn( - `Dry run mode is enabled for "%s" rule. Overriding decision. Decision was: %s`, - rule.type, - decision.conclusion, - ); - } + log.timeEnd("local"); + log.time("remote"); + + // With no cached values, we take a decision remotely. We use a timeout to + // fail open. + try { + log.time("decideApi"); + const decision = await remoteClient.decide( + context, + details, + rules, + ); + log.timeEnd("decideApi"); + + // If the decision is to block and we have a non-zero TTL, we cache the + // block locally + if (decision.isDenied() && decision.ttl > 0) { + log.debug("decide: Caching block locally for %d seconds", decision.ttl); + + blockCache.set( + fingerprint, + decision.reason, + nowInSeconds() + decision.ttl, + ); } - log.timeEnd("local"); - log.time("remote"); + return decision; + } catch (err) { + log.error( + "Encountered problem getting remote decision: %s", + errorMessage(err), + ); + const decision = new ArcjetErrorDecision({ + ttl: 0, + reason: new ArcjetErrorReason(err), + results, + }); - // With no cached values, we take a decision remotely. We use a timeout to - // fail open. - try { - log.time("decideApi"); - const decision = await client.decide(context, details, flatSortedRules); - log.timeEnd("decideApi"); - - // If the decision is to block and we have a non-zero TTL, we cache the - // block locally - if (decision.isDenied() && decision.ttl > 0) { - log.debug( - "decide: Caching block locally for %d seconds", - decision.ttl, - ); + remoteClient.report( + { key, fingerprint, log }, + details, + decision, + rules, + ); - blockCache.set( - fingerprint, - decision.reason, - nowInSeconds() + decision.ttl, - ); - } + return decision; + } finally { + log.timeEnd("remote"); + } + } - return decision; - } catch (err) { - log.error( - "Encountered problem getting remote decision: %s", - errorMessage(err), - ); - const decision = new ArcjetErrorDecision({ - ttl: 0, - reason: new ArcjetErrorReason(err), - results, - }); + // This is a separate function so it can be called recursively + function withRule(rule: Rule) { + // TODO(#207): Remove this when we can default the transport so client is not required + // It is currently optional in the options so the Next SDK can override it for the user + if (typeof client === "undefined") { + throw new Error("Client is required"); + } - client.report( - { key, fingerprint, log }, - details, - decision, - flatSortedRules, - ); + const rules = [...rootRules, rule] + .flat(1) + .sort((a, b) => a.priority - b.priority); - return decision; - } finally { - log.timeEnd("remote"); - } + return Object.freeze({ + get runtime() { + return runtime(); + }, + withRule(rule: Primitive | Product) { + return withRule(rule); + }, + async protect( + request: ArcjetRequest>, + ): Promise { + return protect(rules, request); + }, + }); + } + + return Object.freeze({ + get runtime() { + return runtime(); + }, + withRule(rule: Primitive | Product) { + return withRule(rule); + }, + async protect( + request: ArcjetRequest>, + ): Promise { + return protect(rootRules, request); }, }); } diff --git a/arcjet/test/index.edge.test.ts b/arcjet/test/index.edge.test.ts index 46df15fa4..25fb19f23 100644 --- a/arcjet/test/index.edge.test.ts +++ b/arcjet/test/index.edge.test.ts @@ -35,7 +35,7 @@ describe("Arcjet: Env = Edge runtime", () => { key: "1234", rules: [ // Test rules - foobarbaz(), + // foobarbaz(), tokenBucket( { characteristics: [ @@ -67,7 +67,9 @@ describe("Arcjet: Env = Edge runtime", () => { client, }); - const decision = await aj.protect({ + const aj2 = aj.withRule(foobarbaz()); + + const decision = await aj2.protect({ abc: 123, requested: 1, email: "", From 17f54c1533593e1d6d8f78db75efa149cc0c5595 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Wed, 21 Feb 2024 13:44:58 -0700 Subject: [PATCH 2/8] cleanup --- arcjet/index.ts | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/arcjet/index.ts b/arcjet/index.ts index 3f1ebe18f..458fa355c 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -1062,7 +1062,10 @@ export default function arcjet< .flat(1) .sort((a, b) => a.priority - b.priority); - async function protect(rules: ArcjetRule[], request: ArcjetRequest) { + async function protect( + rules: ArcjetRule[], + request: ArcjetRequest, + ) { // This goes against the type definition above, but users might call // `protect()` with no value and we don't want to crash if (typeof request === "undefined") { @@ -1262,11 +1265,7 @@ export default function arcjet< // fail open. try { log.time("decideApi"); - const decision = await remoteClient.decide( - context, - details, - rules, - ); + const decision = await remoteClient.decide(context, details, rules); log.timeEnd("decideApi"); // If the decision is to block and we have a non-zero TTL, we cache the @@ -1293,12 +1292,7 @@ export default function arcjet< results, }); - remoteClient.report( - { key, fingerprint, log }, - details, - decision, - rules, - ); + remoteClient.report({ key, fingerprint, log }, details, decision, rules); return decision; } finally { @@ -1308,12 +1302,6 @@ export default function arcjet< // This is a separate function so it can be called recursively function withRule(rule: Rule) { - // TODO(#207): Remove this when we can default the transport so client is not required - // It is currently optional in the options so the Next SDK can override it for the user - if (typeof client === "undefined") { - throw new Error("Client is required"); - } - const rules = [...rootRules, rule] .flat(1) .sort((a, b) => a.priority - b.priority); From 216d9c47fa212bbdd23960b69241395f51a28a3a Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Thu, 22 Feb 2024 10:24:38 -0700 Subject: [PATCH 3/8] withRule in next sdk and docs --- arcjet-next/index.ts | 229 ++++++++++++++++++++++++------------------- arcjet/index.ts | 7 ++ 2 files changed, 135 insertions(+), 101 deletions(-) diff --git a/arcjet-next/index.ts b/arcjet-next/index.ts index 48d5b5025..0fcf9df8f 100644 --- a/arcjet-next/index.ts +++ b/arcjet-next/index.ts @@ -21,6 +21,7 @@ import arcjet, { RemoteClientOptions, defaultBaseUrl, createRemoteClient, + Arcjet, } from "arcjet"; import findIP from "@arcjet/ip"; @@ -160,6 +161,10 @@ function cookiesToString(cookies?: ArcjetNextRequest["cookies"]): string { .join("; "); } +/** + * The ArcjetNext client provides a public `protect()` method to + * make a decision about how a Next.js request should be handled. + */ export interface ArcjetNext { get runtime(): Runtime; /** @@ -178,124 +183,146 @@ export interface ArcjetNext { // that is required if the ExtraProps aren't strictly an empty object ...props: Props extends WithoutCustomProps ? [] : [Props] ): Promise; -} -/** - * This is the main class for Arcjet when using Next.js. It provides several - * methods for protecting Next.js routes depending on whether they are using the - * Edge or Serverless Functions runtime. - */ -/** - * Create a new Arcjet Next client. If possible, call this outside of the - * request context so it persists across requests. - * - * @param key - The key to identify the site in Arcjet. - * @param options - Arcjet configuration options to apply to all requests. - * These can be overriden on a per-request basis by providing them to the - * `protect()` or `protectApi` methods. - */ -export default function arcjetNext( - options: ArcjetOptions, -): ArcjetNext>> { - const client = options.client ?? createNextRemoteClient(); + /** + * Augments the client with another rule. Useful for varying rules based on + * criteria in your handler—e.g. different rate limit for logged in users. + * + * @param rule The rule to add to this execution. + * @returns An augmented {@link ArcjetNext} client. + */ + withRule( + rule: Rule, + ): ArcjetNext>>; +} - const aj = arcjet({ ...options, client }); +function toArcjetRequest( + request: ArcjetNextRequest, + props: Props, +): ArcjetRequest { + // We construct an ArcjetHeaders to normalize over Headers + const headers = new ArcjetHeaders(request.headers); + + const ip = findIP(request, headers); + const method = request.method ?? ""; + const host = headers.get("host") ?? ""; + let path = ""; + let query = ""; + let protocol = ""; + // TODO(#36): nextUrl has formatting logic when you `toString` but + // we don't account for that here + if (typeof request.nextUrl !== "undefined") { + path = request.nextUrl.pathname ?? ""; + if (typeof request.nextUrl.search !== "undefined") { + query = request.nextUrl.search; + } + if (typeof request.nextUrl.protocol !== "undefined") { + protocol = request.nextUrl.protocol; + } + } else { + if (typeof request.socket?.encrypted !== "undefined") { + protocol = request.socket.encrypted ? "https:" : "http:"; + } else { + protocol = "http:"; + } + // Do some very simple validation, but also try/catch around URL parsing + if ( + typeof request.url !== "undefined" && + request.url !== "" && + host !== "" + ) { + try { + const url = new URL(request.url, `${protocol}//${host}`); + path = url.pathname; + query = url.search; + protocol = url.protocol; + } catch { + // If the parsing above fails, just set the path as whatever url we + // received. + // TODO(#216): Add logging to arcjet-next + path = request.url ?? ""; + } + } else { + path = request.url ?? ""; + } + } + const cookies = cookiesToString(request.cookies); + + const extra: { [key: string]: string } = {}; + + // If we're running on Vercel, we can add some extra information + if (process.env["VERCEL"]) { + // Vercel ID https://vercel.com/docs/concepts/edge-network/headers + extra["vercel-id"] = headers.get("x-vercel-id") ?? ""; + // Vercel deployment URL + // https://vercel.com/docs/concepts/edge-network/headers + extra["vercel-deployment-url"] = + headers.get("x-vercel-deployment-url") ?? ""; + // Vercel git commit SHA + // https://vercel.com/docs/concepts/projects/environment-variables/system-environment-variables + extra["vercel-git-commit-sha"] = process.env["VERCEL_GIT_COMMIT_SHA"] ?? ""; + extra["vercel-git-commit-sha"] = process.env["VERCEL_GIT_COMMIT_SHA"] ?? ""; + } + return { + ...props, + ...extra, + ip, + method, + protocol, + host, + path, + headers, + cookies, + query, + }; +} +function withClient( + aj: Arcjet>, +): ArcjetNext> { return Object.freeze({ get runtime() { return aj.runtime; }, + withRule(rule: Primitive | Product) { + const client = aj.withRule(rule); + return withClient(client); + }, async protect( request: ArcjetNextRequest, ...[props]: ExtraProps extends WithoutCustomProps ? [] : [ExtraProps] ): Promise { - // We construct an ArcjetHeaders to normalize over Headers - const headers = new ArcjetHeaders(request.headers); - - const ip = findIP(request, headers); - const method = request.method ?? ""; - const host = headers.get("host") ?? ""; - let path = ""; - let query = ""; - let protocol = ""; - // TODO(#36): nextUrl has formatting logic when you `toString` but - // we don't account for that here - if (typeof request.nextUrl !== "undefined") { - path = request.nextUrl.pathname ?? ""; - if (typeof request.nextUrl.search !== "undefined") { - query = request.nextUrl.search; - } - if (typeof request.nextUrl.protocol !== "undefined") { - protocol = request.nextUrl.protocol; - } - } else { - if (typeof request.socket?.encrypted !== "undefined") { - protocol = request.socket.encrypted ? "https:" : "http:"; - } else { - protocol = "http:"; - } - // Do some very simple validation, but also try/catch around URL parsing - if ( - typeof request.url !== "undefined" && - request.url !== "" && - host !== "" - ) { - try { - const url = new URL(request.url, `${protocol}//${host}`); - path = url.pathname; - query = url.search; - protocol = url.protocol; - } catch { - // If the parsing above fails, just set the path as whatever url we - // received. - // TODO(#216): Add logging to arcjet-next - path = request.url ?? ""; - } - } else { - path = request.url ?? ""; - } - } - const cookies = cookiesToString(request.cookies); - - const extra: { [key: string]: string } = {}; - - // If we're running on Vercel, we can add some extra information - if (process.env["VERCEL"]) { - // Vercel ID https://vercel.com/docs/concepts/edge-network/headers - extra["vercel-id"] = headers.get("x-vercel-id") ?? ""; - // Vercel deployment URL - // https://vercel.com/docs/concepts/edge-network/headers - extra["vercel-deployment-url"] = - headers.get("x-vercel-deployment-url") ?? ""; - // Vercel git commit SHA - // https://vercel.com/docs/concepts/projects/environment-variables/system-environment-variables - extra["vercel-git-commit-sha"] = - process.env["VERCEL_GIT_COMMIT_SHA"] ?? ""; - extra["vercel-git-commit-sha"] = - process.env["VERCEL_GIT_COMMIT_SHA"] ?? ""; - } - - const decision = await aj.protect({ - ...props, - ip, - method, - protocol, - host, - path, - headers, - cookies, - query, - ...extra, - // TODO(#220): The generic manipulations get really mad here, so we just cast it - } as ArcjetRequest>); - - return decision; + // TODO(#220): The generic manipulations get really mad here, so we cast + // Further investigation makes it seem like it has something to do with + // the definition of `props` in the signature but it's hard to track down + const req = toArcjetRequest(request, props ?? {}) as ArcjetRequest< + ExtraProps + >; + + return aj.protect(req); }, }); } +/** + * Create a new Arcjet Next client. Always build your initial client outside of + * a request handler so it persists across requests. If you need to augment a + * client inside a handler, call the `withRule()` function on the base client. + * + * @param options - Arcjet configuration options to apply to all requests. + */ +export default function arcjetNext( + options: ArcjetOptions, +): ArcjetNext>> { + const client = options.client ?? createNextRemoteClient(); + + const aj = arcjet({ ...options, client }); + + return withClient(aj); +} + /** * Protects your Next.js application using Arcjet middleware. * diff --git a/arcjet/index.ts b/arcjet/index.ts index 458fa355c..c21ba2183 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -1026,6 +1026,13 @@ export interface Arcjet { */ protect(request: ArcjetRequest): Promise; + /** + * Augments the client with another rule. Useful for varying rules based on + * criteria in your handler—e.g. different rate limit for logged in users. + * + * @param rule The rule to add to this execution. + * @returns An augmented {@link Arcjet} client. + */ withRule( rule: Rule, ): Arcjet>>; From 2cdec6eb8919d8c0b11fac7f9a748c14820e1f66 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Thu, 22 Feb 2024 10:25:49 -0700 Subject: [PATCH 4/8] doc cleanup --- arcjet-next/index.ts | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/arcjet-next/index.ts b/arcjet-next/index.ts index 0fcf9df8f..e5c268b9f 100644 --- a/arcjet-next/index.ts +++ b/arcjet-next/index.ts @@ -307,9 +307,10 @@ function withClient( } /** - * Create a new Arcjet Next client. Always build your initial client outside of - * a request handler so it persists across requests. If you need to augment a - * client inside a handler, call the `withRule()` function on the base client. + * Create a new {@link ArcjetNext} client. Always build your initial client + * outside of a request handler so it persists across requests. If you need to + * augment a client inside a handler, call the `withRule()` function on the base + * client. * * @param options - Arcjet configuration options to apply to all requests. */ From a6459b98b4d1fb4e5996aad8ba4ff15a9a3c5424 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Thu, 22 Feb 2024 11:53:52 -0700 Subject: [PATCH 5/8] tests and cleanup --- arcjet/index.ts | 3 +- arcjet/test/index.node.test.ts | 171 +++++++++++++++++++++++++++++++-- 2 files changed, 165 insertions(+), 9 deletions(-) diff --git a/arcjet/index.ts b/arcjet/index.ts index c21ba2183..ffe802315 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -1309,8 +1309,7 @@ export default function arcjet< // This is a separate function so it can be called recursively function withRule(rule: Rule) { - const rules = [...rootRules, rule] - .flat(1) + const rules = [...rootRules, ...rule] .sort((a, b) => a.priority - b.priority); return Object.freeze({ diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index 8a8ba2de5..dc4fca103 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -56,6 +56,7 @@ import arcjet, { tokenBucket, slidingWindow, Primitive, + Arcjet, } from "../index"; // Type helpers from https://github.com/sindresorhus/type-fest but adjusted for @@ -92,7 +93,8 @@ type IsEqual = type Assert = T; type Props

= P extends Primitive ? Props : never; -type RequiredProps

= IsEqual, E>; +type RuleProps

= IsEqual, E>; +type SDKProps = IsEqual ? P : never, E>; // Instances of Headers contain symbols that may be different depending // on if they have been iterated or not, so we need this equality tester @@ -2030,7 +2032,7 @@ describe("Primitive > tokenBucket", () => { capacity: 120, }); type Test = Assert< - RequiredProps< + RuleProps< typeof rules, { requested: number; userId: string | number | boolean } > @@ -2052,7 +2054,7 @@ describe("Primitive > tokenBucket", () => { interval: 60, capacity: 120, }); - type Test = Assert>; + type Test = Assert>; }); test("produces a rules based on single `limit` specified", async () => { @@ -2236,7 +2238,7 @@ describe("Primitive > fixedWindow", () => { max: 1, }); type Test = Assert< - RequiredProps + RuleProps >; }); @@ -2254,7 +2256,7 @@ describe("Primitive > fixedWindow", () => { window: "1h", max: 1, }); - type Test = Assert>; + type Test = Assert>; }); test("produces a rules based on single `limit` specified", async () => { @@ -2428,7 +2430,7 @@ describe("Primitive > slidingWindow", () => { max: 1, }); type Test = Assert< - RequiredProps + RuleProps >; }); @@ -2446,7 +2448,7 @@ describe("Primitive > slidingWindow", () => { interval: "1h", max: 1, }); - type Test = Assert>; + type Test = Assert>; }); test("produces a rules based on single `limit` specified", async () => { @@ -3197,6 +3199,10 @@ describe("SDK", () => { }; } + function testRuleProps(): Primitive<{abc: number}> { + return []; + } + test("creates a new Arcjet SDK with no rules", () => { const client = { decide: jest.fn(async () => { @@ -3253,6 +3259,157 @@ describe("SDK", () => { }).toThrow(); }); + test("can augment rules via `withRule` API", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const aj = arcjet({ + key: "test-key", + rules: [], + client, + }); + type WithoutRuleTest = Assert>; + + const aj2 = aj.withRule( + tokenBucket({ + characteristics: ["userId"], + refillRate: 60, + interval: 60, + capacity: 120, + }), + ); + type WithRuleTest = Assert< + SDKProps< + typeof aj2, + { requested: number; userId: string | number | boolean } + > + >; + }); + + test("can chain new rules via multiple `withRule` calls", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const aj = arcjet({ + key: "test-key", + rules: [], + client, + }); + type WithoutRuleTest = Assert>; + + const aj2 = aj.withRule( + tokenBucket({ + characteristics: ["userId"], + refillRate: 60, + interval: 60, + capacity: 120, + }), + ); + type WithRuleTestOne = Assert< + SDKProps< + typeof aj2, + { requested: number; userId: string | number | boolean } + > + >; + + const aj3 = aj2.withRule(testRuleProps()) + type WithRuleTestTwo = Assert< + SDKProps< + typeof aj3, + { requested: number; userId: string | number | boolean, abc: number } + > + >; + }); + + test("creates different augmented clients when `withRule` not chained", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const aj = arcjet({ + key: "test-key", + rules: [], + client, + }); + type WithoutRuleTest = Assert>; + + const aj2 = aj.withRule( + tokenBucket({ + characteristics: ["userId"], + refillRate: 60, + interval: 60, + capacity: 120, + }), + ); + type WithRuleTestOne = Assert< + SDKProps< + typeof aj2, + { requested: number; userId: string | number | boolean } + > + >; + + const aj3 = aj.withRule(testRuleProps()) + type WithRuleTestTwo = Assert< + SDKProps< + typeof aj3, + { abc: number } + > + >; + }); + + test("augment SDK still has the `runtime` property", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const aj = arcjet({ + key: "test-key", + rules: [], + client, + }); + + const aj2 = aj.withRule( + tokenBucket({ + characteristics: ["userId"], + refillRate: 60, + interval: 60, + capacity: 120, + }), + ); + + expect(aj2).toHaveProperty("runtime", Runtime.Node); + }); + test("creates a new Arcjet SDK with only local rules", () => { const client = { decide: jest.fn(async () => { From 9ad5b89b54b780b66a8eeac245dfe1d9b7fdb1e8 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Thu, 22 Feb 2024 12:19:11 -0700 Subject: [PATCH 6/8] rework clerk example --- examples/nextjs-14-clerk-rl/README.md | 9 ++- .../app/api/arcjet/route.ts | 72 +++++++++++++++++++ .../app/api/private/route.ts | 59 --------------- .../app/api/public/route.ts | 42 ----------- examples/nextjs-14-clerk-rl/app/page.tsx | 25 +++---- 5 files changed, 87 insertions(+), 120 deletions(-) create mode 100644 examples/nextjs-14-clerk-rl/app/api/arcjet/route.ts delete mode 100644 examples/nextjs-14-clerk-rl/app/api/private/route.ts delete mode 100644 examples/nextjs-14-clerk-rl/app/api/public/route.ts diff --git a/examples/nextjs-14-clerk-rl/README.md b/examples/nextjs-14-clerk-rl/README.md index bd919a75b..2309bfc64 100644 --- a/examples/nextjs-14-clerk-rl/README.md +++ b/examples/nextjs-14-clerk-rl/README.md @@ -10,12 +10,11 @@ This example shows how to use an Arcjet rate limit with a user ID from [Clerk authentication and Next.js](https://clerk.com/docs/quickstarts/nextjs). -It sets up 2 API routes: +It sets up the `/api/arcjet` route. -* `/api/public` does not require authentication and has a low rate limit based - on the user IP address. -* `/api/private` uses Clerk authentication and has a higher rate limit based on - the Clerk user ID. +* Unauthenticated users receive a low rate limit based on the user IP address. +* Users authenticated with Clerk have a higher rate limit based on the Clerk + user ID. ## How to use diff --git a/examples/nextjs-14-clerk-rl/app/api/arcjet/route.ts b/examples/nextjs-14-clerk-rl/app/api/arcjet/route.ts new file mode 100644 index 000000000..95c70601d --- /dev/null +++ b/examples/nextjs-14-clerk-rl/app/api/arcjet/route.ts @@ -0,0 +1,72 @@ +import arcjet, { ArcjetDecision, tokenBucket } from "@arcjet/next"; +import { NextResponse } from "next/server"; +import { currentUser } from "@clerk/nextjs"; + +// The root Arcjet client is created outside of the handler. +const aj = arcjet({ + // Get your site key from https://app.arcjet.com + // and set it as an environment variable rather than hard coding. + // See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables + key: process.env.ARCJET_KEY!, + rules: [], +}); + +export async function GET(req: Request) { + // Get the current user from Clerk + // See https://clerk.com/docs/references/nextjs/current-user + const user = await currentUser(); + + let decision: ArcjetDecision; + if (user) { + // Allow higher limits for signed in users. + const rl = aj.withRule( + // Create a token bucket rate limit. Fixed and sliding window rate limits + // are also supported. See https://docs.arcjet.com/rate-limiting/algorithms + tokenBucket({ + mode: "LIVE", // will block requests at the limit. Use "DRY_RUN" to log only + // Rate limit based on the Clerk userId + // See https://clerk.com/docs/references/nextjs/authentication-object + // See https://docs.arcjet.com/rate-limiting/configuration#characteristics + characteristics: ["userId"], + refillRate: 20, // refill 20 tokens per interval + interval: 10, // refill every 10 seconds + capacity: 100, // bucket maximum capacity of 100 tokens + }) + ); + + // Deduct 5 tokens from the token bucket + decision = await rl.protect(req, { userId: user.id, requested: 5 } ); + } else { + // Limit the amount of requests for anonymous users. + const rl = aj.withRule( + // Create a token bucket rate limit. Fixed and sliding window rate limits + // are also supported. See https://docs.arcjet.com/rate-limiting/algorithms + tokenBucket({ + mode: "LIVE", // will block requests at the limit. Use "DRY_RUN" to log only + // Use the built in ip.src characteristic + // See https://docs.arcjet.com/rate-limiting/configuration#characteristics + characteristics: ["ip.src"], + refillRate: 5, // refill 5 tokens per interval + interval: 10, // refill every 10 seconds + capacity: 10, // bucket maximum capacity of 10 tokens + }) + ); + + // Deduct 5 tokens from the token bucket + decision = await rl.protect(req, { requested: 5 }) + } + + if (decision.isDenied()) { + return NextResponse.json( + { + error: "Too Many Requests", + reason: decision.reason, + }, + { + status: 429, + } + ); + } + + return NextResponse.json({ message: "Hello World" }); +} diff --git a/examples/nextjs-14-clerk-rl/app/api/private/route.ts b/examples/nextjs-14-clerk-rl/app/api/private/route.ts deleted file mode 100644 index f049f6092..000000000 --- a/examples/nextjs-14-clerk-rl/app/api/private/route.ts +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Testing this route requires a Clerk user JWT token passed in the - * Authorization header. - * - * `curl -v http://localhost:3000/api/private -H "Authorization: Bearer TOKENHERE"` - * - * Get the token from the /api/token route. - */ -import arcjet, { tokenBucket } from "@arcjet/next"; -import { NextResponse } from "next/server"; -import { currentUser } from "@clerk/nextjs"; - -// The arcjet instance is created outside of the handler -const aj = arcjet({ - // Get your site key from https://app.arcjet.com - // and set it as an environment variable rather than hard coding. - // See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables - key: process.env.ARCJET_KEY!, - rules: [ - // Create a token bucket rate limit. Fixed and sliding window rate limits - // are also supported. See https://docs.arcjet.com/rate-limiting/algorithms - tokenBucket({ - mode: "LIVE", // will block requests at the limit. Use "DRY_RUN" to log only - // Rate limit based on the Clerk userId - // See https://clerk.com/docs/references/nextjs/authentication-object - // See https://docs.arcjet.com/rate-limiting/configuration#characteristics - characteristics: ["userId"], - refillRate: 5, // refill 5 tokens per interval - interval: 10, // refill every 10 seconds - capacity: 10, // bucket maximum capacity of 10 tokens - }), - ], -}); - -export async function GET(req: Request) { - // Get the current user from Clerk - // See https://clerk.com/docs/references/nextjs/current-user - const user = await currentUser(); - if (!user) { - return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); - } - - // Deduct 5 tokens from the user's bucket - const decision = await aj.protect(req, { userId: user.id, requested: 5 }); - - if (decision.isDenied()) { - return NextResponse.json( - { - error: "Too Many Requests", - reason: decision.reason, - }, - { - status: 429, - } - ); - } - - return NextResponse.json({ message: "Hello World" }); -} \ No newline at end of file diff --git a/examples/nextjs-14-clerk-rl/app/api/public/route.ts b/examples/nextjs-14-clerk-rl/app/api/public/route.ts deleted file mode 100644 index a04aecdee..000000000 --- a/examples/nextjs-14-clerk-rl/app/api/public/route.ts +++ /dev/null @@ -1,42 +0,0 @@ -import arcjet, { tokenBucket } from "@arcjet/next"; -import { NextResponse } from "next/server"; - -// The arcjet instance is created outside of the handler -const aj = arcjet({ - // Get your site key from https://app.arcjet.com - // and set it as an environment variable rather than hard coding. - // See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables - key: process.env.ARCJET_KEY!, - rules: [ - // Create a token bucket rate limit. Fixed and sliding window rate limits - // are also supported. See https://docs.arcjet.com/rate-limiting/algorithms - tokenBucket({ - mode: "LIVE", // will block requests at the limit. Use "DRY_RUN" to log only - // Use the built in ip.src characteristic - // See https://docs.arcjet.com/rate-limiting/configuration#characteristics - characteristics: ["ip.src"], - refillRate: 5, // refill 5 tokens per interval - interval: 10, // refill every 10 seconds - capacity: 10, // bucket maximum capacity of 10 tokens - }), - ], -}); - -export async function GET(req: Request) { - // Deduct 5 tokens from the bucket - const decision = await aj.protect(req, { requested: 5 }); - - if (decision.isDenied()) { - return NextResponse.json( - { - error: "Too Many Requests", - reason: decision.reason, - }, - { - status: 429, - } - ); - } - - return NextResponse.json({ message: "Hello World" }); -} \ No newline at end of file diff --git a/examples/nextjs-14-clerk-rl/app/page.tsx b/examples/nextjs-14-clerk-rl/app/page.tsx index 70629aac7..ff062c898 100644 --- a/examples/nextjs-14-clerk-rl/app/page.tsx +++ b/examples/nextjs-14-clerk-rl/app/page.tsx @@ -5,22 +5,19 @@ export default function Home() {

Arcjet Rate Limit / Clerk Authentication Example

- These two API routes are both protected with an Arcjet rate limit: + This API route is protected with an Arcjet rate limit. + + /api/arcjet +

  • - - /api/public - {" "} - does not require authentication and has a low rate limit based on - the user IP address. + Unauthenticated users receive a low rate limit based on the user + IP address.
  • - - /api/private - {" "} - uses Clerk authentication and has a higher rate limit based on the - Clerk user ID. + Users authenticated with Clerk have a higher rate limit based on + the Clerk user ID.
@@ -39,8 +36,8 @@ export default function Home() {
  • Visit{" "} - - /api/private + + /api/arcjet {" "} in your browser or use the token to send several curl{" "} requests to /api/private @@ -48,7 +45,7 @@ export default function Home() {
    -            curl -v http://localhost:3000/api/private -H "Authorization: Bearer
    +            curl -v http://localhost:3000/api/arcjet -H "Authorization: Bearer
                 TOKENHERE"
               
  • From a844cc3970c8b8cda4583514b14b21479c62892d Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Thu, 22 Feb 2024 12:20:33 -0700 Subject: [PATCH 7/8] fmt --- arcjet/index.ts | 5 +++-- arcjet/test/index.node.test.ts | 15 +++++---------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/arcjet/index.ts b/arcjet/index.ts index ffe802315..0ab303cf3 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -1309,8 +1309,9 @@ export default function arcjet< // This is a separate function so it can be called recursively function withRule(rule: Rule) { - const rules = [...rootRules, ...rule] - .sort((a, b) => a.priority - b.priority); + const rules = [...rootRules, ...rule].sort( + (a, b) => a.priority - b.priority, + ); return Object.freeze({ get runtime() { diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index dc4fca103..726253879 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -3199,7 +3199,7 @@ describe("SDK", () => { }; } - function testRuleProps(): Primitive<{abc: number}> { + function testRuleProps(): Primitive<{ abc: number }> { return []; } @@ -3328,11 +3328,11 @@ describe("SDK", () => { > >; - const aj3 = aj2.withRule(testRuleProps()) + const aj3 = aj2.withRule(testRuleProps()); type WithRuleTestTwo = Assert< SDKProps< typeof aj3, - { requested: number; userId: string | number | boolean, abc: number } + { requested: number; userId: string | number | boolean; abc: number } > >; }); @@ -3371,13 +3371,8 @@ describe("SDK", () => { > >; - const aj3 = aj.withRule(testRuleProps()) - type WithRuleTestTwo = Assert< - SDKProps< - typeof aj3, - { abc: number } - > - >; + const aj3 = aj.withRule(testRuleProps()); + type WithRuleTestTwo = Assert>; }); test("augment SDK still has the `runtime` property", async () => { From 1229513f0a6adb667fbf63288a16e692b503f119 Mon Sep 17 00:00:00 2001 From: David Mytton Date: Fri, 23 Feb 2024 09:31:16 +0000 Subject: [PATCH 8/8] Slimmed comments & add bot rule --- examples/nextjs-14-clerk-rl/README.md | 1 + .../app/api/arcjet/route.ts | 70 +++++++++++-------- 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/examples/nextjs-14-clerk-rl/README.md b/examples/nextjs-14-clerk-rl/README.md index 2309bfc64..a9115d6e3 100644 --- a/examples/nextjs-14-clerk-rl/README.md +++ b/examples/nextjs-14-clerk-rl/README.md @@ -15,6 +15,7 @@ It sets up the `/api/arcjet` route. * Unauthenticated users receive a low rate limit based on the user IP address. * Users authenticated with Clerk have a higher rate limit based on the Clerk user ID. +* A bot detection rule is also added to check all requests. ## How to use diff --git a/examples/nextjs-14-clerk-rl/app/api/arcjet/route.ts b/examples/nextjs-14-clerk-rl/app/api/arcjet/route.ts index 95c70601d..844c52a52 100644 --- a/examples/nextjs-14-clerk-rl/app/api/arcjet/route.ts +++ b/examples/nextjs-14-clerk-rl/app/api/arcjet/route.ts @@ -1,14 +1,16 @@ -import arcjet, { ArcjetDecision, tokenBucket } from "@arcjet/next"; +import arcjet, { ArcjetDecision, tokenBucket, detectBot } from "@arcjet/next"; import { NextResponse } from "next/server"; import { currentUser } from "@clerk/nextjs"; // The root Arcjet client is created outside of the handler. const aj = arcjet({ - // Get your site key from https://app.arcjet.com - // and set it as an environment variable rather than hard coding. - // See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables - key: process.env.ARCJET_KEY!, - rules: [], + key: process.env.ARCJET_KEY!, // Get your site key from https://app.arcjet.com + rules: [ + detectBot({ + mode: "LIVE", // will block requests. Use "DRY_RUN" to log only + block: ["AUTOMATED"], // blocks all automated clients + }), + ], }); export async function GET(req: Request) { @@ -20,14 +22,10 @@ export async function GET(req: Request) { if (user) { // Allow higher limits for signed in users. const rl = aj.withRule( - // Create a token bucket rate limit. Fixed and sliding window rate limits - // are also supported. See https://docs.arcjet.com/rate-limiting/algorithms + // Create a token bucket rate limit. Other algorithms are supported. tokenBucket({ - mode: "LIVE", // will block requests at the limit. Use "DRY_RUN" to log only - // Rate limit based on the Clerk userId - // See https://clerk.com/docs/references/nextjs/authentication-object - // See https://docs.arcjet.com/rate-limiting/configuration#characteristics - characteristics: ["userId"], + mode: "LIVE", // will block requests. Use "DRY_RUN" to log only + characteristics: ["userId"], // Rate limit based on the Clerk userId refillRate: 20, // refill 20 tokens per interval interval: 10, // refill every 10 seconds capacity: 100, // bucket maximum capacity of 100 tokens @@ -35,17 +33,14 @@ export async function GET(req: Request) { ); // Deduct 5 tokens from the token bucket - decision = await rl.protect(req, { userId: user.id, requested: 5 } ); + decision = await rl.protect(req, { userId: user.id, requested: 5 }); } else { // Limit the amount of requests for anonymous users. const rl = aj.withRule( - // Create a token bucket rate limit. Fixed and sliding window rate limits - // are also supported. See https://docs.arcjet.com/rate-limiting/algorithms + // Create a token bucket rate limit. Other algorithms are supported. tokenBucket({ - mode: "LIVE", // will block requests at the limit. Use "DRY_RUN" to log only - // Use the built in ip.src characteristic - // See https://docs.arcjet.com/rate-limiting/configuration#characteristics - characteristics: ["ip.src"], + mode: "LIVE", // will block requests. Use "DRY_RUN" to log only + characteristics: ["ip.src"], // Use the built in ip.src characteristic refillRate: 5, // refill 5 tokens per interval interval: 10, // refill every 10 seconds capacity: 10, // bucket maximum capacity of 10 tokens @@ -53,20 +48,33 @@ export async function GET(req: Request) { ); // Deduct 5 tokens from the token bucket - decision = await rl.protect(req, { requested: 5 }) + decision = await rl.protect(req, { requested: 5 }); } if (decision.isDenied()) { - return NextResponse.json( - { - error: "Too Many Requests", - reason: decision.reason, - }, - { - status: 429, - } - ); + if (decision.reason.isRateLimit()) { + return NextResponse.json( + { + error: "Too Many Requests", + reason: decision.reason, + }, + { + status: 429, + } + ); + } else { + // Detected a bot + return NextResponse.json( + { + error: "Forbidden", + reason: decision.reason, + }, + { + status: 403, + } + ); + } } return NextResponse.json({ message: "Hello World" }); -} +} \ No newline at end of file