env.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import re
  2. import traceback
  3. from datatypes import ParseError, StepOutput, TaskState
  4. from tasks.base import Task
  5. from openhands.controller.state.state import State
  6. class SimplifiedEnv:
  7. INVALID_INPUT_MESSAGE = (
  8. "I don't understand your input. \n"
  9. 'If you want to execute code, please use <execute_ipython> YOUR_CODE_HERE </execute_ipython>.\n'
  10. 'If you want to give me an answer, please use <solution> YOUR_SOLUTION_HERE </solution>.\n'
  11. 'For example: The answer to the question is <solution> 42 </solution>. \n'
  12. )
  13. def __init__(self, agent_state: State, task: Task, task_config: dict[str, int]):
  14. self.agent_state = agent_state
  15. self.task = task
  16. agent_action_count = {
  17. 'propose_solution': 0,
  18. 'use_tool': 0,
  19. 'invalid_action': 0,
  20. }
  21. # check if agent_state has attribute turn_info set
  22. if hasattr(self.agent_state, 'propose_solution_count'):
  23. agent_action_count['propose_solution'] = (
  24. self.agent_state.propose_solution_count
  25. )
  26. self.task_state = TaskState(agent_action_count=agent_action_count)
  27. self.task_config = task_config
  28. def step(self, lm_message: str):
  29. observation = self.handle_propose_solution(lm_message)
  30. self.check_max_iteration()
  31. turn_info = (
  32. self.task_config['max_iterations'] - self.agent_state.iteration,
  33. self.task_config['max_propose_solution']
  34. - self.task_state.agent_action_count['propose_solution'],
  35. )
  36. output = StepOutput(
  37. observation=observation,
  38. success=self.task_state.success,
  39. turn_info=turn_info,
  40. )
  41. self.agent_state.propose_solution_count = self.task_state.agent_action_count[
  42. 'propose_solution'
  43. ]
  44. self.log_output(output)
  45. return self.task_state
  46. def handle_propose_solution(self, lm_message) -> str | None:
  47. """Propose answer to check the task success.
  48. It might set self.state.finished = True if the task is successful.
  49. """
  50. self.task_state.agent_action_count['propose_solution'] += 1
  51. try:
  52. parsed = self.parse_propose_solution(lm_message)
  53. task_success = self.check_task_success(parsed['answer'])
  54. if task_success:
  55. self.task_state.finished = True
  56. self.task_state.success = True
  57. self.task_state.terminate_reason = 'task_success'
  58. # NOTE: should not return the function now, because we need to log the output
  59. # Set state.finished = True will terminate the episode
  60. except ParseError:
  61. return SimplifiedEnv.INVALID_INPUT_MESSAGE
  62. except Exception:
  63. error_traceback = traceback.format_exc()
  64. return f'{error_traceback}'
  65. def parse_propose_solution(self, lm_message: str) -> dict:
  66. """Define the parsing logic."""
  67. lm_output = '\n' + lm_message + '\n'
  68. answer = '\n'.join(
  69. [
  70. i.strip()
  71. for i in re.findall(r'<solution>(.*?)</solution>', lm_output, re.DOTALL)
  72. ]
  73. )
  74. if answer == '':
  75. raise ParseError('No answer found.')
  76. return {'answer': answer}
  77. def log_output(self, output: StepOutput) -> None:
  78. if self.task_state.finished:
  79. return
  80. content = output.to_str()
  81. self.task_state.latest_output = output.to_dict()
  82. self.task_state.latest_output['content'] = content
  83. def check_task_success(self, answer: str) -> bool:
  84. # log_message.info(f"STUDENT ANSWER: [{answer}]")
  85. # log_message.info(f"REFERENCE ANSWER: [{self.task.reference}]")
  86. return self.task.success(answer)
  87. def check_max_iteration(self):
  88. """Check if the agent has reached the max iteration limit.
  89. It might set self.state.finished = True if the agent has reached the max iteration limit.
  90. """
  91. if self.task_state.finished:
  92. # ignore if the episode is already finished (e.g., task success)
  93. return
  94. if (
  95. # propose solution > max output solution
  96. self.task_state.agent_action_count['propose_solution']
  97. >= self.task_config['max_propose_solution']
  98. ):
  99. self.task_state.finished = True
  100. self.task_state.success = False
  101. self.task_state.terminate_reason = 'max_propose_steps'
  102. elif self.agent_state.iteration >= self.task_config['max_iterations']:
  103. self.task_state.finished = True
  104. self.task_state.success = False
  105. self.task_state.terminate_reason = 'max_iterations'