Skip to content

Commit

Permalink
support clearing all subsequent exchanges
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelchia committed Sep 8, 2024
1 parent cdf864b commit f4d996e
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 21 deletions.
31 changes: 23 additions & 8 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ def broadcast_message(self, message: Message):
filter(lambda m: m.id != message.id, self.pending_messages)
)
elif isinstance(message, ClearMessage):
if message.target:
self._clear_chat_history_at(message.target)
if message.targets:
self._clear_chat_history_at(message.targets)
else:
self.chat_history.clear()
self.pending_messages.clear()
Expand All @@ -277,7 +277,22 @@ async def on_message(self, message):
return

if isinstance(request, ClearRequest):
self.broadcast_message(ClearMessage(target=request.target))
if not request.target:
targets = None
elif request.after:
target_msg = None
for msg in self.chat_history:
if msg.id == request.target:
target_msg = msg
if target_msg:
targets = [
msg.id
for msg in self.chat_history
if msg.time >= target_msg.time and msg.type == "human"
]
else:
targets = [request.target]
self.broadcast_message(ClearMessage(targets=targets))
return

chat_request = request
Expand Down Expand Up @@ -326,19 +341,19 @@ async def _route(self, message):
command_readable = "Default" if command == "default" else command
self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.")

def _clear_chat_history_at(self, msg_id: str):
def _clear_chat_history_at(self, msg_ids: List[str]):
"""
Clears a conversation exchange given a human message ID `msg_id`.
Clears conversation exchanges associated with list of human message IDs.
"""
self.chat_history[:] = [
msg
for msg in self.chat_history
if msg.id != msg_id and getattr(msg, "reply_to", None) != msg_id
if msg.id not in msg_ids and getattr(msg, "reply_to", None) not in msg_ids
]
self.pending_messages[:] = [
msg for msg in self.pending_messages if msg.reply_to != msg_id
msg for msg in self.pending_messages if msg.reply_to not in msg_ids
]
self.llm_chat_memory.clear(msg_id)
self.llm_chat_memory.clear(msg_ids)

def on_close(self):
self.log.debug("Disconnecting client with user %s", self.client_id)
Expand Down
9 changes: 4 additions & 5 deletions packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Add messages to the store"""
self.add_messages(messages)

def clear(self, human_msg_id: Optional[str] = None) -> None:
def clear(self, human_msg_ids: Optional[List[str]] = None) -> None:
"""Clears conversation exchanges. If `human_msg_id` is provided, only
clears the respective human message and its reply. Otherwise, clears
all messages."""
if human_msg_id:
if human_msg_ids:
self._all_messages = [
m
for m in self._all_messages
if m.additional_kwargs[HUMAN_MSG_ID_KEY] != human_msg_id
if m.additional_kwargs[HUMAN_MSG_ID_KEY] not in human_msg_ids
]
self.cleared_msgs.add(human_msg_id)
self.cleared_msgs.update(human_msg_ids)
else:
self._all_messages = []
self.cleared_msgs = set()
Expand Down Expand Up @@ -95,7 +95,6 @@ def messages(self) -> List[BaseMessage]:

def add_message(self, message: BaseMessage) -> None:
# prevent adding pending messages to the store if clear was triggered.
# if targeted clearing, prevent adding target message if still pending.
if (
self.last_human_msg.time > self.history.clear_time
and self.last_human_msg.id not in self.history.cleared_msgs
Expand Down
9 changes: 7 additions & 2 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class ClearRequest(BaseModel):
If not provided, this requests the backend to clear all messages.
"""

after: Optional[bool]
"""
Whether to clear target and all subsequent exchanges.
"""


class ChatUser(BaseModel):
# User ID assigned by IdentityProvider.
Expand Down Expand Up @@ -114,9 +119,9 @@ class HumanChatMessage(BaseModel):

class ClearMessage(BaseModel):
type: Literal["clear"] = "clear"
target: Optional[str] = None
targets: Optional[List[str]] = None
"""
Message ID of the HumanChatMessage to delete an exchange at.
Message IDs of the HumanChatMessage to delete an exchange at.
If not provided, this instructs the frontend to clear all messages.
"""

Expand Down
9 changes: 5 additions & 4 deletions packages/jupyter-ai/src/chat_handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ export class ChatHandler implements IDisposable {
case 'connection':
break;
case 'clear':
if (newMessage.target) {
if (newMessage.targets) {
const targets = newMessage.targets;
this._messages = this._messages.filter(
msg =>
msg.id !== newMessage.target &&
!('reply_to' in msg && msg.reply_to === newMessage.target)
!targets.includes(msg.id) &&
!('reply_to' in msg && targets.includes(msg.reply_to))
);
this._pendingMessages = this._pendingMessages.filter(
msg => msg.reply_to !== newMessage.target
msg => !targets.includes(msg.reply_to)
);
} else {
this._messages = [];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ type DeleteButtonProps = {
export function ChatMessageDelete(props: DeleteButtonProps): JSX.Element {
const request: AiService.ClearRequest = {
type: 'clear',
target: props.message.id
target: props.message.id,
after: false
};
return (
<TooltippedIconButton
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ export namespace AiService {
export type ClearRequest = {
type: 'clear';
target?: string;
after?: boolean;
};

export type Collaborator = {
Expand Down Expand Up @@ -143,7 +144,7 @@ export namespace AiService {

export type ClearMessage = {
type: 'clear';
target?: string;
targets?: string[];
};

export type PendingMessage = {
Expand Down

0 comments on commit f4d996e

Please sign in to comment.