diff --git a/langgraph/src/graph/index.ts b/langgraph/src/graph/index.ts index 9292db7a..05cea1b5 100644 --- a/langgraph/src/graph/index.ts +++ b/langgraph/src/graph/index.ts @@ -5,4 +5,8 @@ export { StateGraph, type CompiledStateGraph, } from "./state.js"; -export { MessageGraph, messagesStateReducer } from "./message.js"; +export { + MessageGraph, + messagesStateReducer, + MessagesState, +} from "./message.js"; diff --git a/langgraph/src/graph/message.ts b/langgraph/src/graph/message.ts index 1c0393c6..10e9438f 100644 --- a/langgraph/src/graph/message.ts +++ b/langgraph/src/graph/message.ts @@ -5,6 +5,7 @@ import { } from "@langchain/core/messages"; import { v4 } from "uuid"; import { StateGraph } from "./state.js"; +import { Annotation } from "./annotation.js"; type Messages = | Array @@ -76,6 +77,9 @@ export class MessageGraph extends StateGraph< } } -export interface MessagesState { - messages: BaseMessage[]; -} +export const MessagesState = Annotation.Root({ + messages: Annotation({ + reducer: messagesStateReducer, + default: () => [], + }), +}); diff --git a/langgraph/src/prebuilt/react_agent_executor.ts b/langgraph/src/prebuilt/react_agent_executor.ts index 4860bac0..62bcea7b 100644 --- a/langgraph/src/prebuilt/react_agent_executor.ts +++ b/langgraph/src/prebuilt/react_agent_executor.ts @@ -40,7 +40,7 @@ export type N = typeof START | "agent" | "tools"; export type CreateReactAgentParams = { llm: BaseChatModel; tools: - | ToolNode + | ToolNode | (StructuredToolInterface | RunnableToolLike)[]; messageModifier?: | SystemMessage diff --git a/langgraph/src/prebuilt/tool_node.ts b/langgraph/src/prebuilt/tool_node.ts index fdd71090..6378eac3 100644 --- a/langgraph/src/prebuilt/tool_node.ts +++ b/langgraph/src/prebuilt/tool_node.ts @@ -17,7 +17,7 @@ export type ToolNodeOptions = { }; export class ToolNode< - T extends BaseMessage[] | MessagesState + T extends BaseMessage[] | typeof MessagesState.State > extends RunnableCallable { /** A node that runs the tools requested in the last AIMessage. It can be used @@ -41,9 +41,9 @@ export class ToolNode< } private async run( - input: BaseMessage[] | MessagesState, + input: BaseMessage[] | typeof MessagesState.State, config: RunnableConfig - ): Promise { + ): Promise { const message = Array.isArray(input) ? input[input.length - 1] : input.messages[input.messages.length - 1]; @@ -92,7 +92,7 @@ export class ToolNode< } export function toolsCondition( - state: BaseMessage[] | MessagesState + state: BaseMessage[] | typeof MessagesState.State ): "tools" | typeof END { const message = Array.isArray(state) ? state[state.length - 1] diff --git a/langgraph/src/web.ts b/langgraph/src/web.ts index 1716eb8d..af1529e2 100644 --- a/langgraph/src/web.ts +++ b/langgraph/src/web.ts @@ -8,6 +8,7 @@ export { MessageGraph, messagesStateReducer, Annotation, + MessagesState, type StateType, type UpdateType, type CompiledGraph,