Skip to content

Commit

Permalink
Move methods generating completion replies to provider
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski committed Apr 5, 2024
1 parent 3bfce32 commit 180463d
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 245 deletions.
12 changes: 12 additions & 0 deletions docs/source/developers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,18 @@ my-provider = "my_provider:MyEmbeddingsProvider"

[Embeddings]: https://api.python.langchain.com/en/stable/embeddings/langchain_core.embeddings.Embeddings.html


### Custom completion providers

Any model provider derived from `BaseProvider` can be used as a completion provider.
However, some providers may benefit from customizing handling of completion requests.

There are two asynchronous methods which can be overridden in subclasses of `BaseProvider`:
- `generate_inline_completions`: takes a request (`InlineCompletionRequest`) and returns `InlineCompletionReply`
- `stream_inline_completions`: takes a request and yields an initiating reply (`InlineCompletionReply`) with `isIncomplete` set to `True` followed by subsequent chunks (`InlineCompletionStreamChunk`)

When streaming all replies and chunks for given invocation of the `stream_inline_completions()` method should include a constant and unique string token identifying the stream. All chunks except for the last chunk for a given item should have the `done` value set to `False`.

## Prompt templates

Each provider can define **prompt templates** for each supported format. A prompt
Expand Down
84 changes: 83 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import io
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union
from typing import (
Any,
AsyncIterator,
ClassVar,
Coroutine,
Dict,
List,
Literal,
Optional,
Union,
)

from jsonpath_ng import parse
from langchain.chat_models.base import BaseChatModel
Expand All @@ -20,6 +30,8 @@
)
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema import LLMResult
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.utils import get_from_dict_or_env
from langchain_community.chat_models import (
BedrockChat,
Expand All @@ -46,6 +58,13 @@
except:
from pydantic.main import ModelMetaclass

from . import _completion as completion
from .models.completion import (
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
)
from .models.persona import Persona

CHAT_SYSTEM_PROMPT = """
Expand Down Expand Up @@ -395,6 +414,69 @@ def is_chat_provider(self):
def allows_concurrency(self):
return True

async def generate_inline_completions(
self, request: InlineCompletionRequest
) -> InlineCompletionReply:
chain = self._create_completion_chain()
model_arguments = completion.template_inputs_from_request(request)
suggestion = await chain.ainvoke(input=model_arguments)
suggestion = completion.post_process_suggestion(suggestion, request)
return InlineCompletionReply(
list=InlineCompletionList(items=[{"insertText": suggestion}]),
reply_to=request.number,
)

async def stream_inline_completions(
self, request: InlineCompletionRequest
) -> AsyncIterator[InlineCompletionStreamChunk]:
chain = self._create_completion_chain()
token = completion.token_from_request(request, 0)
model_arguments = completion.template_inputs_from_request(request)
suggestion = ""

# send an incomplete `InlineCompletionReply`, indicating to the
# client that LLM output is about to streamed across this connection.
yield InlineCompletionReply(
list=InlineCompletionList(
items=[
{
# insert text starts empty as we do not pre-generate any part
"insertText": "",
"isIncomplete": True,
"token": token,
}
]
),
reply_to=request.number,
)

async for fragment in chain.astream(input=model_arguments):
suggestion += fragment
if suggestion.startswith("```"):
if "\n" not in suggestion:
# we are not ready to apply post-processing
continue
else:
suggestion = completion.post_process_suggestion(suggestion, request)
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=False,
)

# finally, send a message confirming that we are done
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=True,
)

def _create_completion_chain(self) -> Runnable:
prompt_template = self.get_completion_prompt_template()
return prompt_template | self | StrOutputParser()


class AI21Provider(BaseProvider, AI21):
id = "ai21"
Expand Down
13 changes: 4 additions & 9 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import traceback
from asyncio import AbstractEventLoop
from typing import Any, AsyncIterator, Dict, Union
from typing import Any, Dict, Union

import tornado
from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin
Expand All @@ -11,7 +11,6 @@
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
)
from jupyter_server.base.handlers import JupyterHandler
from langchain.pydantic_v1 import BaseModel, ValidationError
Expand All @@ -27,9 +26,7 @@ class BaseInlineCompletionHandler(
##
# Interface for subclasses
##
async def handle_request(
self, message: InlineCompletionRequest
) -> InlineCompletionReply:
async def handle_request(self, message: InlineCompletionRequest) -> None:
"""
Handles an inline completion request, without streaming. Subclasses
must define this method and write a reply via `self.write_message()`.
Expand All @@ -40,9 +37,7 @@ async def handle_request(
"The required method `self.handle_request()` is not defined by this subclass."
)

async def handle_stream_request(
self, message: InlineCompletionRequest
) -> AsyncIterator[InlineCompletionStreamChunk]:
async def handle_stream_request(self, message: InlineCompletionRequest) -> None:
"""
Handles an inline completion request, **with streaming**.
Implementations may optionally define this method. Implementations that
Expand All @@ -64,7 +59,7 @@ async def handle_stream_request(
def loop(self) -> AbstractEventLoop:
return self.settings["jai_event_loop"]

def write_message(self, message: Union[bytes, str, Dict[str, Any], BaseModel]):
def write(self, message: Union[bytes, str, Dict[str, Any], BaseModel]):
"""
Write a bytes, string, dict, or Pydantic model object to the WebSocket
connection. The base definition of this method is provided by Tornado.
Expand Down
154 changes: 12 additions & 142 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,154 +1,24 @@
from typing import Dict, Type

from jupyter_ai_magics.providers import BaseProvider
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable

from ..models import (
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
)
from ..models import InlineCompletionRequest
from .base import BaseInlineCompletionHandler


class DefaultInlineCompletionHandler(BaseInlineCompletionHandler):
llm_chain: Runnable

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
unified_parameters = {
**provider_params,
**(self.get_model_parameters(provider, provider_params)),
}
llm = provider(**unified_parameters)

prompt_template = llm.get_completion_prompt_template()

self.llm = llm
self.llm_chain = prompt_template | llm | StrOutputParser()

async def handle_request(self, request: InlineCompletionRequest) -> None:
async def handle_request(self, request: InlineCompletionRequest):
"""Handles an inline completion request without streaming."""
self.get_llm_chain()
model_arguments = self._template_inputs_from_request(request)
suggestion = await self.llm_chain.ainvoke(input=model_arguments)
suggestion = self._post_process_suggestion(suggestion, request)
self.write_message(
InlineCompletionReply(
list=InlineCompletionList(items=[{"insertText": suggestion}]),
reply_to=request.number,
)
)

def _write_incomplete_reply(self, request: InlineCompletionRequest):
"""Writes an incomplete `InlineCompletionReply`, indicating to the
client that LLM output is about to streamed across this connection.
Should be called first in `self.handle_stream_request()`."""
llm = self.get_llm()
if not llm:
raise ValueError("Please select a model for inline completion.")

token = self._token_from_request(request, 0)
reply = InlineCompletionReply(
list=InlineCompletionList(
items=[
{
# insert text starts empty as we do not pre-generate any part
"insertText": "",
"isIncomplete": True,
"token": token,
}
]
),
reply_to=request.number,
)
self.write_message(reply)
reply = await llm.generate_inline_completions(request)
self.write(reply)

async def handle_stream_request(self, request: InlineCompletionRequest):
# first, send empty initial reply.
self._write_incomplete_reply(request)

# then, generate and stream LLM output over this connection.
self.get_llm_chain()
token = self._token_from_request(request, 0)
model_arguments = self._template_inputs_from_request(request)
suggestion = ""

async for fragment in self.llm_chain.astream(input=model_arguments):
suggestion += fragment
if suggestion.startswith("```"):
if "\n" not in suggestion:
# we are not ready to apply post-processing
continue
else:
suggestion = self._post_process_suggestion(suggestion, request)
self.write_message(
InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=False,
)
)

