Skip to content

Commit

Permalink
Narrow type
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Aug 28, 2024
1 parent dd240cf commit 4fc252e
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions libs/langgraph/src/prebuilt/tool_node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ export type ToolNodeOptions = {
* tool calls are requested, they will be run in parallel. The output will be
* a list of ToolMessages, one for each tool call.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export class ToolNode<T = any> extends RunnableCallable<T, T> {
export class ToolNode<
T extends BaseMessage[] | Partial<typeof MessagesAnnotation.State>
> extends RunnableCallable<T, T> {
tools: (StructuredToolInterface | RunnableToolLike)[];

handleToolErrors = true;
Expand All @@ -38,15 +39,12 @@ export class ToolNode<T = any> extends RunnableCallable<T, T> {
this.handleToolErrors = handleToolErrors ?? this.handleToolErrors;
}

private async run(
input: BaseMessage[] | typeof MessagesAnnotation.State,
config: RunnableConfig
): Promise<BaseMessage[] | typeof MessagesAnnotation.State> {
private async run(input: T, config: RunnableConfig): Promise<T> {
const message = Array.isArray(input)
? input[input.length - 1]
: input.messages[input.messages.length - 1];
: input.messages?.[input.messages.length - 1];

if (message._getType() !== "ai") {
if (message?._getType() !== "ai") {
throw new Error("ToolNode only accepts AIMessages as input.");
}

Expand Down Expand Up @@ -85,7 +83,7 @@ export class ToolNode<T = any> extends RunnableCallable<T, T> {
}) ?? []
);

return Array.isArray(input) ? outputs : { messages: outputs };
return (Array.isArray(input) ? outputs : { messages: outputs }) as T;
}
}

Expand Down

0 comments on commit 4fc252e

Please sign in to comment.