Skip to content

Commit

Permalink
revert obj return type changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jun 17, 2024
1 parent 1ef2f5c commit fd9161c
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 136 deletions.
3 changes: 1 addition & 2 deletions langchain-core/src/callbacks/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ abstract class BaseCallbackHandlerMethodsClass {
* Called at the end of a Tool run, with the tool output and the run ID.
*/
handleToolEnd?(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
output: string | Record<string, any>,
output: string,
runId: string,
parentRunId?: string,
tags?: string[]
Expand Down
3 changes: 1 addition & 2 deletions langchain-core/src/callbacks/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,7 @@ export class CallbackManagerForToolRun
);
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
async handleToolEnd(output: string | Record<string, any>): Promise<void> {
async handleToolEnd(output: string): Promise<void> {
await Promise.all(
this.handlers.map((handler) =>
consumeCallback(async () => {
Expand Down
6 changes: 2 additions & 4 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,9 @@ export abstract class BaseChatModel<
*/
bindTools?<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends z.ZodObject<any, any, any, any> = z.ZodObject<any, any, any, any>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends string | Record<string, any> = string
T extends z.ZodObject<any, any, any, any> = z.ZodObject<any, any, any, any>
>(
tools: (StructuredToolInterface<T, RunOutput> | Record<string, unknown>)[],
tools: (StructuredToolInterface<T> | Record<string, unknown>)[],
kwargs?: Partial<CallOptions>
): Runnable<BaseLanguageModelInput, OutputMessageType, CallOptions>;

Expand Down
3 changes: 1 addition & 2 deletions langchain-core/src/messages/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ export type MessageContentComplex =
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| (Record<string, any> & { type?: "text" | "image_url" | "tool" | string })
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| (Record<string, any> & { type?: never })
| Record<string, unknown>;
| (Record<string, any> & { type?: never });

export type MessageContent = string | MessageContentComplex[];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1825,7 +1825,7 @@ test("Runnable streamEvents method with simple tools", async () => {

test("Runnable streamEvents method with tools that return objects", async () => {
const adderFunc = (_params: { x: number; y: number }) => {
return { sum: 3 };
return JSON.stringify({ sum: 3 });
};
const parameterlessTool = tool(adderFunc, {
name: "parameterless",
Expand All @@ -1847,9 +1847,7 @@ test("Runnable streamEvents method with tools that return objects", async () =>
},
{
data: {
output: {
sum: 3,
},
output: JSON.stringify({ sum: 3 }),
},
event: "on_tool_end",
metadata: {},
Expand Down Expand Up @@ -1885,7 +1883,7 @@ test("Runnable streamEvents method with tools that return objects", async () =>
tags: [],
},
{
data: { output: { sum: 3 } },
data: { output: JSON.stringify({ sum: 3 }) },
event: "on_tool_end",
metadata: {},
name: "with_parameters",
Expand Down
83 changes: 30 additions & 53 deletions langchain-core/src/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@ export class ToolInputParsingException extends Error {

export interface StructuredToolInterface<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends ZodAny = ZodAny,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends string | Record<string, any> = string
T extends ZodAny = ZodAny
> extends RunnableInterface<
(z.output<T> extends string ? string : never) | z.input<T>,
RunOutput
string
> {
lc_namespace: string[];

Expand All @@ -63,7 +61,7 @@ export interface StructuredToolInterface<
configArg?: Callbacks | RunnableConfig,
/** @deprecated */
tags?: string[]
): Promise<RunOutput>;
): Promise<string>;

name: string;

Expand All @@ -77,12 +75,10 @@ export interface StructuredToolInterface<
*/
export abstract class StructuredTool<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends ZodAny = ZodAny,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends string | Record<string, any> = string
T extends ZodAny = ZodAny
> extends BaseLangChain<
(z.output<T> extends string ? string : never) | z.input<T>,
RunOutput
string
> {
abstract schema: T | z.ZodEffects<T>;

Expand All @@ -98,7 +94,7 @@ export abstract class StructuredTool<
arg: z.output<T>,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
): Promise<RunOutput>;
): Promise<string>;

/**
* Invokes the tool with the provided input and configuration.
Expand All @@ -109,7 +105,7 @@ export abstract class StructuredTool<
async invoke(
input: (z.output<T> extends string ? string : never) | z.input<T>,
config?: RunnableConfig
): Promise<RunOutput> {
): Promise<string> {
return this.call(input, ensureConfig(config));
}

Expand All @@ -129,7 +125,7 @@ export abstract class StructuredTool<
configArg?: Callbacks | RunnableConfig,
/** @deprecated */
tags?: string[]
): Promise<RunOutput> {
): Promise<string> {
let parsed;
try {
parsed = await this.schema.parseAsync(arg);
Expand Down Expand Up @@ -196,11 +192,7 @@ export interface ToolInterface extends StructuredToolInterface {
/**
* Base class for Tools that accept input as a string.
*/
export abstract class Tool<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends string | Record<string, any> = string
// eslint-disable-next-line @typescript-eslint/no-explicit-any
> extends StructuredTool<any, RunOutput> {
export abstract class Tool extends StructuredTool {
schema = z
.object({ input: z.string().optional() })
.transform((obj) => obj.input);
Expand All @@ -221,7 +213,7 @@ export abstract class Tool<
call(
arg: string | undefined | z.input<this["schema"]>,
callbacks?: Callbacks | RunnableConfig
): Promise<RunOutput> {
): Promise<string> {
return super.call(
typeof arg === "string" || !arg ? { input: arg } : arg,
callbacks
Expand All @@ -238,41 +230,33 @@ export interface BaseDynamicToolInput extends ToolParams {
/**
* Interface for the input parameters of the DynamicTool class.
*/
export interface DynamicToolInput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends string | Record<string, any> = string
> extends BaseDynamicToolInput {
export interface DynamicToolInput extends BaseDynamicToolInput {
func: (
input: string,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
) => Promise<RunOutput>;
) => Promise<string>;
}

/**
* Interface for the input parameters of the DynamicStructuredTool class.
*/
export interface DynamicStructuredToolInput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends ZodAny = ZodAny,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends string | Record<string, any> = string
T extends ZodAny = ZodAny
> extends BaseDynamicToolInput {
func: (
input: z.infer<T>,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
) => Promise<RunOutput>;
) => Promise<string>;
schema: T;
}

/**
* A tool that can be created dynamically from a function, name, and description.
*/
export class DynamicTool<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends string | Record<string, any> = string
> extends Tool<RunOutput> {
export class DynamicTool extends Tool {
static lc_name() {
return "DynamicTool";
}
Expand All @@ -281,9 +265,9 @@ export class DynamicTool<

description: string;

func: DynamicToolInput<RunOutput>["func"];
func: DynamicToolInput["func"];

constructor(fields: DynamicToolInput<RunOutput>) {
constructor(fields: DynamicToolInput) {
super(fields);
this.name = fields.name;
this.description = fields.description;
Expand All @@ -297,7 +281,7 @@ export class DynamicTool<
async call(
arg: string | undefined | z.input<this["schema"]>,
configArg?: RunnableConfig | Callbacks
): Promise<RunOutput> {
): Promise<string> {
const config = parseCallbackConfigArg(configArg);
if (config.runName === undefined) {
config.runName = this.name;
Expand All @@ -310,7 +294,7 @@ export class DynamicTool<
input: string,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
): Promise<RunOutput> {
): Promise<string> {
return this.func(input, runManager, config);
}
}
Expand All @@ -323,10 +307,8 @@ export class DynamicTool<
*/
export class DynamicStructuredTool<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends ZodAny = ZodAny,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends string | Record<string, any> = string
> extends StructuredTool<T, RunOutput> {
T extends ZodAny = ZodAny
> extends StructuredTool<T> {
static lc_name() {
return "DynamicStructuredTool";
}
Expand All @@ -335,11 +317,11 @@ export class DynamicStructuredTool<

description: string;

func: DynamicStructuredToolInput<T, RunOutput>["func"];
func: DynamicStructuredToolInput<T>["func"];

schema: T;

constructor(fields: DynamicStructuredToolInput<T, RunOutput>) {
constructor(fields: DynamicStructuredToolInput<T>) {
super(fields);
this.name = fields.name;
this.description = fields.description;
Expand All @@ -356,7 +338,7 @@ export class DynamicStructuredTool<
configArg?: RunnableConfig | Callbacks,
/** @deprecated */
tags?: string[]
): Promise<RunOutput> {
): Promise<string> {
const config = parseCallbackConfigArg(configArg);
if (config.runName === undefined) {
config.runName = this.name;
Expand All @@ -368,7 +350,7 @@ export class DynamicStructuredTool<
arg: z.output<T>,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
): Promise<RunOutput> {
): Promise<string> {
return this.func(arg, runManager, config);
}
}
Expand Down Expand Up @@ -414,29 +396,24 @@ interface ToolWrapperParams<RunInput extends ZodAny = ZodAny>
* Creates a new StructuredTool instance with the provided function, name, description, and schema.
* @function
* @template {ZodAny} RunInput The input schema for the tool.
* @template {string | Record<string, any>} RunOutput The output schema for the tool.
*
* @param {RunnableFunc<RunInput, RunOutput>} func - The function to invoke when the tool is called.
* @param {RunnableFunc<RunInput, string>} func - The function to invoke when the tool is called.
* @param fields - An object containing the following properties:
* @param {string} fields.name The name of the tool.
* @param {string | undefined} fields.description The description of the tool. Defaults to `${fields.name} tool`.
* @param {z.ZodObject<any, any, any, any>} fields.schema The Zod schema defining the input for the tool.
*
* @returns {StructuredTool<RunInput, RunOutput>} A new StructuredTool instance.
* @returns {StructuredTool<RunInput, string>} A new StructuredTool instance.
*/
export function tool<
RunInput extends ZodAny = ZodAny,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends string | Record<string, any> = string
>(
func: RunnableFunc<z.infer<RunInput>, RunOutput>,
export function tool<RunInput extends ZodAny = ZodAny>(
func: RunnableFunc<z.infer<RunInput>, string>,
fields: ToolWrapperParams<RunInput>
) {
const schema =
fields.schema ??
z.object({ input: z.string().optional() }).transform((obj) => obj.input);

return new DynamicStructuredTool<RunInput, RunOutput>({
return new DynamicStructuredTool<RunInput>({
name: fields.name,
description: fields.description ?? `${fields.name} tool`,
schema: schema as RunInput,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { expect } from "@jest/globals";
ync testStructuredFewShotExamplesToolOimport { expect } from "@jest/globals";
import { BaseChatModelCallOptions } from "@langchain/core/language_models/chat_models";
import {
AIMessage,
Expand Down Expand Up @@ -340,73 +340,6 @@ export abstract class ChatModelIntegrationTests<
expect(resultStringContent).toBeInstanceOf(this.invokeResponseType);
}

/**
* Test that model can process few-shot examples with tool calls
* that return objects instead of strings.
* @returns {Promise<void>}
*/
async testStructuredFewShotExamplesToolObjectReturn(
callOptions?: InstanceType<this["Cls"]>["ParsedCallOptions"]
) {
if (!this.chatModelHasToolCalling) {
console.log("Test requires tool calling. Skipping...");
return;
}
const adderFunc = (params: { a: number; b: number }) => ({
sum: params.a + params.b,
});

const model = new this.Cls(this.constructorArgs);
const adderTool = tool(adderFunc, {
name: "AdderTool",
schema: z.object({
a: z.number().int(),
b: z.number().int(),
}),
description: "Add two numbers",
});
if (!model.bindTools) {
throw new Error("bindTools undefined. Cannot test few-shot examples.");
}
const modelWithTools = model.bindTools([adderTool]);
const functionName = adderTool.name;
const functionArgs = { a: 1, b: 2 };

const { functionId } = this;
const functionResult = await adderTool.invoke(functionArgs);

const messagesStringContent = [
new HumanMessage("What is 1 + 2"),
new AIMessage({
content: "",
tool_calls: [
{
name: functionName,
args: functionArgs,
id: functionId,
},
],
}),
new ToolMessage(
{
content: [functionResult],
},
functionId,
functionName
),
new AIMessage({
content: [functionResult],
}),
new HumanMessage("What is 3 + 4"),
];

const resultStringContent = await modelWithTools.invoke(
messagesStringContent,
callOptions
);
expect(resultStringContent).toBeInstanceOf(this.invokeResponseType);
}

async testWithStructuredOutput() {
if (!this.chatModelHasStructuredOutput) {
console.log("Test requires withStructuredOutput. Skipping...");
Expand Down

0 comments on commit fd9161c

Please sign in to comment.