Skip to content

Commit

Permalink
Merge pull request #7 from upstash/dx-973-upstash-llm-client-for-rag-…
Browse files Browse the repository at this point in the history
…chat

DX-973: Upstash LLM Client
  • Loading branch information
fahreddinozcan authored Jun 4, 2024
2 parents de612b5 + a11786b commit 262fcb6
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 0 deletions.
Binary file modified bun.lockb
Binary file not shown.
127 changes: 127 additions & 0 deletions src/upstash-llm-client.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/* eslint-disable unicorn/numeric-separators-style */
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import type { AIMessage } from "@langchain/core/messages";
import { Redis } from "@upstash/redis";
import { Index } from "@upstash/vector";
import type { StreamingTextResponse } from "ai";
import { afterAll, beforeAll, describe, expect, test } from "bun:test";

import { RAGChat } from "./rag-chat";
import { awaitUntilIndexed } from "./test-utils";
import { UpstashLLMClient } from "./upstash-llm-client";

describe("RAG Chat with Upstash LLM Client", () => {
const vector = new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
});
afterAll(async () => await vector.reset());

describe("meta-llama/Meta-Llama-3-8B-Instruct", () => {
const client = new UpstashLLMClient({
model: "meta-llama/Meta-Llama-3-8B-Instruct",
apiKey: process.env.UPSTASH_LLM_REST_TOKEN!,
streaming: true,
});

const ragChat = new RAGChat({
model: client,
vector,
redis: new Redis({
token: process.env.UPSTASH_REDIS_REST_TOKEN!,
url: process.env.UPSTASH_REDIS_REST_URL!,
}),
});

beforeAll(async () => {
await ragChat.addContext({
dataType: "text",
data: "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.",
});
await awaitUntilIndexed(vector);
});

test(
"should get result without streaming",
async () => {
const result = (await ragChat.chat(
"What year was the construction of the Eiffel Tower completed, and what is its height?",
{ stream: false }
)) as AIMessage;

expect(result.content).toContain("330");
},
{ timeout: 10000 }
);

test(
"should get result with streaming",
async () => {
const result = (await ragChat.chat(
"Which famous artworks can be found in the Louvre Museum?",
{
stream: true,
}
)) as StreamingTextResponse;

expect(result).toBeTruthy();
},
{ timeout: 10000 }
);
});

describe("mistralai/Mistral-7B-Instruct-v0.2", () => {
const client = new UpstashLLMClient({
model: "mistralai/Mistral-7B-Instruct-v0.2",
apiKey: process.env.UPSTASH_LLM_REST_TOKEN!,
streaming: true,
});

const ragChat = new RAGChat({
model: client,
vector,
redis: new Redis({
token: process.env.UPSTASH_REDIS_REST_TOKEN!,
url: process.env.UPSTASH_REDIS_REST_URL!,
}),
});

beforeAll(async () => {
await ragChat.addContext({
dataType: "text",
data: "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.",
});
await awaitUntilIndexed(vector);
});

// afterAll(async () => await vector.reset());

test(
"should get result without streaming",
async () => {
const result = (await ragChat.chat(
"What year was the construction of the Eiffel Tower completed, and what is its height?",
{ stream: false }
)) as AIMessage;

expect(result.content).toContain("330");
},
{ timeout: 10000 }
);

test(
"should get result with streaming",
async () => {
const result = (await ragChat.chat(
"Which famous artworks can be found in the Louvre Museum?",
{
stream: true,
}
)) as StreamingTextResponse;

expect(result).toBeTruthy();
},
{ timeout: 10000 }
);
});
});
80 changes: 80 additions & 0 deletions src/upstash-llm-client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import { ChatOpenAI } from "@langchain/openai";
import { type BaseMessage } from "@langchain/core/messages";
import { type ChatGeneration } from "@langchain/core/outputs";

export type Model = "mistralai/Mistral-7B-Instruct-v0.2" | "meta-llama/Meta-Llama-3-8B-Instruct";

export type UpstashLLMClientConfig = {
model: Model;
apiKey: string;
streaming: boolean;
maxTokens?: number;
stop?: string[];
topP?: number;
temperature?: number;
frequencyPenalty?: number;
presencePenalty?: number;
n?: number;
logitBias?: Record<string, number>;
logProbs?: number;
topLogprobs?: number;
};

export class UpstashLLMClient extends ChatOpenAI {
modelName: Model;
apiKey: string;
maxTokens?: number;
stop?: string[];
temperature = 1;
n = 1;
streaming: boolean;
topP = 1;
frequencyPenalty = 0;
presencePenalty = 0;
logitBias?: Record<string, number>;
logProbs?: number;
topLogprobs?: number;

constructor(config: UpstashLLMClientConfig) {
super(
{
modelName: config.model,
apiKey: config.apiKey,
maxTokens: config.maxTokens,
streaming: config.streaming,
topP: config.topP,
temperature: config.temperature,
n: config.n,
frequencyPenalty: config.frequencyPenalty,
presencePenalty: config.presencePenalty,
logitBias: config.logitBias,
topLogprobs: config.topLogprobs,
stop: config.stop,
},
{
baseURL: "https://qstash.upstash.io/llm/v1",
}
);

this.modelName = config.model;
this.apiKey = config.apiKey;
this.maxTokens = config.maxTokens;
this.streaming = config.streaming;
this.logitBias = config.logitBias;
this.topLogprobs = config.topLogprobs;
this.stop = config.stop;

// @ts-expect-error This is overriding the method
this.getNumTokensFromGenerations = (_generations: ChatGeneration[]): Promise<number> => {
return Promise.resolve(0);
};

this.getNumTokensFromMessages = (
_messages: BaseMessage[]
): Promise<{ totalCount: number; countPerMessage: number[] }> => {
return new Promise((resolve, _) => {
resolve({ totalCount: 0, countPerMessage: [0] });
});
};
}
}

0 comments on commit 262fcb6

Please sign in to comment.