diff --git a/src/database.test.ts b/src/database.test.ts index 518316a..07354b6 100644 --- a/src/database.test.ts +++ b/src/database.test.ts @@ -32,6 +32,6 @@ describe("Database", () => { similarityThreshold: 0.5, namespace: "", }); - expect(result).toContain("330"); + expect(result.map(({ data }) => data).join(" ")).toContain("330"); }); }); diff --git a/src/database.ts b/src/database.ts index 8c88cdc..b938053 100644 --- a/src/database.ts +++ b/src/database.ts @@ -5,7 +5,6 @@ import { nanoid } from "nanoid"; import { DEFAULT_SIMILARITY_THRESHOLD, DEFAULT_TOP_K } from "./constants"; import { FileDataLoader } from "./file-loader"; import type { AddContextOptions } from "./types"; -import { formatFacts } from "./utils"; export type IndexUpsertPayload = { input: number[]; id?: string | number; metadata?: string }; export type FilePath = string; @@ -89,7 +88,7 @@ export class Database { similarityThreshold = DEFAULT_SIMILARITY_THRESHOLD, topK = DEFAULT_TOP_K, namespace, - }: VectorPayload): Promise { + }: VectorPayload): Promise<{ data: string; id: string }[]> { const index = this.index; const result = await index.query>( { @@ -103,13 +102,17 @@ export class Database { if (allValuesUndefined) { console.error("There is no answer for this question in the provided context."); - return formatFacts(["There is no answer for this question in the provided context."]); + + return [ + { data: " There is no answer for this question in the provided context.", id: "error" }, + ]; } const facts = result .filter((x) => x.score >= similarityThreshold) - .map((embedding) => `- ${embedding.data ?? ""}`); - return formatFacts(facts); + .map((embedding) => ({ data: `- ${embedding.data ?? ""}`, id: embedding.id.toString() })); + + return facts; } /** diff --git a/src/rag-chat.ts b/src/rag-chat.ts index a48e01f..25ca3ca 100644 --- a/src/rag-chat.ts +++ b/src/rag-chat.ts @@ -8,7 +8,7 @@ import type { CustomPrompt } from "./rag-chat-base"; import { RAGChatBase } from "./rag-chat-base"; import { RateLimitService } from "./ratelimit-service"; import type { ChatOptions, RAGChatConfig } from "./types"; -import { appendDefaultsIfNeeded } from "./utils"; +import { appendDefaultsIfNeeded, formatFacts } from "./utils"; import { RatelimitUpstashError } from "./error"; type ChatReturnType> = Promise< @@ -88,13 +88,19 @@ export class RAGChat extends RAGChatBase { }); // Sanitizes the given input by stripping all the newline chars. Then, queries vector db with sanitized question. - const { question, context } = await this.prepareChat({ + const { question, context: originalContext } = await this.prepareChat({ question: input, similarityThreshold: options_.similarityThreshold, topK: options_.topK, namespace: options_.namespace, }); + // clone context to avoid mutation issues + const clonedContext = structuredClone(originalContext); + const modifiedContext = await options?.onContextFetched?.(clonedContext); + + const context = formatFacts((modifiedContext ?? originalContext).map(({ data }) => data)); + // Gets the chat history from redis or in-memory store. const chatHistory = await this.history.getMessages({ sessionId: options_.sessionId, diff --git a/src/types.ts b/src/types.ts index 7b903ef..f49f0c7 100644 --- a/src/types.ts +++ b/src/types.ts @@ -8,6 +8,7 @@ import type { LLMClient } from "./custom-llm-client"; declare const __brand: unique symbol; type Brand = { [__brand]: B }; export type Branded = T & Brand; +type OptionalAsync = T | Promise; export type ChatOptions = { /** Set to `true` if working with web apps and you want to be interactive without stalling users. @@ -75,11 +76,18 @@ export type ChatOptions = { content: string; rawContent: string; }) => void; + + /** + * Hook to access the retrieved context and modify as you wish. + */ + onContextFetched?: ( + context: PrepareChatResult["context"] + ) => OptionalAsync | OptionalAsync; }; export type PrepareChatResult = { question: string; - context: string; + context: { data: string; id: string }[]; }; type RAGChatConfigCommon = {