From 474994fa1c3f22eb9f3ebac7c6ea574b7baa188b Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Tue, 12 Nov 2024 17:48:22 -0800 Subject: [PATCH] fix(anthropic, bedrock): Remove message merging logic (#7196) --- .../src/utils/message_inputs.ts | 36 ++++----------- .../src/utils/bedrock/anthropic.ts | 46 ++++++++----------- 2 files changed, 26 insertions(+), 56 deletions(-) diff --git a/libs/langchain-anthropic/src/utils/message_inputs.ts b/libs/langchain-anthropic/src/utils/message_inputs.ts index 4082405de828..c8db15ba5b9e 100644 --- a/libs/langchain-anthropic/src/utils/message_inputs.ts +++ b/libs/langchain-anthropic/src/utils/message_inputs.ts @@ -35,15 +35,15 @@ function _formatImage(imageUrl: string) { } as any; } -function _mergeMessages( +function _ensureMessageContents( messages: BaseMessage[] ): (SystemMessage | HumanMessage | AIMessage)[] { // Merge runs of human/tool messages into single human messages with content blocks. - const merged = []; + const updatedMsgs = []; for (const message of messages) { if (message._getType() === "tool") { if (typeof message.content === "string") { - const previousMessage = merged[merged.length - 1]; + const previousMessage = updatedMsgs[updatedMsgs.length - 1]; if ( previousMessage?._getType() === "human" && Array.isArray(previousMessage.content) && @@ -58,7 +58,7 @@ function _mergeMessages( }); } else { // If not, we create a new human message with the tool result. - merged.push( + updatedMsgs.push( new HumanMessage({ content: [ { @@ -71,7 +71,7 @@ function _mergeMessages( ); } } else { - merged.push( + updatedMsgs.push( new HumanMessage({ content: [ { @@ -84,30 +84,10 @@ function _mergeMessages( ); } } else { - const previousMessage = merged[merged.length - 1]; - if ( - previousMessage?._getType() === "human" && - message._getType() === "human" - ) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let combinedContent: Record[]; - if (typeof previousMessage.content === "string") { - combinedContent = [{ type: "text", text: previousMessage.content }]; - } else { - combinedContent = previousMessage.content; - } - if (typeof message.content === "string") { - combinedContent.push({ type: "text", text: message.content }); - } else { - combinedContent = combinedContent.concat(message.content); - } - previousMessage.content = combinedContent; - } else { - merged.push(message); - } + updatedMsgs.push(message); } } - return merged; + return updatedMsgs; } export function _convertLangChainToolCallToAnthropic( @@ -202,7 +182,7 @@ function _formatContent(content: MessageContent) { export function _convertMessagesToAnthropicPayload( messages: BaseMessage[] ): AnthropicMessageCreateParams { - const mergedMessages = _mergeMessages(messages); + const mergedMessages = _ensureMessageContents(messages); let system; if (mergedMessages.length > 0 && mergedMessages[0]._getType() === "system") { system = messages[0].content; diff --git a/libs/langchain-community/src/utils/bedrock/anthropic.ts b/libs/langchain-community/src/utils/bedrock/anthropic.ts index 4565a2f1615d..3f440bd2b014 100644 --- a/libs/langchain-community/src/utils/bedrock/anthropic.ts +++ b/libs/langchain-community/src/utils/bedrock/anthropic.ts @@ -47,15 +47,15 @@ function _formatImage(imageUrl: string) { } as any; } -function _mergeMessages( +function _ensureMessageContents( messages: BaseMessage[] ): (SystemMessage | HumanMessage | AIMessage)[] { // Merge runs of human/tool messages into single human messages with content blocks. - const merged = []; + const updatedMsgs = []; for (const message of messages) { if (message._getType() === "tool") { if (typeof message.content === "string") { - const previousMessage = merged[merged.length - 1]; + const previousMessage = updatedMsgs[updatedMsgs.length - 1]; if ( previousMessage?._getType() === "human" && Array.isArray(previousMessage.content) && @@ -70,7 +70,7 @@ function _mergeMessages( }); } else { // If not, we create a new human message with the tool result. - merged.push( + updatedMsgs.push( new HumanMessage({ content: [ { @@ -83,33 +83,23 @@ function _mergeMessages( ); } } else { - merged.push(new HumanMessage({ content: message.content })); + updatedMsgs.push( + new HumanMessage({ + content: [ + { + type: "tool_result", + content: _formatContent(message.content), + tool_use_id: (message as ToolMessage).tool_call_id, + }, + ], + }) + ); } } else { - const previousMessage = merged[merged.length - 1]; - if ( - previousMessage?._getType() === "human" && - message._getType() === "human" - ) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let combinedContent: Record[]; - if (typeof previousMessage.content === "string") { - combinedContent = [{ type: "text", text: previousMessage.content }]; - } else { - combinedContent = previousMessage.content; - } - if (typeof message.content === "string") { - combinedContent.push({ type: "text", text: message.content }); - } else { - combinedContent = combinedContent.concat(message.content); - } - previousMessage.content = combinedContent; - } else { - merged.push(message); - } + updatedMsgs.push(message); } } - return merged; + return updatedMsgs; } export function _convertLangChainToolCallToAnthropic( @@ -170,7 +160,7 @@ export function formatMessagesForAnthropic(messages: BaseMessage[]): { system?: string; messages: Record[]; } { - const mergedMessages = _mergeMessages(messages); + const mergedMessages = _ensureMessageContents(messages); let system: string | undefined; if (mergedMessages.length > 0 && mergedMessages[0]._getType() === "system") { if (typeof messages[0].content !== "string") {