Skip to content

Commit

Permalink
Memory component (#60)
Browse files Browse the repository at this point in the history
* Fixed webloader bug

* Added memory component and Chat Buffer memory
  • Loading branch information
adithya-aiplanet authored Jul 16, 2024
1 parent a50f19b commit 75b9f2e
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/beyondllm/generator/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pysbd
from .base import BaseGenerator,GeneratorConfig
from dataclasses import dataclass,field
from beyondllm.memory.base import BaseMemory

def default_llm():
api_key = os.getenv('GOOGLE_API_KEY')
Expand Down Expand Up @@ -52,7 +53,8 @@ class Generate:
system_prompt: str = None
retriever:str = ''
llm: GeminiModel = field(default_factory=default_llm)

memory: BaseMemory = None

def __post_init__(self):
self.pipeline()

Expand All @@ -63,20 +65,30 @@ def pipeline(self):
if self.system_prompt is None:
self.system_prompt = """
You are an AI assistant who always answer to the user QUERY within the given CONTEXT \
You are only job here it to act as knowledge transfer and given accurate response to QUERY by looking in depth into CONTEXT \
If QUERY is not with the context, YOU MUST tell `I don't know. The query is not in the given context` \
You are only job here it to act as knowledge transfer and given accurate response to QUERY by looking in depth into CONTEXT. \
Use both the context and CHAT HISTORY to answer the query. \
If QUERY is not within the context or chat history, YOU MUST tell `I don't know. The query is not in the given context` \
YOU MUST not hallucinate. You are best when it comes to answering from CONTEXT \
If you FAIL to execute this task, you will be fired and you will suffer \
If you FAIL to execute this task, you will be fired and you will suffer
"""

memory_content = ""
if self.memory is not None:
memory_content = self.memory.get_memory()

template = f"""
{self.system_prompt}
CONTEXT: {temp}
--------------------
CHAT HISTORY: {memory_content}
QUERY: {self.question}
"""

self.RESPONSE = self.llm.predict(template)
# Store the question and response in memory
if self.memory is not None:
self.memory.add_to_memory(question=self.question, response=self.RESPONSE)
return self.CONTEXT,self.RESPONSE

def call(self):
Expand Down
1 change: 1 addition & 0 deletions src/beyondllm/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .chatBufferMemory import ChatBufferMemory
35 changes: 35 additions & 0 deletions src/beyondllm/memory/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Dict, List, Optional
from abc import ABC, abstractmethod
from pydantic import BaseModel

class MemoryConfig(BaseModel):
"""Base configuration model for all memory components.
This class can be extended to include more fields specific to certain memory components.
"""
pass

class BaseMemory(ABC):
"""
Base class for memory components.
"""

@abstractmethod
def load_memory(self, **kwargs) -> None:
"""Loads the memory component."""
raise NotImplementedError

@abstractmethod
def add_to_memory(self, context: str, response: str) -> None:
"""Adds the conversation context and response to memory."""
raise NotImplementedError

@abstractmethod
def get_memory(self) -> Any:
"""Returns the current memory."""
raise NotImplementedError

@abstractmethod
def clear_memory(self) -> None:
"""Clears the memory."""
raise NotImplementedError
41 changes: 41 additions & 0 deletions src/beyondllm/memory/chatBufferMemory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from dataclasses import dataclass, field
from typing import List, Optional, Dict
from .base import BaseMemory, MemoryConfig
import re

@dataclass
class ChatBufferMemory(BaseMemory):
"""
ChatBufferMemory stores the recent conversation history as a buffer.
Args:
window_size (int): The maximum number of messages to store in the buffer.
memory_key (str): The key to identify the memory.
max_tokens (int): The maximum number of tokens to store in the buffer. Defaults to None, which means no limit.
"""

window_size: int = 5
memory_key: str = "chat_buffer"
_buffer: List[Dict[str, str]] = field(default_factory=list)
config: MemoryConfig = field(default_factory=MemoryConfig)

def load_memory(self, **kwargs) -> None:
self._buffer = []
self.config = MemoryConfig(**kwargs)

def add_to_memory(self, question: str, response: str) -> None:
"""Adds the context and response to the buffer, maintaining the window size."""
self._buffer.append(
{"question": question, "response": response}
)

if len(self._buffer) > self.window_size+1:
self._buffer.pop(0)

def get_memory(self) -> List[Dict[str, str]]:
"""Returns the current buffer."""
return self._buffer

def clear_memory(self) -> None:
"""Clears the buffer."""
self._buffer = []

0 comments on commit 75b9f2e

Please sign in to comment.