Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to delete messages + start new chat session #951

Merged
merged 26 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
67191c1
add ui components
michaelchia Aug 12, 2024
348d857
temp add help message to new chat
michaelchia Aug 12, 2024
b80de74
at to target
michaelchia Aug 13, 2024
45bb813
Merge branch 'main' into delete-msg-button
michaelchia Aug 13, 2024
cf30c7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
de04bf7
Merge branch 'main' into delete-msg-button
michaelchia Aug 17, 2024
a5b52db
Merge branch 'main' into delete-msg-button
michaelchia Aug 27, 2024
33b6237
broadcast ClearMessage sends help message
michaelchia Aug 17, 2024
d866e52
clear llm_chat_memory
michaelchia Aug 27, 2024
c20d1af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
d13345e
typo
michaelchia Aug 27, 2024
425cd34
Update chat welcome message
srdas Aug 28, 2024
d6d7273
improve docstring
michaelchia Aug 30, 2024
ada86bf
type typo
michaelchia Aug 30, 2024
341b109
use tooltippedbutton + do not show new chat button on welcome
michaelchia Aug 30, 2024
6b123bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 30, 2024
483f1c5
Merge branch 'main' into delete-msg-button
michaelchia Aug 30, 2024
4663466
Update Playwright Snapshots
github-actions[bot] Aug 30, 2024
a42f582
fix not adding uncleared pending messages to memory
michaelchia Aug 30, 2024
9bc935f
reimplement to delete specific message exchange
michaelchia Sep 4, 2024
f41afd8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2024
5b574e4
fix lint
michaelchia Sep 4, 2024
be19f1d
refine docs
michaelchia Sep 6, 2024
7b185a1
keep list of cleared messages
michaelchia Sep 6, 2024
cdf864b
update doc
michaelchia Sep 6, 2024
f4d996e
support clearing all subsequent exchanges
michaelchia Sep 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,4 @@ async def process_message(self, _):
continue

handler.broadcast_message(ClearMessage())
self._chat_history.clear()
self.llm_chat_memory.clear()
break

# re-send help message
self.send_help_message()
42 changes: 41 additions & 1 deletion packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
ChatMessage,
ChatRequest,
ChatUser,
ClearMessage,
ClearRequest,
ClosePendingMessage,
ConnectionMessage,
HumanChatMessage,
Expand All @@ -40,6 +42,8 @@
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
from jupyter_ai_magics.providers import BaseProvider

from .history import BoundChatHistory


class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""
Expand Down Expand Up @@ -98,6 +102,10 @@ def chat_history(self) -> List[ChatMessage]:
def chat_history(self, new_history):
self.settings["chat_history"] = new_history

@property
def llm_chat_memory(self) -> "BoundChatHistory":
return self.settings["llm_chat_memory"]

@property
def loop(self) -> AbstractEventLoop:
return self.settings["jai_event_loop"]
Expand Down Expand Up @@ -246,17 +254,33 @@ def broadcast_message(self, message: Message):
self.pending_messages = list(
filter(lambda m: m.id != message.id, self.pending_messages)
)
elif isinstance(message, ClearMessage):
if message.target:
self._clear_chat_history_at(message.target)
else:
self.chat_history.clear()
self.pending_messages.clear()
self.llm_chat_memory.clear()
self.settings["jai_chat_handlers"]["default"].send_help_message()

async def on_message(self, message):
self.log.debug("Message received: %s", message)

try:
message = json.loads(message)
chat_request = ChatRequest(**message)
if message.get("type") == "clear":
request = ClearRequest(**message)
else:
request = ChatRequest(**message)
except ValidationError as e:
self.log.error(e)
return

if isinstance(request, ClearRequest):
self.broadcast_message(ClearMessage(target=request.target))
return

chat_request = request
message_body = chat_request.prompt
if chat_request.selection:
message_body += f"\n\n```\n{chat_request.selection.source}\n```\n"
Expand Down Expand Up @@ -302,6 +326,22 @@ async def _route(self, message):
command_readable = "Default" if command == "default" else command
self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.")

