Skip to content

Commit

Permalink
core[minor],openai[patch]: Add usage metadata to AIMessage/Chunk (#…
Browse files Browse the repository at this point in the history
…5586)

* core[minor],openai[patch]: Add usage metadata to base ai message

* chore: lint files

* Only set stream_options if the model is OpenAI

* add test confirming last finish reason is stop

* chore: lint files

* docs

* lint n format

* cr

* cr

* Remove Azure specific check

* Remove comment

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
bracesproul and jacoblee93 authored May 31, 2024
1 parent fc67984 commit 7d8fa2e
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 3 deletions.
13 changes: 13 additions & 0 deletions docs/core_docs/docs/integrations/chat/openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,16 @@ You can also use the callbacks system:
### With `.generate()`

<CodeBlock language="typescript">{OpenAIGenerationInfo}</CodeBlock>

### Streaming tokens

OpenAI supports streaming token counts via an opt-in call option. This can be set by passing `{ stream_options: { include_usage: true } }`.
Setting this call option will cause the model to return an additional chunk at the end of the stream, containing the token usage.

import OpenAIStreamTokens from "@examples/models/chat/integration_openai_stream_tokens.ts";

<CodeBlock language="typescript">{OpenAIStreamTokens}</CodeBlock>

:::tip
See the LangSmith trace [here](https://smith.langchain.com/public/66bf7377-cc69-4676-91b6-25929a05e8b7/r)
:::
30 changes: 30 additions & 0 deletions examples/src/models/chat/integration_openai_stream_tokens.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { AIMessageChunk } from "@langchain/core/messages";
import { ChatOpenAI } from "@langchain/openai";

// Instantiate the model
const model = new ChatOpenAI();

const response = await model.stream("Hello, how are you?", {
// Pass the stream options
stream_options: {
include_usage: true,
},
});

// Iterate over the response, only saving the last chunk
let finalResult: AIMessageChunk | undefined;
for await (const chunk of response) {
if (finalResult) {
finalResult = finalResult.concat(chunk);
} else {
finalResult = chunk;
}
}

console.log(finalResult?.usage_metadata);

/*
{ input_tokens: 13, output_tokens: 30, total_tokens: 43 }
*/
58 changes: 55 additions & 3 deletions langchain-core/src/messages/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ import {
export type AIMessageFields = BaseMessageFields & {
tool_calls?: ToolCall[];
invalid_tool_calls?: InvalidToolCall[];
usage_metadata?: UsageMetadata;
};

/**
* Usage metadata for a message, such as token counts.
*/
export type UsageMetadata = {
/**
* The count of input (or prompt) tokens.
*/
input_tokens: number;
/**
* The count of output (or completion) tokens
*/
output_tokens: number;
/**
* The total token count
*/
total_tokens: number;
};

/**
Expand All @@ -30,6 +49,11 @@ export class AIMessage extends BaseMessage {

invalid_tool_calls?: InvalidToolCall[] = [];

/**
* If provided, token usage information associated with the message.
*/
usage_metadata?: UsageMetadata;

get lc_aliases(): Record<string, string> {
// exclude snake case conversion to pascal case
return {
Expand Down Expand Up @@ -94,6 +118,7 @@ export class AIMessage extends BaseMessage {
this.invalid_tool_calls =
initParams.invalid_tool_calls ?? this.invalid_tool_calls;
}
this.usage_metadata = initParams.usage_metadata;
}

static lc_name() {
Expand Down Expand Up @@ -127,6 +152,11 @@ export class AIMessageChunk extends BaseMessageChunk {

tool_call_chunks?: ToolCallChunk[] = [];

/**
* If provided, token usage information associated with the message.
*/
usage_metadata?: UsageMetadata;

constructor(fields: string | AIMessageChunkFields) {
let initParams: AIMessageChunkFields;
if (typeof fields === "string") {
Expand Down Expand Up @@ -177,10 +207,11 @@ export class AIMessageChunk extends BaseMessageChunk {
// properties with initializers, so we have to check types twice.
super(initParams);
this.tool_call_chunks =
initParams?.tool_call_chunks ?? this.tool_call_chunks;
this.tool_calls = initParams?.tool_calls ?? this.tool_calls;
initParams.tool_call_chunks ?? this.tool_call_chunks;
this.tool_calls = initParams.tool_calls ?? this.tool_calls;
this.invalid_tool_calls =
initParams?.invalid_tool_calls ?? this.invalid_tool_calls;
initParams.invalid_tool_calls ?? this.invalid_tool_calls;
this.usage_metadata = initParams.usage_metadata;
}

get lc_aliases(): Record<string, string> {
Expand Down Expand Up @@ -226,6 +257,27 @@ export class AIMessageChunk extends BaseMessageChunk {
combinedFields.tool_call_chunks = rawToolCalls;
}
}
if (
this.usage_metadata !== undefined ||
chunk.usage_metadata !== undefined
) {
const left: UsageMetadata = this.usage_metadata ?? {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
};
const right: UsageMetadata = chunk.usage_metadata ?? {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
};
const usage_metadata: UsageMetadata = {
input_tokens: left.input_tokens + right.input_tokens,
output_tokens: left.output_tokens + right.output_tokens,
total_tokens: left.total_tokens + right.total_tokens,
};
combinedFields.usage_metadata = usage_metadata;
}
return new AIMessageChunk(combinedFields);
}
}
22 changes: 22 additions & 0 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ export interface ChatOpenAICallOptions
promptIndex?: number;
response_format?: { type: "json_object" };
seed?: number;
stream_options?: { include_usage: boolean };
}

/**
Expand Down Expand Up @@ -553,6 +554,9 @@ export class ChatOpenAI<
tool_choice: options?.tool_choice,
response_format: options?.response_format,
seed: options?.seed,
...(options?.stream_options !== undefined
? { stream_options: options.stream_options }
: {}),
...this.modelKwargs,
};
return params;
Expand Down Expand Up @@ -586,8 +590,12 @@ export class ChatOpenAI<
};
let defaultRole: OpenAIRoleEnum | undefined;
const streamIterable = await this.completionWithRetry(params, options);
let usage: OpenAIClient.Completions.CompletionUsage | undefined;
for await (const data of streamIterable) {
const choice = data?.choices[0];
if (data.usage) {
usage = data.usage;
}
if (!choice) {
continue;
}
Expand Down Expand Up @@ -632,6 +640,20 @@ export class ChatOpenAI<
{ chunk: generationChunk }
);
}
if (usage) {
const generationChunk = new ChatGenerationChunk({
message: new AIMessageChunk({
content: "",
usage_metadata: {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
},
}),
text: "",
});
yield generationChunk;
}
if (options.signal?.aborted) {
throw new Error("AbortError");
}
Expand Down
60 changes: 60 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { test, jest, expect } from "@jest/globals";
import {
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
Expand Down Expand Up @@ -767,3 +768,62 @@ test("Test ChatOpenAI token usage reporting for streaming calls", async () => {
expect(streamingTokenUsed).toEqual(nonStreamingTokenUsed);
}
});

test("Finish reason is 'stop'", async () => {
const model = new ChatOpenAI();
const response = await model.stream("Hello, how are you?");
let finalResult: AIMessageChunk | undefined;
for await (const chunk of response) {
if (finalResult) {
finalResult = finalResult.concat(chunk);
} else {
finalResult = chunk;
}
}
expect(finalResult).toBeTruthy();
expect(finalResult?.response_metadata?.finish_reason).toBe("stop");
});

test("Streaming tokens can be found in usage_metadata field", async () => {
const model = new ChatOpenAI();
const response = await model.stream("Hello, how are you?", {
stream_options: {
include_usage: true,
},
});
let finalResult: AIMessageChunk | undefined;
for await (const chunk of response) {
if (finalResult) {
finalResult = finalResult.concat(chunk);
} else {
finalResult = chunk;
}
}
console.log({
usage_metadata: finalResult?.usage_metadata,
});
expect(finalResult).toBeTruthy();
expect(finalResult?.usage_metadata).toBeTruthy();
expect(finalResult?.usage_metadata?.input_tokens).toBeGreaterThan(0);
expect(finalResult?.usage_metadata?.output_tokens).toBeGreaterThan(0);
expect(finalResult?.usage_metadata?.total_tokens).toBeGreaterThan(0);
});

test("streaming: true tokens can be found in usage_metadata field", async () => {
const model = new ChatOpenAI({
streaming: true,
});
const response = await model.invoke("Hello, how are you?", {
stream_options: {
include_usage: true,
},
});
console.log({
usage_metadata: response?.usage_metadata,
});
expect(response).toBeTruthy();
expect(response?.usage_metadata).toBeTruthy();
expect(response?.usage_metadata?.input_tokens).toBeGreaterThan(0);
expect(response?.usage_metadata?.output_tokens).toBeGreaterThan(0);
expect(response?.usage_metadata?.total_tokens).toBeGreaterThan(0);
});

0 comments on commit 7d8fa2e

Please sign in to comment.