From aa3da2d9b31bdd32a3e0af7535e7233f1df15181 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Wed, 10 Jul 2024 16:39:38 +0100 Subject: [PATCH] Remove pending message on error --- .../jupyter_ai/chat_handlers/base.py | 7 ++- .../jupyter_ai/chat_handlers/default.py | 51 +++++++++---------- packages/jupyter-ai/jupyter_ai/models.py | 1 + 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index d3c4e497d..a60317209 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -289,6 +289,8 @@ def close_pending(self, pending_msg: PendingMessage): handler.broadcast_message(close_pending_msg) break + pending_msg.closed = True + @contextlib.contextmanager def pending(self, text: str, ellipsis: bool = True): """ @@ -297,9 +299,10 @@ def pending(self, text: str, ellipsis: bool = True): """ pending_msg = self.start_pending(text, ellipsis=ellipsis) try: - yield + yield pending_msg finally: - self.close_pending(pending_msg) + if not pending_msg.closed: + self.close_pending(pending_msg) def get_llm_chain(self): lm_provider = self.config_manager.lm_provider diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index e0d923e77..4aebdde80 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -97,29 +97,28 @@ async def process_message(self, message: HumanChatMessage): received_first_chunk = False # start with a pending message - pending_message = self.start_pending("Generating response") - - # stream response in chunks. this works even if a provider does not - # implement streaming, as `astream()` defaults to yielding `_call()` - # when `_stream()` is not implemented on the LLM class. - async for chunk in self.llm_chain.astream( - {"input": message.body}, - config={"configurable": {"session_id": "static_session"}}, - ): - if not received_first_chunk: - # when receiving the first chunk, close the pending message and - # start the stream. - self.close_pending(pending_message) - stream_id = self._start_stream(human_msg=message) - received_first_chunk = True - - if isinstance(chunk, AIMessageChunk): - self._send_stream_chunk(stream_id, chunk.content) - elif isinstance(chunk, str): - self._send_stream_chunk(stream_id, chunk) - else: - self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") - break - - # complete stream after all chunks have been streamed - self._send_stream_chunk(stream_id, "", complete=True) + with self.pending("Generating response") as pending_message: + # stream response in chunks. this works even if a provider does not + # implement streaming, as `astream()` defaults to yielding `_call()` + # when `_stream()` is not implemented on the LLM class. + async for chunk in self.llm_chain.astream( + {"input": message.body}, + config={"configurable": {"session_id": "static_session"}}, + ): + if not received_first_chunk: + # when receiving the first chunk, close the pending message and + # start the stream. + self.close_pending(pending_message) + stream_id = self._start_stream(human_msg=message) + received_first_chunk = True + + if isinstance(chunk, AIMessageChunk): + self._send_stream_chunk(stream_id, chunk.content) + elif isinstance(chunk, str): + self._send_stream_chunk(stream_id, chunk) + else: + self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") + break + + # complete stream after all chunks have been streamed + self._send_stream_chunk(stream_id, "", complete=True) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index ea541568b..331f4cd94 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -103,6 +103,7 @@ class PendingMessage(BaseModel): body: str persona: Persona ellipsis: bool = True + closed: bool = False class ClosePendingMessage(BaseModel):