From d111f15e2e59349a167043b23fb046eee80cb61b Mon Sep 17 00:00:00 2001 From: Pierce Kelaita Date: Sun, 5 May 2024 13:47:01 -0700 Subject: [PATCH] add tests for ChatMemory --- l2m2/client/llm_client.py | 24 ++++----- l2m2/memory/__init__.py | 4 +- l2m2/memory/chat_memory.py | 22 ++++++--- tests/l2m2/client/test_llm_client.py | 56 +++++++++++++++++++++ tests/l2m2/memory/__init__.py | 0 tests/l2m2/memory/test_chat_memory.py | 70 +++++++++++++++++++++++++++ 6 files changed, 157 insertions(+), 19 deletions(-) create mode 100644 tests/l2m2/memory/__init__.py create mode 100644 tests/l2m2/memory/test_chat_memory.py diff --git a/l2m2/client/llm_client.py b/l2m2/client/llm_client.py index 83943c9..f1eb5c0 100644 --- a/l2m2/client/llm_client.py +++ b/l2m2/client/llm_client.py @@ -53,11 +53,6 @@ def __init__( self.add_provider(provider, api_key) if enable_memory: - if memory_window_size is None: - memory_window_size = DEFAULT_WINDOW_SIZE - if not memory_window_size > 0: - raise ValueError("Memory window size must be a positive integer.") - self.memory = ChatMemory(memory_window_size) @staticmethod @@ -69,7 +64,6 @@ def get_available_providers() -> Set[str]: Returns: Set[str]: A set of available providers. """ - # return set([str(info["provider"]) for info in MODEL_INFO.values()]) return set(PROVIDER_INFO.keys()) @staticmethod @@ -81,7 +75,6 @@ def get_available_models() -> Set[str]: Returns: Set[str]: A set of available models. """ - # return set(MODEL_INFO.keys()) return set(MODEL_INFO.keys()) def get_active_providers(self) -> Set[str]: @@ -209,10 +202,19 @@ def clear_memory(self) -> None: ) self.memory.clear() - def enable_memory(self) -> None: - """Enable memory if it is not already enabled.""" - if self.memory is None: - self.memory = ChatMemory() + def enable_memory(self, window_size: int = DEFAULT_WINDOW_SIZE) -> None: + """Enable memory, with a specified window size. + + Args: + window_size (int, optional): The size of the memory window. Defaults to + `l2m2.memory.DEFAULT_WINDOW_SIZE`. + + Raises: + ValueError: If memory is already enabled. + """ + if self.memory is not None: + raise ValueError("Memory is already enabled.") + self.memory = ChatMemory(window_size) def call( self, diff --git a/l2m2/memory/__init__.py b/l2m2/memory/__init__.py index 3b577f8..c503309 100644 --- a/l2m2/memory/__init__.py +++ b/l2m2/memory/__init__.py @@ -1,3 +1,3 @@ -from .chat_memory import ChatMemory, Message, DEFAULT_WINDOW_SIZE +from .chat_memory import ChatMemory, ChatMemoryEntry, DEFAULT_WINDOW_SIZE -__all__ = ["ChatMemory", "Message", "DEFAULT_WINDOW_SIZE"] +__all__ = ["ChatMemory", "ChatMemoryEntry", "DEFAULT_WINDOW_SIZE"] diff --git a/l2m2/memory/chat_memory.py b/l2m2/memory/chat_memory.py index 491e869..e365536 100644 --- a/l2m2/memory/chat_memory.py +++ b/l2m2/memory/chat_memory.py @@ -5,7 +5,9 @@ DEFAULT_WINDOW_SIZE = 40 -class Message: +class ChatMemoryEntry: + """Represents a message in a conversation memory.""" + class Role(Enum): USER = "user" AGENT = "agent" @@ -27,9 +29,15 @@ def __init__(self, window_size: int = DEFAULT_WINDOW_SIZE) -> None: Args: window_size (int, optional): The maximum number of messages to store. Defaults to DEFAULT_WINDOW_SIZE. + + Raises: + ValueError: If `window_size` is less than or equal to 0. """ + if not window_size > 0: + raise ValueError("window_size must be a positive integer.") + self.window_size: int = window_size - self.mem_window: Deque[Message] = deque(maxlen=window_size) + self.mem_window: Deque[ChatMemoryEntry] = deque(maxlen=window_size) def add_user_message(self, text: str) -> None: """Adds a user message to the memory. @@ -37,7 +45,7 @@ def add_user_message(self, text: str) -> None: Args: text (str): The user message to add. """ - self.mem_window.append(Message(text, Message.Role.USER)) + self.mem_window.append(ChatMemoryEntry(text, ChatMemoryEntry.Role.USER)) def add_agent_message(self, text: str) -> None: """Adds an agent message to the memory. @@ -45,7 +53,7 @@ def add_agent_message(self, text: str) -> None: Args: text (str): The agent message to add. """ - self.mem_window.append(Message(text, Message.Role.AGENT)) + self.mem_window.append(ChatMemoryEntry(text, ChatMemoryEntry.Role.AGENT)) def unpack( self, role_key: str, message_key: str, user_key: str, agent_key: str @@ -82,7 +90,9 @@ def unpack( """ res = [] for message in self.mem_window: - role_value = user_key if message.role == Message.Role.USER else agent_key + role_value = ( + user_key if message.role == ChatMemoryEntry.Role.USER else agent_key + ) res.append({role_key: role_value, message_key: message.text}) return res @@ -93,5 +103,5 @@ def clear(self) -> None: def __len__(self) -> int: return len(self.mem_window) - def __iter__(self) -> Iterator[Message]: + def __iter__(self) -> Iterator[ChatMemoryEntry]: return iter(self.mem_window) diff --git a/tests/l2m2/client/test_llm_client.py b/tests/l2m2/client/test_llm_client.py index 4059ac6..0385d16 100644 --- a/tests/l2m2/client/test_llm_client.py +++ b/tests/l2m2/client/test_llm_client.py @@ -416,3 +416,59 @@ def test_multi_provider_pref_inactive(llm_client): llm_client.add_provider("replicate", "test-key-replicate") with pytest.raises(ValueError): llm_client.call(prompt="Hello", model="llama3-70b", prefer_provider="openai") + + +# -- Tests for memory -- # + + +@patch(f"{MODULE_PATH}.OpenAI") +def test_memory(mock_openai): + mock_client = Mock() + mock_call = mock_client.chat.completions.create + mock_response = construct_mock_from_path("choices[0].message.content") + mock_call.return_value = mock_response + mock_openai.return_value = mock_client + + llm_client = LLMClient(enable_memory=True) + llm_client.add_provider("openai", "fake-api-key") + + llm_client.get_memory().add_user_message("A") + llm_client.get_memory().add_agent_message("B") + + response = llm_client.call(prompt="C", model="gpt-4-turbo") + assert response == "response" + + assert llm_client.get_memory().unpack("role", "content", "user", "assistant") == [ + {"role": "user", "content": "A"}, + {"role": "assistant", "content": "B"}, + {"role": "user", "content": "C"}, + {"role": "assistant", "content": "response"}, + ] + + llm_client.clear_memory() + + assert llm_client.get_memory().unpack("role", "content", "user", "assistant") == [] + + +def test_memory_errors(llm_client): + with pytest.raises(ValueError): + llm_client.get_memory() + + with pytest.raises(ValueError): + llm_client.clear_memory() + + llm_client.enable_memory() + + with pytest.raises(ValueError): + llm_client.enable_memory() + + +def test_memory_unsupported_provider(): + unsupported_providers = { + "replicate": "llama3-8b", + } + for provider, model in unsupported_providers.items(): + llm_client = LLMClient(enable_memory=True) + llm_client.add_provider(provider, "fake-api-key") + with pytest.raises(ValueError): + llm_client.call(prompt="Hello", model=model) diff --git a/tests/l2m2/memory/__init__.py b/tests/l2m2/memory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/l2m2/memory/test_chat_memory.py b/tests/l2m2/memory/test_chat_memory.py new file mode 100644 index 0000000..7df0cdc --- /dev/null +++ b/tests/l2m2/memory/test_chat_memory.py @@ -0,0 +1,70 @@ +import pytest +from l2m2.memory.chat_memory import ChatMemory, DEFAULT_WINDOW_SIZE, ChatMemoryEntry + + +def test_chat_memory(): + memory = ChatMemory() + assert memory.window_size == DEFAULT_WINDOW_SIZE + assert len(memory) == 0 + + memory.add_user_message("A") + memory.add_agent_message("B") + it = iter(memory) + e1 = next(it) + e2 = next(it) + assert len(memory) == 2 + assert e1.text == "A" + assert e1.role == ChatMemoryEntry.Role.USER + assert e2.text == "B" + assert e2.role == ChatMemoryEntry.Role.AGENT + + memory.clear() + assert len(memory) == 0 + + +def test_unpack(): + memory = ChatMemory(window_size=10) + memory.add_user_message("A") + memory.add_agent_message("B") + memory.add_user_message("C") + memory.add_agent_message("D") + memory.add_user_message("E") + memory.add_agent_message("F") + mem_arr = memory.unpack("role", "text", "user", "agent") + assert len(mem_arr) == 6 + assert mem_arr[0] == {"role": "user", "text": "A"} + assert mem_arr[1] == {"role": "agent", "text": "B"} + assert mem_arr[2] == {"role": "user", "text": "C"} + assert mem_arr[3] == {"role": "agent", "text": "D"} + assert mem_arr[4] == {"role": "user", "text": "E"} + assert mem_arr[5] == {"role": "agent", "text": "F"} + + +def test_sliding_window(): + memory = ChatMemory(window_size=3) + memory.add_user_message("A") + memory.add_agent_message("B") + memory.add_user_message("C") + memory.add_agent_message("D") + memory.add_user_message("E") + memory.add_agent_message("F") + assert len(memory) == 3 + + mem_arr = memory.unpack("role", "text", "user", "agent") + assert len(mem_arr) == 3 + assert mem_arr[0] == {"role": "agent", "text": "D"} + assert mem_arr[1] == {"role": "user", "text": "E"} + assert mem_arr[2] == {"role": "agent", "text": "F"} + + memory.add_agent_message("G") + + mem_arr = memory.unpack("role", "text", "user", "agent") + assert len(mem_arr) == 3 + assert mem_arr[0] == {"role": "user", "text": "E"} + assert mem_arr[1] == {"role": "agent", "text": "F"} + assert mem_arr[2] == {"role": "agent", "text": "G"} + + +def test_bad_window_size(): + with pytest.raises(ValueError): + ChatMemory(window_size=-1)