Skip to content

Commit

Permalink
rm anthropic test, implement standard tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jul 24, 2024
1 parent dc574f6 commit a5f41a6
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 66 deletions.
50 changes: 27 additions & 23 deletions libs/langchain-anthropic/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -473,37 +473,40 @@ function _formatContent(content: MessageContent) {
type: "image" as const, // Explicitly setting the type as "image"
source,
};
} else if (textTypes.find((t) => t === contentPart.type) && "text" in contentPart) {
} else if (
textTypes.find((t) => t === contentPart.type) &&
"text" in contentPart
) {
// Assuming contentPart is of type MessageContentText here
return {
type: "text" as const, // Explicitly setting the type as "text"
text: contentPart.text,
};
} else if (
toolTypes.find((t) => t === contentPart.type)
) {
if ("index" in contentPart) {
// Anthropic does not support passing the index field here, so we remove it
delete contentPart.index;
} else if (toolTypes.find((t) => t === contentPart.type)) {
const contentPartCopy = { ...contentPart };
if ("index" in contentPartCopy) {
// Anthropic does not support passing the index field here, so we remove it.
delete contentPartCopy.index;
}

if (contentPart.type === "input_json_delta") {
// If type is `input_json_delta`, rename to `tool_use` for Anthropic
contentPart.type = "tool_use";

if (contentPartCopy.type === "input_json_delta") {
// `input_json_delta` type only represents yielding partial tool inputs
// and is not a valid type for Anthropic messages.
contentPartCopy.type = "tool_use";
}

if ("input" in contentPart) {
// If the input is a JSON string, attempt to parse it
if ("input" in contentPartCopy) {
// Anthropic tool use inputs should be valid objects, when applicable.
try {
contentPart.input = JSON.parse(contentPart.input);
contentPartCopy.input = JSON.parse(contentPartCopy.input);
} catch {
// no-op
}
}

// TODO: Fix when SDK types are fixed
return {
...contentPart,
...contentPartCopy,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any;
} else {
Expand Down Expand Up @@ -636,13 +639,15 @@ function extractToolCallChunk(
if (typeof inputJsonDeltaChunks.input === "string") {
newToolCallChunk = {
id: inputJsonDeltaChunks.id,
name: inputJsonDeltaChunks.name,
args: inputJsonDeltaChunks.input,
index: inputJsonDeltaChunks.index,
type: "tool_call_chunk",
};
} else {
newToolCallChunk = {
id: inputJsonDeltaChunks.id,
name: inputJsonDeltaChunks.name,
args: JSON.stringify(inputJsonDeltaChunks.input, null, 2),
index: inputJsonDeltaChunks.index,
type: "tool_call_chunk",
Expand Down Expand Up @@ -994,10 +999,12 @@ export class ChatAnthropicMessages<
streamUsage: !!(this.streamUsage || options.streamUsage),
coerceContentToString,
usageData,
toolUse: toolUse ? {
id: toolUse.id,
name: toolUse.name,
} : undefined,
toolUse: toolUse
? {
id: toolUse.id,
name: toolUse.name,
}
: undefined,
});
if (!result) continue;

Expand Down Expand Up @@ -1092,14 +1099,11 @@ export class ChatAnthropicMessages<
},
}
: requestOptions;
const formattedMsgs = _formatMessagesForAnthropic(messages);
console.log("formattedMsgs");
console.dir(formattedMsgs, { depth: null });
const response = await this.completionWithRetry(
{
...params,
stream: false,
...formattedMsgs,
..._formatMessagesForAnthropic(messages),
},
options
);
Expand Down
41 changes: 0 additions & 41 deletions libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -440,44 +440,3 @@ test("llm token callbacks can handle tool calls", async () => {
if (!args) return;
expect(args).toEqual(JSON.parse(tokens));
});

test.only("Anthropic can stream tool calls, and invoke again with that tool call", async () => {
const input = [
new HumanMessage("What is the weather in SF?"),
];

const weatherTool = tool(
(_) => "The weather in San Francisco is 25°C",
{
name: "get_weather",
description: zodSchema.description,
schema: zodSchema,
}
);

const modelWithTools = model.bindTools([weatherTool]);

const stream = await modelWithTools.stream(input);

let finalChunk: AIMessageChunk | undefined;
for await (const chunk of stream) {
finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk);
}
if (!finalChunk) {
throw new Error("chunk not defined");
}
// Push the AI message with the tool call to the input array.
input.push(finalChunk);
// Push a ToolMessage to the input array to represent the tool call response.
input.push(
new ToolMessage({
tool_call_id: finalChunk.tool_calls?.[0].id ?? "",
content:
"The weather in San Francisco is currently 25 degrees and sunny.",
name: "get_weather",
})
);
// Invoke again to ensure Anthropic can handle it's own tool call.
const finalResult = await modelWithTools.invoke(input);
console.dir(finalResult, { depth: null });
});
2 changes: 1 addition & 1 deletion libs/langchain-aws/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@
"index.d.ts",
"index.d.cts"
]
}
}
170 changes: 169 additions & 1 deletion libs/langchain-standard-tests/src/integration_tests/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import {
getBufferString,
} from "@langchain/core/messages";
import { z } from "zod";
import { StructuredTool } from "@langchain/core/tools";
import { StructuredTool, tool } from "@langchain/core/tools";
import { zodToJsonSchema } from "zod-to-json-schema";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { RunnableLambda } from "@langchain/core/runnables";
import { concat } from "@langchain/core/utils/stream";
import {
BaseChatModelsTests,
BaseChatModelsTestsFields,
Expand Down Expand Up @@ -522,6 +523,159 @@ export abstract class ChatModelIntegrationTests<
expect(cacheValue2).toEqual(cacheValue);
}

