Skip to content

Commit

Permalink
Relax tool node generic requiremenet
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Aug 28, 2024
1 parent 120c048 commit dd240cf
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 11 deletions.
18 changes: 8 additions & 10 deletions libs/langgraph/src/prebuilt/tool_node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ export type ToolNodeOptions = {
handleToolErrors?: boolean;
};

export class ToolNode<
T extends BaseMessage[] | typeof MessagesAnnotation.State
> extends RunnableCallable<T, T> {
/**
A node that runs the tools requested in the last AIMessage. It can be used
either in StateGraph with a "messages" key or in MessageGraph. If multiple
tool calls are requested, they will be run in parallel. The output will be
a list of ToolMessages, one for each tool call.
*/

/**
* A node that runs the tools requested in the last AIMessage. It can be used
* either in StateGraph with a "messages" key or in MessageGraph. If multiple
* tool calls are requested, they will be run in parallel. The output will be
* a list of ToolMessages, one for each tool call.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export class ToolNode<T = any> extends RunnableCallable<T, T> {
tools: (StructuredToolInterface | RunnableToolLike)[];

handleToolErrors = true;
Expand Down
69 changes: 68 additions & 1 deletion libs/langgraph/src/tests/prebuilt.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
/* eslint-disable no-param-reassign */
import { beforeAll, describe, expect, it } from "@jest/globals";
import { PromptTemplate } from "@langchain/core/prompts";
import { StructuredTool, Tool } from "@langchain/core/tools";
import { StructuredTool, tool, Tool } from "@langchain/core/tools";
import { FakeStreamingLLM } from "@langchain/core/utils/testing";

import {
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
Expand All @@ -19,6 +20,7 @@ import {
createAgentExecutor,
createReactAgent,
} from "../prebuilt/index.js";
import { Annotation, messagesStateReducer, StateGraph } from "../web.js";

// Tracing slows down the tests
beforeAll(() => {
Expand Down Expand Up @@ -492,4 +494,69 @@ describe("ToolNode", () => {
`Error: Tool "badtool" not found.\n Please fix your mistakes.`
);
});

it("Should work in a state graph", async () => {
const AgentAnnotation = Annotation.Root({
messages: Annotation<BaseMessage[]>({
reducer: messagesStateReducer,
default: () => [],
}),
prop2: Annotation<string>,
});

const weatherTool = tool(
async ({ query }) => {
// This is a placeholder for the actual implementation
if (
query.toLowerCase().includes("sf") ||
query.toLowerCase().includes("san francisco")
) {
return "It's 60 degrees and foggy.";
}
return "It's 90 degrees and sunny.";
},
{
name: "weather",
description: "Call to get the current weather for a location.",
schema: z.object({
query: z.string().describe("The query to use in your search."),
}),
}
);

const graph = new StateGraph(AgentAnnotation)
.addNode("tools", new ToolNode([weatherTool]))
.addEdge("__start__", "tools")
.addEdge("tools", "__end__")
.compile();
const aiMessage = new AIMessage({
content: "",
tool_calls: [
{
id: "call_1234",
args: {
query: "SF",
},
name: "weather",
type: "tool_call",
},
],
});
const res = await graph.invoke({
messages: [aiMessage],
});
const toolMessageId = res.messages[1].id;
expect(res).toEqual({
messages: [
aiMessage,
expect.objectContaining({
id: toolMessageId,
name: "weather",
artifact: undefined,
content: "It's 60 degrees and foggy.",
tool_call_id: "call_1234",
}),
],
});
});
});

0 comments on commit dd240cf

Please sign in to comment.