Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws[patch]: Fix empty content bug #6043

Merged
merged 4 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions libs/langchain-aws/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,21 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
const converseSystem: BedrockSystemContentBlock[] = messages
.filter((msg) => msg._getType() === "system")
.map((msg) => {
const text = msg.content;
if (typeof text !== "string") {
throw new Error("System message content must be a string.");
if (typeof msg.content === "string") {
return { text: msg.content };
} else if (msg.content.length === 1 && msg.content[0].type === "text") {
return { text: msg.content[0].text };
}
return { text };
throw new Error(
"System message content must be either a string, or a content array containing a single text object."
);
});
const converseMessages: BedrockMessage[] = messages
.filter((msg) => !["system", "tool", "function"].includes(msg._getType()))
.filter((msg) => msg._getType() !== "system")
.map((msg) => {
if (msg._getType() === "ai") {
const castMsg = msg as AIMessage;
if (typeof castMsg.content === "string") {
if (typeof castMsg.content === "string" && castMsg.content !== "") {
return {
role: "assistant",
content: [
Expand All @@ -99,16 +102,21 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
},
})),
};
} else {
} else if (Array.isArray(castMsg.content)) {
const contentBlocks: ContentBlock[] = castMsg.content.map(
(block) => {
if (block.type === "text") {
if (block.type === "text" && block.text !== "") {
return {
text: block.text,
};
} else {
const blockValues = Object.fromEntries(
Object.values(block).filter(([key]) => key !== "type")
);
throw new Error(
`Unsupported content block type: ${block.type}`
`Unsupported content block type: ${
block.type
} with content of ${JSON.stringify(blockValues, null, 2)}`
);
}
}
Expand All @@ -117,10 +125,14 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
role: "assistant",
content: contentBlocks,
};
} else {
throw new Error(
`Invalid message content: empty string. '${msg._getType()}' must contain non-empty content.`
);
}
}
} else if (msg._getType() === "human" || msg._getType() === "generic") {
if (typeof msg.content === "string") {
if (typeof msg.content === "string" && msg.content !== "") {
return {
role: "user",
content: [
Expand All @@ -129,7 +141,7 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
},
],
};
} else {
} else if (Array.isArray(msg.content)) {
const contentBlocks: ContentBlock[] = msg.content.flatMap((block) => {
if (block.type === "image_url") {
const base64: string =
Expand All @@ -149,12 +161,17 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
role: "user",
content: contentBlocks,
};
} else {
throw new Error(
`Invalid message content: empty string. '${msg._getType()}' must contain non-empty content.`
);
}
} else if (msg._getType() === "tool") {
const castMsg = msg as ToolMessage;
if (typeof castMsg.content === "string") {
return {
role: undefined,
// Tool use messages are always from the user
role: "user",
content: [
{
toolResult: {
Expand All @@ -170,7 +187,8 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
};
} else {
return {
role: undefined,
// Tool use messages are always from the user
role: "user",
content: [
{
toolResult: {
Expand Down
51 changes: 50 additions & 1 deletion libs/langchain-aws/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
/* eslint-disable no-process-env */

import { test, expect } from "@jest/globals";
import { AIMessageChunk, HumanMessage } from "@langchain/core/messages";
import {
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
} from "@langchain/core/messages";
import { tool } from "@langchain/core/tools";
import { z } from "zod";
import { ChatBedrockConverse } from "../chat_models.js";
Expand Down Expand Up @@ -319,3 +325,46 @@ test("Test ChatBedrockConverse tool_choice works", async () => {
expect(result.tool_calls?.[0].name).toBe("get_weather");
expect(result.tool_calls?.[0].id).toBeDefined();
});

test("Model can handle empty content messages", async () => {
const model = new ChatBedrockConverse({
...baseConstructorArgs,
});

const retrieverTool = tool((_) => "Success", {
name: "retrieverTool",
schema: z.object({
url: z.string().describe("The URL to fetch"),
}),
description: "A tool to fetch data from a URL",
});

const messages = [
new SystemMessage("You're an advanced AI assistant."),
new HumanMessage(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://weather.com",
},
id: "123_retriever_tool",
},
],
}),
new ToolMessage({
tool_call_id: "123_retriever_tool",
content: "The weather in Berkeley, CA is 70 degrees and sunny.",
}),
];

const result = await model.bindTools([retrieverTool]).invoke(messages);

expect(result.content).toBeDefined();
expect(typeof result.content).toBe("string");
expect(result.content.length).toBeGreaterThan(1);
});
73 changes: 73 additions & 0 deletions libs/langchain-aws/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import {
SystemMessage,
HumanMessage,
AIMessage,
ToolMessage,
} from "@langchain/core/messages";
import { convertToConverseMessages } from "../common.js";

test("convertToConverseMessages works", () => {
const messages = [
new SystemMessage("You're an advanced AI assistant."),
new HumanMessage(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://weather.com",
},
id: "123_retriever_tool",
},
],
}),
new ToolMessage({
tool_call_id: "123_retriever_tool",
content: "The weather in Berkeley, CA is 70 degrees and sunny.",
}),
];

const { converseMessages, converseSystem } =
convertToConverseMessages(messages);

expect(converseSystem).toHaveLength(1);
expect(converseSystem[0].text).toBe("You're an advanced AI assistant.");

expect(converseMessages).toHaveLength(3);

const userMsgs = converseMessages.filter((msg) => msg.role === "user");
// Length of two because of the first user question, and tool use
// messages will have the user role.
expect(userMsgs).toHaveLength(2);
const textUserMsg = userMsgs.find((msg) => msg.content?.[0].text);
expect(textUserMsg?.content?.[0].text).toBe(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
);

const toolUseUserMsg = userMsgs.find((msg) => msg.content?.[0].toolResult);
expect(toolUseUserMsg).toBeDefined();
expect(toolUseUserMsg?.content).toHaveLength(1);
if (!toolUseUserMsg?.content?.length) return;

const toolResultContent = toolUseUserMsg.content[0];
expect(toolResultContent).toBeDefined();
expect(toolResultContent.toolResult?.toolUseId).toBe("123_retriever_tool");
expect(toolResultContent.toolResult?.content?.[0].text).toBe(
"The weather in Berkeley, CA is 70 degrees and sunny."
);

const assistantMsg = converseMessages.find((msg) => msg.role === "assistant");
expect(assistantMsg).toBeDefined();
if (!assistantMsg) return;

const toolUseContent = assistantMsg.content?.find((c) => "toolUse" in c);
expect(toolUseContent).toBeDefined();
expect(toolUseContent?.toolUse?.name).toBe("retrieverTool");
expect(toolUseContent?.toolUse?.toolUseId).toBe("123_retriever_tool");
expect(toolUseContent?.toolUse?.input).toEqual({
url: "https://weather.com",
});
});
Loading