From 1c99b34c52e41c8cd0b1fa8c386d7a96a52d59ae Mon Sep 17 00:00:00 2001 From: michael Date: Fri, 9 Aug 2024 15:52:23 +0800 Subject: [PATCH 1/8] make chat memory a shared object --- .../jupyter_ai/chat_handlers/base.py | 3 +++ .../jupyter_ai/chat_handlers/clear.py | 1 + .../jupyter_ai/chat_handlers/default.py | 4 +--- packages/jupyter-ai/jupyter_ai/extension.py | 19 ++++++++++++++++++- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index fbb4cdb31..6f336f1db 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -28,6 +28,7 @@ from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider from langchain.pydantic_v1 import BaseModel +from langchain_core.chat_history import BaseChatMessageHistory if TYPE_CHECKING: from jupyter_ai.handlers import RootChatHandler @@ -113,6 +114,7 @@ def __init__( root_chat_handlers: Dict[str, "RootChatHandler"], model_parameters: Dict[str, Dict], chat_history: List[ChatMessage], + llm_chat_history: BaseChatMessageHistory, root_dir: str, preferred_dir: Optional[str], dask_client_future: Awaitable[DaskClient], @@ -122,6 +124,7 @@ def __init__( self._root_chat_handlers = root_chat_handlers self.model_parameters = model_parameters self._chat_history = chat_history + self.llm_chat_history = llm_chat_history self.parser = argparse.ArgumentParser( add_help=False, description=self.help, formatter_class=MarkdownHelpFormatter ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index 97cae4ab4..1462937a4 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -27,6 +27,7 @@ async def process_message(self, _): # Clear chat handler.broadcast_message(ClearMessage()) self._chat_history.clear() + self.llm_chat_history.clear() # Build /help message and reinstate it in chat chat_handlers = handler.chat_handlers diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 4aebdde80..460b561a2 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -11,7 +11,6 @@ from langchain_core.messages import AIMessageChunk from langchain_core.runnables.history import RunnableWithMessageHistory -from ..history import BoundedChatHistory from .base import BaseChatHandler, SlashCommandRoutingType @@ -42,10 +41,9 @@ def create_llm_chain( runnable = prompt_template | llm if not llm.manages_history: - history = BoundedChatHistory(k=2) runnable = RunnableWithMessageHistory( runnable=runnable, - get_session_history=lambda *args: history, + get_session_history=lambda *args: self.llm_chat_history, input_messages_key="input", history_messages_key="history", ) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 9aab5c5ba..f5bf2346c 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -10,7 +10,7 @@ from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp from tornado.web import StaticFileHandler -from traitlets import Dict, List, Unicode +from traitlets import Dict, List, Unicode, Integer from .chat_handlers import ( AskChatHandler, @@ -34,6 +34,8 @@ RootChatHandler, SlashCommandsInfoHandler, ) +from .history import BoundedChatHistory + JUPYTERNAUT_AVATAR_ROUTE = JupyternautPersona.avatar_route JUPYTERNAUT_AVATAR_PATH = str( @@ -158,6 +160,15 @@ class AiExtension(ExtensionApp): config=True, ) + default_max_chat_history = Integer( + default_value=2, + allow_none=False, + help=""" + Number of chat interactions to keep in the conversational memory object. + """, + config=True, + ) + def initialize_settings(self): start = time.time() @@ -222,6 +233,11 @@ def initialize_settings(self): # memory object used by the LM chain. self.settings["chat_history"] = [] + # conversational memory object used by LM chain + self.settings["llm_chat_history"] = BoundedChatHistory( + k=self.default_max_chat_history + ) + # list of pending messages self.settings["pending_messages"] = [] @@ -252,6 +268,7 @@ def initialize_settings(self): **common_handler_kargs, "root_chat_handlers": self.settings["jai_root_chat_handlers"], "chat_history": self.settings["chat_history"], + "llm_chat_history": self.settings["llm_chat_history"], "root_dir": self.serverapp.root_dir, "dask_client_future": self.settings["dask_client_future"], "model_parameters": self.settings["model_parameters"], From 0595c7f091350fbc02f53fc01e29ac8f62bc1a7e Mon Sep 17 00:00:00 2001 From: michael Date: Fri, 9 Aug 2024 16:53:30 +0800 Subject: [PATCH 2/8] store all messages internally --- packages/jupyter-ai/jupyter_ai/history.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 02b77b911..2ee6647e1 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -14,27 +14,27 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel): messages and 2 AI messages. """ - messages: List[BaseMessage] = Field(default_factory=list) + _messages: List[BaseMessage] = Field(default_factory=list) size: int = 0 k: int + @property + def messages(self) -> List[BaseMessage]: + return self._messages[-self.k * 2 :] + async def aget_messages(self) -> List[BaseMessage]: return self.messages def add_message(self, message: BaseMessage) -> None: """Add a self-created message to the store""" - self.messages.append(message) - self.size += 1 - - if self.size > self.k * 2: - self.messages.pop(0) + self._messages.append(message) async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: """Add messages to the store""" self.add_messages(messages) def clear(self) -> None: - self.messages = [] + self._messages = [] async def aclear(self) -> None: self.clear() From 0d3376ff0a3daf603f85105be4e913c21ab3de29 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Aug 2024 09:01:12 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- packages/jupyter-ai/jupyter_ai/extension.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index f5bf2346c..b9aa2c2c2 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -10,7 +10,7 @@ from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp from tornado.web import StaticFileHandler -from traitlets import Dict, List, Unicode, Integer +from traitlets import Dict, Integer, List, Unicode from .chat_handlers import ( AskChatHandler, @@ -36,7 +36,6 @@ ) from .history import BoundedChatHistory - JUPYTERNAUT_AVATAR_ROUTE = JupyternautPersona.avatar_route JUPYTERNAUT_AVATAR_PATH = str( os.path.join(os.path.dirname(__file__), "static", "jupyternaut.svg") From fc9cff7b7ecf5c65d1e2d906d3b80fc6a01343d1 Mon Sep 17 00:00:00 2001 From: michael Date: Fri, 9 Aug 2024 17:08:55 +0800 Subject: [PATCH 4/8] fix test --- packages/jupyter-ai/jupyter_ai/tests/test_handlers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index d2e73ce6c..457c79623 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -18,6 +18,7 @@ PendingMessage, Persona, ) +from jupyter_ai.history import BoundedChatHistory from jupyter_ai_magics import BaseProvider from langchain_community.llms import FakeListLLM from tornado.httputil import HTTPServerRequest @@ -70,6 +71,7 @@ def broadcast_message(message: Message) -> None: root_chat_handlers={"root": root_handler}, model_parameters={}, chat_history=[], + llm_chat_history=BoundedChatHistory(k=2), root_dir="", preferred_dir="", dask_client_future=None, From 115a6e74be511f00b632fe022fa93daf7b0487b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Aug 2024 09:10:59 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- packages/jupyter-ai/jupyter_ai/tests/test_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index 457c79623..16d439169 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -8,6 +8,7 @@ from jupyter_ai.chat_handlers import DefaultChatHandler, learn from jupyter_ai.config_manager import ConfigManager from jupyter_ai.handlers import RootChatHandler +from jupyter_ai.history import BoundedChatHistory from jupyter_ai.models import ( AgentStreamChunkMessage, AgentStreamMessage, @@ -18,7 +19,6 @@ PendingMessage, Persona, ) -from jupyter_ai.history import BoundedChatHistory from jupyter_ai_magics import BaseProvider from langchain_community.llms import FakeListLLM from tornado.httputil import HTTPServerRequest From 020b4b31d167b269b7176407b0d3b67701e81e52 Mon Sep 17 00:00:00 2001 From: michael Date: Fri, 9 Aug 2024 17:33:18 +0800 Subject: [PATCH 6/8] fix chat history class --- packages/jupyter-ai/jupyter_ai/history.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 2ee6647e1..68481893a 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from typing import List, Sequence, Optional from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage @@ -14,27 +14,27 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel): messages and 2 AI messages. """ - _messages: List[BaseMessage] = Field(default_factory=list) + all_messages: List[BaseMessage] = Field(default_factory=list, alias='messages') size: int = 0 k: int @property def messages(self) -> List[BaseMessage]: - return self._messages[-self.k * 2 :] + return self.all_messages[-self.k * 2 :] async def aget_messages(self) -> List[BaseMessage]: return self.messages def add_message(self, message: BaseMessage) -> None: """Add a self-created message to the store""" - self._messages.append(message) + self.all_messages.append(message) async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: """Add messages to the store""" self.add_messages(messages) def clear(self) -> None: - self._messages = [] + self.all_messages = [] async def aclear(self) -> None: self.clear() From 3afe37ce502491083af152c80f863e23fdc9c87b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Aug 2024 09:33:35 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- packages/jupyter-ai/jupyter_ai/history.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 68481893a..390ab0cef 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, Optional +from typing import List, Optional, Sequence from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage @@ -14,7 +14,7 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel): messages and 2 AI messages. """ - all_messages: List[BaseMessage] = Field(default_factory=list, alias='messages') + all_messages: List[BaseMessage] = Field(default_factory=list, alias="messages") size: int = 0 k: int From 741589606553ffc091c0195c7a7a5de2892937af Mon Sep 17 00:00:00 2001 From: michael Date: Fri, 9 Aug 2024 23:49:12 +0800 Subject: [PATCH 8/8] prevent enter from sending empty message --- packages/jupyter-ai/src/components/chat-input.tsx | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/packages/jupyter-ai/src/components/chat-input.tsx b/packages/jupyter-ai/src/components/chat-input.tsx index 4ab9bcd36..ddc2d4209 100644 --- a/packages/jupyter-ai/src/components/chat-input.tsx +++ b/packages/jupyter-ai/src/components/chat-input.tsx @@ -207,6 +207,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { props.chatHandler.sendMessage({ prompt, selection }); } + const inputExists = !!input.trim(); function handleKeyDown(event: React.KeyboardEvent) { if (event.key !== 'Enter') { return; @@ -218,6 +219,12 @@ export function ChatInput(props: ChatInputProps): JSX.Element { return; } + if (!inputExists) { + event.stopPropagation(); + event.preventDefault(); + return; + } + if ( event.key === 'Enter' && ((props.sendWithShiftEnter && event.shiftKey) || @@ -240,7 +247,6 @@ export function ChatInput(props: ChatInputProps): JSX.Element { ); - const inputExists = !!input.trim(); const sendButtonProps: SendButtonProps = { onSend, sendWithShiftEnter: props.sendWithShiftEnter,