diff --git a/bun.lockb b/bun.lockb index 85c537f..3f1cc92 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/data/the_wonderful_wizard_of_oz.pdf b/data/the_wonderful_wizard_of_oz.pdf new file mode 100644 index 0000000..359b531 Binary files /dev/null and b/data/the_wonderful_wizard_of_oz.pdf differ diff --git a/package.json b/package.json index e1d5b5b..a146c91 100644 --- a/package.json +++ b/package.json @@ -47,13 +47,15 @@ "vitest": "latest" }, "dependencies": { - "@langchain/community": "^0.0.50", + "@langchain/community": "^0.2.1", "@langchain/core": "^0.1.58", "@langchain/openai": "^0.0.28", "@upstash/ratelimit": "^1.1.3", "@upstash/redis": "^1.31.1", "@upstash/vector": "^1.1.1", "ai": "^3.1.1", - "nanoid": "^5.0.7" + "langchain": "^0.2.0", + "nanoid": "^5.0.7", + "pdf-parse": "^1.1.1" } } diff --git a/src/rag-chat.test.ts b/src/rag-chat.test.ts index 5603970..5459068 100644 --- a/src/rag-chat.test.ts +++ b/src/rag-chat.test.ts @@ -1,16 +1,15 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import type { AIMessage } from "@langchain/core/messages"; +import { PromptTemplate } from "@langchain/core/prompts"; import { ChatOpenAI } from "@langchain/openai"; +import { Ratelimit } from "@upstash/ratelimit"; +import { Redis } from "@upstash/redis"; +import { Index } from "@upstash/vector"; import type { StreamingTextResponse } from "ai"; -import { sleep } from "bun"; import { afterAll, beforeAll, describe, expect, test } from "bun:test"; -import { RAGChat } from "./rag-chat"; -import { Index } from "@upstash/vector"; -import { Redis } from "@upstash/redis"; -import { Ratelimit } from "@upstash/ratelimit"; import { RatelimitUpstashError } from "./error/ratelimit"; -import { PromptTemplate } from "@langchain/core/prompts"; -import { delay } from "./utils"; +import { RAGChat } from "./rag-chat"; +import { awaitUntilIndexed } from "./test-utils"; describe("RAG Chat with advance configs and direct instances", () => { const vector = new Index({ @@ -34,12 +33,11 @@ describe("RAG Chat with advance configs and direct instances", () => { }); beforeAll(async () => { - await ragChat.addContext( - "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.", - "text" - ); - //eslint-disable-next-line @typescript-eslint/no-magic-numbers - await sleep(3000); + await ragChat.addContext({ + dataType: "text", + data: "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.", + }); + await awaitUntilIndexed(vector); }); afterAll(async () => await vector.reset()); @@ -98,11 +96,13 @@ describe("RAG Chat with ratelimit", () => { "should throw ratelimit error", async () => { await ragChat.addContext( - "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.", - "text" + { + dataType: "text", + data: "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.", + }, + { metadataKey: "text" } ); - //eslint-disable-next-line @typescript-eslint/no-magic-numbers - await sleep(3000); + await awaitUntilIndexed(vector); await ragChat.chat( "What year was the construction of the Eiffel Tower completed, and what is its height?", @@ -120,11 +120,12 @@ describe("RAG Chat with ratelimit", () => { }); describe("RAG Chat with custom template", () => { + const vector = new Index({ + token: process.env.UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.UPSTASH_VECTOR_REST_URL!, + }); const ragChat = new RAGChat({ - vector: new Index({ - token: process.env.UPSTASH_VECTOR_REST_TOKEN!, - url: process.env.UPSTASH_VECTOR_REST_URL!, - }), + vector, redis: new Redis({ token: process.env.UPSTASH_REDIS_REST_TOKEN!, url: process.env.UPSTASH_REDIS_REST_URL!, @@ -142,11 +143,14 @@ describe("RAG Chat with custom template", () => { test( "should get result without streaming", async () => { - await ragChat.addContext("Ankara is the capital of Turkiye."); + await ragChat.addContext( + { dataType: "text", data: "Ankara is the capital of Turkiye." }, + { metadataKey: "text" } + ); // Wait for it to be indexed // eslint-disable-next-line @typescript-eslint/no-magic-numbers - await delay(3000); + await awaitUntilIndexed(vector); const result = (await ragChat.chat("Where is the capital of Turkiye?", { stream: false, @@ -157,3 +161,46 @@ describe("RAG Chat with custom template", () => { { timeout: 30_000 } ); }); + +describe("RAG Chat addContext using PDF", () => { + const vector = new Index({ + token: process.env.UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.UPSTASH_VECTOR_REST_URL!, + }); + const redis = new Redis({ + token: process.env.UPSTASH_REDIS_REST_TOKEN!, + url: process.env.UPSTASH_REDIS_REST_URL!, + }); + const ragChat = new RAGChat({ + redis, + vector, + model: new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: false, + verbose: false, + temperature: 0, + apiKey: process.env.OPENAI_API_KEY, + }), + }); + + afterAll(async () => { + await vector.reset(); + }); + + test( + "should be able to successfully query embedded book", + async () => { + await ragChat.addContext({ + dataType: "pdf", + fileSource: "./data/the_wonderful_wizard_of_oz.pdf", + opts: { chunkSize: 500, chunkOverlap: 50 }, + }); + await awaitUntilIndexed(vector); + const result = (await ragChat.chat("Whats the author of The Wonderful Wizard of Oz?", { + stream: false, + })) as AIMessage; + expect(result.content).toContain("Frank"); + }, + { timeout: 30_000 } + ); +}); diff --git a/src/rag-chat.ts b/src/rag-chat.ts index 637cf40..95e7cc4 100644 --- a/src/rag-chat.ts +++ b/src/rag-chat.ts @@ -8,7 +8,7 @@ import { RatelimitUpstashError } from "./error/ratelimit"; import type { Config } from "./config"; import { RAGChatBase } from "./rag-chat-base"; -import type { AddContextPayload } from "./services"; +import type { AddContextOptions, AddContextPayload } from "./services"; import { HistoryService, RetrievalService } from "./services"; import { RateLimitService } from "./services/ratelimit"; import type { ChatOptions } from "./types"; @@ -34,8 +34,18 @@ export class RAGChat extends RAGChatBase { this.#ratelimitService = ratelimitService; } + /** + * A method that allows you to chat LLM using Vector DB as your knowledge store and Redis - optional - as a chat history. + * + * @example + * ```typescript + * await ragChat.chat("Where is the capital of Turkiye?", { + * stream: false, + * }) + * ``` + */ async chat(input: string, options: ChatOptions): Promise { - // Adds chat session id and ratelimit session id if not provided. + // Adds all the necessary default options that users can skip in the options parameter above. const options_ = appendDefaultsIfNeeded(options); // Checks ratelimit of the user. If not enabled `success` will be always true. @@ -65,12 +75,26 @@ export class RAGChat extends RAGChatBase { : this.chainCall(options_, question, facts); } - /** Context can be either plain text or embeddings */ - async addContext(context: AddContextPayload[] | string, metadataKey = "text") { - const retrievalServiceStatus = await this.retrievalService.addEmbeddingOrTextToVectorDb( - context, - metadataKey - ); + /** + * A method that allows you to add various data types into a vector database. + * It supports plain text, embeddings, PDF, and CSV. Additionally, it handles text-splitting for CSV and PDF. + * + * @example + * ```typescript + * await addDataToVectorDb({ + * dataType: "pdf", + * fileSource: "./data/the_wonderful_wizard_of_oz.pdf", + * opts: { chunkSize: 500, chunkOverlap: 50 }, + * }); + * // OR + * await addDataToVectorDb({ + * dataType: "text", + * data: "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.", + * }); + * ``` + */ + async addContext(context: AddContextPayload, options?: AddContextOptions) { + const retrievalServiceStatus = await this.retrievalService.addDataToVectorDb(context, options); return retrievalServiceStatus === "Success" ? "OK" : "NOT-OK"; } diff --git a/src/services/retrieval.ts b/src/services/retrieval.ts index d0a0224..d4d473c 100644 --- a/src/services/retrieval.ts +++ b/src/services/retrieval.ts @@ -2,8 +2,26 @@ import { nanoid } from "nanoid"; import { DEFAULT_METADATA_KEY, DEFAULT_SIMILARITY_THRESHOLD, DEFAULT_TOP_K } from "../constants"; import { formatFacts } from "../utils"; import type { Index } from "@upstash/vector"; +import { PDFLoader } from "@langchain/community/document_loaders/fs/pdf"; +import type { RecursiveCharacterTextSplitterParams } from "langchain/text_splitter"; +import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; -export type AddContextPayload = { input: string | number[]; id?: string; metadata?: string }; +type IndexUpsertPayload = { input: string | number[]; id?: string; metadata?: string }; +type FilePath = string; + +export type AddContextPayload = + | { dataType: "text"; data: string } + | { dataType: "embedding"; data: IndexUpsertPayload[] } + | { + dataType: "pdf"; + fileSource: FilePath | Blob; + opts?: Partial; + } + | { dataType: "csv"; fileSource: FilePath | Blob }; + +export type AddContextOptions = { + metadataKey?: string; +}; export type RetrievePayload = { question: string; @@ -18,6 +36,11 @@ export class RetrievalService { this.index = index; } + /** + * A method that allows you to query the vector database with plain text. + * It takes care of the text-to-embedding conversion by itself. + * Additionally, it lets consumers pass various options to tweak the output. + */ async retrieveFromVectorDb({ question, similarityThreshold = DEFAULT_SIMILARITY_THRESHOLD, @@ -32,17 +55,6 @@ export class RetrievalService { includeVectors: false, }); - const allValuesUndefined = result.every( - (embedding) => embedding.metadata?.[metadataKey] === undefined - ); - - if (allValuesUndefined) { - throw new TypeError(` - Query to the vector store returned ${result.length} vectors but none had "${metadataKey}" field in their metadata. - Text of your vectors should be in the "${metadataKey}" field in the metadata for the RAG Chat. - `); - } - const facts = result .filter((x) => x.score >= similarityThreshold) .map( @@ -51,32 +63,58 @@ export class RetrievalService { return formatFacts(facts); } - async addEmbeddingOrTextToVectorDb( - input: AddContextPayload[] | string, - metadataKey = "text" - ): Promise { - if (typeof input === "string") { - return this.index.upsert({ - data: input, - id: nanoid(), - metadata: { [metadataKey]: input }, - }); - } - const items = input.map((context) => { - const isText = typeof context.input === "string"; - const metadata = context.metadata - ? { [metadataKey]: context.metadata } - : isText - ? { [metadataKey]: context.input } - : {}; + /** + * A method that allows you to add various data types into a vector database. + * It supports plain text, embeddings, PDF, and CSV. Additionally, it handles text-splitting for CSV and PDF. + */ + async addDataToVectorDb( + input: AddContextPayload, + options?: AddContextOptions + ): Promise { + const { metadataKey = "text" } = options ?? {}; - return { - [isText ? "data" : "vector"]: context.input, - id: context.id ?? nanoid(), - metadata, - }; - }); + switch (input.dataType) { + case "text": { + return this.index.upsert({ + data: input.data, + id: nanoid(), + metadata: { [metadataKey]: input.data }, + }); + } + case "embedding": { + const items = input.data.map((context) => { + const isText = typeof context.input === "string"; + const metadata = context.metadata + ? { [metadataKey]: context.metadata } + : isText + ? { [metadataKey]: context.input } + : {}; + + return { + [isText ? "data" : "vector"]: context.input, + id: context.id ?? nanoid(), + metadata, + }; + }); + + return this.index.upsert(items); + } + case "pdf": { + const loader = new PDFLoader(input.fileSource); + const documents = await loader.load(); - return this.index.upsert(items); + // Users will be able to pass options like chunkSize,chunkOverlap when calling addContext from RAGChat instance directly. + const splitter = new RecursiveCharacterTextSplitter(input.opts); + + const splittedDocuments = await splitter.splitDocuments(documents); + const upsertPayload = splittedDocuments.map((document) => ({ + data: document.pageContent, + metadata: { [metadataKey]: document.pageContent }, + id: nanoid(), + })); + + return this.index.upsert(upsertPayload); + } + } } } diff --git a/src/test-utils.ts b/src/test-utils.ts new file mode 100644 index 0000000..d70d9d5 --- /dev/null +++ b/src/test-utils.ts @@ -0,0 +1,24 @@ +/* eslint-disable @typescript-eslint/no-magic-numbers */ +import type { Index } from "@upstash/vector"; +import { sleep } from "bun"; + +export const awaitUntilIndexed = async (client: Index, timeoutMillis = 10_000) => { + const start = performance.now(); + + const getInfo = async () => { + return await client.info(); + }; + + do { + const info = await getInfo(); + if (info.pendingVectorCount === 0) { + // OK, nothing more to index. + return; + } + + // Not indexed yet, sleep a bit and check again if the timeout is not passed. + await sleep(1000); + } while (performance.now() < start + timeoutMillis); + + throw new Error(`Indexing is not completed in ${timeoutMillis} ms.`); +};