diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
index 551db8bbc..3b6c672a1 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
@@ -67,11 +67,26 @@
The following is a friendly conversation between you and a human.
""".strip()
-CHAT_DEFAULT_TEMPLATE = """Current conversation:
-{history}
-Human: {input}
+CHAT_DEFAULT_TEMPLATE = """
+{% if context %}
+Context:
+{{context}}
+
+{% endif %}
+Current conversation:
+{{history}}
+Human: {{input}}
AI:"""
+HUMAN_MESSAGE_TEMPLATE = """
+{% if context %}
+
+{{context}}
+
+
+{% endif %}
+{{input}}
+"""
COMPLETION_SYSTEM_PROMPT = """
You are an application built to provide helpful code completion suggestions.
@@ -400,17 +415,21 @@ def get_chat_prompt_template(self) -> PromptTemplate:
CHAT_SYSTEM_PROMPT
).format(provider_name=name, local_model_id=self.model_id),
MessagesPlaceholder(variable_name="history"),
- HumanMessagePromptTemplate.from_template("{input}"),
+ HumanMessagePromptTemplate.from_template(
+ HUMAN_MESSAGE_TEMPLATE,
+ template_format="jinja2",
+ ),
]
)
else:
return PromptTemplate(
- input_variables=["history", "input"],
+ input_variables=["history", "input", "context"],
template=CHAT_SYSTEM_PROMPT.format(
provider_name=name, local_model_id=self.model_id
)
+ "\n\n"
+ CHAT_DEFAULT_TEMPLATE,
+ template_format="jinja2",
)
def get_completion_prompt_template(self) -> PromptTemplate:
diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
index 347bfbf83..2e5786f57 100644
--- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
+++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
@@ -35,6 +35,7 @@
if TYPE_CHECKING:
from jupyter_ai.handlers import RootChatHandler
from jupyter_ai.history import BoundedChatHistory
+ from jupyter_ai.context_providers import BaseContextProvider
from langchain_core.chat_history import BaseChatMessageHistory
@@ -121,6 +122,10 @@ class BaseChatHandler:
chat handlers, which is necessary for some use-cases like printing the help
message."""
+ context_providers: Dict[str, Type["BaseContextProvider"]]
+ """Dictionary of context providers. Allows chat handlers to reference
+ context providers, which can be used to provide context to the LLM."""
+
def __init__(
self,
log: Logger,
@@ -134,6 +139,7 @@ def __init__(
dask_client_future: Awaitable[DaskClient],
help_message_template: str,
chat_handlers: Dict[str, "BaseChatHandler"],
+ context_providers: Dict[str, Type["BaseContextProvider"]],
):
self.log = log
self.config_manager = config_manager
@@ -154,6 +160,7 @@ def __init__(
self.dask_client_future = dask_client_future
self.help_message_template = help_message_template
self.chat_handlers = chat_handlers
+ self.context_providers = context_providers
self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
index 2bf88bafe..cc905446b 100644
--- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
+++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
@@ -1,3 +1,4 @@
+import asyncio
import time
from typing import Dict, Type
from uuid import uuid4
@@ -13,6 +14,7 @@
from langchain_core.runnables.history import RunnableWithMessageHistory
from ..models import HumanChatMessage
+from ..context_providers import ContextProviderException
from .base import BaseChatHandler, SlashCommandRoutingType
@@ -27,6 +29,7 @@ class DefaultChatHandler(BaseChatHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ self.prompt_template = None
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
@@ -40,6 +43,7 @@ def create_llm_chain(
prompt_template = llm.get_chat_prompt_template()
self.llm = llm
+ self.prompt_template = prompt_template
runnable = prompt_template | llm # type:ignore
if not llm.manages_history:
@@ -101,6 +105,16 @@ async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
received_first_chunk = False
+ inputs = {"input": message.body}
+ if "context" in self.prompt_template.input_variables:
+ # include context from context providers.
+ try:
+ context_prompt = await self.make_context_prompt(message)
+ except ContextProviderException as e:
+ self.reply(str(e), message)
+ return
+ inputs["context"] = context_prompt
+
# start with a pending message
with self.pending("Generating response", message) as pending_message:
# stream response in chunks. this works even if a provider does not
@@ -108,7 +122,7 @@ async def process_message(self, message: HumanChatMessage):
# when `_stream()` is not implemented on the LLM class.
assert self.llm_chain
async for chunk in self.llm_chain.astream(
- {"input": message.body},
+ inputs,
config={"configurable": {"last_human_msg": message}},
):
if not received_first_chunk:
@@ -128,3 +142,13 @@ async def process_message(self, message: HumanChatMessage):
# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)
+
+ async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
+ return "\n\n".join(
+ await asyncio.gather(
+ *[
+ provider.make_context_prompt(human_msg)
+ for provider in self.context_providers.values()
+ ]
+ )
+ )
diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py b/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py
new file mode 100644
index 000000000..880faf28b
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py
@@ -0,0 +1,3 @@
+from .base import BaseContextProvider, ContextProviderException
+from .file import FileContextProvider
+from .learned import LearnedContextProvider
diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/_examples.py b/packages/jupyter-ai/jupyter_ai/context_providers/_examples.py
new file mode 100644
index 000000000..6772d7a3e
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/context_providers/_examples.py
@@ -0,0 +1,133 @@
+# This file is for illustrative purposes
+# It is to be deleted before merging
+from jupyter_ai.models import HumanChatMessage
+from langchain_community.retrievers import WikipediaRetriever
+from langchain_community.retrievers import ArxivRetriever
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.output_parsers import StrOutputParser
+
+from .base import BaseContextProvider
+
+
+# Examples of the ease of implementing retriever based context providers
+ARXIV_TEMPLATE = """
+Title: {title}
+Publish Date: {publish_date}
+'''
+{content}
+'''
+""".strip()
+
+
+class ArxivContextProvider(BaseContextProvider):
+ id = "arvix"
+ description = "Include papers from Arxiv"
+ remove_from_prompt = True
+ header = "Following are snippets of research papers:"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.retriever = ArxivRetriever()
+
+ async def make_context_prompt(self, message: HumanChatMessage) -> str:
+ if not self._find_instances(message.prompt):
+ return ""
+ query = self._clean_prompt(message.body)
+ docs = await self.retriever.ainvoke(query)
+ context = "\n\n".join(
+ [
+ ARXIV_TEMPLATE.format(
+ content=d.page_content,
+ title=d.metadata["Title"],
+ publish_date=d.metadata["Published"],
+ )
+ for d in docs
+ ]
+ )
+ return self.header + "\n" + context
+
+
+# Another retriever based context provider with a rewrite step using LLM
+WIKI_TEMPLATE = """
+Title: {title}
+'''
+{content}
+'''
+""".strip()
+
+REWRITE_TEMPLATE = """Provide a better search query for \
+web search engine to answer the given question, end \
+the queries with ’**’. Question: \
+{x} Answer:"""
+
+
+class WikiContextProvider(BaseContextProvider):
+ id = "wiki"
+ description = "Include knowledge from Wikipedia"
+ remove_from_prompt = True
+ header = "Following are information from wikipedia:"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.retriever = WikipediaRetriever()
+
+ async def make_context_prompt(self, message: HumanChatMessage) -> str:
+ if not self._find_instances(message.prompt):
+ return ""
+ prompt = self._clean_prompt(message.body)
+ search_query = await self._rewrite_prompt(prompt)
+ docs = await self.retriever.ainvoke(search_query)
+ context = "\n\n".join(
+ [
+ WIKI_TEMPLATE.format(
+ content=d.page_content,
+ title=d.metadata["title"],
+ )
+ for d in docs
+ ]
+ )
+ return self.header + "\n" + context
+
+ async def _rewrite_prompt(self, prompt: str) -> str:
+ return await self.get_llm_chain().ainvoke(prompt)
+
+ def get_llm_chain(self):
+ # from https://github.com/langchain-ai/langchain/blob/master/cookbook/rewrite.ipynb
+ llm = self.get_llm()
+ rewrite_prompt = ChatPromptTemplate.from_template(REWRITE_TEMPLATE)
+
+ def _parse(text):
+ return text.strip('"').strip("**")
+
+ return rewrite_prompt | llm | StrOutputParser() | _parse
+
+
+# Partial example of non-command context provider for errors.
+# Assuming there is an option in UI to add cell errors to messages,
+# default chat will automatically invoke this context provider to add
+# solutions retrieved from a custom error database or a stackoverflow / google
+# retriever pipeline to find solutions for errors.
+class ErrorContextProvider(BaseContextProvider):
+ id = "error"
+ description = "Include custom error context"
+ remove_from_prompt = True
+ header = "Following are potential solutions for the error:"
+ is_command = False # will not show up in autocomplete
+
+ async def make_context_prompt(self, message: HumanChatMessage) -> str:
+ # will run for every message with a cell error since it does not
+ # use _find_instances to check for the presence of the command in
+ # the message.
+ if not (message.selection and message.selection.type == "cell-with-error"):
+ return ""
+ docs = await self.solution_retriever.ainvoke(message.selection)
+ if not docs:
+ return ""
+ context = "\n\n".join([d.page_content for d in docs])
+ return self.header + "\n" + context
+
+ @property
+ def solution_retriever(self):
+ # retriever that takes an error and returns a solutions from a database
+ # of error messages.
+ raise NotImplementedError("Error retriever not implemented")
diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/base.py b/packages/jupyter-ai/jupyter_ai/context_providers/base.py
new file mode 100644
index 000000000..83f970e55
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/context_providers/base.py
@@ -0,0 +1,157 @@
+import abc
+import re
+from typing import ClassVar, List
+
+import os
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ ClassVar,
+ Dict,
+ List,
+ Optional,
+)
+
+from dask.distributed import Client as DaskClient
+from jupyter_ai.config_manager import ConfigManager, Logger
+from jupyter_ai.models import (
+ ChatMessage,
+ HumanChatMessage,
+)
+from jupyter_ai.chat_handlers.base import get_preferred_dir
+from jupyter_ai.models import ListOptionsEntry, HumanChatMessage
+
+if TYPE_CHECKING:
+ from jupyter_ai.history import BoundedChatHistory
+ from jupyter_ai.chat_handlers import BaseChatHandler
+
+
+class BaseContextProvider(abc.ABC):
+ id: ClassVar[str]
+ description: ClassVar[str]
+ requires_arg: ClassVar[bool] = False
+ is_command: ClassVar[bool] = (
+ True # whether the context provider can be invoked from chat
+ )
+ remove_from_prompt: ClassVar[bool] = (
+ False # whether the command should be removed from prompt
+ )
+
+ def __init__(
+ self,
+ *,
+ log: Logger,
+ config_manager: ConfigManager,
+ model_parameters: Dict[str, Dict],
+ chat_history: List[ChatMessage],
+ llm_chat_memory: "BoundedChatHistory",
+ root_dir: str,
+ preferred_dir: Optional[str],
+ dask_client_future: Awaitable[DaskClient],
+ chat_handlers: Dict[str, "BaseChatHandler"],
+ context_providers: Dict[str, "BaseContextProvider"],
+ ):
+ preferred_dir = preferred_dir or ""
+ self.log = log
+ self.config_manager = config_manager
+ self.model_parameters = model_parameters
+ self._chat_history = chat_history
+ self.llm_chat_memory = llm_chat_memory
+ self.root_dir = os.path.abspath(os.path.expanduser(root_dir))
+ self.preferred_dir = get_preferred_dir(self.root_dir, preferred_dir)
+ self.dask_client_future = dask_client_future
+ self.chat_handlers = chat_handlers
+ self.context_providers = context_providers
+
+ self.llm = None
+
+ @property
+ def pattern(self) -> str:
+ return (
+ rf"(? str:
+ """Returns a context prompt for all instances of the context provider
+ command.
+ """
+ pass
+
+ def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]:
+ """Returns a list of autocomplete options for arguments to the command
+ based on the prefix.
+ Only triggered if ':' is present after the command id (e.g. '@file:').
+ """
+ return []
+
+ def replace_prompt(self, prompt: str) -> str:
+ """Cleans up instances of the command from the prompt before
+ sending it to the LLM
+ """
+ if self.remove_from_prompt:
+ return re.sub(self.pattern, "", prompt)
+ return prompt
+
+ def _find_instances(self, text: str) -> List[str]:
+ # finds instances of the context provider command in the text
+ matches = re.finditer(self.pattern, text)
+ results = []
+ for match in matches:
+ start, end = match.span()
+ before = text[:start]
+ after = text[end:]
+ # Check if the match is within backticks
+ if before.count("`") % 2 == 0 and after.count("`") % 2 == 0:
+ results.append(match.group())
+ return results
+
+ def _clean_prompt(self, text: str) -> str:
+ # useful for cleaning up the prompt before sending it to a retriever
+ for provider in self.context_providers.values():
+ text = provider.replace_prompt(text)
+ return text
+
+ @property
+ def base_dir(self) -> str:
+ # same as BaseChatHandler.output_dir
+ if self.preferred_dir and os.path.exists(self.preferred_dir):
+ return self.preferred_dir
+ else:
+ return self.root_dir
+
+ def get_llm(self):
+ lm_provider = self.config_manager.lm_provider
+ lm_provider_params = self.config_manager.lm_provider_params
+
+ curr_lm_id = (
+ f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None
+ )
+ next_lm_id = (
+ f'{lm_provider.id}:{lm_provider_params["model_id"]}'
+ if lm_provider
+ else None
+ )
+
+ if not lm_provider or not lm_provider_params:
+ return None
+
+ if curr_lm_id != next_lm_id:
+ model_parameters = self.model_parameters.get(
+ f"{lm_provider.id}:{lm_provider_params['model_id']}", {}
+ )
+ unified_parameters = {
+ "verbose": True,
+ **lm_provider_params,
+ **model_parameters,
+ }
+ llm = lm_provider(**unified_parameters)
+ self.llm = llm
+ return self.llm
+
+
+class ContextProviderException(Exception):
+ # Used to generate a response when a context provider fails
+ pass
diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/file.py b/packages/jupyter-ai/jupyter_ai/context_providers/file.py
new file mode 100644
index 000000000..2edf1750c
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/context_providers/file.py
@@ -0,0 +1,116 @@
+import os
+import glob
+import re
+from typing import List
+
+import nbformat
+from jupyter_ai.models import ListOptionsEntry, HumanChatMessage
+from jupyter_ai.document_loaders.directory import SUPPORTED_EXTS
+
+from .base import BaseContextProvider, ContextProviderException
+
+FILE_CONTEXT_TEMPLATE = """
+File: {filepath}
+```
+{content}
+```
+""".strip()
+
+
+class FileContextProvider(BaseContextProvider):
+ id = "file"
+ description = "Include file contents"
+ requires_arg = True
+ header = "Following are contents of files referenced:"
+
+ def replace_prompt(self, prompt: str) -> str:
+ # replaces instances of @file: with ''
+ def substitute(match):
+ filepath = match.group(0).partition(":")[2]
+ return f"'{filepath}'"
+
+ return re.sub(self.pattern, substitute, prompt)
+
+ def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]:
+ is_abs = not os.path.isabs(arg_prefix)
+ path_prefix = arg_prefix if is_abs else os.path.join(self.base_dir, arg_prefix)
+ return [
+ self._make_option(path, is_abs, is_dir)
+ for path in glob.glob(path_prefix + "*")
+ if (
+ (is_dir := os.path.isdir(path))
+ or os.path.splitext(path)[1] in SUPPORTED_EXTS
+ )
+ ]
+
+ def _make_option(self, path: str, is_abs: bool, is_dir: bool) -> ListOptionsEntry:
+ if not is_abs:
+ path = os.path.relpath(path, self.base_dir)
+ if is_dir:
+ path += "/"
+ return ListOptionsEntry.from_arg(
+ type="@",
+ id=self.id,
+ description="Directory" if is_dir else "File",
+ arg=path,
+ is_complete=not is_dir,
+ )
+
+ async def make_context_prompt(self, message: HumanChatMessage) -> str:
+ instances = set(self._find_instances(message.prompt))
+ if not instances:
+ return ""
+ context = "\n\n".join(
+ [context for i in instances if (context := self._make_instance_context(i))]
+ )
+ if not context:
+ return ""
+ return self.header + "\n" + context
+
+ def _make_instance_context(self, instance: str) -> str:
+ filepath = instance.partition(":")[2]
+ if not os.path.isabs(filepath):
+ filepath = os.path.join(self.base_dir, filepath)
+
+ if not os.path.exists(filepath):
+ raise ContextProviderException(
+ f"File not found while trying to read '{filepath}' "
+ f"triggered by `{instance}`."
+ )
+ if os.path.isdir(filepath):
+ raise ContextProviderException(
+ f"Cannot read directory '{filepath}' triggered by `{instance}`. "
+ f"Only files are supported."
+ )
+ if os.path.splitext(filepath)[1] not in SUPPORTED_EXTS:
+ raise ContextProviderException(
+ f"Cannot read unsupported file type '{filepath}' triggered by `{instance}`. "
+ f"Supported file extensions are: {', '.join(SUPPORTED_EXTS)}."
+ )
+ try:
+ with open(filepath) as f:
+ content = f.read()
+ except PermissionError:
+ raise ContextProviderException(
+ f"Permission denied while trying to read '{filepath}' "
+ f"triggered by `{instance}`."
+ )
+ return FILE_CONTEXT_TEMPLATE.format(
+ filepath=filepath,
+ content=self._process_file(content, filepath),
+ )
+
+ def _process_file(self, content: str, filepath: str):
+ if filepath.endswith(".ipynb"):
+ nb = nbformat.reads(content, as_version=4)
+ return "\n\n".join([cell.source for cell in nb.cells])
+ return content
+
+ def get_filepaths(self, message: HumanChatMessage) -> List[str]:
+ filepaths = []
+ for instance in self._find_instances(message.prompt):
+ filepath = instance.partition(":")[2]
+ if not os.path.isabs(filepath):
+ filepath = os.path.join(self.base_dir, filepath)
+ filepaths.append(filepath)
+ return filepaths
diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/learned.py b/packages/jupyter-ai/jupyter_ai/context_providers/learned.py
new file mode 100644
index 000000000..72f81ef41
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/context_providers/learned.py
@@ -0,0 +1,50 @@
+from typing import List
+
+from jupyter_ai.models import HumanChatMessage
+from jupyter_ai.chat_handlers.learn import Retriever
+
+from .base import BaseContextProvider
+from .file import FileContextProvider
+
+
+FILE_CHUNK_TEMPLATE = """
+Snippet from file: {filepath}
+```
+{content}
+```
+""".strip()
+
+
+class LearnedContextProvider(BaseContextProvider):
+ id = "learned"
+ description = "Include learned context"
+ remove_from_prompt = True
+ header = "Following are snippets from potentially relevant files:"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.retriever = Retriever(learn_chat_handler=self.chat_handlers["/learn"])
+
+ async def make_context_prompt(self, message: HumanChatMessage) -> str:
+ if not self.retriever or not self._find_instances(message.prompt):
+ return ""
+ query = self._clean_prompt(message.body)
+ docs = await self.retriever.ainvoke(query)
+ excluded = self._get_repeated_files(message)
+ context = "\n\n".join(
+ [
+ FILE_CHUNK_TEMPLATE.format(
+ filepath=d.metadata["path"], content=d.page_content
+ )
+ for d in docs
+ if d.metadata["path"] not in excluded and d.page_content
+ ]
+ )
+ return self.header + "\n" + context
+
+ def _get_repeated_files(self, message: HumanChatMessage) -> List[str]:
+ # don't include files that are already provided by the file context provider
+ file_context_provider = self.context_providers.get("file")
+ if isinstance(file_context_provider, FileContextProvider):
+ return file_context_provider.get_filepaths(message)
+ return []
diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py
index 34b484546..c76ebc412 100644
--- a/packages/jupyter-ai/jupyter_ai/extension.py
+++ b/packages/jupyter-ai/jupyter_ai/extension.py
@@ -22,6 +22,7 @@
HelpChatHandler,
LearnChatHandler,
)
+from .context_providers import FileContextProvider, LearnedContextProvider
from .completions.handlers import DefaultInlineCompletionHandler
from .config_manager import ConfigManager
from .handlers import (
@@ -32,6 +33,7 @@
ModelProviderHandler,
RootChatHandler,
SlashCommandsInfoHandler,
+ AutocompleteOptionsHandler,
)
from .history import BoundedChatHistory
@@ -58,6 +60,7 @@ class AiExtension(ExtensionApp):
(r"api/ai/chats/?", RootChatHandler),
(r"api/ai/chats/history?", ChatHistoryHandler),
(r"api/ai/chats/slash_commands?", SlashCommandsInfoHandler),
+ (r"api/ai/chats/autocomplete_options?", AutocompleteOptionsHandler),
(r"api/ai/providers?", ModelProviderHandler),
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
(r"api/ai/completion/inline/?", DefaultInlineCompletionHandler),
@@ -285,6 +288,10 @@ def initialize_settings(self):
eps = entry_points()
+ # Create empty context providers dict to be filled later.
+ # This is created early to use as kwargs for chat handlers.
+ self.settings["jai_context_providers"] = {}
+
# initialize chat handlers
chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")
chat_handlers = {}
@@ -301,6 +308,7 @@ def initialize_settings(self):
"preferred_dir": self.serverapp.contents_manager.preferred_dir,
"help_message_template": self.help_message_template,
"chat_handlers": chat_handlers,
+ "context_providers": self.settings["jai_context_providers"],
}
default_chat_handler = DefaultChatHandler(**chat_handler_kwargs)
clear_chat_handler = ClearChatHandler(**chat_handler_kwargs)
@@ -376,6 +384,58 @@ def initialize_settings(self):
# bind chat handlers to settings
self.settings["jai_chat_handlers"] = chat_handlers
+ # initialize context providers
+ context_providers = self.settings["jai_context_providers"]
+ context_providers_kwargs = {
+ "log": self.log,
+ "config_manager": self.settings["jai_config_manager"],
+ "model_parameters": self.settings["model_parameters"],
+ "chat_history": self.settings["chat_history"],
+ "llm_chat_memory": self.settings["llm_chat_memory"],
+ "root_dir": self.serverapp.root_dir,
+ "dask_client_future": self.settings["dask_client_future"],
+ "model_parameters": self.settings["model_parameters"],
+ "preferred_dir": self.serverapp.contents_manager.preferred_dir,
+ "chat_handlers": self.settings["jai_chat_handlers"],
+ "context_providers": self.settings["jai_context_providers"],
+ }
+ context_providers_clses = [
+ FileContextProvider,
+ LearnedContextProvider,
+ ]
+ context_providers_eps = eps.select(group="jupyter_ai.context_providers")
+ for context_provider_ep in context_providers_eps:
+ try:
+ context_provider = context_provider_ep.load()
+ except Exception as err:
+ self.log.error(
+ f"Unable to load context provider class from entry point `{context_provider_ep.name}`: "
+ + f"Unexpected {err=}, {type(err)=}"
+ )
+ continue
+ context_providers_clses.append(context_provider)
+
+ for context_provider in context_providers_clses:
+ if context_provider.id in context_providers:
+ self.log.error(
+ f"Unable to register context provider `{context_provider.id}` because it already exists"
+ )
+ continue
+
+ if context_provider.is_command and not re.match(
+ r"^[a-zA-Z0-9_]+$", context_provider.id
+ ):
+ self.log.error(
+ f"Context provider `{context_provider.id}` is an invalid ID; "
+ + f"must contain only letters, numbers, and underscores"
+ )
+ continue
+
+ context_providers[context_provider.id] = context_provider(
+ **context_providers_kwargs
+ )
+ self.log.info(f"Registered context provider `{context_provider.id}`.")
+
# show help message at server start
self._show_help_message()
diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py
index 7c4f16e63..6a1df7c03 100644
--- a/packages/jupyter-ai/jupyter_ai/handlers.py
+++ b/packages/jupyter-ai/jupyter_ai/handlers.py
@@ -33,6 +33,8 @@
ListProvidersResponse,
ListSlashCommandsEntry,
ListSlashCommandsResponse,
+ ListOptionsEntry,
+ ListOptionsResponse,
Message,
PendingMessage,
UpdateConfigRequest,
@@ -43,6 +45,7 @@
from jupyter_ai_magics.providers import BaseProvider
from .history import BoundedChatHistory
+ from .context_providers import BaseContextProvider
class ChatHistoryHandler(BaseAPIHandler):
@@ -571,3 +574,101 @@ def get(self):
# sort slash commands by slash id and deliver the response
response.slash_commands.sort(key=lambda sc: sc.slash_id)
self.finish(response.json())
+
+
+class AutocompleteOptionsHandler(BaseAPIHandler):
+ """List context that are currently available to the user."""
+
+ @property
+ def config_manager(self) -> ConfigManager:
+ return self.settings["jai_config_manager"]
+
+ @property
+ def context_providers(self) -> Dict[str, "BaseContextProvider"]:
+ return self.settings["jai_context_providers"]
+
+ @property
+ def chat_handlers(self) -> Dict[str, "BaseChatHandler"]:
+ return self.settings["jai_chat_handlers"]
+
+ @web.authenticated
+ def get(self):
+ response = ListOptionsResponse()
+
+ # if no selected LLM, return an empty response
+ if not self.config_manager.lm_provider:
+ self.finish(response.json())
+ return
+
+ response.options = (
+ self._get_slash_command_options() + self._get_context_provider_options()
+ )
+ self.finish(response.json())
+
+ @web.authenticated
+ def post(self):
+ try:
+ data = self.get_json_body()
+ context_provider = self.context_providers.get(data["id"])
+ arg_prefix = data["arg_prefix"]
+ response = ListOptionsResponse()
+
+ if not context_provider:
+ self.finish(response.json())
+ return
+
+ response.options = context_provider.get_arg_options(arg_prefix)
+ self.finish(response.json())
+ except (ValidationError, WriteConflictError, KeyEmptyError) as e:
+ self.log.exception(e)
+ raise HTTPError(500, str(e)) from e
+ except ValueError as e:
+ self.log.exception(e)
+ raise HTTPError(500, str(e.cause) if hasattr(e, "cause") else str(e))
+ except Exception as e:
+ self.log.exception(e)
+ raise HTTPError(
+ 500, "Unexpected error occurred while updating the context provider."
+ ) from e
+
+ def _get_slash_command_options(self) -> List[ListOptionsEntry]:
+ options = []
+ for id, chat_handler in self.chat_handlers.items():
+ # filter out any chat handler that is not a slash command
+ if (
+ id == "default"
+ or chat_handler.routing_type.routing_method != "slash_command"
+ ):
+ continue
+
+ # hint the type of this attribute
+ routing_type: SlashCommandRoutingType = chat_handler.routing_type
+
+ # filter out any chat handler that is unsupported by the current LLM
+ if (
+ "/" + routing_type.slash_id
+ in self.config_manager.lm_provider.unsupported_slash_commands
+ ):
+ continue
+
+ options.append(
+ ListOptionsEntry.from_command(
+ type="/", id=routing_type.slash_id, description=chat_handler.help
+ )
+ )
+ options.sort(key=lambda opt: opt.id)
+ return options
+
+ def _get_context_provider_options(self) -> List[ListOptionsEntry]:
+ options = [
+ ListOptionsEntry.from_command(
+ type="@",
+ id=context_provider.id,
+ description=context_provider.description,
+ requires_arg=context_provider.requires_arg,
+ )
+ for context_provider in self.context_providers.values()
+ if context_provider.is_command
+ ]
+ options.sort(key=lambda opt: opt.id)
+ return options
diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py
index bda4d3421..0a3476bb4 100644
--- a/packages/jupyter-ai/jupyter_ai/models.py
+++ b/packages/jupyter-ai/jupyter_ai/models.py
@@ -263,3 +263,37 @@ class ListSlashCommandsEntry(BaseModel):
class ListSlashCommandsResponse(BaseModel):
slash_commands: List[ListSlashCommandsEntry] = []
+
+
+class ListOptionsEntry(BaseModel):
+ type: Literal["/", "@"]
+ id: str
+ label: str
+ description: str
+
+ @classmethod
+ def from_command(
+ cls,
+ type: Literal["/", "@"],
+ id: str,
+ description: str,
+ requires_arg: bool = False,
+ ):
+ label = type + id + (":" if requires_arg else " ")
+ return cls(type=type, id=id, description=description, label=label)
+
+ @classmethod
+ def from_arg(
+ cls,
+ type: Literal["/", "@"],
+ id: str,
+ description: str,
+ arg: str,
+ is_complete: bool = True,
+ ):
+ label = type + id + ":" + arg + (" " if is_complete else "")
+ return cls(type=type, id=id, description=description, label=label)
+
+
+class ListOptionsResponse(BaseModel):
+ options: List[ListOptionsEntry] = []
diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py
new file mode 100644
index 000000000..b7c97596e
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py
@@ -0,0 +1,63 @@
+import logging
+from unittest import mock
+
+import pytest
+from jupyter_ai.context_providers import FileContextProvider
+from jupyter_ai.config_manager import ConfigManager
+from jupyter_ai.history import BoundedChatHistory
+from jupyter_ai.models import (
+ ChatClient,
+ HumanChatMessage,
+ Persona,
+)
+
+
+@pytest.fixture
+def human_chat_message() -> HumanChatMessage:
+ chat_client = ChatClient(
+ id=0, username="test", initials="test", name="test", display_name="test"
+ )
+ prompt = (
+ "@file:test1.py @file @file:dir/test2.md test test\n"
+ "@file:/dir/test3.png test@file:test4.py"
+ )
+ return HumanChatMessage(
+ id="test",
+ time=0,
+ body=prompt,
+ prompt=prompt,
+ client=chat_client,
+ )
+
+
+@pytest.fixture
+def file_context_provider() -> FileContextProvider:
+ config_manager = mock.create_autospec(ConfigManager)
+ config_manager.persona = Persona(name="test", avatar_route="test")
+ return FileContextProvider(
+ log=logging.getLogger(__name__),
+ config_manager=config_manager,
+ model_parameters={},
+ chat_history=[],
+ llm_chat_memory=BoundedChatHistory(k=2),
+ root_dir="",
+ preferred_dir="",
+ dask_client_future=None,
+ chat_handlers={},
+ context_providers={},
+ )
+
+
+def test_find_instances(file_context_provider, human_chat_message):
+ expected = ["@file:test1.py", "@file:dir/test2.md", "@file:/dir/test3.png"]
+ instances = file_context_provider._find_instances(human_chat_message.prompt)
+ assert instances == expected
+
+
+def test_replace_prompt(file_context_provider, human_chat_message):
+ expected = (
+ "'test1.py' @file 'dir/test2.md' test test\n"
+ "'/dir/test3.png' test@file:test4.py"
+ )
+ prompt = file_context_provider.replace_prompt(human_chat_message.prompt)
+ assert prompt == expected
diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py
index c1ca7b098..a94c3fbf8 100644
--- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py
+++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py
@@ -76,6 +76,7 @@ def broadcast_message(message: Message) -> None:
dask_client_future=None,
help_message_template=DEFAULT_HELP_MESSAGE_TEMPLATE,
chat_handlers={},
+ context_providers={},
)
diff --git a/packages/jupyter-ai/src/components/chat-input.tsx b/packages/jupyter-ai/src/components/chat-input.tsx
index ddc2d4209..b341653e7 100644
--- a/packages/jupyter-ai/src/components/chat-input.tsx
+++ b/packages/jupyter-ai/src/components/chat-input.tsx
@@ -38,12 +38,6 @@ type ChatInputProps = {
personaName: string;
};
-type SlashCommandOption = {
- id: string;
- label: string;
- description: string;
-};
-
/**
* List of icons per slash command, shown in the autocomplete popup.
*
@@ -65,9 +59,9 @@ const DEFAULT_SLASH_COMMAND_ICONS: Record = {
/**
* Renders an option shown in the slash command autocomplete.
*/
-function renderSlashCommandOption(
+function renderAutocompleteOption(
optionProps: React.HTMLAttributes,
- option: SlashCommandOption
+ option: AiService.AutocompleteOption
): JSX.Element {
const icon =
option.id in DEFAULT_SLASH_COMMAND_ICONS
@@ -99,8 +93,14 @@ function renderSlashCommandOption(
export function ChatInput(props: ChatInputProps): JSX.Element {
const [input, setInput] = useState('');
- const [slashCommandOptions, setSlashCommandOptions] = useState<
- SlashCommandOption[]
+ const [autocompleteOptions, setAutocompleteOptions] = useState<
+ AiService.AutocompleteOption[]
+ >([]);
+ const [autocompleteCommandOptions, setAutocompleteCommandOptions] = useState<
+ AiService.AutocompleteOption[]
+ >([]);
+ const [autocompleteArgOptions, setAutocompleteArgOptions] = useState<
+ AiService.AutocompleteOption[]
>([]);
const [currSlashCommand, setCurrSlashCommand] = useState(null);
const activeCell = useActiveCellContext();
@@ -110,20 +110,45 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
* initial mount to populate the slash command autocomplete.
*/
useEffect(() => {
- async function getSlashCommands() {
- const slashCommands = (await AiService.listSlashCommands())
- .slash_commands;
- setSlashCommandOptions(
- slashCommands.map(slashCommand => ({
- id: slashCommand.slash_id,
- label: '/' + slashCommand.slash_id + ' ',
- description: slashCommand.description
- }))
- );
+ async function getAutocompleteCommandOptions() {
+ const response = await AiService.listAutocompleteOptions();
+ setAutocompleteCommandOptions(response.options);
}
- getSlashCommands();
+ getAutocompleteCommandOptions();
}, []);
+ useEffect(() => {
+ async function getAutocompleteArgOptions() {
+ let options: AiService.AutocompleteOption[] = [];
+ const lastWord = input.split(/\s+/).pop() || '';
+ if (lastWord.startsWith('@') && lastWord.includes(':')) {
+ const [id, argPrefix] = lastWord.split(':', 2);
+ // get option that matches the command
+ const option = autocompleteCommandOptions.find(
+ option => option.id === id.slice(1) && option.type === '@'
+ );
+ if (option) {
+ const response = await AiService.listAutocompleteArgOptions({
+ id: option.id,
+ arg_prefix: argPrefix
+ });
+ options = response.options;
+ }
+ }
+ setAutocompleteArgOptions(options);
+ }
+ getAutocompleteArgOptions();
+ }, [autocompleteCommandOptions, input]);
+
+ // Combine the fixed options with the argument options
+ useEffect(() => {
+ if (autocompleteArgOptions.length > 0) {
+ setAutocompleteOptions(autocompleteArgOptions);
+ } else {
+ setAutocompleteOptions(autocompleteCommandOptions);
+ }
+ }, [autocompleteCommandOptions, autocompleteArgOptions]);
+
// whether any option is highlighted in the slash command autocomplete
const [highlighted, setHighlighted] = useState(false);
@@ -153,7 +178,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
* chat input. Close the autocomplete when the user clears the chat input.
*/
useEffect(() => {
- if (input === '/') {
+ if (input === '/' || input.endsWith('@')) {
setOpen(true);
return;
}
@@ -255,12 +280,39 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
currSlashCommand
};
+ function filterAutocompleteOptions(
+ options: AiService.AutocompleteOption[],
+ inputValue: string
+ ): AiService.AutocompleteOption[] {
+ const lastWord = inputValue.split(/\s+/).pop() || '';
+ if (
+ (lastWord.startsWith('/') && lastWord === inputValue) ||
+ lastWord.startsWith('@')
+ ) {
+ return options.filter(option => option.label.startsWith(lastWord));
+ }
+ return [];
+ }
+
return (
{
+ return filterAutocompleteOptions(options, inputValue);
+ }}
+ onChange={(_, option) => {
+ const value = typeof option === 'string' ? option : option.label;
+ let matchLength = 0;
+ for (let i = 1; i <= value.length; i++) {
+ if (input.endsWith(value.slice(0, i))) {
+ matchLength = i;
+ }
+ }
+ setInput(input + value.slice(matchLength));
+ }}
onInputChange={(_, newValue: string) => {
setInput(newValue);
}}
@@ -273,12 +325,16 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
setHighlighted(!!highlightedOption);
}
}
- onClose={() => setOpen(false)}
+ onClose={(_, reason) => {
+ if (reason !== 'selectOption' || input.endsWith(' ')) {
+ setOpen(false);
+ }
+ }}
// set this to an empty string to prevent the last selected slash
// command from being shown in blue
value=""
open={open}
- options={slashCommandOptions}
+ options={autocompleteOptions}
// hide default extra right padding in the text field
disableClearable
// ensure the autocomplete popup always renders on top
@@ -292,7 +348,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element {
}
}
}}
- renderOption={renderSlashCommandOption}
+ renderOption={renderAutocompleteOption}
ListboxProps={{
sx: {
'& .MuiAutocomplete-option': {
diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts
index b2f4efc16..64771e158 100644
--- a/packages/jupyter-ai/src/handler.ts
+++ b/packages/jupyter-ai/src/handler.ts
@@ -315,4 +315,38 @@ export namespace AiService {
export async function listSlashCommands(): Promise {
return requestAPI('chats/slash_commands');
}
+
+ export type AutocompleteOption = {
+ type: '/' | '@';
+ id: string;
+ description: string;
+ label: string;
+ };
+
+ export type ListAutocompleteOptionsResponse = {
+ options: AutocompleteOption[];
+ };
+
+ export type AutocompleteArgOptionsRequest = {
+ id: string;
+ arg_prefix: string;
+ };
+
+ export async function listAutocompleteOptions(): Promise {
+ return requestAPI(
+ 'chats/autocomplete_options'
+ );
+ }
+
+ export async function listAutocompleteArgOptions(
+ request: AutocompleteArgOptionsRequest
+ ): Promise {
+ return requestAPI(
+ 'chats/autocomplete_options',
+ {
+ method: 'POST',
+ body: JSON.stringify(request)
+ }
+ );
+ }
}