Browse Source

[Mint evaluation] Fix bug in stopping when the agent reaches max steps or solution proposals (#2268)

* fix: bug in stopping when the agent reaches max steps or solution proposals

* remove --eval-num-workers

* update env.py
Ryan H. Tran 1 year ago
parent
commit
0584e428b2

+ 9 - 5
evaluation/mint/datatypes.py

@@ -13,11 +13,15 @@ class TaskState:
     ):
         self.finished = finished
         self.success = success
-        self.agent_action_count: Dict[str, int] = agent_action_count or {
-            'propose_solution': 0,
-            'use_tool': 0,
-            'invalid_action': 0,
-        }
+        self.agent_action_count: Dict[str, int] = (
+            agent_action_count
+            if agent_action_count
+            else {
+                'propose_solution': 0,
+                'use_tool': 0,
+                'invalid_action': 0,
+            }
+        )
         self.terminate_reason = terminate_reason
         self.latest_output = latest_output
 

+ 18 - 6
evaluation/mint/env.py

@@ -19,7 +19,20 @@ class SimplifiedEnv:
     def __init__(self, agent_state: State, task: Task, task_config: Dict[str, int]):
         self.agent_state = agent_state
         self.task = task
-        self.task_state = TaskState()
+
+        agent_action_count = {
+            'propose_solution': 0,
+            'use_tool': 0,
+            'invalid_action': 0,
+        }
+        # check if agent_state has attribute turn_info set
+        if hasattr(self.agent_state, 'propose_solution_count'):
+            agent_action_count['propose_solution'] = (
+                self.agent_state.propose_solution_count
+            )
+
+        self.task_state = TaskState(agent_action_count=agent_action_count)
+
         self.task_config = task_config
 
     def step(self, lm_message: str):
@@ -39,6 +52,9 @@ class SimplifiedEnv:
             turn_info=turn_info,
         )
 
+        self.agent_state.propose_solution_count = self.task_state.agent_action_count[
+            'propose_solution'
+        ]
         self.log_output(output)
         return self.task_state
 
@@ -109,11 +125,7 @@ class SimplifiedEnv:
             self.task_state.finished = True
             self.task_state.success = False
             self.task_state.terminate_reason = 'max_propose_steps'
-        elif (
-            # (propose_solution + use_tool) > max iteration limit
-            sum(self.task_state.agent_action_count.values())
-            >= self.task_config['max_iterations']
-        ):
+        elif self.agent_state.iteration >= self.task_config['max_iterations']:
             self.task_state.finished = True
             self.task_state.success = False
             self.task_state.terminate_reason = 'max_iterations'

+ 2 - 4
evaluation/mint/run_infer.py

@@ -51,10 +51,8 @@ def codeact_user_response(state: State, task: Task, task_config: Dict[str, int])
     state.task_state = result_state
 
     if not result_state.latest_output:
-        if result_state.success:
-            msg = '/exit'
-        else:
-            msg = 'Something went wrong! No output from the model.'
+        # Task is finished
+        msg = '/exit'
     else:
         msg = result_state.latest_output['content']
 

+ 1 - 0
evaluation/mint/scripts/run_infer.sh

@@ -15,6 +15,7 @@ echo "AGENT_VERSION: $AGENT_VERSION"
 export PYTHONPATH=$(pwd)
 
 COMMAND="poetry run python ./evaluation/mint/run_infer.py \
+    --llm-config $MODEL_CONFIG \
     --max-iterations 5 \
     --max-propose-solution 2 \
     --eval-note $AGENT_VERSION"