Skip to content

Commit

Permalink
mistral[minor]: Populate usage metadata for mistral (#5751)
Browse files Browse the repository at this point in the history
* mistral[minor]: Populate usage metadata for mistral

* chore: lint files

* chore: lint files

* fix build issue

* bump min core dep
  • Loading branch information
bracesproul authored Jun 13, 2024
1 parent ce4b8ef commit 4a42922
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 27 deletions.
2 changes: 1 addition & 1 deletion libs/langchain-mistralai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@langchain/core": ">0.1.56 <0.3.0",
"@langchain/core": ">=0.2.5 <0.3.0",
"@mistralai/mistralai": "^0.4.0",
"uuid": "^9.0.0",
"zod": "^3.22.4",
Expand Down
63 changes: 54 additions & 9 deletions libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
ChatRequest,
Tool as MistralAITool,
Message as MistralAIMessage,
TokenUsage as MistralAITokenUsage,
} from "@mistralai/mistralai";
import {
MessageType,
Expand Down Expand Up @@ -80,14 +81,21 @@ interface MistralAICallOptions
};
tools: StructuredToolInterface[] | MistralAIToolInput[] | MistralAITool[];
tool_choice?: MistralAIToolChoice;
/**
* Whether or not to include token usage in the stream.
* @default {true}
*/
streamUsage?: boolean;
}

export interface ChatMistralAICallOptions extends MistralAICallOptions {}

/**
* Input to chat model class.
*/
export interface ChatMistralAIInput extends BaseChatModelParams {
export interface ChatMistralAIInput
extends BaseChatModelParams,
Pick<ChatMistralAICallOptions, "streamUsage"> {
/**
* The API key to use.
* @default {process.env.MISTRAL_API_KEY}
Expand Down Expand Up @@ -216,7 +224,8 @@ function convertMessagesToMistralMessages(
}

function mistralAIResponseToChatMessage(
choice: ChatCompletionResponse["choices"][0]
choice: ChatCompletionResponse["choices"][0],
usage?: MistralAITokenUsage
): BaseMessage {
const { message } = choice;
// MistralAI SDK does not include tool_calls in the non
Expand Down Expand Up @@ -254,19 +263,41 @@ function mistralAIResponseToChatMessage(
}))
: undefined,
},
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
}
: undefined,
});
}
default:
return new HumanMessage(message.content ?? "");
}
}

function _convertDeltaToMessageChunk(delta: {
role?: string | undefined;
content?: string | undefined;
tool_calls?: MistralAIToolCalls[] | undefined;
}) {
function _convertDeltaToMessageChunk(
delta: {
role?: string | undefined;
content?: string | undefined;
tool_calls?: MistralAIToolCalls[] | undefined;
},
usage?: MistralAITokenUsage | null
) {
if (!delta.content && !delta.tool_calls) {
if (usage) {
return new AIMessageChunk({
content: "",
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
}
: undefined,
});
}
return null;
}
// Our merge additional kwargs util function will throw unless there
Expand Down Expand Up @@ -313,6 +344,13 @@ function _convertDeltaToMessageChunk(delta: {
content,
tool_call_chunks: toolCallChunks,
additional_kwargs,
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
}
: undefined,
});
} else if (role === "tool") {
return new ToolMessageChunk({
Expand Down Expand Up @@ -389,6 +427,8 @@ export class ChatMistralAI<

lc_serializable = true;

streamUsage = true;

constructor(fields?: ChatMistralAIInput) {
super(fields ?? {});
const apiKey = fields?.apiKey ?? getEnvironmentVariable("MISTRAL_API_KEY");
Expand All @@ -409,6 +449,7 @@ export class ChatMistralAI<
this.seed = this.randomSeed;
this.modelName = fields?.model ?? fields?.modelName ?? this.model;
this.model = this.modelName;
this.streamUsage = fields?.streamUsage ?? this.streamUsage;
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
Expand Down Expand Up @@ -600,7 +641,7 @@ export class ChatMistralAI<
const text = part.message?.content ?? "";
const generation: ChatGeneration = {
text,
message: mistralAIResponseToChatMessage(part),
message: mistralAIResponseToChatMessage(part, response?.usage),
};
if (part.finish_reason) {
generation.generationInfo = { finish_reason: part.finish_reason };
Expand Down Expand Up @@ -643,7 +684,11 @@ export class ChatMistralAI<
prompt: 0,
completion: choice.index ?? 0,
};
const message = _convertDeltaToMessageChunk(delta);
const shouldStreamUsage = this.streamUsage || options.streamUsage;
const message = _convertDeltaToMessageChunk(
delta,
shouldStreamUsage ? data.usage : null
);
if (message === null) {
// Do not yield a chunk if the message is empty
continue;
Expand Down
66 changes: 66 additions & 0 deletions libs/langchain-mistralai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { DynamicStructuredTool, StructuredTool } from "@langchain/core/tools";
import { z } from "zod";
import {
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
ToolMessage,
Expand Down Expand Up @@ -916,3 +917,68 @@ describe("codestral-latest", () => {
console.log(parsedArgs.code);
});
});

test("Stream token count usage_metadata", async () => {
const model = new ChatMistralAI({
model: "codestral-latest",
temperature: 0,
});
let res: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
if (!res) {
res = chunk;
} else {
res = res.concat(chunk);
}
}
console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBe(13);
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});

test("streamUsage excludes token usage", async () => {
const model = new ChatMistralAI({
model: "codestral-latest",
temperature: 0,
streamUsage: false,
});
let res: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
if (!res) {
res = chunk;
} else {
res = res.concat(chunk);
}
}
console.log(res);
expect(res?.usage_metadata).not.toBeDefined();
});

test("Invoke token count usage_metadata", async () => {
const model = new ChatMistralAI({
model: "codestral-latest",
temperature: 0,
});
const res = await model.invoke("Why is the sky blue? Be concise.");
console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBe(13);
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,6 @@ class ChatMistralAIStandardIntegrationTests extends ChatModelIntegrationTests<
functionId: "123456789",
});
}

async testUsageMetadataStreaming() {
this.skipTestMessage(
"testUsageMetadataStreaming",
"ChatMistralAI",
"Streaming tokens is not currently supported."
);
}

async testUsageMetadata() {
this.skipTestMessage(
"testUsageMetadata",
"ChatMistralAI",
"Usage metadata tokens is not currently supported."
);
}
}

const testClass = new ChatMistralAIStandardIntegrationTests();
Expand Down
2 changes: 1 addition & 1 deletion yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10444,7 +10444,7 @@ __metadata:
resolution: "@langchain/mistralai@workspace:libs/langchain-mistralai"
dependencies:
"@jest/globals": ^29.5.0
"@langchain/core": ">0.1.56 <0.3.0"
"@langchain/core": ">=0.2.5 <0.3.0"
"@langchain/scripts": ~0.0.14
"@langchain/standard-tests": 0.0.0
"@mistralai/mistralai": ^0.4.0
Expand Down

0 comments on commit 4a42922

Please sign in to comment.