diff --git a/.yarn/cache/@mistralai-mistralai-npm-0.1.3-934551f985-3f8299811b.zip b/.yarn/cache/@mistralai-mistralai-npm-0.4.0-843348d71c-1b03fc0b55.zip similarity index 93% rename from .yarn/cache/@mistralai-mistralai-npm-0.1.3-934551f985-3f8299811b.zip rename to .yarn/cache/@mistralai-mistralai-npm-0.4.0-843348d71c-1b03fc0b55.zip index e786ddf800a6..4c7f5fabf34a 100644 Binary files a/.yarn/cache/@mistralai-mistralai-npm-0.1.3-934551f985-3f8299811b.zip and b/.yarn/cache/@mistralai-mistralai-npm-0.4.0-843348d71c-1b03fc0b55.zip differ diff --git a/docs/core_docs/docs/integrations/chat/mistral.mdx b/docs/core_docs/docs/integrations/chat/mistral.mdx index 645e0e56b012..2566ee06c223 100644 --- a/docs/core_docs/docs/integrations/chat/mistral.mdx +++ b/docs/core_docs/docs/integrations/chat/mistral.mdx @@ -7,21 +7,16 @@ import CodeBlock from "@theme/CodeBlock"; # ChatMistralAI [Mistral AI](https://mistral.ai/) is a research organization and hosting platform for LLMs. -They're most known for their family of 7B models ([`mistral7b` // `mistral-tiny`](https://mistral.ai/news/announcing-mistral-7b/), [`mixtral8x7b` // `mistral-small`](https://mistral.ai/news/mixtral-of-experts/)). - The LangChain implementation of Mistral's models uses their hosted generation API, making it easier to access their models without needing to run them locally. -## Models - -Mistral's API offers access to two of their open source, and proprietary models: +:::tip +Want to run Mistral's models locally? Check out our [Ollama integration](/docs/integrations/chat/ollama). +::: -- `open-mistral-7b` (aka `mistral-tiny-2312`) -- `open-mixtral-8x7b` (aka `mistral-small-2312`) -- `mistral-small-latest` (aka `mistral-small-2402`) (default) -- `mistral-medium-latest` (aka `mistral-medium-2312`) -- `mistral-large-latest` (aka `mistral-large-2402`) +## Models -See [this page](https://docs.mistral.ai/guides/model-selection/) for an up to date list. +Mistral's API offers access to two of their open source, and proprietary models. +See [this page](https://docs.mistral.ai/getting-started/models/) for an up to date list. ## Setup diff --git a/docs/core_docs/docs/integrations/llms/mistral.ipynb b/docs/core_docs/docs/integrations/llms/mistral.ipynb new file mode 100644 index 000000000000..6cd3bfec377a --- /dev/null +++ b/docs/core_docs/docs/integrations/llms/mistral.ipynb @@ -0,0 +1,160 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MistralAI\n", + "\n", + "```{=mdx}\n", + ":::tip\n", + "Want to run Mistral's models locally? Check out our [Ollama integration](/docs/integrations/chat/ollama).\n", + ":::\n", + "```\n", + "\n", + "Here's how you can initialize an `MistralAI` LLM instance:\n", + "\n", + "```{=mdx}\n", + "import IntegrationInstallTooltip from \"@mdx_components/integration_install_tooltip.mdx\";\n", + "import Npm2Yarn from \"@theme/Npm2Yarn\";\n", + "\n", + "\n", + "\n", + "\n", + " @langchain/mistralai\n", + "\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "console.log('hello world');\n", + "```\n", + "This will output 'hello world' to the console.\n" + ] + } + ], + "source": [ + "import { MistralAI } from \"@langchain/mistralai\";\n", + "\n", + "const model = new MistralAI({\n", + " model: \"codestral-latest\", // Defaults to \"codestral-latest\" if no model provided.\n", + " temperature: 0,\n", + " apiKey: \"YOUR-API-KEY\", // In Node.js defaults to process.env.MISTRAL_API_KEY\n", + "});\n", + "const res = await model.invoke(\n", + " \"You can print 'hello world' to the console in javascript like this:\\n```javascript\"\n", + ");\n", + "console.log(res);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since the Mistral LLM is a completions model, they also allow you to insert a `suffix` to the prompt. Suffixes can be passed via the call options when invoking a model like so:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "console.log('hello world');\n", + "```\n" + ] + } + ], + "source": [ + "const res = await model.invoke(\n", + " \"You can print 'hello world' to the console in javascript like this:\\n```javascript\", {\n", + " suffix: \"```\"\n", + " }\n", + ");\n", + "console.log(res);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As seen in the first example, the model generated the requested `console.log('hello world')` code snippet, but also included extra unwanted text. By adding a suffix, we can constrain the model to only complete the prompt up to the suffix (in this case, three backticks). This allows us to easily parse the completion and extract only the desired response without the suffix using a custom output parser." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "console.log('hello world');\n", + "\n" + ] + } + ], + "source": [ + "import { MistralAI } from \"@langchain/mistralai\";\n", + "\n", + "const model = new MistralAI({\n", + " model: \"codestral-latest\",\n", + " temperature: 0,\n", + " apiKey: \"YOUR-API-KEY\",\n", + "});\n", + "\n", + "const suffix = \"```\";\n", + "\n", + "const customOutputParser = (input: string) => {\n", + " if (input.includes(suffix)) {\n", + " return input.split(suffix)[0];\n", + " }\n", + " throw new Error(\"Input does not contain suffix.\")\n", + "};\n", + "\n", + "const res = await model.invoke(\n", + " \"You can print 'hello world' to the console in javascript like this:\\n```javascript\", {\n", + " suffix,\n", + " }\n", + ");\n", + "\n", + "console.log(customOutputParser(res));" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "TypeScript", + "language": "typescript", + "name": "tslab" + }, + "language_info": { + "codemirror_mode": { + "mode": "typescript", + "name": "javascript", + "typescript": true + }, + "file_extension": ".ts", + "mimetype": "text/typescript", + "name": "typescript", + "version": "3.7.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/core_docs/vercel.json b/docs/core_docs/vercel.json index 818afaff2e0e..b808cbe3e215 100644 --- a/docs/core_docs/vercel.json +++ b/docs/core_docs/vercel.json @@ -33,6 +33,10 @@ "source": "/v0.2/docs(/?)", "destination": "/v0.2/docs/introduction/" }, + { + "source": "/docs/integrations/:path(.*/?)*", + "destination": "/v0.2/docs/integrations/:path*" + }, { "source": "/docs/:path(.*/?)*", "destination": "/v0.1/docs/:path*" diff --git a/examples/src/models/chat/chat_mistralai_tools.ts b/examples/src/models/chat/chat_mistralai_tools.ts index 7e6e9ebf3928..fce2c9c69a54 100644 --- a/examples/src/models/chat/chat_mistralai_tools.ts +++ b/examples/src/models/chat/chat_mistralai_tools.ts @@ -33,7 +33,7 @@ class CalculatorTool extends StructuredTool { const model = new ChatMistralAI({ apiKey: process.env.MISTRAL_API_KEY, - model: "mistral-large", + model: "mistral-large-latest", }); // Bind the tool to the model diff --git a/examples/src/models/chat/chat_mistralai_wsa.ts b/examples/src/models/chat/chat_mistralai_wsa.ts index 37d36fecca64..0e429fe7e364 100644 --- a/examples/src/models/chat/chat_mistralai_wsa.ts +++ b/examples/src/models/chat/chat_mistralai_wsa.ts @@ -14,7 +14,7 @@ const calculatorSchema = z const model = new ChatMistralAI({ apiKey: process.env.MISTRAL_API_KEY, - model: "mistral-large", + model: "mistral-large-latest", }); // Pass the schema and tool name to the withStructuredOutput method diff --git a/examples/src/models/chat/chat_mistralai_wsa_json.ts b/examples/src/models/chat/chat_mistralai_wsa_json.ts index 4a9553055edd..3dbe3508bb39 100644 --- a/examples/src/models/chat/chat_mistralai_wsa_json.ts +++ b/examples/src/models/chat/chat_mistralai_wsa_json.ts @@ -21,7 +21,7 @@ const calculatorJsonSchema = { const model = new ChatMistralAI({ apiKey: process.env.MISTRAL_API_KEY, - model: "mistral-large", + model: "mistral-large-latest", }); // Pass the schema and tool name to the withStructuredOutput method diff --git a/langchain-core/package.json b/langchain-core/package.json index e285b65521c8..565b068605da 100644 --- a/langchain-core/package.json +++ b/langchain-core/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/core", - "version": "0.2.3", + "version": "0.2.4", "description": "Core LangChain.js abstractions and schemas", "type": "module", "engines": { diff --git a/libs/langchain-mistralai/package.json b/libs/langchain-mistralai/package.json index 56e13455425c..6429d3437040 100644 --- a/libs/langchain-mistralai/package.json +++ b/libs/langchain-mistralai/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/mistralai", - "version": "0.0.22", + "version": "0.0.23", "description": "MistralAI integration for LangChain.js", "type": "module", "engines": { @@ -41,7 +41,7 @@ "license": "MIT", "dependencies": { "@langchain/core": ">0.1.56 <0.3.0", - "@mistralai/mistralai": "^0.1.3", + "@mistralai/mistralai": "^0.4.0", "uuid": "^9.0.0", "zod": "^3.22.4", "zod-to-json-schema": "^3.22.4" diff --git a/libs/langchain-mistralai/src/chat_models.ts b/libs/langchain-mistralai/src/chat_models.ts index 861d9df55c77..4762ddf75127 100644 --- a/libs/langchain-mistralai/src/chat_models.ts +++ b/libs/langchain-mistralai/src/chat_models.ts @@ -3,10 +3,11 @@ import { ChatCompletionResponse, Function as MistralAIFunction, ToolCalls as MistralAIToolCalls, - ToolChoice as MistralAIToolChoice, ResponseFormat, ChatCompletionResponseChunk, - ToolType, + ChatRequest, + Tool as MistralAITool, + Message as MistralAIMessage, } from "@mistralai/mistralai"; import { MessageType, @@ -44,7 +45,6 @@ import { import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { NewTokenIndices } from "@langchain/core/callbacks/base"; import { StructuredTool, StructuredToolInterface } from "@langchain/core/tools"; -import { convertToOpenAITool } from "@langchain/core/utils/function_calling"; import { z } from "zod"; import { type BaseLLMOutputParser, @@ -70,40 +70,15 @@ interface TokenUsage { totalTokens?: number; } -type MistralAIInputMessage = { - role: string; - name?: string; - content: string | string[]; - tool_calls?: MistralAIToolCalls[]; -}; +export type MistralAIToolChoice = "auto" | "any" | "none"; type MistralAIToolInput = { type: string; function: MistralAIFunction }; - -type MistralAIChatCompletionOptions = { - model: string; - messages: Array<{ - role: string; - name?: string; - content: string | string[]; - tool_calls?: MistralAIToolCalls[]; - }>; - tools?: Array; - temperature?: number; - maxTokens?: number; - topP?: number; - randomSeed?: number; - safeMode?: boolean; - safePrompt?: boolean; - toolChoice?: MistralAIToolChoice; - responseFormat?: ResponseFormat; -}; - interface MistralAICallOptions extends Omit { response_format?: { type: "text" | "json_object"; }; - tools: StructuredToolInterface[] | MistralAIToolInput[]; + tools: StructuredToolInterface[] | MistralAIToolInput[] | MistralAITool[]; tool_choice?: MistralAIToolChoice; } @@ -178,7 +153,7 @@ export interface ChatMistralAIInput extends BaseChatModelParams { function convertMessagesToMistralMessages( messages: Array -): Array { +): Array { const getRole = (role: MessageType) => { switch (role) { case "human": @@ -212,7 +187,7 @@ function convertMessagesToMistralMessages( const getTools = (message: BaseMessage): MistralAIToolCalls[] | undefined => { if (isAIMessage(message) && !!message.tool_calls?.length) { return message.tool_calls - .map((toolCall) => ({ ...toolCall, id: "null" })) + .map((toolCall) => ({ ...toolCall, id: toolCall.id })) .map(convertLangChainToolCallToOpenAI) as MistralAIToolCalls[]; } if (!message.additional_kwargs.tool_calls?.length) { @@ -221,8 +196,8 @@ function convertMessagesToMistralMessages( const toolCalls: Omit[] = message.additional_kwargs.tool_calls; return toolCalls?.map((toolCall) => ({ - id: "null", - type: "function" as ToolType.function, + id: toolCall.id, + type: "function", function: toolCall.function, })); }; @@ -235,7 +210,7 @@ function convertMessagesToMistralMessages( content, tool_calls: toolCalls, }; - }); + }) as MistralAIMessage[]; } function mistralAIResponseToChatMessage( @@ -270,7 +245,12 @@ function mistralAIResponseToChatMessage( tool_calls: toolCalls, invalid_tool_calls: invalidToolCalls, additional_kwargs: { - tool_calls: rawToolCalls, + tool_calls: rawToolCalls.length + ? rawToolCalls.map((toolCall) => ({ + ...toolCall, + type: "function", + })) + : undefined, }, }); } @@ -350,8 +330,18 @@ function _convertDeltaToMessageChunk(delta: { function _convertStructuredToolToMistralTool( tools: StructuredToolInterface[] -): MistralAIToolInput[] { - return tools.map((tool) => convertToOpenAITool(tool) as MistralAIToolInput); +): MistralAITool[] { + return tools.map((tool) => { + const description = tool.description ?? `Tool: ${tool.name}`; + return { + type: "function", + function: { + name: tool.name, + description, + parameters: zodToJsonSchema(tool.schema), + }, + }; + }); } /** @@ -439,17 +429,27 @@ export class ChatMistralAI< */ invocationParams( options?: this["ParsedCallOptions"] - ): Omit { + ): Omit { const { response_format, tools, tool_choice } = options ?? {}; - const mistralAITools = tools + const mistralAITools: Array | undefined = tools ?.map((tool) => { if ("lc_namespace" in tool) { return _convertStructuredToolToMistralTool([tool]); } - return tool; + if (!tool.function.description) { + return { + type: "function", + function: { + name: tool.function.name, + description: `Tool: ${tool.function.name}`, + parameters: tool.function.parameters, + }, + } as MistralAITool; + } + return tool as MistralAITool; }) .flat(); - const params: Omit = { + const params: Omit = { model: this.model, tools: mistralAITools, temperature: this.temperature, @@ -484,21 +484,21 @@ export class ChatMistralAI< /** * Calls the MistralAI API with retry logic in case of failures. - * @param {MistralAIChatCompletionOptions} input The input to send to the MistralAI API. + * @param {ChatRequest} input The input to send to the MistralAI API. * @returns {Promise>} The response from the MistralAI API. */ async completionWithRetry( - input: MistralAIChatCompletionOptions, + input: ChatRequest, streaming: true ): Promise>; async completionWithRetry( - input: MistralAIChatCompletionOptions, + input: ChatRequest, streaming: false ): Promise; async completionWithRetry( - input: MistralAIChatCompletionOptions, + input: ChatRequest, streaming: boolean ): Promise< ChatCompletionResponse | AsyncGenerator diff --git a/libs/langchain-mistralai/src/index.ts b/libs/langchain-mistralai/src/index.ts index e080d1245956..569bb40d2982 100644 --- a/libs/langchain-mistralai/src/index.ts +++ b/libs/langchain-mistralai/src/index.ts @@ -1,2 +1,3 @@ export * from "./chat_models.js"; export * from "./embeddings.js"; +export * from "./llms.js"; diff --git a/libs/langchain-mistralai/src/llms.ts b/libs/langchain-mistralai/src/llms.ts new file mode 100644 index 000000000000..944c18ea67b6 --- /dev/null +++ b/libs/langchain-mistralai/src/llms.ts @@ -0,0 +1,361 @@ +import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { BaseLLMParams, LLM } from "@langchain/core/language_models/llms"; +import { type BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; +import { GenerationChunk, LLMResult } from "@langchain/core/outputs"; +import { + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseChunk, + type CompletionRequest, +} from "@mistralai/mistralai"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { chunkArray } from "@langchain/core/utils/chunk_array"; +import { AsyncCaller } from "@langchain/core/utils/async_caller"; + +export interface MistralAICallOptions extends BaseLanguageModelCallOptions { + /** + * Optional text/code that adds more context for the model. + * When given a prompt and a suffix the model will fill what + * is between them. When suffix is not provided, the model + * will simply execute completion starting with prompt. + */ + suffix?: string; +} + +export interface MistralAIInput extends BaseLLMParams { + /** + * The name of the model to use. + * @default "codestral-latest" + */ + model?: string; + /** + * The API key to use. + * @default {process.env.MISTRAL_API_KEY} + */ + apiKey?: string; + /** + * Override the default endpoint. + */ + endpoint?: string; + /** + * What sampling temperature to use, between 0.0 and 2.0. + * Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + * @default {0.7} + */ + temperature?: number; + /** + * Nucleus sampling, where the model considers the results of the tokens with `topP` probability mass. + * So 0.1 means only the tokens comprising the top 10% probability mass are considered. + * Should be between 0 and 1. + * @default {1} + */ + topP?: number; + /** + * The maximum number of tokens to generate in the completion. + * The token count of your prompt plus maxTokens cannot exceed the model's context length. + */ + maxTokens?: number; + /** + * Whether or not to stream the response. + * @default {false} + */ + streaming?: boolean; + /** + * The seed to use for random sampling. If set, different calls will generate deterministic results. + * Alias for `seed` + */ + randomSeed?: number; + /** + * Batch size to use when passing multiple documents to generate + */ + batchSize?: number; +} + +/** + * MistralAI completions LLM. + */ +export class MistralAI + extends LLM + implements MistralAIInput +{ + static lc_name() { + return "MistralAI"; + } + + model = "codestral-latest"; + + temperature = 0; + + topP?: number; + + maxTokens?: number | undefined; + + randomSeed?: number | undefined; + + streaming = false; + + batchSize = 20; + + apiKey: string; + + endpoint?: string; + + maxRetries?: number; + + maxConcurrency?: number; + + constructor(fields?: MistralAIInput) { + super(fields ?? {}); + + this.model = fields?.model ?? this.model; + this.temperature = fields?.temperature ?? this.temperature; + this.topP = fields?.topP ?? this.topP; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.randomSeed = fields?.randomSeed ?? this.randomSeed; + this.batchSize = fields?.batchSize ?? this.batchSize; + this.streaming = fields?.streaming ?? this.streaming; + this.endpoint = fields?.endpoint; + this.maxRetries = fields?.maxRetries; + this.maxConcurrency = fields?.maxConcurrency; + + const apiKey = fields?.apiKey ?? getEnvironmentVariable("MISTRAL_API_KEY"); + if (!apiKey) { + throw new Error( + `MistralAI requires an API key to be set. +Either provide one via the "apiKey" field in the constructor, or set the "MISTRAL_API_KEY" environment variable.` + ); + } + this.apiKey = apiKey; + } + + get lc_secrets(): { [key: string]: string } | undefined { + return { + apiKey: "MISTRAL_API_KEY", + }; + } + + _llmType() { + return "mistralai"; + } + + invocationParams( + options: this["ParsedCallOptions"] + ): Omit { + return { + model: this.model, + suffix: options.suffix, + temperature: this.temperature, + maxTokens: this.maxTokens, + topP: this.topP, + randomSeed: this.randomSeed, + stop: options.stop, + }; + } + + /** + * For some given input string and options, return a string output. + * + * Despite the fact that `invoke` is overridden below, we still need this + * in order to handle public APi calls to `generate()`. + */ + async _call( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + const params = { + ...this.invocationParams(options), + prompt, + }; + const result = await this.completionWithRetry(params, options, false); + return result.choices[0].message.content ?? ""; + } + + async _generate( + prompts: string[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { + const subPrompts = chunkArray(prompts, this.batchSize); + const choices: ChatCompletionResponseChoice[][] = []; + + const params = this.invocationParams(options); + + for (let i = 0; i < subPrompts.length; i += 1) { + const data = await (async () => { + if (this.streaming) { + const responseData: Array< + { choices: ChatCompletionResponseChoice[] } & Partial< + Omit + > + > = []; + for (let x = 0; x < subPrompts[i].length; x += 1) { + const choices: ChatCompletionResponseChoice[] = []; + let response: + | Omit + | undefined; + const stream = await this.completionWithRetry( + { + ...params, + prompt: subPrompts[i][x], + }, + options, + true + ); + for await (const message of stream) { + // on the first message set the response properties + if (!response) { + response = { + id: message.id, + object: "chat.completion", + created: message.created, + model: message.model, + }; + } + + // on all messages, update choice + for (const part of message.choices) { + if (!choices[part.index]) { + choices[part.index] = { + index: part.index, + message: { + role: part.delta.role ?? "assistant", + content: part.delta.content ?? "", + tool_calls: null, + }, + finish_reason: part.finish_reason, + }; + } else { + const choice = choices[part.index]; + choice.message.content += part.delta.content ?? ""; + choice.finish_reason = part.finish_reason; + } + void runManager?.handleLLMNewToken(part.delta.content ?? "", { + prompt: part.index, + completion: part.index, + }); + } + } + if (options.signal?.aborted) { + throw new Error("AbortError"); + } + responseData.push({ + ...response, + choices, + }); + } + return responseData; + } else { + const responseData: Array = []; + for (let x = 0; x < subPrompts[i].length; x += 1) { + const res = await this.completionWithRetry( + { + ...params, + prompt: subPrompts[i][x], + }, + options, + false + ); + responseData.push(res); + } + return responseData; + } + })(); + + choices.push(...data.map((d) => d.choices)); + } + + const generations = choices.map((promptChoices) => + promptChoices.map((choice) => ({ + text: choice.message.content ?? "", + generationInfo: { + finishReason: choice.finish_reason, + }, + })) + ); + return { + generations, + }; + } + + async completionWithRetry( + request: CompletionRequest, + options: this["ParsedCallOptions"], + stream: false + ): Promise; + + async completionWithRetry( + request: CompletionRequest, + options: this["ParsedCallOptions"], + stream: true + ): Promise>; + + async completionWithRetry( + request: CompletionRequest, + options: this["ParsedCallOptions"], + stream: boolean + ): Promise< + | ChatCompletionResponse + | AsyncGenerator + > { + const { MistralClient } = await this.imports(); + const caller = new AsyncCaller({ + maxConcurrency: options.maxConcurrency || this.maxConcurrency, + maxRetries: this.maxRetries, + }); + const client = new MistralClient( + this.apiKey, + this.endpoint, + this.maxRetries, + options.timeout + ); + return caller.callWithOptions( + { + signal: options.signal, + }, + async () => { + if (stream) { + return client.completionStream(request); + } else { + return client.completion(request); + } + } + ); + } + + async *_streamResponseChunks( + prompt: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const params = { + ...this.invocationParams(options), + prompt, + }; + const stream = await this.completionWithRetry(params, options, true); + for await (const data of stream) { + const choice = data?.choices[0]; + if (!choice) { + continue; + } + const chunk = new GenerationChunk({ + text: choice.delta.content ?? "", + generationInfo: { + finishReason: choice.finish_reason, + tokenUsage: data.usage, + }, + }); + yield chunk; + // eslint-disable-next-line no-void + void runManager?.handleLLMNewToken(chunk.text ?? ""); + } + if (options.signal?.aborted) { + throw new Error("AbortError"); + } + } + + /** @ignore */ + private async imports() { + const { default: MistralClient } = await import("@mistralai/mistralai"); + return { MistralClient }; + } +} diff --git a/libs/langchain-mistralai/src/tests/agent.int.test.ts b/libs/langchain-mistralai/src/tests/agent.int.test.ts index 206347bb5e9a..156b773859f1 100644 --- a/libs/langchain-mistralai/src/tests/agent.int.test.ts +++ b/libs/langchain-mistralai/src/tests/agent.int.test.ts @@ -18,7 +18,7 @@ test("createToolCallingAgent works", async () => { // ["placeholder", "{agent_scratchpad}"], // ]); // const llm = new ChatMistralAI({ - // model: "mistral-large", + // model: "mistral-large-latest", // temperature: 0, // }); // const agent = await createToolCallingAgent({ diff --git a/libs/langchain-mistralai/src/tests/chat_models.int.test.ts b/libs/langchain-mistralai/src/tests/chat_models.int.test.ts index 17a3069a2c81..827eaa08db03 100644 --- a/libs/langchain-mistralai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-mistralai/src/tests/chat_models.int.test.ts @@ -20,7 +20,7 @@ import { ChatMistralAI } from "../chat_models.js"; test("Test ChatMistralAI can invoke", async () => { const model = new ChatMistralAI({ - modelName: "mistral-tiny", + model: "mistral-tiny", }); const prompt = ChatPromptTemplate.fromMessages([ ["system", "You are a helpful assistant"], @@ -71,7 +71,7 @@ test("Can call tools using structured tools", async () => { } const model = new ChatMistralAI({ - modelName: "mistral-large", + model: "mistral-large-latest", }).bind({ tools: [new Calculator()], }); @@ -120,7 +120,7 @@ test("Can call tools", async () => { ]; const model = new ChatMistralAI({ - modelName: "mistral-large", + model: "mistral-large-latest", }).bind({ tools, }); @@ -170,7 +170,7 @@ test("Can call .stream with tool calling", async () => { } const model = new ChatMistralAI({ - modelName: "mistral-large", + model: "mistral-large-latest", }).bind({ tools: [new Calculator()], }); @@ -208,7 +208,7 @@ test("Can call .stream with tool calling", async () => { test("Can use json mode response format", async () => { const model = new ChatMistralAI({ - modelName: "mistral-large", + model: "mistral-large-latest", }).bind({ response_format: { type: "json_object", @@ -235,7 +235,7 @@ To use a calculator respond with valid JSON containing a single key: 'calculator test("Can call .stream with json mode", async () => { const model = new ChatMistralAI({ - modelName: "mistral-large", + model: "mistral-large-latest", }).bind({ response_format: { type: "json_object", @@ -294,7 +294,7 @@ test("Can stream and concat responses for a complex tool", async () => { } const model = new ChatMistralAI({ - modelName: "mistral-large", + model: "mistral-large-latest", }).bind({ tools: [new PersonTraits()], }); @@ -332,7 +332,7 @@ test("Can stream and concat responses for a complex tool", async () => { test("Few shotting with tool calls", async () => { const chat = new ChatMistralAI({ - modelName: "mistral-large", + model: "mistral-large-latest", temperature: 0, }).bind({ tools: [ @@ -385,7 +385,7 @@ describe("withStructuredOutput", () => { test("withStructuredOutput zod schema function calling", async () => { const model = new ChatMistralAI({ temperature: 0, - modelName: "mistral-large", + model: "mistral-large-latest", }); const calculatorSchema = z @@ -421,7 +421,7 @@ describe("withStructuredOutput", () => { test("withStructuredOutput zod schema JSON mode", async () => { const model = new ChatMistralAI({ temperature: 0, - modelName: "mistral-large", + model: "mistral-large-latest", }); const calculatorSchema = z.object({ @@ -459,7 +459,7 @@ describe("withStructuredOutput", () => { test("withStructuredOutput JSON schema function calling", async () => { const model = new ChatMistralAI({ temperature: 0, - modelName: "mistral-large", + model: "mistral-large-latest", }); const calculatorSchema = z @@ -496,7 +496,7 @@ describe("withStructuredOutput", () => { test("withStructuredOutput OpenAI function definition function calling", async () => { const model = new ChatMistralAI({ temperature: 0, - modelName: "mistral-large", + model: "mistral-large-latest", }); const calculatorSchema = z @@ -531,7 +531,7 @@ describe("withStructuredOutput", () => { test("withStructuredOutput JSON schema JSON mode", async () => { const model = new ChatMistralAI({ temperature: 0, - modelName: "mistral-large", + model: "mistral-large-latest", }); const calculatorSchema = z.object({ @@ -569,7 +569,7 @@ describe("withStructuredOutput", () => { test("withStructuredOutput includeRaw true", async () => { const model = new ChatMistralAI({ temperature: 0, - modelName: "mistral-large", + model: "mistral-large-latest", }); const calculatorSchema = z @@ -642,7 +642,7 @@ describe("withStructuredOutput", () => { test("Model is compatible with OpenAI tools agent and Agent Executor", async () => { const llm: BaseChatModel = new ChatMistralAI({ temperature: 0, - modelName: "mistral-large-latest", + model: "mistral-large-latest", }); const systemMessage = SystemMessagePromptTemplate.fromTemplate( @@ -823,3 +823,96 @@ describe("ChatMistralAI aborting", () => { expect(didError).toBeTruthy(); }); }); + +describe("codestral-latest", () => { + test("Test ChatMistralAI can invoke codestral-latest", async () => { + const model = new ChatMistralAI({ + model: "codestral-latest", + }); + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + ["human", "{input}"], + ]); + const response = await prompt.pipe(model).invoke({ + input: "How can I log 'Hello, World!' in Python?", + }); + console.log("response", response); + expect(response.content.length).toBeGreaterThan(1); + expect((response.content as string).toLowerCase()).toContain("hello"); + expect((response.content as string).toLowerCase()).toContain("world"); + }); + + test("Test ChatMistralAI can stream codestral-latest", async () => { + const model = new ChatMistralAI({ + model: "codestral-latest", + }); + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + ["human", "{input}"], + ]); + const response = await prompt.pipe(model).stream({ + input: "How can I log 'Hello, World!' in Python?", + }); + let itters = 0; + let fullMessage = ""; + for await (const item of response) { + console.log(item); + itters += 1; + fullMessage += item.content; + } + console.log("fullMessage", fullMessage); + expect(itters).toBeGreaterThan(1); + expect(fullMessage.toLowerCase()).toContain("hello"); + expect(fullMessage.toLowerCase()).toContain("world"); + }); + + test("Can call tools using structured tools codestral-latest", async () => { + class CodeSandbox extends StructuredTool { + name = "code_sandbox"; + + description = + "A tool which can run Python code in an isolated environment"; + + schema = z.object({ + code: z + .string() + .describe( + "The Python code to execute. Must only contain valid Python code." + ), + }); + + async _call(input: z.infer) { + return JSON.stringify(input, null, 2); + } + } + + const model = new ChatMistralAI({ + model: "codestral-latest", + }).bind({ + tools: [new CodeSandbox()], + tool_choice: "any", + }); + + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are an excellent python engineer."], + ["human", "{input}"], + ]); + + const chain = prompt.pipe(model); + const response = await chain.invoke({ + input: + "Write a function that takes in a single argument and logs it to the console. Ensure the code is in Python.", + }); + console.log(response); + expect("tool_calls" in response.additional_kwargs).toBe(true); + console.log(response.additional_kwargs.tool_calls?.[0]); + if (!response.additional_kwargs.tool_calls?.[0]) { + throw new Error("No tool call found"); + } + const sandboxTool = response.additional_kwargs.tool_calls[0]; + expect(sandboxTool.function.name).toBe("code_sandbox"); + const parsedArgs = JSON.parse(sandboxTool.function.arguments); + expect(parsedArgs.code).toBeDefined(); + console.log(parsedArgs.code); + }); +}); diff --git a/libs/langchain-mistralai/src/tests/llms.int.test.ts b/libs/langchain-mistralai/src/tests/llms.int.test.ts new file mode 100644 index 000000000000..5a0aa418fcaf --- /dev/null +++ b/libs/langchain-mistralai/src/tests/llms.int.test.ts @@ -0,0 +1,152 @@ +import { test, expect } from "@jest/globals"; +import { CallbackManager } from "@langchain/core/callbacks/manager"; +import { MistralAI } from "../llms.js"; + +test("Test MistralAI", async () => { + const model = new MistralAI({ + maxTokens: 5, + model: "codestral-latest", + }); + const res = await model.invoke( + "Log 'Hello world' to the console in javascript: " + ); + console.log({ res }, "Test MistralAI"); + expect(res.length).toBeGreaterThan(1); +}); + +test("Test MistralAI with stop in object", async () => { + const model = new MistralAI({ + maxTokens: 5, + model: "codestral-latest", + }); + const res = await model.invoke("console.log 'Hello world' in javascript:", { + stop: ["world"], + }); + console.log({ res }, "Test MistralAI with stop in object"); +}); + +test("Test MistralAI with timeout in call options", async () => { + const model = new MistralAI({ + maxTokens: 5, + maxRetries: 0, + model: "codestral-latest", + }); + await expect(() => + model.invoke("Log 'Hello world' to the console in javascript: ", { + timeout: 10, + }) + ).rejects.toThrow(); +}, 5000); + +test("Test MistralAI with timeout in call options and node adapter", async () => { + const model = new MistralAI({ + maxTokens: 5, + maxRetries: 0, + model: "codestral-latest", + }); + await expect(() => + model.invoke("Log 'Hello world' to the console in javascript: ", { + timeout: 10, + }) + ).rejects.toThrow(); +}, 5000); + +test("Test MistralAI with signal in call options", async () => { + const model = new MistralAI({ + maxTokens: 5, + model: "codestral-latest", + }); + const controller = new AbortController(); + await expect(async () => { + const ret = await model.stream( + "Log 'Hello world' to the console in javascript 100 times: ", + { + signal: controller.signal, + } + ); + + for await (const chunk of ret) { + console.log({ chunk }, "Test MistralAI with signal in call options"); + controller.abort(); + } + + return ret; + }).rejects.toThrow(); +}, 5000); + +test("Test MistralAI in streaming mode", async () => { + let nrNewTokens = 0; + let streamedCompletion = ""; + + const model = new MistralAI({ + maxTokens: 5, + model: "codestral-latest", + streaming: true, + callbacks: CallbackManager.fromHandlers({ + async handleLLMNewToken(token: string) { + nrNewTokens += 1; + streamedCompletion += token; + }, + }), + }); + const res = await model.invoke( + "Log 'Hello world' to the console in javascript: " + ); + console.log({ res }, "Test MistralAI in streaming mode"); + + expect(nrNewTokens > 0).toBe(true); + expect(res).toBe(streamedCompletion); +}); + +test("Test MistralAI stream method", async () => { + const model = new MistralAI({ + maxTokens: 50, + model: "codestral-latest", + }); + const stream = await model.stream( + "Log 'Hello world' to the console in javascript: ." + ); + const chunks = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + expect(chunks.length).toBeGreaterThan(1); +}); + +test("Test MistralAI stream method with abort", async () => { + await expect(async () => { + const model = new MistralAI({ + maxTokens: 250, + maxRetries: 0, + model: "codestral-latest", + }); + const stream = await model.stream( + "How is your day going? Be extremely verbose.", + { + signal: AbortSignal.timeout(1000), + } + ); + for await (const chunk of stream) { + console.log({ chunk }, "Test MistralAI stream method with abort"); + } + }).rejects.toThrow(); +}); + +test("Test MistralAI stream method with early break", async () => { + const model = new MistralAI({ + maxTokens: 50, + model: "codestral-latest", + }); + const stream = await model.stream( + "How is your day going? Be extremely verbose." + ); + let i = 0; + for await (const chunk of stream) { + console.log({ chunk }, "Test MistralAI stream method with early break"); + i += 1; + if (i > 5) { + break; + } + } + expect(i).toBeGreaterThan(5); +}); diff --git a/yarn.lock b/yarn.lock index 0dc77072b126..84fa81542ac2 100644 --- a/yarn.lock +++ b/yarn.lock @@ -9922,7 +9922,7 @@ __metadata: "@jest/globals": ^29.5.0 "@langchain/core": ">0.1.56 <0.3.0" "@langchain/scripts": ~0.0.14 - "@mistralai/mistralai": ^0.1.3 + "@mistralai/mistralai": ^0.4.0 "@swc/core": ^1.3.90 "@swc/jest": ^0.2.29 "@tsconfig/recommended": ^1.0.3 @@ -10418,12 +10418,12 @@ __metadata: languageName: node linkType: hard -"@mistralai/mistralai@npm:^0.1.3": - version: 0.1.3 - resolution: "@mistralai/mistralai@npm:0.1.3" +"@mistralai/mistralai@npm:^0.4.0": + version: 0.4.0 + resolution: "@mistralai/mistralai@npm:0.4.0" dependencies: node-fetch: ^2.6.7 - checksum: 3f8299811b06027dfbdae4fd86564ccda1a48ec4276940e1dcace4fa447cc4f7e61808a38d71ef73116fc6cf90b74257e6537e1c47c96d50ff52bfb8f62d2947 + checksum: 1b03fc0b55164c02e5fb29fb2d09ebe4ad44346fc313f7fb3ab09e48f73f975763d1ac9654098d433ea17d7caa20654b2b15510822276acc9fa46db461a254a6 languageName: node linkType: hard