diff --git a/libs/langgraph/src/prebuilt/tool_node.ts b/libs/langgraph/src/prebuilt/tool_node.ts index 9cb6bf7e6..5acb7c2e9 100644 --- a/libs/langgraph/src/prebuilt/tool_node.ts +++ b/libs/langgraph/src/prebuilt/tool_node.ts @@ -16,16 +16,14 @@ export type ToolNodeOptions = { handleToolErrors?: boolean; }; -export class ToolNode< - T extends BaseMessage[] | typeof MessagesAnnotation.State -> extends RunnableCallable { - /** - A node that runs the tools requested in the last AIMessage. It can be used - either in StateGraph with a "messages" key or in MessageGraph. If multiple - tool calls are requested, they will be run in parallel. The output will be - a list of ToolMessages, one for each tool call. - */ - +/** + * A node that runs the tools requested in the last AIMessage. It can be used + * either in StateGraph with a "messages" key or in MessageGraph. If multiple + * tool calls are requested, they will be run in parallel. The output will be + * a list of ToolMessages, one for each tool call. + */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export class ToolNode extends RunnableCallable { tools: (StructuredToolInterface | RunnableToolLike)[]; handleToolErrors = true; diff --git a/libs/langgraph/src/tests/prebuilt.test.ts b/libs/langgraph/src/tests/prebuilt.test.ts index ccfd3cd9c..2c77cc9f9 100644 --- a/libs/langgraph/src/tests/prebuilt.test.ts +++ b/libs/langgraph/src/tests/prebuilt.test.ts @@ -2,11 +2,12 @@ /* eslint-disable no-param-reassign */ import { beforeAll, describe, expect, it } from "@jest/globals"; import { PromptTemplate } from "@langchain/core/prompts"; -import { StructuredTool, Tool } from "@langchain/core/tools"; +import { StructuredTool, tool, Tool } from "@langchain/core/tools"; import { FakeStreamingLLM } from "@langchain/core/utils/testing"; import { AIMessage, + BaseMessage, HumanMessage, SystemMessage, ToolMessage, @@ -19,6 +20,7 @@ import { createAgentExecutor, createReactAgent, } from "../prebuilt/index.js"; +import { Annotation, messagesStateReducer, StateGraph } from "../web.js"; // Tracing slows down the tests beforeAll(() => { @@ -492,4 +494,69 @@ describe("ToolNode", () => { `Error: Tool "badtool" not found.\n Please fix your mistakes.` ); }); + + it("Should work in a state graph", async () => { + const AgentAnnotation = Annotation.Root({ + messages: Annotation({ + reducer: messagesStateReducer, + default: () => [], + }), + prop2: Annotation, + }); + + const weatherTool = tool( + async ({ query }) => { + // This is a placeholder for the actual implementation + if ( + query.toLowerCase().includes("sf") || + query.toLowerCase().includes("san francisco") + ) { + return "It's 60 degrees and foggy."; + } + return "It's 90 degrees and sunny."; + }, + { + name: "weather", + description: "Call to get the current weather for a location.", + schema: z.object({ + query: z.string().describe("The query to use in your search."), + }), + } + ); + + const graph = new StateGraph(AgentAnnotation) + .addNode("tools", new ToolNode([weatherTool])) + .addEdge("__start__", "tools") + .addEdge("tools", "__end__") + .compile(); + const aiMessage = new AIMessage({ + content: "", + tool_calls: [ + { + id: "call_1234", + args: { + query: "SF", + }, + name: "weather", + type: "tool_call", + }, + ], + }); + const res = await graph.invoke({ + messages: [aiMessage], + }); + const toolMessageId = res.messages[1].id; + expect(res).toEqual({ + messages: [ + aiMessage, + expect.objectContaining({ + id: toolMessageId, + name: "weather", + artifact: undefined, + content: "It's 60 degrees and foggy.", + tool_call_id: "call_1234", + }), + ], + }); + }); });