agent_controller.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. import asyncio
  2. from typing import Optional, Type
  3. from agenthub.codeact_agent.codeact_agent import CodeActAgent
  4. from opendevin.controller.agent import Agent
  5. from opendevin.controller.state.plan import Plan
  6. from opendevin.controller.state.state import State
  7. from opendevin.core.config import config
  8. from opendevin.core.exceptions import (
  9. AgentMalformedActionError,
  10. AgentNoActionError,
  11. LLMOutputError,
  12. MaxCharsExceedError,
  13. )
  14. from opendevin.core.logger import opendevin_logger as logger
  15. from opendevin.core.schema import AgentState
  16. from opendevin.events.action import (
  17. Action,
  18. AddTaskAction,
  19. AgentDelegateAction,
  20. AgentFinishAction,
  21. ChangeAgentStateAction,
  22. MessageAction,
  23. ModifyTaskAction,
  24. NullAction,
  25. )
  26. from opendevin.events.event import Event
  27. from opendevin.events.observation import (
  28. AgentDelegateObservation,
  29. AgentStateChangedObservation,
  30. ErrorObservation,
  31. NullObservation,
  32. Observation,
  33. )
  34. from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
  35. from opendevin.runtime import DockerSSHBox, Sandbox
  36. from opendevin.runtime.runtime import Runtime
  37. from opendevin.runtime.server.runtime import ServerRuntime
  38. MAX_ITERATIONS = config.max_iterations
  39. MAX_CHARS = config.llm.max_chars
  40. class AgentController:
  41. id: str
  42. agent: Agent
  43. max_iterations: int
  44. runtime: Runtime
  45. event_stream: EventStream
  46. agent_task: Optional[asyncio.Task] = None
  47. delegate: 'AgentController | None' = None
  48. state: State | None = None
  49. _agent_state: AgentState = AgentState.LOADING
  50. _cur_step: int = 0
  51. def __init__(
  52. self,
  53. agent: Agent,
  54. event_stream: EventStream,
  55. sid: str = 'default',
  56. max_iterations: int = MAX_ITERATIONS,
  57. max_chars: int = MAX_CHARS,
  58. sandbox: Optional[Sandbox] = None,
  59. remind_iterations: bool = config.remind_iterations,
  60. ):
  61. """Initializes a new instance of the AgentController class.
  62. Args:
  63. agent: The agent instance to control.
  64. sid: The session ID of the agent.
  65. max_iterations: The maximum number of iterations the agent can run.
  66. max_chars: The maximum number of characters the agent can output.
  67. sandbox: An optional initialized sandbox to run the agent in. If not provided, a default sandbox will be created based on config.
  68. remind_iterations: A boolean value indicating whether to remind the agent its remaining budget of interaction.
  69. """
  70. self.id = sid
  71. self.agent = agent
  72. self.event_stream = event_stream
  73. self.event_stream.subscribe(
  74. EventStreamSubscriber.AGENT_CONTROLLER, self.on_event
  75. )
  76. self.max_iterations = max_iterations
  77. self.remind_iterations = remind_iterations
  78. if self.remind_iterations:
  79. logger.info(
  80. 'Iteration reminder is ENABLED: agent will be reminded of remaining turns.'
  81. )
  82. self.runtime = ServerRuntime(sandbox=sandbox, sid=self.id)
  83. self.max_chars = max_chars
  84. # Initialize agent-required plugins for sandbox (if any)
  85. self.runtime.init_sandbox_plugins(agent.sandbox_plugins)
  86. if isinstance(agent, CodeActAgent) and not isinstance(
  87. self.runtime.sandbox, DockerSSHBox
  88. ):
  89. logger.warning(
  90. 'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
  91. )
  92. async def close(self):
  93. if self.agent_task is not None:
  94. self.agent_task.cancel()
  95. self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
  96. self.runtime.sandbox.close()
  97. self.runtime.browser.close()
  98. await self.set_agent_state_to(AgentState.STOPPED)
  99. def update_state_for_step(self, i):
  100. if self.state is None:
  101. return
  102. self.state.iteration = i
  103. self.state.background_commands_obs = self.runtime.get_background_obs()
  104. def update_state_after_step(self):
  105. if self.state is None:
  106. return
  107. self.state.updated_info = []
  108. async def add_error_to_history(self, message: str):
  109. await self.add_history(NullAction(), ErrorObservation(message))
  110. async def add_history(
  111. self, action: Action, observation: Observation, add_to_stream=True
  112. ):
  113. if self.state is None:
  114. raise ValueError('Added history while state was None')
  115. if not isinstance(action, Action):
  116. raise TypeError(
  117. f'action must be an instance of Action, got {type(action).__name__} instead'
  118. )
  119. if not isinstance(observation, Observation):
  120. raise TypeError(
  121. f'observation must be an instance of Observation, got {type(observation).__name__} instead'
  122. )
  123. self.state.history.append((action, observation))
  124. self.state.updated_info.append((action, observation))
  125. if add_to_stream:
  126. await self.event_stream.add_event(action, EventSource.AGENT)
  127. await self.event_stream.add_event(observation, EventSource.AGENT)
  128. async def _run(self):
  129. if self.state is None:
  130. return
  131. if self._agent_state != AgentState.RUNNING:
  132. raise ValueError('Task is not in running state')
  133. for i in range(self._cur_step, self.max_iterations):
  134. self._cur_step = i
  135. try:
  136. finished = await self.step(i)
  137. if finished:
  138. await self.set_agent_state_to(AgentState.FINISHED)
  139. break
  140. except Exception:
  141. logger.error('Error in loop', exc_info=True)
  142. await self.set_agent_state_to(AgentState.ERROR)
  143. await self.add_error_to_history(
  144. 'Oops! Something went wrong while completing your task. You can check the logs for more info.'
  145. )
  146. break
  147. if self._is_stuck():
  148. logger.info('Loop detected, stopping task')
  149. await self.set_agent_state_to(AgentState.ERROR)
  150. await self.add_error_to_history(
  151. 'I got stuck into a loop, the task has stopped.'
  152. )
  153. break
  154. await asyncio.sleep(
  155. 0.001
  156. ) # Give back control for a tick, so other async stuff can run
  157. final_state = self.get_agent_state()
  158. if final_state == AgentState.RUNNING:
  159. await self.set_agent_state_to(AgentState.PAUSED)
  160. async def setup_task(self, task: str, inputs: dict = {}):
  161. """Sets up the agent controller with a task."""
  162. await self.set_agent_state_to(AgentState.INIT)
  163. self.state = State(Plan(task))
  164. self.state.inputs = inputs
  165. async def on_event(self, event: Event):
  166. if isinstance(event, ChangeAgentStateAction):
  167. await self.set_agent_state_to(event.agent_state) # type: ignore
  168. elif isinstance(event, MessageAction) and event.source == EventSource.USER:
  169. await self.add_history(event, NullObservation(''), add_to_stream=False)
  170. if self.get_agent_state() == AgentState.AWAITING_USER_INPUT:
  171. await self.set_agent_state_to(AgentState.RUNNING)
  172. async def reset_task(self):
  173. if self.agent_task is not None:
  174. self.agent_task.cancel()
  175. self.state = None
  176. self._cur_step = 0
  177. self.agent.reset()
  178. async def set_agent_state_to(self, new_state: AgentState):
  179. logger.info(
  180. f'Setting agent({type(self.agent).__name__}) state from {self._agent_state} to {new_state}'
  181. )
  182. if new_state == self._agent_state:
  183. return
  184. self._agent_state = new_state
  185. if new_state == AgentState.RUNNING:
  186. self.agent_task = asyncio.create_task(self._run())
  187. elif (
  188. new_state == AgentState.PAUSED
  189. or new_state == AgentState.AWAITING_USER_INPUT
  190. ):
  191. self._cur_step += 1
  192. if self.agent_task is not None:
  193. self.agent_task.cancel()
  194. elif (
  195. new_state == AgentState.STOPPED
  196. or new_state == AgentState.ERROR
  197. or new_state == AgentState.FINISHED
  198. ):
  199. await self.reset_task()
  200. await self.event_stream.add_event(
  201. AgentStateChangedObservation('', self._agent_state), EventSource.AGENT
  202. )
  203. def get_agent_state(self):
  204. """Returns the current state of the agent task."""
  205. return self._agent_state
  206. async def start_delegate(self, action: AgentDelegateAction):
  207. AgentCls: Type[Agent] = Agent.get_cls(action.agent)
  208. agent = AgentCls(llm=self.agent.llm)
  209. self.delegate = AgentController(
  210. sid=self.id + '-delegate',
  211. agent=agent,
  212. event_stream=self.event_stream,
  213. max_iterations=self.max_iterations,
  214. max_chars=self.max_chars,
  215. )
  216. task = action.inputs.get('task') or ''
  217. await self.delegate.setup_task(task, action.inputs)
  218. def add_iteration_reminder_when_needed(self, i: int, obs: Observation):
  219. """Add iteration reminder to the observation if needed.
  220. Args:
  221. i: The current iteration number (0-indexed).
  222. obs: The observation to add the reminder to.
  223. """
  224. if self.remind_iterations:
  225. obs.content += f'\n\nENVIRONMENT REMINDER: You have {self.max_iterations - i - 1} turns left to complete the task.'
  226. return obs
  227. async def step(self, i: int) -> bool:
  228. if self.state is None:
  229. raise ValueError('No task to run')
  230. if self.delegate is not None:
  231. delegate_done = await self.delegate.step(i)
  232. if delegate_done:
  233. outputs = self.delegate.state.outputs if self.delegate.state else {}
  234. obs: Observation = AgentDelegateObservation(content='', outputs=outputs)
  235. await self.add_history(NullAction(), obs)
  236. self.delegate = None
  237. self.delegateAction = None
  238. return False
  239. logger.info(f'STEP {i}', extra={'msg_type': 'STEP'})
  240. if i == 0:
  241. logger.info(self.state.plan.main_goal, extra={'msg_type': 'PLAN'})
  242. if self.state.num_of_chars > self.max_chars:
  243. raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
  244. log_obs = self.runtime.get_background_obs()
  245. for obs in log_obs:
  246. await self.add_history(NullAction(), obs)
  247. logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
  248. self.update_state_for_step(i)
  249. action: Action = NullAction()
  250. observation: Observation = NullObservation('')
  251. try:
  252. action = self.agent.step(self.state)
  253. if action is None:
  254. raise AgentNoActionError('No action was returned')
  255. except (AgentMalformedActionError, AgentNoActionError, LLMOutputError) as e:
  256. observation = ErrorObservation(str(e))
  257. logger.info(action, extra={'msg_type': 'ACTION'})
  258. self.update_state_after_step()
  259. if isinstance(action, AgentFinishAction):
  260. self.state.outputs = action.outputs # type: ignore[attr-defined]
  261. logger.info(action, extra={'msg_type': 'INFO'})
  262. await self.add_history(action, NullObservation(''))
  263. return True
  264. elif isinstance(action, MessageAction) and action.wait_for_response:
  265. # FIXME: remove this once history is managed outside the agent controller
  266. await self.add_history(action, NullObservation(''))
  267. await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
  268. return False
  269. elif isinstance(action, AgentDelegateAction):
  270. await self.start_delegate(action)
  271. elif isinstance(action, AddTaskAction):
  272. self.state.plan.add_subtask(action.parent, action.goal, action.subtasks)
  273. elif isinstance(action, ModifyTaskAction):
  274. self.state.plan.set_subtask_state(action.id, action.state)
  275. elif not isinstance(observation, ErrorObservation):
  276. observation = await self.runtime.run_action(action)
  277. observation = self.add_iteration_reminder_when_needed(i, observation)
  278. if not isinstance(observation, NullObservation):
  279. logger.info(observation, extra={'msg_type': 'OBSERVATION'})
  280. await self.add_history(action, observation)
  281. return False
  282. def get_state(self):
  283. return self.state
  284. def _is_stuck(self):
  285. # check if delegate stuck
  286. if self.delegate and self.delegate._is_stuck():
  287. return True
  288. if (
  289. self.state is None
  290. or self.state.history is None
  291. or len(self.state.history) < 3
  292. ):
  293. return False
  294. # if the last three (Action, Observation) tuples are too repetitive
  295. # the agent got stuck in a loop
  296. if all(
  297. [
  298. self.state.history[-i][0] == self.state.history[-3][0]
  299. for i in range(1, 3)
  300. ]
  301. ):
  302. # it repeats same action, give it a chance, but not if:
  303. if all(
  304. isinstance(self.state.history[-i][1], NullObservation)
  305. for i in range(1, 4)
  306. ):
  307. # same (Action, NullObservation): like 'think' the same thought over and over
  308. logger.debug('Action, NullObservation loop detected')
  309. return True
  310. elif all(
  311. isinstance(self.state.history[-i][1], ErrorObservation)
  312. for i in range(1, 4)
  313. ):
  314. # (NullAction, ErrorObservation): errors coming from an exception
  315. # (Action, ErrorObservation): the same action getting an error, even if not necessarily the same error
  316. logger.debug('Action, ErrorObservation loop detected')
  317. return True
  318. return False