Selaa lähdekoodia

Bug fix: Metrics not accumulated across agent delegation (#3012)

* Add test to reproduce cost miscalculation bug

* Fix metrics bug

* Copy metrics upon AgentRejectAction
Boxuan Li 1 vuosi sitten
vanhempi
sitoutus
be6e6e3add

+ 4 - 2
opendevin/controller/agent_controller.py

@@ -123,7 +123,7 @@ class AgentController:
 
     async def update_state_after_step(self):
         # update metrics especially for cost
-        self.state.metrics = self.agent.llm.metrics
+        self.state.local_metrics = self.agent.llm.metrics
 
     async def report_error(self, message: str, exception: Exception | None = None):
         """This error will be reported to the user and sent to the LLM next step, in the hope it can self-correct.
@@ -174,9 +174,11 @@ class AgentController:
             self.state.root_task.set_subtask_state(event.task_id, event.state)
         elif isinstance(event, AgentFinishAction):
             self.state.outputs = event.outputs  # type: ignore[attr-defined]
+            self.state.metrics.merge(self.state.local_metrics)
             await self.set_agent_state_to(AgentState.FINISHED)
         elif isinstance(event, AgentRejectAction):
             self.state.outputs = event.outputs  # type: ignore[attr-defined]
+            self.state.metrics.merge(self.state.local_metrics)
             await self.set_agent_state_to(AgentState.REJECTED)
         elif isinstance(event, Observation):
             if (
@@ -260,7 +262,7 @@ class AgentController:
             iteration=self.state.iteration,
             max_iterations=self.state.max_iterations,
             delegate_level=self.state.delegate_level + 1,
-            # metrics should be shared between parent and child
+            # global metrics should be shared between parent and child
             metrics=self.state.metrics,
         )
         logger.info(

+ 2 - 0
opendevin/controller/state/state.py

@@ -98,6 +98,8 @@ class State:
     traffic_control_state: TrafficControlState = TrafficControlState.NORMAL
     # global metrics for the current task
     metrics: Metrics = Metrics()
+    # local metrics for the current subtask
+    local_metrics: Metrics = Metrics()
     # root agent has level 0, and every delegate increases the level by one
     delegate_level: int = 0
     # start_id and end_id track the range of events in history

+ 4 - 0
opendevin/core/metrics.py

@@ -28,6 +28,10 @@ class Metrics:
         self._accumulated_cost += value
         self._costs.append(value)
 
+    def merge(self, other: 'Metrics') -> None:
+        self._accumulated_cost += other.accumulated_cost
+        self._costs += other._costs
+
     def get(self):
         """Return the metrics in a dictionary."""
         return {'accumulated_cost': self._accumulated_cost, 'costs': self._costs}

+ 11 - 0
tests/integration/conftest.py

@@ -173,6 +173,11 @@ def mock_completion(*args, test_name, **kwargs):
     return response
 
 
+@pytest.fixture
+def current_test_name(request):
+    return request.node.name
+
+
 @pytest.fixture(autouse=True)
 def patch_completion(monkeypatch, request):
     test_name = request.node.name
@@ -182,6 +187,12 @@ def patch_completion(monkeypatch, request):
         partial(mock_completion, test_name=test_name),
     )
 
+    # Mock LLM completion cost (1 USD per conversation)
+    monkeypatch.setattr(
+        'opendevin.llm.llm.litellm_completion_cost',
+        lambda completion_response, **extra_kwargs: 1,
+    )
+
     # Mock user input (only for tests that have user_responses.log)
     user_responses_str = mock_user_response(test_name=test_name)
     if user_responses_str:

+ 28 - 15
tests/integration/test_agent.py

@@ -28,10 +28,25 @@ print(f'workspace_mount_path: {workspace_mount_path}')
 print(f'workspace_mount_path_in_sandbox: {workspace_mount_path_in_sandbox}')
 
 
-def validate_final_state(final_state: State | None):
+def get_number_of_prompts(test_name: str):
+    mock_dir = os.path.join(
+        os.environ.get('SCRIPT_DIR'), 'mock', os.environ.get('DEFAULT_AGENT'), test_name
+    )
+    prompt_files = [file for file in os.listdir(mock_dir) if file.startswith('prompt_')]
+    return len(prompt_files)
+
+
+def validate_final_state(final_state: State | None, test_name: str):
     assert final_state is not None
     assert final_state.agent_state == AgentState.STOPPED
     assert final_state.last_error is None
+    # number of LLM conversations should be the same as number of prompt/response
+    # log files under mock/[agent]/[test_name] folder. If not, it means there are
+    # redundant prompt/response log files checked into the repository.
+    num_of_conversations = get_number_of_prompts(test_name)
+    assert num_of_conversations > 0
+    # we mock the cost of every conversation to be 1 USD
+    assert final_state.metrics.accumulated_cost == num_of_conversations
     if final_state.history.has_delegation():
         assert final_state.iteration > final_state.local_iteration
     else:
@@ -55,7 +70,7 @@ def validate_final_state(final_state: State | None):
     os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
     reason='Manager agent is not capable of finishing this in reasonable steps yet',
 )
-def test_write_simple_script() -> None:
+def test_write_simple_script(current_test_name) -> None:
     task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
     args = parse_arguments()
 
@@ -65,9 +80,7 @@ def test_write_simple_script() -> None:
     final_state: State | None = asyncio.run(
         run_agent_controller(agent, task, exit_on_message=True)
     )
-    assert final_state is not None
-    assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.last_error is None
+    validate_final_state(final_state, current_test_name)
 
     # Verify the script file exists
     assert workspace_base is not None
@@ -103,7 +116,7 @@ def test_write_simple_script() -> None:
     os.getenv('SANDBOX_BOX_TYPE') == 'local',
     reason='local sandbox shows environment-dependent absolute path for pwd command',
 )
-def test_edits():
+def test_edits(current_test_name):
     args = parse_arguments()
     # Copy workspace artifacts to workspace_base location
     source_dir = os.path.join(os.path.dirname(__file__), 'workspace/test_edits/')
@@ -122,7 +135,7 @@ def test_edits():
     final_state: State | None = asyncio.run(
         run_agent_controller(agent, task, exit_on_message=True)
     )
-    validate_final_state(final_state)
+    validate_final_state(final_state, current_test_name)
 
     # Verify bad.txt has been fixed
     text = """This is a stupid typo.
@@ -144,7 +157,7 @@ Enjoy!
     os.getenv('SANDBOX_BOX_TYPE') != 'ssh',
     reason='Currently, only ssh sandbox supports stateful tasks',
 )
