Skip to content

Commit

Permalink
add tests for ChatMemory
Browse files Browse the repository at this point in the history
  • Loading branch information
pkelaita committed May 5, 2024
1 parent 9336db0 commit d111f15
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 19 deletions.
24 changes: 13 additions & 11 deletions l2m2/client/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,6 @@ def __init__(
self.add_provider(provider, api_key)

if enable_memory:
if memory_window_size is None:
memory_window_size = DEFAULT_WINDOW_SIZE
if not memory_window_size > 0:
raise ValueError("Memory window size must be a positive integer.")

self.memory = ChatMemory(memory_window_size)

@staticmethod
Expand All @@ -69,7 +64,6 @@ def get_available_providers() -> Set[str]:
Returns:
Set[str]: A set of available providers.
"""
# return set([str(info["provider"]) for info in MODEL_INFO.values()])
return set(PROVIDER_INFO.keys())

@staticmethod
Expand All @@ -81,7 +75,6 @@ def get_available_models() -> Set[str]:
Returns:
Set[str]: A set of available models.
"""
# return set(MODEL_INFO.keys())
return set(MODEL_INFO.keys())

def get_active_providers(self) -> Set[str]:
Expand Down Expand Up @@ -209,10 +202,19 @@ def clear_memory(self) -> None:
)
self.memory.clear()

def enable_memory(self) -> None:
"""Enable memory if it is not already enabled."""
if self.memory is None:
self.memory = ChatMemory()
def enable_memory(self, window_size: int = DEFAULT_WINDOW_SIZE) -> None:
"""Enable memory, with a specified window size.
Args:
window_size (int, optional): The size of the memory window. Defaults to
`l2m2.memory.DEFAULT_WINDOW_SIZE`.
Raises:
ValueError: If memory is already enabled.
"""
if self.memory is not None:
raise ValueError("Memory is already enabled.")
self.memory = ChatMemory(window_size)

