agent_controller.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import asyncio
  2. import inspect
  3. import traceback
  4. from typing import List, Callable, Literal, Mapping, Awaitable, Any, cast
  5. from termcolor import colored
  6. from opendevin.plan import Plan
  7. from opendevin.state import State
  8. from opendevin.agent import Agent
  9. from opendevin.action import (
  10. Action,
  11. NullAction,
  12. AgentFinishAction,
  13. AddTaskAction,
  14. ModifyTaskAction,
  15. )
  16. from opendevin.observation import Observation, AgentErrorObservation, NullObservation
  17. from opendevin import config
  18. from opendevin.logging import opendevin_logger as logger
  19. from .command_manager import CommandManager
  20. ColorType = Literal[
  21. 'red',
  22. 'green',
  23. 'yellow',
  24. 'blue',
  25. 'magenta',
  26. 'cyan',
  27. 'light_grey',
  28. 'dark_grey',
  29. 'light_red',
  30. 'light_green',
  31. 'light_yellow',
  32. 'light_blue',
  33. 'light_magenta',
  34. 'light_cyan',
  35. 'white',
  36. ]
  37. DISABLE_COLOR_PRINTING = (
  38. config.get_or_default('DISABLE_COLOR', 'false').lower() == 'true'
  39. )
  40. MAX_ITERATIONS = config.get('MAX_ITERATIONS')
  41. def print_with_color(text: Any, print_type: str = 'INFO'):
  42. TYPE_TO_COLOR: Mapping[str, ColorType] = {
  43. 'BACKGROUND LOG': 'blue',
  44. 'ACTION': 'green',
  45. 'OBSERVATION': 'yellow',
  46. 'INFO': 'cyan',
  47. 'ERROR': 'red',
  48. 'PLAN': 'light_magenta',
  49. }
  50. color = TYPE_TO_COLOR.get(print_type.upper(), TYPE_TO_COLOR['INFO'])
  51. if DISABLE_COLOR_PRINTING:
  52. print(f"\n{print_type.upper()}:\n{str(text)}", flush=True)
  53. else:
  54. print(
  55. colored(f"\n{print_type.upper()}:\n", color, attrs=['bold'])
  56. + colored(str(text), color),
  57. flush=True,
  58. )
  59. class AgentController:
  60. id: str
  61. def __init__(
  62. self,
  63. agent: Agent,
  64. workdir: str,
  65. id: str = '',
  66. max_iterations: int = MAX_ITERATIONS,
  67. container_image: str | None = None,
  68. callbacks: List[Callable] = [],
  69. ):
  70. self.id = id
  71. self.agent = agent
  72. self.max_iterations = max_iterations
  73. self.workdir = workdir
  74. self.command_manager = CommandManager(
  75. self.id, workdir, container_image)
  76. self.callbacks = callbacks
  77. def update_state_for_step(self, i):
  78. self.state.iteration = i
  79. self.state.background_commands_obs = self.command_manager.get_background_obs()
  80. def update_state_after_step(self):
  81. self.state.updated_info = []
  82. def add_history(self, action: Action, observation: Observation):
  83. if not isinstance(action, Action):
  84. raise ValueError('action must be an instance of Action')
  85. if not isinstance(observation, Observation):
  86. raise ValueError('observation must be an instance of Observation')
  87. self.state.history.append((action, observation))
  88. self.state.updated_info.append((action, observation))
  89. async def start_loop(self, task: str):
  90. finished = False
  91. plan = Plan(task)
  92. self.state = State(plan)
  93. for i in range(self.max_iterations):
  94. try:
  95. finished = await self.step(i)
  96. except Exception as e:
  97. logger.error('Error in loop', exc_info=True)
  98. raise e
  99. if finished:
  100. break
  101. if not finished:
  102. logger.info('Exited before finishing the task.')
  103. async def step(self, i: int):
  104. print('\n\n==============', flush=True)
  105. print('STEP', i, flush=True)
  106. print_with_color(self.state.plan.main_goal, 'PLAN')
  107. log_obs = self.command_manager.get_background_obs()
  108. for obs in log_obs:
  109. self.add_history(NullAction(), obs)
  110. await self._run_callbacks(obs)
  111. print_with_color(obs, 'BACKGROUND LOG')
  112. self.update_state_for_step(i)
  113. action: Action = NullAction()
  114. observation: Observation = NullObservation('')
  115. try:
  116. action = self.agent.step(self.state)
  117. if action is None:
  118. raise ValueError('Agent must return an action')
  119. print_with_color(action, 'ACTION')
  120. except Exception as e:
  121. observation = AgentErrorObservation(str(e))
  122. print_with_color(observation, 'ERROR')
  123. traceback.print_exc()
  124. # TODO Change to more robust error handling
  125. if 'The api_key client option must be set' or 'Incorrect API key provided:' in observation.content:
  126. raise
  127. self.update_state_after_step()
  128. await self._run_callbacks(action)
  129. finished = isinstance(action, AgentFinishAction)
  130. if finished:
  131. print_with_color(action, 'INFO')
  132. return True
  133. if isinstance(action, AddTaskAction):
  134. try:
  135. self.state.plan.add_subtask(
  136. action.parent, action.goal, action.subtasks)
  137. except Exception as e:
  138. observation = AgentErrorObservation(str(e))
  139. print_with_color(observation, 'ERROR')
  140. traceback.print_exc()
  141. elif isinstance(action, ModifyTaskAction):
  142. try:
  143. self.state.plan.set_subtask_state(action.id, action.state)
  144. except Exception as e:
  145. observation = AgentErrorObservation(str(e))
  146. print_with_color(observation, 'ERROR')
  147. traceback.print_exc()
  148. if action.executable:
  149. try:
  150. if inspect.isawaitable(action.run(self)):
  151. observation = await cast(Awaitable[Observation], action.run(self))
  152. else:
  153. observation = action.run(self)
  154. except Exception as e:
  155. observation = AgentErrorObservation(str(e))
  156. print_with_color(observation, 'ERROR')
  157. traceback.print_exc()
  158. if not isinstance(observation, NullObservation):
  159. print_with_color(observation, 'OBSERVATION')
  160. self.add_history(action, observation)
  161. await self._run_callbacks(observation)
  162. async def _run_callbacks(self, event):
  163. if event is None:
  164. return
  165. for callback in self.callbacks:
  166. idx = self.callbacks.index(callback)
  167. try:
  168. callback(event)
  169. except Exception:
  170. logger.exception('Callback error: %s', idx)
  171. pass
  172. await asyncio.sleep(
  173. 0.001
  174. ) # Give back control for a tick, so we can await in callbacks