From 93997acb6f155ac97d00e267207e4a4704d842b2 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 27 Aug 2024 18:27:54 -0700 Subject: [PATCH] Adds more advanced test --- examples/quickstart.ipynb | 4 +-- libs/langgraph/src/tests/prebuilt.test.ts | 41 +++++++++++++++++++---- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/examples/quickstart.ipynb b/examples/quickstart.ipynb index c8f28d843..95c88ab0b 100644 --- a/examples/quickstart.ipynb +++ b/examples/quickstart.ipynb @@ -187,7 +187,7 @@ "// Define the tools for the agent to use\n", "const tools = [new TavilySearchResults({ maxResults: 3 })];\n", "\n", - "const toolNode = new ToolNode(tools);\n", + "const toolNode = new ToolNode(tools);\n", "\n", "const model = new ChatOpenAI({ temperature: 0 }).bindTools(tools);\n", "\n", @@ -282,4 +282,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/libs/langgraph/src/tests/prebuilt.test.ts b/libs/langgraph/src/tests/prebuilt.test.ts index 2c77cc9f9..a15dddd2b 100644 --- a/libs/langgraph/src/tests/prebuilt.test.ts +++ b/libs/langgraph/src/tests/prebuilt.test.ts @@ -524,11 +524,6 @@ describe("ToolNode", () => { } ); - const graph = new StateGraph(AgentAnnotation) - .addNode("tools", new ToolNode([weatherTool])) - .addEdge("__start__", "tools") - .addEdge("tools", "__end__") - .compile(); const aiMessage = new AIMessage({ content: "", tool_calls: [ @@ -542,8 +537,41 @@ describe("ToolNode", () => { }, ], }); + + const aiMessage2 = new AIMessage({ + content: "FOO", + }); + + async function callModel(state: typeof AgentAnnotation.State) { + // We return a list, because this will get added to the existing list + if (state.messages.includes(aiMessage)) { + return { messages: [aiMessage2] }; + } + return { messages: [aiMessage] }; + } + + function shouldContinue( + { messages }: typeof AgentAnnotation.State + ): "tools" | "__end__" { + const lastMessage: AIMessage = messages[messages.length - 1]; + + // If the LLM makes a tool call, then we route to the "tools" node + if ((lastMessage.tool_calls?.length ?? 0) > 0) { + return "tools"; + } + // Otherwise, we stop (reply to the user) + return "__end__"; + } + + const graph = new StateGraph(AgentAnnotation) + .addNode("agent", callModel) + .addNode("tools", new ToolNode([weatherTool])) + .addEdge("__start__", "agent") + .addConditionalEdges("agent", shouldContinue) + .addEdge("tools", "agent") + .compile(); const res = await graph.invoke({ - messages: [aiMessage], + messages: [], }); const toolMessageId = res.messages[1].id; expect(res).toEqual({ @@ -556,6 +584,7 @@ describe("ToolNode", () => { content: "It's 60 degrees and foggy.", tool_call_id: "call_1234", }), + aiMessage2, ], }); });