From 4df67bae7967b8481a91d234f29a32f10e1903af Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Tue, 13 Feb 2024 14:22:01 -0700 Subject: [PATCH 1/3] feat: Separate ArcjetRequest and ArcjetRequestDetails types to accept record of headers --- arcjet/index.ts | 34 +++- arcjet/test/index.node.test.ts | 345 ++++++++++++++++++--------------- protocol/index.ts | 5 +- 3 files changed, 220 insertions(+), 164 deletions(-) diff --git a/arcjet/index.ts b/arcjet/index.ts index 6865ac7e7..3ae57a876 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -260,7 +260,9 @@ function toString(value: unknown) { return ""; } -function extraProps(details: ArcjetRequestDetails): Record { +function extraProps( + details: ArcjetRequest, +): Record { const extra: Map = new Map(); for (const [key, value] of Object.entries(details)) { if (isUnknownRequestProperty(key)) { @@ -315,7 +317,7 @@ export function createRemoteClient( query: details.query, // TODO(#208): Re-add body // body: details.body, - extra: extraProps(details), + extra: details.extra, email: typeof details.email === "string" ? details.email : undefined, }, rules: rules.map(ArcjetRuleToProtocol), @@ -364,7 +366,7 @@ export function createRemoteClient( headers: Object.fromEntries(details.headers.entries()), // TODO(#208): Re-add body // body: details.body, - extra: extraProps(details), + extra: details.extra, email: typeof details.email === "string" ? details.email : undefined, }, decision: ArcjetDecisionToProtocol(decision), @@ -625,7 +627,17 @@ export type ExtraProps = Rules extends [] * @property ...extra - Extra data that might be useful for Arcjet. For example, requested tokens are specified as the `requested` property. */ export type ArcjetRequest = Simplify< - Partial & Props + { + [key: string]: unknown; + ip?: string; + method?: string; + protocol?: string; + host?: string; + path?: string; + headers?: Headers | Record; + cookies?: string; + query?: string; + } & Props >; function isLocalRule( @@ -1052,9 +1064,19 @@ export default function arcjet< request = {} as typeof request; } - const details = Object.freeze({ - ...request, + const details: Partial = Object.freeze({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, headers: new ArcjetHeaders(request.headers), + cookies: request.cookies, + query: request.query, + // TODO(#208): Re-add body + // body: request.body, + extra: extraProps(request), + email: typeof request.email === "string" ? request.email : undefined, }); log.time("local"); diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index e3ae8794d..5e18e3c46 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -254,7 +254,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -279,12 +281,7 @@ describe("createRemoteClient", () => { expect(router.decide).toHaveBeenCalledWith( new DecideRequest({ details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, fingerprint, @@ -312,7 +309,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -338,12 +337,7 @@ describe("createRemoteClient", () => { expect(router.decide).toHaveBeenCalledWith( new DecideRequest({ details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, fingerprint, @@ -371,7 +365,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -395,12 +391,7 @@ describe("createRemoteClient", () => { expect(router.decide).toHaveBeenCalledWith( new DecideRequest({ details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, fingerprint, @@ -428,7 +419,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, email: "abc@example.com", }; @@ -453,14 +446,8 @@ describe("createRemoteClient", () => { expect(router.decide).toHaveBeenCalledWith( new DecideRequest({ details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, - email: details.email, }, fingerprint, rules: [], @@ -487,7 +474,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, email: "abc@example.com", }; @@ -517,14 +506,8 @@ describe("createRemoteClient", () => { expect(router.decide).toHaveBeenCalledWith( new DecideRequest({ details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, - email: details.email, }, fingerprint, rules: [new Rule()], @@ -551,7 +534,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -591,7 +576,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -630,7 +617,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -669,7 +658,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -711,7 +702,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -759,7 +752,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -800,7 +795,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, email: "test@example.com", }; @@ -834,14 +831,8 @@ describe("createRemoteClient", () => { sdkVersion: "__ARCJET_SDK_VERSION__", fingerprint, details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, - email: details.email, }, decision: { id: decision.id, @@ -872,7 +863,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [promise, resolve] = deferred(); @@ -905,12 +898,7 @@ describe("createRemoteClient", () => { sdkStack: SDKStack.SDK_STACK_NODEJS, sdkVersion: "__ARCJET_SDK_VERSION__", details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, decision: { @@ -942,7 +930,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [promise, resolve] = deferred(); @@ -975,12 +965,7 @@ describe("createRemoteClient", () => { sdkVersion: "__ARCJET_SDK_VERSION__", fingerprint, details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, decision: { @@ -1019,7 +1004,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [promise, resolve] = deferred(); @@ -1052,12 +1039,7 @@ describe("createRemoteClient", () => { sdkVersion: "__ARCJET_SDK_VERSION__", fingerprint, details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, decision: { @@ -1089,7 +1071,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [promise, resolve] = deferred(); @@ -1118,12 +1102,7 @@ describe("createRemoteClient", () => { sdkVersion: "__ARCJET_SDK_VERSION__", fingerprint, details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, decision: { @@ -1155,7 +1134,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, email: "abc@example.com", }; @@ -1202,14 +1183,8 @@ describe("createRemoteClient", () => { sdkVersion: "__ARCJET_SDK_VERSION__", fingerprint, details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, - email: details.email, }, decision: { id: decision.id, @@ -1247,7 +1222,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [promise, resolve] = deferred(); @@ -1662,7 +1639,9 @@ describe("Primitive > detectBot", () => { ]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -1714,7 +1693,9 @@ describe("Primitive > detectBot", () => { ]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -1766,7 +1747,9 @@ describe("Primitive > detectBot", () => { ]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -1805,7 +1788,9 @@ describe("Primitive > detectBot", () => { ]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot({ @@ -1861,7 +1846,9 @@ describe("Primitive > detectBot", () => { headers: new Headers([["User-Agent", "curl/8.1.2"]]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -1912,7 +1899,9 @@ describe("Primitive > detectBot", () => { ]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -1952,7 +1941,9 @@ describe("Primitive > detectBot", () => { headers: new Headers([["User-Agent", "curl/8.1.2"]]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -2802,6 +2793,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@example.com", + extra: {}, }; const [rule] = validateEmail(); @@ -2833,6 +2825,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz", + extra: {}, }; const [rule] = validateEmail(); @@ -2864,6 +2857,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@localhost", + extra: {}, }; const [rule] = validateEmail(); @@ -2895,6 +2889,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@localhost", + extra: {}, }; const [rule] = validateEmail({ @@ -2928,6 +2923,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "@example.com", + extra: {}, }; const [rule] = validateEmail(); @@ -2959,6 +2955,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@[127.0.0.1]", + extra: {}, }; const [rule] = validateEmail(); @@ -2990,6 +2987,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@localhost", + extra: {}, }; const [rule] = validateEmail({ @@ -3023,6 +3021,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@[127.0.0.1]", + extra: {}, }; const [rule] = validateEmail({ @@ -3343,7 +3342,7 @@ describe("SDK", () => { report: jest.fn(), }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3361,7 +3360,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.conclusion).toEqual("DENY"); expect(allowed.validate).toHaveBeenCalledTimes(1); @@ -3459,7 +3458,7 @@ describe("SDK", () => { report: jest.fn(), }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3477,7 +3476,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.conclusion).toEqual("DENY"); expect(denied.validate).toHaveBeenCalledTimes(1); @@ -3498,13 +3497,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const context = { - key, - fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", - }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3516,12 +3509,12 @@ describe("SDK", () => { const allowed = testRuleLocalAllowed(); const aj = arcjet({ - key, + key: "test-key", rules: [[allowed]], client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(client.report).toHaveBeenCalledTimes(0); expect(client.decide).toHaveBeenCalledTimes(1); // TODO: Validate correct `ruleResults` are sent with `decide` when available @@ -3545,7 +3538,7 @@ describe("SDK", () => { fingerprint: "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3562,11 +3555,21 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(client.decide).toHaveBeenCalledTimes(1); expect(client.decide).toHaveBeenCalledWith( expect.objectContaining(context), - expect.objectContaining(details), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + }, + }), [rule], ); }); @@ -3589,7 +3592,7 @@ describe("SDK", () => { fingerprint: "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3606,11 +3609,21 @@ describe("SDK", () => { client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(client.report).toHaveBeenCalledTimes(1); expect(client.report).toHaveBeenCalledWith( expect.objectContaining(context), - expect.objectContaining(details), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + }, + }), expect.objectContaining({ conclusion: "DENY", }), @@ -3630,8 +3643,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3643,12 +3655,12 @@ describe("SDK", () => { const denied = testRuleLocalDenied(); const aj = arcjet({ - key, + key: "test-key", rules: [[denied]], client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(client.decide).toHaveBeenCalledTimes(0); }); @@ -3670,7 +3682,7 @@ describe("SDK", () => { fingerprint: "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3686,13 +3698,23 @@ describe("SDK", () => { client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(client.report).toHaveBeenCalledTimes(0); expect(client.decide).toHaveBeenCalledTimes(1); expect(client.decide).toHaveBeenCalledWith( expect.objectContaining(context), - expect.objectContaining(details), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + }, + }), [], ); }); @@ -3709,8 +3731,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3721,12 +3742,12 @@ describe("SDK", () => { }; const aj = arcjet({ - key, + key: "test-key", rules: [], client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.isErrored()).toBe(false); @@ -3735,7 +3756,7 @@ describe("SDK", () => { expect(decision.conclusion).toEqual("DENY"); - const decision2 = await aj.protect(details); + const decision2 = await aj.protect(request); expect(decision2.isErrored()).toBe(false); expect(client.decide).toHaveBeenCalledTimes(1); @@ -3777,13 +3798,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const context = { - key, - fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", - }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3794,12 +3809,12 @@ describe("SDK", () => { }; const aj = arcjet({ - key, + key: "test-key", rules: [[testRuleLocalThrow()]], client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(client.report).toHaveBeenCalledTimes(0); expect(client.decide).toHaveBeenCalledTimes(1); @@ -3818,8 +3833,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3837,7 +3851,7 @@ describe("SDK", () => { type: "TEST_RULE_LOCAL_THROW_STRING", priority: 1, validate: jest.fn(), - async protect(context, req) { + async protect(context, details) { errorLogSpy = jest.spyOn(context.log, "error"); throw "Local rule protect failed"; }, @@ -3845,12 +3859,12 @@ describe("SDK", () => { } const aj = arcjet({ - key, + key: "test-key", rules: [[testRuleLocalThrowString()]], client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(errorLogSpy).toHaveBeenCalledTimes(1); expect(errorLogSpy).toHaveBeenCalledWith( @@ -3872,8 +3886,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3891,7 +3904,7 @@ describe("SDK", () => { type: "TEST_RULE_LOCAL_THROW_NULL", priority: 1, validate: jest.fn(), - async protect(context, req) { + async protect(context, details) { errorLogSpy = jest.spyOn(context.log, "error"); throw null; }, @@ -3899,12 +3912,12 @@ describe("SDK", () => { } const aj = arcjet({ - key, + key: "test-key", rules: [[testRuleLocalThrowNull()]], client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(errorLogSpy).toHaveBeenCalledTimes(1); expect(errorLogSpy).toHaveBeenCalledWith( @@ -3926,8 +3939,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3938,19 +3950,19 @@ describe("SDK", () => { }; const aj = arcjet({ - key, + key: "test-key", rules: [[testRuleLocalDryRun()]], client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.isDenied()).toBe(false); expect(client.decide).toBeCalledTimes(1); expect(client.report).toBeCalledTimes(1); - const decision2 = await aj.protect(details); + const decision2 = await aj.protect(request); expect(decision2.isDenied()).toBe(false); @@ -3976,7 +3988,7 @@ describe("SDK", () => { fingerprint: "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3994,14 +4006,24 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.isErrored()).toBe(false); expect(client.decide).toHaveBeenCalledTimes(1); expect(client.decide).toHaveBeenCalledWith( expect.objectContaining(context), - expect.objectContaining(details), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + }, + }), [rule], ); }); @@ -4020,7 +4042,7 @@ describe("SDK", () => { fingerprint: "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -4036,7 +4058,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.isErrored()).toBe(true); @@ -4044,7 +4066,18 @@ describe("SDK", () => { expect(client.report).toHaveBeenCalledTimes(1); expect(client.report).toHaveBeenCalledWith( expect.objectContaining(context), - expect.objectContaining(details), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + + } + }), expect.objectContaining({ conclusion: "ERROR", }), diff --git a/protocol/index.ts b/protocol/index.ts index cfdbc79dc..7c799935a 100644 --- a/protocol/index.ts +++ b/protocol/index.ts @@ -371,16 +371,17 @@ export class ArcjetErrorDecision extends ArcjetDecision { } export interface ArcjetRequestDetails { - [key: string]: unknown; ip: string; method: string; protocol: string; host: string; path: string; - // TODO(#215): Allow `Record` and `Record`? headers: Headers; cookies: string; query: string; + extra: { [key: string]: string }; + // TODO: Consider moving email to `extra` map + email?: string; } export type ArcjetRule = { From c86cd83957e153476893df0728e1fbbf3787f7a3 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Tue, 20 Feb 2024 13:24:35 -0700 Subject: [PATCH 2/3] add tests --- arcjet/index.ts | 2 +- arcjet/test/index.node.test.ts | 197 ++++++++++++++++++++++++++++++--- 2 files changed, 183 insertions(+), 16 deletions(-) diff --git a/arcjet/index.ts b/arcjet/index.ts index 3ae57a876..f2f7425de 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -257,7 +257,7 @@ function toString(value: unknown) { return value ? "true" : "false"; } - return ""; + return ""; } function extraProps( diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index 5e18e3c46..b1865c592 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -3369,7 +3369,7 @@ describe("SDK", () => { expect(denied.protect).toHaveBeenCalledTimes(1); }); - test("works with an empty details object", async () => { + test("works with an empty request object", async () => { const client = { decide: jest.fn(async () => { return new ArcjetAllowDecision({ @@ -3381,7 +3381,7 @@ describe("SDK", () => { report: jest.fn(), }; - const details = {}; + const request = {}; const aj = arcjet({ key: "test-key", @@ -3389,11 +3389,11 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.conclusion).toEqual("ALLOW"); }); - test("does not crash with no details object", async () => { + test("does not crash with no request object", async () => { const client = { decide: jest.fn(async () => { return new ArcjetAllowDecision({ @@ -3428,7 +3428,7 @@ describe("SDK", () => { report: jest.fn(), }; - const details = {}; + const request = {}; const rules: ArcjetRule[][] = []; // We only iterate 4 times because `testRuleMultiple` generates 3 rules @@ -3442,7 +3442,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.conclusion).toEqual("ERROR"); }); @@ -3485,6 +3485,174 @@ describe("SDK", () => { expect(allowed.protect).toHaveBeenCalledTimes(0); }); + test("accepts plain object of headers", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const key = "test-key"; + const context = { + key, + fingerprint: + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + }; + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: { "User-Agent": "curl/8.1.2" }, + "extra-test": "extra-test-value", + }; + + const aj = arcjet({ + key: "test-key", + rules: [], + client, + }); + + const decision = await aj.protect(request); + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.objectContaining(context), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: new Headers(Object.entries(request.headers)), + extra: { + "extra-test": "extra-test-value", + }, + }), + [], + ); + }); + + test("accepts plain object of `raw` headers", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const key = "test-key"; + const context = { + key, + fingerprint: + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + }; + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: { "User-Agent": ["curl/8.1.2", "something"] }, + "extra-test": "extra-test-value", + }; + + const aj = arcjet({ + key: "test-key", + rules: [], + client, + }); + + const decision = await aj.protect(request); + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.objectContaining(context), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: new Headers([ + ["user-agent", "curl/8.1.2"], + ["user-agent", "something"], + ]), + extra: { + "extra-test": "extra-test-value", + }, + }), + [], + ); + }); + + test("converts extra keys with non-string values to string values", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const key = "test-key"; + const context = { + key, + fingerprint: + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + }; + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: { "User-Agent": "curl/8.1.2" }, + "extra-number": 123, + "extra-false": false, + "extra-true": true, + "extra-unsupported": new Date(), + }; + + const aj = arcjet({ + key: "test-key", + rules: [], + client, + }); + + const decision = await aj.protect(request); + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.objectContaining(context), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: new Headers(Object.entries(request.headers)), + extra: { + "extra-number": "123", + "extra-false": "false", + "extra-true": "true", + "extra-unsupported": "", + }, + }), + [], + ); + }); + test("does not call `client.report()` if the local decision is ALLOW", async () => { const client = { decide: jest.fn(async () => { @@ -4068,15 +4236,14 @@ describe("SDK", () => { expect.objectContaining(context), expect.objectContaining({ ip: request.ip, - method: request.method, - protocol: request.protocol, - host: request.host, - path: request.path, - headers: request.headers, - extra: { - "extra-test": "extra-test-value", - - } + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + }, }), expect.objectContaining({ conclusion: "ERROR", From 82498766cecdc8697f4187bc87cfbad467295656 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Tue, 20 Feb 2024 13:26:29 -0700 Subject: [PATCH 3/3] fmt --- arcjet/index.ts | 44 ++++++++++++++++++---------------- arcjet/test/index.node.test.ts | 14 +++++------ 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/arcjet/index.ts b/arcjet/index.ts index f2f7425de..36e0b8f55 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -168,13 +168,14 @@ type LiteralCheck< | boolean | symbol | bigint, -> = IsNever extends false // Must be wider than `never` - ? [T] extends [LiteralType] // Must be narrower than `LiteralType` - ? [LiteralType] extends [T] // Cannot be wider than `LiteralType` - ? false - : true - : false - : false; +> = + IsNever extends false // Must be wider than `never` + ? [T] extends [LiteralType] // Must be narrower than `LiteralType` + ? [LiteralType] extends [T] // Cannot be wider than `LiteralType` + ? false + : true + : false + : false; type IsStringLiteral = LiteralCheck; export interface RemoteClient { @@ -586,20 +587,21 @@ export type Product = ArcjetRule[]; // Note: If a user doesn't provide the object literal to our primitives // directly, we fallback to no required props. They can opt-in by adding the // `as const` suffix to the characteristics array. -type PropsForCharacteristic = IsStringLiteral extends true - ? T extends - | "ip.src" - | "http.host" - | "http.method" - | "http.request.uri.path" - | `http.request.headers["${string}"]` - | `http.request.cookie["${string}"]` - | `http.request.uri.args["${string}"]` - ? {} - : T extends string - ? Record - : never - : {}; +type PropsForCharacteristic = + IsStringLiteral extends true + ? T extends + | "ip.src" + | "http.host" + | "http.method" + | "http.request.uri.path" + | `http.request.headers["${string}"]` + | `http.request.cookie["${string}"]` + | `http.request.uri.args["${string}"]` + ? {} + : T extends string + ? Record + : never + : {}; // Rules can specify they require specific props on an ArcjetRequest type PropsForRule = R extends ArcjetRule ? Props : {}; // We theoretically support an arbitrary amount of rule flattening, diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index b1865c592..8a8ba2de5 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -83,17 +83,15 @@ import arcjet, { // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -type IsEqual = (() => G extends A ? 1 : 2) extends () => G extends B - ? 1 - : 2 - ? true - : false; +type IsEqual = + (() => G extends A ? 1 : 2) extends () => G extends B ? 1 : 2 + ? true + : false; // Type testing utilities type Assert = T; -type Props

= P extends Primitive - ? Props - : never; +type Props

= + P extends Primitive ? Props : never; type RequiredProps

= IsEqual, E>; // Instances of Headers contain symbols that may be different depending