Skip to content

Commit

Permalink
anthropic[patch]: Stream tokens (#5730)
Browse files Browse the repository at this point in the history
* anthropic[patch]: Stream tokens

* chore: lint files

* added normal test

* add streamUsage field to control streaming token counts
  • Loading branch information
bracesproul authored Jun 11, 2024
1 parent de3e618 commit 9a0675c
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 10 deletions.
43 changes: 41 additions & 2 deletions libs/langchain-anthropic/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
ToolMessage,
isAIMessage,
MessageContent,
UsageMetadata,
} from "@langchain/core/messages";
import {
ChatGeneration,
Expand Down Expand Up @@ -67,7 +68,9 @@ type AnthropicToolChoice =
}
| "any"
| "auto";
export interface ChatAnthropicCallOptions extends BaseLanguageModelCallOptions {
export interface ChatAnthropicCallOptions
extends BaseLanguageModelCallOptions,
Pick<AnthropicInput, "streamUsage"> {
tools?: (StructuredToolInterface | AnthropicTool)[];
/**
* Whether or not to specify what tool the model should use
Expand Down Expand Up @@ -211,6 +214,12 @@ export interface AnthropicInput {
* `anthropic.messages`} that are not explicitly specified on this class.
*/
invocationKwargs?: Kwargs;

/**
* Whether or not to include token usage data in streamed chunks.
* @default true
*/
streamUsage?: boolean;
}

/**
Expand Down Expand Up @@ -485,6 +494,8 @@ export class ChatAnthropicMessages<
// Used for streaming requests
protected streamingClient: Anthropic;

streamUsage = true;

constructor(fields?: Partial<AnthropicInput> & BaseChatModelParams) {
super(fields ?? {});

Expand Down Expand Up @@ -516,12 +527,13 @@ export class ChatAnthropicMessages<

this.streaming = fields?.streaming ?? false;
this.clientOptions = fields?.clientOptions ?? {};
this.streamUsage = fields?.streamUsage ?? this.streamUsage;
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
const params = this.invocationParams(options);
return {
ls_provider: "openai",
ls_provider: "anthropic",
ls_model_name: this.model,
ls_model_type: "chat",
ls_temperature: params.temperature ?? undefined,
Expand Down Expand Up @@ -691,18 +703,36 @@ export class ChatAnthropicMessages<
}
}
usageData = usage;
let usageMetadata: UsageMetadata | undefined;
if (this.streamUsage || options.streamUsage) {
usageMetadata = {
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens,
};
}
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: "",
additional_kwargs: filteredAdditionalKwargs,
usage_metadata: usageMetadata,
}),
text: "",
});
} else if (data.type === "message_delta") {
let usageMetadata: UsageMetadata | undefined;
if (this.streamUsage || options.streamUsage) {
usageMetadata = {
input_tokens: data.usage.output_tokens,
output_tokens: 0,
total_tokens: data.usage.output_tokens,
};
}
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: "",
additional_kwargs: { ...data.delta },
usage_metadata: usageMetadata,
}),
text: "",
});
Expand All @@ -723,10 +753,19 @@ export class ChatAnthropicMessages<
}
}
}
let usageMetadata: UsageMetadata | undefined;
if (this.streamUsage || options.streamUsage) {
usageMetadata = {
input_tokens: usageData.input_tokens,
output_tokens: usageData.output_tokens,
total_tokens: usageData.input_tokens + usageData.output_tokens,
};
}
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: "",
additional_kwargs: { usage: usageData },
usage_metadata: usageMetadata,
}),
text: "",
});
Expand Down
29 changes: 28 additions & 1 deletion libs/langchain-anthropic/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable no-process-env */

import { expect, test } from "@jest/globals";
import { HumanMessage } from "@langchain/core/messages";
import { AIMessageChunk, HumanMessage } from "@langchain/core/messages";
import { ChatPromptValue } from "@langchain/core/prompt_values";
import {
PromptTemplate,
Expand Down Expand Up @@ -318,3 +318,30 @@ test("Test ChatAnthropic multimodal", async () => {
]);
console.log(res);
});

test("Stream tokens", async () => {
const model = new ChatAnthropic({
model: "claude-3-haiku-20240307",
temperature: 0,
});
let res: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
if (!res) {
res = chunk;
} else {
res = res.concat(chunk);
}
}
console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBe(34);
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@ class ChatAnthropicStandardIntegrationTests extends ChatModelIntegrationTests<
},
});
}

async testUsageMetadataStreaming() {
console.warn(
"Skipping testUsageMetadataStreaming, not implemented in ChatAnthropic."
);
}
}

const testClass = new ChatAnthropicStandardIntegrationTests();
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-cloudflare/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export class ChatCloudflareWorkersAI

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
return {
ls_provider: "openai",
ls_provider: "cloudflare",
ls_model_name: this.model,
ls_model_type: "chat",
ls_stop: options.stop,
Expand Down

0 comments on commit 9a0675c

Please sign in to comment.