agent_controller.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. import asyncio
  2. from typing import Callable, List, Type
  3. from agenthub.codeact_agent.codeact_agent import CodeActAgent
  4. from opendevin import config
  5. from opendevin.action import (
  6. Action,
  7. AgentDelegateAction,
  8. AgentFinishAction,
  9. AgentTalkAction,
  10. NullAction,
  11. )
  12. from opendevin.action.tasks import TaskStateChangedAction
  13. from opendevin.agent import Agent
  14. from opendevin.browser.browser_env import BrowserEnv
  15. from opendevin.controller.action_manager import ActionManager
  16. from opendevin.exceptions import (
  17. AgentMalformedActionError,
  18. AgentNoActionError,
  19. LLMOutputError,
  20. MaxCharsExceedError,
  21. )
  22. from opendevin.logger import opendevin_logger as logger
  23. from opendevin.observation import (
  24. AgentDelegateObservation,
  25. AgentErrorObservation,
  26. NullObservation,
  27. Observation,
  28. UserMessageObservation,
  29. )
  30. from opendevin.plan import Plan
  31. from opendevin.sandbox import DockerSSHBox
  32. from opendevin.schema import TaskState
  33. from opendevin.schema.config import ConfigType
  34. from opendevin.state import State
  35. MAX_ITERATIONS = config.get(ConfigType.MAX_ITERATIONS)
  36. MAX_CHARS = config.get(ConfigType.MAX_CHARS)
  37. class AgentController:
  38. id: str
  39. agent: Agent
  40. max_iterations: int
  41. action_manager: ActionManager
  42. callbacks: List[Callable]
  43. browser: BrowserEnv
  44. delegate: 'AgentController | None' = None
  45. state: State | None = None
  46. _task_state: TaskState = TaskState.INIT
  47. _cur_step: int = 0
  48. def __init__(
  49. self,
  50. agent: Agent,
  51. sid: str = 'default',
  52. max_iterations: int = MAX_ITERATIONS,
  53. max_chars: int = MAX_CHARS,
  54. callbacks: List[Callable] = [],
  55. ):
  56. """Initializes a new instance of the AgentController class.
  57. Args:
  58. agent: The agent instance to control.
  59. sid: The session ID of the agent.
  60. max_iterations: The maximum number of iterations the agent can run.
  61. max_chars: The maximum number of characters the agent can output.
  62. callbacks: A list of callback functions to run after each action.
  63. """
  64. self.id = sid
  65. self.agent = agent
  66. self.max_iterations = max_iterations
  67. self.action_manager = ActionManager(self.id)
  68. self.max_chars = max_chars
  69. self.callbacks = callbacks
  70. # Initialize agent-required plugins for sandbox (if any)
  71. self.action_manager.init_sandbox_plugins(agent.sandbox_plugins)
  72. # Initialize browser environment
  73. self.browser = BrowserEnv()
  74. if isinstance(agent, CodeActAgent) and not isinstance(
  75. self.action_manager.sandbox, DockerSSHBox
  76. ):
  77. logger.warning(
  78. 'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
  79. )
  80. self._await_user_message_queue: asyncio.Queue = asyncio.Queue()
  81. def update_state_for_step(self, i):
  82. if self.state is None:
  83. return
  84. self.state.iteration = i
  85. self.state.background_commands_obs = self.action_manager.get_background_obs()
  86. def update_state_after_step(self):
  87. if self.state is None:
  88. return
  89. self.state.updated_info = []
  90. def add_history(self, action: Action, observation: Observation):
  91. if self.state is None:
  92. return
  93. if not isinstance(action, Action):
  94. raise TypeError(
  95. f'action must be an instance of Action, got {type(action).__name__} instead'
  96. )
  97. if not isinstance(observation, Observation):
  98. raise TypeError(
  99. f'observation must be an instance of Observation, got {type(observation).__name__} instead'
  100. )
  101. self.state.history.append((action, observation))
  102. self.state.updated_info.append((action, observation))
  103. async def _run(self):
  104. if self.state is None:
  105. return
  106. if self._task_state != TaskState.RUNNING:
  107. raise ValueError('Task is not in running state')
  108. for i in range(self._cur_step, self.max_iterations):
  109. self._cur_step = i
  110. try:
  111. finished = await self.step(i)
  112. if finished:
  113. self._task_state = TaskState.FINISHED
  114. except Exception:
  115. logger.error('Error in loop', exc_info=True)
  116. await self._run_callbacks(
  117. AgentErrorObservation(
  118. 'Oops! Something went wrong while completing your task. You can check the logs for more info.'
  119. )
  120. )
  121. await self.set_task_state_to(TaskState.STOPPED)
  122. break
  123. if self._task_state == TaskState.FINISHED:
  124. logger.info('Task finished by agent')
  125. await self.reset_task()
  126. break
  127. elif self._task_state == TaskState.STOPPED:
  128. logger.info('Task stopped by user')
  129. await self.reset_task()
  130. break
  131. elif self._task_state == TaskState.PAUSED:
  132. logger.info('Task paused')
  133. self._cur_step = i + 1
  134. await self.notify_task_state_changed()
  135. break
  136. if self._is_stuck():
  137. logger.info('Loop detected, stopping task')
  138. observation = AgentErrorObservation(
  139. 'I got stuck into a loop, the task has stopped.'
  140. )
  141. await self._run_callbacks(observation)
  142. await self.set_task_state_to(TaskState.STOPPED)
  143. break
  144. async def setup_task(self, task: str, inputs: dict = {}):
  145. """Sets up the agent controller with a task."""
  146. self._task_state = TaskState.RUNNING
  147. await self.notify_task_state_changed()
  148. self.state = State(Plan(task))
  149. self.state.inputs = inputs
  150. async def start(self, task: str):
  151. """Starts the agent controller with a task.
  152. If task already run before, it will continue from the last step.
  153. """
  154. await self.setup_task(task)
  155. await self._run()
  156. async def resume(self):
  157. if self.state is None:
  158. raise ValueError('No task to resume')
  159. self._task_state = TaskState.RUNNING
  160. await self.notify_task_state_changed()
  161. await self._run()
  162. async def reset_task(self):
  163. self.state = None
  164. self._cur_step = 0
  165. self._task_state = TaskState.INIT
  166. self.agent.reset()
  167. await self.notify_task_state_changed()
  168. async def set_task_state_to(self, state: TaskState):
  169. self._task_state = state
  170. if state == TaskState.STOPPED:
  171. await self.reset_task()
  172. logger.info(f'Task state set to {state}')
  173. def get_task_state(self):
  174. """Returns the current state of the agent task."""
  175. return self._task_state
  176. async def notify_task_state_changed(self):
  177. await self._run_callbacks(TaskStateChangedAction(self._task_state))
  178. async def add_user_message(self, message: UserMessageObservation):
  179. if self.state is None:
  180. return
  181. if self._task_state == TaskState.AWAITING_USER_INPUT:
  182. self._await_user_message_queue.put_nowait(message)
  183. # set the task state to running
  184. self._task_state = TaskState.RUNNING
  185. await self.notify_task_state_changed()
  186. elif self._task_state == TaskState.RUNNING:
  187. self.add_history(NullAction(), message)
  188. else:
  189. raise ValueError(
  190. f'Task (state: {self._task_state}) is not in a state to add user message'
  191. )
  192. async def wait_for_user_input(self) -> UserMessageObservation:
  193. self._task_state = TaskState.AWAITING_USER_INPUT
  194. await self.notify_task_state_changed()
  195. # wait for the next user message
  196. if len(self.callbacks) == 0:
  197. logger.info(
  198. 'Use STDIN to request user message as no callbacks are registered',
  199. extra={'msg_type': 'INFO'},
  200. )
  201. message = input('Request user input [type /exit to stop interaction] >> ')
  202. user_message_observation = UserMessageObservation(message)
  203. else:
  204. user_message_observation = await self._await_user_message_queue.get()
  205. self._await_user_message_queue.task_done()
  206. return user_message_observation
  207. async def start_delegate(self, action: AgentDelegateAction):
  208. AgentCls: Type[Agent] = Agent.get_cls(action.agent)
  209. agent = AgentCls(llm=self.agent.llm)
  210. self.delegate = AgentController(
  211. sid=self.id + '-delegate',
  212. agent=agent,
  213. max_iterations=self.max_iterations,
  214. max_chars=self.max_chars,
  215. callbacks=self.callbacks,
  216. )
  217. task = action.inputs.get('task') or ''
  218. await self.delegate.setup_task(task, action.inputs)
  219. async def step(self, i: int) -> bool:
  220. if self.state is None:
  221. raise ValueError('No task to run')
  222. if self.delegate is not None:
  223. delegate_done = await self.delegate.step(i)
  224. if delegate_done:
  225. outputs = self.delegate.state.outputs if self.delegate.state else {}
  226. obs: Observation = AgentDelegateObservation(content='', outputs=outputs)
  227. self.add_history(NullAction(), obs)
  228. self.delegate = None
  229. self.delegateAction = None
  230. return False
  231. logger.info(f'STEP {i}', extra={'msg_type': 'STEP'})
  232. if i == 0:
  233. logger.info(self.state.plan.main_goal, extra={'msg_type': 'PLAN'})
  234. if self.state.num_of_chars > self.max_chars:
  235. raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
  236. log_obs = self.action_manager.get_background_obs()
  237. for obs in log_obs:
  238. self.add_history(NullAction(), obs)
  239. await self._run_callbacks(obs)
  240. logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
  241. self.update_state_for_step(i)
  242. action: Action = NullAction()
  243. observation: Observation = NullObservation('')
  244. try:
  245. action = self.agent.step(self.state)
  246. if action is None:
  247. raise AgentNoActionError('No action was returned')
  248. except (AgentMalformedActionError, AgentNoActionError, LLMOutputError) as e:
  249. observation = AgentErrorObservation(str(e))
  250. logger.info(action, extra={'msg_type': 'ACTION'})
  251. self.update_state_after_step()
  252. await self._run_callbacks(action)
  253. # whether to await for user messages
  254. if isinstance(action, AgentTalkAction):
  255. # await for the next user messages
  256. user_message_observation = await self.wait_for_user_input()
  257. logger.info(user_message_observation, extra={'msg_type': 'OBSERVATION'})
  258. self.add_history(action, user_message_observation)
  259. return False
  260. finished = isinstance(action, AgentFinishAction)
  261. if finished:
  262. self.state.outputs = action.outputs # type: ignore[attr-defined]
  263. logger.info(action, extra={'msg_type': 'INFO'})
  264. return True
  265. if isinstance(observation, NullObservation):
  266. observation = await self.action_manager.run_action(action, self)
  267. if not isinstance(observation, NullObservation):
  268. logger.info(observation, extra={'msg_type': 'OBSERVATION'})
  269. self.add_history(action, observation)
  270. await self._run_callbacks(observation)
  271. return False
  272. async def _run_callbacks(self, event):
  273. if event is None:
  274. return
  275. for callback in self.callbacks:
  276. idx = self.callbacks.index(callback)
  277. try:
  278. await callback(event)
  279. except Exception as e:
  280. logger.exception(f'Callback error: {e}, idx: {idx}')
  281. await asyncio.sleep(
  282. 0.001
  283. ) # Give back control for a tick, so we can await in callbacks
  284. def get_state(self):
  285. return self.state
  286. def _is_stuck(self):
  287. if (
  288. self.state is None
  289. or self.state.history is None
  290. or len(self.state.history) < 3
  291. ):
  292. return False
  293. # if the last three (Action, Observation) tuples are too repetitive
  294. # the agent got stuck in a loop
  295. if all(
  296. [
  297. self.state.history[-i][0] == self.state.history[-3][0]
  298. for i in range(1, 3)
  299. ]
  300. ):
  301. # it repeats same action, give it a chance, but not if:
  302. if all(
  303. isinstance(self.state.history[-i][1], NullObservation)
  304. for i in range(1, 4)
  305. ):
  306. # same (Action, NullObservation): like 'think' the same thought over and over
  307. logger.debug('Action, NullObservation loop detected')
  308. return True
  309. elif all(
  310. isinstance(self.state.history[-i][1], AgentErrorObservation)
  311. for i in range(1, 4)
  312. ):
  313. # (NullAction, AgentErrorObservation): errors coming from an exception
  314. # (Action, AgentErrorObservation): the same action getting an error, even if not necessarily the same error
  315. logger.debug('Action, AgentErrorObservation loop detected')
  316. return True
  317. return False