Skip to content

Commit

Permalink
feat: support adding pdfs as context
Browse files Browse the repository at this point in the history
  • Loading branch information
ogzhanolguncu committed May 22, 2024
1 parent 4350067 commit b9cef53
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 70 deletions.
Binary file modified bun.lockb
Binary file not shown.
Binary file added data/the_wonderful_wizard_of_oz.pdf
Binary file not shown.
6 changes: 4 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@
"vitest": "latest"
},
"dependencies": {
"@langchain/community": "^0.0.50",
"@langchain/community": "^0.2.1",
"@langchain/core": "^0.1.58",
"@langchain/openai": "^0.0.28",
"@upstash/ratelimit": "^1.1.3",
"@upstash/redis": "^1.31.1",
"@upstash/vector": "^1.1.1",
"ai": "^3.1.1",
"nanoid": "^5.0.7"
"langchain": "^0.2.0",
"nanoid": "^5.0.7",
"pdf-parse": "^1.1.1"
}
}
93 changes: 70 additions & 23 deletions src/rag-chat.test.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import type { AIMessage } from "@langchain/core/messages";
import { PromptTemplate } from "@langchain/core/prompts";
import { ChatOpenAI } from "@langchain/openai";
import { Ratelimit } from "@upstash/ratelimit";
import { Redis } from "@upstash/redis";
import { Index } from "@upstash/vector";
import type { StreamingTextResponse } from "ai";
import { sleep } from "bun";
import { afterAll, beforeAll, describe, expect, test } from "bun:test";
import { RAGChat } from "./rag-chat";
import { Index } from "@upstash/vector";
import { Redis } from "@upstash/redis";
import { Ratelimit } from "@upstash/ratelimit";
import { RatelimitUpstashError } from "./error/ratelimit";
import { PromptTemplate } from "@langchain/core/prompts";
import { delay } from "./utils";
import { RAGChat } from "./rag-chat";
import { awaitUntilIndexed } from "./test-utils";

