Skip to content

Commit

Permalink
aws[patch]: Fix empty content bug (#6043)
Browse files Browse the repository at this point in the history
* aws[patch]: Fix empty content bug

* fix empty string issue

* chore: lint files

* moved non int test to non int test file
  • Loading branch information
bracesproul authored Jul 11, 2024
1 parent 58da6a9 commit 6111718
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 14 deletions.
44 changes: 31 additions & 13 deletions libs/langchain-aws/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,21 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
const converseSystem: BedrockSystemContentBlock[] = messages
.filter((msg) => msg._getType() === "system")
.map((msg) => {
const text = msg.content;
if (typeof text !== "string") {
throw new Error("System message content must be a string.");
if (typeof msg.content === "string") {
return { text: msg.content };
} else if (msg.content.length === 1 && msg.content[0].type === "text") {
return { text: msg.content[0].text };
}
return { text };
throw new Error(
"System message content must be either a string, or a content array containing a single text object."
);
});
const converseMessages: BedrockMessage[] = messages
.filter((msg) => !["system", "tool", "function"].includes(msg._getType()))
.filter((msg) => msg._getType() !== "system")
.map((msg) => {
if (msg._getType() === "ai") {
const castMsg = msg as AIMessage;
if (typeof castMsg.content === "string") {
if (typeof castMsg.content === "string" && castMsg.content !== "") {
return {
role: "assistant",
content: [
Expand All @@ -99,16 +102,21 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
},
})),
};
} else {
} else if (Array.isArray(castMsg.content)) {
const contentBlocks: ContentBlock[] = castMsg.content.map(
(block) => {
if (block.type === "text") {
if (block.type === "text" && block.text !== "") {
return {
text: block.text,
};
} else {
const blockValues = Object.fromEntries(
Object.values(block).filter(([key]) => key !== "type")
);
throw new Error(
`Unsupported content block type: ${block.type}`
`Unsupported content block type: ${
block.type
} with content of ${JSON.stringify(blockValues, null, 2)}`
);
}
}
Expand All @@ -117,10 +125,14 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
role: "assistant",
content: contentBlocks,
};
} else {
throw new Error(
`Invalid message content: empty string. '${msg._getType()}' must contain non-empty content.`
);
}
}
} else if (msg._getType() === "human" || msg._getType() === "generic") {
if (typeof msg.content === "string") {
if (typeof msg.content === "string" && msg.content !== "") {
return {
role: "user",
content: [
Expand All @@ -129,7 +141,7 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
},
],
};
} else {
} else if (Array.isArray(msg.content)) {
const contentBlocks: ContentBlock[] = msg.content.flatMap((block) => {
if (block.type === "image_url") {
const base64: string =
Expand All @@ -149,12 +161,17 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
role: "user",
content: contentBlocks,
};
} else {
throw new Error(
`Invalid message content: empty string. '${msg._getType()}' must contain non-empty content.`
);
}
} else if (msg._getType() === "tool") {
const castMsg = msg as ToolMessage;
if (typeof castMsg.content === "string") {
return {
role: undefined,
// Tool use messages are always from the user
role: "user",
content: [
{
toolResult: {
Expand All @@ -170,7 +187,8 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
};
} else {
return {
role: undefined,
// Tool use messages are always from the user
role: "user",
content: [
{
toolResult: {
Expand Down
51 changes: 50 additions & 1 deletion libs/langchain-aws/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
/* eslint-disable no-process-env */

import { test, expect } from "@jest/globals";
import { AIMessageChunk, HumanMessage } from "@langchain/core/messages";
import {
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
} from "@langchain/core/messages";
import { tool } from "@langchain/core/tools";
import { z } from "zod";
import { ChatBedrockConverse } from "../chat_models.js";
Expand Down Expand Up @@ -319,3 +325,46 @@ test("Test ChatBedrockConverse tool_choice works", async () => {
expect(result.tool_calls?.[0].name).toBe("get_weather");
expect(result.tool_calls?.[0].id).toBeDefined();
});

test("Model can handle empty content messages", async () => {
const model = new ChatBedrockConverse({
...baseConstructorArgs,
});

const retrieverTool = tool((_) => "Success", {
name: "retrieverTool",
schema: z.object({
url: z.string().describe("The URL to fetch"),
}),
description: "A tool to fetch data from a URL",
});

const messages = [
new SystemMessage("You're an advanced AI assistant."),
new HumanMessage(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://weather.com",
},
id: "123_retriever_tool",
},
],
}),
new ToolMessage({
tool_call_id: "123_retriever_tool",
content: "The weather in Berkeley, CA is 70 degrees and sunny.",
}),
];

const result = await model.bindTools([retrieverTool]).invoke(messages);

expect(result.content).toBeDefined();
expect(typeof result.content).toBe("string");
expect(result.content.length).toBeGreaterThan(1);
});
73 changes: 73 additions & 0 deletions libs/langchain-aws/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import {
SystemMessage,
HumanMessage,
AIMessage,
ToolMessage,
} from "@langchain/core/messages";
import { convertToConverseMessages } from "../common.js";

test("convertToConverseMessages works", () => {
const messages = [
new SystemMessage("You're an advanced AI assistant."),
new HumanMessage(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://weather.com",
},
id: "123_retriever_tool",
},
],
}),
new ToolMessage({
tool_call_id: "123_retriever_tool",
content: "The weather in Berkeley, CA is 70 degrees and sunny.",
}),
];

const { converseMessages, converseSystem } =
convertToConverseMessages(messages);

expect(converseSystem).toHaveLength(1);
expect(converseSystem[0].text).toBe("You're an advanced AI assistant.");

expect(converseMessages).toHaveLength(3);

const userMsgs = converseMessages.filter((msg) => msg.role === "user");
// Length of two because of the first user question, and tool use
// messages will have the user role.
expect(userMsgs).toHaveLength(2);
const textUserMsg = userMsgs.find((msg) => msg.content?.[0].text);
expect(textUserMsg?.content?.[0].text).toBe(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
);

const toolUseUserMsg = userMsgs.find((msg) => msg.content?.[0].toolResult);
expect(toolUseUserMsg).toBeDefined();
expect(toolUseUserMsg?.content).toHaveLength(1);
if (!toolUseUserMsg?.content?.length) return;

const toolResultContent = toolUseUserMsg.content[0];
expect(toolResultContent).toBeDefined();
expect(toolResultContent.toolResult?.toolUseId).toBe("123_retriever_tool");
expect(toolResultContent.toolResult?.content?.[0].text).toBe(
"The weather in Berkeley, CA is 70 degrees and sunny."
);

const assistantMsg = converseMessages.find((msg) => msg.role === "assistant");
expect(assistantMsg).toBeDefined();
if (!assistantMsg) return;

const toolUseContent = assistantMsg.content?.find((c) => "toolUse" in c);
expect(toolUseContent).toBeDefined();
expect(toolUseContent?.toolUse?.name).toBe("retrieverTool");
expect(toolUseContent?.toolUse?.toolUseId).toBe("123_retriever_tool");
expect(toolUseContent?.toolUse?.input).toEqual({
url: "https://weather.com",
});
});

0 comments on commit 6111718

Please sign in to comment.