-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Experimental] aact-based Based Agent (#221)
* add aact as a dependency * minimal demo example of running custom model * devcontainer setup and example * remove default_bad_process_model to allow using custom model entirely * improve the demo to show parallel execution * CI: update tests trigger from pull request target to pull request * fix mypy errors * adding stubs to pyproject.toml * poetry lock * install all extras in the devcontainer start script * add dev containers instruction * migration to uv * update mypy * Update index.mdx * update uv venv path in the devcontainer and contributor's guide * simple examples of using aact for multi-agent async communication * allowing agents' aact function to return None * import Self for 3.10 * Create readme.md * dockerfile * record node log * frequency -> interval * docker compose (it works) * use published images to speed up * add ci test with docker * use compose action github action * update docker compose file * update compose file path * use github-action-docker-compose-test-run * remove unused port binding in docker-compose * add quotes to docker compose command * test run * test run * write test script in tests.sh * use docker compose * test run * --rm * ./ -> . * test * change to arm64 * fix docker platform problem * change test os * fix some build bugs * fix runner dir * fix a test case for sample * update cli test to test_install * update test benchmark to improve coverage * remove unused and maintain structured output compatibility * fix evaluator bug * add a test script which contributors can run locally * bump the version to 0.1.1 * add langchain openai back * add langchain openai in uv lock * remove redundant cast * add test case * test base agent * more coverage for agent.py * add __init__ to sotopia.experimental * chore: Add experimental page and agents documentation * Agent Documentation
- Loading branch information
Showing
18 changed files
with
1,549 additions
and
520 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{ | ||
"index": { | ||
"title": "Overview" | ||
}, | ||
"agents": { | ||
"title": "Agents" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import { Callout } from "nextra/components" | ||
|
||
<Callout type="warning"> | ||
This part of the documentation is for experimental features. The APIs and functionalities are subject to frequent change. | ||
</Callout> | ||
|
||
<Callout type="warning"> | ||
The Agent API implemented here conflicts with stable Agent API in Sotopia. | ||
</Callout> | ||
|
||
Agent is a concept in Sotopia to represent decision-making entities that can interact with each other in a social environment. Agents can be human participants, AI models, or other entities. | ||
No matter which type of agent, they have the same interface to interact with the environment: | ||
the input and output are of derived types of `aact.messages.DataModel`. | ||
|
||
### Creating your own agents | ||
To create your own agents, you need to subclass the `BaseAgent` class | ||
and implement the asynchronous `aact` method. | ||
The `aact` method takes an `Observation` object as input and returns an `AgentAction` object as output. Here is an example of a simple agent that always says "Hello, world!": | ||
|
||
```python | ||
from aact import NodeFactory | ||
from aact.messages import Text | ||
from sotopia.experimental import BaseAgent | ||
|
||
@NodeFactory.register("simple_echo_agent") # Register the agent so that it can be used in the dataflow | ||
class SimpleEchoAgent(BaseAgent[Text, Text]): | ||
def __init__(self, input_channel: str, output_channel: str, redis_url: str) -> None: | ||
super().__init__( # call the constructor of the base class | ||
input_channel_types=[(input_channel, Text)], | ||
output_channel_types=[(output_channel, Text)], | ||
) | ||
|
||
async def aact(self, observation: Text) -> Text: # major agent reactive function | ||
return Text(text=f"Hello, {observation.text}!") | ||
``` | ||
|
||
Let me break this down for you: | ||
1. `NodeFactory` is a decorator that registers the agent so that it can be used in the dataflow. Dataflow is a concept in `aact` that defines how `nodes` are interacting with each other. | ||
2. `channel` is a concept in `redis` pubsub and `aact`. A node can send messages to many channels, and receive messages many channels as well. To subclass `BaseAgent`, you will need to feed two lists of channel-message type pairs to `input_channel_types` and `output_channel_types` respectively. | ||
3. Inherit the `BaseAgent` class and specify the input and output channel types in the constructor. | ||
4. Implement the `aact` method that takes an `Observation` object as input and returns an `AgentAction` object as output. In this case, the agent always says "Hello, ..." | ||
|
||
For a running example, try out `examples/experimental/tick_and_echo_agents`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import { Callout } from "nextra/components" | ||
|
||
<Callout type="warning"> | ||
This part of the documentation is for experimental features. The APIs and functionalities are subject to frequent change. | ||
</Callout> | ||
|
||
The experimental APIs of Sotopia are intended for quickly prototyping and experimenting with new functionalities, | ||
without breaking the existing stable APIs. But we will still maintain the quality of the code for these features. | ||
Feel free to raise an issue if you find any bugs or wants more features in the experimental APIs. | ||
|
||
# Experimetal APIs | ||
The experimental APIs are in different states: | ||
|
||
- *scheduled*: the APIs will be merged into next minor releases. | ||
- *implemented*: the APIs are implemented and can be used, which might be merged into the stable APIs in the next few minor releases. | ||
- *planned*: the APIs are planned and will be implemented in the future. | ||
- *idealized*: the APIs are idealized and might be implemented in the future. | ||
|
||
Here are the experimental APIs: | ||
- [Agents](/experimental/agents) (*implemented*): aact-based asynchronous agents that don't follow OpenAI Gym's turn-based formulation. | ||
- Engines (*planned*): aact-based asynchronous environment engines. This would include | ||
- [Orchestrator](https://github.com/sotopia-lab/sotopia/issues/231): an engine base class for engines that dictates the orders and turns of the agents. | ||
- [Evaluator](https://github.com/sotopia-lab/sotopia/issues/232): an engine base class for engines that evaluates the agents' performance. | ||
- API Engine: an engine that interacts with REST APIs. | ||
- Generation APIs (*planned*): experimental generation APIs |
169 changes: 169 additions & 0 deletions
169
examples/experimental/group_discussion_agents/group_discussion_agents.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
from typing import AsyncIterator | ||
from aact import Message, NodeFactory | ||
from aact.messages import Text, Tick, DataModel, DataModelFactory | ||
from sotopia.agents.llm_agent import ainput | ||
from sotopia.experimental.agents import BaseAgent | ||
|
||
from sotopia.generation_utils import agenerate | ||
from sotopia.generation_utils.generate import StrOutputParser | ||
from sotopia.messages import ActionType | ||
|
||
from pydantic import Field | ||
|
||
|
||
@DataModelFactory.register("agent_action") | ||
class AgentAction(DataModel): | ||
agent_name: str = Field(description="the name of the agent") | ||
action_type: ActionType = Field( | ||
description="whether to speak at this turn or choose to not do anything" | ||
) | ||
argument: str = Field( | ||
description="the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action" | ||
) | ||
|
||
def to_natural_language(self) -> str: | ||
match self.action_type: | ||
case "none": | ||
return "did nothing" | ||
case "speak": | ||
return f'said: "{self.argument}"' | ||
case "non-verbal communication": | ||
return f"[{self.action_type}] {self.argument}" | ||
case "action": | ||
return f"[{self.action_type}] {self.argument}" | ||
case "leave": | ||
return "left the conversation" | ||
|
||
|
||
def _format_message_history(message_history: list[tuple[str, str]]) -> str: | ||
return "\n".join( | ||
(f"{speaker} said {message}") for speaker, message in message_history | ||
) | ||
|
||
|
||
@NodeFactory.register("llm_agent") | ||
class LLMAgent(BaseAgent[AgentAction | Tick, AgentAction]): | ||
def __init__( | ||
self, | ||
input_text_channels: list[str], | ||
input_tick_channel: str, | ||
output_channel: str, | ||
query_interval: int, | ||
agent_name: str, | ||
goal: str, | ||
model_name: str, | ||
redis_url: str, | ||
): | ||
super().__init__( | ||
[ | ||
(input_text_channel, AgentAction) | ||
for input_text_channel in input_text_channels | ||
] | ||
+ [ | ||
(input_tick_channel, Tick), | ||
], | ||
[(output_channel, AgentAction)], | ||
redis_url, | ||
) | ||
self.output_channel = output_channel | ||
self.query_interval = query_interval | ||
self.count_ticks = 0 | ||
self.message_history: list[tuple[str, str]] = [] | ||
self.name = agent_name | ||
self.model_name = model_name | ||
self.goal = goal | ||
|
||
async def send(self, message: AgentAction) -> None: | ||
if message.action_type == "speak": | ||
await self.r.publish( | ||
self.output_channel, | ||
Message[AgentAction](data=message).model_dump_json(), | ||
) | ||
|
||
async def aact(self, message: AgentAction | Tick) -> AgentAction: | ||
match message: | ||
case Tick(): | ||
self.count_ticks += 1 | ||
if self.count_ticks % self.query_interval == 0: | ||
agent_action: str = await agenerate( | ||
model_name=self.model_name, | ||
template="Imagine that you are a friend of the other persons. Here is the " | ||
"conversation between you and them.\n" | ||
"You are {agent_name} in the conversation.\n" | ||
"{message_history}\n" | ||
"and you plan to {goal}.\n" | ||
"You can choose to interrupt the other person " | ||
"by saying something or not to interrupt by outputting notiong. What would you say? " | ||
"Please only output a sentence or not outputting anything." | ||
"{format_instructions}", | ||
input_values={ | ||
"message_history": _format_message_history( | ||
self.message_history | ||
), | ||
"goal": self.goal, | ||
"agent_name": self.name, | ||
}, | ||
temperature=0.7, | ||
output_parser=StrOutputParser(), | ||
) | ||
if agent_action != "none" and agent_action != "": | ||
self.message_history.append((self.name, agent_action)) | ||
return AgentAction( | ||
agent_name=self.name, | ||
action_type="speak", | ||
argument=agent_action, | ||
) | ||
else: | ||
return AgentAction( | ||
agent_name=self.name, action_type="none", argument="" | ||
) | ||
else: | ||
return AgentAction( | ||
agent_name=self.name, action_type="none", argument="" | ||
) | ||
case AgentAction( | ||
agent_name=agent_name, action_type=action_type, argument=text | ||
): | ||
if action_type == "speak": | ||
self.message_history.append((agent_name, text)) | ||
return AgentAction( | ||
agent_name=self.name, action_type="none", argument="" | ||
) | ||
case _: | ||
raise ValueError(f"Unexpected message type: {type(message)}") | ||
|
||
|
||
@NodeFactory.register("input_node") | ||
class InputNode(BaseAgent[AgentAction, AgentAction]): | ||
def __init__( | ||
self, | ||
input_channel: str, | ||
output_channel: str, | ||
agent_name: str, | ||
redis_url: str = "redis://localhost:6379/0", | ||
): | ||
super().__init__( | ||
input_channel_types=[(input_channel, AgentAction)], | ||
output_channel_types=[(output_channel, AgentAction)], | ||
redis_url=redis_url, | ||
) | ||
self.input_channel = input_channel | ||
self.agent_name = agent_name | ||
|
||
async def event_handler( | ||
self, channel: str, message: Message[AgentAction] | ||
) -> AsyncIterator[tuple[str, Message[AgentAction]]]: | ||
if channel == self.input_channel: | ||
print(f"Received message: {message}") | ||
else: | ||
raise ValueError(f"Unexpected channel: {channel}") | ||
yield self.output_channel, Text(text=message.data.argument) | ||
|
||
async def _task_scheduler(self) -> None: | ||
while not self.shutdown_event.is_set(): | ||
text_input = await ainput() | ||
await self.send( | ||
AgentAction( | ||
agent_name=self.agent_name, action_type="speak", argument=text_input | ||
) | ||
) |
57 changes: 57 additions & 0 deletions
57
examples/experimental/group_discussion_agents/group_discussion_agents.toml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
redis_url = "redis://localhost:6379/0" | ||
extra_modules = ["examples.experimental.group_discussion_agents.group_discussion_agents"] | ||
|
||
[[nodes]] | ||
node_name = "Jack" | ||
node_class = "llm_agent" | ||
|
||
[nodes.node_args] | ||
query_interval = 5 | ||
output_channel = "Jack" | ||
input_text_channels = ["Jane", "John"] | ||
input_tick_channel = "tick/secs/1" | ||
goal = "want to play pocker with your friends tonight" | ||
model_name = "gpt-4o-mini" | ||
agent_name = "Jack" | ||
|
||
[[nodes]] | ||
node_name = "Jane" | ||
node_class = "llm_agent" | ||
|
||
[nodes.node_args] | ||
query_interval = 7 | ||
output_channel = "Jane" | ||
input_text_channels = ["Jack", "John"] | ||
input_tick_channel = "tick/secs/1" | ||
goal = "want to play soccer with your friends tonight" | ||
model_name = "gpt-4o-mini" | ||
agent_name = "Jane" | ||
|
||
[[nodes]] | ||
node_name = "John" | ||
node_class = "llm_agent" | ||
|
||
[nodes.node_args] | ||
query_interval = 10 | ||
output_channel = "John" | ||
input_text_channels = ["Jack", "Jane"] | ||
input_tick_channel = "tick/secs/1" | ||
goal = "want to go to concert with your friends tonight" | ||
model_name = "gpt-4o-mini" | ||
agent_name = "John" | ||
|
||
[[nodes]] | ||
node_name = "record" | ||
node_class = "record" | ||
|
||
[nodes.node_args] | ||
jsonl_file_path = "log.jsonl" | ||
|
||
[nodes.node_args.record_channel_types] | ||
"Jack" = "agent_action" | ||
"Jane" = "agent_action" | ||
"John" = "agent_action" | ||
|
||
[[nodes]] | ||
node_name = "tick" | ||
node_class = "tick" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
To run this example, please use aact to launch. | ||
|
||
```bash | ||
aact run-dataflow examples/experimental/group_discussion_agents/group_discussion_agents.toml | ||
``` |
Oops, something went wrong.