Skip to content

Commit

Permalink
standard-tests[minor]: Add tests for parallel tool calls (#6258)
Browse files Browse the repository at this point in the history
* standard-tests[minor]: Add tests for parallel tool calls

* run tests

* cr

* chore: lint files

* cr

* refactored test

* cr

* add missing import
  • Loading branch information
bracesproul authored Jul 29, 2024
1 parent a240cd7 commit 023a581
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,24 @@ class ChatAnthropicStandardIntegrationTests extends ChatModelIntegrationTests<
Cls: ChatAnthropic,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
constructorArgs: {
model: "claude-3-haiku-20240307",
},
});
}

async testParallelToolCalling() {
// Override constructor args to use a better model for this test.
// I found that haiku struggles with parallel tool calling.
const constructorArgsCopy = { ...this.constructorArgs };
this.constructorArgs = {
...this.constructorArgs,
model: "claude-3-5-sonnet-20240620",
};
await super.testParallelToolCalling();
this.constructorArgs = constructorArgsCopy;
}
}

const testClass = new ChatAnthropicStandardIntegrationTests();
Expand Down
7 changes: 7 additions & 0 deletions libs/langchain-aws/src/tests/chat_models.standard.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ChatBedrockConverseStandardIntegrationTests extends ChatModelIntegrationTe
Cls: ChatBedrockConverse,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
constructorArgs: {
region,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
Expand Down Expand Up @@ -51,6 +52,12 @@ class ChatBedrockConverseStandardIntegrationTests extends ChatModelIntegrationTe
"Not properly implemented."
);
}

async testParallelToolCalling() {
// Pass `true` in the second argument to only verify it can support parallel tool calls in the message history.
// This is because the model struggles to actually call parallel tools.
await super.testParallelToolCalling(undefined, true);
}
}

const testClass = new ChatBedrockConverseStandardIntegrationTests();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ class ChatGoogleGenerativeAIStandardIntegrationTests extends ChatModelIntegratio
Cls: ChatGoogleGenerativeAI,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
constructorArgs: {
maxRetries: 1,
model: "gemini-1.5-pro",
},
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
Cls: ChatVertexAI,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
invokeResponseType: AIMessageChunk,
constructorArgs: {
model: "gemini-1.5-pro",
Expand All @@ -42,6 +43,12 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
"Google VertexAI only supports objects in schemas when the parameters are defined."
);
}

async testParallelToolCalling() {
// Pass `true` in the second argument to only verify it can support parallel tool calls in the message history.
// This is because the model struggles to actually call parallel tools.
await super.testParallelToolCalling(undefined, true);
}
}

const testClass = new ChatVertexAIStandardIntegrationTests();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class AzureChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests<
Cls: AzureChatOpenAI,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
constructorArgs: {
model: "gpt-3.5-turbo",
},
Expand Down Expand Up @@ -62,6 +63,12 @@ class AzureChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests<
"AzureChatOpenAI only supports objects in schemas when the parameters are defined."
);
}

async testParallelToolCalling() {
// Pass `true` in the second argument to only verify it can support parallel tool calls in the message history.
// This is because the model struggles to actually call parallel tools.
await super.testParallelToolCalling(undefined, true);
}
}

const testClass = new AzureChatOpenAIStandardIntegrationTests();
Expand Down
13 changes: 13 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.standard.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests<
Cls: ChatOpenAI,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
supportsParallelToolCalls: true,
constructorArgs: {
model: "gpt-3.5-turbo",
},
Expand All @@ -44,6 +45,18 @@ class ChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests<
"\nOpenAI only supports objects in schemas when the parameters are defined."
);
}

async testParallelToolCalling() {
// Override constructor args to use a better model for this test.
// I found that GPT 3.5 struggles with parallel tool calling.
const constructorArgsCopy = { ...this.constructorArgs };
this.constructorArgs = {
...this.constructorArgs,
model: "gpt-4o",
};
await super.testParallelToolCalling();
this.constructorArgs = constructorArgsCopy;
}
}

const testClass = new ChatOpenAIStandardIntegrationTests();
Expand Down
214 changes: 214 additions & 0 deletions libs/langchain-standard-tests/src/integration_tests/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ interface ChatModelIntegrationTestsFields<
* @default "abc123"
*/
functionId?: string;
/**
* Whether or not the model supports parallel tool calling.
* @default false
*/
supportsParallelToolCalls?: boolean;
}

