Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(community) : Upgrade node-llama-cpp to be compatible with version 3 #7135

Merged
merged 16 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/src/embeddings/llama_cpp_basic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { LlamaCppEmbeddings } from "@langchain/community/embeddings/llama_cpp";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const embeddings = new LlamaCppEmbeddings({
const embeddings = await LlamaCppEmbeddings.initialize({
modelPath: llamaPath,
});

Expand Down
2 changes: 1 addition & 1 deletion examples/src/embeddings/llama_cpp_docs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const documents = ["Hello World!", "Bye Bye!"];

const embeddings = new LlamaCppEmbeddings({
const embeddings = await LlamaCppEmbeddings.initialize({
modelPath: llamaPath,
});

Expand Down
2 changes: 1 addition & 1 deletion examples/src/models/chat/integration_llama_cpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { HumanMessage } from "@langchain/core/messages";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath });
const model = await ChatLlamaCpp.initialize({ modelPath: llamaPath });

const response = await model.invoke([
new HumanMessage({ content: "My name is John." }),
Expand Down
5 changes: 4 additions & 1 deletion examples/src/models/chat/integration_llama_cpp_chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import { PromptTemplate } from "@langchain/core/prompts";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.5 });
const model = await ChatLlamaCpp.initialize({
modelPath: llamaPath,
temperature: 0.5,
});

const prompt = PromptTemplate.fromTemplate(
"What is a good name for a company that makes {product}?"
Expand Down
5 changes: 4 additions & 1 deletion examples/src/models/chat/integration_llama_cpp_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ import { ChatLlamaCpp } from "@langchain/community/chat_models/llama_cpp";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.7 });
const model = await ChatLlamaCpp.initialize({
modelPath: llamaPath,
temperature: 0.7,
});

const stream = await model.stream("Tell me a short story about a happy Llama.");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ import { SystemMessage, HumanMessage } from "@langchain/core/messages";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.7 });
const model = await ChatLlamaCpp.initialize({
modelPath: llamaPath,
temperature: 0.7,
});

const controller = new AbortController();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ import { SystemMessage, HumanMessage } from "@langchain/core/messages";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const llamaCpp = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.7 });
const llamaCpp = await ChatLlamaCpp.initialize({
modelPath: llamaPath,
temperature: 0.7,
});

