From 92d07b858553ae9e38016e6f06c0df80a67ca20b Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 25 Jun 2024 15:49:39 -0700 Subject: [PATCH] improve chat history handling - ensures users never miss streamed chunks when joining - also removes temporary print/log statements introduced in prev commit --- .../jupyter_ai/chat_handlers/base.py | 2 +- .../jupyter_ai/chat_handlers/default.py | 36 ++++-- packages/jupyter-ai/jupyter_ai/extension.py | 3 + packages/jupyter-ai/jupyter_ai/handlers.py | 59 ++++++++-- packages/jupyter-ai/jupyter_ai/models.py | 26 ++-- packages/jupyter-ai/src/chat_handler.ts | 111 ++++++++++++++---- packages/jupyter-ai/src/components/chat.tsx | 73 +++--------- packages/jupyter-ai/src/handler.ts | 11 +- 8 files changed, 199 insertions(+), 122 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 5509f7924..6bee9ec20 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -202,7 +202,7 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): def persona(self): return self.config_manager.persona - def start_pending(self, text: str, ellipsis: bool = True) -> str: + def start_pending(self, text: str, ellipsis: bool = True) -> PendingMessage: """ Sends a pending message to the client. diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 346b4708c..055df1afb 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -4,12 +4,12 @@ from jupyter_ai.models import HumanChatMessage, AgentStreamMessage, AgentStreamChunkMessage from jupyter_ai_magics.providers import BaseProvider -from langchain.chains import ConversationChain, LLMChain from langchain.memory import ConversationBufferWindowMemory +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.messages import AIMessageChunk from .base import BaseChatHandler, SlashCommandRoutingType from ..history import BoundedChatHistory -from langchain_core.runnables.history import RunnableWithMessageHistory class DefaultChatHandler(BaseChatHandler): @@ -64,6 +64,7 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str: body="", reply_to=human_msg.id, persona=self.persona, + complete=False ) for handler in self._root_chat_handlers.values(): @@ -96,16 +97,29 @@ def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = Fals async def process_message(self, message: HumanChatMessage): self.get_llm_chain() + received_first_chunk = False - stream_id = self._start_stream(human_msg=message) + # start with a pending message + pending_message = self.start_pending("Generating response") + + # stream response in chunks. this works even if a provider does not + # implement streaming, as `astream()` defaults to yielding `_call()` + # when `_stream()` is not implemented on the LLM class. async for chunk in self.llm_chain.astream({ "input": message.body }, config={"configurable": {"session_id": "static_session"}}): - self.log.error(chunk.content) - self._send_stream_chunk(stream_id, chunk.content) + if not received_first_chunk: + # when receiving the first chunk, close the pending message and + # start the stream. + self.close_pending(pending_message) + stream_id = self._start_stream(human_msg=message) + received_first_chunk = True + + if isinstance(chunk, AIMessageChunk): + self._send_stream_chunk(stream_id, chunk.content) + elif isinstance(chunk, str): + self._send_stream_chunk(stream_id, chunk) + else: + self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") + break + # complete stream after all chunks have been streamed self._send_stream_chunk(stream_id, "", complete=True) - - # with self.pending("Generating response"): - # response = await self.llm_chain.apredict( - # input=message.body, stop=["\nHuman:"] - # ) - # self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 848a6eed7..f97ff9ee9 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -222,6 +222,9 @@ def initialize_settings(self): # memory object used by the LM chain. self.settings["chat_history"] = [] + # list of pending messages + self.settings["pending_messages"] = [] + # get reference to event loop # `asyncio.get_event_loop()` is deprecated in Python 3.11+, in favor of # the more readable `asyncio.get_event_loop_policy().get_event_loop()`. diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 339863c3f..6ba3560f5 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -17,6 +17,8 @@ from .models import ( AgentChatMessage, + AgentStreamMessage, + AgentStreamChunkMessage, ChatClient, ChatHistory, ChatMessage, @@ -29,6 +31,8 @@ ListSlashCommandsEntry, ListSlashCommandsResponse, Message, + PendingMessage, + ClosePendingMessage, UpdateConfigRequest, ) @@ -43,16 +47,19 @@ class ChatHistoryHandler(BaseAPIHandler): _messages = [] @property - def chat_history(self): + def chat_history(self) -> List[ChatMessage]: return self.settings["chat_history"] - - @chat_history.setter - def _chat_history_setter(self, new_history): - self.settings["chat_history"] = new_history + + @property + def pending_messages(self) -> List[PendingMessage]: + return self.settings["pending_messages"] @tornado.web.authenticated async def get(self): - history = ChatHistory(messages=self.chat_history) + history = ChatHistory( + messages=self.chat_history, + pending_messages=self.pending_messages + ) self.finish(history.json()) @@ -88,10 +95,22 @@ def chat_client(self) -> ChatClient: def chat_history(self) -> List[ChatMessage]: return self.settings["chat_history"] + @chat_history.setter + def chat_history(self, new_history): + self.settings["chat_history"] = new_history + @property def loop(self) -> AbstractEventLoop: return self.settings["jai_event_loop"] + @property + def pending_messages(self) -> List[PendingMessage]: + return self.settings["pending_messages"] + + @pending_messages.setter + def pending_messages(self, new_pending_messages): + self.settings["pending_messages"] = new_pending_messages + def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) @@ -167,7 +186,10 @@ def open(self): self.root_chat_handlers[client_id] = self self.chat_clients[client_id] = ChatClient(**current_user, id=client_id) self.client_id = client_id - self.write_message(ConnectionMessage(client_id=client_id).dict()) + self.write_message(ConnectionMessage( + client_id=client_id, + history=ChatHistory(messages=self.chat_history, pending_messages=self.pending_messages) + ).dict()) self.log.info(f"Client connected. ID: {client_id}") self.log.debug("Clients are : %s", self.root_chat_handlers.keys()) @@ -185,11 +207,26 @@ def broadcast_message(self, message: Message): if client: client.write_message(message.dict()) - # Only append ChatMessage instances to history, not control messages - if isinstance(message, HumanChatMessage) or isinstance( - message, AgentChatMessage - ): + # append all messages of type `ChatMessage` directly to the chat history + if isinstance(message, (HumanChatMessage, AgentChatMessage, AgentStreamMessage)): self.chat_history.append(message) + elif isinstance(message, AgentStreamChunkMessage): + # for stream chunks, modify the corresponding `AgentStreamMessage` + # by appending its content and potentially marking it as complete. + chunk: AgentStreamChunkMessage = message + + # iterate backwards from the end of the list + for i in range(len(self.chat_history) - 1, -1, -1): + if self.chat_history[i].type == 'agent-stream' and self.chat_history[i].id == chunk.id: + stream_message: AgentStreamMessage = self.chat_history[i] + stream_message.body += chunk.content + stream_message.complete = chunk.stream_complete + break + elif isinstance(message, PendingMessage): + self.pending_messages.append(message) + elif isinstance(message, ClosePendingMessage): + self.pending_messages = list(filter(lambda m: m.id != message.id, self.pending_messages)) + async def on_message(self, message): self.log.debug("Message received: %s", message) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 37cc35058..d59451d76 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -70,6 +70,7 @@ class AgentChatMessage(BaseModel): class AgentStreamMessage(AgentChatMessage): type: Literal['agent-stream'] = 'agent-stream' + complete: bool # other attrs inherited from `AgentChatMessage` class AgentStreamChunkMessage(BaseModel): @@ -89,11 +90,6 @@ class HumanChatMessage(BaseModel): selection: Optional[Selection] -class ConnectionMessage(BaseModel): - type: Literal["connection"] = "connection" - client_id: str - - class ClearMessage(BaseModel): type: Literal["clear"] = "clear" @@ -116,8 +112,22 @@ class ClosePendingMessage(BaseModel): ChatMessage = Union[ AgentChatMessage, HumanChatMessage, + AgentStreamMessage, ] + +class ChatHistory(BaseModel): + """History of chat messages""" + messages: List[ChatMessage] + pending_messages: List[PendingMessage] + + +class ConnectionMessage(BaseModel): + type: Literal["connection"] = "connection" + client_id: str + history: ChatHistory + + Message = Union[ AgentChatMessage, HumanChatMessage, @@ -128,12 +138,6 @@ class ClosePendingMessage(BaseModel): ] -class ChatHistory(BaseModel): - """History of chat messages""" - - messages: List[ChatMessage] - - class ListProvidersEntry(BaseModel): """Model provider with supported models and provider's authentication strategy diff --git a/packages/jupyter-ai/src/chat_handler.ts b/packages/jupyter-ai/src/chat_handler.ts index fb54df519..f1b131dcf 100644 --- a/packages/jupyter-ai/src/chat_handler.ts +++ b/packages/jupyter-ai/src/chat_handler.ts @@ -1,7 +1,9 @@ import { IDisposable } from '@lumino/disposable'; import { ServerConnection } from '@jupyterlab/services'; import { URLExt } from '@jupyterlab/coreutils'; -import { AiService, requestAPI } from './handler'; +import { Signal } from '@lumino/signaling'; + +import { AiService } from './handler'; const CHAT_SERVICE_URL = 'api/ai/chats'; @@ -65,18 +67,6 @@ export class ChatHandler implements IDisposable { } } - public async getHistory(): Promise { - let data: AiService.ChatHistory = { messages: [] }; - try { - data = await requestAPI('chats/history', { - method: 'GET' - }); - } catch (e) { - return Promise.reject(e); - } - return data; - } - /** * Whether the chat handler is disposed. */ @@ -106,23 +96,84 @@ export class ChatHandler implements IDisposable { } } - private _onMessage(message: AiService.Message): void { + get history(): AiService.ChatHistory { + return { + messages: this._messages, + pending_messages: this._pendingMessages + }; + } + + get historyChanged(): Signal { + return this._historyChanged; + } + + private _onMessage(newMessage: AiService.Message): void { // resolve promise from `sendMessage()` - if (message.type === 'human' && message.client.id === this.id) { - this._sendResolverQueue.shift()?.(message.id); + if (newMessage.type === 'human' && newMessage.client.id === this.id) { + this._sendResolverQueue.shift()?.(newMessage.id); } // resolve promise from `replyFor()` if it exists if ( - message.type === 'agent' && - message.reply_to in this._replyForResolverDict + newMessage.type === 'agent' && + newMessage.reply_to in this._replyForResolverDict ) { - this._replyForResolverDict[message.reply_to](message); - delete this._replyForResolverDict[message.reply_to]; + this._replyForResolverDict[newMessage.reply_to](newMessage); + delete this._replyForResolverDict[newMessage.reply_to]; } // call listeners in serial - this._listeners.forEach(listener => listener(message)); + this._listeners.forEach(listener => listener(newMessage)); + + // append message to chat history. this block should always set `_messages` + // or `_pendingMessages` to a new array instance rather than modifying + // in-place so consumer React components re-render. + switch (newMessage.type) { + case 'connection': + break; + case 'clear': + this._messages = []; + break; + case 'pending': + this._pendingMessages = [...this._pendingMessages, newMessage]; + break; + case 'close-pending': + this._pendingMessages = this._pendingMessages.filter( + p => p.id !== newMessage.id + ); + break; + case 'agent-stream-chunk': { + const target = newMessage.id; + const streamMessage = this._messages.find( + (m): m is AiService.AgentStreamMessage => + m.type === 'agent-stream' && m.id === target + ); + if (!streamMessage) { + console.error( + `Received stream chunk with ID ${target}, but no agent-stream message with that ID exists. ` + + 'Ignoring this stream chunk.' + ); + break; + } + + streamMessage.body += newMessage.content; + if (newMessage.stream_complete) { + streamMessage.complete = true; + } + this._messages = [...this._messages]; + break; + } + default: + // human or agent chat message + this._messages = [...this._messages, newMessage]; + break; + } + + // finally, trigger `historyChanged` signal + this._historyChanged.emit({ + messages: this._messages, + pending_messages: this._pendingMessages + }); } /** @@ -173,6 +224,11 @@ export class ChatHandler implements IDisposable { return; } this.id = message.client_id; + + // initialize chat history from `ConnectionMessage` + this._messages = message.history.messages; + this._pendingMessages = message.history.pending_messages; + resolve(); this.removeListener(listenForConnection); }; @@ -184,4 +240,17 @@ export class ChatHandler implements IDisposable { private _isDisposed = false; private _socket: WebSocket | null = null; private _listeners: ((msg: any) => void)[] = []; + + /** + * The list of chat messages + */ + private _messages: AiService.ChatMessage[] = []; + private _pendingMessages: AiService.PendingMessage[] = []; + + /** + * Signal for when the chat history is changed. Components rendering the chat + * history should subscribe to this signal and update their state when this + * signal is triggered. + */ + private _historyChanged = new Signal(this); } diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index be186aee2..c84ae022b 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -38,10 +38,12 @@ function ChatBody({ setChatView: chatViewHandler, rmRegistry: renderMimeRegistry }: ChatBodyProps): JSX.Element { - const [messages, setMessages] = useState([]); + const [messages, setMessages] = useState([ + ...chatHandler.history.messages + ]); const [pendingMessages, setPendingMessages] = useState< AiService.PendingMessage[] - >([]); + >([...chatHandler.history.pending_messages]); const [showWelcomeMessage, setShowWelcomeMessage] = useState(false); const [includeSelection, setIncludeSelection] = useState(true); const [replaceSelection, setReplaceSelection] = useState(false); @@ -50,17 +52,13 @@ function ChatBody({ const [sendWithShiftEnter, setSendWithShiftEnter] = useState(true); /** - * Effect: fetch history and config on initial render + * Effect: fetch config on initial render */ useEffect(() => { - async function fetchHistory() { + async function fetchConfig() { try { - const [history, config] = await Promise.all([ - chatHandler.getHistory(), - AiService.getConfig() - ]); + const config = await AiService.getConfig(); setSendWithShiftEnter(config.send_with_shift_enter ?? false); - setMessages(history.messages); if (!config.model_provider_id) { setShowWelcomeMessage(true); } @@ -69,65 +67,22 @@ function ChatBody({ } } - fetchHistory(); + fetchConfig(); }, [chatHandler]); /** * Effect: listen to chat messages */ useEffect(() => { - function handleChatEvents(newMessage: AiService.Message) { - switch (newMessage.type) { - case 'connection': - return; - case 'clear': - setMessages([]); - return; - case 'pending': - setPendingMessages(pendingMessages => [ - ...pendingMessages, - newMessage - ]); - return; - case 'close-pending': - setPendingMessages(pendingMessages => - pendingMessages.filter(p => p.id !== newMessage.id) - ); - return; - case 'agent-stream-chunk': - setMessages(prevMessages => { - const target = newMessage.id; - const streamMessage = - prevMessages.find( - (m): m is AiService.AgentStreamMessage => - m.type === 'agent-stream' && m.id === target - ); - if (!streamMessage) { - console.error( - `Received stream chunk with ID ${target}, but no agent-stream message with that ID exists. ` + - 'Ignoring this stream chunk.' - ); - return prevMessages; - } - - streamMessage.body += newMessage.content; - if (newMessage.stream_complete) { - console.log('COMPLETE SET'); - streamMessage.complete = true; - } - return [...prevMessages]; - }); - return; - default: - // human or agent chat message - setMessages(prevMessages => [...prevMessages, newMessage]); - return; - } + function onHistoryChange(_: unknown, history: AiService.ChatHistory) { + setMessages([...history.messages]); + setPendingMessages([...history.pending_messages]); } - chatHandler.addListener(handleChatEvents); + chatHandler.historyChanged.connect(onHistoryChange); + return function cleanup() { - chatHandler.removeListener(handleChatEvents); + chatHandler.historyChanged.disconnect(onHistoryChange); }; }, [chatHandler]); diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index deda9d3c0..b93a4571a 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -109,6 +109,7 @@ export namespace AiService { export type ConnectionMessage = { type: 'connection'; client_id: string; + history: ChatHistory; }; export type ClearMessage = { @@ -131,14 +132,7 @@ export namespace AiService { export type AgentStreamMessage = Omit & { type: 'agent-stream'; - - /** - * This field only exists in the frontend model to indicate whether the - * stream is complete. When an `AgentStreamChunkMessage` is received with - * `stream_complete=True`, this property is set to `True` to indicate that - * to React components that the stream is complete. - */ - complete?: boolean; + complete: boolean; }; export type AgentStreamChunkMessage = { @@ -165,6 +159,7 @@ export namespace AiService { export type ChatHistory = { messages: ChatMessage[]; + pending_messages: PendingMessage[]; }; export type DescribeConfigResponse = {