export abstract class ChatModelIntegrationTests<
Expand All @@ -82,6 +87,8 @@ export abstract class ChatModelIntegrationTests<

invokeResponseType: typeof AIMessage | typeof AIMessageChunk = AIMessage;

supportsParallelToolCalls = false;

constructor(
fields: ChatModelIntegrationTestsFields<
CallOptions,
Expand All @@ -93,6 +100,8 @@ export abstract class ChatModelIntegrationTests<
this.functionId = fields.functionId ?? this.functionId;
this.invokeResponseType =
fields.invokeResponseType ?? this.invokeResponseType;
this.supportsParallelToolCalls =
fields.supportsParallelToolCalls ?? this.supportsParallelToolCalls;
}

/**
Expand Down Expand Up @@ -1313,6 +1322,204 @@ Extraction path: {extractionPath}`,
expect(typeof result.apiDetails === "object").toBeTruthy();
}

/**
* Tests the chat model's ability to handle parallel tool calls in various scenarios.
* This comprehensive test covers three aspects of parallel tool calling:
* 1. Invoking multiple tools simultaneously
* 2. Streaming responses with parallel tool calls
* 3. Processing message histories containing parallel tool calls
*
* The test uses a weather tool and a current time tool to simulate complex, multi-tool scenarios.
* It ensures that the model can correctly process and respond to prompts requiring multiple tool calls,
* both in streaming and non-streaming contexts, and can handle message histories with parallel tool calls.
*
* @param {InstanceType<this["Cls"]>["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model.
* @param {boolean} onlyVerifyHistory If true, only verifies the message history test.
*/
async testParallelToolCalling(
callOptions?: InstanceType<this["Cls"]>["ParsedCallOptions"],
onlyVerifyHistory = false
) {
// Skip the test if the model doesn't support tool calling
if (!this.chatModelHasToolCalling) {
console.log("Test requires tool calling. Skipping...");
return;
}
// Skip the test if the model doesn't support parallel tool calls
if (!this.supportsParallelToolCalls) {
console.log("Test requires parallel tool calls. Skipping...");
return;
}
const model = new this.Cls(this.constructorArgs);
if (!model.bindTools) {
throw new Error(
"bindTools undefined. Cannot test OpenAI formatted tool calls."
);
}

const weatherTool = tool((_) => "no-op", {
name: "get_current_weather",
description: "Get the current weather in a given location",
schema: z.object({
location: z.string().describe("The city name, e.g. San Francisco"),
}),
});
const currentTimeTool = tool((_) => "no-op", {
name: "get_current_time",
description: "Get the current time in a given location",
schema: z.object({
location: z.string().describe("The city name, e.g. San Francisco"),
}),
});

const modelWithTools = model.bindTools([weatherTool, currentTimeTool]);

const callParallelToolsPrompt =
"What's the weather and current time in San Francisco?\n" +
"Ensure you ALWAYS call the 'get_current_weather' tool for weather and 'get_current_time' tool for time.";

// Save the result of the parallel tool calls for the history test.
let parallelToolCallsMessage: AIMessage | undefined;

/**
* Tests the basic functionality of invoking multiple tools in parallel.
* Verifies that the model can call both the weather and current time tools simultaneously.
*/
const invokeParallelTools = async () => {
const result: AIMessage = await modelWithTools.invoke(
callParallelToolsPrompt,
callOptions
);
// Model should call at least two tools. Using greater than or equal since it might call the current time tool multiple times.
expect(result.tool_calls?.length).toBeGreaterThanOrEqual(2);
if (!result.tool_calls?.length) return;

const weatherToolCalls = result.tool_calls.find(
(tc) => tc.name === weatherTool.name
);
const currentTimeToolCalls = result.tool_calls.find(
(tc) => tc.name === currentTimeTool.name
);

expect(weatherToolCalls).toBeDefined();
expect(currentTimeToolCalls).toBeDefined();
parallelToolCallsMessage = result;
};

/**
* Tests the model's ability to stream responses while making parallel tool calls.
* Ensures that the streamed result contains calls to both the weather and current time tools.
*/
const streamParallelTools = async () => {
const stream = await modelWithTools.stream(
callParallelToolsPrompt,
callOptions
);
let finalChunk: AIMessageChunk | undefined;
for await (const chunk of stream) {
finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk);
}

expect(finalChunk).toBeDefined();
if (!finalChunk) return;

// Model should call at least two tools. Do not penalize for calling more than two tools, as
// long as it calls both the weather and current time tools.
expect(finalChunk.tool_calls?.length).toBeGreaterThanOrEqual(2);
if (!finalChunk.tool_calls?.length) return;

const weatherToolCalls = finalChunk.tool_calls.find(
(tc) => tc.name === weatherTool.name
);
const currentTimeToolCalls = finalChunk.tool_calls.find(
(tc) => tc.name === currentTimeTool.name
);

expect(weatherToolCalls).toBeDefined();
expect(currentTimeToolCalls).toBeDefined();
};

