agent_controller.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. import asyncio
  2. from typing import Callable, List, Type
  3. from opendevin import config
  4. from opendevin.action import (
  5. Action,
  6. AgentFinishAction,
  7. AgentDelegateAction,
  8. NullAction,
  9. )
  10. from opendevin.observation import (
  11. Observation,
  12. AgentErrorObservation,
  13. AgentDelegateObservation,
  14. NullObservation,
  15. )
  16. from opendevin.agent import Agent
  17. from opendevin.exceptions import AgentMalformedActionError, AgentNoActionError, MaxCharsExceedError
  18. from opendevin.logger import opendevin_logger as logger
  19. from opendevin.plan import Plan
  20. from opendevin.state import State
  21. from opendevin.action.tasks import TaskStateChangedAction
  22. from opendevin.schema import TaskState
  23. from opendevin.controller.action_manager import ActionManager
  24. MAX_ITERATIONS = config.get('MAX_ITERATIONS')
  25. MAX_CHARS = config.get('MAX_CHARS')
  26. class AgentController:
  27. id: str
  28. agent: Agent
  29. max_iterations: int
  30. action_manager: ActionManager
  31. callbacks: List[Callable]
  32. delegate: 'AgentController | None' = None
  33. state: State | None = None
  34. _task_state: TaskState = TaskState.INIT
  35. _cur_step: int = 0
  36. def __init__(
  37. self,
  38. agent: Agent,
  39. inputs: dict = {},
  40. sid: str = 'default',
  41. max_iterations: int = MAX_ITERATIONS,
  42. max_chars: int = MAX_CHARS,
  43. callbacks: List[Callable] = [],
  44. ):
  45. self.id = sid
  46. self.agent = agent
  47. self.max_iterations = max_iterations
  48. self.action_manager = ActionManager(self.id)
  49. self.max_chars = max_chars
  50. self.callbacks = callbacks
  51. # Initialize agent-required plugins for sandbox (if any)
  52. self.action_manager.init_sandbox_plugins(agent.sandbox_plugins)
  53. def update_state_for_step(self, i):
  54. if self.state is None:
  55. return
  56. self.state.iteration = i
  57. self.state.background_commands_obs = self.action_manager.get_background_obs()
  58. def update_state_after_step(self):
  59. if self.state is None:
  60. return
  61. self.state.updated_info = []
  62. def add_history(self, action: Action, observation: Observation):
  63. if self.state is None:
  64. return
  65. if not isinstance(action, Action):
  66. raise TypeError(
  67. f'action must be an instance of Action, got {type(action).__name__} instead'
  68. )
  69. if not isinstance(observation, Observation):
  70. raise TypeError(
  71. f'observation must be an instance of Observation, got {type(observation).__name__} instead'
  72. )
  73. self.state.history.append((action, observation))
  74. self.state.updated_info.append((action, observation))
  75. async def _run(self):
  76. if self.state is None:
  77. return
  78. if self._task_state != TaskState.RUNNING:
  79. raise ValueError('Task is not in running state')
  80. for i in range(self._cur_step, self.max_iterations):
  81. self._cur_step = i
  82. try:
  83. finished = await self.step(i)
  84. if finished:
  85. self._task_state = TaskState.FINISHED
  86. except Exception as e:
  87. logger.error('Error in loop', exc_info=True)
  88. raise e
  89. if self._task_state == TaskState.FINISHED:
  90. logger.info('Task finished by agent')
  91. await self.reset_task()
  92. break
  93. elif self._task_state == TaskState.STOPPED:
  94. logger.info('Task stopped by user')
  95. await self.reset_task()
  96. break
  97. elif self._task_state == TaskState.PAUSED:
  98. logger.info('Task paused')
  99. self._cur_step = i + 1
  100. await self.notify_task_state_changed()
  101. break
  102. if self._is_stuck():
  103. logger.info('Loop detected, stopping task')
  104. observation = AgentErrorObservation('I got stuck into a loop, the task has stopped.')
  105. await self._run_callbacks(observation)
  106. await self.set_task_state_to(TaskState.STOPPED)
  107. break
  108. async def setup_task(self, task: str, inputs: dict = {}):
  109. """Sets up the agent controller with a task.
  110. """
  111. self._task_state = TaskState.RUNNING
  112. await self.notify_task_state_changed()
  113. self.state = State(Plan(task))
  114. self.state.inputs = inputs
  115. async def start(self, task: str):
  116. """Starts the agent controller with a task.
  117. If task already run before, it will continue from the last step.
  118. """
  119. await self.setup_task(task)
  120. await self._run()
  121. async def resume(self):
  122. if self.state is None:
  123. raise ValueError('No task to resume')
  124. self._task_state = TaskState.RUNNING
  125. await self.notify_task_state_changed()
  126. await self._run()
  127. async def reset_task(self):
  128. self.state = None
  129. self._cur_step = 0
  130. self._task_state = TaskState.INIT
  131. self.agent.reset()
  132. await self.notify_task_state_changed()
  133. async def set_task_state_to(self, state: TaskState):
  134. self._task_state = state
  135. if state == TaskState.STOPPED:
  136. await self.reset_task()
  137. logger.info(f'Task state set to {state}')
  138. def get_task_state(self):
  139. """Returns the current state of the agent task."""
  140. return self._task_state
  141. async def notify_task_state_changed(self):
  142. await self._run_callbacks(TaskStateChangedAction(self._task_state))
  143. async def start_delegate(self, action: AgentDelegateAction):
  144. AgentCls: Type[Agent] = Agent.get_cls(action.agent)
  145. agent = AgentCls(llm=self.agent.llm)
  146. self.delegate = AgentController(
  147. sid=self.id + '-delegate',
  148. agent=agent,
  149. max_iterations=self.max_iterations,
  150. max_chars=self.max_chars,
  151. callbacks=self.callbacks,
  152. )
  153. task = action.inputs.get('task') or ''
  154. await self.delegate.setup_task(task, action.inputs)
  155. async def step(self, i: int) -> bool:
  156. if self.state is None:
  157. raise ValueError('No task to run')
  158. if self.delegate is not None:
  159. delegate_done = await self.delegate.step(i)
  160. if delegate_done:
  161. outputs = self.delegate.state.outputs if self.delegate.state else {}
  162. obs: Observation = AgentDelegateObservation(content='', outputs=outputs)
  163. self.add_history(NullAction(), obs)
  164. self.delegate = None
  165. self.delegateAction = None
  166. return False
  167. logger.info(f'STEP {i}', extra={'msg_type': 'STEP'})
  168. logger.info(self.state.plan.main_goal, extra={'msg_type': 'PLAN'})
  169. if self.state.num_of_chars > self.max_chars:
  170. raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
  171. log_obs = self.action_manager.get_background_obs()
  172. for obs in log_obs:
  173. self.add_history(NullAction(), obs)
  174. await self._run_callbacks(obs)
  175. logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
  176. self.update_state_for_step(i)
  177. action: Action = NullAction()
  178. observation: Observation = NullObservation('')
  179. try:
  180. action = self.agent.step(self.state)
  181. if action is None:
  182. raise AgentNoActionError('No action was returned')
  183. except (AgentMalformedActionError, AgentNoActionError) as e:
  184. observation = AgentErrorObservation(str(e))
  185. logger.info(action, extra={'msg_type': 'ACTION'})
  186. self.update_state_after_step()
  187. await self._run_callbacks(action)
  188. finished = isinstance(action, AgentFinishAction)
  189. if finished:
  190. self.state.outputs = action.outputs # type: ignore[attr-defined]
  191. logger.info(action, extra={'msg_type': 'INFO'})
  192. return True
  193. if isinstance(observation, NullObservation):
  194. observation = await self.action_manager.run_action(action, self)
  195. if not isinstance(observation, NullObservation):
  196. logger.info(observation, extra={'msg_type': 'OBSERVATION'})
  197. self.add_history(action, observation)
  198. await self._run_callbacks(observation)
  199. return False
  200. async def _run_callbacks(self, event):
  201. if event is None:
  202. return
  203. for callback in self.callbacks:
  204. idx = self.callbacks.index(callback)
  205. try:
  206. await callback(event)
  207. except Exception as e:
  208. logger.exception(f'Callback error: {e}, idx: {idx}')
  209. await asyncio.sleep(
  210. 0.001
  211. ) # Give back control for a tick, so we can await in callbacks
  212. def get_state(self):
  213. return self.state
  214. def _is_stuck(self):
  215. if self.state is None or self.state.history is None or len(self.state.history) < 3:
  216. return False
  217. # if the last three (Action, Observation) tuples are too repetitive
  218. # the agent got stuck in a loop
  219. if all(
  220. [self.state.history[-i][0] == self.state.history[-3][0] for i in range(1, 3)]
  221. ):
  222. # it repeats same action, give it a chance, but not if:
  223. if (all
  224. (isinstance(self.state.history[-i][1], NullObservation) for i in range(1, 4))):
  225. # same (Action, NullObservation): like 'think' the same thought over and over
  226. logger.debug('Action, NullObservation loop detected')
  227. return True
  228. elif (all
  229. (isinstance(self.state.history[-i][1], AgentErrorObservation) for i in range(1, 4))):
  230. # (NullAction, AgentErrorObservation): errors coming from an exception
  231. # (Action, AgentErrorObservation): the same action getting an error, even if not necessarily the same error
  232. logger.debug('Action, AgentErrorObservation loop detected')
  233. return True
  234. return False