Skip to content

Commit

Permalink
refactor chat memory, add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
pkelaita committed May 5, 2024
1 parent a1b013c commit 9336db0
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 22 deletions.
49 changes: 44 additions & 5 deletions l2m2/client/llm_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Set, Dict, Optional, List
from typing import Any, Set, Dict, Optional

import google.generativeai as google
from cohere import Client as CohereClient
Expand Down Expand Up @@ -180,6 +180,40 @@ def set_preferred_providers(self, preferred_providers: Dict[str, str]) -> None:

self.preferred_providers.update(preferred_providers)

def get_memory(self) -> ChatMemory:
"""Get the memory object, if memory is enabled.
Returns:
ChatMemory: The memory object.
Raises:
ValueError: If memory is not enabled.
"""
if self.memory is None:
raise ValueError(
"Client memory is not enabled. Instantiate the LLM client with enable_memory=True"
+ " to enable memory."
)
return self.memory

def clear_memory(self) -> None:
"""Clear the memory, if memory is enabled.
Raises:
ValueError: If memory is not enabled.
"""
if self.memory is None:
raise ValueError(
"Client memory is not enabled. Instantiate the LLM client with enable_memory=True"
+ " to enable memory."
)
self.memory.clear()

def enable_memory(self) -> None:
"""Enable memory if it is not already enabled."""
if self.memory is None:
self.memory = ChatMemory()

def call(
self,
*,
Expand Down Expand Up @@ -249,15 +283,14 @@ def call(
+ " default provider for the model with set_preferred_providers."
)

result = self._call_impl(
return self._call_impl(
MODEL_INFO[model][provider],
provider,
prompt,
system_prompt,
temperature,
max_tokens,
)
return result

def call_custom(
self,
Expand Down Expand Up @@ -310,7 +343,12 @@ def call_custom(
}

return self._call_impl(
model_info, provider, prompt, system_prompt, temperature, max_tokens
model_info,
provider,
prompt,
system_prompt,
temperature,
max_tokens,
)

def _call_impl(
Expand Down Expand Up @@ -347,7 +385,8 @@ def add_param(name: str, value: Any) -> None:
)
assert isinstance(result, str), "This should never happen."
if self.memory is not None:
self.memory.add(prompt, result)
self.memory.add_user_message(prompt)
self.memory.add_agent_message(result)
return result

def _call_openai(
Expand Down
4 changes: 2 additions & 2 deletions l2m2/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .chat_memory import ChatMemory, MessagePair, DEFAULT_WINDOW_SIZE
from .chat_memory import ChatMemory, Message, DEFAULT_WINDOW_SIZE

__all__ = ["ChatMemory", "MessagePair", "DEFAULT_WINDOW_SIZE"]
__all__ = ["ChatMemory", "Message", "DEFAULT_WINDOW_SIZE"]
90 changes: 75 additions & 15 deletions l2m2/memory/chat_memory.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,97 @@
from collections import deque
from typing import Deque, Iterator, List, Dict
from enum import Enum

DEFAULT_WINDOW_SIZE = 40


class MessagePair:
def __init__(self, user: str, agent: str):
self.user = user
self.agent = agent
class Message:
class Role(Enum):
USER = "user"
AGENT = "agent"

def __init__(self, text: str, role: Role):
self.text = text
self.role = role


class ChatMemory:
"""Represents a sliding-window conversation memory between a user and an agent. `ChatMemory` is
the most basic type of memory and is designed to be passed directly to chat-based models such
as `llama3-70b-instruct`.
"""

def __init__(self, window_size: int = DEFAULT_WINDOW_SIZE) -> None:
"""Create a new ChatMemory object.
Args:
window_size (int, optional): The maximum number of messages to store.
Defaults to DEFAULT_WINDOW_SIZE.
"""
self.window_size: int = window_size
self.memory: Deque[MessagePair] = deque(maxlen=window_size)
self.mem_window: Deque[Message] = deque(maxlen=window_size)

def add_user_message(self, text: str) -> None:
"""Adds a user message to the memory.
Args:
text (str): The user message to add.
"""
self.mem_window.append(Message(text, Message.Role.USER))

def add(self, user: str, agent: str) -> None:
self.memory.append(MessagePair(user, agent))
def add_agent_message(self, text: str) -> None:
"""Adds an agent message to the memory.
def get(self, index: int) -> MessagePair:
return self.memory[index]
Args:
text (str): The agent message to add.
"""
self.mem_window.append(Message(text, Message.Role.AGENT))

def unpack(
self, role_key: str, message_key: str, user_key: str, agent_key: str
) -> List[Dict[str, str]]:
"""Gets a representation of the memory as a list of objects designed to
be passed directly into LLM provider APIs as JSON.
For example, with the following memory:
```
memory = ChatMemory()
memory.add_user_message("Hello")
memory.add_agent_message("Hi!")
memory.add_user_message("How are you?")
memory.add_agent_message("I'm good, how are you?")
```
`memory.unpack("role", "content", "user", "assistant")` would return:
```
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "How are you?"},
{"role": "assistant", "content": "I'm good, how are you?"}
]
```
Args:
role_key (str): The key to use to denote role. For example, "role".
message_key (str): The key to use to denote the message. For example, "content".
user_key (str): The key to denote the user's message. For example, "user".
agent_key (str): The key to denote the agent or model's message. For example, "assistant".
Returns:
List[Dict[str, str]]: The representation of the memory as a list of objects.
"""
res = []
for pair in self.memory:
res.append({role_key: user_key, message_key: pair.user})
res.append({role_key: agent_key, message_key: pair.agent})
for message in self.mem_window:
role_value = user_key if message.role == Message.Role.USER else agent_key
res.append({role_key: role_value, message_key: message.text})
return res

def clear(self) -> None:
"""Clears the memory."""
self.mem_window.clear()

def __len__(self) -> int:
return len(self.memory)
return len(self.mem_window)

def __iter__(self) -> Iterator[MessagePair]:
return iter(self.memory)
def __iter__(self) -> Iterator[Message]:
return iter(self.mem_window)
2 changes: 2 additions & 0 deletions tests/l2m2/client/test_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def _generic_test_call(
model_name,
):
mock_client = Mock()
if provider_key != "replicate":
llm_client.enable_memory()

# Dynamically get the mock call and response objects based on the delimited paths
mock_call = get_nested_attribute(mock_client, call_path)
Expand Down

0 comments on commit 9336db0

Please sign in to comment.