Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(azure-cosmosdb): add AzureCosmosDBMongoChatMessageHistory #7305

Merged
merged 6 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import {
Collection,
Document as AzureCosmosMongoDBDocument,
PushOperator,
Db,
MongoClient,
} from "mongodb";
import { BaseListChatMessageHistory } from "@langchain/core/chat_history";
import {
BaseMessage,
mapChatMessagesToStoredMessages,
mapStoredMessagesToChatMessages,
} from "@langchain/core/messages";
import { getEnvironmentVariable } from "@langchain/core/utils/env";

export interface AzureCosmosDBMongoChatHistoryDBConfig {
readonly client?: MongoClient;
readonly connectionString?: string;
readonly databaseName?: string;
readonly collectionName?: string;
}

export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHistory {
lc_namespace = ["langchain", "stores", "message", "azurecosmosdb"];

get lc_secrets(): { [key: string]: string } {
return {
connectionString: "AZURE_COSMOSDB_MONGODB_CONNECTION_STRING",
};
}

private initPromise?: Promise<void>;

private readonly client: MongoClient | undefined;

private database: Db;

private collection: Collection<AzureCosmosMongoDBDocument>;

private sessionId: string;

private idKey = "sessionId";
fatmelon marked this conversation as resolved.
Show resolved Hide resolved

initialize: () => Promise<void>;

constructor(
dbConfig: AzureCosmosDBMongoChatHistoryDBConfig,
sessionId: string
) {
super();

const connectionString =
dbConfig.connectionString ??
getEnvironmentVariable("AZURE_COSMOSDB_MONGODB_CONNECTION_STRING");

if (!dbConfig.client && !connectionString) {
throw new Error(
"Mongo client or connection string must be set."
);
}

if (!dbConfig.client) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
this.client = new MongoClient(connectionString!, {
appName: "langchainjs",
});
}

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const client = dbConfig.client || this.client!;
const databaseName = dbConfig.databaseName ?? "documentsDB";
fatmelon marked this conversation as resolved.
Show resolved Hide resolved
const collectionName = dbConfig.collectionName ?? "documents";
fatmelon marked this conversation as resolved.
Show resolved Hide resolved

this.sessionId = sessionId;

// Deferring initialization to the first call to `initialize`
this.initialize = () => {
if (this.initPromise === undefined) {
this.initPromise = this.init(
client,
databaseName,
collectionName
).catch((error) => {
console.error(
"Error during AzureCosmosDBMongoChatMessageHistory initialization: ",
error
);
});
}

return this.initPromise;
};
}

/**
* Initializes the AzureCosmosDBMongoChatMessageHistory by connecting to the database.
* @param client The MongoClient to use for connecting to the database.
* @param databaseName The name of the database to use.
* @param collectionName The name of the collection to use.
* @returns A promise that resolves when the AzureCosmosDBMongoChatMessageHistory has been initialized.
*/
private async init(
client: MongoClient,
databaseName: string,
collectionName: string
): Promise<void> {
this.initPromise = (async () => {
await client.connect();
this.database = client.db(databaseName);
this.collection = this.database.collection(collectionName);
})();

return this.initPromise;
}

/**
* Retrieves the messages stored in the history.
* @returns A promise that resolves with the messages stored in the history.
*/
async getMessages(): Promise<BaseMessage[]> {
await this.initialize();

const document = await this.collection.findOne({
[this.idKey]: this.sessionId,
});
const messages = document?.messages || [];
return mapStoredMessagesToChatMessages(messages);
}

/**
* Adds a message to the history.
* @param message The message to add to the history.
* @returns A promise that resolves when the message has been added to the history.
*/
async addMessage(message: BaseMessage): Promise<void> {
await this.initialize();

const messages = mapChatMessagesToStoredMessages([message]);
await this.collection.updateOne(
{ [this.idKey]: this.sessionId },
{
$push: { messages: { $each: messages } } as PushOperator<Document>,
},
{ upsert: true }
);
}

/**
* Clear the history.
* @returns A promise that resolves when the history has been cleared.
*/
async clear(): Promise<void> {
await this.initialize();

await this.collection.deleteOne({ [this.idKey]: this.sessionId });
}
}
3 changes: 2 additions & 1 deletion libs/langchain-azure-cosmosdb/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export * from "./azure_cosmosdb_mongodb.js";
export * from "./azure_cosmosdb_nosql.js";
export * from "./caches.js";
export * from "./chat_histories.js";
export * from "./chat_histories/nosql.js";
export * from "./chat_histories/mongodb.js";
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { HumanMessage, AIMessage } from "@langchain/core/messages";
import { CosmosClient } from "@azure/cosmos";
import { DefaultAzureCredential } from "@azure/identity";
import { ObjectId } from "mongodb";
import { AzureCosmsosDBNoSQLChatMessageHistory } from "../chat_histories.js";
import { AzureCosmsosDBNoSQLChatMessageHistory } from "../chat_histories/nosql.js";

const DATABASE_NAME = "langchainTestDB";
const CONTAINER_NAME = "testContainer";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/* eslint-disable no-process-env */

import { MongoClient, ObjectId } from "mongodb";
import { AIMessage, HumanMessage } from "@langchain/core/messages";
import {
AzureCosmosDBMongoChatMessageHistory,
AzureCosmosDBMongoChatHistoryDBConfig,
} from "../chat_histories/mongodb.js";

afterAll(async () => {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const client = new MongoClient(
process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING!
);
await client.connect();
await client.db("langchain").dropDatabase();
await client.close();
});

