diff --git a/analyze/edge-light.ts b/analyze/edge-light.ts index ebd82d755..1c302f2f9 100644 --- a/analyze/edge-light.ts +++ b/analyze/edge-light.ts @@ -1,4 +1,4 @@ -import type { ArcjetLogger } from "@arcjet/protocol"; +import type { ArcjetLogger, ArcjetRequestDetails } from "@arcjet/protocol"; import * as core from "./wasm/arcjet_analyze_js_req.component.js"; import type { @@ -14,6 +14,7 @@ import componentCore3Wasm from "./wasm/arcjet_analyze_js_req.component.core3.was interface AnalyzeContext { log: ArcjetLogger; + characteristics: string[]; } async function moduleFromPath(path: string): Promise { @@ -72,45 +73,21 @@ export { /** * Generate a fingerprint for the client. This is used to identify the client * across multiple requests. - * @param ip - The IP address of the client. + * @param context - The Arcjet Analyze context. + * @param request - The request to fingerprint. * @returns A SHA-256 string fingerprint. */ export async function generateFingerprint( context: AnalyzeContext, - ip: string, + request: Partial, ): Promise { - if (ip == "") { - return ""; - } - const analyze = await init(context); if (typeof analyze !== "undefined") { - return analyze.generateFingerprint(ip); - } - - if (hasSubtleCryptoDigest()) { - // Fingerprint v1 is just the IP address - const fingerprintRaw = `fp_1_${ip}`; - - // Based on MDN example at - // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/digest#converting_a_digest_to_a_hex_string - - // Encode the raw fingerprint into a utf-8 Uint8Array - const fingerprintUint8 = new TextEncoder().encode(fingerprintRaw); - // Hash the message with SHA-256 - const fingerprintArrayBuffer = await crypto.subtle.digest( - "SHA-256", - fingerprintUint8, + return analyze.generateFingerprint( + JSON.stringify(request), + context.characteristics, ); - // Convert the ArrayBuffer to a byte array - const fingerprintArray = Array.from(new Uint8Array(fingerprintArrayBuffer)); - // Convert the bytes to a hex string - const fingerprint = fingerprintArray - .map((b) => b.toString(16).padStart(2, "0")) - .join(""); - - return fingerprint; } return ""; @@ -149,24 +126,3 @@ export async function detectBot( }; } } - -function hasSubtleCryptoDigest() { - if (typeof crypto === "undefined") { - return false; - } - - if (!("subtle" in crypto)) { - return false; - } - if (typeof crypto.subtle === "undefined") { - return false; - } - if (!("digest" in crypto.subtle)) { - return false; - } - if (typeof crypto.subtle.digest !== "function") { - return false; - } - - return true; -} diff --git a/analyze/index.ts b/analyze/index.ts index 4196f605c..212318da6 100644 --- a/analyze/index.ts +++ b/analyze/index.ts @@ -1,4 +1,4 @@ -import type { ArcjetLogger } from "@arcjet/protocol"; +import type { ArcjetLogger, ArcjetRequestDetails } from "@arcjet/protocol"; import * as core from "./wasm/arcjet_analyze_js_req.component.js"; import type { @@ -14,6 +14,7 @@ import { wasm as componentCore3Wasm } from "./wasm/arcjet_analyze_js_req.compone interface AnalyzeContext { log: ArcjetLogger; + characteristics: string[]; } // TODO: Do we actually need this wasmCache or does `import` cache correctly? @@ -86,45 +87,21 @@ export { /** * Generate a fingerprint for the client. This is used to identify the client * across multiple requests. - * @param ip - The IP address of the client. + * @param context - The Arcjet Analyze context. + * @param request - The request to fingerprint. * @returns A SHA-256 string fingerprint. */ export async function generateFingerprint( context: AnalyzeContext, - ip: string, + request: Partial, ): Promise { - if (ip == "") { - return ""; - } - const analyze = await init(context); if (typeof analyze !== "undefined") { - return analyze.generateFingerprint(ip); - } - - if (hasSubtleCryptoDigest()) { - // Fingerprint v1 is just the IP address - const fingerprintRaw = `fp_1_${ip}`; - - // Based on MDN example at - // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/digest#converting_a_digest_to_a_hex_string - - // Encode the raw fingerprint into a utf-8 Uint8Array - const fingerprintUint8 = new TextEncoder().encode(fingerprintRaw); - // Hash the message with SHA-256 - const fingerprintArrayBuffer = await crypto.subtle.digest( - "SHA-256", - fingerprintUint8, + return analyze.generateFingerprint( + JSON.stringify(request), + context.characteristics, ); - // Convert the ArrayBuffer to a byte array - const fingerprintArray = Array.from(new Uint8Array(fingerprintArrayBuffer)); - // Convert the bytes to a hex string - const fingerprint = fingerprintArray - .map((b) => b.toString(16).padStart(2, "0")) - .join(""); - - return fingerprint; } return ""; @@ -163,24 +140,3 @@ export async function detectBot( }; } } - -function hasSubtleCryptoDigest() { - if (typeof crypto === "undefined") { - return false; - } - - if (!("subtle" in crypto)) { - return false; - } - if (typeof crypto.subtle === "undefined") { - return false; - } - if (!("digest" in crypto.subtle)) { - return false; - } - if (typeof crypto.subtle.digest !== "function") { - return false; - } - - return true; -} diff --git a/analyze/wasm/arcjet_analyze_js_req.component.core.wasm b/analyze/wasm/arcjet_analyze_js_req.component.core.wasm index 0c106fbca..4c41b3d36 100644 Binary files a/analyze/wasm/arcjet_analyze_js_req.component.core.wasm and b/analyze/wasm/arcjet_analyze_js_req.component.core.wasm differ diff --git a/analyze/wasm/arcjet_analyze_js_req.component.core2.wasm b/analyze/wasm/arcjet_analyze_js_req.component.core2.wasm index 73347ee86..4affd384a 100644 Binary files a/analyze/wasm/arcjet_analyze_js_req.component.core2.wasm and b/analyze/wasm/arcjet_analyze_js_req.component.core2.wasm differ diff --git a/analyze/wasm/arcjet_analyze_js_req.component.core3.wasm b/analyze/wasm/arcjet_analyze_js_req.component.core3.wasm index 784cc286e..4c4bf1373 100644 Binary files a/analyze/wasm/arcjet_analyze_js_req.component.core3.wasm and b/analyze/wasm/arcjet_analyze_js_req.component.core3.wasm differ diff --git a/analyze/wasm/arcjet_analyze_js_req.component.d.ts b/analyze/wasm/arcjet_analyze_js_req.component.d.ts index 774ec062f..e10a583d8 100644 --- a/analyze/wasm/arcjet_analyze_js_req.component.d.ts +++ b/analyze/wasm/arcjet_analyze_js_req.component.d.ts @@ -28,7 +28,7 @@ export interface ImportObject { } export interface Root { detectBot(headers: string, patternsAdd: string, patternsRemove: string): BotDetectionResult, - generateFingerprint(ip: string): string, + generateFingerprint(request: string, characteristics: string[]): string, isValidEmail(candidate: string, options: EmailValidationConfig | undefined): boolean, } diff --git a/analyze/wasm/arcjet_analyze_js_req.component.js b/analyze/wasm/arcjet_analyze_js_req.component.js index 943482e8d..02e0566c0 100644 --- a/analyze/wasm/arcjet_analyze_js_req.component.js +++ b/analyze/wasm/arcjet_analyze_js_req.component.js @@ -163,15 +163,25 @@ async function instantiate(getCoreModule, imports, instantiateCore = WebAssembly return variant5.val; } - function generateFingerprint(arg0) { + function generateFingerprint(arg0, arg1) { var ptr0 = utf8Encode(arg0, realloc0, memory0); var len0 = utf8EncodedLen; - const ret = exports1['generate-fingerprint'](ptr0, len0); - var ptr1 = dataView(memory0).getInt32(ret + 0, true); - var len1 = dataView(memory0).getInt32(ret + 4, true); - var result1 = utf8Decoder.decode(new Uint8Array(memory0.buffer, ptr1, len1)); + var vec2 = arg1; + var len2 = vec2.length; + var result2 = realloc0(0, 0, 4, len2 * 8); + for (let i = 0; i < vec2.length; i++) { + const e = vec2[i]; + const base = result2 + i * 8;var ptr1 = utf8Encode(e, realloc0, memory0); + var len1 = utf8EncodedLen; + dataView(memory0).setInt32(base + 4, len1, true); + dataView(memory0).setInt32(base + 0, ptr1, true); + } + const ret = exports1['generate-fingerprint'](ptr0, len0, result2, len2); + var ptr3 = dataView(memory0).getInt32(ret + 0, true); + var len3 = dataView(memory0).getInt32(ret + 4, true); + var result3 = utf8Decoder.decode(new Uint8Array(memory0.buffer, ptr3, len3)); postReturn1(ret); - return result1; + return result3; } function isValidEmail(arg0, arg1) { diff --git a/analyze/wasm/arcjet_analyze_js_req.component.wasm b/analyze/wasm/arcjet_analyze_js_req.component.wasm index 126ae656f..3da0ecca2 100644 Binary files a/analyze/wasm/arcjet_analyze_js_req.component.wasm and b/analyze/wasm/arcjet_analyze_js_req.component.wasm differ diff --git a/analyze/workerd.ts b/analyze/workerd.ts index 080f4305c..b7073c86f 100644 --- a/analyze/workerd.ts +++ b/analyze/workerd.ts @@ -1,4 +1,4 @@ -import type { ArcjetLogger } from "@arcjet/protocol"; +import type { ArcjetLogger, ArcjetRequestDetails } from "@arcjet/protocol"; import * as core from "./wasm/arcjet_analyze_js_req.component.js"; import type { @@ -14,6 +14,7 @@ import componentCore3Wasm from "./wasm/arcjet_analyze_js_req.component.core3.was interface AnalyzeContext { log: ArcjetLogger; + characteristics: string[]; } async function moduleFromPath(path: string): Promise { @@ -72,45 +73,21 @@ export { /** * Generate a fingerprint for the client. This is used to identify the client * across multiple requests. - * @param ip - The IP address of the client. + * @param context - The Arcjet Analyze context. + * @param request - The request to fingerprint. * @returns A SHA-256 string fingerprint. */ export async function generateFingerprint( context: AnalyzeContext, - ip: string, + request: Partial, ): Promise { - if (ip == "") { - return ""; - } - const analyze = await init(context); if (typeof analyze !== "undefined") { - return analyze.generateFingerprint(ip); - } - - if (hasSubtleCryptoDigest()) { - // Fingerprint v1 is just the IP address - const fingerprintRaw = `fp_1_${ip}`; - - // Based on MDN example at - // https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/digest#converting_a_digest_to_a_hex_string - - // Encode the raw fingerprint into a utf-8 Uint8Array - const fingerprintUint8 = new TextEncoder().encode(fingerprintRaw); - // Hash the message with SHA-256 - const fingerprintArrayBuffer = await crypto.subtle.digest( - "SHA-256", - fingerprintUint8, + return analyze.generateFingerprint( + JSON.stringify(request), + context.characteristics, ); - // Convert the ArrayBuffer to a byte array - const fingerprintArray = Array.from(new Uint8Array(fingerprintArrayBuffer)); - // Convert the bytes to a hex string - const fingerprint = fingerprintArray - .map((b) => b.toString(16).padStart(2, "0")) - .join(""); - - return fingerprint; } return ""; @@ -149,24 +126,3 @@ export async function detectBot( }; } } - -function hasSubtleCryptoDigest() { - if (typeof crypto === "undefined") { - return false; - } - - if (!("subtle" in crypto)) { - return false; - } - if (typeof crypto.subtle === "undefined") { - return false; - } - if (!("digest" in crypto.subtle)) { - return false; - } - if (typeof crypto.subtle.digest !== "function") { - return false; - } - - return true; -} diff --git a/arcjet/index.ts b/arcjet/index.ts index 4703d30b4..edb84ced2 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -21,8 +21,12 @@ import { ArcjetSlidingWindowRateLimitRule, ArcjetShieldRule, ArcjetLogger, + ArcjetRateLimitRule, } from "@arcjet/protocol"; -import { ArcjetBotTypeToProtocol } from "@arcjet/protocol/convert.js"; +import { + ArcjetBotTypeToProtocol, + isRateLimitRule, +} from "@arcjet/protocol/convert.js"; import { Client } from "@arcjet/protocol/client.js"; import * as analyze from "@arcjet/analyze"; import * as duration from "@arcjet/duration"; @@ -784,6 +788,10 @@ export interface ArcjetOptions { * Rules to apply when protecting a request. */ rules: readonly [...Rules]; + /** + * Characteristics to be used to uniquely identify clients. + */ + characteristics?: string[]; /** * The client used to make requests to the Arcjet API. This must be set * when creating the SDK, such as inside @arcjet/next or mocked in tests. @@ -890,21 +898,19 @@ export default function arcjet< log.time?.("local"); log.time?.("fingerprint"); - let ip = ""; - if (typeof details.ip === "string") { - ip = details.ip; - } - if (details.ip === "") { - log.warn("generateFingerprint: ip is empty"); - } + + const characteristics = options.characteristics + ? options.characteristics + : []; const baseContext = { key, log, + characteristics, ...ctx, }; - const fingerprint = await analyze.generateFingerprint(baseContext, ip); + const fingerprint = await analyze.generateFingerprint(baseContext, details); log.debug("fingerprint (%s): %s", rt, fingerprint); log.timeEnd?.("fingerprint"); @@ -945,14 +951,24 @@ export default function arcjet< } const results: ArcjetRuleResult[] = []; - // Default all rules to NOT_RUN/ALLOW before doing anything for (let idx = 0; idx < rules.length; idx++) { + // Default all rules to NOT_RUN/ALLOW before doing anything results[idx] = new ArcjetRuleResult({ ttl: 0, state: "NOT_RUN", conclusion: "ALLOW", reason: new ArcjetReason(), }); + + // Add top-level characteristics to all Rate Limit rules that don't already have + // their own set of characteristics. + const candidate_rule = rules[idx]; + if (isRateLimitRule(candidate_rule)) { + if (typeof candidate_rule.characteristics === "undefined") { + candidate_rule.characteristics = characteristics; + rules[idx] = candidate_rule; + } + } } // We have our own local cache which we check first. This doesn't work in diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index 75ef64eb7..11b9cc028 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -377,6 +377,7 @@ describe("Primitive > detectBot", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { headers: new Headers(), @@ -396,6 +397,7 @@ describe("Primitive > detectBot", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { headers: undefined, @@ -415,6 +417,7 @@ describe("Primitive > detectBot", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -460,6 +463,7 @@ describe("Primitive > detectBot", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -515,6 +519,7 @@ describe("Primitive > detectBot", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -570,6 +575,7 @@ describe("Primitive > detectBot", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -612,6 +618,7 @@ describe("Primitive > detectBot", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -676,6 +683,7 @@ describe("Primitive > detectBot", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -725,6 +733,7 @@ describe("Primitive > detectBot", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -773,6 +782,7 @@ describe("Primitive > detectBot", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -1448,6 +1458,7 @@ describe("Primitive > validateEmail", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { email: "abc@example.com", @@ -1467,6 +1478,7 @@ describe("Primitive > validateEmail", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { email: undefined, @@ -1486,6 +1498,7 @@ describe("Primitive > validateEmail", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -1519,6 +1532,7 @@ describe("Primitive > validateEmail", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -1552,6 +1566,7 @@ describe("Primitive > validateEmail", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -1585,6 +1600,7 @@ describe("Primitive > validateEmail", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -1620,6 +1636,7 @@ describe("Primitive > validateEmail", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -1653,6 +1670,7 @@ describe("Primitive > validateEmail", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -1686,6 +1704,7 @@ describe("Primitive > validateEmail", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -1721,6 +1740,7 @@ describe("Primitive > validateEmail", () => { fingerprint: "test-fingerprint", runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -2425,7 +2445,7 @@ describe("SDK", () => { const context = { key, fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + "fp::2::516289fae7993d35ffb6e76883e09b475bbc7a622a378f3b430f35e8c657687e", }; const request = { ip: "172.100.1.1", @@ -2479,7 +2499,7 @@ describe("SDK", () => { const context = { key, fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + "fp::2::516289fae7993d35ffb6e76883e09b475bbc7a622a378f3b430f35e8c657687e", }; const request = { ip: "172.100.1.1", @@ -2536,7 +2556,7 @@ describe("SDK", () => { const context = { key, fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + "fp::2::516289fae7993d35ffb6e76883e09b475bbc7a622a378f3b430f35e8c657687e", }; const request = { ip: "172.100.1.1", @@ -2632,7 +2652,7 @@ describe("SDK", () => { const context = { key, fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + "fp::2::516289fae7993d35ffb6e76883e09b475bbc7a622a378f3b430f35e8c657687e", }; const request = { ip: "172.100.1.1", @@ -2687,7 +2707,7 @@ describe("SDK", () => { const context = { key, fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + "fp::2::516289fae7993d35ffb6e76883e09b475bbc7a622a378f3b430f35e8c657687e", }; const request = { ip: "172.100.1.1", @@ -2779,7 +2799,7 @@ describe("SDK", () => { const context = { key, fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + "fp::2::516289fae7993d35ffb6e76883e09b475bbc7a622a378f3b430f35e8c657687e", }; const request = { ip: "172.100.1.1", @@ -3086,7 +3106,7 @@ describe("SDK", () => { const context = { key, fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + "fp::2::516289fae7993d35ffb6e76883e09b475bbc7a622a378f3b430f35e8c657687e", }; const request = { ip: "172.100.1.1", @@ -3145,7 +3165,7 @@ describe("SDK", () => { const context = { key, fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + "fp::2::516289fae7993d35ffb6e76883e09b475bbc7a622a378f3b430f35e8c657687e", }; const request = { ip: "172.100.1.1", @@ -3200,7 +3220,7 @@ describe("SDK", () => { const context = { key, fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + "fp::2::516289fae7993d35ffb6e76883e09b475bbc7a622a378f3b430f35e8c657687e", }; const request = { ip: "172.100.1.1", @@ -3244,4 +3264,317 @@ describe("SDK", () => { [], ); }); + + test("additional characteristics are propagated to fixedWindow if they aren't separately specified in fixedWindow", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const rateLimitRule = fixedWindow({ + mode: "LIVE", + window: "1h", + max: 60, + }); + + const localCharacteristics = ["someAdditionalCharacteristic"]; + const aj = arcjet({ + key: "test-key", + characteristics: localCharacteristics, + rules: [rateLimitRule], + client, + log, + }); + + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers(), + }; + + const _ = await aj.protect({}, request); + + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.anything(), + expect.anything(), + [ + { + characteristics: localCharacteristics, + ...rateLimitRule[0], + }, + ], + ); + }); + + test("Additional characteristics aren't propagated to fixedWindow if they are separately specified in fixedWindow", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const localCharacteristics = ["someLocalCharacteristic"]; + const rateLimitRule = fixedWindow({ + mode: "LIVE", + window: "1h", + max: 60, + characteristics: localCharacteristics, + }); + + const aj = arcjet({ + key: "test-key", + characteristics: ["someAdditionalCharacteristic"], + rules: [rateLimitRule], + client, + log, + }); + + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers(), + }; + + const _ = await aj.protect({}, request); + + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.anything(), + expect.anything(), + [ + { + characteristics: localCharacteristics, + ...rateLimitRule[0], + }, + ], + ); + }); + + test("Additional characteristics are propagated to slidingWindow if they aren't separately specified in slidingWindow", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const rateLimitRule = slidingWindow({ + mode: "LIVE", + interval: "1h", + max: 60, + }); + + const globalCharacteristics = ["someAdditionalCharacteristic"]; + const aj = arcjet({ + key: "test-key", + characteristics: globalCharacteristics, + rules: [rateLimitRule], + client, + log, + }); + + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers(), + }; + + const _ = await aj.protect({}, request); + + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.anything(), + expect.anything(), + [ + { + characteristics: globalCharacteristics, + ...rateLimitRule[0], + }, + ], + ); + }); + + test("Additional characteristics aren't propagated to slidingWindow if they are separately specified in slidingWindow", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const localCharacteristics = ["someLocalCharacteristic"]; + const rateLimitRule = slidingWindow({ + mode: "LIVE", + interval: "1h", + max: 60, + characteristics: localCharacteristics, + }); + + const aj = arcjet({ + key: "test-key", + characteristics: ["someAdditionalCharacteristic"], + rules: [rateLimitRule], + client, + log, + }); + + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers(), + }; + + const _ = await aj.protect({}, request); + + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.anything(), + expect.anything(), + [ + { + characteristics: localCharacteristics, + ...rateLimitRule[0], + }, + ], + ); + }); + + test("Additional characteristics are propagated to tokenBucket if they aren't separately specified in tokenBucket", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const rateLimitRule = tokenBucket({ + mode: "LIVE", + interval: "1h", + refillRate: 1, + capacity: 10, + }); + + const globalCharacteristics = ["someAdditionalCharacteristic"]; + const aj = arcjet({ + key: "test-key", + characteristics: globalCharacteristics, + rules: [rateLimitRule], + client, + log, + }); + + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers(), + requested: 1, + }; + + const _ = await aj.protect({}, request); + + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.anything(), + expect.anything(), + [ + { + characteristics: globalCharacteristics, + ...rateLimitRule[0], + }, + ], + ); + }); + + test("additional characteristics aren't propagated to tokenBucket if they are separately specified in tokenBucket", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const localCharacteristics = ["someLocalCharacteristic"]; + const rateLimitRule = tokenBucket({ + mode: "LIVE", + interval: "1h", + refillRate: 1, + capacity: 10, + characteristics: localCharacteristics, + }); + + const aj = arcjet({ + key: "test-key", + characteristics: ["someAdditionalCharacteristic"], + rules: [rateLimitRule], + client, + log, + }); + + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers(), + requested: 1, + }; + + const _ = await aj.protect({}, request); + + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.anything(), + expect.anything(), + [ + { + characteristics: localCharacteristics, + ...rateLimitRule[0], + }, + ], + ); + }); }); diff --git a/protocol/client.ts b/protocol/client.ts index 924d9cef9..5f10cedb1 100644 --- a/protocol/client.ts +++ b/protocol/client.ts @@ -81,6 +81,7 @@ export function createClient(options: ClientOptions): Client { const decideRequest = new DecideRequest({ sdkStack, sdkVersion, + characteristics: context.characteristics, details: { ip: details.ip, method: details.method, @@ -136,6 +137,7 @@ export function createClient(options: ClientOptions): Client { const reportRequest = new ReportRequest({ sdkStack, sdkVersion, + characteristics: context.characteristics, details: { ip: details.ip, method: details.method, diff --git a/protocol/convert.ts b/protocol/convert.ts index 1ee680b81..869fea8b3 100644 --- a/protocol/convert.ts +++ b/protocol/convert.ts @@ -540,7 +540,7 @@ export function ArcjetDecisionFromProtocol( } } -function isRateLimitRule( +export function isRateLimitRule( rule: ArcjetRule, ): rule is ArcjetRateLimitRule { return rule.type === "RATE_LIMIT"; diff --git a/protocol/gen/es/decide/v1alpha1/decide_connect.d.ts b/protocol/gen/es/decide/v1alpha1/decide_connect.d.ts index 0812c343b..d50f8a2f4 100644 --- a/protocol/gen/es/decide/v1alpha1/decide_connect.d.ts +++ b/protocol/gen/es/decide/v1alpha1/decide_connect.d.ts @@ -1,5 +1,5 @@ // @generated by protoc-gen-connect-es v1.4.0 -// @generated from file decide/v1alpha1/decide.proto (package proto.decide.v1alpha1, syntax proto3) +// @generated from file proto/decide/v1alpha1/decide.proto (package proto.decide.v1alpha1, syntax proto3) /* eslint-disable */ // @ts-nocheck diff --git a/protocol/gen/es/decide/v1alpha1/decide_connect.js b/protocol/gen/es/decide/v1alpha1/decide_connect.js index de3322251..fc8f0a5b7 100644 --- a/protocol/gen/es/decide/v1alpha1/decide_connect.js +++ b/protocol/gen/es/decide/v1alpha1/decide_connect.js @@ -1,5 +1,5 @@ // @generated by protoc-gen-connect-es v1.4.0 -// @generated from file decide/v1alpha1/decide.proto (package proto.decide.v1alpha1, syntax proto3) +// @generated from file proto/decide/v1alpha1/decide.proto (package proto.decide.v1alpha1, syntax proto3) /* eslint-disable */ // @ts-nocheck diff --git a/protocol/gen/es/decide/v1alpha1/decide_pb.d.ts b/protocol/gen/es/decide/v1alpha1/decide_pb.d.ts index ece60fa15..66b85896d 100644 --- a/protocol/gen/es/decide/v1alpha1/decide_pb.d.ts +++ b/protocol/gen/es/decide/v1alpha1/decide_pb.d.ts @@ -1,5 +1,5 @@ // @generated by protoc-gen-es v1.8.0 -// @generated from file decide/v1alpha1/decide.proto (package proto.decide.v1alpha1, syntax proto3) +// @generated from file proto/decide/v1alpha1/decide.proto (package proto.decide.v1alpha1, syntax proto3) /* eslint-disable */ // @ts-nocheck @@ -1418,6 +1418,13 @@ export declare class DecideRequest extends Message { */ rules: Rule[]; + /** + * The characteristics that should be used for fingerprinting. + * + * @generated from field: repeated string characteristics = 6; + */ + characteristics: string[]; + constructor(data?: PartialMessage); static readonly runtime: typeof proto3; @@ -1510,6 +1517,13 @@ export declare class ReportRequest extends Message { */ receivedAt?: Timestamp; + /** + * The characteristics that should be used for fingerprinting. + * + * @generated from field: repeated string characteristics = 8; + */ + characteristics: string[]; + constructor(data?: PartialMessage); static readonly runtime: typeof proto3; diff --git a/protocol/gen/es/decide/v1alpha1/decide_pb.js b/protocol/gen/es/decide/v1alpha1/decide_pb.js index 9b9572fdd..724335c97 100644 --- a/protocol/gen/es/decide/v1alpha1/decide_pb.js +++ b/protocol/gen/es/decide/v1alpha1/decide_pb.js @@ -1,5 +1,5 @@ // @generated by protoc-gen-es v1.8.0 -// @generated from file decide/v1alpha1/decide.proto (package proto.decide.v1alpha1, syntax proto3) +// @generated from file proto/decide/v1alpha1/decide.proto (package proto.decide.v1alpha1, syntax proto3) /* eslint-disable */ // @ts-nocheck @@ -416,6 +416,7 @@ export const DecideRequest = /*@__PURE__*/ proto3.makeMessageType( { no: 2, name: "sdk_version", kind: "scalar", T: 9 /* ScalarType.STRING */ }, { no: 4, name: "details", kind: "message", T: RequestDetails }, { no: 5, name: "rules", kind: "message", T: Rule, repeated: true }, + { no: 6, name: "characteristics", kind: "scalar", T: 9 /* ScalarType.STRING */, repeated: true }, ], ); @@ -446,6 +447,7 @@ export const ReportRequest = /*@__PURE__*/ proto3.makeMessageType( { no: 5, name: "decision", kind: "message", T: Decision }, { no: 6, name: "rules", kind: "message", T: Rule, repeated: true }, { no: 7, name: "received_at", kind: "message", T: Timestamp }, + { no: 8, name: "characteristics", kind: "scalar", T: 9 /* ScalarType.STRING */, repeated: true }, ], ); diff --git a/protocol/index.ts b/protocol/index.ts index 356c8dbdd..4e809daa0 100644 --- a/protocol/index.ts +++ b/protocol/index.ts @@ -689,6 +689,7 @@ export interface ArcjetRateLimitRule extends ArcjetRule { type: "RATE_LIMIT"; algorithm: ArcjetRateLimitAlgorithm; + characteristics?: string[]; } export interface ArcjetTokenBucketRateLimitRule @@ -696,7 +697,6 @@ export interface ArcjetTokenBucketRateLimitRule algorithm: "TOKEN_BUCKET"; match?: string; - characteristics?: string[]; refillRate: number; interval: number; capacity: number; @@ -707,7 +707,6 @@ export interface ArcjetFixedWindowRateLimitRule algorithm: "FIXED_WINDOW"; match?: string; - characteristics?: string[]; max: number; window: number; } @@ -717,7 +716,6 @@ export interface ArcjetSlidingWindowRateLimitRule algorithm: "SLIDING_WINDOW"; match?: string; - characteristics?: string[]; max: number; interval: number; } @@ -766,4 +764,5 @@ export type ArcjetContext = { fingerprint: string; runtime: string; log: ArcjetLogger; + characteristics: string[]; }; diff --git a/protocol/test/client.test.ts b/protocol/test/client.test.ts index 5c81137d2..a366719e6 100644 --- a/protocol/test/client.test.ts +++ b/protocol/test/client.test.ts @@ -118,6 +118,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -174,6 +175,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -231,6 +233,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -286,6 +289,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -342,6 +346,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -403,6 +408,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -447,6 +453,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -490,6 +497,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -533,6 +541,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -579,6 +588,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -631,6 +641,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -675,6 +686,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const receivedAt = Timestamp.now(); const details = { @@ -744,6 +756,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const receivedAt = Timestamp.now(); const details = { @@ -812,6 +825,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const receivedAt = Timestamp.now(); const details = { @@ -887,6 +901,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const receivedAt = Timestamp.now(); const details = { @@ -955,6 +970,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const receivedAt = Timestamp.now(); const details = { @@ -1019,6 +1035,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const receivedAt = Timestamp.now(); const details = { @@ -1109,6 +1126,7 @@ describe("createClient", () => { fingerprint, runtime: "test", log, + characteristics: [], }; const details = { ip: "172.100.1.1", @@ -1145,4 +1163,149 @@ describe("createClient", () => { expect(logSpy).toHaveBeenCalledTimes(1); }); + + test("calling `decide` will make RPC with top level characteristics included", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + characteristics: ["src.ip"], + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + email: "abc@example.com", + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.ALLOW, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const _ = await client.decide(context, details, []); + + expect(router.decide).toHaveBeenCalledTimes(1); + expect(router.decide).toHaveBeenCalledWith( + new DecideRequest({ + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + characteristics: ["src.ip"], + rules: [], + sdkStack: SDKStack.SDK_STACK_NODEJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + }), + expect.anything(), + ); + }); + + test("calling `report` will make RPC with top level characteristics included", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + characteristics: ["ip.src"], + }; + const receivedAt = Timestamp.now(); + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + email: "abc@example.com", + }; + + const [promise, resolve] = deferred(); + + const router = { + report: jest.fn((args) => { + resolve(); + return new ReportResponse({}); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + + const decision = new ArcjetDenyDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [ + new ArcjetRuleResult({ + ttl: 0, + state: "RUN", + conclusion: "DENY", + reason: new ArcjetReason(), + }), + ], + }); + client.report(context, details, decision, []); + + await promise; + + expect(router.report).toHaveBeenCalledTimes(1); + expect(router.report).toHaveBeenCalledWith( + new ReportRequest({ + sdkStack: SDKStack.SDK_STACK_NODEJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + decision: { + id: decision.id, + conclusion: Conclusion.DENY, + reason: new Reason(), + ruleResults: [ + new RuleResult({ + ruleId: "", + state: RuleState.RUN, + conclusion: Conclusion.DENY, + reason: new Reason(), + }), + ], + }, + rules: [], + receivedAt, + characteristics: ["ip.src"], + }), + expect.anything(), + ); + }); });