-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor chat memory, add docstrings
- Loading branch information
Showing
4 changed files
with
123 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .chat_memory import ChatMemory, MessagePair, DEFAULT_WINDOW_SIZE | ||
from .chat_memory import ChatMemory, Message, DEFAULT_WINDOW_SIZE | ||
|
||
__all__ = ["ChatMemory", "MessagePair", "DEFAULT_WINDOW_SIZE"] | ||
__all__ = ["ChatMemory", "Message", "DEFAULT_WINDOW_SIZE"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,97 @@ | ||
from collections import deque | ||
from typing import Deque, Iterator, List, Dict | ||
from enum import Enum | ||
|
||
DEFAULT_WINDOW_SIZE = 40 | ||
|
||
|
||
class MessagePair: | ||
def __init__(self, user: str, agent: str): | ||
self.user = user | ||
self.agent = agent | ||
class Message: | ||
class Role(Enum): | ||
USER = "user" | ||
AGENT = "agent" | ||
|
||
def __init__(self, text: str, role: Role): | ||
self.text = text | ||
self.role = role | ||
|
||
|
||
class ChatMemory: | ||
"""Represents a sliding-window conversation memory between a user and an agent. `ChatMemory` is | ||
the most basic type of memory and is designed to be passed directly to chat-based models such | ||
as `llama3-70b-instruct`. | ||
""" | ||
|
||
def __init__(self, window_size: int = DEFAULT_WINDOW_SIZE) -> None: | ||
"""Create a new ChatMemory object. | ||
Args: | ||
window_size (int, optional): The maximum number of messages to store. | ||
Defaults to DEFAULT_WINDOW_SIZE. | ||
""" | ||
self.window_size: int = window_size | ||
self.memory: Deque[MessagePair] = deque(maxlen=window_size) | ||
self.mem_window: Deque[Message] = deque(maxlen=window_size) | ||
|
||
def add_user_message(self, text: str) -> None: | ||
"""Adds a user message to the memory. | ||
Args: | ||
text (str): The user message to add. | ||
""" | ||
self.mem_window.append(Message(text, Message.Role.USER)) | ||
|
||
def add(self, user: str, agent: str) -> None: | ||
self.memory.append(MessagePair(user, agent)) | ||
def add_agent_message(self, text: str) -> None: | ||
"""Adds an agent message to the memory. | ||
def get(self, index: int) -> MessagePair: | ||
return self.memory[index] | ||
Args: | ||
text (str): The agent message to add. | ||
""" | ||
self.mem_window.append(Message(text, Message.Role.AGENT)) | ||
|
||
def unpack( | ||
self, role_key: str, message_key: str, user_key: str, agent_key: str | ||
) -> List[Dict[str, str]]: | ||
"""Gets a representation of the memory as a list of objects designed to | ||
be passed directly into LLM provider APIs as JSON. | ||
For example, with the following memory: | ||
``` | ||
memory = ChatMemory() | ||
memory.add_user_message("Hello") | ||
memory.add_agent_message("Hi!") | ||
memory.add_user_message("How are you?") | ||
memory.add_agent_message("I'm good, how are you?") | ||
``` | ||
`memory.unpack("role", "content", "user", "assistant")` would return: | ||
``` | ||
[ | ||
{"role": "user", "content": "Hello"}, | ||
{"role": "assistant", "content": "Hi!"}, | ||
{"role": "user", "content": "How are you?"}, | ||
{"role": "assistant", "content": "I'm good, how are you?"} | ||
] | ||
``` | ||
Args: | ||
role_key (str): The key to use to denote role. For example, "role". | ||
message_key (str): The key to use to denote the message. For example, "content". | ||
user_key (str): The key to denote the user's message. For example, "user". | ||
agent_key (str): The key to denote the agent or model's message. For example, "assistant". | ||
Returns: | ||
List[Dict[str, str]]: The representation of the memory as a list of objects. | ||
""" | ||
res = [] | ||
for pair in self.memory: | ||
res.append({role_key: user_key, message_key: pair.user}) | ||
res.append({role_key: agent_key, message_key: pair.agent}) | ||
for message in self.mem_window: | ||
role_value = user_key if message.role == Message.Role.USER else agent_key | ||
res.append({role_key: role_value, message_key: message.text}) | ||
return res | ||
|
||
def clear(self) -> None: | ||
"""Clears the memory.""" | ||
self.mem_window.clear() | ||
|
||
def __len__(self) -> int: | ||
return len(self.memory) | ||
return len(self.mem_window) | ||
|
||
def __iter__(self) -> Iterator[MessagePair]: | ||
return iter(self.memory) | ||
def __iter__(self) -> Iterator[Message]: | ||
return iter(self.mem_window) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters