Browse Source

fix: asyncio issues with security analyzer + enable security analyzer in cli (#5356)

Mislav Balunovic 1 year ago
parent
commit
871c544b74

+ 3 - 0
openhands/controller/agent_controller.py

@@ -284,6 +284,8 @@ class AgentController:
             self.agent.llm.metrics.merge(observation.llm_metrics)
 
         if self._pending_action and self._pending_action.id == observation.cause:
+            if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
+                return
             self._pending_action = None
             if self.state.agent_state == AgentState.USER_CONFIRMED:
                 await self.set_agent_state_to(AgentState.RUNNING)
@@ -369,6 +371,7 @@ class AgentController:
             else:
                 confirmation_state = ActionConfirmationStatus.REJECTED
             self._pending_action.confirmation_state = confirmation_state  # type: ignore[attr-defined]
+            self._pending_action._id = None  # type: ignore[attr-defined]
             self.event_stream.add_event(self._pending_action, EventSource.AGENT)
 
         self.state.agent_state = new_state

+ 43 - 2
openhands/core/cli.py

@@ -11,6 +11,7 @@ from openhands import __version__
 from openhands.controller import AgentController
 from openhands.controller.agent import Agent
 from openhands.core.config import (
+    AppConfig,
     get_parser,
     load_app_config,
 )
@@ -20,6 +21,7 @@ from openhands.core.schema import AgentState
 from openhands.events import EventSource, EventStream, EventStreamSubscriber
 from openhands.events.action import (
     Action,
+    ActionConfirmationStatus,
     ChangeAgentStateAction,
     CmdRunAction,
     FileEditAction,
@@ -30,10 +32,12 @@ from openhands.events.observation import (
     AgentStateChangedObservation,
     CmdOutputObservation,
     FileEditObservation,
+    NullObservation,
 )
 from openhands.llm.llm import LLM
 from openhands.runtime import get_runtime_cls
 from openhands.runtime.base import Runtime
+from openhands.security import SecurityAnalyzer, options
 from openhands.storage import get_file_store
 
 
@@ -45,6 +49,15 @@ def display_command(command: str):
     print('❯ ' + colored(command + '\n', 'green'))
 
 
+def display_confirmation(confirmation_state: ActionConfirmationStatus):
+    if confirmation_state == ActionConfirmationStatus.CONFIRMED:
+        print(colored('✅ ' + confirmation_state + '\n', 'green'))
+    elif confirmation_state == ActionConfirmationStatus.REJECTED:
+        print(colored('❌ ' + confirmation_state + '\n', 'red'))
+    else:
+        print(colored('⏳ ' + confirmation_state + '\n', 'yellow'))
+
+
 def display_command_output(output: str):
     lines = output.split('\n')
     for line in lines:
@@ -59,7 +72,7 @@ def display_file_edit(event: FileEditAction | FileEditObservation):
     print(colored(str(event), 'green'))
 
 
-def display_event(event: Event):
+def display_event(event: Event, config: AppConfig):
     if isinstance(event, Action):
         if hasattr(event, 'thought'):
             display_message(event.thought)
@@ -74,6 +87,8 @@ def display_event(event: Event):
         display_file_edit(event)
     if isinstance(event, FileEditObservation):
         display_file_edit(event)
+    if hasattr(event, 'confirmation_state') and config.security.confirmation_mode:
+        display_confirmation(event.confirmation_state)
 
 
 async def main():
@@ -119,12 +134,18 @@ async def main():
         headless_mode=True,
     )
 
+    if config.security.security_analyzer:
+        options.SecurityAnalyzers.get(
+            config.security.security_analyzer, SecurityAnalyzer
+        )(event_stream)
+
     controller = AgentController(
         agent=agent,
         max_iterations=config.max_iterations,
         max_budget_per_task=config.max_budget_per_task,
         agent_to_llm_config=config.get_agent_to_llm_config_map(),
         event_stream=event_stream,
+        confirmation_mode=config.security.confirmation_mode,
     )
 
     async def prompt_for_next_task():
@@ -143,14 +164,34 @@ async def main():
         action = MessageAction(content=next_message)
         event_stream.add_event(action, EventSource.USER)
 
+    async def prompt_for_user_confirmation():
+        loop = asyncio.get_event_loop()
+        user_confirmation = await loop.run_in_executor(
+            None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
+        )
+        return user_confirmation.lower() == 'y'
+
     async def on_event(event: Event):
-        display_event(event)
+        display_event(event, config)
         if isinstance(event, AgentStateChangedObservation):
             if event.agent_state in [
                 AgentState.AWAITING_USER_INPUT,
                 AgentState.FINISHED,
             ]:
                 await prompt_for_next_task()
+        if (
+            isinstance(event, NullObservation)
+            and controller.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION
+        ):
+            user_confirmed = await prompt_for_user_confirmation()
+            if user_confirmed:
+                event_stream.add_event(
+                    ChangeAgentStateAction(AgentState.USER_CONFIRMED), EventSource.USER
+                )
+            else:
+                event_stream.add_event(
+                    ChangeAgentStateAction(AgentState.USER_REJECTED), EventSource.USER
+                )
 
     event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
 

+ 4 - 0
openhands/core/config/security_config.py

@@ -32,5 +32,9 @@ class SecurityConfig:
 
         return f"SecurityConfig({', '.join(attr_str)})"
 
+    @classmethod
+    def from_dict(cls, security_config_dict: dict) -> 'SecurityConfig':
+        return cls(**security_config_dict)
+
     def __repr__(self):
         return self.__str__()

+ 7 - 0
openhands/core/config/utils.py

@@ -18,6 +18,7 @@ from openhands.core.config.config_utils import (
 )
 from openhands.core.config.llm_config import LLMConfig
 from openhands.core.config.sandbox_config import SandboxConfig
+from openhands.core.config.security_config import SecurityConfig
 
 load_dotenv()
 
@@ -144,6 +145,12 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
                             )
                             llm_config = LLMConfig.from_dict(nested_value)
                             cfg.set_llm_config(llm_config, nested_key)
+                elif key is not None and key.lower() == 'security':
+                    logger.openhands_logger.debug(
+                        'Attempt to load security config from config toml'
+                    )
+                    security_config = SecurityConfig.from_dict(value)
+                    cfg.security = security_config
                 elif not key.startswith('sandbox') and key.lower() != 'core':
                     logger.openhands_logger.warning(
                         f'Unknown key in {toml_file}: "{key}"'

+ 1 - 1
openhands/security/invariant/analyzer.py

@@ -300,7 +300,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
         )
         # we should confirm only on agent actions
         event_source = event.source if event.source else EventSource.AGENT
-        await call_sync_from_async(self.event_stream.add_event, new_event, event_source)
+        self.event_stream.add_event(new_event, event_source)
 
     async def security_risk(self, event: Action) -> ActionSecurityRisk:
         logger.debug('Calling security_risk on InvariantAnalyzer')