diff --git a/src/custom-llm.ts b/src/custom-llm.ts new file mode 100644 index 0000000..6266fc9 --- /dev/null +++ b/src/custom-llm.ts @@ -0,0 +1,116 @@ +/* eslint-disable @typescript-eslint/no-magic-numbers */ +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; +import { type CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { GenerationChunk } from "@langchain/core/outputs"; +import { type OpenAIClient } from "@langchain/openai"; + +// eslint-disable-next-line @typescript-eslint/consistent-type-definitions +export interface UpstashLLMParameters extends BaseLLMParams { + /** Writer API key */ + apiKey?: string; + + /** Model to use */ + model?: string; + + /** Sampling temperature to use */ + temperature?: number; + + /** Minimum number of tokens to generate. */ + minTokens?: number; + + /** Maximum number of tokens to generate in the completion. */ + maxTokens?: number; + + /** Generates this many completions server-side and returns the "best"." */ + bestOf?: number; + + /** Penalizes repeated tokens according to frequency. */ + frequencyPenalty?: number; + + /** Penalizes repeated tokens regardless of frequency. */ + presencePenalty?: number; + + /** Total probability mass of tokens to consider at each step. */ + topP?: number; +} + +export default class UpstashLLM extends LLM { + temperature = 0.7; + + maxTokens = 256; + + topP = 1; + + frequencyPenalty = 0; + + presencePenalty = 0; + + n = 1; + + model = "mistralai/Mistral-7B-Instruct-v0.2"; + + batchSize = 20; + + apiKey: string; + + constructor(fields: UpstashLLMParameters) { + super({}); + if (!fields.apiKey) { + throw new Error("apiKey is required"); + } + + this.topP = fields.topP ?? this.topP; + this.temperature = fields.temperature ?? this.temperature; + this.maxTokens = fields.maxTokens ?? this.maxTokens; + this.frequencyPenalty = fields.frequencyPenalty ?? this.frequencyPenalty; + this.presencePenalty = fields.presencePenalty ?? this.presencePenalty; + this.model = fields.model ?? this.model; + this.apiKey = fields.apiKey; + } + + _llmType() { + return "Upstash LLM"; + } + + async _call(prompt: string) { + const url = `${process.env.UPSTASH_MODELS_BACKEND_URL}/v1/completions`; + const data = { + prompt: prompt, + model: this.model, + max_tokens: this.maxTokens, + top_p: this.topP, + temperature: this.temperature, + frequency_penalty: this.frequencyPenalty, + }; + + const response = await fetch(url, { + method: "POST", + body: JSON.stringify(data), + headers: { + Authorization: `Bearer ${this.apiKey}`, + Accept: "application/json", + "Content-Type": "application/json", + }, + }); + + const object = await response.json(); + + const result = object as OpenAIClient.Completions.Completion; + + return result.choices[0].text; + } + + async *_streamResponseChunks( + prompt: string, + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + for (const letter of prompt.slice(0, this.n)) { + yield new GenerationChunk({ + text: letter, + }); + + await runManager?.handleLLMNewToken(letter); + } + } +} diff --git a/src/rag-chat.test.ts b/src/rag-chat.test.ts index e23cb3e..85b1c96 100644 --- a/src/rag-chat.test.ts +++ b/src/rag-chat.test.ts @@ -10,12 +10,14 @@ import { DEFAULT_REDIS_DB_NAME, DEFAULT_VECTOR_DB_NAME } from "./constants"; import { RatelimitUpstashError } from "./error"; import { PromptTemplate } from "@langchain/core/prompts"; import { sleep } from "bun"; +import UpstashLLM from "./custom-llm"; describe("RAG Chat with advance configs and direct instances", async () => { const vector = new Index({ token: process.env.UPSTASH_VECTOR_REST_TOKEN!, url: process.env.UPSTASH_VECTOR_REST_URL!, }); + const ragChat = await RAGChat.initialize({ email: process.env.UPSTASH_EMAIL!, token: process.env.UPSTASH_TOKEN!, @@ -49,6 +51,7 @@ describe("RAG Chat with advance configs and direct instances", async () => { "What year was the construction of the Eiffel Tower completed, and what is its height?", { stream: false } )) as AIMessage; + expect(result.content).toContain("330"); }); @@ -56,6 +59,7 @@ describe("RAG Chat with advance configs and direct instances", async () => { const result = (await ragChat.chat("Which famous artworks can be found in the Louvre Museum?", { stream: true, })) as StreamingTextResponse; + expect(result).toBeTruthy(); }); }); @@ -252,3 +256,44 @@ describe("RAG Chat with custom template", async () => { { timeout: 30_000 } ); }); + +describe("RAG Chat with custom LLM", async () => { + const ragChat = await RAGChat.initialize({ + email: process.env.UPSTASH_EMAIL!, + token: process.env.UPSTASH_TOKEN!, + model: new UpstashLLM({ + apiKey: process.env.UPSTASH_MODELS_REST_TOKEN, + model: "mistralai/Mistral-7B-Instruct-v0.2", + }), + redis: new Redis({ + token: process.env.UPSTASH_REDIS_REST_TOKEN!, + url: process.env.UPSTASH_REDIS_REST_URL!, + }), + }); + + beforeAll(async () => { + await ragChat.addContext( + "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.", + "text" + ); + //eslint-disable-next-line @typescript-eslint/no-magic-numbers + await sleep(3000); + }); + + test("should get result without streaming", async () => { + const result = (await ragChat.chat( + "What year was the construction of the Eiffel Tower completed, and what is its height?", + { stream: false } + )) as AIMessage; + + expect(result).toContain("330"); + }); + + test("should get result with streaming", async () => { + const result = (await ragChat.chat("Which famous artworks can be found in the Louvre Museum?", { + stream: true, + })) as StreamingTextResponse; + + expect(result).toBeTruthy(); + }); +});