Skip to content

Commit

Permalink
google-common [patch], google-* [tests]: Fix streaming false (#5571)
Browse files Browse the repository at this point in the history
* Fix for issue #5475 - no tools_call in message.
Natively support non-streaming mode.

* Test for structuredOutput (see #5475).
Fixes for other tests.

* Fixes for tests.

* Formatting.
  • Loading branch information
afirstenberg authored May 28, 2024
1 parent c3ce775 commit 4b4f611
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 14 deletions.
3 changes: 1 addition & 2 deletions libs/langchain-google-common/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ export abstract class AbstractGoogleLLMConnection<
AuthOptions
> {
async buildUrlMethodGemini(): Promise<string> {
// Vertex AI only handles streamedGenerateContent
return "streamGenerateContent";
return this.streaming ? "streamGenerateContent" : "generateContent";
}

async buildUrlMethod(): Promise<string> {
Expand Down
18 changes: 14 additions & 4 deletions libs/langchain-google-gauth/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
BaseMessageChunk,
BaseMessageLike,
HumanMessage,
MessageContentComplex,
// MessageContentComplex,
SystemMessage,
ToolMessage,
} from "@langchain/core/messages";
Expand All @@ -31,12 +31,22 @@ describe("GAuth Chat", () => {

const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/);

/*
expect(aiMessage.content.length).toBeGreaterThan(0);
expect(aiMessage.content[0]).toBeDefined();

const content = aiMessage.content[0] as MessageContentComplex;
expect(typeof content).toBe("string");
expect(content).toBe("2");
expect(content).toHaveProperty("type");
expect(content.type).toEqual("text");
const textContent = content as MessageContentText;
expect(textContent.text).toBeDefined();
expect(textContent.text).toEqual("2");
*/
} catch (e) {
console.error(e);
throw e;
Expand Down
48 changes: 45 additions & 3 deletions libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { test } from "@jest/globals";
// eslint-disable-next-line import/no-extraneous-dependencies
import { z } from "zod";
import { BaseLanguageModelInput } from "@langchain/core/language_models/base";
import { ChatPromptValue } from "@langchain/core/prompt_values";
import {
Expand All @@ -7,10 +9,11 @@ import {
BaseMessage,
BaseMessageChunk,
HumanMessage,
MessageContentComplex,
MessageContentText,
// MessageContentComplex,
// MessageContentText,
SystemMessage,
} from "@langchain/core/messages";
import { ConsoleCallbackHandler } from "@langchain/core/tracers/console";
import { ChatVertexAI } from "../chat_models.js";
import { VertexAI } from "../llms.js";

Expand All @@ -29,16 +32,22 @@ describe("GAuth Chat", () => {

const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/);

/*
expect(aiMessage.content.length).toBeGreaterThan(0);
expect(aiMessage.content[0]).toBeDefined();

const content = aiMessage.content[0] as MessageContentComplex;
expect(content).toHaveProperty("type");
expect(content.type).toEqual("text");
const textContent = content as MessageContentText;
expect(textContent.text).toBeDefined();
expect(textContent.text).toEqual("2");
*/
} catch (e) {
console.error(e);
throw e;
Expand All @@ -62,6 +71,12 @@ describe("GAuth Chat", () => {

const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(["H", "T"]).toContainEqual(text);

/*
expect(aiMessage.content.length).toBeGreaterThan(0);
expect(aiMessage.content[0]).toBeDefined();
Expand All @@ -72,6 +87,7 @@ describe("GAuth Chat", () => {
const textContent = content as MessageContentText;
expect(textContent.text).toBeDefined();
expect(["H", "T"]).toContainEqual(textContent.text);
*/
} catch (e) {
console.error(e);
throw e;
Expand Down Expand Up @@ -109,4 +125,30 @@ describe("GAuth Chat", () => {
throw e;
}
});

test("structuredOutput", async () => {
const handler = new ConsoleCallbackHandler();

const calculatorSchema = z.object({
operation: z
.enum(["add", "subtract", "multiply", "divide"])
.describe("The type of operation to execute"),
number1: z.number().describe("The first number to operate on."),
number2: z.number().describe("The second number to operate on."),
});

const model = new ChatVertexAI({
temperature: 0.7,
model: "gemini-1.0-pro",
callbacks: [handler],
}).withStructuredOutput(calculatorSchema);

const response = await model.invoke("What is 1628253239 times 81623836?");
expect(response).toHaveProperty("operation");
expect(response.operation).toEqual("multiply");
expect(response).toHaveProperty("number1");
expect(response.number1).toEqual(1628253239);
expect(response).toHaveProperty("number2");
expect(response.number2).toEqual(81623836);
});
});
21 changes: 16 additions & 5 deletions libs/langchain-google-webauth/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import {
BaseMessageChunk,
BaseMessageLike,
HumanMessage,
MessageContentComplex,
MessageContentText,
// MessageContentComplex,
// MessageContentText,
SystemMessage,
ToolMessage,
} from "@langchain/core/messages";
Expand Down Expand Up @@ -253,16 +253,22 @@ describe("Google Webauth Chat", () => {

const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/);

/*
expect(aiMessage.content.length).toBeGreaterThan(0);
expect(aiMessage.content[0]).toBeDefined();

const content = aiMessage.content[0] as MessageContentComplex;
expect(content).toHaveProperty("type");
expect(content.type).toEqual("text");
const textContent = content as MessageContentText;
expect(textContent.text).toBeDefined();
expect(textContent.text).toEqual("2");
*/
} catch (e) {
console.error(e);
throw e;
Expand All @@ -286,6 +292,12 @@ describe("Google Webauth Chat", () => {

const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(["H", "T"]).toContainEqual(text);

/*
expect(aiMessage.content.length).toBeGreaterThan(0);
expect(aiMessage.content[0]).toBeDefined();
Expand All @@ -296,6 +308,7 @@ describe("Google Webauth Chat", () => {
const textContent = content as MessageContentText;
expect(textContent.text).toBeDefined();
expect(["H", "T"]).toContainEqual(textContent.text);
*/
} catch (e) {
console.error(e);
throw e;
Expand Down Expand Up @@ -361,8 +374,6 @@ describe("Google Webauth Chat", () => {
});
const result = await model.invoke("Run a test on the cobalt project");
expect(result).toHaveProperty("content");
expect(Array.isArray(result.content)).toBeTruthy();
expect(result.content).toHaveLength(0);
const args = result?.lc_kwargs?.additional_kwargs;
expect(args).toBeDefined();
expect(args).toHaveProperty("tool_calls");
Expand Down

0 comments on commit 4b4f611

Please sign in to comment.