From dd240cf45ce73234ef04b66abe691c528c8d94b1 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 27 Aug 2024 18:07:26 -0700 Subject: [PATCH 1/7] Relax tool node generic requiremenet --- libs/langgraph/src/prebuilt/tool_node.ts | 18 +++--- libs/langgraph/src/tests/prebuilt.test.ts | 69 ++++++++++++++++++++++- 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/libs/langgraph/src/prebuilt/tool_node.ts b/libs/langgraph/src/prebuilt/tool_node.ts index 9cb6bf7e6..5acb7c2e9 100644 --- a/libs/langgraph/src/prebuilt/tool_node.ts +++ b/libs/langgraph/src/prebuilt/tool_node.ts @@ -16,16 +16,14 @@ export type ToolNodeOptions = { handleToolErrors?: boolean; }; -export class ToolNode< - T extends BaseMessage[] | typeof MessagesAnnotation.State -> extends RunnableCallable { - /** - A node that runs the tools requested in the last AIMessage. It can be used - either in StateGraph with a "messages" key or in MessageGraph. If multiple - tool calls are requested, they will be run in parallel. The output will be - a list of ToolMessages, one for each tool call. - */ - +/** + * A node that runs the tools requested in the last AIMessage. It can be used + * either in StateGraph with a "messages" key or in MessageGraph. If multiple + * tool calls are requested, they will be run in parallel. The output will be + * a list of ToolMessages, one for each tool call. + */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export class ToolNode extends RunnableCallable { tools: (StructuredToolInterface | RunnableToolLike)[]; handleToolErrors = true; diff --git a/libs/langgraph/src/tests/prebuilt.test.ts b/libs/langgraph/src/tests/prebuilt.test.ts index ccfd3cd9c..2c77cc9f9 100644 --- a/libs/langgraph/src/tests/prebuilt.test.ts +++ b/libs/langgraph/src/tests/prebuilt.test.ts @@ -2,11 +2,12 @@ /* eslint-disable no-param-reassign */ import { beforeAll, describe, expect, it } from "@jest/globals"; import { PromptTemplate } from "@langchain/core/prompts"; -import { StructuredTool, Tool } from "@langchain/core/tools"; +import { StructuredTool, tool, Tool } from "@langchain/core/tools"; import { FakeStreamingLLM } from "@langchain/core/utils/testing"; import { AIMessage, + BaseMessage, HumanMessage, SystemMessage, ToolMessage, @@ -19,6 +20,7 @@ import { createAgentExecutor, createReactAgent, } from "../prebuilt/index.js"; +import { Annotation, messagesStateReducer, StateGraph } from "../web.js"; // Tracing slows down the tests beforeAll(() => { @@ -492,4 +494,69 @@ describe("ToolNode", () => { `Error: Tool "badtool" not found.\n Please fix your mistakes.` ); }); + + it("Should work in a state graph", async () => { + const AgentAnnotation = Annotation.Root({ + messages: Annotation({ + reducer: messagesStateReducer, + default: () => [], + }), + prop2: Annotation, + }); + + const weatherTool = tool( + async ({ query }) => { + // This is a placeholder for the actual implementation + if ( + query.toLowerCase().includes("sf") || + query.toLowerCase().includes("san francisco") + ) { + return "It's 60 degrees and foggy."; + } + return "It's 90 degrees and sunny."; + }, + { + name: "weather", + description: "Call to get the current weather for a location.", + schema: z.object({ + query: z.string().describe("The query to use in your search."), + }), + } + ); + + const graph = new StateGraph(AgentAnnotation) + .addNode("tools", new ToolNode([weatherTool])) + .addEdge("__start__", "tools") + .addEdge("tools", "__end__") + .compile(); + const aiMessage = new AIMessage({ + content: "", + tool_calls: [ + { + id: "call_1234", + args: { + query: "SF", + }, + name: "weather", + type: "tool_call", + }, + ], + }); + const res = await graph.invoke({ + messages: [aiMessage], + }); + const toolMessageId = res.messages[1].id; + expect(res).toEqual({ + messages: [ + aiMessage, + expect.objectContaining({ + id: toolMessageId, + name: "weather", + artifact: undefined, + content: "It's 60 degrees and foggy.", + tool_call_id: "call_1234", + }), + ], + }); + }); }); From 4fc252e84fb2a4f56706f5015c24777649e59542 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 27 Aug 2024 18:11:31 -0700 Subject: [PATCH 2/7] Narrow type --- libs/langgraph/src/prebuilt/tool_node.ts | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/libs/langgraph/src/prebuilt/tool_node.ts b/libs/langgraph/src/prebuilt/tool_node.ts index 5acb7c2e9..edd1505d4 100644 --- a/libs/langgraph/src/prebuilt/tool_node.ts +++ b/libs/langgraph/src/prebuilt/tool_node.ts @@ -22,8 +22,9 @@ export type ToolNodeOptions = { * tool calls are requested, they will be run in parallel. The output will be * a list of ToolMessages, one for each tool call. */ -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export class ToolNode extends RunnableCallable { +export class ToolNode< + T extends BaseMessage[] | Partial +> extends RunnableCallable { tools: (StructuredToolInterface | RunnableToolLike)[]; handleToolErrors = true; @@ -38,15 +39,12 @@ export class ToolNode extends RunnableCallable { this.handleToolErrors = handleToolErrors ?? this.handleToolErrors; } - private async run( - input: BaseMessage[] | typeof MessagesAnnotation.State, - config: RunnableConfig - ): Promise { + private async run(input: T, config: RunnableConfig): Promise { const message = Array.isArray(input) ? input[input.length - 1] - : input.messages[input.messages.length - 1]; + : input.messages?.[input.messages.length - 1]; - if (message._getType() !== "ai") { + if (message?._getType() !== "ai") { throw new Error("ToolNode only accepts AIMessages as input."); } @@ -85,7 +83,7 @@ export class ToolNode extends RunnableCallable { }) ?? [] ); - return Array.isArray(input) ? outputs : { messages: outputs }; + return (Array.isArray(input) ? outputs : { messages: outputs }) as T; } } From 93997acb6f155ac97d00e267207e4a4704d842b2 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 27 Aug 2024 18:27:54 -0700 Subject: [PATCH 3/7] Adds more advanced test --- examples/quickstart.ipynb | 4 +-- libs/langgraph/src/tests/prebuilt.test.ts | 41 +++++++++++++++++++---- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/examples/quickstart.ipynb b/examples/quickstart.ipynb index c8f28d843..95c88ab0b 100644 --- a/examples/quickstart.ipynb +++ b/examples/quickstart.ipynb @@ -187,7 +187,7 @@ "// Define the tools for the agent to use\n", "const tools = [new TavilySearchResults({ maxResults: 3 })];\n", "\n", - "const toolNode = new ToolNode(tools);\n", + "const toolNode = new ToolNode(tools);\n", "\n", "const model = new ChatOpenAI({ temperature: 0 }).bindTools(tools);\n", "\n", @@ -282,4 +282,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/libs/langgraph/src/tests/prebuilt.test.ts b/libs/langgraph/src/tests/prebuilt.test.ts index 2c77cc9f9..a15dddd2b 100644 --- a/libs/langgraph/src/tests/prebuilt.test.ts +++ b/libs/langgraph/src/tests/prebuilt.test.ts @@ -524,11 +524,6 @@ describe("ToolNode", () => { } ); - const graph = new StateGraph(AgentAnnotation) - .addNode("tools", new ToolNode([weatherTool])) - .addEdge("__start__", "tools") - .addEdge("tools", "__end__") - .compile(); const aiMessage = new AIMessage({ content: "", tool_calls: [ @@ -542,8 +537,41 @@ describe("ToolNode", () => { }, ], }); + + const aiMessage2 = new AIMessage({ + content: "FOO", + }); + + async function callModel(state: typeof AgentAnnotation.State) { + // We return a list, because this will get added to the existing list + if (state.messages.includes(aiMessage)) { + return { messages: [aiMessage2] }; + } + return { messages: [aiMessage] }; + } + + function shouldContinue( + { messages }: typeof AgentAnnotation.State + ): "tools" | "__end__" { + const lastMessage: AIMessage = messages[messages.length - 1]; + + // If the LLM makes a tool call, then we route to the "tools" node + if ((lastMessage.tool_calls?.length ?? 0) > 0) { + return "tools"; + } + // Otherwise, we stop (reply to the user) + return "__end__"; + } + + const graph = new StateGraph(AgentAnnotation) + .addNode("agent", callModel) + .addNode("tools", new ToolNode([weatherTool])) + .addEdge("__start__", "agent") + .addConditionalEdges("agent", shouldContinue) + .addEdge("tools", "agent") + .compile(); const res = await graph.invoke({ - messages: [aiMessage], + messages: [], }); const toolMessageId = res.messages[1].id; expect(res).toEqual({ @@ -556,6 +584,7 @@ describe("ToolNode", () => { content: "It's 60 degrees and foggy.", tool_call_id: "call_1234", }), + aiMessage2, ], }); }); From 425dab6fa2b21511085840f2108d5d4c85a34009 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 27 Aug 2024 18:30:13 -0700 Subject: [PATCH 4/7] Format --- libs/langgraph/src/tests/prebuilt.test.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langgraph/src/tests/prebuilt.test.ts b/libs/langgraph/src/tests/prebuilt.test.ts index a15dddd2b..9de83c83a 100644 --- a/libs/langgraph/src/tests/prebuilt.test.ts +++ b/libs/langgraph/src/tests/prebuilt.test.ts @@ -550,9 +550,9 @@ describe("ToolNode", () => { return { messages: [aiMessage] }; } - function shouldContinue( - { messages }: typeof AgentAnnotation.State - ): "tools" | "__end__" { + function shouldContinue({ + messages, + }: typeof AgentAnnotation.State): "tools" | "__end__" { const lastMessage: AIMessage = messages[messages.length - 1]; // If the LLM makes a tool call, then we route to the "tools" node From 476782ce371cf284ad1508140e047626f52126b4 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 28 Aug 2024 09:50:56 -0700 Subject: [PATCH 5/7] Update validate ntbk --- .github/workflows/validate-new-notebooks.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/validate-new-notebooks.yml b/.github/workflows/validate-new-notebooks.yml index f551a1784..3375d7f88 100644 --- a/.github/workflows/validate-new-notebooks.yml +++ b/.github/workflows/validate-new-notebooks.yml @@ -20,6 +20,7 @@ on: paths: - 'docs/docs/**' - 'examples/**' + - '*.ipynb' - 'deno.json' workflow_dispatch: From c9001554cbd913df3a658f5bdd24427b9590fb5a Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 28 Aug 2024 09:51:39 -0700 Subject: [PATCH 6/7] drop branches main --- .github/workflows/validate-new-notebooks.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/validate-new-notebooks.yml b/.github/workflows/validate-new-notebooks.yml index 3375d7f88..3a4883372 100644 --- a/.github/workflows/validate-new-notebooks.yml +++ b/.github/workflows/validate-new-notebooks.yml @@ -15,8 +15,6 @@ on: branches: - main pull_request: - branches: - - main paths: - 'docs/docs/**' - 'examples/**' From dc312dfe86a579faa4ffdde4e2b8a6ff1ce31863 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 28 Aug 2024 09:54:15 -0700 Subject: [PATCH 7/7] update paths --- .github/workflows/validate-new-notebooks.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/validate-new-notebooks.yml b/.github/workflows/validate-new-notebooks.yml index 3a4883372..ddb9687a2 100644 --- a/.github/workflows/validate-new-notebooks.yml +++ b/.github/workflows/validate-new-notebooks.yml @@ -12,13 +12,12 @@ concurrency: on: push: - branches: - - main + branches: ["main"] pull_request: paths: - 'docs/docs/**' - 'examples/**' - - '*.ipynb' + - 'examples/*' - 'deno.json' workflow_dispatch: