Skip to content

Commit

Permalink
Add ability to delete messages + start new chat session (#951)
Browse files Browse the repository at this point in the history
* add ui components

* temp add help message to new chat

* at to target

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* broadcast ClearMessage sends help message

* clear llm_chat_memory

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* typo

* Update chat welcome message

* improve docstring

Co-authored-by: david qiu <[email protected]>

* type typo

* use tooltippedbutton + do not show new chat button on welcome

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update Playwright Snapshots

* fix not adding uncleared pending messages to memory

* reimplement to delete specific message exchange

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

* refine docs

Co-authored-by: david qiu <[email protected]>

* keep list of cleared messages

* update doc

* support clearing all subsequent exchanges

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sanjiv Das <[email protected]>
Co-authored-by: david qiu <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Sep 9, 2024
1 parent 4b45ec2 commit df67eb8
Show file tree
Hide file tree
Showing 16 changed files with 247 additions and 52 deletions.
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()

try:
with self.pending("Searching learned documents"):
with self.pending("Searching learned documents", message):
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.reply(response, message)
Expand Down
33 changes: 22 additions & 11 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,13 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
def persona(self):
return self.config_manager.persona

def start_pending(self, text: str, ellipsis: bool = True) -> PendingMessage:
def start_pending(
self,
text: str,
human_msg: Optional[HumanChatMessage] = None,
*,
ellipsis: bool = True,
) -> PendingMessage:
"""
Sends a pending message to the client.
Expand All @@ -282,6 +288,7 @@ def start_pending(self, text: str, ellipsis: bool = True) -> PendingMessage:
id=uuid4().hex,
time=time.time(),
body=text,
reply_to=human_msg.id if human_msg else "",
persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
ellipsis=ellipsis,
)
Expand Down Expand Up @@ -315,12 +322,18 @@ def close_pending(self, pending_msg: PendingMessage):
pending_msg.closed = True

@contextlib.contextmanager
def pending(self, text: str, ellipsis: bool = True):
def pending(
self,
text: str,
human_msg: Optional[HumanChatMessage] = None,
*,
ellipsis: bool = True,
):
"""
Context manager that sends a pending message to the client, and closes
it after the block is executed.
"""
pending_msg = self.start_pending(text, ellipsis=ellipsis)
pending_msg = self.start_pending(text, human_msg=human_msg, ellipsis=ellipsis)
try:
yield pending_msg
finally:
Expand Down Expand Up @@ -378,17 +391,15 @@ def parse_args(self, message, silent=False):
return None
return args

def get_llm_chat_history(
def get_llm_chat_memory(
self,
last_human_msg: Optional[HumanChatMessage] = None,
last_human_msg: HumanChatMessage,
**kwargs,
) -> "BaseChatMessageHistory":
if last_human_msg:
return WrappedBoundedChatHistory(
history=self.llm_chat_memory,
last_human_msg=last_human_msg,
)
return self.llm_chat_memory
return WrappedBoundedChatHistory(
history=self.llm_chat_memory,
last_human_msg=last_human_msg,
)

@property
def output_dir(self) -> str:
Expand Down
5 changes: 0 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,4 @@ async def process_message(self, _):
continue

handler.broadcast_message(ClearMessage())
self._chat_history.clear()
self.llm_chat_memory.clear()
break

# re-send help message
self.send_help_message()
4 changes: 2 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def create_llm_chain(
if not llm.manages_history:
runnable = RunnableWithMessageHistory(
runnable=runnable,
get_session_history=self.get_llm_chat_history,
get_session_history=self.get_llm_chat_memory,
input_messages_key="input",
history_messages_key="history",
history_factory_config=[
Expand Down Expand Up @@ -101,7 +101,7 @@ async def process_message(self, message: HumanChatMessage):
received_first_chunk = False

# start with a pending message
with self.pending("Generating response") as pending_message:
with self.pending("Generating response", message) as pending_message:
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def process_message(self, message: HumanChatMessage):
extra_instructions = message.prompt[4:].strip() or "None."

self.get_llm_chain()
with self.pending("Analyzing error"):
with self.pending("Analyzing error", message):
response = await self.llm_chain.apredict(
extra_instructions=extra_instructions,
stop=["\nHuman:"],
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def process_message(self, message: HumanChatMessage):
# delete and relearn index if embedding model was changed
await self.delete_and_relearn()

with self.pending(f"Loading and splitting files for {load_path}"):
with self.pending(f"Loading and splitting files for {load_path}", message):
try:
await self.learn_dir(
load_path, args.chunk_size, args.chunk_overlap, args.all_files
Expand Down
71 changes: 62 additions & 9 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
ChatMessage,
ChatRequest,
ChatUser,
ClearMessage,
ClearRequest,
ClosePendingMessage,
ConnectionMessage,
HumanChatMessage,
Expand All @@ -40,6 +42,8 @@
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
from jupyter_ai_magics.providers import BaseProvider

from .history import BoundChatHistory


class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""
Expand Down Expand Up @@ -98,6 +102,10 @@ def chat_history(self) -> List[ChatMessage]:
def chat_history(self, new_history):
self.settings["chat_history"] = new_history

@property
def llm_chat_memory(self) -> "BoundChatHistory":
return self.settings["llm_chat_memory"]

@property
def loop(self) -> AbstractEventLoop:
return self.settings["jai_event_loop"]
Expand Down Expand Up @@ -202,14 +210,6 @@ def broadcast_message(self, message: Message):
Appends message to chat history.
"""

self.log.debug("Broadcasting message: %s to all clients...", message)
client_ids = self.root_chat_handlers.keys()

for client_id in client_ids:
client = self.root_chat_handlers[client_id]
if client:
client.write_message(message.dict())

# do not broadcast agent messages that are replying to cleared human message
if (
isinstance(message, (AgentChatMessage, AgentStreamMessage))
Expand All @@ -220,6 +220,14 @@ def broadcast_message(self, message: Message):
]:
return

self.log.debug("Broadcasting message: %s to all clients...", message)
client_ids = self.root_chat_handlers.keys()

for client_id in client_ids:
client = self.root_chat_handlers[client_id]
if client:
client.write_message(message.dict())

# append all messages of type `ChatMessage` directly to the chat history
if isinstance(
message, (HumanChatMessage, AgentChatMessage, AgentStreamMessage)
Expand All @@ -246,17 +254,48 @@ def broadcast_message(self, message: Message):
self.pending_messages = list(
filter(lambda m: m.id != message.id, self.pending_messages)
)
elif isinstance(message, ClearMessage):
if message.targets:
self._clear_chat_history_at(message.targets)
else:
self.chat_history.clear()
self.pending_messages.clear()
self.llm_chat_memory.clear()
self.settings["jai_chat_handlers"]["default"].send_help_message()

async def on_message(self, message):
self.log.debug("Message received: %s", message)

try:
message = json.loads(message)
chat_request = ChatRequest(**message)
if message.get("type") == "clear":
request = ClearRequest(**message)
else:
request = ChatRequest(**message)
except ValidationError as e:
self.log.error(e)
return

if isinstance(request, ClearRequest):
if not request.target:
targets = None
elif request.after:
target_msg = None
for msg in self.chat_history:
if msg.id == request.target:
target_msg = msg
if target_msg:
targets = [
msg.id
for msg in self.chat_history
if msg.time >= target_msg.time and msg.type == "human"
]
else:
targets = [request.target]
self.broadcast_message(ClearMessage(targets=targets))
return

chat_request = request
message_body = chat_request.prompt
if chat_request.selection:
message_body += f"\n\n```\n{chat_request.selection.source}\n```\n"
Expand Down Expand Up @@ -302,6 +341,20 @@ async def _route(self, message):
command_readable = "Default" if command == "default" else command
self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.")

def _clear_chat_history_at(self, msg_ids: List[str]):
"""
Clears conversation exchanges associated with list of human message IDs.
"""
self.chat_history[:] = [
msg
for msg in self.chat_history
if msg.id not in msg_ids and getattr(msg, "reply_to", None) not in msg_ids
]
self.pending_messages[:] = [
msg for msg in self.pending_messages if msg.reply_to not in msg_ids
]
self.llm_chat_memory.clear(msg_ids)

def on_close(self):
self.log.debug("Disconnecting client with user %s", self.client_id)

Expand Down
37 changes: 31 additions & 6 deletions packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import time
from typing import List, Sequence
from typing import List, Optional, Sequence, Set

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr

from .models import HumanChatMessage

HUMAN_MSG_ID_KEY = "_jupyter_ai_human_msg_id"


class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
"""
Expand All @@ -19,6 +21,7 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel):

k: int
clear_time: float = 0.0
cleared_msgs: Set[str] = set()
_all_messages: List[BaseMessage] = PrivateAttr(default_factory=list)

@property
Expand All @@ -30,15 +33,33 @@ async def aget_messages(self) -> List[BaseMessage]:

def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
if HUMAN_MSG_ID_KEY not in message.additional_kwargs:
# human message id must be added to allow for targeted clearing of messages.
# `WrappedBoundedChatHistory` should be used instead to add messages.
raise ValueError(
"Message must have a human message ID to be added to the store."
)
self._all_messages.append(message)

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Add messages to the store"""
self.add_messages(messages)

def clear(self) -> None:
self._all_messages = []
self.clear_time = time.time()
def clear(self, human_msg_ids: Optional[List[str]] = None) -> None:
"""Clears conversation exchanges. If `human_msg_id` is provided, only
clears the respective human message and its reply. Otherwise, clears
all messages."""
if human_msg_ids:
self._all_messages = [
m
for m in self._all_messages
if m.additional_kwargs[HUMAN_MSG_ID_KEY] not in human_msg_ids
]
self.cleared_msgs.update(human_msg_ids)
else:
self._all_messages = []
self.cleared_msgs = set()
self.clear_time = time.time()

async def aclear(self) -> None:
self.clear()
Expand Down Expand Up @@ -73,8 +94,12 @@ def messages(self) -> List[BaseMessage]:
return self.history.messages

def add_message(self, message: BaseMessage) -> None:
"""Prevent adding messages to the store if clear was triggered."""
if self.last_human_msg.time > self.history.clear_time:
# prevent adding pending messages to the store if clear was triggered.
if (
self.last_human_msg.time > self.history.clear_time
and self.last_human_msg.id not in self.history.cleared_msgs
):
message.additional_kwargs[HUMAN_MSG_ID_KEY] = self.last_human_msg.id
self.history.add_message(message)

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
Expand Down
20 changes: 20 additions & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ class ChatRequest(BaseModel):
selection: Optional[Selection]


class ClearRequest(BaseModel):
type: Literal["clear"]
target: Optional[str]
"""
Message ID of the HumanChatMessage to delete an exchange at.
If not provided, this requests the backend to clear all messages.
"""

after: Optional[bool]
"""
Whether to clear target and all subsequent exchanges.
"""


class ChatUser(BaseModel):
# User ID assigned by IdentityProvider.
username: str
Expand Down Expand Up @@ -105,13 +119,19 @@ class HumanChatMessage(BaseModel):

class ClearMessage(BaseModel):
type: Literal["clear"] = "clear"
targets: Optional[List[str]] = None
"""
Message IDs of the HumanChatMessage to delete an exchange at.
If not provided, this instructs the frontend to clear all messages.
"""


class PendingMessage(BaseModel):
type: Literal["pending"] = "pending"
id: str
time: float
body: str
reply_to: str
persona: Persona
ellipsis: bool = True
closed: bool = False
Expand Down
Loading

0 comments on commit df67eb8

Please sign in to comment.