From cb6d5b589c6d6e1aac0a45a937c93d58eced529d Mon Sep 17 00:00:00 2001 From: michael Date: Fri, 7 Jun 2024 18:51:58 +0800 Subject: [PATCH] support pending message draft --- .../jupyter_ai/chat_handlers/ask.py | 6 +- .../jupyter_ai/chat_handlers/base.py | 60 +++++++- .../jupyter_ai/chat_handlers/default.py | 5 +- .../jupyter_ai/chat_handlers/learn.py | 26 ++-- packages/jupyter-ai/jupyter_ai/models.py | 23 +++- packages/jupyter-ai/src/components/chat.tsx | 29 +++- .../src/components/pending-messages.tsx | 128 ++++++++++++++++++ packages/jupyter-ai/src/handler.ts | 18 ++- 8 files changed, 269 insertions(+), 26 deletions(-) create mode 100644 packages/jupyter-ai/src/components/pending-messages.tsx diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index d34007a00..86606f3aa 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -16,6 +16,7 @@ Follow Up Input: {question} Standalone question:""" CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE) +PENDING_MESSAGE = "Searching learned documents" class AskChatHandler(BaseChatHandler): @@ -71,8 +72,9 @@ async def process_message(self, message: HumanChatMessage): self.get_llm_chain() try: - result = await self.llm_chain.acall({"question": query}) - response = result["answer"] + with self.pending(PENDING_MESSAGE): + result = await self.llm_chain.acall({"question": query}) + response = result["answer"] self.reply(response, message) except AssertionError as e: self.log.error(e) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 83636776e..437dfd68f 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -2,6 +2,7 @@ import os import time import traceback +import contextlib from typing import ( TYPE_CHECKING, Awaitable, @@ -17,7 +18,13 @@ from dask.distributed import Client as DaskClient from jupyter_ai.config_manager import ConfigManager, Logger -from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage +from jupyter_ai.models import ( + AgentChatMessage, + ChatMessage, + HumanChatMessage, + PendingMessage, + ClosePendingMessage, +) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider from langchain.pydantic_v1 import BaseModel @@ -192,6 +199,57 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): handler.broadcast_message(agent_msg) break + + def start_pending(self, text: str, ellipsis: bool = True) -> str: + """ + Sends a pending message to the client. + + Returns the pending message ID. + """ + persona = self.config_manager.persona + + pending_msg = PendingMessage( + id=uuid4().hex, + time=time.time(), + body=text, + persona=Persona(name=persona.name, avatar_route=persona.avatar_route), + ellipsis=ellipsis, + ) + + for handler in self._root_chat_handlers.values(): + if not handler: + continue + + handler.broadcast_message(pending_msg) + break + return pending_msg + + def close_pending(self, pending_msg: PendingMessage): + """ + Closes a pending message. + """ + close_pending_msg = ClosePendingMessage( + id=pending_msg.id, + ) + + for handler in self._root_chat_handlers.values(): + if not handler: + continue + + handler.broadcast_message(close_pending_msg) + break + + @contextlib.contextmanager + def pending(self, text: str, ellipsis: bool = True): + """ + Context manager that sends a pending message to the client, and closes + it after the block is executed. + """ + pending_msg = self.start_pending(text, ellipsis=ellipsis) + try: + yield + finally: + self.close_pending(pending_msg) def get_llm_chain(self): lm_provider = self.config_manager.lm_provider diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3f936a142..a09660837 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -7,6 +7,8 @@ from .base import BaseChatHandler, SlashCommandRoutingType +PENDING_MESSAGE = "Thinking" + class DefaultChatHandler(BaseChatHandler): id = "default" @@ -45,5 +47,6 @@ def create_llm_chain( async def process_message(self, message: HumanChatMessage): self.get_llm_chain() - response = await self.llm_chain.apredict(input=message.body, stop=["\nHuman:"]) + with self.pending(PENDING_MESSAGE): + response = await self.llm_chain.apredict(input=message.body, stop=["\nHuman:"]) self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 0f10b0147..e7ebc7394 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -151,19 +151,19 @@ async def process_message(self, message: HumanChatMessage): # delete and relearn index if embedding model was changed await self.delete_and_relearn() - if args.verbose: - self.reply(f"Loading and splitting files for {load_path}", message) - - try: - await self.learn_dir( - load_path, args.chunk_size, args.chunk_overlap, args.all_files - ) - except Exception as e: - response = f"""Learn documents in **{load_path}** failed. {str(e)}.""" - else: - self.save() - response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. - You can ask questions about these docs by prefixing your message with **/ask**.""" + # if args.verbose: + # self.reply(f"Loading and splitting files for {load_path}", message) + with self.pending(f"Loading and splitting files for {load_path}"): + try: + await self.learn_dir( + load_path, args.chunk_size, args.chunk_overlap, args.all_files + ) + except Exception as e: + response = f"""Learn documents in **{load_path}** failed. {str(e)}.""" + else: + self.save() + response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. + You can ask questions about these docs by prefixing your message with **/ask**.""" self.reply(response, message) def _build_list_response(self): diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 84e2a524b..9e269a223 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -87,13 +87,34 @@ class ClearMessage(BaseModel): type: Literal["clear"] = "clear" +class PendingMessage(BaseModel): + type: Literal["pending"] = "pending" + id: str + time: float + body: str + persona: Persona + ellipsis: bool = True + + +class ClosePendingMessage(BaseModel): + type: Literal["pending"] = "close-pending" + id: str + + # the type of messages being broadcast to clients ChatMessage = Union[ AgentChatMessage, HumanChatMessage, ] -Message = Union[AgentChatMessage, HumanChatMessage, ConnectionMessage, ClearMessage] +Message = Union[ + AgentChatMessage, + HumanChatMessage, + ConnectionMessage, + ClearMessage, + PendingMessage, + ClosePendingMessage, +] class ChatHistory(BaseModel): diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index abf974054..4c232ed5b 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -9,6 +9,7 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { JlThemeProvider } from './jl-theme-provider'; import { ChatMessages } from './chat-messages'; +import { PendingMessages } from './pending-messages'; import { ChatInput } from './chat-input'; import { ChatSettings } from './chat-settings'; import { AiService } from '../handler'; @@ -38,6 +39,9 @@ function ChatBody({ rmRegistry: renderMimeRegistry }: ChatBodyProps): JSX.Element { const [messages, setMessages] = useState([]); + const [pendingMessages, setPendingMessages] = useState< + AiService.PendingMessage[] + >([]); const [showWelcomeMessage, setShowWelcomeMessage] = useState(false); const [includeSelection, setIncludeSelection] = useState(true); const [replaceSelection, setReplaceSelection] = useState(false); @@ -73,14 +77,24 @@ function ChatBody({ */ useEffect(() => { function handleChatEvents(message: AiService.Message) { - if (message.type === 'connection') { - return; - } else if (message.type === 'clear') { - setMessages([]); - return; + switch (message.type) { + case 'connection': + return; + case 'clear': + setMessages([]); + return; + case 'pending': + setPendingMessages(pendingMessages => [...pendingMessages, message]); + return; + case 'close-pending': + setPendingMessages(pendingMessages => + pendingMessages.filter(p => p.id !== message.id) + ); + return; + default: + setMessages(messageGroups => [...messageGroups, message]); + return; } - - setMessages(messageGroups => [...messageGroups, message]); } chatHandler.addListener(handleChatEvents); @@ -157,6 +171,7 @@ function ChatBody({ <> + {props.text}; + } + const [dots, setDots] = useState(''); + + useEffect(() => { + const interval = setInterval(() => { + setDots(dots => (dots.length < 3 ? dots + '.' : '')); + }, 500); + + return () => clearInterval(interval); + }, []); + return {props.text + dots}; +} + +export function PendingMessages(props: PendingMessagesProps): JSX.Element { + if (props.messages.length === 0) { + return <>; + } + + const [timestamps, setTimestamps] = useState>({}); + const personaGroups = groupMessages(props.messages); + /** + * Effect: update cached timestamp strings upon receiving a new message. + */ + useEffect(() => { + const newTimestamps: Record = { ...timestamps }; + let timestampAdded = false; + + for (const message of props.messages) { + if (!(message.id in newTimestamps)) { + // Use the browser's default locale + newTimestamps[message.id] = new Date(message.time * 1000) // Convert message time to milliseconds + .toLocaleTimeString([], { + hour: 'numeric', // Avoid leading zero for hours; we don't want "03:15 PM" + minute: '2-digit' + }); + + timestampAdded = true; + } + } + if (timestampAdded) { + setTimestamps(newTimestamps); + } + }, [personaGroups.map(group => group.lastMessage)]); + + return ( + :not(:last-child)': { + borderBottom: '1px solid var(--jp-border-color2)' + } + }} + > + {personaGroups.map((group, i) => ( + + + {group.messages.map((message, j) => ( + + + + ))} + + ))} + + ); +} + +function groupMessages( + messages: AiService.PendingMessage[] +): PendingMessageGroup[] { + const groups: PendingMessageGroup[] = []; + const personaMap = new Map(); + for (const message of messages) { + if (!personaMap.has(message.persona.name)) { + personaMap.set(message.persona.name, []); + } + personaMap.get(message.persona.name)?.push(message); + } + // create a dummy agent message for each persona group + for (const messages of personaMap.values()) { + const lastMessage = messages[messages.length - 1]; + groups.push({ + lastMessage: { + type: 'agent', + id: lastMessage.id, + time: lastMessage.time, + body: '', + reply_to: '', + persona: lastMessage.persona + }, + messages + }); + } + return groups; +} diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 5d06691fe..c8f457fe2 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -115,12 +115,28 @@ export namespace AiService { type: 'clear'; }; + export type PendingMessage = { + type: 'pending'; + id: string; + time: number; + body: string; + persona: Persona; + ellipsis: boolean; + }; + + export type ClosePendingMessage = { + type: 'close-pending'; + id: string; + }; + export type ChatMessage = AgentChatMessage | HumanChatMessage; export type Message = | AgentChatMessage | HumanChatMessage | ConnectionMessage - | ClearMessage; + | ClearMessage + | PendingMessage + | ClosePendingMessage; export type ChatHistory = { messages: ChatMessage[];