Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor]: Add base implementation of withStructuredOutput #5752

Merged
merged 8 commits into from
Jun 21, 2024
Merged
21 changes: 21 additions & 0 deletions langchain-core/src/language_models/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,27 @@ export const getModelContextSize = (modelName: string): number => {
}
};

/**
* Whether or not the input matches the OpenAI tool definition.
* @param {unknown} tool The input to check.
* @returns {boolean} Whether the input is an OpenAI tool definition.
*/
export function isOpenAITool(tool: unknown): tool is ToolDefinition {
if (typeof tool !== "object" || !tool) return false;
if (
"type" in tool &&
tool.type === "function" &&
"function" in tool &&
typeof tool.function === "object" &&
tool.function &&
"name" in tool.function &&
"parameters" in tool.function
) {
return true;
}
return false;
}

interface CalculateMaxTokenProps {
prompt: string;
modelName: TiktokenModel;
Expand Down
155 changes: 151 additions & 4 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import {
AIMessage,
type BaseMessage,
BaseMessageChunk,
type BaseMessageLike,
HumanMessage,
coerceMessageLikeToMessage,
AIMessageChunk,
} from "../messages/index.js";
import type { BasePromptValueInterface } from "../prompt_values.js";
import {
Expand All @@ -17,6 +20,8 @@ import {
} from "../outputs.js";
import {
BaseLanguageModel,
StructuredOutputMethodOptions,
ToolDefinition,
type BaseLanguageModelCallOptions,
type BaseLanguageModelInput,
type BaseLanguageModelParams,
Expand All @@ -29,10 +34,16 @@ import {
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";
import { StructuredToolInterface } from "../tools.js";
import { Runnable } from "../runnables/base.js";
import {
Runnable,
RunnableLambda,
RunnableSequence,
} from "../runnables/base.js";
import { isStreamEventsHandler } from "../tracers/event_stream.js";
import { isLogStreamHandler } from "../tracers/log_stream.js";
import { concat } from "../utils/stream.js";
import { RunnablePassthrough } from "../runnables/passthrough.js";
import { isZodSchema } from "../utils/types/is_zod_schema.js";

/**
* Represents a serialized chat model.
Expand Down Expand Up @@ -143,12 +154,16 @@ export abstract class BaseChatModel<
* Bind tool-like objects to this chat model.
*
* @param tools A list of tool definitions to bind to this chat model.
* Can be a structured tool or an object matching the provider's
* specific tool schema.
* Can be a structured tool, an OpenAI formatted tool, or an object
* matching the provider's specific tool schema.
* @param kwargs Any additional parameters to bind.
*/
bindTools?(
tools: (StructuredToolInterface | Record<string, unknown>)[],
tools: (
| StructuredToolInterface
| Record<string, unknown>
| ToolDefinition
)[],
kwargs?: Partial<CallOptions>
): Runnable<BaseLanguageModelInput, OutputMessageType, CallOptions>;

Expand Down Expand Up @@ -714,6 +729,138 @@ export abstract class BaseChatModel<
}
return result.content;
}

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<false>
): Runnable<BaseLanguageModelInput, RunOutput>;

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<true>
): Runnable<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }>;

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<boolean>
):
| Runnable<BaseLanguageModelInput, RunOutput>
| Runnable<
BaseLanguageModelInput,
{
raw: BaseMessage;
parsed: RunOutput;
}
> {
if (typeof this.bindTools !== "function") {
throw new Error(
`Chat model must implement ".bindTools()" to use withStructuredOutput.`
);
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const schema: z.ZodType<RunOutput> | Record<string, any> = outputSchema;
const name = config?.name;
const description = schema.description ?? "A function available to call.";
const method = config?.method;
const includeRaw = config?.includeRaw;
if (method === "jsonMode") {
throw new Error(
`Base withStructuredOutput implementation only supports "functionCalling" as a method.`
);
}

let functionName = name ?? "extract";
let tools: ToolDefinition[];
if (isZodSchema(schema)) {
tools = [
{
type: "function",
function: {
name: functionName,
description,
parameters: zodToJsonSchema(schema),
},
},
];
} else {
if ("name" in schema) {
functionName = schema.name;
}
tools = [
{
type: "function",
function: {
name: functionName,
description,
parameters: schema,
},
},
];
}

const llm = this.bindTools(tools);
const outputParser = RunnableLambda.from<AIMessageChunk, RunOutput>(
(input: AIMessageChunk): RunOutput => {
if (!input.tool_calls || input.tool_calls.length === 0) {
throw new Error("No tool calls found in the response.");
}
const toolCall = input.tool_calls.find(
(tc) => tc.name === functionName
);
if (!toolCall) {
throw new Error(`No tool call found with name ${functionName}.`);
}
return toolCall.args as RunOutput;
}
);

if (!includeRaw) {
return llm.pipe(outputParser).withConfig({
runName: "StructuredOutput",
}) as Runnable<BaseLanguageModelInput, RunOutput>;
}

const parserAssign = RunnablePassthrough.assign({
// eslint-disable-next-line @typescript-eslint/no-explicit-any
parsed: (input: any, config) => outputParser.invoke(input.raw, config),
});
const parserNone = RunnablePassthrough.assign({
parsed: () => null,
});
const parsedWithFallback = parserAssign.withFallbacks({
fallbacks: [parserNone],
});
return RunnableSequence.from<
BaseLanguageModelInput,
{ raw: BaseMessage; parsed: RunOutput }
>([
{
raw: llm,
},
parsedWithFallback,
]).withConfig({
runName: "StructuredOutputRunnable",
});
}
}

/**
Expand Down
Loading