Skip to content

Commit

Permalink
make help message template configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Aug 6, 2024
1 parent 7531f42 commit f96b202
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 119 deletions.
37 changes: 37 additions & 0 deletions packages/jupyter-ai-test/jupyter_ai_test/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,40 @@ class TestProviderWithStreaming(BaseProvider, TestLLMWithStreaming):
fields: ClassVar[List[Field]] = []
"""User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""

class TestProviderAskLearnUnsupported(BaseProvider, TestLLMWithStreaming):
id: ClassVar[str] = "test-provider-ask-learn-unsupported"
"""ID for this provider class."""

name: ClassVar[str] = "Test Provider (/learn and /ask unsupported)"
"""User-facing name of this provider."""

models: ClassVar[List[str]] = ["test"]
"""List of supported models by their IDs. For registry providers, this will
be just ["*"]."""

help: ClassVar[str] = None
"""Text to display in lieu of a model list for a registry provider that does
not provide a list of models."""

model_id_key: ClassVar[str] = "model_id"
"""Kwarg expected by the upstream LangChain provider."""

model_id_label: ClassVar[str] = "Model ID"
"""Human-readable label of the model ID."""

pypi_package_deps: ClassVar[List[str]] = []
"""List of PyPi package dependencies."""

auth_strategy: ClassVar[AuthStrategy] = None
"""Authentication/authorization strategy. Declares what credentials are
required to use this model provider. Generally should not be `None`."""

registry: ClassVar[bool] = False
"""Whether this provider is a registry provider."""

fields: ClassVar[List[Field]] = []
"""User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""

unsupported_slash_commands = {"/learn", "/ask"}
1 change: 1 addition & 0 deletions packages/jupyter-ai-test/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"]
[project.entry-points."jupyter_ai.model_providers"]
test-provider = "jupyter_ai_test.test_providers:TestProvider"
test-provider-with-streaming = "jupyter_ai_test.test_providers:TestProviderWithStreaming"
test-provider-ask-learn-unsupported = "jupyter_ai_test.test_providers:TestProviderAskLearnUnsupported"

[project.entry-points."jupyter_ai.chat_handlers"]
test-slash-command = "jupyter_ai_test.test_slash_commands:TestSlashCommand"
Expand Down
49 changes: 47 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class BaseChatHandler:
"""What this chat handler does, which third-party models it contacts,
the data it returns to the user, and so on, for display in the UI."""

routing_type: HandlerRoutingType = ...
routing_type: ClassVar[HandlerRoutingType] = ...

uses_llm: ClassVar[bool] = True
"""Class attribute specifying whether this chat handler uses the LLM
Expand All @@ -102,10 +102,20 @@ class BaseChatHandler:
parse the arguments and display help when user queries with
`-h` or `--help`"""

_requests_count = 0
_requests_count: ClassVar[int] = 0
"""Class attribute set to the number of requests that Jupyternaut is
currently handling."""

# Instance attributes
help_message_template: str
"""Format string template that is used to build the help message. Specified
from traitlets configuration."""

chat_handlers: Dict[str, 'BaseChatHandler']
"""Dictionary of chat handlers. Allows one chat handler to reference other
chat handlers, which is necessary for some use-cases like printing the help
message."""

