-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ollama[minor]: Port embeddings to ollama package #5
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; | ||
import { Ollama } from "ollama/browser"; | ||
import type { Options as OllamaOptions } from "ollama"; | ||
import { OllamaCamelCaseOptions } from "./types.js"; | ||
|
||
/** | ||
* Interface for OllamaEmbeddings parameters. Extends EmbeddingsParams and | ||
* defines additional parameters specific to the OllamaEmbeddings class. | ||
*/ | ||
interface OllamaEmbeddingsParams extends EmbeddingsParams { | ||
/** | ||
* The Ollama model to use for embeddings. | ||
* @default "mxbai-embed-large" | ||
*/ | ||
model?: string; | ||
|
||
/** | ||
* Base URL of the Ollama server | ||
* @default "http://localhost:11434" | ||
*/ | ||
baseUrl?: string; | ||
|
||
/** | ||
* Extra headers to include in the Ollama API request | ||
*/ | ||
headers?: Record<string, string>; | ||
|
||
/** | ||
* Defaults to "5m" | ||
*/ | ||
keepAlive?: string; | ||
|
||
/** | ||
* Whether or not to truncate the input text to fit inside the model's | ||
* context window. | ||
* @default false | ||
*/ | ||
truncate?: boolean; | ||
|
||
/** | ||
* Advanced Ollama API request parameters in camelCase, see | ||
* https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values | ||
* for details of the available parameters. | ||
*/ | ||
requestOptions?: OllamaCamelCaseOptions; | ||
} | ||
|
||
export class OllamaEmbeddings extends Embeddings { | ||
model = "mxbai-embed-large"; | ||
|
||
baseUrl = "http://localhost:11434"; | ||
|
||
headers?: Record<string, string>; | ||
|
||
keepAlive = "5m"; | ||
|
||
requestOptions?: Partial<OllamaOptions>; | ||
|
||
client: Ollama; | ||
|
||
truncate = false; | ||
|
||
constructor(fields?: OllamaEmbeddingsParams) { | ||
super({ maxConcurrency: 1, ...fields }); | ||
|
||
this.client = new Ollama({ | ||
host: fields?.baseUrl, | ||
}); | ||
this.baseUrl = fields?.baseUrl ?? this.baseUrl; | ||
|
||
this.model = fields?.model ?? this.model; | ||
this.headers = fields?.headers; | ||
this.keepAlive = fields?.keepAlive ?? this.keepAlive; | ||
this.truncate = fields?.truncate ?? this.truncate; | ||
this.requestOptions = fields?.requestOptions | ||
? this._convertOptions(fields?.requestOptions) | ||
: undefined; | ||
} | ||
|
||
/** convert camelCased Ollama request options like "useMMap" to | ||
* the snake_cased equivalent which the ollama API actually uses. | ||
* Used only for consistency with the llms/Ollama and chatModels/Ollama classes | ||
*/ | ||
_convertOptions( | ||
requestOptions: OllamaCamelCaseOptions | ||
): Partial<OllamaOptions> { | ||
const snakeCasedOptions: Partial<OllamaOptions> = {}; | ||
const mapping: Record<keyof OllamaCamelCaseOptions, string> = { | ||
embeddingOnly: "embedding_only", | ||
frequencyPenalty: "frequency_penalty", | ||
keepAlive: "keep_alive", | ||
logitsAll: "logits_all", | ||
lowVram: "low_vram", | ||
mainGpu: "main_gpu", | ||
mirostat: "mirostat", | ||
mirostatEta: "mirostat_eta", | ||
mirostatTau: "mirostat_tau", | ||
numBatch: "num_batch", | ||
numCtx: "num_ctx", | ||
numGpu: "num_gpu", | ||
numKeep: "num_keep", | ||
numPredict: "num_predict", | ||
numThread: "num_thread", | ||
penalizeNewline: "penalize_newline", | ||
presencePenalty: "presence_penalty", | ||
repeatLastN: "repeat_last_n", | ||
repeatPenalty: "repeat_penalty", | ||
temperature: "temperature", | ||
stop: "stop", | ||
tfsZ: "tfs_z", | ||
topK: "top_k", | ||
topP: "top_p", | ||
typicalP: "typical_p", | ||
useMlock: "use_mlock", | ||
useMmap: "use_mmap", | ||
vocabOnly: "vocab_only", | ||
f16Kv: "f16_kv", | ||
numa: "numa", | ||
seed: "seed", | ||
}; | ||
|
||
for (const [key, value] of Object.entries(requestOptions)) { | ||
const snakeCasedOption = mapping[key as keyof OllamaCamelCaseOptions]; | ||
if (snakeCasedOption) { | ||
snakeCasedOptions[snakeCasedOption as keyof OllamaOptions] = value; | ||
} | ||
} | ||
return snakeCasedOptions; | ||
} | ||
|
||
async embedDocuments(texts: string[]): Promise<number[][]> { | ||
return this.embeddingWithRetry(texts); | ||
} | ||
|
||
async embedQuery(text: string) { | ||
return (await this.embeddingWithRetry([text]))[0]; | ||
} | ||
|
||
private async embeddingWithRetry(texts: string[]): Promise<number[][]> { | ||
const res = await this.caller.call(() => | ||
this.client.embed({ | ||
model: this.model, | ||
input: texts, | ||
keep_alive: this.keepAlive, | ||
options: this.requestOptions, | ||
truncate: this.truncate, | ||
}) | ||
); | ||
return res.embeddings; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
export * from "./chat_models.js"; | ||
export * from "./embeddings.js"; | ||
export * from "./types.js"; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import { test, expect } from "@jest/globals"; | ||
import { OllamaEmbeddings } from "../embeddings.js"; | ||
|
||
test("Test OllamaEmbeddings.embedQuery", async () => { | ||
const embeddings = new OllamaEmbeddings(); | ||
const res = await embeddings.embedQuery("Hello world"); | ||
expect(res).toHaveLength(1024); | ||
expect(typeof res[0]).toBe("number"); | ||
}); | ||
Comment on lines
+4
to
+9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current tests for OllamaEmbeddings look good for basic functionality, but they don't cover error handling or edge cases. Consider adding tests for:
This will ensure the OllamaEmbeddings class behaves correctly under various conditions and improves the overall reliability of the package.
|
||
|
||
test("Test OllamaEmbeddings.embedDocuments", async () => { | ||
const embeddings = new OllamaEmbeddings(); | ||
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]); | ||
expect(res).toHaveLength(2); | ||
expect(res[0]).toHaveLength(1024); | ||
expect(typeof res[0][0]).toBe("number"); | ||
expect(res[1]).toHaveLength(1024); | ||
expect(typeof res[1][0]).toBe("number"); | ||
}); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
export interface OllamaCamelCaseOptions { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
|
||
numa?: boolean; | ||
numCtx?: number; | ||
numBatch?: number; | ||
numGpu?: number; | ||
mainGpu?: number; | ||
lowVram?: boolean; | ||
f16Kv?: boolean; | ||
logitsAll?: boolean; | ||
vocabOnly?: boolean; | ||
useMmap?: boolean; | ||
useMlock?: boolean; | ||
embeddingOnly?: boolean; | ||
numThread?: number; | ||
numKeep?: number; | ||
seed?: number; | ||
numPredict?: number; | ||
topK?: number; | ||
topP?: number; | ||
tfsZ?: number; | ||
typicalP?: number; | ||
repeatLastN?: number; | ||
temperature?: number; | ||
repeatPenalty?: number; | ||
presencePenalty?: number; | ||
frequencyPenalty?: number; | ||
mirostat?: number; | ||
mirostatTau?: number; | ||
mirostatEta?: number; | ||
penalizeNewline?: boolean; | ||
/** | ||
* @default "5m" | ||
*/ | ||
keepAlive?: string | number; | ||
stop?: string[]; | ||
} | ||
Comment on lines
+1
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
headers
property is defined and set in the constructor, but it's not being used in the API calls. To ensure custom headers are applied, you should pass theheaders
to the Ollama client when it's initialized. Update the constructor to include headers in the Ollama client initialization: