diff --git a/libs/langchain-aws/src/common.ts b/libs/langchain-aws/src/common.ts index 945fa932f43f..9fc7fd22a6aa 100644 --- a/libs/langchain-aws/src/common.ts +++ b/libs/langchain-aws/src/common.ts @@ -67,18 +67,21 @@ export function convertToConverseMessages(messages: BaseMessage[]): { const converseSystem: BedrockSystemContentBlock[] = messages .filter((msg) => msg._getType() === "system") .map((msg) => { - const text = msg.content; - if (typeof text !== "string") { - throw new Error("System message content must be a string."); + if (typeof msg.content === "string") { + return { text: msg.content }; + } else if (msg.content.length === 1 && msg.content[0].type === "text") { + return { text: msg.content[0].text }; } - return { text }; + throw new Error( + "System message content must be either a string, or a content array containing a single text object." + ); }); const converseMessages: BedrockMessage[] = messages - .filter((msg) => !["system", "tool", "function"].includes(msg._getType())) + .filter((msg) => msg._getType() !== "system") .map((msg) => { if (msg._getType() === "ai") { const castMsg = msg as AIMessage; - if (typeof castMsg.content === "string") { + if (typeof castMsg.content === "string" && castMsg.content !== "") { return { role: "assistant", content: [ @@ -99,16 +102,21 @@ export function convertToConverseMessages(messages: BaseMessage[]): { }, })), }; - } else { + } else if (Array.isArray(castMsg.content)) { const contentBlocks: ContentBlock[] = castMsg.content.map( (block) => { - if (block.type === "text") { + if (block.type === "text" && block.text !== "") { return { text: block.text, }; } else { + const blockValues = Object.fromEntries( + Object.values(block).filter(([key]) => key !== "type") + ); throw new Error( - `Unsupported content block type: ${block.type}` + `Unsupported content block type: ${ + block.type + } with content of ${JSON.stringify(blockValues, null, 2)}` ); } } @@ -117,10 +125,14 @@ export function convertToConverseMessages(messages: BaseMessage[]): { role: "assistant", content: contentBlocks, }; + } else { + throw new Error( + `Invalid message content: empty string. '${msg._getType()}' must contain non-empty content.` + ); } } } else if (msg._getType() === "human" || msg._getType() === "generic") { - if (typeof msg.content === "string") { + if (typeof msg.content === "string" && msg.content !== "") { return { role: "user", content: [ @@ -129,7 +141,7 @@ export function convertToConverseMessages(messages: BaseMessage[]): { }, ], }; - } else { + } else if (Array.isArray(msg.content)) { const contentBlocks: ContentBlock[] = msg.content.flatMap((block) => { if (block.type === "image_url") { const base64: string = @@ -149,12 +161,17 @@ export function convertToConverseMessages(messages: BaseMessage[]): { role: "user", content: contentBlocks, }; + } else { + throw new Error( + `Invalid message content: empty string. '${msg._getType()}' must contain non-empty content.` + ); } } else if (msg._getType() === "tool") { const castMsg = msg as ToolMessage; if (typeof castMsg.content === "string") { return { - role: undefined, + // Tool use messages are always from the user + role: "user", content: [ { toolResult: { @@ -170,7 +187,8 @@ export function convertToConverseMessages(messages: BaseMessage[]): { }; } else { return { - role: undefined, + // Tool use messages are always from the user + role: "user", content: [ { toolResult: { diff --git a/libs/langchain-aws/src/tests/chat_models.int.test.ts b/libs/langchain-aws/src/tests/chat_models.int.test.ts index fda040eb56ff..78652f848f9c 100644 --- a/libs/langchain-aws/src/tests/chat_models.int.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.int.test.ts @@ -1,7 +1,13 @@ /* eslint-disable no-process-env */ import { test, expect } from "@jest/globals"; -import { AIMessageChunk, HumanMessage } from "@langchain/core/messages"; +import { + AIMessage, + AIMessageChunk, + HumanMessage, + SystemMessage, + ToolMessage, +} from "@langchain/core/messages"; import { tool } from "@langchain/core/tools"; import { z } from "zod"; import { ChatBedrockConverse } from "../chat_models.js"; @@ -319,3 +325,46 @@ test("Test ChatBedrockConverse tool_choice works", async () => { expect(result.tool_calls?.[0].name).toBe("get_weather"); expect(result.tool_calls?.[0].id).toBeDefined(); }); + +test("Model can handle empty content messages", async () => { + const model = new ChatBedrockConverse({ + ...baseConstructorArgs, + }); + + const retrieverTool = tool((_) => "Success", { + name: "retrieverTool", + schema: z.object({ + url: z.string().describe("The URL to fetch"), + }), + description: "A tool to fetch data from a URL", + }); + + const messages = [ + new SystemMessage("You're an advanced AI assistant."), + new HumanMessage( + "What's the weather like today in Berkeley, CA? Use weather.com to check." + ), + new AIMessage({ + content: "", + tool_calls: [ + { + name: "retrieverTool", + args: { + url: "https://weather.com", + }, + id: "123_retriever_tool", + }, + ], + }), + new ToolMessage({ + tool_call_id: "123_retriever_tool", + content: "The weather in Berkeley, CA is 70 degrees and sunny.", + }), + ]; + + const result = await model.bindTools([retrieverTool]).invoke(messages); + + expect(result.content).toBeDefined(); + expect(typeof result.content).toBe("string"); + expect(result.content.length).toBeGreaterThan(1); +}); diff --git a/libs/langchain-aws/src/tests/chat_models.test.ts b/libs/langchain-aws/src/tests/chat_models.test.ts new file mode 100644 index 000000000000..55c9bdc9713a --- /dev/null +++ b/libs/langchain-aws/src/tests/chat_models.test.ts @@ -0,0 +1,73 @@ +import { + SystemMessage, + HumanMessage, + AIMessage, + ToolMessage, +} from "@langchain/core/messages"; +import { convertToConverseMessages } from "../common.js"; + +test("convertToConverseMessages works", () => { + const messages = [ + new SystemMessage("You're an advanced AI assistant."), + new HumanMessage( + "What's the weather like today in Berkeley, CA? Use weather.com to check." + ), + new AIMessage({ + content: "", + tool_calls: [ + { + name: "retrieverTool", + args: { + url: "https://weather.com", + }, + id: "123_retriever_tool", + }, + ], + }), + new ToolMessage({ + tool_call_id: "123_retriever_tool", + content: "The weather in Berkeley, CA is 70 degrees and sunny.", + }), + ]; + + const { converseMessages, converseSystem } = + convertToConverseMessages(messages); + + expect(converseSystem).toHaveLength(1); + expect(converseSystem[0].text).toBe("You're an advanced AI assistant."); + + expect(converseMessages).toHaveLength(3); + + const userMsgs = converseMessages.filter((msg) => msg.role === "user"); + // Length of two because of the first user question, and tool use + // messages will have the user role. + expect(userMsgs).toHaveLength(2); + const textUserMsg = userMsgs.find((msg) => msg.content?.[0].text); + expect(textUserMsg?.content?.[0].text).toBe( + "What's the weather like today in Berkeley, CA? Use weather.com to check." + ); + + const toolUseUserMsg = userMsgs.find((msg) => msg.content?.[0].toolResult); + expect(toolUseUserMsg).toBeDefined(); + expect(toolUseUserMsg?.content).toHaveLength(1); + if (!toolUseUserMsg?.content?.length) return; + + const toolResultContent = toolUseUserMsg.content[0]; + expect(toolResultContent).toBeDefined(); + expect(toolResultContent.toolResult?.toolUseId).toBe("123_retriever_tool"); + expect(toolResultContent.toolResult?.content?.[0].text).toBe( + "The weather in Berkeley, CA is 70 degrees and sunny." + ); + + const assistantMsg = converseMessages.find((msg) => msg.role === "assistant"); + expect(assistantMsg).toBeDefined(); + if (!assistantMsg) return; + + const toolUseContent = assistantMsg.content?.find((c) => "toolUse" in c); + expect(toolUseContent).toBeDefined(); + expect(toolUseContent?.toolUse?.name).toBe("retrieverTool"); + expect(toolUseContent?.toolUse?.toolUseId).toBe("123_retriever_tool"); + expect(toolUseContent?.toolUse?.input).toEqual({ + url: "https://weather.com", + }); +});