From 3500df4d220b9206be3c6136f4d6ea46c42d20e0 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Mon, 5 Aug 2024 11:32:46 -0700 Subject: [PATCH] Adds signal test, fix bug (#297) * Adds signal test, fix bug * Lint * Fix --- langgraph/src/pregel/index.ts | 2 +- langgraph/src/tests/prebuilt.test.ts | 36 ++++++++++++++++++++++++++++ langgraph/src/tests/utils.ts | 4 ++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/langgraph/src/pregel/index.ts b/langgraph/src/pregel/index.ts index 666a0d11..5211b645 100644 --- a/langgraph/src/pregel/index.ts +++ b/langgraph/src/pregel/index.ts @@ -694,7 +694,7 @@ export class Pregel< proc.invoke(input, updatedConfig) ); - await executeTasks(tasks, this.stepTimeout); + await executeTasks(tasks, this.stepTimeout, config.signal); // combine pending writes from all tasks const pendingWrites: Array<[keyof Cc, unknown]> = []; diff --git a/langgraph/src/tests/prebuilt.test.ts b/langgraph/src/tests/prebuilt.test.ts index 6a5b49ec..c38d1309 100644 --- a/langgraph/src/tests/prebuilt.test.ts +++ b/langgraph/src/tests/prebuilt.test.ts @@ -330,6 +330,42 @@ describe("createReactAgent", () => { ]); }); + it("Should respect a passed signal", async () => { + const llm = new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + content: "result1", + tool_calls: [ + { name: "search_api", id: "tool_abcd123", args: { query: "foo" } }, + ], + }), + new AIMessage("result2"), + ], + sleep: 500, + }); + + const agent = createReactAgent({ + llm, + tools: [new SearchAPIWithArtifact()], + messageModifier: "You are a helpful assistant", + }); + + const controller = new AbortController(); + + setTimeout(() => controller.abort(), 100); + + await expect(async () => { + await agent.invoke( + { + messages: [new HumanMessage("Hello Input!")], + }, + { + signal: controller.signal, + } + ); + }).rejects.toThrowError(); + }); + it("Works with tools that return content_and_artifact response format", async () => { const llm = new FakeToolCallingChatModel({ responses: [ diff --git a/langgraph/src/tests/utils.ts b/langgraph/src/tests/utils.ts index 082083c4..43c24d06 100644 --- a/langgraph/src/tests/utils.ts +++ b/langgraph/src/tests/utils.ts @@ -1,3 +1,4 @@ +/* eslint-disable no-promise-executor-return */ import assert from "node:assert"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { @@ -97,6 +98,9 @@ export class FakeToolCallingChatModel extends BaseChatModel { if (this.thrownErrorString) { throw new Error(this.thrownErrorString); } + if (this.sleep !== undefined) { + await new Promise((resolve) => setTimeout(resolve, this.sleep)); + } const msg = this.responses?.[this.idx] ?? messages[this.idx]; const generation: ChatResult = { generations: [