From 40ab189c22e90eed36e6622ad079a64a6d033a04 Mon Sep 17 00:00:00 2001 From: michaelchia Date: Wed, 19 Jun 2024 00:08:33 +0800 Subject: [PATCH] Support pending/loading message while waiting for response (#821) * support pending message draft * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * styling + pending message for /fix * change default pending message * remove persona groups * inline styling * single timestamp * use message id as component key Co-authored-by: david qiu * fix conditional useEffect * prefer MUI Typography in PendingMessageElement to match font size * merge 2 outer div elements into 1 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: david qiu --- .../jupyter_ai/chat_handlers/ask.py | 5 +- .../jupyter_ai/chat_handlers/base.py | 60 ++++++++- .../jupyter_ai/chat_handlers/default.py | 5 +- .../jupyter_ai/chat_handlers/fix.py | 17 +-- .../jupyter_ai/chat_handlers/learn.py | 24 ++-- packages/jupyter-ai/jupyter_ai/models.py | 23 +++- packages/jupyter-ai/src/components/chat.tsx | 29 +++-- .../src/components/pending-messages.tsx | 115 ++++++++++++++++++ packages/jupyter-ai/src/handler.ts | 18 ++- 9 files changed, 262 insertions(+), 34 deletions(-) create mode 100644 packages/jupyter-ai/src/components/pending-messages.tsx diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index d34007a00..79736f2cb 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -71,8 +71,9 @@ async def process_message(self, message: HumanChatMessage): self.get_llm_chain() try: - result = await self.llm_chain.acall({"question": query}) - response = result["answer"] + with self.pending("Searching learned documents"): + result = await self.llm_chain.acall({"question": query}) + response = result["answer"] self.reply(response, message) except AssertionError as e: self.log.error(e) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 83636776e..97392168a 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -1,4 +1,5 @@ import argparse +import contextlib import os import time import traceback @@ -17,7 +18,13 @@ from dask.distributed import Client as DaskClient from jupyter_ai.config_manager import ConfigManager, Logger -from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage +from jupyter_ai.models import ( + AgentChatMessage, + ChatMessage, + ClosePendingMessage, + HumanChatMessage, + PendingMessage, +) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider from langchain.pydantic_v1 import BaseModel @@ -193,6 +200,57 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): handler.broadcast_message(agent_msg) break + def start_pending(self, text: str, ellipsis: bool = True) -> str: + """ + Sends a pending message to the client. + + Returns the pending message ID. + """ + persona = self.config_manager.persona + + pending_msg = PendingMessage( + id=uuid4().hex, + time=time.time(), + body=text, + persona=Persona(name=persona.name, avatar_route=persona.avatar_route), + ellipsis=ellipsis, + ) + + for handler in self._root_chat_handlers.values(): + if not handler: + continue + + handler.broadcast_message(pending_msg) + break + return pending_msg + + def close_pending(self, pending_msg: PendingMessage): + """ + Closes a pending message. + """ + close_pending_msg = ClosePendingMessage( + id=pending_msg.id, + ) + + for handler in self._root_chat_handlers.values(): + if not handler: + continue + + handler.broadcast_message(close_pending_msg) + break + + @contextlib.contextmanager + def pending(self, text: str, ellipsis: bool = True): + """ + Context manager that sends a pending message to the client, and closes + it after the block is executed. + """ + pending_msg = self.start_pending(text, ellipsis=ellipsis) + try: + yield + finally: + self.close_pending(pending_msg) + def get_llm_chain(self): lm_provider = self.config_manager.lm_provider lm_provider_params = self.config_manager.lm_provider_params diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3f936a142..75c5e6023 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -45,5 +45,8 @@ def create_llm_chain( async def process_message(self, message: HumanChatMessage): self.get_llm_chain() - response = await self.llm_chain.apredict(input=message.body, stop=["\nHuman:"]) + with self.pending("Generating response"): + response = await self.llm_chain.apredict( + input=message.body, stop=["\nHuman:"] + ) self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index 0f62e5681..f8c9f6f6b 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -92,12 +92,13 @@ async def process_message(self, message: HumanChatMessage): extra_instructions = message.body[4:].strip() or "None." self.get_llm_chain() - response = await self.llm_chain.apredict( - extra_instructions=extra_instructions, - stop=["\nHuman:"], - cell_content=selection.source, - error_name=selection.error.name, - error_value=selection.error.value, - traceback="\n".join(selection.error.traceback), - ) + with self.pending("Analyzing error"): + response = await self.llm_chain.apredict( + extra_instructions=extra_instructions, + stop=["\nHuman:"], + cell_content=selection.source, + error_name=selection.error.name, + error_value=selection.error.value, + traceback="\n".join(selection.error.traceback), + ) self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 0f10b0147..e8ca6bddc 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -151,19 +151,17 @@ async def process_message(self, message: HumanChatMessage): # delete and relearn index if embedding model was changed await self.delete_and_relearn() - if args.verbose: - self.reply(f"Loading and splitting files for {load_path}", message) - - try: - await self.learn_dir( - load_path, args.chunk_size, args.chunk_overlap, args.all_files - ) - except Exception as e: - response = f"""Learn documents in **{load_path}** failed. {str(e)}.""" - else: - self.save() - response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. - You can ask questions about these docs by prefixing your message with **/ask**.""" + with self.pending(f"Loading and splitting files for {load_path}"): + try: + await self.learn_dir( + load_path, args.chunk_size, args.chunk_overlap, args.all_files + ) + except Exception as e: + response = f"""Learn documents in **{load_path}** failed. {str(e)}.""" + else: + self.save() + response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. + You can ask questions about these docs by prefixing your message with **/ask**.""" self.reply(response, message) def _build_list_response(self): diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 84e2a524b..9e269a223 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -87,13 +87,34 @@ class ClearMessage(BaseModel): type: Literal["clear"] = "clear" +class PendingMessage(BaseModel): + type: Literal["pending"] = "pending" + id: str + time: float + body: str + persona: Persona + ellipsis: bool = True + + +class ClosePendingMessage(BaseModel): + type: Literal["pending"] = "close-pending" + id: str + + # the type of messages being broadcast to clients ChatMessage = Union[ AgentChatMessage, HumanChatMessage, ] -Message = Union[AgentChatMessage, HumanChatMessage, ConnectionMessage, ClearMessage] +Message = Union[ + AgentChatMessage, + HumanChatMessage, + ConnectionMessage, + ClearMessage, + PendingMessage, + ClosePendingMessage, +] class ChatHistory(BaseModel): diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index abf974054..4c232ed5b 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -9,6 +9,7 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { JlThemeProvider } from './jl-theme-provider'; import { ChatMessages } from './chat-messages'; +import { PendingMessages } from './pending-messages'; import { ChatInput } from './chat-input'; import { ChatSettings } from './chat-settings'; import { AiService } from '../handler'; @@ -38,6 +39,9 @@ function ChatBody({ rmRegistry: renderMimeRegistry }: ChatBodyProps): JSX.Element { const [messages, setMessages] = useState([]); + const [pendingMessages, setPendingMessages] = useState< + AiService.PendingMessage[] + >([]); const [showWelcomeMessage, setShowWelcomeMessage] = useState(false); const [includeSelection, setIncludeSelection] = useState(true); const [replaceSelection, setReplaceSelection] = useState(false); @@ -73,14 +77,24 @@ function ChatBody({ */ useEffect(() => { function handleChatEvents(message: AiService.Message) { - if (message.type === 'connection') { - return; - } else if (message.type === 'clear') { - setMessages([]); - return; + switch (message.type) { + case 'connection': + return; + case 'clear': + setMessages([]); + return; + case 'pending': + setPendingMessages(pendingMessages => [...pendingMessages, message]); + return; + case 'close-pending': + setPendingMessages(pendingMessages => + pendingMessages.filter(p => p.id !== message.id) + ); + return; + default: + setMessages(messageGroups => [...messageGroups, message]); + return; } - - setMessages(messageGroups => [...messageGroups, message]); } chatHandler.addListener(handleChatEvents); @@ -157,6 +171,7 @@ function ChatBody({ <> + { + const interval = setInterval(() => { + setDots(dots => (dots.length < 3 ? dots + '.' : '')); + }, 500); + + return () => clearInterval(interval); + }, []); + + let text = props.text; + if (props.ellipsis) { + text = props.text + dots; + } + + return ( + + {text.split('\n').map((line, index) => ( + + {line} + + ))} + + ); +} + +export function PendingMessages( + props: PendingMessagesProps +): JSX.Element | null { + const [timestamp, setTimestamp] = useState(''); + const [agentMessage, setAgentMessage] = + useState(null); + + useEffect(() => { + if (props.messages.length === 0) { + setAgentMessage(null); + setTimestamp(''); + return; + } + const lastMessage = props.messages[props.messages.length - 1]; + setAgentMessage({ + type: 'agent', + id: lastMessage.id, + time: lastMessage.time, + body: '', + reply_to: '', + persona: lastMessage.persona + }); + + // timestamp format copied from ChatMessage + const newTimestamp = new Date(lastMessage.time * 1000).toLocaleTimeString( + [], + { + hour: 'numeric', + minute: '2-digit' + } + ); + setTimestamp(newTimestamp); + }, [props.messages]); + + if (!agentMessage) { + return null; + } + + return ( + + + :not(:last-child)': { + marginBottom: '2em' + } + }} + > + {props.messages.map(message => ( + + ))} + + + ); +} diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 5d06691fe..c8f457fe2 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -115,12 +115,28 @@ export namespace AiService { type: 'clear'; }; + export type PendingMessage = { + type: 'pending'; + id: string; + time: number; + body: string; + persona: Persona; + ellipsis: boolean; + }; + + export type ClosePendingMessage = { + type: 'close-pending'; + id: string; + }; + export type ChatMessage = AgentChatMessage | HumanChatMessage; export type Message = | AgentChatMessage | HumanChatMessage | ConnectionMessage - | ClearMessage; + | ClearMessage + | PendingMessage + | ClosePendingMessage; export type ChatHistory = { messages: ChatMessage[];