Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add context hook #21

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/database.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ describe("Database", () => {
similarityThreshold: 0.5,
namespace: "",
});
expect(result).toContain("330");
expect(result.map(({ data }) => data).join(" ")).toContain("330");
});
});
13 changes: 8 additions & 5 deletions src/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { nanoid } from "nanoid";
import { DEFAULT_SIMILARITY_THRESHOLD, DEFAULT_TOP_K } from "./constants";
import { FileDataLoader } from "./file-loader";
import type { AddContextOptions } from "./types";
import { formatFacts } from "./utils";

export type IndexUpsertPayload = { input: number[]; id?: string | number; metadata?: string };
export type FilePath = string;
Expand Down Expand Up @@ -89,7 +88,7 @@ export class Database {
similarityThreshold = DEFAULT_SIMILARITY_THRESHOLD,
topK = DEFAULT_TOP_K,
namespace,
}: VectorPayload): Promise<string> {
}: VectorPayload): Promise<{ data: string; id: string }[]> {
const index = this.index;
const result = await index.query<Record<string, string>>(
{
Expand All @@ -103,13 +102,17 @@ export class Database {

if (allValuesUndefined) {
console.error("There is no answer for this question in the provided context.");
return formatFacts(["There is no answer for this question in the provided context."]);

return [
{ data: " There is no answer for this question in the provided context.", id: "error" },
];
}

const facts = result
.filter((x) => x.score >= similarityThreshold)
.map((embedding) => `- ${embedding.data ?? ""}`);
return formatFacts(facts);
.map((embedding) => ({ data: `- ${embedding.data ?? ""}`, id: embedding.id.toString() }));

return facts;
}

/**
Expand Down
10 changes: 8 additions & 2 deletions src/rag-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import type { CustomPrompt } from "./rag-chat-base";
import { RAGChatBase } from "./rag-chat-base";
import { RateLimitService } from "./ratelimit-service";
import type { ChatOptions, RAGChatConfig } from "./types";
import { appendDefaultsIfNeeded } from "./utils";
import { appendDefaultsIfNeeded, formatFacts } from "./utils";
import { RatelimitUpstashError } from "./error";

type ChatReturnType<T extends Partial<ChatOptions>> = Promise<
Expand Down Expand Up @@ -88,13 +88,19 @@ export class RAGChat extends RAGChatBase {
});

// Sanitizes the given input by stripping all the newline chars. Then, queries vector db with sanitized question.
const { question, context } = await this.prepareChat({
const { question, context: originalContext } = await this.prepareChat({
question: input,
similarityThreshold: options_.similarityThreshold,
topK: options_.topK,
namespace: options_.namespace,
});

// clone context to avoid mutation issues
const clonedContext = structuredClone(originalContext);
const modifiedContext = await options?.onContextFetched?.(clonedContext);

const context = formatFacts((modifiedContext ?? originalContext).map(({ data }) => data));

// Gets the chat history from redis or in-memory store.
const chatHistory = await this.history.getMessages({
sessionId: options_.sessionId,
Expand Down
10 changes: 9 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import type { LLMClient } from "./custom-llm-client";
declare const __brand: unique symbol;
type Brand<B> = { [__brand]: B };
export type Branded<T, B> = T & Brand<B>;
type OptionalAsync<T> = T | Promise<T>;

export type ChatOptions = {
/** Set to `true` if working with web apps and you want to be interactive without stalling users.
Expand Down Expand Up @@ -75,11 +76,18 @@ export type ChatOptions = {
content: string;
rawContent: string;
}) => void;

/**
* Hook to access the retrieved context and modify as you wish.
*/
onContextFetched?: (
context: PrepareChatResult["context"]
) => OptionalAsync<PrepareChatResult["context"]> | OptionalAsync<undefined>;
};

export type PrepareChatResult = {
question: string;
context: string;
context: { data: string; id: string }[];
};

type RAGChatConfigCommon = {
Expand Down
Loading