-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
121 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,8 @@ | ||
# Expose the core classes | ||
from .core.tool import BaseTool | ||
from .core.message import Message | ||
from .core.task import BaseTask, CompositeTask, SequentialTask | ||
from .core.agent import Agent, ThoughtProcess | ||
|
||
# Expose the tools | ||
from .tools.oai_tools import OpenAIChatTool |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
""" | ||
This file contains the necessary tools of using OpenAI models. | ||
""" | ||
|
||
from pydantic import Field | ||
from polymind.core.tool import BaseTool | ||
from polymind.core.message import Message | ||
from openai import AsyncOpenAI | ||
import os | ||
from dotenv import load_dotenv | ||
|
||
|
||
class OpenAIChatTool(BaseTool): | ||
"""OpenAITool is a bridge to OpenAI APIs. | ||
The tool can be initialized with llm_name, system_prompt, max_tokens, and temperature. | ||
The input message of this tool should contain a "prompt", and optionally a "system_prompt". | ||
The "system_prompt" in the input message will override the default system_prompt. | ||
The tool will return a message with the response from the OpenAI chat API. | ||
""" | ||
|
||
class Config: | ||
arbitrary_types_allowed: bool = True # Allow arbitrary types | ||
|
||
tool_name: str = "open-ai-chat" | ||
client: AsyncOpenAI = Field(default=None) | ||
llm_name: str = Field(default="gpt-3.5-turbo") | ||
system_prompt: str = Field(default="You are a helpful AI assistant.") | ||
max_tokens: int = Field(default=1500) | ||
temperature: float = Field(default=0.7) | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | ||
|
||
async def _execute(self, input: Message) -> Message: | ||
"""Execute the tool and return the result. | ||
The derived class must implement this method to define the behavior of the tool. | ||
Args: | ||
input (Message): The input to the tool carried in a message. | ||
Returns: | ||
Message: The result of the tool carried in a message. | ||
""" | ||
prompt = input.get("prompt", "") | ||
system_prompt = input.get("system_prompt", self.system_prompt) | ||
if not prompt: | ||
raise ValueError("Prompt cannot be empty.") | ||
|
||
response = await self.client.chat.completions.create( | ||
model=self.llm_name, | ||
messages=[ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": prompt}, | ||
], | ||
) | ||
content = response.choices[0].message.content | ||
response_message = Message(content={"response": content}) | ||
return response_message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Test cases for OpenAIChatTool. | ||
Run the test with the following command: | ||
poetry run pytest tests/polymind/tools/test_oai_tools.py | ||
""" | ||
|
||
"""Test cases for OpenAIChatTool.""" | ||
|
||
import pytest | ||
from unittest.mock import AsyncMock, patch | ||
from polymind.core.message import Message | ||
from polymind.tools.oai_tools import OpenAIChatTool | ||
|
||
|
||
@pytest.fixture | ||
def mock_env_vars(monkeypatch): | ||
"""Fixture to mock environment variables.""" | ||
monkeypatch.setenv("OPENAI_API_KEY", "test_key") | ||
|
||
|
||
@pytest.fixture | ||
def tool(mock_env_vars): | ||
"""Fixture to create an instance of OpenAIChatTool with mocked environment variables.""" | ||
llm_name = "gpt-4-turbo" | ||
system_prompt = "You are an orchestrator" | ||
return OpenAIChatTool(llm_name=llm_name, system_prompt=system_prompt) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_execute_success(tool): | ||
"""Test _execute method of OpenAIChatTool for successful API call.""" | ||
prompt = "How are you?" | ||
system_prompt = "You are a helpful AI assistant." | ||
expected_response_content = "I'm doing great, thanks for asking!" | ||
|
||
# Patch the specific instance of AsyncOpenAI used by our tool instance | ||
with patch.object( | ||
tool.client.chat.completions, "create", new_callable=AsyncMock | ||
) as mock_create: | ||
mock_create.return_value = AsyncMock( | ||
choices=[AsyncMock(message=AsyncMock(content=expected_response_content))] | ||
) | ||
input_message = Message( | ||
content={"prompt": prompt, "system_prompt": system_prompt} | ||
) | ||
response_message = await tool._execute(input_message) | ||
|
||
assert response_message.content["response"] == expected_response_content | ||
mock_create.assert_called_once() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_execute_failure_empty_prompt(tool): | ||
"""Test _execute method of OpenAIChatTool raises ValueError with empty prompt.""" | ||
with pytest.raises(ValueError) as excinfo: | ||
await tool._execute(Message(content={"prompt": ""})) | ||
assert "Prompt cannot be empty." in str(excinfo.value) |