Просмотр исходного кода

History clean up (#2849)

* clean up add_history

* refactor last agent message
Engel Nyst 1 год назад
Родитель
Сommit
2df1d67007

+ 2 - 10
evaluation/EDA/run_infer.py

@@ -24,7 +24,6 @@ from opendevin.core.config import config, get_llm_config_arg, get_parser
 from opendevin.core.logger import get_console_handler
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.main import run_agent_controller
-from opendevin.events.action import MessageAction
 from opendevin.llm.llm import LLM
 
 game = None
@@ -44,10 +43,7 @@ def codeact_user_response_eda(state: State) -> str:
 
     # retrieve the latest model message from history
     if state.history:
-        for event in state.history.get_events(reverse=True):
-            if isinstance(event, MessageAction) and event.source == 'agent':
-                model_guess = event.content
-                break
+        model_guess = state.history.get_last_agent_message()
 
     assert game is not None, 'Game is not initialized.'
     msg = game.generate_user_response(model_guess)
@@ -150,11 +146,7 @@ def process_instance(
     if state is None:
         raise ValueError('State should not be None.')
 
-    final_message = ''
-    for event in state.history.get_events(reverse=True):
-        if isinstance(event, MessageAction) and event.source == 'agent':
-            final_message = event.content
-            break
+    final_message = state.history.get_last_agent_message()
 
     logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}')
     test_result = game.reward()

+ 1 - 5
evaluation/gorilla/run_infer.py

@@ -131,12 +131,8 @@ def process_instance(agent, question_id, question, metadata, reset_logger: bool
         if state is None:
             raise ValueError('State should not be None.')
 
-        model_answer_raw = ''
-
         # retrieve the last message from the agent
-        for event in state.history.get_events(reverse=True):
-            if isinstance(event, MessageAction) and event.source == 'agent':
-                model_answer_raw = event
+        model_answer_raw = state.history.get_last_agent_message()
 
         # attempt to parse model_answer
         _, _, ast_eval = get_data(metadata['hub'])

+ 2 - 10
evaluation/gpqa/run_infer.py

@@ -41,7 +41,6 @@ from opendevin.core.config import config, get_llm_config_arg, get_parser
 from opendevin.core.logger import get_console_handler
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.main import run_agent_controller
-from opendevin.events.action import MessageAction
 from opendevin.llm.llm import LLM
 
 AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
@@ -209,15 +208,8 @@ def process_instance(
         assert state is not None, 'State should not be None.'
 
         # ======= Attempt to evaluate the agent's edits =======
-        # get the final message from the state history (default to None if not found)
-        final_message = next(
-            (
-                act.content
-                for act in state.history.get_events(reverse=True)
-                if isinstance(act, MessageAction)
-            ),
-            None,
-        )
+        # get the final message from the state history (default to empty if not found)
+        final_message = state.history.get_last_agent_message()
 
         logger.info(f'Final message generated by the agent: {final_message}')
 

+ 1 - 7
evaluation/toolqa/run_infer.py

@@ -20,7 +20,6 @@ from opendevin.core.config import config, get_llm_config_arg, get_parser
 from opendevin.core.logger import get_console_handler
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.main import run_agent_controller
-from opendevin.events.action import MessageAction
 from opendevin.llm.llm import LLM
 
 from .utils import download_data, download_tools, encode_question, eval_answer, get_data
@@ -95,13 +94,8 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
     if state is None:
         raise ValueError('State should not be None.')
 
-    model_answer_raw = ''
-
     # retrieve the last message from the agent
-    for event in state.history.get_events(reverse=True):
-        if isinstance(event, MessageAction) and event.source == 'agent':
-            model_answer_raw = event.content
-            break
+    model_answer_raw = state.history.get_last_agent_message()
 
     # attempt to parse model_answer
     correct = eval_answer(str(model_answer_raw), str(answer))

+ 0 - 16
opendevin/controller/agent_controller.py

@@ -31,7 +31,6 @@ from opendevin.events.observation import (
     AgentStateChangedObservation,
     CmdOutputObservation,
     ErrorObservation,
-    NullObservation,
     Observation,
 )
 
@@ -128,13 +127,6 @@ class AgentController:
             self.state.last_error += f': {exception}'
         self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
 
-    async def add_history(self, action: Action, observation: Observation):
-        if isinstance(action, NullAction) and isinstance(observation, NullObservation):
-            return
-        logger.debug(
-            f'Adding history ({type(action).__name__} with id={action.id}, {type(observation).__name__} with id={observation.id})'
-        )
-
     async def _start_step_loop(self):
         logger.info(f'[Agent Controller {self.id}] Starting step loop...')
         while True:
@@ -160,7 +152,6 @@ class AgentController:
             await self.set_agent_state_to(event.agent_state)  # type: ignore
         elif isinstance(event, MessageAction):
             if event.source == EventSource.USER:
-                await self.add_history(event, NullObservation(''))
                 if self.get_agent_state() != AgentState.RUNNING:
                     await self.set_agent_state_to(AgentState.RUNNING)
             elif event.source == EventSource.AGENT and event.wait_for_response:
@@ -179,18 +170,14 @@ class AgentController:
             await self.set_agent_state_to(AgentState.REJECTED)
         elif isinstance(event, Observation):
             if self._pending_action and self._pending_action.id == event.cause:
-                await self.add_history(self._pending_action, event)
                 self._pending_action = None
                 logger.info(event, extra={'msg_type': 'OBSERVATION'})
             elif isinstance(event, CmdOutputObservation):
-                await self.add_history(NullAction(), event)
                 logger.info(event, extra={'msg_type': 'OBSERVATION'})
             elif isinstance(event, AgentDelegateObservation):
-                await self.add_history(NullAction(), event)
                 self.state.history.on_event(event)
                 logger.info(event, extra={'msg_type': 'OBSERVATION'})
             elif isinstance(event, ErrorObservation):
-                await self.add_history(NullAction(), event)
                 logger.info(event, extra={'msg_type': 'OBSERVATION'})
 
     def reset_task(self):
@@ -359,9 +346,6 @@ class AgentController:
         if not isinstance(action, NullAction):
             self.event_stream.add_event(action, EventSource.AGENT)
 
-        if not action.runnable:
-            await self.add_history(action, NullObservation(''))
-
         await self.update_state_after_step()
         logger.info(action, extra={'msg_type': 'ACTION'})
 

+ 18 - 1
opendevin/memory/history.py

@@ -127,7 +127,7 @@ class ShortTermHistory(list[Event]):
 
     def get_last_user_message(self) -> str:
         """
-        Return the latest user message from the event stream.
+        Return the content of the last user message from the event stream.
         """
 
         last_user_message = next(
@@ -141,6 +141,23 @@ class ShortTermHistory(list[Event]):
 
         return last_user_message if last_user_message is not None else ''
 
+    def get_last_agent_message(self) -> str:
+        """
+        Return the content of the last agent message from the event stream.
+        """
+
+        last_agent_message = next(
+            (
+                event.content
+                for event in self._event_stream.get_events(reverse=True)
+                if isinstance(event, MessageAction)
+                and event.source == EventSource.AGENT
+            ),
+            None,
+        )
+
+        return last_agent_message if last_agent_message is not None else ''
+
     def get_last_events(self, n: int) -> list[Event]:
         """
         Return the last n events from the event stream.