diff --git a/libs/langchain-ollama/package.json b/libs/langchain-ollama/package.json index 46a81f9dd5d1..7b855707b65d 100644 --- a/libs/langchain-ollama/package.json +++ b/libs/langchain-ollama/package.json @@ -1,5 +1,5 @@ { - "name": "langchain-ollama", + "name": "@langchain/ollama", "version": "0.0.0", "description": "Ollama integration for LangChain.js", "type": "module", @@ -35,12 +35,13 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@langchain/core": ">0.1.0 <0.3.0", + "@langchain/core": ">0.2.14 <0.3.0", "ollama": "^0.5.2" }, "devDependencies": { "@jest/globals": "^29.5.0", "@langchain/scripts": "~0.0.14", + "@langchain/standard-tests": "0.0.0", "@swc/core": "^1.3.90", "@swc/jest": "^0.2.29", "@tsconfig/recommended": "^1.0.3", diff --git a/libs/langchain-ollama/src/chat_models.ts b/libs/langchain-ollama/src/chat_models.ts index d3e7fba69157..d4764b0d2f7a 100644 --- a/libs/langchain-ollama/src/chat_models.ts +++ b/libs/langchain-ollama/src/chat_models.ts @@ -1,4 +1,9 @@ -import { MessageContentText, type BaseMessage } from "@langchain/core/messages"; +import { + AIMessage, + MessageContentText, + UsageMetadata, + type BaseMessage, +} from "@langchain/core/messages"; import { type BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; @@ -12,6 +17,7 @@ import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs"; import { AIMessageChunk } from "@langchain/core/messages"; import type { ChatRequest as OllamaChatRequest, + ChatResponse as OllamaChatResponse, Message as OllamaMessage, } from "ollama"; @@ -144,7 +150,7 @@ export interface ChatOllamaInput extends BaseChatModelParams { * exist. * @default false */ - checkModelExists?: boolean; + checkOrPullModel?: boolean; streaming?: boolean; numa?: boolean; numCtx?: number; @@ -183,7 +189,22 @@ export interface ChatOllamaInput extends BaseChatModelParams { } /** - * Integration with a chat model. + * Integration with the Ollama SDK. + * + * @example + * ```typescript + * import { ChatOllama } from "@langchain/ollama"; + * + * const model = new ChatOllama({ + * model: "llama3", // Default model. + * }); + * + * const result = await model.invoke([ + * "human", + * "What is a good name for a company that makes colorful socks?", + * ]); + * console.log(result); + * ``` */ export class ChatOllama extends BaseChatModel @@ -262,7 +283,7 @@ export class ChatOllama client: Ollama; - checkModelExists = false; + checkOrPullModel = false; constructor(fields?: ChatOllamaInput) { super(fields ?? {}); @@ -304,7 +325,7 @@ export class ChatOllama this.streaming = fields?.streaming; this.format = fields?.format; this.keepAlive = fields?.keepAlive ?? this.keepAlive; - this.checkModelExists = fields?.checkModelExists ?? this.checkModelExists; + this.checkOrPullModel = fields?.checkOrPullModel ?? this.checkOrPullModel; } // Replace @@ -415,7 +436,7 @@ export class ChatOllama options: this["ParsedCallOptions"], runManager?: CallbackManagerForLLMRun ): Promise { - if (this.checkModelExists) { + if (this.checkOrPullModel) { if (!(await this.checkModelExistsOnMachine(this.model))) { await this.pull(this.model, { logProgress: true, @@ -423,23 +444,33 @@ export class ChatOllama } } - let finalChunk: ChatGenerationChunk | undefined; + let finalChunk: AIMessageChunk | undefined; for await (const chunk of this._streamResponseChunks( messages, options, runManager )) { if (!finalChunk) { - finalChunk = chunk; + finalChunk = chunk.message; } else { - finalChunk = finalChunk.concat(chunk); + finalChunk = finalChunk.concat(chunk.message); } } + + // Convert from AIMessageChunk to AIMessage since `generate` expects AIMessage. + const nonChunkMessage = new AIMessage({ + content: finalChunk?.content ?? "", + response_metadata: finalChunk?.response_metadata, + usage_metadata: finalChunk?.usage_metadata, + }); return { generations: [ { - text: finalChunk?.text ?? "", - message: finalChunk?.message as AIMessageChunk, + text: + typeof nonChunkMessage.content === "string" + ? nonChunkMessage.content + : "", + message: nonChunkMessage, }, ], }; @@ -454,7 +485,7 @@ export class ChatOllama options: this["ParsedCallOptions"], runManager?: CallbackManagerForLLMRun ): AsyncGenerator { - if (this.checkModelExists) { + if (this.checkOrPullModel) { if (!(await this.checkModelExistsOnMachine(this.model))) { await this.pull(this.model, { logProgress: true, @@ -465,26 +496,47 @@ export class ChatOllama const params = this.invocationParams(options); const ollamaMessages = convertToOllamaMessages(messages); - const stream = this.client.chat({ + const stream = await this.client.chat({ ...params, messages: ollamaMessages, stream: true, }); - for await (const chunk of await stream) { + + let lastMetadata: Omit | undefined; + const usageMetadata: UsageMetadata = { + input_tokens: 0, + output_tokens: 0, + total_tokens: 0, + }; + + for await (const chunk of stream) { if (options.signal?.aborted) { this.client.abort(); } const { message: responseMessage, ...rest } = chunk; + usageMetadata.input_tokens += rest.prompt_eval_count ?? 0; + usageMetadata.output_tokens += rest.eval_count ?? 0; + usageMetadata.total_tokens = + usageMetadata.input_tokens + usageMetadata.output_tokens; + lastMetadata = rest; + yield new ChatGenerationChunk({ text: responseMessage.content, message: new AIMessageChunk({ content: responseMessage.content, - response_metadata: { - ...rest, - }, }), }); await runManager?.handleLLMNewToken(responseMessage.content); } + + // Yield the `response_metadata` as the final chunk. + yield new ChatGenerationChunk({ + text: "", + message: new AIMessageChunk({ + content: "", + response_metadata: lastMetadata, + usage_metadata: usageMetadata, + }), + }); } } diff --git a/libs/langchain-ollama/src/tests/chat_models.int.test.ts b/libs/langchain-ollama/src/tests/chat_models.int.test.ts index 57cab8749d18..47be20413724 100644 --- a/libs/langchain-ollama/src/tests/chat_models.int.test.ts +++ b/libs/langchain-ollama/src/tests/chat_models.int.test.ts @@ -11,10 +11,11 @@ import { import { ChatOllama } from "../chat_models.js"; test("test invoke", async () => { - const ollama = new ChatOllama({}); - const result = await ollama.invoke( - "What is a good name for a company that makes colorful socks?" - ); + const ollama = new ChatOllama(); + const result = await ollama.invoke([ + "human", + "What is a good name for a company that makes colorful socks?", + ]); expect(result).toBeDefined(); expect(typeof result.content).toBe("string"); expect(result.content.length).toBeGreaterThan(1); @@ -127,11 +128,9 @@ AI:`; test("Test ChatOllama with an image", async () => { const __filename = fileURLToPath(import.meta.url); const __dirname = path.dirname(__filename); - console.log("__dirname", __dirname); const imageData = await fs.readFile(path.join(__dirname, "/data/hotdog.jpg")); const chat = new ChatOllama({ model: "llava", - checkModelExists: true, }); const res = await chat.invoke([ new HumanMessage({ diff --git a/libs/langchain-ollama/src/tests/chat_models.standard.int.test.ts b/libs/langchain-ollama/src/tests/chat_models.standard.int.test.ts new file mode 100644 index 000000000000..c09c6a88ad6f --- /dev/null +++ b/libs/langchain-ollama/src/tests/chat_models.standard.int.test.ts @@ -0,0 +1,26 @@ +/* eslint-disable no-process-env */ +import { test, expect } from "@jest/globals"; +import { ChatModelIntegrationTests } from "@langchain/standard-tests"; +import { AIMessageChunk } from "@langchain/core/messages"; +import { ChatOllama, ChatOllamaCallOptions } from "../chat_models.js"; + +class ChatOllamaStandardIntegrationTests extends ChatModelIntegrationTests< + ChatOllamaCallOptions, + AIMessageChunk +> { + constructor() { + super({ + Cls: ChatOllama, + chatModelHasToolCalling: false, + chatModelHasStructuredOutput: false, + constructorArgs: {}, + }); + } +} + +const testClass = new ChatOllamaStandardIntegrationTests(); + +test("ChatOllamaStandardIntegrationTests", async () => { + const testResults = await testClass.runTests(); + expect(testResults).toBe(true); +}); diff --git a/libs/langchain-ollama/src/tests/chat_models.standard.test.ts b/libs/langchain-ollama/src/tests/chat_models.standard.test.ts new file mode 100644 index 000000000000..aa4e1e156790 --- /dev/null +++ b/libs/langchain-ollama/src/tests/chat_models.standard.test.ts @@ -0,0 +1,34 @@ +/* eslint-disable no-process-env */ +import { test, expect } from "@jest/globals"; +import { ChatModelUnitTests } from "@langchain/standard-tests"; +import { AIMessageChunk } from "@langchain/core/messages"; +import { ChatOllama, ChatOllamaCallOptions } from "../chat_models.js"; + +class ChatOllamaStandardUnitTests extends ChatModelUnitTests< + ChatOllamaCallOptions, + AIMessageChunk +> { + constructor() { + super({ + Cls: ChatOllama, + chatModelHasToolCalling: false, + chatModelHasStructuredOutput: false, + constructorArgs: {}, + }); + } + + testChatModelInitApiKey() { + this.skipTestMessage( + "testChatModelInitApiKey", + "ChatOllama", + "API key is not required for ChatOllama" + ); + } +} + +const testClass = new ChatOllamaStandardUnitTests(); + +test("ChatOllamaStandardUnitTests", () => { + const testResults = testClass.runTests(); + expect(testResults).toBe(true); +}); diff --git a/yarn.lock b/yarn.lock index 633c5eae894a..1b70f5010c7e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -11462,6 +11462,38 @@ __metadata: languageName: unknown linkType: soft +"@langchain/ollama@workspace:libs/langchain-ollama": + version: 0.0.0-use.local + resolution: "@langchain/ollama@workspace:libs/langchain-ollama" + dependencies: + "@jest/globals": ^29.5.0 + "@langchain/core": ">0.2.14 <0.3.0" + "@langchain/scripts": ~0.0.14 + "@langchain/standard-tests": 0.0.0 + "@swc/core": ^1.3.90 + "@swc/jest": ^0.2.29 + "@tsconfig/recommended": ^1.0.3 + "@typescript-eslint/eslint-plugin": ^6.12.0 + "@typescript-eslint/parser": ^6.12.0 + dotenv: ^16.3.1 + dpdm: ^3.12.0 + eslint: ^8.33.0 + eslint-config-airbnb-base: ^15.0.0 + eslint-config-prettier: ^8.6.0 + eslint-plugin-import: ^2.27.5 + eslint-plugin-no-instanceof: ^1.0.1 + eslint-plugin-prettier: ^4.2.1 + jest: ^29.5.0 + jest-environment-node: ^29.6.4 + ollama: ^0.5.2 + prettier: ^2.8.3 + release-it: ^15.10.1 + rollup: ^4.5.2 + ts-jest: ^29.1.0 + typescript: <5.2.0 + languageName: unknown + linkType: soft + "@langchain/openai@>=0.1.0 <0.3.0, @langchain/openai@workspace:*, @langchain/openai@workspace:^, @langchain/openai@workspace:libs/langchain-openai": version: 0.0.0-use.local resolution: "@langchain/openai@workspace:libs/langchain-openai" @@ -29870,37 +29902,6 @@ __metadata: languageName: node linkType: hard -"langchain-ollama@workspace:libs/langchain-ollama": - version: 0.0.0-use.local - resolution: "langchain-ollama@workspace:libs/langchain-ollama" - dependencies: - "@jest/globals": ^29.5.0 - "@langchain/core": ">0.1.0 <0.3.0" - "@langchain/scripts": ~0.0.14 - "@swc/core": ^1.3.90 - "@swc/jest": ^0.2.29 - "@tsconfig/recommended": ^1.0.3 - "@typescript-eslint/eslint-plugin": ^6.12.0 - "@typescript-eslint/parser": ^6.12.0 - dotenv: ^16.3.1 - dpdm: ^3.12.0 - eslint: ^8.33.0 - eslint-config-airbnb-base: ^15.0.0 - eslint-config-prettier: ^8.6.0 - eslint-plugin-import: ^2.27.5 - eslint-plugin-no-instanceof: ^1.0.1 - eslint-plugin-prettier: ^4.2.1 - jest: ^29.5.0 - jest-environment-node: ^29.6.4 - ollama: ^0.5.2 - prettier: ^2.8.3 - release-it: ^15.10.1 - rollup: ^4.5.2 - ts-jest: ^29.1.0 - typescript: <5.2.0 - languageName: unknown - linkType: soft - "langchain@npm:0.2.3": version: 0.2.3 resolution: "langchain@npm:0.2.3"