Skip to content

Commit

Permalink
Revert "langgraph[patch]: Relax tool node generic requirement (#404)"
Browse files Browse the repository at this point in the history
This reverts commit e298739.
  • Loading branch information
bracesproul authored Aug 28, 2024
1 parent e298739 commit 3340441
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 109 deletions.
2 changes: 1 addition & 1 deletion examples/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"// Define the tools for the agent to use\n",
"const tools = [new TavilySearchResults({ maxResults: 3 })];\n",
"\n",
"const toolNode = new ToolNode(tools);\n",
"const toolNode = new ToolNode<typeof GraphAnnotation.State>(tools);\n",
"\n",
"const model = new ChatOpenAI({ temperature: 0 }).bindTools(tools);\n",
"\n",
Expand Down
26 changes: 15 additions & 11 deletions libs/langgraph/src/prebuilt/tool_node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ export type ToolNodeOptions = {
handleToolErrors?: boolean;
};

/**
* A node that runs the tools requested in the last AIMessage. It can be used
* either in StateGraph with a "messages" key or in MessageGraph. If multiple
* tool calls are requested, they will be run in parallel. The output will be
* a list of ToolMessages, one for each tool call.
*/
export class ToolNode<
T extends BaseMessage[] | Partial<typeof MessagesAnnotation.State>
T extends BaseMessage[] | typeof MessagesAnnotation.State
> extends RunnableCallable<T, T> {
/**
A node that runs the tools requested in the last AIMessage. It can be used
either in StateGraph with a "messages" key or in MessageGraph. If multiple
tool calls are requested, they will be run in parallel. The output will be
a list of ToolMessages, one for each tool call.
*/

tools: (StructuredToolInterface | RunnableToolLike)[];

handleToolErrors = true;
Expand All @@ -39,12 +40,15 @@ export class ToolNode<
this.handleToolErrors = handleToolErrors ?? this.handleToolErrors;
}

private async run(input: T, config: RunnableConfig): Promise<T> {
private async run(
input: BaseMessage[] | typeof MessagesAnnotation.State,
config: RunnableConfig
): Promise<BaseMessage[] | typeof MessagesAnnotation.State> {
const message = Array.isArray(input)
? input[input.length - 1]
: input.messages?.[input.messages.length - 1];
: input.messages[input.messages.length - 1];

if (message?._getType() !== "ai") {
if (message._getType() !== "ai") {
throw new Error("ToolNode only accepts AIMessages as input.");
}

Expand Down Expand Up @@ -83,7 +87,7 @@ export class ToolNode<
}) ?? []
);

return (Array.isArray(input) ? outputs : { messages: outputs }) as T;
return Array.isArray(input) ? outputs : { messages: outputs };
}
}

Expand Down
98 changes: 1 addition & 97 deletions libs/langgraph/src/tests/prebuilt.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
/* eslint-disable no-param-reassign */
import { beforeAll, describe, expect, it } from "@jest/globals";
import { PromptTemplate } from "@langchain/core/prompts";
import { StructuredTool, tool, Tool } from "@langchain/core/tools";
import { StructuredTool, Tool } from "@langchain/core/tools";
import { FakeStreamingLLM } from "@langchain/core/utils/testing";

import {
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
Expand All @@ -20,7 +19,6 @@ import {
createAgentExecutor,
createReactAgent,
} from "../prebuilt/index.js";
import { Annotation, messagesStateReducer, StateGraph } from "../web.js";

// Tracing slows down the tests
beforeAll(() => {
Expand Down Expand Up @@ -494,98 +492,4 @@ describe("ToolNode", () => {
`Error: Tool "badtool" not found.\n Please fix your mistakes.`
);
});

it("Should work in a state graph", async () => {
const AgentAnnotation = Annotation.Root({
messages: Annotation<BaseMessage[]>({
reducer: messagesStateReducer,
default: () => [],
}),
prop2: Annotation<string>,
});

const weatherTool = tool(
async ({ query }) => {
// This is a placeholder for the actual implementation
if (
query.toLowerCase().includes("sf") ||
query.toLowerCase().includes("san francisco")
) {
return "It's 60 degrees and foggy.";
}
return "It's 90 degrees and sunny.";
},
{
name: "weather",
description: "Call to get the current weather for a location.",
schema: z.object({
query: z.string().describe("The query to use in your search."),
}),
}
);

const aiMessage = new AIMessage({
content: "",
tool_calls: [
{
id: "call_1234",
args: {
query: "SF",
},
name: "weather",
type: "tool_call",
},
],
});

const aiMessage2 = new AIMessage({
content: "FOO",
});

async function callModel(state: typeof AgentAnnotation.State) {
// We return a list, because this will get added to the existing list
if (state.messages.includes(aiMessage)) {
return { messages: [aiMessage2] };
}
return { messages: [aiMessage] };
}

function shouldContinue({
messages,
}: typeof AgentAnnotation.State): "tools" | "__end__" {
const lastMessage: AIMessage = messages[messages.length - 1];

// If the LLM makes a tool call, then we route to the "tools" node
if ((lastMessage.tool_calls?.length ?? 0) > 0) {
return "tools";
}
// Otherwise, we stop (reply to the user)
return "__end__";
}

const graph = new StateGraph(AgentAnnotation)
.addNode("agent", callModel)
.addNode("tools", new ToolNode([weatherTool]))
.addEdge("__start__", "agent")
.addConditionalEdges("agent", shouldContinue)
.addEdge("tools", "agent")
.compile();
const res = await graph.invoke({
messages: [],
});
const toolMessageId = res.messages[1].id;
expect(res).toEqual({
messages: [
aiMessage,
expect.objectContaining({
id: toolMessageId,
name: "weather",
artifact: undefined,
content: "It's 60 degrees and foggy.",
tool_call_id: "call_1234",
}),
aiMessage2,
],
});
});
});

0 comments on commit 3340441

Please sign in to comment.