agent_controller.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import asyncio
  2. import inspect
  3. import traceback
  4. from typing import List, Callable, Literal, Mapping, Awaitable, Any, cast
  5. from termcolor import colored
  6. from opendevin import config
  7. from opendevin.action import (
  8. Action,
  9. NullAction,
  10. AgentFinishAction,
  11. AddTaskAction,
  12. ModifyTaskAction,
  13. )
  14. from opendevin.agent import Agent
  15. from opendevin.logger import opendevin_logger as logger
  16. from opendevin.observation import Observation, AgentErrorObservation, NullObservation
  17. from opendevin.plan import Plan
  18. from opendevin.state import State
  19. from .command_manager import CommandManager
  20. ColorType = Literal[
  21. "red",
  22. "green",
  23. "yellow",
  24. "blue",
  25. "magenta",
  26. "cyan",
  27. "light_grey",
  28. "dark_grey",
  29. "light_red",
  30. "light_green",
  31. "light_yellow",
  32. "light_blue",
  33. "light_magenta",
  34. "light_cyan",
  35. "white",
  36. ]
  37. DISABLE_COLOR_PRINTING = (
  38. config.get('DISABLE_COLOR').lower() == 'true'
  39. )
  40. MAX_ITERATIONS = config.get("MAX_ITERATIONS")
  41. def print_with_color(text: Any, print_type: str = "INFO"):
  42. TYPE_TO_COLOR: Mapping[str, ColorType] = {
  43. "BACKGROUND LOG": "blue",
  44. "ACTION": "green",
  45. "OBSERVATION": "yellow",
  46. "INFO": "cyan",
  47. "ERROR": "red",
  48. "PLAN": "light_magenta",
  49. }
  50. color = TYPE_TO_COLOR.get(print_type.upper(), TYPE_TO_COLOR["INFO"])
  51. if DISABLE_COLOR_PRINTING:
  52. print(f'\n{print_type.upper()}:\n{str(text)}', flush=True)
  53. else:
  54. print(
  55. colored(f'\n{print_type.upper()}:\n', color, attrs=['bold'])
  56. + colored(str(text), color),
  57. flush=True,
  58. )
  59. class AgentController:
  60. id: str
  61. agent: Agent
  62. max_iterations: int
  63. workdir: str
  64. command_manager: CommandManager
  65. callbacks: List[Callable]
  66. def __init__(
  67. self,
  68. agent: Agent,
  69. workdir: str,
  70. sid: str = "",
  71. max_iterations: int = MAX_ITERATIONS,
  72. container_image: str | None = None,
  73. callbacks: List[Callable] = [],
  74. ):
  75. self.id = sid
  76. self.agent = agent
  77. self.max_iterations = max_iterations
  78. self.workdir = workdir
  79. self.command_manager = CommandManager(self.id, workdir, container_image)
  80. self.callbacks = callbacks
  81. def update_state_for_step(self, i):
  82. self.state.iteration = i
  83. self.state.background_commands_obs = self.command_manager.get_background_obs()
  84. def update_state_after_step(self):
  85. self.state.updated_info = []
  86. def add_history(self, action: Action, observation: Observation):
  87. if not isinstance(action, Action):
  88. raise ValueError("action must be an instance of Action")
  89. if not isinstance(observation, Observation):
  90. raise ValueError("observation must be an instance of Observation")
  91. self.state.history.append((action, observation))
  92. self.state.updated_info.append((action, observation))
  93. async def start_loop(self, task: str):
  94. finished = False
  95. plan = Plan(task)
  96. self.state = State(plan)
  97. for i in range(self.max_iterations):
  98. try:
  99. finished = await self.step(i)
  100. except Exception as e:
  101. logger.error("Error in loop", exc_info=True)
  102. raise e
  103. if finished:
  104. break
  105. if not finished:
  106. logger.info("Exited before finishing the task.")
  107. async def step(self, i: int):
  108. print("\n\n==============", flush=True)
  109. print("STEP", i, flush=True)
  110. print_with_color(self.state.plan.main_goal, "PLAN")
  111. log_obs = self.command_manager.get_background_obs()
  112. for obs in log_obs:
  113. self.add_history(NullAction(), obs)
  114. await self._run_callbacks(obs)
  115. print_with_color(obs, "BACKGROUND LOG")
  116. self.update_state_for_step(i)
  117. action: Action = NullAction()
  118. observation: Observation = NullObservation("")
  119. try:
  120. action = self.agent.step(self.state)
  121. if action is None:
  122. raise ValueError("Agent must return an action")
  123. print_with_color(action, "ACTION")
  124. except Exception as e:
  125. observation = AgentErrorObservation(str(e))
  126. print_with_color(observation, "ERROR")
  127. traceback.print_exc()
  128. # TODO Change to more robust error handling
  129. if (
  130. "The api_key client option must be set" in observation.content
  131. or "Incorrect API key provided:" in observation.content
  132. ):
  133. raise
  134. self.update_state_after_step()
  135. await self._run_callbacks(action)
  136. finished = isinstance(action, AgentFinishAction)
  137. if finished:
  138. print_with_color(action, "INFO")
  139. return True
  140. if isinstance(action, AddTaskAction):
  141. try:
  142. self.state.plan.add_subtask(action.parent, action.goal, action.subtasks)
  143. except Exception as e:
  144. observation = AgentErrorObservation(str(e))
  145. print_with_color(observation, "ERROR")
  146. traceback.print_exc()
  147. elif isinstance(action, ModifyTaskAction):
  148. try:
  149. self.state.plan.set_subtask_state(action.id, action.state)
  150. except Exception as e:
  151. observation = AgentErrorObservation(str(e))
  152. print_with_color(observation, "ERROR")
  153. traceback.print_exc()
  154. if action.executable:
  155. try:
  156. if inspect.isawaitable(action.run(self)):
  157. observation = await cast(Awaitable[Observation], action.run(self))
  158. else:
  159. observation = action.run(self)
  160. except Exception as e:
  161. observation = AgentErrorObservation(str(e))
  162. print_with_color(observation, "ERROR")
  163. traceback.print_exc()
  164. if not isinstance(observation, NullObservation):
  165. print_with_color(observation, "OBSERVATION")
  166. self.add_history(action, observation)
  167. await self._run_callbacks(observation)
  168. async def _run_callbacks(self, event):
  169. if event is None:
  170. return
  171. for callback in self.callbacks:
  172. idx = self.callbacks.index(callback)
  173. try:
  174. callback(event)
  175. except Exception as e:
  176. logger.exception(f"Callback error: {e}, idx: {idx}")
  177. await asyncio.sleep(
  178. 0.001
  179. ) # Give back control for a tick, so we can await in callbacks