From e6b37b1ecd7eb7f7ffa79464c2577b8f6e7f121e Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Wed, 10 Jul 2024 19:35:56 +0100 Subject: [PATCH] Add a test case for closing pending messages --- .../jupyter_ai/tests/test_handlers.py | 116 +++++++++++++++++- 1 file changed, 115 insertions(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index 496cde371..a38ea6a26 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -1,8 +1,27 @@ +import logging import os import stat +from typing import Optional from unittest import mock -from jupyter_ai.chat_handlers import learn +import pytest +from jupyter_ai.chat_handlers import DefaultChatHandler, learn +from jupyter_ai.config_manager import ConfigManager +from jupyter_ai.handlers import RootChatHandler +from jupyter_ai.models import ( + AgentStreamChunkMessage, + AgentStreamMessage, + ChatClient, + ClosePendingMessage, + HumanChatMessage, + Message, + PendingMessage, + Persona, +) +from jupyter_ai_magics import BaseProvider +from langchain_community.llms import FakeListLLM +from tornado.httputil import HTTPServerRequest +from tornado.web import Application class MockLearnHandler(learn.LearnChatHandler): @@ -10,6 +29,69 @@ def __init__(self): pass +class MockProvider(BaseProvider, FakeListLLM): + id = "my_provider" + name = "My Provider" + model_id_key = "model" + models = ["model"] + should_raise: Optional[bool] + + def __init__(self, **kwargs): + if "responses" not in kwargs: + kwargs["responses"] = ["Test response"] + super().__init__(**kwargs) + + def astream(self, *args, **kwargs): + if self.should_raise: + raise TestException() + return super().astream(*args, **kwargs) + + +class TestDefaultChatHandler(DefaultChatHandler): + def __init__(self, lm_provider=None, lm_provider_params=None): + self.request = HTTPServerRequest() + self.application = Application() + self.messages = [] + self.tasks = [] + config_manager = mock.create_autospec(ConfigManager) + config_manager.lm_provider = lm_provider or MockProvider + config_manager.lm_provider_params = lm_provider_params or {"model_id": "model"} + config_manager.persona = Persona(name="test", avatar_route="test") + + def broadcast_message(message: Message) -> None: + self.messages.append(message) + + root_handler = mock.create_autospec(RootChatHandler) + root_handler.broadcast_message = broadcast_message + + super().__init__( + log=logging.getLogger(__name__), + config_manager=config_manager, + root_chat_handlers={"root": root_handler}, + model_parameters={}, + chat_history=[], + root_dir="", + preferred_dir="", + dask_client_future=None, + ) + + +class TestException(Exception): + pass + + +@pytest.fixture +def chat_client(): + return ChatClient( + id=0, username="test", initials="test", name="test", display_name="test" + ) + + +@pytest.fixture +def human_chat_message(chat_client): + return HumanChatMessage(id="test", time=0, body="test message", client=chat_client) + + def test_learn_index_permissions(tmp_path): test_dir = tmp_path / "test" with mock.patch.object(learn, "INDEX_SAVE_DIR", new=test_dir): @@ -19,6 +101,38 @@ def test_learn_index_permissions(tmp_path): assert stat.filemode(mode) == "drwx------" +async def test_default_closes_pending_on_success(human_chat_message): + handler = TestDefaultChatHandler( + lm_provider=MockProvider, + lm_provider_params={ + "model_id": "model", + "should_raise": False, + }, + ) + await handler.process_message(human_chat_message) + + # >=2 because there are additional stream messages that follow + assert len(handler.messages) >= 2 + assert isinstance(handler.messages[0], PendingMessage) + assert isinstance(handler.messages[1], ClosePendingMessage) + + +async def test_default_closes_pending_on_error(human_chat_message): + handler = TestDefaultChatHandler( + lm_provider=MockProvider, + lm_provider_params={ + "model_id": "model", + "should_raise": True, + }, + ) + with pytest.raises(TestException): + await handler.process_message(human_chat_message) + + assert len(handler.messages) == 2 + assert isinstance(handler.messages[0], PendingMessage) + assert isinstance(handler.messages[1], ClosePendingMessage) + + # TODO # import json