From 1296334a33f84d55dd9a3c59a07bc6f4b87b5b18 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Tue, 6 Aug 2024 15:47:34 -0700 Subject: [PATCH] openai[minor],core[minor]: Add support for passing strict in openai tools (#6418) * openai[minor],core[minor]: Add support for passing strict in openai tools * add integration test * chore: lint files * Cr * cr * fix * fixed all tests * docs * fix build errors * fix more type errors * cr * cr * chore: lint files --- .../docs/integrations/chat/openai.ipynb | 78 +++++- langchain-core/src/language_models/base.ts | 9 + langchain-core/src/utils/function_calling.ts | 62 ++++- langchain/src/agents/openai_tools/index.ts | 4 +- .../src/agents/openai_tools/output_parser.ts | 2 +- libs/langchain-groq/src/chat_models.ts | 2 +- libs/langchain-ollama/src/chat_models.ts | 6 +- libs/langchain-openai/package.json | 2 +- libs/langchain-openai/src/chat_models.ts | 158 ++++++++--- .../azure/chat_models.standard.int.test.ts | 19 +- .../src/tests/chat_models.test.ts | 259 ++++++++++++++++++ .../chat_models_structured_output.int.test.ts | 35 +++ libs/langchain-openai/src/types.ts | 6 + yarn.lock | 24 +- 14 files changed, 617 insertions(+), 49 deletions(-) create mode 100644 libs/langchain-openai/src/tests/chat_models.test.ts diff --git a/docs/core_docs/docs/integrations/chat/openai.ipynb b/docs/core_docs/docs/integrations/chat/openai.ipynb index 096cbb3afd36..3b155a178d77 100644 --- a/docs/core_docs/docs/integrations/chat/openai.ipynb +++ b/docs/core_docs/docs/integrations/chat/openai.ipynb @@ -411,7 +411,7 @@ }, { "cell_type": "markdown", - "id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3", + "id": "bc5ecebd", "metadata": {}, "source": [ "## Tool calling\n", @@ -420,8 +420,82 @@ "\n", "- [How to: disable parallel tool calling](/docs/how_to/tool_calling_parallel/)\n", "- [How to: force a tool call](/docs/how_to/tool_choice/)\n", - "- [How to: bind model-specific tool formats to a model](/docs/how_to/tool_calling#binding-model-specific-formats-advanced).\n", + "- [How to: bind model-specific tool formats to a model](/docs/how_to/tool_calling#binding-model-specific-formats-advanced)." + ] + }, + { + "cell_type": "markdown", + "id": "3392390e", + "metadata": {}, + "source": [ + "### ``strict: true``\n", + "\n", + "```{=mdx}\n", + "\n", + ":::info Requires ``@langchain/openai >= 0.2.6``\n", + "\n", + "As of Aug 6, 2024, OpenAI supports a `strict` argument when calling tools that will enforce that the tool argument schema is respected by the model. See more here: https://platform.openai.com/docs/guides/function-calling\n", + "\n", + "**Note**: If ``strict: true`` the tool definition will also be validated, and a subset of JSON schema are accepted. Crucially, schema cannot have optional args (those with default values). Read the full docs on what types of schema are supported here: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. \n", + ":::\n", + "\n", + "\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "90f0d465", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\n", + " {\n", + " name: 'get_current_weather',\n", + " args: { location: 'Hanoi' },\n", + " type: 'tool_call',\n", + " id: 'call_aB85ybkLCoccpzqHquuJGH3d'\n", + " }\n", + "]\n" + ] + } + ], + "source": [ + "import { ChatOpenAI } from \"@langchain/openai\";\n", + "import { tool } from \"@langchain/core/tools\";\n", + "import { z } from \"zod\";\n", + "\n", + "const weatherTool = tool((_) => \"no-op\", {\n", + " name: \"get_current_weather\",\n", + " description: \"Get the current weather\",\n", + " schema: z.object({\n", + " location: z.string(),\n", + " }),\n", + "})\n", + "\n", + "const llmWithStrictTrue = new ChatOpenAI({\n", + " model: \"gpt-4o\",\n", + "}).bindTools([weatherTool], {\n", + " strict: true,\n", + " tool_choice: weatherTool.name,\n", + "});\n", + "\n", + "// Although the question is not about the weather, it will call the tool with the correct arguments\n", + "// because we passed `tool_choice` and `strict: true`.\n", + "const strictTrueResult = await llmWithStrictTrue.invoke(\"What is 127862 times 12898 divided by 2?\");\n", "\n", + "console.dir(strictTrueResult.tool_calls, { depth: null });" + ] + }, + { + "cell_type": "markdown", + "id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3", + "metadata": {}, + "source": [ "## API reference\n", "\n", "For detailed documentation of all ChatOpenAI features and configurations head to the API reference: https://api.js.langchain.com/classes/langchain_openai.ChatOpenAI.html" diff --git a/langchain-core/src/language_models/base.ts b/langchain-core/src/language_models/base.ts index 0e8af1bc32bf..cea8ca2f9ae3 100644 --- a/langchain-core/src/language_models/base.ts +++ b/langchain-core/src/language_models/base.ts @@ -233,6 +233,15 @@ export interface FunctionDefinition { * how to call the function. */ description?: string; + + /** + * Whether to enable strict schema adherence when generating the function call. If + * set to true, the model will follow the exact schema defined in the `parameters` + * field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn + * more about Structured Outputs in the + * [function calling guide](https://platform.openai.com/docs/guides/function-calling). + */ + strict?: boolean; } export interface ToolDefinition { diff --git a/langchain-core/src/utils/function_calling.ts b/langchain-core/src/utils/function_calling.ts index 3871ffc4453d..38a976f75d7b 100644 --- a/langchain-core/src/utils/function_calling.ts +++ b/langchain-core/src/utils/function_calling.ts @@ -13,12 +13,26 @@ import { Runnable, RunnableToolLike } from "../runnables/base.js"; * @returns {FunctionDefinition} The inputted tool in OpenAI function format. */ export function convertToOpenAIFunction( - tool: StructuredToolInterface | RunnableToolLike + tool: StructuredToolInterface | RunnableToolLike, + fields?: + | { + /** + * If `true`, model output is guaranteed to exactly match the JSON Schema + * provided in the function definition. + */ + strict?: boolean; + } + | number ): FunctionDefinition { + // @TODO 0.3.0 Remove the `number` typing + const fieldsCopy = typeof fields === "number" ? undefined : fields; + return { name: tool.name, description: tool.description, parameters: zodToJsonSchema(tool.schema), + // Do not include the `strict` field if it is `undefined`. + ...(fieldsCopy?.strict !== undefined ? { strict: fieldsCopy.strict } : {}), }; } @@ -34,15 +48,35 @@ export function convertToOpenAIFunction( */ export function convertToOpenAITool( // eslint-disable-next-line @typescript-eslint/no-explicit-any - tool: StructuredToolInterface | Record | RunnableToolLike + tool: StructuredToolInterface | Record | RunnableToolLike, + fields?: + | { + /** + * If `true`, model output is guaranteed to exactly match the JSON Schema + * provided in the function definition. + */ + strict?: boolean; + } + | number ): ToolDefinition { - if (isStructuredTool(tool) || isRunnableToolLike(tool)) { - return { + // @TODO 0.3.0 Remove the `number` typing + const fieldsCopy = typeof fields === "number" ? undefined : fields; + + let toolDef: ToolDefinition | undefined; + if (isLangChainTool(tool)) { + toolDef = { type: "function", function: convertToOpenAIFunction(tool), }; + } else { + toolDef = tool as ToolDefinition; + } + + if (fieldsCopy?.strict !== undefined) { + toolDef.function.strict = fieldsCopy.strict; } - return tool as ToolDefinition; + + return toolDef; } /** @@ -76,3 +110,21 @@ export function isRunnableToolLike(tool?: unknown): tool is RunnableToolLike { tool.constructor.lc_name() === "RunnableToolLike" ); } + +/** + * Whether or not the tool is one of StructuredTool, RunnableTool or StructuredToolParams. + * It returns `is StructuredToolParams` since that is the most minimal interface of the three, + * while still containing the necessary properties to be passed to a LLM for tool calling. + * + * @param {unknown | undefined} tool The tool to check if it is a LangChain tool. + * @returns {tool is StructuredToolParams} Whether the inputted tool is a LangChain tool. + */ +export function isLangChainTool( + tool?: unknown +): tool is StructuredToolInterface { + return ( + isRunnableToolLike(tool) || + // eslint-disable-next-line @typescript-eslint/no-explicit-any + isStructuredTool(tool as any) + ); +} diff --git a/langchain/src/agents/openai_tools/index.ts b/langchain/src/agents/openai_tools/index.ts index ae071993224e..fe13da61f844 100644 --- a/langchain/src/agents/openai_tools/index.ts +++ b/langchain/src/agents/openai_tools/index.ts @@ -116,7 +116,9 @@ export async function createOpenAIToolsAgent({ ].join("\n") ); } - const modelWithTools = llm.bind({ tools: tools.map(convertToOpenAITool) }); + const modelWithTools = llm.bind({ + tools: tools.map((tool) => convertToOpenAITool(tool)), + }); const agent = AgentRunnableSequence.fromRunnables( [ RunnablePassthrough.assign({ diff --git a/langchain/src/agents/openai_tools/output_parser.ts b/langchain/src/agents/openai_tools/output_parser.ts index dbaa15d8ad27..c18d6a1ff2ab 100644 --- a/langchain/src/agents/openai_tools/output_parser.ts +++ b/langchain/src/agents/openai_tools/output_parser.ts @@ -30,7 +30,7 @@ export type { ToolsAgentAction, ToolsAgentStep }; * new ChatOpenAI({ * modelName: "gpt-3.5-turbo-1106", * temperature: 0, - * }).bind({ tools: tools.map(convertToOpenAITool) }), + * }).bind({ tools: tools.map((tool) => convertToOpenAITool(tool)) }), * new OpenAIToolsAgentOutputParser(), * ]).withConfig({ runName: "OpenAIToolsAgent" }); * diff --git a/libs/langchain-groq/src/chat_models.ts b/libs/langchain-groq/src/chat_models.ts index b2291dc552ce..413e8803fdff 100644 --- a/libs/langchain-groq/src/chat_models.ts +++ b/libs/langchain-groq/src/chat_models.ts @@ -437,7 +437,7 @@ export class ChatGroq extends BaseChatModel< kwargs?: Partial ): Runnable { return this.bind({ - tools: tools.map(convertToOpenAITool), + tools: tools.map((tool) => convertToOpenAITool(tool)), ...kwargs, }); } diff --git a/libs/langchain-ollama/src/chat_models.ts b/libs/langchain-ollama/src/chat_models.ts index 9f70a9e0e0b0..15c7ca31897e 100644 --- a/libs/langchain-ollama/src/chat_models.ts +++ b/libs/langchain-ollama/src/chat_models.ts @@ -298,7 +298,7 @@ export class ChatOllama kwargs?: Partial ): Runnable { return this.bind({ - tools: tools.map(convertToOpenAITool), + tools: tools.map((tool) => convertToOpenAITool(tool)), ...kwargs, }); } @@ -359,7 +359,9 @@ export class ChatOllama stop: options?.stop, }, tools: options?.tools?.length - ? (options.tools.map(convertToOpenAITool) as OllamaTool[]) + ? (options.tools.map((tool) => + convertToOpenAITool(tool) + ) as OllamaTool[]) : undefined, }; } diff --git a/libs/langchain-openai/package.json b/libs/langchain-openai/package.json index 3115ef248c48..7a565308410c 100644 --- a/libs/langchain-openai/package.json +++ b/libs/langchain-openai/package.json @@ -37,7 +37,7 @@ "dependencies": { "@langchain/core": ">=0.2.16 <0.3.0", "js-tiktoken": "^1.0.12", - "openai": "^4.49.1", + "openai": "^4.55.0", "zod": "^3.22.4", "zod-to-json-schema": "^3.22.3" }, diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index db86e6e91940..52d2eeb52bf4 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -27,12 +27,13 @@ import { LangSmithParams, type BaseChatModelParams, } from "@langchain/core/language_models/chat_models"; -import type { - BaseFunctionCallOptions, - BaseLanguageModelInput, - FunctionDefinition, - StructuredOutputMethodOptions, - StructuredOutputMethodParams, +import { + isOpenAITool, + type BaseFunctionCallOptions, + type BaseLanguageModelInput, + type FunctionDefinition, + type StructuredOutputMethodOptions, + type StructuredOutputMethodParams, } from "@langchain/core/language_models/base"; import { NewTokenIndices } from "@langchain/core/callbacks/base"; import { convertToOpenAITool } from "@langchain/core/utils/function_calling"; @@ -274,10 +275,59 @@ function convertMessagesToOpenAIParams(messages: BaseMessage[]) { }); } +type ChatOpenAIToolType = + | StructuredToolInterface + | OpenAIClient.ChatCompletionTool + | RunnableToolLike + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record; + +function _convertChatOpenAIToolTypeToOpenAITool( + tool: ChatOpenAIToolType, + fields?: { + strict?: boolean; + } +): OpenAIClient.ChatCompletionTool { + if (isOpenAITool(tool)) { + if (fields?.strict !== undefined) { + return { + ...tool, + function: { + ...tool.function, + strict: fields.strict, + }, + }; + } + + return tool; + } + return convertToOpenAITool(tool, fields); +} + +export interface ChatOpenAIStructuredOutputMethodOptions< + IncludeRaw extends boolean +> extends StructuredOutputMethodOptions { + /** + * strict: If `true` and `method` = "function_calling", model output is + * guaranteed to exactly match the schema. If `true`, the input schema + * will also be validated according to + * https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. + * If `false`, input schema will not be validated and model output will not + * be validated. + * If `undefined`, `strict` argument will not be passed to the model. + * + * @version 0.2.6 + * @note Planned breaking change in version `0.3.0`: + * `strict` will default to `true` when `method` is + * "function_calling" as of version `0.3.0`. + */ + strict?: boolean; +} + export interface ChatOpenAICallOptions extends OpenAICallOptions, BaseFunctionCallOptions { - tools?: StructuredToolInterface[] | OpenAIClient.ChatCompletionTool[]; + tools?: ChatOpenAIToolType[]; tool_choice?: OpenAIToolChoice; promptIndex?: number; response_format?: { type: "json_object" }; @@ -299,6 +349,27 @@ export interface ChatOpenAICallOptions * call multiple tools in one response. */ parallel_tool_calls?: boolean; + /** + * If `true`, model output is guaranteed to exactly match the JSON Schema + * provided in the tool definition. If `true`, the input schema will also be + * validated according to + * https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. + * + * If `false`, input schema will not be validated and model output will not + * be validated. + * + * If `undefined`, `strict` argument will not be passed to the model. + * + * @version 0.2.6 + */ + strict?: boolean; +} + +export interface ChatOpenAIFields + extends Partial, + Partial, + BaseChatModelParams { + configuration?: ClientOptions & LegacyOpenAIInput; } /** @@ -441,12 +512,14 @@ export class ChatOpenAI< protected clientConfig: ClientOptions; + /** + * Whether the model supports the `strict` argument when passing in tools. + * If `undefined` the `strict` argument will not be passed to OpenAI. + */ + supportsStrictToolCalling?: boolean; + constructor( - fields?: Partial & - Partial & - BaseChatModelParams & { - configuration?: ClientOptions & LegacyOpenAIInput; - }, + fields?: ChatOpenAIFields, /** @deprecated */ configuration?: ClientOptions & LegacyOpenAIInput ) { @@ -541,6 +614,12 @@ export class ChatOpenAI< ...configuration, ...fields?.configuration, }; + + // If `supportsStrictToolCalling` is explicitly set, use that value. + // Else leave undefined so it's not passed to OpenAI. + if (fields?.supportsStrictToolCalling !== undefined) { + this.supportsStrictToolCalling = fields.supportsStrictToolCalling; + } } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { @@ -556,15 +635,19 @@ export class ChatOpenAI< } override bindTools( - tools: ( - | Record - | StructuredToolInterface - | RunnableToolLike - )[], + tools: ChatOpenAIToolType[], kwargs?: Partial ): Runnable { + let strict: boolean | undefined; + if (kwargs?.strict !== undefined) { + strict = kwargs.strict; + } else if (this.supportsStrictToolCalling !== undefined) { + strict = this.supportsStrictToolCalling; + } return this.bind({ - tools: tools.map(convertToOpenAITool), + tools: tools.map((tool) => + _convertChatOpenAIToolTypeToOpenAITool(tool, { strict }) + ), ...kwargs, } as Partial); } @@ -578,16 +661,13 @@ export class ChatOpenAI< streaming?: boolean; } ): Omit { - function isStructuredToolArray( - tools?: unknown[] - ): tools is StructuredToolInterface[] { - return ( - tools !== undefined && - tools.every((tool) => - Array.isArray((tool as StructuredToolInterface).lc_namespace) - ) - ); + let strict: boolean | undefined; + if (options?.strict !== undefined) { + strict = options.strict; + } else if (this.supportsStrictToolCalling !== undefined) { + strict = this.supportsStrictToolCalling; } + let streamOptionsConfig = {}; if (options?.stream_options !== undefined) { streamOptionsConfig = { stream_options: options.stream_options }; @@ -614,9 +694,11 @@ export class ChatOpenAI< stream: this.streaming, functions: options?.functions, function_call: options?.function_call, - tools: isStructuredToolArray(options?.tools) - ? options?.tools.map(convertToOpenAITool) - : options?.tools, + tools: options?.tools?.length + ? options.tools.map((tool) => + _convertChatOpenAIToolTypeToOpenAITool(tool, { strict }) + ) + : undefined, tool_choice: formatToOpenAIToolChoice(options?.tool_choice), response_format: options?.response_format, seed: options?.seed, @@ -1098,7 +1180,7 @@ export class ChatOpenAI< | z.ZodType // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, - config?: StructuredOutputMethodOptions + config?: ChatOpenAIStructuredOutputMethodOptions ): Runnable; withStructuredOutput< @@ -1110,7 +1192,7 @@ export class ChatOpenAI< | z.ZodType // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, - config?: StructuredOutputMethodOptions + config?: ChatOpenAIStructuredOutputMethodOptions ): Runnable; withStructuredOutput< @@ -1122,7 +1204,7 @@ export class ChatOpenAI< | z.ZodType // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, - config?: StructuredOutputMethodOptions + config?: ChatOpenAIStructuredOutputMethodOptions ): | Runnable | Runnable< @@ -1148,6 +1230,12 @@ export class ChatOpenAI< let llm: Runnable; let outputParser: BaseLLMOutputParser; + if (config?.strict !== undefined && method === "jsonMode") { + throw new Error( + "Argument `strict` is only supported for `method` = 'function_calling'" + ); + } + if (method === "jsonMode") { llm = this.bind({ response_format: { type: "json_object" }, @@ -1179,6 +1267,8 @@ export class ChatOpenAI< name: functionName, }, }, + // Do not pass `strict` argument to OpenAI if `config.strict` is undefined + ...(config?.strict !== undefined ? { strict: config.strict } : {}), } as Partial); outputParser = new JsonOutputKeyToolsParser({ returnSingle: true, @@ -1215,6 +1305,8 @@ export class ChatOpenAI< name: functionName, }, }, + // Do not pass `strict` argument to OpenAI if `config.strict` is undefined + ...(config?.strict !== undefined ? { strict: config.strict } : {}), } as Partial); outputParser = new JsonOutputKeyToolsParser({ returnSingle: true, diff --git a/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts b/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts index 8146f04d0f88..64052685d6c2 100644 --- a/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts +++ b/libs/langchain-openai/src/tests/azure/chat_models.standard.int.test.ts @@ -1,17 +1,25 @@ /* eslint-disable no-process-env */ -import { test, expect } from "@jest/globals"; +import { test, expect, beforeAll, afterAll } from "@jest/globals"; import { ChatModelIntegrationTests } from "@langchain/standard-tests"; import { AIMessageChunk } from "@langchain/core/messages"; import { AzureChatOpenAI } from "../../azure/chat_models.js"; import { ChatOpenAICallOptions } from "../../chat_models.js"; +let openAIAPIKey: string | undefined; + beforeAll(() => { + if (process.env.OPENAI_API_KEY) { + openAIAPIKey = process.env.OPENAI_API_KEY; + process.env.OPENAI_API_KEY = ""; + } + if (!process.env.AZURE_OPENAI_API_KEY) { process.env.AZURE_OPENAI_API_KEY = process.env.TEST_AZURE_OPENAI_API_KEY; } if (!process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME) { process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = - process.env.TEST_AZURE_OPENAI_API_DEPLOYMENT_NAME; + process.env.TEST_AZURE_OPENAI_API_DEPLOYMENT_NAME ?? + process.env.AZURE_OPENAI_CHAT_DEPLOYMENT_NAME; } if (!process.env.AZURE_OPENAI_BASE_PATH) { process.env.AZURE_OPENAI_BASE_PATH = @@ -23,6 +31,12 @@ beforeAll(() => { } }); +afterAll(() => { + if (openAIAPIKey) { + process.env.OPENAI_API_KEY = openAIAPIKey; + } +}); + class AzureChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests< ChatOpenAICallOptions, AIMessageChunk @@ -35,6 +49,7 @@ class AzureChatOpenAIStandardIntegrationTests extends ChatModelIntegrationTests< supportsParallelToolCalls: true, constructorArgs: { model: "gpt-3.5-turbo", + maxRetries: 0, }, }); } diff --git a/libs/langchain-openai/src/tests/chat_models.test.ts b/libs/langchain-openai/src/tests/chat_models.test.ts new file mode 100644 index 000000000000..a24c180ff1d0 --- /dev/null +++ b/libs/langchain-openai/src/tests/chat_models.test.ts @@ -0,0 +1,259 @@ +/* eslint-disable @typescript-eslint/no-explicit-any, no-process-env */ +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { it, expect, describe, beforeAll, afterAll, jest } from "@jest/globals"; +import { ChatOpenAI } from "../chat_models.js"; + +describe("strict tool calling", () => { + const weatherTool = { + type: "function" as const, + function: { + name: "get_current_weather", + description: "Get the current weather in a location", + parameters: zodToJsonSchema( + z.object({ + location: z.string().describe("The location to get the weather for"), + }) + ), + }, + }; + + // Store the original value of LANGCHAIN_TRACING_V2 + let oldLangChainTracingValue: string | undefined; + // Before all tests, save the current LANGCHAIN_TRACING_V2 value + beforeAll(() => { + oldLangChainTracingValue = process.env.LANGCHAIN_TRACING_V2; + }); + // After all tests, restore the original LANGCHAIN_TRACING_V2 value + afterAll(() => { + if (oldLangChainTracingValue !== undefined) { + process.env.LANGCHAIN_TRACING_V2 = oldLangChainTracingValue; + } else { + // If it was undefined, remove the environment variable + delete process.env.LANGCHAIN_TRACING_V2; + } + }); + + it("Can accept strict as a call arg via .bindTools", async () => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { + // Store the request details for later inspection + mockFetch.mock.calls.push([url, options]); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }); + }); + + const model = new ChatOpenAI({ + model: "gpt-4", + apiKey: "test-key", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + const modelWithTools = model.bindTools([weatherTool], { strict: true }); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + expect(JSON.parse(options.body).tools[0].function).toHaveProperty( + "strict", + true + ); + } else { + throw new Error("Body not found in request."); + } + }); + + it("Can accept strict as a call arg via .bind", async () => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { + // Store the request details for later inspection + mockFetch.mock.calls.push([url, options]); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }); + }); + + const model = new ChatOpenAI({ + model: "gpt-4", + apiKey: "test-key", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + const modelWithTools = model.bind({ + tools: [weatherTool], + strict: true, + }); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + expect(JSON.parse(options.body).tools[0].function).toHaveProperty( + "strict", + true + ); + } else { + throw new Error("Body not found in request."); + } + }); + + it("Strict is false if supportsStrictToolCalling is false", async () => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { + // Store the request details for later inspection + mockFetch.mock.calls.push([url, options]); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }); + }); + + const model = new ChatOpenAI({ + model: "gpt-4", + apiKey: "test-key", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + supportsStrictToolCalling: false, + }); + + // Do NOT pass `strict` here since we're checking that it's set to true by default + const modelWithTools = model.bindTools([weatherTool]); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + expect(JSON.parse(options.body).tools[0].function).toHaveProperty( + "strict", + false + ); + } else { + throw new Error("Body not found in request."); + } + }); + + it("Strict is set to true if passed in .withStructuredOutput", async () => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { + // Store the request details for later inspection + mockFetch.mock.calls.push([url, options]); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }); + }); + + const model = new ChatOpenAI({ + model: "doesnt-start-with-gpt-4", + apiKey: "test-key", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + supportsStrictToolCalling: true, + }); + + const modelWithTools = model.withStructuredOutput( + z.object({ + location: z.string().describe("The location to get the weather for"), + }), + { + strict: true, + } + ); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + const body = JSON.parse(options.body); + expect(body.tools[0].function).toHaveProperty("strict", true); + } else { + throw new Error("Body not found in request."); + } + }); + + it("Strict is NOT passed to OpenAI if NOT passed in .withStructuredOutput", async () => { + const mockFetch = jest.fn<(url: any, options?: any) => Promise>(); + mockFetch.mockImplementation((url, options) => { + // Store the request details for later inspection + mockFetch.mock.calls.push([url, options]); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }); + }); + + const model = new ChatOpenAI({ + model: "doesnt-start-with-gpt-4", + apiKey: "test-key", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + const modelWithTools = model.withStructuredOutput( + z.object({ + location: z.string().describe("The location to get the weather for"), + }) + ); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect( + modelWithTools.invoke("What's the weather like?") + ).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + const body = JSON.parse(options.body); + expect(body.tools[0].function).not.toHaveProperty("strict"); + } else { + throw new Error("Body not found in request."); + } + }); +}); diff --git a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts index 86bf0247bd49..bc0328357a73 100644 --- a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts @@ -2,6 +2,7 @@ import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatPromptTemplate } from "@langchain/core/prompts"; import { AIMessage } from "@langchain/core/messages"; +import { test, expect } from "@jest/globals"; import { ChatOpenAI } from "../chat_models.js"; test("withStructuredOutput zod schema function calling", async () => { @@ -320,3 +321,37 @@ test("parallelToolCalls param", async () => { // console.log(response.tool_calls); expect(response.tool_calls?.length).toBe(1); }); + +test("Passing strict true forces the model to conform to the schema", async () => { + const model = new ChatOpenAI({ + model: "gpt-4o", + temperature: 0, + maxRetries: 0, + }); + + const weatherTool = { + type: "function" as const, + function: { + name: "get_current_weather", + description: "Get the current weather in a location", + parameters: zodToJsonSchema( + z.object({ + location: z.string().describe("The location to get the weather for"), + }) + ), + }, + }; + const modelWithTools = model.bindTools([weatherTool], { + strict: true, + tool_choice: "get_current_weather", + }); + + const result = await modelWithTools.invoke( + "Whats the result of 173827 times 287326 divided by 2?" + ); + // Expect at least one tool call, allow multiple + expect(result.tool_calls?.length).toBeGreaterThanOrEqual(1); + expect(result.tool_calls?.[0].name).toBe("get_current_weather"); + expect(result.tool_calls?.[0].args).toHaveProperty("location"); + console.log(result.tool_calls?.[0].args); +}); diff --git a/libs/langchain-openai/src/types.ts b/libs/langchain-openai/src/types.ts index 19e6af483d7d..0d93089619e2 100644 --- a/libs/langchain-openai/src/types.ts +++ b/libs/langchain-openai/src/types.ts @@ -155,6 +155,12 @@ export interface OpenAIChatInput extends OpenAIBaseInput { * Currently in experimental beta. */ __includeRawResponse?: boolean; + + /** + * Whether the model supports the `strict` argument when passing in tools. + * If `undefined` the `strict` argument will not be passed to OpenAI. + */ + supportsStrictToolCalling?: boolean; } export declare interface AzureOpenAIInput { diff --git a/yarn.lock b/yarn.lock index c96e97605695..075329f22ffb 100644 --- a/yarn.lock +++ b/yarn.lock @@ -12199,7 +12199,7 @@ __metadata: jest: ^29.5.0 jest-environment-node: ^29.6.4 js-tiktoken: ^1.0.12 - openai: ^4.49.1 + openai: ^4.55.0 prettier: ^2.8.3 release-it: ^17.6.0 rimraf: ^5.0.1 @@ -34040,6 +34040,28 @@ __metadata: languageName: node linkType: hard +"openai@npm:^4.55.0": + version: 4.55.0 + resolution: "openai@npm:4.55.0" + dependencies: + "@types/node": ^18.11.18 + "@types/node-fetch": ^2.6.4 + abort-controller: ^3.0.0 + agentkeepalive: ^4.2.1 + form-data-encoder: 1.7.2 + formdata-node: ^4.3.2 + node-fetch: ^2.6.7 + peerDependencies: + zod: ^3.23.8 + peerDependenciesMeta: + zod: + optional: true + bin: + openai: bin/cli + checksum: b2b1daa976516262e08e182ee982976a1dc615eebd250bbd71f4122740ebeeb207a20af6d35c718b67f1c3457196b524667a0c7fa417ab4e119020b5c1f5cd74 + languageName: node + linkType: hard + "openapi-types@npm:^12.1.3": version: 12.1.3 resolution: "openapi-types@npm:12.1.3"