def _clear_chat_history_at(self, msg_id: str):
"""Clears the chat history at a specific message ID."""
michaelchia marked this conversation as resolved.
Show resolved Hide resolved
target_msg = None
for msg in self.chat_history:
if msg.id == msg_id:
target_msg = msg

if target_msg is not None:
self.chat_history[:] = [
msg for msg in self.chat_history if msg.time < target_msg.time
]
self.pending_messages[:] = [
msg for msg in self.pending_messages if msg.time < target_msg.time
]
self.llm_chat_memory.clear(target_msg.time)

def on_close(self):
self.log.debug("Disconnecting client with user %s", self.client_id)

Expand Down
35 changes: 31 additions & 4 deletions packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from .models import HumanChatMessage

MESSAGE_TIME_KEY = "_jupyter_ai_msg_time"


class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
"""
Expand All @@ -19,6 +21,7 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel):

k: int
clear_time: float = 0.0
clear_after: float = 0.0
_all_messages: List[BaseMessage] = PrivateAttr(default_factory=list)

@property
Expand All @@ -30,15 +33,33 @@ async def aget_messages(self) -> List[BaseMessage]:

def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
# Adds a timestamp to the message as a fallback if message was not
# added not using WrappedBoundedChatHistory.
# In such a case, it possible that this message may be cleared even if
# the target clear message is after this one.
# This will occur if the current time is greater than the last_human_msg time of
# a future message that was added using WrappedBoundedChatHistory.
message.additional_kwargs[MESSAGE_TIME_KEY] = message.additional_kwargs.get(
MESSAGE_TIME_KEY, time.time()
)
michaelchia marked this conversation as resolved.
Show resolved Hide resolved
self._all_messages.append(message)

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Add messages to the store"""
self.add_messages(messages)

def clear(self) -> None:
self._all_messages = []
def clear(self, after: float = 0.0) -> None:
"""Clear all messages after the given time"""
if after:
self._all_messages = [
m
for m in self._all_messages
if m.additional_kwargs[MESSAGE_TIME_KEY] < after
]
else:
self._all_messages = []
self.clear_time = time.time()
self.clear_after = after

async def aclear(self) -> None:
self.clear()
Expand Down Expand Up @@ -73,8 +94,14 @@ def messages(self) -> List[BaseMessage]:
return self.history.messages

def add_message(self, message: BaseMessage) -> None:
"""Prevent adding messages to the store if clear was triggered."""
if self.last_human_msg.time > self.history.clear_time:
# prevent adding pending messages to the store if clear was triggered.
# if partial clearing, allow adding pending messages that were not cleared.
if (
self.last_human_msg.time
> self.history.clear_time | self.last_human_msg.time
< self.history.clear_after
):
message.additional_kwargs[MESSAGE_TIME_KEY] = self.last_human_msg.time
michaelchia marked this conversation as resolved.
Show resolved Hide resolved
self.history.add_message(message)

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
Expand Down
14 changes: 14 additions & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ class ChatRequest(BaseModel):
selection: Optional[Selection]


class ClearRequest(BaseModel):
type: Literal["clear"]
target: Optional[str]
"""
Message ID of the ChatMessage to clear at and all messages after.
If not provided, this requests the backend to clear all messages.
"""
michaelchia marked this conversation as resolved.
Show resolved Hide resolved


class ChatUser(BaseModel):
# User ID assigned by IdentityProvider.
username: str
Expand Down Expand Up @@ -105,6 +114,11 @@ class HumanChatMessage(BaseModel):

class ClearMessage(BaseModel):
type: Literal["clear"] = "clear"
target: Optional[str] = None
"""
Message ID of the ChatMessage to clear at and all messages after.
If not provided, this instructs the frontend to clear all messages.
"""
michaelchia marked this conversation as resolved.
Show resolved Hide resolved


class PendingMessage(BaseModel):
Expand Down
20 changes: 17 additions & 3 deletions packages/jupyter-ai/src/chat_handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export class ChatHandler implements IDisposable {
* Sends a message across the WebSocket. Promise resolves to the message ID
* when the server sends the same message back, acknowledging receipt.
*/
public sendMessage(message: AiService.ChatRequest): Promise<string> {
public sendMessage(message: AiService.Request): Promise<string> {
return new Promise(resolve => {
this._socket?.send(JSON.stringify(message));
this._sendResolverQueue.push(resolve);
Expand Down Expand Up @@ -132,8 +132,22 @@ export class ChatHandler implements IDisposable {
case 'connection':
break;
case 'clear':
this._messages = [];
this._pendingMessages = [];
if (newMessage.target) {
const targetMsg = this._messages.find(
m => m.id === newMessage.target
);
if (targetMsg) {
this._messages = this._messages.filter(
msg => msg.time < targetMsg.time
);
this._pendingMessages = this._pendingMessages.filter(
msg => msg.time < targetMsg.time
);
}
} else {
this._messages = [];
this._pendingMessages = [];
}
break;
case 'pending':
this._pendingMessages = [...this._pendingMessages, newMessage];
Expand Down
13 changes: 13 additions & 0 deletions packages/jupyter-ai/src/components/chat-messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@ import { AiService } from '../handler';
import { RendermimeMarkdown } from './rendermime-markdown';
import { useCollaboratorsContext } from '../contexts/collaborators-context';
import { ChatMessageMenu } from './chat-messages/chat-message-menu';
import { ChatMessageDelete } from './chat-messages/chat-message-delete';
import { ChatHandler } from '../chat_handler';
import { IJaiMessageFooter } from '../tokens';

type ChatMessagesProps = {
rmRegistry: IRenderMimeRegistry;
messages: AiService.ChatMessage[];
chatHandler: ChatHandler;
messageFooter: IJaiMessageFooter | null;
};

type ChatMessageHeaderProps = {
message: AiService.ChatMessage;
chatHandler: ChatHandler;
timestamp: string;
sx?: SxProps<Theme>;
};
Expand Down Expand Up @@ -113,6 +117,7 @@ export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element {
const shouldShowMenu =
props.message.type === 'agent' ||
(props.message.type === 'agent-stream' && props.message.complete);
const shouldShowDelete = props.message.type === 'human';

return (
<Box
Expand Down Expand Up @@ -154,6 +159,13 @@ export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element {
sx={{ marginLeft: '4px', marginRight: '-8px' }}
/>
)}
{shouldShowDelete && (
<ChatMessageDelete
message={props.message}
chatHandler={props.chatHandler}
sx={{ marginLeft: '4px', marginRight: '-8px' }}
/>
)}
</Box>
</Box>
</Box>
Expand Down Expand Up @@ -208,6 +220,7 @@ export function ChatMessages(props: ChatMessagesProps): JSX.Element {
<ChatMessageHeader
message={message}
timestamp={timestamps[message.id]}
chatHandler={props.chatHandler}
sx={{ marginBottom: 3 }}
/>
<RendermimeMarkdown
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import React from 'react';
import { SxProps } from '@mui/material';
import { Close } from '@mui/icons-material';

import { AiService } from '../../handler';
import { ChatHandler } from '../../chat_handler';
import { TooltippedIconButton } from '../mui-extras/tooltipped-icon-button';

type DeleteButtonProps = {
message: AiService.ChatMessage;
chatHandler: ChatHandler;
sx?: SxProps;
};

export function ChatMessageDelete(props: DeleteButtonProps): JSX.Element {
const request: AiService.ClearRequest = {
type: 'clear',
target: props.message.id
};
return (
<TooltippedIconButton
onClick={() => props.chatHandler.sendMessage(request)}
sx={props.sx}
tooltip="Delete this and all future messages"
>
<Close />
</TooltippedIconButton>
);
}

export default ChatMessageDelete;
Loading
Loading