def __init__(
self,
log: Logger,
Expand All @@ -116,6 +126,8 @@ def __init__(
root_dir: str,
preferred_dir: Optional[str],
dask_client_future: Awaitable[DaskClient],
help_message_template: str,
chat_handlers: Dict[str, 'BaseChatHandler'],
):
self.log = log
self.config_manager = config_manager
Expand All @@ -133,6 +145,9 @@ def __init__(
self.root_dir = os.path.abspath(os.path.expanduser(root_dir))
self.preferred_dir = get_preferred_dir(self.root_dir, preferred_dir)
self.dask_client_future = dask_client_future
self.help_message_template = help_message_template
self.chat_handlers = chat_handlers

self.llm = None
self.llm_params = None
self.llm_chain = None
Expand Down Expand Up @@ -366,3 +381,33 @@ def output_dir(self) -> str:
return self.preferred_dir
else:
return self.root_dir

def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> None:
"""Sends a help message to all connected clients."""
lm_provider = self.config_manager.lm_provider
unsupported_slash_commands = lm_provider.unsupported_slash_commands if lm_provider else set()
chat_handlers = self.chat_handlers
slash_commands = {k: v for k, v in chat_handlers.items() if k != "default" }
for key in unsupported_slash_commands:
del slash_commands[key]

# markdown string that lists the slash commands
slash_commands_list = "\n".join(
[
f"* `{command_name}` — {handler.help}"
for command_name, handler in slash_commands.items()
]
)

help_message_body = self.help_message_template.format(persona_name=self.persona.name, slash_commands_list=slash_commands_list)
help_message = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=help_message_body,
reply_to=human_msg.id if human_msg else "",
persona=self.persona,
)

self._chat_history.append(help_message)
for websocket in self._root_chat_handlers.values():
websocket.write_message(help_message.json())
21 changes: 5 additions & 16 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from typing import List

from jupyter_ai.chat_handlers.help import build_help_message
from jupyter_ai.models import ChatMessage, ClearMessage
from jupyter_ai.models import ClearMessage

from .base import BaseChatHandler, SlashCommandRoutingType

Expand All @@ -20,22 +17,14 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def process_message(self, _):
# Clear chat
for handler in self._root_chat_handlers.values():
if not handler:
continue

# Clear chat
handler.broadcast_message(ClearMessage())
self._chat_history.clear()

# Build /help message and reinstate it in chat
chat_handlers = handler.chat_handlers
persona = self.config_manager.persona
lm_provider = self.config_manager.lm_provider
unsupported_slash_commands = (
lm_provider.unsupported_slash_commands if lm_provider else set()
)
msg = build_help_message(chat_handlers, persona, unsupported_slash_commands)
self.reply(msg.body)

break

# re-send help message
self.send_help_message()
64 changes: 3 additions & 61 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,7 @@
import time
from typing import Dict
from uuid import uuid4

from jupyter_ai.models import AgentChatMessage, HumanChatMessage
from jupyter_ai_magics import Persona
from jupyter_ai.models import HumanChatMessage

from .base import BaseChatHandler, SlashCommandRoutingType

HELP_MESSAGE = """Hi there! I'm {persona_name}, your programming assistant.
You can ask me a question using the text box below. You can also use these commands:
{commands}
Jupyter AI includes [magic commands](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#the-ai-and-ai-magic-commands) that you can use in your notebooks.
For more information, see the [documentation](https://jupyter-ai.readthedocs.io).
"""


def _format_help_message(
chat_handlers: Dict[str, BaseChatHandler],
persona: Persona,
unsupported_slash_commands: set,
):
if unsupported_slash_commands:
keys = set(chat_handlers.keys()) - unsupported_slash_commands
chat_handlers = {key: chat_handlers[key] for key in keys}

commands = "\n".join(
[
f"* `{command_name}` — {handler.help}"
for command_name, handler in chat_handlers.items()
if command_name != "default"
]
)
return HELP_MESSAGE.format(commands=commands, persona_name=persona.name)


def build_help_message(
chat_handlers: Dict[str, BaseChatHandler],
persona: Persona,
unsupported_slash_commands: set,
):
return AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=_format_help_message(chat_handlers, persona, unsupported_slash_commands),
reply_to="",
persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
)


class HelpChatHandler(BaseChatHandler):
id = "help"
name = "Help"
Expand All @@ -58,19 +11,8 @@ class HelpChatHandler(BaseChatHandler):

uses_llm = False

def __init__(self, *args, chat_handlers: Dict[str, BaseChatHandler], **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._chat_handlers = chat_handlers

async def process_message(self, message: HumanChatMessage):
persona = self.config_manager.persona
lm_provider = self.config_manager.lm_provider
unsupported_slash_commands = (
lm_provider.unsupported_slash_commands if lm_provider else set()
)
self.reply(
_format_help_message(
self._chat_handlers, persona, unsupported_slash_commands
),
message,
)
self.send_help_message(message)
Loading

0 comments on commit f96b202

Please sign in to comment.