Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor],openai[patch],langchain[patch]: Allow tool functions to input ToolCall / return ToolMessage #6005

Merged
merged 21 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion langchain-core/src/callbacks/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
import type { SerializedFields } from "../load/map_keys.js";
import type { DocumentInterface } from "../documents/document.js";
import { getEnvironmentVariable } from "../utils/env.js";
import { ToolMessage } from "../messages/tool.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type Error = any;
Expand Down Expand Up @@ -197,7 +198,7 @@ abstract class BaseCallbackHandlerMethodsClass {
* Called at the end of a Tool run, with the tool output and the run ID.
*/
handleToolEnd?(
output: string,
output: string | ToolMessage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be an any type -- a tool should be allowed to return anything, and we shouldn't be limiting the return type in anyway

Suggested change
output: string | ToolMessage,
output: any,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would we handle things that aren't a string or tool message? Type check and e.g. JSON.stringify?

runId: string,
parentRunId?: string,
tags?: string[]
Expand Down
3 changes: 2 additions & 1 deletion langchain-core/src/callbacks/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { consumeCallback } from "./promises.js";
import { Serialized } from "../load/serializable.js";
import type { DocumentInterface } from "../documents/document.js";
import { isTracingEnabled } from "../utils/callbacks.js";
import { ToolMessage } from "../messages/tool.js";

if (
/* #__PURE__ */ getEnvironmentVariable("LANGCHAIN_TRACING_V2") === "true" &&
Expand Down Expand Up @@ -493,7 +494,7 @@ export class CallbackManagerForToolRun
);
}

async handleToolEnd(output: string): Promise<void> {
async handleToolEnd(output: string | ToolMessage): Promise<void> {
await Promise.all(
this.handlers.map((handler) =>
consumeCallback(async () => {
Expand Down
87 changes: 54 additions & 33 deletions langchain-core/src/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
} from "./language_models/base.js";
import { ensureConfig, type RunnableConfig } from "./runnables/config.js";
import type { RunnableFunc, RunnableInterface } from "./runnables/base.js";
import type { ToolMessage } from "./messages/tool.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type ZodAny = z.ZodObject<any, any, any, any>;
Expand All @@ -34,10 +35,12 @@ export class ToolInputParsingException extends Error {
}
}

export interface StructuredToolInterface<T extends ZodAny = ZodAny>
extends RunnableInterface<
export interface StructuredToolInterface<
T extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
> extends RunnableInterface<
(z.output<T> extends string ? string : never) | z.input<T>,
string
RunOutput
> {
lc_namespace: string[];

Expand All @@ -59,7 +62,7 @@ export interface StructuredToolInterface<T extends ZodAny = ZodAny>
configArg?: Callbacks | RunnableConfig,
/** @deprecated */
tags?: string[]
): Promise<string>;
): Promise<RunOutput>;

name: string;

Expand All @@ -72,10 +75,11 @@ export interface StructuredToolInterface<T extends ZodAny = ZodAny>
* Base class for Tools that accept input of any shape defined by a Zod schema.
*/
export abstract class StructuredTool<
T extends ZodAny = ZodAny
T extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
> extends BaseLangChain<
(z.output<T> extends string ? string : never) | z.input<T>,
string
RunOutput
> {
abstract schema: T | z.ZodEffects<T>;

Expand All @@ -91,7 +95,7 @@ export abstract class StructuredTool<
arg: z.output<T>,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
): Promise<string>;
): Promise<RunOutput>;

/**
* Invokes the tool with the provided input and configuration.
Expand All @@ -102,8 +106,8 @@ export abstract class StructuredTool<
async invoke(
input: (z.output<T> extends string ? string : never) | z.input<T>,
config?: RunnableConfig
): Promise<string> {
return this.call(input, ensureConfig(config));
): Promise<RunOutput> {
return this.call(input, ensureConfig(config)) as Promise<RunOutput>;
}

/**
Expand All @@ -122,7 +126,7 @@ export abstract class StructuredTool<
configArg?: Callbacks | RunnableConfig,
/** @deprecated */
tags?: string[]
): Promise<string> {
): Promise<RunOutput> {
let parsed;
try {
parsed = await this.schema.parseAsync(arg);
Expand Down Expand Up @@ -170,7 +174,10 @@ export abstract class StructuredTool<
returnDirect = false;
}

export interface ToolInterface extends StructuredToolInterface {
export interface ToolInterface<
T extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
> extends StructuredToolInterface<T, RunOutput> {
/**
* @deprecated Use .invoke() instead. Will be removed in 0.3.0.
*
Expand All @@ -183,13 +190,15 @@ export interface ToolInterface extends StructuredToolInterface {
call(
arg: string | undefined | z.input<this["schema"]>,
callbacks?: Callbacks | RunnableConfig
): Promise<string>;
): Promise<RunOutput>;
}

/**
* Base class for Tools that accept input as a string.
*/
export abstract class Tool extends StructuredTool {
export abstract class Tool<
RunOutput extends string | ToolMessage = string
> extends StructuredTool<ZodAny, RunOutput> {
schema = z
.object({ input: z.string().optional() })
.transform((obj) => obj.input);
Expand All @@ -210,7 +219,7 @@ export abstract class Tool extends StructuredTool {
call(
arg: string | undefined | z.input<this["schema"]>,
callbacks?: Callbacks | RunnableConfig
): Promise<string> {
): Promise<RunOutput> {
return super.call(
typeof arg === "string" || !arg ? { input: arg } : arg,
callbacks
Expand All @@ -227,31 +236,37 @@ export interface BaseDynamicToolInput extends ToolParams {
/**
* Interface for the input parameters of the DynamicTool class.
*/
export interface DynamicToolInput extends BaseDynamicToolInput {
export interface DynamicToolInput<
RunOutput extends string | ToolMessage = string
> extends BaseDynamicToolInput {
func: (
input: string,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
) => Promise<string>;
) => Promise<RunOutput>;
}

/**
* Interface for the input parameters of the DynamicStructuredTool class.
*/
export interface DynamicStructuredToolInput<T extends ZodAny = ZodAny>
extends BaseDynamicToolInput {
export interface DynamicStructuredToolInput<
T extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
> extends BaseDynamicToolInput {
func: (
input: z.infer<T>,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
) => Promise<string>;
) => Promise<RunOutput>;
schema: T;
}

/**
* A tool that can be created dynamically from a function, name, and description.
*/
export class DynamicTool extends Tool {
export class DynamicTool<
RunOutput extends string | ToolMessage = string
> extends Tool<RunOutput> {
static lc_name() {
return "DynamicTool";
}
Expand All @@ -260,9 +275,9 @@ export class DynamicTool extends Tool {

description: string;

func: DynamicToolInput["func"];
func: DynamicToolInput<RunOutput>["func"];

constructor(fields: DynamicToolInput) {
constructor(fields: DynamicToolInput<RunOutput>) {
super(fields);
this.name = fields.name;
this.description = fields.description;
Expand All @@ -276,7 +291,7 @@ export class DynamicTool extends Tool {
async call(
arg: string | undefined | z.input<this["schema"]>,
configArg?: RunnableConfig | Callbacks
): Promise<string> {
): Promise<RunOutput> {
const config = parseCallbackConfigArg(configArg);
if (config.runName === undefined) {
config.runName = this.name;
Expand All @@ -289,7 +304,7 @@ export class DynamicTool extends Tool {
input: string,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
): Promise<string> {
): Promise<RunOutput> {
return this.func(input, runManager, config);
}
}
Expand All @@ -301,8 +316,9 @@ export class DynamicTool extends Tool {
* provided function when the tool is called.
*/
export class DynamicStructuredTool<
T extends ZodAny = ZodAny
> extends StructuredTool<T> {
T extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
> extends StructuredTool<T, RunOutput> {
static lc_name() {
return "DynamicStructuredTool";
}
Expand All @@ -311,11 +327,11 @@ export class DynamicStructuredTool<

description: string;

func: DynamicStructuredToolInput<T>["func"];
func: DynamicStructuredToolInput<T, RunOutput>["func"];

schema: T;

constructor(fields: DynamicStructuredToolInput<T>) {
constructor(fields: DynamicStructuredToolInput<T, RunOutput>) {
super(fields);
this.name = fields.name;
this.description = fields.description;
Expand All @@ -332,7 +348,7 @@ export class DynamicStructuredTool<
configArg?: RunnableConfig | Callbacks,
/** @deprecated */
tags?: string[]
): Promise<string> {
): Promise<RunOutput> {
const config = parseCallbackConfigArg(configArg);
if (config.runName === undefined) {
config.runName = this.name;
Expand All @@ -344,7 +360,7 @@ export class DynamicStructuredTool<
arg: z.output<T>,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
): Promise<string> {
): Promise<RunOutput> {
return this.func(arg, runManager, config);
}
}
Expand All @@ -365,6 +381,7 @@ export abstract class BaseToolkit {
/**
* Parameters for the tool function.
* @template {ZodAny} RunInput The input schema for the tool.
* @template {string | ToolMessage} RunOutput The output type for the tool.
*/
interface ToolWrapperParams<RunInput extends ZodAny = ZodAny>
extends ToolParams {
Expand All @@ -390,6 +407,7 @@ interface ToolWrapperParams<RunInput extends ZodAny = ZodAny>
* Creates a new StructuredTool instance with the provided function, name, description, and schema.
* @function
* @template {ZodAny} RunInput The input schema for the tool.
* @template {string | ToolMessage} RunOutput The output type for the tool.
*
* @param {RunnableFunc<RunInput, string>} func - The function to invoke when the tool is called.
* @param fields - An object containing the following properties:
Expand All @@ -399,8 +417,11 @@ interface ToolWrapperParams<RunInput extends ZodAny = ZodAny>
*
* @returns {StructuredTool<RunInput, string>} A new StructuredTool instance.
*/
export function tool<RunInput extends ZodAny = ZodAny>(
func: RunnableFunc<z.infer<RunInput>, string>,
export function tool<
RunInput extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
>(
func: RunnableFunc<z.infer<RunInput>, RunOutput>,
fields: ToolWrapperParams<RunInput>
) {
const schema =
Expand All @@ -409,7 +430,7 @@ export function tool<RunInput extends ZodAny = ZodAny>(

const description =
fields.description ?? schema.description ?? `${fields.name} tool`;
return new DynamicStructuredTool<RunInput>({
return new DynamicStructuredTool<RunInput, RunOutput>({
name: fields.name,
description,
schema: schema as RunInput,
Expand Down
Loading