Skip to content

Commit

Permalink
Merge branch 'main' into brace/tool-runtime-doc
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul authored Jun 15, 2024
2 parents 8dce5c3 + a2ded31 commit 3e72f3f
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 45 deletions.
2 changes: 1 addition & 1 deletion libs/langchain-google-common/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@langchain/core": ">0.1.56 <0.3.0",
"@langchain/core": ">=0.2.5 <0.3.0",
"uuid": "^9.0.0",
"zod-to-json-schema": "^3.22.4"
},
Expand Down
24 changes: 20 additions & 4 deletions libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { type BaseMessage } from "@langchain/core/messages";
import { UsageMetadata, type BaseMessage } from "@langchain/core/messages";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";

import {
Expand Down Expand Up @@ -150,7 +150,8 @@ export interface ChatGoogleBaseInput<AuthOptions>
extends BaseChatModelParams,
GoogleConnectionParams<AuthOptions>,
GoogleAIModelParams,
GoogleAISafetyParams {}
GoogleAISafetyParams,
Pick<GoogleAIBaseLanguageModelCallOptions, "streamUsage"> {}

function convertToGeminiTools(
structuredTools: (StructuredToolInterface | Record<string, unknown>)[]
Expand Down Expand Up @@ -216,6 +217,8 @@ export abstract class ChatGoogleBase<AuthOptions>

safetyHandler: GoogleAISafetyHandler;

streamUsage = true;

protected connection: ChatConnection<AuthOptions>;

protected streamedConnection: ChatConnection<AuthOptions>;
Expand All @@ -226,7 +229,7 @@ export abstract class ChatGoogleBase<AuthOptions>
copyAndValidateModelParamsInto(fields, this);
this.safetyHandler =
fields?.safetyHandler ?? new DefaultGeminiSafetyHandler();

this.streamUsage = fields?.streamUsage ?? this.streamUsage;
const client = this.buildClient(fields);
this.buildConnection(fields ?? {}, client);
}
Expand Down Expand Up @@ -342,12 +345,24 @@ export abstract class ChatGoogleBase<AuthOptions>

// Get the streaming parser of the response
const stream = response.data as JsonStream;

let usageMetadata: UsageMetadata | undefined;
// Loop until the end of the stream
// During the loop, yield each time we get a chunk from the streaming parser
// that is either available or added to the queue
while (!stream.streamDone) {
const output = await stream.nextChunk();
if (
output &&
output.usageMetadata &&
this.streamUsage !== false &&
options.streamUsage !== false
) {
usageMetadata = {
input_tokens: output.usageMetadata.promptTokenCount,
output_tokens: output.usageMetadata.candidatesTokenCount,
total_tokens: output.usageMetadata.totalTokenCount,
};
}
const chunk =
output !== null
? safeResponseToChatGeneration({ data: output }, this.safetyHandler)
Expand All @@ -356,6 +371,7 @@ export abstract class ChatGoogleBase<AuthOptions>
generationInfo: { finishReason: "stop" },
message: new AIMessageChunk({
content: "",
usage_metadata: usageMetadata,
}),
});
yield chunk;
Expand Down
9 changes: 8 additions & 1 deletion libs/langchain-google-common/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,14 @@ export interface GoogleAIBaseLLMInput<AuthOptions>
export interface GoogleAIBaseLanguageModelCallOptions
extends BaseLanguageModelCallOptions,
GoogleAIModelRequestParams,
GoogleAISafetyParams {}
GoogleAISafetyParams {
/**
* Whether or not to include usage data, like token counts
* in the streamed response chunks.
* @default true
*/
streamUsage?: boolean;
}

