Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify inline completion backend #553

Merged
merged 2 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 138 additions & 45 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,169 @@
import json
import time
import traceback
from asyncio import AbstractEventLoop
from typing import Any, AsyncIterator, Dict, Union

# necessary to prevent circular import
from typing import TYPE_CHECKING, AsyncIterator, Dict

import tornado
from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin
from jupyter_ai.completions.models import (
CompletionError,
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
ModelChangedNotification,
)
from jupyter_ai.config_manager import ConfigManager, Logger

if TYPE_CHECKING:
from jupyter_ai.handlers import InlineCompletionHandler
from jupyter_server.base.handlers import JupyterHandler
from langchain.pydantic_v1 import BaseModel, ValidationError


class BaseInlineCompletionHandler(LLMHandlerMixin):
"""Class implementing completion handling."""
class BaseInlineCompletionHandler(
LLMHandlerMixin, JupyterHandler, tornado.websocket.WebSocketHandler
):
"""A Tornado WebSocket handler that receives inline completion requests and
fulfills them accordingly. This class is instantiated once per WebSocket
connection."""

handler_kind = "completion"

def __init__(
self,
log: Logger,
config_manager: ConfigManager,
model_parameters: Dict[str, Dict],
ws_sessions: Dict[str, "InlineCompletionHandler"],
):
super().__init__(log, config_manager, model_parameters)
self.ws_sessions = ws_sessions

async def on_message(
self, message: InlineCompletionRequest
) -> InlineCompletionReply:
try:
return await self.process_message(message)
except Exception as e:
return await self._handle_exc(e, message)

