Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DX-932: Add Custom LLM Interface #2

Merged
merged 2 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions src/custom-llm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/* eslint-disable @typescript-eslint/no-magic-numbers */
import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms";
import { type CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { GenerationChunk } from "@langchain/core/outputs";
import { type OpenAIClient } from "@langchain/openai";

// eslint-disable-next-line @typescript-eslint/consistent-type-definitions
export interface UpstashLLMParameters extends BaseLLMParams {
/** Writer API key */
apiKey?: string;

/** Model to use */
model?: string;

/** Sampling temperature to use */
temperature?: number;

/** Minimum number of tokens to generate. */
minTokens?: number;

/** Maximum number of tokens to generate in the completion. */
maxTokens?: number;

/** Generates this many completions server-side and returns the "best"." */
bestOf?: number;

/** Penalizes repeated tokens according to frequency. */
frequencyPenalty?: number;

/** Penalizes repeated tokens regardless of frequency. */
presencePenalty?: number;

/** Total probability mass of tokens to consider at each step. */
topP?: number;
}

export default class UpstashLLM extends LLM {
temperature = 0.7;

maxTokens = 256;

topP = 1;

frequencyPenalty = 0;

presencePenalty = 0;

n = 1;

model = "mistralai/Mistral-7B-Instruct-v0.2";

batchSize = 20;

apiKey: string;

constructor(fields: UpstashLLMParameters) {
super({});
if (!fields.apiKey) {
throw new Error("apiKey is required");
}

this.topP = fields.topP ?? this.topP;
this.temperature = fields.temperature ?? this.temperature;
this.maxTokens = fields.maxTokens ?? this.maxTokens;
this.frequencyPenalty = fields.frequencyPenalty ?? this.frequencyPenalty;
this.presencePenalty = fields.presencePenalty ?? this.presencePenalty;
this.model = fields.model ?? this.model;
this.apiKey = fields.apiKey;
}

_llmType() {
return "Upstash LLM";
}

async _call(prompt: string) {
const url = `${process.env.UPSTASH_MODELS_BACKEND_URL}/v1/completions`;
const data = {
prompt: prompt,
model: this.model,
max_tokens: this.maxTokens,
top_p: this.topP,
temperature: this.temperature,
frequency_penalty: this.frequencyPenalty,
};

const response = await fetch(url, {
method: "POST",
body: JSON.stringify(data),
headers: {
Authorization: `Bearer ${this.apiKey}`,
Accept: "application/json",
"Content-Type": "application/json",
},
});

const object = await response.json();

const result = object as OpenAIClient.Completions.Completion;

return result.choices[0].text;
}

async *_streamResponseChunks(
prompt: string,
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
for (const letter of prompt.slice(0, this.n)) {
yield new GenerationChunk({
text: letter,
});

await runManager?.handleLLMNewToken(letter);
}
}
}
45 changes: 45 additions & 0 deletions src/rag-chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import { DEFAULT_REDIS_DB_NAME, DEFAULT_VECTOR_DB_NAME } from "./constants";
import { RatelimitUpstashError } from "./error";
import { PromptTemplate } from "@langchain/core/prompts";
import { sleep } from "bun";
import UpstashLLM from "./custom-llm";

describe("RAG Chat with advance configs and direct instances", async () => {
const vector = new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
});

const ragChat = await RAGChat.initialize({
email: process.env.UPSTASH_EMAIL!,
token: process.env.UPSTASH_TOKEN!,
Expand Down Expand Up @@ -49,13 +51,15 @@ describe("RAG Chat with advance configs and direct instances", async () => {
"What year was the construction of the Eiffel Tower completed, and what is its height?",
{ stream: false }
)) as AIMessage;

expect(result.content).toContain("330");
});

test("should get result with streaming", async () => {
const result = (await ragChat.chat("Which famous artworks can be found in the Louvre Museum?", {
stream: true,
})) as StreamingTextResponse;

expect(result).toBeTruthy();
});
});
Expand Down Expand Up @@ -252,3 +256,44 @@ describe("RAG Chat with custom template", async () => {
{ timeout: 30_000 }
);
});

describe("RAG Chat with custom LLM", async () => {
const ragChat = await RAGChat.initialize({
email: process.env.UPSTASH_EMAIL!,
token: process.env.UPSTASH_TOKEN!,
model: new UpstashLLM({
apiKey: process.env.UPSTASH_MODELS_REST_TOKEN,
model: "mistralai/Mistral-7B-Instruct-v0.2",
}),
redis: new Redis({
token: process.env.UPSTASH_REDIS_REST_TOKEN!,
url: process.env.UPSTASH_REDIS_REST_URL!,
}),
});

beforeAll(async () => {
await ragChat.addContext(
"Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.",
"text"
);
//eslint-disable-next-line @typescript-eslint/no-magic-numbers
await sleep(3000);
});

test("should get result without streaming", async () => {
const result = (await ragChat.chat(
"What year was the construction of the Eiffel Tower completed, and what is its height?",
{ stream: false }
)) as AIMessage;

expect(result).toContain("330");
});

test("should get result with streaming", async () => {
const result = (await ragChat.chat("Which famous artworks can be found in the Louvre Museum?", {
stream: true,
})) as StreamingTextResponse;

expect(result).toBeTruthy();
});
});