From 49fe57658588af522bf4a27a9b05b082b3ea3f09 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Thu, 27 Jun 2024 13:42:07 -0700 Subject: [PATCH] final --- libs/langchain-aws/src/chat_models.ts | 24 +++++++- libs/langchain-aws/src/common.ts | 18 +++++- .../src/tests/chat_models.int.test.ts | 61 ++++++++----------- 3 files changed, 63 insertions(+), 40 deletions(-) diff --git a/libs/langchain-aws/src/chat_models.ts b/libs/langchain-aws/src/chat_models.ts index b36266aa5919..3df53aeefa23 100644 --- a/libs/langchain-aws/src/chat_models.ts +++ b/libs/langchain-aws/src/chat_models.ts @@ -111,11 +111,21 @@ export interface ChatBedrockConverseInput * @link https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html */ additionalModelRequestFields?: __DocumentType; + /** + * Whether or not to include usage data, like token counts + * in the streamed response chunks. Passing as a call option will + * take precedence over the class-level setting. + * @default true + */ + streamUsage?: boolean; } export interface ChatBedrockConverseCallOptions extends BaseLanguageModelCallOptions, - Pick { + Pick< + ChatBedrockConverseInput, + "additionalModelRequestFields" | "streamUsage" + > { /** * A list of stop sequences. A stop sequence is a sequence of characters that causes * the model to stop generating the response. @@ -181,6 +191,8 @@ export class ChatBedrockConverse additionalModelRequestFields?: __DocumentType; + streamUsage = true; + client: BedrockRuntimeClient; constructor(fields?: ChatBedrockConverseInput) { @@ -231,6 +243,7 @@ export class ChatBedrockConverse this.endpointHost = rest?.endpointHost; this.topP = rest?.topP; this.additionalModelRequestFields = rest?.additionalModelRequestFields; + this.streamUsage = rest?.streamUsage ?? this.streamUsage; } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { @@ -369,7 +382,10 @@ export class ChatBedrockConverse const { converseMessages, converseSystem } = convertToConverseMessages(messages); const params = this.invocationParams(options); - + let { streamUsage } = this; + if (options.streamUsage !== undefined) { + streamUsage = options.streamUsage; + } const command = new ConverseStreamCommand({ modelId: this.model, messages: converseMessages, @@ -388,7 +404,9 @@ export class ChatBedrockConverse yield textChatGeneration; await runManager?.handleLLMNewToken(textChatGeneration.text); } else if (chunk.metadata) { - yield handleConverseStreamMetadata(chunk.metadata); + yield handleConverseStreamMetadata(chunk.metadata, { + streamUsage, + }); } else { yield new ChatGenerationChunk({ text: "", diff --git a/libs/langchain-aws/src/common.ts b/libs/langchain-aws/src/common.ts index b440fdb9dc56..945fa932f43f 100644 --- a/libs/langchain-aws/src/common.ts +++ b/libs/langchain-aws/src/common.ts @@ -284,6 +284,15 @@ export function convertConverseMessageToLangChainMessage( `Unsupported message role received in ChatBedrockConverse response: ${message.role}` ); } + let requestId: string | undefined; + if ( + "$metadata" in responseMetadata && + responseMetadata.$metadata && + typeof responseMetadata.$metadata === "object" && + "requestId" in responseMetadata.$metadata + ) { + requestId = responseMetadata.$metadata.requestId as string; + } let tokenUsage: UsageMetadata | undefined; if (responseMetadata.usage) { const input_tokens = responseMetadata.usage.inputTokens ?? 0; @@ -305,6 +314,7 @@ export function convertConverseMessageToLangChainMessage( content: message.content[0].text, response_metadata: responseMetadata, usage_metadata: tokenUsage, + id: requestId, }); } else { const toolCalls: ToolCall[] = []; @@ -333,6 +343,7 @@ export function convertConverseMessageToLangChainMessage( tool_calls: toolCalls.length ? toolCalls : undefined, response_metadata: responseMetadata, usage_metadata: tokenUsage, + id: requestId, }); } } @@ -397,7 +408,10 @@ export function handleConverseStreamContentBlockStart( } export function handleConverseStreamMetadata( - metadata: ConverseStreamMetadataEvent + metadata: ConverseStreamMetadataEvent, + extra: { + streamUsage: boolean; + } ): ChatGenerationChunk { const inputTokens = metadata.usage?.inputTokens ?? 0; const outputTokens = metadata.usage?.outputTokens ?? 0; @@ -410,7 +424,7 @@ export function handleConverseStreamMetadata( text: "", message: new AIMessageChunk({ content: "", - usage_metadata, + usage_metadata: extra.streamUsage ? usage_metadata : undefined, response_metadata: { // Use the same key as returned from the Converse API metadata, diff --git a/libs/langchain-aws/src/tests/chat_models.int.test.ts b/libs/langchain-aws/src/tests/chat_models.int.test.ts index 17312c50322b..bd8d547b90c2 100644 --- a/libs/langchain-aws/src/tests/chat_models.int.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.int.test.ts @@ -13,6 +13,7 @@ const baseConstructorArgs: Partial< secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, }, + maxRetries: 1, }; test("Test ChatBedrockConverse can invoke", async () => { @@ -81,25 +82,6 @@ test("Test ChatBedrockConverse with stop", async () => { expect(res.content).not.toContain("world"); }); -// AbortSignal not implemented yet. -test.skip("Test ChatBedrockConverse stream method with abort", async () => { - await expect(async () => { - const model = new ChatBedrockConverse({ - ...baseConstructorArgs, - maxTokens: 100, - }); - const stream = await model.stream( - "How is your day going? Be extremely verbose.", - { - signal: AbortSignal.timeout(500), - } - ); - for await (const chunk of stream) { - console.log(chunk); - } - }).rejects.toThrow(); -}); - test("Test ChatBedrockConverse stream method with early break", async () => { const model = new ChatBedrockConverse({ ...baseConstructorArgs, @@ -119,7 +101,10 @@ test("Test ChatBedrockConverse stream method with early break", async () => { }); test("Streaming tokens can be found in usage_metadata field", async () => { - const model = new ChatBedrockConverse(); + const model = new ChatBedrockConverse({ + ...baseConstructorArgs, + maxTokens: 5, + }); const response = await model.stream("Hello, how are you?"); let finalResult: AIMessageChunk | undefined; for await (const chunk of response) { @@ -140,28 +125,34 @@ test("Streaming tokens can be found in usage_metadata field", async () => { }); test("populates ID field on AIMessage", async () => { - const model = new ChatBedrockConverse(); + const model = new ChatBedrockConverse({ + ...baseConstructorArgs, + maxTokens: 5, + }); const response = await model.invoke("Hell"); console.log({ invokeId: response.id, }); expect(response.id?.length).toBeGreaterThan(1); - expect(response?.id?.startsWith("chatcmpl-")).toBe(true); + + /** + * Bedrock Converse does not include an ID in + * the response of a streaming call. + */ // Streaming - let finalChunk: AIMessageChunk | undefined; - for await (const chunk of await model.stream("Hell")) { - if (!finalChunk) { - finalChunk = chunk; - } else { - finalChunk = finalChunk.concat(chunk); - } - } - console.log({ - streamId: finalChunk?.id, - }); - expect(finalChunk?.id?.length).toBeGreaterThan(1); - expect(finalChunk?.id?.startsWith("chatcmpl-")).toBe(true); + // let finalChunk: AIMessageChunk | undefined; + // for await (const chunk of await model.stream("Hell")) { + // if (!finalChunk) { + // finalChunk = chunk; + // } else { + // finalChunk = finalChunk.concat(chunk); + // } + // } + // console.log({ + // streamId: finalChunk?.id, + // }); + // expect(finalChunk?.id?.length).toBeGreaterThan(1); }); test("Test ChatBedrockConverse can invoke tools", async () => {