diff --git a/bun.lockb b/bun.lockb index 2b5b12c..5fa88c7 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index 8ee798d..6795117 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@upstash/rag-chat", - "version": "0.0.20-alpha", + "version": "0.0.21-alpha", "main": "./dist/index.js", "module": "./dist/index.mjs", "types": "./dist/index.d.ts", @@ -52,6 +52,7 @@ "@langchain/core": "^0.1.58", "@langchain/openai": "^0.0.28", "@upstash/sdk": "0.0.30-alpha", - "ai": "^3.0.35" + "ai": "^3.1.1", + "nanoid": "^5.0.7" } } diff --git a/src/client-factory.ts b/src/client-factory.ts index 33cbcb8..499ee95 100644 --- a/src/client-factory.ts +++ b/src/client-factory.ts @@ -1,10 +1,10 @@ +/* eslint-disable unicorn/no-useless-undefined */ import type { Index, Redis } from "@upstash/sdk"; import type { Config } from "./config"; import { Upstash } from "@upstash/sdk"; import { RedisClient } from "./clients/redis"; import { VectorClient } from "./clients/vector"; -import { InternalUpstashError } from "./error/internal"; export type ClientFactoryConfig = Pick; export class ClientFactory { @@ -44,23 +44,19 @@ export class ClientFactory { redis: TInit["redis"] extends true ? Redis : undefined; vector: TInit["vector"] extends true ? Index : undefined; }> { - let redis: Redis | undefined; - let vector: Index | undefined; + let redisPromise: Promise = Promise.resolve(undefined); + let vectorPromise: Promise = Promise.resolve(undefined); if (options.redis) { - redis = await this.createRedisClient(); - if (!redis) { - throw new InternalUpstashError("Couldn't initialize Redis client"); - } + redisPromise = this.createRedisClient(); } if (options.vector) { - vector = await this.createVectorClient(); - if (!vector) { - throw new InternalUpstashError("Couldn't initialize Vector client"); - } + vectorPromise = this.createVectorClient(); } + const [redis, vector] = await Promise.all([redisPromise, vectorPromise]); + return { redis, vector } as { redis: TInit["redis"] extends true ? Redis : undefined; vector: TInit["vector"] extends true ? Index : undefined; diff --git a/src/clients/vector/index.test.ts b/src/clients/vector/index.test.ts index baf9e06..837643d 100644 --- a/src/clients/vector/index.test.ts +++ b/src/clients/vector/index.test.ts @@ -22,7 +22,7 @@ describe("Vector Client", () => { await upstashSDK.deleteVectorIndex(DEFAULT_VECTOR_DB_NAME); }, - { timeout: 20_000 } + { timeout: 30_000 } ); test( @@ -38,7 +38,7 @@ describe("Vector Client", () => { await upstashSDK.deleteVectorIndex("test-name"); }, - { timeout: 20_000 } + { timeout: 30_000 } ); test( @@ -62,6 +62,6 @@ describe("Vector Client", () => { await upstashSDK.deleteVectorIndex(indexName); }, - { timeout: 20_000 } + { timeout: 30_000 } ); }); diff --git a/src/clients/vector/index.ts b/src/clients/vector/index.ts index b76802d..627f517 100644 --- a/src/clients/vector/index.ts +++ b/src/clients/vector/index.ts @@ -3,6 +3,7 @@ import { Index } from "@upstash/sdk"; import type { PreferredRegions } from "../../types"; import { DEFAULT_VECTOR_DB_NAME } from "../../constants"; +import { delay } from "../../utils"; export const DEFAULT_VECTOR_CONFIG: CreateIndexPayload = { name: DEFAULT_VECTOR_DB_NAME, @@ -94,6 +95,7 @@ export class VectorClient { }); } } + await delay(); if (index?.name) { const client = await this.upstashSDK.newVectorClient(index.name); diff --git a/src/rag-chat-base.ts b/src/rag-chat-base.ts new file mode 100644 index 0000000..7870114 --- /dev/null +++ b/src/rag-chat-base.ts @@ -0,0 +1,96 @@ +import type { Callbacks } from "@langchain/core/callbacks/manager"; +import type { AIMessage, BaseMessage } from "@langchain/core/messages"; +import { RunnableSequence, RunnableWithMessageHistory } from "@langchain/core/runnables"; +import { StreamingTextResponse, LangChainStream } from "ai"; + +import type { PrepareChatResult, ChatOptions } from "./types"; +import { sanitizeQuestion, formatChatHistory } from "./utils"; +import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base"; +import type { PromptTemplate } from "@langchain/core/prompts"; +import type { HistoryService, RetrievePayload } from "./services"; +import type { RetrievalService } from "./services/retrieval"; + +type CustomInputValues = { chat_history?: BaseMessage[]; question: string; context: string }; + +export class RAGChatBase { + protected retrievalService: RetrievalService; + protected historyService: HistoryService; + + #model: BaseLanguageModelInterface; + #template: PromptTemplate; + + constructor( + retrievalService: RetrievalService, + historyService: HistoryService, + config: { model: BaseLanguageModelInterface; template: PromptTemplate } + ) { + this.retrievalService = retrievalService; + this.historyService = historyService; + + this.#model = config.model; + this.#template = config.template; + } + + protected async prepareChat({ + question: input, + similarityThreshold, + topK, + }: RetrievePayload): Promise { + const question = sanitizeQuestion(input); + const facts = await this.retrievalService.retrieveFromVectorDb({ + question, + similarityThreshold, + topK, + }); + return { question, facts }; + } + + protected streamingChainCall = ( + chatOptions: ChatOptions, + question: string, + facts: string + ): StreamingTextResponse => { + const { stream, handlers } = LangChainStream(); + void this.chainCall(chatOptions, question, facts, [handlers]); + return new StreamingTextResponse(stream, {}); + }; + + protected chainCall( + chatOptions: ChatOptions, + question: string, + facts: string, + handlers?: Callbacks + ): Promise { + const formattedHistoryChain = RunnableSequence.from([ + { + chat_history: (input) => formatChatHistory(input.chat_history ?? []), + question: (input) => input.question, + context: (input) => input.context, + }, + this.#template, + this.#model, + ]); + + const chainWithMessageHistory = new RunnableWithMessageHistory({ + runnable: formattedHistoryChain, + getMessageHistory: (sessionId: string) => + this.historyService.getMessageHistory({ + sessionId, + length: chatOptions.includeHistory, + }), + inputMessagesKey: "question", + historyMessagesKey: "chat_history", + }); + + return chainWithMessageHistory.invoke( + { + question, + context: facts, + }, + { + callbacks: handlers ?? undefined, + configurable: { sessionId: chatOptions.sessionId }, + } + ) as Promise; + } +} diff --git a/src/rag-chat.test.ts b/src/rag-chat.test.ts new file mode 100644 index 0000000..f614530 --- /dev/null +++ b/src/rag-chat.test.ts @@ -0,0 +1,233 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ +import { ChatOpenAI } from "@langchain/openai"; +import { RAGChat } from "./rag-chat"; +import { afterAll, beforeAll, describe, expect, test } from "bun:test"; +import type { AIMessage } from "@langchain/core/messages"; +import { delay } from "./utils"; +import { Index, Ratelimit, Redis, Upstash } from "@upstash/sdk"; +import type { StreamingTextResponse } from "ai"; +import { DEFAULT_REDIS_DB_NAME, DEFAULT_VECTOR_DB_NAME } from "./constants"; +import { RatelimitUpstashError } from "./error"; +import { PromptTemplate } from "@langchain/core/prompts"; + +describe("RAG Chat with advance configs and direct instances", async () => { + const ragChat = await RAGChat.initialize({ + email: process.env.UPSTASH_EMAIL!, + token: process.env.UPSTASH_TOKEN!, + model: new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: true, + verbose: false, + temperature: 0, + apiKey: process.env.OPENAI_API_KEY, + }), + vector: new Index({ + token: process.env.UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.UPSTASH_VECTOR_REST_URL!, + }), + redis: new Redis({ + token: process.env.UPSTASH_REDIS_REST_TOKEN!, + url: process.env.UPSTASH_REDIS_REST_URL!, + }), + }); + + beforeAll(async () => { + await ragChat.addContext( + "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall." + ); + }); + + test("should get result without streaming", async () => { + const result = (await ragChat.chat( + "What year was the construction of the Eiffel Tower completed, and what is its height?", + { stream: false } + )) as AIMessage; + expect(result.content).toContain("330"); + }); + + test("should get result with streaming", async () => { + const result = (await ragChat.chat("Which famous artworks can be found in the Louvre Museum?", { + stream: true, + })) as StreamingTextResponse; + expect(result).toBeTruthy(); + }); +}); + +describe("RAG Chat with basic configs", async () => { + const ragChat = await RAGChat.initialize({ + email: process.env.UPSTASH_EMAIL!, + token: process.env.UPSTASH_TOKEN!, + model: new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: false, + verbose: false, + temperature: 0, + apiKey: process.env.OPENAI_API_KEY, + }), + region: "eu-west-1", + }); + + const upstashSDK = new Upstash({ + email: process.env.UPSTASH_EMAIL!, + token: process.env.UPSTASH_TOKEN!, + }); + + afterAll(async () => { + await upstashSDK.deleteRedisDatabase(DEFAULT_REDIS_DB_NAME); + await upstashSDK.deleteVectorIndex(DEFAULT_VECTOR_DB_NAME); + }); + + test( + "should get result without streaming", + async () => { + await ragChat.addContext( + "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall." + ); + + // Wait for it to be indexed + // eslint-disable-next-line @typescript-eslint/no-magic-numbers + await delay(3000); + + const result = (await ragChat.chat( + "What year was the construction of the Eiffel Tower completed, and what is its height?", + { stream: false } + )) as AIMessage; + + expect(result.content).toContain("330"); + }, + { timeout: 30_000 } + ); +}); + +describe("RAG Chat with ratelimit", async () => { + const redis = new Redis({ + token: process.env.UPSTASH_REDIS_REST_TOKEN!, + url: process.env.UPSTASH_REDIS_REST_URL!, + }); + const ragChat = await RAGChat.initialize({ + email: process.env.UPSTASH_EMAIL!, + token: process.env.UPSTASH_TOKEN!, + model: new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: false, + verbose: false, + temperature: 0, + apiKey: process.env.OPENAI_API_KEY, + }), + vector: new Index({ + token: process.env.UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.UPSTASH_VECTOR_REST_URL!, + }), + redis, + ratelimit: new Ratelimit({ + redis, + limiter: Ratelimit.tokenBucket(1, "1d", 1), + prefix: "@upstash/rag-chat-ratelimit", + }), + }); + + afterAll(async () => { + await redis.flushdb(); + }); + + test("should throw ratelimit error", async () => { + await ragChat.chat( + "What year was the construction of the Eiffel Tower completed, and what is its height?", + { stream: false } + ); + + const throwable = async () => { + await ragChat.chat("You shall not pass", { stream: false }); + }; + + expect(throwable).toThrowError(RatelimitUpstashError); + }); +}); + +describe("RAG Chat with instance names", async () => { + const ragChat = await RAGChat.initialize({ + email: process.env.UPSTASH_EMAIL!, + token: process.env.UPSTASH_TOKEN!, + redis: "my-fancy-redis-db", + vector: "my-fancy-vector-db", + model: new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: false, + verbose: false, + temperature: 0, + apiKey: process.env.OPENAI_API_KEY, + }), + region: "eu-west-1", + }); + + afterAll(async () => { + const upstashSDK = new Upstash({ + email: process.env.UPSTASH_EMAIL!, + token: process.env.UPSTASH_TOKEN!, + }); + await upstashSDK.deleteRedisDatabase("my-fancy-redis-db"); + await upstashSDK.deleteVectorIndex("my-fancy-vector-db"); + }); + + test( + "should get result without streaming", + async () => { + await ragChat.addContext( + "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall." + ); + + // Wait for it to be indexed + // eslint-disable-next-line @typescript-eslint/no-magic-numbers + await delay(3000); + + const result = (await ragChat.chat( + "What year was the construction of the Eiffel Tower completed, and what is its height?", + { stream: false } + )) as AIMessage; + + expect(result.content).toContain("330"); + }, + { timeout: 30_000 } + ); +}); + +describe("RAG Chat with custom template", async () => { + const ragChat = await RAGChat.initialize({ + email: process.env.UPSTASH_EMAIL!, + token: process.env.UPSTASH_TOKEN!, + vector: new Index({ + token: process.env.UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.UPSTASH_VECTOR_REST_URL!, + }), + redis: new Redis({ + token: process.env.UPSTASH_REDIS_REST_TOKEN!, + url: process.env.UPSTASH_REDIS_REST_URL!, + }), + template: PromptTemplate.fromTemplate("Just say `I'm a cookie monster`. Nothing else."), + model: new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: false, + verbose: false, + temperature: 0, + apiKey: process.env.OPENAI_API_KEY, + }), + }); + + test( + "should get result without streaming", + async () => { + await ragChat.addContext("Ankara is the capital of Turkiye."); + + // Wait for it to be indexed + // eslint-disable-next-line @typescript-eslint/no-magic-numbers + await delay(3000); + + const result = (await ragChat.chat("Where is the capital of Turkiye?", { + stream: false, + })) as AIMessage; + + expect(result.content).toContain("I'm a cookie monster"); + }, + { timeout: 30_000 } + ); +}); diff --git a/src/rag-chat.ts b/src/rag-chat.ts index dd5eb76..893fc8a 100644 --- a/src/rag-chat.ts +++ b/src/rag-chat.ts @@ -1,34 +1,25 @@ -import type { Callbacks } from "@langchain/core/callbacks/manager"; -import type { BaseMessage } from "@langchain/core/messages"; -import { RunnableSequence, RunnableWithMessageHistory } from "@langchain/core/runnables"; -import { LangChainStream, StreamingTextResponse } from "ai"; import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base"; +import type { AIMessage } from "@langchain/core/messages"; import type { PromptTemplate } from "@langchain/core/prompts"; +import type { StreamingTextResponse } from "ai"; import { HistoryService } from "./services/history"; -import { RetrievalService } from "./services/retrieval"; import { RateLimitService } from "./services/ratelimit"; -import type { RetrievePayload } from "./services/retrieval"; +import { RetrievalService } from "./services/retrieval"; import { QA_TEMPLATE } from "./prompts"; import { UpstashModelError } from "./error/model"; import { RatelimitUpstashError } from "./error/ratelimit"; -import type { ChatOptions, PrepareChatResult, RAGChatConfig } from "./types"; import { ClientFactory } from "./client-factory"; import { Config } from "./config"; -import { appendDefaultsIfNeeded, formatChatHistory, sanitizeQuestion } from "./utils"; - -type CustomInputValues = { chat_history?: BaseMessage[]; question: string; context: string }; +import { RAGChatBase } from "./rag-chat-base"; +import type { ChatOptions, RAGChatConfig } from "./types"; +import { appendDefaultsIfNeeded } from "./utils"; -export class RAGChat { - private retrievalService: RetrievalService; - private historyService: HistoryService; - private ratelimitService: RateLimitService; - - private model: BaseLanguageModelInterface; - private template: PromptTemplate; +export class RAGChat extends RAGChatBase { + #ratelimitService: RateLimitService; constructor( retrievalService: RetrievalService, @@ -36,37 +27,16 @@ export class RAGChat { ratelimitService: RateLimitService, config: { model: BaseLanguageModelInterface; template: PromptTemplate } ) { - this.retrievalService = retrievalService; - this.historyService = historyService; - this.ratelimitService = ratelimitService; - - this.model = config.model; - this.template = config.template; - } - - private async prepareChat({ - question: input, - similarityThreshold, - topK, - }: RetrievePayload): Promise { - const question = sanitizeQuestion(input); - const facts = await this.retrievalService.retrieveFromVectorDb({ - question, - similarityThreshold, - topK, - }); - return { question, facts }; + super(retrievalService, historyService, config); + this.#ratelimitService = ratelimitService; } - async chat( - input: string, - options: ChatOptions - ): Promise> { + async chat(input: string, options: ChatOptions): Promise { // Adds chat session id and ratelimit session id if not provided. const options_ = appendDefaultsIfNeeded(options); //Checks ratelimit of the user. If not enabled `success` will be always true. - const { success, resetTime } = await this.ratelimitService.checkLimit( + const { success, resetTime } = await this.#ratelimitService.checkLimit( options_.ratelimitSessionId ); @@ -89,53 +59,11 @@ export class RAGChat { : this.chainCall(options_, question, facts); } - private streamingChainCall = ( - chatOptions: ChatOptions, - question: string, - facts: string - ): StreamingTextResponse => { - const { stream, handlers } = LangChainStream(); - void this.chainCall(chatOptions, question, facts, [handlers]); - return new StreamingTextResponse(stream, {}); - }; - - private chainCall( - chatOptions: ChatOptions, - question: string, - facts: string, - handlers?: Callbacks - ) { - const formattedHistoryChain = RunnableSequence.from([ - { - chat_history: (input) => formatChatHistory(input.chat_history ?? []), - question: (input) => input.question, - context: (input) => input.context, - }, - this.template, - this.model, - ]); - - const chainWithMessageHistory = new RunnableWithMessageHistory({ - runnable: formattedHistoryChain, - getMessageHistory: (sessionId: string) => - this.historyService.getMessageHistory({ - sessionId, - length: chatOptions.includeHistory, - }), - inputMessagesKey: "question", - historyMessagesKey: "chat_history", - }); - - return chainWithMessageHistory.invoke( - { - question, - context: facts, - }, - { - callbacks: handlers ?? undefined, - configurable: { sessionId: chatOptions.sessionId }, - } - ); + /** Context can be either plain text or embeddings */ + async addContext(context: string | number[]) { + const retrievalService = await this.retrievalService.addEmbeddingOrTextToVectorDb(context); + if (retrievalService === "Success") return "OK"; + return "NOT-OK"; } /** diff --git a/src/services/retrieval.ts b/src/services/retrieval.ts index 063daee..d6b860e 100644 --- a/src/services/retrieval.ts +++ b/src/services/retrieval.ts @@ -3,6 +3,7 @@ import { formatFacts } from "../utils"; import type { RAGChatConfig } from "../types"; import { ClientFactory } from "../client-factory"; import { Config } from "../config"; +import { nanoid } from "nanoid"; const SIMILARITY_THRESHOLD = 0.5; const TOP_K = 5; @@ -51,6 +52,13 @@ export class RetrievalService { return formatFacts(facts); } + async addEmbeddingOrTextToVectorDb(input: string | number[]) { + if (typeof input === "string") { + return this.index.upsert({ data: input, id: nanoid(), metadata: { value: input } }); + } + return this.index.upsert({ vector: input, id: nanoid(), metadata: { value: input } }); + } + public static async init(config: RetrievalInit) { const clientFactory = new ClientFactory( new Config(config.email, config.token, {