Skip to content

Commit

Permalink
feat(azure-cosmosdb): add session context and retrieve all sessions f…
Browse files Browse the repository at this point in the history
…or a user (#7242)

Co-authored-by: Jacob Lee <[email protected]>
  • Loading branch information
sinedied and jacoblee93 authored Dec 3, 2024
1 parent 8eadded commit e8822ad
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 65 deletions.
22 changes: 19 additions & 3 deletions examples/src/memory/azure_cosmosdb_nosql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,31 @@ const res1 = await chainWithHistory.invoke(
);
console.log({ res1 });
/*
{ res1: 'Hi Jim! How can I assist you today?' }
{ res1: 'Hi Jim! How can I assist you today?' }
*/

const res2 = await chainWithHistory.invoke(
{ input: "What did I just say my name was?" },
{ configurable: { sessionId: "langchain-test-session" } }
);
console.log({ res2 });

/*
{ res2: { response: 'You said your name was Jim.' }
*/
*/

// Give this session a title
const chatHistory = (await chainWithHistory.getMessageHistory(
"langchain-test-session"
)) as AzureCosmsosDBNoSQLChatMessageHistory;

await chatHistory.setContext({ title: "Introducing Jim" });

// List all session for the user
const sessions = await chatHistory.getAllSessions();

console.log(sessions);
/*
[
{ sessionId: 'langchain-test-session', context: { title: "Introducing Jim" } }
]
*/
6 changes: 3 additions & 3 deletions libs/langchain-azure-cosmosdb/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@azure/cosmos": "4.0.1-beta.3",
"@azure/identity": "^4.2.0",
"mongodb": "^6.8.0"
"@azure/cosmos": "^4.2.0",
"@azure/identity": "^4.5.0",
"mongodb": "^6.10.0"
},
"peerDependencies": {
"@langchain/core": ">=0.2.21 <0.4.0"
Expand Down
6 changes: 4 additions & 2 deletions libs/langchain-azure-cosmosdb/src/azure_cosmosdb_nosql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import {
IndexingPolicy,
SqlParameter,
SqlQuerySpec,
VectorEmbedding,
VectorEmbeddingPolicy,
VectorIndex,
} from "@azure/cosmos";
import { DefaultAzureCredential, TokenCredential } from "@azure/identity";

Expand Down Expand Up @@ -186,7 +188,7 @@ export class AzureCosmosDBNoSQLVectorStore extends VectorStore {
distanceFunction: "cosine",
// Will be determined automatically during initialization
dimensions: 0,
},
} as VectorEmbedding,
];
}

Expand All @@ -195,7 +197,7 @@ export class AzureCosmosDBNoSQLVectorStore extends VectorStore {
{
path: "/vector",
type: "quantizedFlat",
},
} as VectorIndex,
];
}

