Procházet zdrojové kódy

fix: metric logging in agent controller (#4387)

Xingyao Wang před 1 rokem
rodič
revize
6bbd75c6e7

+ 14 - 7
openhands/controller/agent_controller.py

@@ -127,8 +127,8 @@ class AgentController:
         self.state.local_iteration += 1
 
     async def update_state_after_step(self):
-        # update metrics especially for cost
-        self.state.local_metrics = self.agent.llm.metrics
+        # update metrics especially for cost. Use deepcopy to avoid it being modified by agent.reset()
+        self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
         if 'llm_completions' not in self.state.extra_data:
             self.state.extra_data['llm_completions'] = []
         self.state.extra_data['llm_completions'].extend(self.agent.llm.llm_completions)
@@ -139,12 +139,12 @@ class AgentController:
 
         This method should be called for a particular type of errors, which have:
         - a user-friendly message, which will be shown in the chat box. This should not be a raw exception message.
-        - an ErrorObservation that can be sent to the LLM by the agent, with the exception message, so it can self-correct next time.
+        - an ErrorObservation that can be sent to the LLM by the user role, with the exception message, so it can self-correct next time.
         """
         self.state.last_error = message
         if exception:
             self.state.last_error += f': {exception}'
-        self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
+        self.event_stream.add_event(ErrorObservation(message), EventSource.USER)
 
     async def start_step_loop(self):
         """The main loop for the agent's step-by-step execution."""
@@ -229,6 +229,11 @@ class AgentController:
                 observation_to_print.content, self.agent.llm.config.max_message_chars
             )
         logger.info(observation_to_print, extra={'msg_type': 'OBSERVATION'})
+
+        # Merge with the metrics from the LLM - it will to synced to the controller's local metrics in update_state_after_step()
+        if observation.llm_metrics is not None:
+            self.agent.llm.metrics.merge(observation.llm_metrics)
+
         if self._pending_action and self._pending_action.id == observation.cause:
             self._pending_action = None
             if self.state.agent_state == AgentState.USER_CONFIRMED:
@@ -450,8 +455,9 @@ class AgentController:
         logger.info(action, extra={'msg_type': 'ACTION'})
 
         if self._is_stuck():
-            await self.report_error('Agent got stuck in a loop')
+            # This need to go BEFORE report_error to sync metrics
             await self.set_agent_state_to(AgentState.ERROR)
+            await self.report_error('Agent got stuck in a loop')
 
     async def _delegate_step(self):
         """Executes a single step of the delegate agent."""
@@ -519,20 +525,21 @@ class AgentController:
         else:
             self.state.traffic_control_state = TrafficControlState.THROTTLING
             if self.headless_mode:
+                # This need to go BEFORE report_error to sync metrics
+                await self.set_agent_state_to(AgentState.ERROR)
                 # set to ERROR state if running in headless mode
                 # since user cannot resume on the web interface
                 await self.report_error(
                     f'Agent reached maximum {limit_type} in headless mode, task stopped. '
                     f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}'
                 )
-                await self.set_agent_state_to(AgentState.ERROR)
             else:
+                await self.set_agent_state_to(AgentState.PAUSED)
                 await self.report_error(
                     f'Agent reached maximum {limit_type}, task paused. '
                     f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}. '
                     f'{TRAFFIC_CONTROL_REMINDER}'
                 )
-                await self.set_agent_state_to(AgentState.PAUSED)
             stop_step = True
         return stop_step
 

+ 13 - 0
openhands/events/event.py

@@ -2,6 +2,8 @@ from dataclasses import dataclass
 from datetime import datetime
 from enum import Enum
 
+from openhands.core.metrics import Metrics
+
 
 class EventSource(str, Enum):
     AGENT = 'agent'
@@ -58,3 +60,14 @@ class Event:
         if hasattr(self, 'blocking'):
             # .blocking needs to be set to True if .timeout is set
             self.blocking = True
+
+    # optional metadata, LLM call cost of the edit
+    @property
+    def llm_metrics(self) -> Metrics | None:
+        if hasattr(self, '_llm_metrics'):
+            return self._llm_metrics  # type: ignore[attr-defined]
+        return None
+
+    @llm_metrics.setter
+    def llm_metrics(self, value: Metrics) -> None:
+        self._llm_metrics = value