Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for custom providers #713

Merged
merged 5 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from .exception import store_exception
from .magics import AiMagics

# expose JupyternautPersona on the package root
# required by `jupyter-ai`.
from .models.persona import JupyternautPersona, Persona

# expose model providers on the package root
from .providers import (
AI21Provider,
Expand Down
26 changes: 26 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from langchain.pydantic_v1 import BaseModel


class Persona(BaseModel):
"""
Model of an **agent persona**, a struct that includes the name & avatar
shown on agent replies in the chat UI.

Each persona is specific to a single provider, set on the `persona` field.
"""

name: str = ...
"""
Name of the persona, e.g. "Jupyternaut". This is used to render the name
shown on agent replies in the chat UI.
"""

avatar_route: str = ...
"""
The server route that should be used the avatar of this persona. This is
used to render the avatar shown on agent replies in the chat UI.
"""


JUPYTERNAUT_AVATAR_ROUTE = "api/ai/static/jupyternaut.svg"
JupyternautPersona = Persona(name="Jupyternaut", avatar_route=JUPYTERNAUT_AVATAR_ROUTE)
25 changes: 25 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
except:
from pydantic.main import ModelMetaclass

from .models.persona import Persona

CHAT_SYSTEM_PROMPT = """
You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
Expand Down Expand Up @@ -214,6 +215,30 @@ class Config:
"""User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""

manages_history: ClassVar[bool] = False
"""Whether this provider manages its own conversation history upstream. If
set to `True`, Jupyter AI will not pass the chat history to this provider
when invoked."""

persona: ClassVar[Optional[Persona]] = None
"""
The **persona** of this provider, a struct that defines the name and avatar
shown on agent replies in the chat UI. When set to `None`, `jupyter-ai` will
choose a default persona when rendering agent messages by this provider.

Because this field is set to `None` by default, `jupyter-ai` will render a
default persona for all providers that are included natively with the
`jupyter-ai` package. This field is reserved for Jupyter AI modules that
dlqqq marked this conversation as resolved.
Show resolved Hide resolved
serve a custom provider and want to distinguish it in the chat UI.
"""

unsupported_slash_commands: ClassVar[set] = {}
"""
A set of slash commands unsupported by this provider. Unsupported slash
commands are not shown in the help message, and cannot be used while this
provider is selected.
"""

#
# instance attrs
#
Expand Down
21 changes: 20 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dask.distributed import Client as DaskClient
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage
from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from langchain.pydantic_v1 import BaseModel

Expand Down Expand Up @@ -94,10 +95,21 @@ async def on_message(self, message: HumanChatMessage):
`self.handle_exc()` when an exception is raised. This method is called
by RootChatHandler when it routes a human message to this chat handler.
"""
lm_provider_klass = self.config_manager.lm_provider

# ensure the current slash command is supported
if self.routing_type.routing_method == "slash_command":
slash_command = (
"/" + self.routing_type.slash_id if self.routing_type.slash_id else ""
)
if slash_command in lm_provider_klass.unsupported_slash_commands:
self.reply(
"Sorry, the selected language model does not support this slash command."
)
return

# check whether the configured LLM can support a request at this time.
if self.uses_llm and BaseChatHandler._requests_count > 0:
lm_provider_klass = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
lm_provider = lm_provider_klass(**lm_provider_params)

Expand Down Expand Up @@ -159,11 +171,18 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage):
self.reply(response, message)

def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
"""
Sends an agent message, usually in response to a received
`HumanChatMessage`.
"""
persona = self.config_manager.persona

agent_msg = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=response,
reply_to=human_msg.id if human_msg else "",
persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
)

for handler in self._root_chat_handlers.values():
Expand Down
14 changes: 9 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationChain
from langchain.chains import ConversationChain, LLMChain
from langchain.memory import ConversationBufferWindowMemory

from .base import BaseChatHandler, SlashCommandRoutingType
Expand Down Expand Up @@ -30,14 +30,18 @@ def create_llm_chain(
llm = provider(**unified_parameters)

prompt_template = llm.get_chat_prompt_template()
self.llm = llm
self.memory = ConversationBufferWindowMemory(
return_messages=llm.is_chat_provider, k=2
)

self.llm = llm
self.llm_chain = ConversationChain(
llm=llm, prompt=prompt_template, verbose=True, memory=self.memory
)
if llm.manages_history:
self.llm_chain = LLMChain(llm=llm, prompt=prompt_template, verbose=True)

else:
self.llm_chain = ConversationChain(
llm=llm, prompt=prompt_template, verbose=True, memory=self.memory
)

def clear_memory(self):
# clear chain memory
Expand Down
24 changes: 19 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from uuid import uuid4

from jupyter_ai.models import AgentChatMessage, HumanChatMessage
from jupyter_ai_magics import Persona

from .base import BaseChatHandler, SlashCommandRoutingType

HELP_MESSAGE = """Hi there! I'm Jupyternaut, your programming assistant.
HELP_MESSAGE = """Hi there! I'm {persona_name}, your programming assistant.
dlqqq marked this conversation as resolved.
Show resolved Hide resolved
You can ask me a question using the text box below. You can also use these commands:
{commands}

Expand All @@ -15,23 +16,36 @@
"""


def _format_help_message(chat_handlers: Dict[str, BaseChatHandler]):
def _format_help_message(
chat_handlers: Dict[str, BaseChatHandler],
persona: Persona,
unsupported_slash_commands: set,
):
if unsupported_slash_commands:
keys = set(chat_handlers.keys()) - unsupported_slash_commands
chat_handlers = {key: chat_handlers[key] for key in keys}

commands = "\n".join(
[
f"* `{command_name}` — {handler.help}"
for command_name, handler in chat_handlers.items()
if command_name != "default"
]
)
return HELP_MESSAGE.format(commands=commands)
return HELP_MESSAGE.format(commands=commands, persona_name=persona.name)


def HelpMessage(chat_handlers: Dict[str, BaseChatHandler]):
def build_help_message(
chat_handlers: Dict[str, BaseChatHandler],
persona: Persona,
unsupported_slash_commands: set,
):
return AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=_format_help_message(chat_handlers),
body=_format_help_message(chat_handlers, persona, unsupported_slash_commands),
reply_to="",
persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
)