Expand Down
62 changes: 57 additions & 5 deletions libs/langchain-azure-cosmosdb/src/chat_histories.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { Container, CosmosClient, CosmosClientOptions } from "@azure/cosmos";
import {
Container,
CosmosClient,
CosmosClientOptions,
ErrorResponse,
} from "@azure/cosmos";
import { DefaultAzureCredential, TokenCredential } from "@azure/identity";
import { BaseListChatMessageHistory } from "@langchain/core/chat_history";
import {
Expand All @@ -12,6 +17,14 @@ const USER_AGENT_SUFFIX = "langchainjs-cdbnosql-chathistory-javascript";
const DEFAULT_DATABASE_NAME = "chatHistoryDB";
const DEFAULT_CONTAINER_NAME = "chatHistoryContainer";

/**
* Lightweight type for listing chat sessions.
*/
export type ChatSession = {
id: string;
context: Record<string, unknown>;
};

/**
* Type for the input to the `AzureCosmosDBNoSQLChatMessageHistory` constructor.
*/
Expand Down Expand Up @@ -68,7 +81,6 @@ export interface AzureCosmosDBNoSQLChatMessageHistoryInput {
* );
* ```
*/

export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHistory {
lc_namespace = ["langchain", "stores", "message", "azurecosmosdb"];

Expand All @@ -90,6 +102,8 @@ export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHi

private initPromise?: Promise<void>;

private context: Record<string, unknown> = {};

constructor(chatHistoryInput: AzureCosmosDBNoSQLChatMessageHistoryInput) {
super();

Expand Down Expand Up @@ -175,9 +189,11 @@ export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHi
this.messageList = await this.getMessages();
this.messageList.push(message);
const messages = mapChatMessagesToStoredMessages(this.messageList);
const context = await this.getContext();
await this.container.items.upsert({
id: this.sessionId,
userId: this.userId,
context,
messages,
});
}
Expand All @@ -188,17 +204,53 @@ export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHi
await this.container.item(this.sessionId, this.userId).delete();
}

async clearAllSessionsForUser(userId: string) {
async clearAllSessions() {
await this.initializeContainer();
const query = {
query: "SELECT c.id FROM c WHERE c.userId = @userId",
parameters: [{ name: "@userId", value: userId }],
parameters: [{ name: "@userId", value: this.userId }],
};
const { resources: userSessions } = await this.container.items
.query(query)
.fetchAll();
for (const userSession of userSessions) {
await this.container.item(userSession.id, userId).delete();
await this.container.item(userSession.id, this.userId).delete();
}
}

async getAllSessions(): Promise<ChatSession[]> {
await this.initializeContainer();
const query = {
query: "SELECT c.id, c.context FROM c WHERE c.userId = @userId",
parameters: [{ name: "@userId", value: this.userId }],
};
const { resources: userSessions } = await this.container.items
.query(query)
.fetchAll();
return userSessions ?? [];
}

async getContext(): Promise<Record<string, unknown>> {
const document = await this.container
.item(this.sessionId, this.userId)
.read();
this.context = document.resource?.context || this.context;
return this.context;
}

async setContext(context: Record<string, unknown>): Promise<void> {
await this.initializeContainer();
this.context = context || {};
try {
await this.container
.item(this.sessionId, this.userId)
.patch([{ op: "replace", path: "/context", value: this.context }]);
} catch (_error: unknown) {
const error = _error as ErrorResponse;
// If document does not exist yet, context will be set when adding the first message
if (error?.code !== 404) {
throw error;
}
}
}
}
3 changes: 2 additions & 1 deletion libs/langchain-azure-cosmosdb/src/tests/caches.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import {
CosmosClient,
IndexingMode,
VectorEmbedding,
VectorEmbeddingPolicy,
} from "@azure/cosmos";
import { DefaultAzureCredential } from "@azure/identity";
Expand Down Expand Up @@ -33,7 +34,7 @@ function vectorEmbeddingPolicy(
dataType: "float32",
distanceFunction,
dimensions: dimension,
},
} as VectorEmbedding,
],
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,41 @@ test("Test clear all sessions for a user", async () => {
const result2 = await chatHistory1.getMessages();
expect(result2).toEqual(expectedMessages);

await chatHistory1.clearAllSessionsForUser("user1");
await chatHistory1.clearAllSessions();

const deletedResult1 = await chatHistory1.getMessages();
const deletedResult2 = await chatHistory2.getMessages();
expect(deletedResult1).toStrictEqual([]);
expect(deletedResult2).toStrictEqual([]);
});

test("Test set context and get all sessions for a user", async () => {
const session1 = {
userId: "user1",
databaseName: DATABASE_NAME,
containerName: CONTAINER_NAME,
sessionId: new ObjectId().toString(),
};
const context1 = { title: "Best vocalist" };
const chatHistory1 = new AzureCosmsosDBNoSQLChatMessageHistory(session1);

await chatHistory1.setContext(context1);
await chatHistory1.addUserMessage("Who is the best vocalist?");
await chatHistory1.addAIMessage("Ozzy Osbourne");

const chatHistory2 = new AzureCosmsosDBNoSQLChatMessageHistory({
...session1,
sessionId: new ObjectId().toString(),
});
const context2 = { title: "Best guitarist" };

await chatHistory2.addUserMessage("Who is the best guitarist?");
await chatHistory2.addAIMessage("Jimi Hendrix");
await chatHistory2.setContext(context2);

const sessions = await chatHistory1.getAllSessions();

expect(sessions.length).toBe(2);
expect(sessions[0].context).toEqual(context1);
expect(sessions[1].context).toEqual(context2);
});
Loading

0 comments on commit e8822ad

Please sign in to comment.