agent_controller.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import asyncio
  2. import inspect
  3. import traceback
  4. from typing import List, Callable, Literal, Mapping, Awaitable, Any, cast
  5. from termcolor import colored
  6. from opendevin.plan import Plan
  7. from opendevin.state import State
  8. from opendevin.agent import Agent
  9. from opendevin.action import (
  10. Action,
  11. NullAction,
  12. AgentFinishAction,
  13. AddTaskAction,
  14. ModifyTaskAction
  15. )
  16. from opendevin.observation import (
  17. Observation,
  18. AgentErrorObservation,
  19. NullObservation
  20. )
  21. from opendevin import config
  22. from .command_manager import CommandManager
  23. ColorType = Literal['red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'light_grey', 'dark_grey', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_magenta', 'light_cyan', 'white']
  24. DISABLE_COLOR_PRINTING = config.get_or_default("DISABLE_COLOR", "false").lower() == "true"
  25. MAX_ITERATIONS = config.get("MAX_ITERATIONS")
  26. def print_with_color(text: Any, print_type: str = "INFO"):
  27. TYPE_TO_COLOR: Mapping[str, ColorType] = {
  28. "BACKGROUND LOG": "blue",
  29. "ACTION": "green",
  30. "OBSERVATION": "yellow",
  31. "INFO": "cyan",
  32. "ERROR": "red",
  33. "PLAN": "light_magenta",
  34. }
  35. color = TYPE_TO_COLOR.get(print_type.upper(), TYPE_TO_COLOR["INFO"])
  36. if DISABLE_COLOR_PRINTING:
  37. print(f"\n{print_type.upper()}:\n{str(text)}", flush=True)
  38. else:
  39. print(
  40. colored(f"\n{print_type.upper()}:\n", color, attrs=["bold"])
  41. + colored(str(text), color),
  42. flush=True,
  43. )
  44. class AgentController:
  45. def __init__(
  46. self,
  47. agent: Agent,
  48. workdir: str,
  49. max_iterations: int = MAX_ITERATIONS,
  50. container_image: str | None = None,
  51. callbacks: List[Callable] = [],
  52. ):
  53. self.agent = agent
  54. self.max_iterations = max_iterations
  55. self.workdir = workdir
  56. self.command_manager = CommandManager(workdir,container_image)
  57. self.callbacks = callbacks
  58. def update_state_for_step(self, i):
  59. self.state.iteration = i
  60. self.state.background_commands_obs = self.command_manager.get_background_obs()
  61. def update_state_after_step(self):
  62. self.state.updated_info = []
  63. def add_history(self, action: Action, observation: Observation):
  64. if not isinstance(action, Action):
  65. raise ValueError("action must be an instance of Action")
  66. if not isinstance(observation, Observation):
  67. raise ValueError("observation must be an instance of Observation")
  68. self.state.history.append((action, observation))
  69. self.state.updated_info.append((action, observation))
  70. async def start_loop(self, task: str):
  71. finished = False
  72. plan = Plan(task)
  73. self.state = State(plan)
  74. for i in range(self.max_iterations):
  75. try:
  76. finished = await self.step(i)
  77. except Exception as e:
  78. print("Error in loop", e, flush=True)
  79. raise e
  80. if finished:
  81. break
  82. if not finished:
  83. print("Exited before finishing", flush=True)
  84. async def step(self, i: int):
  85. print("\n\n==============", flush=True)
  86. print("STEP", i, flush=True)
  87. print_with_color(self.state.plan.main_goal, "PLAN")
  88. log_obs = self.command_manager.get_background_obs()
  89. for obs in log_obs:
  90. self.add_history(NullAction(), obs)
  91. await self._run_callbacks(obs)
  92. print_with_color(obs, "BACKGROUND LOG")
  93. self.update_state_for_step(i)
  94. action: Action = NullAction()
  95. observation: Observation = NullObservation("")
  96. try:
  97. action = self.agent.step(self.state)
  98. if action is None:
  99. raise ValueError("Agent must return an action")
  100. print_with_color(action, "ACTION")
  101. except Exception as e:
  102. observation = AgentErrorObservation(str(e))
  103. print_with_color(observation, "ERROR")
  104. traceback.print_exc()
  105. # TODO Change to more robust error handling
  106. if "The api_key client option must be set" in observation.content:
  107. raise
  108. self.update_state_after_step()
  109. await self._run_callbacks(action)
  110. finished = isinstance(action, AgentFinishAction)
  111. if finished:
  112. print_with_color(action, "INFO")
  113. return True
  114. if isinstance(action, AddTaskAction):
  115. try:
  116. self.state.plan.add_subtask(action.parent, action.goal, action.subtasks)
  117. except Exception as e:
  118. observation = AgentErrorObservation(str(e))
  119. print_with_color(observation, "ERROR")
  120. traceback.print_exc()
  121. elif isinstance(action, ModifyTaskAction):
  122. try:
  123. self.state.plan.set_subtask_state(action.id, action.state)
  124. except Exception as e:
  125. observation = AgentErrorObservation(str(e))
  126. print_with_color(observation, "ERROR")
  127. traceback.print_exc()
  128. if action.executable:
  129. try:
  130. if inspect.isawaitable(action.run(self)):
  131. observation = await cast(Awaitable[Observation], action.run(self))
  132. else:
  133. observation = action.run(self)
  134. except Exception as e:
  135. observation = AgentErrorObservation(str(e))
  136. print_with_color(observation, "ERROR")
  137. traceback.print_exc()
  138. if not isinstance(observation, NullObservation):
  139. print_with_color(observation, "OBSERVATION")
  140. self.add_history(action, observation)
  141. await self._run_callbacks(observation)
  142. async def _run_callbacks(self, event):
  143. if event is None:
  144. return
  145. for callback in self.callbacks:
  146. idx = self.callbacks.index(callback)
  147. try:
  148. callback(event)
  149. except Exception as e:
  150. print("Callback error:" + str(idx), e, flush=True)
  151. pass
  152. await asyncio.sleep(0.001) # Give back control for a tick, so we can await in callbacks