Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph[patch]: Relax tool node generic requirement #404

Merged
merged 8 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"// Define the tools for the agent to use\n",
"const tools = [new TavilySearchResults({ maxResults: 3 })];\n",
"\n",
"const toolNode = new ToolNode<typeof GraphAnnotation.State>(tools);\n",
"const toolNode = new ToolNode(tools);\n",
"\n",
"const model = new ChatOpenAI({ temperature: 0 }).bindTools(tools);\n",
"\n",
Expand Down
26 changes: 11 additions & 15 deletions libs/langgraph/src/prebuilt/tool_node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@ export type ToolNodeOptions = {
handleToolErrors?: boolean;
};

/**
* A node that runs the tools requested in the last AIMessage. It can be used
* either in StateGraph with a "messages" key or in MessageGraph. If multiple
* tool calls are requested, they will be run in parallel. The output will be
* a list of ToolMessages, one for each tool call.
*/
export class ToolNode<
T extends BaseMessage[] | typeof MessagesAnnotation.State
T extends BaseMessage[] | Partial<typeof MessagesAnnotation.State>
> extends RunnableCallable<T, T> {
/**
A node that runs the tools requested in the last AIMessage. It can be used
either in StateGraph with a "messages" key or in MessageGraph. If multiple
tool calls are requested, they will be run in parallel. The output will be
a list of ToolMessages, one for each tool call.
*/

tools: (StructuredToolInterface | RunnableToolLike)[];

handleToolErrors = true;
Expand All @@ -40,15 +39,12 @@ export class ToolNode<
this.handleToolErrors = handleToolErrors ?? this.handleToolErrors;
}

private async run(
input: BaseMessage[] | typeof MessagesAnnotation.State,
config: RunnableConfig
): Promise<BaseMessage[] | typeof MessagesAnnotation.State> {
private async run(input: T, config: RunnableConfig): Promise<T> {
const message = Array.isArray(input)
? input[input.length - 1]
: input.messages[input.messages.length - 1];
: input.messages?.[input.messages.length - 1];

if (message._getType() !== "ai") {
if (message?._getType() !== "ai") {
throw new Error("ToolNode only accepts AIMessages as input.");
}

Expand Down Expand Up @@ -87,7 +83,7 @@ export class ToolNode<
}) ?? []
);

return Array.isArray(input) ? outputs : { messages: outputs };
return (Array.isArray(input) ? outputs : { messages: outputs }) as T;
}
}

Expand Down
98 changes: 97 additions & 1 deletion libs/langgraph/src/tests/prebuilt.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
/* eslint-disable no-param-reassign */
import { beforeAll, describe, expect, it } from "@jest/globals";
import { PromptTemplate } from "@langchain/core/prompts";
import { StructuredTool, Tool } from "@langchain/core/tools";
import { StructuredTool, tool, Tool } from "@langchain/core/tools";
import { FakeStreamingLLM } from "@langchain/core/utils/testing";

import {
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
Expand All @@ -19,6 +20,7 @@ import {
createAgentExecutor,
createReactAgent,
} from "../prebuilt/index.js";
import { Annotation, messagesStateReducer, StateGraph } from "../web.js";

// Tracing slows down the tests
beforeAll(() => {
Expand Down Expand Up @@ -492,4 +494,98 @@ describe("ToolNode", () => {
`Error: Tool "badtool" not found.\n Please fix your mistakes.`
);
});

it("Should work in a state graph", async () => {
const AgentAnnotation = Annotation.Root({
messages: Annotation<BaseMessage[]>({
reducer: messagesStateReducer,
default: () => [],
}),
prop2: Annotation<string>,
});

const weatherTool = tool(
async ({ query }) => {
// This is a placeholder for the actual implementation
if (
query.toLowerCase().includes("sf") ||
query.toLowerCase().includes("san francisco")
) {
return "It's 60 degrees and foggy.";
}
return "It's 90 degrees and sunny.";
},
{
name: "weather",
description: "Call to get the current weather for a location.",
schema: z.object({
query: z.string().describe("The query to use in your search."),
}),
}
);

const aiMessage = new AIMessage({
content: "",
tool_calls: [
{
id: "call_1234",
args: {
query: "SF",
},
name: "weather",
type: "tool_call",
},
],
});

const aiMessage2 = new AIMessage({
content: "FOO",
});

async function callModel(state: typeof AgentAnnotation.State) {
// We return a list, because this will get added to the existing list
if (state.messages.includes(aiMessage)) {
return { messages: [aiMessage2] };
}
return { messages: [aiMessage] };
}

function shouldContinue({
messages,
}: typeof AgentAnnotation.State): "tools" | "__end__" {
const lastMessage: AIMessage = messages[messages.length - 1];

// If the LLM makes a tool call, then we route to the "tools" node
if ((lastMessage.tool_calls?.length ?? 0) > 0) {
return "tools";
}
// Otherwise, we stop (reply to the user)
return "__end__";
}

const graph = new StateGraph(AgentAnnotation)
.addNode("agent", callModel)
.addNode("tools", new ToolNode([weatherTool]))
.addEdge("__start__", "agent")
.addConditionalEdges("agent", shouldContinue)
.addEdge("tools", "agent")
.compile();
const res = await graph.invoke({
messages: [],
});
const toolMessageId = res.messages[1].id;
expect(res).toEqual({
messages: [
aiMessage,
expect.objectContaining({
id: toolMessageId,
name: "weather",
artifact: undefined,
content: "It's 60 degrees and foggy.",
tool_call_id: "call_1234",
}),
aiMessage2,
],
});
});
});
Loading