Răsfoiți Sursa

feat: add max_budget_per_task configuration to control task cost (#2070)

* feat: add max_budget_per_task configuration to control task cost

* Fix test_arg_parser.py

* Use the config.max_budget_per_task as default value

* Add max_budget_per_task to core/main.py as well

* Update opendevin/controller/agent_controller.py

Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>

---------

Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>
Aleksandar 1 an în urmă
părinte
comite
18d07bda89

+ 13 - 2
opendevin/controller/agent_controller.py

@@ -37,6 +37,7 @@ from opendevin.events.observation import (
 
 MAX_ITERATIONS = config.max_iterations
 MAX_CHARS = config.llm.max_chars
+MAX_BUDGET_PER_TASK = config.max_budget_per_task
 
 
 class AgentController:
@@ -56,6 +57,7 @@ class AgentController:
         sid: str = 'default',
         max_iterations: int = MAX_ITERATIONS,
         max_chars: int = MAX_CHARS,
+        max_budget_per_task: float | None = MAX_BUDGET_PER_TASK,
         inputs: dict | None = None,
     ):
         """Initializes a new instance of the AgentController class.
@@ -66,6 +68,7 @@ class AgentController:
             sid: The session ID of the agent.
             max_iterations: The maximum number of iterations the agent can run.
             max_chars: The maximum number of characters the agent can output.
+            max_budget_per_task: The maximum budget (in USD) allowed per task, beyond which the agent will stop.
             inputs: The initial inputs to the agent.
         """
         self.id = sid
@@ -77,6 +80,7 @@ class AgentController:
         )
         self.max_iterations = max_iterations
         self.max_chars = max_chars
+        self.max_budget_per_task = max_budget_per_task
         self.agent_task = asyncio.create_task(self._start_step_loop())
 
     async def close(self):
@@ -88,10 +92,17 @@ class AgentController:
     def update_state_before_step(self):
         self.state.iteration += 1
 
-    def update_state_after_step(self):
+    async def update_state_after_step(self):
         self.state.updated_info = []
         # update metrics especially for cost
         self.state.metrics = self.agent.llm.metrics
+        if self.max_budget_per_task is not None:
+            current_cost = self.state.metrics.accumulated_cost
+            if current_cost > self.max_budget_per_task:
+                await self.report_error(
+                    f'Task budget exceeded. Current cost: {current_cost}, Max budget: {self.max_budget_per_task}'
+                )
+                await self.set_agent_state_to(AgentState.ERROR)
 
     async def report_error(self, message: str, exception: Exception | None = None):
         self.state.error = message
@@ -235,7 +246,7 @@ class AgentController:
 
         logger.info(action, extra={'msg_type': 'ACTION'})
 
-        self.update_state_after_step()
+        await self.update_state_after_step()
         if action.runnable:
             self._pending_action = action
         else:

+ 9 - 0
opendevin/core/config.py

@@ -137,6 +137,7 @@ class AppConfig(metaclass=Singleton):
         sandbox_container_image: The container image to use for the sandbox.
         run_as_devin: Whether to run as devin.
         max_iterations: The maximum number of iterations.
+        max_budget_per_task: The maximum budget allowed per task, beyond which the agent will stop.
         e2b_api_key: The E2B API key.
         sandbox_type: The type of sandbox to use. Options are: ssh, exec, e2b, local.
         use_host_network: Whether to use the host network.
@@ -166,6 +167,7 @@ class AppConfig(metaclass=Singleton):
     )
     run_as_devin: bool = True
     max_iterations: int = 100
+    max_budget_per_task: float | None = None
     e2b_api_key: str = ''
     sandbox_type: str = 'ssh'  # Can be 'ssh', 'exec', or 'e2b'
     use_host_network: bool = False
@@ -490,6 +492,13 @@ def get_parser():
         type=int,
         help='The maximum number of iterations to run the agent',
     )
+    parser.add_argument(
+        '-b',
+        '--max-budget-per-task',
+        default=config.max_budget_per_task,
+        type=float,
+        help='The maximum budget allowed per task, beyond which the agent will stop.',
+    )
     parser.add_argument(
         '-n',
         '--max-chars',

+ 1 - 0
opendevin/core/main.py

@@ -86,6 +86,7 @@ async def main(
     controller = AgentController(
         agent=agent,
         max_iterations=args.max_iterations,
+        max_budget_per_task=args.max_budget_per_task,
         max_chars=args.max_chars,
         event_stream=event_stream,
     )

+ 5 - 2
tests/unit/test_arg_parser.py

@@ -10,8 +10,8 @@ def test_help_message(capsys):
     captured = capsys.readouterr()
     expected_help_message = """
 usage: pytest [-h] [-d DIRECTORY] [-t TASK] [-f FILE] [-c AGENT_CLS]
-              [-m MODEL_NAME] [-i MAX_ITERATIONS] [-n MAX_CHARS]
-              [--eval-output-dir EVAL_OUTPUT_DIR]
+              [-m MODEL_NAME] [-i MAX_ITERATIONS] [-b MAX_BUDGET_PER_TASK]
+              [-n MAX_CHARS] [--eval-output-dir EVAL_OUTPUT_DIR]
               [--eval-n-limit EVAL_N_LIMIT]
               [--eval-num-workers EVAL_NUM_WORKERS] [--eval-note EVAL_NOTE]
               [-l LLM_CONFIG]
@@ -31,6 +31,9 @@ options:
                         The (litellm) model name to use
   -i MAX_ITERATIONS, --max-iterations MAX_ITERATIONS
                         The maximum number of iterations to run the agent
+  -b MAX_BUDGET_PER_TASK, --max-budget-per-task MAX_BUDGET_PER_TASK
+                        The maximum budget allowed per task, beyond which the
+                        agent will stop.
   -n MAX_CHARS, --max-chars MAX_CHARS
                         The maximum number of characters to send to and
                         receive from LLM per task

+ 17 - 0
tests/unit/test_config.py

@@ -285,3 +285,20 @@ def test_api_keys_repr_str():
             assert (
                 'token' not in attr_name.lower() or 'tokens' in attr_name.lower()
             ), f"Unexpected attribute '{attr_name}' contains 'token' in AppConfig"
+
+
+def test_max_iterations_and_max_budget_per_task_from_toml(temp_toml_file):
+    temp_toml = """
+[core]
+max_iterations = 100
+max_budget_per_task = 4.0
+"""
+
+    config = AppConfig()
+    with open(temp_toml_file, 'w') as f:
+        f.write(temp_toml)
+
+    load_from_toml(config, temp_toml_file)
+
+    assert config.max_iterations == 100
+    assert config.max_budget_per_task == 4.0