Skip to content

Commit

Permalink
improve chat history handling
Browse files Browse the repository at this point in the history
- ensures users never miss streamed chunks when joining
- also removes temporary print/log statements introduced in prev commit
  • Loading branch information
dlqqq committed Jun 25, 2024
1 parent f80e80e commit 92d07b8
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 122 deletions.
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
def persona(self):
return self.config_manager.persona

def start_pending(self, text: str, ellipsis: bool = True) -> str:
def start_pending(self, text: str, ellipsis: bool = True) -> PendingMessage:
"""
Sends a pending message to the client.
Expand Down
36 changes: 25 additions & 11 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from jupyter_ai.models import HumanChatMessage, AgentStreamMessage, AgentStreamChunkMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationChain, LLMChain
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.messages import AIMessageChunk

from .base import BaseChatHandler, SlashCommandRoutingType
from ..history import BoundedChatHistory
from langchain_core.runnables.history import RunnableWithMessageHistory


class DefaultChatHandler(BaseChatHandler):
Expand Down Expand Up @@ -64,6 +64,7 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str:
body="",
reply_to=human_msg.id,
persona=self.persona,
complete=False
)

for handler in self._root_chat_handlers.values():
Expand Down Expand Up @@ -96,16 +97,29 @@ def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = Fals

async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
received_first_chunk = False

stream_id = self._start_stream(human_msg=message)
# start with a pending message
pending_message = self.start_pending("Generating response")

# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
async for chunk in self.llm_chain.astream({ "input": message.body }, config={"configurable": {"session_id": "static_session"}}):
self.log.error(chunk.content)
self._send_stream_chunk(stream_id, chunk.content)
if not received_first_chunk:
# when receiving the first chunk, close the pending message and
# start the stream.
self.close_pending(pending_message)
stream_id = self._start_stream(human_msg=message)
received_first_chunk = True

if isinstance(chunk, AIMessageChunk):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
else:
self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}")
break

# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)

# with self.pending("Generating response"):
# response = await self.llm_chain.apredict(
# input=message.body, stop=["\nHuman:"]
# )
# self.reply(response, message)
3 changes: 3 additions & 0 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ def initialize_settings(self):
# memory object used by the LM chain.
self.settings["chat_history"] = []

# list of pending messages
self.settings["pending_messages"] = []

# get reference to event loop
# `asyncio.get_event_loop()` is deprecated in Python 3.11+, in favor of
# the more readable `asyncio.get_event_loop_policy().get_event_loop()`.
Expand Down
59 changes: 48 additions & 11 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from .models import (
AgentChatMessage,
AgentStreamMessage,
AgentStreamChunkMessage,
ChatClient,
ChatHistory,
ChatMessage,
Expand All @@ -29,6 +31,8 @@
ListSlashCommandsEntry,
ListSlashCommandsResponse,
Message,
PendingMessage,
ClosePendingMessage,
UpdateConfigRequest,
)

Expand All @@ -43,16 +47,19 @@ class ChatHistoryHandler(BaseAPIHandler):
_messages = []

@property
def chat_history(self):
def chat_history(self) -> List[ChatMessage]:
return self.settings["chat_history"]

@chat_history.setter
def _chat_history_setter(self, new_history):
self.settings["chat_history"] = new_history
@property
def pending_messages(self) -> List[PendingMessage]:
return self.settings["pending_messages"]

@tornado.web.authenticated
async def get(self):
history = ChatHistory(messages=self.chat_history)
history = ChatHistory(
messages=self.chat_history,
pending_messages=self.pending_messages
)
self.finish(history.json())


Expand Down Expand Up @@ -88,10 +95,22 @@ def chat_client(self) -> ChatClient:
def chat_history(self) -> List[ChatMessage]:
return self.settings["chat_history"]

@chat_history.setter
def chat_history(self, new_history):
self.settings["chat_history"] = new_history

@property
def loop(self) -> AbstractEventLoop:
return self.settings["jai_event_loop"]

@property
def pending_messages(self) -> List[PendingMessage]:
return self.settings["pending_messages"]

@pending_messages.setter
def pending_messages(self, new_pending_messages):
self.settings["pending_messages"] = new_pending_messages

def initialize(self):
self.log.debug("Initializing websocket connection %s", self.request.path)

Expand Down Expand Up @@ -167,7 +186,10 @@ def open(self):
self.root_chat_handlers[client_id] = self
self.chat_clients[client_id] = ChatClient(**current_user, id=client_id)
self.client_id = client_id
self.write_message(ConnectionMessage(client_id=client_id).dict())
self.write_message(ConnectionMessage(
client_id=client_id,
history=ChatHistory(messages=self.chat_history, pending_messages=self.pending_messages)
).dict())

self.log.info(f"Client connected. ID: {client_id}")
self.log.debug("Clients are : %s", self.root_chat_handlers.keys())
Expand All @@ -185,11 +207,26 @@ def broadcast_message(self, message: Message):
if client:
client.write_message(message.dict())

# Only append ChatMessage instances to history, not control messages
if isinstance(message, HumanChatMessage) or isinstance(
message, AgentChatMessage
):
# append all messages of type `ChatMessage` directly to the chat history
if isinstance(message, (HumanChatMessage, AgentChatMessage, AgentStreamMessage)):
self.chat_history.append(message)
elif isinstance(message, AgentStreamChunkMessage):
# for stream chunks, modify the corresponding `AgentStreamMessage`
# by appending its content and potentially marking it as complete.
chunk: AgentStreamChunkMessage = message

# iterate backwards from the end of the list
for i in range(len(self.chat_history) - 1, -1, -1):
if self.chat_history[i].type == 'agent-stream' and self.chat_history[i].id == chunk.id:
stream_message: AgentStreamMessage = self.chat_history[i]
stream_message.body += chunk.content
stream_message.complete = chunk.stream_complete
break
elif isinstance(message, PendingMessage):
self.pending_messages.append(message)
elif isinstance(message, ClosePendingMessage):
self.pending_messages = list(filter(lambda m: m.id != message.id, self.pending_messages))


async def on_message(self, message):
self.log.debug("Message received: %s", message)
Expand Down
26 changes: 15 additions & 11 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class AgentChatMessage(BaseModel):

class AgentStreamMessage(AgentChatMessage):
type: Literal['agent-stream'] = 'agent-stream'
complete: bool
# other attrs inherited from `AgentChatMessage`

class AgentStreamChunkMessage(BaseModel):
Expand All @@ -89,11 +90,6 @@ class HumanChatMessage(BaseModel):
selection: Optional[Selection]


class ConnectionMessage(BaseModel):
type: Literal["connection"] = "connection"
client_id: str


class ClearMessage(BaseModel):
type: Literal["clear"] = "clear"

Expand All @@ -116,8 +112,22 @@ class ClosePendingMessage(BaseModel):
ChatMessage = Union[
AgentChatMessage,
HumanChatMessage,
AgentStreamMessage,
]


class ChatHistory(BaseModel):
"""History of chat messages"""
messages: List[ChatMessage]
pending_messages: List[PendingMessage]


class ConnectionMessage(BaseModel):
type: Literal["connection"] = "connection"
client_id: str
history: ChatHistory


Message = Union[
AgentChatMessage,
HumanChatMessage,
Expand All @@ -128,12 +138,6 @@ class ClosePendingMessage(BaseModel):
]


class ChatHistory(BaseModel):
"""History of chat messages"""

messages: List[ChatMessage]


class ListProvidersEntry(BaseModel):
"""Model provider with supported models
and provider's authentication strategy
Expand Down
111 changes: 90 additions & 21 deletions packages/jupyter-ai/src/chat_handler.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { IDisposable } from '@lumino/disposable';
import { ServerConnection } from '@jupyterlab/services';
import { URLExt } from '@jupyterlab/coreutils';
import { AiService, requestAPI } from './handler';
import { Signal } from '@lumino/signaling';

