Skip to content

Commit

Permalink
core[minor]: Allow tool functions to return ToolMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jul 8, 2024
1 parent 3d52258 commit 6b8e1a1
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 35 deletions.
3 changes: 2 additions & 1 deletion langchain-core/src/callbacks/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
import type { SerializedFields } from "../load/map_keys.js";
import type { DocumentInterface } from "../documents/document.js";
import { getEnvironmentVariable } from "../utils/env.js";
import { ToolMessage } from "../messages/tool.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type Error = any;
Expand Down Expand Up @@ -197,7 +198,7 @@ abstract class BaseCallbackHandlerMethodsClass {
* Called at the end of a Tool run, with the tool output and the run ID.
*/
handleToolEnd?(
output: string,
output: string | ToolMessage,
runId: string,
parentRunId?: string,
tags?: string[]
Expand Down
3 changes: 2 additions & 1 deletion langchain-core/src/callbacks/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { consumeCallback } from "./promises.js";
import { Serialized } from "../load/serializable.js";
import type { DocumentInterface } from "../documents/document.js";
import { isTracingEnabled } from "../utils/callbacks.js";
import { ToolMessage } from "../messages/tool.js";

if (
/* #__PURE__ */ getEnvironmentVariable("LANGCHAIN_TRACING_V2") === "true" &&
Expand Down Expand Up @@ -493,7 +494,7 @@ export class CallbackManagerForToolRun
);
}

async handleToolEnd(output: string): Promise<void> {
async handleToolEnd(output: string | ToolMessage): Promise<void> {
await Promise.all(
this.handlers.map((handler) =>
consumeCallback(async () => {
Expand Down
87 changes: 54 additions & 33 deletions langchain-core/src/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
} from "./language_models/base.js";
import { ensureConfig, type RunnableConfig } from "./runnables/config.js";
import type { RunnableFunc, RunnableInterface } from "./runnables/base.js";
import type { ToolMessage } from "./messages/tool.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type ZodAny = z.ZodObject<any, any, any, any>;
Expand All @@ -34,10 +35,12 @@ export class ToolInputParsingException extends Error {
}
}

export interface StructuredToolInterface<T extends ZodAny = ZodAny>
extends RunnableInterface<
export interface StructuredToolInterface<
T extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
> extends RunnableInterface<
(z.output<T> extends string ? string : never) | z.input<T>,
string
RunOutput
> {
lc_namespace: string[];

Expand All @@ -59,7 +62,7 @@ export interface StructuredToolInterface<T extends ZodAny = ZodAny>
configArg?: Callbacks | RunnableConfig,
/** @deprecated */
tags?: string[]
): Promise<string>;
): Promise<RunOutput>;

name: string;

Expand All @@ -72,10 +75,11 @@ export interface StructuredToolInterface<T extends ZodAny = ZodAny>
* Base class for Tools that accept input of any shape defined by a Zod schema.
*/
export abstract class StructuredTool<
T extends ZodAny = ZodAny
T extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
> extends BaseLangChain<
(z.output<T> extends string ? string : never) | z.input<T>,
string
RunOutput
> {
abstract schema: T | z.ZodEffects<T>;

Expand All @@ -91,7 +95,7 @@ export abstract class StructuredTool<
arg: z.output<T>,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
): Promise<string>;
): Promise<RunOutput>;

/**
* Invokes the tool with the provided input and configuration.
Expand All @@ -102,8 +106,8 @@ export abstract class StructuredTool<
async invoke(
input: (z.output<T> extends string ? string : never) | z.input<T>,
config?: RunnableConfig
): Promise<string> {
return this.call(input, ensureConfig(config));
): Promise<RunOutput> {
return this.call(input, ensureConfig(config)) as Promise<RunOutput>;
}

/**
Expand All @@ -122,7 +126,7 @@ export abstract class StructuredTool<
configArg?: Callbacks | RunnableConfig,
/** @deprecated */
tags?: string[]
): Promise<string> {
): Promise<RunOutput> {
let parsed;
try {
parsed = await this.schema.parseAsync(arg);
Expand Down Expand Up @@ -170,7 +174,10 @@ export abstract class StructuredTool<
returnDirect = false;
}

export interface ToolInterface extends StructuredToolInterface {
export interface ToolInterface<
T extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
> extends StructuredToolInterface<T, RunOutput> {
/**
* @deprecated Use .invoke() instead. Will be removed in 0.3.0.
*
Expand All @@ -183,13 +190,15 @@ export interface ToolInterface extends StructuredToolInterface {
call(
arg: string | undefined | z.input<this["schema"]>,
callbacks?: Callbacks | RunnableConfig
): Promise<string>;
): Promise<RunOutput>;
}

/**
* Base class for Tools that accept input as a string.
*/
export abstract class Tool extends StructuredTool {
export abstract class Tool<
RunOutput extends string | ToolMessage = string
> extends StructuredTool<ZodAny, RunOutput> {
schema = z
.object({ input: z.string().optional() })
.transform((obj) => obj.input);
Expand All @@ -210,7 +219,7 @@ export abstract class Tool extends StructuredTool {
call(
arg: string | undefined | z.input<this["schema"]>,
callbacks?: Callbacks | RunnableConfig
): Promise<string> {
): Promise<RunOutput> {
return super.call(
typeof arg === "string" || !arg ? { input: arg } : arg,
callbacks
Expand All @@ -227,31 +236,37 @@ export interface BaseDynamicToolInput extends ToolParams {
/**
* Interface for the input parameters of the DynamicTool class.
*/
export interface DynamicToolInput extends BaseDynamicToolInput {
export interface DynamicToolInput<
RunOutput extends string | ToolMessage = string
> extends BaseDynamicToolInput {
func: (
input: string,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
) => Promise<string>;
) => Promise<RunOutput>;
}

/**
* Interface for the input parameters of the DynamicStructuredTool class.
*/
export interface DynamicStructuredToolInput<T extends ZodAny = ZodAny>
extends BaseDynamicToolInput {
export interface DynamicStructuredToolInput<
T extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
> extends BaseDynamicToolInput {
func: (
input: z.infer<T>,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
) => Promise<string>;
) => Promise<RunOutput>;
schema: T;
}

/**
* A tool that can be created dynamically from a function, name, and description.
*/
export class DynamicTool extends Tool {
export class DynamicTool<
RunOutput extends string | ToolMessage = string
> extends Tool<RunOutput> {
static lc_name() {
return "DynamicTool";
}
Expand All @@ -260,9 +275,9 @@ export class DynamicTool extends Tool {

description: string;

func: DynamicToolInput["func"];
func: DynamicToolInput<RunOutput>["func"];

constructor(fields: DynamicToolInput) {
constructor(fields: DynamicToolInput<RunOutput>) {
super(fields);
this.name = fields.name;
this.description = fields.description;
Expand All @@ -276,7 +291,7 @@ export class DynamicTool extends Tool {
async call(
arg: string | undefined | z.input<this["schema"]>,
configArg?: RunnableConfig | Callbacks
): Promise<string> {
): Promise<RunOutput> {
const config = parseCallbackConfigArg(configArg);
if (config.runName === undefined) {
config.runName = this.name;
Expand All @@ -289,7 +304,7 @@ export class DynamicTool extends Tool {
input: string,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
): Promise<string> {
): Promise<RunOutput> {
return this.func(input, runManager, config);
}
}
Expand All @@ -301,8 +316,9 @@ export class DynamicTool extends Tool {
* provided function when the tool is called.
*/
export class DynamicStructuredTool<
T extends ZodAny = ZodAny
> extends StructuredTool<T> {
T extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
> extends StructuredTool<T, RunOutput> {
static lc_name() {
return "DynamicStructuredTool";
}
Expand All @@ -311,11 +327,11 @@ export class DynamicStructuredTool<

description: string;

func: DynamicStructuredToolInput<T>["func"];
func: DynamicStructuredToolInput<T, RunOutput>["func"];

schema: T;

constructor(fields: DynamicStructuredToolInput<T>) {
constructor(fields: DynamicStructuredToolInput<T, RunOutput>) {
super(fields);
this.name = fields.name;
this.description = fields.description;
Expand All @@ -332,7 +348,7 @@ export class DynamicStructuredTool<
configArg?: RunnableConfig | Callbacks,
/** @deprecated */
tags?: string[]
): Promise<string> {
): Promise<RunOutput> {
const config = parseCallbackConfigArg(configArg);
if (config.runName === undefined) {
config.runName = this.name;
Expand All @@ -344,7 +360,7 @@ export class DynamicStructuredTool<
arg: z.output<T>,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
): Promise<string> {
): Promise<RunOutput> {
return this.func(arg, runManager, config);
}
}
Expand All @@ -365,6 +381,7 @@ export abstract class BaseToolkit {
/**
* Parameters for the tool function.
* @template {ZodAny} RunInput The input schema for the tool.
* @template {string | ToolMessage} RunOutput The output type for the tool.
*/
interface ToolWrapperParams<RunInput extends ZodAny = ZodAny>
extends ToolParams {
Expand All @@ -390,6 +407,7 @@ interface ToolWrapperParams<RunInput extends ZodAny = ZodAny>
* Creates a new StructuredTool instance with the provided function, name, description, and schema.
* @function
* @template {ZodAny} RunInput The input schema for the tool.
* @template {string | ToolMessage} RunOutput The output type for the tool.
*
* @param {RunnableFunc<RunInput, string>} func - The function to invoke when the tool is called.
* @param fields - An object containing the following properties:
Expand All @@ -399,8 +417,11 @@ interface ToolWrapperParams<RunInput extends ZodAny = ZodAny>
*
* @returns {StructuredTool<RunInput, string>} A new StructuredTool instance.
*/
export function tool<RunInput extends ZodAny = ZodAny>(
func: RunnableFunc<z.infer<RunInput>, string>,
export function tool<
RunInput extends ZodAny = ZodAny,
RunOutput extends string | ToolMessage = string
>(
func: RunnableFunc<z.infer<RunInput>, RunOutput>,
fields: ToolWrapperParams<RunInput>
) {
const schema =
Expand All @@ -409,7 +430,7 @@ export function tool<RunInput extends ZodAny = ZodAny>(

const description =
fields.description ?? schema.description ?? `${fields.name} tool`;
return new DynamicStructuredTool<RunInput>({
return new DynamicStructuredTool<RunInput, RunOutput>({
name: fields.name,
description,
schema: schema as RunInput,
Expand Down

0 comments on commit 6b8e1a1

Please sign in to comment.