Skip to content

Commit

Permalink
Merge branch 'main' into brace/tools-return-tool-message
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul authored Jul 8, 2024
2 parents 6b8e1a1 + e6616a5 commit eca01c1
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions libs/langchain-standard-tests/src/integration_tests/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
import { z } from "zod";
import { StructuredTool } from "@langchain/core/tools";
import { zodToJsonSchema } from "zod-to-json-schema";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import {
BaseChatModelsTests,
BaseChatModelsTestsFields,
Expand All @@ -37,6 +38,14 @@ class AdderTool extends StructuredTool {
}
}

const MATH_ADDITION_PROMPT = /* #__PURE__ */ ChatPromptTemplate.fromMessages([
[
"system",
"You are bad at math and must ALWAYS call the {toolName} function.",
],
["human", "What is the sum of 1836281973 and 19973286?"],
]);

interface ChatModelIntegrationTestsFields<
CallOptions extends BaseChatModelCallOptions = BaseChatModelCallOptions,
OutputMessageType extends BaseMessageChunk = BaseMessageChunk,
Expand Down Expand Up @@ -228,11 +237,11 @@ export abstract class ChatModelIntegrationTests<
new ToolMessage(functionResult, functionId, functionName),
];

const resultStringContent = await modelWithTools.invoke(
const result = await modelWithTools.invoke(
messagesStringContent,
callOptions
);
expect(resultStringContent).toBeInstanceOf(this.invokeResponseType);
expect(result).toBeInstanceOf(this.invokeResponseType);
}

/**
Expand Down Expand Up @@ -334,11 +343,11 @@ export abstract class ChatModelIntegrationTests<
new HumanMessage("What is 3 + 4"),
];

const resultStringContent = await modelWithTools.invoke(
const result = await modelWithTools.invoke(
messagesStringContent,
callOptions
);
expect(resultStringContent).toBeInstanceOf(this.invokeResponseType);
expect(result).toBeInstanceOf(this.invokeResponseType);
}

async testWithStructuredOutput() {
Expand All @@ -353,13 +362,17 @@ export abstract class ChatModelIntegrationTests<
"withStructuredOutput undefined. Cannot test tool message histories."
);
}
const modelWithTools = model.withStructuredOutput(adderSchema);
const modelWithTools = model.withStructuredOutput(adderSchema, {
name: "math_addition",
});

const resultStringContent = await modelWithTools.invoke("What is 1 + 2");
expect(resultStringContent.a).toBeDefined();
expect([1, 2].includes(resultStringContent.a)).toBeTruthy();
expect(resultStringContent.b).toBeDefined();
expect([1, 2].includes(resultStringContent.b)).toBeTruthy();
const result = await MATH_ADDITION_PROMPT.pipe(modelWithTools).invoke({
toolName: "math_addition",
});
expect(result.a).toBeDefined();
expect(typeof result.a).toBe("number");
expect(result.b).toBeDefined();
expect(typeof result.b).toBe("number");
}

async testWithStructuredOutputIncludeRaw() {
Expand All @@ -376,14 +389,17 @@ export abstract class ChatModelIntegrationTests<
}
const modelWithTools = model.withStructuredOutput(adderSchema, {
includeRaw: true,
name: "math_addition",
});

const resultStringContent = await modelWithTools.invoke("What is 1 + 2");
expect(resultStringContent.raw).toBeInstanceOf(this.invokeResponseType);
expect(resultStringContent.parsed.a).toBeDefined();
expect([1, 2].includes(resultStringContent.parsed.a)).toBeTruthy();
expect(resultStringContent.parsed.b).toBeDefined();
expect([1, 2].includes(resultStringContent.parsed.b)).toBeTruthy();
const result = await MATH_ADDITION_PROMPT.pipe(modelWithTools).invoke({
toolName: "math_addition",
});
expect(result.raw).toBeInstanceOf(this.invokeResponseType);
expect(result.parsed.a).toBeDefined();
expect(typeof result.parsed.a).toBe("number");
expect(result.parsed.b).toBeDefined();
expect(typeof result.parsed.b).toBe("number");
}

async testBindToolsWithOpenAIFormattedTools() {
Expand All @@ -409,7 +425,11 @@ export abstract class ChatModelIntegrationTests<
},
]);

const result: AIMessage = await modelWithTools.invoke("What is 1 + 2");
const result: AIMessage = await MATH_ADDITION_PROMPT.pipe(
modelWithTools
).invoke({
toolName: "math_addition",
});
expect(result.tool_calls).toHaveLength(1);
if (!result.tool_calls) {
throw new Error("result.tool_calls is undefined");
Expand Down

0 comments on commit eca01c1

Please sign in to comment.