Skip to content

Commit

Permalink
Framework for adding context to LLM prompt (jupyterlab#993)
Browse files Browse the repository at this point in the history
* context provider

* split base and base command context providers + replacing prompt

* comment

* only replace prompt if context variable in template

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Run mypy on CI, fix or ignore typing issues (jupyterlab#987)

* Run mypy on CI

* Rename, add mypy to test deps

* Fix typing jupyter-ai codebase (mostly)

* Three more cases

* update deepmerge version specifier

---------

Co-authored-by: David L. Qiu <[email protected]>

* context provider

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* mypy

* black

* modify backtick logic

* allow for spaces in filepath

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test

* refactor autocomplete to remove hardcoded '/' and '@' prefix

* modify context prompt template

Co-authored-by: david qiu <[email protected]>

* refactor

* docstrings + refactor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* mypy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add context providers to help

* remove _examples.py and remove @learned from defaults

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make find_commands unoverridable

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Michał Krassowski <[email protected]>
Co-authored-by: David L. Qiu <[email protected]>
  • Loading branch information
4 people authored and Marchlak committed Oct 28, 2024
1 parent 36faff3 commit d704b72
Show file tree
Hide file tree
Showing 14 changed files with 942 additions and 87 deletions.
28 changes: 23 additions & 5 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,25 @@
The following is a friendly conversation between you and a human.
""".strip()

CHAT_DEFAULT_TEMPLATE = """Current conversation:
{history}
Human: {input}
CHAT_DEFAULT_TEMPLATE = """
{% if context %}
Context:
{{context}}
{% endif %}
Current conversation:
{{history}}
Human: {{input}}
AI:"""

HUMAN_MESSAGE_TEMPLATE = """
{% if context %}
Context:
{{context}}
{% endif %}
{{input}}
"""

COMPLETION_SYSTEM_PROMPT = """
You are an application built to provide helpful code completion suggestions.
Expand Down Expand Up @@ -400,17 +414,21 @@ def get_chat_prompt_template(self) -> PromptTemplate:
CHAT_SYSTEM_PROMPT
).format(provider_name=name, local_model_id=self.model_id),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
HumanMessagePromptTemplate.from_template(
HUMAN_MESSAGE_TEMPLATE,
template_format="jinja2",
),
]
)
else:
return PromptTemplate(
input_variables=["history", "input"],
input_variables=["history", "input", "context"],
template=CHAT_SYSTEM_PROMPT.format(
provider_name=name, local_model_id=self.model_id
)
+ "\n\n"
+ CHAT_DEFAULT_TEMPLATE,
template_format="jinja2",
)

def get_completion_prompt_template(self) -> PromptTemplate:
Expand Down
18 changes: 17 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from langchain.pydantic_v1 import BaseModel

if TYPE_CHECKING:
from jupyter_ai.context_providers import BaseCommandContextProvider
from jupyter_ai.handlers import RootChatHandler
from jupyter_ai.history import BoundedChatHistory
from langchain_core.chat_history import BaseChatMessageHistory
Expand Down Expand Up @@ -121,6 +122,10 @@ class BaseChatHandler:
chat handlers, which is necessary for some use-cases like printing the help
message."""

context_providers: Dict[str, "BaseCommandContextProvider"]
"""Dictionary of context providers. Allows chat handlers to reference
context providers, which can be used to provide context to the LLM."""

def __init__(
self,
log: Logger,
Expand All @@ -134,6 +139,7 @@ def __init__(
dask_client_future: Awaitable[DaskClient],
help_message_template: str,
chat_handlers: Dict[str, "BaseChatHandler"],
context_providers: Dict[str, "BaseCommandContextProvider"],
):
self.log = log
self.config_manager = config_manager
Expand All @@ -154,6 +160,7 @@ def __init__(
self.dask_client_future = dask_client_future
self.help_message_template = help_message_template
self.chat_handlers = chat_handlers
self.context_providers = context_providers

self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
Expand Down Expand Up @@ -430,8 +437,17 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non
]
)

context_commands_list = "\n".join(
[
f"* `{cp.command_id}` — {cp.help}"
for cp in self.context_providers.values()
]
)

help_message_body = self.help_message_template.format(
persona_name=self.persona.name, slash_commands_list=slash_commands_list
persona_name=self.persona.name,
slash_commands_list=slash_commands_list,
context_commands_list=context_commands_list,
)
help_message = AgentChatMessage(
id=uuid4().hex,
Expand Down
35 changes: 34 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import time
from typing import Dict, Type
from uuid import uuid4
Expand All @@ -12,6 +13,7 @@
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory

from ..context_providers import ContextProviderException, find_commands
from ..models import HumanChatMessage
from .base import BaseChatHandler, SlashCommandRoutingType

Expand All @@ -27,6 +29,7 @@ class DefaultChatHandler(BaseChatHandler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt_template = None

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
Expand All @@ -40,6 +43,7 @@ def create_llm_chain(

prompt_template = llm.get_chat_prompt_template()
self.llm = llm
self.prompt_template = prompt_template

runnable = prompt_template | llm # type:ignore
if not llm.manages_history:
Expand Down Expand Up @@ -101,14 +105,25 @@ async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
received_first_chunk = False

inputs = {"input": message.body}
if "context" in self.prompt_template.input_variables:
# include context from context providers.
try:
context_prompt = await self.make_context_prompt(message)
except ContextProviderException as e:
self.reply(str(e), message)
return
inputs["context"] = context_prompt
inputs["input"] = self.replace_prompt(inputs["input"])

# start with a pending message
with self.pending("Generating response", message) as pending_message:
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
assert self.llm_chain
async for chunk in self.llm_chain.astream(
{"input": message.body},
inputs,
config={"configurable": {"last_human_msg": message}},
):
if not received_first_chunk:
Expand All @@ -128,3 +143,21 @@ async def process_message(self, message: HumanChatMessage):

# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)

async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
return "\n\n".join(
await asyncio.gather(
*[
provider.make_context_prompt(human_msg)
for provider in self.context_providers.values()
if find_commands(provider, human_msg.prompt)
]
)
)

def replace_prompt(self, prompt: str) -> str:
# modifies prompt by the context providers.
# some providers may modify or remove their '@' commands from the prompt.
for provider in self.context_providers.values():
prompt = provider.replace_prompt(prompt)
return prompt
7 changes: 7 additions & 0 deletions packages/jupyter-ai/jupyter_ai/context_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .base import (
BaseCommandContextProvider,
ContextCommand,
ContextProviderException,
find_commands,
)
from .file import FileContextProvider
53 changes: 53 additions & 0 deletions packages/jupyter-ai/jupyter_ai/context_providers/_learned.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Currently unused as it is duplicating the functionality of the /ask command.
# TODO: Rename "learned" to something better.
from typing import List

from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai.models import HumanChatMessage

from .base import BaseCommandContextProvider, ContextCommand
from .file import FileContextProvider

FILE_CHUNK_TEMPLATE = """
Snippet from file: {filepath}
```
{content}
```
""".strip()


class LearnedContextProvider(BaseCommandContextProvider):
id = "learned"
help = "Include content indexed from `/learn`"
remove_from_prompt = True
header = "Following are snippets from potentially relevant files:"

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.retriever = Retriever(learn_chat_handler=self.chat_handlers["/learn"])

async def _make_context_prompt(
self, message: HumanChatMessage, commands: List[ContextCommand]
) -> str:
if not self.retriever:
return ""
query = self._clean_prompt(message.body)
docs = await self.retriever.ainvoke(query)
excluded = self._get_repeated_files(message)
context = "\n\n".join(
[
FILE_CHUNK_TEMPLATE.format(
filepath=d.metadata["path"], content=d.page_content
)
for d in docs
if d.metadata["path"] not in excluded and d.page_content
]
)
return self.header + "\n" + context

def _get_repeated_files(self, message: HumanChatMessage) -> List[str]:
# don't include files that are already provided by the file context provider
file_context_provider = self.context_providers.get("file")
if isinstance(file_context_provider, FileContextProvider):
return file_context_provider.get_filepaths(message)
return []
Loading

0 comments on commit d704b72

Please sign in to comment.