Skip to content

Commit

Permalink
feat: add a way to identify different message creators
Browse files Browse the repository at this point in the history
  • Loading branch information
ogzhanolguncu committed May 30, 2024
1 parent 055f821 commit ca2144a
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 14 deletions.
Binary file modified bun.lockb
Binary file not shown.
1 change: 1 addition & 0 deletions index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ export * from "./src/database";
export * from "./src/ratelimit";
export * from "./src/error";
export * from "./src/types";
export { MODEL_NAME_WITH_PROVIDER_SPLITTER } from "./src/constants";
12 changes: 7 additions & 5 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@upstash/rag-chat",
"version": "0.0.25-alpha",
"version": "0.0.27-alpha",
"main": "./dist/index.js",
"module": "./dist/index.mjs",
"types": "./dist/index.d.ts",
Expand Down Expand Up @@ -49,16 +49,18 @@
"dependencies": {
"@langchain/community": "^0.2.1",
"@langchain/core": "^0.1.58",
"@langchain/openai": "^0.0.28",
"@upstash/ratelimit": "^1.1.3",
"@upstash/redis": "^1.31.1",
"@upstash/vector": "^1.1.1",
"ai": "^3.1.1",
"cheerio": "^1.0.0-rc.12",
"d3-dsv": "^3.0.1",
"html-to-text": "^9.0.5",
"langchain": "^0.2.0",
"nanoid": "^5.0.7",
"pdf-parse": "^1.1.1"
},
"peerDependencies": {
"@upstash/redis": "^1.31.3",
"@upstash/vector": "^1.1.1",
"@upstash/ratelimit": "^1.1.3",
"@langchain/openai": "^0.0.34"
}
}
3 changes: 3 additions & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ export const DEFAULT_METADATA_KEY = "text";
//History related default options
export const DEFAULT_HISTORY_TTL = 86_400;
export const DEFAULT_HISTORY_LENGTH = 5;

//We need that constant to split creator LLM such as `ChatOpenAI_gpt-3.5-turbo`. Format is `provider_modelName`.
export const MODEL_NAME_WITH_PROVIDER_SPLITTER = "_";
11 changes: 9 additions & 2 deletions src/history/in-memory-custom-history.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ import { CustomInMemoryChatMessageHistory } from "./in-memory-custom-history";

