Skip to content

Commit

Permalink
split base and base command context providers + replacing prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelchia committed Sep 19, 2024
1 parent 8c4380d commit 6e79d8b
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 64 deletions.
9 changes: 8 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
received_first_chunk = False

inputs = {"input": message.body}
inputs = {"input": self.replace_prompt(message.body)}
if "context" in self.prompt_template.input_variables:
# include context from context providers.
try:
Expand Down Expand Up @@ -152,3 +152,10 @@ async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
]
)
)

def replace_prompt(self, prompt: str) -> str:
# modifies prompt by the context providers.
# some providers may modify or remove their '@' commands from the prompt.
for provider in self.context_providers.values():
prompt = provider.replace_prompt(prompt)
return prompt
6 changes: 5 additions & 1 deletion packages/jupyter-ai/jupyter_ai/context_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .base import BaseContextProvider, ContextProviderException
from .base import (
BaseContextProvider,
BaseCommandContextProvider,
ContextProviderException,
)
from .file import FileContextProvider
from .learned import LearnedContextProvider
6 changes: 3 additions & 3 deletions packages/jupyter-ai/jupyter_ai/context_providers/_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

from .base import BaseContextProvider
from .base import BaseContextProvider, BaseCommandContextProvider


# Examples of the ease of implementing retriever based context providers
Expand All @@ -19,7 +19,7 @@
""".strip()


class ArxivContextProvider(BaseContextProvider):
class ArxivContextProvider(BaseCommandContextProvider):
id = "arvix"
description = "Include papers from Arxiv"
remove_from_prompt = True
Expand Down Expand Up @@ -61,7 +61,7 @@ async def make_context_prompt(self, message: HumanChatMessage) -> str:
{x} Answer:"""


class WikiContextProvider(BaseContextProvider):
class WikiContextProvider(BaseCommandContextProvider):
id = "wiki"
description = "Include knowledge from Wikipedia"
remove_from_prompt = True
Expand Down
108 changes: 67 additions & 41 deletions packages/jupyter-ai/jupyter_ai/context_providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@
class BaseContextProvider(abc.ABC):
id: ClassVar[str]
description: ClassVar[str]
requires_arg: ClassVar[bool] = False
is_command: ClassVar[bool] = (
True # whether the context provider can be invoked from chat
)
remove_from_prompt: ClassVar[bool] = (
False # whether the command should be removed from prompt
)

def __init__(
self,
Expand Down Expand Up @@ -65,51 +58,19 @@ def __init__(

self.llm = None

@property
def pattern(self) -> str:
return (
rf"(?<![^\s.])@{self.id}:[^\s]+"
if self.requires_arg
else rf"(?<![^\s.])@{self.id}(?![^\s.])"
)

@abc.abstractmethod
async def make_context_prompt(self, message: HumanChatMessage) -> str:
"""Returns a context prompt for all instances of the context provider
command.
"""
pass

def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]:
"""Returns a list of autocomplete options for arguments to the command
based on the prefix.
Only triggered if ':' is present after the command id (e.g. '@file:').
"""
return []

def replace_prompt(self, prompt: str) -> str:
"""Cleans up instances of the command from the prompt before
sending it to the LLM
"""
if self.remove_from_prompt:
return re.sub(self.pattern, "", prompt)
"""Modifies the prompt before sending it to the LLM."""
return prompt

def _find_instances(self, text: str) -> List[str]:
# finds instances of the context provider command in the text
matches = re.finditer(self.pattern, text)
results = []
for match in matches:
start, end = match.span()
before = text[:start]
after = text[end:]
# Check if the match is within backticks
if before.count("`") % 2 == 0 and after.count("`") % 2 == 0:
results.append(match.group())
return results

def _clean_prompt(self, text: str) -> str:
# useful for cleaning up the prompt before sending it to a retriever
# util for cleaning up the prompt before sending it to a retriever
for provider in self.context_providers.values():
text = provider.replace_prompt(text)
return text
Expand Down Expand Up @@ -152,6 +113,71 @@ def get_llm(self):
return self.llm


class BaseCommandContextProvider(BaseContextProvider):
requires_arg: ClassVar[bool] = False
remove_from_prompt: ClassVar[bool] = (
False # whether the command should be removed from prompt
)

@property
def pattern(self) -> str:
return (
rf"(?<![^\s.])@{self.id}:[^\s]+"
if self.requires_arg
else rf"(?<![^\s.])@{self.id}(?![^\s.])"
)

def replace_prompt(self, prompt: str) -> str:
"""Cleans up instances of the command from the prompt before
sending it to the LLM
"""

def replace(match):
if _is_within_backticks(match, prompt):
return match.group()
return self._replace_instance(match.group())

return re.sub(self.pattern, replace, prompt)

def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]:
"""Returns a list of autocomplete options for arguments to the command
based on the prefix.
Only triggered if ':' is present after the command id (e.g. '@file:').
"""
if self.requires_arg:
# default implementation that should be modified if 'requires_arg' is True
return [
ListOptionsEntry.from_arg(
type="@",
id=self.id,
description=self.description,
arg=arg_prefix,
is_complete=True,
)
]
return []

