From 023a581125f3701f04790ca99d793b1ecb638ef6 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 29 Jul 2024 15:30:01 -0700 Subject: [PATCH] standard-tests[minor]: Add tests for parallel tool calls (#6258) * standard-tests[minor]: Add tests for parallel tool calls * run tests * cr * chore: lint files * cr * refactored test * cr * add missing import --- .../tests/chat_models.standard.int.test.ts | 13 ++ .../tests/chat_models.standard.int.test.ts | 7 + .../tests/chat_models.standard.int.test.ts | 2 + .../tests/chat_models.standard.int.test.ts | 7 + .../azure/chat_models.standard.int.test.ts | 7 + .../tests/chat_models.standard.int.test.ts | 13 ++ .../src/integration_tests/chat_models.ts | 214 ++++++++++++++++++ 7 files changed, 263 insertions(+) diff --git a/libs/langchain-anthropic/src/tests/chat_models.standard.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models.standard.int.test.ts index 1980680bf019..a0fd9a0f7420 100644 --- a/libs/langchain-anthropic/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models.standard.int.test.ts @@ -18,11 +18,24 @@ class ChatAnthropicStandardIntegrationTests extends ChatModelIntegrationTests< Cls: ChatAnthropic, chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, + supportsParallelToolCalls: true, constructorArgs: { model: "claude-3-haiku-20240307", }, }); } + + async testParallelToolCalling() { + // Override constructor args to use a better model for this test. + // I found that haiku struggles with parallel tool calling. + const constructorArgsCopy = { ...this.constructorArgs }; + this.constructorArgs = { + ...this.constructorArgs, + model: "claude-3-5-sonnet-20240620", + }; + await super.testParallelToolCalling(); + this.constructorArgs = constructorArgsCopy; + } } const testClass = new ChatAnthropicStandardIntegrationTests(); diff --git a/libs/langchain-aws/src/tests/chat_models.standard.int.test.ts b/libs/langchain-aws/src/tests/chat_models.standard.int.test.ts index bf222ea61ba3..bfc099864956 100644 --- a/libs/langchain-aws/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.standard.int.test.ts @@ -17,6 +17,7 @@ class ChatBedrockConverseStandardIntegrationTests extends ChatModelIntegrationTe Cls: ChatBedrockConverse, chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, + supportsParallelToolCalls: true, constructorArgs: { region, model: "anthropic.claude-3-sonnet-20240229-v1:0", @@ -51,6 +52,12 @@ class ChatBedrockConverseStandardIntegrationTests extends ChatModelIntegrationTe "Not properly implemented." ); } + + async testParallelToolCalling() { + // Pass `true` in the second argument to only verify it can support parallel tool calls in the message history. + // This is because the model struggles to actually call parallel tools. + await super.testParallelToolCalling(undefined, true); + } } const testClass = new ChatBedrockConverseStandardIntegrationTests(); diff --git a/libs/langchain-google-genai/src/tests/chat_models.standard.int.test.ts b/libs/langchain-google-genai/src/tests/chat_models.standard.int.test.ts index 384d04660e4b..835e91644eb3 100644 --- a/libs/langchain-google-genai/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-google-genai/src/tests/chat_models.standard.int.test.ts @@ -21,8 +21,10 @@ class ChatGoogleGenerativeAIStandardIntegrationTests extends ChatModelIntegratio Cls: ChatGoogleGenerativeAI, chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, + supportsParallelToolCalls: true, constructorArgs: { maxRetries: 1, + model: "gemini-1.5-pro", }, }); } diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts index 7e02632bb658..fbcd8a713659 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts @@ -19,6 +19,7 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests< Cls: ChatVertexAI, chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, + supportsParallelToolCalls: true, invokeResponseType: AIMessageChunk, constructorArgs: { model: "gemini-1.5-pro", @@ -42,6 +43,12 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests< "Google VertexAI only supports objects in schemas when the parameters are defined." ); } + + async testParallelToolCalling() { + // Pass `true` in the second argument to only verify it can support parallel tool calls in the message history. + // This is because the model struggles to actually call parallel tools. + await super.testParallelToolCalling(undefined, true); + } } const testClass = new ChatVertexAIStandardIntegrationTests(); diff --git a/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts b/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts index 641f257b3e1d..8146f04d0f88 100644 --- a/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts +++ b/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts @@ -32,6 +32,7 @@ class AzureChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests< Cls: AzureChatOpenAI, chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, + supportsParallelToolCalls: true, constructorArgs: { model: "gpt-3.5-turbo", }, @@ -62,6 +63,12 @@ class AzureChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests< "AzureChatOpenAI only supports objects in schemas when the parameters are defined." ); } + + async testParallelToolCalling() { + // Pass `true` in the second argument to only verify it can support parallel tool calls in the message history. + // This is because the model struggles to actually call parallel tools. + await super.testParallelToolCalling(undefined, true); + } } const testClass = new AzureChatOpenAIStandardIntegrationTests(); diff --git a/libs/langchain-openai/src/tests/chat_models.standard.int.test.ts b/libs/langchain-openai/src/tests/chat_models.standard.int.test.ts index 8151611adf1a..f16536d91997 100644 --- a/libs/langchain-openai/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.standard.int.test.ts @@ -18,6 +18,7 @@ class ChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests< Cls: ChatOpenAI, chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, + supportsParallelToolCalls: true, constructorArgs: { model: "gpt-3.5-turbo", }, @@ -44,6 +45,18 @@ class ChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests< "\nOpenAI only supports objects in schemas when the parameters are defined." ); } + + async testParallelToolCalling() { + // Override constructor args to use a better model for this test. + // I found that GPT 3.5 struggles with parallel tool calling. + const constructorArgsCopy = { ...this.constructorArgs }; + this.constructorArgs = { + ...this.constructorArgs, + model: "gpt-4o", + }; + await super.testParallelToolCalling(); + this.constructorArgs = constructorArgsCopy; + } } const testClass = new ChatOpenAIStandardIntegrationTests(); diff --git a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts index ce23b9f6869b..8fff150f1cf5 100644 --- a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts +++ b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts @@ -71,6 +71,11 @@ interface ChatModelIntegrationTestsFields< * @default "abc123" */ functionId?: string; + /** + * Whether or not the model supports parallel tool calling. + * @default false + */ + supportsParallelToolCalls?: boolean; } export abstract class ChatModelIntegrationTests< @@ -82,6 +87,8 @@ export abstract class ChatModelIntegrationTests< invokeResponseType: typeof AIMessage | typeof AIMessageChunk = AIMessage; + supportsParallelToolCalls = false; + constructor( fields: ChatModelIntegrationTestsFields< CallOptions, @@ -93,6 +100,8 @@ export abstract class ChatModelIntegrationTests< this.functionId = fields.functionId ?? this.functionId; this.invokeResponseType = fields.invokeResponseType ?? this.invokeResponseType; + this.supportsParallelToolCalls = + fields.supportsParallelToolCalls ?? this.supportsParallelToolCalls; } /** @@ -1313,6 +1322,204 @@ Extraction path: {extractionPath}`, expect(typeof result.apiDetails === "object").toBeTruthy(); } + /** + * Tests the chat model's ability to handle parallel tool calls in various scenarios. + * This comprehensive test covers three aspects of parallel tool calling: + * 1. Invoking multiple tools simultaneously + * 2. Streaming responses with parallel tool calls + * 3. Processing message histories containing parallel tool calls + * + * The test uses a weather tool and a current time tool to simulate complex, multi-tool scenarios. + * It ensures that the model can correctly process and respond to prompts requiring multiple tool calls, + * both in streaming and non-streaming contexts, and can handle message histories with parallel tool calls. + * + * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {boolean} onlyVerifyHistory If true, only verifies the message history test. + */ + async testParallelToolCalling( + callOptions?: InstanceType["ParsedCallOptions"], + onlyVerifyHistory = false + ) { + // Skip the test if the model doesn't support tool calling + if (!this.chatModelHasToolCalling) { + console.log("Test requires tool calling. Skipping..."); + return; + } + // Skip the test if the model doesn't support parallel tool calls + if (!this.supportsParallelToolCalls) { + console.log("Test requires parallel tool calls. Skipping..."); + return; + } + const model = new this.Cls(this.constructorArgs); + if (!model.bindTools) { + throw new Error( + "bindTools undefined. Cannot test OpenAI formatted tool calls." + ); + } + + const weatherTool = tool((_) => "no-op", { + name: "get_current_weather", + description: "Get the current weather in a given location", + schema: z.object({ + location: z.string().describe("The city name, e.g. San Francisco"), + }), + }); + const currentTimeTool = tool((_) => "no-op", { + name: "get_current_time", + description: "Get the current time in a given location", + schema: z.object({ + location: z.string().describe("The city name, e.g. San Francisco"), + }), + }); + + const modelWithTools = model.bindTools([weatherTool, currentTimeTool]); + + const callParallelToolsPrompt = + "What's the weather and current time in San Francisco?\n" + + "Ensure you ALWAYS call the 'get_current_weather' tool for weather and 'get_current_time' tool for time."; + + // Save the result of the parallel tool calls for the history test. + let parallelToolCallsMessage: AIMessage | undefined; + + /** + * Tests the basic functionality of invoking multiple tools in parallel. + * Verifies that the model can call both the weather and current time tools simultaneously. + */ + const invokeParallelTools = async () => { + const result: AIMessage = await modelWithTools.invoke( + callParallelToolsPrompt, + callOptions + ); + // Model should call at least two tools. Using greater than or equal since it might call the current time tool multiple times. + expect(result.tool_calls?.length).toBeGreaterThanOrEqual(2); + if (!result.tool_calls?.length) return; + + const weatherToolCalls = result.tool_calls.find( + (tc) => tc.name === weatherTool.name + ); + const currentTimeToolCalls = result.tool_calls.find( + (tc) => tc.name === currentTimeTool.name + ); + + expect(weatherToolCalls).toBeDefined(); + expect(currentTimeToolCalls).toBeDefined(); + parallelToolCallsMessage = result; + }; + + /** + * Tests the model's ability to stream responses while making parallel tool calls. + * Ensures that the streamed result contains calls to both the weather and current time tools. + */ + const streamParallelTools = async () => { + const stream = await modelWithTools.stream( + callParallelToolsPrompt, + callOptions + ); + let finalChunk: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); + } + + expect(finalChunk).toBeDefined(); + if (!finalChunk) return; + + // Model should call at least two tools. Do not penalize for calling more than two tools, as + // long as it calls both the weather and current time tools. + expect(finalChunk.tool_calls?.length).toBeGreaterThanOrEqual(2); + if (!finalChunk.tool_calls?.length) return; + + const weatherToolCalls = finalChunk.tool_calls.find( + (tc) => tc.name === weatherTool.name + ); + const currentTimeToolCalls = finalChunk.tool_calls.find( + (tc) => tc.name === currentTimeTool.name + ); + + expect(weatherToolCalls).toBeDefined(); + expect(currentTimeToolCalls).toBeDefined(); + }; + + /** + * Tests the model's ability to process a message history containing parallel tool calls. + * Verifies that the model can generate a response based on previous tool calls without making unnecessary additional tool calls. + */ + const invokeParallelToolCallResultsInHistory = async () => { + const defaultAIMessageWithParallelTools = new AIMessage({ + content: "", + tool_calls: [ + { + name: weatherTool.name, + id: "get_current_weather_id", + args: { location: "San Francisco" }, + }, + { + name: currentTimeTool.name, + id: "get_current_time_id", + args: { location: "San Francisco" }, + }, + ], + }); + if (!parallelToolCallsMessage) { + // Allow this variable to be assigned in the first test, or if only run histories + // is passed, assign it here since the first test will not run. + parallelToolCallsMessage = defaultAIMessageWithParallelTools; + } + // Find the tool calls for the weather and current time tools so we can re-use the IDs in the message history. + const parallelToolCallWeather = parallelToolCallsMessage.tool_calls?.find( + (tc) => tc.name === weatherTool.name + ); + const parallelToolCallCurrentTime = + parallelToolCallsMessage.tool_calls?.find( + (tc) => tc.name === currentTimeTool.name + ); + if (!parallelToolCallWeather?.id || !parallelToolCallCurrentTime?.id) { + throw new Error( + `IDs not found in one of both of parallel tool calls:\nWeather ID: ${parallelToolCallWeather?.id}\nCurrent Time ID: ${parallelToolCallCurrentTime?.id}` + ); + } + + const messageHistory = [ + new HumanMessage(callParallelToolsPrompt), + // The saved message from earlier when we called the model to generate the parallel tool calls. + parallelToolCallsMessage, + new ToolMessage({ + name: weatherTool.name, + tool_call_id: parallelToolCallWeather.id, + content: "It is currently 24 degrees with hail in San Francisco.", + }), + new ToolMessage({ + name: currentTimeTool.name, + tool_call_id: parallelToolCallCurrentTime.id, + content: "The current time in San Francisco is 12:02 PM.", + }), + ]; + + const result: AIMessage = await modelWithTools.invoke( + messageHistory, + callOptions + ); + // The model should NOT call a tool given this message history. + expect(result.tool_calls ?? []).toHaveLength(0); + + if (typeof result.content === "string") { + expect(result.content).not.toBe(""); + } else { + expect(result.content.length).toBeGreaterThan(0); + const textOrTextDeltaContent = result.content.find( + (c) => c.type === "text" || c.type === "text_delta" + ); + expect(textOrTextDeltaContent).toBeDefined(); + } + }; + + // Now we can invoke each of our tests synchronously, as the last test requires the result of the first test. + if (!onlyVerifyHistory) { + await invokeParallelTools(); + await streamParallelTools(); + } + await invokeParallelToolCallResultsInHistory(); + } + /** * Run all unit tests for the chat model. * Each test is wrapped in a try/catch block to prevent the entire test suite from failing. @@ -1451,6 +1658,13 @@ Extraction path: {extractionPath}`, console.error("testInvokeMoreComplexTools failed", e.message); } + try { + await this.testParallelToolCalling(); + } catch (e: any) { + allTestsPassed = false; + console.error("testParallelToolCalling failed", e.message); + } + return allTestsPassed; } }