diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index eb03df156..9d7e7915c 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -91,6 +91,8 @@ async def handle_stream_request(self, request: InlineCompletionRequest): continue else: suggestion = self._post_process_suggestion(suggestion, request) + elif suggestion.rstrip().endswith("```"): + suggestion = self._post_process_suggestion(suggestion, request) self.write_message( InlineCompletionStreamChunk( type="stream", @@ -151,4 +153,9 @@ def _post_process_suggestion( if suggestion.startswith(request.prefix): suggestion = suggestion[len(request.prefix) :] break + + # check if the suggestion ends with a closing markdown identifier and remove it + if suggestion.rstrip().endswith("```"): + suggestion = suggestion.rstrip()[:-3].rstrip() + return suggestion diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index 1b950af74..fd2b2666c 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -1,6 +1,7 @@ import json from types import SimpleNamespace +import pytest from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler from jupyter_ai.completions.models import InlineCompletionRequest from jupyter_ai_magics import BaseProvider @@ -17,7 +18,8 @@ class MockProvider(BaseProvider, FakeListLLM): models = ["model"] def __init__(self, **kwargs): - kwargs["responses"] = ["Test response"] + if not "responses" in kwargs: + kwargs["responses"] = ["Test response"] super().__init__(**kwargs) @@ -34,7 +36,7 @@ def __init__(self): create_task=lambda x: self.tasks.append(x) ) self.settings["model_parameters"] = {} - self.llm_params = {} + self.llm_params = {"model_id": "model"} self.create_llm_chain(MockProvider, {"model_id": "model"}) def write_message(self, message: str) -> None: # type: ignore @@ -88,8 +90,45 @@ async def test_handle_request(inline_handler): assert suggestions[0].insertText == "Test response" +@pytest.mark.parametrize( + "response,expected_suggestion", + [ + ("```python\nTest python code\n```", "Test python code"), + ("```\ntest\n```\n \n", "test"), + ("```hello```world```", "hello```world"), + ], +) +async def test_handle_request_with_spurious_fragments(response, expected_suggestion): + inline_handler = MockCompletionHandler() + inline_handler.create_llm_chain( + MockProvider, + { + "model_id": "model", + "responses": [response], + }, + ) + dummy_request = InlineCompletionRequest( + number=1, prefix="", suffix="", mime="", stream=False + ) + + await inline_handler.handle_request(dummy_request) + # should write a single reply + assert len(inline_handler.messages) == 1 + # reply should contain a single suggestion + suggestions = inline_handler.messages[0].list.items + assert len(suggestions) == 1 + # the suggestion should include insert text from LLM without spurious fragments + assert suggestions[0].insertText == expected_suggestion + + async def test_handle_stream_request(inline_handler): - inline_handler.llm_chain = FakeListLLM(responses=["test"]) + inline_handler.create_llm_chain( + MockProvider, + { + "model_id": "model", + "responses": ["test"], + }, + ) dummy_request = InlineCompletionRequest( number=1, prefix="", suffix="", mime="", stream=True ) @@ -106,11 +145,11 @@ async def test_handle_stream_request(inline_handler): # second reply should be a chunk containing the token second = inline_handler.messages[1] assert second.type == "stream" - assert second.response.insertText == "Test response" + assert second.response.insertText == "test" assert second.done == False # third reply should be a closing chunk third = inline_handler.messages[2] assert third.type == "stream" - assert third.response.insertText == "Test response" + assert third.response.insertText == "test" assert third.done == True