diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 79736f2cb..5c3026685 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -71,7 +71,7 @@ async def process_message(self, message: HumanChatMessage): self.get_llm_chain() try: - with self.pending("Searching learned documents"): + with self.pending("Searching learned documents", message): result = await self.llm_chain.acall({"question": query}) response = result["answer"] self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 516afce6c..b97015518 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -270,7 +270,13 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): def persona(self): return self.config_manager.persona - def start_pending(self, text: str, ellipsis: bool = True) -> PendingMessage: + def start_pending( + self, + text: str, + human_msg: Optional[HumanChatMessage] = None, + *, + ellipsis: bool = True, + ) -> PendingMessage: """ Sends a pending message to the client. @@ -282,6 +288,7 @@ def start_pending(self, text: str, ellipsis: bool = True) -> PendingMessage: id=uuid4().hex, time=time.time(), body=text, + reply_to=human_msg.id if human_msg else "", persona=Persona(name=persona.name, avatar_route=persona.avatar_route), ellipsis=ellipsis, ) @@ -315,12 +322,18 @@ def close_pending(self, pending_msg: PendingMessage): pending_msg.closed = True @contextlib.contextmanager - def pending(self, text: str, ellipsis: bool = True): + def pending( + self, + text: str, + human_msg: Optional[HumanChatMessage] = None, + *, + ellipsis: bool = True, + ): """ Context manager that sends a pending message to the client, and closes it after the block is executed. """ - pending_msg = self.start_pending(text, ellipsis=ellipsis) + pending_msg = self.start_pending(text, human_msg=human_msg, ellipsis=ellipsis) try: yield pending_msg finally: @@ -378,17 +391,15 @@ def parse_args(self, message, silent=False): return None return args - def get_llm_chat_history( + def get_llm_chat_memory( self, - last_human_msg: Optional[HumanChatMessage] = None, + last_human_msg: HumanChatMessage, **kwargs, ) -> "BaseChatMessageHistory": - if last_human_msg: - return WrappedBoundedChatHistory( - history=self.llm_chat_memory, - last_human_msg=last_human_msg, - ) - return self.llm_chat_memory + return WrappedBoundedChatHistory( + history=self.llm_chat_memory, + last_human_msg=last_human_msg, + ) @property def output_dir(self) -> str: diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index 22a7c83e9..a05bc3e57 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -23,9 +23,4 @@ async def process_message(self, _): continue handler.broadcast_message(ClearMessage()) - self._chat_history.clear() - self.llm_chat_memory.clear() break - - # re-send help message - self.send_help_message() diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 02256f4eb..a51ef29e3 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -44,7 +44,7 @@ def create_llm_chain( if not llm.manages_history: runnable = RunnableWithMessageHistory( runnable=runnable, - get_session_history=self.get_llm_chat_history, + get_session_history=self.get_llm_chat_memory, input_messages_key="input", history_messages_key="history", history_factory_config=[ @@ -101,7 +101,7 @@ async def process_message(self, message: HumanChatMessage): received_first_chunk = False # start with a pending message - with self.pending("Generating response") as pending_message: + with self.pending("Generating response", message) as pending_message: # stream response in chunks. this works even if a provider does not # implement streaming, as `astream()` defaults to yielding `_call()` # when `_stream()` is not implemented on the LLM class. diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index 318f9a5dd..d6ecc6d81 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -92,7 +92,7 @@ async def process_message(self, message: HumanChatMessage): extra_instructions = message.prompt[4:].strip() or "None." self.get_llm_chain() - with self.pending("Analyzing error"): + with self.pending("Analyzing error", message): response = await self.llm_chain.apredict( extra_instructions=extra_instructions, stop=["\nHuman:"], diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 8d6fb09aa..29e147f22 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -187,7 +187,7 @@ async def process_message(self, message: HumanChatMessage): # delete and relearn index if embedding model was changed await self.delete_and_relearn() - with self.pending(f"Loading and splitting files for {load_path}"): + with self.pending(f"Loading and splitting files for {load_path}", message): try: await self.learn_dir( load_path, args.chunk_size, args.chunk_overlap, args.all_files diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 8021441db..a614e3e84 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -24,6 +24,8 @@ ChatMessage, ChatRequest, ChatUser, + ClearMessage, + ClearRequest, ClosePendingMessage, ConnectionMessage, HumanChatMessage, @@ -40,6 +42,8 @@ from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider from jupyter_ai_magics.providers import BaseProvider + from .history import BoundChatHistory + class ChatHistoryHandler(BaseAPIHandler): """Handler to return message history""" @@ -98,6 +102,10 @@ def chat_history(self) -> List[ChatMessage]: def chat_history(self, new_history): self.settings["chat_history"] = new_history + @property + def llm_chat_memory(self) -> "BoundChatHistory": + return self.settings["llm_chat_memory"] + @property def loop(self) -> AbstractEventLoop: return self.settings["jai_event_loop"] @@ -202,14 +210,6 @@ def broadcast_message(self, message: Message): Appends message to chat history. """ - self.log.debug("Broadcasting message: %s to all clients...", message) - client_ids = self.root_chat_handlers.keys() - - for client_id in client_ids: - client = self.root_chat_handlers[client_id] - if client: - client.write_message(message.dict()) - # do not broadcast agent messages that are replying to cleared human message if ( isinstance(message, (AgentChatMessage, AgentStreamMessage)) @@ -220,6 +220,14 @@ def broadcast_message(self, message: Message): ]: return + self.log.debug("Broadcasting message: %s to all clients...", message) + client_ids = self.root_chat_handlers.keys() + + for client_id in client_ids: + client = self.root_chat_handlers[client_id] + if client: + client.write_message(message.dict()) + # append all messages of type `ChatMessage` directly to the chat history if isinstance( message, (HumanChatMessage, AgentChatMessage, AgentStreamMessage) @@ -246,17 +254,48 @@ def broadcast_message(self, message: Message): self.pending_messages = list( filter(lambda m: m.id != message.id, self.pending_messages) ) + elif isinstance(message, ClearMessage): + if message.targets: + self._clear_chat_history_at(message.targets) + else: + self.chat_history.clear() + self.pending_messages.clear() + self.llm_chat_memory.clear() + self.settings["jai_chat_handlers"]["default"].send_help_message() async def on_message(self, message): self.log.debug("Message received: %s", message) try: message = json.loads(message) - chat_request = ChatRequest(**message) + if message.get("type") == "clear": + request = ClearRequest(**message) + else: + request = ChatRequest(**message) except ValidationError as e: self.log.error(e) return + if isinstance(request, ClearRequest): + if not request.target: + targets = None + elif request.after: + target_msg = None + for msg in self.chat_history: + if msg.id == request.target: + target_msg = msg + if target_msg: + targets = [ + msg.id + for msg in self.chat_history + if msg.time >= target_msg.time and msg.type == "human" + ] + else: + targets = [request.target] + self.broadcast_message(ClearMessage(targets=targets)) + return + + chat_request = request message_body = chat_request.prompt if chat_request.selection: message_body += f"\n\n```\n{chat_request.selection.source}\n```\n" @@ -302,6 +341,20 @@ async def _route(self, message): command_readable = "Default" if command == "default" else command self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.") + def _clear_chat_history_at(self, msg_ids: List[str]): + """ + Clears conversation exchanges associated with list of human message IDs. + """ + self.chat_history[:] = [ + msg + for msg in self.chat_history + if msg.id not in msg_ids and getattr(msg, "reply_to", None) not in msg_ids + ] + self.pending_messages[:] = [ + msg for msg in self.pending_messages if msg.reply_to not in msg_ids + ] + self.llm_chat_memory.clear(msg_ids) + def on_close(self): self.log.debug("Disconnecting client with user %s", self.client_id) diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 8216fbcaf..9e1064194 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -1,5 +1,5 @@ import time -from typing import List, Sequence +from typing import List, Optional, Sequence, Set from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage @@ -7,6 +7,8 @@ from .models import HumanChatMessage +HUMAN_MSG_ID_KEY = "_jupyter_ai_human_msg_id" + class BoundedChatHistory(BaseChatMessageHistory, BaseModel): """ @@ -19,6 +21,7 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel): k: int clear_time: float = 0.0 + cleared_msgs: Set[str] = set() _all_messages: List[BaseMessage] = PrivateAttr(default_factory=list) @property @@ -30,15 +33,33 @@ async def aget_messages(self) -> List[BaseMessage]: def add_message(self, message: BaseMessage) -> None: """Add a self-created message to the store""" + if HUMAN_MSG_ID_KEY not in message.additional_kwargs: + # human message id must be added to allow for targeted clearing of messages. + # `WrappedBoundedChatHistory` should be used instead to add messages. + raise ValueError( + "Message must have a human message ID to be added to the store." + ) self._all_messages.append(message) async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: """Add messages to the store""" self.add_messages(messages) - def clear(self) -> None: - self._all_messages = [] - self.clear_time = time.time() + def clear(self, human_msg_ids: Optional[List[str]] = None) -> None: + """Clears conversation exchanges. If `human_msg_id` is provided, only + clears the respective human message and its reply. Otherwise, clears + all messages.""" + if human_msg_ids: + self._all_messages = [ + m + for m in self._all_messages + if m.additional_kwargs[HUMAN_MSG_ID_KEY] not in human_msg_ids + ] + self.cleared_msgs.update(human_msg_ids) + else: + self._all_messages = [] + self.cleared_msgs = set() + self.clear_time = time.time() async def aclear(self) -> None: self.clear() @@ -73,8 +94,12 @@ def messages(self) -> List[BaseMessage]: return self.history.messages def add_message(self, message: BaseMessage) -> None: - """Prevent adding messages to the store if clear was triggered.""" - if self.last_human_msg.time > self.history.clear_time: + # prevent adding pending messages to the store if clear was triggered. + if ( + self.last_human_msg.time > self.history.clear_time + and self.last_human_msg.id not in self.history.cleared_msgs + ): + message.additional_kwargs[HUMAN_MSG_ID_KEY] = self.last_human_msg.id self.history.add_message(message) async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index f2fa098bb..f9098a12a 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -39,6 +39,20 @@ class ChatRequest(BaseModel): selection: Optional[Selection] +class ClearRequest(BaseModel): + type: Literal["clear"] + target: Optional[str] + """ + Message ID of the HumanChatMessage to delete an exchange at. + If not provided, this requests the backend to clear all messages. + """ + + after: Optional[bool] + """ + Whether to clear target and all subsequent exchanges. + """ + + class ChatUser(BaseModel): # User ID assigned by IdentityProvider. username: str @@ -105,6 +119,11 @@ class HumanChatMessage(BaseModel): class ClearMessage(BaseModel): type: Literal["clear"] = "clear" + targets: Optional[List[str]] = None + """ + Message IDs of the HumanChatMessage to delete an exchange at. + If not provided, this instructs the frontend to clear all messages. + """ class PendingMessage(BaseModel): @@ -112,6 +131,7 @@ class PendingMessage(BaseModel): id: str time: float body: str + reply_to: str persona: Persona ellipsis: bool = True closed: bool = False diff --git a/packages/jupyter-ai/src/chat_handler.ts b/packages/jupyter-ai/src/chat_handler.ts index bdce19458..76c93a851 100644 --- a/packages/jupyter-ai/src/chat_handler.ts +++ b/packages/jupyter-ai/src/chat_handler.ts @@ -39,7 +39,7 @@ export class ChatHandler implements IDisposable { * Sends a message across the WebSocket. Promise resolves to the message ID * when the server sends the same message back, acknowledging receipt. */ - public sendMessage(message: AiService.ChatRequest): Promise { + public sendMessage(message: AiService.Request): Promise { return new Promise(resolve => { this._socket?.send(JSON.stringify(message)); this._sendResolverQueue.push(resolve); @@ -132,8 +132,20 @@ export class ChatHandler implements IDisposable { case 'connection': break; case 'clear': - this._messages = []; - this._pendingMessages = []; + if (newMessage.targets) { + const targets = newMessage.targets; + this._messages = this._messages.filter( + msg => + !targets.includes(msg.id) && + !('reply_to' in msg && targets.includes(msg.reply_to)) + ); + this._pendingMessages = this._pendingMessages.filter( + msg => !targets.includes(msg.reply_to) + ); + } else { + this._messages = []; + this._pendingMessages = []; + } break; case 'pending': this._pendingMessages = [...this._pendingMessages, newMessage]; diff --git a/packages/jupyter-ai/src/components/chat-messages.tsx b/packages/jupyter-ai/src/components/chat-messages.tsx index 86b6793d9..c3fc0921a 100644 --- a/packages/jupyter-ai/src/components/chat-messages.tsx +++ b/packages/jupyter-ai/src/components/chat-messages.tsx @@ -10,16 +10,20 @@ import { AiService } from '../handler'; import { RendermimeMarkdown } from './rendermime-markdown'; import { useCollaboratorsContext } from '../contexts/collaborators-context'; import { ChatMessageMenu } from './chat-messages/chat-message-menu'; +import { ChatMessageDelete } from './chat-messages/chat-message-delete'; +import { ChatHandler } from '../chat_handler'; import { IJaiMessageFooter } from '../tokens'; type ChatMessagesProps = { rmRegistry: IRenderMimeRegistry; messages: AiService.ChatMessage[]; + chatHandler: ChatHandler; messageFooter: IJaiMessageFooter | null; }; type ChatMessageHeaderProps = { message: AiService.ChatMessage; + chatHandler: ChatHandler; timestamp: string; sx?: SxProps; }; @@ -113,6 +117,7 @@ export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element { const shouldShowMenu = props.message.type === 'agent' || (props.message.type === 'agent-stream' && props.message.complete); + const shouldShowDelete = props.message.type === 'human'; return ( )} + {shouldShowDelete && ( + + )} @@ -208,6 +220,7 @@ export function ChatMessages(props: ChatMessagesProps): JSX.Element { props.chatHandler.sendMessage(request)} + sx={props.sx} + tooltip="Delete this exchange" + > + + + ); +} + +export default ChatMessageDelete; diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index 0915d6ca6..dbff8b65f 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -3,6 +3,7 @@ import { Box } from '@mui/system'; import { Button, IconButton, Stack } from '@mui/material'; import SettingsIcon from '@mui/icons-material/Settings'; import ArrowBackIcon from '@mui/icons-material/ArrowBack'; +import AddIcon from '@mui/icons-material/Add'; import type { Awareness } from 'y-protocols/awareness'; import type { IThemeManager } from '@jupyterlab/apputils'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; @@ -24,10 +25,13 @@ import { ActiveCellManager } from '../contexts/active-cell-context'; import { ScrollContainer } from './scroll-container'; +import { TooltippedIconButton } from './mui-extras/tooltipped-icon-button'; type ChatBodyProps = { chatHandler: ChatHandler; - setChatView: (view: ChatView) => void; + openSettingsView: () => void; + showWelcomeMessage: boolean; + setShowWelcomeMessage: (show: boolean) => void; rmRegistry: IRenderMimeRegistry; focusInputSignal: ISignal; messageFooter: IJaiMessageFooter | null; @@ -51,7 +55,9 @@ function getPersonaName(messages: AiService.ChatMessage[]): string { function ChatBody({ chatHandler, focusInputSignal, - setChatView: chatViewHandler, + openSettingsView, + showWelcomeMessage, + setShowWelcomeMessage, rmRegistry: renderMimeRegistry, messageFooter }: ChatBodyProps): JSX.Element { @@ -64,7 +70,6 @@ function ChatBody({ const [personaName, setPersonaName] = useState( getPersonaName(messages) ); - const [showWelcomeMessage, setShowWelcomeMessage] = useState(false); const [sendWithShiftEnter, setSendWithShiftEnter] = useState(true); /** @@ -103,11 +108,6 @@ function ChatBody({ }; }, [chatHandler]); - const openSettingsView = () => { - setShowWelcomeMessage(false); - chatViewHandler(ChatView.Settings); - }; - if (showWelcomeMessage) { return ( - + (props.chatView || ChatView.Chat); + const [showWelcomeMessage, setShowWelcomeMessage] = useState(false); + + const openSettingsView = () => { + setShowWelcomeMessage(false); + setView(ChatView.Settings); + }; return ( @@ -216,9 +223,21 @@ export function Chat(props: ChatProps): JSX.Element { )} {view === ChatView.Chat ? ( - setView(ChatView.Settings)}> - - + + {!showWelcomeMessage && ( + + props.chatHandler.sendMessage({ type: 'clear' }) + } + tooltip="New chat" + > + + + )} + openSettingsView()}> + + + ) : ( )} @@ -227,7 +246,9 @@ export function Chat(props: ChatProps): JSX.Element { {view === ChatView.Chat && (