diff --git a/examples/quickstart.ipynb b/examples/quickstart.ipynb index 879c0223a..95c88ab0b 100644 --- a/examples/quickstart.ipynb +++ b/examples/quickstart.ipynb @@ -187,7 +187,7 @@ "// Define the tools for the agent to use\n", "const tools = [new TavilySearchResults({ maxResults: 3 })];\n", "\n", - "const toolNode = new ToolNode(tools);\n", + "const toolNode = new ToolNode(tools);\n", "\n", "const model = new ChatOpenAI({ temperature: 0 }).bindTools(tools);\n", "\n", diff --git a/libs/langgraph/src/prebuilt/tool_node.ts b/libs/langgraph/src/prebuilt/tool_node.ts index 9cb6bf7e6..edd1505d4 100644 --- a/libs/langgraph/src/prebuilt/tool_node.ts +++ b/libs/langgraph/src/prebuilt/tool_node.ts @@ -16,16 +16,15 @@ export type ToolNodeOptions = { handleToolErrors?: boolean; }; +/** + * A node that runs the tools requested in the last AIMessage. It can be used + * either in StateGraph with a "messages" key or in MessageGraph. If multiple + * tool calls are requested, they will be run in parallel. The output will be + * a list of ToolMessages, one for each tool call. + */ export class ToolNode< - T extends BaseMessage[] | typeof MessagesAnnotation.State + T extends BaseMessage[] | Partial > extends RunnableCallable { - /** - A node that runs the tools requested in the last AIMessage. It can be used - either in StateGraph with a "messages" key or in MessageGraph. If multiple - tool calls are requested, they will be run in parallel. The output will be - a list of ToolMessages, one for each tool call. - */ - tools: (StructuredToolInterface | RunnableToolLike)[]; handleToolErrors = true; @@ -40,15 +39,12 @@ export class ToolNode< this.handleToolErrors = handleToolErrors ?? this.handleToolErrors; } - private async run( - input: BaseMessage[] | typeof MessagesAnnotation.State, - config: RunnableConfig - ): Promise { + private async run(input: T, config: RunnableConfig): Promise { const message = Array.isArray(input) ? input[input.length - 1] - : input.messages[input.messages.length - 1]; + : input.messages?.[input.messages.length - 1]; - if (message._getType() !== "ai") { + if (message?._getType() !== "ai") { throw new Error("ToolNode only accepts AIMessages as input."); } @@ -87,7 +83,7 @@ export class ToolNode< }) ?? [] ); - return Array.isArray(input) ? outputs : { messages: outputs }; + return (Array.isArray(input) ? outputs : { messages: outputs }) as T; } } diff --git a/libs/langgraph/src/tests/prebuilt.test.ts b/libs/langgraph/src/tests/prebuilt.test.ts index ccfd3cd9c..9de83c83a 100644 --- a/libs/langgraph/src/tests/prebuilt.test.ts +++ b/libs/langgraph/src/tests/prebuilt.test.ts @@ -2,11 +2,12 @@ /* eslint-disable no-param-reassign */ import { beforeAll, describe, expect, it } from "@jest/globals"; import { PromptTemplate } from "@langchain/core/prompts"; -import { StructuredTool, Tool } from "@langchain/core/tools"; +import { StructuredTool, tool, Tool } from "@langchain/core/tools"; import { FakeStreamingLLM } from "@langchain/core/utils/testing"; import { AIMessage, + BaseMessage, HumanMessage, SystemMessage, ToolMessage, @@ -19,6 +20,7 @@ import { createAgentExecutor, createReactAgent, } from "../prebuilt/index.js"; +import { Annotation, messagesStateReducer, StateGraph } from "../web.js"; // Tracing slows down the tests beforeAll(() => { @@ -492,4 +494,98 @@ describe("ToolNode", () => { `Error: Tool "badtool" not found.\n Please fix your mistakes.` ); }); + + it("Should work in a state graph", async () => { + const AgentAnnotation = Annotation.Root({ + messages: Annotation({ + reducer: messagesStateReducer, + default: () => [], + }), + prop2: Annotation, + }); + + const weatherTool = tool( + async ({ query }) => { + // This is a placeholder for the actual implementation + if ( + query.toLowerCase().includes("sf") || + query.toLowerCase().includes("san francisco") + ) { + return "It's 60 degrees and foggy."; + } + return "It's 90 degrees and sunny."; + }, + { + name: "weather", + description: "Call to get the current weather for a location.", + schema: z.object({ + query: z.string().describe("The query to use in your search."), + }), + } + ); + + const aiMessage = new AIMessage({ + content: "", + tool_calls: [ + { + id: "call_1234", + args: { + query: "SF", + }, + name: "weather", + type: "tool_call", + }, + ], + }); + + const aiMessage2 = new AIMessage({ + content: "FOO", + }); + + async function callModel(state: typeof AgentAnnotation.State) { + // We return a list, because this will get added to the existing list + if (state.messages.includes(aiMessage)) { + return { messages: [aiMessage2] }; + } + return { messages: [aiMessage] }; + } + + function shouldContinue({ + messages, + }: typeof AgentAnnotation.State): "tools" | "__end__" { + const lastMessage: AIMessage = messages[messages.length - 1]; + + // If the LLM makes a tool call, then we route to the "tools" node + if ((lastMessage.tool_calls?.length ?? 0) > 0) { + return "tools"; + } + // Otherwise, we stop (reply to the user) + return "__end__"; + } + + const graph = new StateGraph(AgentAnnotation) + .addNode("agent", callModel) + .addNode("tools", new ToolNode([weatherTool])) + .addEdge("__start__", "agent") + .addConditionalEdges("agent", shouldContinue) + .addEdge("tools", "agent") + .compile(); + const res = await graph.invoke({ + messages: [], + }); + const toolMessageId = res.messages[1].id; + expect(res).toEqual({ + messages: [ + aiMessage, + expect.objectContaining({ + id: toolMessageId, + name: "weather", + artifact: undefined, + content: "It's 60 degrees and foggy.", + tool_call_id: "call_1234", + }), + aiMessage2, + ], + }); + }); });