Skip to content

Commit

Permalink
openai[minor],core[minor]: Add support for passing strict in openai t…
Browse files Browse the repository at this point in the history
…ools (#6418)

* openai[minor],core[minor]: Add support for passing strict in openai tools

* add integration test

* chore: lint files

* Cr

* cr

* fix

* fixed all tests

* docs

* fix build errors

* fix more type errors

* cr

* cr

* chore: lint files
  • Loading branch information
bracesproul authored Aug 6, 2024
1 parent b79ea85 commit 1296334
Show file tree
Hide file tree
Showing 14 changed files with 617 additions and 49 deletions.
78 changes: 76 additions & 2 deletions docs/core_docs/docs/integrations/chat/openai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@
},
{
"cell_type": "markdown",
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
"id": "bc5ecebd",
"metadata": {},
"source": [
"## Tool calling\n",
Expand All @@ -420,8 +420,82 @@
"\n",
"- [How to: disable parallel tool calling](/docs/how_to/tool_calling_parallel/)\n",
"- [How to: force a tool call](/docs/how_to/tool_choice/)\n",
"- [How to: bind model-specific tool formats to a model](/docs/how_to/tool_calling#binding-model-specific-formats-advanced).\n",
"- [How to: bind model-specific tool formats to a model](/docs/how_to/tool_calling#binding-model-specific-formats-advanced)."
]
},
{
"cell_type": "markdown",
"id": "3392390e",
"metadata": {},
"source": [
"### ``strict: true``\n",
"\n",
"```{=mdx}\n",
"\n",
":::info Requires ``@langchain/openai >= 0.2.6``\n",
"\n",
"As of Aug 6, 2024, OpenAI supports a `strict` argument when calling tools that will enforce that the tool argument schema is respected by the model. See more here: https://platform.openai.com/docs/guides/function-calling\n",
"\n",
"**Note**: If ``strict: true`` the tool definition will also be validated, and a subset of JSON schema are accepted. Crucially, schema cannot have optional args (those with default values). Read the full docs on what types of schema are supported here: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. \n",
":::\n",
"\n",
"\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "90f0d465",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\n",
" {\n",
" name: 'get_current_weather',\n",
" args: { location: 'Hanoi' },\n",
" type: 'tool_call',\n",
" id: 'call_aB85ybkLCoccpzqHquuJGH3d'\n",
" }\n",
"]\n"
]
}
],
"source": [
"import { ChatOpenAI } from \"@langchain/openai\";\n",
"import { tool } from \"@langchain/core/tools\";\n",
"import { z } from \"zod\";\n",
"\n",
"const weatherTool = tool((_) => \"no-op\", {\n",
" name: \"get_current_weather\",\n",
" description: \"Get the current weather\",\n",
" schema: z.object({\n",
" location: z.string(),\n",
" }),\n",
"})\n",
"\n",
"const llmWithStrictTrue = new ChatOpenAI({\n",
" model: \"gpt-4o\",\n",
"}).bindTools([weatherTool], {\n",
" strict: true,\n",
" tool_choice: weatherTool.name,\n",
"});\n",
"\n",
"// Although the question is not about the weather, it will call the tool with the correct arguments\n",
"// because we passed `tool_choice` and `strict: true`.\n",
"const strictTrueResult = await llmWithStrictTrue.invoke(\"What is 127862 times 12898 divided by 2?\");\n",
"\n",
"console.dir(strictTrueResult.tool_calls, { depth: null });"
]
},
{
"cell_type": "markdown",
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatOpenAI features and configurations head to the API reference: https://api.js.langchain.com/classes/langchain_openai.ChatOpenAI.html"
Expand Down
9 changes: 9 additions & 0 deletions langchain-core/src/language_models/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,15 @@ export interface FunctionDefinition {
* how to call the function.
*/
description?: string;

/**
* Whether to enable strict schema adherence when generating the function call. If
* set to true, the model will follow the exact schema defined in the `parameters`
* field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn
* more about Structured Outputs in the
* [function calling guide](https://platform.openai.com/docs/guides/function-calling).
*/
strict?: boolean;
}

export interface ToolDefinition {
Expand Down
62 changes: 57 additions & 5 deletions langchain-core/src/utils/function_calling.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,26 @@ import { Runnable, RunnableToolLike } from "../runnables/base.js";
* @returns {FunctionDefinition} The inputted tool in OpenAI function format.
*/
export function convertToOpenAIFunction(
tool: StructuredToolInterface | RunnableToolLike
tool: StructuredToolInterface | RunnableToolLike,
fields?:
| {
/**
* If `true`, model output is guaranteed to exactly match the JSON Schema
* provided in the function definition.
*/
strict?: boolean;
}
| number
): FunctionDefinition {
// @TODO 0.3.0 Remove the `number` typing
const fieldsCopy = typeof fields === "number" ? undefined : fields;

return {
name: tool.name,
description: tool.description,
parameters: zodToJsonSchema(tool.schema),
// Do not include the `strict` field if it is `undefined`.
...(fieldsCopy?.strict !== undefined ? { strict: fieldsCopy.strict } : {}),
};
}

Expand All @@ -34,15 +48,35 @@ export function convertToOpenAIFunction(
*/
export function convertToOpenAITool(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike,
fields?:
| {
/**
* If `true`, model output is guaranteed to exactly match the JSON Schema
* provided in the function definition.
*/
strict?: boolean;
}
| number
): ToolDefinition {
if (isStructuredTool(tool) || isRunnableToolLike(tool)) {
return {
// @TODO 0.3.0 Remove the `number` typing
const fieldsCopy = typeof fields === "number" ? undefined : fields;

let toolDef: ToolDefinition | undefined;
if (isLangChainTool(tool)) {
toolDef = {
type: "function",
function: convertToOpenAIFunction(tool),
};
} else {
toolDef = tool as ToolDefinition;
}

if (fieldsCopy?.strict !== undefined) {
toolDef.function.strict = fieldsCopy.strict;
}
return tool as ToolDefinition;

return toolDef;
}

/**
Expand Down Expand Up @@ -76,3 +110,21 @@ export function isRunnableToolLike(tool?: unknown): tool is RunnableToolLike {
tool.constructor.lc_name() === "RunnableToolLike"
);
}

/**
* Whether or not the tool is one of StructuredTool, RunnableTool or StructuredToolParams.
* It returns `is StructuredToolParams` since that is the most minimal interface of the three,
* while still containing the necessary properties to be passed to a LLM for tool calling.
*
* @param {unknown | undefined} tool The tool to check if it is a LangChain tool.
* @returns {tool is StructuredToolParams} Whether the inputted tool is a LangChain tool.
*/
export function isLangChainTool(
tool?: unknown
): tool is StructuredToolInterface {
return (
isRunnableToolLike(tool) ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
isStructuredTool(tool as any)
);
}
4 changes: 3 additions & 1 deletion langchain/src/agents/openai_tools/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ export async function createOpenAIToolsAgent({
].join("\n")
);
}
const modelWithTools = llm.bind({ tools: tools.map(convertToOpenAITool) });
const modelWithTools = llm.bind({
tools: tools.map((tool) => convertToOpenAITool(tool)),
});
const agent = AgentRunnableSequence.fromRunnables(
[
RunnablePassthrough.assign({
Expand Down
2 changes: 1 addition & 1 deletion langchain/src/agents/openai_tools/output_parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export type { ToolsAgentAction, ToolsAgentStep };
* new ChatOpenAI({
* modelName: "gpt-3.5-turbo-1106",
* temperature: 0,
* }).bind({ tools: tools.map(convertToOpenAITool) }),
* }).bind({ tools: tools.map((tool) => convertToOpenAITool(tool)) }),
* new OpenAIToolsAgentOutputParser(),
* ]).withConfig({ runName: "OpenAIToolsAgent" });
*
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-groq/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ export class ChatGroq extends BaseChatModel<
kwargs?: Partial<ChatGroqCallOptions>
): Runnable<BaseLanguageModelInput, AIMessageChunk, ChatGroqCallOptions> {
return this.bind({
tools: tools.map(convertToOpenAITool),
tools: tools.map((tool) => convertToOpenAITool(tool)),
...kwargs,
});
}
Expand Down
6 changes: 4 additions & 2 deletions libs/langchain-ollama/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ export class ChatOllama
kwargs?: Partial<this["ParsedCallOptions"]>
): Runnable<BaseLanguageModelInput, AIMessageChunk, ChatOllamaCallOptions> {
return this.bind({
tools: tools.map(convertToOpenAITool),
tools: tools.map((tool) => convertToOpenAITool(tool)),
...kwargs,
});
}
Expand Down Expand Up @@ -359,7 +359,9 @@ export class ChatOllama
stop: options?.stop,
},
tools: options?.tools?.length
? (options.tools.map(convertToOpenAITool) as OllamaTool[])
? (options.tools.map((tool) =>
convertToOpenAITool(tool)
) as OllamaTool[])
: undefined,
};
}
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"dependencies": {
"@langchain/core": ">=0.2.16 <0.3.0",
"js-tiktoken": "^1.0.12",
"openai": "^4.49.1",
"openai": "^4.55.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.3"
},
Expand Down
Loading

0 comments on commit 1296334

Please sign in to comment.