diff --git a/langchain-core/src/language_models/base.ts b/langchain-core/src/language_models/base.ts index 2f3e8ce1c5ee..56dd3b9aa80d 100644 --- a/langchain-core/src/language_models/base.ts +++ b/langchain-core/src/language_models/base.ts @@ -82,6 +82,27 @@ export const getModelContextSize = (modelName: string): number => { } }; +/** + * Whether or not the input matches the OpenAI tool definition. + * @param {unknown} tool The input to check. + * @returns {boolean} Whether the input is an OpenAI tool definition. + */ +export function isOpenAITool(tool: unknown): tool is ToolDefinition { + if (typeof tool !== "object" || !tool) return false; + if ( + "type" in tool && + tool.type === "function" && + "function" in tool && + typeof tool.function === "object" && + tool.function && + "name" in tool.function && + "parameters" in tool.function + ) { + return true; + } + return false; +} + interface CalculateMaxTokenProps { prompt: string; modelName: TiktokenModel; diff --git a/langchain-core/src/language_models/chat_models.ts b/langchain-core/src/language_models/chat_models.ts index d78bb04b2e0e..f3b900427ba0 100644 --- a/langchain-core/src/language_models/chat_models.ts +++ b/langchain-core/src/language_models/chat_models.ts @@ -1,3 +1,5 @@ +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; import { AIMessage, type BaseMessage, @@ -5,6 +7,7 @@ import { type BaseMessageLike, HumanMessage, coerceMessageLikeToMessage, + AIMessageChunk, } from "../messages/index.js"; import type { BasePromptValueInterface } from "../prompt_values.js"; import { @@ -17,6 +20,8 @@ import { } from "../outputs.js"; import { BaseLanguageModel, + StructuredOutputMethodOptions, + ToolDefinition, type BaseLanguageModelCallOptions, type BaseLanguageModelInput, type BaseLanguageModelParams, @@ -29,10 +34,16 @@ import { import type { RunnableConfig } from "../runnables/config.js"; import type { BaseCache } from "../caches.js"; import { StructuredToolInterface } from "../tools.js"; -import { Runnable } from "../runnables/base.js"; +import { + Runnable, + RunnableLambda, + RunnableSequence, +} from "../runnables/base.js"; import { isStreamEventsHandler } from "../tracers/event_stream.js"; import { isLogStreamHandler } from "../tracers/log_stream.js"; import { concat } from "../utils/stream.js"; +import { RunnablePassthrough } from "../runnables/passthrough.js"; +import { isZodSchema } from "../utils/types/is_zod_schema.js"; /** * Represents a serialized chat model. @@ -143,12 +154,16 @@ export abstract class BaseChatModel< * Bind tool-like objects to this chat model. * * @param tools A list of tool definitions to bind to this chat model. - * Can be a structured tool or an object matching the provider's - * specific tool schema. + * Can be a structured tool, an OpenAI formatted tool, or an object + * matching the provider's specific tool schema. * @param kwargs Any additional parameters to bind. */ bindTools?( - tools: (StructuredToolInterface | Record)[], + tools: ( + | StructuredToolInterface + | Record + | ToolDefinition + )[], kwargs?: Partial ): Runnable; @@ -714,6 +729,138 @@ export abstract class BaseChatModel< } return result.content; } + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + // eslint-disable-next-line @typescript-eslint/no-explicit-any + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): Runnable; + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): Runnable; + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): + | Runnable + | Runnable< + BaseLanguageModelInput, + { + raw: BaseMessage; + parsed: RunOutput; + } + > { + if (typeof this.bindTools !== "function") { + throw new Error( + `Chat model must implement ".bindTools()" to use withStructuredOutput.` + ); + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const schema: z.ZodType | Record = outputSchema; + const name = config?.name; + const description = schema.description ?? "A function available to call."; + const method = config?.method; + const includeRaw = config?.includeRaw; + if (method === "jsonMode") { + throw new Error( + `Base withStructuredOutput implementation only supports "functionCalling" as a method.` + ); + } + + let functionName = name ?? "extract"; + let tools: ToolDefinition[]; + if (isZodSchema(schema)) { + tools = [ + { + type: "function", + function: { + name: functionName, + description, + parameters: zodToJsonSchema(schema), + }, + }, + ]; + } else { + if ("name" in schema) { + functionName = schema.name; + } + tools = [ + { + type: "function", + function: { + name: functionName, + description, + parameters: schema, + }, + }, + ]; + } + + const llm = this.bindTools(tools); + const outputParser = RunnableLambda.from( + (input: AIMessageChunk): RunOutput => { + if (!input.tool_calls || input.tool_calls.length === 0) { + throw new Error("No tool calls found in the response."); + } + const toolCall = input.tool_calls.find( + (tc) => tc.name === functionName + ); + if (!toolCall) { + throw new Error(`No tool call found with name ${functionName}.`); + } + return toolCall.args as RunOutput; + } + ); + + if (!includeRaw) { + return llm.pipe(outputParser).withConfig({ + runName: "StructuredOutput", + }) as Runnable; + } + + const parserAssign = RunnablePassthrough.assign({ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + parsed: (input: any, config) => outputParser.invoke(input.raw, config), + }); + const parserNone = RunnablePassthrough.assign({ + parsed: () => null, + }); + const parsedWithFallback = parserAssign.withFallbacks({ + fallbacks: [parserNone], + }); + return RunnableSequence.from< + BaseLanguageModelInput, + { raw: BaseMessage; parsed: RunOutput } + >([ + { + raw: llm, + }, + parsedWithFallback, + ]).withConfig({ + runName: "StructuredOutputRunnable", + }); + } } /**