diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 106c3225d..fc09d8f19 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -16,6 +16,7 @@ Union, cast, ) +from typing import get_args as get_type_args from uuid import uuid4 from dask.distributed import Client as DaskClient @@ -26,6 +27,7 @@ ChatMessage, ClosePendingMessage, HumanChatMessage, + Message, PendingMessage, ) from jupyter_ai_magics import Persona @@ -261,6 +263,26 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): ) self.reply(response, message) + def broadcast_message(self, message: Message): + """ + Broadcasts a message to all WebSocket connections. If there are no + WebSocket connections and the message is a chat message, this method + directly appends to `self.chat_history`. + """ + broadcast = False + for websocket in self._root_chat_handlers.values(): + if not websocket: + continue + + websocket.broadcast_message(message) + broadcast = True + break + + if not broadcast: + if isinstance(message, get_type_args(ChatMessage)): + cast(ChatMessage, message) + self._chat_history.append(message) + def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): """ Sends an agent message, usually in response to a received @@ -274,12 +296,7 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): persona=self.persona, ) - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(agent_msg) - break + self.broadcast_message(agent_msg) @property def persona(self): @@ -308,12 +325,7 @@ def start_pending( ellipsis=ellipsis, ) - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(pending_msg) - break + self.broadcast_message(pending_msg) return pending_msg def close_pending(self, pending_msg: PendingMessage): @@ -327,13 +339,7 @@ def close_pending(self, pending_msg: PendingMessage): id=pending_msg.id, ) - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(close_pending_msg) - break - + self.broadcast_message(close_pending_msg) pending_msg.closed = True @contextlib.contextmanager @@ -464,6 +470,4 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non persona=self.persona, ) - self._chat_history.append(help_message) - for websocket in self._root_chat_handlers.values(): - websocket.write_message(help_message.json()) + self.broadcast_message(help_message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index a05bc3e57..d5b0ab6c7 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -1,4 +1,4 @@ -from jupyter_ai.models import ClearMessage +from jupyter_ai.models import ClearRequest from .base import BaseChatHandler, SlashCommandRoutingType @@ -17,10 +17,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) async def process_message(self, _): - # Clear chat + # Clear chat by triggering `RootChatHandler.on_clear_request()`. for handler in self._root_chat_handlers.values(): if not handler: continue - handler.broadcast_message(ClearMessage()) + handler.on_clear_request(ClearRequest()) break diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index ed0e50a24..e22d2ae3b 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -38,7 +38,6 @@ ListSlashCommandsResponse, Message, PendingMessage, - StopMessage, StopRequest, UpdateConfigRequest, ) @@ -126,6 +125,13 @@ def pending_messages(self) -> List[PendingMessage]: def pending_messages(self, new_pending_messages): self.settings["pending_messages"] = new_pending_messages + @property + def cleared_message_ids(self) -> Set[str]: + """Set of `HumanChatMessage.id` that were cleared via `ClearRequest`.""" + if "cleared_message_ids" not in self.settings: + self.settings["cleared_message_ids"] = set() + return self.settings["cleared_message_ids"] + def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) @@ -227,12 +233,9 @@ def broadcast_message(self, message: Message): # do not broadcast agent messages that are replying to cleared human message if ( isinstance(message, (AgentChatMessage, AgentStreamMessage)) - and message.reply_to + and message.reply_to in self.cleared_message_ids ): - if message.reply_to not in [ - m.id for m in self.chat_history if isinstance(m, HumanChatMessage) - ]: - return + return self.log.debug("Broadcasting message: %s to all clients...", message) client_ids = self.root_chat_handlers.keys() @@ -269,14 +272,6 @@ def broadcast_message(self, message: Message): self.pending_messages = list( filter(lambda m: m.id != message.id, self.pending_messages) ) - elif isinstance(message, ClearMessage): - if message.targets: - self._clear_chat_history_at(message.targets) - else: - self.chat_history.clear() - self.pending_messages.clear() - self.llm_chat_memory.clear() - self.settings["jai_chat_handlers"]["default"].send_help_message() async def on_message(self, message): self.log.debug("Message received: %s", message) @@ -294,22 +289,7 @@ async def on_message(self, message): return if isinstance(request, ClearRequest): - if not request.target: - targets = None - elif request.after: - target_msg = None - for msg in self.chat_history: - if msg.id == request.target: - target_msg = msg - if target_msg: - targets = [ - msg.id - for msg in self.chat_history - if msg.time >= target_msg.time and msg.type == "human" - ] - else: - targets = [request.target] - self.broadcast_message(ClearMessage(targets=targets)) + self.on_clear_request(request) return if isinstance(request, StopRequest): @@ -340,6 +320,46 @@ async def on_message(self, message): # as a distinct concurrent task. self.loop.create_task(self._route(chat_message)) + def on_clear_request(self, request: ClearRequest): + target = request.target + + # if no target, clear all messages + if not target: + for msg in self.chat_history: + if msg.type == "human": + self.cleared_message_ids.add(msg.id) + + self.chat_history.clear() + self.pending_messages.clear() + self.llm_chat_memory.clear() + self.broadcast_message(ClearMessage()) + self.settings["jai_chat_handlers"]["default"].send_help_message() + return + + # otherwise, clear a single message + self.cleared_message_ids.add(target) + for msg in self.chat_history[::-1]: + # interrupt the single message + if msg.type == "agent-stream" and getattr(msg, "reply_to", None) == target: + try: + self.message_interrupted[msg.id].set() + except KeyError: + # do nothing if the message was already interrupted + # or stream got completed (thread-safe way!) + pass + break + + self.chat_history[:] = [ + msg + for msg in self.chat_history + if msg.id != target and getattr(msg, "reply_to", None) != target + ] + self.pending_messages[:] = [ + msg for msg in self.pending_messages if msg.reply_to != target + ] + self.llm_chat_memory.clear([target]) + self.broadcast_message(ClearMessage(targets=[target])) + def on_stop_request(self): # set of message IDs that were submitted by this user, determined by the # username associated with this WebSocket connection. @@ -390,37 +410,6 @@ async def _route(self, message): command_readable = "Default" if command == "default" else command self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.") - def _clear_chat_history_at(self, msg_ids: List[str]): - """ - Clears conversation exchanges associated with list of human message IDs. - """ - messages_to_interrupt = [ - msg - for msg in self.chat_history - if ( - msg.type == "agent-stream" - and getattr(msg, "reply_to", None) in msg_ids - and not msg.complete - ) - ] - for msg in messages_to_interrupt: - try: - self.message_interrupted[msg.id].set() - except KeyError: - # do nothing if the message was already interrupted - # or stream got completed (thread-safe way!) - pass - - self.chat_history[:] = [ - msg - for msg in self.chat_history - if msg.id not in msg_ids and getattr(msg, "reply_to", None) not in msg_ids - ] - self.pending_messages[:] = [ - msg for msg in self.pending_messages if msg.reply_to not in msg_ids - ] - self.llm_chat_memory.clear(msg_ids) - def on_close(self): self.log.debug("Disconnecting client with user %s", self.client_id) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index ba292e43a..48dbe6193 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -52,18 +52,13 @@ class StopRequest(BaseModel): class ClearRequest(BaseModel): - type: Literal["clear"] + type: Literal["clear"] = "clear" target: Optional[str] """ Message ID of the HumanChatMessage to delete an exchange at. If not provided, this requests the backend to clear all messages. """ - after: Optional[bool] - """ - Whether to clear target and all subsequent exchanges. - """ - class ChatUser(BaseModel): # User ID assigned by IdentityProvider. @@ -148,13 +143,6 @@ class HumanChatMessage(BaseModel): client: ChatClient -class StopMessage(BaseModel): - """Message broadcast to clients after receiving a request to stop stop streaming or generating response""" - - type: Literal["stop"] = "stop" - target: str - - class ClearMessage(BaseModel): type: Literal["clear"] = "clear" targets: Optional[List[str]] = None diff --git a/packages/jupyter-ai/src/components/chat-messages/chat-message-delete.tsx b/packages/jupyter-ai/src/components/chat-messages/chat-message-delete.tsx index b91e15b93..d6fc691bd 100644 --- a/packages/jupyter-ai/src/components/chat-messages/chat-message-delete.tsx +++ b/packages/jupyter-ai/src/components/chat-messages/chat-message-delete.tsx @@ -15,8 +15,7 @@ type DeleteButtonProps = { export function ChatMessageDelete(props: DeleteButtonProps): JSX.Element { const request: AiService.ClearRequest = { type: 'clear', - target: props.message.id, - after: false + target: props.message.id }; return (