Skip to content

Commit

Permalink
Fixed git cherry-pick mistakes in chat_models
Browse files Browse the repository at this point in the history
  • Loading branch information
BaNg-W committed Nov 16, 2024
1 parent b702fbe commit 9213bfd
Showing 1 changed file with 35 additions and 53 deletions.
88 changes: 35 additions & 53 deletions libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,7 @@ interface TokenUsage {
totalTokens?: number;
}

export type MistralAIToolChoice = "auto" | "any" | "none";

type MistralAIToolInput = { type: string; function: MistralAIFunction };

type ChatMistralAIToolType =
| MistralAIToolInput
| MistralAITool
| BindToolsInput;
type ChatMistralAIToolType = MistralAIToolCall | MistralAITool | BindToolsInput;

export interface ChatMistralAICallOptions
extends Omit<BaseLanguageModelCallOptions, "stop"> {
Expand Down Expand Up @@ -316,14 +309,14 @@ function convertMessagesToMistralMessages(
);
};

const getTools = (message: BaseMessage): MistralAIToolCalls[] | undefined => {
const getTools = (message: BaseMessage): MistralAIToolCall[] | undefined => {
if (isAIMessage(message) && !!message.tool_calls?.length) {
return message.tool_calls
.map((toolCall) => ({
...toolCall,
id: _convertToolCallIdToMistralCompatible(toolCall.id ?? ""),
}))
.map(convertLangChainToolCallToOpenAI) as MistralAIToolCalls[];
.map(convertLangChainToolCallToOpenAI) as MistralAIToolCall[];
}
return undefined;
};
Expand Down Expand Up @@ -398,19 +391,12 @@ function mistralAIResponseToChatMessage(
content,
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
additional_kwargs: {
tool_calls: rawToolCalls.length
? rawToolCalls.map((toolCall) => ({
...toolCall,
type: "function",
}))
: undefined,
},
additional_kwargs: {},
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
input_tokens: usage.promptTokens,
output_tokens: usage.completionTokens,
total_tokens: usage.totalTokens,
}
: undefined,
});
Expand All @@ -434,9 +420,9 @@ function _convertDeltaToMessageChunk(
content: "",
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
input_tokens: usage.promptTokens,
output_tokens: usage.completionTokens,
total_tokens: usage.totalTokens,
}
: undefined,
});
Expand Down Expand Up @@ -466,13 +452,15 @@ function _convertDeltaToMessageChunk(
let additional_kwargs;
const toolCallChunks: ToolCallChunk[] = [];
if (rawToolCallChunksWithIndex !== undefined) {
additional_kwargs = {
tool_calls: rawToolCallChunksWithIndex,
};
for (const rawToolCallChunk of rawToolCallChunksWithIndex) {
const rawArgs = rawToolCallChunk.function?.arguments;
const args =
rawArgs === undefined || typeof rawArgs === "string"
? rawArgs
: JSON.stringify(rawArgs);
toolCallChunks.push({
name: rawToolCallChunk.function?.name,
args: rawToolCallChunk.function?.arguments,
args,
id: rawToolCallChunk.id,
index: rawToolCallChunk.index,
type: "tool_call_chunk",
Expand All @@ -491,9 +479,9 @@ function _convertDeltaToMessageChunk(
additional_kwargs,
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
input_tokens: usage.promptTokens,
output_tokens: usage.completionTokens,
total_tokens: usage.totalTokens,
}
: undefined,
});
Expand Down Expand Up @@ -938,7 +926,6 @@ export class ChatMistralAI<
this.temperature = fields?.temperature ?? this.temperature;
this.topP = fields?.topP ?? this.topP;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.safeMode = fields?.safeMode ?? this.safeMode;
this.safePrompt = fields?.safePrompt ?? this.safePrompt;
this.randomSeed = fields?.seed ?? fields?.randomSeed ?? this.seed;
this.seed = this.randomSeed;
Expand Down Expand Up @@ -999,14 +986,13 @@ export class ChatMistralAI<
const mistralAITools: Array<MistralAITool> | undefined = tools?.length
? _convertToolToMistralTool(tools)
: undefined;
const params: Omit<ChatRequest, "messages"> = {
const params: Omit<MistralAIChatCompletionRequest, "messages"> = {
model: this.model,
tools: mistralAITools,
temperature: this.temperature,
maxTokens: this.maxTokens,
topP: this.topP,
randomSeed: this.seed,
safeMode: this.safeMode,
safePrompt: this.safePrompt,
toolChoice: tool_choice,
responseFormat: response_format,
Expand Down Expand Up @@ -1035,10 +1021,10 @@ export class ChatMistralAI<
async completionWithRetry(
input: MistralAIChatCompletionStreamRequest,
streaming: true
): Promise<AsyncGenerator<ChatCompletionResponseChunk>>;
): Promise<AsyncIterable<MistralAIChatCompletionEvent>>;

async completionWithRetry(
input: ChatRequest,
input: MistralAIChatCompletionRequest,
streaming: false
): Promise<MistralAIChatCompletionResponse>;

Expand Down Expand Up @@ -1067,14 +1053,19 @@ export class ChatMistralAI<
| MistralAIChatCompletionResponse
| AsyncIterable<MistralAIChatCompletionEvent>;
if (streaming) {
res = client.chatStream(input);
res = await client.chat.stream(input);
} else {
res = await client.chat(input);
res = await client.chat.complete(input);
}
return res;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
if (e.message?.includes("status: 400")) {
console.log(e, e.status, e.code, e.statusCode, e.message);
if (
e.message?.includes("status: 400") ||
e.message?.toLowerCase().includes("status 400") ||
e.message?.includes("validation failed")
) {
e.status = 400;
}
throw e;
Expand Down Expand Up @@ -1123,11 +1114,8 @@ export class ChatMistralAI<
// Not streaming, so we can just call the API once.
const response = await this.completionWithRetry(input, false);

const {
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
total_tokens: totalTokens,
} = response?.usage ?? {};
const { completionTokens, promptTokens, totalTokens } =
response?.usage ?? {};

if (completionTokens) {
tokenUsage.completionTokens =
Expand Down Expand Up @@ -1158,8 +1146,8 @@ export class ChatMistralAI<
text,
message: mistralAIResponseToChatMessage(part, response?.usage),
};
if (part.finish_reason) {
generation.generationInfo = { finish_reason: part.finish_reason };
if (part.finishReason) {
generation.generationInfo = { finishReason: part.finishReason };
}
generations.push(generation);
}
Expand All @@ -1182,7 +1170,7 @@ export class ChatMistralAI<
};

const streamIterable = await this.completionWithRetry(input, true);
for await (const data of streamIterable) {
for await (const { data } of streamIterable) {
if (options.signal?.aborted) {
throw new Error("AbortError");
}
Expand Down Expand Up @@ -1445,12 +1433,6 @@ export class ChatMistralAI<
parsedWithFallback,
]);
}

/** @ignore */
private async imports() {
const { default: MistralClient } = await import("@mistralai/mistralai");
return { MistralClient };
}
}

function isZodSchema<
Expand Down

0 comments on commit 9213bfd

Please sign in to comment.