diff --git a/langchain-core/src/messages/base.ts b/langchain-core/src/messages/base.ts index 264c4ffa52bb..41a7cb2c2f70 100644 --- a/langchain-core/src/messages/base.ts +++ b/langchain-core/src/messages/base.ts @@ -305,6 +305,35 @@ export function _mergeLists(left?: any[], right?: any[]) { } } +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function _mergeObj( + left: T | undefined, + right: T | undefined +): T { + if (!left && !right) { + throw new Error("Cannot merge two undefined objects."); + } + if (!left || !right) { + return left || (right as T); + } else if (typeof left !== typeof right) { + throw new Error( + `Cannot merge objects of different types.\nLeft ${typeof left}\nRight ${typeof right}` + ); + } else if (typeof left === "string" && typeof right === "string") { + return (left + right) as T; + } else if (Array.isArray(left) && Array.isArray(right)) { + return _mergeLists(left, right) as T; + } else if (typeof left === "object" && typeof right === "object") { + return _mergeDicts(left, right) as T; + } else if (left === right) { + return left; + } else { + throw new Error( + `Can not merge objects of different types.\nLeft ${left}\nRight ${right}` + ); + } +} + /** * Represents a chunk of a message, which can be concatenated with other * message chunks. It includes a method `_merge_kwargs_dict()` for merging diff --git a/langchain-core/src/messages/tests/base_message.test.ts b/langchain-core/src/messages/tests/base_message.test.ts index 257d1c4cebad..fb6f3797f31f 100644 --- a/langchain-core/src/messages/tests/base_message.test.ts +++ b/langchain-core/src/messages/tests/base_message.test.ts @@ -1,6 +1,11 @@ import { test } from "@jest/globals"; import { ChatPromptTemplate } from "../../prompts/chat.js"; -import { HumanMessage, AIMessage, ToolMessage } from "../index.js"; +import { + HumanMessage, + AIMessage, + ToolMessage, + ToolMessageChunk, +} from "../index.js"; import { load } from "../../load/index.js"; test("Test ChatPromptTemplate can format OpenAI content image messages", async () => { @@ -127,3 +132,64 @@ test("Deserialisation and serialisation of messages with ID", async () => { expect(deserialized).toEqual(message); expect(deserialized.id).toBe(messageId); }); + +test("Can concat raw_output (string) of ToolMessageChunk", () => { + const rawOutputOne = "Hello"; + const rawOutputTwo = " world"; + const chunk1 = new ToolMessageChunk({ + content: "Hello", + tool_call_id: "1", + raw_output: rawOutputOne, + }); + const chunk2 = new ToolMessageChunk({ + content: " world", + tool_call_id: "1", + raw_output: rawOutputTwo, + }); + + const concated = chunk1.concat(chunk2); + expect(concated.raw_output).toBe(`${rawOutputOne}${rawOutputTwo}`); +}); + +test("Can concat raw_output (array) of ToolMessageChunk", () => { + const rawOutputOne = ["Hello", " world"]; + const rawOutputTwo = ["!!"]; + const chunk1 = new ToolMessageChunk({ + content: "Hello", + tool_call_id: "1", + raw_output: rawOutputOne, + }); + const chunk2 = new ToolMessageChunk({ + content: " world", + tool_call_id: "1", + raw_output: rawOutputTwo, + }); + + const concated = chunk1.concat(chunk2); + expect(concated.raw_output).toEqual(["Hello", " world", "!!"]); +}); + +test("Can concat raw_output (object) of ToolMessageChunk", () => { + const rawOutputOne = { + foo: "bar", + }; + const rawOutputTwo = { + bar: "baz", + }; + const chunk1 = new ToolMessageChunk({ + content: "Hello", + tool_call_id: "1", + raw_output: rawOutputOne, + }); + const chunk2 = new ToolMessageChunk({ + content: " world", + tool_call_id: "1", + raw_output: rawOutputTwo, + }); + + const concated = chunk1.concat(chunk2); + expect(concated.raw_output).toEqual({ + foo: "bar", + bar: "baz", + }); +}); diff --git a/langchain-core/src/messages/tool.ts b/langchain-core/src/messages/tool.ts index 3375cfd8572d..996e2cdba8cc 100644 --- a/langchain-core/src/messages/tool.ts +++ b/langchain-core/src/messages/tool.ts @@ -5,9 +5,18 @@ import { mergeContent, _mergeDicts, type MessageType, + _mergeObj, } from "./base.js"; export interface ToolMessageFieldsWithToolCallId extends BaseMessageFields { + /** + * The raw output of the tool. + * + * **Not part of the payload sent to the model.** Should only be specified if it is + * different from the message content, i.e. if only a subset of the full tool output + * is being passed as message content. + */ + raw_output?: unknown; tool_call_id: string; } @@ -26,6 +35,15 @@ export class ToolMessage extends BaseMessage { tool_call_id: string; + /** + * The raw output of the tool. + * + * **Not part of the payload sent to the model.** Should only be specified if it is + * different from the message content, i.e. if only a subset of the full tool output + * is being passed as message content. + */ + raw_output?: unknown; + constructor(fields: ToolMessageFieldsWithToolCallId); constructor( @@ -45,6 +63,7 @@ export class ToolMessage extends BaseMessage { } super(fields); this.tool_call_id = fields.tool_call_id; + this.raw_output = fields.raw_output; } _getType(): MessageType { @@ -63,9 +82,19 @@ export class ToolMessage extends BaseMessage { export class ToolMessageChunk extends BaseMessageChunk { tool_call_id: string; + /** + * The raw output of the tool. + * + * **Not part of the payload sent to the model.** Should only be specified if it is + * different from the message content, i.e. if only a subset of the full tool output + * is being passed as message content. + */ + raw_output?: unknown; + constructor(fields: ToolMessageFieldsWithToolCallId) { super(fields); this.tool_call_id = fields.tool_call_id; + this.raw_output = fields.raw_output; } static lc_name() { @@ -87,6 +116,7 @@ export class ToolMessageChunk extends BaseMessageChunk { this.response_metadata, chunk.response_metadata ), + raw_output: _mergeObj(this.raw_output, chunk.raw_output), tool_call_id: this.tool_call_id, id: this.id ?? chunk.id, });