Просмотр исходного кода

Docs/improve agent controller docstrings (#5233)

Cheng Yang 1 год назад
Родитель
Сommit
3b18d77d31
1 измененных файлов с 22 добавлено и 20 удалено
  1. 22 20
      openhands/controller/agent_controller.py

+ 22 - 20
openhands/controller/agent_controller.py

@@ -102,9 +102,11 @@ class AgentController:
             agent_configs: A dictionary mapping agent names to agent configurations in the case that
                 we delegate to a different agent.
             sid: The session ID of the agent.
+            confirmation_mode: Whether to enable confirmation mode for agent actions.
             initial_state: The initial state of the controller.
             is_delegate: Whether this controller is a delegate.
             headless_mode: Whether the agent is run in headless mode.
+            status_callback: Optional callback function to handle status updates.
         """
         self._step_lock = asyncio.Lock()
         self.id = sid
@@ -133,10 +135,11 @@ class AgentController:
         self._stuck_detector = StuckDetector(self.state)
         self.status_callback = status_callback
 
-    async def close(self):
+    async def close(self) -> None:
         """Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
 
-        Note that it's fairly important that this closes properly, otherwise the state is incomplete."""
+        Note that it's fairly important that this closes properly, otherwise the state is incomplete.
+        """
         await self.set_agent_state_to(AgentState.STOPPED)
 
         # we made history, now is the time to rewrite it!
@@ -165,11 +168,13 @@ class AgentController:
         self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
         self._closed = True
 
-    def log(self, level: str, message: str, extra: dict | None = None):
+    def log(self, level: str, message: str, extra: dict | None = None) -> None:
         """Logs a message to the agent controller's logger.
 
         Args:
+            level (str): The logging level to use (e.g., 'info', 'debug', 'error').
             message (str): The message to log.
+            extra (dict | None, optional): Additional fields to include in the log. Defaults to None.
         """
         message = f'[Agent Controller {self.id}] {message}'
         getattr(logger, level)(message, extra=extra, stacklevel=2)
@@ -195,7 +200,6 @@ class AgentController:
 
     async def start_step_loop(self):
         """The main loop for the agent's step-by-step execution."""
-
         self.log('info', 'Starting step loop...')
         while should_continue():
             if self._closed:
@@ -212,7 +216,7 @@ class AgentController:
 
             await asyncio.sleep(0.1)
 
-    async def on_event(self, event: Event):
+    async def on_event(self, event: Event) -> None:
         """Callback from the event stream. Notifies the controller of incoming events.
 
         Args:
@@ -230,7 +234,7 @@ class AgentController:
         elif isinstance(event, Observation):
             await self._handle_observation(event)
 
-    async def _handle_action(self, action: Action):
+    async def _handle_action(self, action: Action) -> None:
         """Handles actions from the event stream.
 
         Args:
@@ -257,7 +261,7 @@ class AgentController:
             self.state.metrics.merge(self.state.local_metrics)
             await self.set_agent_state_to(AgentState.REJECTED)
 
-    async def _handle_observation(self, observation: Observation):
+    async def _handle_observation(self, observation: Observation) -> None:
         """Handles observation from the event stream.
 
         Args:
@@ -288,7 +292,7 @@ class AgentController:
             if self.state.agent_state == AgentState.ERROR:
                 self.state.metrics.merge(self.state.local_metrics)
 
-    async def _handle_message_action(self, action: MessageAction):
+    async def _handle_message_action(self, action: MessageAction) -> None:
         """Handles message actions from the event stream.
 
         Args:
@@ -309,13 +313,12 @@ class AgentController:
         elif action.source == EventSource.AGENT and action.wait_for_response:
             await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
 
-    def reset_task(self):
+    def reset_task(self) -> None:
         """Resets the agent's task."""
-
         self.almost_stuck = 0
         self.agent.reset()
 
-    async def set_agent_state_to(self, new_state: AgentState):
+    async def set_agent_state_to(self, new_state: AgentState) -> None:
         """Updates the agent's state and handles side effects. Can emit events to the event stream.
 
         Args:
@@ -376,7 +379,7 @@ class AgentController:
             await self.set_agent_state_to(self.state.resume_state)
             self.state.resume_state = None
 
-    def get_agent_state(self):
+    def get_agent_state(self) -> AgentState:
         """Returns the current state of the agent.
 
         Returns:
@@ -384,7 +387,7 @@ class AgentController:
         """
         return self.state.agent_state
 
-    async def start_delegate(self, action: AgentDelegateAction):
+    async def start_delegate(self, action: AgentDelegateAction) -> None:
         """Start a delegate agent to handle a subtask.
 
         OpenHands is a multi-agentic system. A `task` is a conversation between
@@ -532,7 +535,7 @@ class AgentController:
         log_level = 'info' if LOG_ALL_EVENTS else 'debug'
         self.log(log_level, str(action), extra={'msg_type': 'ACTION'})
 
-    async def _delegate_step(self):
+    async def _delegate_step(self) -> None:
         """Executes a single step of the delegate agent."""
         await self.delegate._step()  # type: ignore[union-attr]
         assert self.delegate is not None
@@ -596,7 +599,7 @@ class AgentController:
 
     async def _handle_traffic_control(
         self, limit_type: str, current_value: float, max_value: float
-    ):
+    ) -> bool:
         """Handles agent state after hitting the traffic control limit.
 
         Args:
@@ -628,7 +631,7 @@ class AgentController:
             stop_step = True
         return stop_step
 
-    def get_state(self):
+    def get_state(self) -> State:
         """Returns the current running state object.
 
         Returns:
@@ -641,7 +644,7 @@ class AgentController:
         state: State | None,
         max_iterations: int,
         confirmation_mode: bool = False,
-    ):
+    ) -> None:
         """Sets the initial state for the agent, either from the previous session, or from a parent agent, or by creating a new one.
 
         Args:
@@ -672,7 +675,7 @@ class AgentController:
 
             self._init_history()
 
-    def _init_history(self):
+    def _init_history(self) -> None:
         """Initializes the agent's history from the event stream.
 
         The history is a list of events that:
@@ -688,7 +691,6 @@ class AgentController:
 
         Otherwise loads normally from start_id.
         """
-
         # define range of events to fetch
         # delegates start with a start_id and initially won't find any events
         # otherwise we're restoring a previous session
@@ -884,7 +886,7 @@ class AgentController:
 
         return kept_events
 
-    def _is_stuck(self):
+    def _is_stuck(self) -> bool:
         """Checks if the agent or its delegate is stuck in a loop.
 
         Returns: