From 39e8d2bf3eb27fec34170f46891855f2290fa529 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Krassowski?= <5832902+krassowski@users.noreply.github.com> Date: Tue, 23 Apr 2024 21:23:49 +0100 Subject: [PATCH] Move methods generating completion replies to the provider (#717) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Move methods generating completion replies to provider Add an example completion provider Adjust name of variable to reflect new implementation * Rename `_completion.py` → `completion_utils.py` --- docs/source/developers/index.md | 78 +++++++++ .../jupyter_ai_magics/completion_utils.py | 52 ++++++ .../jupyter_ai_magics/models/completion.py | 81 +++++++++ .../jupyter_ai_magics/providers.py | 86 +++++++++- .../jupyter_ai/completions/handlers/base.py | 30 ++-- .../completions/handlers/default.py | 161 ++---------------- .../completions/handlers/llm_mixin.py | 51 +++--- .../jupyter_ai/completions/models.py | 88 ++-------- .../tests/completions/test_handlers.py | 43 +++-- 9 files changed, 390 insertions(+), 280 deletions(-) create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py diff --git a/docs/source/developers/index.md b/docs/source/developers/index.md index c9062ad1d..aac923285 100644 --- a/docs/source/developers/index.md +++ b/docs/source/developers/index.md @@ -150,6 +150,84 @@ my-provider = "my_provider:MyEmbeddingsProvider" [Embeddings]: https://api.python.langchain.com/en/stable/embeddings/langchain_core.embeddings.Embeddings.html + +### Custom completion providers + +Any model provider derived from `BaseProvider` can be used as a completion provider. +However, some providers may benefit from customizing handling of completion requests. + +There are two asynchronous methods which can be overridden in subclasses of `BaseProvider`: +- `generate_inline_completions`: takes a request (`InlineCompletionRequest`) and returns `InlineCompletionReply` +- `stream_inline_completions`: takes a request and yields an initiating reply (`InlineCompletionReply`) with `isIncomplete` set to `True` followed by subsequent chunks (`InlineCompletionStreamChunk`) + +When streaming all replies and chunks for given invocation of the `stream_inline_completions()` method should include a constant and unique string token identifying the stream. All chunks except for the last chunk for a given item should have the `done` value set to `False`. + +The following example demonstrates a custom implementation of the completion provider with both a method for sending multiple completions in one go, and streaming multiple completions concurrently. +The implementation and explanation for the `merge_iterators` function used in this example can be found [here](https://stackoverflow.com/q/72445371/4877269). + +```python +class MyCompletionProvider(BaseProvider, FakeListLLM): + id = "my_provider" + name = "My Provider" + model_id_key = "model" + models = ["model_a"] + + def __init__(self, **kwargs): + kwargs["responses"] = ["This fake response will not be used for completion"] + super().__init__(**kwargs) + + async def generate_inline_completions(self, request: InlineCompletionRequest): + return InlineCompletionReply( + list=InlineCompletionList(items=[ + {"insertText": "An ant minding its own business"}, + {"insertText": "A bug searching for a snack"} + ]), + reply_to=request.number, + ) + + async def stream_inline_completions(self, request: InlineCompletionRequest): + token_1 = f"t{request.number}s0" + token_2 = f"t{request.number}s1" + + yield InlineCompletionReply( + list=InlineCompletionList( + items=[ + {"insertText": "An ", "isIncomplete": True, "token": token_1}, + {"insertText": "", "isIncomplete": True, "token": token_2} + ] + ), + reply_to=request.number, + ) + + # where merge_iterators + async for reply in merge_iterators([ + self._stream("elephant dancing in the rain", request.number, token_1, start_with="An"), + self._stream("A flock of birds flying around a mountain", request.number, token_2) + ]): + yield reply + + async def _stream(self, sentence, request_number, token, start_with = ""): + suggestion = start_with + + for fragment in sentence.split(): + await asyncio.sleep(0.75) + suggestion += " " + fragment + yield InlineCompletionStreamChunk( + type="stream", + response={"insertText": suggestion, "token": token}, + reply_to=request_number, + done=False + ) + + # finally, send a message confirming that we are done + yield InlineCompletionStreamChunk( + type="stream", + response={"insertText": suggestion, "token": token}, + reply_to=request_number, + done=True, + ) +``` + ## Prompt templates Each provider can define **prompt templates** for each supported format. A prompt diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py new file mode 100644 index 000000000..204da5e7b --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py @@ -0,0 +1,52 @@ +from typing import Dict + +from .models.completion import InlineCompletionRequest + + +def token_from_request(request: InlineCompletionRequest, suggestion: int): + """Generate a deterministic token (for matching streamed messages) + using request number and suggestion number""" + return f"t{request.number}s{suggestion}" + + +def template_inputs_from_request(request: InlineCompletionRequest) -> Dict: + suffix = request.suffix.strip() + filename = request.path.split("/")[-1] if request.path else "untitled" + + return { + "prefix": request.prefix, + "suffix": suffix, + "language": request.language, + "filename": filename, + "stop": ["\n```"], + } + + +def post_process_suggestion(suggestion: str, request: InlineCompletionRequest) -> str: + """Remove spurious fragments from the suggestion. + + While most models (especially instruct and infill models do not require + any pre-processing, some models such as gpt-4 which only have chat APIs + may require removing spurious fragments. This function uses heuristics + and request data to remove such fragments. + """ + # gpt-4 tends to add "```python" or similar + language = request.language or "python" + markdown_identifiers = {"ipython": ["ipython", "python", "py"]} + bad_openings = [ + f"```{identifier}" + for identifier in markdown_identifiers.get(language, [language]) + ] + ["```"] + for opening in bad_openings: + if suggestion.startswith(opening): + suggestion = suggestion[len(opening) :].lstrip() + # check for the prefix inclusion (only if there was a bad opening) + if suggestion.startswith(request.prefix): + suggestion = suggestion[len(request.prefix) :] + break + + # check if the suggestion ends with a closing markdown identifier and remove it + if suggestion.rstrip().endswith("```"): + suggestion = suggestion.rstrip()[:-3].rstrip() + + return suggestion diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py new file mode 100644 index 000000000..147f6ceec --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py @@ -0,0 +1,81 @@ +from typing import List, Literal, Optional + +from langchain.pydantic_v1 import BaseModel + + +class InlineCompletionRequest(BaseModel): + """Message send by client to request inline completions. + + Prefix/suffix implementation is used to avoid the need for synchronising + the notebook state at every key press (subject to change in future).""" + + # unique message ID generated by the client used to identify replies and + # to easily discard replies for older requests + number: int + # prefix should include full text of the current cell preceding the cursor + prefix: str + # suffix should include full text of the current cell preceding the cursor + suffix: str + # media type for the current language, e.g. `text/x-python` + mime: str + # whether to stream the response (if supported by the model) + stream: bool + # path to the notebook of file for which the completions are generated + path: Optional[str] + # language inferred from the document mime type (if possible) + language: Optional[str] + # identifier of the cell for which the completions are generated if in a notebook + # previous cells and following cells can be used to learn the wider context + cell_id: Optional[str] + + +class InlineCompletionItem(BaseModel): + """The inline completion suggestion to be displayed on the frontend. + + See JupyterLab `InlineCompletionItem` documentation for the details. + """ + + insertText: str + filterText: Optional[str] + isIncomplete: Optional[bool] + token: Optional[str] + + +class CompletionError(BaseModel): + type: str + traceback: str + + +class InlineCompletionList(BaseModel): + """Reflection of JupyterLab's `IInlineCompletionList`.""" + + items: List[InlineCompletionItem] + + +class InlineCompletionReply(BaseModel): + """Message sent from model to client with the infill suggestions""" + + list: InlineCompletionList + # number of request for which we are replying + reply_to: int + error: Optional[CompletionError] + + +class InlineCompletionStreamChunk(BaseModel): + """Message sent from model to client with the infill suggestions""" + + type: Literal["stream"] = "stream" + response: InlineCompletionItem + reply_to: int + done: bool + error: Optional[CompletionError] + + +__all__ = [ + "InlineCompletionRequest", + "InlineCompletionItem", + "CompletionError", + "InlineCompletionList", + "InlineCompletionReply", + "InlineCompletionStreamChunk", +] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index c07061b1d..40764f232 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -5,7 +5,17 @@ import io import json from concurrent.futures import ThreadPoolExecutor -from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union +from typing import ( + Any, + AsyncIterator, + ClassVar, + Coroutine, + Dict, + List, + Literal, + Optional, + Union, +) from jsonpath_ng import parse from langchain.chat_models.base import BaseChatModel @@ -20,6 +30,8 @@ ) from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.schema import LLMResult +from langchain.schema.output_parser import StrOutputParser +from langchain.schema.runnable import Runnable from langchain.utils import get_from_dict_or_env from langchain_community.chat_models import ( BedrockChat, @@ -46,6 +58,13 @@ except: from pydantic.main import ModelMetaclass +from . import completion_utils as completion +from .models.completion import ( + InlineCompletionList, + InlineCompletionReply, + InlineCompletionRequest, + InlineCompletionStreamChunk, +) from .models.persona import Persona CHAT_SYSTEM_PROMPT = """ @@ -395,6 +414,71 @@ def is_chat_provider(self): def allows_concurrency(self): return True + async def generate_inline_completions( + self, request: InlineCompletionRequest + ) -> InlineCompletionReply: + chain = self._create_completion_chain() + model_arguments = completion.template_inputs_from_request(request) + suggestion = await chain.ainvoke(input=model_arguments) + suggestion = completion.post_process_suggestion(suggestion, request) + return InlineCompletionReply( + list=InlineCompletionList(items=[{"insertText": suggestion}]), + reply_to=request.number, + ) + + async def stream_inline_completions( + self, request: InlineCompletionRequest + ) -> AsyncIterator[InlineCompletionStreamChunk]: + chain = self._create_completion_chain() + token = completion.token_from_request(request, 0) + model_arguments = completion.template_inputs_from_request(request) + suggestion = "" + + # send an incomplete `InlineCompletionReply`, indicating to the + # client that LLM output is about to streamed across this connection. + yield InlineCompletionReply( + list=InlineCompletionList( + items=[ + { + # insert text starts empty as we do not pre-generate any part + "insertText": "", + "isIncomplete": True, + "token": token, + } + ] + ), + reply_to=request.number, + ) + + async for fragment in chain.astream(input=model_arguments): + suggestion += fragment + if suggestion.startswith("```"): + if "\n" not in suggestion: + # we are not ready to apply post-processing + continue + else: + suggestion = completion.post_process_suggestion(suggestion, request) + elif suggestion.rstrip().endswith("```"): + suggestion = completion.post_process_suggestion(suggestion, request) + yield InlineCompletionStreamChunk( + type="stream", + response={"insertText": suggestion, "token": token}, + reply_to=request.number, + done=False, + ) + + # finally, send a message confirming that we are done + yield InlineCompletionStreamChunk( + type="stream", + response={"insertText": suggestion, "token": token}, + reply_to=request.number, + done=True, + ) + + def _create_completion_chain(self) -> Runnable: + prompt_template = self.get_completion_prompt_template() + return prompt_template | self | StrOutputParser() + class AI21Provider(BaseProvider, AI21): id = "ai21" diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py index c52c308db..9eb4f845a 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py @@ -2,7 +2,7 @@ import time import traceback from asyncio import AbstractEventLoop -from typing import Any, AsyncIterator, Dict, Union +from typing import Union import tornado from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin @@ -14,7 +14,7 @@ InlineCompletionStreamChunk, ) from jupyter_server.base.handlers import JupyterHandler -from langchain.pydantic_v1 import BaseModel, ValidationError +from langchain.pydantic_v1 import ValidationError class BaseInlineCompletionHandler( @@ -27,12 +27,10 @@ class BaseInlineCompletionHandler( ## # Interface for subclasses ## - async def handle_request( - self, message: InlineCompletionRequest - ) -> InlineCompletionReply: + async def handle_request(self, message: InlineCompletionRequest) -> None: """ Handles an inline completion request, without streaming. Subclasses - must define this method and write a reply via `self.write_message()`. + must define this method and write a reply via `self.reply()`. The method definition does not need to be wrapped in a try/except block. """ @@ -40,14 +38,11 @@ async def handle_request( "The required method `self.handle_request()` is not defined by this subclass." ) - async def handle_stream_request( - self, message: InlineCompletionRequest - ) -> AsyncIterator[InlineCompletionStreamChunk]: + async def handle_stream_request(self, message: InlineCompletionRequest) -> None: """ Handles an inline completion request, **with streaming**. Implementations may optionally define this method. Implementations that - do so should stream replies via successive calls to - `self.write_message()`. + do so should stream replies via successive calls to `self.reply()`. The method definition does not need to be wrapped in a try/except block. """ @@ -64,14 +59,9 @@ async def handle_stream_request( def loop(self) -> AbstractEventLoop: return self.settings["jai_event_loop"] - def write_message(self, message: Union[bytes, str, Dict[str, Any], BaseModel]): - """ - Write a bytes, string, dict, or Pydantic model object to the WebSocket - connection. The base definition of this method is provided by Tornado. - """ - if isinstance(message, BaseModel): - message = message.dict() - + def reply(self, reply: Union[InlineCompletionReply, InlineCompletionStreamChunk]): + """Write a reply object to the WebSocket connection.""" + message = reply.dict() super().write_message(message) def initialize(self): @@ -144,7 +134,7 @@ async def handle_exc(self, e: Exception, request: InlineCompletionRequest): title=e.args[0] if e.args else "Exception", traceback=traceback.format_exc(), ) - self.write_message( + self.reply( InlineCompletionReply( list=InlineCompletionList(items=[]), error=error, diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index 9d7e7915c..38676b998 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -1,161 +1,24 @@ -from typing import Dict, Type - -from jupyter_ai_magics.providers import BaseProvider -from langchain.prompts import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - PromptTemplate, - SystemMessagePromptTemplate, -) -from langchain.schema.output_parser import StrOutputParser -from langchain.schema.runnable import Runnable - -from ..models import ( - InlineCompletionList, - InlineCompletionReply, - InlineCompletionRequest, - InlineCompletionStreamChunk, -) +from ..models import InlineCompletionRequest from .base import BaseInlineCompletionHandler class DefaultInlineCompletionHandler(BaseInlineCompletionHandler): - llm_chain: Runnable - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def create_llm_chain( - self, provider: Type[BaseProvider], provider_params: Dict[str, str] - ): - unified_parameters = { - **provider_params, - **(self.get_model_parameters(provider, provider_params)), - } - llm = provider(**unified_parameters) - - prompt_template = llm.get_completion_prompt_template() - - self.llm = llm - self.llm_chain = prompt_template | llm | StrOutputParser() - - async def handle_request(self, request: InlineCompletionRequest) -> None: + async def handle_request(self, request: InlineCompletionRequest): """Handles an inline completion request without streaming.""" - self.get_llm_chain() - model_arguments = self._template_inputs_from_request(request) - suggestion = await self.llm_chain.ainvoke(input=model_arguments) - suggestion = self._post_process_suggestion(suggestion, request) - self.write_message( - InlineCompletionReply( - list=InlineCompletionList(items=[{"insertText": suggestion}]), - reply_to=request.number, - ) - ) + llm = self.get_llm() + if not llm: + raise ValueError("Please select a model for inline completion.") - def _write_incomplete_reply(self, request: InlineCompletionRequest): - """Writes an incomplete `InlineCompletionReply`, indicating to the - client that LLM output is about to streamed across this connection. - Should be called first in `self.handle_stream_request()`.""" - - token = self._token_from_request(request, 0) - reply = InlineCompletionReply( - list=InlineCompletionList( - items=[ - { - # insert text starts empty as we do not pre-generate any part - "insertText": "", - "isIncomplete": True, - "token": token, - } - ] - ), - reply_to=request.number, - ) - self.write_message(reply) + reply = await llm.generate_inline_completions(request) + self.reply(reply) async def handle_stream_request(self, request: InlineCompletionRequest): - # first, send empty initial reply. - self._write_incomplete_reply(request) - - # then, generate and stream LLM output over this connection. - self.get_llm_chain() - token = self._token_from_request(request, 0) - model_arguments = self._template_inputs_from_request(request) - suggestion = "" - - async for fragment in self.llm_chain.astream(input=model_arguments): - suggestion += fragment - if suggestion.startswith("```"): - if "\n" not in suggestion: - # we are not ready to apply post-processing - continue - else: - suggestion = self._post_process_suggestion(suggestion, request) - elif suggestion.rstrip().endswith("```"): - suggestion = self._post_process_suggestion(suggestion, request) - self.write_message( - InlineCompletionStreamChunk( - type="stream", - response={"insertText": suggestion, "token": token}, - reply_to=request.number, - done=False, - ) - ) - - # finally, send a message confirming that we are done - self.write_message( - InlineCompletionStreamChunk( - type="stream", - response={"insertText": suggestion, "token": token}, - reply_to=request.number, - done=True, - ) - ) - - def _token_from_request(self, request: InlineCompletionRequest, suggestion: int): - """Generate a deterministic token (for matching streamed messages) - using request number and suggestion number""" - return f"t{request.number}s{suggestion}" - - def _template_inputs_from_request(self, request: InlineCompletionRequest) -> Dict: - suffix = request.suffix.strip() - filename = request.path.split("/")[-1] if request.path else "untitled" - - return { - "prefix": request.prefix, - "suffix": suffix, - "language": request.language, - "filename": filename, - "stop": ["\n```"], - } - - def _post_process_suggestion( - self, suggestion: str, request: InlineCompletionRequest - ) -> str: - """Remove spurious fragments from the suggestion. - - While most models (especially instruct and infill models do not require - any pre-processing, some models such as gpt-4 which only have chat APIs - may require removing spurious fragments. This function uses heuristics - and request data to remove such fragments. - """ - # gpt-4 tends to add "```python" or similar - language = request.language or "python" - markdown_identifiers = {"ipython": ["ipython", "python", "py"]} - bad_openings = [ - f"```{identifier}" - for identifier in markdown_identifiers.get(language, [language]) - ] + ["```"] - for opening in bad_openings: - if suggestion.startswith(opening): - suggestion = suggestion[len(opening) :].lstrip() - # check for the prefix inclusion (only if there was a bad opening) - if suggestion.startswith(request.prefix): - suggestion = suggestion[len(request.prefix) :] - break - - # check if the suggestion ends with a closing markdown identifier and remove it - if suggestion.rstrip().endswith("```"): - suggestion = suggestion.rstrip()[:-3].rstrip() + llm = self.get_llm() + if not llm: + raise ValueError("Please select a model for inline completion.") - return suggestion + async for reply in llm.stream_inline_completions(request): + self.reply(reply) diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py index 1371e3cbf..ed454fff6 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Type +from logging import Logger +from typing import Any, Dict, Optional, Type from jupyter_ai.config_manager import ConfigManager from jupyter_ai_magics.providers import BaseProvider @@ -7,34 +8,32 @@ class LLMHandlerMixin: """Base class containing shared methods and attributes used by LLM handler classes.""" - # This could be used to derive `BaseChatHandler` too (there is a lot of duplication!), - # but it was decided against it to avoid introducing conflicts for backports against 1.x - handler_kind: str + settings: dict + log: Logger @property - def config_manager(self) -> ConfigManager: + def jai_config_manager(self) -> ConfigManager: return self.settings["jai_config_manager"] @property def model_parameters(self) -> Dict[str, Dict[str, Any]]: return self.settings["model_parameters"] - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.llm = None - self.llm_params = None - self.llm_chain = None + self._llm: Optional[BaseProvider] = None + self._llm_params = None - def get_llm_chain(self): - lm_provider = self.config_manager.lm_provider - lm_provider_params = self.config_manager.lm_provider_params + def get_llm(self) -> Optional[BaseProvider]: + lm_provider = self.jai_config_manager.lm_provider + lm_provider_params = self.jai_config_manager.lm_provider_params if not lm_provider or not lm_provider_params: return None curr_lm_id = ( - f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None + f'{self._llm.id}:{lm_provider_params["model_id"]}' if self._llm else None ) next_lm_id = ( f'{lm_provider.id}:{lm_provider_params["model_id"]}' @@ -42,19 +41,23 @@ def get_llm_chain(self): else None ) + should_recreate_llm = False if curr_lm_id != next_lm_id: self.log.info( f"Switching {self.handler_kind} language model from {curr_lm_id} to {next_lm_id}." ) - self.create_llm_chain(lm_provider, lm_provider_params) - elif self.llm_params != lm_provider_params: + should_recreate_llm = True + elif self._llm_params != lm_provider_params: self.log.info( f"{self.handler_kind} model params changed, updating the llm chain." ) - self.create_llm_chain(lm_provider, lm_provider_params) + should_recreate_llm = True + + if should_recreate_llm: + self._llm = self.create_llm(lm_provider, lm_provider_params) + self._llm_params = lm_provider_params - self.llm_params = lm_provider_params - return self.llm_chain + return self._llm def get_model_parameters( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -63,7 +66,13 @@ def get_model_parameters( f"{provider.id}:{provider_params['model_id']}", {} ) - def create_llm_chain( + def create_llm( self, provider: Type[BaseProvider], provider_params: Dict[str, str] - ): - raise NotImplementedError("Should be implemented by subclasses") + ) -> BaseProvider: + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + llm = provider(**unified_parameters) + + return llm diff --git a/packages/jupyter-ai/jupyter_ai/completions/models.py b/packages/jupyter-ai/jupyter_ai/completions/models.py index 507365408..e9679379e 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/models.py +++ b/packages/jupyter-ai/jupyter_ai/completions/models.py @@ -1,71 +1,17 @@ -from typing import List, Literal, Optional - -from langchain.pydantic_v1 import BaseModel - - -class InlineCompletionRequest(BaseModel): - """Message send by client to request inline completions. - - Prefix/suffix implementation is used to avoid the need for synchronising - the notebook state at every key press (subject to change in future).""" - - # unique message ID generated by the client used to identify replies and - # to easily discard replies for older requests - number: int - # prefix should include full text of the current cell preceding the cursor - prefix: str - # suffix should include full text of the current cell preceding the cursor - suffix: str - # media type for the current language, e.g. `text/x-python` - mime: str - # whether to stream the response (if supported by the model) - stream: bool - # path to the notebook of file for which the completions are generated - path: Optional[str] - # language inferred from the document mime type (if possible) - language: Optional[str] - # identifier of the cell for which the completions are generated if in a notebook - # previous cells and following cells can be used to learn the wider context - cell_id: Optional[str] - - -class InlineCompletionItem(BaseModel): - """The inline completion suggestion to be displayed on the frontend. - - See JuptyerLab `InlineCompletionItem` documentation for the details. - """ - - insertText: str - filterText: Optional[str] - isIncomplete: Optional[bool] - token: Optional[str] - - -class CompletionError(BaseModel): - type: str - traceback: str - - -class InlineCompletionList(BaseModel): - """Reflection of JupyterLab's `IInlineCompletionList`.""" - - items: List[InlineCompletionItem] - - -class InlineCompletionReply(BaseModel): - """Message sent from model to client with the infill suggestions""" - - list: InlineCompletionList - # number of request for which we are replying - reply_to: int - error: Optional[CompletionError] - - -class InlineCompletionStreamChunk(BaseModel): - """Message sent from model to client with the infill suggestions""" - - type: Literal["stream"] = "stream" - response: InlineCompletionItem - reply_to: int - done: bool - error: Optional[CompletionError] +from jupyter_ai_magics.models.completion import ( + CompletionError, + InlineCompletionItem, + InlineCompletionList, + InlineCompletionReply, + InlineCompletionRequest, + InlineCompletionStreamChunk, +) + +__all__ = [ + "InlineCompletionRequest", + "InlineCompletionItem", + "CompletionError", + "InlineCompletionList", + "InlineCompletionReply", + "InlineCompletionStreamChunk", +] diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index fd2b2666c..8597463f2 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -1,9 +1,14 @@ import json from types import SimpleNamespace +from typing import Union import pytest from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler -from jupyter_ai.completions.models import InlineCompletionRequest +from jupyter_ai.completions.models import ( + InlineCompletionReply, + InlineCompletionRequest, + InlineCompletionStreamChunk, +) from jupyter_ai_magics import BaseProvider from langchain_community.llms import FakeListLLM from pytest import fixture @@ -18,28 +23,31 @@ class MockProvider(BaseProvider, FakeListLLM): models = ["model"] def __init__(self, **kwargs): - if not "responses" in kwargs: + if "responses" not in kwargs: kwargs["responses"] = ["Test response"] super().__init__(**kwargs) class MockCompletionHandler(DefaultInlineCompletionHandler): - def __init__(self): + def __init__(self, lm_provider=None, lm_provider_params=None): self.request = HTTPServerRequest() self.application = Application() self.messages = [] self.tasks = [] self.settings["jai_config_manager"] = SimpleNamespace( - lm_provider=MockProvider, lm_provider_params={"model_id": "model"} + lm_provider=lm_provider or MockProvider, + lm_provider_params=lm_provider_params or {"model_id": "model"}, ) self.settings["jai_event_loop"] = SimpleNamespace( create_task=lambda x: self.tasks.append(x) ) self.settings["model_parameters"] = {} - self.llm_params = {"model_id": "model"} - self.create_llm_chain(MockProvider, {"model_id": "model"}) + self._llm_params = {} + self._llm = None - def write_message(self, message: str) -> None: # type: ignore + def reply( + self, message: Union[InlineCompletionReply, InlineCompletionStreamChunk] + ) -> None: self.messages.append(message) async def handle_exc(self, e: Exception, _request: InlineCompletionRequest): @@ -99,10 +107,9 @@ async def test_handle_request(inline_handler): ], ) async def test_handle_request_with_spurious_fragments(response, expected_suggestion): - inline_handler = MockCompletionHandler() - inline_handler.create_llm_chain( - MockProvider, - { + inline_handler = MockCompletionHandler( + lm_provider=MockProvider, + lm_provider_params={ "model_id": "model", "responses": [response], }, @@ -121,10 +128,10 @@ async def test_handle_request_with_spurious_fragments(response, expected_suggest assert suggestions[0].insertText == expected_suggestion -async def test_handle_stream_request(inline_handler): - inline_handler.create_llm_chain( - MockProvider, - { +async def test_handle_stream_request(): + inline_handler = MockCompletionHandler( + lm_provider=MockProvider, + lm_provider_params={ "model_id": "model", "responses": ["test"], }, @@ -140,16 +147,16 @@ async def test_handle_stream_request(inline_handler): # first reply should be empty to start the stream first = inline_handler.messages[0].list.items[0] assert first.insertText == "" - assert first.isIncomplete == True + assert first.isIncomplete is True # second reply should be a chunk containing the token second = inline_handler.messages[1] assert second.type == "stream" assert second.response.insertText == "test" - assert second.done == False + assert second.done is False # third reply should be a closing chunk third = inline_handler.messages[2] assert third.type == "stream" assert third.response.insertText == "test" - assert third.done == True + assert third.done is True