From dc574f6914ddf599d7e63bef2b1c092af197603e Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 24 Jul 2024 15:20:30 -0700 Subject: [PATCH 1/2] anthropic[patch]: Fix passing streamed tool calls back to anthropic --- libs/langchain-anthropic/src/chat_models.ts | 92 +++++++++++++++++-- .../src/tests/chat_models-tools.int.test.ts | 41 +++++++++ 2 files changed, 127 insertions(+), 6 deletions(-) diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index a0a0aae6333c..6cb6bbde625b 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -160,10 +160,18 @@ function _makeMessageChunkFromAnthropicEvent( streamUsage: boolean; coerceContentToString: boolean; usageData: { input_tokens: number; output_tokens: number }; + toolUse?: { + id: string; + name: string; + }; } ): { chunk: AIMessageChunk; usageData: { input_tokens: number; output_tokens: number }; + toolUse?: { + id: string; + name: string; + }; } | null { let usageDataCopy = { ...fields.usageData }; @@ -233,6 +241,10 @@ function _makeMessageChunkFromAnthropicEvent( additional_kwargs: {}, }), usageData: usageDataCopy, + toolUse: { + id: data.content_block.id, + name: data.content_block.name, + }, }; } else if ( data.type === "content_block_delta" && @@ -274,6 +286,25 @@ function _makeMessageChunkFromAnthropicEvent( }), usageData: usageDataCopy, }; + } else if (data.type === "content_block_stop" && fields.toolUse) { + // Only yield the ID & name when the tool_use block is complete. + // This is so the names & IDs do not get concatenated. + return { + chunk: new AIMessageChunk({ + content: fields.coerceContentToString + ? "" + : [ + { + id: fields.toolUse.id, + name: fields.toolUse.name, + index: data.index, + type: "input_json_delta", + }, + ], + additional_kwargs: {}, + }), + usageData: usageDataCopy, + }; } return null; @@ -424,6 +455,9 @@ export function _convertLangChainToolCallToAnthropic( } function _formatContent(content: MessageContent) { + const toolTypes = ["tool_use", "tool_result", "input_json_delta"]; + const textTypes = ["text", "text_delta"]; + if (typeof content === "string") { return content; } else { @@ -439,16 +473,34 @@ function _formatContent(content: MessageContent) { type: "image" as const, // Explicitly setting the type as "image" source, }; - } else if (contentPart.type === "text") { + } else if (textTypes.find((t) => t === contentPart.type) && "text" in contentPart) { // Assuming contentPart is of type MessageContentText here return { type: "text" as const, // Explicitly setting the type as "text" text: contentPart.text, }; } else if ( - contentPart.type === "tool_use" || - contentPart.type === "tool_result" + toolTypes.find((t) => t === contentPart.type) ) { + if ("index" in contentPart) { + // Anthropic does not support passing the index field here, so we remove it + delete contentPart.index; + } + + if (contentPart.type === "input_json_delta") { + // If type is `input_json_delta`, rename to `tool_use` for Anthropic + contentPart.type = "tool_use"; + } + + if ("input" in contentPart) { + // If the input is a JSON string, attempt to parse it + try { + contentPart.input = JSON.parse(contentPart.input); + } catch { + // no-op + } + } + // TODO: Fix when SDK types are fixed return { ...contentPart, @@ -519,7 +571,9 @@ function _formatMessagesForAnthropic(messages: BaseMessage[]): { const hasMismatchedToolCalls = !message.tool_calls.every((toolCall) => content.find( (contentPart) => - contentPart.type === "tool_use" && contentPart.id === toolCall.id + (contentPart.type === "tool_use" || + contentPart.type === "input_json_delta") && + contentPart.id === toolCall.id ) ); if (hasMismatchedToolCalls) { @@ -581,12 +635,14 @@ function extractToolCallChunk( ) { if (typeof inputJsonDeltaChunks.input === "string") { newToolCallChunk = { + id: inputJsonDeltaChunks.id, args: inputJsonDeltaChunks.input, index: inputJsonDeltaChunks.index, type: "tool_call_chunk", }; } else { newToolCallChunk = { + id: inputJsonDeltaChunks.id, args: JSON.stringify(inputJsonDeltaChunks.input, null, 2), index: inputJsonDeltaChunks.index, type: "tool_call_chunk", @@ -919,6 +975,14 @@ export class ChatAnthropicMessages< let usageData = { input_tokens: 0, output_tokens: 0 }; let concatenatedChunks: AIMessageChunk | undefined; + // Anthropic only yields the tool name and id once, so we need to save those + // so we can yield them with the rest of the tool_use content. + let toolUse: + | { + id: string; + name: string; + } + | undefined; for await (const data of stream) { if (options.signal?.aborted) { @@ -930,12 +994,25 @@ export class ChatAnthropicMessages< streamUsage: !!(this.streamUsage || options.streamUsage), coerceContentToString, usageData, + toolUse: toolUse ? { + id: toolUse.id, + name: toolUse.name, + } : undefined, }); if (!result) continue; - const { chunk, usageData: updatedUsageData } = result; + const { + chunk, + usageData: updatedUsageData, + toolUse: updatedToolUse, + } = result; + usageData = updatedUsageData; + if (updatedToolUse) { + toolUse = updatedToolUse; + } + const newToolCallChunk = extractToolCallChunk(chunk); // Maintain concatenatedChunks for accessing the complete `tool_use` content block. concatenatedChunks = concatenatedChunks @@ -1015,11 +1092,14 @@ export class ChatAnthropicMessages< }, } : requestOptions; + const formattedMsgs = _formatMessagesForAnthropic(messages); + console.log("formattedMsgs"); + console.dir(formattedMsgs, { depth: null }); const response = await this.completionWithRetry( { ...params, stream: false, - ..._formatMessagesForAnthropic(messages), + ...formattedMsgs, }, options ); diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index 3af61e38f04e..f3566a657f21 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -440,3 +440,44 @@ test("llm token callbacks can handle tool calls", async () => { if (!args) return; expect(args).toEqual(JSON.parse(tokens)); }); + +test.only("Anthropic can stream tool calls, and invoke again with that tool call", async () => { + const input = [ + new HumanMessage("What is the weather in SF?"), + ]; + + const weatherTool = tool( + (_) => "The weather in San Francisco is 25°C", + { + name: "get_weather", + description: zodSchema.description, + schema: zodSchema, + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + + const stream = await modelWithTools.stream(input); + + let finalChunk: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); + } + if (!finalChunk) { + throw new Error("chunk not defined"); + } + // Push the AI message with the tool call to the input array. + input.push(finalChunk); + // Push a ToolMessage to the input array to represent the tool call response. + input.push( + new ToolMessage({ + tool_call_id: finalChunk.tool_calls?.[0].id ?? "", + content: + "The weather in San Francisco is currently 25 degrees and sunny.", + name: "get_weather", + }) + ); + // Invoke again to ensure Anthropic can handle it's own tool call. + const finalResult = await modelWithTools.invoke(input); + console.dir(finalResult, { depth: null }); +}); From a5f41a63fafe532c4224f12b832148ed2828c803 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 24 Jul 2024 15:43:21 -0700 Subject: [PATCH 2/2] rm anthropic test, implement standard tests --- libs/langchain-anthropic/src/chat_models.ts | 50 +++--- .../src/tests/chat_models-tools.int.test.ts | 41 ----- libs/langchain-aws/package.json | 2 +- .../src/integration_tests/chat_models.ts | 170 +++++++++++++++++- 4 files changed, 197 insertions(+), 66 deletions(-) diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index 6cb6bbde625b..a9fcbb118944 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -473,29 +473,32 @@ function _formatContent(content: MessageContent) { type: "image" as const, // Explicitly setting the type as "image" source, }; - } else if (textTypes.find((t) => t === contentPart.type) && "text" in contentPart) { + } else if ( + textTypes.find((t) => t === contentPart.type) && + "text" in contentPart + ) { // Assuming contentPart is of type MessageContentText here return { type: "text" as const, // Explicitly setting the type as "text" text: contentPart.text, }; - } else if ( - toolTypes.find((t) => t === contentPart.type) - ) { - if ("index" in contentPart) { - // Anthropic does not support passing the index field here, so we remove it - delete contentPart.index; + } else if (toolTypes.find((t) => t === contentPart.type)) { + const contentPartCopy = { ...contentPart }; + if ("index" in contentPartCopy) { + // Anthropic does not support passing the index field here, so we remove it. + delete contentPartCopy.index; } - - if (contentPart.type === "input_json_delta") { - // If type is `input_json_delta`, rename to `tool_use` for Anthropic - contentPart.type = "tool_use"; + + if (contentPartCopy.type === "input_json_delta") { + // `input_json_delta` type only represents yielding partial tool inputs + // and is not a valid type for Anthropic messages. + contentPartCopy.type = "tool_use"; } - if ("input" in contentPart) { - // If the input is a JSON string, attempt to parse it + if ("input" in contentPartCopy) { + // Anthropic tool use inputs should be valid objects, when applicable. try { - contentPart.input = JSON.parse(contentPart.input); + contentPartCopy.input = JSON.parse(contentPartCopy.input); } catch { // no-op } @@ -503,7 +506,7 @@ function _formatContent(content: MessageContent) { // TODO: Fix when SDK types are fixed return { - ...contentPart, + ...contentPartCopy, // eslint-disable-next-line @typescript-eslint/no-explicit-any } as any; } else { @@ -636,6 +639,7 @@ function extractToolCallChunk( if (typeof inputJsonDeltaChunks.input === "string") { newToolCallChunk = { id: inputJsonDeltaChunks.id, + name: inputJsonDeltaChunks.name, args: inputJsonDeltaChunks.input, index: inputJsonDeltaChunks.index, type: "tool_call_chunk", @@ -643,6 +647,7 @@ function extractToolCallChunk( } else { newToolCallChunk = { id: inputJsonDeltaChunks.id, + name: inputJsonDeltaChunks.name, args: JSON.stringify(inputJsonDeltaChunks.input, null, 2), index: inputJsonDeltaChunks.index, type: "tool_call_chunk", @@ -994,10 +999,12 @@ export class ChatAnthropicMessages< streamUsage: !!(this.streamUsage || options.streamUsage), coerceContentToString, usageData, - toolUse: toolUse ? { - id: toolUse.id, - name: toolUse.name, - } : undefined, + toolUse: toolUse + ? { + id: toolUse.id, + name: toolUse.name, + } + : undefined, }); if (!result) continue; @@ -1092,14 +1099,11 @@ export class ChatAnthropicMessages< }, } : requestOptions; - const formattedMsgs = _formatMessagesForAnthropic(messages); - console.log("formattedMsgs"); - console.dir(formattedMsgs, { depth: null }); const response = await this.completionWithRetry( { ...params, stream: false, - ...formattedMsgs, + ..._formatMessagesForAnthropic(messages), }, options ); diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index f3566a657f21..3af61e38f04e 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -440,44 +440,3 @@ test("llm token callbacks can handle tool calls", async () => { if (!args) return; expect(args).toEqual(JSON.parse(tokens)); }); - -test.only("Anthropic can stream tool calls, and invoke again with that tool call", async () => { - const input = [ - new HumanMessage("What is the weather in SF?"), - ]; - - const weatherTool = tool( - (_) => "The weather in San Francisco is 25°C", - { - name: "get_weather", - description: zodSchema.description, - schema: zodSchema, - } - ); - - const modelWithTools = model.bindTools([weatherTool]); - - const stream = await modelWithTools.stream(input); - - let finalChunk: AIMessageChunk | undefined; - for await (const chunk of stream) { - finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); - } - if (!finalChunk) { - throw new Error("chunk not defined"); - } - // Push the AI message with the tool call to the input array. - input.push(finalChunk); - // Push a ToolMessage to the input array to represent the tool call response. - input.push( - new ToolMessage({ - tool_call_id: finalChunk.tool_calls?.[0].id ?? "", - content: - "The weather in San Francisco is currently 25 degrees and sunny.", - name: "get_weather", - }) - ); - // Invoke again to ensure Anthropic can handle it's own tool call. - const finalResult = await modelWithTools.invoke(input); - console.dir(finalResult, { depth: null }); -}); diff --git a/libs/langchain-aws/package.json b/libs/langchain-aws/package.json index 5bfac18e5270..679c204ac2d4 100644 --- a/libs/langchain-aws/package.json +++ b/libs/langchain-aws/package.json @@ -97,4 +97,4 @@ "index.d.ts", "index.d.cts" ] -} \ No newline at end of file +} diff --git a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts index f3151ed32b59..ac73ef8631a6 100644 --- a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts +++ b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts @@ -10,10 +10,11 @@ import { getBufferString, } from "@langchain/core/messages"; import { z } from "zod"; -import { StructuredTool } from "@langchain/core/tools"; +import { StructuredTool, tool } from "@langchain/core/tools"; import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatPromptTemplate } from "@langchain/core/prompts"; import { RunnableLambda } from "@langchain/core/runnables"; +import { concat } from "@langchain/core/utils/stream"; import { BaseChatModelsTests, BaseChatModelsTestsFields, @@ -522,6 +523,159 @@ export abstract class ChatModelIntegrationTests< expect(cacheValue2).toEqual(cacheValue); } + /** + * This test verifies models can invoke a tool, and use the AIMessage + * with the tool call in a followup request. This is useful when building + * agents, or other pipelines that invoke tools. + */ + async testModelCanUseToolUseAIMessage() { + if (!this.chatModelHasToolCalling) { + console.log("Test requires tool calling. Skipping..."); + return; + } + + const model = new this.Cls(this.constructorArgs); + if (!model.bindTools) { + throw new Error( + "bindTools undefined. Cannot test OpenAI formatted tool calls." + ); + } + + const weatherSchema = z.object({ + location: z.string().describe("The location to get the weather for."), + }); + + // Define the tool + const weatherTool = tool( + (_) => "The weather in San Francisco is 70 degrees and sunny.", + { + name: "get_current_weather", + schema: weatherSchema, + description: "Get the current weather for a location.", + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + + // List of messages to initially invoke the model with, and to hold + // followup messages to invoke the model with. + const messages = [ + new HumanMessage( + "What's the weather like in San Francisco right now? Use the 'get_current_weather' tool to find the answer." + ), + ]; + + const result: AIMessage = await modelWithTools.invoke(messages); + + expect(result.tool_calls?.[0]).toBeDefined(); + if (!result.tool_calls?.[0]) { + throw new Error("result.tool_calls is undefined"); + } + const { tool_calls } = result; + expect(tool_calls[0].name).toBe("get_current_weather"); + + // Push the result of the tool call into the messages array so we can + // confirm in the followup request the model can use the tool call. + messages.push(result); + + // Create a dummy ToolMessage representing the output of the tool call. + const toolMessage = new ToolMessage({ + tool_call_id: tool_calls[0].id ?? "", + name: tool_calls[0].name, + content: await weatherTool.invoke( + tool_calls[0].args as z.infer + ), + }); + messages.push(toolMessage); + + const finalResult = await modelWithTools.invoke(messages); + + expect(finalResult.content).not.toBe(""); + } + + /** + * Same as the above test, but streaming both model invocations. + */ + async testModelCanUseToolUseAIMessageWithStreaming() { + if (!this.chatModelHasToolCalling) { + console.log("Test requires tool calling. Skipping..."); + return; + } + + const model = new this.Cls(this.constructorArgs); + if (!model.bindTools) { + throw new Error( + "bindTools undefined. Cannot test OpenAI formatted tool calls." + ); + } + + const weatherSchema = z.object({ + location: z.string().describe("The location to get the weather for."), + }); + + // Define the tool + const weatherTool = tool( + (_) => "The weather in San Francisco is 70 degrees and sunny.", + { + name: "get_current_weather", + schema: weatherSchema, + description: "Get the current weather for a location.", + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + + // List of messages to initially invoke the model with, and to hold + // followup messages to invoke the model with. + const messages = [ + new HumanMessage( + "What's the weather like in San Francisco right now? Use the 'get_current_weather' tool to find the answer." + ), + ]; + + const stream = await modelWithTools.stream(messages); + let result: AIMessageChunk | undefined; + for await (const chunk of stream) { + result = !result ? chunk : concat(result, chunk); + } + + expect(result).toBeDefined(); + if (!result) return; + + expect(result.tool_calls?.[0]).toBeDefined(); + if (!result.tool_calls?.[0]) { + throw new Error("result.tool_calls is undefined"); + } + + const { tool_calls } = result; + expect(tool_calls[0].name).toBe("get_current_weather"); + + // Push the result of the tool call into the messages array so we can + // confirm in the followup request the model can use the tool call. + messages.push(result); + + // Create a dummy ToolMessage representing the output of the tool call. + const toolMessage = new ToolMessage({ + tool_call_id: tool_calls[0].id ?? "", + name: tool_calls[0].name, + content: await weatherTool.invoke( + tool_calls[0].args as z.infer + ), + }); + messages.push(toolMessage); + + const finalStream = await modelWithTools.stream(messages); + let finalResult: AIMessageChunk | undefined; + for await (const chunk of finalStream) { + finalResult = !finalResult ? chunk : concat(finalResult, chunk); + } + + expect(finalResult).toBeDefined(); + if (!finalResult) return; + + expect(finalResult.content).not.toBe(""); + } + /** * Run all unit tests for the chat model. * Each test is wrapped in a try/catch block to prevent the entire test suite from failing. @@ -629,6 +783,20 @@ export abstract class ChatModelIntegrationTests< console.error("testCacheComplexMessageTypes failed", e); } + try { + await this.testModelCanUseToolUseAIMessage(); + } catch (e: any) { + allTestsPassed = false; + console.error("testModelCanUseToolUseAIMessage failed", e); + } + + try { + await this.testModelCanUseToolUseAIMessageWithStreaming(); + } catch (e: any) { + allTestsPassed = false; + console.error("testModelCanUseToolUseAIMessageWithStreaming failed", e); + } + return allTestsPassed; } }