Expand Down
12 changes: 12 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest
from jupyter_ai_magics import JupyternautPersona, Persona
from jupyter_ai_magics.utils import (
AnyProvider,
EmProvidersDict,
Expand Down Expand Up @@ -452,3 +453,14 @@ def em_provider_params(self):
"model_id": em_lid,
**authn_fields,
}

@property
def persona(self) -> Persona:
"""
The current agent persona, set by the selected LM provider. If the
selected LM provider is `None`, this property returns
`JupyternautPersona` by default.
"""
lm_provider = self.lm_provider
persona = getattr(lm_provider, "persona", None) or JupyternautPersona
return persona
47 changes: 42 additions & 5 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
import os
import re
import time

from dask.distributed import Client as DaskClient
from importlib_metadata import entry_points
from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai_magics import JupyternautPersona
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp
from tornado.web import StaticFileHandler
from traitlets import Dict, List, Unicode

from .chat_handlers import (
Expand All @@ -18,7 +20,7 @@
HelpChatHandler,
LearnChatHandler,
)
from .chat_handlers.help import HelpMessage
from .chat_handlers.help import build_help_message
from .completions.handlers import DefaultInlineCompletionHandler
from .config_manager import ConfigManager
from .handlers import (
Expand All @@ -30,6 +32,11 @@
RootChatHandler,
)

JUPYTERNAUT_AVATAR_ROUTE = JupyternautPersona.avatar_route
JUPYTERNAUT_AVATAR_PATH = str(
os.path.join(os.path.dirname(__file__), "static", "jupyternaut.svg")
)


class AiExtension(ExtensionApp):
name = "jupyter_ai"
Expand All @@ -41,6 +48,14 @@ class AiExtension(ExtensionApp):
(r"api/ai/providers?", ModelProviderHandler),
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
(r"api/ai/completion/inline/?", DefaultInlineCompletionHandler),
# serve the default persona avatar at this path.
# the `()` at the end of the URL denotes an empty regex capture group,
# required by Tornado.
(
rf"{JUPYTERNAUT_AVATAR_ROUTE}()",
StaticFileHandler,
{"path": JUPYTERNAUT_AVATAR_PATH},
),
]

allowed_providers = List(
Expand Down Expand Up @@ -303,14 +318,36 @@ def initialize_settings(self):
# Make help always appear as the last command
jai_chat_handlers["/help"] = help_chat_handler

self.settings["chat_history"].append(
HelpMessage(chat_handlers=jai_chat_handlers)
)
# bind chat handlers to settings
self.settings["jai_chat_handlers"] = jai_chat_handlers

# show help message at server start
self._show_help_message()

latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.")

def _show_help_message(self):
"""
Method that ensures a dynamically-generated help message is included in
the chat history shown to users.
"""
chat_handlers = self.settings["jai_chat_handlers"]
config_manager: ConfigManager = self.settings["jai_config_manager"]
lm_provider = config_manager.lm_provider

if not lm_provider:
return

persona = config_manager.persona
unsupported_slash_commands = (
lm_provider.unsupported_slash_commands if lm_provider else set()
)
help_message = build_help_message(
chat_handlers, persona, unsupported_slash_commands
)
self.settings["chat_history"].append(help_message)

async def _get_dask_client(self):
return DaskClient(processes=False, asynchronous=True)

Expand Down
13 changes: 12 additions & 1 deletion packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Literal, Optional, Union

from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import AuthStrategy, Field
from langchain.pydantic_v1 import BaseModel, validator

Expand Down Expand Up @@ -34,8 +35,18 @@ class AgentChatMessage(BaseModel):
id: str
time: float
body: str
# message ID of the HumanChatMessage it is replying to

reply_to: str
"""
Message ID of the HumanChatMessage being replied to. This is set to an empty
string if not applicable.
"""

persona: Persona
"""
The persona of the selected provider. If the selected provider is `None`,
this defaults to a description of `JupyternautPersona`.
"""


class HumanChatMessage(BaseModel):
Expand Down
9 changes: 9 additions & 0 deletions packages/jupyter-ai/jupyter_ai/static/jupyternaut.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading