diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py index 3c8daa3f5..a1cdeb0d1 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py @@ -1,13 +1,14 @@ import traceback # necessary to prevent circular import -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, AsyncIterator, Dict from jupyter_ai.completions.models import ( CompletionError, InlineCompletionList, InlineCompletionReply, InlineCompletionRequest, + InlineCompletionStreamChunk, ModelChangedNotification, ) from jupyter_ai.config_manager import ConfigManager, Logger @@ -51,6 +52,15 @@ async def process_message( """ raise NotImplementedError("Should be implemented by subclasses.") + async def stream( + self, message: InlineCompletionRequest + ) -> AsyncIterator[InlineCompletionStreamChunk]: + """ " + Stream the inline completion as it is generated. Completion handlers + (subclasses) can implement this method. + """ + raise NotImplementedError() + async def _handle_exc(self, e: Exception, message: InlineCompletionRequest): error = CompletionError( type=e.__class__.__name__, diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index 0267717cb..7cf771342 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -1,18 +1,20 @@ from typing import Dict, Type from jupyter_ai_magics.providers import BaseProvider -from langchain.chains import LLMChain from langchain.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate, SystemMessagePromptTemplate, ) +from langchain.schema.output_parser import StrOutputParser +from langchain.schema.runnable import Runnable from ..models import ( InlineCompletionList, InlineCompletionReply, InlineCompletionRequest, + InlineCompletionStreamChunk, ModelChangedNotification, ) from .base import BaseInlineCompletionHandler @@ -45,7 +47,7 @@ class DefaultInlineCompletionHandler(BaseInlineCompletionHandler): - llm_chain: LLMChain + llm_chain: Runnable def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -79,28 +81,82 @@ def create_llm_chain( ) self.llm = llm - self.llm_chain = LLMChain(llm=llm, prompt=prompt_template, verbose=True) + self.llm_chain = prompt_template | llm | StrOutputParser() async def process_message( self, request: InlineCompletionRequest ) -> InlineCompletionReply: - self.get_llm_chain() - suffix = request.suffix.strip() - prediction = await self.llm_chain.apredict( - prefix=request.prefix, - # only add the suffix template if the suffix is there to save input tokens/computation time - after=AFTER_TEMPLATE.format(suffix=suffix) if suffix else "", - language=request.language, - filename=request.path.split("/")[-1] if request.path else "untitled", - stop=["\n```"], - ) - prediction = self._post_process_suggestion(prediction, request) + if request.stream: + token = self._token_from_request(request, 0) + return InlineCompletionReply( + list=InlineCompletionList( + items=[ + { + # insert text starts empty as we do not pre-generate any part + "insertText": "", + "isIncomplete": True, + "token": token, + } + ] + ), + reply_to=request.number, + ) + else: + self.get_llm_chain() + model_arguments = self._template_inputs_from_request(request) + suggestion = await self.llm_chain.ainvoke(input=model_arguments) + suggestion = self._post_process_suggestion(suggestion, request) + return InlineCompletionReply( + list=InlineCompletionList(items=[{"insertText": suggestion}]), + reply_to=request.number, + ) - return InlineCompletionReply( - list=InlineCompletionList(items=[{"insertText": prediction}]), + async def stream(self, request: InlineCompletionRequest): + self.get_llm_chain() + token = self._token_from_request(request, 0) + model_arguments = self._template_inputs_from_request(request) + suggestion = "" + async for fragment in self.llm_chain.astream(input=model_arguments): + suggestion += fragment + if suggestion.startswith("```"): + if "\n" not in suggestion: + # we are not ready to apply post-processing + continue + else: + suggestion = self._post_process_suggestion(suggestion, request) + yield InlineCompletionStreamChunk( + type="stream", + response={"insertText": suggestion, "token": token}, + reply_to=request.number, + done=False, + ) + # at the end send a message confirming that we are done + yield InlineCompletionStreamChunk( + type="stream", + response={"insertText": suggestion, "token": token}, reply_to=request.number, + done=True, ) + def _token_from_request(self, request: InlineCompletionRequest, suggestion: int): + """Generate a deterministic token (for matching streamed messages) + using request number and suggestion number""" + return f"t{request.number}s{suggestion}" + + def _template_inputs_from_request(self, request: InlineCompletionRequest) -> Dict: + suffix = request.suffix.strip() + # only add the suffix template if the suffix is there to save input tokens/computation time + after = AFTER_TEMPLATE.format(suffix=suffix) if suffix else "" + filename = request.path.split("/")[-1] if request.path else "untitled" + + return { + "prefix": request.prefix, + "after": after, + "language": request.language, + "filename": filename, + "stop": ["\n```"], + } + def _post_process_suggestion( self, suggestion: str, request: InlineCompletionRequest ) -> str: diff --git a/packages/jupyter-ai/jupyter_ai/completions/models.py b/packages/jupyter-ai/jupyter_ai/completions/models.py index 6f60ddbc5..a64610036 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/models.py +++ b/packages/jupyter-ai/jupyter_ai/completions/models.py @@ -18,6 +18,8 @@ class InlineCompletionRequest(BaseModel): suffix: str # media type for the current language, e.g. `text/x-python` mime: str + # whether to stream the response (if supported by the model) + stream: bool # path to the notebook of file for which the completions are generated path: Optional[str] # language inferred from the document mime type (if possible) @@ -36,7 +38,7 @@ class InlineCompletionItem(BaseModel): insertText: str filterText: Optional[str] isIncomplete: Optional[bool] - token: Optional[bool] + token: Optional[str] class CompletionError(BaseModel): @@ -59,6 +61,16 @@ class InlineCompletionReply(BaseModel): error: Optional[CompletionError] +class InlineCompletionStreamChunk(BaseModel): + """Message sent from model to client with the infill suggestions""" + + type: Literal["stream"] = "stream" + response: InlineCompletionItem + reply_to: int + done: bool + error: Optional[CompletionError] + + class ModelChangedNotification(BaseModel): type: Literal["model_changed"] = "model_changed" model: Optional[str] diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 0161709b6..bbf86d0a5 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -305,10 +305,17 @@ async def on_message(self, message): self.log.error(e) return - # do not await this, as it blocks the parent task responsible for - # handling messages from a websocket. instead, process each message - # as a distinct concurrent task. - self.loop.create_task(self._complete(request)) + if request.stream: + try: + stream_coroutine = self._stream(request) + except NotImplementedError: + self.log.info( + "Not streaming as handler does not implement stream() method" + ) + await self._complete(request) + self.loop.create_task(stream_coroutine) + else: + self.loop.create_task(self._complete(request)) async def _complete(self, request: InlineCompletionRequest): start = time.time() @@ -317,6 +324,14 @@ async def _complete(self, request: InlineCompletionRequest): self.log.info(f"Inline completion handler resolved in {latency_ms} ms.") self.write_message(reply.dict()) + async def _stream(self, request: InlineCompletionRequest): + start = time.time() + handler = self.settings["jai_inline_completion_handler"] + async for chunk in handler.stream(request): + self.write_message(chunk.dict()) + latency_ms = round((time.time() - start) * 1000) + self.log.info(f"Inline completion streaming completed in {latency_ms} ms.") + def on_close(self): self.log.debug(f"Disconnecting client {self.client_id}") self.websocket_sessions.pop(self.client_id, None) diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index cf3ca08d4..5758c7275 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -2,7 +2,10 @@ import { URLExt } from '@jupyterlab/coreutils'; import { ServerConnection } from '@jupyterlab/services'; -import type { IInlineCompletionList } from '@jupyterlab/completer'; +import type { + IInlineCompletionList, + IInlineCompletionItem +} from '@jupyterlab/completer'; const API_NAMESPACE = 'api/ai'; @@ -65,6 +68,8 @@ export namespace AiService { /* The model may consider the following suffix */ suffix: string; mime: string; + /* Whether to stream the response (if streaming is supported by the model) */ + stream: boolean; language?: string; cell_id?: string; }; @@ -84,6 +89,13 @@ export namespace AiService { error?: CompletionError; }; + export type InlineCompletionStreamChunk = { + type: 'stream'; + response: IInlineCompletionItem; + reply_to: number; + done: boolean; + }; + export type InlineCompletionModelChanged = { type: 'model_changed'; model: string; @@ -137,6 +149,7 @@ export namespace AiService { export type CompleterMessage = | InlineCompletionReply | ConnectionMessage + | InlineCompletionStreamChunk | InlineCompletionModelChanged; export type ChatHistory = { diff --git a/packages/jupyter-ai/src/inline-completions.ts b/packages/jupyter-ai/src/inline-completions.ts index 75c14bb34..6f5ff4660 100644 --- a/packages/jupyter-ai/src/inline-completions.ts +++ b/packages/jupyter-ai/src/inline-completions.ts @@ -4,6 +4,7 @@ import { } from '@jupyterlab/application'; import { ICompletionProviderManager, + InlineCompletionTriggerKind, IInlineCompletionProvider, IInlineCompletionContext, CompletionHandler @@ -28,6 +29,8 @@ import { IJupyternautStatus } from './tokens'; const SERVICE_URL = 'api/ai/completion/inline'; +type StreamChunk = AiService.InlineCompletionStreamChunk; + export class CompletionWebsocketHandler implements IDisposable { /** * The server settings used to make API requests. @@ -64,10 +67,20 @@ export class CompletionWebsocketHandler implements IDisposable { }); } + /** + * Signal emitted when completion AI model changes. + */ get modelChanged(): ISignal { return this._modelChanged; } + /** + * Signal emitted when a new chunk of completion is streamed. + */ + get streamed(): ISignal { + return this._streamed; + } + /** * Whether the completion handler is disposed. */ @@ -106,6 +119,10 @@ export class CompletionWebsocketHandler implements IDisposable { this._initialized.resolve(); break; } + case 'stream': { + this._streamed.emit(message); + break; + } default: { if (message.reply_to in this._replyForResolver) { this._replyForResolver[message.reply_to](message); @@ -161,6 +178,7 @@ export class CompletionWebsocketHandler implements IDisposable { private _isDisposed = false; private _socket: WebSocket | null = null; private _modelChanged = new Signal(this); + private _streamed = new Signal(this); private _initialized: PromiseDelegate = new PromiseDelegate(); } @@ -187,6 +205,7 @@ class JupyterAIInlineProvider implements IInlineCompletionProvider { this._currentModel = model; } ); + options.completionHandler.streamed.connect(this._receiveStreamChunk, this); } get name() { @@ -226,6 +245,19 @@ class JupyterAIInlineProvider implements IInlineCompletionProvider { path = context.widget.context.path; } const number = ++this._counter; + + const streamPreference = this._settings.streaming; + const stream = + streamPreference === 'always' + ? true + : streamPreference === 'never' + ? false + : context.triggerKind === InlineCompletionTriggerKind.Invoke; + + if (stream) { + // Reset stream promises handler + this._streamPromises.clear(); + } const result = await this.options.completionHandler.sendMessage({ path: context.session?.path, mime, @@ -233,8 +265,10 @@ class JupyterAIInlineProvider implements IInlineCompletionProvider { suffix: this._suffixFromRequest(request), language: this._resolveLanguage(language), number, + stream, cell_id: cellId }); + const error = result.error; if (error) { Notification.emit(`Inline completion failed: ${error.type}`, 'error', { @@ -243,10 +277,9 @@ class JupyterAIInlineProvider implements IInlineCompletionProvider { { label: 'Show Traceback', callback: () => { - showErrorMessage( - 'Inline completion failed on the server side', - error.traceback - ); + showErrorMessage('Inline completion failed on the server side', { + message: error.traceback + }); } } ] @@ -258,6 +291,20 @@ class JupyterAIInlineProvider implements IInlineCompletionProvider { return result.list; } + /** + * Stream a reply for completion identified by given `token`. + */ + async *stream(token: string) { + let done = false; + while (!done) { + const delegate = new PromiseDelegate(); + this._streamPromises.set(token, delegate); + const promise = delegate.promise; + yield promise; + done = (await promise).done; + } + } + get schema(): ISettingRegistry.IProperty { const knownLanguages = this.options.languageRegistry.getLanguages(); return { @@ -287,6 +334,16 @@ class JupyterAIInlineProvider implements IInlineCompletionProvider { }, description: 'Languages for which the completions should not be shown.' + }, + streaming: { + title: 'Streaming', + type: 'string', + oneOf: [ + { const: 'always', title: 'Always' }, + { const: 'manual', title: 'When invoked manually' }, + { const: 'never', title: 'Never' } + ], + description: 'Whether to show suggestions as they are generated' } }, default: JupyterAIInlineProvider.DEFAULT_SETTINGS as any @@ -305,6 +362,28 @@ class JupyterAIInlineProvider implements IInlineCompletionProvider { return !this._settings.disabledLanguages.includes(language); } + /** + * Process the stream chunk to make it available in the awaiting generator. + */ + private _receiveStreamChunk( + _emitter: CompletionWebsocketHandler, + chunk: StreamChunk + ) { + const token = chunk.response.token; + if (!token) { + throw Error('Stream chunks must return define `token` in `response`'); + } + const delegate = this._streamPromises.get(token); + if (!delegate) { + console.warn('Unhandled stream chunk'); + } else { + delegate.resolve(chunk); + if (chunk.done) { + this._streamPromises.delete(token); + } + } + } + /** * Extract prefix from request, accounting for context window limit. */ @@ -341,6 +420,8 @@ class JupyterAIInlineProvider implements IInlineCompletionProvider { private _settings: JupyterAIInlineProvider.ISettings = JupyterAIInlineProvider.DEFAULT_SETTINGS; + private _streamPromises: Map> = + new Map(); private _currentModel = ''; private _counter = 0; } @@ -356,6 +437,7 @@ namespace JupyterAIInlineProvider { debouncerDelay: number; enabled: boolean; disabledLanguages: string[]; + streaming: 'always' | 'manual' | 'never'; } export const DEFAULT_SETTINGS: ISettings = { maxPrefix: 10000, @@ -366,7 +448,8 @@ namespace JupyterAIInlineProvider { debouncerDelay: 250, enabled: true, // ipythongfm means "IPython GitHub Flavoured Markdown" - disabledLanguages: ['ipythongfm'] + disabledLanguages: ['ipythongfm'], + streaming: 'manual' }; }