Skip to content

Commit

Permalink
implemented tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jul 5, 2024
1 parent 0836874 commit 641690d
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 65 deletions.
168 changes: 106 additions & 62 deletions libs/langchain-ollama/src/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import {
AIMessage,
MessageContentText,
type BaseMessage,
} from "@langchain/core/messages";
import { MessageContentText, type BaseMessage } from "@langchain/core/messages";
import { type BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";

import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
type BaseChatModelParams,
BaseChatModel,
LangSmithParams,
} from "@langchain/core/language_models/chat_models";
import ollama from "ollama/browser";
import { Ollama } from "ollama/browser";
import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs";
import { AIMessageChunk } from "@langchain/core/messages";
import type {
ChatRequest as OllamaChatRequest,
Message as OllamaMessage,
} from "ollama";

function extractBase64FromDataUrl(dataUrl: string): string {
const match = dataUrl.match(/^data:.*?;base64,(.*)$/);
return match ? match[1] : "";
}

function convertToOllamaMessages(messages: BaseMessage[]): OllamaMessage[] {
return messages.flatMap((msg) => {
Expand All @@ -34,13 +40,13 @@ function convertToOllamaMessages(messages: BaseMessage[]): OllamaMessage[] {
return {
role: "user",
content: "",
images: [c.image_url],
images: [extractBase64FromDataUrl(c.image_url)],
};
} else if (c.image_url.url && typeof c.image_url.url === "string") {
return {
role: "user",
content: "",
images: [c.image_url.url],
images: [extractBase64FromDataUrl(c.image_url.url)],
};
}
}
Expand Down Expand Up @@ -96,14 +102,13 @@ function convertToOllamaMessages(messages: BaseMessage[]): OllamaMessage[] {
});
}