# finally, send a message confirming that we are done
self.write_message(
InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=True,
)
)

def _token_from_request(self, request: InlineCompletionRequest, suggestion: int):
"""Generate a deterministic token (for matching streamed messages)
using request number and suggestion number"""
return f"t{request.number}s{suggestion}"

def _template_inputs_from_request(self, request: InlineCompletionRequest) -> Dict:
suffix = request.suffix.strip()
filename = request.path.split("/")[-1] if request.path else "untitled"

return {
"prefix": request.prefix,
"suffix": suffix,
"language": request.language,
"filename": filename,
"stop": ["\n```"],
}

def _post_process_suggestion(
self, suggestion: str, request: InlineCompletionRequest
) -> str:
"""Remove spurious fragments from the suggestion.
llm = self.get_llm()
if not llm:
raise ValueError("Please select a model for inline completion.")

While most models (especially instruct and infill models do not require
any pre-processing, some models such as gpt-4 which only have chat APIs
may require removing spurious fragments. This function uses heuristics
and request data to remove such fragments.
"""
# gpt-4 tends to add "```python" or similar
language = request.language or "python"
markdown_identifiers = {"ipython": ["ipython", "python", "py"]}
bad_openings = [
f"```{identifier}"
for identifier in markdown_identifiers.get(language, [language])
] + ["```"]
for opening in bad_openings:
if suggestion.startswith(opening):
suggestion = suggestion[len(opening) :].lstrip()
# check for the prefix inclusion (only if there was a bad opening)
if suggestion.startswith(request.prefix):
suggestion = suggestion[len(request.prefix) :]
break
return suggestion
async for reply in llm.stream_inline_completions(request):
self.write(reply)
Loading

0 comments on commit 180463d

Please sign in to comment.