/**
* Tests the model's ability to process a message history containing parallel tool calls.
* Verifies that the model can generate a response based on previous tool calls without making unnecessary additional tool calls.
*/
const invokeParallelToolCallResultsInHistory = async () => {
const defaultAIMessageWithParallelTools = new AIMessage({
content: "",
tool_calls: [
{
name: weatherTool.name,
id: "get_current_weather_id",
args: { location: "San Francisco" },
},
{
name: currentTimeTool.name,
id: "get_current_time_id",
args: { location: "San Francisco" },
},
],
});
if (!parallelToolCallsMessage) {
// Allow this variable to be assigned in the first test, or if only run histories
// is passed, assign it here since the first test will not run.
parallelToolCallsMessage = defaultAIMessageWithParallelTools;
}
// Find the tool calls for the weather and current time tools so we can re-use the IDs in the message history.
const parallelToolCallWeather = parallelToolCallsMessage.tool_calls?.find(
(tc) => tc.name === weatherTool.name
);
const parallelToolCallCurrentTime =
parallelToolCallsMessage.tool_calls?.find(
(tc) => tc.name === currentTimeTool.name
);
if (!parallelToolCallWeather?.id || !parallelToolCallCurrentTime?.id) {
throw new Error(
`IDs not found in one of both of parallel tool calls:\nWeather ID: ${parallelToolCallWeather?.id}\nCurrent Time ID: ${parallelToolCallCurrentTime?.id}`
);
}

const messageHistory = [
new HumanMessage(callParallelToolsPrompt),
// The saved message from earlier when we called the model to generate the parallel tool calls.
parallelToolCallsMessage,
new ToolMessage({
name: weatherTool.name,
tool_call_id: parallelToolCallWeather.id,
content: "It is currently 24 degrees with hail in San Francisco.",
}),
new ToolMessage({
name: currentTimeTool.name,
tool_call_id: parallelToolCallCurrentTime.id,
content: "The current time in San Francisco is 12:02 PM.",
}),
];

const result: AIMessage = await modelWithTools.invoke(
messageHistory,
callOptions
);
// The model should NOT call a tool given this message history.
expect(result.tool_calls ?? []).toHaveLength(0);

if (typeof result.content === "string") {
expect(result.content).not.toBe("");
} else {
expect(result.content.length).toBeGreaterThan(0);
const textOrTextDeltaContent = result.content.find(
(c) => c.type === "text" || c.type === "text_delta"
);
expect(textOrTextDeltaContent).toBeDefined();
}
};

// Now we can invoke each of our tests synchronously, as the last test requires the result of the first test.
if (!onlyVerifyHistory) {
await invokeParallelTools();
await streamParallelTools();
}
await invokeParallelToolCallResultsInHistory();
}

/**
* Run all unit tests for the chat model.
* Each test is wrapped in a try/catch block to prevent the entire test suite from failing.
Expand Down Expand Up @@ -1451,6 +1658,13 @@ Extraction path: {extractionPath}`,
console.error("testInvokeMoreComplexTools failed", e.message);
}

try {
await this.testParallelToolCalling();
} catch (e: any) {
allTestsPassed = false;
console.error("testParallelToolCalling failed", e.message);
}

return allTestsPassed;
}
}

0 comments on commit 023a581

Please sign in to comment.