From 12a6e8ecc519b5896fe262844697bbec20877a5a Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 17 Oct 2024 08:03:09 -0700 Subject: [PATCH 1/8] remove unused `StopMessage` backend model --- packages/jupyter-ai/jupyter_ai/models.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index ba292e43a..1b53066e0 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -148,13 +148,6 @@ class HumanChatMessage(BaseModel): client: ChatClient -class StopMessage(BaseModel): - """Message broadcast to clients after receiving a request to stop stop streaming or generating response""" - - type: Literal["stop"] = "stop" - target: str - - class ClearMessage(BaseModel): type: Literal["clear"] = "clear" targets: Optional[List[str]] = None From 226acafb469f43dcbaa9f7ab5ab487fff7d9878a Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 17 Oct 2024 08:04:39 -0700 Subject: [PATCH 2/8] fixup --- packages/jupyter-ai/jupyter_ai/handlers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index ed0e50a24..cb03176e0 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -38,7 +38,6 @@ ListSlashCommandsResponse, Message, PendingMessage, - StopMessage, StopRequest, UpdateConfigRequest, ) From 0b7afbed7572d9477303cd376ec03ded379af34f Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Fri, 18 Oct 2024 07:30:39 -0700 Subject: [PATCH 3/8] remove unused `after` attr on `ClearRequest` --- packages/jupyter-ai/jupyter_ai/handlers.py | 17 +++-------------- packages/jupyter-ai/jupyter_ai/models.py | 6 ------ .../chat-messages/chat-message-delete.tsx | 3 +-- packages/jupyter-ai/src/handler.ts | 1 - 4 files changed, 4 insertions(+), 23 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index cb03176e0..f5e83e71c 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -293,21 +293,10 @@ async def on_message(self, message): return if isinstance(request, ClearRequest): - if not request.target: - targets = None - elif request.after: - target_msg = None - for msg in self.chat_history: - if msg.id == request.target: - target_msg = msg - if target_msg: - targets = [ - msg.id - for msg in self.chat_history - if msg.time >= target_msg.time and msg.type == "human" - ] - else: + if request.target: targets = [request.target] + else: + targets = None self.broadcast_message(ClearMessage(targets=targets)) return diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 1b53066e0..d0e17a240 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -59,12 +59,6 @@ class ClearRequest(BaseModel): If not provided, this requests the backend to clear all messages. """ - after: Optional[bool] - """ - Whether to clear target and all subsequent exchanges. - """ - - class ChatUser(BaseModel): # User ID assigned by IdentityProvider. username: str diff --git a/packages/jupyter-ai/src/components/chat-messages/chat-message-delete.tsx b/packages/jupyter-ai/src/components/chat-messages/chat-message-delete.tsx index b91e15b93..d6fc691bd 100644 --- a/packages/jupyter-ai/src/components/chat-messages/chat-message-delete.tsx +++ b/packages/jupyter-ai/src/components/chat-messages/chat-message-delete.tsx @@ -15,8 +15,7 @@ type DeleteButtonProps = { export function ChatMessageDelete(props: DeleteButtonProps): JSX.Element { const request: AiService.ClearRequest = { type: 'clear', - target: props.message.id, - after: false + target: props.message.id }; return ( Date: Fri, 18 Oct 2024 08:48:09 -0700 Subject: [PATCH 4/8] unify message clearing logic into `on_clear_request()` --- .../jupyter_ai/chat_handlers/clear.py | 7 +- packages/jupyter-ai/jupyter_ai/handlers.py | 98 ++++++++++--------- packages/jupyter-ai/jupyter_ai/models.py | 2 +- 3 files changed, 54 insertions(+), 53 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index a05bc3e57..16eb67c29 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -1,5 +1,4 @@ -from jupyter_ai.models import ClearMessage - +from jupyter_ai.models import ClearRequest from .base import BaseChatHandler, SlashCommandRoutingType @@ -17,10 +16,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) async def process_message(self, _): - # Clear chat + # Clear chat by triggering `RootChatHandler.on_clear_request()`. for handler in self._root_chat_handlers.values(): if not handler: continue - handler.broadcast_message(ClearMessage()) + handler.on_clear_request(ClearRequest()) break diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index f5e83e71c..e5c191967 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -121,6 +121,13 @@ def loop(self) -> AbstractEventLoop: def pending_messages(self) -> List[PendingMessage]: return self.settings["pending_messages"] + @property + def cleared_message_ids(self) -> Set[str]: + """Set of `HumanChatMessage.id` that were cleared via `ClearRequest`.""" + if "cleared_message_ids" not in self.settings: + self.settings["cleared_message_ids"] = set() + return self.settings["cleared_message_ids"] + @pending_messages.setter def pending_messages(self, new_pending_messages): self.settings["pending_messages"] = new_pending_messages @@ -226,12 +233,9 @@ def broadcast_message(self, message: Message): # do not broadcast agent messages that are replying to cleared human message if ( isinstance(message, (AgentChatMessage, AgentStreamMessage)) - and message.reply_to + and message.reply_to in self.cleared_message_ids ): - if message.reply_to not in [ - m.id for m in self.chat_history if isinstance(m, HumanChatMessage) - ]: - return + return self.log.debug("Broadcasting message: %s to all clients...", message) client_ids = self.root_chat_handlers.keys() @@ -268,14 +272,6 @@ def broadcast_message(self, message: Message): self.pending_messages = list( filter(lambda m: m.id != message.id, self.pending_messages) ) - elif isinstance(message, ClearMessage): - if message.targets: - self._clear_chat_history_at(message.targets) - else: - self.chat_history.clear() - self.pending_messages.clear() - self.llm_chat_memory.clear() - self.settings["jai_chat_handlers"]["default"].send_help_message() async def on_message(self, message): self.log.debug("Message received: %s", message) @@ -293,11 +289,7 @@ async def on_message(self, message): return if isinstance(request, ClearRequest): - if request.target: - targets = [request.target] - else: - targets = None - self.broadcast_message(ClearMessage(targets=targets)) + self.on_clear_request(request) return if isinstance(request, StopRequest): @@ -327,6 +319,46 @@ async def on_message(self, message): # handling messages from a websocket. instead, process each message # as a distinct concurrent task. self.loop.create_task(self._route(chat_message)) + + def on_clear_request(self, request: ClearRequest): + target = request.target + + # if no target, clear all messages + if not target: + for msg in self.chat_history: + if msg.type == "human": + self.cleared_message_ids.add(msg.id) + + self.chat_history.clear() + self.pending_messages.clear() + self.llm_chat_memory.clear() + self.broadcast_message(ClearMessage()) + self.settings["jai_chat_handlers"]["default"].send_help_message() + return + + # otherwise, clear a single message + self.cleared_message_ids.add(target) + for msg in self.chat_history[::-1]: + # interrupt the single message + if (msg.type == "agent-stream" and getattr(msg, "reply_to", None) == target): + try: + self.message_interrupted[msg.id].set() + except KeyError: + # do nothing if the message was already interrupted + # or stream got completed (thread-safe way!) + pass + break + + self.chat_history[:] = [ + msg + for msg in self.chat_history + if msg.id != target and getattr(msg, "reply_to", None) != target + ] + self.pending_messages[:] = [ + msg for msg in self.pending_messages if msg.reply_to != target + ] + self.llm_chat_memory.clear([target]) + self.broadcast_message(ClearMessage(targets=[target])) def on_stop_request(self): # set of message IDs that were submitted by this user, determined by the @@ -378,36 +410,6 @@ async def _route(self, message): command_readable = "Default" if command == "default" else command self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.") - def _clear_chat_history_at(self, msg_ids: List[str]): - """ - Clears conversation exchanges associated with list of human message IDs. - """ - messages_to_interrupt = [ - msg - for msg in self.chat_history - if ( - msg.type == "agent-stream" - and getattr(msg, "reply_to", None) in msg_ids - and not msg.complete - ) - ] - for msg in messages_to_interrupt: - try: - self.message_interrupted[msg.id].set() - except KeyError: - # do nothing if the message was already interrupted - # or stream got completed (thread-safe way!) - pass - - self.chat_history[:] = [ - msg - for msg in self.chat_history - if msg.id not in msg_ids and getattr(msg, "reply_to", None) not in msg_ids - ] - self.pending_messages[:] = [ - msg for msg in self.pending_messages if msg.reply_to not in msg_ids - ] - self.llm_chat_memory.clear(msg_ids) def on_close(self): self.log.debug("Disconnecting client with user %s", self.client_id) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index d0e17a240..73026f736 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -52,7 +52,7 @@ class StopRequest(BaseModel): class ClearRequest(BaseModel): - type: Literal["clear"] + type: Literal["clear"] = "clear" target: Optional[str] """ Message ID of the HumanChatMessage to delete an exchange at. From 31b8bced6cb324a130b4409db19e093539d0b164 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Fri, 18 Oct 2024 08:48:29 -0700 Subject: [PATCH 5/8] unify message broadcast logic into `broadcast_message()` --- .../jupyter_ai/chat_handlers/base.py | 45 ++++++++++--------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 106c3225d..47dff1c07 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -27,6 +27,7 @@ ClosePendingMessage, HumanChatMessage, PendingMessage, + Message, ) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider @@ -260,6 +261,25 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): f"Sorry, an error occurred. Details below:\n\n```\n{formatted_e}\n```" ) self.reply(response, message) + + def broadcast_message(self, message: Message): + """ + Broadcasts a message to all WebSocket connections. If there are no + WebSocket connections, this method directly appends to + `self.chat_history`. + """ + broadcast = False + for websocket in self._root_chat_handlers.values(): + if not websocket: + continue + + websocket.broadcast_message(message) + broadcast = True + break + + if not broadcast: + self._chat_history.append(message) + def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): """ @@ -274,12 +294,8 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): persona=self.persona, ) - for handler in self._root_chat_handlers.values(): - if not handler: - continue + self.broadcast_message(agent_msg) - handler.broadcast_message(agent_msg) - break @property def persona(self): @@ -308,12 +324,7 @@ def start_pending( ellipsis=ellipsis, ) - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(pending_msg) - break + self.broadcast_message(pending_msg) return pending_msg def close_pending(self, pending_msg: PendingMessage): @@ -327,13 +338,7 @@ def close_pending(self, pending_msg: PendingMessage): id=pending_msg.id, ) - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(close_pending_msg) - break - + self.broadcast_message(close_pending_msg) pending_msg.closed = True @contextlib.contextmanager @@ -464,6 +469,4 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non persona=self.persona, ) - self._chat_history.append(help_message) - for websocket in self._root_chat_handlers.values(): - websocket.write_message(help_message.json()) + self.broadcast_message(help_message) \ No newline at end of file From be2fde59fb8c291bbbf507f98df223599f6710d1 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Fri, 18 Oct 2024 09:02:00 -0700 Subject: [PATCH 6/8] pre-commit --- packages/jupyter-ai/jupyter_ai/chat_handlers/base.py | 10 ++++------ packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py | 1 + packages/jupyter-ai/jupyter_ai/handlers.py | 5 ++--- packages/jupyter-ai/jupyter_ai/models.py | 1 + 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 47dff1c07..af7ca559a 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -26,8 +26,8 @@ ChatMessage, ClosePendingMessage, HumanChatMessage, - PendingMessage, Message, + PendingMessage, ) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider @@ -261,7 +261,7 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): f"Sorry, an error occurred. Details below:\n\n```\n{formatted_e}\n```" ) self.reply(response, message) - + def broadcast_message(self, message: Message): """ Broadcasts a message to all WebSocket connections. If there are no @@ -276,10 +276,9 @@ def broadcast_message(self, message: Message): websocket.broadcast_message(message) broadcast = True break - + if not broadcast: self._chat_history.append(message) - def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): """ @@ -296,7 +295,6 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): self.broadcast_message(agent_msg) - @property def persona(self): return self.config_manager.persona @@ -469,4 +467,4 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non persona=self.persona, ) - self.broadcast_message(help_message) \ No newline at end of file + self.broadcast_message(help_message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index 16eb67c29..d5b0ab6c7 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -1,4 +1,5 @@ from jupyter_ai.models import ClearRequest + from .base import BaseChatHandler, SlashCommandRoutingType diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index e5c191967..06175961b 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -319,7 +319,7 @@ async def on_message(self, message): # handling messages from a websocket. instead, process each message # as a distinct concurrent task. self.loop.create_task(self._route(chat_message)) - + def on_clear_request(self, request: ClearRequest): target = request.target @@ -340,7 +340,7 @@ def on_clear_request(self, request: ClearRequest): self.cleared_message_ids.add(target) for msg in self.chat_history[::-1]: # interrupt the single message - if (msg.type == "agent-stream" and getattr(msg, "reply_to", None) == target): + if msg.type == "agent-stream" and getattr(msg, "reply_to", None) == target: try: self.message_interrupted[msg.id].set() except KeyError: @@ -410,7 +410,6 @@ async def _route(self, message): command_readable = "Default" if command == "default" else command self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.") - def on_close(self): self.log.debug("Disconnecting client with user %s", self.client_id) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 73026f736..48dbe6193 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -59,6 +59,7 @@ class ClearRequest(BaseModel): If not provided, this requests the backend to clear all messages. """ + class ChatUser(BaseModel): # User ID assigned by IdentityProvider. username: str From 6738a22ea9835f517a723b99cb3c616bca7a6662 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Fri, 18 Oct 2024 09:50:12 -0700 Subject: [PATCH 7/8] fix typing issues raised by mypy --- packages/jupyter-ai/jupyter_ai/chat_handlers/base.py | 9 ++++++--- packages/jupyter-ai/jupyter_ai/handlers.py | 8 ++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index af7ca559a..5bbd91add 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -9,6 +9,7 @@ Awaitable, ClassVar, Dict, + get_args as get_type_args, List, Literal, Optional, @@ -265,8 +266,8 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): def broadcast_message(self, message: Message): """ Broadcasts a message to all WebSocket connections. If there are no - WebSocket connections, this method directly appends to - `self.chat_history`. + WebSocket connections and the message is a chat message, this method + directly appends to `self.chat_history`. """ broadcast = False for websocket in self._root_chat_handlers.values(): @@ -278,7 +279,9 @@ def broadcast_message(self, message: Message): break if not broadcast: - self._chat_history.append(message) + if isinstance(message, get_type_args(ChatMessage)): + cast(ChatMessage, message) + self._chat_history.append(message) def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): """ diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 06175961b..e22d2ae3b 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -121,6 +121,10 @@ def loop(self) -> AbstractEventLoop: def pending_messages(self) -> List[PendingMessage]: return self.settings["pending_messages"] + @pending_messages.setter + def pending_messages(self, new_pending_messages): + self.settings["pending_messages"] = new_pending_messages + @property def cleared_message_ids(self) -> Set[str]: """Set of `HumanChatMessage.id` that were cleared via `ClearRequest`.""" @@ -128,10 +132,6 @@ def cleared_message_ids(self) -> Set[str]: self.settings["cleared_message_ids"] = set() return self.settings["cleared_message_ids"] - @pending_messages.setter - def pending_messages(self, new_pending_messages): - self.settings["pending_messages"] = new_pending_messages - def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) From dc7e345e60db5169b8f8e4be213231ed3cebd708 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Fri, 18 Oct 2024 09:50:41 -0700 Subject: [PATCH 8/8] pre-commit --- packages/jupyter-ai/jupyter_ai/chat_handlers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 5bbd91add..fc09d8f19 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -9,7 +9,6 @@ Awaitable, ClassVar, Dict, - get_args as get_type_args, List, Literal, Optional, @@ -17,6 +16,7 @@ Union, cast, ) +from typing import get_args as get_type_args from uuid import uuid4 from dask.distributed import Client as DaskClient