Skip to content

Commit

Permalink
DX-932: Add Custom LLM Interface
Browse files Browse the repository at this point in the history
  • Loading branch information
fahreddinozcan committed May 20, 2024
1 parent ef23aab commit ebaba3b
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 0 deletions.
116 changes: 116 additions & 0 deletions src/custom-llm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/* eslint-disable @typescript-eslint/no-magic-numbers */
import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms";
import { type CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { GenerationChunk } from "@langchain/core/outputs";
import { type OpenAIClient } from "@langchain/openai";

// eslint-disable-next-line @typescript-eslint/consistent-type-definitions
export interface UpstashLLMParameters extends BaseLLMParams {
/** Writer API key */
apiKey?: string;

/** Model to use */
model?: string;

/** Sampling temperature to use */
temperature?: number;

/** Minimum number of tokens to generate. */
minTokens?: number;

/** Maximum number of tokens to generate in the completion. */
maxTokens?: number;

/** Generates this many completions server-side and returns the "best"." */
bestOf?: number;

/** Penalizes repeated tokens according to frequency. */
frequencyPenalty?: number;

/** Penalizes repeated tokens regardless of frequency. */
presencePenalty?: number;

/** Total probability mass of tokens to consider at each step. */
topP?: number;
}

export default class UpstashLLM extends LLM {
temperature = 0.7;

maxTokens = 256;

topP = 1;

frequencyPenalty = 0;

presencePenalty = 0;

n = 1;

model = "mistralai/Mistral-7B-Instruct-v0.2";

batchSize = 20;

apiKey: string;

constructor(fields: UpstashLLMParameters) {
super({});
if (!fields.apiKey) {
throw new Error("apiKey is required");
}

this.topP = fields.topP ?? this.topP;
this.temperature = fields.temperature ?? this.temperature;
this.maxTokens = fields.maxTokens ?? this.maxTokens;
this.frequencyPenalty = fields.frequencyPenalty ?? this.frequencyPenalty;
this.presencePenalty = fields.presencePenalty ?? this.presencePenalty;
this.model = fields.model ?? this.model;
this.apiKey = fields.apiKey;
}

_llmType() {
return "Upstash LLM";
}

async _call(prompt: string) {
const url = `${process.env.UPSTASH_MODELS_BACKEND_URL}/v1/completions`;
const data = {
prompt: prompt,
model: this.model,
max_tokens: this.maxTokens,
top_p: this.topP,
temperature: this.temperature,
frequency_penalty: this.frequencyPenalty,
};

const response = await fetch(url, {
method: "POST",
body: JSON.stringify(data),
headers: {
Authorization: `Bearer ${this.apiKey}`,
Accept: "application/json",
"Content-Type": "application/json",
},
});

const object = await response.json();

const result = object as OpenAIClient.Completions.Completion;

return result.choices[0].text;
}

async *_streamResponseChunks(
prompt: string,
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
for (const letter of prompt.slice(0, this.n)) {
yield new GenerationChunk({
text: letter,
});

await runManager?.handleLLMNewToken(letter);
}
}
}
45 changes: 45 additions & 0 deletions src/rag-chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import { DEFAULT_REDIS_DB_NAME, DEFAULT_VECTOR_DB_NAME } from "./constants";
import { RatelimitUpstashError } from "./error";
import { PromptTemplate } from "@langchain/core/prompts";
import { sleep } from "bun";
import UpstashLLM from "./custom-llm";

describe("RAG Chat with advance configs and direct instances", async () => {
const vector = new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
});

const ragChat = await RAGChat.initialize({
email: process.env.UPSTASH_EMAIL!,
token: process.env.UPSTASH_TOKEN!,
Expand Down Expand Up @@ -49,13 +51,15 @@ describe("RAG Chat with advance configs and direct instances", async () => {
"What year was the construction of the Eiffel Tower completed, and what is its height?",
{ stream: false }
)) as AIMessage;

expect(result.content).toContain("330");
});

test("should get result with streaming", async () => {
const result = (await ragChat.chat("Which famous artworks can be found in the Louvre Museum?", {
stream: true,
})) as StreamingTextResponse;

expect(result).toBeTruthy();
});
});
Expand Down Expand Up @@ -252,3 +256,44 @@ describe("RAG Chat with custom template", async () => {
{ timeout: 30_000 }
);
});

describe("RAG Chat with custom LLM", async () => {
const ragChat = await RAGChat.initialize({
email: process.env.UPSTASH_EMAIL!,
token: process.env.UPSTASH_TOKEN!,
model: new UpstashLLM({
apiKey: process.env.UPSTASH_MODELS_REST_TOKEN,
model: "mistralai/Mistral-7B-Instruct-v0.2",
}),
redis: new Redis({
token: process.env.UPSTASH_REDIS_REST_TOKEN!,
url: process.env.UPSTASH_REDIS_REST_URL!,
}),
});

beforeAll(async () => {
await ragChat.addContext(
"Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.",
"text"
);
//eslint-disable-next-line @typescript-eslint/no-magic-numbers
await sleep(3000);
});

test("should get result without streaming", async () => {
const result = (await ragChat.chat(
"What year was the construction of the Eiffel Tower completed, and what is its height?",
{ stream: false }
)) as AIMessage;

expect(result).toContain("330");
});

test("should get result with streaming", async () => {
const result = (await ragChat.chat("Which famous artworks can be found in the Louvre Museum?", {
stream: true,
})) as StreamingTextResponse;

expect(result).toBeTruthy();
});
});

0 comments on commit ebaba3b

Please sign in to comment.