export interface OllamaMessage {
role: string;
content: string;
images?: Uint8Array[] | string[];
export interface ChatOllamaCallOptions extends BaseLanguageModelCallOptions {
/**
* An array of strings to stop on.
*/
stop?: string[];
}

export interface ChatOllamaCallOptions extends BaseLanguageModelCallOptions {}

export interface PullModelOptions {
/**
* Whether or not to stream the download.
Expand All @@ -129,6 +134,17 @@ export interface ChatOllamaInput extends BaseChatModelParams {
* @default "llama3"
*/
model?: string;
/**
* The host URL of the Ollama server.
*/
baseUrl?: string;
/**
* Whether or not to check the model exists on the local machine before
* invoking it. If set to `true`, the model will be pulled if it does not
* exist.
* @default false
*/
checkModelExists?: boolean;
streaming?: boolean;
numa?: boolean;
numCtx?: number;
Expand Down Expand Up @@ -160,6 +176,9 @@ export interface ChatOllamaInput extends BaseChatModelParams {
mirostatEta?: number;
penalizeNewline?: boolean;
format?: string;
/**
* @default "5m"
*/
keepAlive?: string | number;
}

Expand Down Expand Up @@ -239,10 +258,19 @@ export class ChatOllama

format?: string;

keepAlive?: string | number;
keepAlive?: string | number = "5m";

client: Ollama;

checkModelExists = false;

constructor(fields?: ChatOllamaInput) {
super(fields ?? {});

this.client = new Ollama({
host: fields?.baseUrl,
});

this.model = fields?.model ?? this.model;
this.numa = fields?.numa;
this.numCtx = fields?.numCtx;
Expand Down Expand Up @@ -275,7 +303,8 @@ export class ChatOllama
this.penalizeNewline = fields?.penalizeNewline;
this.streaming = fields?.streaming;
this.format = fields?.format;
this.keepAlive = fields?.keepAlive;
this.keepAlive = fields?.keepAlive ?? this.keepAlive;
this.checkModelExists = fields?.checkModelExists ?? this.checkModelExists;
}

// Replace
Expand All @@ -297,7 +326,7 @@ export class ChatOllama
};

if (stream) {
for await (const chunk of await ollama.pull({
for await (const chunk of await this.client.pull({
model,
insecure,
stream,
Expand All @@ -307,14 +336,28 @@ export class ChatOllama
}
}
} else {
const response = await ollama.pull({ model, insecure });
const response = await this.client.pull({ model, insecure });
if (logProgress) {
console.log(response);
}
}
}

invocationParams(_options?: this["ParsedCallOptions"]) {
getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
const params = this.invocationParams(options);
return {
ls_provider: "ollama",
ls_model_name: this.model,
ls_model_type: "chat",
ls_temperature: params.options?.temperature ?? undefined,
ls_max_tokens: params.options?.num_predict ?? undefined,
ls_stop: options.stop,
};
}

invocationParams(
options?: this["ParsedCallOptions"]
): Omit<OllamaChatRequest, "messages"> {
return {
model: this.model,
format: this.format,
Expand Down Expand Up @@ -349,64 +392,54 @@ export class ChatOllama
mirostat_tau: this.mirostatTau,
mirostat_eta: this.mirostatEta,
penalize_newline: this.penalizeNewline,
stop: options?.stop,
},
};
}

/**
* Check if a model exists on the local machine.
*
* @param {string} model The name of the model to check.
* @returns {Promise<boolean>} Whether or not the model exists.
*/
private async checkModelExistsOnMachine(model: string): Promise<boolean> {
const { models } = await this.client.list();
return !!models.find(
(m) => m.name === model || m.name === `${model}:latest`
);
}

async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const { models } = await ollama.list();
// By default, pull the model if it does not exist.
if (!models.find((model) => model.name === this.model)) {
await this.pull(this.model);
if (this.checkModelExists) {
if (!(await this.checkModelExistsOnMachine(this.model))) {
await this.pull(this.model, {
logProgress: true,
});
}
}

if (this.streaming) {
let finalChunk: ChatGenerationChunk | undefined;
for await (const chunk of this._streamResponseChunks(
messages,
options,
runManager
)) {
if (!finalChunk) {
finalChunk = chunk;
} else {
finalChunk = finalChunk.concat(chunk);
}
let finalChunk: ChatGenerationChunk | undefined;
for await (const chunk of this._streamResponseChunks(
messages,
options,
runManager
)) {
if (!finalChunk) {
finalChunk = chunk;
} else {
finalChunk = finalChunk.concat(chunk);
}
return {
generations: [
{
text: finalChunk?.text ?? "",
message: finalChunk?.message as AIMessageChunk,
},
],
};
}
const params = this.invocationParams(options);
const ollamaMessages = convertToOllamaMessages(messages);

const response = await ollama.chat({
...params,
messages: ollamaMessages,
stream: false,
});
const { message: responseMessage, ...rest } = response;

runManager?.handleLLMNewToken(responseMessage.content);
return {
generations: [
{
text: responseMessage.content,
message: new AIMessage({
content: responseMessage.content,
response_metadata: {
...rest,
},
}),
text: finalChunk?.text ?? "",
message: finalChunk?.message as AIMessageChunk,
},
],
};
Expand All @@ -421,15 +454,26 @@ export class ChatOllama
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (this.checkModelExists) {
if (!(await this.checkModelExistsOnMachine(this.model))) {
await this.pull(this.model, {
logProgress: true,
});
}
}

const params = this.invocationParams(options);
const ollamaMessages = convertToOllamaMessages(messages);

const stream = ollama.chat({
const stream = this.client.chat({
...params,
messages: ollamaMessages,
stream: true,
});
for await (const chunk of await stream) {
if (options.signal?.aborted) {
this.client.abort();
}
const { message: responseMessage, ...rest } = chunk;
yield new ChatGenerationChunk({
text: responseMessage.content,
Expand Down
Loading

0 comments on commit 641690d

Please sign in to comment.