Skip to content

Commit

Permalink
fix(prebuilt): use messagesStateReducer to support more formats (#387)
Browse files Browse the repository at this point in the history
Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
dqbd and jacoblee93 authored Aug 27, 2024
1 parent 1806051 commit e4a7858
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
4 changes: 2 additions & 2 deletions libs/langgraph/src/prebuilt/chat_agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ import {
} from "../graph/state.js";
import { END, START } from "../graph/index.js";

/** @ignore */
/** @deprecated Use {@link createReactAgent} instead with tool calling. */
export type FunctionCallingExecutorState = { messages: Array<BaseMessage> };

/** @ignore */
/** @deprecated Use {@link createReactAgent} instead with tool calling. */
export function createFunctionCallingExecutor<Model extends object>({
model,
tools,
Expand Down
9 changes: 7 additions & 2 deletions libs/langgraph/src/prebuilt/react_agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ import {
} from "@langchain/core/language_models/base";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { BaseCheckpointSaver } from "@langchain/langgraph-checkpoint";
import { END, START, StateGraph } from "../graph/index.js";
import {
END,
messagesStateReducer,
START,
StateGraph,
} from "../graph/index.js";
import { MessagesAnnotation } from "../graph/messages_annotation.js";
import { CompiledStateGraph, StateGraphArgs } from "../graph/state.js";
import { All } from "../pregel/types.js";
Expand Down Expand Up @@ -81,7 +86,7 @@ export function createReactAgent(
} = props;
const schema: StateGraphArgs<AgentState>["channels"] = {
messages: {
value: (left: BaseMessage[], right: BaseMessage[]) => left.concat(right),
value: messagesStateReducer,
default: () => [],
},
};
Expand Down
35 changes: 27 additions & 8 deletions libs/langgraph/src/tests/prebuilt.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/* eslint-disable no-process-env */
/* eslint-disable no-param-reassign */
import { beforeAll, describe, expect, it } from "@jest/globals";
import { PromptTemplate } from "@langchain/core/prompts";
import { StructuredTool, Tool } from "@langchain/core/tools";
Expand Down Expand Up @@ -275,7 +276,7 @@ describe("createReactAgent", () => {
messages: [new HumanMessage("Hello Input!")],
});

expect(result.messages).toEqual([
const expected = [
new HumanMessage("Hello Input!"),
new AIMessage({
content: "result1",
Expand All @@ -287,9 +288,14 @@ describe("createReactAgent", () => {
name: "search_api",
content: "result for foo",
tool_call_id: "tool_abcd123",
artifact: undefined,
}),
new AIMessage("result2"),
]);
].map((message, i) => {
message.id = result.messages[i].id;
return message;
});
expect(result.messages).toEqual(expected);
});

it("Can use SystemMessage message modifier", async () => {
Expand All @@ -314,7 +320,7 @@ describe("createReactAgent", () => {
const result = await agent.invoke({
messages: [],
});
expect(result.messages).toEqual([
const expected = [
new AIMessage({
content: "result1",
tool_calls: [
Expand All @@ -325,9 +331,14 @@ describe("createReactAgent", () => {
name: "search_api",
content: "result for foo",
tool_call_id: "tool_abcd123",
artifact: undefined,
}),
new AIMessage("result2"),
]);
].map((message, i) => {
message.id = result.messages[i].id;
return message;
});
expect(result.messages).toEqual(expected);
});

it("Should respect a passed signal", async () => {
Expand Down Expand Up @@ -389,7 +400,7 @@ describe("createReactAgent", () => {
messages: [new HumanMessage("Hello Input!")],
});

expect(result.messages).toEqual([
const expected = [
new HumanMessage("Hello Input!"),
new AIMessage({
content: "result1",
Expand All @@ -404,7 +415,11 @@ describe("createReactAgent", () => {
artifact: Buffer.from("123"),
}),
new AIMessage("result2"),
]);
].map((message, i) => {
message.id = result.messages[i].id;
return message;
});
expect(result.messages).toEqual(expected);
});

it("Can accept RunnableToolLike", async () => {
Expand Down Expand Up @@ -442,7 +457,7 @@ describe("createReactAgent", () => {
messages: [new HumanMessage("Hello Input!")],
});

expect(result.messages).toEqual([
const expected = [
new HumanMessage("Hello Input!"),
new AIMessage({
content: "result1",
Expand All @@ -456,7 +471,11 @@ describe("createReactAgent", () => {
tool_call_id: "tool_abcd123",
}),
new AIMessage("result2"),
]);
].map((message, i) => {
message.id = result.messages[i].id;
return message;
});
expect(result.messages).toEqual(expected);
});
});

Expand Down

0 comments on commit e4a7858

Please sign in to comment.