From a3fbf28118b6999da9b3a8d6f36a049e15556480 Mon Sep 17 00:00:00 2001 From: michael Date: Tue, 17 Sep 2024 16:28:21 +0800 Subject: [PATCH] refactor autocomplete to remove hardcoded '/' and '@' prefix --- .../jupyter_ai/context_providers/base.py | 37 +++++++++---- .../jupyter_ai/context_providers/file.py | 16 +++--- packages/jupyter-ai/jupyter_ai/handlers.py | 18 +++++-- packages/jupyter-ai/jupyter_ai/models.py | 21 +++++--- .../jupyter-ai/src/components/chat-input.tsx | 54 ++++++++++--------- packages/jupyter-ai/src/handler.ts | 2 +- 6 files changed, 92 insertions(+), 56 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/base.py b/packages/jupyter-ai/jupyter_ai/context_providers/base.py index 7f7adabfe..bd5e15720 100644 --- a/packages/jupyter-ai/jupyter_ai/context_providers/base.py +++ b/packages/jupyter-ai/jupyter_ai/context_providers/base.py @@ -106,7 +106,7 @@ class ContextCommand(BaseModel): @property def id(self) -> str: - return self.cmd.partition(":")[0][1:] + return self.cmd.partition(":")[0] @property def arg(self) -> Optional[str]: @@ -122,11 +122,17 @@ def __hash__(self) -> int: class BaseCommandContextProvider(BaseContextProvider): + id_prefix: ClassVar[str] = "@" + only_start: ClassVar[bool] = False requires_arg: ClassVar[bool] = False remove_from_prompt: ClassVar[bool] = ( False # whether the command should be removed from prompt ) + @property + def command_id(self) -> str: + return self.id_prefix + self.id + @property def pattern(self) -> str: # arg pattern allows for arguments between quotes or spaces with escape character ('\ ') @@ -153,20 +159,14 @@ def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]: """ if self.requires_arg: # default implementation that should be modified if 'requires_arg' is True - return [ - ListOptionsEntry.from_arg( - type="@", - id=self.id, - description=self.description, - arg=arg_prefix, - is_complete=True, - ) - ] + return [self._make_arg_option(arg_prefix)] return [] def _find_commands(self, text: str) -> List[ContextCommand]: # finds commands of the context provider in the text - matches = re.finditer(self.pattern, text) + matches = list(re.finditer(self.pattern, text)) + if self.only_start: + matches = [match for match in matches if match.start() == 0] results = [] for match in matches: if not _is_within_backticks(match, text): @@ -178,6 +178,21 @@ def _replace_command(self, command: ContextCommand) -> str: return "" return command.cmd + def _make_arg_option( + self, + arg: str, + *, + is_complete: bool = True, + description: Optional[str] = None, + ) -> ListOptionsEntry: + return ListOptionsEntry.from_arg( + id=self.command_id, + description=description or self.description, + only_start=self.only_start, + arg=arg, + is_complete=is_complete, + ) + def _is_within_backticks(match, text): # potentially buggy if there is a stray backtick in text diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/file.py b/packages/jupyter-ai/jupyter_ai/context_providers/file.py index e3ae0adb0..a5f98aa53 100644 --- a/packages/jupyter-ai/jupyter_ai/context_providers/file.py +++ b/packages/jupyter-ai/jupyter_ai/context_providers/file.py @@ -27,7 +27,11 @@ def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]: path_prefix = arg_prefix if is_abs else os.path.join(self.base_dir, arg_prefix) path_prefix = path_prefix return [ - self._make_option(path, is_abs, is_dir) + self._make_arg_option( + arg=self._make_path(path, is_abs, is_dir), + description="Directory" if is_dir else "File", + is_complete=not is_dir, + ) for path in glob.glob(path_prefix + "*") if ( (is_dir := os.path.isdir(path)) @@ -35,18 +39,12 @@ def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]: ) ] - def _make_option(self, path: str, is_abs: bool, is_dir: bool) -> ListOptionsEntry: + def _make_path(self, path: str, is_abs: bool, is_dir: bool) -> str: if not is_abs: path = os.path.relpath(path, self.base_dir) if is_dir: path += "/" - return ListOptionsEntry.from_arg( - type="@", - id=self.id, - description="Directory" if is_dir else "File", - arg=path, - is_complete=not is_dir, - ) + return path async def make_context_prompt(self, message: HumanChatMessage) -> str: commands = set(self._find_commands(message.prompt)) diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 02866a35e..c105ad384 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -610,7 +610,15 @@ def get(self): def post(self): try: data = self.get_json_body() - context_provider = self.context_providers.get(data["id"]) + context_provider = next( + ( + cp + for cp in self.context_providers.values() + if isinstance(cp, BaseCommandContextProvider) + and cp.command_id == data["id"] + ), + None, + ) cmd = data["cmd"] response = ListOptionsResponse() @@ -658,7 +666,9 @@ def _get_slash_command_options(self) -> List[ListOptionsEntry]: options.append( ListOptionsEntry.from_command( - type="/", id=routing_type.slash_id, description=chat_handler.help + id="/" + routing_type.slash_id, + description=chat_handler.help, + only_start=True, ) ) options.sort(key=lambda opt: opt.id) @@ -667,9 +677,9 @@ def _get_slash_command_options(self) -> List[ListOptionsEntry]: def _get_context_provider_options(self) -> List[ListOptionsEntry]: options = [ ListOptionsEntry.from_command( - type="@", - id=context_provider.id, + id=context_provider.command_id, description=context_provider.description, + only_start=context_provider.only_start, requires_arg=context_provider.requires_arg, ) for context_provider in self.context_providers.values() diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 5d59bdcb7..3514104bd 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -266,34 +266,41 @@ class ListSlashCommandsResponse(BaseModel): class ListOptionsEntry(BaseModel): - type: Literal["/", "@"] id: str + # includes the command prefix. e.g. "/clear", "@file". label: str description: str + only_start: bool + # only allows autocomplete to be triggered if command is at start of input @classmethod def from_command( cls, - type: Literal["/", "@"], id: str, description: str, + only_start: bool = False, requires_arg: bool = False, ): - label = type + id + (":" if requires_arg else " ") - return cls(type=type, id=id, description=description, label=label) + label = id + (":" if requires_arg else " ") + return cls(id=id, description=description, label=label, only_start=only_start) @classmethod def from_arg( cls, - type: Literal["/", "@"], id: str, description: str, arg: str, + only_start: bool = False, is_complete: bool = True, ): arg = arg.replace("\\ ", " ").replace(" ", "\\ ") # escape spaces - label = type + id + ":" + arg + (" " if is_complete else "") - return cls(type=type, id=id, description=description, label=label) + label = id + ":" + arg + (" " if is_complete else "") + return cls( + id=id, + description=description, + label=label, + only_start=only_start, + ) class ListOptionsResponse(BaseModel): diff --git a/packages/jupyter-ai/src/components/chat-input.tsx b/packages/jupyter-ai/src/components/chat-input.tsx index 56c2410a6..df873e0db 100644 --- a/packages/jupyter-ai/src/components/chat-input.tsx +++ b/packages/jupyter-ai/src/components/chat-input.tsx @@ -45,14 +45,15 @@ type ChatInputProps = { * unclear whether custom icons should be defined within a Lumino plugin (in the * frontend) or served from a static server route (in the backend). */ -const DEFAULT_SLASH_COMMAND_ICONS: Record = { - ask: , - clear: , - export: , - fix: , - generate: , - help: , - learn: , +const DEFAULT_COMMAND_ICONS: Record = { + '/ask': , + '/clear': , + '/export': , + '/fix': , + '/generate': , + '/help': , + '/learn': , + '@file': , unknown: }; @@ -64,9 +65,9 @@ function renderAutocompleteOption( option: AiService.AutocompleteOption ): JSX.Element { const icon = - option.id in DEFAULT_SLASH_COMMAND_ICONS - ? DEFAULT_SLASH_COMMAND_ICONS[option.id] - : DEFAULT_SLASH_COMMAND_ICONS.unknown; + option.id in DEFAULT_COMMAND_ICONS + ? DEFAULT_COMMAND_ICONS[option.id] + : DEFAULT_COMMAND_ICONS.unknown; return (
  • @@ -120,12 +121,12 @@ export function ChatInput(props: ChatInputProps): JSX.Element { useEffect(() => { async function getAutocompleteArgOptions() { let options: AiService.AutocompleteOption[] = []; - const lastWord = input.split(/(? option.id === id.slice(1) && option.type === '@' + option => option.id === id ); if (option) { const response = await AiService.listAutocompleteArgOptions({ @@ -149,10 +150,10 @@ export function ChatInput(props: ChatInputProps): JSX.Element { } }, [autocompleteCommandOptions, autocompleteArgOptions]); - // whether any option is highlighted in the slash command autocomplete + // whether any option is highlighted in the autocomplete const [highlighted, setHighlighted] = useState(false); - // controls whether the slash command autocomplete is open + // controls whether the autocomplete is open const [open, setOpen] = useState(false); // store reference to the input element to enable focusing it easily @@ -178,7 +179,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { * chat input. Close the autocomplete when the user clears the chat input. */ useEffect(() => { - if (input === '/' || input.endsWith('@')) { + if (filterAutocompleteOptions(autocompleteOptions, input).length > 0) { setOpen(true); return; } @@ -284,14 +285,15 @@ export function ChatInput(props: ChatInputProps): JSX.Element { options: AiService.AutocompleteOption[], inputValue: string ): AiService.AutocompleteOption[] { - const lastWord = inputValue.split(/(? option.label.startsWith(lastWord)); + const lastWord = getLastWord(inputValue); + if (lastWord === '') { + return []; } - return []; + const isStart = lastWord === inputValue; + return options.filter( + option => + option.label.startsWith(lastWord) && (!option.only_start || isStart) + ); } return ( @@ -387,3 +389,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { ); } + +function getLastWord(input: string): string { + return input.split(/(?