def call(
self,
Expand Down
4 changes: 2 additions & 2 deletions l2m2/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .chat_memory import ChatMemory, Message, DEFAULT_WINDOW_SIZE
from .chat_memory import ChatMemory, ChatMemoryEntry, DEFAULT_WINDOW_SIZE

__all__ = ["ChatMemory", "Message", "DEFAULT_WINDOW_SIZE"]
__all__ = ["ChatMemory", "ChatMemoryEntry", "DEFAULT_WINDOW_SIZE"]
22 changes: 16 additions & 6 deletions l2m2/memory/chat_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
DEFAULT_WINDOW_SIZE = 40


class Message:
class ChatMemoryEntry:
"""Represents a message in a conversation memory."""

class Role(Enum):
USER = "user"
AGENT = "agent"
Expand All @@ -27,25 +29,31 @@ def __init__(self, window_size: int = DEFAULT_WINDOW_SIZE) -> None:
Args:
window_size (int, optional): The maximum number of messages to store.
Defaults to DEFAULT_WINDOW_SIZE.
Raises:
ValueError: If `window_size` is less than or equal to 0.
"""
if not window_size > 0:
raise ValueError("window_size must be a positive integer.")

self.window_size: int = window_size
self.mem_window: Deque[Message] = deque(maxlen=window_size)
self.mem_window: Deque[ChatMemoryEntry] = deque(maxlen=window_size)

def add_user_message(self, text: str) -> None:
"""Adds a user message to the memory.
Args:
text (str): The user message to add.
"""
self.mem_window.append(Message(text, Message.Role.USER))
self.mem_window.append(ChatMemoryEntry(text, ChatMemoryEntry.Role.USER))

def add_agent_message(self, text: str) -> None:
"""Adds an agent message to the memory.
Args:
text (str): The agent message to add.
"""
self.mem_window.append(Message(text, Message.Role.AGENT))
self.mem_window.append(ChatMemoryEntry(text, ChatMemoryEntry.Role.AGENT))

def unpack(
self, role_key: str, message_key: str, user_key: str, agent_key: str
Expand Down Expand Up @@ -82,7 +90,9 @@ def unpack(
"""
res = []
for message in self.mem_window:
role_value = user_key if message.role == Message.Role.USER else agent_key
role_value = (
user_key if message.role == ChatMemoryEntry.Role.USER else agent_key
)
res.append({role_key: role_value, message_key: message.text})
return res

Expand All @@ -93,5 +103,5 @@ def clear(self) -> None:
def __len__(self) -> int:
return len(self.mem_window)

def __iter__(self) -> Iterator[Message]:
def __iter__(self) -> Iterator[ChatMemoryEntry]:
return iter(self.mem_window)
56 changes: 56 additions & 0 deletions tests/l2m2/client/test_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,59 @@ def test_multi_provider_pref_inactive(llm_client):
llm_client.add_provider("replicate", "test-key-replicate")
with pytest.raises(ValueError):
llm_client.call(prompt="Hello", model="llama3-70b", prefer_provider="openai")


# -- Tests for memory -- #


@patch(f"{MODULE_PATH}.OpenAI")
def test_memory(mock_openai):
mock_client = Mock()
mock_call = mock_client.chat.completions.create
mock_response = construct_mock_from_path("choices[0].message.content")
mock_call.return_value = mock_response
mock_openai.return_value = mock_client

llm_client = LLMClient(enable_memory=True)
llm_client.add_provider("openai", "fake-api-key")

llm_client.get_memory().add_user_message("A")
llm_client.get_memory().add_agent_message("B")

response = llm_client.call(prompt="C", model="gpt-4-turbo")
assert response == "response"

assert llm_client.get_memory().unpack("role", "content", "user", "assistant") == [
{"role": "user", "content": "A"},
{"role": "assistant", "content": "B"},
{"role": "user", "content": "C"},
{"role": "assistant", "content": "response"},
]

llm_client.clear_memory()

assert llm_client.get_memory().unpack("role", "content", "user", "assistant") == []


def test_memory_errors(llm_client):
with pytest.raises(ValueError):
llm_client.get_memory()

with pytest.raises(ValueError):
llm_client.clear_memory()

llm_client.enable_memory()

with pytest.raises(ValueError):
llm_client.enable_memory()


def test_memory_unsupported_provider():
unsupported_providers = {
"replicate": "llama3-8b",
}
for provider, model in unsupported_providers.items():
llm_client = LLMClient(enable_memory=True)
llm_client.add_provider(provider, "fake-api-key")
with pytest.raises(ValueError):
llm_client.call(prompt="Hello", model=model)
Empty file added tests/l2m2/memory/__init__.py
Empty file.
70 changes: 70 additions & 0 deletions tests/l2m2/memory/test_chat_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
from l2m2.memory.chat_memory import ChatMemory, DEFAULT_WINDOW_SIZE, ChatMemoryEntry


def test_chat_memory():
memory = ChatMemory()
assert memory.window_size == DEFAULT_WINDOW_SIZE
assert len(memory) == 0

memory.add_user_message("A")
memory.add_agent_message("B")
it = iter(memory)
e1 = next(it)
e2 = next(it)
assert len(memory) == 2
assert e1.text == "A"
assert e1.role == ChatMemoryEntry.Role.USER
assert e2.text == "B"
assert e2.role == ChatMemoryEntry.Role.AGENT

memory.clear()
assert len(memory) == 0


def test_unpack():
memory = ChatMemory(window_size=10)
memory.add_user_message("A")
memory.add_agent_message("B")
memory.add_user_message("C")
memory.add_agent_message("D")
memory.add_user_message("E")
memory.add_agent_message("F")
mem_arr = memory.unpack("role", "text", "user", "agent")
assert len(mem_arr) == 6
assert mem_arr[0] == {"role": "user", "text": "A"}
assert mem_arr[1] == {"role": "agent", "text": "B"}
assert mem_arr[2] == {"role": "user", "text": "C"}
assert mem_arr[3] == {"role": "agent", "text": "D"}
assert mem_arr[4] == {"role": "user", "text": "E"}
assert mem_arr[5] == {"role": "agent", "text": "F"}


def test_sliding_window():
memory = ChatMemory(window_size=3)
memory.add_user_message("A")
memory.add_agent_message("B")
memory.add_user_message("C")
memory.add_agent_message("D")
memory.add_user_message("E")
memory.add_agent_message("F")
assert len(memory) == 3

mem_arr = memory.unpack("role", "text", "user", "agent")
assert len(mem_arr) == 3
assert mem_arr[0] == {"role": "agent", "text": "D"}
assert mem_arr[1] == {"role": "user", "text": "E"}
assert mem_arr[2] == {"role": "agent", "text": "F"}

memory.add_agent_message("G")

mem_arr = memory.unpack("role", "text", "user", "agent")
assert len(mem_arr) == 3
assert mem_arr[0] == {"role": "user", "text": "E"}
assert mem_arr[1] == {"role": "agent", "text": "F"}
assert mem_arr[2] == {"role": "agent", "text": "G"}


def test_bad_window_size():
with pytest.raises(ValueError):
ChatMemory(window_size=-1)

0 comments on commit d111f15

Please sign in to comment.