Skip to content

Commit

Permalink
Allow unlimited LLM memory through traitlets configuration (jupyterla…
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski authored Sep 10, 2024
1 parent 43e6acc commit 83cbd8e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
6 changes: 6 additions & 0 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,13 @@ class AiExtension(ExtensionApp):
default_value=2,
help="""
Number of chat interactions to keep in the conversational memory object.
An interaction is defined as an exchange between a human and AI, thus
comprising of one or two messages.
Set to `None` to keep all interactions.
""",
allow_none=True,
config=True,
)

Expand Down
8 changes: 5 additions & 3 deletions packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import List, Optional, Sequence, Set
from typing import List, Optional, Sequence, Set, Union

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
Expand All @@ -16,16 +16,18 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
`k` exchanges between a user and an LLM.
For example, when `k=2`, `BoundedChatHistory` will store up to 2 human
messages and 2 AI messages.
messages and 2 AI messages. If `k` is set to `None` all messages are kept.
"""

k: int
k: Union[int, None]
clear_time: float = 0.0
cleared_msgs: Set[str] = set()
_all_messages: List[BaseMessage] = PrivateAttr(default_factory=list)

@property
def messages(self) -> List[BaseMessage]:
if self.k is None:
return self._all_messages
return self._all_messages[-self.k * 2 :]

async def aget_messages(self) -> List[BaseMessage]:
Expand Down
48 changes: 48 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from unittest import mock

import pytest
from jupyter_ai.extension import AiExtension
from jupyter_ai.history import HUMAN_MSG_ID_KEY
from jupyter_ai_magics import BaseProvider
from langchain_core.messages import BaseMessage

pytest_plugins = ["pytest_jupyter.jupyter_server"]

Expand Down Expand Up @@ -43,3 +48,46 @@ def test_blocks_providers(argv, jp_configurable_serverapp):
ai._link_jupyter_server_extension(server)
ai.initialize_settings()
assert KNOWN_LM_A not in ai.settings["lm_providers"]


@pytest.fixture
def jp_server_config(jp_server_config):
# Disable the extension during server startup to avoid double initialization
return {"ServerApp": {"jpserver_extensions": {"jupyter_ai": False}}}


@pytest.fixture
def ai_extension(jp_serverapp):
ai = AiExtension()
# `BaseProvider.server_settings` can be only initialized once; however, the tests
# may run in parallel setting it with race condition; because we are not testing
# the `BaseProvider.server_settings` here, we can just mock the setter
settings_mock = mock.PropertyMock()
with mock.patch.object(BaseProvider.__class__, "server_settings", settings_mock):
yield ai


@pytest.mark.parametrize(
"max_history,messages_to_add,expected_size",
[
# for max_history = 1 we expect to see up to 2 messages (1 human and 1 AI message)
(1, 4, 2),
# if there is less than `max_history` messages, all should be returned
(1, 1, 1),
# if no limit is set, all messages should be returned
(None, 9, 9),
],
)
def test_max_chat_history(ai_extension, max_history, messages_to_add, expected_size):
ai = ai_extension
ai.default_max_chat_history = max_history
ai.initialize_settings()
for i in range(messages_to_add):
message = BaseMessage(
content=f"Test message {i}",
type="test",
additional_kwargs={HUMAN_MSG_ID_KEY: f"message-{i}"},
)
ai.settings["llm_chat_memory"].add_message(message)

assert len(ai.settings["llm_chat_memory"].messages) == expected_size

0 comments on commit 83cbd8e

Please sign in to comment.