/**
* Input to LLM class.
Expand Down
11 changes: 11 additions & 0 deletions libs/langchain-google-common/src/utils/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
MessageContentText,
SystemMessage,
ToolMessage,
UsageMetadata,
isAIMessage,
} from "@langchain/core/messages";
import {
Expand Down Expand Up @@ -604,12 +605,22 @@ export function responseToChatGenerations(
id: toolCall.id,
index: i,
}));
let usageMetadata: UsageMetadata | undefined;
if ("usageMetadata" in response.data) {
usageMetadata = {
input_tokens: response.data.usageMetadata.promptTokenCount as number,
output_tokens: response.data.usageMetadata
.candidatesTokenCount as number,
total_tokens: response.data.usageMetadata.totalTokenCount as number,
};
}
ret = [
new ChatGenerationChunk({
message: new AIMessageChunk({
content: combinedContent,
additional_kwargs: ret[ret.length - 1]?.message.additional_kwargs,
tool_call_chunks: toolCallChunks,
usage_metadata: usageMetadata,
}),
text: combinedText,
generationInfo: ret[ret.length - 1].generationInfo,
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-google-genai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"license": "MIT",
"dependencies": {
"@google/generative-ai": "^0.7.0",
"@langchain/core": ">0.1.5 <0.3.0",
"@langchain/core": ">=0.2.5 <0.3.0",
"zod-to-json-schema": "^3.22.4"
},
"devDependencies": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,6 @@ class ChatGoogleGenerativeAIStandardIntegrationTests extends ChatModelIntegratio
});
}

async testUsageMetadataStreaming() {
// ChatGoogleGenerativeAI does not support streaming tokens by
// default, so we must pass in a call option to
// enable streaming tokens.
const callOptions: ChatGoogleGenerativeAI["ParsedCallOptions"] = {
streamUsage: true,
};
await super.testUsageMetadataStreaming(callOptions);
}

async testUsageMetadata() {
// ChatGoogleGenerativeAI does not support counting tokens
// by default, so we must pass in a call option to enable
// streaming tokens.
const callOptions: ChatGoogleGenerativeAI["ParsedCallOptions"] = {
streamUsage: true,
};
await super.testUsageMetadata(callOptions);
}

async testToolMessageHistoriesStringContent() {
this.skipTestMessage(
"testToolMessageHistoriesStringContent",
Expand Down
62 changes: 62 additions & 0 deletions libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,65 @@ describe("GAuth Chat", () => {
expect(result).toHaveProperty("location");
});
});

test("Stream token count usage_metadata", async () => {
const model = new ChatVertexAI({
temperature: 0,
});
let res: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
if (!res) {
res = chunk;
} else {
res = res.concat(chunk);
}
}
console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBe(9);
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});

test("streamUsage excludes token usage", async () => {
const model = new ChatVertexAI({
temperature: 0,
streamUsage: false,
});
let res: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
if (!res) {
res = chunk;
} else {
res = res.concat(chunk);
}
}
console.log(res);
expect(res?.usage_metadata).not.toBeDefined();
});

test("Invoke token count usage_metadata", async () => {
const model = new ChatVertexAI({
temperature: 0,
});
const res = await model.invoke("Why is the sky blue? Be concise.");
console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBe(9);
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,6 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
});
}

async testUsageMetadataStreaming() {
this.skipTestMessage(
"testUsageMetadataStreaming",
"ChatVertexAI",
"Streaming tokens is not currently supported."
);
}

async testUsageMetadata() {
this.skipTestMessage(
"testUsageMetadata",
"ChatVertexAI",
"Usage metadata tokens is not currently supported."
);
}

async testToolMessageHistoriesListContent() {
this.skipTestMessage(
"testToolMessageHistoriesListContent",
Expand Down
4 changes: 2 additions & 2 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10194,7 +10194,7 @@ __metadata:
resolution: "@langchain/google-common@workspace:libs/langchain-google-common"
dependencies:
"@jest/globals": ^29.5.0
"@langchain/core": ">0.1.56 <0.3.0"
"@langchain/core": ">=0.2.5 <0.3.0"
"@langchain/scripts": ~0.0.14
"@swc/core": ^1.3.90
"@swc/jest": ^0.2.29
Expand Down Expand Up @@ -10261,7 +10261,7 @@ __metadata:
dependencies:
"@google/generative-ai": ^0.7.0
"@jest/globals": ^29.5.0
"@langchain/core": ">0.1.5 <0.3.0"
"@langchain/core": ">=0.2.5 <0.3.0"
"@langchain/scripts": ~0.0.14
"@langchain/standard-tests": 0.0.0
"@swc/core": ^1.3.90
Expand Down

0 comments on commit 3e72f3f

Please sign in to comment.