From d0822d2d89ec4d4385254cf922cd1715f88d0763 Mon Sep 17 00:00:00 2001 From: Jason Weill <93281816+JasonWeill@users.noreply.github.com> Date: Thu, 7 Dec 2023 09:59:12 -0800 Subject: [PATCH] Base chat handler refactor for custom slash commands (#398) * Adds attributes, starts adding to subclasses * Consistent syntax * Help for all handlers * Fix slash ID error * Iterate through entry points * Fix typo in call to select() * Moves config to magics, modifies extensions to attempt to load classes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Moves config to proper location, improves error logging * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP: Updates per feedback, adds custom handler * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removes redundant code, style fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removes unnecessary custom message * Instantiates class * Validates slash ID * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Consistent arguments to chat handlers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactors to avoid intentionally unused params * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updates docs, removes custom handler from source and config * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Renames process_message to match base class * Adds needed parameter that had been deleted * Joins lines in contributor doc * Removes natural language routing type, which is not yet used * Update docs/source/developers/index.md Co-authored-by: Piyush Jain * Update docs/source/developers/index.md Co-authored-by: Piyush Jain * Update docs/source/developers/index.md Co-authored-by: Piyush Jain * Revises per @3coins, avoids Latinism * Removes Configurable, since we do not yet have configurable traits * Uses Literal for validation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Piyush Jain --- docs/source/developers/index.md | 42 ++++++++++ .../jupyter_ai/chat_handlers/__init__.py | 2 +- .../jupyter_ai/chat_handlers/ask.py | 7 +- .../jupyter_ai/chat_handlers/base.py | 54 +++++++++++- .../jupyter_ai/chat_handlers/clear.py | 10 ++- .../jupyter_ai/chat_handlers/default.py | 14 ++-- .../jupyter_ai/chat_handlers/generate.py | 10 ++- .../jupyter_ai/chat_handlers/help.py | 7 +- .../jupyter_ai/chat_handlers/learn.py | 13 +-- packages/jupyter-ai/jupyter_ai/extension.py | 84 +++++++++++++++---- packages/jupyter-ai/jupyter_ai/handlers.py | 2 +- 11 files changed, 201 insertions(+), 44 deletions(-) diff --git a/docs/source/developers/index.md b/docs/source/developers/index.md index 12e714048..ba3c969d7 100644 --- a/docs/source/developers/index.md +++ b/docs/source/developers/index.md @@ -120,3 +120,45 @@ class MyProvider(BaseProvider, FakeListLLM): ``` Please note that this will only work with Jupyter AI magics (the `%ai` and `%%ai` magic commands). Custom prompt templates are not used in the chat interface yet. + +## Custom slash commands in the chat UI + +You can add a custom slash command to the chat interface by +creating a new class that inherits from `BaseChatHandler`. Set +its `id`, `name`, `help` message for display in the user interface, +and `routing_type`. Each custom slash command must have a unique +slash command. Slash commands can only contain ASCII letters, numerals, +and underscores. Each slash command must be unique; custom slash +commands cannot replace built-in slash commands. + +Add your custom handler in Python code: + +```python +from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType +from jupyter_ai.models import HumanChatMessage + +class CustomChatHandler(BaseChatHandler): + id = "custom" + name = "Custom" + help = "A chat handler that does something custom" + routing_type = SlashCommandRoutingType(slash_id="custom") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def process_message(self, message: HumanChatMessage): + # Put your custom logic here + self.reply("", message) +``` + +Jupyter AI uses entry points to support custom slash commands. +In the `pyproject.toml` file, add your custom handler to the +`[project.entry-points."jupyter_ai.chat_handlers"]` section: + +``` +[project.entry-points."jupyter_ai.chat_handlers"] +custom = "custom_package:CustomChatHandler" +``` + +Then, install your package so that Jupyter AI adds custom chat handlers +to the existing chat handlers. diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py index c3c64b789..e4c69f012 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py @@ -1,5 +1,5 @@ from .ask import AskChatHandler -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType from .clear import ClearChatHandler from .default import DefaultChatHandler from .generate import GenerateChatHandler diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index e5c852051..bfb55ce21 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -7,7 +7,7 @@ from langchain.memory import ConversationBufferWindowMemory from langchain.prompts import PromptTemplate -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. @@ -26,6 +26,11 @@ class AskChatHandler(BaseChatHandler): to the LLM to generate the final reply. """ + id = "ask" + name = "Ask with Local Data" + help = "Asks a question with retrieval augmented generation (RAG)" + routing_type = SlashCommandRoutingType(slash_id="ask") + def __init__(self, retriever, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index c7ca70f97..15aef6788 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -1,35 +1,81 @@ import argparse +import os import time import traceback - -# necessary to prevent circular import -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import ( + TYPE_CHECKING, + Awaitable, + ClassVar, + Dict, + List, + Literal, + Optional, + Type, +) from uuid import uuid4 +from dask.distributed import Client as DaskClient from jupyter_ai.config_manager import ConfigManager, Logger -from jupyter_ai.models import AgentChatMessage, HumanChatMessage +from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider +# necessary to prevent circular import +from pydantic import BaseModel + if TYPE_CHECKING: from jupyter_ai.handlers import RootChatHandler +# Chat handler type, with specific attributes for each +class HandlerRoutingType(BaseModel): + routing_method: ClassVar[str] = Literal["slash_command"] + """The routing method that sends commands to this handler.""" + + +class SlashCommandRoutingType(HandlerRoutingType): + routing_method = "slash_command" + + slash_id: Optional[str] + """Slash ID for routing a chat command to this handler. Only one handler + may declare a particular slash ID. Must contain only alphanumerics and + underscores.""" + + class BaseChatHandler: """Base ChatHandler class containing shared methods and attributes used by multiple chat handler classes.""" + # Class attributes + id: ClassVar[str] = ... + """ID for this chat handler; should be unique""" + + name: ClassVar[str] = ... + """User-facing name of this handler""" + + help: ClassVar[str] = ... + """What this chat handler does, which third-party models it contacts, + the data it returns to the user, and so on, for display in the UI.""" + + routing_type: HandlerRoutingType = ... + def __init__( self, log: Logger, config_manager: ConfigManager, root_chat_handlers: Dict[str, "RootChatHandler"], model_parameters: Dict[str, Dict], + chat_history: List[ChatMessage], + root_dir: str, + dask_client_future: Awaitable[DaskClient], ): self.log = log self.config_manager = config_manager self._root_chat_handlers = root_chat_handlers self.model_parameters = model_parameters + self._chat_history = chat_history self.parser = argparse.ArgumentParser() + self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) + self.dask_client_future = dask_client_future self.llm = None self.llm_params = None self.llm_chain = None diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index a2a39bb00..7042c4632 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -2,13 +2,17 @@ from jupyter_ai.models import ChatMessage, ClearMessage -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType class ClearChatHandler(BaseChatHandler): - def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): + id = "clear" + name = "Clear chat messages" + help = "Clears the displayed chat message history only; does not clear the context sent to chat providers" + routing_type = SlashCommandRoutingType(slash_id="clear") + + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._chat_history = chat_history async def process_message(self, _): self._chat_history.clear() diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index d329e05e2..5bd839ca5 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -12,7 +12,7 @@ SystemMessagePromptTemplate, ) -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType SYSTEM_PROMPT = """ You are Jupyternaut, a conversational assistant living in JupyterLab to help users. @@ -32,10 +32,14 @@ class DefaultChatHandler(BaseChatHandler): - def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): + id = "default" + name = "Default" + help = "Responds to prompts that are not otherwise handled by a chat handler" + routing_type = SlashCommandRoutingType(slash_id=None) + + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.memory = ConversationBufferWindowMemory(return_messages=True, k=2) - self.chat_history = chat_history def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -80,8 +84,8 @@ def clear_memory(self): self.reply(reply_message) # clear transcript for new chat clients - if self.chat_history: - self.chat_history.clear() + if self._chat_history: + self._chat_history.clear() async def process_message(self, message: HumanChatMessage): self.get_llm_chain() diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index ca18becc2..b3d5212ae 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Type import nbformat -from jupyter_ai.chat_handlers import BaseChatHandler +from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider from langchain.chains import LLMChain @@ -216,11 +216,13 @@ def create_notebook(outline): class GenerateChatHandler(BaseChatHandler): - """Generates a Jupyter notebook given a description.""" + id = "generate" + name = "Generate Notebook" + help = "Generates a Jupyter notebook, including name, outline, and section contents" + routing_type = SlashCommandRoutingType(slash_id="generate") - def __init__(self, root_dir: str, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) self.llm = None def create_llm_chain( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index be89d1165..cbf4c19c9 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -4,7 +4,7 @@ from jupyter_ai.models import AgentChatMessage, HumanChatMessage -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType HELP_MESSAGE = """Hi there! I'm Jupyternaut, your programming assistant. You can ask me a question using the text box below. You can also use these commands: @@ -29,6 +29,11 @@ def HelpMessage(): class HelpChatHandler(BaseChatHandler): + id = "help" + name = "Help" + help = "Displays a help message in the chat message area" + routing_type = SlashCommandRoutingType(slash_id="help") + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 40cae643b..825acf453 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -24,19 +24,20 @@ ) from langchain.vectorstores import FAISS -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), "jupyter_ai", "indices") METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, "metadata.json") class LearnChatHandler(BaseChatHandler): - def __init__( - self, root_dir: str, dask_client_future: Awaitable[DaskClient], *args, **kwargs - ): + id = "learn" + name = "Learn Local Data" + help = "Pass a list of files and directories. Once converted to vector format, you can ask about them with /ask." + routing_type = SlashCommandRoutingType(slash_id="learn") + + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.root_dir = root_dir - self.dask_client_future = dask_client_future self.parser.prog = "/learn" self.parser.add_argument("-a", "--all-files", action="store_true") self.parser.add_argument("-v", "--verbose", action="store_true") diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 8ab8c0cc6..ec08d9962 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,6 +1,9 @@ +import logging +import re import time from dask.distributed import Client as DaskClient +from importlib_metadata import entry_points from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp @@ -40,7 +43,7 @@ class AiExtension(ExtensionApp): allowed_providers = List( Unicode(), default_value=None, - help="Identifiers of allow-listed providers. If `None`, all are allowed.", + help="Identifiers of allowlisted providers. If `None`, all are allowed.", allow_none=True, config=True, ) @@ -48,7 +51,7 @@ class AiExtension(ExtensionApp): blocked_providers = List( Unicode(), default_value=None, - help="Identifiers of block-listed providers. If `None`, none are blocked.", + help="Identifiers of blocklisted providers. If `None`, none are blocked.", allow_none=True, config=True, ) @@ -156,32 +159,29 @@ def initialize_settings(self): # consumers a Future that resolves to the Dask client when awaited. dask_client_future = loop.create_task(self._get_dask_client()) + eps = entry_points() # initialize chat handlers + chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers") + chat_handler_kwargs = { "log": self.log, "config_manager": self.settings["jai_config_manager"], "root_chat_handlers": self.settings["jai_root_chat_handlers"], + "chat_history": self.settings["chat_history"], + "root_dir": self.serverapp.root_dir, + "dask_client_future": dask_client_future, "model_parameters": self.settings["model_parameters"], } - default_chat_handler = DefaultChatHandler( - **chat_handler_kwargs, chat_history=self.settings["chat_history"] - ) - clear_chat_handler = ClearChatHandler( - **chat_handler_kwargs, chat_history=self.settings["chat_history"] - ) - generate_chat_handler = GenerateChatHandler( - **chat_handler_kwargs, - root_dir=self.serverapp.root_dir, - ) - learn_chat_handler = LearnChatHandler( - **chat_handler_kwargs, - root_dir=self.serverapp.root_dir, - dask_client_future=dask_client_future, - ) + + default_chat_handler = DefaultChatHandler(**chat_handler_kwargs) + clear_chat_handler = ClearChatHandler(**chat_handler_kwargs) + generate_chat_handler = GenerateChatHandler(**chat_handler_kwargs) + learn_chat_handler = LearnChatHandler(**chat_handler_kwargs) help_chat_handler = HelpChatHandler(**chat_handler_kwargs) retriever = Retriever(learn_chat_handler=learn_chat_handler) ask_chat_handler = AskChatHandler(**chat_handler_kwargs, retriever=retriever) - self.settings["jai_chat_handlers"] = { + + jai_chat_handlers = { "default": default_chat_handler, "/ask": ask_chat_handler, "/clear": clear_chat_handler, @@ -190,6 +190,54 @@ def initialize_settings(self): "/help": help_chat_handler, } + slash_command_pattern = r"^[a-zA-Z0-9_]+$" + for chat_handler_ep in chat_handler_eps: + try: + chat_handler = chat_handler_ep.load() + except Exception as err: + self.log.error( + f"Unable to load chat handler class from entry point `{chat_handler_ep.name}`: " + + f"Unexpected {err=}, {type(err)=}" + ) + continue + + if chat_handler.routing_type.routing_method == "slash_command": + # Each slash ID must be used only once. + # Slash IDs may contain only alphanumerics and underscores. + slash_id = chat_handler.routing_type.slash_id + + if slash_id is None: + self.log.error( + f"Handler `{chat_handler_ep.name}` has an invalid slash command " + + f"`None`; only the default chat handler may use this" + ) + continue + + # Validate slash ID (/^[A-Za-z0-9_]+$/) + if re.match(slash_command_pattern, slash_id): + command_name = f"/{slash_id}" + else: + self.log.error( + f"Handler `{chat_handler_ep.name}` has an invalid slash command " + + f"`{slash_id}`; must contain only letters, numbers, " + + "and underscores" + ) + continue + + if command_name in jai_chat_handlers: + self.log.error( + f"Unable to register chat handler `{chat_handler.id}` because command `{command_name}` already has a handler" + ) + continue + + # The entry point is a class; we need to instantiate the class to send messages to it + jai_chat_handlers[command_name] = chat_handler(**chat_handler_kwargs) + self.log.info( + f"Registered chat handler `{chat_handler.id}` with command `{command_name}`." + ) + + self.settings["jai_chat_handlers"] = jai_chat_handlers + latency_ms = round((time.time() - start) * 1000) self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 170ae4006..ae1498946 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -165,7 +165,7 @@ def open(self): def broadcast_message(self, message: Message): """Broadcasts message to all connected clients. - Appends message to `self.chat_history`. + Appends message to chat history. """ self.log.debug("Broadcasting message: %s to all clients...", message)