Skip to content

Commit

Permalink
final
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jun 27, 2024
1 parent b103443 commit 49fe576
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 40 deletions.
24 changes: 21 additions & 3 deletions libs/langchain-aws/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,21 @@ export interface ChatBedrockConverseInput
* @link https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
*/
additionalModelRequestFields?: __DocumentType;
/**
* Whether or not to include usage data, like token counts
* in the streamed response chunks. Passing as a call option will
* take precedence over the class-level setting.
* @default true
*/
streamUsage?: boolean;
}

export interface ChatBedrockConverseCallOptions
extends BaseLanguageModelCallOptions,
Pick<ChatBedrockConverseInput, "additionalModelRequestFields"> {
Pick<
ChatBedrockConverseInput,
"additionalModelRequestFields" | "streamUsage"
> {
/**
* A list of stop sequences. A stop sequence is a sequence of characters that causes
* the model to stop generating the response.
Expand Down Expand Up @@ -181,6 +191,8 @@ export class ChatBedrockConverse

additionalModelRequestFields?: __DocumentType;

streamUsage = true;

client: BedrockRuntimeClient;

constructor(fields?: ChatBedrockConverseInput) {
Expand Down Expand Up @@ -231,6 +243,7 @@ export class ChatBedrockConverse
this.endpointHost = rest?.endpointHost;
this.topP = rest?.topP;
this.additionalModelRequestFields = rest?.additionalModelRequestFields;
this.streamUsage = rest?.streamUsage ?? this.streamUsage;
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
Expand Down Expand Up @@ -369,7 +382,10 @@ export class ChatBedrockConverse
const { converseMessages, converseSystem } =
convertToConverseMessages(messages);
const params = this.invocationParams(options);

let { streamUsage } = this;
if (options.streamUsage !== undefined) {
streamUsage = options.streamUsage;
}
const command = new ConverseStreamCommand({
modelId: this.model,
messages: converseMessages,
Expand All @@ -388,7 +404,9 @@ export class ChatBedrockConverse
yield textChatGeneration;
await runManager?.handleLLMNewToken(textChatGeneration.text);
} else if (chunk.metadata) {
yield handleConverseStreamMetadata(chunk.metadata);
yield handleConverseStreamMetadata(chunk.metadata, {
streamUsage,
});
} else {
yield new ChatGenerationChunk({
text: "",
Expand Down
18 changes: 16 additions & 2 deletions libs/langchain-aws/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,15 @@ export function convertConverseMessageToLangChainMessage(
`Unsupported message role received in ChatBedrockConverse response: ${message.role}`
);
}
let requestId: string | undefined;
if (
"$metadata" in responseMetadata &&
responseMetadata.$metadata &&
typeof responseMetadata.$metadata === "object" &&
"requestId" in responseMetadata.$metadata
) {
requestId = responseMetadata.$metadata.requestId as string;
}
let tokenUsage: UsageMetadata | undefined;
if (responseMetadata.usage) {
const input_tokens = responseMetadata.usage.inputTokens ?? 0;
Expand All @@ -305,6 +314,7 @@ export function convertConverseMessageToLangChainMessage(
content: message.content[0].text,
response_metadata: responseMetadata,
usage_metadata: tokenUsage,
id: requestId,
});
} else {
const toolCalls: ToolCall[] = [];
Expand Down Expand Up @@ -333,6 +343,7 @@ export function convertConverseMessageToLangChainMessage(
tool_calls: toolCalls.length ? toolCalls : undefined,
response_metadata: responseMetadata,
usage_metadata: tokenUsage,
id: requestId,
});
}
}
Expand Down Expand Up @@ -397,7 +408,10 @@ export function handleConverseStreamContentBlockStart(
}

export function handleConverseStreamMetadata(
metadata: ConverseStreamMetadataEvent
metadata: ConverseStreamMetadataEvent,
extra: {
streamUsage: boolean;
}
): ChatGenerationChunk {
const inputTokens = metadata.usage?.inputTokens ?? 0;
const outputTokens = metadata.usage?.outputTokens ?? 0;
Expand All @@ -410,7 +424,7 @@ export function handleConverseStreamMetadata(
text: "",
message: new AIMessageChunk({
content: "",
usage_metadata,
usage_metadata: extra.streamUsage ? usage_metadata : undefined,
response_metadata: {
// Use the same key as returned from the Converse API
metadata,
Expand Down
61 changes: 26 additions & 35 deletions libs/langchain-aws/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const baseConstructorArgs: Partial<
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
maxRetries: 1,
};

test("Test ChatBedrockConverse can invoke", async () => {
Expand Down Expand Up @@ -81,25 +82,6 @@ test("Test ChatBedrockConverse with stop", async () => {
expect(res.content).not.toContain("world");
});

// AbortSignal not implemented yet.
test.skip("Test ChatBedrockConverse stream method with abort", async () => {
await expect(async () => {
const model = new ChatBedrockConverse({
...baseConstructorArgs,
maxTokens: 100,
});
const stream = await model.stream(
"How is your day going? Be extremely verbose.",
{
signal: AbortSignal.timeout(500),
}
);
for await (const chunk of stream) {
console.log(chunk);
}
}).rejects.toThrow();
});

test("Test ChatBedrockConverse stream method with early break", async () => {
const model = new ChatBedrockConverse({
...baseConstructorArgs,
Expand All @@ -119,7 +101,10 @@ test("Test ChatBedrockConverse stream method with early break", async () => {
});

test("Streaming tokens can be found in usage_metadata field", async () => {
const model = new ChatBedrockConverse();
const model = new ChatBedrockConverse({
...baseConstructorArgs,
maxTokens: 5,
});
const response = await model.stream("Hello, how are you?");
let finalResult: AIMessageChunk | undefined;
for await (const chunk of response) {
Expand All @@ -140,28 +125,34 @@ test("Streaming tokens can be found in usage_metadata field", async () => {
});

test("populates ID field on AIMessage", async () => {
const model = new ChatBedrockConverse();
const model = new ChatBedrockConverse({
...baseConstructorArgs,
maxTokens: 5,
});
const response = await model.invoke("Hell");
console.log({
invokeId: response.id,
});
expect(response.id?.length).toBeGreaterThan(1);
expect(response?.id?.startsWith("chatcmpl-")).toBe(true);

/**
* Bedrock Converse does not include an ID in
* the response of a streaming call.
*/

// Streaming
let finalChunk: AIMessageChunk | undefined;
for await (const chunk of await model.stream("Hell")) {
if (!finalChunk) {
finalChunk = chunk;
} else {
finalChunk = finalChunk.concat(chunk);
}
}
console.log({
streamId: finalChunk?.id,
});
expect(finalChunk?.id?.length).toBeGreaterThan(1);
expect(finalChunk?.id?.startsWith("chatcmpl-")).toBe(true);
// let finalChunk: AIMessageChunk | undefined;
// for await (const chunk of await model.stream("Hell")) {
// if (!finalChunk) {
// finalChunk = chunk;
// } else {
// finalChunk = finalChunk.concat(chunk);
// }
// }
// console.log({
// streamId: finalChunk?.id,
// });
// expect(finalChunk?.id?.length).toBeGreaterThan(1);
});

test("Test ChatBedrockConverse can invoke tools", async () => {
Expand Down

0 comments on commit 49fe576

Please sign in to comment.