Skip to content

Commit

Permalink
Merge pull request #64 from small-thinking/enhance-tool
Browse files Browse the repository at this point in the history
Add open function conversion for tools
  • Loading branch information
yxjiang authored Apr 2, 2024
2 parents fd84561 + 9698ea0 commit 90e91e6
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 2 deletions.
53 changes: 52 additions & 1 deletion polymind/core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from typing import Any, Dict, List, get_origin
from typing import Any, Dict, List, Union, get_origin

from dotenv import load_dotenv
from pydantic import BaseModel, Field, field_validator
Expand All @@ -24,6 +24,30 @@ class Param(BaseModel):
description: str = Field(description="A description of the parameter.")
example: str = Field(default="", description="An example value for the parameter.")

def to_open_function_format(self) -> Dict[str, Union[str, bool, Dict[str, Any]]]:
"""Convert the parameter to the Open Function format."""
# Remove the element type if is a list or dict, replace int to integer
type_str = self.type
if type_str.startswith("List[") or type_str.startswith("Dict["):
type_str = type_str.split("[")[0]
elif type_str == "int":
type_str = "integer"
elif type_str == "ndarray" or type_str == "np.ndarray" or type_str == "numpy.ndarray":
type_str = "object"
elif type_str == "pandas.DataFrame" or type_str == "pd.DataFrame" or type_str == "DataFrame":
type_str = "object"
elif type_str == "str":
type_str = "string"
property_dict = {
"type": type_str.lower(),
"description": self.description,
}

if self.example:
property_dict["example"] = str(self.example)

return {self.name: property_dict}

def to_json_obj(self) -> Dict[str, str]:
return {
"name": self.name,
Expand Down Expand Up @@ -150,6 +174,33 @@ def input_spec(self) -> List[Param]:
"""
pass

def to_open_function_format(self) -> Dict[str, Union[str, Dict[str, Any]]]:
"""Return the specification of the tool in the format expected by the open function."""
input_properties = {}
for param in self.input_spec():
input_properties.update(param.to_open_function_format())

output_properties = {}
for param in self.output_spec():
output_properties.update(param.to_open_function_format())

return {
"type": "function",
"function": {
"name": self.tool_name,
"description": self.descriptions[0], # Use the first description as the main description
"parameters": {
"type": "object",
"properties": input_properties,
"required": [param.name for param in self.input_spec() if param.required],
},
"responses": {
"type": "object",
"properties": output_properties,
},
},
}

def _validate_input_message(self, input_message: Message) -> None:
"""Validate the input message against the input spec.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "polymind"
version = "0.0.35" # Update this version before publishing to PyPI
version = "0.0.36" # Update this version before publishing to PyPI
description = "PolyMind is a customizable collaborative multi-agent framework for collective intelligence and distributed problem solving."
authors = ["TechTao"]
license = "MIT License"
Expand Down
137 changes: 137 additions & 0 deletions tests/polymind/core/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,84 @@ def test_param_with_description_and_default_example(self):
assert param.description == description, "Param description should match the input description"
assert param.example == "", "Param example should use the default empty string if not specified"

@pytest.mark.parametrize(
"param_args, expected_output",
[
(
dict(
name="query",
type="str",
description="The query to search for",
example="hello world",
),
{"query": {"type": "string", "description": "The query to search for", "example": "hello world"}},
),
(
dict(
name="options",
type="List[Dict[str, int]]",
description="A list of option dictionaries",
example="[{'option1': 1}, {'option2': 2}]",
),
{
"options": {
"type": "list",
"description": "A list of option dictionaries",
"example": "[{'option1': 1}, {'option2': 2}]",
}
},
),
(
dict(
name="flag",
type="bool",
description="A boolean flag",
),
{"flag": {"type": "bool", "description": "A boolean flag"}},
),
],
)
def test_to_open_function_format(self, param_args, expected_output):
"""Test that to_open_function_format works correctly for different parameter types and examples."""
param = Param(**param_args)
assert param.to_open_function_format() == expected_output

@pytest.mark.parametrize(
"param_list, expected_output",
[
(
[
Param(
name="query",
type="str",
description="The query to search for",
example="hello world",
),
Param(
name="page",
type="int",
description="The page number to return",
example="1",
),
],
{
"query": {
"type": "string",
"description": "The query to search for",
"example": "hello world",
},
"page": {"type": "integer", "description": "The page number to return", "example": "1"},
},
),
],
)
def test_to_open_function_format_multiple_params(self, param_list, expected_output):
"""Test that multiple Param instances can be combined correctly."""
combined_output = {}
for param in param_list:
combined_output.update(param.to_open_function_format())
assert combined_output == expected_output


# Test tools that fails the validation
class NoNameTool(BaseTool):
Expand Down Expand Up @@ -236,6 +314,26 @@ async def _execute(self, input: Message) -> Message:
)


class ExampleTool(BaseTool):
tool_name: str = "example_tool"
descriptions: List[str] = ["Performs an example task", "Useful for testing", "Demonstrates open function format"]

def input_spec(self) -> List[Param]:
return [
Param(name="input1", type="str", required=True, description="First input parameter", example="example1"),
Param(name="input2", type="int", required=False, description="Second input parameter", example="2"),
]

def output_spec(self) -> List[Param]:
return [
Param(name="output", type="str", required=True, description="Output parameter", example="result"),
]

async def _execute(self, input: Message) -> Message:
message = Message(content={"output": "result"})
return message


@pytest.fixture(autouse=True)
def load_env_vars():
# Setup: Define environment variable before each test
Expand Down Expand Up @@ -303,3 +401,42 @@ def test_get_spec(self):
spec_json_obj = json.loads(spec_str)
expected_json_obj = json.loads(expected_json_str)
assert spec_json_obj == expected_json_obj, "The spec string should match the expected JSON string"

@pytest.mark.parametrize(
"tool_instance, expected_spec",
[
(
ExampleTool(),
{
"type": "function",
"function": {
"name": "example_tool",
"description": "Performs an example task",
"parameters": {
"type": "object",
"properties": {
"input1": {
"type": "string",
"example": "example1",
"description": "First input parameter",
},
"input2": {"type": "integer", "example": "2", "description": "Second input parameter"},
},
"required": ["input1"],
},
"responses": {
"type": "object",
"properties": {
"output": {"type": "string", "example": "result", "description": "Output parameter"}
},
},
},
},
),
],
)
def test_to_open_function_format(self, tool_instance, expected_spec):
spec = tool_instance.to_open_function_format()
assert (
spec == expected_spec
), "The generated open function format specification should match the expected specification"

0 comments on commit 90e91e6

Please sign in to comment.