Skip to content

Commit

Permalink
support pending message draft
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelchia committed Jun 13, 2024
1 parent a3b3ce0 commit cb6d5b5
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 26 deletions.
6 changes: 4 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Follow Up Input: {question}
Standalone question:"""
CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE)
PENDING_MESSAGE = "Searching learned documents"


class AskChatHandler(BaseChatHandler):
Expand Down Expand Up @@ -71,8 +72,9 @@ async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()

try:
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
with self.pending(PENDING_MESSAGE):
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
Expand Down
60 changes: 59 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import time
import traceback
import contextlib
from typing import (
TYPE_CHECKING,
Awaitable,
Expand All @@ -17,7 +18,13 @@

from dask.distributed import Client as DaskClient
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage
from jupyter_ai.models import (
AgentChatMessage,
ChatMessage,
HumanChatMessage,
PendingMessage,
ClosePendingMessage,
)
from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from langchain.pydantic_v1 import BaseModel
Expand Down Expand Up @@ -192,6 +199,57 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):

handler.broadcast_message(agent_msg)
break

def start_pending(self, text: str, ellipsis: bool = True) -> str:
"""
Sends a pending message to the client.
Returns the pending message ID.
"""
persona = self.config_manager.persona

pending_msg = PendingMessage(
id=uuid4().hex,
time=time.time(),
body=text,
persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
ellipsis=ellipsis,
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(pending_msg)
break
return pending_msg

def close_pending(self, pending_msg: PendingMessage):
"""
Closes a pending message.
"""
close_pending_msg = ClosePendingMessage(
id=pending_msg.id,
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(close_pending_msg)
break

@contextlib.contextmanager
def pending(self, text: str, ellipsis: bool = True):
"""
Context manager that sends a pending message to the client, and closes
it after the block is executed.
"""
pending_msg = self.start_pending(text, ellipsis=ellipsis)
try:
yield
finally:
self.close_pending(pending_msg)

def get_llm_chain(self):
lm_provider = self.config_manager.lm_provider
Expand Down
5 changes: 4 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from .base import BaseChatHandler, SlashCommandRoutingType

PENDING_MESSAGE = "Thinking"


class DefaultChatHandler(BaseChatHandler):
id = "default"
Expand Down Expand Up @@ -45,5 +47,6 @@ def create_llm_chain(

async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
response = await self.llm_chain.apredict(input=message.body, stop=["\nHuman:"])
with self.pending(PENDING_MESSAGE):
response = await self.llm_chain.apredict(input=message.body, stop=["\nHuman:"])
self.reply(response, message)
26 changes: 13 additions & 13 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,19 @@ async def process_message(self, message: HumanChatMessage):
# delete and relearn index if embedding model was changed
await self.delete_and_relearn()

if args.verbose:
self.reply(f"Loading and splitting files for {load_path}", message)

try:
await self.learn_dir(
load_path, args.chunk_size, args.chunk_overlap, args.all_files
)
except Exception as e:
response = f"""Learn documents in **{load_path}** failed. {str(e)}."""
else:
self.save()
response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them.
You can ask questions about these docs by prefixing your message with **/ask**."""
# if args.verbose:
# self.reply(f"Loading and splitting files for {load_path}", message)
with self.pending(f"Loading and splitting files for {load_path}"):
try:
await self.learn_dir(
load_path, args.chunk_size, args.chunk_overlap, args.all_files
)
except Exception as e:
response = f"""Learn documents in **{load_path}** failed. {str(e)}."""
else:
self.save()
response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them.
You can ask questions about these docs by prefixing your message with **/ask**."""
self.reply(response, message)

def _build_list_response(self):
Expand Down
23 changes: 22 additions & 1 deletion packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,34 @@ class ClearMessage(BaseModel):
type: Literal["clear"] = "clear"


class PendingMessage(BaseModel):
type: Literal["pending"] = "pending"
id: str
time: float
body: str
persona: Persona
ellipsis: bool = True


class ClosePendingMessage(BaseModel):
type: Literal["pending"] = "close-pending"
id: str


# the type of messages being broadcast to clients
ChatMessage = Union[
AgentChatMessage,
HumanChatMessage,
]

Message = Union[AgentChatMessage, HumanChatMessage, ConnectionMessage, ClearMessage]
Message = Union[
AgentChatMessage,
HumanChatMessage,
ConnectionMessage,
ClearMessage,
PendingMessage,
ClosePendingMessage,
]


class ChatHistory(BaseModel):
Expand Down
29 changes: 22 additions & 7 deletions packages/jupyter-ai/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime';

import { JlThemeProvider } from './jl-theme-provider';
import { ChatMessages } from './chat-messages';
import { PendingMessages } from './pending-messages';
import { ChatInput } from './chat-input';
import { ChatSettings } from './chat-settings';
import { AiService } from '../handler';
Expand Down Expand Up @@ -38,6 +39,9 @@ function ChatBody({
rmRegistry: renderMimeRegistry
}: ChatBodyProps): JSX.Element {
const [messages, setMessages] = useState<AiService.ChatMessage[]>([]);
const [pendingMessages, setPendingMessages] = useState<
AiService.PendingMessage[]
>([]);
const [showWelcomeMessage, setShowWelcomeMessage] = useState<boolean>(false);
const [includeSelection, setIncludeSelection] = useState(true);
const [replaceSelection, setReplaceSelection] = useState(false);
Expand Down Expand Up @@ -73,14 +77,24 @@ function ChatBody({
*/
useEffect(() => {
function handleChatEvents(message: AiService.Message) {
if (message.type === 'connection') {
return;
} else if (message.type === 'clear') {
setMessages([]);
return;
switch (message.type) {
case 'connection':
return;
case 'clear':
setMessages([]);
return;
case 'pending':
setPendingMessages(pendingMessages => [...pendingMessages, message]);
return;
case 'close-pending':
setPendingMessages(pendingMessages =>
pendingMessages.filter(p => p.id !== message.id)
);
return;
default:
setMessages(messageGroups => [...messageGroups, message]);
return;
}

setMessages(messageGroups => [...messageGroups, message]);
}

chatHandler.addListener(handleChatEvents);
Expand Down Expand Up @@ -157,6 +171,7 @@ function ChatBody({
<>
<ScrollContainer sx={{ flexGrow: 1 }}>
<ChatMessages messages={messages} rmRegistry={renderMimeRegistry} />
<PendingMessages messages={pendingMessages} />
</ScrollContainer>
<ChatInput
value={input}
Expand Down
128 changes: 128 additions & 0 deletions packages/jupyter-ai/src/components/pending-messages.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import React, { useState, useEffect } from 'react';

import { Box } from '@mui/material';
import { AiService } from '../handler';
import { ChatMessageHeader } from './chat-messages';

type PendingMessagesProps = {
messages: AiService.PendingMessage[];
};

type PendingMessageElementProps = {
text: string;
ellipsis: boolean;
};

type PendingMessageGroup = {
// Creating lastMessage as an AgentChatMessage
// as a hacky way to reuse ChatMessageHeader
// TODO: Refactor ChatMessageHeader to accept PendingMessage
lastMessage: AiService.AgentChatMessage;
messages: AiService.PendingMessage[];
};

function PendingMessageElement(props: PendingMessageElementProps): JSX.Element {
if (!props.ellipsis) {
return <span>{props.text}</span>;
}
const [dots, setDots] = useState('');

useEffect(() => {
const interval = setInterval(() => {
setDots(dots => (dots.length < 3 ? dots + '.' : ''));
}, 500);

return () => clearInterval(interval);
}, []);
return <span>{props.text + dots}</span>;
}

export function PendingMessages(props: PendingMessagesProps): JSX.Element {
if (props.messages.length === 0) {
return <></>;
}

const [timestamps, setTimestamps] = useState<Record<string, string>>({});
const personaGroups = groupMessages(props.messages);
/**
* Effect: update cached timestamp strings upon receiving a new message.
*/
useEffect(() => {
const newTimestamps: Record<string, string> = { ...timestamps };
let timestampAdded = false;

for (const message of props.messages) {
if (!(message.id in newTimestamps)) {
// Use the browser's default locale
newTimestamps[message.id] = new Date(message.time * 1000) // Convert message time to milliseconds
.toLocaleTimeString([], {
hour: 'numeric', // Avoid leading zero for hours; we don't want "03:15 PM"
minute: '2-digit'
});

timestampAdded = true;
}
}
if (timestampAdded) {
setTimestamps(newTimestamps);
}
}, [personaGroups.map(group => group.lastMessage)]);

return (
<Box
sx={{
borderTop: '1px solid var(--jp-border-color2)',
'& > :not(:last-child)': {
borderBottom: '1px solid var(--jp-border-color2)'
}
}}
>
{personaGroups.map((group, i) => (
<Box key={i} sx={{ padding: 4 }}>
<ChatMessageHeader
message={group.lastMessage}
timestamp={timestamps[group.lastMessage.id]}
sx={{ marginBottom: 3 }}
/>
{group.messages.map((message, j) => (
<Box key={j} sx={{ padding: 2 }}>
<PendingMessageElement
text={message.body}
ellipsis={message.ellipsis}
/>
</Box>
))}
</Box>
))}
</Box>
);
}

function groupMessages(
messages: AiService.PendingMessage[]
): PendingMessageGroup[] {
const groups: PendingMessageGroup[] = [];
const personaMap = new Map<string, AiService.PendingMessage[]>();
for (const message of messages) {
if (!personaMap.has(message.persona.name)) {
personaMap.set(message.persona.name, []);
}
personaMap.get(message.persona.name)?.push(message);
}
// create a dummy agent message for each persona group
for (const messages of personaMap.values()) {
const lastMessage = messages[messages.length - 1];
groups.push({
lastMessage: {
type: 'agent',
id: lastMessage.id,
time: lastMessage.time,
body: '',
reply_to: '',
persona: lastMessage.persona
},
messages
});
}
return groups;
}
Loading

0 comments on commit cb6d5b5

Please sign in to comment.