diff --git a/README.md b/README.md index 3e2ff24..57506e2 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ Note that `agentai` automatically parses the Python Enum type (TemperatureUnit) 3. **Create a Conversation object and add messages** ```python -conversation = Conversation() +conversation = Conversation(model=GPT_MODEL) conversation.add_message("user", "what is the weather like today?") ``` @@ -152,7 +152,7 @@ agent_system_message = """You are ChinookGPT, a helpful assistant who gets answe Provide as many details as possible to your users Begin!""" -sql_conversation = Conversation() +sql_conversation = Conversation(model=GPT_MODEL) sql_conversation.add_message(role="system", content=agent_system_message) sql_conversation.add_message("user", "Hi, who are the top 5 artists by number of tracks") assistant_message = chat_complete_execute_fn( diff --git a/agentai/conversation.py b/agentai/conversation.py index c1caa1a..f7d0315 100644 --- a/agentai/conversation.py +++ b/agentai/conversation.py @@ -13,8 +13,15 @@ class Message(BaseModel): class Conversation: - def __init__(self, history: List[Message] = [], id: Optional[str] = None, max_history_tokens: int = 200): + def __init__( + self, + history: List[Message] = [], + id: Optional[str] = None, + max_history_tokens: Optional[int] = None, + model: Optional[str] = None, + ) -> None: self.history: List[Message] = history + self.trimmed_history: List[Message] = [] self.role_to_color = { "system": "red", "user": "green", @@ -23,6 +30,7 @@ def __init__(self, history: List[Message] = [], id: Optional[str] = None, max_hi } self.id = id self.max_history_tokens = max_history_tokens + self.model = model def add_message(self, role: str, content: str, name: Optional[str] = None) -> None: message_dict = {"role": role, "content": content} @@ -30,6 +38,8 @@ def add_message(self, role: str, content: str, name: Optional[str] = None) -> No message_dict["name"] = name message = Message(**message_dict) self.history.append(message) + if self.max_history_tokens and self.model: + self.trim_history() def display_conversation(self) -> None: for message in self.history: @@ -40,25 +50,42 @@ def display_conversation(self) -> None: ) ) - def get_history(self) -> List[Message]: + def trim_history(self) -> None: """Function to get the conversation history based on the number of tokens""" + + # raise an error if max_history_tokens or model is not set + if not self.max_history_tokens: + raise ValueError("max_history_tokens is not set in Conversation") + if not self.model: + raise ValueError("model is not set in Conversation") + local = threading.local() try: - enc = local.gpt2enc + enc = local.encoder except AttributeError: - enc = tiktoken.get_encoding("gpt2") - local.gpt2enc = enc + enc = tiktoken.encoding_for_model(self.model) + local.encoder = enc total_tokens = 0 # Iterate 2 at a time to avoid cutting in between a (prompt, response) pair - for i in range(len(self.history) -1, -1, -2): + for i in range(len(self.history) - 1, -1, -2): # Iterate over the messages in reverse order - from the latest to the oldest messages message = self.history[i] # Message(role='User', content='I appreciate that. Take care too!', name=None) content = message.content tokens = len(enc.encode(content)) total_tokens += tokens if total_tokens > self.max_history_tokens: - # Trim the history inplace to keep the total tokens under max_tokens - self.history = self.history[i + 1 :] + # Trim the history to keep the total tokens under max_tokens + # and store the trimmed history in self.trimmed_history + # to have a faster access to it later + self.trimmed_history = self.history[i + 1 :] break + + def get_history(self, trimmed: bool = False) -> List[Message]: + # Return the trimmed history if it exists and trimmed is True + if trimmed: + if not self.trimmed_history: + self.trim_history() + return self.trimmed_history + # Return the full history otherwise return self.history diff --git a/docs/04_Evaluating_OpenAI_Functions.ipynb b/docs/04_Evaluating_OpenAI_Functions.ipynb index 8b0e97d..8de2d64 100644 --- a/docs/04_Evaluating_OpenAI_Functions.ipynb +++ b/docs/04_Evaluating_OpenAI_Functions.ipynb @@ -23667,6 +23667,7 @@ ], "source": [ "import json\n", + "\n", "with open(\"evaluation_resuls.json\", \"w\") as f:\n", " json.dump(evaluation_results, f, indent=4)\n", "\n", diff --git a/docs/05_Conversation_History.ipynb b/docs/05_Conversation_History.ipynb index 3c0d2a1..e5c5d3d 100644 --- a/docs/05_Conversation_History.ipynb +++ b/docs/05_Conversation_History.ipynb @@ -9,7 +9,10 @@ "from agentai.conversation import Conversation\n", "\n", "# Create a conversation object with a maximum history length of 100 tokens\n", - "conversation = Conversation(max_history_tokens=100)" + "conversation = Conversation(max_history_tokens=100, model=\"gpt-3.5-turbo\")\n", + "# Note: max_history_tokens allows you to control the length of the conversation history in your prompt.\n", + "# Specify the 'model' being used if you're using the 'max_history_tokens'.\n", + "# 'model' here is used to pick the tokenizer for trimming the conversation history. " ] }, { @@ -18,7 +21,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "# Chatgpt Simulated conversations of different length\n", "conversation.add_message(\"User\", \"Hello!\")\n", "conversation.add_message(\"Assistant\", \"Hi there!\")\n", @@ -77,7 +79,7 @@ " \"You're welcome! I hope these tips help you find some relief from stress. If you have any more questions or need further assistance, don't hesitate to ask. Take care!\",\n", ")\n", "conversation.add_message(\"User\", \"I appreciate that. Take care too!\")\n", - "conversation.add_message(\"Assistant\", \"Thank you! Have a fantastic day!\")\n" + "conversation.add_message(\"Assistant\", \"Thank you! Have a fantastic day!\")" ] }, { @@ -115,7 +117,9 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "trimmed_history = conversation.get_history(trimmed=True)" + ] } ], "metadata": { @@ -134,7 +138,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.11.4" }, "orig_nbformat": 4 },