Skip to content

Commit

Permalink
Implement inline completion provider (front)
Browse files Browse the repository at this point in the history
Also, remove unused `@jupyterlab/collaboration` which was
pulling old lumino versions causing problems for tokens.
  • Loading branch information
krassowski committed Nov 18, 2023
1 parent 7f32acc commit 68609ea
Show file tree
Hide file tree
Showing 13 changed files with 824 additions and 1,704 deletions.
2 changes: 2 additions & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class BaseChatHandler(BaseLLMHandler):
"""Base ChatHandler class containing shared methods and attributes used by
multiple chat handler classes."""

handler_kind = "chat"

def __init__(
self,
log: Logger,
Expand Down
7 changes: 4 additions & 3 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import time
import traceback

# necessary to prevent circular import
from typing import TYPE_CHECKING, Dict
from uuid import uuid4

from jupyter_ai.completions.models import (
CompletionError,
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
ModelChangedNotification,
Expand All @@ -21,6 +20,8 @@
class BaseInlineCompletionHandler(BaseLLMHandler):
"""Class implementing completion handling."""

handler_kind = "completion"

def __init__(
self,
log: Logger,
Expand Down Expand Up @@ -56,7 +57,7 @@ async def _handle_exc(self, e: Exception, message: InlineCompletionRequest):
title=e.args[0] if e.args else "Exception",
traceback=traceback.format_exc(),
)
return InlineCompletionReply(items=[], error=error)
return InlineCompletionReply(list=InlineCompletionList(items=[]), error=error)

def broadcast(self, message: ModelChangedNotification):
for session in self.ws_sessions.values():
Expand Down
26 changes: 15 additions & 11 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Type

from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationChain
from langchain.chains import LLMChain
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
Expand All @@ -10,6 +10,7 @@
)

from ..models import (
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
ModelChangedNotification,
Expand Down Expand Up @@ -40,16 +41,22 @@


class DefaultInlineCompletionHandler(BaseInlineCompletionHandler):
llm_chain: LLMChain

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
curr_lm_id = (
f'{self.llm.id}:{provider_params["model_id"]}' if self.llm else None
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
next_lm_id = (
f'{lm_provider.id}:{lm_provider_params["model_id"]}'
if lm_provider
else None
)
self.broadcast(ModelChangedNotification(model=curr_lm_id))
self.broadcast(ModelChangedNotification(model=next_lm_id))

model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)
Expand All @@ -74,9 +81,7 @@ def create_llm_chain(
)

self.llm = llm
self.llm_chain = ConversationChain(
llm=llm, prompt=prompt_template, verbose=True
)
self.llm_chain = LLMChain(llm=llm, prompt=prompt_template, verbose=True)

async def process_message(
self, request: InlineCompletionRequest
Expand All @@ -86,11 +91,10 @@ async def process_message(
prefix=request.prefix,
suffix=request.suffix,
language=request.language,
filename=request.path.split("/")[-1],
filename=request.path.split("/")[-1] if request.path else "untitled",
stop=["\n```"],
)
reply = InlineCompletionReply(
items=[prediction],
return InlineCompletionReply(
list=InlineCompletionList(items=[{"insertText": prediction}]),
reply_to=request.number,
)
return reply
20 changes: 13 additions & 7 deletions packages/jupyter-ai/jupyter_ai/completions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class InlineCompletionRequest(BaseModel):
# unique message ID generated by the client used to identify replies and
# to easily discard replies for older requests
number: int
# path to the notebook of file for which the completions are generated
path: str
# prefix should include full text of the current cell preceding the cursor
prefix: str
# suffix should include full text of the current cell preceding the cursor
suffix: str
# media type for the current language, e.g. `text/x-python`
mime: str
# path to the notebook of file for which the completions are generated
path: Optional[str]
# language inferred from the document mime type (if possible)
language: Optional[str]
# identifier of the cell for which the completions are generated if in a notebook
Expand All @@ -33,9 +33,9 @@ class InlineCompletionItem(BaseModel):
See JuptyerLab `InlineCompletionItem` documentation for the details.
"""

insert_text: str
filter_text: Optional[str]
is_incomplete: Optional[bool]
insertText: str
filterText: Optional[str]
isIncomplete: Optional[bool]
token: Optional[bool]


Expand All @@ -44,15 +44,21 @@ class CompletionError(BaseModel):
traceback: str


class InlineCompletionList(BaseModel):
"""Reflection of JupyterLab's `IInlineCompletionList`."""

items: List[InlineCompletionItem]


class InlineCompletionReply(BaseModel):
"""Message sent from model to client with the infill suggestions"""

items: List[InlineCompletionItem]
list: InlineCompletionList
# number of request for which we are replying
reply_to: int
error: Optional[CompletionError]


class ModelChangedNotification(BaseModel):
type: Literal["model_changed"] = "model_changed"
model: str
model: Optional[str]
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class AiExtension(ExtensionApp):
(r"api/ai/chats/history?", ChatHistoryHandler),
(r"api/ai/providers?", ModelProviderHandler),
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
(r"api/ai/completion/inline?", InlineCompletionHandler),
(r"api/ai/completion/inline/?", InlineCompletionHandler),
]

allowed_providers = List(
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def open(self):

self.log.info(f"Inline completion connected. ID: {client_id}")
self.log.debug(
f"Inline completion sessions are: {self.root_chat_handlers.keys()}"
f"Inline completion sessions are: {self.websocket_sessions.keys()}"
)

async def on_message(self, message):
Expand Down
12 changes: 10 additions & 2 deletions packages/jupyter-ai/jupyter_ai/llm_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
class BaseLLMHandler:
"""Base class containing shared methods and attributes used by LLM handler classes."""

handler_kind: str

def __init__(
self,
log: Logger,
Expand All @@ -20,6 +22,10 @@ def __init__(
self.llm_params = None
self.llm_chain = None

def model_changed_callback(self):
"""Method which can be overridden in sub-classes to listen to model change."""
pass

def get_llm_chain(self):
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
Expand All @@ -38,12 +44,14 @@ def get_llm_chain(self):

if curr_lm_id != next_lm_id:
self.log.info(
f"Switching chat language model from {curr_lm_id} to {next_lm_id}."
f"Switching {self.handler_kind} language model from {curr_lm_id} to {next_lm_id}."
)
self.create_llm_chain(lm_provider, lm_provider_params)
self.model_changed_callback()
elif self.llm_params != lm_provider_params:
self.log.info("Chat model params changed, updating the llm chain.")
self.log.info(
f"{self.handler_kind} model params changed, updating the llm chain."
)
self.create_llm_chain(lm_provider, lm_provider_params)

self.llm_params = lm_provider_params
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"@jupyterlab/cells": "^4",
"@jupyterlab/codeeditor": "^4",
"@jupyterlab/codemirror": "^4",
"@jupyterlab/collaboration": "^3",
"@jupyterlab/completer": "^4.1.0-alpha.3",
"@jupyterlab/coreutils": "^6",
"@jupyterlab/fileeditor": "^4",
"@jupyterlab/notebook": "^4",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function ExistingApiKey(props: ExistingApiKeyProps) {
}, [input]);

const onError = useCallback(
emsg => {
(emsg: any) => {
props.alert.show('error', emsg);
},
[props.alert]
Expand Down
33 changes: 33 additions & 0 deletions packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { URLExt } from '@jupyterlab/coreutils';

import { ServerConnection } from '@jupyterlab/services';

import type { IInlineCompletionList } from '@jupyterlab/completer';

const API_NAMESPACE = 'api/ai';

/**
Expand Down Expand Up @@ -55,6 +57,32 @@ export namespace AiService {
prompt: string;
};

export type InlineCompletionRequest = {
number: number;
path?: string;
/* The model has to complete given prefix */
prefix: string;
/* The model may consider the following suffix */
suffix: string;
mime: string;
language?: string;
cell_id?: string;
};

export type InlineCompletionReply = {
/**
* Type for this message can be skipped (`inline_completion` is presumed default).
**/
type?: 'inline_completion';
list: IInlineCompletionList;
reply_to: number;
};

export type InlineCompletionModelChanged = {
type: 'model_changed';
model: string;
};

export type Collaborator = {
username: string;
initials: string;
Expand Down Expand Up @@ -100,6 +128,11 @@ export namespace AiService {
| ConnectionMessage
| ClearMessage;

export type CompleterMessage =
| InlineCompletionReply
| ConnectionMessage
| InlineCompletionModelChanged;

export type ChatHistory = {
messages: ChatMessage[];
};
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { buildChatSidebar } from './widgets/chat-sidebar';
import { SelectionWatcher } from './selection-watcher';
import { ChatHandler } from './chat_handler';
import { buildErrorWidget } from './widgets/chat-error';
import { inlineCompletionProvider } from './inline-completions';

export type DocumentTracker = IWidgetTracker<IDocumentWidget>;

Expand Down Expand Up @@ -60,4 +61,4 @@ const plugin: JupyterFrontEndPlugin<void> = {
}
};

export default plugin;
export default [plugin, inlineCompletionProvider];
Loading

0 comments on commit 68609ea

Please sign in to comment.