agent_controller.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import asyncio
  2. from typing import Callable, List, Type
  3. from agenthub.codeact_agent.codeact_agent import CodeActAgent
  4. from opendevin import config
  5. from opendevin.action import (
  6. Action,
  7. AgentDelegateAction,
  8. AgentFinishAction,
  9. AgentTalkAction,
  10. NullAction,
  11. )
  12. from opendevin.action.tasks import TaskStateChangedAction
  13. from opendevin.agent import Agent
  14. from opendevin.controller.action_manager import ActionManager
  15. from opendevin.exceptions import (
  16. AgentMalformedActionError,
  17. AgentNoActionError,
  18. LLMOutputError,
  19. MaxCharsExceedError,
  20. )
  21. from opendevin.logger import opendevin_logger as logger
  22. from opendevin.observation import (
  23. AgentDelegateObservation,
  24. AgentErrorObservation,
  25. NullObservation,
  26. Observation,
  27. UserMessageObservation,
  28. )
  29. from opendevin.plan import Plan
  30. from opendevin.sandbox import DockerSSHBox
  31. from opendevin.schema import TaskState
  32. from opendevin.schema.config import ConfigType
  33. from opendevin.state import State
  34. MAX_ITERATIONS = config.get(ConfigType.MAX_ITERATIONS)
  35. MAX_CHARS = config.get(ConfigType.MAX_CHARS)
  36. class AgentController:
  37. id: str
  38. agent: Agent
  39. max_iterations: int
  40. action_manager: ActionManager
  41. callbacks: List[Callable]
  42. delegate: 'AgentController | None' = None
  43. state: State | None = None
  44. _task_state: TaskState = TaskState.INIT
  45. _cur_step: int = 0
  46. def __init__(
  47. self,
  48. agent: Agent,
  49. inputs: dict = {},
  50. sid: str = 'default',
  51. max_iterations: int = MAX_ITERATIONS,
  52. max_chars: int = MAX_CHARS,
  53. callbacks: List[Callable] = [],
  54. ):
  55. self.id = sid
  56. self.agent = agent
  57. self.max_iterations = max_iterations
  58. self.action_manager = ActionManager(self.id)
  59. self.max_chars = max_chars
  60. self.callbacks = callbacks
  61. # Initialize agent-required plugins for sandbox (if any)
  62. self.action_manager.init_sandbox_plugins(agent.sandbox_plugins)
  63. if isinstance(agent, CodeActAgent) and not isinstance(self.action_manager.sandbox, DockerSSHBox):
  64. logger.warning('CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.')
  65. self._await_user_message_queue: asyncio.Queue = asyncio.Queue()
  66. def update_state_for_step(self, i):
  67. if self.state is None:
  68. return
  69. self.state.iteration = i
  70. self.state.background_commands_obs = self.action_manager.get_background_obs()
  71. def update_state_after_step(self):
  72. if self.state is None:
  73. return
  74. self.state.updated_info = []
  75. def add_history(self, action: Action, observation: Observation):
  76. if self.state is None:
  77. return
  78. if not isinstance(action, Action):
  79. raise TypeError(
  80. f'action must be an instance of Action, got {type(action).__name__} instead'
  81. )
  82. if not isinstance(observation, Observation):
  83. raise TypeError(
  84. f'observation must be an instance of Observation, got {type(observation).__name__} instead'
  85. )
  86. self.state.history.append((action, observation))
  87. self.state.updated_info.append((action, observation))
  88. async def _run(self):
  89. if self.state is None:
  90. return
  91. if self._task_state != TaskState.RUNNING:
  92. raise ValueError('Task is not in running state')
  93. for i in range(self._cur_step, self.max_iterations):
  94. self._cur_step = i
  95. try:
  96. finished = await self.step(i)
  97. if finished:
  98. self._task_state = TaskState.FINISHED
  99. except Exception:
  100. logger.error('Error in loop', exc_info=True)
  101. await self._run_callbacks(
  102. AgentErrorObservation('Oops! Something went wrong while completing your task. You can check the logs for more info.'))
  103. await self.set_task_state_to(TaskState.STOPPED)
  104. break
  105. if self._task_state == TaskState.FINISHED:
  106. logger.info('Task finished by agent')
  107. await self.reset_task()
  108. break
  109. elif self._task_state == TaskState.STOPPED:
  110. logger.info('Task stopped by user')
  111. await self.reset_task()
  112. break
  113. elif self._task_state == TaskState.PAUSED:
  114. logger.info('Task paused')
  115. self._cur_step = i + 1
  116. await self.notify_task_state_changed()
  117. break
  118. if self._is_stuck():
  119. logger.info('Loop detected, stopping task')
  120. observation = AgentErrorObservation('I got stuck into a loop, the task has stopped.')
  121. await self._run_callbacks(observation)
  122. await self.set_task_state_to(TaskState.STOPPED)
  123. break
  124. async def setup_task(self, task: str, inputs: dict = {}):
  125. """Sets up the agent controller with a task.
  126. """
  127. self._task_state = TaskState.RUNNING
  128. await self.notify_task_state_changed()
  129. self.state = State(Plan(task))
  130. self.state.inputs = inputs
  131. async def start(self, task: str):
  132. """Starts the agent controller with a task.
  133. If task already run before, it will continue from the last step.
  134. """
  135. await self.setup_task(task)
  136. await self._run()
  137. async def resume(self):
  138. if self.state is None:
  139. raise ValueError('No task to resume')
  140. self._task_state = TaskState.RUNNING
  141. await self.notify_task_state_changed()
  142. await self._run()
  143. async def reset_task(self):
  144. self.state = None
  145. self._cur_step = 0
  146. self._task_state = TaskState.INIT
  147. self.agent.reset()
  148. await self.notify_task_state_changed()
  149. async def set_task_state_to(self, state: TaskState):
  150. self._task_state = state
  151. if state == TaskState.STOPPED:
  152. await self.reset_task()
  153. logger.info(f'Task state set to {state}')
  154. def get_task_state(self):
  155. """Returns the current state of the agent task."""
  156. return self._task_state
  157. async def notify_task_state_changed(self):
  158. await self._run_callbacks(TaskStateChangedAction(self._task_state))
  159. async def add_user_message(self, message: UserMessageObservation):
  160. if self.state is None:
  161. return
  162. if self._task_state == TaskState.AWAITING_USER_INPUT:
  163. self._await_user_message_queue.put_nowait(message)
  164. # set the task state to running
  165. self._task_state = TaskState.RUNNING
  166. await self.notify_task_state_changed()
  167. elif self._task_state == TaskState.RUNNING:
  168. self.add_history(NullAction(), message)
  169. else:
  170. raise ValueError(f'Task (state: {self._task_state}) is not in a state to add user message')
  171. async def wait_for_user_input(self) -> UserMessageObservation:
  172. self._task_state = TaskState.AWAITING_USER_INPUT
  173. await self.notify_task_state_changed()
  174. # wait for the next user message
  175. if len(self.callbacks) == 0:
  176. logger.info('Use STDIN to request user message as no callbacks are registered', extra={'msg_type': 'INFO'})
  177. message = input('Request user input [type /exit to stop interaction] >> ')
  178. user_message_observation = UserMessageObservation(message)
  179. else:
  180. user_message_observation = await self._await_user_message_queue.get()
  181. self._await_user_message_queue.task_done()
  182. return user_message_observation
  183. async def start_delegate(self, action: AgentDelegateAction):
  184. AgentCls: Type[Agent] = Agent.get_cls(action.agent)
  185. agent = AgentCls(llm=self.agent.llm)
  186. self.delegate = AgentController(
  187. sid=self.id + '-delegate',
  188. agent=agent,
  189. max_iterations=self.max_iterations,
  190. max_chars=self.max_chars,
  191. callbacks=self.callbacks,
  192. )
  193. task = action.inputs.get('task') or ''
  194. await self.delegate.setup_task(task, action.inputs)
  195. async def step(self, i: int) -> bool:
  196. if self.state is None:
  197. raise ValueError('No task to run')
  198. if self.delegate is not None:
  199. delegate_done = await self.delegate.step(i)
  200. if delegate_done:
  201. outputs = self.delegate.state.outputs if self.delegate.state else {}
  202. obs: Observation = AgentDelegateObservation(content='', outputs=outputs)
  203. self.add_history(NullAction(), obs)
  204. self.delegate = None
  205. self.delegateAction = None
  206. return False
  207. logger.info(f'STEP {i}', extra={'msg_type': 'STEP'})
  208. if i == 0:
  209. logger.info(self.state.plan.main_goal, extra={'msg_type': 'PLAN'})
  210. if self.state.num_of_chars > self.max_chars:
  211. raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
  212. log_obs = self.action_manager.get_background_obs()
  213. for obs in log_obs:
  214. self.add_history(NullAction(), obs)
  215. await self._run_callbacks(obs)
  216. logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
  217. self.update_state_for_step(i)
  218. action: Action = NullAction()
  219. observation: Observation = NullObservation('')
  220. try:
  221. action = self.agent.step(self.state)
  222. if action is None:
  223. raise AgentNoActionError('No action was returned')
  224. except (AgentMalformedActionError, AgentNoActionError, LLMOutputError) as e:
  225. observation = AgentErrorObservation(str(e))
  226. logger.info(action, extra={'msg_type': 'ACTION'})
  227. self.update_state_after_step()
  228. await self._run_callbacks(action)
  229. # whether to await for user messages
  230. if isinstance(action, AgentTalkAction):
  231. # await for the next user messages
  232. user_message_observation = await self.wait_for_user_input()
  233. logger.info(user_message_observation, extra={'msg_type': 'OBSERVATION'})
  234. self.add_history(action, user_message_observation)
  235. return False
  236. finished = isinstance(action, AgentFinishAction)
  237. if finished:
  238. self.state.outputs = action.outputs # type: ignore[attr-defined]
  239. logger.info(action, extra={'msg_type': 'INFO'})
  240. return True
  241. if isinstance(observation, NullObservation):
  242. observation = await self.action_manager.run_action(action, self)
  243. if not isinstance(observation, NullObservation):
  244. logger.info(observation, extra={'msg_type': 'OBSERVATION'})
  245. self.add_history(action, observation)
  246. await self._run_callbacks(observation)
  247. return False
  248. async def _run_callbacks(self, event):
  249. if event is None:
  250. return
  251. for callback in self.callbacks:
  252. idx = self.callbacks.index(callback)
  253. try:
  254. await callback(event)
  255. except Exception as e:
  256. logger.exception(f'Callback error: {e}, idx: {idx}')
  257. await asyncio.sleep(
  258. 0.001
  259. ) # Give back control for a tick, so we can await in callbacks
  260. def get_state(self):
  261. return self.state
  262. def _is_stuck(self):
  263. if self.state is None or self.state.history is None or len(self.state.history) < 3:
  264. return False
  265. # if the last three (Action, Observation) tuples are too repetitive
  266. # the agent got stuck in a loop
  267. if all(
  268. [self.state.history[-i][0] == self.state.history[-3][0] for i in range(1, 3)]
  269. ):
  270. # it repeats same action, give it a chance, but not if:
  271. if (all
  272. (isinstance(self.state.history[-i][1], NullObservation) for i in range(1, 4))):
  273. # same (Action, NullObservation): like 'think' the same thought over and over
  274. logger.debug('Action, NullObservation loop detected')
  275. return True
  276. elif (all
  277. (isinstance(self.state.history[-i][1], AgentErrorObservation) for i in range(1, 4))):
  278. # (NullAction, AgentErrorObservation): errors coming from an exception
  279. # (Action, AgentErrorObservation): the same action getting an error, even if not necessarily the same error
  280. logger.debug('Action, AgentErrorObservation loop detected')
  281. return True
  282. return False