Skip to content

Commit

Permalink
Implement streaming support
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski committed Nov 27, 2023
1 parent aa4c7ff commit 85ecc02
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 28 deletions.
12 changes: 11 additions & 1 deletion packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import traceback

# necessary to prevent circular import
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, AsyncIterator, Dict

from jupyter_ai.completions.models import (
CompletionError,
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
ModelChangedNotification,
)
from jupyter_ai.config_manager import ConfigManager, Logger
Expand Down Expand Up @@ -51,6 +52,15 @@ async def process_message(
"""
raise NotImplementedError("Should be implemented by subclasses.")

async def stream(
self, message: InlineCompletionRequest
) -> AsyncIterator[InlineCompletionStreamChunk]:
""" "
Stream the inline completion as it is generated. Completion handlers
(subclasses) can implement this method.
"""
raise NotImplementedError()

async def _handle_exc(self, e: Exception, message: InlineCompletionRequest):
error = CompletionError(
type=e.__class__.__name__,
Expand Down
88 changes: 72 additions & 16 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from typing import Dict, Type

from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable

from ..models import (
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
ModelChangedNotification,
)
from .base import BaseInlineCompletionHandler
Expand Down Expand Up @@ -45,7 +47,7 @@


class DefaultInlineCompletionHandler(BaseInlineCompletionHandler):
llm_chain: LLMChain
llm_chain: Runnable

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -79,28 +81,82 @@ def create_llm_chain(
)

self.llm = llm
self.llm_chain = LLMChain(llm=llm, prompt=prompt_template, verbose=True)
self.llm_chain = prompt_template | llm | StrOutputParser()

async def process_message(
self, request: InlineCompletionRequest
) -> InlineCompletionReply:
self.get_llm_chain()
suffix = request.suffix.strip()
prediction = await self.llm_chain.apredict(
prefix=request.prefix,
# only add the suffix template if the suffix is there to save input tokens/computation time
after=AFTER_TEMPLATE.format(suffix=suffix) if suffix else "",
language=request.language,
filename=request.path.split("/")[-1] if request.path else "untitled",
stop=["\n```"],
)
prediction = self._post_process_suggestion(prediction, request)
if request.stream:
token = self._token_from_request(request, 0)
return InlineCompletionReply(
list=InlineCompletionList(
items=[
{
# insert text starts empty as we do not pre-generate any part
"insertText": "",
"isIncomplete": True,
"token": token,
}
]
),
reply_to=request.number,
)
else:
self.get_llm_chain()
model_arguments = self._template_inputs_from_request(request)
suggestion = await self.llm_chain.ainvoke(input=model_arguments)
suggestion = self._post_process_suggestion(suggestion, request)
return InlineCompletionReply(
list=InlineCompletionList(items=[{"insertText": suggestion}]),
reply_to=request.number,
)

return InlineCompletionReply(
list=InlineCompletionList(items=[{"insertText": prediction}]),
async def stream(self, request: InlineCompletionRequest):
self.get_llm_chain()
token = self._token_from_request(request, 0)
model_arguments = self._template_inputs_from_request(request)
suggestion = ""
async for fragment in self.llm_chain.astream(input=model_arguments):
suggestion += fragment
if suggestion.startswith("```"):
if "\n" not in suggestion:
# we are not ready to apply post-processing
continue
else:
suggestion = self._post_process_suggestion(suggestion, request)
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=False,
)
# at the end send a message confirming that we are done
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=True,
)

def _token_from_request(self, request: InlineCompletionRequest, suggestion: int):
"""Generate a deterministic token (for matching streamed messages)
using request number and suggestion number"""
return f"t{request.number}s{suggestion}"

def _template_inputs_from_request(self, request: InlineCompletionRequest) -> Dict:
suffix = request.suffix.strip()
# only add the suffix template if the suffix is there to save input tokens/computation time
after = AFTER_TEMPLATE.format(suffix=suffix) if suffix else ""
filename = request.path.split("/")[-1] if request.path else "untitled"

return {
"prefix": request.prefix,
"after": after,
"language": request.language,
"filename": filename,
"stop": ["\n```"],
}

def _post_process_suggestion(
self, suggestion: str, request: InlineCompletionRequest
) -> str:
Expand Down
14 changes: 13 additions & 1 deletion packages/jupyter-ai/jupyter_ai/completions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class InlineCompletionRequest(BaseModel):
suffix: str
# media type for the current language, e.g. `text/x-python`
mime: str
# whether to stream the response (if supported by the model)
stream: bool
# path to the notebook of file for which the completions are generated
path: Optional[str]
# language inferred from the document mime type (if possible)
Expand All @@ -36,7 +38,7 @@ class InlineCompletionItem(BaseModel):
insertText: str
filterText: Optional[str]
isIncomplete: Optional[bool]
token: Optional[bool]
token: Optional[str]


class CompletionError(BaseModel):
Expand All @@ -59,6 +61,16 @@ class InlineCompletionReply(BaseModel):
error: Optional[CompletionError]


class InlineCompletionStreamChunk(BaseModel):
"""Message sent from model to client with the infill suggestions"""

type: Literal["stream"] = "stream"
response: InlineCompletionItem
reply_to: int
done: bool
error: Optional[CompletionError]


class ModelChangedNotification(BaseModel):
type: Literal["model_changed"] = "model_changed"
model: Optional[str]
23 changes: 19 additions & 4 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,17 @@ async def on_message(self, message):
self.log.error(e)
return

# do not await this, as it blocks the parent task responsible for
# handling messages from a websocket. instead, process each message
# as a distinct concurrent task.
self.loop.create_task(self._complete(request))
if request.stream:
try:
stream_coroutine = self._stream(request)
except NotImplementedError:
self.log.info(
"Not streaming as handler does not implement stream() method"
)
await self._complete(request)
self.loop.create_task(stream_coroutine)
else:
self.loop.create_task(self._complete(request))

async def _complete(self, request: InlineCompletionRequest):
start = time.time()
Expand All @@ -317,6 +324,14 @@ async def _complete(self, request: InlineCompletionRequest):
self.log.info(f"Inline completion handler resolved in {latency_ms} ms.")
self.write_message(reply.dict())

async def _stream(self, request: InlineCompletionRequest):
start = time.time()
handler = self.settings["jai_inline_completion_handler"]
async for chunk in handler.stream(request):
self.write_message(chunk.dict())
latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Inline completion streaming completed in {latency_ms} ms.")

def on_close(self):
self.log.debug(f"Disconnecting client {self.client_id}")
self.websocket_sessions.pop(self.client_id, None)
Expand Down
15 changes: 14 additions & 1 deletion packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ import { URLExt } from '@jupyterlab/coreutils';

import { ServerConnection } from '@jupyterlab/services';

import type { IInlineCompletionList } from '@jupyterlab/completer';
import type {
IInlineCompletionList,
IInlineCompletionItem
} from '@jupyterlab/completer';

const API_NAMESPACE = 'api/ai';

Expand Down Expand Up @@ -65,6 +68,8 @@ export namespace AiService {
/* The model may consider the following suffix */
suffix: string;
mime: string;
/* Whether to stream the response (if streaming is supported by the model) */
stream: boolean;
language?: string;
cell_id?: string;
};
Expand All @@ -84,6 +89,13 @@ export namespace AiService {
error?: CompletionError;
};

export type InlineCompletionStreamChunk = {
type: 'stream';
response: IInlineCompletionItem;
reply_to: number;
done: boolean;
};

export type InlineCompletionModelChanged = {
type: 'model_changed';
model: string;
Expand Down Expand Up @@ -137,6 +149,7 @@ export namespace AiService {
export type CompleterMessage =
| InlineCompletionReply
| ConnectionMessage
| InlineCompletionStreamChunk
| InlineCompletionModelChanged;

export type ChatHistory = {
Expand Down
Loading

0 comments on commit 85ecc02

Please sign in to comment.