Skip to content

Commit

Permalink
add streamUsage
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jun 11, 2024
1 parent b38a802 commit 448511e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 8 deletions.
26 changes: 19 additions & 7 deletions libs/langchain-cohere/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ export interface ChatCohereInput extends BaseChatModelParams {
* @default {false}
*/
streaming?: boolean;
/**
* Whether or not to include token usage when streaming.
* This will include an extra chunk at the end of the stream
* with `eventType: "stream-end"` and the token usage in
* `usage_metadata`.
* @default {true}
*/
streamUsage?: boolean;
}

interface TokenUsage {
Expand All @@ -58,7 +66,8 @@ interface TokenUsage {
export interface CohereChatCallOptions
extends BaseLanguageModelCallOptions,
Partial<Omit<Cohere.ChatRequest, "message">>,
Partial<Omit<Cohere.ChatStreamRequest, "message">> {}
Partial<Omit<Cohere.ChatStreamRequest, "message">>,
Pick<ChatCohereInput, "streamUsage"> {}

function convertMessagesToCohereMessages(
messages: Array<BaseMessage>
Expand Down Expand Up @@ -130,6 +139,8 @@ export class ChatCohere<

streaming = false;

streamUsage: boolean = true;

constructor(fields?: ChatCohereInput) {
super(fields ?? {});

Expand All @@ -144,6 +155,7 @@ export class ChatCohere<
this.model = fields?.model ?? this.model;
this.temperature = fields?.temperature ?? this.temperature;
this.streaming = fields?.streaming ?? this.streaming;
this.streamUsage = fields?.streamUsage ?? this.streamUsage;
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
Expand Down Expand Up @@ -331,7 +343,9 @@ export class ChatCohere<
if (chunk.eventType === "text-generation") {
yield new ChatGenerationChunk({
text: chunk.text,
message: new AIMessageChunk({ content: chunk.text }),
message: new AIMessageChunk({
content: chunk.text,
}),
});
await runManager?.handleLLMNewToken(chunk.text);
} else if (chunk.eventType !== "stream-end") {
Expand All @@ -351,13 +365,11 @@ export class ChatCohere<
});
} else if (
chunk.eventType === "stream-end" &&
chunk.response.meta?.tokens &&
(chunk.response.meta.tokens.inputTokens ||
chunk.response.meta.tokens.outputTokens)
(this.streamUsage || options.streamUsage)
) {
// stream-end events contain the final token count
const input_tokens = chunk.response.meta.tokens.inputTokens ?? 0;
const output_tokens = chunk.response.meta.tokens.outputTokens ?? 0;
const input_tokens = chunk.response.meta?.tokens?.inputTokens ?? 0;
const output_tokens = chunk.response.meta?.tokens?.outputTokens ?? 0;
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({
Expand Down
39 changes: 38 additions & 1 deletion libs/langchain-cohere/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@ test("should abort the request", async () => {
}).rejects.toThrow("AbortError");
});

test("Stream token count usage_metadata", async () => {
test.only("Stream token count usage_metadata", async () => {
const model = new ChatCohere({
model: "command-light",
temperature: 0,
});
let res: AIMessageChunk | null = null;
let lastRes: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
Expand All @@ -73,6 +74,7 @@ test("Stream token count usage_metadata", async () => {
} else {
res = res.concat(chunk);
}
lastRes = chunk;
}
console.log(res);
expect(res?.usage_metadata).toBeDefined();
Expand All @@ -84,6 +86,41 @@ test("Stream token count usage_metadata", async () => {
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
expect(lastRes?.additional_kwargs).toBeDefined();
if (!lastRes?.additional_kwargs) {
return;
}
expect(lastRes.additional_kwargs.eventType).toBe("stream-end");
});

test.only("streamUsage excludes token usage", async () => {
const model = new ChatCohere({
model: "command-light",
temperature: 0,
streamUsage: false,
});
let res: AIMessageChunk | null = null;
let lastRes: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
if (!res) {
res = chunk;
} else {
res = res.concat(chunk);
}
lastRes = chunk;
}
console.log(res);
expect(res?.usage_metadata).not.toBeDefined();
if (res?.usage_metadata) {
return;
}
expect(lastRes?.additional_kwargs).toBeDefined();
if (!lastRes?.additional_kwargs) {
return;
}
expect(lastRes.additional_kwargs.eventType).not.toBe("stream-end");
});

test("Invoke token count usage_metadata", async () => {
Expand Down

0 comments on commit 448511e

Please sign in to comment.