Skip to content

Commit

Permalink
feat: move core functions to base class
Browse files Browse the repository at this point in the history
  • Loading branch information
ogzhanolguncu committed May 7, 2024
1 parent 6aaaa32 commit dd58b7c
Show file tree
Hide file tree
Showing 9 changed files with 369 additions and 105 deletions.
Binary file modified bun.lockb
Binary file not shown.
5 changes: 3 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@upstash/rag-chat",
"version": "0.0.20-alpha",
"version": "0.0.21-alpha",
"main": "./dist/index.js",
"module": "./dist/index.mjs",
"types": "./dist/index.d.ts",
Expand Down Expand Up @@ -52,6 +52,7 @@
"@langchain/core": "^0.1.58",
"@langchain/openai": "^0.0.28",
"@upstash/sdk": "0.0.30-alpha",
"ai": "^3.0.35"
"ai": "^3.1.1",
"nanoid": "^5.0.7"
}
}
18 changes: 7 additions & 11 deletions src/client-factory.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
/* eslint-disable unicorn/no-useless-undefined */
import type { Index, Redis } from "@upstash/sdk";
import type { Config } from "./config";

import { Upstash } from "@upstash/sdk";
import { RedisClient } from "./clients/redis";
import { VectorClient } from "./clients/vector";
import { InternalUpstashError } from "./error/internal";

export type ClientFactoryConfig = Pick<Config, "email" | "token" | "vector" | "redis" | "region">;
export class ClientFactory {
Expand Down Expand Up @@ -44,23 +44,19 @@ export class ClientFactory {
redis: TInit["redis"] extends true ? Redis : undefined;
vector: TInit["vector"] extends true ? Index : undefined;
}> {
let redis: Redis | undefined;
let vector: Index | undefined;
let redisPromise: Promise<Redis | undefined> = Promise.resolve(undefined);
let vectorPromise: Promise<Index | undefined> = Promise.resolve(undefined);

if (options.redis) {
redis = await this.createRedisClient();
if (!redis) {
throw new InternalUpstashError("Couldn't initialize Redis client");
}
redisPromise = this.createRedisClient();
}

if (options.vector) {
vector = await this.createVectorClient();
if (!vector) {
throw new InternalUpstashError("Couldn't initialize Vector client");
}
vectorPromise = this.createVectorClient();
}

const [redis, vector] = await Promise.all([redisPromise, vectorPromise]);

return { redis, vector } as {
redis: TInit["redis"] extends true ? Redis : undefined;
vector: TInit["vector"] extends true ? Index : undefined;
Expand Down
6 changes: 3 additions & 3 deletions src/clients/vector/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ describe("Vector Client", () => {

await upstashSDK.deleteVectorIndex(DEFAULT_VECTOR_DB_NAME);
},
{ timeout: 20_000 }
{ timeout: 30_000 }
);

test(
Expand All @@ -38,7 +38,7 @@ describe("Vector Client", () => {

await upstashSDK.deleteVectorIndex("test-name");
},
{ timeout: 20_000 }
{ timeout: 30_000 }
);

test(
Expand All @@ -62,6 +62,6 @@ describe("Vector Client", () => {

await upstashSDK.deleteVectorIndex(indexName);
},
{ timeout: 20_000 }
{ timeout: 30_000 }
);
});
2 changes: 2 additions & 0 deletions src/clients/vector/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { Index } from "@upstash/sdk";

import type { PreferredRegions } from "../../types";
import { DEFAULT_VECTOR_DB_NAME } from "../../constants";
import { delay } from "../../utils";

export const DEFAULT_VECTOR_CONFIG: CreateIndexPayload = {
name: DEFAULT_VECTOR_DB_NAME,
Expand Down Expand Up @@ -94,6 +95,7 @@ export class VectorClient {
});
}
}
await delay();

if (index?.name) {
const client = await this.upstashSDK.newVectorClient(index.name);
Expand Down
96 changes: 96 additions & 0 deletions src/rag-chat-base.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import type { Callbacks } from "@langchain/core/callbacks/manager";
import type { AIMessage, BaseMessage } from "@langchain/core/messages";
import { RunnableSequence, RunnableWithMessageHistory } from "@langchain/core/runnables";
import { StreamingTextResponse, LangChainStream } from "ai";

import type { PrepareChatResult, ChatOptions } from "./types";
import { sanitizeQuestion, formatChatHistory } from "./utils";
import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
import type { PromptTemplate } from "@langchain/core/prompts";
import type { HistoryService, RetrievePayload } from "./services";
import type { RetrievalService } from "./services/retrieval";

type CustomInputValues = { chat_history?: BaseMessage[]; question: string; context: string };

export class RAGChatBase {
protected retrievalService: RetrievalService;
protected historyService: HistoryService;

#model: BaseLanguageModelInterface;
#template: PromptTemplate;

constructor(
retrievalService: RetrievalService,
historyService: HistoryService,
config: { model: BaseLanguageModelInterface; template: PromptTemplate }
) {
this.retrievalService = retrievalService;
this.historyService = historyService;

this.#model = config.model;
this.#template = config.template;
}

protected async prepareChat({
question: input,
similarityThreshold,
topK,
}: RetrievePayload): Promise<PrepareChatResult> {
const question = sanitizeQuestion(input);
const facts = await this.retrievalService.retrieveFromVectorDb({
question,
similarityThreshold,
topK,
});
return { question, facts };
}

protected streamingChainCall = (
chatOptions: ChatOptions,
question: string,
facts: string
): StreamingTextResponse => {
const { stream, handlers } = LangChainStream();
void this.chainCall(chatOptions, question, facts, [handlers]);
return new StreamingTextResponse(stream, {});
};

protected chainCall(
chatOptions: ChatOptions,
question: string,
facts: string,
handlers?: Callbacks
): Promise<AIMessage> {
const formattedHistoryChain = RunnableSequence.from<CustomInputValues>([
{
chat_history: (input) => formatChatHistory(input.chat_history ?? []),
question: (input) => input.question,
context: (input) => input.context,
},
this.#template,
this.#model,
]);

const chainWithMessageHistory = new RunnableWithMessageHistory({
runnable: formattedHistoryChain,
getMessageHistory: (sessionId: string) =>
this.historyService.getMessageHistory({
sessionId,
length: chatOptions.includeHistory,
}),
inputMessagesKey: "question",
historyMessagesKey: "chat_history",
});

return chainWithMessageHistory.invoke(
{
question,
context: facts,
},
{
callbacks: handlers ?? undefined,
configurable: { sessionId: chatOptions.sessionId },
}
) as Promise<AIMessage>;
}
}
Loading

0 comments on commit dd58b7c

Please sign in to comment.