Skip to content

Commit

Permalink
add standard tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jul 7, 2024
1 parent 641690d commit 4d186b9
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 56 deletions.
5 changes: 3 additions & 2 deletions libs/langchain-ollama/package.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"name": "langchain-ollama",
"name": "@langchain/ollama",
"version": "0.0.0",
"description": "Ollama integration for LangChain.js",
"type": "module",
Expand Down Expand Up @@ -35,12 +35,13 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@langchain/core": ">0.1.0 <0.3.0",
"@langchain/core": ">0.2.14 <0.3.0",
"ollama": "^0.5.2"
},
"devDependencies": {
"@jest/globals": "^29.5.0",
"@langchain/scripts": "~0.0.14",
"@langchain/standard-tests": "0.0.0",
"@swc/core": "^1.3.90",
"@swc/jest": "^0.2.29",
"@tsconfig/recommended": "^1.0.3",
Expand Down
86 changes: 69 additions & 17 deletions libs/langchain-ollama/src/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { MessageContentText, type BaseMessage } from "@langchain/core/messages";
import {
AIMessage,
MessageContentText,
UsageMetadata,
type BaseMessage,
} from "@langchain/core/messages";
import { type BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";

import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
Expand All @@ -12,6 +17,7 @@ import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs";
import { AIMessageChunk } from "@langchain/core/messages";
import type {
ChatRequest as OllamaChatRequest,
ChatResponse as OllamaChatResponse,
Message as OllamaMessage,
} from "ollama";

Expand Down Expand Up @@ -144,7 +150,7 @@ export interface ChatOllamaInput extends BaseChatModelParams {
* exist.
* @default false
*/
checkModelExists?: boolean;
checkOrPullModel?: boolean;
streaming?: boolean;
numa?: boolean;
numCtx?: number;
Expand Down Expand Up @@ -183,7 +189,22 @@ export interface ChatOllamaInput extends BaseChatModelParams {
}

/**
* Integration with a chat model.
* Integration with the Ollama SDK.
*
* @example
* ```typescript
* import { ChatOllama } from "@langchain/ollama";
*
* const model = new ChatOllama({
* model: "llama3", // Default model.
* });
*
* const result = await model.invoke([
* "human",
* "What is a good name for a company that makes colorful socks?",
* ]);
* console.log(result);
* ```
*/
export class ChatOllama
extends BaseChatModel<ChatOllamaCallOptions, AIMessageChunk>
Expand Down Expand Up @@ -262,7 +283,7 @@ export class ChatOllama

client: Ollama;

checkModelExists = false;
checkOrPullModel = false;

constructor(fields?: ChatOllamaInput) {
super(fields ?? {});
Expand Down Expand Up @@ -304,7 +325,7 @@ export class ChatOllama
this.streaming = fields?.streaming;
this.format = fields?.format;
this.keepAlive = fields?.keepAlive ?? this.keepAlive;
this.checkModelExists = fields?.checkModelExists ?? this.checkModelExists;
this.checkOrPullModel = fields?.checkOrPullModel ?? this.checkOrPullModel;
}

// Replace
Expand Down Expand Up @@ -415,31 +436,41 @@ export class ChatOllama
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (this.checkModelExists) {
if (this.checkOrPullModel) {
if (!(await this.checkModelExistsOnMachine(this.model))) {
await this.pull(this.model, {
logProgress: true,
});
}
}

let finalChunk: ChatGenerationChunk | undefined;
let finalChunk: AIMessageChunk | undefined;
for await (const chunk of this._streamResponseChunks(
messages,
options,
runManager
)) {
if (!finalChunk) {
finalChunk = chunk;
finalChunk = chunk.message;
} else {
finalChunk = finalChunk.concat(chunk);
finalChunk = finalChunk.concat(chunk.message);
}
}

// Convert from AIMessageChunk to AIMessage since `generate` expects AIMessage.
const nonChunkMessage = new AIMessage({
content: finalChunk?.content ?? "",
response_metadata: finalChunk?.response_metadata,
usage_metadata: finalChunk?.usage_metadata,
});
return {
generations: [
{
text: finalChunk?.text ?? "",
message: finalChunk?.message as AIMessageChunk,
text:
typeof nonChunkMessage.content === "string"
? nonChunkMessage.content
: "",
message: nonChunkMessage,
},
],
};
Expand All @@ -454,7 +485,7 @@ export class ChatOllama
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (this.checkModelExists) {
if (this.checkOrPullModel) {
if (!(await this.checkModelExistsOnMachine(this.model))) {
await this.pull(this.model, {
logProgress: true,
Expand All @@ -465,26 +496,47 @@ export class ChatOllama
const params = this.invocationParams(options);
const ollamaMessages = convertToOllamaMessages(messages);

const stream = this.client.chat({
const stream = await this.client.chat({
...params,
messages: ollamaMessages,
stream: true,
});
for await (const chunk of await stream) {

let lastMetadata: Omit<OllamaChatResponse, "message"> | undefined;
const usageMetadata: UsageMetadata = {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
};

for await (const chunk of stream) {
if (options.signal?.aborted) {
this.client.abort();
}
const { message: responseMessage, ...rest } = chunk;
usageMetadata.input_tokens += rest.prompt_eval_count ?? 0;
usageMetadata.output_tokens += rest.eval_count ?? 0;
usageMetadata.total_tokens =
usageMetadata.input_tokens + usageMetadata.output_tokens;
lastMetadata = rest;

yield new ChatGenerationChunk({
text: responseMessage.content,
message: new AIMessageChunk({
content: responseMessage.content,
response_metadata: {
...rest,
},
}),
});
await runManager?.handleLLMNewToken(responseMessage.content);
}

// Yield the `response_metadata` as the final chunk.
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({
content: "",
response_metadata: lastMetadata,
usage_metadata: usageMetadata,
}),
});
}
}
11 changes: 5 additions & 6 deletions libs/langchain-ollama/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ import {
import { ChatOllama } from "../chat_models.js";

test("test invoke", async () => {
const ollama = new ChatOllama({});
const result = await ollama.invoke(
"What is a good name for a company that makes colorful socks?"
);
const ollama = new ChatOllama();
const result = await ollama.invoke([
"human",
"What is a good name for a company that makes colorful socks?",
]);
expect(result).toBeDefined();
expect(typeof result.content).toBe("string");
expect(result.content.length).toBeGreaterThan(1);
Expand Down Expand Up @@ -127,11 +128,9 @@ AI:`;
test("Test ChatOllama with an image", async () => {
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
console.log("__dirname", __dirname);
const imageData = await fs.readFile(path.join(__dirname, "/data/hotdog.jpg"));
const chat = new ChatOllama({
model: "llava",
checkModelExists: true,
});
const res = await chat.invoke([
new HumanMessage({
Expand Down
26 changes: 26 additions & 0 deletions libs/langchain-ollama/src/tests/chat_models.standard.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/* eslint-disable no-process-env */
import { test, expect } from "@jest/globals";
import { ChatModelIntegrationTests } from "@langchain/standard-tests";
import { AIMessageChunk } from "@langchain/core/messages";
import { ChatOllama, ChatOllamaCallOptions } from "../chat_models.js";

class ChatOllamaStandardIntegrationTests extends ChatModelIntegrationTests<
ChatOllamaCallOptions,
AIMessageChunk
> {
constructor() {
super({
Cls: ChatOllama,
chatModelHasToolCalling: false,
chatModelHasStructuredOutput: false,
constructorArgs: {},
});
}
}

const testClass = new ChatOllamaStandardIntegrationTests();

test("ChatOllamaStandardIntegrationTests", async () => {
const testResults = await testClass.runTests();
expect(testResults).toBe(true);
});
34 changes: 34 additions & 0 deletions libs/langchain-ollama/src/tests/chat_models.standard.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/* eslint-disable no-process-env */
import { test, expect } from "@jest/globals";
import { ChatModelUnitTests } from "@langchain/standard-tests";
import { AIMessageChunk } from "@langchain/core/messages";
import { ChatOllama, ChatOllamaCallOptions } from "../chat_models.js";

class ChatOllamaStandardUnitTests extends ChatModelUnitTests<
ChatOllamaCallOptions,
AIMessageChunk
> {
constructor() {
super({
Cls: ChatOllama,
chatModelHasToolCalling: false,
chatModelHasStructuredOutput: false,
constructorArgs: {},
});
}

testChatModelInitApiKey() {
this.skipTestMessage(
"testChatModelInitApiKey",
"ChatOllama",
"API key is not required for ChatOllama"
);
}
}

const testClass = new ChatOllamaStandardUnitTests();

test("ChatOllamaStandardUnitTests", () => {
const testResults = testClass.runTests();
expect(testResults).toBe(true);
});
63 changes: 32 additions & 31 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11462,6 +11462,38 @@ __metadata:
languageName: unknown
linkType: soft

"@langchain/ollama@workspace:libs/langchain-ollama":
version: 0.0.0-use.local
resolution: "@langchain/ollama@workspace:libs/langchain-ollama"
dependencies:
"@jest/globals": ^29.5.0
"@langchain/core": ">0.2.14 <0.3.0"
"@langchain/scripts": ~0.0.14
"@langchain/standard-tests": 0.0.0
"@swc/core": ^1.3.90
"@swc/jest": ^0.2.29
"@tsconfig/recommended": ^1.0.3
"@typescript-eslint/eslint-plugin": ^6.12.0
"@typescript-eslint/parser": ^6.12.0
dotenv: ^16.3.1
dpdm: ^3.12.0
eslint: ^8.33.0
eslint-config-airbnb-base: ^15.0.0
eslint-config-prettier: ^8.6.0
eslint-plugin-import: ^2.27.5
eslint-plugin-no-instanceof: ^1.0.1
eslint-plugin-prettier: ^4.2.1
jest: ^29.5.0
jest-environment-node: ^29.6.4
ollama: ^0.5.2
prettier: ^2.8.3
release-it: ^15.10.1
rollup: ^4.5.2
ts-jest: ^29.1.0
typescript: <5.2.0
languageName: unknown
linkType: soft

"@langchain/openai@>=0.1.0 <0.3.0, @langchain/openai@workspace:*, @langchain/openai@workspace:^, @langchain/openai@workspace:libs/langchain-openai":
version: 0.0.0-use.local
resolution: "@langchain/openai@workspace:libs/langchain-openai"
Expand Down Expand Up @@ -29870,37 +29902,6 @@ __metadata:
languageName: node
linkType: hard

"langchain-ollama@workspace:libs/langchain-ollama":
version: 0.0.0-use.local
resolution: "langchain-ollama@workspace:libs/langchain-ollama"
dependencies:
"@jest/globals": ^29.5.0
"@langchain/core": ">0.1.0 <0.3.0"
"@langchain/scripts": ~0.0.14
"@swc/core": ^1.3.90
"@swc/jest": ^0.2.29
"@tsconfig/recommended": ^1.0.3
"@typescript-eslint/eslint-plugin": ^6.12.0
"@typescript-eslint/parser": ^6.12.0
dotenv: ^16.3.1
dpdm: ^3.12.0
eslint: ^8.33.0
eslint-config-airbnb-base: ^15.0.0
eslint-config-prettier: ^8.6.0
eslint-plugin-import: ^2.27.5
eslint-plugin-no-instanceof: ^1.0.1
eslint-plugin-prettier: ^4.2.1
jest: ^29.5.0
jest-environment-node: ^29.6.4
ollama: ^0.5.2
prettier: ^2.8.3
release-it: ^15.10.1
rollup: ^4.5.2
ts-jest: ^29.1.0
typescript: <5.2.0
languageName: unknown
linkType: soft

"langchain@npm:0.2.3":
version: 0.2.3
resolution: "langchain@npm:0.2.3"
Expand Down

0 comments on commit 4d186b9

Please sign in to comment.