|
@@ -19,7 +19,20 @@ class SimplifiedEnv:
|
|
|
def __init__(self, agent_state: State, task: Task, task_config: Dict[str, int]):
|
|
def __init__(self, agent_state: State, task: Task, task_config: Dict[str, int]):
|
|
|
self.agent_state = agent_state
|
|
self.agent_state = agent_state
|
|
|
self.task = task
|
|
self.task = task
|
|
|
- self.task_state = TaskState()
|
|
|
|
|
|
|
+
|
|
|
|
|
+ agent_action_count = {
|
|
|
|
|
+ 'propose_solution': 0,
|
|
|
|
|
+ 'use_tool': 0,
|
|
|
|
|
+ 'invalid_action': 0,
|
|
|
|
|
+ }
|
|
|
|
|
+ # check if agent_state has attribute turn_info set
|
|
|
|
|
+ if hasattr(self.agent_state, 'propose_solution_count'):
|
|
|
|
|
+ agent_action_count['propose_solution'] = (
|
|
|
|
|
+ self.agent_state.propose_solution_count
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ self.task_state = TaskState(agent_action_count=agent_action_count)
|
|
|
|
|
+
|
|
|
self.task_config = task_config
|
|
self.task_config = task_config
|
|
|
|
|
|
|
|
def step(self, lm_message: str):
|
|
def step(self, lm_message: str):
|
|
@@ -39,6 +52,9 @@ class SimplifiedEnv:
|
|
|
turn_info=turn_info,
|
|
turn_info=turn_info,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ self.agent_state.propose_solution_count = self.task_state.agent_action_count[
|
|
|
|
|
+ 'propose_solution'
|
|
|
|
|
+ ]
|
|
|
self.log_output(output)
|
|
self.log_output(output)
|
|
|
return self.task_state
|
|
return self.task_state
|
|
|
|
|
|
|
@@ -109,11 +125,7 @@ class SimplifiedEnv:
|
|
|
self.task_state.finished = True
|
|
self.task_state.finished = True
|
|
|
self.task_state.success = False
|
|
self.task_state.success = False
|
|
|
self.task_state.terminate_reason = 'max_propose_steps'
|
|
self.task_state.terminate_reason = 'max_propose_steps'
|
|
|
- elif (
|
|
|
|
|
- # (propose_solution + use_tool) > max iteration limit
|
|
|
|
|
- sum(self.task_state.agent_action_count.values())
|
|
|
|
|
- >= self.task_config['max_iterations']
|
|
|
|
|
- ):
|
|
|
|
|
|
|
+ elif self.agent_state.iteration >= self.task_config['max_iterations']:
|
|
|
self.task_state.finished = True
|
|
self.task_state.finished = True
|
|
|
self.task_state.success = False
|
|
self.task_state.success = False
|
|
|
self.task_state.terminate_reason = 'max_iterations'
|
|
self.task_state.terminate_reason = 'max_iterations'
|