From 4b4423a591a6918bf6156c732dc20aaedec0dcb2 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Thu, 13 Jun 2024 15:46:02 -0700 Subject: [PATCH] google-genai[minor]: Add support for token counting via usage_metadata (#5757) * google-genai[minor]: Add support for token counting via usage_metadata * jsdoc nits * pass entire request obj when getting input tok * fix and stop making api calls --- .../langchain-google-genai/src/chat_models.ts | 127 ++++++++++++++---- .../src/tests/chat_models.int.test.ts | 64 ++++++++- .../tests/chat_models.standard.int.test.ts | 24 ++-- .../src/utils/common.ts | 13 +- 4 files changed, 189 insertions(+), 39 deletions(-) diff --git a/libs/langchain-google-genai/src/chat_models.ts b/libs/langchain-google-genai/src/chat_models.ts index 73b54d22bfab..226ceeed0a4c 100644 --- a/libs/langchain-google-genai/src/chat_models.ts +++ b/libs/langchain-google-genai/src/chat_models.ts @@ -6,9 +6,14 @@ import { type FunctionDeclarationSchema as GenerativeAIFunctionDeclarationSchema, GenerateContentRequest, SafetySetting, + Part as GenerativeAIPart, } from "@google/generative-ai"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; -import { AIMessageChunk, BaseMessage } from "@langchain/core/messages"; +import { + AIMessageChunk, + BaseMessage, + UsageMetadata, +} from "@langchain/core/messages"; import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { @@ -56,12 +61,20 @@ export interface GoogleGenerativeAIChatCallOptions tools?: | StructuredToolInterface[] | GoogleGenerativeAIFunctionDeclarationsTool[]; + /** + * Whether or not to include usage data, like token counts + * in the streamed response chunks. + * @default true + */ + streamUsage?: boolean; } /** * An interface defining the input to the ChatGoogleGenerativeAI class. */ -export interface GoogleGenerativeAIChatInput extends BaseChatModelParams { +export interface GoogleGenerativeAIChatInput + extends BaseChatModelParams, + Pick { /** * Model Name to use * @@ -222,6 +235,8 @@ export class ChatGoogleGenerativeAI streaming = false; + streamUsage = true; + private client: GenerativeModel; get _isMultimodalModel() { @@ -306,6 +321,7 @@ export class ChatGoogleGenerativeAI baseUrl: fields?.baseUrl, } ); + this.streamUsage = fields?.streamUsage ?? this.streamUsage; } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { @@ -398,27 +414,31 @@ export class ChatGoogleGenerativeAI return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } }; } - const res = await this.caller.callWithOptions( - { signal: options?.signal }, - async () => { - let output; - try { - output = await this.client.generateContent({ - ...parameters, - contents: prompt, - }); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (e: any) { - // TODO: Improve error handling - if (e.message?.includes("400 Bad Request")) { - e.status = 400; - } - throw e; - } - return output; + const res = await this.completionWithRetry({ + ...parameters, + contents: prompt, + }); + + let usageMetadata: UsageMetadata | undefined; + if ("usageMetadata" in res.response) { + const genAIUsageMetadata = res.response.usageMetadata as { + promptTokenCount: number | undefined; + candidatesTokenCount: number | undefined; + totalTokenCount: number | undefined; + }; + usageMetadata = { + input_tokens: genAIUsageMetadata.promptTokenCount ?? 0, + output_tokens: genAIUsageMetadata.candidatesTokenCount ?? 0, + total_tokens: genAIUsageMetadata.totalTokenCount ?? 0, + }; + } + + const generationResult = mapGenerateContentResultToChatResult( + res.response, + { + usageMetadata, } ); - const generationResult = mapGenerateContentResultToChatResult(res.response); await runManager?.handleLLMNewToken( generationResult.generations[0].text ?? "" ); @@ -435,19 +455,53 @@ export class ChatGoogleGenerativeAI this._isMultimodalModel ); const parameters = this.invocationParams(options); + const request = { + ...parameters, + contents: prompt, + }; const stream = await this.caller.callWithOptions( { signal: options?.signal }, async () => { - const { stream } = await this.client.generateContentStream({ - ...parameters, - contents: prompt, - }); + const { stream } = await this.client.generateContentStream(request); return stream; } ); + let usageMetadata: UsageMetadata | undefined; for await (const response of stream) { - const chunk = convertResponseContentToChatGenerationChunk(response); + if ( + "usageMetadata" in response && + this.streamUsage !== false && + options.streamUsage !== false + ) { + const genAIUsageMetadata = response.usageMetadata as { + promptTokenCount: number; + candidatesTokenCount: number; + totalTokenCount: number; + }; + if (!usageMetadata) { + usageMetadata = { + input_tokens: genAIUsageMetadata.promptTokenCount, + output_tokens: genAIUsageMetadata.candidatesTokenCount, + total_tokens: genAIUsageMetadata.totalTokenCount, + }; + } else { + // Under the hood, LangChain combines the prompt tokens. Google returns the updated + // total each time, so we need to find the difference between the tokens. + const outputTokenDiff = + genAIUsageMetadata.candidatesTokenCount - + usageMetadata.output_tokens; + usageMetadata = { + input_tokens: 0, + output_tokens: outputTokenDiff, + total_tokens: outputTokenDiff, + }; + } + } + + const chunk = convertResponseContentToChatGenerationChunk(response, { + usageMetadata, + }); if (!chunk) { continue; } @@ -457,6 +511,27 @@ export class ChatGoogleGenerativeAI } } + async completionWithRetry( + request: string | GenerateContentRequest | (string | GenerativeAIPart)[], + options?: this["ParsedCallOptions"] + ) { + return this.caller.callWithOptions( + { signal: options?.signal }, + async () => { + try { + return this.client.generateContent(request); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + // TODO: Improve error handling + if (e.message?.includes("400 Bad Request")) { + e.status = 400; + } + throw e; + } + } + ); + } + withStructuredOutput< // eslint-disable-next-line @typescript-eslint/no-explicit-any RunOutput extends Record = Record diff --git a/libs/langchain-google-genai/src/tests/chat_models.int.test.ts b/libs/langchain-google-genai/src/tests/chat_models.int.test.ts index ef80c693303c..4030413a8963 100644 --- a/libs/langchain-google-genai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-genai/src/tests/chat_models.int.test.ts @@ -2,7 +2,7 @@ import { test } from "@jest/globals"; import * as fs from "node:fs/promises"; import { fileURLToPath } from "node:url"; import * as path from "node:path"; -import { HumanMessage } from "@langchain/core/messages"; +import { AIMessageChunk, HumanMessage } from "@langchain/core/messages"; import { ChatPromptTemplate, MessagesPlaceholder, @@ -320,3 +320,65 @@ test("ChatGoogleGenerativeAI can call withStructuredOutput genai tools and invok console.log(res); expect(typeof res.url === "string").toBe(true); }); + +test("Stream token count usage_metadata", async () => { + const model = new ChatGoogleGenerativeAI({ + temperature: 0, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBe(10); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(10); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); +}); + +test("streamUsage excludes token usage", async () => { + const model = new ChatGoogleGenerativeAI({ + temperature: 0, + streamUsage: false, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + console.log(res); + expect(res?.usage_metadata).not.toBeDefined(); +}); + +test("Invoke token count usage_metadata", async () => { + const model = new ChatGoogleGenerativeAI({ + temperature: 0, + }); + const res = await model.invoke("Why is the sky blue? Be concise."); + console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBe(10); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(10); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); +}); diff --git a/libs/langchain-google-genai/src/tests/chat_models.standard.int.test.ts b/libs/langchain-google-genai/src/tests/chat_models.standard.int.test.ts index 4f9909358165..335798acb4d5 100644 --- a/libs/langchain-google-genai/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-google-genai/src/tests/chat_models.standard.int.test.ts @@ -28,19 +28,23 @@ class ChatGoogleGenerativeAIStandardIntegrationTests extends ChatModelIntegratio } async testUsageMetadataStreaming() { - this.skipTestMessage( - "testUsageMetadataStreaming", - "ChatGoogleGenerativeAI", - "Streaming tokens is not currently supported." - ); + // ChatGoogleGenerativeAI does not support streaming tokens by + // default, so we must pass in a call option to + // enable streaming tokens. + const callOptions: ChatGoogleGenerativeAI["ParsedCallOptions"] = { + streamUsage: true, + }; + await super.testUsageMetadataStreaming(callOptions); } async testUsageMetadata() { - this.skipTestMessage( - "testUsageMetadata", - "ChatGoogleGenerativeAI", - "Usage metadata tokens is not currently supported." - ); + // ChatGoogleGenerativeAI does not support counting tokens + // by default, so we must pass in a call option to enable + // streaming tokens. + const callOptions: ChatGoogleGenerativeAI["ParsedCallOptions"] = { + streamUsage: true, + }; + await super.testUsageMetadata(callOptions); } async testToolMessageHistoriesStringContent() { diff --git a/libs/langchain-google-genai/src/utils/common.ts b/libs/langchain-google-genai/src/utils/common.ts index ad384d33a7a1..641d103dfb1b 100644 --- a/libs/langchain-google-genai/src/utils/common.ts +++ b/libs/langchain-google-genai/src/utils/common.ts @@ -12,6 +12,7 @@ import { ChatMessage, MessageContent, MessageContentComplex, + UsageMetadata, isBaseMessage, } from "@langchain/core/messages"; import { @@ -179,7 +180,10 @@ export function convertBaseMessagesToContent( } export function mapGenerateContentResultToChatResult( - response: EnhancedGenerateContentResponse + response: EnhancedGenerateContentResponse, + extra?: { + usageMetadata: UsageMetadata | undefined; + } ): ChatResult { // if rejected or error, return empty generations with reason in filters if ( @@ -208,6 +212,7 @@ export function mapGenerateContentResultToChatResult( additional_kwargs: { ...generationInfo, }, + usage_metadata: extra?.usageMetadata, }), generationInfo, }; @@ -218,7 +223,10 @@ export function mapGenerateContentResultToChatResult( } export function convertResponseContentToChatGenerationChunk( - response: EnhancedGenerateContentResponse + response: EnhancedGenerateContentResponse, + extra?: { + usageMetadata: UsageMetadata | undefined; + } ): ChatGenerationChunk | null { if (!response.candidates || response.candidates.length === 0) { return null; @@ -235,6 +243,7 @@ export function convertResponseContentToChatGenerationChunk( // Each chunk can have unique "generationInfo", and merging strategy is unclear, // so leave blank for now. additional_kwargs: {}, + usage_metadata: extra?.usageMetadata, }), generationInfo, });