async def process_message(
##
# Interface for subclasses
##
async def handle_request(
self, message: InlineCompletionRequest
) -> InlineCompletionReply:
"""
Processes an inline completion request. Completion handlers
(subclasses) must implement this method.
Handles an inline completion request, without streaming. Subclasses
must define this method and write a reply via `self.write_message()`.

The method definition does not need to be wrapped in a try/except block.
"""
raise NotImplementedError("Should be implemented by subclasses.")
raise NotImplementedError(
"The required method `self.handle_request()` is not defined by this subclass."
)

async def stream(
async def handle_stream_request(
self, message: InlineCompletionRequest
) -> AsyncIterator[InlineCompletionStreamChunk]:
""" "
Stream the inline completion as it is generated. Completion handlers
(subclasses) can implement this method.
"""
raise NotImplementedError()
Handles an inline completion request, **with streaming**.
Implementations may optionally define this method. Implementations that
do so should stream replies via successive calls to
`self.write_message()`.

The method definition does not need to be wrapped in a try/except block.
"""
raise NotImplementedError(
"The optional method `self.handle_stream_request()` is not defined by this subclass."
)

##
# Definition of base class
##
handler_kind = "completion"

@property
def loop(self) -> AbstractEventLoop:
return self.settings["jai_event_loop"]

def write_message(self, message: Union[bytes, str, Dict[str, Any], BaseModel]):
"""
Write a bytes, string, dict, or Pydantic model object to the WebSocket
connection. The base definition of this method is provided by Tornado.
"""
if isinstance(message, BaseModel):
message = message.dict()

super().write_message(message)

def initialize(self):
self.log.debug("Initializing websocket connection %s", self.request.path)

def pre_get(self):
"""Handles authentication/authorization."""
# authenticate the request before opening the websocket
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise tornado.web.HTTPError(403)

async def _handle_exc(self, e: Exception, message: InlineCompletionRequest):
# authorize the user.
if not self.authorizer.is_authorized(self, user, "execute", "events"):
raise tornado.web.HTTPError(403)

async def get(self, *args, **kwargs):
"""Get an event socket."""
self.pre_get()
res = super().get(*args, **kwargs)
await res

async def on_message(self, message):
"""Public Tornado method that is called when the client sends a message
over this connection. This should **not** be overriden by subclasses."""

# first, verify that the message is an `InlineCompletionRequest`.
self.log.debug("Message received: %s", message)
try:
message = json.loads(message)
request = InlineCompletionRequest(**message)
except ValidationError as e:
self.log.error(e)
return

# next, dispatch the request to the correct handler and create the
# `handle_request` coroutine object
handle_request = None
if request.stream:
try:
handle_request = self._handle_stream_request(request)
except NotImplementedError:
self.log.error(
"Unable to handle stream request. The current `InlineCompletionHandler` does not implement the `handle_stream_request()` method."
)
return

else:
handle_request = self._handle_request(request)

# finally, wrap `handle_request` in an exception handler, and start the
# task on the event loop.
async def handle_request_and_catch():
try:
await handle_request
except Exception as e:
await self.handle_exc(e, request)

self.loop.create_task(handle_request_and_catch())

async def handle_exc(self, e: Exception, request: InlineCompletionRequest):
"""
Handles an exception raised in either `handle_request()` or
`handle_stream_request()`. This base class provides a default
implementation, which may be overriden by subclasses.
"""
error = CompletionError(
type=e.__class__.__name__,
title=e.args[0] if e.args else "Exception",
traceback=traceback.format_exc(),
)
return InlineCompletionReply(
list=InlineCompletionList(items=[]), error=error, reply_to=message.number
self.write_message(
InlineCompletionReply(
list=InlineCompletionList(items=[]),
error=error,
reply_to=request.number,
)
)

def broadcast(self, message: ModelChangedNotification):
for session in self.ws_sessions.values():
session.write_message(message.dict())
async def _handle_request(self, request: InlineCompletionRequest):
"""Private wrapper around `self.handle_request()`."""
start = time.time()
await self.handle_request(request)
latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Inline completion handler resolved in {latency_ms} ms.")

async def _handle_stream_request(self, request: InlineCompletionRequest):
"""Private wrapper around `self.handle_stream_request()`."""
start = time.time()
await self._handle_stream_request(request)
async for chunk in self.stream(request):
self.write_message(chunk.dict())
latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Inline completion streaming completed in {latency_ms} ms.")
90 changes: 49 additions & 41 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
ModelChangedNotification,
)
from .base import BaseInlineCompletionHandler

Expand Down Expand Up @@ -55,15 +54,6 @@ def __init__(self, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
next_lm_id = (
f'{lm_provider.id}:{lm_provider_params["model_id"]}'
if lm_provider
else None
)
self.broadcast(ModelChangedNotification(model=next_lm_id))

model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

Expand All @@ -83,39 +73,52 @@ def create_llm_chain(
self.llm = llm
self.llm_chain = prompt_template | llm | StrOutputParser()

async def process_message(
async def handle_request(
self, request: InlineCompletionRequest
) -> InlineCompletionReply:
if request.stream:
token = self._token_from_request(request, 0)
return InlineCompletionReply(
list=InlineCompletionList(
items=[
{
# insert text starts empty as we do not pre-generate any part
"insertText": "",
"isIncomplete": True,
"token": token,
}
]
),
reply_to=request.number,
)
else:
self.get_llm_chain()
model_arguments = self._template_inputs_from_request(request)
suggestion = await self.llm_chain.ainvoke(input=model_arguments)
suggestion = self._post_process_suggestion(suggestion, request)
return InlineCompletionReply(
"""Handles an inline completion request without streaming."""
self.get_llm_chain()
model_arguments = self._template_inputs_from_request(request)
suggestion = await self.llm_chain.ainvoke(input=model_arguments)
suggestion = self._post_process_suggestion(suggestion, request)
self.write_message(
InlineCompletionReply(
list=InlineCompletionList(items=[{"insertText": suggestion}]),
reply_to=request.number,
)
)

async def stream(self, request: InlineCompletionRequest):
def _write_incomplete_reply(self, request: InlineCompletionRequest):
"""Writes an incomplete `InlineCompletionReply`, indicating to the
client that LLM output is about to streamed across this connection.
Should be called first in `self.handle_stream_request()`."""

token = self._token_from_request(request, 0)
reply = InlineCompletionReply(
list=InlineCompletionList(
items=[
{
# insert text starts empty as we do not pre-generate any part
"insertText": "",
"isIncomplete": True,
"token": token,
}
]
),
reply_to=request.number,
)
self.write_message(reply)

async def handle_stream_request(self, request: InlineCompletionRequest):
# first, send empty initial reply.
self._write_incomplete_reply()

# then, generate and stream LLM output over this connection.
self.get_llm_chain()
token = self._token_from_request(request, 0)
model_arguments = self._template_inputs_from_request(request)
suggestion = ""

async for fragment in self.llm_chain.astream(input=model_arguments):
suggestion += fragment
if suggestion.startswith("```"):
Expand All @@ -124,18 +127,23 @@ async def stream(self, request: InlineCompletionRequest):
continue
else:
suggestion = self._post_process_suggestion(suggestion, request)
yield InlineCompletionStreamChunk(
self.write_message(
InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=False,
)
)

# finally, send a message confirming that we are done
self.write_message(
InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=False,
done=True,
)
# at the end send a message confirming that we are done
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=True,
)

def _token_from_request(self, request: InlineCompletionRequest, suggestion: int):
Expand Down
28 changes: 12 additions & 16 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Type
from typing import Any, Dict, Type

from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.config_manager import ConfigManager
from jupyter_ai_magics.providers import BaseProvider


Expand All @@ -12,23 +12,20 @@ class LLMHandlerMixin:

handler_kind: str

def __init__(
self,
log: Logger,
config_manager: ConfigManager,
model_parameters: Dict[str, Dict],
):
self.log = log
self.config_manager = config_manager
self.model_parameters = model_parameters
@property
def config_manager(self) -> ConfigManager:
return self.settings["jai_config_manager"]

@property
def model_parameters(self) -> Dict[str, Dict[str, Any]]:
return self.settings["model_parameters"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.llm = None
self.llm_params = None
self.llm_chain = None

def model_changed_callback(self):
"""Method which can be overridden in sub-classes to listen to model change."""
pass

def get_llm_chain(self):
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
Expand All @@ -50,7 +47,6 @@ def get_llm_chain(self):
f"Switching {self.handler_kind} language model from {curr_lm_id} to {next_lm_id}."
)
self.create_llm_chain(lm_provider, lm_provider_params)
self.model_changed_callback()
elif self.llm_params != lm_provider_params:
self.log.info(
f"{self.handler_kind} model params changed, updating the llm chain."
Expand Down
Loading
Loading