From 4a918362761a2bcc09ab3be670e8ef039358ab71 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 29 May 2024 14:39:49 -0700 Subject: [PATCH] cr --- libs/langchain-mistralai/src/llms.ts | 44 +++++++++++---- .../src/tests/llms.int.test.ts | 53 ++++--------------- 2 files changed, 45 insertions(+), 52 deletions(-) diff --git a/libs/langchain-mistralai/src/llms.ts b/libs/langchain-mistralai/src/llms.ts index 8d95b1f2ff4d..944c18ea67b6 100644 --- a/libs/langchain-mistralai/src/llms.ts +++ b/libs/langchain-mistralai/src/llms.ts @@ -10,6 +10,7 @@ import { } from "@mistralai/mistralai"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { chunkArray } from "@langchain/core/utils/chunk_array"; +import { AsyncCaller } from "@langchain/core/utils/async_caller"; export interface MistralAICallOptions extends BaseLanguageModelCallOptions { /** @@ -99,6 +100,10 @@ export class MistralAI endpoint?: string; + maxRetries?: number; + + maxConcurrency?: number; + constructor(fields?: MistralAIInput) { super(fields ?? {}); @@ -110,6 +115,8 @@ export class MistralAI this.batchSize = fields?.batchSize ?? this.batchSize; this.streaming = fields?.streaming ?? this.streaming; this.endpoint = fields?.endpoint; + this.maxRetries = fields?.maxRetries; + this.maxConcurrency = fields?.maxConcurrency; const apiKey = fields?.apiKey ?? getEnvironmentVariable("MISTRAL_API_KEY"); if (!apiKey) { @@ -159,7 +166,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA ...this.invocationParams(options), prompt, }; - const result = await this.completionWithRetry(params, false); + const result = await this.completionWithRetry(params, options, false); return result.choices[0].message.content ?? ""; } @@ -191,6 +198,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA ...params, prompt: subPrompts[i][x], }, + options, true ); for await (const message of stream) { @@ -244,6 +252,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA ...params, prompt: subPrompts[i][x], }, + options, false ); responseData.push(res); @@ -270,30 +279,47 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA async completionWithRetry( request: CompletionRequest, + options: this["ParsedCallOptions"], stream: false ): Promise; async completionWithRetry( request: CompletionRequest, + options: this["ParsedCallOptions"], stream: true ): Promise>; async completionWithRetry( request: CompletionRequest, + options: this["ParsedCallOptions"], stream: boolean ): Promise< | ChatCompletionResponse | AsyncGenerator > { const { MistralClient } = await this.imports(); - const client = new MistralClient(this.apiKey, this.endpoint); - return this.caller.call(async () => { - if (stream) { - return client.completionStream(request); - } else { - return client.completion(request); - } + const caller = new AsyncCaller({ + maxConcurrency: options.maxConcurrency || this.maxConcurrency, + maxRetries: this.maxRetries, }); + const client = new MistralClient( + this.apiKey, + this.endpoint, + this.maxRetries, + options.timeout + ); + return caller.callWithOptions( + { + signal: options.signal, + }, + async () => { + if (stream) { + return client.completionStream(request); + } else { + return client.completion(request); + } + } + ); } async *_streamResponseChunks( @@ -305,7 +331,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA ...this.invocationParams(options), prompt, }; - const stream = await this.completionWithRetry(params, true); + const stream = await this.completionWithRetry(params, options, true); for await (const data of stream) { const choice = data?.choices[0]; if (!choice) { diff --git a/libs/langchain-mistralai/src/tests/llms.int.test.ts b/libs/langchain-mistralai/src/tests/llms.int.test.ts index cff7363dc2e2..5a0aa418fcaf 100644 --- a/libs/langchain-mistralai/src/tests/llms.int.test.ts +++ b/libs/langchain-mistralai/src/tests/llms.int.test.ts @@ -1,6 +1,5 @@ import { test, expect } from "@jest/globals"; import { CallbackManager } from "@langchain/core/callbacks/manager"; -import { NewTokenIndices } from "@langchain/core/callbacks/base"; import { MistralAI } from "../llms.js"; test("Test MistralAI", async () => { @@ -20,10 +19,9 @@ test("Test MistralAI with stop in object", async () => { maxTokens: 5, model: "codestral-latest", }); - const res = await model.invoke( - "Log 'Hello world' to the console in javascript: ", - { stop: ["world"] } - ); + const res = await model.invoke("console.log 'Hello world' in javascript:", { + stop: ["world"], + }); console.log({ res }, "Test MistralAI with stop in object"); }); @@ -59,15 +57,18 @@ test("Test MistralAI with signal in call options", async () => { model: "codestral-latest", }); const controller = new AbortController(); - await expect(() => { - const ret = model.invoke( - "Log 'Hello world' to the console in javascript: ", + await expect(async () => { + const ret = await model.stream( + "Log 'Hello world' to the console in javascript 100 times: ", { signal: controller.signal, } ); - controller.abort(); + for await (const chunk of ret) { + console.log({ chunk }, "Test MistralAI with signal in call options"); + controller.abort(); + } return ret; }).rejects.toThrow(); @@ -97,40 +98,6 @@ test("Test MistralAI in streaming mode", async () => { expect(res).toBe(streamedCompletion); }); -test.skip("Test MistralAI in streaming mode with multiple prompts", async () => { - let nrNewTokens = 0; - const completions = [ - ["", ""], - ["", ""], - ]; - - const model = new MistralAI({ - maxTokens: 5, - model: "codestral-latest", - streaming: true, - callbacks: CallbackManager.fromHandlers({ - async handleLLMNewToken(token: string, idx: NewTokenIndices) { - nrNewTokens += 1; - completions[idx.prompt][idx.completion] += token; - }, - }), - }); - const res = await model.generate([ - "Log 'Hello world' to the console in javascript: ", - "print hello sea", - ]); - console.log( - res.generations, - res.generations.map((g) => g[0].generationInfo) - ); - - expect(nrNewTokens > 0).toBe(true); - expect(res.generations.length).toBe(2); - expect(res.generations.map((g) => g.map((gg) => gg.text))).toEqual( - completions - ); -}); - test("Test MistralAI stream method", async () => { const model = new MistralAI({ maxTokens: 50,