Skip to content

Commit

Permalink
Add a test case for closing pending messages
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski committed Jul 12, 2024
1 parent 8b6c82d commit e6b37b1
Showing 1 changed file with 115 additions and 1 deletion.
116 changes: 115 additions & 1 deletion packages/jupyter-ai/jupyter_ai/tests/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,97 @@
import logging
import os
import stat
from typing import Optional
from unittest import mock

from jupyter_ai.chat_handlers import learn
import pytest
from jupyter_ai.chat_handlers import DefaultChatHandler, learn
from jupyter_ai.config_manager import ConfigManager
from jupyter_ai.handlers import RootChatHandler
from jupyter_ai.models import (
AgentStreamChunkMessage,
AgentStreamMessage,
ChatClient,
ClosePendingMessage,
HumanChatMessage,
Message,
PendingMessage,
Persona,
)
from jupyter_ai_magics import BaseProvider
from langchain_community.llms import FakeListLLM
from tornado.httputil import HTTPServerRequest
from tornado.web import Application


class MockLearnHandler(learn.LearnChatHandler):
def __init__(self):
pass


class MockProvider(BaseProvider, FakeListLLM):
id = "my_provider"
name = "My Provider"
model_id_key = "model"
models = ["model"]
should_raise: Optional[bool]

def __init__(self, **kwargs):
if "responses" not in kwargs:
kwargs["responses"] = ["Test response"]
super().__init__(**kwargs)

def astream(self, *args, **kwargs):
if self.should_raise:
raise TestException()
return super().astream(*args, **kwargs)


class TestDefaultChatHandler(DefaultChatHandler):
def __init__(self, lm_provider=None, lm_provider_params=None):
self.request = HTTPServerRequest()
self.application = Application()
self.messages = []
self.tasks = []
config_manager = mock.create_autospec(ConfigManager)
config_manager.lm_provider = lm_provider or MockProvider
config_manager.lm_provider_params = lm_provider_params or {"model_id": "model"}
config_manager.persona = Persona(name="test", avatar_route="test")

def broadcast_message(message: Message) -> None:
self.messages.append(message)

root_handler = mock.create_autospec(RootChatHandler)
root_handler.broadcast_message = broadcast_message

super().__init__(
log=logging.getLogger(__name__),
config_manager=config_manager,
root_chat_handlers={"root": root_handler},
model_parameters={},
chat_history=[],
root_dir="",
preferred_dir="",
dask_client_future=None,
)


class TestException(Exception):
pass


@pytest.fixture
def chat_client():
return ChatClient(
id=0, username="test", initials="test", name="test", display_name="test"
)


@pytest.fixture
def human_chat_message(chat_client):
return HumanChatMessage(id="test", time=0, body="test message", client=chat_client)


def test_learn_index_permissions(tmp_path):
test_dir = tmp_path / "test"
with mock.patch.object(learn, "INDEX_SAVE_DIR", new=test_dir):
Expand All @@ -19,6 +101,38 @@ def test_learn_index_permissions(tmp_path):
assert stat.filemode(mode) == "drwx------"


async def test_default_closes_pending_on_success(human_chat_message):
handler = TestDefaultChatHandler(
lm_provider=MockProvider,
lm_provider_params={
"model_id": "model",
"should_raise": False,
},
)
await handler.process_message(human_chat_message)

# >=2 because there are additional stream messages that follow
assert len(handler.messages) >= 2
assert isinstance(handler.messages[0], PendingMessage)
assert isinstance(handler.messages[1], ClosePendingMessage)


async def test_default_closes_pending_on_error(human_chat_message):
handler = TestDefaultChatHandler(
lm_provider=MockProvider,
lm_provider_params={
"model_id": "model",
"should_raise": True,
},
)
with pytest.raises(TestException):
await handler.process_message(human_chat_message)

assert len(handler.messages) == 2
assert isinstance(handler.messages[0], PendingMessage)
assert isinstance(handler.messages[1], ClosePendingMessage)


# TODO
# import json

Expand Down

0 comments on commit e6b37b1

Please sign in to comment.