diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index f81ea11c8..7c609a606 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -12,6 +12,10 @@ from .exception import store_exception from .magics import AiMagics +# expose JupyternautPersona on the package root +# required by `jupyter-ai`. +from .models.persona import JupyternautPersona, Persona + # expose model providers on the package root from .providers import ( AI21Provider, diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py new file mode 100644 index 000000000..fe25397b0 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py @@ -0,0 +1,26 @@ +from langchain.pydantic_v1 import BaseModel + + +class Persona(BaseModel): + """ + Model of an **agent persona**, a struct that includes the name & avatar + shown on agent replies in the chat UI. + + Each persona is specific to a single provider, set on the `persona` field. + """ + + name: str = ... + """ + Name of the persona, e.g. "Jupyternaut". This is used to render the name + shown on agent replies in the chat UI. + """ + + avatar_route: str = ... + """ + The server route that should be used the avatar of this persona. This is + used to render the avatar shown on agent replies in the chat UI. + """ + + +JUPYTERNAUT_AVATAR_ROUTE = "api/ai/static/jupyternaut.svg" +JupyternautPersona = Persona(name="Jupyternaut", avatar_route=JUPYTERNAUT_AVATAR_ROUTE) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 3fcdf9abc..c07061b1d 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -46,6 +46,7 @@ except: from pydantic.main import ModelMetaclass +from .models.persona import Persona CHAT_SYSTEM_PROMPT = """ You are Jupyternaut, a conversational assistant living in JupyterLab to help users. @@ -214,6 +215,30 @@ class Config: """User inputs expected by this provider when initializing it. Each `Field` `f` should be passed in the constructor as a keyword argument, keyed by `f.key`.""" + manages_history: ClassVar[bool] = False + """Whether this provider manages its own conversation history upstream. If + set to `True`, Jupyter AI will not pass the chat history to this provider + when invoked.""" + + persona: ClassVar[Optional[Persona]] = None + """ + The **persona** of this provider, a struct that defines the name and avatar + shown on agent replies in the chat UI. When set to `None`, `jupyter-ai` will + choose a default persona when rendering agent messages by this provider. + + Because this field is set to `None` by default, `jupyter-ai` will render a + default persona for all providers that are included natively with the + `jupyter-ai` package. This field is reserved for Jupyter AI modules that + serve a custom provider and want to distinguish it in the chat UI. + """ + + unsupported_slash_commands: ClassVar[set] = {} + """ + A set of slash commands unsupported by this provider. Unsupported slash + commands are not shown in the help message, and cannot be used while this + provider is selected. + """ + # # instance attrs # diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index bcbba00ba..1ae80c5c5 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -17,6 +17,7 @@ from dask.distributed import Client as DaskClient from jupyter_ai.config_manager import ConfigManager, Logger from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage +from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider from langchain.pydantic_v1 import BaseModel @@ -94,10 +95,21 @@ async def on_message(self, message: HumanChatMessage): `self.handle_exc()` when an exception is raised. This method is called by RootChatHandler when it routes a human message to this chat handler. """ + lm_provider_klass = self.config_manager.lm_provider + + # ensure the current slash command is supported + if self.routing_type.routing_method == "slash_command": + slash_command = ( + "/" + self.routing_type.slash_id if self.routing_type.slash_id else "" + ) + if slash_command in lm_provider_klass.unsupported_slash_commands: + self.reply( + "Sorry, the selected language model does not support this slash command." + ) + return # check whether the configured LLM can support a request at this time. if self.uses_llm and BaseChatHandler._requests_count > 0: - lm_provider_klass = self.config_manager.lm_provider lm_provider_params = self.config_manager.lm_provider_params lm_provider = lm_provider_klass(**lm_provider_params) @@ -159,11 +171,18 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): self.reply(response, message) def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): + """ + Sends an agent message, usually in response to a received + `HumanChatMessage`. + """ + persona = self.config_manager.persona + agent_msg = AgentChatMessage( id=uuid4().hex, time=time.time(), body=response, reply_to=human_msg.id if human_msg else "", + persona=Persona(name=persona.name, avatar_route=persona.avatar_route), ) for handler in self._root_chat_handlers.values(): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 8352a8f8d..df288d409 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -2,7 +2,7 @@ from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider -from langchain.chains import ConversationChain +from langchain.chains import ConversationChain, LLMChain from langchain.memory import ConversationBufferWindowMemory from .base import BaseChatHandler, SlashCommandRoutingType @@ -30,14 +30,18 @@ def create_llm_chain( llm = provider(**unified_parameters) prompt_template = llm.get_chat_prompt_template() + self.llm = llm self.memory = ConversationBufferWindowMemory( return_messages=llm.is_chat_provider, k=2 ) - self.llm = llm - self.llm_chain = ConversationChain( - llm=llm, prompt=prompt_template, verbose=True, memory=self.memory - ) + if llm.manages_history: + self.llm_chain = LLMChain(llm=llm, prompt=prompt_template, verbose=True) + + else: + self.llm_chain = ConversationChain( + llm=llm, prompt=prompt_template, verbose=True, memory=self.memory + ) def clear_memory(self): # clear chain memory diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index ebb8f0383..e46038da5 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -3,10 +3,11 @@ from uuid import uuid4 from jupyter_ai.models import AgentChatMessage, HumanChatMessage +from jupyter_ai_magics import Persona from .base import BaseChatHandler, SlashCommandRoutingType -HELP_MESSAGE = """Hi there! I'm Jupyternaut, your programming assistant. +HELP_MESSAGE = """Hi there! I'm {persona_name}, your programming assistant. You can ask me a question using the text box below. You can also use these commands: {commands} @@ -15,7 +16,15 @@ """ -def _format_help_message(chat_handlers: Dict[str, BaseChatHandler]): +def _format_help_message( + chat_handlers: Dict[str, BaseChatHandler], + persona: Persona, + unsupported_slash_commands: set, +): + if unsupported_slash_commands: + keys = set(chat_handlers.keys()) - unsupported_slash_commands + chat_handlers = {key: chat_handlers[key] for key in keys} + commands = "\n".join( [ f"* `{command_name}` — {handler.help}" @@ -23,15 +32,20 @@ def _format_help_message(chat_handlers: Dict[str, BaseChatHandler]): if command_name != "default" ] ) - return HELP_MESSAGE.format(commands=commands) + return HELP_MESSAGE.format(commands=commands, persona_name=persona.name) -def HelpMessage(chat_handlers: Dict[str, BaseChatHandler]): +def build_help_message( + chat_handlers: Dict[str, BaseChatHandler], + persona: Persona, + unsupported_slash_commands: set, +): return AgentChatMessage( id=uuid4().hex, time=time.time(), - body=_format_help_message(chat_handlers), + body=_format_help_message(chat_handlers, persona, unsupported_slash_commands), reply_to="", + persona=Persona(name=persona.name, avatar_route=persona.avatar_route), ) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 01d3fe766..392e44601 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -8,6 +8,7 @@ from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest +from jupyter_ai_magics import JupyternautPersona, Persona from jupyter_ai_magics.utils import ( AnyProvider, EmProvidersDict, @@ -452,3 +453,14 @@ def em_provider_params(self): "model_id": em_lid, **authn_fields, } + + @property + def persona(self) -> Persona: + """ + The current agent persona, set by the selected LM provider. If the + selected LM provider is `None`, this property returns + `JupyternautPersona` by default. + """ + lm_provider = self.lm_provider + persona = getattr(lm_provider, "persona", None) or JupyternautPersona + return persona diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 245a1c957..0a66a8b1b 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,12 +1,14 @@ -import logging +import os import re import time from dask.distributed import Client as DaskClient from importlib_metadata import entry_points from jupyter_ai.chat_handlers.learn import Retriever +from jupyter_ai_magics import JupyternautPersona from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp +from tornado.web import StaticFileHandler from traitlets import Dict, List, Unicode from .chat_handlers import ( @@ -18,7 +20,7 @@ HelpChatHandler, LearnChatHandler, ) -from .chat_handlers.help import HelpMessage +from .chat_handlers.help import build_help_message from .completions.handlers import DefaultInlineCompletionHandler from .config_manager import ConfigManager from .handlers import ( @@ -30,6 +32,11 @@ RootChatHandler, ) +JUPYTERNAUT_AVATAR_ROUTE = JupyternautPersona.avatar_route +JUPYTERNAUT_AVATAR_PATH = str( + os.path.join(os.path.dirname(__file__), "static", "jupyternaut.svg") +) + class AiExtension(ExtensionApp): name = "jupyter_ai" @@ -41,6 +48,14 @@ class AiExtension(ExtensionApp): (r"api/ai/providers?", ModelProviderHandler), (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), (r"api/ai/completion/inline/?", DefaultInlineCompletionHandler), + # serve the default persona avatar at this path. + # the `()` at the end of the URL denotes an empty regex capture group, + # required by Tornado. + ( + rf"{JUPYTERNAUT_AVATAR_ROUTE}()", + StaticFileHandler, + {"path": JUPYTERNAUT_AVATAR_PATH}, + ), ] allowed_providers = List( @@ -303,14 +318,36 @@ def initialize_settings(self): # Make help always appear as the last command jai_chat_handlers["/help"] = help_chat_handler - self.settings["chat_history"].append( - HelpMessage(chat_handlers=jai_chat_handlers) - ) + # bind chat handlers to settings self.settings["jai_chat_handlers"] = jai_chat_handlers + # show help message at server start + self._show_help_message() + latency_ms = round((time.time() - start) * 1000) self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") + def _show_help_message(self): + """ + Method that ensures a dynamically-generated help message is included in + the chat history shown to users. + """ + chat_handlers = self.settings["jai_chat_handlers"] + config_manager: ConfigManager = self.settings["jai_config_manager"] + lm_provider = config_manager.lm_provider + + if not lm_provider: + return + + persona = config_manager.persona + unsupported_slash_commands = ( + lm_provider.unsupported_slash_commands if lm_provider else set() + ) + help_message = build_help_message( + chat_handlers, persona, unsupported_slash_commands + ) + self.settings["chat_history"].append(help_message) + async def _get_dask_client(self): return DaskClient(processes=False, asynchronous=True) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 41509a74d..32353a694 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Literal, Optional, Union +from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import AuthStrategy, Field from langchain.pydantic_v1 import BaseModel, validator @@ -34,8 +35,18 @@ class AgentChatMessage(BaseModel): id: str time: float body: str - # message ID of the HumanChatMessage it is replying to + reply_to: str + """ + Message ID of the HumanChatMessage being replied to. This is set to an empty + string if not applicable. + """ + + persona: Persona + """ + The persona of the selected provider. If the selected provider is `None`, + this defaults to a description of `JupyternautPersona`. + """ class HumanChatMessage(BaseModel): diff --git a/packages/jupyter-ai/jupyter_ai/static/jupyternaut.svg b/packages/jupyter-ai/jupyter_ai/static/jupyternaut.svg new file mode 100644 index 000000000..dd800d538 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/static/jupyternaut.svg @@ -0,0 +1,9 @@ + + + + + + diff --git a/packages/jupyter-ai/src/components/chat-messages.tsx b/packages/jupyter-ai/src/components/chat-messages.tsx index dd889cf78..8a2e4b658 100644 --- a/packages/jupyter-ai/src/components/chat-messages.tsx +++ b/packages/jupyter-ai/src/components/chat-messages.tsx @@ -2,10 +2,11 @@ import React, { useState, useEffect } from 'react'; import { Avatar, Box, Typography } from '@mui/material'; import type { SxProps, Theme } from '@mui/material'; +import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; +import { ServerConnection } from '@jupyterlab/services'; +// TODO: delete jupyternaut from frontend package import { AiService } from '../handler'; -import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; -import { Jupyternaut } from '../icons'; import { RendermimeMarkdown } from './rendermime-markdown'; import { useCollaboratorsContext } from '../contexts/collaborators-context'; @@ -49,9 +50,11 @@ export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element { ); } else { + const baseUrl = ServerConnection.makeSettings().baseUrl; + const avatar_url = baseUrl + props.message.persona.avatar_route; avatar = ( - - + + ); } @@ -59,7 +62,7 @@ export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element { const name = props.message.type === 'human' ? props.message.client.display_name - : 'Jupyternaut'; + : props.message.persona.name; return (