Skip to content

Commit

Permalink
google-genai[minor]: Add support for token counting via usage_metadata (
Browse files Browse the repository at this point in the history
#5757)

* google-genai[minor]: Add support for token counting via usage_metadata

* jsdoc nits

* pass entire request obj when getting input tok

* fix and stop making api calls
  • Loading branch information
bracesproul authored Jun 13, 2024
1 parent baab194 commit 4b4423a
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 39 deletions.
127 changes: 101 additions & 26 deletions libs/langchain-google-genai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ import {
type FunctionDeclarationSchema as GenerativeAIFunctionDeclarationSchema,
GenerateContentRequest,
SafetySetting,
Part as GenerativeAIPart,
} from "@google/generative-ai";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { AIMessageChunk, BaseMessage } from "@langchain/core/messages";
import {
AIMessageChunk,
BaseMessage,
UsageMetadata,
} from "@langchain/core/messages";
import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import {
Expand Down Expand Up @@ -56,12 +61,20 @@ export interface GoogleGenerativeAIChatCallOptions
tools?:
| StructuredToolInterface[]
| GoogleGenerativeAIFunctionDeclarationsTool[];
/**
* Whether or not to include usage data, like token counts
* in the streamed response chunks.
* @default true
*/
streamUsage?: boolean;
}

/**
* An interface defining the input to the ChatGoogleGenerativeAI class.
*/
export interface GoogleGenerativeAIChatInput extends BaseChatModelParams {
export interface GoogleGenerativeAIChatInput
extends BaseChatModelParams,
Pick<GoogleGenerativeAIChatCallOptions, "streamUsage"> {
/**
* Model Name to use
*
Expand Down Expand Up @@ -222,6 +235,8 @@ export class ChatGoogleGenerativeAI

streaming = false;

streamUsage = true;

private client: GenerativeModel;

get _isMultimodalModel() {
Expand Down Expand Up @@ -306,6 +321,7 @@ export class ChatGoogleGenerativeAI
baseUrl: fields?.baseUrl,
}
);
this.streamUsage = fields?.streamUsage ?? this.streamUsage;
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
Expand Down Expand Up @@ -398,27 +414,31 @@ export class ChatGoogleGenerativeAI
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } };
}

const res = await this.caller.callWithOptions(
{ signal: options?.signal },
async () => {
let output;
try {
output = await this.client.generateContent({
...parameters,
contents: prompt,
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
// TODO: Improve error handling
if (e.message?.includes("400 Bad Request")) {
e.status = 400;
}
throw e;
}
return output;
const res = await this.completionWithRetry({
...parameters,
contents: prompt,
});

let usageMetadata: UsageMetadata | undefined;
if ("usageMetadata" in res.response) {
const genAIUsageMetadata = res.response.usageMetadata as {
promptTokenCount: number | undefined;
candidatesTokenCount: number | undefined;
totalTokenCount: number | undefined;
};
usageMetadata = {
input_tokens: genAIUsageMetadata.promptTokenCount ?? 0,
output_tokens: genAIUsageMetadata.candidatesTokenCount ?? 0,
total_tokens: genAIUsageMetadata.totalTokenCount ?? 0,
};
}

const generationResult = mapGenerateContentResultToChatResult(
res.response,
{
usageMetadata,
}
);
const generationResult = mapGenerateContentResultToChatResult(res.response);
await runManager?.handleLLMNewToken(
generationResult.generations[0].text ?? ""
);
Expand All @@ -435,19 +455,53 @@ export class ChatGoogleGenerativeAI
this._isMultimodalModel
);
const parameters = this.invocationParams(options);
const request = {
...parameters,
contents: prompt,
};
const stream = await this.caller.callWithOptions(
{ signal: options?.signal },
async () => {
const { stream } = await this.client.generateContentStream({
...parameters,
contents: prompt,
});
const { stream } = await this.client.generateContentStream(request);
return stream;
}
);

let usageMetadata: UsageMetadata | undefined;
for await (const response of stream) {
const chunk = convertResponseContentToChatGenerationChunk(response);
if (
"usageMetadata" in response &&
this.streamUsage !== false &&
options.streamUsage !== false
) {
const genAIUsageMetadata = response.usageMetadata as {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
};
if (!usageMetadata) {
usageMetadata = {
input_tokens: genAIUsageMetadata.promptTokenCount,
output_tokens: genAIUsageMetadata.candidatesTokenCount,
total_tokens: genAIUsageMetadata.totalTokenCount,
};
} else {
// Under the hood, LangChain combines the prompt tokens. Google returns the updated
// total each time, so we need to find the difference between the tokens.
const outputTokenDiff =
genAIUsageMetadata.candidatesTokenCount -
usageMetadata.output_tokens;
usageMetadata = {
input_tokens: 0,
output_tokens: outputTokenDiff,
total_tokens: outputTokenDiff,
};
}
}

const chunk = convertResponseContentToChatGenerationChunk(response, {
usageMetadata,
});
if (!chunk) {
continue;
}
Expand All @@ -457,6 +511,27 @@ export class ChatGoogleGenerativeAI
}
}

async completionWithRetry(
request: string | GenerateContentRequest | (string | GenerativeAIPart)[],
options?: this["ParsedCallOptions"]
) {
return this.caller.callWithOptions(
{ signal: options?.signal },
async () => {
try {
return this.client.generateContent(request);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
// TODO: Improve error handling
if (e.message?.includes("400 Bad Request")) {
e.status = 400;
}
throw e;
}
}
);
}

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
Expand Down
64 changes: 63 additions & 1 deletion libs/langchain-google-genai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { test } from "@jest/globals";
import * as fs from "node:fs/promises";
import { fileURLToPath } from "node:url";
import * as path from "node:path";
import { HumanMessage } from "@langchain/core/messages";
import { AIMessageChunk, HumanMessage } from "@langchain/core/messages";
import {
ChatPromptTemplate,
MessagesPlaceholder,
Expand Down Expand Up @@ -320,3 +320,65 @@ test("ChatGoogleGenerativeAI can call withStructuredOutput genai tools and invok
console.log(res);
expect(typeof res.url === "string").toBe(true);
});

test("Stream token count usage_metadata", async () => {
const model = new ChatGoogleGenerativeAI({
temperature: 0,
});
let res: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
if (!res) {
res = chunk;
} else {
res = res.concat(chunk);
}
}
console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBe(10);
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});

test("streamUsage excludes token usage", async () => {
const model = new ChatGoogleGenerativeAI({
temperature: 0,
streamUsage: false,
});
let res: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
if (!res) {
res = chunk;
} else {
res = res.concat(chunk);
}
}
console.log(res);
expect(res?.usage_metadata).not.toBeDefined();
});

test("Invoke token count usage_metadata", async () => {
const model = new ChatGoogleGenerativeAI({
temperature: 0,
});
const res = await model.invoke("Why is the sky blue? Be concise.");
console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBe(10);
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,23 @@ class ChatGoogleGenerativeAIStandardIntegrationTests extends ChatModelIntegratio
}

async testUsageMetadataStreaming() {
this.skipTestMessage(
"testUsageMetadataStreaming",
"ChatGoogleGenerativeAI",
"Streaming tokens is not currently supported."
);
// ChatGoogleGenerativeAI does not support streaming tokens by
// default, so we must pass in a call option to
// enable streaming tokens.
const callOptions: ChatGoogleGenerativeAI["ParsedCallOptions"] = {
streamUsage: true,
};
await super.testUsageMetadataStreaming(callOptions);
}

async testUsageMetadata() {
this.skipTestMessage(
"testUsageMetadata",
"ChatGoogleGenerativeAI",
"Usage metadata tokens is not currently supported."
);
// ChatGoogleGenerativeAI does not support counting tokens
// by default, so we must pass in a call option to enable
// streaming tokens.
const callOptions: ChatGoogleGenerativeAI["ParsedCallOptions"] = {
streamUsage: true,
};
await super.testUsageMetadata(callOptions);
}

async testToolMessageHistoriesStringContent() {
Expand Down
13 changes: 11 additions & 2 deletions libs/langchain-google-genai/src/utils/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
ChatMessage,
MessageContent,
MessageContentComplex,
UsageMetadata,
isBaseMessage,
} from "@langchain/core/messages";
import {
Expand Down Expand Up @@ -179,7 +180,10 @@ export function convertBaseMessagesToContent(
}

export function mapGenerateContentResultToChatResult(
response: EnhancedGenerateContentResponse
response: EnhancedGenerateContentResponse,
extra?: {
usageMetadata: UsageMetadata | undefined;
}
): ChatResult {
// if rejected or error, return empty generations with reason in filters
if (
Expand Down Expand Up @@ -208,6 +212,7 @@ export function mapGenerateContentResultToChatResult(
additional_kwargs: {
...generationInfo,
},
usage_metadata: extra?.usageMetadata,
}),
generationInfo,
};
Expand All @@ -218,7 +223,10 @@ export function mapGenerateContentResultToChatResult(
}

export function convertResponseContentToChatGenerationChunk(
response: EnhancedGenerateContentResponse
response: EnhancedGenerateContentResponse,
extra?: {
usageMetadata: UsageMetadata | undefined;
}
): ChatGenerationChunk | null {
if (!response.candidates || response.candidates.length === 0) {
return null;
Expand All @@ -235,6 +243,7 @@ export function convertResponseContentToChatGenerationChunk(
// Each chunk can have unique "generationInfo", and merging strategy is unclear,
// so leave blank for now.
additional_kwargs: {},
usage_metadata: extra?.usageMetadata,
}),
generationInfo,
});
Expand Down

0 comments on commit 4b4423a

Please sign in to comment.