agent_controller.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import asyncio
  2. import traceback
  3. from typing import List, Callable, Literal, Mapping, Any
  4. from termcolor import colored
  5. from opendevin.plan import Plan
  6. from opendevin.state import State
  7. from opendevin.agent import Agent
  8. from opendevin.action import (
  9. Action,
  10. NullAction,
  11. AgentFinishAction,
  12. AddTaskAction,
  13. ModifyTaskAction
  14. )
  15. from opendevin.observation import (
  16. Observation,
  17. AgentErrorObservation,
  18. NullObservation
  19. )
  20. from opendevin import config
  21. from .command_manager import CommandManager
  22. ColorType = Literal['red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'light_grey', 'dark_grey', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_magenta', 'light_cyan', 'white']
  23. DISABLE_COLOR_PRINTING = config.get_or_default("DISABLE_COLOR", "false").lower() == "true"
  24. def print_with_color(text: Any, print_type: str = "INFO"):
  25. TYPE_TO_COLOR: Mapping[str, ColorType] = {
  26. "BACKGROUND LOG": "blue",
  27. "ACTION": "green",
  28. "OBSERVATION": "yellow",
  29. "INFO": "cyan",
  30. "ERROR": "red",
  31. "PLAN": "light_magenta",
  32. }
  33. color = TYPE_TO_COLOR.get(print_type.upper(), TYPE_TO_COLOR["INFO"])
  34. if DISABLE_COLOR_PRINTING:
  35. print(f"\n{print_type.upper()}:\n{str(text)}", flush=True)
  36. else:
  37. print(
  38. colored(f"\n{print_type.upper()}:\n", color, attrs=["bold"])
  39. + colored(str(text), color),
  40. flush=True,
  41. )
  42. class AgentController:
  43. def __init__(
  44. self,
  45. agent: Agent,
  46. workdir: str,
  47. max_iterations: int = 100,
  48. container_image: str | None = None,
  49. callbacks: List[Callable] = [],
  50. ):
  51. self.agent = agent
  52. self.max_iterations = max_iterations
  53. self.workdir = workdir
  54. self.command_manager = CommandManager(workdir,container_image)
  55. self.callbacks = callbacks
  56. def update_state_for_step(self, i):
  57. self.state.iteration = i
  58. self.state.background_commands_obs = self.command_manager.get_background_obs()
  59. def update_state_after_step(self):
  60. self.state.updated_info = []
  61. def add_history(self, action: Action, observation: Observation):
  62. if not isinstance(action, Action):
  63. raise ValueError("action must be an instance of Action")
  64. if not isinstance(observation, Observation):
  65. raise ValueError("observation must be an instance of Observation")
  66. self.state.history.append((action, observation))
  67. self.state.updated_info.append((action, observation))
  68. async def start_loop(self, task: str):
  69. finished = False
  70. plan = Plan(task)
  71. self.state = State(plan)
  72. for i in range(self.max_iterations):
  73. try:
  74. finished = await self.step(i)
  75. except Exception as e:
  76. print("Error in loop", e, flush=True)
  77. raise e
  78. if finished:
  79. break
  80. if not finished:
  81. print("Exited before finishing", flush=True)
  82. async def step(self, i: int):
  83. print("\n\n==============", flush=True)
  84. print("STEP", i, flush=True)
  85. print_with_color(self.state.plan.main_goal, "PLAN")
  86. log_obs = self.command_manager.get_background_obs()
  87. for obs in log_obs:
  88. self.add_history(NullAction(), obs)
  89. await self._run_callbacks(obs)
  90. print_with_color(obs, "BACKGROUND LOG")
  91. self.update_state_for_step(i)
  92. action: Action = NullAction()
  93. observation: Observation = NullObservation("")
  94. try:
  95. action = self.agent.step(self.state)
  96. if action is None:
  97. raise ValueError("Agent must return an action")
  98. print_with_color(action, "ACTION")
  99. except Exception as e:
  100. observation = AgentErrorObservation(str(e))
  101. print_with_color(observation, "ERROR")
  102. traceback.print_exc()
  103. # TODO Change to more robust error handling
  104. if "The api_key client option must be set" in observation.content:
  105. raise
  106. self.update_state_after_step()
  107. await self._run_callbacks(action)
  108. finished = isinstance(action, AgentFinishAction)
  109. if finished:
  110. print_with_color(action, "INFO")
  111. return True
  112. if isinstance(action, AddTaskAction):
  113. try:
  114. self.state.plan.add_subtask(action.parent, action.goal, action.subtasks)
  115. except Exception as e:
  116. observation = AgentErrorObservation(str(e))
  117. print_with_color(observation, "ERROR")
  118. traceback.print_exc()
  119. elif isinstance(action, ModifyTaskAction):
  120. try:
  121. self.state.plan.set_subtask_state(action.id, action.state)
  122. except Exception as e:
  123. observation = AgentErrorObservation(str(e))
  124. print_with_color(observation, "ERROR")
  125. traceback.print_exc()
  126. if action.executable:
  127. try:
  128. observation = action.run(self)
  129. except Exception as e:
  130. observation = AgentErrorObservation(str(e))
  131. print_with_color(observation, "ERROR")
  132. traceback.print_exc()
  133. if not isinstance(observation, NullObservation):
  134. print_with_color(observation, "OBSERVATION")
  135. self.add_history(action, observation)
  136. await self._run_callbacks(observation)
  137. async def _run_callbacks(self, event):
  138. if event is None:
  139. return
  140. for callback in self.callbacks:
  141. idx = self.callbacks.index(callback)
  142. try:
  143. callback(event)
  144. except Exception as e:
  145. print("Callback error:" + str(idx), e, flush=True)
  146. pass
  147. await asyncio.sleep(0.001) # Give back control for a tick, so we can await in callbacks