From 4a4292200117022303e47d3686aca60208e6173a Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Thu, 13 Jun 2024 09:42:04 -0700 Subject: [PATCH] mistral[minor]: Populate usage metadata for mistral (#5751) * mistral[minor]: Populate usage metadata for mistral * chore: lint files * chore: lint files * fix build issue * bump min core dep --- libs/langchain-mistralai/package.json | 2 +- libs/langchain-mistralai/src/chat_models.ts | 63 +++++++++++++++--- .../src/tests/chat_models.int.test.ts | 66 +++++++++++++++++++ .../tests/chat_models.standard.int.test.ts | 16 ----- yarn.lock | 2 +- 5 files changed, 122 insertions(+), 27 deletions(-) diff --git a/libs/langchain-mistralai/package.json b/libs/langchain-mistralai/package.json index 59f8750e145a..c089b7dc552c 100644 --- a/libs/langchain-mistralai/package.json +++ b/libs/langchain-mistralai/package.json @@ -35,7 +35,7 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@langchain/core": ">0.1.56 <0.3.0", + "@langchain/core": ">=0.2.5 <0.3.0", "@mistralai/mistralai": "^0.4.0", "uuid": "^9.0.0", "zod": "^3.22.4", diff --git a/libs/langchain-mistralai/src/chat_models.ts b/libs/langchain-mistralai/src/chat_models.ts index a86589bd8dfd..9ea8803a343e 100644 --- a/libs/langchain-mistralai/src/chat_models.ts +++ b/libs/langchain-mistralai/src/chat_models.ts @@ -8,6 +8,7 @@ import { ChatRequest, Tool as MistralAITool, Message as MistralAIMessage, + TokenUsage as MistralAITokenUsage, } from "@mistralai/mistralai"; import { MessageType, @@ -80,6 +81,11 @@ interface MistralAICallOptions }; tools: StructuredToolInterface[] | MistralAIToolInput[] | MistralAITool[]; tool_choice?: MistralAIToolChoice; + /** + * Whether or not to include token usage in the stream. + * @default {true} + */ + streamUsage?: boolean; } export interface ChatMistralAICallOptions extends MistralAICallOptions {} @@ -87,7 +93,9 @@ export interface ChatMistralAICallOptions extends MistralAICallOptions {} /** * Input to chat model class. */ -export interface ChatMistralAIInput extends BaseChatModelParams { +export interface ChatMistralAIInput + extends BaseChatModelParams, + Pick { /** * The API key to use. * @default {process.env.MISTRAL_API_KEY} @@ -216,7 +224,8 @@ function convertMessagesToMistralMessages( } function mistralAIResponseToChatMessage( - choice: ChatCompletionResponse["choices"][0] + choice: ChatCompletionResponse["choices"][0], + usage?: MistralAITokenUsage ): BaseMessage { const { message } = choice; // MistralAI SDK does not include tool_calls in the non @@ -254,6 +263,13 @@ function mistralAIResponseToChatMessage( })) : undefined, }, + usage_metadata: usage + ? { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + total_tokens: usage.total_tokens, + } + : undefined, }); } default: @@ -261,12 +277,27 @@ function mistralAIResponseToChatMessage( } } -function _convertDeltaToMessageChunk(delta: { - role?: string | undefined; - content?: string | undefined; - tool_calls?: MistralAIToolCalls[] | undefined; -}) { +function _convertDeltaToMessageChunk( + delta: { + role?: string | undefined; + content?: string | undefined; + tool_calls?: MistralAIToolCalls[] | undefined; + }, + usage?: MistralAITokenUsage | null +) { if (!delta.content && !delta.tool_calls) { + if (usage) { + return new AIMessageChunk({ + content: "", + usage_metadata: usage + ? { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + total_tokens: usage.total_tokens, + } + : undefined, + }); + } return null; } // Our merge additional kwargs util function will throw unless there @@ -313,6 +344,13 @@ function _convertDeltaToMessageChunk(delta: { content, tool_call_chunks: toolCallChunks, additional_kwargs, + usage_metadata: usage + ? { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + total_tokens: usage.total_tokens, + } + : undefined, }); } else if (role === "tool") { return new ToolMessageChunk({ @@ -389,6 +427,8 @@ export class ChatMistralAI< lc_serializable = true; + streamUsage = true; + constructor(fields?: ChatMistralAIInput) { super(fields ?? {}); const apiKey = fields?.apiKey ?? getEnvironmentVariable("MISTRAL_API_KEY"); @@ -409,6 +449,7 @@ export class ChatMistralAI< this.seed = this.randomSeed; this.modelName = fields?.model ?? fields?.modelName ?? this.model; this.model = this.modelName; + this.streamUsage = fields?.streamUsage ?? this.streamUsage; } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { @@ -600,7 +641,7 @@ export class ChatMistralAI< const text = part.message?.content ?? ""; const generation: ChatGeneration = { text, - message: mistralAIResponseToChatMessage(part), + message: mistralAIResponseToChatMessage(part, response?.usage), }; if (part.finish_reason) { generation.generationInfo = { finish_reason: part.finish_reason }; @@ -643,7 +684,11 @@ export class ChatMistralAI< prompt: 0, completion: choice.index ?? 0, }; - const message = _convertDeltaToMessageChunk(delta); + const shouldStreamUsage = this.streamUsage || options.streamUsage; + const message = _convertDeltaToMessageChunk( + delta, + shouldStreamUsage ? data.usage : null + ); if (message === null) { // Do not yield a chunk if the message is empty continue; diff --git a/libs/langchain-mistralai/src/tests/chat_models.int.test.ts b/libs/langchain-mistralai/src/tests/chat_models.int.test.ts index 827eaa08db03..b052f2c3e30d 100644 --- a/libs/langchain-mistralai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-mistralai/src/tests/chat_models.int.test.ts @@ -11,6 +11,7 @@ import { DynamicStructuredTool, StructuredTool } from "@langchain/core/tools"; import { z } from "zod"; import { AIMessage, + AIMessageChunk, BaseMessage, HumanMessage, ToolMessage, @@ -916,3 +917,68 @@ describe("codestral-latest", () => { console.log(parsedArgs.code); }); }); + +test("Stream token count usage_metadata", async () => { + const model = new ChatMistralAI({ + model: "codestral-latest", + temperature: 0, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBe(13); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(10); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); +}); + +test("streamUsage excludes token usage", async () => { + const model = new ChatMistralAI({ + model: "codestral-latest", + temperature: 0, + streamUsage: false, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + console.log(res); + expect(res?.usage_metadata).not.toBeDefined(); +}); + +test("Invoke token count usage_metadata", async () => { + const model = new ChatMistralAI({ + model: "codestral-latest", + temperature: 0, + }); + const res = await model.invoke("Why is the sky blue? Be concise."); + console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBe(13); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(10); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); +}); diff --git a/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts b/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts index 0d80e46fccdb..248b3892f727 100644 --- a/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts @@ -23,22 +23,6 @@ class ChatMistralAIStandardIntegrationTests extends ChatModelIntegrationTests< functionId: "123456789", }); } - - async testUsageMetadataStreaming() { - this.skipTestMessage( - "testUsageMetadataStreaming", - "ChatMistralAI", - "Streaming tokens is not currently supported." - ); - } - - async testUsageMetadata() { - this.skipTestMessage( - "testUsageMetadata", - "ChatMistralAI", - "Usage metadata tokens is not currently supported." - ); - } } const testClass = new ChatMistralAIStandardIntegrationTests(); diff --git a/yarn.lock b/yarn.lock index 2d33838f76e7..0c64972dcfd1 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10444,7 +10444,7 @@ __metadata: resolution: "@langchain/mistralai@workspace:libs/langchain-mistralai" dependencies: "@jest/globals": ^29.5.0 - "@langchain/core": ">0.1.56 <0.3.0" + "@langchain/core": ">=0.2.5 <0.3.0" "@langchain/scripts": ~0.0.14 "@langchain/standard-tests": 0.0.0 "@mistralai/mistralai": ^0.4.0