agent_controller.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import asyncio
  2. import traceback
  3. from typing import Callable, List
  4. from openai import AuthenticationError, APIConnectionError
  5. from litellm import ContextWindowExceededError
  6. from opendevin import config
  7. from opendevin.action import (
  8. Action,
  9. AgentFinishAction,
  10. NullAction,
  11. )
  12. from opendevin.agent import Agent
  13. from opendevin.exceptions import AgentNoActionError, MaxCharsExceedError
  14. from opendevin.logger import opendevin_logger as logger
  15. from opendevin.observation import AgentErrorObservation, NullObservation, Observation
  16. from opendevin.plan import Plan
  17. from opendevin.state import State
  18. from opendevin.action.tasks import TaskStateChangedAction
  19. from opendevin.schema import TaskState
  20. from opendevin.controller.action_manager import ActionManager
  21. MAX_ITERATIONS = config.get('MAX_ITERATIONS')
  22. MAX_CHARS = config.get('MAX_CHARS')
  23. class AgentController:
  24. id: str
  25. agent: Agent
  26. max_iterations: int
  27. action_manager: ActionManager
  28. callbacks: List[Callable]
  29. state: State | None = None
  30. _task_state: TaskState = TaskState.INIT
  31. _cur_step: int = 0
  32. def __init__(
  33. self,
  34. agent: Agent,
  35. sid: str = '',
  36. max_iterations: int = MAX_ITERATIONS,
  37. max_chars: int = MAX_CHARS,
  38. container_image: str | None = None,
  39. callbacks: List[Callable] = [],
  40. ):
  41. self.id = sid
  42. self.agent = agent
  43. self.max_iterations = max_iterations
  44. self.action_manager = ActionManager(self.id, container_image)
  45. self.max_chars = max_chars
  46. self.callbacks = callbacks
  47. # Initialize agent-required plugins for sandbox (if any)
  48. self.action_manager.init_sandbox_plugins(agent.sandbox_plugins)
  49. def update_state_for_step(self, i):
  50. if self.state is None:
  51. return
  52. self.state.iteration = i
  53. self.state.background_commands_obs = self.action_manager.get_background_obs()
  54. def update_state_after_step(self):
  55. if self.state is None:
  56. return
  57. self.state.updated_info = []
  58. def add_history(self, action: Action, observation: Observation):
  59. if self.state is None:
  60. return
  61. if not isinstance(action, Action):
  62. raise TypeError(
  63. f'action must be an instance of Action, got {type(action).__name__} instead'
  64. )
  65. if not isinstance(observation, Observation):
  66. raise TypeError(
  67. f'observation must be an instance of Observation, got {type(observation).__name__} instead'
  68. )
  69. self.state.history.append((action, observation))
  70. self.state.updated_info.append((action, observation))
  71. async def _run(self):
  72. if self.state is None:
  73. return
  74. if self._task_state != TaskState.RUNNING:
  75. raise ValueError('Task is not in running state')
  76. for i in range(self._cur_step, self.max_iterations):
  77. self._cur_step = i
  78. try:
  79. finished = await self.step(i)
  80. if finished:
  81. self._task_state = TaskState.FINISHED
  82. except Exception as e:
  83. logger.error('Error in loop', exc_info=True)
  84. raise e
  85. if self._task_state == TaskState.FINISHED:
  86. logger.info('Task finished by agent')
  87. await self.reset_task()
  88. break
  89. elif self._task_state == TaskState.STOPPED:
  90. logger.info('Task stopped by user')
  91. await self.reset_task()
  92. break
  93. elif self._task_state == TaskState.PAUSED:
  94. logger.info('Task paused')
  95. self._cur_step = i + 1
  96. await self.notify_task_state_changed()
  97. break
  98. if self._is_stuck():
  99. logger.info('Loop detected, stopping task')
  100. observation = AgentErrorObservation('I got stuck into a loop, the task has stopped.')
  101. await self._run_callbacks(observation)
  102. await self.set_task_state_to(TaskState.STOPPED)
  103. break
  104. async def start(self, task: str):
  105. """Starts the agent controller with a task.
  106. If task already run before, it will continue from the last step.
  107. """
  108. self._task_state = TaskState.RUNNING
  109. await self.notify_task_state_changed()
  110. self.state = State(Plan(task))
  111. await self._run()
  112. async def resume(self):
  113. if self.state is None:
  114. raise ValueError('No task to resume')
  115. self._task_state = TaskState.RUNNING
  116. await self.notify_task_state_changed()
  117. await self._run()
  118. async def reset_task(self):
  119. self.state = None
  120. self._cur_step = 0
  121. self._task_state = TaskState.INIT
  122. self.agent.reset()
  123. await self.notify_task_state_changed()
  124. async def set_task_state_to(self, state: TaskState):
  125. self._task_state = state
  126. if state == TaskState.STOPPED:
  127. await self.reset_task()
  128. logger.info(f'Task state set to {state}')
  129. def get_task_state(self):
  130. """Returns the current state of the agent task."""
  131. return self._task_state
  132. async def notify_task_state_changed(self):
  133. await self._run_callbacks(TaskStateChangedAction(self._task_state))
  134. async def step(self, i: int):
  135. if self.state is None:
  136. return
  137. logger.info(f'STEP {i}', extra={'msg_type': 'STEP'})
  138. logger.info(self.state.plan.main_goal, extra={'msg_type': 'PLAN'})
  139. if self.state.num_of_chars > self.max_chars:
  140. raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
  141. log_obs = self.action_manager.get_background_obs()
  142. for obs in log_obs:
  143. self.add_history(NullAction(), obs)
  144. await self._run_callbacks(obs)
  145. logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
  146. self.update_state_for_step(i)
  147. action: Action = NullAction()
  148. observation: Observation = NullObservation('')
  149. try:
  150. action = self.agent.step(self.state)
  151. if action is None:
  152. raise AgentNoActionError()
  153. logger.info(action, extra={'msg_type': 'ACTION'})
  154. except Exception as e:
  155. observation = AgentErrorObservation(str(e))
  156. logger.error(e)
  157. logger.debug(traceback.format_exc())
  158. # raise specific exceptions that need to be handled outside
  159. # note: we are using classes from openai rather than litellm because:
  160. # 1) litellm.exceptions.AuthenticationError is a subclass of openai.AuthenticationError
  161. # 2) embeddings call, initiated by llama-index, has no wrapper for errors.
  162. # This means we have to catch individual authentication errors
  163. # from different providers, and OpenAI is one of these.
  164. if isinstance(e, (AuthenticationError, ContextWindowExceededError, APIConnectionError)):
  165. raise
  166. self.update_state_after_step()
  167. await self._run_callbacks(action)
  168. finished = isinstance(action, AgentFinishAction)
  169. if finished:
  170. logger.info(action, extra={'msg_type': 'INFO'})
  171. return True
  172. if isinstance(observation, NullObservation):
  173. observation = await self.action_manager.run_action(action, self)
  174. if not isinstance(observation, NullObservation):
  175. logger.info(observation, extra={'msg_type': 'OBSERVATION'})
  176. self.add_history(action, observation)
  177. await self._run_callbacks(observation)
  178. async def _run_callbacks(self, event):
  179. if event is None:
  180. return
  181. for callback in self.callbacks:
  182. idx = self.callbacks.index(callback)
  183. try:
  184. await callback(event)
  185. except Exception as e:
  186. logger.exception(f'Callback error: {e}, idx: {idx}')
  187. await asyncio.sleep(
  188. 0.001
  189. ) # Give back control for a tick, so we can await in callbacks
  190. def get_state(self):
  191. return self.state
  192. def _is_stuck(self):
  193. if self.state is None or self.state.history is None or len(self.state.history) < 3:
  194. return False
  195. # if the last three (Action, Observation) tuples are too repetitive
  196. # the agent got stuck in a loop
  197. if all(
  198. [self.state.history[-i][0] == self.state.history[-3][0] for i in range(1, 3)]
  199. ):
  200. # it repeats same action, give it a chance, but not if:
  201. if (all
  202. (isinstance(self.state.history[-i][1], NullObservation) for i in range(1, 4))):
  203. # same (Action, NullObservation): like 'think' the same thought over and over
  204. logger.debug('Action, NullObservation loop detected')
  205. return True
  206. elif (all
  207. (isinstance(self.state.history[-i][1], AgentErrorObservation) for i in range(1, 4))):
  208. # (NullAction, AgentErrorObservation): errors coming from an exception
  209. # (Action, AgentErrorObservation): the same action getting an error, even if not necessarily the same error
  210. logger.debug('Action, AgentErrorObservation loop detected')
  211. return True
  212. return False