diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 551db8bbc..3b6c672a1 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -67,11 +67,26 @@ The following is a friendly conversation between you and a human. """.strip() -CHAT_DEFAULT_TEMPLATE = """Current conversation: -{history} -Human: {input} +CHAT_DEFAULT_TEMPLATE = """ +{% if context %} +Context: +{{context}} + +{% endif %} +Current conversation: +{{history}} +Human: {{input}} AI:""" +HUMAN_MESSAGE_TEMPLATE = """ +{% if context %} + +{{context}} + + +{% endif %} +{{input}} +""" COMPLETION_SYSTEM_PROMPT = """ You are an application built to provide helpful code completion suggestions. @@ -400,17 +415,21 @@ def get_chat_prompt_template(self) -> PromptTemplate: CHAT_SYSTEM_PROMPT ).format(provider_name=name, local_model_id=self.model_id), MessagesPlaceholder(variable_name="history"), - HumanMessagePromptTemplate.from_template("{input}"), + HumanMessagePromptTemplate.from_template( + HUMAN_MESSAGE_TEMPLATE, + template_format="jinja2", + ), ] ) else: return PromptTemplate( - input_variables=["history", "input"], + input_variables=["history", "input", "context"], template=CHAT_SYSTEM_PROMPT.format( provider_name=name, local_model_id=self.model_id ) + "\n\n" + CHAT_DEFAULT_TEMPLATE, + template_format="jinja2", ) def get_completion_prompt_template(self) -> PromptTemplate: diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 347bfbf83..2e5786f57 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -35,6 +35,7 @@ if TYPE_CHECKING: from jupyter_ai.handlers import RootChatHandler from jupyter_ai.history import BoundedChatHistory + from jupyter_ai.context_providers import BaseContextProvider from langchain_core.chat_history import BaseChatMessageHistory @@ -121,6 +122,10 @@ class BaseChatHandler: chat handlers, which is necessary for some use-cases like printing the help message.""" + context_providers: Dict[str, Type["BaseContextProvider"]] + """Dictionary of context providers. Allows chat handlers to reference + context providers, which can be used to provide context to the LLM.""" + def __init__( self, log: Logger, @@ -134,6 +139,7 @@ def __init__( dask_client_future: Awaitable[DaskClient], help_message_template: str, chat_handlers: Dict[str, "BaseChatHandler"], + context_providers: Dict[str, Type["BaseContextProvider"]], ): self.log = log self.config_manager = config_manager @@ -154,6 +160,7 @@ def __init__( self.dask_client_future = dask_client_future self.help_message_template = help_message_template self.chat_handlers = chat_handlers + self.context_providers = context_providers self.llm: Optional[BaseProvider] = None self.llm_params: Optional[dict] = None diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 2bf88bafe..cc905446b 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,3 +1,4 @@ +import asyncio import time from typing import Dict, Type from uuid import uuid4 @@ -13,6 +14,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory from ..models import HumanChatMessage +from ..context_providers import ContextProviderException from .base import BaseChatHandler, SlashCommandRoutingType @@ -27,6 +29,7 @@ class DefaultChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.prompt_template = None def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -40,6 +43,7 @@ def create_llm_chain( prompt_template = llm.get_chat_prompt_template() self.llm = llm + self.prompt_template = prompt_template runnable = prompt_template | llm # type:ignore if not llm.manages_history: @@ -101,6 +105,16 @@ async def process_message(self, message: HumanChatMessage): self.get_llm_chain() received_first_chunk = False + inputs = {"input": message.body} + if "context" in self.prompt_template.input_variables: + # include context from context providers. + try: + context_prompt = await self.make_context_prompt(message) + except ContextProviderException as e: + self.reply(str(e), message) + return + inputs["context"] = context_prompt + # start with a pending message with self.pending("Generating response", message) as pending_message: # stream response in chunks. this works even if a provider does not @@ -108,7 +122,7 @@ async def process_message(self, message: HumanChatMessage): # when `_stream()` is not implemented on the LLM class. assert self.llm_chain async for chunk in self.llm_chain.astream( - {"input": message.body}, + inputs, config={"configurable": {"last_human_msg": message}}, ): if not received_first_chunk: @@ -128,3 +142,13 @@ async def process_message(self, message: HumanChatMessage): # complete stream after all chunks have been streamed self._send_stream_chunk(stream_id, "", complete=True) + + async def make_context_prompt(self, human_msg: HumanChatMessage) -> str: + return "\n\n".join( + await asyncio.gather( + *[ + provider.make_context_prompt(human_msg) + for provider in self.context_providers.values() + ] + ) + ) diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py b/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py new file mode 100644 index 000000000..880faf28b --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py @@ -0,0 +1,3 @@ +from .base import BaseContextProvider, ContextProviderException +from .file import FileContextProvider +from .learned import LearnedContextProvider diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/_examples.py b/packages/jupyter-ai/jupyter_ai/context_providers/_examples.py new file mode 100644 index 000000000..6772d7a3e --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/_examples.py @@ -0,0 +1,133 @@ +# This file is for illustrative purposes +# It is to be deleted before merging +from jupyter_ai.models import HumanChatMessage +from langchain_community.retrievers import WikipediaRetriever +from langchain_community.retrievers import ArxivRetriever +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser + +from .base import BaseContextProvider + + +# Examples of the ease of implementing retriever based context providers +ARXIV_TEMPLATE = """ +Title: {title} +Publish Date: {publish_date} +''' +{content} +''' +""".strip() + + +class ArxivContextProvider(BaseContextProvider): + id = "arvix" + description = "Include papers from Arxiv" + remove_from_prompt = True + header = "Following are snippets of research papers:" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.retriever = ArxivRetriever() + + async def make_context_prompt(self, message: HumanChatMessage) -> str: + if not self._find_instances(message.prompt): + return "" + query = self._clean_prompt(message.body) + docs = await self.retriever.ainvoke(query) + context = "\n\n".join( + [ + ARXIV_TEMPLATE.format( + content=d.page_content, + title=d.metadata["Title"], + publish_date=d.metadata["Published"], + ) + for d in docs + ] + ) + return self.header + "\n" + context + + +# Another retriever based context provider with a rewrite step using LLM +WIKI_TEMPLATE = """ +Title: {title} +''' +{content} +''' +""".strip() + +REWRITE_TEMPLATE = """Provide a better search query for \ +web search engine to answer the given question, end \ +the queries with ’**’. Question: \ +{x} Answer:""" + + +class WikiContextProvider(BaseContextProvider): + id = "wiki" + description = "Include knowledge from Wikipedia" + remove_from_prompt = True + header = "Following are information from wikipedia:" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.retriever = WikipediaRetriever() + + async def make_context_prompt(self, message: HumanChatMessage) -> str: + if not self._find_instances(message.prompt): + return "" + prompt = self._clean_prompt(message.body) + search_query = await self._rewrite_prompt(prompt) + docs = await self.retriever.ainvoke(search_query) + context = "\n\n".join( + [ + WIKI_TEMPLATE.format( + content=d.page_content, + title=d.metadata["title"], + ) + for d in docs + ] + ) + return self.header + "\n" + context + + async def _rewrite_prompt(self, prompt: str) -> str: + return await self.get_llm_chain().ainvoke(prompt) + + def get_llm_chain(self): + # from https://github.com/langchain-ai/langchain/blob/master/cookbook/rewrite.ipynb + llm = self.get_llm() + rewrite_prompt = ChatPromptTemplate.from_template(REWRITE_TEMPLATE) + + def _parse(text): + return text.strip('"').strip("**") + + return rewrite_prompt | llm | StrOutputParser() | _parse + + +# Partial example of non-command context provider for errors. +# Assuming there is an option in UI to add cell errors to messages, +# default chat will automatically invoke this context provider to add +# solutions retrieved from a custom error database or a stackoverflow / google +# retriever pipeline to find solutions for errors. +class ErrorContextProvider(BaseContextProvider): + id = "error" + description = "Include custom error context" + remove_from_prompt = True + header = "Following are potential solutions for the error:" + is_command = False # will not show up in autocomplete + + async def make_context_prompt(self, message: HumanChatMessage) -> str: + # will run for every message with a cell error since it does not + # use _find_instances to check for the presence of the command in + # the message. + if not (message.selection and message.selection.type == "cell-with-error"): + return "" + docs = await self.solution_retriever.ainvoke(message.selection) + if not docs: + return "" + context = "\n\n".join([d.page_content for d in docs]) + return self.header + "\n" + context + + @property + def solution_retriever(self): + # retriever that takes an error and returns a solutions from a database + # of error messages. + raise NotImplementedError("Error retriever not implemented") diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/base.py b/packages/jupyter-ai/jupyter_ai/context_providers/base.py new file mode 100644 index 000000000..83f970e55 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/base.py @@ -0,0 +1,157 @@ +import abc +import re +from typing import ClassVar, List + +import os +from typing import ( + TYPE_CHECKING, + Awaitable, + ClassVar, + Dict, + List, + Optional, +) + +from dask.distributed import Client as DaskClient +from jupyter_ai.config_manager import ConfigManager, Logger +from jupyter_ai.models import ( + ChatMessage, + HumanChatMessage, +) +from jupyter_ai.chat_handlers.base import get_preferred_dir +from jupyter_ai.models import ListOptionsEntry, HumanChatMessage + +if TYPE_CHECKING: + from jupyter_ai.history import BoundedChatHistory + from jupyter_ai.chat_handlers import BaseChatHandler + + +class BaseContextProvider(abc.ABC): + id: ClassVar[str] + description: ClassVar[str] + requires_arg: ClassVar[bool] = False + is_command: ClassVar[bool] = ( + True # whether the context provider can be invoked from chat + ) + remove_from_prompt: ClassVar[bool] = ( + False # whether the command should be removed from prompt + ) + + def __init__( + self, + *, + log: Logger, + config_manager: ConfigManager, + model_parameters: Dict[str, Dict], + chat_history: List[ChatMessage], + llm_chat_memory: "BoundedChatHistory", + root_dir: str, + preferred_dir: Optional[str], + dask_client_future: Awaitable[DaskClient], + chat_handlers: Dict[str, "BaseChatHandler"], + context_providers: Dict[str, "BaseContextProvider"], + ): + preferred_dir = preferred_dir or "" + self.log = log + self.config_manager = config_manager + self.model_parameters = model_parameters + self._chat_history = chat_history + self.llm_chat_memory = llm_chat_memory + self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) + self.preferred_dir = get_preferred_dir(self.root_dir, preferred_dir) + self.dask_client_future = dask_client_future + self.chat_handlers = chat_handlers + self.context_providers = context_providers + + self.llm = None + + @property + def pattern(self) -> str: + return ( + rf"(? str: + """Returns a context prompt for all instances of the context provider + command. + """ + pass + + def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]: + """Returns a list of autocomplete options for arguments to the command + based on the prefix. + Only triggered if ':' is present after the command id (e.g. '@file:'). + """ + return [] + + def replace_prompt(self, prompt: str) -> str: + """Cleans up instances of the command from the prompt before + sending it to the LLM + """ + if self.remove_from_prompt: + return re.sub(self.pattern, "", prompt) + return prompt + + def _find_instances(self, text: str) -> List[str]: + # finds instances of the context provider command in the text + matches = re.finditer(self.pattern, text) + results = [] + for match in matches: + start, end = match.span() + before = text[:start] + after = text[end:] + # Check if the match is within backticks + if before.count("`") % 2 == 0 and after.count("`") % 2 == 0: + results.append(match.group()) + return results + + def _clean_prompt(self, text: str) -> str: + # useful for cleaning up the prompt before sending it to a retriever + for provider in self.context_providers.values(): + text = provider.replace_prompt(text) + return text + + @property + def base_dir(self) -> str: + # same as BaseChatHandler.output_dir + if self.preferred_dir and os.path.exists(self.preferred_dir): + return self.preferred_dir + else: + return self.root_dir + + def get_llm(self): + lm_provider = self.config_manager.lm_provider + lm_provider_params = self.config_manager.lm_provider_params + + curr_lm_id = ( + f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None + ) + next_lm_id = ( + f'{lm_provider.id}:{lm_provider_params["model_id"]}' + if lm_provider + else None + ) + + if not lm_provider or not lm_provider_params: + return None + + if curr_lm_id != next_lm_id: + model_parameters = self.model_parameters.get( + f"{lm_provider.id}:{lm_provider_params['model_id']}", {} + ) + unified_parameters = { + "verbose": True, + **lm_provider_params, + **model_parameters, + } + llm = lm_provider(**unified_parameters) + self.llm = llm + return self.llm + + +class ContextProviderException(Exception): + # Used to generate a response when a context provider fails + pass diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/file.py b/packages/jupyter-ai/jupyter_ai/context_providers/file.py new file mode 100644 index 000000000..2edf1750c --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/file.py @@ -0,0 +1,116 @@ +import os +import glob +import re +from typing import List + +import nbformat +from jupyter_ai.models import ListOptionsEntry, HumanChatMessage +from jupyter_ai.document_loaders.directory import SUPPORTED_EXTS + +from .base import BaseContextProvider, ContextProviderException + +FILE_CONTEXT_TEMPLATE = """ +File: {filepath} +``` +{content} +``` +""".strip() + + +class FileContextProvider(BaseContextProvider): + id = "file" + description = "Include file contents" + requires_arg = True + header = "Following are contents of files referenced:" + + def replace_prompt(self, prompt: str) -> str: + # replaces instances of @file: with '' + def substitute(match): + filepath = match.group(0).partition(":")[2] + return f"'{filepath}'" + + return re.sub(self.pattern, substitute, prompt) + + def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]: + is_abs = not os.path.isabs(arg_prefix) + path_prefix = arg_prefix if is_abs else os.path.join(self.base_dir, arg_prefix) + return [ + self._make_option(path, is_abs, is_dir) + for path in glob.glob(path_prefix + "*") + if ( + (is_dir := os.path.isdir(path)) + or os.path.splitext(path)[1] in SUPPORTED_EXTS + ) + ] + + def _make_option(self, path: str, is_abs: bool, is_dir: bool) -> ListOptionsEntry: + if not is_abs: + path = os.path.relpath(path, self.base_dir) + if is_dir: + path += "/" + return ListOptionsEntry.from_arg( + type="@", + id=self.id, + description="Directory" if is_dir else "File", + arg=path, + is_complete=not is_dir, + ) + + async def make_context_prompt(self, message: HumanChatMessage) -> str: + instances = set(self._find_instances(message.prompt)) + if not instances: + return "" + context = "\n\n".join( + [context for i in instances if (context := self._make_instance_context(i))] + ) + if not context: + return "" + return self.header + "\n" + context + + def _make_instance_context(self, instance: str) -> str: + filepath = instance.partition(":")[2] + if not os.path.isabs(filepath): + filepath = os.path.join(self.base_dir, filepath) + + if not os.path.exists(filepath): + raise ContextProviderException( + f"File not found while trying to read '{filepath}' " + f"triggered by `{instance}`." + ) + if os.path.isdir(filepath): + raise ContextProviderException( + f"Cannot read directory '{filepath}' triggered by `{instance}`. " + f"Only files are supported." + ) + if os.path.splitext(filepath)[1] not in SUPPORTED_EXTS: + raise ContextProviderException( + f"Cannot read unsupported file type '{filepath}' triggered by `{instance}`. " + f"Supported file extensions are: {', '.join(SUPPORTED_EXTS)}." + ) + try: + with open(filepath) as f: + content = f.read() + except PermissionError: + raise ContextProviderException( + f"Permission denied while trying to read '{filepath}' " + f"triggered by `{instance}`." + ) + return FILE_CONTEXT_TEMPLATE.format( + filepath=filepath, + content=self._process_file(content, filepath), + ) + + def _process_file(self, content: str, filepath: str): + if filepath.endswith(".ipynb"): + nb = nbformat.reads(content, as_version=4) + return "\n\n".join([cell.source for cell in nb.cells]) + return content + + def get_filepaths(self, message: HumanChatMessage) -> List[str]: + filepaths = [] + for instance in self._find_instances(message.prompt): + filepath = instance.partition(":")[2] + if not os.path.isabs(filepath): + filepath = os.path.join(self.base_dir, filepath) + filepaths.append(filepath) + return filepaths diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/learned.py b/packages/jupyter-ai/jupyter_ai/context_providers/learned.py new file mode 100644 index 000000000..72f81ef41 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/learned.py @@ -0,0 +1,50 @@ +from typing import List + +from jupyter_ai.models import HumanChatMessage +from jupyter_ai.chat_handlers.learn import Retriever + +from .base import BaseContextProvider +from .file import FileContextProvider + + +FILE_CHUNK_TEMPLATE = """ +Snippet from file: {filepath} +``` +{content} +``` +""".strip() + + +class LearnedContextProvider(BaseContextProvider): + id = "learned" + description = "Include learned context" + remove_from_prompt = True + header = "Following are snippets from potentially relevant files:" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.retriever = Retriever(learn_chat_handler=self.chat_handlers["/learn"]) + + async def make_context_prompt(self, message: HumanChatMessage) -> str: + if not self.retriever or not self._find_instances(message.prompt): + return "" + query = self._clean_prompt(message.body) + docs = await self.retriever.ainvoke(query) + excluded = self._get_repeated_files(message) + context = "\n\n".join( + [ + FILE_CHUNK_TEMPLATE.format( + filepath=d.metadata["path"], content=d.page_content + ) + for d in docs + if d.metadata["path"] not in excluded and d.page_content + ] + ) + return self.header + "\n" + context + + def _get_repeated_files(self, message: HumanChatMessage) -> List[str]: + # don't include files that are already provided by the file context provider + file_context_provider = self.context_providers.get("file") + if isinstance(file_context_provider, FileContextProvider): + return file_context_provider.get_filepaths(message) + return [] diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 34b484546..c76ebc412 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -22,6 +22,7 @@ HelpChatHandler, LearnChatHandler, ) +from .context_providers import FileContextProvider, LearnedContextProvider from .completions.handlers import DefaultInlineCompletionHandler from .config_manager import ConfigManager from .handlers import ( @@ -32,6 +33,7 @@ ModelProviderHandler, RootChatHandler, SlashCommandsInfoHandler, + AutocompleteOptionsHandler, ) from .history import BoundedChatHistory @@ -58,6 +60,7 @@ class AiExtension(ExtensionApp): (r"api/ai/chats/?", RootChatHandler), (r"api/ai/chats/history?", ChatHistoryHandler), (r"api/ai/chats/slash_commands?", SlashCommandsInfoHandler), + (r"api/ai/chats/autocomplete_options?", AutocompleteOptionsHandler), (r"api/ai/providers?", ModelProviderHandler), (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), (r"api/ai/completion/inline/?", DefaultInlineCompletionHandler), @@ -285,6 +288,10 @@ def initialize_settings(self): eps = entry_points() + # Create empty context providers dict to be filled later. + # This is created early to use as kwargs for chat handlers. + self.settings["jai_context_providers"] = {} + # initialize chat handlers chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers") chat_handlers = {} @@ -301,6 +308,7 @@ def initialize_settings(self): "preferred_dir": self.serverapp.contents_manager.preferred_dir, "help_message_template": self.help_message_template, "chat_handlers": chat_handlers, + "context_providers": self.settings["jai_context_providers"], } default_chat_handler = DefaultChatHandler(**chat_handler_kwargs) clear_chat_handler = ClearChatHandler(**chat_handler_kwargs) @@ -376,6 +384,58 @@ def initialize_settings(self): # bind chat handlers to settings self.settings["jai_chat_handlers"] = chat_handlers + # initialize context providers + context_providers = self.settings["jai_context_providers"] + context_providers_kwargs = { + "log": self.log, + "config_manager": self.settings["jai_config_manager"], + "model_parameters": self.settings["model_parameters"], + "chat_history": self.settings["chat_history"], + "llm_chat_memory": self.settings["llm_chat_memory"], + "root_dir": self.serverapp.root_dir, + "dask_client_future": self.settings["dask_client_future"], + "model_parameters": self.settings["model_parameters"], + "preferred_dir": self.serverapp.contents_manager.preferred_dir, + "chat_handlers": self.settings["jai_chat_handlers"], + "context_providers": self.settings["jai_context_providers"], + } + context_providers_clses = [ + FileContextProvider, + LearnedContextProvider, + ] + context_providers_eps = eps.select(group="jupyter_ai.context_providers") + for context_provider_ep in context_providers_eps: + try: + context_provider = context_provider_ep.load() + except Exception as err: + self.log.error( + f"Unable to load context provider class from entry point `{context_provider_ep.name}`: " + + f"Unexpected {err=}, {type(err)=}" + ) + continue + context_providers_clses.append(context_provider) + + for context_provider in context_providers_clses: + if context_provider.id in context_providers: + self.log.error( + f"Unable to register context provider `{context_provider.id}` because it already exists" + ) + continue + + if context_provider.is_command and not re.match( + r"^[a-zA-Z0-9_]+$", context_provider.id + ): + self.log.error( + f"Context provider `{context_provider.id}` is an invalid ID; " + + f"must contain only letters, numbers, and underscores" + ) + continue + + context_providers[context_provider.id] = context_provider( + **context_providers_kwargs + ) + self.log.info(f"Registered context provider `{context_provider.id}`.") + # show help message at server start self._show_help_message() diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 7c4f16e63..6a1df7c03 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -33,6 +33,8 @@ ListProvidersResponse, ListSlashCommandsEntry, ListSlashCommandsResponse, + ListOptionsEntry, + ListOptionsResponse, Message, PendingMessage, UpdateConfigRequest, @@ -43,6 +45,7 @@ from jupyter_ai_magics.providers import BaseProvider from .history import BoundedChatHistory + from .context_providers import BaseContextProvider class ChatHistoryHandler(BaseAPIHandler): @@ -571,3 +574,101 @@ def get(self): # sort slash commands by slash id and deliver the response response.slash_commands.sort(key=lambda sc: sc.slash_id) self.finish(response.json()) + + +class AutocompleteOptionsHandler(BaseAPIHandler): + """List context that are currently available to the user.""" + + @property + def config_manager(self) -> ConfigManager: + return self.settings["jai_config_manager"] + + @property + def context_providers(self) -> Dict[str, "BaseContextProvider"]: + return self.settings["jai_context_providers"] + + @property + def chat_handlers(self) -> Dict[str, "BaseChatHandler"]: + return self.settings["jai_chat_handlers"] + + @web.authenticated + def get(self): + response = ListOptionsResponse() + + # if no selected LLM, return an empty response + if not self.config_manager.lm_provider: + self.finish(response.json()) + return + + response.options = ( + self._get_slash_command_options() + self._get_context_provider_options() + ) + self.finish(response.json()) + + @web.authenticated + def post(self): + try: + data = self.get_json_body() + context_provider = self.context_providers.get(data["id"]) + arg_prefix = data["arg_prefix"] + response = ListOptionsResponse() + + if not context_provider: + self.finish(response.json()) + return + + response.options = context_provider.get_arg_options(arg_prefix) + self.finish(response.json()) + except (ValidationError, WriteConflictError, KeyEmptyError) as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except ValueError as e: + self.log.exception(e) + raise HTTPError(500, str(e.cause) if hasattr(e, "cause") else str(e)) + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred while updating the context provider." + ) from e + + def _get_slash_command_options(self) -> List[ListOptionsEntry]: + options = [] + for id, chat_handler in self.chat_handlers.items(): + # filter out any chat handler that is not a slash command + if ( + id == "default" + or chat_handler.routing_type.routing_method != "slash_command" + ): + continue + + # hint the type of this attribute + routing_type: SlashCommandRoutingType = chat_handler.routing_type + + # filter out any chat handler that is unsupported by the current LLM + if ( + "/" + routing_type.slash_id + in self.config_manager.lm_provider.unsupported_slash_commands + ): + continue + + options.append( + ListOptionsEntry.from_command( + type="/", id=routing_type.slash_id, description=chat_handler.help + ) + ) + options.sort(key=lambda opt: opt.id) + return options + + def _get_context_provider_options(self) -> List[ListOptionsEntry]: + options = [ + ListOptionsEntry.from_command( + type="@", + id=context_provider.id, + description=context_provider.description, + requires_arg=context_provider.requires_arg, + ) + for context_provider in self.context_providers.values() + if context_provider.is_command + ] + options.sort(key=lambda opt: opt.id) + return options diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index bda4d3421..0a3476bb4 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -263,3 +263,37 @@ class ListSlashCommandsEntry(BaseModel): class ListSlashCommandsResponse(BaseModel): slash_commands: List[ListSlashCommandsEntry] = [] + + +class ListOptionsEntry(BaseModel): + type: Literal["/", "@"] + id: str + label: str + description: str + + @classmethod + def from_command( + cls, + type: Literal["/", "@"], + id: str, + description: str, + requires_arg: bool = False, + ): + label = type + id + (":" if requires_arg else " ") + return cls(type=type, id=id, description=description, label=label) + + @classmethod + def from_arg( + cls, + type: Literal["/", "@"], + id: str, + description: str, + arg: str, + is_complete: bool = True, + ): + label = type + id + ":" + arg + (" " if is_complete else "") + return cls(type=type, id=id, description=description, label=label) + + +class ListOptionsResponse(BaseModel): + options: List[ListOptionsEntry] = [] diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py new file mode 100644 index 000000000..b7c97596e --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py @@ -0,0 +1,63 @@ +import logging +from unittest import mock + +import pytest +from jupyter_ai.context_providers import FileContextProvider +from jupyter_ai.config_manager import ConfigManager +from jupyter_ai.history import BoundedChatHistory +from jupyter_ai.models import ( + ChatClient, + HumanChatMessage, + Persona, +) + + +@pytest.fixture +def human_chat_message() -> HumanChatMessage: + chat_client = ChatClient( + id=0, username="test", initials="test", name="test", display_name="test" + ) + prompt = ( + "@file:test1.py @file @file:dir/test2.md test test\n" + "@file:/dir/test3.png test@file:test4.py" + ) + return HumanChatMessage( + id="test", + time=0, + body=prompt, + prompt=prompt, + client=chat_client, + ) + + +@pytest.fixture +def file_context_provider() -> FileContextProvider: + config_manager = mock.create_autospec(ConfigManager) + config_manager.persona = Persona(name="test", avatar_route="test") + return FileContextProvider( + log=logging.getLogger(__name__), + config_manager=config_manager, + model_parameters={}, + chat_history=[], + llm_chat_memory=BoundedChatHistory(k=2), + root_dir="", + preferred_dir="", + dask_client_future=None, + chat_handlers={}, + context_providers={}, + ) + + +def test_find_instances(file_context_provider, human_chat_message): + expected = ["@file:test1.py", "@file:dir/test2.md", "@file:/dir/test3.png"] + instances = file_context_provider._find_instances(human_chat_message.prompt) + assert instances == expected + + +def test_replace_prompt(file_context_provider, human_chat_message): + expected = ( + "'test1.py' @file 'dir/test2.md' test test\n" + "'/dir/test3.png' test@file:test4.py" + ) + prompt = file_context_provider.replace_prompt(human_chat_message.prompt) + assert prompt == expected diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index c1ca7b098..a94c3fbf8 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -76,6 +76,7 @@ def broadcast_message(message: Message) -> None: dask_client_future=None, help_message_template=DEFAULT_HELP_MESSAGE_TEMPLATE, chat_handlers={}, + context_providers={}, ) diff --git a/packages/jupyter-ai/src/components/chat-input.tsx b/packages/jupyter-ai/src/components/chat-input.tsx index ddc2d4209..b341653e7 100644 --- a/packages/jupyter-ai/src/components/chat-input.tsx +++ b/packages/jupyter-ai/src/components/chat-input.tsx @@ -38,12 +38,6 @@ type ChatInputProps = { personaName: string; }; -type SlashCommandOption = { - id: string; - label: string; - description: string; -}; - /** * List of icons per slash command, shown in the autocomplete popup. * @@ -65,9 +59,9 @@ const DEFAULT_SLASH_COMMAND_ICONS: Record = { /** * Renders an option shown in the slash command autocomplete. */ -function renderSlashCommandOption( +function renderAutocompleteOption( optionProps: React.HTMLAttributes, - option: SlashCommandOption + option: AiService.AutocompleteOption ): JSX.Element { const icon = option.id in DEFAULT_SLASH_COMMAND_ICONS @@ -99,8 +93,14 @@ function renderSlashCommandOption( export function ChatInput(props: ChatInputProps): JSX.Element { const [input, setInput] = useState(''); - const [slashCommandOptions, setSlashCommandOptions] = useState< - SlashCommandOption[] + const [autocompleteOptions, setAutocompleteOptions] = useState< + AiService.AutocompleteOption[] + >([]); + const [autocompleteCommandOptions, setAutocompleteCommandOptions] = useState< + AiService.AutocompleteOption[] + >([]); + const [autocompleteArgOptions, setAutocompleteArgOptions] = useState< + AiService.AutocompleteOption[] >([]); const [currSlashCommand, setCurrSlashCommand] = useState(null); const activeCell = useActiveCellContext(); @@ -110,20 +110,45 @@ export function ChatInput(props: ChatInputProps): JSX.Element { * initial mount to populate the slash command autocomplete. */ useEffect(() => { - async function getSlashCommands() { - const slashCommands = (await AiService.listSlashCommands()) - .slash_commands; - setSlashCommandOptions( - slashCommands.map(slashCommand => ({ - id: slashCommand.slash_id, - label: '/' + slashCommand.slash_id + ' ', - description: slashCommand.description - })) - ); + async function getAutocompleteCommandOptions() { + const response = await AiService.listAutocompleteOptions(); + setAutocompleteCommandOptions(response.options); } - getSlashCommands(); + getAutocompleteCommandOptions(); }, []); + useEffect(() => { + async function getAutocompleteArgOptions() { + let options: AiService.AutocompleteOption[] = []; + const lastWord = input.split(/\s+/).pop() || ''; + if (lastWord.startsWith('@') && lastWord.includes(':')) { + const [id, argPrefix] = lastWord.split(':', 2); + // get option that matches the command + const option = autocompleteCommandOptions.find( + option => option.id === id.slice(1) && option.type === '@' + ); + if (option) { + const response = await AiService.listAutocompleteArgOptions({ + id: option.id, + arg_prefix: argPrefix + }); + options = response.options; + } + } + setAutocompleteArgOptions(options); + } + getAutocompleteArgOptions(); + }, [autocompleteCommandOptions, input]); + + // Combine the fixed options with the argument options + useEffect(() => { + if (autocompleteArgOptions.length > 0) { + setAutocompleteOptions(autocompleteArgOptions); + } else { + setAutocompleteOptions(autocompleteCommandOptions); + } + }, [autocompleteCommandOptions, autocompleteArgOptions]); + // whether any option is highlighted in the slash command autocomplete const [highlighted, setHighlighted] = useState(false); @@ -153,7 +178,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { * chat input. Close the autocomplete when the user clears the chat input. */ useEffect(() => { - if (input === '/') { + if (input === '/' || input.endsWith('@')) { setOpen(true); return; } @@ -255,12 +280,39 @@ export function ChatInput(props: ChatInputProps): JSX.Element { currSlashCommand }; + function filterAutocompleteOptions( + options: AiService.AutocompleteOption[], + inputValue: string + ): AiService.AutocompleteOption[] { + const lastWord = inputValue.split(/\s+/).pop() || ''; + if ( + (lastWord.startsWith('/') && lastWord === inputValue) || + lastWord.startsWith('@') + ) { + return options.filter(option => option.label.startsWith(lastWord)); + } + return []; + } + return ( { + return filterAutocompleteOptions(options, inputValue); + }} + onChange={(_, option) => { + const value = typeof option === 'string' ? option : option.label; + let matchLength = 0; + for (let i = 1; i <= value.length; i++) { + if (input.endsWith(value.slice(0, i))) { + matchLength = i; + } + } + setInput(input + value.slice(matchLength)); + }} onInputChange={(_, newValue: string) => { setInput(newValue); }} @@ -273,12 +325,16 @@ export function ChatInput(props: ChatInputProps): JSX.Element { setHighlighted(!!highlightedOption); } } - onClose={() => setOpen(false)} + onClose={(_, reason) => { + if (reason !== 'selectOption' || input.endsWith(' ')) { + setOpen(false); + } + }} // set this to an empty string to prevent the last selected slash // command from being shown in blue value="" open={open} - options={slashCommandOptions} + options={autocompleteOptions} // hide default extra right padding in the text field disableClearable // ensure the autocomplete popup always renders on top @@ -292,7 +348,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { } } }} - renderOption={renderSlashCommandOption} + renderOption={renderAutocompleteOption} ListboxProps={{ sx: { '& .MuiAutocomplete-option': { diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index b2f4efc16..64771e158 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -315,4 +315,38 @@ export namespace AiService { export async function listSlashCommands(): Promise { return requestAPI('chats/slash_commands'); } + + export type AutocompleteOption = { + type: '/' | '@'; + id: string; + description: string; + label: string; + }; + + export type ListAutocompleteOptionsResponse = { + options: AutocompleteOption[]; + }; + + export type AutocompleteArgOptionsRequest = { + id: string; + arg_prefix: string; + }; + + export async function listAutocompleteOptions(): Promise { + return requestAPI( + 'chats/autocomplete_options' + ); + } + + export async function listAutocompleteArgOptions( + request: AutocompleteArgOptionsRequest + ): Promise { + return requestAPI( + 'chats/autocomplete_options', + { + method: 'POST', + body: JSON.stringify(request) + } + ); + } }