Skip to content

Commit

Permalink
Add openai chat tool
Browse files Browse the repository at this point in the history
  • Loading branch information
yxjiang committed Mar 15, 2024
1 parent abd8ad9 commit 55ac05d
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 2 deletions.
4 changes: 4 additions & 0 deletions polymind/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Expose the core classes
from .core.tool import BaseTool
from .core.message import Message
from .core.task import BaseTask, CompositeTask, SequentialTask
from .core.agent import Agent, ThoughtProcess

# Expose the tools
from .tools.oai_tools import OpenAIChatTool
4 changes: 2 additions & 2 deletions polymind/core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class BaseTool(BaseModel, ABC):

tool_name: str

def __init__(self, **data):
super().__init__(**data)
def __init__(self, **kwargs):
super().__init__(**kwargs)
load_dotenv(override=True)

def __str__(self):
Expand Down
59 changes: 59 additions & 0 deletions polymind/tools/oai_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
This file contains the necessary tools of using OpenAI models.
"""

from pydantic import Field
from polymind.core.tool import BaseTool
from polymind.core.message import Message
from openai import AsyncOpenAI
import os
from dotenv import load_dotenv


class OpenAIChatTool(BaseTool):
"""OpenAITool is a bridge to OpenAI APIs.
The tool can be initialized with llm_name, system_prompt, max_tokens, and temperature.
The input message of this tool should contain a "prompt", and optionally a "system_prompt".
The "system_prompt" in the input message will override the default system_prompt.
The tool will return a message with the response from the OpenAI chat API.
"""

class Config:
arbitrary_types_allowed: bool = True # Allow arbitrary types

tool_name: str = "open-ai-chat"
client: AsyncOpenAI = Field(default=None)
llm_name: str = Field(default="gpt-3.5-turbo")
system_prompt: str = Field(default="You are a helpful AI assistant.")
max_tokens: int = Field(default=1500)
temperature: float = Field(default=0.7)

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

async def _execute(self, input: Message) -> Message:
"""Execute the tool and return the result.
The derived class must implement this method to define the behavior of the tool.
Args:
input (Message): The input to the tool carried in a message.
Returns:
Message: The result of the tool carried in a message.
"""
prompt = input.get("prompt", "")
system_prompt = input.get("system_prompt", self.system_prompt)
if not prompt:
raise ValueError("Prompt cannot be empty.")

response = await self.client.chat.completions.create(
model=self.llm_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
)
content = response.choices[0].message.content
response_message = Message(content={"response": content})
return response_message
56 changes: 56 additions & 0 deletions tests/polymind/tools/test_oai_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Test cases for OpenAIChatTool.
Run the test with the following command:
poetry run pytest tests/polymind/tools/test_oai_tools.py
"""

"""Test cases for OpenAIChatTool."""

import pytest
from unittest.mock import AsyncMock, patch
from polymind.core.message import Message
from polymind.tools.oai_tools import OpenAIChatTool


@pytest.fixture
def mock_env_vars(monkeypatch):
"""Fixture to mock environment variables."""
monkeypatch.setenv("OPENAI_API_KEY", "test_key")


@pytest.fixture
def tool(mock_env_vars):
"""Fixture to create an instance of OpenAIChatTool with mocked environment variables."""
llm_name = "gpt-4-turbo"
system_prompt = "You are an orchestrator"
return OpenAIChatTool(llm_name=llm_name, system_prompt=system_prompt)


@pytest.mark.asyncio
async def test_execute_success(tool):
"""Test _execute method of OpenAIChatTool for successful API call."""
prompt = "How are you?"
system_prompt = "You are a helpful AI assistant."
expected_response_content = "I'm doing great, thanks for asking!"

# Patch the specific instance of AsyncOpenAI used by our tool instance
with patch.object(
tool.client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = AsyncMock(
choices=[AsyncMock(message=AsyncMock(content=expected_response_content))]
)
input_message = Message(
content={"prompt": prompt, "system_prompt": system_prompt}
)
response_message = await tool._execute(input_message)

assert response_message.content["response"] == expected_response_content
mock_create.assert_called_once()


@pytest.mark.asyncio
async def test_execute_failure_empty_prompt(tool):
"""Test _execute method of OpenAIChatTool raises ValueError with empty prompt."""
with pytest.raises(ValueError) as excinfo:
await tool._execute(Message(content={"prompt": ""}))
assert "Prompt cannot be empty." in str(excinfo.value)

0 comments on commit 55ac05d

Please sign in to comment.