From 41b5ee8fa7be242cc41d66b42a11cd2144b6c1a3 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Tue, 11 Jun 2024 15:46:32 -0400 Subject: [PATCH 1/3] chore!: Move client into protocol and rename builders in adapters --- arcjet-bun/index.ts | 28 +- arcjet-bun/package.json | 1 + arcjet-next/index.ts | 54 +- arcjet-next/package.json | 1 + arcjet-node/index.ts | 28 +- arcjet-node/package.json | 1 + arcjet-sveltekit/index.ts | 19 +- arcjet-sveltekit/package.json | 1 + arcjet/index.ts | 162 +---- arcjet/test/index.node.test.ts | 1110 ------------------------------- protocol/.gitignore | 2 + protocol/client.ts | 182 +++++ protocol/proto.ts | 10 - protocol/test/client.test.ts | 1141 ++++++++++++++++++++++++++++++++ 14 files changed, 1394 insertions(+), 1346 deletions(-) create mode 100644 protocol/client.ts delete mode 100644 protocol/proto.ts create mode 100644 protocol/test/client.test.ts diff --git a/arcjet-bun/index.ts b/arcjet-bun/index.ts index df6e2a17e..1bfad933f 100644 --- a/arcjet-bun/index.ts +++ b/arcjet-bun/index.ts @@ -7,9 +7,6 @@ import core, { Product, ArcjetRequest, ExtraProps, - RemoteClient, - RemoteClientOptions, - createRemoteClient, Arcjet, } from "arcjet"; import findIP from "@arcjet/ip"; @@ -24,6 +21,7 @@ import { platform, } from "@arcjet/env"; import { Logger } from "@arcjet/logger"; +import { createClient } from "@arcjet/protocol/client.js"; // Re-export all named exports from the generic SDK export * from "arcjet"; @@ -65,9 +63,12 @@ type PlainObject = { [key: string]: unknown; }; -export function createBunRemoteClient( - options?: Partial, -): RemoteClient { +export type RemoteClientOptions = { + baseUrl?: string; + timeout?: number; +}; + +export function createRemoteClient(options?: RemoteClientOptions) { // The base URL for the Arcjet API. Will default to the standard production // API unless environment variable `ARCJET_BASE_URL` is set. const url = options?.baseUrl ?? baseUrl(env); @@ -77,18 +78,15 @@ export function createBunRemoteClient( const timeout = options?.timeout ?? (isProduction(env) ? 500 : 1000); // Transport is the HTTP client that the client uses to make requests. - const transport = - options?.transport ?? - createConnectTransport({ - baseUrl: url, - httpVersion: "1.1", - }); + const transport = createConnectTransport({ + baseUrl: url, + httpVersion: "1.1", + }); - // TODO(#223): Create separate options type to exclude these const sdkStack = "BUN"; const sdkVersion = "__ARCJET_SDK_VERSION__"; - return createRemoteClient({ + return createClient({ transport, baseUrl: url, timeout, @@ -255,7 +253,7 @@ function withClient( export default function arcjet( options: ArcjetOptions, ): ArcjetBun>> { - const client = options.client ?? createBunRemoteClient(); + const client = options.client ?? createRemoteClient(); const log = options.log ? options.log diff --git a/arcjet-bun/package.json b/arcjet-bun/package.json index 2364cb09d..5ccafcfb1 100644 --- a/arcjet-bun/package.json +++ b/arcjet-bun/package.json @@ -42,6 +42,7 @@ "@arcjet/headers": "1.0.0-alpha.14", "@arcjet/ip": "1.0.0-alpha.14", "@arcjet/logger": "1.0.0-alpha.14", + "@arcjet/protocol": "1.0.0-alpha.14", "@connectrpc/connect-node": "1.4.0", "arcjet": "1.0.0-alpha.14" }, diff --git a/arcjet-next/index.ts b/arcjet-next/index.ts index 49155da96..31344bcc5 100644 --- a/arcjet-next/index.ts +++ b/arcjet-next/index.ts @@ -14,9 +14,6 @@ import arcjet, { Product, ArcjetRequest, ExtraProps, - RemoteClient, - RemoteClientOptions, - createRemoteClient, Arcjet, } from "arcjet"; import findIP from "@arcjet/ip"; @@ -29,6 +26,7 @@ import { platform, } from "@arcjet/env"; import { Logger } from "@arcjet/logger"; +import { createClient } from "@arcjet/protocol/client.js"; // Re-export all named exports from the generic SDK export * from "arcjet"; @@ -70,9 +68,12 @@ type PlainObject = { [key: string]: unknown; }; -export function createNextRemoteClient( - options?: Partial, -): RemoteClient { +export type RemoteClientOptions = { + baseUrl?: string; + timeout?: number; +}; + +export function createRemoteClient(options?: RemoteClientOptions) { // The base URL for the Arcjet API. Will default to the standard production // API unless environment variable `ARCJET_BASE_URL` is set. const url = options?.baseUrl ?? baseUrl(process.env); @@ -84,30 +85,27 @@ export function createNextRemoteClient( // Transport is the HTTP client that the client uses to make requests. // The Connect Node client doesn't work on edge runtimes: https://github.com/bufbuild/connect-es/pull/589 // so set the transport using connect-web. The interceptor is required for it work in the edge runtime. - const transport = - options?.transport ?? - createConnectTransport({ - baseUrl: url, - interceptors: [ - /** - * Ensures redirects are followed to properly support the Next.js/Vercel Edge - * Runtime. - * @see - * https://github.com/connectrpc/connect-es/issues/749#issuecomment-1693507516 - */ - (next) => (req) => { - req.init.redirect = "follow"; - return next(req); - }, - ], - fetch, - }); - - // TODO(#223): Create separate options type to exclude these + const transport = createConnectTransport({ + baseUrl: url, + interceptors: [ + /** + * Ensures redirects are followed to properly support the Next.js/Vercel Edge + * Runtime. + * @see + * https://github.com/connectrpc/connect-es/issues/749#issuecomment-1693507516 + */ + (next) => (req) => { + req.init.redirect = "follow"; + return next(req); + }, + ], + fetch, + }); + const sdkStack = "NEXTJS"; const sdkVersion = "__ARCJET_SDK_VERSION__"; - return createRemoteClient({ + return createClient({ transport, baseUrl: url, timeout, @@ -337,7 +335,7 @@ function withClient( export default function arcjetNext( options: ArcjetOptions, ): ArcjetNext>> { - const client = options.client ?? createNextRemoteClient(); + const client = options.client ?? createRemoteClient(); const log = options.log ? options.log diff --git a/arcjet-next/package.json b/arcjet-next/package.json index 366bc1219..caaf6b08e 100644 --- a/arcjet-next/package.json +++ b/arcjet-next/package.json @@ -44,6 +44,7 @@ "@arcjet/headers": "1.0.0-alpha.14", "@arcjet/ip": "1.0.0-alpha.14", "@arcjet/logger": "1.0.0-alpha.14", + "@arcjet/protocol": "1.0.0-alpha.14", "@connectrpc/connect-web": "1.4.0", "arcjet": "1.0.0-alpha.14" }, diff --git a/arcjet-node/index.ts b/arcjet-node/index.ts index 3efaf7037..fe002f249 100644 --- a/arcjet-node/index.ts +++ b/arcjet-node/index.ts @@ -6,9 +6,6 @@ import core, { Product, ArcjetRequest, ExtraProps, - RemoteClient, - RemoteClientOptions, - createRemoteClient, Arcjet, } from "arcjet"; import findIP from "@arcjet/ip"; @@ -21,6 +18,7 @@ import { platform, } from "@arcjet/env"; import { Logger } from "@arcjet/logger"; +import { createClient } from "@arcjet/protocol/client.js"; // Re-export all named exports from the generic SDK export * from "arcjet"; @@ -62,9 +60,12 @@ type PlainObject = { [key: string]: unknown; }; -export function createNodeRemoteClient( - options?: Partial, -): RemoteClient { +export type RemoteClientOptions = { + baseUrl?: string; + timeout?: number; +}; + +export function createRemoteClient(options?: RemoteClientOptions) { // The base URL for the Arcjet API. Will default to the standard production // API unless environment variable `ARCJET_BASE_URL` is set. const url = options?.baseUrl ?? baseUrl(process.env); @@ -74,18 +75,15 @@ export function createNodeRemoteClient( const timeout = options?.timeout ?? (isProduction(process.env) ? 500 : 1000); // Transport is the HTTP client that the client uses to make requests. - const transport = - options?.transport ?? - createConnectTransport({ - baseUrl: url, - httpVersion: "2", - }); + const transport = createConnectTransport({ + baseUrl: url, + httpVersion: "2", + }); - // TODO(#223): Create separate options type to exclude these const sdkStack = "NODEJS"; const sdkVersion = "__ARCJET_SDK_VERSION__"; - return createRemoteClient({ + return createClient({ transport, baseUrl: url, timeout, @@ -249,7 +247,7 @@ function withClient( export default function arcjet( options: ArcjetOptions, ): ArcjetNode>> { - const client = options.client ?? createNodeRemoteClient(); + const client = options.client ?? createRemoteClient(); const log = options.log ? options.log diff --git a/arcjet-node/package.json b/arcjet-node/package.json index 7fbc896c8..4d1428047 100644 --- a/arcjet-node/package.json +++ b/arcjet-node/package.json @@ -44,6 +44,7 @@ "@arcjet/headers": "1.0.0-alpha.14", "@arcjet/ip": "1.0.0-alpha.14", "@arcjet/logger": "1.0.0-alpha.14", + "@arcjet/protocol": "1.0.0-alpha.14", "@connectrpc/connect-node": "1.4.0", "arcjet": "1.0.0-alpha.14" }, diff --git a/arcjet-sveltekit/index.ts b/arcjet-sveltekit/index.ts index f58549d53..371ee79b2 100644 --- a/arcjet-sveltekit/index.ts +++ b/arcjet-sveltekit/index.ts @@ -7,9 +7,6 @@ import core, { Product, ArcjetRequest, ExtraProps, - RemoteClient, - RemoteClientOptions, - createRemoteClient, Arcjet, } from "arcjet"; import findIP from "@arcjet/ip"; @@ -24,6 +21,7 @@ import { platform, } from "@arcjet/env"; import { Logger } from "@arcjet/logger"; +import { createClient } from "@arcjet/protocol/client.js"; // Re-export all named exports from the generic SDK export * from "arcjet"; @@ -93,9 +91,12 @@ function defaultTransport(baseUrl: string) { } } -export function createSvelteKitRemoteClient( - options?: Partial, -): RemoteClient { +export type RemoteClientOptions = { + baseUrl?: string; + timeout?: number; +}; + +export function createRemoteClient(options?: RemoteClientOptions) { // The base URL for the Arcjet API. Will default to the standard production // API unless environment variable `ARCJET_BASE_URL` is set. const url = options?.baseUrl ?? baseUrl(env); @@ -105,13 +106,13 @@ export function createSvelteKitRemoteClient( const timeout = options?.timeout ?? (isProduction(env) ? 500 : 1000); // Transport is the HTTP client that the client uses to make requests. - const transport = options?.transport ?? defaultTransport(url); + const transport = defaultTransport(url); // TODO(#223): Create separate options type to exclude these const sdkStack = "SVELTEKIT"; const sdkVersion = "__ARCJET_SDK_VERSION__"; - return createRemoteClient({ + return createClient({ transport, baseUrl: url, timeout, @@ -255,7 +256,7 @@ function withClient( export default function arcjet( options: ArcjetOptions, ): ArcjetSvelteKit>> { - const client = options.client ?? createSvelteKitRemoteClient(); + const client = options.client ?? createRemoteClient(); const log = options.log ? options.log diff --git a/arcjet-sveltekit/package.json b/arcjet-sveltekit/package.json index 81e3ad274..1374db5fb 100644 --- a/arcjet-sveltekit/package.json +++ b/arcjet-sveltekit/package.json @@ -44,6 +44,7 @@ "@arcjet/headers": "1.0.0-alpha.14", "@arcjet/ip": "1.0.0-alpha.14", "@arcjet/logger": "1.0.0-alpha.14", + "@arcjet/protocol": "1.0.0-alpha.14", "@arcjet/runtime": "1.0.0-alpha.14", "@connectrpc/connect-node": "1.4.0", "@connectrpc/connect-web": "1.4.0", diff --git a/arcjet/index.ts b/arcjet/index.ts index ba2451599..db02f1c82 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -9,7 +9,6 @@ import { ArcjetMode, ArcjetReason, ArcjetRuleResult, - ArcjetStack, ArcjetDecision, ArcjetDenyDecision, ArcjetErrorDecision, @@ -23,21 +22,8 @@ import { ArcjetShieldRule, ArcjetLogger, } from "@arcjet/protocol"; -import { - ArcjetBotTypeToProtocol, - ArcjetStackToProtocol, - ArcjetDecisionFromProtocol, - ArcjetDecisionToProtocol, - ArcjetRuleToProtocol, -} from "@arcjet/protocol/convert.js"; -import { - createPromiseClient, - Transport, - DecideRequest, - DecideService, - ReportRequest, - Timestamp, -} from "@arcjet/protocol/proto.js"; +import { ArcjetBotTypeToProtocol } from "@arcjet/protocol/convert.js"; +import { Client } from "@arcjet/protocol/client.js"; import * as analyze from "@arcjet/analyze"; import * as duration from "@arcjet/duration"; import ArcjetHeaders from "@arcjet/headers"; @@ -175,30 +161,6 @@ type LiteralCheck< : false; type IsStringLiteral = LiteralCheck; -export interface RemoteClient { - decide( - context: ArcjetContext, - details: Partial, - rules: ArcjetRule[], - ): Promise; - // Call the Arcjet Log Decision API with details of the request and decision - // made so we can log it. - report( - context: ArcjetContext, - request: Partial, - decision: ArcjetDecision, - rules: ArcjetRule[], - ): void; -} - -export type RemoteClientOptions = { - transport: Transport; - baseUrl: string; - timeout: number; - sdkStack: ArcjetStack; - sdkVersion: string; -}; - const knownFields = [ "ip", "method", @@ -244,124 +206,6 @@ function extraProps( return Object.fromEntries(extra.entries()); } -export function createRemoteClient(options: RemoteClientOptions): RemoteClient { - const { transport, sdkVersion, baseUrl, timeout } = options; - - const sdkStack = ArcjetStackToProtocol(options.sdkStack); - - const client = createPromiseClient(DecideService, transport); - - return Object.freeze({ - async decide( - context: ArcjetContext, - details: ArcjetRequestDetails, - rules: ArcjetRule[], - ): Promise { - const { log } = context; - - // Build the request object from the Protobuf generated class. - const decideRequest = new DecideRequest({ - sdkStack, - sdkVersion, - details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - headers: Object.fromEntries(details.headers.entries()), - cookies: details.cookies, - query: details.query, - // TODO(#208): Re-add body - // body: details.body, - extra: details.extra, - email: typeof details.email === "string" ? details.email : undefined, - }, - rules: rules.map(ArcjetRuleToProtocol), - }); - - log.debug("Decide request to %s", baseUrl); - - const response = await client.decide(decideRequest, { - headers: { Authorization: `Bearer ${context.key}` }, - timeoutMs: timeout, - }); - - const decision = ArcjetDecisionFromProtocol(response.decision); - - log.debug( - { - id: decision.id, - fingerprint: context.fingerprint, - path: details.path, - runtime: context.runtime, - ttl: decision.ttl, - conclusion: decision.conclusion, - reason: decision.reason, - ruleResults: decision.results, - }, - "Decide response", - ); - - return decision; - }, - - report( - context: ArcjetContext, - details: ArcjetRequestDetails, - decision: ArcjetDecision, - rules: ArcjetRule[], - ): void { - const { log } = context; - - // Build the request object from the Protobuf generated class. - const reportRequest = new ReportRequest({ - sdkStack, - sdkVersion, - details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - headers: Object.fromEntries(details.headers.entries()), - // TODO(#208): Re-add body - // body: details.body, - extra: details.extra, - email: typeof details.email === "string" ? details.email : undefined, - }, - decision: ArcjetDecisionToProtocol(decision), - rules: rules.map(ArcjetRuleToProtocol), - receivedAt: Timestamp.now(), - }); - - log.debug("Report request to %s", baseUrl); - - // We use the promise API directly to avoid returning a promise from this function so execution can't be paused with `await` - client - .report(reportRequest, { - headers: { Authorization: `Bearer ${context.key}` }, - timeoutMs: 2_000, // 2 seconds - }) - .then((response) => { - log.debug( - { - id: response.decision?.id, - fingerprint: context.fingerprint, - path: details.path, - runtime: context.runtime, - ttl: decision.ttl, - }, - "Report response", - ); - }) - .catch((err: unknown) => { - log.info("Encountered problem sending report: %s", errorMessage(err)); - }); - }, - }); -} - type TokenBucketRateLimitOptions = { mode?: ArcjetMode; match?: string; @@ -957,7 +801,7 @@ export interface ArcjetOptions { * The client used to make requests to the Arcjet API. This must be set * when creating the SDK, such as inside @arcjet/next or mocked in tests. */ - client?: RemoteClient; + client?: Client; /** * The logger used to emit useful information from the SDK. */ diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index 6f2fe4c52..649e1328b 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -9,32 +9,15 @@ import { jest, test, } from "@jest/globals"; -import { - createRouterTransport, - DecideRequest, - DecideResponse, - DecideService, - Conclusion, - ReportRequest, - ReportResponse, - Reason, - Rule, - SDKStack, - Timestamp, - RuleResult, - RuleState, -} from "@arcjet/protocol/proto"; import { Logger } from "@arcjet/logger"; import arcjet, { - ArcjetDecision, ArcjetMode, detectBot, rateLimit, ArcjetRule, validateEmail, protectSignup, - createRemoteClient, ArcjetBotType, ArcjetEmailType, ArcjetAllowDecision, @@ -43,7 +26,6 @@ import arcjet, { ArcjetChallengeDecision, ArcjetReason, ArcjetErrorReason, - ArcjetConclusion, ArcjetRuleResult, ArcjetEmailReason, ArcjetBotReason, @@ -138,1102 +120,10 @@ function assertIsLocalRule(rule: ArcjetRule): asserts rule is ArcjetLocalRule { expect("protect" in rule && typeof rule.protect === "function").toEqual(true); } -function deferred(): [Promise, () => void, (reason?: unknown) => void] { - let resolve: () => void; - let reject: (reason?: unknown) => void; - const promise = new Promise((res, rej) => { - resolve = res; - reject = rej; - }); - - // @ts-expect-error - return [promise, resolve, reject]; -} - class ArcjetTestReason extends ArcjetReason {} -class ArcjetInvalidDecision extends ArcjetDecision { - reason: ArcjetReason; - conclusion: ArcjetConclusion; - - constructor() { - super({ ttl: 0, results: [] }); - // @ts-expect-error - this.conclusion = "INVALID"; - this.reason = new ArcjetTestReason(); - } -} - const log = new Logger({ level: "info" }); -describe("createRemoteClient", () => { - const defaultRemoteClientOptions = { - baseUrl: "", - timeout: 0, - sdkStack: "NODEJS" as const, - sdkVersion: "__ARCJET_SDK_VERSION__", - }; - - test("can be called with only a transport", () => { - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(() => {}), - }); - expect(typeof client.decide).toEqual("function"); - expect(typeof client.report).toEqual("function"); - }); - - test("allows overriding the default timeout", async () => { - // TODO(#32): createRouterTransport doesn't seem to handle timeouts/promises correctly - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, {}); - }), - timeout: 300, - }); - expect(typeof client.decide).toEqual("function"); - expect(typeof client.report).toEqual("function"); - }); - - test("allows overriding the sdkStack", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const router = { - decide: jest.fn((args) => { - return new DecideResponse({ - decision: { - conclusion: Conclusion.ALLOW, - }, - }); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - sdkStack: "NEXTJS", - }); - const _ = await client.decide(context, details, []); - - expect(router.decide).toHaveBeenCalledTimes(1); - expect(router.decide).toHaveBeenCalledWith( - new DecideRequest({ - details: { - ...details, - headers: { "user-agent": "curl/8.1.2" }, - }, - rules: [], - sdkStack: SDKStack.SDK_STACK_NEXTJS, - sdkVersion: "__ARCJET_SDK_VERSION__", - }), - expect.anything(), - ); - }); - - test("sets the sdkStack as UNSPECIFIED if invalid", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const router = { - decide: jest.fn((args) => { - return new DecideResponse({ - decision: { - conclusion: Conclusion.ALLOW, - }, - }); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - // @ts-expect-error - sdkStack: "SOMETHING_INVALID", - }); - const _ = await client.decide(context, details, []); - - expect(router.decide).toHaveBeenCalledTimes(1); - expect(router.decide).toHaveBeenCalledWith( - new DecideRequest({ - details: { - ...details, - headers: { "user-agent": "curl/8.1.2" }, - }, - rules: [], - sdkStack: SDKStack.SDK_STACK_UNSPECIFIED, - sdkVersion: "__ARCJET_SDK_VERSION__", - }), - expect.anything(), - ); - }); - - test("calling `decide` will make RPC call with correct message", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const router = { - decide: jest.fn((args) => { - return new DecideResponse({ - decision: { - conclusion: Conclusion.ALLOW, - }, - }); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const _ = await client.decide(context, details, []); - - expect(router.decide).toHaveBeenCalledTimes(1); - expect(router.decide).toHaveBeenCalledWith( - new DecideRequest({ - details: { - ...details, - headers: { "user-agent": "curl/8.1.2" }, - }, - rules: [], - sdkStack: SDKStack.SDK_STACK_NODEJS, - sdkVersion: "__ARCJET_SDK_VERSION__", - }), - expect.anything(), - ); - }); - - test("calling `decide` will make RPC with email included", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - email: "abc@example.com", - }; - - const router = { - decide: jest.fn((args) => { - return new DecideResponse({ - decision: { - conclusion: Conclusion.ALLOW, - }, - }); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const _ = await client.decide(context, details, []); - - expect(router.decide).toHaveBeenCalledTimes(1); - expect(router.decide).toHaveBeenCalledWith( - new DecideRequest({ - details: { - ...details, - headers: { "user-agent": "curl/8.1.2" }, - }, - rules: [], - sdkStack: SDKStack.SDK_STACK_NODEJS, - sdkVersion: "__ARCJET_SDK_VERSION__", - }), - expect.anything(), - ); - }); - - test("calling `decide` will make RPC with rules included", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - email: "abc@example.com", - }; - - const router = { - decide: jest.fn((args) => { - return new DecideResponse({ - decision: { - conclusion: Conclusion.ALLOW, - }, - }); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const rule: ArcjetRule = { - type: "TEST_RULE", - mode: "DRY_RUN", - priority: 1, - }; - const _ = await client.decide(context, details, [rule]); - - expect(router.decide).toHaveBeenCalledTimes(1); - expect(router.decide).toHaveBeenCalledWith( - new DecideRequest({ - details: { - ...details, - headers: { "user-agent": "curl/8.1.2" }, - }, - rules: [new Rule()], - sdkStack: SDKStack.SDK_STACK_NODEJS, - sdkVersion: "__ARCJET_SDK_VERSION__", - }), - expect.anything(), - ); - }); - - test("calling `decide` creates an ALLOW ArcjetDecision if DecideResponse is allowed", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const router = { - decide: jest.fn((args) => { - return new DecideResponse({ - decision: { - conclusion: Conclusion.ALLOW, - }, - }); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const decision = await client.decide(context, details, []); - - expect(decision.isErrored()).toBe(false); - expect(decision.isAllowed()).toBe(true); - }); - - test("calling `decide` creates a DENY ArcjetDecision if DecideResponse is denied", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const router = { - decide: jest.fn((args) => { - return new DecideResponse({ - decision: { - conclusion: Conclusion.DENY, - }, - }); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const decision = await client.decide(context, details, []); - - expect(decision.isDenied()).toBe(true); - }); - - test("calling `decide` creates a CHALLENGE ArcjetDecision if DecideResponse is challenged", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const router = { - decide: jest.fn((args) => { - return new DecideResponse({ - decision: { - conclusion: Conclusion.CHALLENGE, - }, - }); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const decision = await client.decide(context, details, []); - - expect(decision.isChallenged()).toBe(true); - }); - - test("calling `decide` creates an ERROR ArcjetDecision with default message if DecideResponse is error and no reason", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const router = { - decide: jest.fn((args) => { - return new DecideResponse({ - decision: { - conclusion: Conclusion.ERROR, - }, - }); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const decision = await client.decide(context, details, []); - - expect(decision.isErrored()).toBe(true); - expect(decision.reason).toMatchObject({ - message: "Unknown error occurred", - }); - }); - - test("calling `decide` creates an ERROR ArcjetDecision with message if DecideResponse if error and reason available", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const router = { - decide: jest.fn((args) => { - return new DecideResponse({ - decision: { - conclusion: Conclusion.ERROR, - reason: { - reason: { - case: "error", - value: { message: "Boom!" }, - }, - }, - }, - }); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const decision = await client.decide(context, details, []); - - expect(decision.isErrored()).toBe(true); - expect(decision.reason).toMatchObject({ - message: "Boom!", - }); - }); - - test("calling `decide` creates an ERROR ArcjetDecision if DecideResponse is unspecified", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const router = { - decide: jest.fn((args) => { - return new DecideResponse({ - decision: { - conclusion: Conclusion.UNSPECIFIED, - }, - }); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const decision = await client.decide(context, details, []); - - expect(decision.isErrored()).toBe(true); - expect(decision.isAllowed()).toBe(true); - }); - - test("calling `report` will make RPC call with ALLOW decision", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const receivedAt = Timestamp.now(); - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - email: "test@example.com", - }; - - const [promise, resolve] = deferred(); - - const router = { - report: jest.fn((args) => { - resolve(); - return new ReportResponse({}); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const decision = new ArcjetAllowDecision({ - ttl: 0, - reason: new ArcjetTestReason(), - results: [], - }); - client.report(context, details, decision, []); - - await promise; - - expect(router.report).toHaveBeenCalledTimes(1); - expect(router.report).toHaveBeenCalledWith( - new ReportRequest({ - sdkStack: SDKStack.SDK_STACK_NODEJS, - sdkVersion: "__ARCJET_SDK_VERSION__", - details: { - ...details, - headers: { "user-agent": "curl/8.1.2" }, - }, - decision: { - id: decision.id, - conclusion: Conclusion.ALLOW, - reason: new Reason(), - ruleResults: [], - }, - receivedAt, - }), - expect.anything(), - ); - }); - - test("calling `report` will make RPC call with DENY decision", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const receivedAt = Timestamp.now(); - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const [promise, resolve] = deferred(); - - const router = { - report: jest.fn((args) => { - resolve(); - return new ReportResponse({}); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const decision = new ArcjetDenyDecision({ - ttl: 0, - reason: new ArcjetTestReason(), - results: [], - }); - client.report(context, details, decision, []); - - await promise; - - expect(router.report).toHaveBeenCalledTimes(1); - expect(router.report).toHaveBeenCalledWith( - new ReportRequest({ - sdkStack: SDKStack.SDK_STACK_NODEJS, - sdkVersion: "__ARCJET_SDK_VERSION__", - details: { - ...details, - headers: { "user-agent": "curl/8.1.2" }, - }, - decision: { - id: decision.id, - conclusion: Conclusion.DENY, - reason: new Reason(), - ruleResults: [], - }, - receivedAt, - }), - expect.anything(), - ); - }); - - test("calling `report` will make RPC call with ERROR decision", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const receivedAt = Timestamp.now(); - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const [promise, resolve] = deferred(); - - const router = { - report: jest.fn((args) => { - resolve(); - return new ReportResponse({}); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const decision = new ArcjetErrorDecision({ - ttl: 0, - reason: new ArcjetErrorReason("Failure"), - results: [], - }); - client.report(context, details, decision, []); - - await promise; - - expect(router.report).toHaveBeenCalledTimes(1); - expect(router.report).toHaveBeenCalledWith( - new ReportRequest({ - sdkStack: SDKStack.SDK_STACK_NODEJS, - sdkVersion: "__ARCJET_SDK_VERSION__", - details: { - ...details, - headers: { "user-agent": "curl/8.1.2" }, - }, - decision: { - id: decision.id, - conclusion: Conclusion.ERROR, - reason: new Reason({ - reason: { - case: "error", - value: { - message: "Failure", - }, - }, - }), - ruleResults: [], - }, - receivedAt, - }), - expect.anything(), - ); - }); - - test("calling `report` will make RPC call with CHALLENGE decision", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const receivedAt = Timestamp.now(); - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const [promise, resolve] = deferred(); - - const router = { - report: jest.fn((args) => { - resolve(); - return new ReportResponse({}); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const decision = new ArcjetChallengeDecision({ - ttl: 0, - reason: new ArcjetTestReason(), - results: [], - }); - client.report(context, details, decision, []); - - await promise; - - expect(router.report).toHaveBeenCalledTimes(1); - expect(router.report).toHaveBeenCalledWith( - new ReportRequest({ - sdkStack: SDKStack.SDK_STACK_NODEJS, - sdkVersion: "__ARCJET_SDK_VERSION__", - details: { - ...details, - headers: { "user-agent": "curl/8.1.2" }, - }, - decision: { - id: decision.id, - conclusion: Conclusion.CHALLENGE, - reason: new Reason(), - ruleResults: [], - }, - receivedAt, - }), - expect.anything(), - ); - }); - - test("calling `report` will make RPC call with UNSPECIFIED decision if invalid", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const receivedAt = Timestamp.now(); - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const [promise, resolve] = deferred(); - - const router = { - report: jest.fn((args) => { - resolve(); - return new ReportResponse({}); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - const decision = new ArcjetInvalidDecision(); - client.report(context, details, decision, []); - - await promise; - - expect(router.report).toHaveBeenCalledTimes(1); - expect(router.report).toHaveBeenCalledWith( - new ReportRequest({ - sdkStack: SDKStack.SDK_STACK_NODEJS, - sdkVersion: "__ARCJET_SDK_VERSION__", - details: { - ...details, - headers: { "user-agent": "curl/8.1.2" }, - }, - decision: { - id: decision.id, - conclusion: Conclusion.UNSPECIFIED, - reason: new Reason(), - ruleResults: [], - }, - receivedAt, - }), - expect.anything(), - ); - }); - - test("calling `report` will make RPC with rules included", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const receivedAt = Timestamp.now(); - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - email: "abc@example.com", - }; - - const [promise, resolve] = deferred(); - - const router = { - report: jest.fn((args) => { - resolve(); - return new ReportResponse({}); - }), - }; - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, router); - }), - }); - - const decision = new ArcjetDenyDecision({ - ttl: 0, - reason: new ArcjetTestReason(), - results: [ - new ArcjetRuleResult({ - ttl: 0, - state: "RUN", - conclusion: "DENY", - reason: new ArcjetReason(), - }), - ], - }); - const rule: ArcjetRule = { - type: "TEST_RULE", - mode: "LIVE", - priority: 1, - }; - client.report(context, details, decision, [rule]); - - await promise; - - expect(router.report).toHaveBeenCalledTimes(1); - expect(router.report).toHaveBeenCalledWith( - new ReportRequest({ - sdkStack: SDKStack.SDK_STACK_NODEJS, - sdkVersion: "__ARCJET_SDK_VERSION__", - details: { - ...details, - headers: { "user-agent": "curl/8.1.2" }, - }, - decision: { - id: decision.id, - conclusion: Conclusion.DENY, - reason: new Reason(), - ruleResults: [ - new RuleResult({ - ruleId: "", - state: RuleState.RUN, - conclusion: Conclusion.DENY, - reason: new Reason(), - }), - ], - }, - rules: [new Rule()], - receivedAt, - }), - expect.anything(), - ); - }); - - test("calling `report` only logs if it fails", async () => { - const key = "test-key"; - const fingerprint = - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; - const context = { - key, - fingerprint, - runtime: "test", - log, - }; - const details = { - ip: "172.100.1.1", - method: "GET", - protocol: "http", - host: "example.com", - path: "/", - headers: new Headers([["User-Agent", "curl/8.1.2"]]), - extra: { - "extra-test": "extra-test-value", - }, - }; - - const [promise, resolve] = deferred(); - - const logSpy = jest.spyOn(log, "info").mockImplementation(() => { - resolve(); - }); - - const client = createRemoteClient({ - ...defaultRemoteClientOptions, - transport: createRouterTransport(({ service }) => { - service(DecideService, {}); - }), - }); - const decision = new ArcjetAllowDecision({ - ttl: 0, - reason: new ArcjetTestReason(), - results: [], - }); - client.report(context, details, decision, []); - - await promise; - - expect(logSpy).toHaveBeenCalledTimes(1); - }); -}); - describe("ArcjetDecision", () => { test("will default the `id` property if not specified", () => { const decision = new ArcjetAllowDecision({ diff --git a/protocol/.gitignore b/protocol/.gitignore index 5de693ee9..e58045176 100644 --- a/protocol/.gitignore +++ b/protocol/.gitignore @@ -132,6 +132,8 @@ dist # Generated files index.js index.d.ts +client.js +client.d.ts convert.js convert.d.ts proto.js diff --git a/protocol/client.ts b/protocol/client.ts new file mode 100644 index 000000000..924d9cef9 --- /dev/null +++ b/protocol/client.ts @@ -0,0 +1,182 @@ +import { Transport, createPromiseClient } from "@connectrpc/connect"; +import { Timestamp } from "@bufbuild/protobuf"; +import { + ArcjetDecisionFromProtocol, + ArcjetDecisionToProtocol, + ArcjetRuleToProtocol, + ArcjetStackToProtocol, +} from "./convert.js"; +import { + ArcjetContext, + ArcjetDecision, + ArcjetRequestDetails, + ArcjetRule, + ArcjetStack, +} from "./index.js"; +import { DecideService } from "./gen/es/decide/v1alpha1/decide_connect.js"; +import { + DecideRequest, + ReportRequest, +} from "./gen/es/decide/v1alpha1/decide_pb.js"; + +// TODO: Dedupe with `errorMessage` in core +function errorMessage(err: unknown): string { + if (err) { + if (typeof err === "string") { + return err; + } + + if ( + typeof err === "object" && + "message" in err && + typeof err.message === "string" + ) { + return err.message; + } + } + + return "Unknown problem"; +} + +export interface Client { + decide( + context: ArcjetContext, + details: Partial, + rules: ArcjetRule[], + ): Promise; + // Call the Arcjet Log Decision API with details of the request and decision + // made so we can log it. + report( + context: ArcjetContext, + request: Partial, + decision: ArcjetDecision, + rules: ArcjetRule[], + ): void; +} + +export type ClientOptions = { + transport: Transport; + baseUrl: string; + timeout: number; + sdkStack: ArcjetStack; + sdkVersion: string; +}; + +export function createClient(options: ClientOptions): Client { + const { transport, sdkVersion, baseUrl, timeout } = options; + + const sdkStack = ArcjetStackToProtocol(options.sdkStack); + + const client = createPromiseClient(DecideService, transport); + + return Object.freeze({ + async decide( + context: ArcjetContext, + details: ArcjetRequestDetails, + rules: ArcjetRule[], + ): Promise { + const { log } = context; + + // Build the request object from the Protobuf generated class. + const decideRequest = new DecideRequest({ + sdkStack, + sdkVersion, + details: { + ip: details.ip, + method: details.method, + protocol: details.protocol, + host: details.host, + path: details.path, + headers: Object.fromEntries(details.headers.entries()), + cookies: details.cookies, + query: details.query, + // TODO(#208): Re-add body + // body: details.body, + extra: details.extra, + email: typeof details.email === "string" ? details.email : undefined, + }, + rules: rules.map(ArcjetRuleToProtocol), + }); + + log.debug("Decide request to %s", baseUrl); + + const response = await client.decide(decideRequest, { + headers: { Authorization: `Bearer ${context.key}` }, + timeoutMs: timeout, + }); + + const decision = ArcjetDecisionFromProtocol(response.decision); + + log.debug( + { + id: decision.id, + fingerprint: context.fingerprint, + path: details.path, + runtime: context.runtime, + ttl: decision.ttl, + conclusion: decision.conclusion, + reason: decision.reason, + ruleResults: decision.results, + }, + "Decide response", + ); + + return decision; + }, + + report( + context: ArcjetContext, + details: ArcjetRequestDetails, + decision: ArcjetDecision, + rules: ArcjetRule[], + ): void { + const { log } = context; + + // Build the request object from the Protobuf generated class. + const reportRequest = new ReportRequest({ + sdkStack, + sdkVersion, + details: { + ip: details.ip, + method: details.method, + protocol: details.protocol, + host: details.host, + path: details.path, + headers: Object.fromEntries(details.headers.entries()), + // TODO(#208): Re-add body + // body: details.body, + extra: details.extra, + email: typeof details.email === "string" ? details.email : undefined, + }, + decision: ArcjetDecisionToProtocol(decision), + rules: rules.map(ArcjetRuleToProtocol), + receivedAt: Timestamp.now(), + }); + + log.debug("Report request to %s", baseUrl); + + // We use the promise API directly to avoid returning a promise from this function so execution can't be paused with `await` + // TODO(#884): Leverage `waitUntil` if the function is attached to the context + client + .report(reportRequest, { + headers: { Authorization: `Bearer ${context.key}` }, + timeoutMs: 2_000, // 2 seconds + }) + .then((response) => { + log.debug( + { + id: response.decision?.id, + fingerprint: context.fingerprint, + path: details.path, + runtime: context.runtime, + ttl: decision.ttl, + }, + "Report response", + ); + }) + .catch((err: unknown) => { + log.info("Encountered problem sending report: %s", errorMessage(err)); + }); + }, + }); +} diff --git a/protocol/proto.ts b/protocol/proto.ts deleted file mode 100644 index b281f3156..000000000 --- a/protocol/proto.ts +++ /dev/null @@ -1,10 +0,0 @@ -// TODO: Finish abstracting over protobuf and don't re-export -export * from "./gen/es/decide/v1alpha1/decide_pb.js"; -export * from "./gen/es/decide/v1alpha1/decide_connect.js"; - -export { Timestamp, proto3 } from "@bufbuild/protobuf"; -export { - createPromiseClient, - createRouterTransport, - type Transport, -} from "@connectrpc/connect"; diff --git a/protocol/test/client.test.ts b/protocol/test/client.test.ts new file mode 100644 index 000000000..1bf236b10 --- /dev/null +++ b/protocol/test/client.test.ts @@ -0,0 +1,1141 @@ +import { afterEach, beforeEach, describe, expect, jest, test } from "@jest/globals"; +import { createClient } from "../client.js"; +import { createRouterTransport } from "@connectrpc/connect"; +import { DecideService } from "../gen/es/decide/v1alpha1/decide_connect.js"; +import { + Conclusion, + DecideRequest, + DecideResponse, + Reason, + ReportRequest, + ReportResponse, + Rule, + RuleResult, + RuleState, + SDKStack, +} from "../gen/es/decide/v1alpha1/decide_pb.js"; +import { + ArcjetAllowDecision, + ArcjetChallengeDecision, + ArcjetConclusion, + ArcjetDecision, + ArcjetDenyDecision, + ArcjetErrorDecision, + ArcjetErrorReason, + ArcjetReason, + ArcjetRule, + ArcjetRuleResult, +} from "../index.js"; +import { Timestamp } from "@bufbuild/protobuf"; + +function deferred(): [Promise, () => void, (reason?: unknown) => void] { + let resolve: () => void; + let reject: (reason?: unknown) => void; + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + + // @ts-expect-error + return [promise, resolve, reject]; +} + +class ArcjetTestReason extends ArcjetReason {} + +class ArcjetInvalidDecision extends ArcjetDecision { + reason: ArcjetReason; + conclusion: ArcjetConclusion; + + constructor() { + super({ ttl: 0, results: [] }); + // @ts-expect-error + this.conclusion = "INVALID"; + this.reason = new ArcjetTestReason(); + } +} + +beforeEach(() => { + jest.useFakeTimers(); +}); + +afterEach(() => { + jest.useRealTimers(); + jest.clearAllTimers(); + jest.clearAllMocks(); + jest.restoreAllMocks(); +}); + +describe("createClient", () => { + const log = { + debug() {}, + info() {}, + warn() {}, + error() {}, + }; + + const defaultRemoteClientOptions = { + baseUrl: "", + timeout: 0, + sdkStack: "NODEJS" as const, + sdkVersion: "__ARCJET_SDK_VERSION__", + }; + + test("can be called with only a transport", () => { + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(() => {}), + }); + expect(typeof client.decide).toEqual("function"); + expect(typeof client.report).toEqual("function"); + }); + + test("allows overriding the default timeout", async () => { + // TODO(#32): createRouterTransport doesn't seem to handle timeouts/promises correctly + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, {}); + }), + timeout: 300, + }); + expect(typeof client.decide).toEqual("function"); + expect(typeof client.report).toEqual("function"); + }); + + test("allows overriding the sdkStack", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.ALLOW, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + sdkStack: "NEXTJS", + }); + const _ = await client.decide(context, details, []); + + expect(router.decide).toHaveBeenCalledTimes(1); + expect(router.decide).toHaveBeenCalledWith( + new DecideRequest({ + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + rules: [], + sdkStack: SDKStack.SDK_STACK_NEXTJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + }), + expect.anything(), + ); + }); + + test("sets the sdkStack as UNSPECIFIED if invalid", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.ALLOW, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + // @ts-expect-error + sdkStack: "SOMETHING_INVALID", + }); + const _ = await client.decide(context, details, []); + + expect(router.decide).toHaveBeenCalledTimes(1); + expect(router.decide).toHaveBeenCalledWith( + new DecideRequest({ + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + rules: [], + sdkStack: SDKStack.SDK_STACK_UNSPECIFIED, + sdkVersion: "__ARCJET_SDK_VERSION__", + }), + expect.anything(), + ); + }); + + test("calling `decide` will make RPC call with correct message", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.ALLOW, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const _ = await client.decide(context, details, []); + + expect(router.decide).toHaveBeenCalledTimes(1); + expect(router.decide).toHaveBeenCalledWith( + new DecideRequest({ + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + rules: [], + sdkStack: SDKStack.SDK_STACK_NODEJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + }), + expect.anything(), + ); + }); + + test("calling `decide` will make RPC with email included", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + email: "abc@example.com", + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.ALLOW, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const _ = await client.decide(context, details, []); + + expect(router.decide).toHaveBeenCalledTimes(1); + expect(router.decide).toHaveBeenCalledWith( + new DecideRequest({ + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + rules: [], + sdkStack: SDKStack.SDK_STACK_NODEJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + }), + expect.anything(), + ); + }); + + test("calling `decide` will make RPC with rules included", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + email: "abc@example.com", + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.ALLOW, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const rule: ArcjetRule = { + type: "TEST_RULE", + mode: "DRY_RUN", + priority: 1, + }; + const _ = await client.decide(context, details, [rule]); + + expect(router.decide).toHaveBeenCalledTimes(1); + expect(router.decide).toHaveBeenCalledWith( + new DecideRequest({ + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + rules: [new Rule()], + sdkStack: SDKStack.SDK_STACK_NODEJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + }), + expect.anything(), + ); + }); + + test("calling `decide` creates an ALLOW ArcjetDecision if DecideResponse is allowed", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.ALLOW, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = await client.decide(context, details, []); + + expect(decision.isErrored()).toBe(false); + expect(decision.isAllowed()).toBe(true); + }); + + test("calling `decide` creates a DENY ArcjetDecision if DecideResponse is denied", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.DENY, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = await client.decide(context, details, []); + + expect(decision.isDenied()).toBe(true); + }); + + test("calling `decide` creates a CHALLENGE ArcjetDecision if DecideResponse is challenged", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.CHALLENGE, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = await client.decide(context, details, []); + + expect(decision.isChallenged()).toBe(true); + }); + + test("calling `decide` creates an ERROR ArcjetDecision with default message if DecideResponse is error and no reason", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.ERROR, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = await client.decide(context, details, []); + + expect(decision.isErrored()).toBe(true); + expect(decision.reason).toMatchObject({ + message: "Unknown error occurred", + }); + }); + + test("calling `decide` creates an ERROR ArcjetDecision with message if DecideResponse if error and reason available", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.ERROR, + reason: { + reason: { + case: "error", + value: { message: "Boom!" }, + }, + }, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = await client.decide(context, details, []); + + expect(decision.isErrored()).toBe(true); + expect(decision.reason).toMatchObject({ + message: "Boom!", + }); + }); + + test("calling `decide` creates an ERROR ArcjetDecision if DecideResponse is unspecified", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const router = { + decide: jest.fn((args) => { + return new DecideResponse({ + decision: { + conclusion: Conclusion.UNSPECIFIED, + }, + }); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = await client.decide(context, details, []); + + expect(decision.isErrored()).toBe(true); + expect(decision.isAllowed()).toBe(true); + }); + + test("calling `report` will make RPC call with ALLOW decision", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const receivedAt = Timestamp.now(); + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + email: "test@example.com", + }; + + const [promise, resolve] = deferred(); + + const router = { + report: jest.fn((args) => { + resolve(); + return new ReportResponse({}); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + client.report(context, details, decision, []); + + await promise; + + expect(router.report).toHaveBeenCalledTimes(1); + expect(router.report).toHaveBeenCalledWith( + new ReportRequest({ + sdkStack: SDKStack.SDK_STACK_NODEJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + decision: { + id: decision.id, + conclusion: Conclusion.ALLOW, + reason: new Reason(), + ruleResults: [], + }, + receivedAt, + }), + expect.anything(), + ); + }); + + test("calling `report` will make RPC call with DENY decision", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const receivedAt = Timestamp.now(); + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const [promise, resolve] = deferred(); + + const router = { + report: jest.fn((args) => { + resolve(); + return new ReportResponse({}); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = new ArcjetDenyDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + client.report(context, details, decision, []); + + await promise; + + expect(router.report).toHaveBeenCalledTimes(1); + expect(router.report).toHaveBeenCalledWith( + new ReportRequest({ + sdkStack: SDKStack.SDK_STACK_NODEJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + decision: { + id: decision.id, + conclusion: Conclusion.DENY, + reason: new Reason(), + ruleResults: [], + }, + receivedAt, + }), + expect.anything(), + ); + }); + + test("calling `report` will make RPC call with ERROR decision", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const receivedAt = Timestamp.now(); + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const [promise, resolve] = deferred(); + + const router = { + report: jest.fn((args) => { + resolve(); + return new ReportResponse({}); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = new ArcjetErrorDecision({ + ttl: 0, + reason: new ArcjetErrorReason("Failure"), + results: [], + }); + client.report(context, details, decision, []); + + await promise; + + expect(router.report).toHaveBeenCalledTimes(1); + expect(router.report).toHaveBeenCalledWith( + new ReportRequest({ + sdkStack: SDKStack.SDK_STACK_NODEJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + decision: { + id: decision.id, + conclusion: Conclusion.ERROR, + reason: new Reason({ + reason: { + case: "error", + value: { + message: "Failure", + }, + }, + }), + ruleResults: [], + }, + receivedAt, + }), + expect.anything(), + ); + }); + + test("calling `report` will make RPC call with CHALLENGE decision", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const receivedAt = Timestamp.now(); + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const [promise, resolve] = deferred(); + + const router = { + report: jest.fn((args) => { + resolve(); + return new ReportResponse({}); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = new ArcjetChallengeDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + client.report(context, details, decision, []); + + await promise; + + expect(router.report).toHaveBeenCalledTimes(1); + expect(router.report).toHaveBeenCalledWith( + new ReportRequest({ + sdkStack: SDKStack.SDK_STACK_NODEJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + decision: { + id: decision.id, + conclusion: Conclusion.CHALLENGE, + reason: new Reason(), + ruleResults: [], + }, + receivedAt, + }), + expect.anything(), + ); + }); + + test("calling `report` will make RPC call with UNSPECIFIED decision if invalid", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const receivedAt = Timestamp.now(); + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const [promise, resolve] = deferred(); + + const router = { + report: jest.fn((args) => { + resolve(); + return new ReportResponse({}); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = new ArcjetInvalidDecision(); + client.report(context, details, decision, []); + + await promise; + + expect(router.report).toHaveBeenCalledTimes(1); + expect(router.report).toHaveBeenCalledWith( + new ReportRequest({ + sdkStack: SDKStack.SDK_STACK_NODEJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + decision: { + id: decision.id, + conclusion: Conclusion.UNSPECIFIED, + reason: new Reason(), + ruleResults: [], + }, + receivedAt, + }), + expect.anything(), + ); + }); + + test("calling `report` will make RPC with rules included", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const receivedAt = Timestamp.now(); + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + email: "abc@example.com", + }; + + const [promise, resolve] = deferred(); + + const router = { + report: jest.fn((args) => { + resolve(); + return new ReportResponse({}); + }), + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + + const decision = new ArcjetDenyDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [ + new ArcjetRuleResult({ + ttl: 0, + state: "RUN", + conclusion: "DENY", + reason: new ArcjetReason(), + }), + ], + }); + const rule: ArcjetRule = { + type: "TEST_RULE", + mode: "LIVE", + priority: 1, + }; + client.report(context, details, decision, [rule]); + + await promise; + + expect(router.report).toHaveBeenCalledTimes(1); + expect(router.report).toHaveBeenCalledWith( + new ReportRequest({ + sdkStack: SDKStack.SDK_STACK_NODEJS, + sdkVersion: "__ARCJET_SDK_VERSION__", + details: { + ...details, + headers: { "user-agent": "curl/8.1.2" }, + }, + decision: { + id: decision.id, + conclusion: Conclusion.DENY, + reason: new Reason(), + ruleResults: [ + new RuleResult({ + ruleId: "", + state: RuleState.RUN, + conclusion: Conclusion.DENY, + reason: new Reason(), + }), + ], + }, + rules: [new Rule()], + receivedAt, + }), + expect.anything(), + ); + }); + + test("calling `report` only logs if it fails", async () => { + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + }; + + const [promise, resolve] = deferred(); + + const logSpy = jest.spyOn(log, "info").mockImplementation(() => { + resolve(); + }); + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, {}); + }), + }); + const decision = new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + client.report(context, details, decision, []); + + await promise; + + expect(logSpy).toHaveBeenCalledTimes(1); + }); +}); From aeb2511f6990b0e5a1a570f671f6c6b228de99ff Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Tue, 11 Jun 2024 15:53:28 -0400 Subject: [PATCH 2/3] fix example import --- .../nextjs-14-app-dir-rl/app/api/custom_timeout/route.ts | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/nextjs-14-app-dir-rl/app/api/custom_timeout/route.ts b/examples/nextjs-14-app-dir-rl/app/api/custom_timeout/route.ts index d73e6ee66..7f295b682 100644 --- a/examples/nextjs-14-app-dir-rl/app/api/custom_timeout/route.ts +++ b/examples/nextjs-14-app-dir-rl/app/api/custom_timeout/route.ts @@ -1,11 +1,8 @@ -import arcjet, { - validateEmail, - createNextRemoteClient, -} from "@arcjet/next"; +import arcjet, { validateEmail, createRemoteClient } from "@arcjet/next"; import { baseUrl } from "@arcjet/env"; import { NextResponse } from "next/server"; -const client = createNextRemoteClient({ +const client = createRemoteClient({ baseUrl: baseUrl(process.env), timeout: 10, }); From d78c4aba7ea7e41686032b41625015b905b28d0c Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Tue, 11 Jun 2024 15:58:43 -0400 Subject: [PATCH 3/3] fmt --- protocol/test/client.test.ts | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/protocol/test/client.test.ts b/protocol/test/client.test.ts index 1bf236b10..5c81137d2 100644 --- a/protocol/test/client.test.ts +++ b/protocol/test/client.test.ts @@ -1,4 +1,11 @@ -import { afterEach, beforeEach, describe, expect, jest, test } from "@jest/globals"; +import { + afterEach, + beforeEach, + describe, + expect, + jest, + test, +} from "@jest/globals"; import { createClient } from "../client.js"; import { createRouterTransport } from "@connectrpc/connect"; import { DecideService } from "../gen/es/decide/v1alpha1/decide_connect.js";