From 93e2a888cc5ce033791690cbeae2eab0cf6fa80a Mon Sep 17 00:00:00 2001 From: Yx Jiang <2237303+yxjiang@users.noreply.github.com> Date: Fri, 15 Mar 2024 23:31:05 -0700 Subject: [PATCH] Fix tasks tests --- polymind/core/task.py | 8 ++--- tests/polymind/core/test_task.py | 61 +++++++++++++++++++------------- 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/polymind/core/task.py b/polymind/core/task.py index 0b5f5cf..25557ca 100644 --- a/polymind/core/task.py +++ b/polymind/core/task.py @@ -16,12 +16,12 @@ class BaseTask(BaseModel, ABC): But sometimes, a complex task can be divided into multiple sub-tasks. """ - task_name: str - tool: BaseTool + task_name: str = Field(description="The name of the task.") + tool: BaseTool = Field(description="The tool to use for the task.") - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self, task_name: str, tool: BaseTool, **kwargs): load_dotenv(override=True) + super().__init__(task_name=task_name, tool=tool, **kwargs) async def __call__(self, input: Message) -> Message: """Makes the instance callable, delegating to the execute method. diff --git a/tests/polymind/core/test_task.py b/tests/polymind/core/test_task.py index 1bec0b5..1a4316c 100644 --- a/tests/polymind/core/test_task.py +++ b/tests/polymind/core/test_task.py @@ -23,42 +23,53 @@ def load_env_vars(): class MockTool(BaseTool): + """The role of MockTool is to reverse the query and append the tool_name to the "tools_executed" list.""" def input_spec(self) -> list[Param]: - return [Param(name="query", type="str", description="The query to reverse")] + return [ + Param( + name="query", type="str", example="", description="The query to reverse" + ) + ] def output_spec(self) -> list[Param]: - return [Param(name="result", type="str", description="The reversed query")] + return [ + Param( + name="result", type="str", example="", description="The reversed query" + ) + ] async def _execute(self, input: Message) -> Message: # Get the environment variable or use a default value - some_variable = os.getenv("SOME_TOOL_VARIABLE", "default_value") + some_variable = os.getenv("SOME_TOOL_VARIABLE", "") + if some_variable == "": + raise ValueError("Environment variable not loaded correctly") # Ensure the content dictionary initializes "tools_executed" and "env_tool" as lists if they don't exist content = input.content.copy() tool_name = self.tool_name # Append the tool_name to the "tools_executed" list content.setdefault("tools_executed", []).append(tool_name) # Ensure "env_tool" is initialized as a list and append the environment variable - if "env_tool" not in content: - content["env_tool"] = [] - content["env_tool"].append(some_variable) + content.setdefault("env_tool", []).append(some_variable) return Message(content=content) class MockTask(BaseTask): + """The role of MockTask is to append the task_name to the "tasks_executed" + list and append the environment variable to the "env_task" list. + """ + async def _execute(self, input: Message) -> Message: + """Use the corresponding tool to reverse the query.""" # Get the environment variable or use a default value - some_variable = os.getenv("SOME_TASK_VARIABLE", "default_value") + some_variable = os.getenv("SOME_TASK_VARIABLE", "") + if some_variable == "": + raise ValueError("Environment variable not loaded correctly") # Ensure the content dictionary initializes "tasks_executed" and "env_task" as lists if they don't exist - content = input.content.copy() - task_name = self.task_name - # Append the task_name to the "tasks_executed" list - content.setdefault("tasks_executed", []).append(task_name) - # Ensure "env_task" is initialized as a list and append the environment variable - if "env_task" not in content: - content["env_task"] = [] - content["env_task"].append(some_variable) - return Message(content=content) + result = await self.tool(input) + result.content.setdefault("tasks_executed", []).append(self.task_name) + result.content.setdefault("env_task", []).append(some_variable) + return result @pytest.mark.asyncio @@ -75,15 +86,15 @@ async def test_sequential_task_execution(self): input_message = Message(content={}) result_message = await sequential_task(input_message) - assert result_message.content["tasks_executed"] == [ - f"Task{i}" for i in range(num_tasks) - ], "Tasks executed in incorrect order" - # assert all( - # env_value == "test_tool" for env_value in result_message.content["env_tool"] - # ), "Tool environment variable not loaded correctly" - # assert all( - # env_value == "test_task" for env_value in result_message.content["env_task"] - # ), "Task environment variable not loaded correctly" + assert result_message.content["tools_executed"] == [ + f"Tool{i}" for i in range(num_tasks) + ], "Tools executed in incorrect order" + assert all( + env_value == "test_tool" for env_value in result_message.content["env_tool"] + ), "Tool environment variable not loaded correctly" + assert all( + env_value == "test_task" for env_value in result_message.content["env_task"] + ), "Task environment variable not loaded correctly" assert ( sequential_task.context.content["idx"] == num_tasks ), "Context index not updated correctly"