From 55ac05dd6419a0bceb2702636cabe8b362f69138 Mon Sep 17 00:00:00 2001 From: Yx Jiang <2237303+yxjiang@users.noreply.github.com> Date: Thu, 14 Mar 2024 21:42:40 -0700 Subject: [PATCH] Add openai chat tool --- polymind/__init__.py | 4 ++ polymind/core/tool.py | 4 +- polymind/tools/oai_tools.py | 59 ++++++++++++++++++++++++++ tests/polymind/tools/test_oai_tools.py | 56 ++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 polymind/tools/oai_tools.py create mode 100644 tests/polymind/tools/test_oai_tools.py diff --git a/polymind/__init__.py b/polymind/__init__.py index affa643..2562c55 100644 --- a/polymind/__init__.py +++ b/polymind/__init__.py @@ -1,4 +1,8 @@ +# Expose the core classes from .core.tool import BaseTool from .core.message import Message from .core.task import BaseTask, CompositeTask, SequentialTask from .core.agent import Agent, ThoughtProcess + +# Expose the tools +from .tools.oai_tools import OpenAIChatTool diff --git a/polymind/core/tool.py b/polymind/core/tool.py index 032419f..3b7a0bd 100644 --- a/polymind/core/tool.py +++ b/polymind/core/tool.py @@ -13,8 +13,8 @@ class BaseTool(BaseModel, ABC): tool_name: str - def __init__(self, **data): - super().__init__(**data) + def __init__(self, **kwargs): + super().__init__(**kwargs) load_dotenv(override=True) def __str__(self): diff --git a/polymind/tools/oai_tools.py b/polymind/tools/oai_tools.py new file mode 100644 index 0000000..973cc3d --- /dev/null +++ b/polymind/tools/oai_tools.py @@ -0,0 +1,59 @@ +""" +This file contains the necessary tools of using OpenAI models. +""" + +from pydantic import Field +from polymind.core.tool import BaseTool +from polymind.core.message import Message +from openai import AsyncOpenAI +import os +from dotenv import load_dotenv + + +class OpenAIChatTool(BaseTool): + """OpenAITool is a bridge to OpenAI APIs. + The tool can be initialized with llm_name, system_prompt, max_tokens, and temperature. + The input message of this tool should contain a "prompt", and optionally a "system_prompt". + The "system_prompt" in the input message will override the default system_prompt. + The tool will return a message with the response from the OpenAI chat API. + """ + + class Config: + arbitrary_types_allowed: bool = True # Allow arbitrary types + + tool_name: str = "open-ai-chat" + client: AsyncOpenAI = Field(default=None) + llm_name: str = Field(default="gpt-3.5-turbo") + system_prompt: str = Field(default="You are a helpful AI assistant.") + max_tokens: int = Field(default=1500) + temperature: float = Field(default=0.7) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + + async def _execute(self, input: Message) -> Message: + """Execute the tool and return the result. + The derived class must implement this method to define the behavior of the tool. + + Args: + input (Message): The input to the tool carried in a message. + + Returns: + Message: The result of the tool carried in a message. + """ + prompt = input.get("prompt", "") + system_prompt = input.get("system_prompt", self.system_prompt) + if not prompt: + raise ValueError("Prompt cannot be empty.") + + response = await self.client.chat.completions.create( + model=self.llm_name, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + ) + content = response.choices[0].message.content + response_message = Message(content={"response": content}) + return response_message diff --git a/tests/polymind/tools/test_oai_tools.py b/tests/polymind/tools/test_oai_tools.py new file mode 100644 index 0000000..11d5c6c --- /dev/null +++ b/tests/polymind/tools/test_oai_tools.py @@ -0,0 +1,56 @@ +"""Test cases for OpenAIChatTool. +Run the test with the following command: + poetry run pytest tests/polymind/tools/test_oai_tools.py +""" + +"""Test cases for OpenAIChatTool.""" + +import pytest +from unittest.mock import AsyncMock, patch +from polymind.core.message import Message +from polymind.tools.oai_tools import OpenAIChatTool + + +@pytest.fixture +def mock_env_vars(monkeypatch): + """Fixture to mock environment variables.""" + monkeypatch.setenv("OPENAI_API_KEY", "test_key") + + +@pytest.fixture +def tool(mock_env_vars): + """Fixture to create an instance of OpenAIChatTool with mocked environment variables.""" + llm_name = "gpt-4-turbo" + system_prompt = "You are an orchestrator" + return OpenAIChatTool(llm_name=llm_name, system_prompt=system_prompt) + + +@pytest.mark.asyncio +async def test_execute_success(tool): + """Test _execute method of OpenAIChatTool for successful API call.""" + prompt = "How are you?" + system_prompt = "You are a helpful AI assistant." + expected_response_content = "I'm doing great, thanks for asking!" + + # Patch the specific instance of AsyncOpenAI used by our tool instance + with patch.object( + tool.client.chat.completions, "create", new_callable=AsyncMock + ) as mock_create: + mock_create.return_value = AsyncMock( + choices=[AsyncMock(message=AsyncMock(content=expected_response_content))] + ) + input_message = Message( + content={"prompt": prompt, "system_prompt": system_prompt} + ) + response_message = await tool._execute(input_message) + + assert response_message.content["response"] == expected_response_content + mock_create.assert_called_once() + + +@pytest.mark.asyncio +async def test_execute_failure_empty_prompt(tool): + """Test _execute method of OpenAIChatTool raises ValueError with empty prompt.""" + with pytest.raises(ValueError) as excinfo: + await tool._execute(Message(content={"prompt": ""})) + assert "Prompt cannot be empty." in str(excinfo.value)