diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index f5e788bb2e29..8aa8295f2676 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -41,12 +41,13 @@ import { RunnableToolLike, } from "@langchain/core/runnables"; import { isZodSchema } from "@langchain/core/utils/types"; -import { ToolCall } from "@langchain/core/messages/tool"; +import { ToolCall, ToolCallChunk } from "@langchain/core/messages/tool"; import { z } from "zod"; import type { MessageCreateParams, Tool as AnthropicTool, } from "@anthropic-ai/sdk/resources/index.mjs"; +import { concat } from "@langchain/core/utils/stream"; import { AnthropicToolsOutputParser, extractToolCalls, @@ -85,6 +86,10 @@ export interface ChatAnthropicCallOptions type AnthropicMessageResponse = Anthropic.ContentBlock | AnthropicToolResponse; +function _toolsInParams(params: AnthropicMessageCreateParams): boolean { + return !!(params.tools && params.tools.length > 0); +} + function _formatImage(imageUrl: string) { const regex = /^data:(image\/.+);base64,(.+)$/; const match = imageUrl.match(regex); @@ -154,6 +159,130 @@ function isAnthropicTool(tool: any): tool is AnthropicTool { return "input_schema" in tool; } +function _makeMessageChunkFromAnthropicEvent( + data: Anthropic.Messages.RawMessageStreamEvent, + fields: { + streamUsage: boolean; + coerceContentToString: boolean; + usageData: { input_tokens: number; output_tokens: number }; + } +): { + chunk: AIMessageChunk; + usageData: { input_tokens: number; output_tokens: number }; +} | null { + let usageDataCopy = { ...fields.usageData }; + + if (data.type === "message_start") { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { content, usage, ...additionalKwargs } = data.message; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const filteredAdditionalKwargs: Record = {}; + for (const [key, value] of Object.entries(additionalKwargs)) { + if (value !== undefined && value !== null) { + filteredAdditionalKwargs[key] = value; + } + } + usageDataCopy = usage; + let usageMetadata: UsageMetadata | undefined; + if (fields.streamUsage) { + usageMetadata = { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + total_tokens: usage.input_tokens + usage.output_tokens, + }; + } + return { + chunk: new AIMessageChunk({ + content: fields.coerceContentToString ? "" : [], + additional_kwargs: filteredAdditionalKwargs, + usage_metadata: usageMetadata, + }), + usageData: usageDataCopy, + }; + } else if (data.type === "message_delta") { + let usageMetadata: UsageMetadata | undefined; + if (fields.streamUsage) { + usageMetadata = { + input_tokens: data.usage.output_tokens, + output_tokens: 0, + total_tokens: data.usage.output_tokens, + }; + } + if (data?.usage !== undefined) { + usageDataCopy.output_tokens += data.usage.output_tokens; + } + return { + chunk: new AIMessageChunk({ + content: fields.coerceContentToString ? "" : [], + additional_kwargs: { ...data.delta }, + usage_metadata: usageMetadata, + }), + usageData: usageDataCopy, + }; + } else if ( + data.type === "content_block_start" && + data.content_block.type === "tool_use" + ) { + return { + chunk: new AIMessageChunk({ + content: fields.coerceContentToString + ? "" + : [ + { + index: data.index, + ...data.content_block, + input: "", + }, + ], + additional_kwargs: {}, + }), + usageData: usageDataCopy, + }; + } else if ( + data.type === "content_block_delta" && + data.delta.type === "text_delta" + ) { + const content = data.delta?.text; + if (content !== undefined) { + return { + chunk: new AIMessageChunk({ + content: fields.coerceContentToString + ? content + : [ + { + index: data.index, + ...data.delta, + }, + ], + additional_kwargs: {}, + }), + usageData: usageDataCopy, + }; + } + } else if ( + data.type === "content_block_delta" && + data.delta.type === "input_json_delta" + ) { + return { + chunk: new AIMessageChunk({ + content: fields.coerceContentToString + ? "" + : [ + { + index: data.index, + input: data.delta.partial_json, + type: data.delta.type, + }, + ], + additional_kwargs: {}, + }), + usageData: usageDataCopy, + }; + } + + return null; +} + /** * Input to AnthropicChat class. */ @@ -420,6 +549,116 @@ function _formatMessagesForAnthropic(messages: BaseMessage[]): { }; } +function extractToolCallChunk( + chunk: AIMessageChunk +): ToolCallChunk | undefined { + let newToolCallChunk: ToolCallChunk | undefined; + + // Initial chunk for tool calls from anthropic contains identifying information like ID and name. + // This chunk does not contain any input JSON. + const toolUseChunks = Array.isArray(chunk.content) + ? chunk.content.find((c) => c.type === "tool_use") + : undefined; + if ( + toolUseChunks && + "index" in toolUseChunks && + "name" in toolUseChunks && + "id" in toolUseChunks + ) { + newToolCallChunk = { + args: "", + id: toolUseChunks.id, + name: toolUseChunks.name, + index: toolUseChunks.index, + type: "tool_call_chunk", + }; + } + + // Chunks after the initial chunk only contain the index and partial JSON. + const inputJsonDeltaChunks = Array.isArray(chunk.content) + ? chunk.content.find((c) => c.type === "input_json_delta") + : undefined; + if ( + inputJsonDeltaChunks && + "index" in inputJsonDeltaChunks && + "input" in inputJsonDeltaChunks + ) { + if (typeof inputJsonDeltaChunks.input === "string") { + newToolCallChunk = { + args: inputJsonDeltaChunks.input, + index: inputJsonDeltaChunks.index, + type: "tool_call_chunk", + }; + } else { + newToolCallChunk = { + args: JSON.stringify(inputJsonDeltaChunks.input, null, 2), + index: inputJsonDeltaChunks.index, + type: "tool_call_chunk", + }; + } + } + + return newToolCallChunk; +} + +function extractToken(chunk: AIMessageChunk): string | undefined { + return typeof chunk.content === "string" && chunk.content !== "" + ? chunk.content + : undefined; +} + +function extractToolUseContent( + chunk: AIMessageChunk, + concatenatedChunks: AIMessageChunk | undefined +) { + let newConcatenatedChunks = concatenatedChunks; + // Remove `tool_use` content types until the last chunk. + let toolUseContent: + | { + id: string; + type: "tool_use"; + name: string; + input: Record; + } + | undefined; + if (!newConcatenatedChunks) { + newConcatenatedChunks = chunk; + } else { + newConcatenatedChunks = concat(newConcatenatedChunks, chunk); + } + if ( + Array.isArray(newConcatenatedChunks.content) && + newConcatenatedChunks.content.find((c) => c.type === "tool_use") + ) { + try { + const toolUseMsg = newConcatenatedChunks.content.find( + (c) => c.type === "tool_use" + ); + if ( + !toolUseMsg || + !("input" in toolUseMsg || "name" in toolUseMsg || "id" in toolUseMsg) + ) + return; + const parsedArgs = JSON.parse(toolUseMsg.input); + if (parsedArgs) { + toolUseContent = { + type: "tool_use", + id: toolUseMsg.id, + name: toolUseMsg.name, + input: parsedArgs, + }; + } + } catch (_) { + // no-op + } + } + + return { + toolUseContent, + concatenatedChunks: newConcatenatedChunks, + }; +} + /** * Wrapper around Anthropic large language models. * @@ -674,126 +913,92 @@ export class ChatAnthropicMessages< ): AsyncGenerator { const params = this.invocationParams(options); const formattedMessages = _formatMessagesForAnthropic(messages); - if (options.tools !== undefined && options.tools.length > 0) { - const { generations } = await this._generateNonStreaming( - messages, - params, - { - signal: options.signal, - } - ); - const result = generations[0].message as AIMessage; - const toolCallChunks = result.tool_calls?.map( - (toolCall: ToolCall, index: number) => ({ - name: toolCall.name, - args: JSON.stringify(toolCall.args), - id: toolCall.id, - index, - }) - ); - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: result.content, - additional_kwargs: result.additional_kwargs, - tool_call_chunks: toolCallChunks, - usage_metadata: result.usage_metadata, - response_metadata: result.response_metadata, - }), - text: generations[0].text, - }); - } else { - const stream = await this.createStreamWithRetry({ - ...params, - ...formattedMessages, - stream: true, + const coerceContentToString = !_toolsInParams({ + ...params, + ...formattedMessages, + stream: false, + }); + + const stream = await this.createStreamWithRetry({ + ...params, + ...formattedMessages, + stream: true, + }); + let usageData = { input_tokens: 0, output_tokens: 0 }; + + let concatenatedChunks: AIMessageChunk | undefined; + + for await (const data of stream) { + if (options.signal?.aborted) { + stream.controller.abort(); + throw new Error("AbortError: User aborted the request."); + } + + const result = _makeMessageChunkFromAnthropicEvent(data, { + streamUsage: !!(this.streamUsage || options.streamUsage), + coerceContentToString, + usageData, }); - let usageData = { input_tokens: 0, output_tokens: 0 }; - for await (const data of stream) { - if (options.signal?.aborted) { - stream.controller.abort(); - throw new Error("AbortError: User aborted the request."); - } - if (data.type === "message_start") { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { content, usage, ...additionalKwargs } = data.message; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const filteredAdditionalKwargs: Record = {}; - for (const [key, value] of Object.entries(additionalKwargs)) { - if (value !== undefined && value !== null) { - filteredAdditionalKwargs[key] = value; - } - } - usageData = usage; - let usageMetadata: UsageMetadata | undefined; - if (this.streamUsage || options.streamUsage) { - usageMetadata = { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - total_tokens: usage.input_tokens + usage.output_tokens, - }; - } - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: "", - additional_kwargs: filteredAdditionalKwargs, - usage_metadata: usageMetadata, - }), - text: "", - }); - } else if (data.type === "message_delta") { - let usageMetadata: UsageMetadata | undefined; - if (this.streamUsage || options.streamUsage) { - usageMetadata = { - input_tokens: data.usage.output_tokens, - output_tokens: 0, - total_tokens: data.usage.output_tokens, - }; - } - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: "", - additional_kwargs: { ...data.delta }, - usage_metadata: usageMetadata, - }), - text: "", - }); - if (data?.usage !== undefined) { - usageData.output_tokens += data.usage.output_tokens; - } - } else if ( - data.type === "content_block_delta" && - data.delta.type === "text_delta" - ) { - const content = data.delta?.text; - if (content !== undefined) { - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content, - additional_kwargs: {}, - }), - text: content, - }); - await runManager?.handleLLMNewToken(content); - } - } + if (!result) continue; + + const { chunk, usageData: updatedUsageData } = result; + usageData = updatedUsageData; + + const newToolCallChunk = extractToolCallChunk(chunk); + // Maintain concatenatedChunks for accessing the complete `tool_use` content block. + concatenatedChunks = concatenatedChunks + ? concat(concatenatedChunks, chunk) + : chunk; + + let toolUseContent; + const extractedContent = extractToolUseContent(chunk, concatenatedChunks); + if (extractedContent) { + toolUseContent = extractedContent.toolUseContent; + concatenatedChunks = extractedContent.concatenatedChunks; } - let usageMetadata: UsageMetadata | undefined; - if (this.streamUsage || options.streamUsage) { - usageMetadata = { - input_tokens: usageData.input_tokens, - output_tokens: usageData.output_tokens, - total_tokens: usageData.input_tokens + usageData.output_tokens, - }; + + // Filter partial `tool_use` content, and only add `tool_use` chunks if complete JSON available. + const chunkContent = Array.isArray(chunk.content) + ? chunk.content.filter((c) => c.type !== "tool_use") + : chunk.content; + if (Array.isArray(chunkContent) && toolUseContent) { + chunkContent.push(toolUseContent); } + + // Extract the text content token for text field and runManager. + const token = extractToken(chunk); yield new ChatGenerationChunk({ message: new AIMessageChunk({ - content: "", - additional_kwargs: { usage: usageData }, - usage_metadata: usageMetadata, + content: chunkContent, + additional_kwargs: chunk.additional_kwargs, + tool_call_chunks: newToolCallChunk ? [newToolCallChunk] : undefined, + usage_metadata: chunk.usage_metadata, + response_metadata: chunk.response_metadata, }), - text: "", + text: token ?? "", }); + + if (token) { + await runManager?.handleLLMNewToken(token); + } + } + + let usageMetadata: UsageMetadata | undefined; + if (this.streamUsage || options.streamUsage) { + usageMetadata = { + input_tokens: usageData.input_tokens, + output_tokens: usageData.output_tokens, + total_tokens: usageData.input_tokens + usageData.output_tokens, + }; } + yield new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: coerceContentToString ? "" : [], + additional_kwargs: { usage: usageData }, + usage_metadata: usageMetadata, + }), + text: "", + }); } /** @ignore */ diff --git a/libs/langchain-anthropic/src/output_parsers.ts b/libs/langchain-anthropic/src/output_parsers.ts index 1168b8d54d14..c5608900b4b9 100644 --- a/libs/langchain-anthropic/src/output_parsers.ts +++ b/libs/langchain-anthropic/src/output_parsers.ts @@ -82,7 +82,12 @@ export function extractToolCalls(content: Record[]) { const toolCalls: ToolCall[] = []; for (const block of content) { if (block.type === "tool_use") { - toolCalls.push({ name: block.name, args: block.input, id: block.id }); + toolCalls.push({ + name: block.name, + args: block.input, + id: block.id, + type: "tool_call", + }); } } return toolCalls; diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index a6cbdcbd956b..0bd1fe766875 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -1,8 +1,14 @@ /* eslint-disable no-process-env */ import { expect, test } from "@jest/globals"; -import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages"; -import { StructuredTool } from "@langchain/core/tools"; +import { + AIMessage, + AIMessageChunk, + HumanMessage, + ToolMessage, +} from "@langchain/core/messages"; +import { StructuredTool, tool } from "@langchain/core/tools"; +import { concat } from "@langchain/core/utils/stream"; import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatAnthropic } from "../chat_models.js"; @@ -27,7 +33,6 @@ class WeatherTool extends StructuredTool { name = "get_weather"; async _call(input: z.infer) { - console.log(`WeatherTool called with input: ${input}`); return `The weather in ${input.location} is 25°C`; } } @@ -78,7 +83,6 @@ test("Few shotting with tool calls", async () => { ), new HumanMessage("What did you say the weather was?"), ]); - console.log(res); expect(res.content).toContain("24"); }); @@ -90,12 +94,7 @@ test("Can bind & invoke StructuredTools", async () => { const result = await modelWithTools.invoke( "What is the weather in SF today?" ); - console.log( - { - tool_calls: JSON.stringify(result.content, null, 2), - }, - "Can bind & invoke StructuredTools" - ); + expect(Array.isArray(result.content)).toBeTruthy(); if (!Array.isArray(result.content)) { throw new Error("Content is not an array"); @@ -111,7 +110,7 @@ test("Can bind & invoke StructuredTools", async () => { } expect(toolCall).toBeTruthy(); const { name, input } = toolCall; - expect(toolCall.input).toEqual(result.tool_calls?.[0].args); + expect(input).toEqual(result.tool_calls?.[0].args); expect(name).toBe("get_weather"); expect(input).toBeTruthy(); expect(input.location).toBeTruthy(); @@ -128,7 +127,6 @@ test("Can bind & invoke StructuredTools", async () => { ), new HumanMessage("What did you say the weather was?"), ]); - console.log(result2); // This should work, but Anthorpic is too skeptical expect(result2.content).toContain("59"); }); @@ -141,12 +139,7 @@ test("Can bind & invoke AnthropicTools", async () => { const result = await modelWithTools.invoke( "What is the weather in London today?" ); - console.log( - { - tool_calls: JSON.stringify(result.content, null, 2), - }, - "Can bind & invoke StructuredTools" - ); + expect(Array.isArray(result.content)).toBeTruthy(); if (!Array.isArray(result.content)) { throw new Error("Content is not an array"); @@ -170,45 +163,43 @@ test("Can bind & invoke AnthropicTools", async () => { test("Can bind & stream AnthropicTools", async () => { const modelWithTools = model.bind({ tools: [anthropicTool], + tool_choice: { + type: "tool", + name: "get_weather", + }, }); const result = await modelWithTools.stream( "What is the weather in London today?" ); - let finalMessage; + let finalMessage: AIMessageChunk | undefined; for await (const item of result) { - console.log("item", JSON.stringify(item, null, 2)); - finalMessage = item; + if (!finalMessage) { + finalMessage = item; + } else { + finalMessage = concat(finalMessage, item); + } } + expect(finalMessage).toBeDefined(); if (!finalMessage) { throw new Error("No final message returned"); } - console.log( - { - tool_calls: JSON.stringify(finalMessage.content, null, 2), - }, - "Can bind & invoke StructuredTools" - ); expect(Array.isArray(finalMessage.content)).toBeTruthy(); if (!Array.isArray(finalMessage.content)) { throw new Error("Content is not an array"); } - let toolCall: AnthropicToolResponse | undefined; - finalMessage.content.forEach((item) => { - if (item.type === "tool_use") { - toolCall = item as AnthropicToolResponse; - } - }); - if (!toolCall) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const toolCall = finalMessage.tool_calls?.[0]; + if (toolCall === undefined) { throw new Error("No tool call found"); } expect(toolCall).toBeTruthy(); - const { name, input } = toolCall; + const { name, args } = toolCall; expect(name).toBe("get_weather"); - expect(input).toBeTruthy(); - expect(input.location).toBeTruthy(); + expect(args).toBeTruthy(); + expect(args.location).toBeTruthy(); }); test("withStructuredOutput with zod schema", async () => { @@ -222,12 +213,6 @@ test("withStructuredOutput with zod schema", async () => { const result = await modelWithTools.invoke( "What is the weather in London today?" ); - console.log( - { - result, - }, - "withStructuredOutput with zod schema" - ); expect(typeof result.location).toBe("string"); }); @@ -242,12 +227,7 @@ test("withStructuredOutput with AnthropicTool", async () => { const result = await modelWithTools.invoke( "What is the weather in London today?" ); - console.log( - { - result, - }, - "withStructuredOutput with AnthropicTool" - ); + expect(typeof result.location).toBe("string"); }); @@ -263,12 +243,7 @@ test("withStructuredOutput JSON Schema only", async () => { const result = await modelWithTools.invoke( "What is the weather in London today?" ); - console.log( - { - result, - }, - "withStructuredOutput JSON Schema only" - ); + expect(typeof result.location).toBe("string"); }); @@ -314,12 +289,7 @@ test("Can pass tool_choice", async () => { const result = await modelWithTools.invoke( "What is the sum of 272818 and 281818?" ); - console.log( - { - tool_calls: JSON.stringify(result.content, null, 2), - }, - "Can bind & invoke StructuredTools" - ); + expect(Array.isArray(result.content)).toBeTruthy(); if (!Array.isArray(result.content)) { throw new Error("Content is not an array"); @@ -335,7 +305,7 @@ test("Can pass tool_choice", async () => { } expect(toolCall).toBeTruthy(); const { name, input } = toolCall; - expect(toolCall.input).toEqual(result.tool_calls?.[0].args); + expect(input).toEqual(result.tool_calls?.[0].args); expect(name).toBe("get_weather"); expect(input).toBeTruthy(); expect(input.location).toBeTruthy(); @@ -386,3 +356,49 @@ test("withStructuredOutput will always force tool usage", async () => { expect(castMessage.tool_calls).toHaveLength(1); expect(castMessage.tool_calls?.[0].name).toBe("get_weather"); }); + +test("Can stream tool calls", async () => { + const weatherTool = tool((_) => "no-op", { + name: "get_weather", + description: zodSchema.description, + schema: zodSchema, + }); + + const modelWithTools = model.bindTools([weatherTool], { + tool_choice: { + type: "tool", + name: "get_weather", + }, + }); + const stream = await modelWithTools.stream( + "What is the weather in San Francisco CA?" + ); + + let realToolCallChunkStreams = 0; + let prevToolCallChunkArgs = ""; + let finalChunk: AIMessageChunk | undefined; + for await (const chunk of stream) { + if (!finalChunk) { + finalChunk = chunk; + } else { + finalChunk = concat(finalChunk, chunk); + } + if (chunk.tool_call_chunks?.[0]?.args) { + // Check if the args have changed since the last chunk. + // This helps count the number of unique arg updates in the stream, + // ensuring we're receiving multiple chunks with different arg content. + if ( + !prevToolCallChunkArgs || + prevToolCallChunkArgs !== chunk.tool_call_chunks[0].args + ) { + realToolCallChunkStreams += 1; + } + prevToolCallChunkArgs = chunk.tool_call_chunks[0].args; + } + } + + expect(finalChunk?.tool_calls?.[0]).toBeDefined(); + expect(finalChunk?.tool_calls?.[0].name).toBe("get_weather"); + expect(finalChunk?.tool_calls?.[0].args.location).toBeDefined(); + expect(realToolCallChunkStreams).toBeGreaterThan(1); +});