Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed May 29, 2024
1 parent 6a17555 commit 4a91836
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 52 deletions.
44 changes: 35 additions & 9 deletions libs/langchain-mistralai/src/llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
} from "@mistralai/mistralai";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { chunkArray } from "@langchain/core/utils/chunk_array";
import { AsyncCaller } from "@langchain/core/utils/async_caller";

export interface MistralAICallOptions extends BaseLanguageModelCallOptions {
/**
Expand Down Expand Up @@ -99,6 +100,10 @@ export class MistralAI

endpoint?: string;

maxRetries?: number;

maxConcurrency?: number;

constructor(fields?: MistralAIInput) {
super(fields ?? {});

Expand All @@ -110,6 +115,8 @@ export class MistralAI
this.batchSize = fields?.batchSize ?? this.batchSize;
this.streaming = fields?.streaming ?? this.streaming;
this.endpoint = fields?.endpoint;
this.maxRetries = fields?.maxRetries;
this.maxConcurrency = fields?.maxConcurrency;

const apiKey = fields?.apiKey ?? getEnvironmentVariable("MISTRAL_API_KEY");
if (!apiKey) {
Expand Down Expand Up @@ -159,7 +166,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA
...this.invocationParams(options),
prompt,
};
const result = await this.completionWithRetry(params, false);
const result = await this.completionWithRetry(params, options, false);
return result.choices[0].message.content ?? "";
}

Expand Down Expand Up @@ -191,6 +198,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA
...params,
prompt: subPrompts[i][x],
},
options,
true
);
for await (const message of stream) {
Expand Down Expand Up @@ -244,6 +252,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA
...params,
prompt: subPrompts[i][x],
},
options,
false
);
responseData.push(res);
Expand All @@ -270,30 +279,47 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA

async completionWithRetry(
request: CompletionRequest,
options: this["ParsedCallOptions"],
stream: false
): Promise<ChatCompletionResponse>;

async completionWithRetry(
request: CompletionRequest,
options: this["ParsedCallOptions"],
stream: true
): Promise<AsyncGenerator<ChatCompletionResponseChunk, void>>;

async completionWithRetry(
request: CompletionRequest,
options: this["ParsedCallOptions"],
stream: boolean
): Promise<
| ChatCompletionResponse
| AsyncGenerator<ChatCompletionResponseChunk, void, unknown>
> {
const { MistralClient } = await this.imports();
const client = new MistralClient(this.apiKey, this.endpoint);
return this.caller.call(async () => {
if (stream) {
return client.completionStream(request);
} else {
return client.completion(request);
}
const caller = new AsyncCaller({
maxConcurrency: options.maxConcurrency || this.maxConcurrency,
maxRetries: this.maxRetries,
});
const client = new MistralClient(
this.apiKey,
this.endpoint,
this.maxRetries,
options.timeout
);
return caller.callWithOptions(
{
signal: options.signal,
},
async () => {
if (stream) {
return client.completionStream(request);
} else {
return client.completion(request);
}
}
);
}

async *_streamResponseChunks(
Expand All @@ -305,7 +331,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA
...this.invocationParams(options),
prompt,
};
const stream = await this.completionWithRetry(params, true);
const stream = await this.completionWithRetry(params, options, true);
for await (const data of stream) {
const choice = data?.choices[0];
if (!choice) {
Expand Down
53 changes: 10 additions & 43 deletions libs/langchain-mistralai/src/tests/llms.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { test, expect } from "@jest/globals";
import { CallbackManager } from "@langchain/core/callbacks/manager";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { MistralAI } from "../llms.js";

test("Test MistralAI", async () => {
Expand All @@ -20,10 +19,9 @@ test("Test MistralAI with stop in object", async () => {
maxTokens: 5,
model: "codestral-latest",
});
const res = await model.invoke(
"Log 'Hello world' to the console in javascript: ",
{ stop: ["world"] }
);
const res = await model.invoke("console.log 'Hello world' in javascript:", {
stop: ["world"],
});
console.log({ res }, "Test MistralAI with stop in object");
});

Expand Down Expand Up @@ -59,15 +57,18 @@ test("Test MistralAI with signal in call options", async () => {
model: "codestral-latest",
});
const controller = new AbortController();
await expect(() => {
const ret = model.invoke(
"Log 'Hello world' to the console in javascript: ",
await expect(async () => {
const ret = await model.stream(
"Log 'Hello world' to the console in javascript 100 times: ",
{
signal: controller.signal,
}
);

controller.abort();
for await (const chunk of ret) {
console.log({ chunk }, "Test MistralAI with signal in call options");
controller.abort();
}

return ret;
}).rejects.toThrow();
Expand Down Expand Up @@ -97,40 +98,6 @@ test("Test MistralAI in streaming mode", async () => {
expect(res).toBe(streamedCompletion);
});

test.skip("Test MistralAI in streaming mode with multiple prompts", async () => {
let nrNewTokens = 0;
const completions = [
["", ""],
["", ""],
];

const model = new MistralAI({
maxTokens: 5,
model: "codestral-latest",
streaming: true,
callbacks: CallbackManager.fromHandlers({
async handleLLMNewToken(token: string, idx: NewTokenIndices) {
nrNewTokens += 1;
completions[idx.prompt][idx.completion] += token;
},
}),
});
const res = await model.generate([
"Log 'Hello world' to the console in javascript: ",
"print hello sea",
]);
console.log(
res.generations,
res.generations.map((g) => g[0].generationInfo)
);

expect(nrNewTokens > 0).toBe(true);
expect(res.generations.length).toBe(2);
expect(res.generations.map((g) => g.map((gg) => gg.text))).toEqual(
completions
);
});

test("Test MistralAI stream method", async () => {
const model = new MistralAI({
maxTokens: 50,
Expand Down

0 comments on commit 4a91836

Please sign in to comment.