Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): Preserve direct tool outputs, pass raw tool call into tools if available #7340

Merged
merged 5 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion langchain-core/src/messages/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,30 @@ export interface ToolMessageFieldsWithToolCallId extends BaseMessageFields {
status?: "success" | "error";
}

/**
* Marker parameter for objects that tools can return directly.
*
* If a custom BaseTool is invoked with a ToolCall and the output of custom code is
* not an instance of DirectToolOutput, the output will automatically be coerced to
* a string and wrapped in a ToolMessage.
*/
export interface DirectToolOutput {
readonly lc_direct_tool_output: boolean;
}

export function isDirectToolOutput(x: unknown): x is DirectToolOutput {
return (
x != null &&
typeof x === "object" &&
"lc_direct_tool_output" in x &&
x.lc_direct_tool_output === true
);
}

/**
* Represents a tool message in a conversation.
*/
export class ToolMessage extends BaseMessage {
export class ToolMessage extends BaseMessage implements DirectToolOutput {
static lc_name() {
return "ToolMessage";
}
Expand All @@ -40,6 +60,8 @@ export class ToolMessage extends BaseMessage {
return { tool_call_id: "tool_call_id" };
}

lc_direct_tool_output = true;

/**
* Status of the tool invocation.
* @version 0.2.19
Expand Down
53 changes: 32 additions & 21 deletions langchain-core/src/tools/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
type RunnableConfig,
} from "../runnables/config.js";
import type { RunnableFunc, RunnableInterface } from "../runnables/base.js";
import { ToolCall, ToolMessage } from "../messages/tool.js";
import { isDirectToolOutput, ToolCall, ToolMessage } from "../messages/tool.js";
import { MessageContent } from "../messages/base.js";
import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js";
import { _isToolCall, ToolInputParsingException } from "./utils.js";
Expand Down Expand Up @@ -57,6 +57,11 @@ export interface ToolParams extends BaseLangChainParams {
verboseParsingErrors?: boolean;
}

export type ToolRunnableConfig<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ConfigurableFieldType extends Record<string, any> = Record<string, any>
> = RunnableConfig<ConfigurableFieldType> & { toolCall?: ToolCall };

/**
* Schema for defining tools.
*
Expand Down Expand Up @@ -159,7 +164,7 @@ export abstract class StructuredTool<
protected abstract _call(
arg: z.output<T>,
runManager?: CallbackManagerForToolRun,
parentConfig?: RunnableConfig
parentConfig?: ToolRunnableConfig
): Promise<ToolReturnType>;

/**
Expand All @@ -182,21 +187,23 @@ export abstract class StructuredTool<
| ToolCall
| undefined;

let enrichedConfig: ToolRunnableConfig = ensureConfig(config);
if (_isToolCall(input)) {
tool_call_id = input.id;
toolInput = input.args;
enrichedConfig = {
...enrichedConfig,
toolCall: input,
configurable: {
...enrichedConfig.configurable,
tool_call_id,
},
};
} else {
toolInput = input;
}

const ensuredConfig = ensureConfig(config);
return this.call(toolInput, {
...ensuredConfig,
configurable: {
...ensuredConfig.configurable,
tool_call_id,
},
});
return this.call(toolInput, enrichedConfig);
}

/**
Expand All @@ -211,8 +218,8 @@ export abstract class StructuredTool<
* @returns A Promise that resolves with a string.
*/
async call(
arg: (z.output<T> extends string ? string : never) | z.input<T> | ToolCall,
configArg?: Callbacks | RunnableConfig,
arg: (z.output<T> extends string ? string : never) | z.input<T>,
configArg?: Callbacks | ToolRunnableConfig,
/** @deprecated */
tags?: string[]
): Promise<ToolReturnType> {
Expand All @@ -229,7 +236,7 @@ export abstract class StructuredTool<
}

const config = parseCallbackConfigArg(configArg);
const callbackManager_ = await CallbackManager.configure(
const callbackManager_ = CallbackManager.configure(
config.callbacks,
this.callbacks,
config.tags || tags,
Expand Down Expand Up @@ -350,7 +357,7 @@ export interface DynamicToolInput extends BaseDynamicToolInput {
func: (
input: string,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
config?: ToolRunnableConfig
) => Promise<ToolReturnType>;
}

Expand Down Expand Up @@ -400,7 +407,7 @@ export class DynamicTool extends Tool {
*/
async call(
arg: string | undefined | z.input<this["schema"]> | ToolCall,
configArg?: RunnableConfig | Callbacks
configArg?: ToolRunnableConfig | Callbacks
): Promise<ToolReturnType> {
const config = parseCallbackConfigArg(configArg);
if (config.runName === undefined) {
Expand All @@ -413,7 +420,7 @@ export class DynamicTool extends Tool {
async _call(
input: string,
runManager?: CallbackManagerForToolRun,
parentConfig?: RunnableConfig
parentConfig?: ToolRunnableConfig
): Promise<ToolReturnType> {
return this.func(input, runManager, parentConfig);
}
Expand Down Expand Up @@ -553,26 +560,30 @@ interface ToolWrapperParams<
* @returns {DynamicStructuredTool<T>} A new StructuredTool instance.
*/
export function tool<T extends z.ZodString>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
func: RunnableFunc<z.output<T>, ToolReturnType, ToolRunnableConfig>,
fields: ToolWrapperParams<T>
): DynamicTool;

export function tool<T extends ZodObjectAny>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
func: RunnableFunc<z.output<T>, ToolReturnType, ToolRunnableConfig>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T>;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function tool<T extends Record<string, any>>(
func: RunnableFunc<T, ToolReturnType>,
func: RunnableFunc<T, ToolReturnType, ToolRunnableConfig>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T>;

export function tool<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends ZodObjectAny | z.ZodString | Record<string, any> = ZodObjectAny
>(
func: RunnableFunc<T extends ZodObjectAny ? z.output<T> : T, ToolReturnType>,
func: RunnableFunc<
T extends ZodObjectAny ? z.output<T> : T,
ToolReturnType,
ToolRunnableConfig
>,
fields: ToolWrapperParams<T>
):
| DynamicStructuredTool<T extends ZodObjectAny ? T : ZodObjectAny>
Expand Down Expand Up @@ -649,7 +660,7 @@ function _formatToolOutput(params: {
toolCallId?: string;
}): ToolReturnType {
const { content, artifact, toolCallId } = params;
if (toolCallId) {
if (toolCallId && !isDirectToolOutput(content)) {
if (
typeof content === "string" ||
(Array.isArray(content) &&
Expand Down
86 changes: 73 additions & 13 deletions langchain-core/src/tools/tests/tools.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ import { z } from "zod";

import { DynamicStructuredTool, tool } from "../index.js";
import { ToolMessage } from "../../messages/tool.js";
import { RunnableConfig } from "../../runnables/types.js";

test("Tool should error if responseFormat is content_and_artifact but the function doesn't return a tuple", async () => {
const weatherSchema = z.object({
location: z.string(),
});

const weatherTool = tool(
(_) => {
// Should be able to type this as base RunnableConfig without issue,
// though true type is more specific
(_, _config: RunnableConfig) => {
return "str";
},
{
Expand Down Expand Up @@ -51,9 +54,15 @@ test("Does not return tool message if responseFormat is content_and_artifact and
const weatherSchema = z.object({
location: z.string(),
});
const toolCall = {
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
} as const;

const weatherTool = tool(
(input) => {
(input, config) => {
expect(config.toolCall).toEqual(toolCall);
return ["msg_content", input];
},
{
Expand All @@ -63,11 +72,7 @@ test("Does not return tool message if responseFormat is content_and_artifact and
}
);

const toolResult = await weatherTool.invoke({
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
});
const toolResult = await weatherTool.invoke(toolCall);

expect(toolResult).toBe("msg_content");
});
Expand All @@ -77,8 +82,16 @@ test("Returns tool message if responseFormat is content_and_artifact and returns
location: z.string(),
});

const toolCall = {
id: "testid",
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
} as const;

const weatherTool = tool(
(input) => {
(input, config) => {
expect(config.toolCall).toEqual(toolCall);
return ["msg_content", input];
},
{
Expand All @@ -88,23 +101,63 @@ test("Returns tool message if responseFormat is content_and_artifact and returns
}
);

const toolResult = await weatherTool.invoke({
const toolResult = await weatherTool.invoke(toolCall);

expect(toolResult).toBeInstanceOf(ToolMessage);
expect(toolResult.content).toBe("msg_content");
expect(toolResult.artifact).toEqual({ location: "San Francisco" });
expect(toolResult.name).toBe("weather");
});

test("Does not double wrap a returned tool message even if a tool call with id is passed in", async () => {
const weatherSchema = z.object({
location: z.string(),
});

const toolCall = {
id: "testid",
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
});
} as const;

const weatherTool = tool(
(_, config) => {
expect(config.toolCall).toEqual(toolCall);
return new ToolMessage({
tool_call_id: "not_original",
content: "bar",
name: "baz",
});
},
{
name: "weather",
schema: weatherSchema,
}
);

const toolResult = await weatherTool.invoke(toolCall);

expect(toolResult).toBeInstanceOf(ToolMessage);
expect(toolResult.content).toBe("msg_content");
expect(toolResult.artifact).toEqual({ location: "San Francisco" });
expect(toolResult.name).toBe("weather");
expect(toolResult.tool_call_id).toBe("not_original");
expect(toolResult.content).toBe("bar");
expect(toolResult.name).toBe("baz");
});

test("Tool can accept single string input", async () => {
const toolCall = {
id: "testid",
args: { input: "b" },
name: "string_tool",
type: "tool_call",
} as const;

const stringTool = tool<z.ZodString>(
(input: string, config): string => {
expect(config).toMatchObject({ configurable: { foo: "bar" } });
if (config.configurable.usesToolCall) {
expect(config.toolCall).toEqual(toolCall);
}
return `${input}a`;
},
{
Expand All @@ -116,6 +169,13 @@ test("Tool can accept single string input", async () => {

const result = await stringTool.invoke("b", { configurable: { foo: "bar" } });
expect(result).toBe("ba");

const result2 = await stringTool.invoke(toolCall, {
configurable: { foo: "bar", usesToolCall: true },
});
expect(result2).toBeInstanceOf(ToolMessage);
expect(result2.content).toBe("ba");
expect(result2.name).toBe("string_tool");
});

test("Tool declared with JSON schema", async () => {
Expand Down
Loading