diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 480fd4565..7e90fa2f0 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -11,6 +11,7 @@ from langchain_core.messages import AIMessageChunk from langchain_core.runnables import ConfigurableFieldSpec from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.outputs.generation import GenerationChunk try: from jupyterlab_collaborative_chat.ychat import YChat @@ -62,27 +63,29 @@ def create_llm_chain( ) self.llm_chain = runnable - def _start_stream(self, human_msg: HumanChatMessage) -> str: + def _start_stream(self, human_msg: HumanChatMessage, chat: YChat | None) -> str: """ Sends an `agent-stream` message to indicate the start of a response stream. Returns the ID of the message, denoted as the `stream_id`. """ - stream_id = uuid4().hex - stream_msg = AgentStreamMessage( - id=stream_id, - time=time.time(), - body="", - reply_to=human_msg.id, - persona=self.persona, - complete=False, - ) - - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(stream_msg) - break + if chat is not None: + stream_id = self.write_message(chat, "") + else: + stream_id = uuid4().hex + stream_msg = AgentStreamMessage( + id=stream_id, + time=time.time(), + body="", + reply_to=human_msg.id, + persona=self.persona, + complete=False, + ) + for handler in self._root_chat_handlers.values(): + if not handler: + continue + + handler.broadcast_message(stream_msg) + break return stream_id @@ -92,7 +95,7 @@ def _send_stream_chunk(self, stream_id: str, content: str, chat: YChat | None, c appended to an existing `agent-stream` message with ID `stream_id`. """ if chat is not None: - self.write_message(chat, content) + self.write_message(chat, content, stream_id) else: stream_chunk_msg = AgentStreamChunkMessage( id=stream_id, content=content, stream_complete=complete @@ -121,11 +124,11 @@ async def process_message(self, message: HumanChatMessage, chat: YChat | None): # when receiving the first chunk, close the pending message and # start the stream. self.close_pending(pending_message) - stream_id = self._start_stream(human_msg=message) + stream_id = self._start_stream(message, chat) received_first_chunk = True if isinstance(chunk, AIMessageChunk): - self._send_stream_chunk(stream_id, chunk.content) + self._send_stream_chunk(stream_id, chunk.content, chat) elif isinstance(chunk, str): self._send_stream_chunk(stream_id, chunk, chat) else: diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 128263726..201291019 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -229,6 +229,10 @@ def initialize(self): listener=self.connect_chat ) + # Keep the message indexes to avoid extra computation looking for a message when + # updating it. + self.messages_indexes = {} + async def connect_chat(self, logger: EventLogger, schema_id: str, data: dict) -> None: if data["room"].startswith("text:chat:") \ and data["action"] == "initialize"\ @@ -299,7 +303,7 @@ async def _route(self, message: HumanChatMessage, chat: YChat): command_readable = "Default" if command == "default" else command self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.") - def write_message(self, chat: YChat, body: str) -> None: + def write_message(self, chat: YChat, body: str, id: str | None = None) -> str: BOT["avatar_url"] = url_path_join( self.settings.get("base_url", "/"), "api/ai/static/jupyternaut.svg" @@ -310,14 +314,21 @@ def write_message(self, chat: YChat, body: str) -> None: else: BOT["username"] = bot["username"] - chat.add_message({ + index = self.messages_indexes[id] if id else None + id = id if id else str(uuid.uuid4()) + new_index = chat.set_message({ "type": "msg", "body": body, - "id": str(uuid.uuid4()), + "id": id if id else str(uuid.uuid4()), "time": time.time(), "sender": BOT["username"], "raw_time": False - }) + }, index, True) + + if new_index != index: + self.messages_indexes[id] = new_index + + return id def initialize_settings(self): start = time.time() diff --git a/packages/jupyter-ai/package.json b/packages/jupyter-ai/package.json index a1f8029da..cc404aaa3 100644 --- a/packages/jupyter-ai/package.json +++ b/packages/jupyter-ai/package.json @@ -61,7 +61,7 @@ "dependencies": { "@emotion/react": "^11.10.5", "@emotion/styled": "^11.10.5", - "@jupyter/chat": "^0.3.1", + "@jupyter/chat": "^0.4.0", "@jupyter/collaboration": "^1", "@jupyterlab/application": "^4.2.0", "@jupyterlab/apputils": "^4.2.0", diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 9b1d3d996..2353343e3 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "typing_extensions>=4.5.0", "traitlets>=5.0", "deepmerge>=1.0", + "jupyterlab-collaborative-chat", ] dynamic = ["version", "description", "authors", "urls", "keywords"] @@ -56,8 +57,6 @@ dev = ["jupyter_ai_magics[dev]"] all = ["jupyter_ai_magics[all]", "pypdf", "arxiv"] -collaborative = ["jupyterlab-collaborative-chat"] - [tool.hatch.version] source = "nodejs" diff --git a/yarn.lock b/yarn.lock index 49d9b3304..5d45cf685 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2220,7 +2220,7 @@ __metadata: "@babel/preset-env": ^7.0.0 "@emotion/react": ^11.10.5 "@emotion/styled": ^11.10.5 - "@jupyter/chat": ^0.3.1 + "@jupyter/chat": ^0.4.0 "@jupyter/collaboration": ^1 "@jupyterlab/application": ^4.2.0 "@jupyterlab/apputils": ^4.2.0 @@ -2286,9 +2286,9 @@ __metadata: languageName: unknown linkType: soft -"@jupyter/chat@npm:^0.3.1": - version: 0.3.1 - resolution: "@jupyter/chat@npm:0.3.1" +"@jupyter/chat@npm:^0.4.0": + version: 0.4.0 + resolution: "@jupyter/chat@npm:0.4.0" dependencies: "@emotion/react": ^11.10.5 "@emotion/styled": ^11.10.5 @@ -2306,7 +2306,7 @@ __metadata: clsx: ^2.1.0 react: ^18.2.0 react-dom: ^18.2.0 - checksum: 92d1f6d6d3083be2a8c2309fc37eb3bb39bc184e1c41d18f79d51f87e54d7badba1495c5943916e7a21fc65414d6100860b044a14554add8192b545fca2748dd + checksum: 6e309c8e70cf480103eb26f3109da417c58d2e6844d5e56e63feabf71926f9dba6f9bc85caff765dfc574a8fd7ed803a8c03e5d812c28568dcf6ec918bbd2e66 languageName: node linkType: hard