Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very first version of the AI working in jupyterlab_collaborative_chat #984

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# The following import is to make sure jupyter_ydoc is imported before
# jupyterlab_collaborative_chat, otherwise it leads to circular import because of the
# YChat relying on YBaseDoc, and jupyter_ydoc registering YChat from the entry point.
import jupyter_ydoc

from .ask import AskChatHandler
from .base import BaseChatHandler, SlashCommandRoutingType
from .clear import ClearChatHandler
Expand Down
11 changes: 6 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
from typing import Dict, Type
from jupyterlab_collaborative_chat.ychat import YChat

from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
Expand Down Expand Up @@ -59,13 +60,13 @@ def create_llm_chain(
verbose=False,
)

async def process_message(self, message: HumanChatMessage):
args = self.parse_args(message)
async def process_message(self, message: HumanChatMessage, chat: YChat):
args = self.parse_args(message, chat)
if args is None:
return
query = " ".join(args.query)
if not query:
self.reply(f"{self.parser.format_usage()}", message)
self.reply(f"{self.parser.format_usage()}", chat, message)
return

self.get_llm_chain()
Expand All @@ -74,12 +75,12 @@ async def process_message(self, message: HumanChatMessage):
with self.pending("Searching learned documents"):
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.reply(response, message)
self.reply(response, chat, message)
except AssertionError as e:
self.log.error(e)
response = """Sorry, an error occurred while reading the from the learned documents.
If you have changed the embedding provider, try deleting the existing index by running
`/learn -d` command and then re-submitting the `learn <directory>` to learn the documents,
and then asking the question again.
"""
self.reply(response, message)
self.reply(response, chat, message)
51 changes: 30 additions & 21 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
ClassVar,
Dict,
List,
Expand All @@ -15,6 +16,7 @@
Union,
)
from uuid import uuid4
from jupyterlab_collaborative_chat.ychat import YChat

from dask.distributed import Client as DaskClient
from jupyter_ai.config_manager import ConfigManager, Logger
Expand Down Expand Up @@ -132,6 +134,7 @@ def __init__(
dask_client_future: Awaitable[DaskClient],
help_message_template: str,
chat_handlers: Dict[str, "BaseChatHandler"],
write_message: Callable[[YChat, str], None]
):
self.log = log
self.config_manager = config_manager
Expand All @@ -157,13 +160,16 @@ def __init__(
self.llm_params = None
self.llm_chain = None

async def on_message(self, message: HumanChatMessage):
self.write_message = write_message

async def on_message(self, message: HumanChatMessage, chat: YChat):
"""
Method which receives a human message, calls `self.get_llm_chain()`, and
processes the message via `self.process_message()`, calling
`self.handle_exc()` when an exception is raised. This method is called
by RootChatHandler when it routes a human message to this chat handler.
"""
self.log.warn(f"MESSAGE SENT {message.body}")
lm_provider_klass = self.config_manager.lm_provider

# ensure the current slash command is supported
Expand All @@ -173,7 +179,8 @@ async def on_message(self, message: HumanChatMessage):
)
if slash_command in lm_provider_klass.unsupported_slash_commands:
self.reply(
"Sorry, the selected language model does not support this slash command."
"Sorry, the selected language model does not support this slash command.",
chat
)
return

Expand All @@ -185,50 +192,51 @@ async def on_message(self, message: HumanChatMessage):
if not lm_provider.allows_concurrency:
self.reply(
"The currently selected language model can process only one request at a time. Please wait for me to reply before sending another question.",
chat,
message,
)
return

BaseChatHandler._requests_count += 1

if self.__class__.supports_help:
args = self.parse_args(message, silent=True)
args = self.parse_args(message, chat, silent=True)
if args and args.help:
self.reply(self.parser.format_help(), message)
self.reply(self.parser.format_help(), chat, message)
return

try:
await self.process_message(message)
await self.process_message(message, chat)
except Exception as e:
try:
# we try/except `handle_exc()` in case it was overriden and
# raises an exception by accident.
await self.handle_exc(e, message)
await self.handle_exc(e, message, chat)
except Exception as e:
await self._default_handle_exc(e, message)
finally:
BaseChatHandler._requests_count -= 1

