env.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import re
  2. import traceback
  3. from typing import Dict, Optional
  4. from datatypes import ParseError, StepOutput, TaskState
  5. from tasks.base import Task
  6. from opendevin.controller.state.state import State
  7. class SimplifiedEnv:
  8. INVALID_INPUT_MESSAGE = (
  9. "I don't understand your input. \n"
  10. 'If you want to execute code, please use <execute_ipython> YOUR_CODE_HERE </execute_ipython>.\n'
  11. 'If you want to give me an answer, please use <solution> YOUR_SOLUTION_HERE </solution>.\n'
  12. 'For example: The answer to the question is <solution> 42 </solution>. \n'
  13. )
  14. def __init__(self, agent_state: State, task: Task, task_config: Dict[str, int]):
  15. self.agent_state = agent_state
  16. self.task = task
  17. agent_action_count = {
  18. 'propose_solution': 0,
  19. 'use_tool': 0,
  20. 'invalid_action': 0,
  21. }
  22. # check if agent_state has attribute turn_info set
  23. if hasattr(self.agent_state, 'propose_solution_count'):
  24. agent_action_count['propose_solution'] = (
  25. self.agent_state.propose_solution_count
  26. )
  27. self.task_state = TaskState(agent_action_count=agent_action_count)
  28. self.task_config = task_config
  29. def step(self, lm_message: str):
  30. observation = self.handle_propose_solution(lm_message)
  31. self.check_max_iteration()
  32. turn_info = (
  33. self.task_config['max_iterations'] - self.agent_state.iteration,
  34. self.task_config['max_propose_solution']
  35. - self.task_state.agent_action_count['propose_solution'],
  36. )
  37. output = StepOutput(
  38. observation=observation,
  39. success=self.task_state.success,
  40. turn_info=turn_info,
  41. )
  42. self.agent_state.propose_solution_count = self.task_state.agent_action_count[
  43. 'propose_solution'
  44. ]
  45. self.log_output(output)
  46. return self.task_state
  47. def handle_propose_solution(self, lm_message) -> Optional[str]:
  48. """Propose answer to check the task success.
  49. It might set self.state.finished = True if the task is successful.
  50. """
  51. self.task_state.agent_action_count['propose_solution'] += 1
  52. try:
  53. parsed = self.parse_propose_solution(lm_message)
  54. task_success = self.check_task_success(parsed['answer'])
  55. if task_success:
  56. self.task_state.finished = True
  57. self.task_state.success = True
  58. self.task_state.terminate_reason = 'task_success'
  59. # NOTE: should not return the function now, because we need to log the output
  60. # Set state.finished = True will terminate the episode
  61. except ParseError:
  62. return SimplifiedEnv.INVALID_INPUT_MESSAGE
  63. except Exception:
  64. error_traceback = traceback.format_exc()
  65. return f'{error_traceback}'
  66. def parse_propose_solution(self, lm_message: str) -> dict:
  67. """Define the parsing logic."""
  68. lm_output = '\n' + lm_message + '\n'
  69. answer = '\n'.join(
  70. [
  71. i.strip()
  72. for i in re.findall(r'<solution>(.*?)</solution>', lm_output, re.DOTALL)
  73. ]
  74. )
  75. if answer == '':
  76. raise ParseError('No answer found.')
  77. return {'answer': answer}
  78. def log_output(self, output: StepOutput) -> None:
  79. if self.task_state.finished:
  80. return
  81. content = output.to_str()
  82. # self.state.history.append({"role": "user", "content": content})
  83. self.task_state.latest_output = output.to_dict()
  84. self.task_state.latest_output['content'] = content
  85. def check_task_success(self, answer: str) -> bool:
  86. # log_message.info(f"STUDENT ANSWER: [{answer}]")
  87. # log_message.info(f"REFERENCE ANSWER: [{self.task.reference}]")
  88. return self.task.success(answer)
  89. def check_max_iteration(self):
  90. """Check if the agent has reached the max iteration limit.
  91. It might set self.state.finished = True if the agent has reached the max iteration limit.
  92. """
  93. if self.task_state.finished:
  94. # ignore if the episode is already finished (e.g., task success)
  95. return
  96. if (
  97. # propose solution > max output solution
  98. self.task_state.agent_action_count['propose_solution']
  99. >= self.task_config['max_propose_solution']
  100. ):
  101. self.task_state.finished = True
  102. self.task_state.success = False
  103. self.task_state.terminate_reason = 'max_propose_steps'
  104. elif self.agent_state.iteration >= self.task_config['max_iterations']:
  105. self.task_state.finished = True
  106. self.task_state.success = False
  107. self.task_state.terminate_reason = 'max_iterations'