Skip to content

Commit

Permalink
Fix removal of pending messages on error (#888)
Browse files Browse the repository at this point in the history
* Add a test case for closing pending messages

* Remove pending message on error

* Add a test for not sending closing message twice

* Review: make `close_pending` idempotent
  • Loading branch information
krassowski authored Jul 12, 2024
1 parent 8b6c82d commit 8ca4854
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 29 deletions.
10 changes: 8 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ def close_pending(self, pending_msg: PendingMessage):
"""
Closes a pending message.
"""
if pending_msg.closed:
return

close_pending_msg = ClosePendingMessage(
id=pending_msg.id,
)
Expand All @@ -289,6 +292,8 @@ def close_pending(self, pending_msg: PendingMessage):
handler.broadcast_message(close_pending_msg)
break

pending_msg.closed = True

@contextlib.contextmanager
def pending(self, text: str, ellipsis: bool = True):
"""
Expand All @@ -297,9 +302,10 @@ def pending(self, text: str, ellipsis: bool = True):
"""
pending_msg = self.start_pending(text, ellipsis=ellipsis)
try:
yield
yield pending_msg
finally:
self.close_pending(pending_msg)
if not pending_msg.closed:
self.close_pending(pending_msg)

def get_llm_chain(self):
lm_provider = self.config_manager.lm_provider
Expand Down
51 changes: 25 additions & 26 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,29 +97,28 @@ async def process_message(self, message: HumanChatMessage):
received_first_chunk = False

# start with a pending message
pending_message = self.start_pending("Generating response")

# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
async for chunk in self.llm_chain.astream(
{"input": message.body},
config={"configurable": {"session_id": "static_session"}},
):
if not received_first_chunk:
# when receiving the first chunk, close the pending message and
# start the stream.
self.close_pending(pending_message)
stream_id = self._start_stream(human_msg=message)
received_first_chunk = True

if isinstance(chunk, AIMessageChunk):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
else:
self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}")
break

# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)
with self.pending("Generating response") as pending_message:
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
async for chunk in self.llm_chain.astream(
{"input": message.body},
config={"configurable": {"session_id": "static_session"}},
):
if not received_first_chunk:
# when receiving the first chunk, close the pending message and
# start the stream.
self.close_pending(pending_message)
stream_id = self._start_stream(human_msg=message)
received_first_chunk = True

if isinstance(chunk, AIMessageChunk):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
else:
self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}")
break

# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class PendingMessage(BaseModel):
body: str
persona: Persona
ellipsis: bool = True
closed: bool = False


class ClosePendingMessage(BaseModel):
Expand Down
140 changes: 139 additions & 1 deletion packages/jupyter-ai/jupyter_ai/tests/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,97 @@
import logging
import os
import stat
from typing import Optional
from unittest import mock

from jupyter_ai.chat_handlers import learn
import pytest
from jupyter_ai.chat_handlers import DefaultChatHandler, learn
from jupyter_ai.config_manager import ConfigManager
from jupyter_ai.handlers import RootChatHandler
from jupyter_ai.models import (
AgentStreamChunkMessage,
AgentStreamMessage,
ChatClient,
ClosePendingMessage,
HumanChatMessage,
Message,
PendingMessage,
Persona,
)
from jupyter_ai_magics import BaseProvider
from langchain_community.llms import FakeListLLM
from tornado.httputil import HTTPServerRequest
from tornado.web import Application


class MockLearnHandler(learn.LearnChatHandler):
def __init__(self):
pass


class MockProvider(BaseProvider, FakeListLLM):
id = "my_provider"
name = "My Provider"
model_id_key = "model"
models = ["model"]
should_raise: Optional[bool]

def __init__(self, **kwargs):
if "responses" not in kwargs:
kwargs["responses"] = ["Test response"]
super().__init__(**kwargs)

def astream(self, *args, **kwargs):
if self.should_raise:
raise TestException()
return super().astream(*args, **kwargs)


class TestDefaultChatHandler(DefaultChatHandler):
def __init__(self, lm_provider=None, lm_provider_params=None):
self.request = HTTPServerRequest()
self.application = Application()
self.messages = []
self.tasks = []
config_manager = mock.create_autospec(ConfigManager)
config_manager.lm_provider = lm_provider or MockProvider
config_manager.lm_provider_params = lm_provider_params or {"model_id": "model"}
config_manager.persona = Persona(name="test", avatar_route="test")

def broadcast_message(message: Message) -> None:
self.messages.append(message)

root_handler = mock.create_autospec(RootChatHandler)
root_handler.broadcast_message = broadcast_message

super().__init__(
log=logging.getLogger(__name__),
config_manager=config_manager,
root_chat_handlers={"root": root_handler},
model_parameters={},
chat_history=[],
root_dir="",
preferred_dir="",
dask_client_future=None,
)


class TestException(Exception):
pass


@pytest.fixture
def chat_client():
return ChatClient(
id=0, username="test", initials="test", name="test", display_name="test"
)


@pytest.fixture
def human_chat_message(chat_client):
return HumanChatMessage(id="test", time=0, body="test message", client=chat_client)


def test_learn_index_permissions(tmp_path):
test_dir = tmp_path / "test"
with mock.patch.object(learn, "INDEX_SAVE_DIR", new=test_dir):
Expand All @@ -19,6 +101,62 @@ def test_learn_index_permissions(tmp_path):
assert stat.filemode(mode) == "drwx------"


async def test_default_closes_pending_on_success(human_chat_message):
handler = TestDefaultChatHandler(
lm_provider=MockProvider,
lm_provider_params={
"model_id": "model",
"should_raise": False,
},
)
await handler.process_message(human_chat_message)

# >=2 because there are additional stream messages that follow
assert len(handler.messages) >= 2
assert isinstance(handler.messages[0], PendingMessage)
assert isinstance(handler.messages[1], ClosePendingMessage)


async def test_default_closes_pending_on_error(human_chat_message):
handler = TestDefaultChatHandler(
lm_provider=MockProvider,
lm_provider_params={
"model_id": "model",
"should_raise": True,
},
)
with pytest.raises(TestException):
await handler.process_message(human_chat_message)

assert len(handler.messages) == 2
assert isinstance(handler.messages[0], PendingMessage)
assert isinstance(handler.messages[1], ClosePendingMessage)


async def test_sends_closing_message_at_most_once(human_chat_message):
handler = TestDefaultChatHandler(
lm_provider=MockProvider,
lm_provider_params={
"model_id": "model",
"should_raise": False,
},
)
message = handler.start_pending("Flushing Pipe Network")
assert len(handler.messages) == 1
assert isinstance(handler.messages[0], PendingMessage)
assert not message.closed

handler.close_pending(message)
assert len(handler.messages) == 2
assert isinstance(handler.messages[1], ClosePendingMessage)
assert message.closed

# closing an already closed message does not lead to another broadcast
handler.close_pending(message)
assert len(handler.messages) == 2
assert message.closed


# TODO
# import json

Expand Down

0 comments on commit 8ca4854

Please sign in to comment.