describe("RAG Chat with advance configs and direct instances", () => {
const vector = new Index({
Expand All @@ -34,12 +33,11 @@ describe("RAG Chat with advance configs and direct instances", () => {
});

beforeAll(async () => {
await ragChat.addContext(
"Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.",
"text"
);
//eslint-disable-next-line @typescript-eslint/no-magic-numbers
await sleep(3000);
await ragChat.addContext({
dataType: "text",
data: "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.",
});
await awaitUntilIndexed(vector);
});

afterAll(async () => await vector.reset());
Expand Down Expand Up @@ -98,11 +96,13 @@ describe("RAG Chat with ratelimit", () => {
"should throw ratelimit error",
async () => {
await ragChat.addContext(
"Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.",
"text"
{
dataType: "text",
data: "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.",
},
{ metadataKey: "text" }
);
//eslint-disable-next-line @typescript-eslint/no-magic-numbers
await sleep(3000);
await awaitUntilIndexed(vector);

await ragChat.chat(
"What year was the construction of the Eiffel Tower completed, and what is its height?",
Expand All @@ -120,11 +120,12 @@ describe("RAG Chat with ratelimit", () => {
});

describe("RAG Chat with custom template", () => {
const vector = new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
});
const ragChat = new RAGChat({
vector: new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
}),
vector,
redis: new Redis({
token: process.env.UPSTASH_REDIS_REST_TOKEN!,
url: process.env.UPSTASH_REDIS_REST_URL!,
Expand All @@ -142,11 +143,14 @@ describe("RAG Chat with custom template", () => {
test(
"should get result without streaming",
async () => {
await ragChat.addContext("Ankara is the capital of Turkiye.");
await ragChat.addContext(
{ dataType: "text", data: "Ankara is the capital of Turkiye." },
{ metadataKey: "text" }
);

// Wait for it to be indexed
// eslint-disable-next-line @typescript-eslint/no-magic-numbers
await delay(3000);
await awaitUntilIndexed(vector);

const result = (await ragChat.chat("Where is the capital of Turkiye?", {
stream: false,
Expand All @@ -157,3 +161,46 @@ describe("RAG Chat with custom template", () => {
{ timeout: 30_000 }
);
});

describe("RAG Chat addContext using PDF", () => {
const vector = new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
});
const redis = new Redis({
token: process.env.UPSTASH_REDIS_REST_TOKEN!,
url: process.env.UPSTASH_REDIS_REST_URL!,
});
const ragChat = new RAGChat({
redis,
vector,
model: new ChatOpenAI({
modelName: "gpt-3.5-turbo",
streaming: false,
verbose: false,
temperature: 0,
apiKey: process.env.OPENAI_API_KEY,
}),
});

afterAll(async () => {
await vector.reset();
});

test(
"should be able to successfully query embedded book",
async () => {
await ragChat.addContext({
dataType: "pdf",
fileSource: "./data/the_wonderful_wizard_of_oz.pdf",
opts: { chunkSize: 500, chunkOverlap: 50 },
});
await awaitUntilIndexed(vector);
const result = (await ragChat.chat("Whats the author of The Wonderful Wizard of Oz?", {
stream: false,
})) as AIMessage;
expect(result.content).toContain("Frank");
},
{ timeout: 30_000 }
);
});
40 changes: 32 additions & 8 deletions src/rag-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { RatelimitUpstashError } from "./error/ratelimit";

import type { Config } from "./config";
import { RAGChatBase } from "./rag-chat-base";
import type { AddContextPayload } from "./services";
import type { AddContextOptions, AddContextPayload } from "./services";
import { HistoryService, RetrievalService } from "./services";
import { RateLimitService } from "./services/ratelimit";
import type { ChatOptions } from "./types";
Expand All @@ -34,8 +34,18 @@ export class RAGChat extends RAGChatBase {
this.#ratelimitService = ratelimitService;
}

/**
* A method that allows you to chat LLM using Vector DB as your knowledge store and Redis - optional - as a chat history.
*
* @example
* ```typescript
* await ragChat.chat("Where is the capital of Turkiye?", {
* stream: false,
* })
* ```
*/
async chat(input: string, options: ChatOptions): Promise<StreamingTextResponse | AIMessage> {
// Adds chat session id and ratelimit session id if not provided.
// Adds all the necessary default options that users can skip in the options parameter above.
const options_ = appendDefaultsIfNeeded(options);

// Checks ratelimit of the user. If not enabled `success` will be always true.
Expand Down Expand Up @@ -65,12 +75,26 @@ export class RAGChat extends RAGChatBase {
: this.chainCall(options_, question, facts);
}

/** Context can be either plain text or embeddings */
async addContext(context: AddContextPayload[] | string, metadataKey = "text") {
const retrievalServiceStatus = await this.retrievalService.addEmbeddingOrTextToVectorDb(
context,
metadataKey
);
/**
* A method that allows you to add various data types into a vector database.
* It supports plain text, embeddings, PDF, and CSV. Additionally, it handles text-splitting for CSV and PDF.
*
* @example
* ```typescript
* await addDataToVectorDb({
* dataType: "pdf",
* fileSource: "./data/the_wonderful_wizard_of_oz.pdf",
* opts: { chunkSize: 500, chunkOverlap: 50 },
* });
* // OR
* await addDataToVectorDb({
* dataType: "text",
* data: "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.",
* });
* ```
*/
async addContext(context: AddContextPayload, options?: AddContextOptions) {
const retrievalServiceStatus = await this.retrievalService.addDataToVectorDb(context, options);
return retrievalServiceStatus === "Success" ? "OK" : "NOT-OK";
}

Expand Down
112 changes: 75 additions & 37 deletions src/services/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,26 @@ import { nanoid } from "nanoid";
import { DEFAULT_METADATA_KEY, DEFAULT_SIMILARITY_THRESHOLD, DEFAULT_TOP_K } from "../constants";
import { formatFacts } from "../utils";
import type { Index } from "@upstash/vector";
import { PDFLoader } from "@langchain/community/document_loaders/fs/pdf";
import type { RecursiveCharacterTextSplitterParams } from "langchain/text_splitter";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";

export type AddContextPayload = { input: string | number[]; id?: string; metadata?: string };
type IndexUpsertPayload = { input: string | number[]; id?: string; metadata?: string };
type FilePath = string;

export type AddContextPayload =
| { dataType: "text"; data: string }
| { dataType: "embedding"; data: IndexUpsertPayload[] }
| {
dataType: "pdf";
fileSource: FilePath | Blob;
opts?: Partial<RecursiveCharacterTextSplitterParams>;
}
| { dataType: "csv"; fileSource: FilePath | Blob };

export type AddContextOptions = {
metadataKey?: string;
};

export type RetrievePayload = {
question: string;
Expand All @@ -18,6 +36,11 @@ export class RetrievalService {
this.index = index;
}

/**
* A method that allows you to query the vector database with plain text.
* It takes care of the text-to-embedding conversion by itself.
* Additionally, it lets consumers pass various options to tweak the output.
*/
async retrieveFromVectorDb({
question,
similarityThreshold = DEFAULT_SIMILARITY_THRESHOLD,
Expand All @@ -32,17 +55,6 @@ export class RetrievalService {
includeVectors: false,
});

const allValuesUndefined = result.every(
(embedding) => embedding.metadata?.[metadataKey] === undefined
);

if (allValuesUndefined) {
throw new TypeError(`
Query to the vector store returned ${result.length} vectors but none had "${metadataKey}" field in their metadata.
Text of your vectors should be in the "${metadataKey}" field in the metadata for the RAG Chat.
`);
}

const facts = result
.filter((x) => x.score >= similarityThreshold)
.map(
Expand All @@ -51,32 +63,58 @@ export class RetrievalService {
return formatFacts(facts);
}

async addEmbeddingOrTextToVectorDb(
input: AddContextPayload[] | string,
metadataKey = "text"
): Promise<string> {
if (typeof input === "string") {
return this.index.upsert({
data: input,
id: nanoid(),
metadata: { [metadataKey]: input },
});
}
const items = input.map((context) => {
const isText = typeof context.input === "string";
const metadata = context.metadata
? { [metadataKey]: context.metadata }
: isText
? { [metadataKey]: context.input }
: {};
/**
* A method that allows you to add various data types into a vector database.
* It supports plain text, embeddings, PDF, and CSV. Additionally, it handles text-splitting for CSV and PDF.
*/
async addDataToVectorDb(
input: AddContextPayload,
options?: AddContextOptions
): Promise<string | undefined> {
const { metadataKey = "text" } = options ?? {};

return {
[isText ? "data" : "vector"]: context.input,
id: context.id ?? nanoid(),
metadata,
};
});
switch (input.dataType) {
case "text": {
return this.index.upsert({
data: input.data,
id: nanoid(),
metadata: { [metadataKey]: input.data },
});
}
case "embedding": {
const items = input.data.map((context) => {
const isText = typeof context.input === "string";
const metadata = context.metadata
? { [metadataKey]: context.metadata }
: isText
? { [metadataKey]: context.input }
: {};

return {
[isText ? "data" : "vector"]: context.input,
id: context.id ?? nanoid(),
metadata,
};
});

return this.index.upsert(items);
}
case "pdf": {
const loader = new PDFLoader(input.fileSource);
const documents = await loader.load();

return this.index.upsert(items);
// Users will be able to pass options like chunkSize,chunkOverlap when calling addContext from RAGChat instance directly.
const splitter = new RecursiveCharacterTextSplitter(input.opts);

const splittedDocuments = await splitter.splitDocuments(documents);
const upsertPayload = splittedDocuments.map((document) => ({
data: document.pageContent,
metadata: { [metadataKey]: document.pageContent },
id: nanoid(),
}));

return this.index.upsert(upsertPayload);
}
}
}
}
24 changes: 24 additions & 0 deletions src/test-utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/* eslint-disable @typescript-eslint/no-magic-numbers */
import type { Index } from "@upstash/vector";
import { sleep } from "bun";

export const awaitUntilIndexed = async (client: Index, timeoutMillis = 10_000) => {
const start = performance.now();

const getInfo = async () => {
return await client.info();
};

do {
const info = await getInfo();
if (info.pendingVectorCount === 0) {
// OK, nothing more to index.
return;
}

// Not indexed yet, sleep a bit and check again if the timeout is not passed.
await sleep(1000);
} while (performance.now() < start + timeoutMillis);

throw new Error(`Indexing is not completed in ${timeoutMillis} ms.`);
};

0 comments on commit b9cef53

Please sign in to comment.