Skip to content

Commit

Permalink
[v3-dev] Initial migration to jupyterlab-chat (jupyterlab#1043)
Browse files Browse the repository at this point in the history
* Very first version of the AI working in jupyterlab_collaborative_chat

* Allows both collaborative and regular chat to work with AI

* handle the help message in the chat too

* Autocompletion (#2)

* Fix handler methods' parameters

* Add slash commands (autocompletion) to the chat input

* Stream messages (jupyterlab#3)

* Allow for stream messages

* update jupyter collaborative chat dependency

* AI settings (jupyterlab#4)

* Add a menu option to open the AI settings

* Remove the input option from the setting widget

* pre-commit

* linting

* Homogeneize typing for optional arguments

* Fix import

* Showing that the bot is writing (answering) (jupyterlab#5)

* Show that the bot is writing (answering)

* Update jupyter chat dependency

* Some typing

* Update extension to jupyterlab_chat (0.6.0) (jupyterlab#8)

* Fix linting

* Remove try/except to import jupyterlab_chat (not optional anymore), and fix typing

* linter

* Python unit tests

* Fix typing

* lint

* Fix lint and mypy all together

* Fix web_app settings accessor

* Fix jupyter_collaboration version

Co-authored-by: david qiu <[email protected]>

* Remove unecessary try/except

* Dedicate one set of chat handlers per room (jupyterlab#9)

* create new set of chat handlers per room

* make YChat an instance attribute on BaseChatHandler

* revert changes to chat handlers

* pre-commit

* use room_id local var

Co-authored-by: Nicolas Brichet <[email protected]>

---------

Co-authored-by: Nicolas Brichet <[email protected]>

---------

Co-authored-by: david qiu <[email protected]>
Co-authored-by: david qiu <[email protected]>
  • Loading branch information
3 people authored Dec 6, 2024
1 parent a31ec1f commit cb04fa4
Show file tree
Hide file tree
Showing 16 changed files with 849 additions and 96 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType
from jupyter_ai.models import HumanChatMessage
from jupyterlab_chat.ychat import YChat


class TestSlashCommand(BaseChatHandler):
Expand All @@ -25,5 +26,5 @@ class TestSlashCommand(BaseChatHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def process_message(self, message: HumanChatMessage):
async def process_message(self, message: HumanChatMessage, chat: YChat):
self.reply("This is the `/test` slash command.")
5 changes: 5 additions & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# The following import is to make sure jupyter_ydoc is imported before
# jupyterlab_chat, otherwise it leads to circular import because of the
# YChat relying on YBaseDoc, and jupyter_ydoc registering YChat from the entry point.
import jupyter_ydoc

from .ask import AskChatHandler
from .base import BaseChatHandler, SlashCommandRoutingType
from .clear import ClearChatHandler
Expand Down
147 changes: 108 additions & 39 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dask.distributed import Client as DaskClient
from jupyter_ai.callback_handlers import MetadataCallbackHandler
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.constants import BOT
from jupyter_ai.history import WrappedBoundedChatHistory
from jupyter_ai.models import (
AgentChatMessage,
Expand All @@ -36,6 +37,7 @@
)
from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from jupyterlab_chat.ychat import YChat
from langchain.pydantic_v1 import BaseModel
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables import Runnable
Expand Down Expand Up @@ -156,6 +158,7 @@ def __init__(
chat_handlers: Dict[str, "BaseChatHandler"],
context_providers: Dict[str, "BaseCommandContextProvider"],
message_interrupted: Dict[str, asyncio.Event],
ychat: Optional[YChat],
):
self.log = log
self.config_manager = config_manager
Expand All @@ -178,6 +181,14 @@ def __init__(
self.chat_handlers = chat_handlers
self.context_providers = context_providers
self.message_interrupted = message_interrupted
self.ychat = ychat
self.indexes_by_id: Dict[str, str] = {}
"""
Indexes of messages in the YChat document by message ID.
TODO: Remove this once `jupyterlab-chat` can update messages by ID
without an index.
"""

self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
Expand All @@ -198,7 +209,7 @@ async def on_message(self, message: HumanChatMessage):
slash_command = "/" + routing_type.slash_id if routing_type.slash_id else ""
if slash_command in lm_provider_klass.unsupported_slash_commands:
self.reply(
"Sorry, the selected language model does not support this slash command."
"Sorry, the selected language model does not support this slash command.",
)
return

Expand Down Expand Up @@ -238,7 +249,7 @@ async def process_message(self, message: HumanChatMessage):
"""
Processes a human message routed to this chat handler. Chat handlers
(subclasses) must implement this method. Don't forget to call
`self.reply(<response>, message)` at the end!
`self.reply(<response>, chat, message)` at the end!
The method definition does not need to be wrapped in a try/except block;
any exceptions raised here are caught by `self.handle_exc()`.
Expand Down Expand Up @@ -271,11 +282,41 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage):
)
self.reply(response, message)

def write_message(self, body: str, id: Optional[str] = None) -> None:
"""[Jupyter Chat only] Writes a message to the YChat shared document
that this chat handler is assigned to."""
if not self.ychat:
return

bot = self.ychat.get_user(BOT["username"])
if not bot:
self.ychat.set_user(BOT)

index = self.indexes_by_id.get(id, None)
id = id if id else str(uuid4())
new_index = self.ychat.set_message(
{
"type": "msg",
"body": body,
"id": id if id else str(uuid4()),
"time": time.time(),
"sender": BOT["username"],
"raw_time": False,
},
index=index,
append=True,
)

self.indexes_by_id[id] = new_index
return id

def broadcast_message(self, message: Message):
"""
Broadcasts a message to all WebSocket connections. If there are no
WebSocket connections and the message is a chat message, this method
directly appends to `self.chat_history`.
TODO: Remove this after Jupyter Chat migration is complete.
"""
broadcast = False
for websocket in self._root_chat_handlers.values():
Expand All @@ -291,20 +332,26 @@ def broadcast_message(self, message: Message):
cast(ChatMessage, message)
self._chat_history.append(message)

def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
def reply(
self,
response: str,
human_msg: Optional[HumanChatMessage] = None,
):
"""
Sends an agent message, usually in response to a received
`HumanChatMessage`.
"""
agent_msg = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=response,
reply_to=human_msg.id if human_msg else "",
persona=self.persona,
)

self.broadcast_message(agent_msg)
if self.ychat is not None:
self.write_message(response, None)
else:
agent_msg = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=response,
reply_to=human_msg.id if human_msg else "",
persona=self.persona,
)
self.broadcast_message(agent_msg)

@property
def persona(self):
Expand Down Expand Up @@ -333,7 +380,10 @@ def start_pending(
ellipsis=ellipsis,
)

self.broadcast_message(pending_msg)
if self.ychat is not None and self.ychat.awareness is not None:
self.ychat.awareness.set_local_state_field("isWriting", True)
else:
self.broadcast_message(pending_msg)
return pending_msg

def close_pending(self, pending_msg: PendingMessage):
Expand All @@ -347,7 +397,10 @@ def close_pending(self, pending_msg: PendingMessage):
id=pending_msg.id,
)

self.broadcast_message(close_pending_msg)
if self.ychat is not None and self.ychat.awareness is not None:
self.ychat.awareness.set_local_state_field("isWriting", False)
else:
self.broadcast_message(close_pending_msg)
pending_msg.closed = True

@contextlib.contextmanager
Expand All @@ -361,6 +414,9 @@ def pending(
"""
Context manager that sends a pending message to the client, and closes
it after the block is executed.
TODO: Simplify it by only modifying the awareness as soon as jupyterlab chat
is the only used chat.
"""
pending_msg = self.start_pending(text, human_msg=human_msg, ellipsis=ellipsis)
try:
Expand Down Expand Up @@ -470,32 +526,39 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non
slash_commands_list=slash_commands_list,
context_commands_list=context_commands_list,
)
help_message = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=help_message_body,
reply_to=human_msg.id if human_msg else "",
persona=self.persona,
)

self.broadcast_message(help_message)
if self.ychat is not None:
self.write_message(help_message_body, None)
else:
help_message = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
body=help_message_body,
reply_to=human_msg.id if human_msg else "",
persona=self.persona,
)
self.broadcast_message(help_message)

def _start_stream(self, human_msg: HumanChatMessage) -> str:
"""
Sends an `agent-stream` message to indicate the start of a response
stream. Returns the ID of the message, denoted as the `stream_id`.
"""
stream_id = uuid4().hex
stream_msg = AgentStreamMessage(
id=stream_id,
time=time.time(),
body="",
reply_to=human_msg.id,
persona=self.persona,
complete=False,
)
if self.ychat is not None:
stream_id = self.write_message("", None)
else:
stream_id = uuid4().hex
stream_msg = AgentStreamMessage(
id=stream_id,
time=time.time(),
body="",
reply_to=human_msg.id,
persona=self.persona,
complete=False,
)

self.broadcast_message(stream_msg)

self.broadcast_message(stream_msg)
return stream_id

def _send_stream_chunk(
Expand All @@ -509,13 +572,19 @@ def _send_stream_chunk(
Sends an `agent-stream-chunk` message containing content that should be
appended to an existing `agent-stream` message with ID `stream_id`.
"""
if not metadata:
metadata = {}

stream_chunk_msg = AgentStreamChunkMessage(
id=stream_id, content=content, stream_complete=complete, metadata=metadata
)
self.broadcast_message(stream_chunk_msg)
if self.ychat is not None:
self.write_message(content, stream_id)
else:
if not metadata:
metadata = {}

stream_chunk_msg = AgentStreamChunkMessage(
id=stream_id,
content=content,
stream_complete=complete,
metadata=metadata,
)
self.broadcast_message(stream_chunk_msg)

async def stream_reply(
self,
Expand Down
7 changes: 4 additions & 3 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ async def process_message(self, message: HumanChatMessage):
em_provider_cls, em_provider_args = self.get_embedding_provider()
if not em_provider_cls:
self.reply(
"Sorry, please select an embedding provider before using the `/learn` command."
"Sorry, please select an embedding provider before using the `/learn` command.",
)
return

Expand Down Expand Up @@ -163,14 +163,15 @@ async def process_message(self, message: HumanChatMessage):
except ModuleNotFoundError as e:
self.log.error(e)
self.reply(
"No `arxiv` package found. " "Install with `pip install arxiv`."
"No `arxiv` package found. "
"Install with `pip install arxiv`.",
)
return
except Exception as e:
self.log.error(e)
self.reply(
"An error occurred while processing the arXiv file. "
f"Please verify that the arxiv id {id} is correct."
f"Please verify that the arxiv id {id} is correct.",
)
return

Expand Down
8 changes: 8 additions & 0 deletions packages/jupyter-ai/jupyter_ai/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# The BOT currently has a fixed username, because this username is used has key in chats,
# it needs to constant. Do we need to change it ?
BOT = {
"username": "5f6a7570-7974-6572-6e61-75742d626f74",
"name": "Jupyternaut",
"display_name": "Jupyternaut",
"initials": "J",
}
Loading

0 comments on commit cb04fa4

Please sign in to comment.