From e6616a546a6c1f841ad5d3aebce1246e6f827cc6 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 8 Jul 2024 15:13:27 -0700 Subject: [PATCH] standard-tests[minor]: Improve prompting to force model to call tool (#6004) --- .../src/integration_tests/chat_models.ts | 54 +++++++++++++------ 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts index f0e5bd7ca5cb..ce8ff4c3a96b 100644 --- a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts +++ b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts @@ -11,6 +11,7 @@ import { import { z } from "zod"; import { StructuredTool } from "@langchain/core/tools"; import { zodToJsonSchema } from "zod-to-json-schema"; +import { ChatPromptTemplate } from "@langchain/core/prompts"; import { BaseChatModelsTests, BaseChatModelsTestsFields, @@ -37,6 +38,14 @@ class AdderTool extends StructuredTool { } } +const MATH_ADDITION_PROMPT = /* #__PURE__ */ ChatPromptTemplate.fromMessages([ + [ + "system", + "You are bad at math and must ALWAYS call the {toolName} function.", + ], + ["human", "What is the sum of 1836281973 and 19973286?"], +]); + interface ChatModelIntegrationTestsFields< CallOptions extends BaseChatModelCallOptions = BaseChatModelCallOptions, OutputMessageType extends BaseMessageChunk = BaseMessageChunk, @@ -228,11 +237,11 @@ export abstract class ChatModelIntegrationTests< new ToolMessage(functionResult, functionId, functionName), ]; - const resultStringContent = await modelWithTools.invoke( + const result = await modelWithTools.invoke( messagesStringContent, callOptions ); - expect(resultStringContent).toBeInstanceOf(this.invokeResponseType); + expect(result).toBeInstanceOf(this.invokeResponseType); } /** @@ -334,11 +343,11 @@ export abstract class ChatModelIntegrationTests< new HumanMessage("What is 3 + 4"), ]; - const resultStringContent = await modelWithTools.invoke( + const result = await modelWithTools.invoke( messagesStringContent, callOptions ); - expect(resultStringContent).toBeInstanceOf(this.invokeResponseType); + expect(result).toBeInstanceOf(this.invokeResponseType); } async testWithStructuredOutput() { @@ -353,13 +362,17 @@ export abstract class ChatModelIntegrationTests< "withStructuredOutput undefined. Cannot test tool message histories." ); } - const modelWithTools = model.withStructuredOutput(adderSchema); + const modelWithTools = model.withStructuredOutput(adderSchema, { + name: "math_addition", + }); - const resultStringContent = await modelWithTools.invoke("What is 1 + 2"); - expect(resultStringContent.a).toBeDefined(); - expect([1, 2].includes(resultStringContent.a)).toBeTruthy(); - expect(resultStringContent.b).toBeDefined(); - expect([1, 2].includes(resultStringContent.b)).toBeTruthy(); + const result = await MATH_ADDITION_PROMPT.pipe(modelWithTools).invoke({ + toolName: "math_addition", + }); + expect(result.a).toBeDefined(); + expect(typeof result.a).toBe("number"); + expect(result.b).toBeDefined(); + expect(typeof result.b).toBe("number"); } async testWithStructuredOutputIncludeRaw() { @@ -376,14 +389,17 @@ export abstract class ChatModelIntegrationTests< } const modelWithTools = model.withStructuredOutput(adderSchema, { includeRaw: true, + name: "math_addition", }); - const resultStringContent = await modelWithTools.invoke("What is 1 + 2"); - expect(resultStringContent.raw).toBeInstanceOf(this.invokeResponseType); - expect(resultStringContent.parsed.a).toBeDefined(); - expect([1, 2].includes(resultStringContent.parsed.a)).toBeTruthy(); - expect(resultStringContent.parsed.b).toBeDefined(); - expect([1, 2].includes(resultStringContent.parsed.b)).toBeTruthy(); + const result = await MATH_ADDITION_PROMPT.pipe(modelWithTools).invoke({ + toolName: "math_addition", + }); + expect(result.raw).toBeInstanceOf(this.invokeResponseType); + expect(result.parsed.a).toBeDefined(); + expect(typeof result.parsed.a).toBe("number"); + expect(result.parsed.b).toBeDefined(); + expect(typeof result.parsed.b).toBe("number"); } async testBindToolsWithOpenAIFormattedTools() { @@ -409,7 +425,11 @@ export abstract class ChatModelIntegrationTests< }, ]); - const result: AIMessage = await modelWithTools.invoke("What is 1 + 2"); + const result: AIMessage = await MATH_ADDITION_PROMPT.pipe( + modelWithTools + ).invoke({ + toolName: "math_addition", + }); expect(result.tool_calls).toHaveLength(1); if (!result.tool_calls) { throw new Error("result.tool_calls is undefined");