agent_controller.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import asyncio
  2. from typing import Callable, List, Type
  3. from opendevin import config
  4. from opendevin.schema.config import ConfigType
  5. from opendevin.action import (
  6. Action,
  7. AgentFinishAction,
  8. AgentDelegateAction,
  9. NullAction,
  10. )
  11. from opendevin.observation import (
  12. Observation,
  13. AgentErrorObservation,
  14. AgentDelegateObservation,
  15. NullObservation,
  16. )
  17. from opendevin.agent import Agent
  18. from opendevin.exceptions import AgentMalformedActionError, AgentNoActionError, MaxCharsExceedError, LLMOutputError
  19. from opendevin.logger import opendevin_logger as logger
  20. from opendevin.plan import Plan
  21. from opendevin.state import State
  22. from opendevin.action.tasks import TaskStateChangedAction
  23. from opendevin.schema import TaskState
  24. from opendevin.controller.action_manager import ActionManager
  25. MAX_ITERATIONS = config.get(ConfigType.MAX_ITERATIONS)
  26. MAX_CHARS = config.get(ConfigType.MAX_CHARS)
  27. class AgentController:
  28. id: str
  29. agent: Agent
  30. max_iterations: int
  31. action_manager: ActionManager
  32. callbacks: List[Callable]
  33. delegate: 'AgentController | None' = None
  34. state: State | None = None
  35. _task_state: TaskState = TaskState.INIT
  36. _cur_step: int = 0
  37. def __init__(
  38. self,
  39. agent: Agent,
  40. inputs: dict = {},
  41. sid: str = 'default',
  42. max_iterations: int = MAX_ITERATIONS,
  43. max_chars: int = MAX_CHARS,
  44. callbacks: List[Callable] = [],
  45. ):
  46. self.id = sid
  47. self.agent = agent
  48. self.max_iterations = max_iterations
  49. self.action_manager = ActionManager(self.id)
  50. self.max_chars = max_chars
  51. self.callbacks = callbacks
  52. # Initialize agent-required plugins for sandbox (if any)
  53. self.action_manager.init_sandbox_plugins(agent.sandbox_plugins)
  54. def update_state_for_step(self, i):
  55. if self.state is None:
  56. return
  57. self.state.iteration = i
  58. self.state.background_commands_obs = self.action_manager.get_background_obs()
  59. def update_state_after_step(self):
  60. if self.state is None:
  61. return
  62. self.state.updated_info = []
  63. def add_history(self, action: Action, observation: Observation):
  64. if self.state is None:
  65. return
  66. if not isinstance(action, Action):
  67. raise TypeError(
  68. f'action must be an instance of Action, got {type(action).__name__} instead'
  69. )
  70. if not isinstance(observation, Observation):
  71. raise TypeError(
  72. f'observation must be an instance of Observation, got {type(observation).__name__} instead'
  73. )
  74. self.state.history.append((action, observation))
  75. self.state.updated_info.append((action, observation))
  76. async def _run(self):
  77. if self.state is None:
  78. return
  79. if self._task_state != TaskState.RUNNING:
  80. raise ValueError('Task is not in running state')
  81. for i in range(self._cur_step, self.max_iterations):
  82. self._cur_step = i
  83. try:
  84. finished = await self.step(i)
  85. if finished:
  86. self._task_state = TaskState.FINISHED
  87. except Exception:
  88. logger.error('Error in loop', exc_info=True)
  89. await self._run_callbacks(
  90. AgentErrorObservation('Oops! Something went wrong while completing your task. You can check the logs for more info.'))
  91. await self.set_task_state_to(TaskState.STOPPED)
  92. break
  93. if self._task_state == TaskState.FINISHED:
  94. logger.info('Task finished by agent')
  95. await self.reset_task()
  96. break
  97. elif self._task_state == TaskState.STOPPED:
  98. logger.info('Task stopped by user')
  99. await self.reset_task()
  100. break
  101. elif self._task_state == TaskState.PAUSED:
  102. logger.info('Task paused')
  103. self._cur_step = i + 1
  104. await self.notify_task_state_changed()
  105. break
  106. if self._is_stuck():
  107. logger.info('Loop detected, stopping task')
  108. observation = AgentErrorObservation('I got stuck into a loop, the task has stopped.')
  109. await self._run_callbacks(observation)
  110. await self.set_task_state_to(TaskState.STOPPED)
  111. break
  112. async def setup_task(self, task: str, inputs: dict = {}):
  113. """Sets up the agent controller with a task.
  114. """
  115. self._task_state = TaskState.RUNNING
  116. await self.notify_task_state_changed()
  117. self.state = State(Plan(task))
  118. self.state.inputs = inputs
  119. async def start(self, task: str):
  120. """Starts the agent controller with a task.
  121. If task already run before, it will continue from the last step.
  122. """
  123. await self.setup_task(task)
  124. await self._run()
  125. async def resume(self):
  126. if self.state is None:
  127. raise ValueError('No task to resume')
  128. self._task_state = TaskState.RUNNING
  129. await self.notify_task_state_changed()
  130. await self._run()
  131. async def reset_task(self):
  132. self.state = None
  133. self._cur_step = 0
  134. self._task_state = TaskState.INIT
  135. self.agent.reset()
  136. await self.notify_task_state_changed()
  137. async def set_task_state_to(self, state: TaskState):
  138. self._task_state = state
  139. if state == TaskState.STOPPED:
  140. await self.reset_task()
  141. logger.info(f'Task state set to {state}')
  142. def get_task_state(self):
  143. """Returns the current state of the agent task."""
  144. return self._task_state
  145. async def notify_task_state_changed(self):
  146. await self._run_callbacks(TaskStateChangedAction(self._task_state))
  147. async def start_delegate(self, action: AgentDelegateAction):
  148. AgentCls: Type[Agent] = Agent.get_cls(action.agent)
  149. agent = AgentCls(llm=self.agent.llm)
  150. self.delegate = AgentController(
  151. sid=self.id + '-delegate',
  152. agent=agent,
  153. max_iterations=self.max_iterations,
  154. max_chars=self.max_chars,
  155. callbacks=self.callbacks,
  156. )
  157. task = action.inputs.get('task') or ''
  158. await self.delegate.setup_task(task, action.inputs)
  159. async def step(self, i: int) -> bool:
  160. if self.state is None:
  161. raise ValueError('No task to run')
  162. if self.delegate is not None:
  163. delegate_done = await self.delegate.step(i)
  164. if delegate_done:
  165. outputs = self.delegate.state.outputs if self.delegate.state else {}
  166. obs: Observation = AgentDelegateObservation(content='', outputs=outputs)
  167. self.add_history(NullAction(), obs)
  168. self.delegate = None
  169. self.delegateAction = None
  170. return False
  171. logger.info(f'STEP {i}', extra={'msg_type': 'STEP'})
  172. logger.info(self.state.plan.main_goal, extra={'msg_type': 'PLAN'})
  173. if self.state.num_of_chars > self.max_chars:
  174. raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
  175. log_obs = self.action_manager.get_background_obs()
  176. for obs in log_obs:
  177. self.add_history(NullAction(), obs)
  178. await self._run_callbacks(obs)
  179. logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
  180. self.update_state_for_step(i)
  181. action: Action = NullAction()
  182. observation: Observation = NullObservation('')
  183. try:
  184. action = self.agent.step(self.state)
  185. if action is None:
  186. raise AgentNoActionError('No action was returned')
  187. except (AgentMalformedActionError, AgentNoActionError, LLMOutputError) as e:
  188. observation = AgentErrorObservation(str(e))
  189. logger.info(action, extra={'msg_type': 'ACTION'})
  190. self.update_state_after_step()
  191. await self._run_callbacks(action)
  192. finished = isinstance(action, AgentFinishAction)
  193. if finished:
  194. self.state.outputs = action.outputs # type: ignore[attr-defined]
  195. logger.info(action, extra={'msg_type': 'INFO'})
  196. return True
  197. if isinstance(observation, NullObservation):
  198. observation = await self.action_manager.run_action(action, self)
  199. if not isinstance(observation, NullObservation):
  200. logger.info(observation, extra={'msg_type': 'OBSERVATION'})
  201. self.add_history(action, observation)
  202. await self._run_callbacks(observation)
  203. return False
  204. async def _run_callbacks(self, event):
  205. if event is None:
  206. return
  207. for callback in self.callbacks:
  208. idx = self.callbacks.index(callback)
  209. try:
  210. await callback(event)
  211. except Exception as e:
  212. logger.exception(f'Callback error: {e}, idx: {idx}')
  213. await asyncio.sleep(
  214. 0.001
  215. ) # Give back control for a tick, so we can await in callbacks
  216. def get_state(self):
  217. return self.state
  218. def _is_stuck(self):
  219. if self.state is None or self.state.history is None or len(self.state.history) < 3:
  220. return False
  221. # if the last three (Action, Observation) tuples are too repetitive
  222. # the agent got stuck in a loop
  223. if all(
  224. [self.state.history[-i][0] == self.state.history[-3][0] for i in range(1, 3)]
  225. ):
  226. # it repeats same action, give it a chance, but not if:
  227. if (all
  228. (isinstance(self.state.history[-i][1], NullObservation) for i in range(1, 4))):
  229. # same (Action, NullObservation): like 'think' the same thought over and over
  230. logger.debug('Action, NullObservation loop detected')
  231. return True
  232. elif (all
  233. (isinstance(self.state.history[-i][1], AgentErrorObservation) for i in range(1, 4))):
  234. # (NullAction, AgentErrorObservation): errors coming from an exception
  235. # (Action, AgentErrorObservation): the same action getting an error, even if not necessarily the same error
  236. logger.debug('Action, AgentErrorObservation loop detected')
  237. return True
  238. return False