import { AiService } from './handler';

const CHAT_SERVICE_URL = 'api/ai/chats';

Expand Down Expand Up @@ -65,18 +67,6 @@ export class ChatHandler implements IDisposable {
}
}

public async getHistory(): Promise<AiService.ChatHistory> {
let data: AiService.ChatHistory = { messages: [] };
try {
data = await requestAPI('chats/history', {
method: 'GET'
});
} catch (e) {
return Promise.reject(e);
}
return data;
}

/**
* Whether the chat handler is disposed.
*/
Expand Down Expand Up @@ -106,23 +96,84 @@ export class ChatHandler implements IDisposable {
}
}

private _onMessage(message: AiService.Message): void {
get history(): AiService.ChatHistory {
return {
messages: this._messages,
pending_messages: this._pendingMessages
};
}

get historyChanged(): Signal<this, AiService.ChatHistory> {
return this._historyChanged;
}

private _onMessage(newMessage: AiService.Message): void {
// resolve promise from `sendMessage()`
if (message.type === 'human' && message.client.id === this.id) {
this._sendResolverQueue.shift()?.(message.id);
if (newMessage.type === 'human' && newMessage.client.id === this.id) {
this._sendResolverQueue.shift()?.(newMessage.id);
}

// resolve promise from `replyFor()` if it exists
if (
message.type === 'agent' &&
message.reply_to in this._replyForResolverDict
newMessage.type === 'agent' &&
newMessage.reply_to in this._replyForResolverDict
) {
this._replyForResolverDict[message.reply_to](message);
delete this._replyForResolverDict[message.reply_to];
this._replyForResolverDict[newMessage.reply_to](newMessage);
delete this._replyForResolverDict[newMessage.reply_to];
}

// call listeners in serial
this._listeners.forEach(listener => listener(message));
this._listeners.forEach(listener => listener(newMessage));

// append message to chat history. this block should always set `_messages`
// or `_pendingMessages` to a new array instance rather than modifying
// in-place so consumer React components re-render.
switch (newMessage.type) {
case 'connection':
break;
case 'clear':
this._messages = [];
break;
case 'pending':
this._pendingMessages = [...this._pendingMessages, newMessage];
break;
case 'close-pending':
this._pendingMessages = this._pendingMessages.filter(
p => p.id !== newMessage.id
);
break;
case 'agent-stream-chunk': {
const target = newMessage.id;
const streamMessage = this._messages.find<AiService.AgentStreamMessage>(
(m): m is AiService.AgentStreamMessage =>
m.type === 'agent-stream' && m.id === target
);
if (!streamMessage) {
console.error(
`Received stream chunk with ID ${target}, but no agent-stream message with that ID exists. ` +
'Ignoring this stream chunk.'
);
break;
}

streamMessage.body += newMessage.content;
if (newMessage.stream_complete) {
streamMessage.complete = true;
}
this._messages = [...this._messages];
break;
}
default:
// human or agent chat message
this._messages = [...this._messages, newMessage];
break;
}

// finally, trigger `historyChanged` signal
this._historyChanged.emit({
messages: this._messages,
pending_messages: this._pendingMessages
});
}

/**
Expand Down Expand Up @@ -173,6 +224,11 @@ export class ChatHandler implements IDisposable {
return;
}
this.id = message.client_id;

// initialize chat history from `ConnectionMessage`
this._messages = message.history.messages;
this._pendingMessages = message.history.pending_messages;

resolve();
this.removeListener(listenForConnection);
};
Expand All @@ -184,4 +240,17 @@ export class ChatHandler implements IDisposable {
private _isDisposed = false;
private _socket: WebSocket | null = null;
private _listeners: ((msg: any) => void)[] = [];

/**
* The list of chat messages
*/
private _messages: AiService.ChatMessage[] = [];
private _pendingMessages: AiService.PendingMessage[] = [];

/**
* Signal for when the chat history is changed. Components rendering the chat
* history should subscribe to this signal and update their state when this
* signal is triggered.
*/
private _historyChanged = new Signal<this, AiService.ChatHistory>(this);
}
Loading

0 comments on commit 92d07b8

Please sign in to comment.