From 609805da97ebbdf434c08d6d454293a2b35b91b1 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Thu, 8 Aug 2024 12:16:23 -0700 Subject: [PATCH 1/2] ollama[minor]: Port embeddings to ollama package --- libs/langchain-ollama/src/chat_models.ts | 38 +---- libs/langchain-ollama/src/embeddings.ts | 151 ++++++++++++++++++ libs/langchain-ollama/src/index.ts | 2 + .../src/tests/embeddings.int.test.ts | 19 +++ libs/langchain-ollama/src/types.ts | 36 +++++ 5 files changed, 212 insertions(+), 34 deletions(-) create mode 100644 libs/langchain-ollama/src/embeddings.ts create mode 100644 libs/langchain-ollama/src/tests/embeddings.int.test.ts create mode 100644 libs/langchain-ollama/src/types.ts diff --git a/libs/langchain-ollama/src/chat_models.ts b/libs/langchain-ollama/src/chat_models.ts index 52469ac8e3..7c42b37ea1 100644 --- a/libs/langchain-ollama/src/chat_models.ts +++ b/libs/langchain-ollama/src/chat_models.ts @@ -28,6 +28,7 @@ import { convertOllamaMessagesToLangChain, convertToOllamaMessages, } from "./utils.js"; +import { OllamaCamelCaseOptions } from "./types.js"; export interface ChatOllamaCallOptions extends BaseChatModelCallOptions { /** @@ -55,7 +56,9 @@ export interface PullModelOptions { /** * Input to chat model class. */ -export interface ChatOllamaInput extends BaseChatModelParams { +export interface ChatOllamaInput + extends BaseChatModelParams, + OllamaCamelCaseOptions { /** * The model to invoke. If the model does not exist, it * will be pulled. @@ -75,40 +78,7 @@ export interface ChatOllamaInput extends BaseChatModelParams { */ checkOrPullModel?: boolean; streaming?: boolean; - numa?: boolean; - numCtx?: number; - numBatch?: number; - numGpu?: number; - mainGpu?: number; - lowVram?: boolean; - f16Kv?: boolean; - logitsAll?: boolean; - vocabOnly?: boolean; - useMmap?: boolean; - useMlock?: boolean; - embeddingOnly?: boolean; - numThread?: number; - numKeep?: number; - seed?: number; - numPredict?: number; - topK?: number; - topP?: number; - tfsZ?: number; - typicalP?: number; - repeatLastN?: number; - temperature?: number; - repeatPenalty?: number; - presencePenalty?: number; - frequencyPenalty?: number; - mirostat?: number; - mirostatTau?: number; - mirostatEta?: number; - penalizeNewline?: boolean; format?: string; - /** - * @default "5m" - */ - keepAlive?: string | number; } /** diff --git a/libs/langchain-ollama/src/embeddings.ts b/libs/langchain-ollama/src/embeddings.ts new file mode 100644 index 0000000000..af188a8505 --- /dev/null +++ b/libs/langchain-ollama/src/embeddings.ts @@ -0,0 +1,151 @@ +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; +import { Ollama } from "ollama/browser"; +import type { Options as OllamaOptions } from "ollama"; +import { OllamaCamelCaseOptions } from "./types.js"; + +/** + * Interface for OllamaEmbeddings parameters. Extends EmbeddingsParams and + * defines additional parameters specific to the OllamaEmbeddings class. + */ +interface OllamaEmbeddingsParams extends EmbeddingsParams { + /** + * The Ollama model to use for embeddings. + * @default "mxbai-embed-large" + */ + model?: string; + + /** + * Base URL of the Ollama server + * @default "http://localhost:11434" + */ + baseUrl?: string; + + /** + * Extra headers to include in the Ollama API request + */ + headers?: Record; + + /** + * Defaults to "5m" + */ + keepAlive?: string; + + /** + * Whether or not to truncate the input text to fit inside the model's + * context window. + * @default false + */ + truncate?: boolean; + + /** + * Advanced Ollama API request parameters in camelCase, see + * https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values + * for details of the available parameters. + */ + requestOptions?: OllamaCamelCaseOptions; +} + +export class OllamaEmbeddings extends Embeddings { + model = "mxbai-embed-large"; + + baseUrl = "http://localhost:11434"; + + headers?: Record; + + keepAlive = "5m"; + + requestOptions?: Partial; + + client: Ollama; + + truncate = false; + + constructor(fields?: OllamaEmbeddingsParams) { + super({ maxConcurrency: 1, ...fields }); + + this.client = new Ollama({ + host: fields?.baseUrl, + }); + this.baseUrl = fields?.baseUrl ?? this.baseUrl; + + this.model = fields?.model ?? this.model; + this.headers = fields?.headers; + this.keepAlive = fields?.keepAlive ?? this.keepAlive; + this.truncate = fields?.truncate ?? this.truncate; + this.requestOptions = fields?.requestOptions + ? this._convertOptions(fields?.requestOptions) + : undefined; + } + + /** convert camelCased Ollama request options like "useMMap" to + * the snake_cased equivalent which the ollama API actually uses. + * Used only for consistency with the llms/Ollama and chatModels/Ollama classes + */ + _convertOptions( + requestOptions: OllamaCamelCaseOptions + ): Partial { + const snakeCasedOptions: Partial = {}; + const mapping: Record = { + embeddingOnly: "embedding_only", + frequencyPenalty: "frequency_penalty", + keepAlive: "keep_alive", + logitsAll: "logits_all", + lowVram: "low_vram", + mainGpu: "main_gpu", + mirostat: "mirostat", + mirostatEta: "mirostat_eta", + mirostatTau: "mirostat_tau", + numBatch: "num_batch", + numCtx: "num_ctx", + numGpu: "num_gpu", + numKeep: "num_keep", + numPredict: "num_predict", + numThread: "num_thread", + penalizeNewline: "penalize_newline", + presencePenalty: "presence_penalty", + repeatLastN: "repeat_last_n", + repeatPenalty: "repeat_penalty", + temperature: "temperature", + stop: "stop", + tfsZ: "tfs_z", + topK: "top_k", + topP: "top_p", + typicalP: "typical_p", + useMlock: "use_mlock", + useMmap: "use_mmap", + vocabOnly: "vocab_only", + f16Kv: "f16_kv", + numa: "numa", + seed: "seed", + }; + + for (const [key, value] of Object.entries(requestOptions)) { + const snakeCasedOption = mapping[key as keyof OllamaCamelCaseOptions]; + if (snakeCasedOption) { + snakeCasedOptions[snakeCasedOption as keyof OllamaOptions] = value; + } + } + return snakeCasedOptions; + } + + async embedDocuments(texts: string[]): Promise { + return this.embeddingWithRetry(texts); + } + + async embedQuery(text: string) { + return (await this.embeddingWithRetry([text]))[0]; + } + + private async embeddingWithRetry(texts: string[]): Promise { + const res = await this.caller.call(() => + this.client.embed({ + model: this.model, + input: texts, + keep_alive: this.keepAlive, + options: this.requestOptions, + truncate: this.truncate, + }) + ); + return res.embeddings; + } +} diff --git a/libs/langchain-ollama/src/index.ts b/libs/langchain-ollama/src/index.ts index 38c7cea7f4..d431a0dfa4 100644 --- a/libs/langchain-ollama/src/index.ts +++ b/libs/langchain-ollama/src/index.ts @@ -1 +1,3 @@ export * from "./chat_models.js"; +export * from "./embeddings.js"; +export * from "./types.js"; diff --git a/libs/langchain-ollama/src/tests/embeddings.int.test.ts b/libs/langchain-ollama/src/tests/embeddings.int.test.ts new file mode 100644 index 0000000000..866d68e871 --- /dev/null +++ b/libs/langchain-ollama/src/tests/embeddings.int.test.ts @@ -0,0 +1,19 @@ +import { test, expect } from "@jest/globals"; +import { OllamaEmbeddings } from "../embeddings.js"; + +test("Test OllamaEmbeddings.embedQuery", async () => { + const embeddings = new OllamaEmbeddings(); + const res = await embeddings.embedQuery("Hello world"); + expect(res).toHaveLength(1024); + expect(typeof res[0]).toBe("number"); +}); + +test("Test OllamaEmbeddings.embedDocuments", async () => { + const embeddings = new OllamaEmbeddings(); + const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]); + expect(res).toHaveLength(2); + expect(res[0]).toHaveLength(1024); + expect(typeof res[0][0]).toBe("number"); + expect(res[1]).toHaveLength(1024); + expect(typeof res[1][0]).toBe("number"); +}); diff --git a/libs/langchain-ollama/src/types.ts b/libs/langchain-ollama/src/types.ts new file mode 100644 index 0000000000..6b3aaf21c2 --- /dev/null +++ b/libs/langchain-ollama/src/types.ts @@ -0,0 +1,36 @@ +export interface OllamaCamelCaseOptions { + numa?: boolean; + numCtx?: number; + numBatch?: number; + numGpu?: number; + mainGpu?: number; + lowVram?: boolean; + f16Kv?: boolean; + logitsAll?: boolean; + vocabOnly?: boolean; + useMmap?: boolean; + useMlock?: boolean; + embeddingOnly?: boolean; + numThread?: number; + numKeep?: number; + seed?: number; + numPredict?: number; + topK?: number; + topP?: number; + tfsZ?: number; + typicalP?: number; + repeatLastN?: number; + temperature?: number; + repeatPenalty?: number; + presencePenalty?: number; + frequencyPenalty?: number; + mirostat?: number; + mirostatTau?: number; + mirostatEta?: number; + penalizeNewline?: boolean; + /** + * @default "5m" + */ + keepAlive?: string | number; + stop?: string[]; +} From 677f974ebe057eec3e64f185e016f042e5f92c2e Mon Sep 17 00:00:00 2001 From: "korbit-ai[bot]" <131444098+korbit-ai[bot]@users.noreply.github.com> Date: Thu, 15 Aug 2024 16:19:11 +0000 Subject: [PATCH 2/2] [skip ci]