Skip to content

Commit

Permalink
fix: asyncio issues with security analyzer + enable security analyzer…
Browse files Browse the repository at this point in the history
… in cli (#5356)
  • Loading branch information
mbalunovic authored Dec 2, 2024
1 parent 92b38dc commit 871c544
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 3 deletions.
3 changes: 3 additions & 0 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ async def _handle_observation(self, observation: Observation) -> None:
self.agent.llm.metrics.merge(observation.llm_metrics)

if self._pending_action and self._pending_action.id == observation.cause:
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
return
self._pending_action = None
if self.state.agent_state == AgentState.USER_CONFIRMED:
await self.set_agent_state_to(AgentState.RUNNING)
Expand Down Expand Up @@ -369,6 +371,7 @@ async def set_agent_state_to(self, new_state: AgentState) -> None:
else:
confirmation_state = ActionConfirmationStatus.REJECTED
self._pending_action.confirmation_state = confirmation_state # type: ignore[attr-defined]
self._pending_action._id = None # type: ignore[attr-defined]
self.event_stream.add_event(self._pending_action, EventSource.AGENT)

self.state.agent_state = new_state
Expand Down
45 changes: 43 additions & 2 deletions openhands/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from openhands.controller import AgentController
from openhands.controller.agent import Agent
from openhands.core.config import (
AppConfig,
get_parser,
load_app_config,
)
Expand All @@ -20,6 +21,7 @@
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import (
Action,
ActionConfirmationStatus,
ChangeAgentStateAction,
CmdRunAction,
FileEditAction,
Expand All @@ -30,10 +32,12 @@
AgentStateChangedObservation,
CmdOutputObservation,
FileEditObservation,
NullObservation,
)
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.security import SecurityAnalyzer, options
from openhands.storage import get_file_store


Expand All @@ -45,6 +49,15 @@ def display_command(command: str):
print('❯ ' + colored(command + '\n', 'green'))


def display_confirmation(confirmation_state: ActionConfirmationStatus):
if confirmation_state == ActionConfirmationStatus.CONFIRMED:
print(colored('✅ ' + confirmation_state + '\n', 'green'))
elif confirmation_state == ActionConfirmationStatus.REJECTED:
print(colored('❌ ' + confirmation_state + '\n', 'red'))
else:
print(colored('⏳ ' + confirmation_state + '\n', 'yellow'))


def display_command_output(output: str):
lines = output.split('\n')
for line in lines:
Expand All @@ -59,7 +72,7 @@ def display_file_edit(event: FileEditAction | FileEditObservation):
print(colored(str(event), 'green'))


def display_event(event: Event):
def display_event(event: Event, config: AppConfig):
if isinstance(event, Action):
if hasattr(event, 'thought'):
display_message(event.thought)
Expand All @@ -74,6 +87,8 @@ def display_event(event: Event):
display_file_edit(event)
if isinstance(event, FileEditObservation):
display_file_edit(event)
if hasattr(event, 'confirmation_state') and config.security.confirmation_mode:
display_confirmation(event.confirmation_state)


async def main():
Expand Down Expand Up @@ -119,12 +134,18 @@ async def main():
headless_mode=True,
)

if config.security.security_analyzer:
options.SecurityAnalyzers.get(
config.security.security_analyzer, SecurityAnalyzer
)(event_stream)

controller = AgentController(
agent=agent,
max_iterations=config.max_iterations,
max_budget_per_task=config.max_budget_per_task,
agent_to_llm_config=config.get_agent_to_llm_config_map(),
event_stream=event_stream,
confirmation_mode=config.security.confirmation_mode,
)

async def prompt_for_next_task():
Expand All @@ -143,14 +164,34 @@ async def prompt_for_next_task():
action = MessageAction(content=next_message)
event_stream.add_event(action, EventSource.USER)

async def prompt_for_user_confirmation():
loop = asyncio.get_event_loop()
user_confirmation = await loop.run_in_executor(
None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
)
return user_confirmation.lower() == 'y'

async def on_event(event: Event):
display_event(event)
display_event(event, config)
if isinstance(event, AgentStateChangedObservation):
if event.agent_state in [
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
]:
await prompt_for_next_task()
if (
isinstance(event, NullObservation)
and controller.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION
):
user_confirmed = await prompt_for_user_confirmation()
if user_confirmed:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_CONFIRMED), EventSource.USER
)
else:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_REJECTED), EventSource.USER
)

event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))

Expand Down
4 changes: 4 additions & 0 deletions openhands/core/config/security_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,9 @@ def __str__(self):

return f"SecurityConfig({', '.join(attr_str)})"

@classmethod
def from_dict(cls, security_config_dict: dict) -> 'SecurityConfig':
return cls(**security_config_dict)

def __repr__(self):
return self.__str__()
7 changes: 7 additions & 0 deletions openhands/core/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig

load_dotenv()

Expand Down Expand Up @@ -144,6 +145,12 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
)
llm_config = LLMConfig.from_dict(nested_value)
cfg.set_llm_config(llm_config, nested_key)
elif key is not None and key.lower() == 'security':
logger.openhands_logger.debug(
'Attempt to load security config from config toml'
)
security_config = SecurityConfig.from_dict(value)
cfg.security = security_config
elif not key.startswith('sandbox') and key.lower() != 'core':
logger.openhands_logger.warning(
f'Unknown key in {toml_file}: "{key}"'
Expand Down
2 changes: 1 addition & 1 deletion openhands/security/invariant/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ async def confirm(self, event: Event) -> None:
)
# we should confirm only on agent actions
event_source = event.source if event.source else EventSource.AGENT
await call_sync_from_async(self.event_stream.add_event, new_event, event_source)
self.event_stream.add_event(new_event, event_source)

async def security_risk(self, event: Action) -> ActionSecurityRisk:
logger.debug('Calling security_risk on InvariantAnalyzer')
Expand Down

0 comments on commit 871c544

Please sign in to comment.