diff --git a/libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts b/libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts new file mode 100644 index 000000000000..e7db4cb0be72 --- /dev/null +++ b/libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts @@ -0,0 +1,157 @@ +import { + Collection, + Document as AzureCosmosMongoDBDocument, + PushOperator, + Db, + MongoClient, +} from "mongodb"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +export interface AzureCosmosDBMongoChatHistoryDBConfig { + readonly client?: MongoClient; + readonly connectionString?: string; + readonly databaseName?: string; + readonly collectionName?: string; +} + +export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "azurecosmosdb"]; + + get lc_secrets(): { [key: string]: string } { + return { + connectionString: "AZURE_COSMOSDB_MONGODB_CONNECTION_STRING", + }; + } + + private initPromise?: Promise; + + private readonly client: MongoClient | undefined; + + private database: Db; + + private collection: Collection; + + private sessionId: string; + + private idKey = "sessionId"; + + initialize: () => Promise; + + constructor( + dbConfig: AzureCosmosDBMongoChatHistoryDBConfig, + sessionId: string + ) { + super(); + + const connectionString = + dbConfig.connectionString ?? + getEnvironmentVariable("AZURE_COSMOSDB_MONGODB_CONNECTION_STRING"); + + if (!dbConfig.client && !connectionString) { + throw new Error( + "Mongo client or connection string must be set." + ); + } + + if (!dbConfig.client) { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + this.client = new MongoClient(connectionString!, { + appName: "langchainjs", + }); + } + + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const client = dbConfig.client || this.client!; + const databaseName = dbConfig.databaseName ?? "documentsDB"; + const collectionName = dbConfig.collectionName ?? "documents"; + + this.sessionId = sessionId; + + // Deferring initialization to the first call to `initialize` + this.initialize = () => { + if (this.initPromise === undefined) { + this.initPromise = this.init( + client, + databaseName, + collectionName + ).catch((error) => { + console.error( + "Error during AzureCosmosDBMongoChatMessageHistory initialization: ", + error + ); + }); + } + + return this.initPromise; + }; + } + + /** + * Initializes the AzureCosmosDBMongoChatMessageHistory by connecting to the database. + * @param client The MongoClient to use for connecting to the database. + * @param databaseName The name of the database to use. + * @param collectionName The name of the collection to use. + * @returns A promise that resolves when the AzureCosmosDBMongoChatMessageHistory has been initialized. + */ + private async init( + client: MongoClient, + databaseName: string, + collectionName: string + ): Promise { + this.initPromise = (async () => { + await client.connect(); + this.database = client.db(databaseName); + this.collection = this.database.collection(collectionName); + })(); + + return this.initPromise; + } + + /** + * Retrieves the messages stored in the history. + * @returns A promise that resolves with the messages stored in the history. + */ + async getMessages(): Promise { + await this.initialize(); + + const document = await this.collection.findOne({ + [this.idKey]: this.sessionId, + }); + const messages = document?.messages || []; + return mapStoredMessagesToChatMessages(messages); + } + + /** + * Adds a message to the history. + * @param message The message to add to the history. + * @returns A promise that resolves when the message has been added to the history. + */ + async addMessage(message: BaseMessage): Promise { + await this.initialize(); + + const messages = mapChatMessagesToStoredMessages([message]); + await this.collection.updateOne( + { [this.idKey]: this.sessionId }, + { + $push: { messages: { $each: messages } } as PushOperator, + }, + { upsert: true } + ); + } + + /** + * Clear the history. + * @returns A promise that resolves when the history has been cleared. + */ + async clear(): Promise { + await this.initialize(); + + await this.collection.deleteOne({ [this.idKey]: this.sessionId }); + } +} diff --git a/libs/langchain-azure-cosmosdb/src/chat_histories.ts b/libs/langchain-azure-cosmosdb/src/chat_histories/nosql.ts similarity index 100% rename from libs/langchain-azure-cosmosdb/src/chat_histories.ts rename to libs/langchain-azure-cosmosdb/src/chat_histories/nosql.ts diff --git a/libs/langchain-azure-cosmosdb/src/chat_histories_azure_cosmosdb_mongodb.ts b/libs/langchain-azure-cosmosdb/src/chat_histories_azure_cosmosdb_mongodb.ts deleted file mode 100644 index ad85c3a8b502..000000000000 --- a/libs/langchain-azure-cosmosdb/src/chat_histories_azure_cosmosdb_mongodb.ts +++ /dev/null @@ -1,58 +0,0 @@ -import { - Collection, - Document as AzureCosmosMongoDBDocument, - PushOperator, -} from "mongodb"; -import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; -import { - BaseMessage, - mapChatMessagesToStoredMessages, - mapStoredMessagesToChatMessages, -} from "@langchain/core/messages"; - -export interface AzureCosmosDBMongoChatMessageHistoryInput { - collection: Collection; - sessionId: string; -} - -export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHistory { - lc_namespace = ["langchain", "stores", "message", "azurecosmosdb"]; - - private collection: Collection; - - private sessionId: string; - - private idKey = "sessionId"; - - constructor({ - collection, - sessionId, - }: AzureCosmosDBMongoChatMessageHistoryInput) { - super(); - this.collection = collection; - this.sessionId = sessionId; - } - - async getMessages(): Promise { - const document = await this.collection.findOne({ - [this.idKey]: this.sessionId, - }); - const messages = document?.messages || []; - return mapStoredMessagesToChatMessages(messages); - } - - async addMessage(message: BaseMessage): Promise { - const messages = mapChatMessagesToStoredMessages([message]); - await this.collection.updateOne( - { [this.idKey]: this.sessionId }, - { - $push: { messages: { $each: messages } } as PushOperator, - }, - { upsert: true } - ); - } - - async clear(): Promise { - await this.collection.deleteOne({ [this.idKey]: this.sessionId }); - } -} diff --git a/libs/langchain-azure-cosmosdb/src/index.ts b/libs/langchain-azure-cosmosdb/src/index.ts index c5160397b474..32e989fe55c0 100644 --- a/libs/langchain-azure-cosmosdb/src/index.ts +++ b/libs/langchain-azure-cosmosdb/src/index.ts @@ -1,4 +1,5 @@ export * from "./azure_cosmosdb_mongodb.js"; export * from "./azure_cosmosdb_nosql.js"; export * from "./caches.js"; -export * from "./chat_histories.js"; +export * from "./chat_histories/nosql.js"; +export * from "./chat_histories/mongodb.js"; \ No newline at end of file diff --git a/libs/langchain-azure-cosmosdb/src/tests/chat_histories.int.test.ts b/libs/langchain-azure-cosmosdb/src/tests/chat_histories.int.test.ts index 76da66d7f805..85c57cb64860 100644 --- a/libs/langchain-azure-cosmosdb/src/tests/chat_histories.int.test.ts +++ b/libs/langchain-azure-cosmosdb/src/tests/chat_histories.int.test.ts @@ -6,7 +6,7 @@ import { HumanMessage, AIMessage } from "@langchain/core/messages"; import { CosmosClient } from "@azure/cosmos"; import { DefaultAzureCredential } from "@azure/identity"; import { ObjectId } from "mongodb"; -import { AzureCosmsosDBNoSQLChatMessageHistory } from "../chat_histories.js"; +import { AzureCosmsosDBNoSQLChatMessageHistory } from "../chat_histories/nosql.js"; const DATABASE_NAME = "langchainTestDB"; const CONTAINER_NAME = "testContainer"; diff --git a/libs/langchain-azure-cosmosdb/src/tests/chat_histories_azure_cosmosdb_mongodb.int.test.ts b/libs/langchain-azure-cosmosdb/src/tests/chat_histories_azure_cosmosdb_mongodb.int.test.ts index fd454245b463..e3137ef2857a 100644 --- a/libs/langchain-azure-cosmosdb/src/tests/chat_histories_azure_cosmosdb_mongodb.int.test.ts +++ b/libs/langchain-azure-cosmosdb/src/tests/chat_histories_azure_cosmosdb_mongodb.int.test.ts @@ -2,7 +2,10 @@ import { MongoClient, ObjectId } from "mongodb"; import { AIMessage, HumanMessage } from "@langchain/core/messages"; -import { AzureCosmosDBMongoChatMessageHistory } from "../chat_histories_azure_cosmosdb_mongodb.js"; +import { + AzureCosmosDBMongoChatMessageHistory, + AzureCosmosDBMongoChatHistoryDBConfig, +} from "../chat_histories/mongodb.js"; afterAll(async () => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion @@ -18,17 +21,21 @@ test("Test Azure Cosmos MongoDB history store", async () => { expect(process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING).toBeDefined(); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const client = new MongoClient( + const mongoClient = new MongoClient( process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING! ); - await client.connect(); - const collection = client.db("langchain").collection("memory"); + const dbcfg: AzureCosmosDBMongoChatHistoryDBConfig = { + client: mongoClient, + connectionString: process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING, + databaseName: "langchain", + collectionName: "chathistory", + }; const sessionId = new ObjectId().toString(); - const chatHistory = new AzureCosmosDBMongoChatMessageHistory({ - collection, - sessionId, - }); + const chatHistory = new AzureCosmosDBMongoChatMessageHistory( + dbcfg, + sessionId + ); const blankResult = await chatHistory.getMessages(); expect(blankResult).toStrictEqual([]); @@ -45,24 +52,28 @@ test("Test Azure Cosmos MongoDB history store", async () => { console.log(resultWithHistory); expect(resultWithHistory).toEqual(expectedMessages); - await client.close(); + await mongoClient.close(); }); test("Test clear Azure Cosmos MongoDB history store", async () => { expect(process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING).toBeDefined(); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const client = new MongoClient( + const mongoClient = new MongoClient( process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING! ); - await client.connect(); - const collection = client.db("langchain").collection("memory"); + const dbcfg: AzureCosmosDBMongoChatHistoryDBConfig = { + client: mongoClient, + connectionString: process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING, + databaseName: "langchain", + collectionName: "chathistory", + }; const sessionId = new ObjectId().toString(); - const chatHistory = new AzureCosmosDBMongoChatMessageHistory({ - collection, - sessionId, - }); + const chatHistory = new AzureCosmosDBMongoChatMessageHistory( + dbcfg, + sessionId + ); await chatHistory.addUserMessage("Who is the best vocalist?"); await chatHistory.addAIChatMessage("Ozzy Osbourne"); @@ -80,5 +91,5 @@ test("Test clear Azure Cosmos MongoDB history store", async () => { const blankResult = await chatHistory.getMessages(); expect(blankResult).toStrictEqual([]); - await client.close(); + await mongoClient.close(); });