async def process_message(self, message: HumanChatMessage):
async def process_message(self, message: HumanChatMessage, chat: YChat):
"""
Processes a human message routed to this chat handler. Chat handlers
(subclasses) must implement this method. Don't forget to call
`self.reply(<response>, message)` at the end!
`self.reply(<response>, chat, message)` at the end!

The method definition does not need to be wrapped in a try/except block;
any exceptions raised here are caught by `self.handle_exc()`.
"""
raise NotImplementedError("Should be implemented by subclasses.")

async def handle_exc(self, e: Exception, message: HumanChatMessage):
async def handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat):
"""
Handles an exception raised by `self.process_message()`. A default
implementation is provided, however chat handlers (subclasses) should
implement this method to provide a more helpful error response.
"""
await self._default_handle_exc(e, message)
await self._default_handle_exc(e, message, chat)

async def _default_handle_exc(self, e: Exception, message: HumanChatMessage):
async def _default_handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat):
"""
The default definition of `handle_exc()`. This is the default used when
the `handle_exc()` excepts.
Expand All @@ -238,15 +246,15 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage):
if lm_provider and lm_provider.is_api_key_exc(e):
provider_name = getattr(self.config_manager.lm_provider, "name", "")
response = f"Oops! There's a problem connecting to {provider_name}. Please update your {provider_name} API key in the chat settings."
self.reply(response, message)
self.reply(response, chat, message)
return
formatted_e = traceback.format_exc()
response = (
f"Sorry, an error occurred. Details below:\n\n```\n{formatted_e}\n```"
)
self.reply(response, message)
self.reply(response, chat, message)

def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
def reply(self, response: str, chat: YChat, human_msg: Optional[HumanChatMessage] = None):
"""
Sends an agent message, usually in response to a received
`HumanChatMessage`.
Expand All @@ -259,12 +267,13 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
persona=self.persona,
)

for handler in self._root_chat_handlers.values():
if not handler:
continue
self.write_message(chat, response)
# for handler in self._root_chat_handlers.values():
# if not handler:
# continue

handler.broadcast_message(agent_msg)
break
# handler.broadcast_message(agent_msg)
# break

@property
def persona(self):
Expand Down Expand Up @@ -367,14 +376,14 @@ def create_llm_chain(
):
raise NotImplementedError("Should be implemented by subclasses")

def parse_args(self, message, silent=False):
def parse_args(self, message, chat, silent=False):
args = message.body.split(" ")
try:
args = self.parser.parse_args(args[1:])
except (argparse.ArgumentError, SystemExit) as e:
if not silent:
response = f"{self.parser.format_usage()}"
self.reply(response, message)
self.reply(response, chat, message)
return None
return args

Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from jupyter_ai.models import ClearMessage
from jupyterlab_collaborative_chat.ychat import YChat

from .base import BaseChatHandler, SlashCommandRoutingType

Expand All @@ -16,7 +17,7 @@ class ClearChatHandler(BaseChatHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def process_message(self, _):
async def process_message(self, _, chat: YChat):
# Clear chat
for handler in self._root_chat_handlers.values():
if not handler:
Expand Down
22 changes: 12 additions & 10 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from typing import Dict, Type
from uuid import uuid4
from jupyterlab_collaborative_chat.ychat import YChat

from jupyter_ai.models import (
AgentStreamChunkMessage,
Expand Down Expand Up @@ -80,23 +81,24 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str:

return stream_id

def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False):
def _send_stream_chunk(self, stream_id: str, content: str, chat: YChat, complete: bool = False):
"""
Sends an `agent-stream-chunk` message containing content that should be
appended to an existing `agent-stream` message with ID `stream_id`.
"""
stream_chunk_msg = AgentStreamChunkMessage(
id=stream_id, content=content, stream_complete=complete
)
self.write_message(chat, stream_chunk_msg.content)
# for handler in self._root_chat_handlers.values():
# if not handler:
# continue

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(stream_chunk_msg)
break
# handler.broadcast_message(stream_chunk_msg)
# break

async def process_message(self, message: HumanChatMessage):
async def process_message(self, message: HumanChatMessage, chat: YChat):
self.log.warning("PROCESS IN DEFAULT HANDLER")
self.get_llm_chain()
received_first_chunk = False

Expand All @@ -119,10 +121,10 @@ async def process_message(self, message: HumanChatMessage):
if isinstance(chunk, AIMessageChunk):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
self._send_stream_chunk(stream_id, chunk, chat)
else:
self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}")
break

# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)
self._send_stream_chunk(stream_id, "", chat, complete=True)
7 changes: 4 additions & 3 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from datetime import datetime
from typing import List
from jupyterlab_collaborative_chat.ychat import YChat

from jupyter_ai.models import AgentChatMessage, HumanChatMessage

Expand Down Expand Up @@ -31,11 +32,11 @@ def chat_message_to_markdown(self, message):
return ""

# Write the chat history to a markdown file with a timestamp
async def process_message(self, message: HumanChatMessage):
async def process_message(self, message: HumanChatMessage, chat: YChat):
markdown_content = "\n\n".join(
self.chat_message_to_markdown(msg) for msg in self._chat_history
)
args = self.parse_args(message)
args = self.parse_args(message, chat)
chat_filename = ( # if no filename, use "chat_history" + timestamp
args.path[0]
if (args.path and args.path[0] != "")
Expand All @@ -46,4 +47,4 @@ async def process_message(self, message: HumanChatMessage):
) # Do not use timestamp if filename is entered as argument
with open(chat_file, "w") as chat_history:
chat_history.write(markdown_content)
self.reply(f"File saved to `{chat_file}`")
self.reply(f"File saved to `{chat_file}`", chat)
6 changes: 4 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Type
from jupyterlab_collaborative_chat.ychat import YChat

from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
Expand Down Expand Up @@ -77,10 +78,11 @@ def create_llm_chain(
self.llm = llm
self.llm_chain = LLMChain(llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True)

async def process_message(self, message: HumanChatMessage):
async def process_message(self, message: HumanChatMessage, chat: YChat):
if not (message.selection and message.selection.type == "cell-with-error"):
self.reply(
"`/fix` requires an active code cell with error output. Please click on a cell with error output and retry.",
chat,
message,
)
return
Expand All @@ -101,4 +103,4 @@ async def process_message(self, message: HumanChatMessage):
error_value=selection.error.value,
traceback="\n".join(selection.error.traceback),
)
self.reply(response, message)
self.reply(response, chat, message)
11 changes: 6 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import traceback
from pathlib import Path
from typing import Dict, List, Optional, Type
from jupyterlab_collaborative_chat.ychat import YChat

import nbformat
from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType
Expand Down Expand Up @@ -261,18 +262,18 @@ async def _generate_notebook(self, prompt: str):
nbformat.write(notebook, final_path)
return final_path

async def process_message(self, message: HumanChatMessage):
async def process_message(self, message: HumanChatMessage, chat: YChat):
self.get_llm_chain()

# first send a verification message to user
response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions."
self.reply(response, message)
self.reply(response, chat, message)

final_path = await self._generate_notebook(prompt=message.body)
response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it."""
self.reply(response, message)
self.reply(response, chat, message)

async def handle_exc(self, e: Exception, message: HumanChatMessage):
async def handle_exc(self, e: Exception, message: HumanChatMessage, chat: YChat):
timestamp = time.strftime("%Y-%m-%d-%H.%M.%S")
default_log_dir = Path(self.output_dir) / "jupyter-ai-logs"
log_dir = self.log_dir or default_log_dir
Expand All @@ -282,4 +283,4 @@ async def handle_exc(self, e: Exception, message: HumanChatMessage):
traceback.print_exc(file=log)

response = f"An error occurred while generating the notebook. The error details have been saved to `./{log_path}`.\n\nTry running `/generate` again, as some language models require multiple attempts before a notebook is generated."
self.reply(response, message)
self.reply(response, chat, message)
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from jupyter_ai.models import HumanChatMessage
from jupyterlab_collaborative_chat.ychat import YChat

from .base import BaseChatHandler, SlashCommandRoutingType

Expand All @@ -15,5 +16,5 @@ class HelpChatHandler(BaseChatHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def process_message(self, message: HumanChatMessage):
async def process_message(self, message: HumanChatMessage, chat: YChat):
self.send_help_message(message)
Loading
Loading