Skip to content

Commit

Permalink
add unpack for memory
Browse files Browse the repository at this point in the history
  • Loading branch information
pkelaita committed May 5, 2024
1 parent e9583e0 commit a1b013c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 23 deletions.
34 changes: 12 additions & 22 deletions l2m2/client/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,11 @@ def _call_openai(
params: Dict[str, Any],
) -> str:
oai = OpenAI(api_key=self.api_keys["openai"])
messages: List[Dict[str, str]] = []
messages = []
if system_prompt is not None:
messages.insert(0, {"role": "system", "content": system_prompt})
messages.append({"role": "system", "content": system_prompt})
if self.memory is not None:
for pair in self.memory:
messages.append({"role": "user", "content": pair.user})
messages.append({"role": "assistant", "content": pair.agent})
messages.extend(self.memory.unpack("role", "content", "user", "assistant"))
messages.append({"role": "user", "content": prompt})
result = oai.chat.completions.create(
model=model_id,
Expand All @@ -383,11 +381,9 @@ def _call_anthropic(
anthr = Anthropic(api_key=self.api_keys["anthropic"])
if system_prompt is not None:
params["system"] = system_prompt
messages: List[Dict[str, str]] = []
messages = []
if self.memory is not None:
for pair in self.memory:
messages.append({"role": "user", "content": pair.user})
messages.append({"role": "assistant", "content": pair.agent})
messages.extend(self.memory.unpack("role", "content", "user", "assistant"))
messages.append({"role": "user", "content": prompt})
result = anthr.messages.create(
model=model_id,
Expand All @@ -407,11 +403,9 @@ def _call_cohere(
if system_prompt is not None:
params["preamble"] = system_prompt
if self.memory is not None:
chat_history = []
for pair in self.memory:
chat_history.append({"role": "USER", "message": pair.user})
chat_history.append({"role": "CHATBOT", "message": pair.agent})
params["chat_history"] = chat_history
params["chat_history"] = self.memory.unpack(
"role", "message", "USER", "CHATBOT"
)
result = cohere.chat(
model=model_id,
message=prompt,
Expand All @@ -431,9 +425,7 @@ def _call_groq(
if system_prompt is not None:
messages.append({"role": "system", "content": system_prompt})
if self.memory is not None:
for pair in self.memory:
messages.append({"role": "user", "content": pair.user})
messages.append({"role": "assistant", "content": pair.agent})
messages.extend(self.memory.unpack("role", "content", "user", "assistant"))
messages.append({"role": "user", "content": prompt})
result = groq.chat.completions.create(
model=model_id,
Expand All @@ -458,15 +450,13 @@ def _call_google(
prompt = f"{system_prompt}\n{prompt}"
else:
model_params["system_instruction"] = system_prompt
model = google.GenerativeModel(**model_params)

messages = []
if self.memory is not None:
for pair in self.memory:
messages.append({"role": "user", "parts": [pair.user]})
messages.append({"role": "model", "parts": [pair.agent]})
messages.append({"role": "user", "parts": [prompt]})
messages.extend(self.memory.unpack("role", "parts", "user", "model"))
messages.append({"role": "user", "parts": prompt})

model = google.GenerativeModel(**model_params)
result = model.generate_content(messages, generation_config=params)
result = result.candidates[0]

Expand Down
11 changes: 10 additions & 1 deletion l2m2/memory/chat_memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import deque
from typing import Deque, Iterator
from typing import Deque, Iterator, List, Dict

DEFAULT_WINDOW_SIZE = 40

Expand All @@ -21,6 +21,15 @@ def add(self, user: str, agent: str) -> None:
def get(self, index: int) -> MessagePair:
return self.memory[index]

def unpack(
self, role_key: str, message_key: str, user_key: str, agent_key: str
) -> List[Dict[str, str]]:
res = []
for pair in self.memory:
res.append({role_key: user_key, message_key: pair.user})
res.append({role_key: agent_key, message_key: pair.agent})
return res

def __len__(self) -> int:
return len(self.memory)

Expand Down

0 comments on commit a1b013c

Please sign in to comment.