-def test_ipython():
+def test_ipython(current_test_name):
     args = parse_arguments()
 
     # Create the agent
@@ -155,7 +168,7 @@ def test_ipython():
     final_state: State | None = asyncio.run(
         run_agent_controller(agent, task, exit_on_message=True)
     )
-    validate_final_state(final_state)
+    validate_final_state(final_state, current_test_name)
 
     # Verify the file exists
     file_path = os.path.join(workspace_base, 'test.txt')
@@ -177,7 +190,7 @@ def test_ipython():
     os.getenv('SANDBOX_BOX_TYPE') == 'local',
     reason='FIXME: local sandbox does not capture stderr',
 )
-def test_simple_task_rejection():
+def test_simple_task_rejection(current_test_name):
     args = parse_arguments()
 
     # Create the agent
@@ -187,7 +200,7 @@ def test_simple_task_rejection():
     # the workspace is not a git repo
     task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
     final_state: State | None = asyncio.run(run_agent_controller(agent, task))
-    validate_final_state(final_state)
+    validate_final_state(final_state, current_test_name)
     assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
 
 
@@ -200,7 +213,7 @@ def test_simple_task_rejection():
     os.getenv('SANDBOX_BOX_TYPE') != 'ssh',
     reason='Currently, only ssh sandbox supports stateful tasks',
 )
-def test_ipython_module():
+def test_ipython_module(current_test_name):
     args = parse_arguments()
 
     # Create the agent
@@ -211,7 +224,7 @@ def test_ipython_module():
     final_state: State | None = asyncio.run(
         run_agent_controller(agent, task, exit_on_message=True)
     )
-    validate_final_state(final_state)
+    validate_final_state(final_state, current_test_name)
 
     # Verify the file exists
     file_path = os.path.join(workspace_base, 'test.txt')
@@ -239,7 +252,7 @@ def test_ipython_module():
     and os.getenv('SANDBOX_BOX_TYPE', '').lower() != 'ssh',
     reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
 )
-def test_browse_internet(http_server):
+def test_browse_internet(http_server, current_test_name):
     args = parse_arguments()
 
     # Create the agent
@@ -250,7 +263,7 @@ def test_browse_internet(http_server):
     final_state: State | None = asyncio.run(
         run_agent_controller(agent, task, exit_on_message=True)
     )
-    validate_final_state(final_state)
+    validate_final_state(final_state, current_test_name)
 
     # last action
     last_action = final_state.history.get_last_action()