Skip to content

Commit

Permalink
core[minor]: Add raw_output field to ToolMessage (#6007)
Browse files Browse the repository at this point in the history
* core[minor]: Add raw_output field to ToolMessage

* add generics

* tests

* made both be able to be undefined

* cr

* remove generics, replace with unknown
  • Loading branch information
bracesproul authored Jul 11, 2024
1 parent 1be142a commit d0d0c2f
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 1 deletion.
29 changes: 29 additions & 0 deletions langchain-core/src/messages/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,35 @@ export function _mergeLists(left?: any[], right?: any[]) {
}
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function _mergeObj<T = any>(
left: T | undefined,
right: T | undefined
): T {
if (!left && !right) {
throw new Error("Cannot merge two undefined objects.");
}
if (!left || !right) {
return left || (right as T);
} else if (typeof left !== typeof right) {
throw new Error(
`Cannot merge objects of different types.\nLeft ${typeof left}\nRight ${typeof right}`
);
} else if (typeof left === "string" && typeof right === "string") {
return (left + right) as T;
} else if (Array.isArray(left) && Array.isArray(right)) {
return _mergeLists(left, right) as T;
} else if (typeof left === "object" && typeof right === "object") {
return _mergeDicts(left, right) as T;
} else if (left === right) {
return left;
} else {
throw new Error(
`Can not merge objects of different types.\nLeft ${left}\nRight ${right}`
);
}
}

/**
* Represents a chunk of a message, which can be concatenated with other
* message chunks. It includes a method `_merge_kwargs_dict()` for merging
Expand Down
68 changes: 67 additions & 1 deletion langchain-core/src/messages/tests/base_message.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import { test } from "@jest/globals";
import { ChatPromptTemplate } from "../../prompts/chat.js";
import { HumanMessage, AIMessage, ToolMessage } from "../index.js";
import {
HumanMessage,
AIMessage,
ToolMessage,
ToolMessageChunk,
} from "../index.js";
import { load } from "../../load/index.js";

test("Test ChatPromptTemplate can format OpenAI content image messages", async () => {
Expand Down Expand Up @@ -127,3 +132,64 @@ test("Deserialisation and serialisation of messages with ID", async () => {
expect(deserialized).toEqual(message);
expect(deserialized.id).toBe(messageId);
});

test("Can concat raw_output (string) of ToolMessageChunk", () => {
const rawOutputOne = "Hello";
const rawOutputTwo = " world";
const chunk1 = new ToolMessageChunk({
content: "Hello",
tool_call_id: "1",
raw_output: rawOutputOne,
});
const chunk2 = new ToolMessageChunk({
content: " world",
tool_call_id: "1",
raw_output: rawOutputTwo,
});

const concated = chunk1.concat(chunk2);
expect(concated.raw_output).toBe(`${rawOutputOne}${rawOutputTwo}`);
});

test("Can concat raw_output (array) of ToolMessageChunk", () => {
const rawOutputOne = ["Hello", " world"];
const rawOutputTwo = ["!!"];
const chunk1 = new ToolMessageChunk({
content: "Hello",
tool_call_id: "1",
raw_output: rawOutputOne,
});
const chunk2 = new ToolMessageChunk({
content: " world",
tool_call_id: "1",
raw_output: rawOutputTwo,
});

const concated = chunk1.concat(chunk2);
expect(concated.raw_output).toEqual(["Hello", " world", "!!"]);
});

test("Can concat raw_output (object) of ToolMessageChunk", () => {
const rawOutputOne = {
foo: "bar",
};
const rawOutputTwo = {
bar: "baz",
};
const chunk1 = new ToolMessageChunk({
content: "Hello",
tool_call_id: "1",
raw_output: rawOutputOne,
});
const chunk2 = new ToolMessageChunk({
content: " world",
tool_call_id: "1",
raw_output: rawOutputTwo,
});

const concated = chunk1.concat(chunk2);
expect(concated.raw_output).toEqual({
foo: "bar",
bar: "baz",
});
});
30 changes: 30 additions & 0 deletions langchain-core/src/messages/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,18 @@ import {
mergeContent,
_mergeDicts,
type MessageType,
_mergeObj,
} from "./base.js";

export interface ToolMessageFieldsWithToolCallId extends BaseMessageFields {
/**
* The raw output of the tool.
*
* **Not part of the payload sent to the model.** Should only be specified if it is
* different from the message content, i.e. if only a subset of the full tool output
* is being passed as message content.
*/
raw_output?: unknown;
tool_call_id: string;
}

Expand All @@ -26,6 +35,15 @@ export class ToolMessage extends BaseMessage {

tool_call_id: string;

/**
* The raw output of the tool.
*
* **Not part of the payload sent to the model.** Should only be specified if it is
* different from the message content, i.e. if only a subset of the full tool output
* is being passed as message content.
*/
raw_output?: unknown;

constructor(fields: ToolMessageFieldsWithToolCallId);

constructor(
Expand All @@ -45,6 +63,7 @@ export class ToolMessage extends BaseMessage {
}
super(fields);
this.tool_call_id = fields.tool_call_id;
this.raw_output = fields.raw_output;
}

_getType(): MessageType {
Expand All @@ -63,9 +82,19 @@ export class ToolMessage extends BaseMessage {
export class ToolMessageChunk extends BaseMessageChunk {
tool_call_id: string;

/**
* The raw output of the tool.
*
* **Not part of the payload sent to the model.** Should only be specified if it is
* different from the message content, i.e. if only a subset of the full tool output
* is being passed as message content.
*/
raw_output?: unknown;

constructor(fields: ToolMessageFieldsWithToolCallId) {
super(fields);
this.tool_call_id = fields.tool_call_id;
this.raw_output = fields.raw_output;
}

static lc_name() {
Expand All @@ -87,6 +116,7 @@ export class ToolMessageChunk extends BaseMessageChunk {
this.response_metadata,
chunk.response_metadata
),
raw_output: _mergeObj(this.raw_output, chunk.raw_output),
tool_call_id: this.tool_call_id,
id: this.id ?? chunk.id,
});
Expand Down

0 comments on commit d0d0c2f

Please sign in to comment.