From e8822ad2d75a7701d4ce75c2495c9aa6dd39ab5f Mon Sep 17 00:00:00 2001 From: Yohan Lasorsa Date: Tue, 3 Dec 2024 23:41:19 +0100 Subject: [PATCH] feat(azure-cosmosdb): add session context and retrieve all sessions for a user (#7242) Co-authored-by: Jacob Lee --- examples/src/memory/azure_cosmosdb_nosql.ts | 22 ++- libs/langchain-azure-cosmosdb/package.json | 6 +- .../src/azure_cosmosdb_nosql.ts | 6 +- .../src/chat_histories.ts | 62 ++++++- .../src/tests/caches.int.test.ts | 3 +- .../src/tests/chat_histories.int.test.ts | 33 +++- yarn.lock | 169 ++++++++++++------ 7 files changed, 236 insertions(+), 65 deletions(-) diff --git a/examples/src/memory/azure_cosmosdb_nosql.ts b/examples/src/memory/azure_cosmosdb_nosql.ts index 2f3cddf4460f..415a64b91f94 100644 --- a/examples/src/memory/azure_cosmosdb_nosql.ts +++ b/examples/src/memory/azure_cosmosdb_nosql.ts @@ -44,7 +44,7 @@ const res1 = await chainWithHistory.invoke( ); console.log({ res1 }); /* -{ res1: 'Hi Jim! How can I assist you today?' } + { res1: 'Hi Jim! How can I assist you today?' } */ const res2 = await chainWithHistory.invoke( @@ -52,7 +52,23 @@ const res2 = await chainWithHistory.invoke( { configurable: { sessionId: "langchain-test-session" } } ); console.log({ res2 }); - /* { res2: { response: 'You said your name was Jim.' } - */ + */ + +// Give this session a title +const chatHistory = (await chainWithHistory.getMessageHistory( + "langchain-test-session" +)) as AzureCosmsosDBNoSQLChatMessageHistory; + +await chatHistory.setContext({ title: "Introducing Jim" }); + +// List all session for the user +const sessions = await chatHistory.getAllSessions(); + +console.log(sessions); +/* + [ + { sessionId: 'langchain-test-session', context: { title: "Introducing Jim" } } + ] + */ diff --git a/libs/langchain-azure-cosmosdb/package.json b/libs/langchain-azure-cosmosdb/package.json index 7cd8a1cd4101..c43248a3647c 100644 --- a/libs/langchain-azure-cosmosdb/package.json +++ b/libs/langchain-azure-cosmosdb/package.json @@ -32,9 +32,9 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@azure/cosmos": "4.0.1-beta.3", - "@azure/identity": "^4.2.0", - "mongodb": "^6.8.0" + "@azure/cosmos": "^4.2.0", + "@azure/identity": "^4.5.0", + "mongodb": "^6.10.0" }, "peerDependencies": { "@langchain/core": ">=0.2.21 <0.4.0" diff --git a/libs/langchain-azure-cosmosdb/src/azure_cosmosdb_nosql.ts b/libs/langchain-azure-cosmosdb/src/azure_cosmosdb_nosql.ts index 618d43ab64c9..3e4acb259c77 100644 --- a/libs/langchain-azure-cosmosdb/src/azure_cosmosdb_nosql.ts +++ b/libs/langchain-azure-cosmosdb/src/azure_cosmosdb_nosql.ts @@ -15,7 +15,9 @@ import { IndexingPolicy, SqlParameter, SqlQuerySpec, + VectorEmbedding, VectorEmbeddingPolicy, + VectorIndex, } from "@azure/cosmos"; import { DefaultAzureCredential, TokenCredential } from "@azure/identity"; @@ -186,7 +188,7 @@ export class AzureCosmosDBNoSQLVectorStore extends VectorStore { distanceFunction: "cosine", // Will be determined automatically during initialization dimensions: 0, - }, + } as VectorEmbedding, ]; } @@ -195,7 +197,7 @@ export class AzureCosmosDBNoSQLVectorStore extends VectorStore { { path: "/vector", type: "quantizedFlat", - }, + } as VectorIndex, ]; } diff --git a/libs/langchain-azure-cosmosdb/src/chat_histories.ts b/libs/langchain-azure-cosmosdb/src/chat_histories.ts index 033acc521334..24e98fd3b074 100644 --- a/libs/langchain-azure-cosmosdb/src/chat_histories.ts +++ b/libs/langchain-azure-cosmosdb/src/chat_histories.ts @@ -1,4 +1,9 @@ -import { Container, CosmosClient, CosmosClientOptions } from "@azure/cosmos"; +import { + Container, + CosmosClient, + CosmosClientOptions, + ErrorResponse, +} from "@azure/cosmos"; import { DefaultAzureCredential, TokenCredential } from "@azure/identity"; import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; import { @@ -12,6 +17,14 @@ const USER_AGENT_SUFFIX = "langchainjs-cdbnosql-chathistory-javascript"; const DEFAULT_DATABASE_NAME = "chatHistoryDB"; const DEFAULT_CONTAINER_NAME = "chatHistoryContainer"; +/** + * Lightweight type for listing chat sessions. + */ +export type ChatSession = { + id: string; + context: Record; +}; + /** * Type for the input to the `AzureCosmosDBNoSQLChatMessageHistory` constructor. */ @@ -68,7 +81,6 @@ export interface AzureCosmosDBNoSQLChatMessageHistoryInput { * ); * ``` */ - export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHistory { lc_namespace = ["langchain", "stores", "message", "azurecosmosdb"]; @@ -90,6 +102,8 @@ export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHi private initPromise?: Promise; + private context: Record = {}; + constructor(chatHistoryInput: AzureCosmosDBNoSQLChatMessageHistoryInput) { super(); @@ -175,9 +189,11 @@ export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHi this.messageList = await this.getMessages(); this.messageList.push(message); const messages = mapChatMessagesToStoredMessages(this.messageList); + const context = await this.getContext(); await this.container.items.upsert({ id: this.sessionId, userId: this.userId, + context, messages, }); } @@ -188,17 +204,53 @@ export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHi await this.container.item(this.sessionId, this.userId).delete(); } - async clearAllSessionsForUser(userId: string) { + async clearAllSessions() { await this.initializeContainer(); const query = { query: "SELECT c.id FROM c WHERE c.userId = @userId", - parameters: [{ name: "@userId", value: userId }], + parameters: [{ name: "@userId", value: this.userId }], }; const { resources: userSessions } = await this.container.items .query(query) .fetchAll(); for (const userSession of userSessions) { - await this.container.item(userSession.id, userId).delete(); + await this.container.item(userSession.id, this.userId).delete(); + } + } + + async getAllSessions(): Promise { + await this.initializeContainer(); + const query = { + query: "SELECT c.id, c.context FROM c WHERE c.userId = @userId", + parameters: [{ name: "@userId", value: this.userId }], + }; + const { resources: userSessions } = await this.container.items + .query(query) + .fetchAll(); + return userSessions ?? []; + } + + async getContext(): Promise> { + const document = await this.container + .item(this.sessionId, this.userId) + .read(); + this.context = document.resource?.context || this.context; + return this.context; + } + + async setContext(context: Record): Promise { + await this.initializeContainer(); + this.context = context || {}; + try { + await this.container + .item(this.sessionId, this.userId) + .patch([{ op: "replace", path: "/context", value: this.context }]); + } catch (_error: unknown) { + const error = _error as ErrorResponse; + // If document does not exist yet, context will be set when adding the first message + if (error?.code !== 404) { + throw error; + } } } } diff --git a/libs/langchain-azure-cosmosdb/src/tests/caches.int.test.ts b/libs/langchain-azure-cosmosdb/src/tests/caches.int.test.ts index d6b66ddaac05..c7acb92f7c86 100644 --- a/libs/langchain-azure-cosmosdb/src/tests/caches.int.test.ts +++ b/libs/langchain-azure-cosmosdb/src/tests/caches.int.test.ts @@ -4,6 +4,7 @@ import { CosmosClient, IndexingMode, + VectorEmbedding, VectorEmbeddingPolicy, } from "@azure/cosmos"; import { DefaultAzureCredential } from "@azure/identity"; @@ -33,7 +34,7 @@ function vectorEmbeddingPolicy( dataType: "float32", distanceFunction, dimensions: dimension, - }, + } as VectorEmbedding, ], }; } diff --git a/libs/langchain-azure-cosmosdb/src/tests/chat_histories.int.test.ts b/libs/langchain-azure-cosmosdb/src/tests/chat_histories.int.test.ts index 81f2070ceb81..76da66d7f805 100644 --- a/libs/langchain-azure-cosmosdb/src/tests/chat_histories.int.test.ts +++ b/libs/langchain-azure-cosmosdb/src/tests/chat_histories.int.test.ts @@ -159,10 +159,41 @@ test("Test clear all sessions for a user", async () => { const result2 = await chatHistory1.getMessages(); expect(result2).toEqual(expectedMessages); - await chatHistory1.clearAllSessionsForUser("user1"); + await chatHistory1.clearAllSessions(); const deletedResult1 = await chatHistory1.getMessages(); const deletedResult2 = await chatHistory2.getMessages(); expect(deletedResult1).toStrictEqual([]); expect(deletedResult2).toStrictEqual([]); }); + +test("Test set context and get all sessions for a user", async () => { + const session1 = { + userId: "user1", + databaseName: DATABASE_NAME, + containerName: CONTAINER_NAME, + sessionId: new ObjectId().toString(), + }; + const context1 = { title: "Best vocalist" }; + const chatHistory1 = new AzureCosmsosDBNoSQLChatMessageHistory(session1); + + await chatHistory1.setContext(context1); + await chatHistory1.addUserMessage("Who is the best vocalist?"); + await chatHistory1.addAIMessage("Ozzy Osbourne"); + + const chatHistory2 = new AzureCosmsosDBNoSQLChatMessageHistory({ + ...session1, + sessionId: new ObjectId().toString(), + }); + const context2 = { title: "Best guitarist" }; + + await chatHistory2.addUserMessage("Who is the best guitarist?"); + await chatHistory2.addAIMessage("Jimi Hendrix"); + await chatHistory2.setContext(context2); + + const sessions = await chatHistory1.getAllSessions(); + + expect(sessions.length).toBe(2); + expect(sessions[0].context).toEqual(context1); + expect(sessions[1].context).toEqual(context2); +}); diff --git a/yarn.lock b/yarn.lock index d99b52dd6f78..1a6472c33e59 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6032,6 +6032,17 @@ __metadata: languageName: node linkType: hard +"@azure/core-auth@npm:^1.7.1, @azure/core-auth@npm:^1.8.0, @azure/core-auth@npm:^1.9.0": + version: 1.9.0 + resolution: "@azure/core-auth@npm:1.9.0" + dependencies: + "@azure/abort-controller": ^2.0.0 + "@azure/core-util": ^1.11.0 + tslib: ^2.6.2 + checksum: 4050112188db093c5e01caca0175708c767054c0cea4202430ff43ee42a16430235752ccc0002caea1796c8f01b4f6369c878762bf4c1b2f61af1b7ac13182fc + languageName: node + linkType: hard + "@azure/core-client@npm:^1.3.0": version: 1.7.3 resolution: "@azure/core-client@npm:1.7.3" @@ -6132,19 +6143,19 @@ __metadata: languageName: node linkType: hard -"@azure/core-rest-pipeline@npm:^1.2.0": - version: 1.16.2 - resolution: "@azure/core-rest-pipeline@npm:1.16.2" +"@azure/core-rest-pipeline@npm:^1.15.1, @azure/core-rest-pipeline@npm:^1.17.0": + version: 1.18.0 + resolution: "@azure/core-rest-pipeline@npm:1.18.0" dependencies: "@azure/abort-controller": ^2.0.0 - "@azure/core-auth": ^1.4.0 + "@azure/core-auth": ^1.8.0 "@azure/core-tracing": ^1.0.1 - "@azure/core-util": ^1.9.0 + "@azure/core-util": ^1.11.0 "@azure/logger": ^1.0.0 http-proxy-agent: ^7.0.0 https-proxy-agent: ^7.0.0 tslib: ^2.6.2 - checksum: b30bfdb7c49435c5f7c9493be8cd39d2d7a5bc24de4e7a772336f14f607517866d4bd0c97d15672f6aa2b630d27bd22b3561933cde1c415aa2e63ba6c18289b1 + checksum: 4c8e6572938fd693494ec44477b58afa7c16aed7ea8ef061fcc0cf8a8e602d7ea07676f46b8c850d38e04e5ac4ab10888f88bce8ffac6db1bd3b77bf07a07f29 languageName: node linkType: hard @@ -6176,6 +6187,15 @@ __metadata: languageName: node linkType: hard +"@azure/core-tracing@npm:^1.1.1": + version: 1.2.0 + resolution: "@azure/core-tracing@npm:1.2.0" + dependencies: + tslib: ^2.6.2 + checksum: 202ebf411a3076bd2c48b7a4c1b63335f53be6dd97f7d53500e3191b7ed0fdad25de219f422e777fde824031fd5c67087654de0304a5c0cd67c38cdcab96117c + languageName: node + linkType: hard + "@azure/core-util@npm:^1.0.0, @azure/core-util@npm:^1.1.0, @azure/core-util@npm:^1.3.0, @azure/core-util@npm:^1.4.0": version: 1.6.1 resolution: "@azure/core-util@npm:1.6.1" @@ -6196,7 +6216,17 @@ __metadata: languageName: node linkType: hard -"@azure/core-util@npm:^1.6.1, @azure/core-util@npm:^1.9.0": +"@azure/core-util@npm:^1.11.0, @azure/core-util@npm:^1.8.1": + version: 1.11.0 + resolution: "@azure/core-util@npm:1.11.0" + dependencies: + "@azure/abort-controller": ^2.0.0 + tslib: ^2.6.2 + checksum: 91e3ec329d9eddaa66be5efb1785dad68dcb48dd779fca36e39db041673230510158ff5ca9ccef9f19c3e4d8e9af29f66a367cfc31a7b94d2541f80ef94ec797 + languageName: node + linkType: hard + +"@azure/core-util@npm:^1.6.1": version: 1.9.2 resolution: "@azure/core-util@npm:1.9.2" dependencies: @@ -6206,28 +6236,25 @@ __metadata: languageName: node linkType: hard -"@azure/cosmos@npm:4.0.1-beta.3": - version: 4.0.1-beta.3 - resolution: "@azure/cosmos@npm:4.0.1-beta.3" +"@azure/cosmos@npm:^4.2.0": + version: 4.2.0 + resolution: "@azure/cosmos@npm:4.2.0" dependencies: - "@azure/abort-controller": ^1.0.0 - "@azure/core-auth": ^1.3.0 - "@azure/core-rest-pipeline": ^1.2.0 - "@azure/core-tracing": ^1.0.0 - debug: ^4.1.1 + "@azure/abort-controller": ^2.0.0 + "@azure/core-auth": ^1.7.1 + "@azure/core-rest-pipeline": ^1.15.1 + "@azure/core-tracing": ^1.1.1 + "@azure/core-util": ^1.8.1 fast-json-stable-stringify: ^2.1.0 - jsbi: ^3.1.3 - node-abort-controller: ^3.0.0 + jsbi: ^4.3.0 priorityqueuejs: ^2.0.0 - semaphore: ^1.0.5 - tslib: ^2.2.0 - universal-user-agent: ^6.0.0 - uuid: ^8.3.0 - checksum: 5223ba77195030898a3aa201f7dbf2c5d99be4f63cefa93c3542c4122d1ad36f3bab22a4113dba961b3c878d7b2b63ee52a269ada35473ebcd2c42c7643ca5a8 + semaphore: ^1.1.0 + tslib: ^2.6.2 + checksum: b571f5a99b12520a2128b8ed0eb61cd66c432e21f533e778cd54a508e89b8bd57e8e05eedc1dcfdb4417c91a675bdb63d6c1cfcd9a21895d444e51de80288f33 languageName: node linkType: hard -"@azure/identity@npm:^4.2.0, @azure/identity@npm:^4.2.1": +"@azure/identity@npm:^4.2.1": version: 4.4.1 resolution: "@azure/identity@npm:4.4.1" dependencies: @@ -6249,6 +6276,28 @@ __metadata: languageName: node linkType: hard +"@azure/identity@npm:^4.5.0": + version: 4.5.0 + resolution: "@azure/identity@npm:4.5.0" + dependencies: + "@azure/abort-controller": ^2.0.0 + "@azure/core-auth": ^1.9.0 + "@azure/core-client": ^1.9.2 + "@azure/core-rest-pipeline": ^1.17.0 + "@azure/core-tracing": ^1.0.0 + "@azure/core-util": ^1.11.0 + "@azure/logger": ^1.0.0 + "@azure/msal-browser": ^3.26.1 + "@azure/msal-node": ^2.15.0 + events: ^3.0.0 + jws: ^4.0.0 + open: ^8.0.0 + stoppable: ^1.1.0 + tslib: ^2.2.0 + checksum: 07d15898f194a220376d8d9c0ee891c93c6da188e44e76810fb781bf3bb7424498a6c1fa5b92c5a4d31f62b7398953f8a5bcf0f0ed57ed72239ce1c4f594b355 + languageName: node + linkType: hard + "@azure/logger@npm:^1.0.0, @azure/logger@npm:^1.0.3": version: 1.0.4 resolution: "@azure/logger@npm:1.0.4" @@ -6267,6 +6316,15 @@ __metadata: languageName: node linkType: hard +"@azure/msal-browser@npm:^3.26.1": + version: 3.27.0 + resolution: "@azure/msal-browser@npm:3.27.0" + dependencies: + "@azure/msal-common": 14.16.0 + checksum: 22c7d087380405f87139a7dfa579b8a49a17d5493e748e1e609f5733bb7549dd5b8558d709f81500f8faa3feebbc2245f8978adc96dc2ce84c54825b37301465 + languageName: node + linkType: hard + "@azure/msal-common@npm:14.14.0": version: 14.14.0 resolution: "@azure/msal-common@npm:14.14.0" @@ -6274,6 +6332,24 @@ __metadata: languageName: node linkType: hard +"@azure/msal-common@npm:14.16.0": + version: 14.16.0 + resolution: "@azure/msal-common@npm:14.16.0" + checksum: 01ec26e22243c5c435b97db085e96f5488733336c142b65a118ee6e523a548d3f17d013147810948cceaee7bdc339362bb9b2799fc9ea53c9d4c9aa10d8987e3 + languageName: node + linkType: hard + +"@azure/msal-node@npm:^2.15.0": + version: 2.16.2 + resolution: "@azure/msal-node@npm:2.16.2" + dependencies: + "@azure/msal-common": 14.16.0 + jsonwebtoken: ^9.0.0 + uuid: ^8.3.0 + checksum: 3676972cf7e1e91ea60773d7054275534239d209989da4c4c1aa790790ba309a2da58d6c593b6465feb1c7028772fce77757227e7ac9631b3a79e4f5a0a81aab + languageName: node + linkType: hard + "@azure/msal-node@npm:^2.9.2": version: 2.12.0 resolution: "@azure/msal-node@npm:2.12.0" @@ -11471,8 +11547,8 @@ __metadata: version: 0.0.0-use.local resolution: "@langchain/azure-cosmosdb@workspace:libs/langchain-azure-cosmosdb" dependencies: - "@azure/cosmos": 4.0.1-beta.3 - "@azure/identity": ^4.2.0 + "@azure/cosmos": ^4.2.0 + "@azure/identity": ^4.5.0 "@jest/globals": ^29.5.0 "@langchain/core": "workspace:*" "@langchain/openai": "workspace:^" @@ -11492,7 +11568,7 @@ __metadata: eslint-plugin-prettier: ^4.2.1 jest: ^29.5.0 jest-environment-node: ^29.6.4 - mongodb: ^6.8.0 + mongodb: ^6.10.0 prettier: ^2.8.3 release-it: ^15.10.1 rollup: ^4.5.2 @@ -32857,10 +32933,10 @@ __metadata: languageName: node linkType: hard -"jsbi@npm:^3.1.3": - version: 3.2.5 - resolution: "jsbi@npm:3.2.5" - checksum: 642d1bb139ad1c1e96c4907eb159565e980a0d168487626b493d0d0b7b341da0e43001089d3b21703fe17b18a7a6c0f42c92026f71d54471ed0a0d1b3015ec0f +"jsbi@npm:^4.3.0": + version: 4.3.0 + resolution: "jsbi@npm:4.3.0" + checksum: 27c4f178eb7fd9d1756144066fdebc62f4a0176e877f55e646e8ce84075c13551bd575a316b9959ccdcca9d5dc05a81c9907cfa09f0cfeb43c9777797e36b0e9 languageName: node linkType: hard @@ -34972,12 +35048,12 @@ __metadata: languageName: node linkType: hard -"mongodb@npm:^6.3.0": - version: 6.3.0 - resolution: "mongodb@npm:6.3.0" +"mongodb@npm:^6.10.0": + version: 6.10.0 + resolution: "mongodb@npm:6.10.0" dependencies: - "@mongodb-js/saslprep": ^1.1.0 - bson: ^6.2.0 + "@mongodb-js/saslprep": ^1.1.5 + bson: ^6.7.0 mongodb-connection-string-url: ^3.0.0 peerDependencies: "@aws-sdk/credential-providers": ^3.188.0 @@ -35002,16 +35078,16 @@ __metadata: optional: true socks: optional: true - checksum: ebc5d9dbd1299321b6873e86eb4ea635316f97450644811db24ce2b01432b1c641def864facf2eab6f0c0c5c360c318108ea5555142f55177ca4c33991c6d7c4 + checksum: b8e7ab9fb84181cb020b5fef5fedd90a5fc12140e688fa12ba588d523a958bb9f8790bfaceeca9f594171794eda0f56be855d7d0588705db82b3de7bf5e2352c languageName: node linkType: hard -"mongodb@npm:^6.8.0": - version: 6.8.0 - resolution: "mongodb@npm:6.8.0" +"mongodb@npm:^6.3.0": + version: 6.3.0 + resolution: "mongodb@npm:6.3.0" dependencies: - "@mongodb-js/saslprep": ^1.1.5 - bson: ^6.7.0 + "@mongodb-js/saslprep": ^1.1.0 + bson: ^6.2.0 mongodb-connection-string-url: ^3.0.0 peerDependencies: "@aws-sdk/credential-providers": ^3.188.0 @@ -35036,7 +35112,7 @@ __metadata: optional: true socks: optional: true - checksum: 5a744e9bf0f21a6f639d935b807ea4c4502f6c38719413e7c6dbed2323786c347a877e905bfd711259f552b21774a5d9d8a9271c97ed1634804f97f10addd440 + checksum: ebc5d9dbd1299321b6873e86eb4ea635316f97450644811db24ce2b01432b1c641def864facf2eab6f0c0c5c360c318108ea5555142f55177ca4c33991c6d7c4 languageName: node linkType: hard @@ -35331,13 +35407,6 @@ __metadata: languageName: node linkType: hard -"node-abort-controller@npm:^3.0.0": - version: 3.1.1 - resolution: "node-abort-controller@npm:3.1.1" - checksum: 2c340916af9710328b11c0828223fc65ba320e0d082214a211311bf64c2891028e42ef276b9799188c4ada9e6e1c54cf7a0b7c05dd9d59fcdc8cd633304c8047 - languageName: node - linkType: hard - "node-addon-api@npm:^3.0.0": version: 3.2.1 resolution: "node-addon-api@npm:3.2.1" @@ -40028,7 +40097,7 @@ __metadata: languageName: node linkType: hard -"semaphore@npm:^1.0.5": +"semaphore@npm:^1.1.0": version: 1.1.0 resolution: "semaphore@npm:1.1.0" checksum: d2445d232ad9959048d4748ef54eb01bc7b60436be2b42fb7de20c4cffacf70eafeeecd3772c1baf408cfdce3805fa6618a4389590335671f18cde54ef3cfae4