From 6e426abc880148c565633d6add055b731cce92d1 Mon Sep 17 00:00:00 2001 From: michaelchia Date: Thu, 26 Sep 2024 01:16:33 +0800 Subject: [PATCH] Framework for adding context to LLM prompt (#993) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * context provider * split base and base command context providers + replacing prompt * comment * only replace prompt if context variable in template * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run mypy on CI, fix or ignore typing issues (#987) * Run mypy on CI * Rename, add mypy to test deps * Fix typing jupyter-ai codebase (mostly) * Three more cases * update deepmerge version specifier --------- Co-authored-by: David L. Qiu * context provider * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * mypy * black * modify backtick logic * allow for spaces in filepath * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test * refactor autocomplete to remove hardcoded '/' and '@' prefix * modify context prompt template Co-authored-by: david qiu * refactor * docstrings + refactor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * mypy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add context providers to help * remove _examples.py and remove @learned from defaults * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make find_commands unoverridable --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michał Krassowski <5832902+krassowski@users.noreply.github.com> Co-authored-by: David L. Qiu --- .../jupyter_ai_magics/providers.py | 28 +- .../jupyter_ai/chat_handlers/base.py | 18 +- .../jupyter_ai/chat_handlers/default.py | 35 ++- .../jupyter_ai/context_providers/__init__.py | 7 + .../jupyter_ai/context_providers/_learned.py | 53 ++++ .../jupyter_ai/context_providers/base.py | 243 ++++++++++++++++++ .../jupyter_ai/context_providers/file.py | 119 +++++++++ packages/jupyter-ai/jupyter_ai/extension.py | 156 ++++++++--- packages/jupyter-ai/jupyter_ai/handlers.py | 110 ++++++++ packages/jupyter-ai/jupyter_ai/models.py | 18 ++ .../tests/test_context_providers.py | 80 ++++++ .../jupyter_ai/tests/test_handlers.py | 1 + .../jupyter-ai/src/components/chat-input.tsx | 135 +++++++--- packages/jupyter-ai/src/handler.ts | 26 ++ 14 files changed, 942 insertions(+), 87 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/context_providers/__init__.py create mode 100644 packages/jupyter-ai/jupyter_ai/context_providers/_learned.py create mode 100644 packages/jupyter-ai/jupyter_ai/context_providers/base.py create mode 100644 packages/jupyter-ai/jupyter_ai/context_providers/file.py create mode 100644 packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 551db8bbc..023e51a62 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -67,11 +67,25 @@ The following is a friendly conversation between you and a human. """.strip() -CHAT_DEFAULT_TEMPLATE = """Current conversation: -{history} -Human: {input} +CHAT_DEFAULT_TEMPLATE = """ +{% if context %} +Context: +{{context}} + +{% endif %} +Current conversation: +{{history}} +Human: {{input}} AI:""" +HUMAN_MESSAGE_TEMPLATE = """ +{% if context %} +Context: +{{context}} + +{% endif %} +{{input}} +""" COMPLETION_SYSTEM_PROMPT = """ You are an application built to provide helpful code completion suggestions. @@ -400,17 +414,21 @@ def get_chat_prompt_template(self) -> PromptTemplate: CHAT_SYSTEM_PROMPT ).format(provider_name=name, local_model_id=self.model_id), MessagesPlaceholder(variable_name="history"), - HumanMessagePromptTemplate.from_template("{input}"), + HumanMessagePromptTemplate.from_template( + HUMAN_MESSAGE_TEMPLATE, + template_format="jinja2", + ), ] ) else: return PromptTemplate( - input_variables=["history", "input"], + input_variables=["history", "input", "context"], template=CHAT_SYSTEM_PROMPT.format( provider_name=name, local_model_id=self.model_id ) + "\n\n" + CHAT_DEFAULT_TEMPLATE, + template_format="jinja2", ) def get_completion_prompt_template(self) -> PromptTemplate: diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 347bfbf83..fb2559c30 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -33,6 +33,7 @@ from langchain.pydantic_v1 import BaseModel if TYPE_CHECKING: + from jupyter_ai.context_providers import BaseCommandContextProvider from jupyter_ai.handlers import RootChatHandler from jupyter_ai.history import BoundedChatHistory from langchain_core.chat_history import BaseChatMessageHistory @@ -121,6 +122,10 @@ class BaseChatHandler: chat handlers, which is necessary for some use-cases like printing the help message.""" + context_providers: Dict[str, "BaseCommandContextProvider"] + """Dictionary of context providers. Allows chat handlers to reference + context providers, which can be used to provide context to the LLM.""" + def __init__( self, log: Logger, @@ -134,6 +139,7 @@ def __init__( dask_client_future: Awaitable[DaskClient], help_message_template: str, chat_handlers: Dict[str, "BaseChatHandler"], + context_providers: Dict[str, "BaseCommandContextProvider"], ): self.log = log self.config_manager = config_manager @@ -154,6 +160,7 @@ def __init__( self.dask_client_future = dask_client_future self.help_message_template = help_message_template self.chat_handlers = chat_handlers + self.context_providers = context_providers self.llm: Optional[BaseProvider] = None self.llm_params: Optional[dict] = None @@ -430,8 +437,17 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non ] ) + context_commands_list = "\n".join( + [ + f"* `{cp.command_id}` — {cp.help}" + for cp in self.context_providers.values() + ] + ) + help_message_body = self.help_message_template.format( - persona_name=self.persona.name, slash_commands_list=slash_commands_list + persona_name=self.persona.name, + slash_commands_list=slash_commands_list, + context_commands_list=context_commands_list, ) help_message = AgentChatMessage( id=uuid4().hex, diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 2bf88bafe..dc6753b58 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,3 +1,4 @@ +import asyncio import time from typing import Dict, Type from uuid import uuid4 @@ -12,6 +13,7 @@ from langchain_core.runnables import ConfigurableFieldSpec from langchain_core.runnables.history import RunnableWithMessageHistory +from ..context_providers import ContextProviderException, find_commands from ..models import HumanChatMessage from .base import BaseChatHandler, SlashCommandRoutingType @@ -27,6 +29,7 @@ class DefaultChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.prompt_template = None def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -40,6 +43,7 @@ def create_llm_chain( prompt_template = llm.get_chat_prompt_template() self.llm = llm + self.prompt_template = prompt_template runnable = prompt_template | llm # type:ignore if not llm.manages_history: @@ -101,6 +105,17 @@ async def process_message(self, message: HumanChatMessage): self.get_llm_chain() received_first_chunk = False + inputs = {"input": message.body} + if "context" in self.prompt_template.input_variables: + # include context from context providers. + try: + context_prompt = await self.make_context_prompt(message) + except ContextProviderException as e: + self.reply(str(e), message) + return + inputs["context"] = context_prompt + inputs["input"] = self.replace_prompt(inputs["input"]) + # start with a pending message with self.pending("Generating response", message) as pending_message: # stream response in chunks. this works even if a provider does not @@ -108,7 +123,7 @@ async def process_message(self, message: HumanChatMessage): # when `_stream()` is not implemented on the LLM class. assert self.llm_chain async for chunk in self.llm_chain.astream( - {"input": message.body}, + inputs, config={"configurable": {"last_human_msg": message}}, ): if not received_first_chunk: @@ -128,3 +143,21 @@ async def process_message(self, message: HumanChatMessage): # complete stream after all chunks have been streamed self._send_stream_chunk(stream_id, "", complete=True) + + async def make_context_prompt(self, human_msg: HumanChatMessage) -> str: + return "\n\n".join( + await asyncio.gather( + *[ + provider.make_context_prompt(human_msg) + for provider in self.context_providers.values() + if find_commands(provider, human_msg.prompt) + ] + ) + ) + + def replace_prompt(self, prompt: str) -> str: + # modifies prompt by the context providers. + # some providers may modify or remove their '@' commands from the prompt. + for provider in self.context_providers.values(): + prompt = provider.replace_prompt(prompt) + return prompt diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py b/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py new file mode 100644 index 000000000..7c521d848 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py @@ -0,0 +1,7 @@ +from .base import ( + BaseCommandContextProvider, + ContextCommand, + ContextProviderException, + find_commands, +) +from .file import FileContextProvider diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/_learned.py b/packages/jupyter-ai/jupyter_ai/context_providers/_learned.py new file mode 100644 index 000000000..5128487de --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/_learned.py @@ -0,0 +1,53 @@ +# Currently unused as it is duplicating the functionality of the /ask command. +# TODO: Rename "learned" to something better. +from typing import List + +from jupyter_ai.chat_handlers.learn import Retriever +from jupyter_ai.models import HumanChatMessage + +from .base import BaseCommandContextProvider, ContextCommand +from .file import FileContextProvider + +FILE_CHUNK_TEMPLATE = """ +Snippet from file: {filepath} +``` +{content} +``` +""".strip() + + +class LearnedContextProvider(BaseCommandContextProvider): + id = "learned" + help = "Include content indexed from `/learn`" + remove_from_prompt = True + header = "Following are snippets from potentially relevant files:" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.retriever = Retriever(learn_chat_handler=self.chat_handlers["/learn"]) + + async def _make_context_prompt( + self, message: HumanChatMessage, commands: List[ContextCommand] + ) -> str: + if not self.retriever: + return "" + query = self._clean_prompt(message.body) + docs = await self.retriever.ainvoke(query) + excluded = self._get_repeated_files(message) + context = "\n\n".join( + [ + FILE_CHUNK_TEMPLATE.format( + filepath=d.metadata["path"], content=d.page_content + ) + for d in docs + if d.metadata["path"] not in excluded and d.page_content + ] + ) + return self.header + "\n" + context + + def _get_repeated_files(self, message: HumanChatMessage) -> List[str]: + # don't include files that are already provided by the file context provider + file_context_provider = self.context_providers.get("file") + if isinstance(file_context_provider, FileContextProvider): + return file_context_provider.get_filepaths(message) + return [] diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/base.py b/packages/jupyter-ai/jupyter_ai/context_providers/base.py new file mode 100644 index 000000000..1b0953e84 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/base.py @@ -0,0 +1,243 @@ +import abc +import os +import re +from typing import TYPE_CHECKING, Awaitable, ClassVar, Dict, List, Optional + +from dask.distributed import Client as DaskClient +from jupyter_ai.chat_handlers.base import get_preferred_dir +from jupyter_ai.config_manager import ConfigManager, Logger +from jupyter_ai.models import ChatMessage, HumanChatMessage, ListOptionsEntry +from langchain.pydantic_v1 import BaseModel + +if TYPE_CHECKING: + from jupyter_ai.chat_handlers import BaseChatHandler + from jupyter_ai.history import BoundedChatHistory + + +class _BaseContextProvider(abc.ABC): + id: ClassVar[str] + """Unique identifier for the context provider command.""" + help: ClassVar[str] + """What this chat handler does, which third-party models it contacts, + the data it returns to the user, and so on, for display in the UI.""" + + def __init__( + self, + *, + log: Logger, + config_manager: ConfigManager, + model_parameters: Dict[str, Dict], + chat_history: List[ChatMessage], + llm_chat_memory: "BoundedChatHistory", + root_dir: str, + preferred_dir: Optional[str], + dask_client_future: Awaitable[DaskClient], + chat_handlers: Dict[str, "BaseChatHandler"], + context_providers: Dict[str, "BaseCommandContextProvider"], + ): + preferred_dir = preferred_dir or "" + self.log = log + self.config_manager = config_manager + self.model_parameters = model_parameters + self._chat_history = chat_history + self.llm_chat_memory = llm_chat_memory + self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) + self.preferred_dir = get_preferred_dir(self.root_dir, preferred_dir) + self.dask_client_future = dask_client_future + self.chat_handlers = chat_handlers + self.context_providers = context_providers + + self.llm = None + + @abc.abstractmethod + async def make_context_prompt(self, message: HumanChatMessage) -> str: + """Returns a context prompt for all commands of the context provider + command. + """ + pass + + def replace_prompt(self, prompt: str) -> str: + """Modifies the prompt before sending it to the LLM.""" + return prompt + + def _clean_prompt(self, text: str) -> str: + # util for cleaning up the prompt before sending it to a retriever + for provider in self.context_providers.values(): + text = provider.replace_prompt(text) + return text + + @property + def base_dir(self) -> str: + # same as BaseChatHandler.output_dir + if self.preferred_dir and os.path.exists(self.preferred_dir): + return self.preferred_dir + else: + return self.root_dir + + def get_llm(self): + lm_provider = self.config_manager.lm_provider + lm_provider_params = self.config_manager.lm_provider_params + + curr_lm_id = ( + f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None + ) + next_lm_id = ( + f'{lm_provider.id}:{lm_provider_params["model_id"]}' + if lm_provider + else None + ) + + if not lm_provider or not lm_provider_params: + return None + + if curr_lm_id != next_lm_id: + model_parameters = self.model_parameters.get( + f"{lm_provider.id}:{lm_provider_params['model_id']}", {} + ) + unified_parameters = { + "verbose": True, + **lm_provider_params, + **model_parameters, + } + llm = lm_provider(**unified_parameters) + self.llm = llm + return self.llm + + +class ContextCommand(BaseModel): + cmd: str + + @property + def id(self) -> str: + return self.cmd.partition(":")[0] + + @property + def arg(self) -> Optional[str]: + if ":" not in self.cmd: + return None + return self.cmd.partition(":")[2].strip("'\"").replace("\\ ", " ") + + def __str__(self) -> str: + return self.cmd + + def __hash__(self) -> int: + return hash(self.cmd) + + +class BaseCommandContextProvider(_BaseContextProvider): + id_prefix: ClassVar[str] = "@" + """Prefix symbol for command. Generally should not be overridden.""" + + # Configuration + requires_arg: ClassVar[bool] = False + """Whether command has an argument. E.g. '@file:'.""" + remove_from_prompt: ClassVar[bool] = False + """Whether the command should be removed from prompt when passing to LLM.""" + only_start: ClassVar[bool] = False + """Whether to command can only be inserted at the start of the prompt.""" + + @property + def command_id(self) -> str: + return self.id_prefix + self.id + + @property + def pattern(self) -> str: + # arg pattern allows for arguments between quotes or spaces with escape character ('\ ') + return ( + rf"(? str: + """Returns a context prompt for all commands of the context provider + command. + """ + commands = find_commands(self, message.prompt) + if not commands: + return "" + return await self._make_context_prompt(message, commands) + + @abc.abstractmethod + async def _make_context_prompt( + self, message: HumanChatMessage, commands: List[ContextCommand] + ) -> str: + """Returns a context prompt for the given commands.""" + pass + + def replace_prompt(self, prompt: str) -> str: + """Cleans up commands from the prompt before sending it to the LLM""" + + def replace(match): + if _is_command_call(match, prompt): + return self._replace_command(ContextCommand(cmd=match.group())) + return match.group() + + return re.sub(self.pattern, replace, prompt) + + def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]: + """Returns a list of autocomplete options for arguments to the command + based on the prefix. + Only triggered if ':' is present after the command id (e.g. '@file:'). + """ + if self.requires_arg: + # default implementation that should be modified if 'requires_arg' is True + return [self._make_arg_option(arg_prefix)] + return [] + + def _replace_command(self, command: ContextCommand) -> str: + if self.remove_from_prompt: + return "" + return command.cmd + + def _make_arg_option( + self, + arg: str, + *, + is_complete: bool = True, + description: Optional[str] = None, + ) -> ListOptionsEntry: + arg = arg.replace("\\ ", " ").replace(" ", "\\ ") # escape spaces + label = self.command_id + ":" + arg + (" " if is_complete else "") + return ListOptionsEntry( + id=self.command_id, + description=description or self.help, + label=label, + only_start=self.only_start, + ) + + +def find_commands( + context_provider: BaseCommandContextProvider, text: str +) -> List[ContextCommand]: + # finds commands of the context provider in the text + matches = list(re.finditer(context_provider.pattern, text)) + if context_provider.only_start: + matches = [match for match in matches if match.start() == 0] + results = [] + for match in matches: + if _is_command_call(match, text): + results.append(ContextCommand(cmd=match.group())) + return results + + +class ContextProviderException(Exception): + # Used to generate a response when a context provider fails + pass + + +def _is_command_call(match, text): + """Check if the match is a command call rather than a part of a code block. + This is done by checking if there is an even number of backticks before and + after the match. If there is an odd number of backticks, the match is likely + inside a code block. + """ + # potentially buggy if there is a stray backtick in text + # e.g. "help me count the backticks '`' ... ```\n...@cmd in code\n```". + # can be addressed by having selection in context rather than in prompt. + # more generally addressed by having a better command detection mechanism + # such as placing commands within special tags. + start, end = match.span() + before = text[:start] + after = text[end:] + return before.count("`") % 2 == 0 or after.count("`") % 2 == 0 diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/file.py b/packages/jupyter-ai/jupyter_ai/context_providers/file.py new file mode 100644 index 000000000..a45004d48 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/file.py @@ -0,0 +1,119 @@ +import glob +import os +from typing import List + +import nbformat +from jupyter_ai.document_loaders.directory import SUPPORTED_EXTS +from jupyter_ai.models import HumanChatMessage, ListOptionsEntry + +from .base import ( + BaseCommandContextProvider, + ContextCommand, + ContextProviderException, + find_commands, +) + +FILE_CONTEXT_TEMPLATE = """ +File: {filepath} +``` +{content} +``` +""".strip() + + +class FileContextProvider(BaseCommandContextProvider): + id = "file" + help = "Include selected file's contents" + requires_arg = True + header = "Following are contents of files referenced:" + + def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]: + is_abs = not os.path.isabs(arg_prefix) + path_prefix = arg_prefix if is_abs else os.path.join(self.base_dir, arg_prefix) + path_prefix = path_prefix + return [ + self._make_arg_option( + arg=self._make_path(path, is_abs, is_dir), + description="Directory" if is_dir else "File", + is_complete=not is_dir, + ) + for path in glob.glob(path_prefix + "*") + if ( + (is_dir := os.path.isdir(path)) + or os.path.splitext(path)[1] in SUPPORTED_EXTS + ) + ] + + def _make_path(self, path: str, is_abs: bool, is_dir: bool) -> str: + if not is_abs: + path = os.path.relpath(path, self.base_dir) + if is_dir: + path += "/" + return path + + async def _make_context_prompt( + self, message: HumanChatMessage, commands: List[ContextCommand] + ) -> str: + context = "\n\n".join( + [ + context + for i in set(commands) + if (context := self._make_command_context(i)) + ] + ) + if not context: + return "" + return self.header + "\n" + context + + def _make_command_context(self, command: ContextCommand) -> str: + filepath = command.arg or "" + if not os.path.isabs(filepath): + filepath = os.path.join(self.base_dir, filepath) + + if not os.path.exists(filepath): + raise ContextProviderException( + f"File not found while trying to read '{filepath}' " + f"triggered by `{command}`." + ) + if os.path.isdir(filepath): + raise ContextProviderException( + f"Cannot read directory '{filepath}' triggered by `{command}`. " + f"Only files are supported." + ) + if os.path.splitext(filepath)[1] not in SUPPORTED_EXTS: + raise ContextProviderException( + f"Cannot read unsupported file type '{filepath}' triggered by `{command}`. " + f"Supported file extensions are: {', '.join(SUPPORTED_EXTS)}." + ) + try: + with open(filepath) as f: + content = f.read() + except PermissionError: + raise ContextProviderException( + f"Permission denied while trying to read '{filepath}' " + f"triggered by `{command}`." + ) + return FILE_CONTEXT_TEMPLATE.format( + filepath=filepath, + content=self._process_file(content, filepath), + ) + + def _process_file(self, content: str, filepath: str): + if filepath.endswith(".ipynb"): + nb = nbformat.reads(content, as_version=4) + return "\n\n".join([cell.source for cell in nb.cells]) + return content + + def _replace_command(self, command: ContextCommand) -> str: + # replaces commands of @file: with '' + filepath = command.arg or "" + return f"'{filepath}'" + + def get_filepaths(self, message: HumanChatMessage) -> List[str]: + filepaths = [] + for command in find_commands(self, message.prompt): + filepath = command.arg or "" + if not os.path.isabs(filepath): + filepath = os.path.join(self.base_dir, filepath) + filepaths.append(filepath) + return filepaths diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 34b484546..7d93e2c90 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -24,8 +24,10 @@ ) from .completions.handlers import DefaultInlineCompletionHandler from .config_manager import ConfigManager +from .context_providers import BaseCommandContextProvider, FileContextProvider from .handlers import ( ApiKeysHandler, + AutocompleteOptionsHandler, ChatHistoryHandler, EmbeddingsModelProviderHandler, GlobalConfigHandler, @@ -45,6 +47,9 @@ You can ask me a question using the text box below. You can also use these commands: {slash_commands_list} +You can use the following commands to add context to your questions: +{context_commands_list} + Jupyter AI includes [magic commands](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#the-ai-and-ai-magic-commands) that you can use in your notebooks. For more information, see the [documentation](https://jupyter-ai.readthedocs.io). """ @@ -58,6 +63,7 @@ class AiExtension(ExtensionApp): (r"api/ai/chats/?", RootChatHandler), (r"api/ai/chats/history?", ChatHistoryHandler), (r"api/ai/chats/slash_commands?", SlashCommandsInfoHandler), + (r"api/ai/chats/autocomplete_options?", AutocompleteOptionsHandler), (r"api/ai/providers?", ModelProviderHandler), (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), (r"api/ai/completion/inline/?", DefaultInlineCompletionHandler), @@ -283,9 +289,64 @@ def initialize_settings(self): # consumers a Future that resolves to the Dask client when awaited. self.settings["dask_client_future"] = loop.create_task(self._get_dask_client()) - eps = entry_points() + # Create empty context providers dict to be filled later. + # This is created early to use as kwargs for chat handlers. + self.settings["jai_context_providers"] = {} # initialize chat handlers + self._init_chat_handlers() + + # initialize context providers + self._init_context_provders() + + # show help message at server start + self._show_help_message() + + latency_ms = round((time.time() - start) * 1000) + self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") + + def _show_help_message(self): + """ + Method that ensures a dynamically-generated help message is included in + the chat history shown to users. + """ + # call `send_help_message()` on any instance of `BaseChatHandler`. The + # `default` chat handler should always exist, so we reference that + # object when calling `send_help_message()`. + default_chat_handler: DefaultChatHandler = self.settings["jai_chat_handlers"][ + "default" + ] + default_chat_handler.send_help_message() + + async def _get_dask_client(self): + return DaskClient(processes=False, asynchronous=True) + + async def stop_extension(self): + """ + Public method called by Jupyter Server when the server is stopping. + This calls the cleanup code defined in `self._stop_exception()` inside + an exception handler, as the server halts if this method raises an + exception. + """ + try: + await self._stop_extension() + except Exception as e: + self.log.error("Jupyter AI raised an exception while stopping:") + self.log.exception(e) + + async def _stop_extension(self): + """ + Private method that defines the cleanup code to run when the server is + stopping. + """ + if "dask_client_future" in self.settings: + dask_client: DaskClient = await self.settings["dask_client_future"] + self.log.info("Closing Dask client.") + await dask_client.close() + self.log.debug("Closed Dask client.") + + def _init_chat_handlers(self): + eps = entry_points() chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers") chat_handlers = {} chat_handler_kwargs = { @@ -301,6 +362,7 @@ def initialize_settings(self): "preferred_dir": self.serverapp.contents_manager.preferred_dir, "help_message_template": self.help_message_template, "chat_handlers": chat_handlers, + "context_providers": self.settings["jai_context_providers"], } default_chat_handler = DefaultChatHandler(**chat_handler_kwargs) clear_chat_handler = ClearChatHandler(**chat_handler_kwargs) @@ -376,48 +438,58 @@ def initialize_settings(self): # bind chat handlers to settings self.settings["jai_chat_handlers"] = chat_handlers - # show help message at server start - self._show_help_message() - - latency_ms = round((time.time() - start) * 1000) - self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") - - def _show_help_message(self): - """ - Method that ensures a dynamically-generated help message is included in - the chat history shown to users. - """ - # call `send_help_message()` on any instance of `BaseChatHandler`. The - # `default` chat handler should always exist, so we reference that - # object when calling `send_help_message()`. - default_chat_handler: DefaultChatHandler = self.settings["jai_chat_handlers"][ - "default" + def _init_context_provders(self): + eps = entry_points() + context_providers_eps = eps.select(group="jupyter_ai.context_providers") + context_providers = self.settings["jai_context_providers"] + context_providers_kwargs = { + "log": self.log, + "config_manager": self.settings["jai_config_manager"], + "model_parameters": self.settings["model_parameters"], + "chat_history": self.settings["chat_history"], + "llm_chat_memory": self.settings["llm_chat_memory"], + "root_dir": self.serverapp.root_dir, + "dask_client_future": self.settings["dask_client_future"], + "model_parameters": self.settings["model_parameters"], + "preferred_dir": self.serverapp.contents_manager.preferred_dir, + "chat_handlers": self.settings["jai_chat_handlers"], + "context_providers": self.settings["jai_context_providers"], + } + context_providers_clses = [ + FileContextProvider, ] - default_chat_handler.send_help_message() + for context_provider_ep in context_providers_eps: + try: + context_provider = context_provider_ep.load() + except Exception as err: + self.log.error( + f"Unable to load context provider class from entry point `{context_provider_ep.name}`: " + + f"Unexpected {err=}, {type(err)=}" + ) + continue + context_providers_clses.append(context_provider) - async def _get_dask_client(self): - return DaskClient(processes=False, asynchronous=True) + for context_provider in context_providers_clses: + if not issubclass(context_provider, BaseCommandContextProvider): + self.log.error( + f"Unable to register context provider `{context_provider.id}` because it does not inherit from `BaseCommandContextProvider`" + ) + continue - async def stop_extension(self): - """ - Public method called by Jupyter Server when the server is stopping. - This calls the cleanup code defined in `self._stop_exception()` inside - an exception handler, as the server halts if this method raises an - exception. - """ - try: - await self._stop_extension() - except Exception as e: - self.log.error("Jupyter AI raised an exception while stopping:") - self.log.exception(e) + if context_provider.id in context_providers: + self.log.error( + f"Unable to register context provider `{context_provider.id}` because it already exists" + ) + continue - async def _stop_extension(self): - """ - Private method that defines the cleanup code to run when the server is - stopping. - """ - if "dask_client_future" in self.settings: - dask_client: DaskClient = await self.settings["dask_client_future"] - self.log.info("Closing Dask client.") - await dask_client.close() - self.log.debug("Closed Dask client.") + if not re.match(r"^[a-zA-Z0-9_]+$", context_provider.id): + self.log.error( + f"Context provider `{context_provider.id}` is an invalid ID; " + + f"must contain only letters, numbers, and underscores" + ) + continue + + context_providers[context_provider.id] = context_provider( + **context_providers_kwargs + ) + self.log.info(f"Registered context provider `{context_provider.id}`.") diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 7c4f16e63..614df557d 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -9,6 +9,7 @@ import tornado from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType from jupyter_ai.config_manager import ConfigManager, KeyEmptyError, WriteConflictError +from jupyter_ai.context_providers import BaseCommandContextProvider, ContextCommand from jupyter_server.base.handlers import APIHandler as BaseAPIHandler from jupyter_server.base.handlers import JupyterHandler from langchain.pydantic_v1 import ValidationError @@ -29,6 +30,8 @@ ClosePendingMessage, ConnectionMessage, HumanChatMessage, + ListOptionsEntry, + ListOptionsResponse, ListProvidersEntry, ListProvidersResponse, ListSlashCommandsEntry, @@ -42,6 +45,7 @@ from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider from jupyter_ai_magics.providers import BaseProvider + from .context_providers import BaseCommandContextProvider from .history import BoundedChatHistory @@ -571,3 +575,109 @@ def get(self): # sort slash commands by slash id and deliver the response response.slash_commands.sort(key=lambda sc: sc.slash_id) self.finish(response.json()) + + +class AutocompleteOptionsHandler(BaseAPIHandler): + """List context that are currently available to the user.""" + + @property + def config_manager(self) -> ConfigManager: # type:ignore[override] + return self.settings["jai_config_manager"] + + @property + def context_providers(self) -> Dict[str, "BaseCommandContextProvider"]: + return self.settings["jai_context_providers"] + + @property + def chat_handlers(self) -> Dict[str, "BaseChatHandler"]: + return self.settings["jai_chat_handlers"] + + @web.authenticated + def get(self): + response = ListOptionsResponse() + + # if no selected LLM, return an empty response + if not self.config_manager.lm_provider: + self.finish(response.json()) + return + + partial_cmd = self.get_query_argument("partialCommand", None) + if partial_cmd: + # if providing options for partial command argument + cmd = ContextCommand(cmd=partial_cmd) + context_provider = next( + ( + cp + for cp in self.context_providers.values() + if isinstance(cp, BaseCommandContextProvider) + and cp.command_id == cmd.id + ), + None, + ) + if ( + cmd.arg is not None + and context_provider + and isinstance(context_provider, BaseCommandContextProvider) + ): + response.options = context_provider.get_arg_options(cmd.arg) + else: + response.options = ( + self._get_slash_command_options() + self._get_context_provider_options() + ) + self.finish(response.json()) + + def _get_slash_command_options(self) -> List[ListOptionsEntry]: + options = [] + for id, chat_handler in self.chat_handlers.items(): + # filter out any chat handler that is not a slash command + if id == "default" or not isinstance( + chat_handler.routing_type, SlashCommandRoutingType + ): + continue + + routing_type = chat_handler.routing_type + + # filter out any chat handler that is unsupported by the current LLM + if ( + not routing_type.slash_id + or "/" + routing_type.slash_id + in self.config_manager.lm_provider.unsupported_slash_commands + ): + continue + + options.append( + self._make_autocomplete_option( + id="/" + routing_type.slash_id, + description=chat_handler.help, + only_start=True, + requires_arg=False, + ) + ) + options.sort(key=lambda opt: opt.id) + return options + + def _get_context_provider_options(self) -> List[ListOptionsEntry]: + options = [ + self._make_autocomplete_option( + id=context_provider.command_id, + description=context_provider.help, + only_start=context_provider.only_start, + requires_arg=context_provider.requires_arg, + ) + for context_provider in self.context_providers.values() + if isinstance(context_provider, BaseCommandContextProvider) + ] + options.sort(key=lambda opt: opt.id) + return options + + def _make_autocomplete_option( + self, + id: str, + description: str, + only_start: bool, + requires_arg: bool, + ): + label = id + (":" if requires_arg else " ") + return ListOptionsEntry( + id=id, description=description, label=label, only_start=only_start + ) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index bda4d3421..e951ac6e8 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -263,3 +263,21 @@ class ListSlashCommandsEntry(BaseModel): class ListSlashCommandsResponse(BaseModel): slash_commands: List[ListSlashCommandsEntry] = [] + + +class ListOptionsEntry(BaseModel): + id: str + """ID of the autocomplete option. + Includes the command prefix. E.g. "/clear", "@file".""" + label: str + """Text that will be inserted into the prompt when the option is selected. + Includes a space at the end if the option is complete. + Partial suggestions do not include the space and may trigger future suggestions.""" + description: str + """Text next to the option in the autocomplete list.""" + only_start: bool + """Whether to command can only be inserted at the start of the prompt.""" + + +class ListOptionsResponse(BaseModel): + options: List[ListOptionsEntry] = [] diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py new file mode 100644 index 000000000..132dcf871 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py @@ -0,0 +1,80 @@ +import logging +from unittest import mock + +import pytest +from jupyter_ai.config_manager import ConfigManager +from jupyter_ai.context_providers import FileContextProvider, find_commands +from jupyter_ai.history import BoundedChatHistory +from jupyter_ai.models import ChatClient, HumanChatMessage, Persona + + +@pytest.fixture +def human_chat_message() -> HumanChatMessage: + chat_client = ChatClient( + id=0, username="test", initials="test", name="test", display_name="test" + ) + prompt = ( + "@file:test1.py @file @file:dir/test2.md test test\n" + "@file:/dir/test3.png\n" + "test@file:fail1.py\n" + "@file:dir\\ test\\ /test\\ 4.py\n" # spaces with escape + "@file:'test 5.py' @file:\"test6 .py\"\n" # quotes with spaces + "@file:'test7.py test\"\n" # do not allow for mixed quotes + "```\n@file:fail2.py\n```\n" # do not look within backticks + ) + return HumanChatMessage( + id="test", + time=0, + body=prompt, + prompt=prompt, + client=chat_client, + ) + + +@pytest.fixture +def file_context_provider() -> FileContextProvider: + config_manager = mock.create_autospec(ConfigManager) + config_manager.persona = Persona(name="test", avatar_route="test") + return FileContextProvider( + log=logging.getLogger(__name__), + config_manager=config_manager, + model_parameters={}, + chat_history=[], + llm_chat_memory=BoundedChatHistory(k=2), + root_dir="", + preferred_dir="", + dask_client_future=None, + chat_handlers={}, + context_providers={}, + ) + + +def test_find_instances(file_context_provider, human_chat_message): + expected = [ + "@file:test1.py", + "@file:dir/test2.md", + "@file:/dir/test3.png", + r"@file:dir\ test\ /test\ 4.py", + "@file:'test 5.py'", + '@file:"test6 .py"', + "@file:'test7.py", + ] + commands = [ + cmd.cmd + for cmd in find_commands(file_context_provider, human_chat_message.prompt) + ] + assert commands == expected + + +def test_replace_prompt(file_context_provider, human_chat_message): + expected = ( + "'test1.py' @file 'dir/test2.md' test test\n" + "'/dir/test3.png'\n" + "test@file:fail1.py\n" + "'dir test /test 4.py'\n" + "'test 5.py' 'test6 .py'\n" + "'test7.py' test\"\n" + "```\n@file:fail2.py\n```\n" # do not look within backticks + ) + prompt = file_context_provider.replace_prompt(human_chat_message.prompt) + assert prompt == expected diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index c1ca7b098..a94c3fbf8 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -76,6 +76,7 @@ def broadcast_message(message: Message) -> None: dask_client_future=None, help_message_template=DEFAULT_HELP_MESSAGE_TEMPLATE, chat_handlers={}, + context_providers={}, ) diff --git a/packages/jupyter-ai/src/components/chat-input.tsx b/packages/jupyter-ai/src/components/chat-input.tsx index ddc2d4209..8652307c2 100644 --- a/packages/jupyter-ai/src/components/chat-input.tsx +++ b/packages/jupyter-ai/src/components/chat-input.tsx @@ -38,12 +38,6 @@ type ChatInputProps = { personaName: string; }; -type SlashCommandOption = { - id: string; - label: string; - description: string; -}; - /** * List of icons per slash command, shown in the autocomplete popup. * @@ -51,28 +45,29 @@ type SlashCommandOption = { * unclear whether custom icons should be defined within a Lumino plugin (in the * frontend) or served from a static server route (in the backend). */ -const DEFAULT_SLASH_COMMAND_ICONS: Record = { - ask: , - clear: , - export: , - fix: , - generate: , - help: , - learn: , +const DEFAULT_COMMAND_ICONS: Record = { + '/ask': , + '/clear': , + '/export': , + '/fix': , + '/generate': , + '/help': , + '/learn': , + '@file': , unknown: }; /** * Renders an option shown in the slash command autocomplete. */ -function renderSlashCommandOption( +function renderAutocompleteOption( optionProps: React.HTMLAttributes, - option: SlashCommandOption + option: AiService.AutocompleteOption ): JSX.Element { const icon = - option.id in DEFAULT_SLASH_COMMAND_ICONS - ? DEFAULT_SLASH_COMMAND_ICONS[option.id] - : DEFAULT_SLASH_COMMAND_ICONS.unknown; + option.id in DEFAULT_COMMAND_ICONS + ? DEFAULT_COMMAND_ICONS[option.id] + : DEFAULT_COMMAND_ICONS.unknown; return (
  • @@ -99,8 +94,14 @@ function renderSlashCommandOption( export function ChatInput(props: ChatInputProps): JSX.Element { const [input, setInput] = useState(''); - const [slashCommandOptions, setSlashCommandOptions] = useState< - SlashCommandOption[] + const [autocompleteOptions, setAutocompleteOptions] = useState< + AiService.AutocompleteOption[] + >([]); + const [autocompleteCommandOptions, setAutocompleteCommandOptions] = useState< + AiService.AutocompleteOption[] + >([]); + const [autocompleteArgOptions, setAutocompleteArgOptions] = useState< + AiService.AutocompleteOption[] >([]); const [currSlashCommand, setCurrSlashCommand] = useState(null); const activeCell = useActiveCellContext(); @@ -110,24 +111,46 @@ export function ChatInput(props: ChatInputProps): JSX.Element { * initial mount to populate the slash command autocomplete. */ useEffect(() => { - async function getSlashCommands() { - const slashCommands = (await AiService.listSlashCommands()) - .slash_commands; - setSlashCommandOptions( - slashCommands.map(slashCommand => ({ - id: slashCommand.slash_id, - label: '/' + slashCommand.slash_id + ' ', - description: slashCommand.description - })) - ); + async function getAutocompleteCommandOptions() { + const response = await AiService.listAutocompleteOptions(); + setAutocompleteCommandOptions(response.options); } - getSlashCommands(); + getAutocompleteCommandOptions(); }, []); - // whether any option is highlighted in the slash command autocomplete + useEffect(() => { + async function getAutocompleteArgOptions() { + let options: AiService.AutocompleteOption[] = []; + const lastWord = getLastWord(input); + if (lastWord.includes(':')) { + const id = lastWord.split(':', 1)[0]; + // get option that matches the command + const option = autocompleteCommandOptions.find( + option => option.id === id + ); + if (option) { + const response = await AiService.listAutocompleteArgOptions(lastWord); + options = response.options; + } + } + setAutocompleteArgOptions(options); + } + getAutocompleteArgOptions(); + }, [autocompleteCommandOptions, input]); + + // Combine the fixed options with the argument options + useEffect(() => { + if (autocompleteArgOptions.length > 0) { + setAutocompleteOptions(autocompleteArgOptions); + } else { + setAutocompleteOptions(autocompleteCommandOptions); + } + }, [autocompleteCommandOptions, autocompleteArgOptions]); + + // whether any option is highlighted in the autocomplete const [highlighted, setHighlighted] = useState(false); - // controls whether the slash command autocomplete is open + // controls whether the autocomplete is open const [open, setOpen] = useState(false); // store reference to the input element to enable focusing it easily @@ -153,7 +176,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { * chat input. Close the autocomplete when the user clears the chat input. */ useEffect(() => { - if (input === '/') { + if (filterAutocompleteOptions(autocompleteOptions, input).length > 0) { setOpen(true); return; } @@ -255,12 +278,40 @@ export function ChatInput(props: ChatInputProps): JSX.Element { currSlashCommand }; + function filterAutocompleteOptions( + options: AiService.AutocompleteOption[], + inputValue: string + ): AiService.AutocompleteOption[] { + const lastWord = getLastWord(inputValue); + if (lastWord === '') { + return []; + } + const isStart = lastWord === inputValue; + return options.filter( + option => + option.label.startsWith(lastWord) && (!option.only_start || isStart) + ); + } + return ( { + return filterAutocompleteOptions(options, inputValue); + }} + onChange={(_, option) => { + const value = typeof option === 'string' ? option : option.label; + let matchLength = 0; + for (let i = 1; i <= value.length; i++) { + if (input.endsWith(value.slice(0, i))) { + matchLength = i; + } + } + setInput(input + value.slice(matchLength)); + }} onInputChange={(_, newValue: string) => { setInput(newValue); }} @@ -273,12 +324,16 @@ export function ChatInput(props: ChatInputProps): JSX.Element { setHighlighted(!!highlightedOption); } } - onClose={() => setOpen(false)} + onClose={(_, reason) => { + if (reason !== 'selectOption' || input.endsWith(' ')) { + setOpen(false); + } + }} // set this to an empty string to prevent the last selected slash // command from being shown in blue value="" open={open} - options={slashCommandOptions} + options={autocompleteOptions} // hide default extra right padding in the text field disableClearable // ensure the autocomplete popup always renders on top @@ -292,7 +347,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { } } }} - renderOption={renderSlashCommandOption} + renderOption={renderAutocompleteOption} ListboxProps={{ sx: { '& .MuiAutocomplete-option': { @@ -331,3 +386,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { ); } + +function getLastWord(input: string): string { + return input.split(/(? { return requestAPI('chats/slash_commands'); } + + export type AutocompleteOption = { + id: string; + description: string; + label: string; + only_start: boolean; + }; + + export type ListAutocompleteOptionsResponse = { + options: AutocompleteOption[]; + }; + + export async function listAutocompleteOptions(): Promise { + return requestAPI( + 'chats/autocomplete_options' + ); + } + + export async function listAutocompleteArgOptions( + partialCommand: string + ): Promise { + return requestAPI( + 'chats/autocomplete_options?partialCommand=' + + encodeURIComponent(partialCommand) + ); + } }