From 2102e334cce10037f328210ae874e76e7d990fc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Krassowski?= <5832902+krassowski@users.noreply.github.com> Date: Tue, 10 Sep 2024 21:37:13 +0100 Subject: [PATCH] Allow unlimited LLM memory through traitlets configuration (#986) --- packages/jupyter-ai/jupyter_ai/extension.py | 6 +++ packages/jupyter-ai/jupyter_ai/history.py | 8 ++-- .../jupyter_ai/tests/test_extension.py | 48 +++++++++++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index a57fb0eb9..78c4dfefa 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -188,7 +188,13 @@ class AiExtension(ExtensionApp): default_value=2, help=""" Number of chat interactions to keep in the conversational memory object. + + An interaction is defined as an exchange between a human and AI, thus + comprising of one or two messages. + + Set to `None` to keep all interactions. """, + allow_none=True, config=True, ) diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 9e1064194..c857b4486 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -1,5 +1,5 @@ import time -from typing import List, Optional, Sequence, Set +from typing import List, Optional, Sequence, Set, Union from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage @@ -16,16 +16,18 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel): `k` exchanges between a user and an LLM. For example, when `k=2`, `BoundedChatHistory` will store up to 2 human - messages and 2 AI messages. + messages and 2 AI messages. If `k` is set to `None` all messages are kept. """ - k: int + k: Union[int, None] clear_time: float = 0.0 cleared_msgs: Set[str] = set() _all_messages: List[BaseMessage] = PrivateAttr(default_factory=list) @property def messages(self) -> List[BaseMessage]: + if self.k is None: + return self._all_messages return self._all_messages[-self.k * 2 :] async def aget_messages(self) -> List[BaseMessage]: diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py index b46c148c3..9ae52d8a0 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py @@ -1,7 +1,12 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from unittest import mock + import pytest from jupyter_ai.extension import AiExtension +from jupyter_ai.history import HUMAN_MSG_ID_KEY +from jupyter_ai_magics import BaseProvider +from langchain_core.messages import BaseMessage pytest_plugins = ["pytest_jupyter.jupyter_server"] @@ -43,3 +48,46 @@ def test_blocks_providers(argv, jp_configurable_serverapp): ai._link_jupyter_server_extension(server) ai.initialize_settings() assert KNOWN_LM_A not in ai.settings["lm_providers"] + + +@pytest.fixture +def jp_server_config(jp_server_config): + # Disable the extension during server startup to avoid double initialization + return {"ServerApp": {"jpserver_extensions": {"jupyter_ai": False}}} + + +@pytest.fixture +def ai_extension(jp_serverapp): + ai = AiExtension() + # `BaseProvider.server_settings` can be only initialized once; however, the tests + # may run in parallel setting it with race condition; because we are not testing + # the `BaseProvider.server_settings` here, we can just mock the setter + settings_mock = mock.PropertyMock() + with mock.patch.object(BaseProvider.__class__, "server_settings", settings_mock): + yield ai + + +@pytest.mark.parametrize( + "max_history,messages_to_add,expected_size", + [ + # for max_history = 1 we expect to see up to 2 messages (1 human and 1 AI message) + (1, 4, 2), + # if there is less than `max_history` messages, all should be returned + (1, 1, 1), + # if no limit is set, all messages should be returned + (None, 9, 9), + ], +) +def test_max_chat_history(ai_extension, max_history, messages_to_add, expected_size): + ai = ai_extension + ai.default_max_chat_history = max_history + ai.initialize_settings() + for i in range(messages_to_add): + message = BaseMessage( + content=f"Test message {i}", + type="test", + additional_kwargs={HUMAN_MSG_ID_KEY: f"message-{i}"}, + ) + ai.settings["llm_chat_memory"].add_message(message) + + assert len(ai.settings["llm_chat_memory"].messages) == expected_size