From 4fc252e84fb2a4f56706f5015c24777649e59542 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 27 Aug 2024 18:11:31 -0700 Subject: [PATCH] Narrow type --- libs/langgraph/src/prebuilt/tool_node.ts | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/libs/langgraph/src/prebuilt/tool_node.ts b/libs/langgraph/src/prebuilt/tool_node.ts index 5acb7c2e9..edd1505d4 100644 --- a/libs/langgraph/src/prebuilt/tool_node.ts +++ b/libs/langgraph/src/prebuilt/tool_node.ts @@ -22,8 +22,9 @@ export type ToolNodeOptions = { * tool calls are requested, they will be run in parallel. The output will be * a list of ToolMessages, one for each tool call. */ -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export class ToolNode extends RunnableCallable { +export class ToolNode< + T extends BaseMessage[] | Partial +> extends RunnableCallable { tools: (StructuredToolInterface | RunnableToolLike)[]; handleToolErrors = true; @@ -38,15 +39,12 @@ export class ToolNode extends RunnableCallable { this.handleToolErrors = handleToolErrors ?? this.handleToolErrors; } - private async run( - input: BaseMessage[] | typeof MessagesAnnotation.State, - config: RunnableConfig - ): Promise { + private async run(input: T, config: RunnableConfig): Promise { const message = Array.isArray(input) ? input[input.length - 1] - : input.messages[input.messages.length - 1]; + : input.messages?.[input.messages.length - 1]; - if (message._getType() !== "ai") { + if (message?._getType() !== "ai") { throw new Error("ToolNode only accepts AIMessages as input."); } @@ -85,7 +83,7 @@ export class ToolNode extends RunnableCallable { }) ?? [] ); - return Array.isArray(input) ? outputs : { messages: outputs }; + return (Array.isArray(input) ? outputs : { messages: outputs }) as T; } }