Skip to content

Commit

Permalink
context provider
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelchia committed Sep 19, 2024
1 parent fcb2d71 commit 8c4380d
Show file tree
Hide file tree
Showing 15 changed files with 889 additions and 31 deletions.
29 changes: 24 additions & 5 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,26 @@
The following is a friendly conversation between you and a human.
""".strip()

CHAT_DEFAULT_TEMPLATE = """Current conversation:
{history}
Human: {input}
CHAT_DEFAULT_TEMPLATE = """
{% if context %}
Context:
{{context}}
{% endif %}
Current conversation:
{{history}}
Human: {{input}}
AI:"""

HUMAN_MESSAGE_TEMPLATE = """
{% if context %}
<context>
{{context}}
</context>
{% endif %}
{{input}}
"""

COMPLETION_SYSTEM_PROMPT = """
You are an application built to provide helpful code completion suggestions.
Expand Down Expand Up @@ -400,17 +415,21 @@ def get_chat_prompt_template(self) -> PromptTemplate:
CHAT_SYSTEM_PROMPT
).format(provider_name=name, local_model_id=self.model_id),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
HumanMessagePromptTemplate.from_template(
HUMAN_MESSAGE_TEMPLATE,
template_format="jinja2",
),
]
)
else:
return PromptTemplate(
input_variables=["history", "input"],
input_variables=["history", "input", "context"],
template=CHAT_SYSTEM_PROMPT.format(
provider_name=name, local_model_id=self.model_id
)
+ "\n\n"
+ CHAT_DEFAULT_TEMPLATE,
template_format="jinja2",
)

def get_completion_prompt_template(self) -> PromptTemplate:
Expand Down
7 changes: 7 additions & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
if TYPE_CHECKING:
from jupyter_ai.handlers import RootChatHandler
from jupyter_ai.history import BoundedChatHistory
from jupyter_ai.context_providers import BaseContextProvider
from langchain_core.chat_history import BaseChatMessageHistory


Expand Down Expand Up @@ -121,6 +122,10 @@ class BaseChatHandler:
chat handlers, which is necessary for some use-cases like printing the help
message."""

context_providers: Dict[str, Type["BaseContextProvider"]]
"""Dictionary of context providers. Allows chat handlers to reference
context providers, which can be used to provide context to the LLM."""

