From 448511e49a308f4a2a9acf9260b9ceb9b400df4a Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 11 Jun 2024 13:19:58 -0700 Subject: [PATCH] add streamUsage --- libs/langchain-cohere/src/chat_models.ts | 26 +++++++++---- .../src/tests/chat_models.int.test.ts | 39 ++++++++++++++++++- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/libs/langchain-cohere/src/chat_models.ts b/libs/langchain-cohere/src/chat_models.ts index 62ee87ca952e..73bcde88ebdc 100644 --- a/libs/langchain-cohere/src/chat_models.ts +++ b/libs/langchain-cohere/src/chat_models.ts @@ -47,6 +47,14 @@ export interface ChatCohereInput extends BaseChatModelParams { * @default {false} */ streaming?: boolean; + /** + * Whether or not to include token usage when streaming. + * This will include an extra chunk at the end of the stream + * with `eventType: "stream-end"` and the token usage in + * `usage_metadata`. + * @default {true} + */ + streamUsage?: boolean; } interface TokenUsage { @@ -58,7 +66,8 @@ interface TokenUsage { export interface CohereChatCallOptions extends BaseLanguageModelCallOptions, Partial>, - Partial> {} + Partial>, + Pick {} function convertMessagesToCohereMessages( messages: Array @@ -130,6 +139,8 @@ export class ChatCohere< streaming = false; + streamUsage: boolean = true; + constructor(fields?: ChatCohereInput) { super(fields ?? {}); @@ -144,6 +155,7 @@ export class ChatCohere< this.model = fields?.model ?? this.model; this.temperature = fields?.temperature ?? this.temperature; this.streaming = fields?.streaming ?? this.streaming; + this.streamUsage = fields?.streamUsage ?? this.streamUsage; } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { @@ -331,7 +343,9 @@ export class ChatCohere< if (chunk.eventType === "text-generation") { yield new ChatGenerationChunk({ text: chunk.text, - message: new AIMessageChunk({ content: chunk.text }), + message: new AIMessageChunk({ + content: chunk.text, + }), }); await runManager?.handleLLMNewToken(chunk.text); } else if (chunk.eventType !== "stream-end") { @@ -351,13 +365,11 @@ export class ChatCohere< }); } else if ( chunk.eventType === "stream-end" && - chunk.response.meta?.tokens && - (chunk.response.meta.tokens.inputTokens || - chunk.response.meta.tokens.outputTokens) + (this.streamUsage || options.streamUsage) ) { // stream-end events contain the final token count - const input_tokens = chunk.response.meta.tokens.inputTokens ?? 0; - const output_tokens = chunk.response.meta.tokens.outputTokens ?? 0; + const input_tokens = chunk.response.meta?.tokens?.inputTokens ?? 0; + const output_tokens = chunk.response.meta?.tokens?.outputTokens ?? 0; yield new ChatGenerationChunk({ text: "", message: new AIMessageChunk({ diff --git a/libs/langchain-cohere/src/tests/chat_models.int.test.ts b/libs/langchain-cohere/src/tests/chat_models.int.test.ts index 9c55abbd919c..6e87f9ba367e 100644 --- a/libs/langchain-cohere/src/tests/chat_models.int.test.ts +++ b/libs/langchain-cohere/src/tests/chat_models.int.test.ts @@ -59,12 +59,13 @@ test("should abort the request", async () => { }).rejects.toThrow("AbortError"); }); -test("Stream token count usage_metadata", async () => { +test.only("Stream token count usage_metadata", async () => { const model = new ChatCohere({ model: "command-light", temperature: 0, }); let res: AIMessageChunk | null = null; + let lastRes: AIMessageChunk | null = null; for await (const chunk of await model.stream( "Why is the sky blue? Be concise." )) { @@ -73,6 +74,7 @@ test("Stream token count usage_metadata", async () => { } else { res = res.concat(chunk); } + lastRes = chunk; } console.log(res); expect(res?.usage_metadata).toBeDefined(); @@ -84,6 +86,41 @@ test("Stream token count usage_metadata", async () => { expect(res.usage_metadata.total_tokens).toBe( res.usage_metadata.input_tokens + res.usage_metadata.output_tokens ); + expect(lastRes?.additional_kwargs).toBeDefined(); + if (!lastRes?.additional_kwargs) { + return; + } + expect(lastRes.additional_kwargs.eventType).toBe("stream-end"); +}); + +test.only("streamUsage excludes token usage", async () => { + const model = new ChatCohere({ + model: "command-light", + temperature: 0, + streamUsage: false, + }); + let res: AIMessageChunk | null = null; + let lastRes: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + lastRes = chunk; + } + console.log(res); + expect(res?.usage_metadata).not.toBeDefined(); + if (res?.usage_metadata) { + return; + } + expect(lastRes?.additional_kwargs).toBeDefined(); + if (!lastRes?.additional_kwargs) { + return; + } + expect(lastRes.additional_kwargs.eventType).not.toBe("stream-end"); }); test("Invoke token count usage_metadata", async () => {