Skip to content

Commit

Permalink
Fix tasks tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yxjiang committed Mar 16, 2024
1 parent f2292e3 commit 93e2a88
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 29 deletions.
8 changes: 4 additions & 4 deletions polymind/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ class BaseTask(BaseModel, ABC):
But sometimes, a complex task can be divided into multiple sub-tasks.
"""

task_name: str
tool: BaseTool
task_name: str = Field(description="The name of the task.")
tool: BaseTool = Field(description="The tool to use for the task.")

def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init__(self, task_name: str, tool: BaseTool, **kwargs):
load_dotenv(override=True)
super().__init__(task_name=task_name, tool=tool, **kwargs)

async def __call__(self, input: Message) -> Message:
"""Makes the instance callable, delegating to the execute method.
Expand Down
61 changes: 36 additions & 25 deletions tests/polymind/core/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,42 +23,53 @@ def load_env_vars():


class MockTool(BaseTool):
"""The role of MockTool is to reverse the query and append the tool_name to the "tools_executed" list."""

def input_spec(self) -> list[Param]:
return [Param(name="query", type="str", description="The query to reverse")]
return [
Param(
name="query", type="str", example="", description="The query to reverse"
)
]

def output_spec(self) -> list[Param]:
return [Param(name="result", type="str", description="The reversed query")]
return [
Param(
name="result", type="str", example="", description="The reversed query"
)
]

async def _execute(self, input: Message) -> Message:
# Get the environment variable or use a default value
some_variable = os.getenv("SOME_TOOL_VARIABLE", "default_value")
some_variable = os.getenv("SOME_TOOL_VARIABLE", "")
if some_variable == "":
raise ValueError("Environment variable not loaded correctly")
# Ensure the content dictionary initializes "tools_executed" and "env_tool" as lists if they don't exist
content = input.content.copy()
tool_name = self.tool_name
# Append the tool_name to the "tools_executed" list
content.setdefault("tools_executed", []).append(tool_name)
# Ensure "env_tool" is initialized as a list and append the environment variable
if "env_tool" not in content:
content["env_tool"] = []
content["env_tool"].append(some_variable)
content.setdefault("env_tool", []).append(some_variable)
return Message(content=content)


class MockTask(BaseTask):
"""The role of MockTask is to append the task_name to the "tasks_executed"
list and append the environment variable to the "env_task" list.
"""

async def _execute(self, input: Message) -> Message:
"""Use the corresponding tool to reverse the query."""
# Get the environment variable or use a default value
some_variable = os.getenv("SOME_TASK_VARIABLE", "default_value")
some_variable = os.getenv("SOME_TASK_VARIABLE", "")
if some_variable == "":
raise ValueError("Environment variable not loaded correctly")
# Ensure the content dictionary initializes "tasks_executed" and "env_task" as lists if they don't exist
content = input.content.copy()
task_name = self.task_name
# Append the task_name to the "tasks_executed" list
content.setdefault("tasks_executed", []).append(task_name)
# Ensure "env_task" is initialized as a list and append the environment variable
if "env_task" not in content:
content["env_task"] = []
content["env_task"].append(some_variable)
return Message(content=content)
result = await self.tool(input)
result.content.setdefault("tasks_executed", []).append(self.task_name)
result.content.setdefault("env_task", []).append(some_variable)
return result


@pytest.mark.asyncio
Expand All @@ -75,15 +86,15 @@ async def test_sequential_task_execution(self):
input_message = Message(content={})
result_message = await sequential_task(input_message)

assert result_message.content["tasks_executed"] == [
f"Task{i}" for i in range(num_tasks)
], "Tasks executed in incorrect order"
# assert all(
# env_value == "test_tool" for env_value in result_message.content["env_tool"]
# ), "Tool environment variable not loaded correctly"
# assert all(
# env_value == "test_task" for env_value in result_message.content["env_task"]
# ), "Task environment variable not loaded correctly"
assert result_message.content["tools_executed"] == [
f"Tool{i}" for i in range(num_tasks)
], "Tools executed in incorrect order"
assert all(
env_value == "test_tool" for env_value in result_message.content["env_tool"]
), "Tool environment variable not loaded correctly"
assert all(
env_value == "test_task" for env_value in result_message.content["env_task"]
), "Task environment variable not loaded correctly"
assert (
sequential_task.context.content["idx"] == num_tasks
), "Context index not updated correctly"

0 comments on commit 93e2a88

Please sign in to comment.