test("should give last 3 messages from in-memory", async () => {
const messageHistoryLength = 3;
const history = new CustomInMemoryChatMessageHistory([], messageHistoryLength);
const history = new CustomInMemoryChatMessageHistory({
messages: [],
topLevelChatHistoryLength: messageHistoryLength,
modelNameWithProvider: "",
});
await history.addUserMessage("Hello!");
await history.addAIMessage("Hello, human.");
await history.addUserMessage("Whats your name?");
Expand All @@ -16,7 +20,10 @@ test("should give last 3 messages from in-memory", async () => {
});

test("should give all the messages", async () => {
const history = new CustomInMemoryChatMessageHistory();
const history = new CustomInMemoryChatMessageHistory({
messages: [],
modelNameWithProvider: "",
});
await history.addUserMessage("Hello!");
await history.addAIMessage("Hello, human.");
await history.addUserMessage("Whats your name?");
Expand Down
14 changes: 12 additions & 2 deletions src/history/in-memory-custom-history.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@
import { BaseListChatMessageHistory } from "@langchain/core/chat_history";
import type { BaseMessage } from "@langchain/core/messages";

export type CustomInMemoryChatMessageHistoryInput = {
messages?: BaseMessage[];
topLevelChatHistoryLength?: number;
modelNameWithProvider: string;
};

export class CustomInMemoryChatMessageHistory extends BaseListChatMessageHistory {
lc_namespace = ["langchain", "stores", "message", "in_memory"];

private messages: BaseMessage[] = [];
private topLevelChatHistoryLength?: number;
private modelNameWithProvider: string;

constructor(messages?: BaseMessage[], topLevelChatHistoryLength?: number) {
constructor(fields: CustomInMemoryChatMessageHistoryInput) {
const { modelNameWithProvider, messages, topLevelChatHistoryLength } = fields;
// eslint-disable-next-line prefer-rest-params
super(...arguments);
this.messages = messages ?? [];
this.topLevelChatHistoryLength = topLevelChatHistoryLength;
this.modelNameWithProvider = modelNameWithProvider;
}

/**
Expand All @@ -32,7 +41,8 @@ export class CustomInMemoryChatMessageHistory extends BaseListChatMessageHistory
* @returns A promise that resolves when the message has been added.
*/
async addMessage(message: BaseMessage) {
this.messages.push(message);
//@ts-expect-error This our way of mutating Message object to store model name with providers.
this.messages.push({ ...message, modelNameWithProvider: this.modelNameWithProvider });
}

/**
Expand Down
14 changes: 12 additions & 2 deletions src/history/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,25 @@ import { CustomInMemoryChatMessageHistory } from "./in-memory-custom-history";
import { CustomUpstashRedisChatMessageHistory } from "./redis-custom-history";
import { InternalUpstashError } from "../error";

type HistoryConfig = {
redis?: Redis;
modelNameWithProvider: string;
};
type GetHistory = { sessionId: string; length?: number; sessionTTL?: number };

export class History {
private redis?: Redis;
private modelNameWithProvider: string;
private inMemoryChatHistory?: CustomInMemoryChatMessageHistory;

constructor(redis?: Redis) {
constructor(fields: HistoryConfig) {
const { modelNameWithProvider, redis } = fields;

this.redis = redis;
this.modelNameWithProvider = modelNameWithProvider;

if (!redis) {
this.inMemoryChatHistory = new CustomInMemoryChatMessageHistory();
this.inMemoryChatHistory = new CustomInMemoryChatMessageHistory({ modelNameWithProvider });
}
}

Expand All @@ -24,6 +33,7 @@ export class History {
sessionTTL,
topLevelChatHistoryLength: length,
client: this.redis,
modelNameWithProvider: this.modelNameWithProvider,
});
}
} catch (error) {
Expand Down
18 changes: 16 additions & 2 deletions src/history/redis-custom-history.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export type CustomUpstashRedisChatMessageHistoryInput = {
config?: RedisConfigNodejs;
client?: Redis;
topLevelChatHistoryLength?: number;
modelNameWithProvider: string;
};

/**
Expand All @@ -38,13 +39,21 @@ export class CustomUpstashRedisChatMessageHistory extends BaseListChatMessageHis
public client: Redis;

private sessionId: string;
private modelNameWithProvider: string;

private sessionTTL?: number;
private topLevelChatHistoryLength?: number;

constructor(fields: CustomUpstashRedisChatMessageHistoryInput) {
super(fields);
const { sessionId, sessionTTL, config, client, topLevelChatHistoryLength } = fields;
const {
sessionId,
sessionTTL,
config,
client,
topLevelChatHistoryLength,
modelNameWithProvider,
} = fields;
if (client) {
this.client = client;
} else if (config) {
Expand All @@ -54,7 +63,9 @@ export class CustomUpstashRedisChatMessageHistory extends BaseListChatMessageHis
`Upstash Redis message stores require either a config object or a pre-configured client.`
);
}

this.sessionId = sessionId;
this.modelNameWithProvider = modelNameWithProvider;
this.sessionTTL = sessionTTL;
this.topLevelChatHistoryLength = topLevelChatHistoryLength;
}
Expand Down Expand Up @@ -86,7 +97,10 @@ export class CustomUpstashRedisChatMessageHistory extends BaseListChatMessageHis
*/
async addMessage(message: BaseMessage): Promise<void> {
const messageToAdd = mapChatMessagesToStoredMessages([message]);
await this.client.lpush(this.sessionId, JSON.stringify(messageToAdd[0]));
await this.client.lpush(
this.sessionId,
JSON.stringify({ ...messageToAdd[0], modelNameWithProvider: this.modelNameWithProvider })
);
if (this.sessionTTL) {
await this.client.expire(this.sessionId, this.sessionTTL);
}
Expand Down
7 changes: 6 additions & 1 deletion src/rag-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@ import { appendDefaultsIfNeeded } from "./utils";
import type { AddContextOptions, AddContextPayload } from "./database";
import { Database } from "./database";
import { History } from "./history";
import { MODEL_NAME_WITH_PROVIDER_SPLITTER } from "./constants.ts";

export class RAGChat extends RAGChatBase {
#ratelimitService: RateLimitService;

constructor(config: RAGChatConfig) {
const { vector: index, redis } = new Config(config);

const historyService = new History(redis);
const historyService = new History({
redis,
//@ts-expect-error We need that private field to track message creator LLM such as `ChatOpenAI_gpt-3.5-turbo`. Format is `provider_modelName`.
modelNameWithProvider: `${config.model?.getName()}${MODEL_NAME_WITH_PROVIDER_SPLITTER}${config.model?.modelName}`,
});
const vectorService = new Database(index);
const ratelimitService = new RateLimitService(config.ratelimit);

Expand Down

0 comments on commit ca2144a

Please sign in to comment.