Skip to content

Commit

Permalink
Fix the issue of sequence action
Browse files Browse the repository at this point in the history
  • Loading branch information
yxjiang committed Mar 10, 2024
1 parent f482769 commit 5084855
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 19 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,5 @@ cython_debug/

logs
.env
scripts
scripts
.vscode
22 changes: 11 additions & 11 deletions polymind/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

class BaseAction(BaseModel, ABC):
"""BaseAction is the base class of the action.
An action is an object that can leverage tools (an LLM is considered a tool) to perform a specific task.
An action is a stateful object that can leverage tools (an LLM is considered a tool) to perform a specific task.
In most cases, an action is a logically unit of to fulfill an atomic task.
But sometimes, a complex atomic task can be divided into multiple sub-actions.
"""

action_name: str
tools: Dict[str, BaseTool]
tool: BaseTool

async def __call__(self, input: Message) -> Message:
"""Makes the instance callable, delegating to the execute method.
Expand Down Expand Up @@ -78,29 +78,29 @@ async def _execute(self, input: Message) -> Message:
Returns:
Message: The result of the composite action carried in a message.
"""
message = input
self._update_context()
next_action = self._get_next_action(input)
while next_action:
message = await next_action(input)
action = self._get_next_action(message)
while action:
message = await action(message)
self._update_context()
next_action = self._get_next_action(input)
action = self._get_next_action(message)
return message


class SequentialAction(CompositeAction):

actions: List[BaseAction] = Field(default_factory=list)

def __init__(
self, action_name: str, tools: Dict[str, BaseTool], actions: List[BaseAction]
):
super().__init__(action_name=action_name, tools=tools)
def __init__(self, action_name: str, tool: BaseTool, actions: List[BaseAction]):
super().__init__(action_name=action_name, tool=tool)
self.actions = actions

def _update_context(self) -> None:
if not bool(self.context.content):
self.context = Message(content={"idx": 0})
self.context.content["idx"] += 1
else:
self.context.content["idx"] += 1

def _get_next_action(self, input: Message) -> BaseAction:
if self.context.content["idx"] < len(self.actions):
Expand Down
32 changes: 25 additions & 7 deletions tests/polymind/core/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
import pytest
from polymind.core.action import BaseAction, SequentialAction
from polymind.core.message import Message
from polymind.core.tool import BaseTool


class MockTool(BaseTool):
async def _execute(self, input: Message) -> Message:
content = input.content.copy()
tool_name = self.tool_name
content.setdefault("tools_executed", []).append(tool_name)
return Message(content=content)


class MockAction(BaseAction):
Expand All @@ -19,19 +28,28 @@ async def _execute(self, input: Message) -> Message:
class TestSequentialAction:
async def test_sequential_action_execution(self):
# Create mock actions with different names
action1 = MockAction(action_name="Action1", tools={})
action2 = MockAction(action_name="Action2", tools={})
actions = []
num_actions = 5
for i in range(num_actions):
action = MockAction(
action_name=f"Action{i}", tool=MockTool(tool_name=f"Tool{i}")
)
actions.append(action)

# Initialize SequentialAction with the mock actions
sequential_action = SequentialAction(
action_name="test_seq_action", tools={}, actions=[action1, action2]
action_name="test_seq_action",
tool=MockTool(tool_name="Primary"),
actions=actions,
)

input_message = Message(content={})
result_message = await sequential_action(input_message)

# # Check if both actions were executed in the correct order
# assert result_message.content["actions_executed"] == ["Action1", "Action2"]
# Check if both actions were executed in the correct order
assert result_message.content.get("actions_executed", []) == [
"Action{i}".format(i=i) for i in range(num_actions)
], result_message.content["actions_executed"]

# # Check if the context was updated correctly
# assert sequential_action.context.content["idx"] == 2
# Check if the context was updated correctly
assert sequential_action.context.content["idx"] == num_actions

0 comments on commit 5084855

Please sign in to comment.