From 297a7918f329f23f8c352db049dc820575f6eaaf Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Wed, 31 Jul 2024 19:06:42 -0700 Subject: [PATCH] Allow tools to be initialized with JSON schema --- langchain-core/src/tools/index.ts | 53 +++++++---- langchain-core/src/tools/tests/tools.test.ts | 99 +++++++++++++++++++- 2 files changed, 132 insertions(+), 20 deletions(-) diff --git a/langchain-core/src/tools/index.ts b/langchain-core/src/tools/index.ts index c9e2c98402ae..49fd8ddc8cfc 100644 --- a/langchain-core/src/tools/index.ts +++ b/langchain-core/src/tools/index.ts @@ -20,6 +20,7 @@ import { ZodObjectAny } from "../types/zod.js"; import { MessageContent } from "../messages/base.js"; import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js"; import { _isToolCall, ToolInputParsingException } from "./utils.js"; +import { isZodSchema } from "../utils/types/is_zod_schema.js"; export { ToolInputParsingException }; @@ -319,16 +320,18 @@ export interface DynamicToolInput extends BaseDynamicToolInput { * Interface for the input parameters of the DynamicStructuredTool class. */ export interface DynamicStructuredToolInput< - T extends ZodObjectAny = ZodObjectAny + T extends ZodObjectAny | Record = ZodObjectAny > extends BaseDynamicToolInput { func: ( input: BaseDynamicToolInput["responseFormat"] extends "content_and_artifact" ? ToolCall - : z.infer, + : T extends ZodObjectAny + ? z.infer + : T, runManager?: CallbackManagerForToolRun, config?: RunnableConfig ) => Promise; - schema: T; + schema: T extends ZodObjectAny ? T : T; } /** @@ -384,8 +387,8 @@ export class DynamicTool extends Tool { * provided function when the tool is called. */ export class DynamicStructuredTool< - T extends ZodObjectAny = ZodObjectAny -> extends StructuredTool { + T extends ZodObjectAny | Record = ZodObjectAny +> extends StructuredTool { static lc_name() { return "DynamicStructuredTool"; } @@ -396,7 +399,7 @@ export class DynamicStructuredTool< func: DynamicStructuredToolInput["func"]; - schema: T; + schema: T extends ZodObjectAny ? T : ZodObjectAny; constructor(fields: DynamicStructuredToolInput) { super(fields); @@ -404,14 +407,16 @@ export class DynamicStructuredTool< this.description = fields.description; this.func = fields.func; this.returnDirect = fields.returnDirect ?? this.returnDirect; - this.schema = fields.schema; + this.schema = ( + isZodSchema(fields.schema) ? fields.schema : z.object({}) + ) as T extends ZodObjectAny ? T : ZodObjectAny; } /** * @deprecated Use .invoke() instead. Will be removed in 0.3.0. */ async call( - arg: z.output | ToolCall, + arg: (T extends ZodObjectAny ? z.output : T) | ToolCall, configArg?: RunnableConfig | Callbacks, /** @deprecated */ tags?: string[] @@ -424,11 +429,11 @@ export class DynamicStructuredTool< } protected _call( - arg: z.output | ToolCall, + arg: (T extends ZodObjectAny ? z.output : T) | ToolCall, runManager?: CallbackManagerForToolRun, parentConfig?: RunnableConfig ): Promise { - return this.func(arg, runManager, parentConfig); + return this.func(arg as any, runManager, parentConfig); } } @@ -447,10 +452,13 @@ export abstract class BaseToolkit { /** * Parameters for the tool function. - * @template {ZodObjectAny | z.ZodString = ZodObjectAny} RunInput The input schema for the tool. Either any Zod object, or a Zod string. + * @template {ZodObjectAny | z.ZodString | Record = ZodObjectAny} RunInput The input schema for the tool. Either any Zod object, a Zod string, or JSON schema. */ interface ToolWrapperParams< - RunInput extends ZodObjectAny | z.ZodString = ZodObjectAny + RunInput extends + | ZodObjectAny + | z.ZodString + | Record = ZodObjectAny > extends ToolParams { /** * The name of the tool. If using with an LLM, this @@ -494,18 +502,25 @@ interface ToolWrapperParams< * * @returns {DynamicStructuredTool} A new StructuredTool instance. */ -export function tool( +export function tool( func: RunnableFunc, ToolReturnType>, fields: ToolWrapperParams ): DynamicTool; -export function tool( +export function tool( func: RunnableFunc, ToolReturnType>, fields: ToolWrapperParams ): DynamicStructuredTool; -export function tool( - func: RunnableFunc, ToolReturnType>, +export function tool>( + func: RunnableFunc, + fields: ToolWrapperParams +): DynamicStructuredTool; + +export function tool< + T extends ZodObjectAny | z.ZodString | Record = ZodObjectAny +>( + func: RunnableFunc : T, ToolReturnType>, fields: ToolWrapperParams ): | DynamicStructuredTool @@ -518,7 +533,7 @@ export function tool( fields.description ?? fields.schema?.description ?? `${fields.name} tool`, - func, + func: func as any, }); } @@ -528,7 +543,7 @@ export function tool( return new DynamicStructuredTool({ ...fields, description, - schema: fields.schema as T extends ZodObjectAny ? T : ZodObjectAny, + schema: fields.schema as any, // TODO: Consider moving into DynamicStructuredTool constructor func: async (input, runManager, config) => { return new Promise((resolve, reject) => { @@ -539,7 +554,7 @@ export function tool( childConfig, async () => { try { - resolve(func(input, childConfig)); + resolve(func(input as any, childConfig)); } catch (e) { reject(e); } diff --git a/langchain-core/src/tools/tests/tools.test.ts b/langchain-core/src/tools/tests/tools.test.ts index bf577a4a1dc9..4c38800b3489 100644 --- a/langchain-core/src/tools/tests/tools.test.ts +++ b/langchain-core/src/tools/tests/tools.test.ts @@ -1,6 +1,6 @@ import { test, expect } from "@jest/globals"; import { z } from "zod"; -import { tool } from "../index.js"; +import { DynamicStructuredTool, tool } from "../index.js"; import { ToolMessage } from "../../messages/tool.js"; test("Tool should error if responseFormat is content_and_artifact but the function doesn't return a tuple", async () => { @@ -115,3 +115,100 @@ test("Tool can accept single string input", async () => { const result = await stringTool.invoke("b"); expect(result).toBe("ba"); }); + +test("Tool declared with JSON schema", async () => { + const weatherSchema = { + type: "object", + properties: { + location: { + type: "string", + description: "A place", + }, + }, + required: ["location"], + }; + const weatherTool = tool( + (_) => { + return "Sunny"; + }, + { + name: "weather", + schema: weatherSchema, + } + ); + + const weatherTool2 = new DynamicStructuredTool({ + name: "weather", + description: "get the weather", + func: async (_) => { + return "Sunny"; + }, + schema: weatherSchema, + }); + // No validation on JSON schema tools + await weatherTool.invoke({ + somethingSilly: true, + }); + await weatherTool2.invoke({ + somethingSilly: true, + }); +}); + +test("Tool input typing is enforced", async () => { + const weatherSchema = z.object({ + location: z.string(), + }); + + const weatherTool = tool( + (_) => { + return "Sunny"; + }, + { + name: "weather", + schema: weatherSchema, + } + ); + + const weatherTool2 = new DynamicStructuredTool({ + name: "weather", + description: "get the weather", + func: async (_) => { + return "Sunny"; + }, + schema: weatherSchema, + }); + + const weatherTool3 = tool( + async (_) => { + return "Sunny"; + }, + { + name: "weather", + description: "get the weather", + schema: z.string(), + } + ); + + await expect(async () => { + await weatherTool.invoke({ + // @ts-expect-error Invalid argument + badval: "someval", + }); + }).rejects.toThrow(); + const res = await weatherTool.invoke({ + location: "somewhere", + }); + expect(res).toEqual("Sunny"); + await expect(async () => { + await weatherTool2.invoke({ + // @ts-expect-error Invalid argument + badval: "someval", + }); + }).rejects.toThrow(); + const res2 = await weatherTool2.invoke({ + location: "someval", + }); + expect(res2).toEqual("Sunny"); + const res3 = await weatherTool3.invoke("blah"); + expect(res3).toEqual("Sunny"); +});