def _find_instances(self, text: str) -> List[str]:
# finds instances of the context provider command in the text
matches = re.finditer(self.pattern, text)
results = []
for match in matches:
if not _is_within_backticks(match, text):
results.append(match.group())
return results

def _replace_instance(self, instance: str) -> str:
if self.remove_from_prompt:
return ""
return instance


def _is_within_backticks(match, text):
start, _ = match.span()
before = text[:start]
return before.count("`") % 2 == 1


class ContextProviderException(Exception):
# Used to generate a response when a context provider fails
pass
17 changes: 7 additions & 10 deletions packages/jupyter-ai/jupyter_ai/context_providers/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jupyter_ai.models import ListOptionsEntry, HumanChatMessage
from jupyter_ai.document_loaders.directory import SUPPORTED_EXTS

from .base import BaseContextProvider, ContextProviderException
from .base import BaseCommandContextProvider, ContextProviderException

FILE_CONTEXT_TEMPLATE = """
File: {filepath}
Expand All @@ -17,20 +17,12 @@
""".strip()


class FileContextProvider(BaseContextProvider):
class FileContextProvider(BaseCommandContextProvider):
id = "file"
description = "Include file contents"
requires_arg = True
header = "Following are contents of files referenced:"

def replace_prompt(self, prompt: str) -> str:
# replaces instances of @file:<filepath> with '<filepath>'
def substitute(match):
filepath = match.group(0).partition(":")[2]
return f"'{filepath}'"

return re.sub(self.pattern, substitute, prompt)

def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]:
is_abs = not os.path.isabs(arg_prefix)
path_prefix = arg_prefix if is_abs else os.path.join(self.base_dir, arg_prefix)
Expand Down Expand Up @@ -106,6 +98,11 @@ def _process_file(self, content: str, filepath: str):
return "\n\n".join([cell.source for cell in nb.cells])
return content

def _replace_instance(self, instance: str) -> str:
# replaces instances of @file:<filepath> with '<filepath>'
filepath = instance.partition(":")[2]
return f"'{filepath}'"

def get_filepaths(self, message: HumanChatMessage) -> List[str]:
filepaths = []
for instance in self._find_instances(message.prompt):
Expand Down
4 changes: 2 additions & 2 deletions packages/jupyter-ai/jupyter_ai/context_providers/learned.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jupyter_ai.models import HumanChatMessage
from jupyter_ai.chat_handlers.learn import Retriever

from .base import BaseContextProvider
from .base import BaseCommandContextProvider
from .file import FileContextProvider


Expand All @@ -15,7 +15,7 @@
""".strip()


class LearnedContextProvider(BaseContextProvider):
class LearnedContextProvider(BaseCommandContextProvider):
id = "learned"
description = "Include learned context"
remove_from_prompt = True
Expand Down
4 changes: 1 addition & 3 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,7 @@ def initialize_settings(self):
)
continue

if context_provider.is_command and not re.match(
r"^[a-zA-Z0-9_]+$", context_provider.id
):
if not re.match(r"^[a-zA-Z0-9_]+$", context_provider.id):
self.log.error(
f"Context provider `{context_provider.id}` is an invalid ID; "
+ f"must contain only letters, numbers, and underscores"
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tornado
from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType
from jupyter_ai.config_manager import ConfigManager, KeyEmptyError, WriteConflictError
from jupyter_ai.context_providers import BaseCommandContextProvider
from jupyter_server.base.handlers import APIHandler as BaseAPIHandler
from jupyter_server.base.handlers import JupyterHandler
from langchain.pydantic_v1 import ValidationError
Expand Down Expand Up @@ -668,7 +669,7 @@ def _get_context_provider_options(self) -> List[ListOptionsEntry]:
requires_arg=context_provider.requires_arg,
)
for context_provider in self.context_providers.values()
if context_provider.is_command
if isinstance(context_provider, BaseCommandContextProvider)
]
options.sort(key=lambda opt: opt.id)
return options
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def human_chat_message() -> HumanChatMessage:
)
prompt = (
"@file:test1.py @file @file:dir/test2.md test test\n"
"@file:/dir/test3.png test@file:test4.py"
"@file:/dir/test3.png test@file:test4.py ```\n@file:test5.py\n```"
)
return HumanChatMessage(
id="test",
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_find_instances(file_context_provider, human_chat_message):
def test_replace_prompt(file_context_provider, human_chat_message):
expected = (
"'test1.py' @file 'dir/test2.md' test test\n"
"'/dir/test3.png' test@file:test4.py"
"'/dir/test3.png' test@file:test4.py ```\n@file:test5.py\n```"
)
prompt = file_context_provider.replace_prompt(human_chat_message.prompt)
assert prompt == expected

0 comments on commit 6e79d8b

Please sign in to comment.