Skip to content

Commit

Permalink
groq[minor]: Fix streaming metadata back to client (#6573)
Browse files Browse the repository at this point in the history
* groq[minor]: Fix streaming metadata back to client

* chore: lint files

* chore: lint files

* implemented usage metadata for invoke too
  • Loading branch information
bracesproul authored Aug 20, 2024
1 parent 6786832 commit a402534
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 20 deletions.
52 changes: 48 additions & 4 deletions libs/langchain-groq/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
OpenAIToolCall,
isAIMessage,
BaseMessageChunk,
UsageMetadata,
} from "@langchain/core/messages";
import {
ChatGeneration,
Expand Down Expand Up @@ -179,7 +180,8 @@ function convertMessagesToGroqParams(
}

function groqResponseToChatMessage(
message: ChatCompletionsAPI.ChatCompletionMessage
message: ChatCompletionsAPI.ChatCompletionMessage,
usageMetadata?: UsageMetadata
): BaseMessage {
const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as
| OpenAIToolCall[]
Expand All @@ -201,6 +203,7 @@ function groqResponseToChatMessage(
additional_kwargs: { tool_calls: rawToolCalls },
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
usage_metadata: usageMetadata,
});
}
default:
Expand All @@ -226,7 +229,8 @@ function _convertDeltaToolCallToToolCallChunk(
function _convertDeltaToMessageChunk(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
delta: Record<string, any>,
index: number
index: number,
xGroq?: ChatCompletionsAPI.ChatCompletionChunk.XGroq
): {
message: BaseMessageChunk;
toolCallData?: {
Expand All @@ -250,6 +254,18 @@ function _convertDeltaToMessageChunk(
} else {
additional_kwargs = {};
}

let usageMetadata: UsageMetadata | undefined;
let groqMessageId: string | undefined;
if (xGroq?.usage) {
usageMetadata = {
input_tokens: xGroq.usage.prompt_tokens,
output_tokens: xGroq.usage.completion_tokens,
total_tokens: xGroq.usage.total_tokens,
};
groqMessageId = xGroq.id;
}

if (role === "user") {
return {
message: new HumanMessageChunk({ content }),
Expand All @@ -270,6 +286,8 @@ function _convertDeltaToMessageChunk(
index: tc.index,
}))
: undefined,
usage_metadata: usageMetadata,
id: groqMessageId,
}),
toolCallData: toolCallChunks
? toolCallChunks.map((tc) => ({
Expand Down Expand Up @@ -771,7 +789,10 @@ export class ChatGroq extends BaseChatModel<
index: number;
type: "tool_call_chunk";
}[] = [];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let responseMetadata: Record<string, any> | undefined;
for await (const data of response) {
responseMetadata = data;
const choice = data?.choices[0];
if (!choice) {
continue;
Expand All @@ -787,7 +808,8 @@ export class ChatGroq extends BaseChatModel<
...choice.delta,
role,
} ?? {},
choice.index
choice.index,
data.x_groq
);

if (toolCallData) {
Expand Down Expand Up @@ -818,6 +840,19 @@ export class ChatGroq extends BaseChatModel<
void runManager?.handleLLMNewToken(chunk.text ?? "");
}

if (responseMetadata) {
if ("choices" in responseMetadata) {
delete responseMetadata.choices;
}
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: "",
response_metadata: responseMetadata,
}),
text: "",
});
}

if (options.signal?.aborted) {
throw new Error("AbortError");
}
Expand Down Expand Up @@ -898,10 +933,19 @@ export class ChatGroq extends BaseChatModel<
if ("choices" in data && data.choices) {
for (const part of (data as ChatCompletion).choices) {
const text = part.message?.content ?? "";
let usageMetadata: UsageMetadata | undefined;
if (tokenUsage.totalTokens !== undefined) {
usageMetadata = {
input_tokens: tokenUsage.promptTokens ?? 0,
output_tokens: tokenUsage.completionTokens ?? 0,
total_tokens: tokenUsage.totalTokens,
};
}
const generation: ChatGeneration = {
text,
message: groqResponseToChatMessage(
part.message ?? { role: "assistant" }
part.message ?? { role: "assistant" },
usageMetadata
),
};
generation.generationInfo = {
Expand Down
16 changes: 0 additions & 16 deletions libs/langchain-groq/src/tests/chat_models.standard.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,6 @@ class ChatGroqStandardIntegrationTests extends ChatModelIntegrationTests<
});
}

async testUsageMetadataStreaming() {
this.skipTestMessage(
"testUsageMetadataStreaming",
"ChatGroq",
"Streaming tokens is not currently supported."
);
}

async testUsageMetadata() {
this.skipTestMessage(
"testUsageMetadata",
"ChatGroq",
"Usage metadata tokens is not currently supported."
);
}

async testToolMessageHistoriesListContent() {
this.skipTestMessage(
"testToolMessageHistoriesListContent",
Expand Down

0 comments on commit a402534

Please sign in to comment.