test("Test Azure Cosmos MongoDB history store", async () => {
expect(process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING).toBeDefined();

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const mongoClient = new MongoClient(
process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING!
);
const dbcfg: AzureCosmosDBMongoChatHistoryDBConfig = {
client: mongoClient,
connectionString: process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING,
databaseName: "langchain",
collectionName: "chathistory",
};

const sessionId = new ObjectId().toString();
const chatHistory = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId
);

const blankResult = await chatHistory.getMessages();
expect(blankResult).toStrictEqual([]);

await chatHistory.addUserMessage("Who is the best vocalist?");
await chatHistory.addAIChatMessage("Ozzy Osbourne");

const expectedMessages = [
new HumanMessage("Who is the best vocalist?"),
new AIMessage("Ozzy Osbourne"),
];

const resultWithHistory = await chatHistory.getMessages();
console.log(resultWithHistory);
expect(resultWithHistory).toEqual(expectedMessages);

await mongoClient.close();
});

test("Test clear Azure Cosmos MongoDB history store", async () => {
expect(process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING).toBeDefined();

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const mongoClient = new MongoClient(
process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING!
);
const dbcfg: AzureCosmosDBMongoChatHistoryDBConfig = {
client: mongoClient,
connectionString: process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING,
databaseName: "langchain",
collectionName: "chathistory",
};

const sessionId = new ObjectId().toString();
const chatHistory = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId
);

await chatHistory.addUserMessage("Who is the best vocalist?");
await chatHistory.addAIChatMessage("Ozzy Osbourne");

const expectedMessages = [
new HumanMessage("Who is the best vocalist?"),
new AIMessage("Ozzy Osbourne"),
];

const resultWithHistory = await chatHistory.getMessages();
expect(resultWithHistory).toEqual(expectedMessages);

await chatHistory.clear();

const blankResult = await chatHistory.getMessages();
expect(blankResult).toStrictEqual([]);

await mongoClient.close();
});
32 changes: 16 additions & 16 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13491,12 +13491,12 @@ __metadata:
languageName: node
linkType: hard

"@mongodb-js/saslprep@npm:^1.1.5":
version: 1.1.8
resolution: "@mongodb-js/saslprep@npm:1.1.8"
"@mongodb-js/saslprep@npm:^1.1.9":
version: 1.1.9
resolution: "@mongodb-js/saslprep@npm:1.1.9"
dependencies:
sparse-bitfield: ^3.0.3
checksum: 259fda7ec913b5e63f102ae18840ef0d811c4a50919bbc437bd0452980806d640cd06c36076ed1655f1581ef24cd7316be0671f4b7429e7c97c7066524d2dbee
checksum: 6f13983e41c9fbd5273eeae9135e47e5b7a19125a63287bea69e33a618f8e034cfcf2258c77d0f5d6dcf386dfe2bb520bc01613afd1528c52f82c71172629242
languageName: node
linkType: hard

Expand Down Expand Up @@ -22627,20 +22627,20 @@ __metadata:
languageName: node
linkType: hard

"bson@npm:^6.10.0":
version: 6.10.1
resolution: "bson@npm:6.10.1"
checksum: 7c85c8df309bbfd4d42fae54aa37112ee048a89457be908a0e53a01d077d548c94a5a6870dd725ef48130da935286edc8b9ce04830869446db22b8c13a370c42
languageName: node
linkType: hard

"bson@npm:^6.2.0":
version: 6.2.0
resolution: "bson@npm:6.2.0"
checksum: 950fccd2abd0ff5a1bd3637f4697631298f1538314994ab8c9e13f1c9851d0fd042b54fe8340e00151c2acee43917ea40e64b800ceeea811b00f2de3e900c77e
languageName: node
linkType: hard

"bson@npm:^6.7.0":
version: 6.8.0
resolution: "bson@npm:6.8.0"
checksum: 66076b04d7d54e7773d601a19b7c224bc5cff6b008efe102463fbc058879f2c84c0ed793b5b6ed12cc7616bbbe5e670db81cf7352e0ea947918119f8af704ba5
languageName: node
linkType: hard

"buffer-alloc-unsafe@npm:^1.1.0":
version: 1.1.0
resolution: "buffer-alloc-unsafe@npm:1.1.0"
Expand Down Expand Up @@ -35049,11 +35049,11 @@ __metadata:
linkType: hard

"mongodb@npm:^6.10.0":
version: 6.10.0
resolution: "mongodb@npm:6.10.0"
version: 6.11.0
resolution: "mongodb@npm:6.11.0"
dependencies:
"@mongodb-js/saslprep": ^1.1.5
bson: ^6.7.0
"@mongodb-js/saslprep": ^1.1.9
bson: ^6.10.0
mongodb-connection-string-url: ^3.0.0
peerDependencies:
"@aws-sdk/credential-providers": ^3.188.0
Expand All @@ -35078,7 +35078,7 @@ __metadata:
optional: true
socks:
optional: true
checksum: b8e7ab9fb84181cb020b5fef5fedd90a5fc12140e688fa12ba588d523a958bb9f8790bfaceeca9f594171794eda0f56be855d7d0588705db82b3de7bf5e2352c
checksum: cb677bdee565eb9e7cbc27e538d5fafe61312e8ccfc97d4e04fdbf282f03e566537d659394f3236b6f958392707924a7ecae86fcb038a2f8f47b8edafa6edf4d
languageName: node
linkType: hard

Expand Down