Skip to content

Commit

Permalink
implemented mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jun 12, 2024
1 parent 76fd7fe commit 8230a42
Showing 1 changed file with 32 additions and 35 deletions.
67 changes: 32 additions & 35 deletions libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { v4 as uuidv4 } from "uuid";
import {
ChatCompletionResponse,
Function as MistralAIFunction,
ToolCalls as MistralAIToolCalls,
ResponseFormat,
ChatCompletionResponseChunk,
Expand Down Expand Up @@ -29,6 +28,7 @@ import type {
StructuredOutputMethodParams,
StructuredOutputMethodOptions,
FunctionDefinition,
ToolDefinition,
} from "@langchain/core/language_models/base";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
Expand Down Expand Up @@ -63,6 +63,8 @@ import {
RunnableSequence,
} from "@langchain/core/runnables";
import { zodToJsonSchema } from "zod-to-json-schema";
import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { isOpenAITool } from "@langchain/core/utils/is_openai_tool";

interface TokenUsage {
completionTokens?: number;
Expand All @@ -72,18 +74,14 @@ interface TokenUsage {

export type MistralAIToolChoice = "auto" | "any" | "none";

type MistralAIToolInput = { type: string; function: MistralAIFunction };
interface MistralAICallOptions
extends Omit<BaseLanguageModelCallOptions, "stop"> {
export interface ChatMistralAICallOptions extends Omit<BaseLanguageModelCallOptions, "stop"> {
response_format?: {
type: "text" | "json_object";
};
tools: StructuredToolInterface[] | MistralAIToolInput[] | MistralAITool[];
tools?: (StructuredToolInterface | MistralAITool | ToolDefinition | Record<string, unknown>)[];
tool_choice?: MistralAIToolChoice;
}

export interface ChatMistralAICallOptions extends MistralAICallOptions {}

/**
* Input to chat model class.
*/
Expand Down Expand Up @@ -346,11 +344,34 @@ function _convertStructuredToolToMistralTool(
});
}

function formatTools(tools: Required<ChatMistralAICallOptions["tools"]>): MistralAITool[] {
if (!tools || !tools.length) {
return [];
}
return tools.map((tool) => {
if (isStructuredTool(tool)) {
return _convertStructuredToolToMistralTool([tool as StructuredTool]);
}
if (isOpenAITool(tool)) {
return {
type: "function",
function: {
name: tool.function.name,
description: tool.function.description ?? `Tool: ${tool.function.name}`,
parameters: tool.function.parameters,
},
} as MistralAITool;
}
return tool as MistralAITool;
})
.flat();
}

/**
* Integration with a chat model.
*/
export class ChatMistralAI<
CallOptions extends MistralAICallOptions = MistralAICallOptions
CallOptions extends ChatMistralAICallOptions = ChatMistralAICallOptions
>
extends BaseChatModel<CallOptions, AIMessageChunk>
implements ChatMistralAIInput
Expand Down Expand Up @@ -433,24 +454,7 @@ export class ChatMistralAI<
options?: this["ParsedCallOptions"]
): Omit<ChatRequest, "messages"> {
const { response_format, tools, tool_choice } = options ?? {};
const mistralAITools: Array<MistralAITool> | undefined = tools
?.map((tool) => {
if ("lc_namespace" in tool) {
return _convertStructuredToolToMistralTool([tool]);
}
if (!tool.function.description) {
return {
type: "function",
function: {
name: tool.function.name,
description: `Tool: ${tool.function.name}`,
parameters: tool.function.parameters,
},
} as MistralAITool;
}
return tool as MistralAITool;
})
.flat();
const mistralAITools = formatTools(tools) ?? undefined;
const params: Omit<ChatRequest, "messages"> = {
model: this.model,
tools: mistralAITools,
Expand All @@ -467,17 +471,10 @@ export class ChatMistralAI<
}

override bindTools(
tools: (Record<string, unknown> | StructuredToolInterface)[],
tools: (StructuredToolInterface | MistralAITool | ToolDefinition | Record<string, unknown>)[],
kwargs?: Partial<CallOptions>
): Runnable<BaseLanguageModelInput, AIMessageChunk, CallOptions> {
const mistralAITools = tools
?.map((tool) => {
if ("lc_namespace" in tool) {
return _convertStructuredToolToMistralTool([tool as StructuredTool]);
}
return tool;
})
.flat();
const mistralAITools = formatTools(tools)
return this.bind({
tools: mistralAITools,
...kwargs,
Expand Down

0 comments on commit 8230a42

Please sign in to comment.