def __init__(
self,
log: Logger,
Expand All @@ -134,6 +139,7 @@ def __init__(
dask_client_future: Awaitable[DaskClient],
help_message_template: str,
chat_handlers: Dict[str, "BaseChatHandler"],
context_providers: Dict[str, Type["BaseContextProvider"]],
):
self.log = log
self.config_manager = config_manager
Expand All @@ -154,6 +160,7 @@ def __init__(
self.dask_client_future = dask_client_future
self.help_message_template = help_message_template
self.chat_handlers = chat_handlers
self.context_providers = context_providers

self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
Expand Down
26 changes: 25 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import time
from typing import Dict, Type
from uuid import uuid4
Expand All @@ -13,6 +14,7 @@
from langchain_core.runnables.history import RunnableWithMessageHistory

from ..models import HumanChatMessage
from ..context_providers import ContextProviderException
from .base import BaseChatHandler, SlashCommandRoutingType


Expand All @@ -27,6 +29,7 @@ class DefaultChatHandler(BaseChatHandler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt_template = None

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
Expand All @@ -40,6 +43,7 @@ def create_llm_chain(

prompt_template = llm.get_chat_prompt_template()
self.llm = llm
self.prompt_template = prompt_template

runnable = prompt_template | llm # type:ignore
if not llm.manages_history:
Expand Down Expand Up @@ -101,14 +105,24 @@ async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
received_first_chunk = False

inputs = {"input": message.body}
if "context" in self.prompt_template.input_variables:
# include context from context providers.
try:
context_prompt = await self.make_context_prompt(message)
except ContextProviderException as e:
self.reply(str(e), message)
return
inputs["context"] = context_prompt

# start with a pending message
with self.pending("Generating response", message) as pending_message:
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
assert self.llm_chain
async for chunk in self.llm_chain.astream(
{"input": message.body},
inputs,
config={"configurable": {"last_human_msg": message}},
):
if not received_first_chunk:
Expand All @@ -128,3 +142,13 @@ async def process_message(self, message: HumanChatMessage):

# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)

async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
return "\n\n".join(
await asyncio.gather(
*[
provider.make_context_prompt(human_msg)
for provider in self.context_providers.values()
]
)
)
3 changes: 3 additions & 0 deletions packages/jupyter-ai/jupyter_ai/context_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import BaseContextProvider, ContextProviderException
from .file import FileContextProvider
from .learned import LearnedContextProvider
133 changes: 133 additions & 0 deletions packages/jupyter-ai/jupyter_ai/context_providers/_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# This file is for illustrative purposes
# It is to be deleted before merging
from jupyter_ai.models import HumanChatMessage
from langchain_community.retrievers import WikipediaRetriever
from langchain_community.retrievers import ArxivRetriever
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

from .base import BaseContextProvider


# Examples of the ease of implementing retriever based context providers
ARXIV_TEMPLATE = """
Title: {title}
Publish Date: {publish_date}
'''
{content}
'''
""".strip()


class ArxivContextProvider(BaseContextProvider):
id = "arvix"
description = "Include papers from Arxiv"
remove_from_prompt = True
header = "Following are snippets of research papers:"

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.retriever = ArxivRetriever()

async def make_context_prompt(self, message: HumanChatMessage) -> str:
if not self._find_instances(message.prompt):
return ""
query = self._clean_prompt(message.body)
docs = await self.retriever.ainvoke(query)
context = "\n\n".join(
[
ARXIV_TEMPLATE.format(
content=d.page_content,
title=d.metadata["Title"],
publish_date=d.metadata["Published"],
)
for d in docs
]
)
return self.header + "\n" + context


# Another retriever based context provider with a rewrite step using LLM
WIKI_TEMPLATE = """
Title: {title}
'''
{content}
'''
""".strip()

REWRITE_TEMPLATE = """Provide a better search query for \
web search engine to answer the given question, end \
the queries with ’**’. Question: \
{x} Answer:"""


class WikiContextProvider(BaseContextProvider):
id = "wiki"
description = "Include knowledge from Wikipedia"
remove_from_prompt = True
header = "Following are information from wikipedia:"

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.retriever = WikipediaRetriever()

async def make_context_prompt(self, message: HumanChatMessage) -> str:
if not self._find_instances(message.prompt):
return ""
prompt = self._clean_prompt(message.body)
search_query = await self._rewrite_prompt(prompt)
docs = await self.retriever.ainvoke(search_query)
context = "\n\n".join(
[
WIKI_TEMPLATE.format(
content=d.page_content,
title=d.metadata["title"],
)
for d in docs
]
)
return self.header + "\n" + context

async def _rewrite_prompt(self, prompt: str) -> str:
return await self.get_llm_chain().ainvoke(prompt)

def get_llm_chain(self):
# from https://github.com/langchain-ai/langchain/blob/master/cookbook/rewrite.ipynb
llm = self.get_llm()
rewrite_prompt = ChatPromptTemplate.from_template(REWRITE_TEMPLATE)

def _parse(text):
return text.strip('"').strip("**")

return rewrite_prompt | llm | StrOutputParser() | _parse


# Partial example of non-command context provider for errors.
# Assuming there is an option in UI to add cell errors to messages,
# default chat will automatically invoke this context provider to add
# solutions retrieved from a custom error database or a stackoverflow / google
# retriever pipeline to find solutions for errors.
class ErrorContextProvider(BaseContextProvider):
id = "error"
description = "Include custom error context"
remove_from_prompt = True
header = "Following are potential solutions for the error:"
is_command = False # will not show up in autocomplete

async def make_context_prompt(self, message: HumanChatMessage) -> str:
# will run for every message with a cell error since it does not
# use _find_instances to check for the presence of the command in
# the message.
if not (message.selection and message.selection.type == "cell-with-error"):
return ""
docs = await self.solution_retriever.ainvoke(message.selection)
if not docs:
return ""
context = "\n\n".join([d.page_content for d in docs])
return self.header + "\n" + context

@property
def solution_retriever(self):
# retriever that takes an error and returns a solutions from a database
# of error messages.
raise NotImplementedError("Error retriever not implemented")
Loading

0 comments on commit 8c4380d

Please sign in to comment.