/**
* This test verifies models can invoke a tool, and use the AIMessage
* with the tool call in a followup request. This is useful when building
* agents, or other pipelines that invoke tools.
*/
async testModelCanUseToolUseAIMessage() {
if (!this.chatModelHasToolCalling) {
console.log("Test requires tool calling. Skipping...");
return;
}

const model = new this.Cls(this.constructorArgs);
if (!model.bindTools) {
throw new Error(
"bindTools undefined. Cannot test OpenAI formatted tool calls."
);
}

const weatherSchema = z.object({
location: z.string().describe("The location to get the weather for."),
});

// Define the tool
const weatherTool = tool(
(_) => "The weather in San Francisco is 70 degrees and sunny.",
{
name: "get_current_weather",
schema: weatherSchema,
description: "Get the current weather for a location.",
}
);

const modelWithTools = model.bindTools([weatherTool]);

// List of messages to initially invoke the model with, and to hold
// followup messages to invoke the model with.
const messages = [
new HumanMessage(
"What's the weather like in San Francisco right now? Use the 'get_current_weather' tool to find the answer."
),
];

const result: AIMessage = await modelWithTools.invoke(messages);

expect(result.tool_calls?.[0]).toBeDefined();
if (!result.tool_calls?.[0]) {
throw new Error("result.tool_calls is undefined");
}
const { tool_calls } = result;
expect(tool_calls[0].name).toBe("get_current_weather");

// Push the result of the tool call into the messages array so we can
// confirm in the followup request the model can use the tool call.
messages.push(result);

// Create a dummy ToolMessage representing the output of the tool call.
const toolMessage = new ToolMessage({
tool_call_id: tool_calls[0].id ?? "",
name: tool_calls[0].name,
content: await weatherTool.invoke(
tool_calls[0].args as z.infer<typeof weatherSchema>
),
});
messages.push(toolMessage);

const finalResult = await modelWithTools.invoke(messages);

expect(finalResult.content).not.toBe("");
}

/**
* Same as the above test, but streaming both model invocations.
*/
async testModelCanUseToolUseAIMessageWithStreaming() {
if (!this.chatModelHasToolCalling) {
console.log("Test requires tool calling. Skipping...");
return;
}

const model = new this.Cls(this.constructorArgs);
if (!model.bindTools) {
throw new Error(
"bindTools undefined. Cannot test OpenAI formatted tool calls."
);
}

const weatherSchema = z.object({
location: z.string().describe("The location to get the weather for."),
});

// Define the tool
const weatherTool = tool(
(_) => "The weather in San Francisco is 70 degrees and sunny.",
{
name: "get_current_weather",
schema: weatherSchema,
description: "Get the current weather for a location.",
}
);

const modelWithTools = model.bindTools([weatherTool]);

// List of messages to initially invoke the model with, and to hold
// followup messages to invoke the model with.
const messages = [
new HumanMessage(
"What's the weather like in San Francisco right now? Use the 'get_current_weather' tool to find the answer."
),
];

const stream = await modelWithTools.stream(messages);
let result: AIMessageChunk | undefined;
for await (const chunk of stream) {
result = !result ? chunk : concat(result, chunk);
}

expect(result).toBeDefined();
if (!result) return;

expect(result.tool_calls?.[0]).toBeDefined();
if (!result.tool_calls?.[0]) {
throw new Error("result.tool_calls is undefined");
}

const { tool_calls } = result;
expect(tool_calls[0].name).toBe("get_current_weather");

// Push the result of the tool call into the messages array so we can
// confirm in the followup request the model can use the tool call.
messages.push(result);

// Create a dummy ToolMessage representing the output of the tool call.
const toolMessage = new ToolMessage({
tool_call_id: tool_calls[0].id ?? "",
name: tool_calls[0].name,
content: await weatherTool.invoke(
tool_calls[0].args as z.infer<typeof weatherSchema>
),
});
messages.push(toolMessage);

const finalStream = await modelWithTools.stream(messages);
let finalResult: AIMessageChunk | undefined;
for await (const chunk of finalStream) {
finalResult = !finalResult ? chunk : concat(finalResult, chunk);
}

expect(finalResult).toBeDefined();
if (!finalResult) return;

expect(finalResult.content).not.toBe("");
}

/**
* Run all unit tests for the chat model.
* Each test is wrapped in a try/catch block to prevent the entire test suite from failing.
Expand Down Expand Up @@ -629,6 +783,20 @@ export abstract class ChatModelIntegrationTests<
console.error("testCacheComplexMessageTypes failed", e);
}

try {
await this.testModelCanUseToolUseAIMessage();
} catch (e: any) {
allTestsPassed = false;
console.error("testModelCanUseToolUseAIMessage failed", e);
}

try {
await this.testModelCanUseToolUseAIMessageWithStreaming();
} catch (e: any) {
allTestsPassed = false;
console.error("testModelCanUseToolUseAIMessageWithStreaming failed", e);
}

return allTestsPassed;
}
}

0 comments on commit a5f41a6

Please sign in to comment.