agent_controller.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import asyncio
  2. import inspect
  3. import traceback
  4. import time
  5. from typing import List, Callable, Literal, Mapping, Awaitable, Any, cast
  6. from termcolor import colored
  7. from litellm.exceptions import APIConnectionError
  8. from openai import AuthenticationError
  9. from opendevin import config
  10. from opendevin.action import (
  11. Action,
  12. NullAction,
  13. AgentFinishAction,
  14. AddTaskAction,
  15. ModifyTaskAction,
  16. )
  17. from opendevin.agent import Agent
  18. from opendevin.logger import opendevin_logger as logger
  19. from opendevin.exceptions import MaxCharsExceedError, AgentNoActionError
  20. from opendevin.observation import Observation, AgentErrorObservation, NullObservation
  21. from opendevin.plan import Plan
  22. from opendevin.state import State
  23. from .command_manager import CommandManager
  24. ColorType = Literal[
  25. 'red',
  26. 'green',
  27. 'yellow',
  28. 'blue',
  29. 'magenta',
  30. 'cyan',
  31. 'light_grey',
  32. 'dark_grey',
  33. 'light_red',
  34. 'light_green',
  35. 'light_yellow',
  36. 'light_blue',
  37. 'light_magenta',
  38. 'light_cyan',
  39. 'white',
  40. ]
  41. DISABLE_COLOR_PRINTING = (
  42. config.get('DISABLE_COLOR').lower() == 'true'
  43. )
  44. MAX_ITERATIONS = config.get('MAX_ITERATIONS')
  45. MAX_CHARS = config.get('MAX_CHARS')
  46. def print_with_color(text: Any, print_type: str = 'INFO'):
  47. TYPE_TO_COLOR: Mapping[str, ColorType] = {
  48. 'BACKGROUND LOG': 'blue',
  49. 'ACTION': 'green',
  50. 'OBSERVATION': 'yellow',
  51. 'INFO': 'cyan',
  52. 'ERROR': 'red',
  53. 'PLAN': 'light_magenta',
  54. }
  55. color = TYPE_TO_COLOR.get(print_type.upper(), TYPE_TO_COLOR['INFO'])
  56. if DISABLE_COLOR_PRINTING:
  57. print(f'\n{print_type.upper()}:\n{str(text)}', flush=True)
  58. else:
  59. print(
  60. colored(f'\n{print_type.upper()}:\n', color, attrs=['bold'])
  61. + colored(str(text), color),
  62. flush=True,
  63. )
  64. class AgentController:
  65. id: str
  66. agent: Agent
  67. max_iterations: int
  68. command_manager: CommandManager
  69. callbacks: List[Callable]
  70. def __init__(
  71. self,
  72. agent: Agent,
  73. sid: str = '',
  74. max_iterations: int = MAX_ITERATIONS,
  75. max_chars: int = MAX_CHARS,
  76. container_image: str | None = None,
  77. callbacks: List[Callable] = [],
  78. ):
  79. self.id = sid
  80. self.agent = agent
  81. self.max_iterations = max_iterations
  82. self.command_manager = CommandManager(self.id, container_image)
  83. self.max_chars = max_chars
  84. self.callbacks = callbacks
  85. def update_state_for_step(self, i):
  86. self.state.iteration = i
  87. self.state.background_commands_obs = self.command_manager.get_background_obs()
  88. def update_state_after_step(self):
  89. self.state.updated_info = []
  90. def add_history(self, action: Action, observation: Observation):
  91. if not isinstance(action, Action):
  92. raise TypeError(
  93. f'action must be an instance of Action, got {type(action).__name__} instead')
  94. if not isinstance(observation, Observation):
  95. raise TypeError(
  96. f'observation must be an instance of Observation, got {type(observation).__name__} instead')
  97. self.state.history.append((action, observation))
  98. self.state.updated_info.append((action, observation))
  99. async def start_loop(self, task: str):
  100. finished = False
  101. plan = Plan(task)
  102. self.state = State(plan)
  103. for i in range(self.max_iterations):
  104. try:
  105. finished = await self.step(i)
  106. except Exception as e:
  107. logger.error('Error in loop', exc_info=True)
  108. raise e
  109. if finished:
  110. break
  111. if not finished:
  112. logger.info('Exited before finishing the task.')
  113. async def step(self, i: int):
  114. print('\n\n==============', flush=True)
  115. print('STEP', i, flush=True)
  116. print_with_color(self.state.plan.main_goal, 'PLAN')
  117. if self.state.num_of_chars > self.max_chars:
  118. raise MaxCharsExceedError(
  119. self.state.num_of_chars, self.max_chars)
  120. log_obs = self.command_manager.get_background_obs()
  121. for obs in log_obs:
  122. self.add_history(NullAction(), obs)
  123. await self._run_callbacks(obs)
  124. print_with_color(obs, 'BACKGROUND LOG')
  125. self.update_state_for_step(i)
  126. action: Action = NullAction()
  127. observation: Observation = NullObservation('')
  128. try:
  129. action = self.agent.step(self.state)
  130. if action is None:
  131. raise AgentNoActionError()
  132. print_with_color(action, 'ACTION')
  133. except Exception as e:
  134. observation = AgentErrorObservation(str(e))
  135. print_with_color(observation, 'ERROR')
  136. traceback.print_exc()
  137. if isinstance(e, APIConnectionError):
  138. time.sleep(3)
  139. # raise specific exceptions that need to be handled outside
  140. # note: we are using AuthenticationError class from openai rather than
  141. # litellm because:
  142. # 1) litellm.exceptions.AuthenticationError is a subclass of openai.AuthenticationError
  143. # 2) embeddings call, initiated by llama-index, has no wrapper for authentication
  144. # errors. This means we have to catch individual authentication errors
  145. # from different providers, and OpenAI is one of these.
  146. if isinstance(e, (AuthenticationError, AgentNoActionError)):
  147. raise
  148. self.update_state_after_step()
  149. await self._run_callbacks(action)
  150. finished = isinstance(action, AgentFinishAction)
  151. if finished:
  152. print_with_color(action, 'INFO')
  153. return True
  154. if isinstance(action, AddTaskAction):
  155. try:
  156. self.state.plan.add_subtask(
  157. action.parent, action.goal, action.subtasks)
  158. except Exception as e:
  159. observation = AgentErrorObservation(str(e))
  160. print_with_color(observation, 'ERROR')
  161. traceback.print_exc()
  162. elif isinstance(action, ModifyTaskAction):
  163. try:
  164. self.state.plan.set_subtask_state(action.id, action.state)
  165. except Exception as e:
  166. observation = AgentErrorObservation(str(e))
  167. print_with_color(observation, 'ERROR')
  168. traceback.print_exc()
  169. if action.executable:
  170. try:
  171. observation = action.run(self)
  172. if inspect.isawaitable(observation):
  173. observation = await cast(Awaitable[Observation], observation)
  174. except Exception as e:
  175. observation = AgentErrorObservation(str(e))
  176. print_with_color(observation, 'ERROR')
  177. traceback.print_exc()
  178. if not isinstance(observation, NullObservation):
  179. print_with_color(observation, 'OBSERVATION')
  180. self.add_history(action, observation)
  181. await self._run_callbacks(observation)
  182. async def _run_callbacks(self, event):
  183. if event is None:
  184. return
  185. for callback in self.callbacks:
  186. idx = self.callbacks.index(callback)
  187. try:
  188. callback(event)
  189. except Exception as e:
  190. logger.exception(f'Callback error: {e}, idx: {idx}')
  191. await asyncio.sleep(
  192. 0.001
  193. ) # Give back control for a tick, so we can await in callbacks