Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream messages #3

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.outputs.generation import GenerationChunk

try:
from jupyterlab_collaborative_chat.ychat import YChat
Expand Down Expand Up @@ -62,27 +63,29 @@ def create_llm_chain(
)
self.llm_chain = runnable

def _start_stream(self, human_msg: HumanChatMessage) -> str:
def _start_stream(self, human_msg: HumanChatMessage, chat: YChat | None) -> str:
"""
Sends an `agent-stream` message to indicate the start of a response
stream. Returns the ID of the message, denoted as the `stream_id`.
"""
stream_id = uuid4().hex
stream_msg = AgentStreamMessage(
id=stream_id,
time=time.time(),
body="",
reply_to=human_msg.id,
persona=self.persona,
complete=False,
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(stream_msg)
break
if chat is not None:
stream_id = self.write_message(chat, "")
else:
stream_id = uuid4().hex
stream_msg = AgentStreamMessage(
id=stream_id,
time=time.time(),
body="",
reply_to=human_msg.id,
persona=self.persona,
complete=False,
)
for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(stream_msg)
break

return stream_id

Expand All @@ -92,7 +95,7 @@ def _send_stream_chunk(self, stream_id: str, content: str, chat: YChat | None, c
appended to an existing `agent-stream` message with ID `stream_id`.
"""
if chat is not None:
self.write_message(chat, content)
self.write_message(chat, content, stream_id)
else:
stream_chunk_msg = AgentStreamChunkMessage(
id=stream_id, content=content, stream_complete=complete
Expand Down Expand Up @@ -121,11 +124,11 @@ async def process_message(self, message: HumanChatMessage, chat: YChat | None):
# when receiving the first chunk, close the pending message and
# start the stream.
self.close_pending(pending_message)
stream_id = self._start_stream(human_msg=message)
stream_id = self._start_stream(message, chat)
received_first_chunk = True

if isinstance(chunk, AIMessageChunk):
self._send_stream_chunk(stream_id, chunk.content)
self._send_stream_chunk(stream_id, chunk.content, chat)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk, chat)
else:
Expand Down
19 changes: 15 additions & 4 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def initialize(self):
listener=self.connect_chat
)

# Keep the message indexes to avoid extra computation looking for a message when
# updating it.
self.messages_indexes = {}

async def connect_chat(self, logger: EventLogger, schema_id: str, data: dict) -> None:
if data["room"].startswith("text:chat:") \
and data["action"] == "initialize"\
Expand Down Expand Up @@ -299,7 +303,7 @@ async def _route(self, message: HumanChatMessage, chat: YChat):
command_readable = "Default" if command == "default" else command
self.log.info(f"{command_readable} chat handler resolved in {latency_ms} ms.")

def write_message(self, chat: YChat, body: str) -> None:
def write_message(self, chat: YChat, body: str, id: str | None = None) -> str:
BOT["avatar_url"] = url_path_join(
self.settings.get("base_url", "/"),
"api/ai/static/jupyternaut.svg"
Expand All @@ -310,14 +314,21 @@ def write_message(self, chat: YChat, body: str) -> None:
else:
BOT["username"] = bot["username"]

chat.add_message({
index = self.messages_indexes[id] if id else None
id = id if id else str(uuid.uuid4())
new_index = chat.set_message({
"type": "msg",
"body": body,
"id": str(uuid.uuid4()),
"id": id if id else str(uuid.uuid4()),
"time": time.time(),
"sender": BOT["username"],
"raw_time": False
})
}, index, True)

if new_index != index:
self.messages_indexes[id] = new_index

return id

def initialize_settings(self):
start = time.time()
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"dependencies": {
"@emotion/react": "^11.10.5",
"@emotion/styled": "^11.10.5",
"@jupyter/chat": "^0.3.1",
"@jupyter/chat": "^0.4.0",
"@jupyter/collaboration": "^1",
"@jupyterlab/application": "^4.2.0",
"@jupyterlab/apputils": "^4.2.0",
Expand Down
3 changes: 1 addition & 2 deletions packages/jupyter-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"typing_extensions>=4.5.0",
"traitlets>=5.0",
"deepmerge>=1.0",
"jupyterlab-collaborative-chat",
]

dynamic = ["version", "description", "authors", "urls", "keywords"]
Expand All @@ -56,8 +57,6 @@ dev = ["jupyter_ai_magics[dev]"]

all = ["jupyter_ai_magics[all]", "pypdf", "arxiv"]

collaborative = ["jupyterlab-collaborative-chat"]

[tool.hatch.version]
source = "nodejs"

Expand Down
10 changes: 5 additions & 5 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -2220,7 +2220,7 @@ __metadata:
"@babel/preset-env": ^7.0.0
"@emotion/react": ^11.10.5
"@emotion/styled": ^11.10.5
"@jupyter/chat": ^0.3.1
"@jupyter/chat": ^0.4.0
"@jupyter/collaboration": ^1
"@jupyterlab/application": ^4.2.0
"@jupyterlab/apputils": ^4.2.0
Expand Down Expand Up @@ -2286,9 +2286,9 @@ __metadata:
languageName: unknown
linkType: soft

"@jupyter/chat@npm:^0.3.1":
version: 0.3.1
resolution: "@jupyter/chat@npm:0.3.1"
"@jupyter/chat@npm:^0.4.0":
version: 0.4.0
resolution: "@jupyter/chat@npm:0.4.0"
dependencies:
"@emotion/react": ^11.10.5
"@emotion/styled": ^11.10.5
Expand All @@ -2306,7 +2306,7 @@ __metadata:
clsx: ^2.1.0
react: ^18.2.0
react-dom: ^18.2.0
checksum: 92d1f6d6d3083be2a8c2309fc37eb3bb39bc184e1c41d18f79d51f87e54d7badba1495c5943916e7a21fc65414d6100860b044a14554add8192b545fca2748dd
checksum: 6e309c8e70cf480103eb26f3109da417c58d2e6844d5e56e63feabf71926f9dba6f9bc85caff765dfc574a8fd7ed803a8c03e5d812c28568dcf6ec918bbd2e66
languageName: node
linkType: hard

Expand Down