diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
index f81ea11c8..7c609a606 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
@@ -12,6 +12,10 @@
from .exception import store_exception
from .magics import AiMagics
+# expose JupyternautPersona on the package root
+# required by `jupyter-ai`.
+from .models.persona import JupyternautPersona, Persona
+
# expose model providers on the package root
from .providers import (
AI21Provider,
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py
new file mode 100644
index 000000000..fe25397b0
--- /dev/null
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py
@@ -0,0 +1,26 @@
+from langchain.pydantic_v1 import BaseModel
+
+
+class Persona(BaseModel):
+ """
+ Model of an **agent persona**, a struct that includes the name & avatar
+ shown on agent replies in the chat UI.
+
+ Each persona is specific to a single provider, set on the `persona` field.
+ """
+
+ name: str = ...
+ """
+ Name of the persona, e.g. "Jupyternaut". This is used to render the name
+ shown on agent replies in the chat UI.
+ """
+
+ avatar_route: str = ...
+ """
+ The server route that should be used the avatar of this persona. This is
+ used to render the avatar shown on agent replies in the chat UI.
+ """
+
+
+JUPYTERNAUT_AVATAR_ROUTE = "api/ai/static/jupyternaut.svg"
+JupyternautPersona = Persona(name="Jupyternaut", avatar_route=JUPYTERNAUT_AVATAR_ROUTE)
diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
index 3fcdf9abc..c07061b1d 100644
--- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
+++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
@@ -46,6 +46,7 @@
except:
from pydantic.main import ModelMetaclass
+from .models.persona import Persona
CHAT_SYSTEM_PROMPT = """
You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
@@ -214,6 +215,30 @@ class Config:
"""User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""
+ manages_history: ClassVar[bool] = False
+ """Whether this provider manages its own conversation history upstream. If
+ set to `True`, Jupyter AI will not pass the chat history to this provider
+ when invoked."""
+
+ persona: ClassVar[Optional[Persona]] = None
+ """
+ The **persona** of this provider, a struct that defines the name and avatar
+ shown on agent replies in the chat UI. When set to `None`, `jupyter-ai` will
+ choose a default persona when rendering agent messages by this provider.
+
+ Because this field is set to `None` by default, `jupyter-ai` will render a
+ default persona for all providers that are included natively with the
+ `jupyter-ai` package. This field is reserved for Jupyter AI modules that
+ serve a custom provider and want to distinguish it in the chat UI.
+ """
+
+ unsupported_slash_commands: ClassVar[set] = {}
+ """
+ A set of slash commands unsupported by this provider. Unsupported slash
+ commands are not shown in the help message, and cannot be used while this
+ provider is selected.
+ """
+
#
# instance attrs
#
diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
index bcbba00ba..1ae80c5c5 100644
--- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
+++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
@@ -17,6 +17,7 @@
from dask.distributed import Client as DaskClient
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage
+from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from langchain.pydantic_v1 import BaseModel
@@ -94,10 +95,21 @@ async def on_message(self, message: HumanChatMessage):
`self.handle_exc()` when an exception is raised. This method is called
by RootChatHandler when it routes a human message to this chat handler.
"""
+ lm_provider_klass = self.config_manager.lm_provider
+
+ # ensure the current slash command is supported
+ if self.routing_type.routing_method == "slash_command":
+ slash_command = (
+ "/" + self.routing_type.slash_id if self.routing_type.slash_id else ""
+ )
+ if slash_command in lm_provider_klass.unsupported_slash_commands:
+ self.reply(
+ "Sorry, the selected language model does not support this slash command."
+ )
+ return
# check whether the configured LLM can support a request at this time.
if self.uses_llm and BaseChatHandler._requests_count > 0:
- lm_provider_klass = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
lm_provider = lm_provider_klass(**lm_provider_params)
@@ -159,11 +171,18 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage):
self.reply(response, message)
def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
+ """
+ Sends an agent message, usually in response to a received
+ `HumanChatMessage`.
+ """
+ persona = self.config_manager.persona
+
agent_msg = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=response,
reply_to=human_msg.id if human_msg else "",
+ persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
)
for handler in self._root_chat_handlers.values():
diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
index 8352a8f8d..df288d409 100644
--- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
+++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
@@ -2,7 +2,7 @@
from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
-from langchain.chains import ConversationChain
+from langchain.chains import ConversationChain, LLMChain
from langchain.memory import ConversationBufferWindowMemory
from .base import BaseChatHandler, SlashCommandRoutingType
@@ -30,14 +30,18 @@ def create_llm_chain(
llm = provider(**unified_parameters)
prompt_template = llm.get_chat_prompt_template()
+ self.llm = llm
self.memory = ConversationBufferWindowMemory(
return_messages=llm.is_chat_provider, k=2
)
- self.llm = llm
- self.llm_chain = ConversationChain(
- llm=llm, prompt=prompt_template, verbose=True, memory=self.memory
- )
+ if llm.manages_history:
+ self.llm_chain = LLMChain(llm=llm, prompt=prompt_template, verbose=True)
+
+ else:
+ self.llm_chain = ConversationChain(
+ llm=llm, prompt=prompt_template, verbose=True, memory=self.memory
+ )
def clear_memory(self):
# clear chain memory
diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
index ebb8f0383..e46038da5 100644
--- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
+++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
@@ -3,10 +3,11 @@
from uuid import uuid4
from jupyter_ai.models import AgentChatMessage, HumanChatMessage
+from jupyter_ai_magics import Persona
from .base import BaseChatHandler, SlashCommandRoutingType
-HELP_MESSAGE = """Hi there! I'm Jupyternaut, your programming assistant.
+HELP_MESSAGE = """Hi there! I'm {persona_name}, your programming assistant.
You can ask me a question using the text box below. You can also use these commands:
{commands}
@@ -15,7 +16,15 @@
"""
-def _format_help_message(chat_handlers: Dict[str, BaseChatHandler]):
+def _format_help_message(
+ chat_handlers: Dict[str, BaseChatHandler],
+ persona: Persona,
+ unsupported_slash_commands: set,
+):
+ if unsupported_slash_commands:
+ keys = set(chat_handlers.keys()) - unsupported_slash_commands
+ chat_handlers = {key: chat_handlers[key] for key in keys}
+
commands = "\n".join(
[
f"* `{command_name}` — {handler.help}"
@@ -23,15 +32,20 @@ def _format_help_message(chat_handlers: Dict[str, BaseChatHandler]):
if command_name != "default"
]
)
- return HELP_MESSAGE.format(commands=commands)
+ return HELP_MESSAGE.format(commands=commands, persona_name=persona.name)
-def HelpMessage(chat_handlers: Dict[str, BaseChatHandler]):
+def build_help_message(
+ chat_handlers: Dict[str, BaseChatHandler],
+ persona: Persona,
+ unsupported_slash_commands: set,
+):
return AgentChatMessage(
id=uuid4().hex,
time=time.time(),
- body=_format_help_message(chat_handlers),
+ body=_format_help_message(chat_handlers, persona, unsupported_slash_commands),
reply_to="",
+ persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
)
diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py
index 01d3fe766..392e44601 100644
--- a/packages/jupyter-ai/jupyter_ai/config_manager.py
+++ b/packages/jupyter-ai/jupyter_ai/config_manager.py
@@ -8,6 +8,7 @@
from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest
+from jupyter_ai_magics import JupyternautPersona, Persona
from jupyter_ai_magics.utils import (
AnyProvider,
EmProvidersDict,
@@ -452,3 +453,14 @@ def em_provider_params(self):
"model_id": em_lid,
**authn_fields,
}
+
+ @property
+ def persona(self) -> Persona:
+ """
+ The current agent persona, set by the selected LM provider. If the
+ selected LM provider is `None`, this property returns
+ `JupyternautPersona` by default.
+ """
+ lm_provider = self.lm_provider
+ persona = getattr(lm_provider, "persona", None) or JupyternautPersona
+ return persona
diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py
index 245a1c957..0a66a8b1b 100644
--- a/packages/jupyter-ai/jupyter_ai/extension.py
+++ b/packages/jupyter-ai/jupyter_ai/extension.py
@@ -1,12 +1,14 @@
-import logging
+import os
import re
import time
from dask.distributed import Client as DaskClient
from importlib_metadata import entry_points
from jupyter_ai.chat_handlers.learn import Retriever
+from jupyter_ai_magics import JupyternautPersona
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp
+from tornado.web import StaticFileHandler
from traitlets import Dict, List, Unicode
from .chat_handlers import (
@@ -18,7 +20,7 @@
HelpChatHandler,
LearnChatHandler,
)
-from .chat_handlers.help import HelpMessage
+from .chat_handlers.help import build_help_message
from .completions.handlers import DefaultInlineCompletionHandler
from .config_manager import ConfigManager
from .handlers import (
@@ -30,6 +32,11 @@
RootChatHandler,
)
+JUPYTERNAUT_AVATAR_ROUTE = JupyternautPersona.avatar_route
+JUPYTERNAUT_AVATAR_PATH = str(
+ os.path.join(os.path.dirname(__file__), "static", "jupyternaut.svg")
+)
+
class AiExtension(ExtensionApp):
name = "jupyter_ai"
@@ -41,6 +48,14 @@ class AiExtension(ExtensionApp):
(r"api/ai/providers?", ModelProviderHandler),
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
(r"api/ai/completion/inline/?", DefaultInlineCompletionHandler),
+ # serve the default persona avatar at this path.
+ # the `()` at the end of the URL denotes an empty regex capture group,
+ # required by Tornado.
+ (
+ rf"{JUPYTERNAUT_AVATAR_ROUTE}()",
+ StaticFileHandler,
+ {"path": JUPYTERNAUT_AVATAR_PATH},
+ ),
]
allowed_providers = List(
@@ -303,14 +318,36 @@ def initialize_settings(self):
# Make help always appear as the last command
jai_chat_handlers["/help"] = help_chat_handler
- self.settings["chat_history"].append(
- HelpMessage(chat_handlers=jai_chat_handlers)
- )
+ # bind chat handlers to settings
self.settings["jai_chat_handlers"] = jai_chat_handlers
+ # show help message at server start
+ self._show_help_message()
+
latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.")
+ def _show_help_message(self):
+ """
+ Method that ensures a dynamically-generated help message is included in
+ the chat history shown to users.
+ """
+ chat_handlers = self.settings["jai_chat_handlers"]
+ config_manager: ConfigManager = self.settings["jai_config_manager"]
+ lm_provider = config_manager.lm_provider
+
+ if not lm_provider:
+ return
+
+ persona = config_manager.persona
+ unsupported_slash_commands = (
+ lm_provider.unsupported_slash_commands if lm_provider else set()
+ )
+ help_message = build_help_message(
+ chat_handlers, persona, unsupported_slash_commands
+ )
+ self.settings["chat_history"].append(help_message)
+
async def _get_dask_client(self):
return DaskClient(processes=False, asynchronous=True)
diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py
index 41509a74d..32353a694 100644
--- a/packages/jupyter-ai/jupyter_ai/models.py
+++ b/packages/jupyter-ai/jupyter_ai/models.py
@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Literal, Optional, Union
+from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import AuthStrategy, Field
from langchain.pydantic_v1 import BaseModel, validator
@@ -34,8 +35,18 @@ class AgentChatMessage(BaseModel):
id: str
time: float
body: str
- # message ID of the HumanChatMessage it is replying to
+
reply_to: str
+ """
+ Message ID of the HumanChatMessage being replied to. This is set to an empty
+ string if not applicable.
+ """
+
+ persona: Persona
+ """
+ The persona of the selected provider. If the selected provider is `None`,
+ this defaults to a description of `JupyternautPersona`.
+ """
class HumanChatMessage(BaseModel):
diff --git a/packages/jupyter-ai/jupyter_ai/static/jupyternaut.svg b/packages/jupyter-ai/jupyter_ai/static/jupyternaut.svg
new file mode 100644
index 000000000..dd800d538
--- /dev/null
+++ b/packages/jupyter-ai/jupyter_ai/static/jupyternaut.svg
@@ -0,0 +1,9 @@
+
diff --git a/packages/jupyter-ai/src/components/chat-messages.tsx b/packages/jupyter-ai/src/components/chat-messages.tsx
index dd889cf78..8a2e4b658 100644
--- a/packages/jupyter-ai/src/components/chat-messages.tsx
+++ b/packages/jupyter-ai/src/components/chat-messages.tsx
@@ -2,10 +2,11 @@ import React, { useState, useEffect } from 'react';
import { Avatar, Box, Typography } from '@mui/material';
import type { SxProps, Theme } from '@mui/material';
+import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
+import { ServerConnection } from '@jupyterlab/services';
+// TODO: delete jupyternaut from frontend package
import { AiService } from '../handler';
-import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
-import { Jupyternaut } from '../icons';
import { RendermimeMarkdown } from './rendermime-markdown';
import { useCollaboratorsContext } from '../contexts/collaborators-context';
@@ -49,9 +50,11 @@ export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element {
);
} else {
+ const baseUrl = ServerConnection.makeSettings().baseUrl;
+ const avatar_url = baseUrl + props.message.persona.avatar_route;
avatar = (
-
-
+
+
);
}
@@ -59,7 +62,7 @@ export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element {
const name =
props.message.type === 'human'
? props.message.client.display_name
- : 'Jupyternaut';
+ : props.message.persona.name;
return (