Skip to content

Commit

Permalink
Respect order, number to skip and number to take of chat messages at …
Browse files Browse the repository at this point in the history
…underlying DB level rather than higher (microsoft#902)

### Motivation and Context
As described in microsoft#718, we load ALL the messages from ALL of the chats of
the user in the frontend at log in.

Also, no matter how many messages we want, we read them ALL at the DB
level and then throw some away if the number read is too higher at the
service level.

### Description
Now, with this change, we actually respect the order, the number to skip
and the number to take of chat messages at the underlying DB level
rather than at higher service level.

This enables us to save some DB activity and make our queries from the
frontend eventually a lot tighter,

This change was made in a simple manner which doesn't change the
underlying architecture.

### Contribution Checklist
- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [Contribution
Guidelines](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
glahaye authored Mar 27, 2024
1 parent d81e8b0 commit 5ea4120
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 30 deletions.
3 changes: 1 addition & 2 deletions webapi/Controllers/ChatArchiveController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ private async Task<List<Citation>> GetMemoryRecordsAndAppendToEmbeddingsAsync(
/// <returns>The list of chat messages in descending order of the timestamp</returns>
private async Task<List<CopilotChatMessage>> GetAllChatMessagesAsync(string chatId)
{
return (await this._chatMessageRepository.FindByChatIdAsync(chatId))
.OrderByDescending(m => m.Timestamp).ToList();
return (await this._chatMessageRepository.FindByChatIdAsync(chatId)).ToList();
}
}
16 changes: 6 additions & 10 deletions webapi/Controllers/ChatHistoryController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ public async Task<IActionResult> GetAllChatSessionsAsync()
}

/// <summary>
/// Get all chat messages for a chat session.
/// The list will be ordered with the first entry being the most recent message.
/// Get chat messages for a chat session.
/// Messages are returned ordered from most recent to oldest.
/// </summary>
/// <param name="chatId">The chat id.</param>
/// <param name="startIdx">The start index at which the first message will be returned.</param>
/// <param name="count">The number of messages to return. -1 will return all messages starting from startIdx.</param>
/// <param name="skip">Number of messages to skip before starting to return messages.</param>
/// <param name="count">The number of messages to return. -1 returns all messages.</param>
[HttpGet]
[Route("chats/{chatId:guid}/messages")]
[ProducesResponseType(StatusCodes.Status200OK)]
Expand All @@ -183,19 +183,15 @@ public async Task<IActionResult> GetAllChatSessionsAsync()
[Authorize(Policy = AuthPolicyName.RequireChatParticipant)]
public async Task<IActionResult> GetChatMessagesAsync(
[FromRoute] Guid chatId,
[FromQuery] int startIdx = 0,
[FromQuery] int skip = 0,
[FromQuery] int count = -1)
{
// TODO: [Issue #48] the code mixes strings and Guid without being explicit about the serialization format
var chatMessages = await this._messageRepository.FindByChatIdAsync(chatId.ToString());
var chatMessages = await this._messageRepository.FindByChatIdAsync(chatId.ToString(), skip, count);
if (!chatMessages.Any())
{
return this.NotFound($"No messages found for chat id '{chatId}'.");
}

chatMessages = chatMessages.OrderByDescending(m => m.Timestamp).Skip(startIdx);
if (count >= 0) { chatMessages = chatMessages.Take(count); }

return this.Ok(chatMessages);
}

Expand Down
8 changes: 4 additions & 4 deletions webapi/Extensions/ServiceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ internal static IServiceCollection AddCorsPolicy(this IServiceCollection service
public static IServiceCollection AddPersistentChatStore(this IServiceCollection services)
{
IStorageContext<ChatSession> chatSessionStorageContext;
IStorageContext<CopilotChatMessage> chatMessageStorageContext;
ICopilotChatMessageStorageContext chatMessageStorageContext;
IStorageContext<MemorySource> chatMemorySourceStorageContext;
IStorageContext<ChatParticipant> chatParticipantStorageContext;

Expand All @@ -175,7 +175,7 @@ public static IServiceCollection AddPersistentChatStore(this IServiceCollection
case ChatStoreOptions.ChatStoreType.Volatile:
{
chatSessionStorageContext = new VolatileContext<ChatSession>();
chatMessageStorageContext = new VolatileContext<CopilotChatMessage>();
chatMessageStorageContext = new VolatileCopilotChatMessageContext();
chatMemorySourceStorageContext = new VolatileContext<MemorySource>();
chatParticipantStorageContext = new VolatileContext<ChatParticipant>();
break;
Expand All @@ -192,7 +192,7 @@ public static IServiceCollection AddPersistentChatStore(this IServiceCollection
string directory = Path.GetDirectoryName(fullPath) ?? string.Empty;
chatSessionStorageContext = new FileSystemContext<ChatSession>(
new FileInfo(Path.Combine(directory, $"{Path.GetFileNameWithoutExtension(fullPath)}_sessions{Path.GetExtension(fullPath)}")));
chatMessageStorageContext = new FileSystemContext<CopilotChatMessage>(
chatMessageStorageContext = new FileSystemCopilotChatMessageContext(
new FileInfo(Path.Combine(directory, $"{Path.GetFileNameWithoutExtension(fullPath)}_messages{Path.GetExtension(fullPath)}")));
chatMemorySourceStorageContext = new FileSystemContext<MemorySource>(
new FileInfo(Path.Combine(directory, $"{Path.GetFileNameWithoutExtension(fullPath)}_memorysources{Path.GetExtension(fullPath)}")));
Expand All @@ -210,7 +210,7 @@ public static IServiceCollection AddPersistentChatStore(this IServiceCollection
#pragma warning disable CA2000 // Dispose objects before losing scope - objects are singletons for the duration of the process and disposed when the process exits.
chatSessionStorageContext = new CosmosDbContext<ChatSession>(
chatStoreConfig.Cosmos.ConnectionString, chatStoreConfig.Cosmos.Database, chatStoreConfig.Cosmos.ChatSessionsContainer);
chatMessageStorageContext = new CosmosDbContext<CopilotChatMessage>(
chatMessageStorageContext = new CosmosDbCopilotChatMessageContext(
chatStoreConfig.Cosmos.ConnectionString, chatStoreConfig.Cosmos.Database, chatStoreConfig.Cosmos.ChatMessagesContainer);
chatMemorySourceStorageContext = new CosmosDbContext<MemorySource>(
chatStoreConfig.Cosmos.ConnectionString, chatStoreConfig.Cosmos.Database, chatStoreConfig.Cosmos.ChatMemorySourcesContainer);
Expand Down
3 changes: 1 addition & 2 deletions webapi/Plugins/Chat/ChatPlugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ private async Task<string> GetAllowedChatHistoryAsync(
ChatHistory? chatHistory = null,
CancellationToken cancellationToken = default)
{
var messages = await this._chatMessageRepository.FindByChatIdAsync(chatId);
var sortedMessages = messages.OrderByDescending(m => m.Timestamp);
var sortedMessages = await this._chatMessageRepository.FindByChatIdAsync(chatId, 0, 100);

ChatHistory allottedChatHistory = new();
var remainingToken = tokenLimit;
Expand Down
14 changes: 8 additions & 6 deletions webapi/Storage/ChatMessageRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace CopilotChat.WebApi.Storage;
/// <summary>
/// A repository for chat messages.
/// </summary>
public class ChatMessageRepository : Repository<CopilotChatMessage>
public class ChatMessageRepository : CopilotChatMessageRepository
{
/// <summary>
/// Initializes a new instance of the ChatMessageRepository class.
/// </summary>
/// <param name="storageContext">The storage context.</param>
public ChatMessageRepository(IStorageContext<CopilotChatMessage> storageContext)
public ChatMessageRepository(ICopilotChatMessageStorageContext storageContext)
: base(storageContext)
{
}
Expand All @@ -25,10 +25,12 @@ public ChatMessageRepository(IStorageContext<CopilotChatMessage> storageContext)
/// Finds chat messages by chat id.
/// </summary>
/// <param name="chatId">The chat id.</param>
/// <returns>A list of ChatMessages matching the given chatId.</returns>
public Task<IEnumerable<CopilotChatMessage>> FindByChatIdAsync(string chatId)
/// <param name="skip">Number of messages to skip before starting to return messages.</param>
/// <param name="count">The number of messages to return. -1 returns all messages.</param>
/// <returns>A list of ChatMessages matching the given chatId sorted from most recent to oldest.</returns>
public Task<IEnumerable<CopilotChatMessage>> FindByChatIdAsync(string chatId, int skip = 0, int count = -1)
{
return base.StorageContext.QueryEntitiesAsync(e => e.ChatId == chatId);
return base.QueryEntitiesAsync(e => e.ChatId == chatId, skip, count);
}

/// <summary>
Expand All @@ -38,7 +40,7 @@ public Task<IEnumerable<CopilotChatMessage>> FindByChatIdAsync(string chatId)
/// <returns>The most recent ChatMessage matching the given chatId.</returns>
public async Task<CopilotChatMessage> FindLastByChatIdAsync(string chatId)
{
var chatMessages = await this.FindByChatIdAsync(chatId);
var chatMessages = await this.FindByChatIdAsync(chatId, 0, 1);
var first = chatMessages.MaxBy(e => e.Timestamp);
return first ?? throw new KeyNotFoundException($"No messages found for chat '{chatId}'.");
}
Expand Down
30 changes: 29 additions & 1 deletion webapi/Storage/CosmosDbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Linq;
using System.Net;
using System.Threading.Tasks;
using CopilotChat.WebApi.Models.Storage;
using Microsoft.Azure.Cosmos;

namespace CopilotChat.WebApi.Storage;
Expand All @@ -22,7 +23,9 @@ public class CosmosDbContext<T> : IStorageContext<T>, IDisposable where T : ISto
/// <summary>
/// CosmosDB container.
/// </summary>
private readonly Container _container;
#pragma warning disable CA1051 // Do not declare visible instance fields
protected readonly Container _container;
#pragma warning restore CA1051 // Do not declare visible instance fields

/// <summary>
/// Initializes a new instance of the CosmosDbContext class.
Expand Down Expand Up @@ -117,3 +120,28 @@ protected virtual void Dispose(bool disposing)
}
}
}

/// <summary>
/// Specialization of CosmosDbContext<T> for CopilotChatMessage.
/// </summary>
public class CosmosDbCopilotChatMessageContext : CosmosDbContext<CopilotChatMessage>, ICopilotChatMessageStorageContext
{
/// <summary>
/// Initializes a new instance of the CosmosDbCopilotChatMessageContext class.
/// </summary>
/// <param name="connectionString">The CosmosDB connection string.</param>
/// <param name="database">The CosmosDB database name.</param>
/// <param name="container">The CosmosDB container name.</param>
public CosmosDbCopilotChatMessageContext(string connectionString, string database, string container) :
base(connectionString, database, container)
{
}

/// <inheritdoc/>
public Task<IEnumerable<CopilotChatMessage>> QueryEntitiesAsync(Func<CopilotChatMessage, bool> predicate, int skip, int count)
{
return Task.Run<IEnumerable<CopilotChatMessage>>(
() => this._container.GetItemLinqQueryable<CopilotChatMessage>(true)
.Where(predicate).OrderByDescending(m => m.Timestamp).Skip(skip).Take(count).AsEnumerable());
}
}
32 changes: 30 additions & 2 deletions webapi/Storage/FileSystemContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Text.Json;
using System.Threading.Tasks;
using CopilotChat.WebApi.Models.Storage;

namespace CopilotChat.WebApi.Storage;

Expand Down Expand Up @@ -99,14 +100,16 @@ public Task UpsertAsync(T entity)
/// <summary>
/// A concurrent dictionary to store entities in memory.
/// </summary>
private sealed class EntityDictionary : ConcurrentDictionary<string, T>
protected sealed class EntityDictionary : ConcurrentDictionary<string, T>
{
}

/// <summary>
/// Using a concurrent dictionary to store entities in memory.
/// </summary>
private readonly EntityDictionary _entities;
#pragma warning disable CA1051 // Do not declare visible instance fields
protected readonly EntityDictionary _entities;
#pragma warning restore CA1051 // Do not declare visible instance fields

/// <summary>
/// The file path to store entities on disk.
Expand Down Expand Up @@ -164,3 +167,28 @@ private EntityDictionary Load(FileInfo fileInfo)
}
}
}

/// <summary>
/// Specialization of FileSystemContext<T> for CopilotChatMessage.
/// </summary>
public class FileSystemCopilotChatMessageContext : FileSystemContext<CopilotChatMessage>, ICopilotChatMessageStorageContext
{
/// <summary>
/// Initializes a new instance of the CosmosDbContext class.
/// </summary>
/// <param name="connectionString">The CosmosDB connection string.</param>
/// <param name="database">The CosmosDB database name.</param>
/// <param name="container">The CosmosDB container name.</param>
public FileSystemCopilotChatMessageContext(FileInfo filePath) :
base(filePath)
{
}

/// <inheritdoc/>
public Task<IEnumerable<CopilotChatMessage>> QueryEntitiesAsync(Func<CopilotChatMessage, bool> predicate, int skip, int count)
{
return Task.Run<IEnumerable<CopilotChatMessage>>(
() => this._entities.Values
.Where(predicate).OrderByDescending(m => m.Timestamp).Skip(skip).Take(count));
}
}
17 changes: 17 additions & 0 deletions webapi/Storage/IStorageContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using CopilotChat.WebApi.Models.Storage;

namespace CopilotChat.WebApi.Storage;

Expand All @@ -13,6 +14,7 @@ public interface IStorageContext<T> where T : IStorageEntity
{
/// <summary>
/// Query entities in the storage context.
/// <param name="predicate">Predicate that needs to evaluate to true for a particular entryto be returned.</param>
/// </summary>
Task<IEnumerable<T>> QueryEntitiesAsync(Func<T, bool> predicate);

Expand Down Expand Up @@ -42,3 +44,18 @@ public interface IStorageContext<T> where T : IStorageEntity
/// <param name="entity">The entity to be deleted from the context.</param>
Task DeleteAsync(T entity);
}

/// <summary>
/// Specialization of IStorageContext<T> for CopilotChatMessage.
/// </summary>
public interface ICopilotChatMessageStorageContext : IStorageContext<CopilotChatMessage>
{
/// <summary>
/// Query entities in the storage context.
/// </summary>
/// <param name="predicate">Predicate that needs to evaluate to true for a particular entryto be returned.</param>
/// <param name="skip">Number of messages to skip before starting to return messages.</param>
/// <param name="count">The number of messages to return. -1 returns all messages.</param>
/// <returns>A list of ChatMessages matching the given chatId sorted from most recent to oldest.</returns>
Task<IEnumerable<CopilotChatMessage>> QueryEntitiesAsync(Func<CopilotChatMessage, bool> predicate, int skip = 0, int count = -1);
}
28 changes: 28 additions & 0 deletions webapi/Storage/Repository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using CopilotChat.WebApi.Models.Storage;

namespace CopilotChat.WebApi.Storage;

Expand Down Expand Up @@ -70,3 +71,30 @@ public Task UpsertAsync(T entity)
return this.StorageContext.UpsertAsync(entity);
}
}

/// <summary>
/// Specialization of Repository<T> for CopilotChatMessage.
/// </summary>
public class CopilotChatMessageRepository : Repository<CopilotChatMessage>
{
private readonly ICopilotChatMessageStorageContext _messageStorageContext;

public CopilotChatMessageRepository(ICopilotChatMessageStorageContext storageContext)
: base(storageContext)
{
this._messageStorageContext = storageContext;
}

/// <summary>
/// Finds chat messages matching a predicate.
/// </summary>
/// <param name="predicate">Predicate that needs to evaluate to true for a particular entryto be returned.</param>
/// <param name="skip">Number of messages to skip before starting to return messages.</param>
/// <param name="count">The number of messages to return. -1 returns all messages.</param>
/// <returns>A list of ChatMessages matching the given chatId sorted from most recent to oldest.</returns>
public async Task<IEnumerable<CopilotChatMessage>> QueryEntitiesAsync(Func<CopilotChatMessage, bool> predicate, int skip = 0, int count = -1)
{
return await Task.Run<IEnumerable<CopilotChatMessage>>(
() => this._messageStorageContext.QueryEntitiesAsync(predicate, skip, count));
}
}
19 changes: 18 additions & 1 deletion webapi/Storage/VolatileContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Diagnostics;
using System.Linq;
using System.Threading.Tasks;
using CopilotChat.WebApi.Models.Storage;

namespace CopilotChat.WebApi.Storage;

Expand All @@ -18,7 +19,9 @@ public class VolatileContext<T> : IStorageContext<T> where T : IStorageEntity
/// <summary>
/// Using a concurrent dictionary to store entities in memory.
/// </summary>
private readonly ConcurrentDictionary<string, T> _entities;
#pragma warning disable CA1051 // Do not declare visible instance fields
protected readonly ConcurrentDictionary<string, T> _entities;
#pragma warning restore CA1051 // Do not declare visible instance fields

/// <summary>
/// Initializes a new instance of the InMemoryContext class.
Expand Down Expand Up @@ -94,3 +97,17 @@ private string GetDebuggerDisplay()
return this.ToString() ?? string.Empty;
}
}

/// <summary>
/// Specialization of VolatileContext<T> for CopilotChatMessage.
/// </summary>
public class VolatileCopilotChatMessageContext : VolatileContext<CopilotChatMessage>, ICopilotChatMessageStorageContext
{
/// <inheritdoc/>
public Task<IEnumerable<CopilotChatMessage>> QueryEntitiesAsync(Func<CopilotChatMessage, bool> predicate, int skip, int count)
{
return Task.Run<IEnumerable<CopilotChatMessage>>(
() => this._entities.Values
.Where(predicate).OrderByDescending(m => m.Timestamp).Skip(skip).Take(count));
}
}
4 changes: 2 additions & 2 deletions webapp/src/libs/services/ChatService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ export class ChatService extends BaseService {

public getChatMessagesAsync = async (
chatId: string,
startIdx: number,
skip: number,
count: number,
accessToken: string,
): Promise<IChatMessage[]> => {
const result = await this.getResponseAsync<IChatMessage[]>(
{
commandPath: `chats/${chatId}/messages?startIdx=${startIdx}&count=${count}`,
commandPath: `chats/${chatId}/messages?skip=${skip}&count=${count}`,
method: 'GET',
},
accessToken,
Expand Down

0 comments on commit 5ea4120

Please sign in to comment.