diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 8c8a1f917..39b1b8adf 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -485,7 +485,7 @@ async def stream_inline_completions( chain = self._create_completion_chain() token = completion.token_from_request(request, 0) model_arguments = completion.template_inputs_from_request(request) - suggestion = "" + suggestion = processed_suggestion = "" # send an incomplete `InlineCompletionReply`, indicating to the # client that LLM output is about to streamed across this connection. @@ -505,17 +505,12 @@ async def stream_inline_completions( async for fragment in chain.astream(input=model_arguments): suggestion += fragment - if suggestion.startswith("```"): - if "\n" not in suggestion: - # we are not ready to apply post-processing - continue - else: - suggestion = completion.post_process_suggestion(suggestion, request) - elif suggestion.rstrip().endswith("```"): - suggestion = completion.post_process_suggestion(suggestion, request) + processed_suggestion = completion.post_process_suggestion( + suggestion, request + ) yield InlineCompletionStreamChunk( type="stream", - response={"insertText": suggestion, "token": token}, + response={"insertText": processed_suggestion, "token": token}, reply_to=request.number, done=False, ) @@ -523,7 +518,7 @@ async def stream_inline_completions( # finally, send a message confirming that we are done yield InlineCompletionStreamChunk( type="stream", - response={"insertText": suggestion, "token": token}, + response={"insertText": processed_suggestion, "token": token}, reply_to=request.number, done=True, ) diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index c5b5d1eea..d87b75b9d 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -98,13 +98,16 @@ async def test_handle_request(inline_handler): assert suggestions[0].insertText == "Test response" +expected_suggestions_cases = [ + ("```python\nTest python code\n```", "Test python code"), + ("```\ntest\n```\n \n", "test"), + ("```hello```world```", "hello```world"), +] + + @pytest.mark.parametrize( "response,expected_suggestion", - [ - ("```python\nTest python code\n```", "Test python code"), - ("```\ntest\n```\n \n", "test"), - ("```hello```world```", "hello```world"), - ], + expected_suggestions_cases, ) async def test_handle_request_with_spurious_fragments(response, expected_suggestion): inline_handler = MockCompletionHandler( @@ -128,6 +131,32 @@ async def test_handle_request_with_spurious_fragments(response, expected_suggest assert suggestions[0].insertText == expected_suggestion +@pytest.mark.parametrize( + "response,expected_suggestion", + expected_suggestions_cases, +) +async def test_handle_request_with_spurious_fragments_stream( + response, expected_suggestion +): + inline_handler = MockCompletionHandler( + lm_provider=MockProvider, + lm_provider_params={ + "model_id": "model", + "responses": [response], + }, + ) + dummy_request = InlineCompletionRequest( + number=1, prefix="", suffix="", mime="", stream=True + ) + + await inline_handler.handle_stream_request(dummy_request) + assert len(inline_handler.messages) == 3 + # the streamed fragment should not include spurious fragments + assert inline_handler.messages[1].response.insertText == expected_suggestion + # the final state should not include spurious fragments either + assert inline_handler.messages[2].response.insertText == expected_suggestion + + async def test_handle_stream_request(): inline_handler = MockCompletionHandler( lm_provider=MockProvider, diff --git a/packages/jupyter-ai/schema/plugin.json b/packages/jupyter-ai/schema/plugin.json index 57e8ad8ea..78804b5c6 100644 --- a/packages/jupyter-ai/schema/plugin.json +++ b/packages/jupyter-ai/schema/plugin.json @@ -4,6 +4,14 @@ "description": "JupyterLab generative artificial intelligence integration.", "jupyter.lab.setting-icon": "jupyter-ai::chat", "jupyter.lab.setting-icon-label": "Jupyter AI Chat", + "jupyter.lab.shortcuts": [ + { + "command": "jupyter-ai:focus-chat-input", + "keys": ["Accel Shift 1"], + "selector": "body", + "preventDefault": false + } + ], "additionalProperties": false, "type": "object" } diff --git a/packages/jupyter-ai/src/components/chat-input.tsx b/packages/jupyter-ai/src/components/chat-input.tsx index 6e13b9679..7da7fc778 100644 --- a/packages/jupyter-ai/src/components/chat-input.tsx +++ b/packages/jupyter-ai/src/components/chat-input.tsx @@ -1,4 +1,4 @@ -import React, { useEffect, useState } from 'react'; +import React, { useEffect, useRef, useState } from 'react'; import { Autocomplete, @@ -22,6 +22,7 @@ import { HideSource, AutoFixNormal } from '@mui/icons-material'; +import { ISignal } from '@lumino/signaling'; import { AiService } from '../handler'; import { SendButton, SendButtonProps } from './chat-input/send-button'; @@ -33,6 +34,7 @@ type ChatInputProps = { onSend: (selection?: AiService.Selection) => unknown; hasSelection: boolean; includeSelection: boolean; + focusInputSignal: ISignal; toggleIncludeSelection: () => unknown; replaceSelection: boolean; toggleReplaceSelection: () => unknown; @@ -131,6 +133,24 @@ export function ChatInput(props: ChatInputProps): JSX.Element { // controls whether the slash command autocomplete is open const [open, setOpen] = useState(false); + // store reference to the input element to enable focusing it easily + const inputRef = useRef(); + + /** + * Effect: connect the signal emitted on input focus request. + */ + useEffect(() => { + const focusInputElement = () => { + if (inputRef.current) { + inputRef.current.focus(); + } + }; + props.focusInputSignal.connect(focusInputElement); + return () => { + props.focusInputSignal.disconnect(focusInputElement); + }; + }, []); + /** * Effect: Open the autocomplete when the user types a slash into an empty * chat input. Close the autocomplete when the user clears the chat input. @@ -284,6 +304,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { multiline placeholder="Ask Jupyternaut" onKeyDown={handleKeyDown} + inputRef={inputRef} InputProps={{ ...params.InputProps, endAdornment: ( diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index c84ae022b..da63ff39f 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -6,6 +6,7 @@ import ArrowBackIcon from '@mui/icons-material/ArrowBack'; import type { Awareness } from 'y-protocols/awareness'; import type { IThemeManager } from '@jupyterlab/apputils'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; +import { ISignal } from '@lumino/signaling'; import { JlThemeProvider } from './jl-theme-provider'; import { ChatMessages } from './chat-messages'; @@ -31,10 +32,12 @@ type ChatBodyProps = { chatHandler: ChatHandler; setChatView: (view: ChatView) => void; rmRegistry: IRenderMimeRegistry; + focusInputSignal: ISignal; }; function ChatBody({ chatHandler, + focusInputSignal, setChatView: chatViewHandler, rmRegistry: renderMimeRegistry }: ChatBodyProps): JSX.Element { @@ -162,6 +165,7 @@ function ChatBody({ onSend={onSend} hasSelection={!!textSelection?.text} includeSelection={includeSelection} + focusInputSignal={focusInputSignal} toggleIncludeSelection={() => setIncludeSelection(includeSelection => !includeSelection) } @@ -192,6 +196,7 @@ export type ChatProps = { completionProvider: IJaiCompletionProvider | null; openInlineCompleterSettings: () => void; activeCellManager: ActiveCellManager; + focusInputSignal: ISignal; }; enum ChatView { @@ -244,6 +249,7 @@ export function Chat(props: ChatProps): JSX.Element { chatHandler={props.chatHandler} setChatView={setView} rmRegistry={props.rmRegistry} + focusInputSignal={props.focusInputSignal} /> )} {view === ChatView.Settings && ( diff --git a/packages/jupyter-ai/src/index.ts b/packages/jupyter-ai/src/index.ts index 8a0298f40..cd9d8b322 100644 --- a/packages/jupyter-ai/src/index.ts +++ b/packages/jupyter-ai/src/index.ts @@ -21,9 +21,17 @@ import { statusItemPlugin } from './status'; import { IJaiCompletionProvider } from './tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ActiveCellManager } from './contexts/active-cell-context'; +import { Signal } from '@lumino/signaling'; export type DocumentTracker = IWidgetTracker; +export namespace CommandIDs { + /** + * Command to focus the input. + */ + export const focusChatInput = 'jupyter-ai:focus-chat-input'; +} + /** * Initialization data for the jupyter_ai extension. */ @@ -66,7 +74,9 @@ const plugin: JupyterFrontEndPlugin = { }); }; - let chatWidget: ReactWidget | null = null; + const focusInputSignal = new Signal({}); + + let chatWidget: ReactWidget; try { await chatHandler.initialize(); chatWidget = buildChatSidebar( @@ -77,7 +87,8 @@ const plugin: JupyterFrontEndPlugin = { rmRegistry, completionProvider, openInlineCompleterSettings, - activeCellManager + activeCellManager, + focusInputSignal ); } catch (e) { chatWidget = buildErrorWidget(themeManager); @@ -91,6 +102,15 @@ const plugin: JupyterFrontEndPlugin = { if (restorer) { restorer.add(chatWidget, 'jupyter-ai-chat'); } + + // Define jupyter-ai commands + app.commands.addCommand(CommandIDs.focusChatInput, { + execute: () => { + app.shell.activateById(chatWidget.id); + focusInputSignal.emit(); + }, + label: 'Focus the jupyter-ai chat' + }); } }; diff --git a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx index d6a3d9cdb..40cf5945f 100644 --- a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx +++ b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx @@ -1,4 +1,5 @@ import React from 'react'; +import { ISignal } from '@lumino/signaling'; import { ReactWidget } from '@jupyterlab/apputils'; import type { IThemeManager } from '@jupyterlab/apputils'; import type { Awareness } from 'y-protocols/awareness'; @@ -19,7 +20,8 @@ export function buildChatSidebar( rmRegistry: IRenderMimeRegistry, completionProvider: IJaiCompletionProvider | null, openInlineCompleterSettings: () => void, - activeCellManager: ActiveCellManager + activeCellManager: ActiveCellManager, + focusInputSignal: ISignal ): ReactWidget { const ChatWidget = ReactWidget.create( ); ChatWidget.id = 'jupyter-ai::chat';