Skip to content

Commit

Permalink
Fix streaming, add minimal tests
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski committed Jan 21, 2024
1 parent 6de2047 commit 255c4df
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 8 deletions.
6 changes: 2 additions & 4 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def handle_exc(self, e: Exception, request: InlineCompletionRequest):
"""
Handles an exception raised in either `handle_request()` or
`handle_stream_request()`. This base class provides a default
implementation, which may be overriden by subclasses.
implementation, which may be overridden by subclasses.
"""
error = CompletionError(
type=e.__class__.__name__,
Expand All @@ -162,8 +162,6 @@ async def _handle_request(self, request: InlineCompletionRequest):
async def _handle_stream_request(self, request: InlineCompletionRequest):
"""Private wrapper around `self.handle_stream_request()`."""
start = time.time()
await self._handle_stream_request(request)
async for chunk in self.stream(request):
self.write_message(chunk.dict())
await self.handle_stream_request(request)
latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Inline completion streaming completed in {latency_ms} ms.")
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def create_llm_chain(
self.llm = llm
self.llm_chain = prompt_template | llm | StrOutputParser()

async def handle_request(
self, request: InlineCompletionRequest
) -> InlineCompletionReply:
async def handle_request(self, request: InlineCompletionRequest) -> None:
"""Handles an inline completion request without streaming."""
self.get_llm_chain()
model_arguments = self._template_inputs_from_request(request)
Expand Down Expand Up @@ -111,7 +109,7 @@ def _write_incomplete_reply(self, request: InlineCompletionRequest):

async def handle_stream_request(self, request: InlineCompletionRequest):
# first, send empty initial reply.
self._write_incomplete_reply()
self._write_incomplete_reply(request)

# then, generate and stream LLM output over this connection.
self.get_llm_chain()
Expand Down
Empty file.
116 changes: 116 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import json
from types import SimpleNamespace

from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler
from jupyter_ai.completions.models import InlineCompletionRequest
from jupyter_ai_magics import BaseProvider
from langchain_community.llms import FakeListLLM
from pytest import fixture
from tornado.httputil import HTTPServerRequest
from tornado.web import Application


class MockProvider(BaseProvider, FakeListLLM):
id = "my_provider"
name = "My Provider"
model_id_key = "model"
models = ["model"]

def __init__(self, **kwargs):
kwargs["responses"] = ["Test response"]
super().__init__(**kwargs)


class MockCompletionHandler(DefaultInlineCompletionHandler):
def __init__(self):
self.request = HTTPServerRequest()
self.application = Application()
self.messages = []
self.tasks = []
self.settings["jai_config_manager"] = SimpleNamespace(
lm_provider=MockProvider, lm_provider_params={"model_id": "model"}
)
self.settings["jai_event_loop"] = SimpleNamespace(
create_task=lambda x: self.tasks.append(x)
)
self.settings["model_parameters"] = {}
self.llm_params = {}
self.create_llm_chain(MockProvider, {"model_id": "model"})

def write_message(self, message: str) -> None: # type: ignore
self.messages.append(message)

async def handle_exc(self, e: Exception, _request: InlineCompletionRequest):
# raise all exceptions during testing rather
raise e


@fixture
def inline_handler() -> MockCompletionHandler:
return MockCompletionHandler()


async def test_on_message(inline_handler):
request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=False
)
# Test end to end, without checking details of the replies,
# which are tested in appropriate method unit tests.
await inline_handler.on_message(json.dumps(dict(request)))
assert len(inline_handler.tasks) == 1
await inline_handler.tasks[0]
assert len(inline_handler.messages) == 1


async def test_on_message_stream(inline_handler):
stream_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=True
)
# Test end to end, without checking details of the replies,
# which are tested in appropriate method unit tests.
await inline_handler.on_message(json.dumps(dict(stream_request)))
assert len(inline_handler.tasks) == 1
await inline_handler.tasks[0]
assert len(inline_handler.messages) == 3


async def test_handle_request(inline_handler):
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=False
)
await inline_handler.handle_request(dummy_request)
# should write a single reply
assert len(inline_handler.messages) == 1
# reply should contain a single suggestion
suggestions = inline_handler.messages[0].list.items
assert len(suggestions) == 1
# the suggestion should include insert text from LLM
assert suggestions[0].insertText == "Test response"


async def test_handle_stream_request(inline_handler):
inline_handler.llm_chain = FakeListLLM(responses=["test"])
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=True
)
await inline_handler.handle_stream_request(dummy_request)

# should write three replies
assert len(inline_handler.messages) == 3

# first reply should be empty to start the stream
first = inline_handler.messages[0].list.items[0]
assert first.insertText == ""
assert first.isIncomplete == True

# second reply should be a chunk containing the token
second = inline_handler.messages[1]
assert second.type == "stream"
assert second.response.insertText == "Test response"
assert second.done == False

# third reply should be a closing chunk
third = inline_handler.messages[2]
assert third.type == "stream"
assert third.response.insertText == "Test response"
assert third.done == True

0 comments on commit 255c4df

Please sign in to comment.