Skip to content

Commit

Permalink
AutoGPT: Move all the Agent's prompt building code into OneShotAgentP…
Browse files Browse the repository at this point in the history
…romptStrategy
  • Loading branch information
Pwuts committed Sep 27, 2023
1 parent 6336f8f commit e37a7a4
Show file tree
Hide file tree
Showing 25 changed files with 965 additions and 755 deletions.
264 changes: 81 additions & 183 deletions autogpts/autogpt/autogpt/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,21 @@
import logging
import time
from datetime import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from autogpt.config import Config
from autogpt.memory.vector import VectorMemory
from autogpt.models.command_registry import CommandRegistry

from autogpt.config.ai_config import AIConfig
from autogpt.config import AIConfig
from autogpt.core.configuration import Configurable
from autogpt.core.prompting import ChatPrompt
from autogpt.core.resource.model_providers import (
ChatMessage,
ChatModelProvider,
ChatModelResponse,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.json_utils.utilities import extract_dict_from_response
from autogpt.llm.api_manager import ApiManager
from autogpt.logs.log_cycle import (
CURRENT_CONTEXT_FILE_NAME,
Expand All @@ -42,12 +40,11 @@
from .features.context import ContextMixin
from .features.watchdog import WatchdogMixin
from .features.workspace import WorkspaceMixin
from .utils.exceptions import (
AgentException,
CommandExecutionError,
InvalidAgentResponseError,
UnknownCommandError,
from .prompt_strategies.one_shot import (
OneShotAgentPromptConfiguration,
OneShotAgentPromptStrategy,
)
from .utils.exceptions import AgentException, CommandExecutionError, UnknownCommandError

logger = logging.getLogger(__name__)

Expand All @@ -58,6 +55,7 @@ class AgentConfiguration(BaseAgentConfiguration):

class AgentSettings(BaseAgentSettings):
config: AgentConfiguration
prompt_config: OneShotAgentPromptConfiguration


class Agent(
Expand All @@ -69,11 +67,12 @@ class Agent(
):
"""AutoGPT's primary Agent; uses one-shot prompting."""

default_settings = AgentSettings(
default_settings: AgentSettings = AgentSettings(
name="Agent",
description=__doc__,
ai_config=AIConfig(ai_name="AutoGPT"),
config=AgentConfiguration(),
prompt_config=OneShotAgentPromptStrategy.default_configuration,
history=BaseAgent.default_settings.history,
)

Expand All @@ -85,9 +84,14 @@ def __init__(
memory: VectorMemory,
legacy_config: Config,
):
prompt_strategy = OneShotAgentPromptStrategy(
configuration=settings.prompt_config,
logger=logger,
)
super().__init__(
settings=settings,
llm_provider=llm_provider,
prompt_strategy=prompt_strategy,
command_registry=command_registry,
legacy_config=legacy_config,
)
Expand All @@ -101,12 +105,15 @@ def __init__(
self.log_cycle_handler = LogCycleHandler()
"""LogCycleHandler for structured debug logging."""

def construct_base_prompt(self, *args, **kwargs) -> ChatPrompt:
if kwargs.get("prepend_messages") is None:
kwargs["prepend_messages"] = []

def build_prompt(
self,
*args,
extra_messages: list[ChatMessage] = [],
include_os_info: Optional[bool] = None,
**kwargs,
) -> ChatPrompt:
# Clock
kwargs["prepend_messages"].append(
extra_messages.append(
ChatMessage.system(f"The current time and date is {time.strftime('%c')}"),
)

Expand All @@ -132,12 +139,17 @@ def construct_base_prompt(self, *args, **kwargs) -> ChatPrompt:
),
)
logger.debug(budget_msg)
extra_messages.append(budget_msg)

if kwargs.get("append_messages") is None:
kwargs["append_messages"] = []
kwargs["append_messages"].append(budget_msg)
if include_os_info is None:
include_os_info = self.legacy_config.execute_local_commands

return super().construct_base_prompt(*args, **kwargs)
return super().build_prompt(
*args,
extra_messages=extra_messages,
include_os_info=include_os_info,
**kwargs,
)

def on_before_think(self, *args, **kwargs) -> ChatPrompt:
prompt = super().on_before_think(*args, **kwargs)
Expand All @@ -152,6 +164,40 @@ def on_before_think(self, *args, **kwargs) -> ChatPrompt:
)
return prompt

def parse_and_process_response(
self, llm_response: ChatModelResponse, *args, **kwargs
) -> Agent.ThoughtProcessOutput:
for plugin in self.config.plugins:
if not plugin.can_handle_post_planning():
continue
llm_response.response["content"] = plugin.post_planning(
llm_response.response.get("content", "")
)

(
command_name,
arguments,
assistant_reply_dict,
) = self.prompt_strategy.parse_response_content(llm_response.response)

self.log_cycle_handler.log_cycle(
self.ai_config.ai_name,
self.created_at,
self.config.cycle_count,
assistant_reply_dict,
NEXT_ACTION_FILE_NAME,
)

self.event_history.register_action(
Action(
name=command_name,
args=arguments,
reasoning=assistant_reply_dict["thoughts"]["reasoning"],
)
)

return command_name, arguments, assistant_reply_dict

async def execute(
self,
command_name: str,
Expand Down Expand Up @@ -201,10 +247,7 @@ async def execute(
result = ActionErrorResult(reason=e.message, error=e)

result_tlength = self.llm_provider.count_tokens(str(result), self.llm.name)
history_tlength = self.llm_provider.count_tokens(
self.event_history.fmt_paragraph(), self.llm.name
)
if result_tlength + history_tlength > self.send_token_limit:
if result_tlength > self.send_token_limit // 3:
result = ActionErrorResult(
reason=f"Command {command_name} returned too much output. "
"Do not execute this command again with the same arguments."
Expand All @@ -223,159 +266,10 @@ async def execute(

return result

def parse_and_process_response(
self, llm_response: ChatModelResponse, *args, **kwargs
) -> Agent.ThoughtProcessOutput:
if "content" not in llm_response.response:
raise InvalidAgentResponseError("Assistant response has no text content")

response_content = llm_response.response["content"]

for plugin in self.config.plugins:
if not plugin.can_handle_post_planning():
continue
response_content = plugin.post_planning(response_content)

assistant_reply_dict = extract_dict_from_response(response_content)

_, errors = RESPONSE_SCHEMA.validate_object(assistant_reply_dict, logger)
if errors:
raise InvalidAgentResponseError(
"Validation of response failed:\n "
+ ";\n ".join([str(e) for e in errors])
)

# Get command name and arguments
command_name, arguments = extract_command(
assistant_reply_dict, llm_response, self.config.use_functions_api
)
response = command_name, arguments, assistant_reply_dict

self.log_cycle_handler.log_cycle(
self.ai_config.ai_name,
self.created_at,
self.config.cycle_count,
assistant_reply_dict,
NEXT_ACTION_FILE_NAME,
)

self.event_history.register_action(
Action(
name=command_name,
args=arguments,
reasoning=assistant_reply_dict["thoughts"]["reasoning"],
)
)

return response


RESPONSE_SCHEMA = JSONSchema(
type=JSONSchema.Type.OBJECT,
properties={
"thoughts": JSONSchema(
type=JSONSchema.Type.OBJECT,
required=True,
properties={
"text": JSONSchema(
description="thoughts",
type=JSONSchema.Type.STRING,
required=True,
),
"reasoning": JSONSchema(
type=JSONSchema.Type.STRING,
required=True,
),
"plan": JSONSchema(
description="- short bulleted\n- list that conveys\n- long-term plan",
type=JSONSchema.Type.STRING,
required=True,
),
"criticism": JSONSchema(
description="constructive self-criticism",
type=JSONSchema.Type.STRING,
required=True,
),
"speak": JSONSchema(
description="thoughts summary to say to user",
type=JSONSchema.Type.STRING,
required=True,
),
},
),
"command": JSONSchema(
type=JSONSchema.Type.OBJECT,
required=True,
properties={
"name": JSONSchema(
type=JSONSchema.Type.STRING,
required=True,
),
"args": JSONSchema(
type=JSONSchema.Type.OBJECT,
required=True,
),
},
),
},
)


def extract_command(
assistant_reply_json: dict,
assistant_reply: ChatModelResponse,
use_openai_functions_api: bool,
) -> tuple[str, dict[str, str]]:
"""Parse the response and return the command name and arguments
Args:
assistant_reply_json (dict): The response object from the AI
assistant_reply (ChatModelResponse): The model response from the AI
config (Config): The config object
Returns:
tuple: The command name and arguments
Raises:
json.decoder.JSONDecodeError: If the response is not valid JSON
Exception: If any other error occurs
"""
if use_openai_functions_api:
if "function_call" not in assistant_reply.response:
raise InvalidAgentResponseError("No 'function_call' in assistant reply")
assistant_reply_json["command"] = {
"name": assistant_reply.response["function_call"]["name"],
"args": json.loads(assistant_reply.response["function_call"]["arguments"]),
}
try:
if not isinstance(assistant_reply_json, dict):
raise InvalidAgentResponseError(
f"The previous message sent was not a dictionary {assistant_reply_json}"
)

if "command" not in assistant_reply_json:
raise InvalidAgentResponseError("Missing 'command' object in JSON")

command = assistant_reply_json["command"]
if not isinstance(command, dict):
raise InvalidAgentResponseError("'command' object is not a dictionary")

if "name" not in command:
raise InvalidAgentResponseError("Missing 'name' field in 'command' object")

command_name = command["name"]

# Use an empty dictionary if 'args' field is not present in 'command' object
arguments = command.get("args", {})

return command_name, arguments

except json.decoder.JSONDecodeError:
raise InvalidAgentResponseError("Invalid JSON")

except Exception as e:
raise InvalidAgentResponseError(str(e))
#############
# Utilities #
#############


async def execute_command(
Expand Down Expand Up @@ -406,14 +300,18 @@ async def execute_command(
raise CommandExecutionError(str(e))

# Handle non-native commands (e.g. from plugins)
for name, command in agent.prompt_generator.commands.items():
if command_name == name or command_name.lower() == command.description.lower():
try:
return command.function(**arguments)
except AgentException:
raise
except Exception as e:
raise CommandExecutionError(str(e))
if agent._prompt_scratchpad:
for name, command in agent._prompt_scratchpad.commands.items():
if (
command_name == name
or command_name.lower() == command.description.lower()
):
try:
return command.method(**arguments)
except AgentException:
raise
except Exception as e:
raise CommandExecutionError(str(e))

raise UnknownCommandError(
f"Cannot execute command '{command_name}': unknown command."
Expand Down
Loading

0 comments on commit e37a7a4

Please sign in to comment.