diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index d3c4e497d..426c4348c 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -278,6 +278,9 @@ def close_pending(self, pending_msg: PendingMessage): """ Closes a pending message. """ + if pending_msg.closed: + return + close_pending_msg = ClosePendingMessage( id=pending_msg.id, ) @@ -289,6 +292,8 @@ def close_pending(self, pending_msg: PendingMessage): handler.broadcast_message(close_pending_msg) break + pending_msg.closed = True + @contextlib.contextmanager def pending(self, text: str, ellipsis: bool = True): """ @@ -297,9 +302,10 @@ def pending(self, text: str, ellipsis: bool = True): """ pending_msg = self.start_pending(text, ellipsis=ellipsis) try: - yield + yield pending_msg finally: - self.close_pending(pending_msg) + if not pending_msg.closed: + self.close_pending(pending_msg) def get_llm_chain(self): lm_provider = self.config_manager.lm_provider diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index e0d923e77..4aebdde80 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -97,29 +97,28 @@ async def process_message(self, message: HumanChatMessage): received_first_chunk = False # start with a pending message - pending_message = self.start_pending("Generating response") - - # stream response in chunks. this works even if a provider does not - # implement streaming, as `astream()` defaults to yielding `_call()` - # when `_stream()` is not implemented on the LLM class. - async for chunk in self.llm_chain.astream( - {"input": message.body}, - config={"configurable": {"session_id": "static_session"}}, - ): - if not received_first_chunk: - # when receiving the first chunk, close the pending message and - # start the stream. - self.close_pending(pending_message) - stream_id = self._start_stream(human_msg=message) - received_first_chunk = True - - if isinstance(chunk, AIMessageChunk): - self._send_stream_chunk(stream_id, chunk.content) - elif isinstance(chunk, str): - self._send_stream_chunk(stream_id, chunk) - else: - self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") - break - - # complete stream after all chunks have been streamed - self._send_stream_chunk(stream_id, "", complete=True) + with self.pending("Generating response") as pending_message: + # stream response in chunks. this works even if a provider does not + # implement streaming, as `astream()` defaults to yielding `_call()` + # when `_stream()` is not implemented on the LLM class. + async for chunk in self.llm_chain.astream( + {"input": message.body}, + config={"configurable": {"session_id": "static_session"}}, + ): + if not received_first_chunk: + # when receiving the first chunk, close the pending message and + # start the stream. + self.close_pending(pending_message) + stream_id = self._start_stream(human_msg=message) + received_first_chunk = True + + if isinstance(chunk, AIMessageChunk): + self._send_stream_chunk(stream_id, chunk.content) + elif isinstance(chunk, str): + self._send_stream_chunk(stream_id, chunk) + else: + self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") + break + + # complete stream after all chunks have been streamed + self._send_stream_chunk(stream_id, "", complete=True) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index ea541568b..331f4cd94 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -103,6 +103,7 @@ class PendingMessage(BaseModel): body: str persona: Persona ellipsis: bool = True + closed: bool = False class ClosePendingMessage(BaseModel): diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index 496cde371..6fab65763 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -1,8 +1,27 @@ +import logging import os import stat +from typing import Optional from unittest import mock -from jupyter_ai.chat_handlers import learn +import pytest +from jupyter_ai.chat_handlers import DefaultChatHandler, learn +from jupyter_ai.config_manager import ConfigManager +from jupyter_ai.handlers import RootChatHandler +from jupyter_ai.models import ( + AgentStreamChunkMessage, + AgentStreamMessage, + ChatClient, + ClosePendingMessage, + HumanChatMessage, + Message, + PendingMessage, + Persona, +) +from jupyter_ai_magics import BaseProvider +from langchain_community.llms import FakeListLLM +from tornado.httputil import HTTPServerRequest +from tornado.web import Application class MockLearnHandler(learn.LearnChatHandler): @@ -10,6 +29,69 @@ def __init__(self): pass +class MockProvider(BaseProvider, FakeListLLM): + id = "my_provider" + name = "My Provider" + model_id_key = "model" + models = ["model"] + should_raise: Optional[bool] + + def __init__(self, **kwargs): + if "responses" not in kwargs: + kwargs["responses"] = ["Test response"] + super().__init__(**kwargs) + + def astream(self, *args, **kwargs): + if self.should_raise: + raise TestException() + return super().astream(*args, **kwargs) + + +class TestDefaultChatHandler(DefaultChatHandler): + def __init__(self, lm_provider=None, lm_provider_params=None): + self.request = HTTPServerRequest() + self.application = Application() + self.messages = [] + self.tasks = [] + config_manager = mock.create_autospec(ConfigManager) + config_manager.lm_provider = lm_provider or MockProvider + config_manager.lm_provider_params = lm_provider_params or {"model_id": "model"} + config_manager.persona = Persona(name="test", avatar_route="test") + + def broadcast_message(message: Message) -> None: + self.messages.append(message) + + root_handler = mock.create_autospec(RootChatHandler) + root_handler.broadcast_message = broadcast_message + + super().__init__( + log=logging.getLogger(__name__), + config_manager=config_manager, + root_chat_handlers={"root": root_handler}, + model_parameters={}, + chat_history=[], + root_dir="", + preferred_dir="", + dask_client_future=None, + ) + + +class TestException(Exception): + pass + + +@pytest.fixture +def chat_client(): + return ChatClient( + id=0, username="test", initials="test", name="test", display_name="test" + ) + + +@pytest.fixture +def human_chat_message(chat_client): + return HumanChatMessage(id="test", time=0, body="test message", client=chat_client) + + def test_learn_index_permissions(tmp_path): test_dir = tmp_path / "test" with mock.patch.object(learn, "INDEX_SAVE_DIR", new=test_dir): @@ -19,6 +101,62 @@ def test_learn_index_permissions(tmp_path): assert stat.filemode(mode) == "drwx------" +async def test_default_closes_pending_on_success(human_chat_message): + handler = TestDefaultChatHandler( + lm_provider=MockProvider, + lm_provider_params={ + "model_id": "model", + "should_raise": False, + }, + ) + await handler.process_message(human_chat_message) + + # >=2 because there are additional stream messages that follow + assert len(handler.messages) >= 2 + assert isinstance(handler.messages[0], PendingMessage) + assert isinstance(handler.messages[1], ClosePendingMessage) + + +async def test_default_closes_pending_on_error(human_chat_message): + handler = TestDefaultChatHandler( + lm_provider=MockProvider, + lm_provider_params={ + "model_id": "model", + "should_raise": True, + }, + ) + with pytest.raises(TestException): + await handler.process_message(human_chat_message) + + assert len(handler.messages) == 2 + assert isinstance(handler.messages[0], PendingMessage) + assert isinstance(handler.messages[1], ClosePendingMessage) + + +async def test_sends_closing_message_at_most_once(human_chat_message): + handler = TestDefaultChatHandler( + lm_provider=MockProvider, + lm_provider_params={ + "model_id": "model", + "should_raise": False, + }, + ) + message = handler.start_pending("Flushing Pipe Network") + assert len(handler.messages) == 1 + assert isinstance(handler.messages[0], PendingMessage) + assert not message.closed + + handler.close_pending(message) + assert len(handler.messages) == 2 + assert isinstance(handler.messages[1], ClosePendingMessage) + assert message.closed + + # closing an already closed message does not lead to another broadcast + handler.close_pending(message) + assert len(handler.messages) == 2 + assert message.closed + + # TODO # import json