Skip to content

Commit

Permalink
Remove closing markdown identifiers (#686)
Browse files Browse the repository at this point in the history
  • Loading branch information
bartleusink committed Apr 11, 2024
1 parent c12b7fd commit fdf660d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ async def handle_stream_request(self, request: InlineCompletionRequest):
continue
else:
suggestion = self._post_process_suggestion(suggestion, request)
elif suggestion.endswith("```"):
suggestion = self._post_process_suggestion(suggestion, request)
self.write_message(
InlineCompletionStreamChunk(
type="stream",
Expand Down Expand Up @@ -151,4 +153,9 @@ def _post_process_suggestion(
if suggestion.startswith(request.prefix):
suggestion = suggestion[len(request.prefix) :]
break

# check if the suggestion ends with a closing markdown identifier and remove it
if suggestion.endswith("```"):
suggestion = suggestion[:-3].rstrip()

return suggestion
39 changes: 34 additions & 5 deletions packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class MockProvider(BaseProvider, FakeListLLM):
models = ["model"]

def __init__(self, **kwargs):
kwargs["responses"] = ["Test response"]
if not "responses" in kwargs:
kwargs["responses"] = ["Test response"]
super().__init__(**kwargs)


Expand All @@ -34,7 +35,7 @@ def __init__(self):
create_task=lambda x: self.tasks.append(x)
)
self.settings["model_parameters"] = {}
self.llm_params = {}
self.llm_params = {"model_id": "model"}
self.create_llm_chain(MockProvider, {"model_id": "model"})

def write_message(self, message: str) -> None: # type: ignore
Expand Down Expand Up @@ -88,8 +89,36 @@ async def test_handle_request(inline_handler):
assert suggestions[0].insertText == "Test response"


async def test_handle_request_with_spurious_fragments(inline_handler):
inline_handler.create_llm_chain(
MockProvider,
{
"model_id": "model",
"responses": ["```python\nTest python code\n```"],
},
)
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=False
)

await inline_handler.handle_request(dummy_request)
# should write a single reply
assert len(inline_handler.messages) == 1
# reply should contain a single suggestion
suggestions = inline_handler.messages[0].list.items
assert len(suggestions) == 1
# the suggestion should include insert text from LLM without spurious fragments
assert suggestions[0].insertText == "Test python code"


async def test_handle_stream_request(inline_handler):
inline_handler.llm_chain = FakeListLLM(responses=["test"])
inline_handler.create_llm_chain(
MockProvider,
{
"model_id": "model",
"responses": ["test"],
},
)
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=True
)
Expand All @@ -106,11 +135,11 @@ async def test_handle_stream_request(inline_handler):
# second reply should be a chunk containing the token
second = inline_handler.messages[1]
assert second.type == "stream"
assert second.response.insertText == "Test response"
assert second.response.insertText == "test"
assert second.done == False

# third reply should be a closing chunk
third = inline_handler.messages[2]
assert third.type == "stream"
assert third.response.insertText == "Test response"
assert third.response.insertText == "test"
assert third.done == True

0 comments on commit fdf660d

Please sign in to comment.