Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move methods generating completion replies to the provider #717

Merged
merged 2 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions docs/source/developers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,84 @@ my-provider = "my_provider:MyEmbeddingsProvider"

[Embeddings]: https://api.python.langchain.com/en/stable/embeddings/langchain_core.embeddings.Embeddings.html


### Custom completion providers

Any model provider derived from `BaseProvider` can be used as a completion provider.
However, some providers may benefit from customizing handling of completion requests.

There are two asynchronous methods which can be overridden in subclasses of `BaseProvider`:
- `generate_inline_completions`: takes a request (`InlineCompletionRequest`) and returns `InlineCompletionReply`
- `stream_inline_completions`: takes a request and yields an initiating reply (`InlineCompletionReply`) with `isIncomplete` set to `True` followed by subsequent chunks (`InlineCompletionStreamChunk`)

When streaming all replies and chunks for given invocation of the `stream_inline_completions()` method should include a constant and unique string token identifying the stream. All chunks except for the last chunk for a given item should have the `done` value set to `False`.

The following example demonstrates a custom implementation of the completion provider with both a method for sending multiple completions in one go, and streaming multiple completions concurrently.
The implementation and explanation for the `merge_iterators` function used in this example can be found [here](https://stackoverflow.com/q/72445371/4877269).

```python
class MyCompletionProvider(BaseProvider, FakeListLLM):
id = "my_provider"
name = "My Provider"
model_id_key = "model"
models = ["model_a"]

def __init__(self, **kwargs):
kwargs["responses"] = ["This fake response will not be used for completion"]
super().__init__(**kwargs)

async def generate_inline_completions(self, request: InlineCompletionRequest):
return InlineCompletionReply(
list=InlineCompletionList(items=[
{"insertText": "An ant minding its own business"},
{"insertText": "A bug searching for a snack"}
]),
reply_to=request.number,
)

async def stream_inline_completions(self, request: InlineCompletionRequest):
token_1 = f"t{request.number}s0"
token_2 = f"t{request.number}s1"

yield InlineCompletionReply(
list=InlineCompletionList(
items=[
{"insertText": "An ", "isIncomplete": True, "token": token_1},
{"insertText": "", "isIncomplete": True, "token": token_2}
]
),
reply_to=request.number,
)

# where merge_iterators
async for reply in merge_iterators([
self._stream("elephant dancing in the rain", request.number, token_1, start_with="An"),
self._stream("A flock of birds flying around a mountain", request.number, token_2)
]):
yield reply

async def _stream(self, sentence, request_number, token, start_with = ""):
suggestion = start_with

for fragment in sentence.split():
await asyncio.sleep(0.75)
suggestion += " " + fragment
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request_number,
done=False
)

# finally, send a message confirming that we are done
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request_number,
done=True,
)
```

## Prompt templates

Each provider can define **prompt templates** for each supported format. A prompt
Expand Down
52 changes: 52 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Dict

from .models.completion import InlineCompletionRequest


def token_from_request(request: InlineCompletionRequest, suggestion: int):
"""Generate a deterministic token (for matching streamed messages)
using request number and suggestion number"""
return f"t{request.number}s{suggestion}"


def template_inputs_from_request(request: InlineCompletionRequest) -> Dict:
suffix = request.suffix.strip()
filename = request.path.split("/")[-1] if request.path else "untitled"

return {
"prefix": request.prefix,
"suffix": suffix,
"language": request.language,
"filename": filename,
"stop": ["\n```"],
}


def post_process_suggestion(suggestion: str, request: InlineCompletionRequest) -> str:
"""Remove spurious fragments from the suggestion.

While most models (especially instruct and infill models do not require
any pre-processing, some models such as gpt-4 which only have chat APIs
may require removing spurious fragments. This function uses heuristics
and request data to remove such fragments.
"""
# gpt-4 tends to add "```python" or similar
language = request.language or "python"
markdown_identifiers = {"ipython": ["ipython", "python", "py"]}
bad_openings = [
f"```{identifier}"
for identifier in markdown_identifiers.get(language, [language])
] + ["```"]
for opening in bad_openings:
if suggestion.startswith(opening):
suggestion = suggestion[len(opening) :].lstrip()
# check for the prefix inclusion (only if there was a bad opening)
if suggestion.startswith(request.prefix):
suggestion = suggestion[len(request.prefix) :]
break

# check if the suggestion ends with a closing markdown identifier and remove it
if suggestion.rstrip().endswith("```"):
suggestion = suggestion.rstrip()[:-3].rstrip()

return suggestion
81 changes: 81 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import List, Literal, Optional

from langchain.pydantic_v1 import BaseModel


class InlineCompletionRequest(BaseModel):
"""Message send by client to request inline completions.

Prefix/suffix implementation is used to avoid the need for synchronising
the notebook state at every key press (subject to change in future)."""

# unique message ID generated by the client used to identify replies and
# to easily discard replies for older requests
number: int
# prefix should include full text of the current cell preceding the cursor
prefix: str
# suffix should include full text of the current cell preceding the cursor
suffix: str
# media type for the current language, e.g. `text/x-python`
mime: str
# whether to stream the response (if supported by the model)
stream: bool
# path to the notebook of file for which the completions are generated
path: Optional[str]
# language inferred from the document mime type (if possible)
language: Optional[str]
# identifier of the cell for which the completions are generated if in a notebook
# previous cells and following cells can be used to learn the wider context
cell_id: Optional[str]


class InlineCompletionItem(BaseModel):
"""The inline completion suggestion to be displayed on the frontend.

See JupyterLab `InlineCompletionItem` documentation for the details.
"""

insertText: str
filterText: Optional[str]
isIncomplete: Optional[bool]
token: Optional[str]


class CompletionError(BaseModel):
type: str
traceback: str


class InlineCompletionList(BaseModel):
"""Reflection of JupyterLab's `IInlineCompletionList`."""

items: List[InlineCompletionItem]


class InlineCompletionReply(BaseModel):
"""Message sent from model to client with the infill suggestions"""

list: InlineCompletionList
# number of request for which we are replying
reply_to: int
error: Optional[CompletionError]


class InlineCompletionStreamChunk(BaseModel):
"""Message sent from model to client with the infill suggestions"""

type: Literal["stream"] = "stream"
response: InlineCompletionItem
reply_to: int
done: bool
error: Optional[CompletionError]


__all__ = [
"InlineCompletionRequest",
"InlineCompletionItem",
"CompletionError",
"InlineCompletionList",
"InlineCompletionReply",
"InlineCompletionStreamChunk",
]
86 changes: 85 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import io
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union
from typing import (
Any,
AsyncIterator,
ClassVar,
Coroutine,
Dict,
List,
Literal,
Optional,
Union,
)

from jsonpath_ng import parse
from langchain.chat_models.base import BaseChatModel
Expand All @@ -20,6 +30,8 @@
)
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema import LLMResult
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.utils import get_from_dict_or_env
from langchain_community.chat_models import (
BedrockChat,
Expand All @@ -46,6 +58,13 @@
except:
from pydantic.main import ModelMetaclass

from . import completion_utils as completion
from .models.completion import (
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
)
from .models.persona import Persona

CHAT_SYSTEM_PROMPT = """
Expand Down Expand Up @@ -395,6 +414,71 @@ def is_chat_provider(self):
def allows_concurrency(self):
return True

async def generate_inline_completions(
self, request: InlineCompletionRequest
) -> InlineCompletionReply:
chain = self._create_completion_chain()
model_arguments = completion.template_inputs_from_request(request)
suggestion = await chain.ainvoke(input=model_arguments)
suggestion = completion.post_process_suggestion(suggestion, request)
return InlineCompletionReply(
list=InlineCompletionList(items=[{"insertText": suggestion}]),
reply_to=request.number,
)

async def stream_inline_completions(
self, request: InlineCompletionRequest
) -> AsyncIterator[InlineCompletionStreamChunk]:
chain = self._create_completion_chain()
token = completion.token_from_request(request, 0)
model_arguments = completion.template_inputs_from_request(request)
suggestion = ""

# send an incomplete `InlineCompletionReply`, indicating to the
# client that LLM output is about to streamed across this connection.
yield InlineCompletionReply(
list=InlineCompletionList(
items=[
{
# insert text starts empty as we do not pre-generate any part
"insertText": "",
"isIncomplete": True,
"token": token,
}
]
),
reply_to=request.number,
)

async for fragment in chain.astream(input=model_arguments):
suggestion += fragment
if suggestion.startswith("```"):
if "\n" not in suggestion:
# we are not ready to apply post-processing
continue
else:
suggestion = completion.post_process_suggestion(suggestion, request)
elif suggestion.rstrip().endswith("```"):
suggestion = completion.post_process_suggestion(suggestion, request)
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=False,
)

# finally, send a message confirming that we are done
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=True,
)

def _create_completion_chain(self) -> Runnable:
prompt_template = self.get_completion_prompt_template()
return prompt_template | self | StrOutputParser()


class AI21Provider(BaseProvider, AI21):
id = "ai21"
Expand Down
Loading
Loading