diff --git a/src/beyondllm/generator/generate.py b/src/beyondllm/generator/generate.py index 66a9535..b9bf127 100644 --- a/src/beyondllm/generator/generate.py +++ b/src/beyondllm/generator/generate.py @@ -7,6 +7,7 @@ import pysbd from .base import BaseGenerator,GeneratorConfig from dataclasses import dataclass,field +from beyondllm.memory.base import BaseMemory def default_llm(): api_key = os.getenv('GOOGLE_API_KEY') @@ -52,7 +53,8 @@ class Generate: system_prompt: str = None retriever:str = '' llm: GeminiModel = field(default_factory=default_llm) - + memory: BaseMemory = None + def __post_init__(self): self.pipeline() @@ -63,20 +65,30 @@ def pipeline(self): if self.system_prompt is None: self.system_prompt = """ You are an AI assistant who always answer to the user QUERY within the given CONTEXT \ - You are only job here it to act as knowledge transfer and given accurate response to QUERY by looking in depth into CONTEXT \ - If QUERY is not with the context, YOU MUST tell `I don't know. The query is not in the given context` \ + You are only job here it to act as knowledge transfer and given accurate response to QUERY by looking in depth into CONTEXT. \ + Use both the context and CHAT HISTORY to answer the query. \ + If QUERY is not within the context or chat history, YOU MUST tell `I don't know. The query is not in the given context` \ YOU MUST not hallucinate. You are best when it comes to answering from CONTEXT \ - If you FAIL to execute this task, you will be fired and you will suffer \ + If you FAIL to execute this task, you will be fired and you will suffer """ + memory_content = "" + if self.memory is not None: + memory_content = self.memory.get_memory() + template = f""" {self.system_prompt} CONTEXT: {temp} + -------------------- + CHAT HISTORY: {memory_content} QUERY: {self.question} """ self.RESPONSE = self.llm.predict(template) + # Store the question and response in memory + if self.memory is not None: + self.memory.add_to_memory(question=self.question, response=self.RESPONSE) return self.CONTEXT,self.RESPONSE def call(self): diff --git a/src/beyondllm/memory/__init__.py b/src/beyondllm/memory/__init__.py new file mode 100644 index 0000000..39af6f5 --- /dev/null +++ b/src/beyondllm/memory/__init__.py @@ -0,0 +1 @@ +from .chatBufferMemory import ChatBufferMemory \ No newline at end of file diff --git a/src/beyondllm/memory/base.py b/src/beyondllm/memory/base.py new file mode 100644 index 0000000..99a49a6 --- /dev/null +++ b/src/beyondllm/memory/base.py @@ -0,0 +1,35 @@ +from typing import Any, Dict, List, Optional +from abc import ABC, abstractmethod +from pydantic import BaseModel + +class MemoryConfig(BaseModel): + """Base configuration model for all memory components. + + This class can be extended to include more fields specific to certain memory components. + """ + pass + +class BaseMemory(ABC): + """ + Base class for memory components. + """ + + @abstractmethod + def load_memory(self, **kwargs) -> None: + """Loads the memory component.""" + raise NotImplementedError + + @abstractmethod + def add_to_memory(self, context: str, response: str) -> None: + """Adds the conversation context and response to memory.""" + raise NotImplementedError + + @abstractmethod + def get_memory(self) -> Any: + """Returns the current memory.""" + raise NotImplementedError + + @abstractmethod + def clear_memory(self) -> None: + """Clears the memory.""" + raise NotImplementedError diff --git a/src/beyondllm/memory/chatBufferMemory.py b/src/beyondllm/memory/chatBufferMemory.py new file mode 100644 index 0000000..a9c091d --- /dev/null +++ b/src/beyondllm/memory/chatBufferMemory.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass, field +from typing import List, Optional, Dict +from .base import BaseMemory, MemoryConfig +import re + +@dataclass +class ChatBufferMemory(BaseMemory): + """ + ChatBufferMemory stores the recent conversation history as a buffer. + + Args: + window_size (int): The maximum number of messages to store in the buffer. + memory_key (str): The key to identify the memory. + max_tokens (int): The maximum number of tokens to store in the buffer. Defaults to None, which means no limit. + """ + + window_size: int = 5 + memory_key: str = "chat_buffer" + _buffer: List[Dict[str, str]] = field(default_factory=list) + config: MemoryConfig = field(default_factory=MemoryConfig) + + def load_memory(self, **kwargs) -> None: + self._buffer = [] + self.config = MemoryConfig(**kwargs) + + def add_to_memory(self, question: str, response: str) -> None: + """Adds the context and response to the buffer, maintaining the window size.""" + self._buffer.append( + {"question": question, "response": response} + ) + + if len(self._buffer) > self.window_size+1: + self._buffer.pop(0) + + def get_memory(self) -> List[Dict[str, str]]: + """Returns the current buffer.""" + return self._buffer + + def clear_memory(self) -> None: + """Clears the buffer.""" + self._buffer = [] \ No newline at end of file