const stream = await llamaCpp.stream([
new SystemMessage(
Expand Down
2 changes: 1 addition & 1 deletion examples/src/models/chat/integration_llama_cpp_system.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { SystemMessage, HumanMessage } from "@langchain/core/messages";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath });
const model = await ChatLlamaCpp.initialize({ modelPath: llamaPath });

const response = await model.invoke([
new SystemMessage(
Expand Down
2 changes: 1 addition & 1 deletion examples/src/models/llm/llama_cpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { LlamaCpp } from "@langchain/community/llms/llama_cpp";
const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";
const question = "Where do Llamas come from?";

const model = new LlamaCpp({ modelPath: llamaPath });
const model = await LlamaCpp.initialize({ modelPath: llamaPath });

console.log(`You: ${question}`);
const response = await model.invoke(question);
Expand Down
5 changes: 4 additions & 1 deletion examples/src/models/llm/llama_cpp_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ import { LlamaCpp } from "@langchain/community/llms/llama_cpp";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new LlamaCpp({ modelPath: llamaPath, temperature: 0.7 });
const model = await LlamaCpp.initialize({
modelPath: llamaPath,
temperature: 0.7,
});

const prompt = "Tell me a short story about a happy Llama.";

Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"mongodb": "^5.2.0",
"mysql2": "^3.9.8",
"neo4j-driver": "^5.17.0",
"node-llama-cpp": "^2",
"node-llama-cpp": "3.1.1",
"notion-to-md": "^3.1.0",
"officeparser": "^4.0.4",
"pdf-parse": "1.1.1",
Expand Down
93 changes: 59 additions & 34 deletions libs/langchain-community/src/chat_models/llama_cpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ import {
LlamaModel,
LlamaContext,
LlamaChatSession,
type ConversationInteraction,
type Token,
ChatUserMessage,
ChatModelResponse,
ChatHistoryItem,
getLlama,
} from "node-llama-cpp";

import {
Expand Down Expand Up @@ -47,7 +51,7 @@ export interface LlamaCppCallOptions extends BaseLanguageModelCallOptions {
* @example
* ```typescript
* // Initialize the ChatLlamaCpp model with the path to the model binary file.
* const model = new ChatLlamaCpp({
* const model = await ChatLlamaCpp.initialize({
* modelPath: "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin",
* temperature: 0.5,
* });
Expand Down Expand Up @@ -87,20 +91,35 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
return "ChatLlamaCpp";
}

constructor(inputs: LlamaCppInputs) {
public constructor(inputs: LlamaCppInputs) {
super(inputs);
this.maxTokens = inputs?.maxTokens;
this.temperature = inputs?.temperature;
this.topK = inputs?.topK;
this.topP = inputs?.topP;
this.trimWhitespaceSuffix = inputs?.trimWhitespaceSuffix;
this._model = createLlamaModel(inputs);
this._context = createLlamaContext(this._model, inputs);
this._session = null;
}

/**
* Initializes the llama_cpp model for usage in the chat models wrapper.
* @param inputs - the inputs passed onto the model.
* @returns A Promise that resolves to the ChatLlamaCpp type class.
*/
public static async initialize(
inputs: LlamaBaseCppInputs
): Promise<ChatLlamaCpp> {
const instance = new ChatLlamaCpp(inputs);
const llama = await getLlama();

instance._model = await createLlamaModel(inputs, llama);
instance._context = await createLlamaContext(instance._model, inputs);

return instance;
}

_llmType() {
return "llama2_cpp";
return "llama_cpp";
}

/** @ignore */
Expand Down Expand Up @@ -146,7 +165,9 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
signal: options.signal,
onToken: async (tokens: number[]) => {
options.onToken?.(tokens);
await runManager?.handleLLMNewToken(this._context.decode(tokens));
await runManager?.handleLLMNewToken(
this._model.detokenize(tokens.map((num) => num as Token))
);
},
maxTokens: this?.maxTokens,
temperature: this?.temperature,
Expand Down Expand Up @@ -180,20 +201,23 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
};

const prompt = this._buildPrompt(input);
const sequence = this._context.getSequence();

const stream = await this.caller.call(async () =>
this._context.evaluate(this._context.encode(prompt), promptOptions)
sequence.evaluate(this._model.tokenize(prompt), promptOptions)
);

for await (const chunk of stream) {
yield new ChatGenerationChunk({
text: this._context.decode([chunk]),
text: this._model.detokenize([chunk]),
message: new AIMessageChunk({
content: this._context.decode([chunk]),
content: this._model.detokenize([chunk]),
}),
generationInfo: {},
});
await runManager?.handleLLMNewToken(this._context.decode([chunk]) ?? "");
await runManager?.handleLLMNewToken(
this._model.detokenize([chunk]) ?? ""
);
}
}

Expand All @@ -202,12 +226,12 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
let prompt = "";
let sysMessage = "";
let noSystemMessages: BaseMessage[] = [];
let interactions: ConversationInteraction[] = [];
let interactions: ChatHistoryItem[] = [];

// Let's see if we have a system message
if (messages.findIndex((msg) => msg._getType() === "system") !== -1) {
if (messages.findIndex((msg) => msg.getType() === "system") !== -1) {
const sysMessages = messages.filter(
(message) => message._getType() === "system"
(message) => message.getType() === "system"
);

const systemMessageContent = sysMessages[sysMessages.length - 1].content;
Expand All @@ -222,7 +246,7 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {

// Now filter out the system messages
noSystemMessages = messages.filter(
(message) => message._getType() !== "system"
(message) => message.getType() !== "system"
);
} else {
noSystemMessages = messages;
Expand All @@ -231,9 +255,7 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
// Lets see if we just have a prompt left or are their previous interactions?
if (noSystemMessages.length > 1) {
// Is the last message a prompt?
if (
noSystemMessages[noSystemMessages.length - 1]._getType() === "human"
) {
if (noSystemMessages[noSystemMessages.length - 1].getType() === "human") {
const finalMessageContent =
noSystemMessages[noSystemMessages.length - 1].content;
if (typeof finalMessageContent !== "string") {
Expand Down Expand Up @@ -261,23 +283,23 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
// Now lets construct a session according to what we got
if (sysMessage !== "" && interactions.length > 0) {
this._session = new LlamaChatSession({
context: this._context,
conversationHistory: interactions,
contextSequence: this._context.getSequence(),
systemPrompt: sysMessage,
});
this._session.setChatHistory(interactions);
} else if (sysMessage !== "" && interactions.length === 0) {
this._session = new LlamaChatSession({
context: this._context,
contextSequence: this._context.getSequence(),
systemPrompt: sysMessage,
});
} else if (sysMessage === "" && interactions.length > 0) {
this._session = new LlamaChatSession({
context: this._context,
conversationHistory: interactions,
contextSequence: this._context.getSequence(),
});
this._session.setChatHistory(interactions);
} else {
this._session = new LlamaChatSession({
context: this._context,
contextSequence: this._context.getSequence(),
});
}

Expand All @@ -287,8 +309,8 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
// This builds a an array of interactions
protected _convertMessagesToInteractions(
messages: BaseMessage[]
): ConversationInteraction[] {
const result: ConversationInteraction[] = [];
): ChatHistoryItem[] {
const result: ChatHistoryItem[] = [];

for (let i = 0; i < messages.length; i += 2) {
if (i + 1 < messages.length) {
Expand All @@ -299,10 +321,13 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
"ChatLlamaCpp does not support non-string message content."
);
}
result.push({
prompt,
response,
});
const llamaPrompt: ChatUserMessage = { type: "user", text: prompt };
const llamaResponse: ChatModelResponse = {
type: "model",
response: [response],
};
result.push(llamaPrompt);
result.push(llamaResponse);
}
}

Expand All @@ -313,19 +338,19 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
const prompt = input
.map((message) => {
let messageText;
if (message._getType() === "human") {
if (message.getType() === "human") {
messageText = `[INST] ${message.content} [/INST]`;
} else if (message._getType() === "ai") {
} else if (message.getType() === "ai") {
messageText = message.content;
} else if (message._getType() === "system") {
} else if (message.getType() === "system") {
messageText = `<<SYS>> ${message.content} <</SYS>>`;
} else if (ChatMessage.isInstance(message)) {
messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice(
1
)}: ${message.content}`;
} else {
console.warn(
`Unsupported message type passed to llama_cpp: "${message._getType()}"`
`Unsupported message type passed to llama_cpp: "${message.getType()}"`
);
messageText = "";
}
Expand Down
Loading
Loading