agent_controller.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. import asyncio
  2. from typing import Optional, Type
  3. from agenthub.codeact_agent.codeact_agent import CodeActAgent
  4. from opendevin.controller.agent import Agent
  5. from opendevin.controller.state.state import State
  6. from opendevin.core.config import config
  7. from opendevin.core.exceptions import (
  8. AgentMalformedActionError,
  9. AgentNoActionError,
  10. LLMOutputError,
  11. MaxCharsExceedError,
  12. )
  13. from opendevin.core.logger import opendevin_logger as logger
  14. from opendevin.core.schema import AgentState
  15. from opendevin.events.action import (
  16. Action,
  17. AddTaskAction,
  18. AgentDelegateAction,
  19. AgentFinishAction,
  20. ChangeAgentStateAction,
  21. MessageAction,
  22. ModifyTaskAction,
  23. NullAction,
  24. )
  25. from opendevin.events.event import Event
  26. from opendevin.events.observation import (
  27. AgentDelegateObservation,
  28. AgentStateChangedObservation,
  29. ErrorObservation,
  30. NullObservation,
  31. Observation,
  32. )
  33. from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
  34. from opendevin.runtime import DockerSSHBox, Sandbox
  35. from opendevin.runtime.runtime import Runtime
  36. from opendevin.runtime.server.runtime import ServerRuntime
  37. MAX_ITERATIONS = config.max_iterations
  38. MAX_CHARS = config.llm.max_chars
  39. class AgentController:
  40. id: str
  41. agent: Agent
  42. max_iterations: int
  43. runtime: Runtime
  44. event_stream: EventStream
  45. state: State
  46. agent_task: Optional[asyncio.Task] = None
  47. delegate: 'AgentController | None' = None
  48. _agent_state: AgentState = AgentState.LOADING
  49. _cur_step: int = 0
  50. def __init__(
  51. self,
  52. agent: Agent,
  53. event_stream: EventStream,
  54. sid: str = 'default',
  55. max_iterations: int = MAX_ITERATIONS,
  56. max_chars: int = MAX_CHARS,
  57. inputs: dict | None = None,
  58. sandbox: Optional[Sandbox] = None,
  59. ):
  60. """Initializes a new instance of the AgentController class.
  61. Args:
  62. agent: The agent instance to control.
  63. event_stream: The event stream to publish events to.
  64. sid: The session ID of the agent.
  65. max_iterations: The maximum number of iterations the agent can run.
  66. max_chars: The maximum number of characters the agent can output.
  67. inputs: The initial inputs to the agent.
  68. sandbox: An optional initialized sandbox to run the agent in. If not provided, a default sandbox will be created based on config.
  69. """
  70. self.id = sid
  71. self.agent = agent
  72. self.state = State(inputs=inputs or {})
  73. self.event_stream = event_stream
  74. self.event_stream.subscribe(
  75. EventStreamSubscriber.AGENT_CONTROLLER, self.on_event
  76. )
  77. self.max_iterations = max_iterations
  78. self.runtime = ServerRuntime(sandbox=sandbox, sid=self.id)
  79. self.max_chars = max_chars
  80. # Initialize agent-required plugins for sandbox (if any)
  81. self.runtime.init_sandbox_plugins(agent.sandbox_plugins)
  82. if isinstance(agent, CodeActAgent) and not isinstance(
  83. self.runtime.sandbox, DockerSSHBox
  84. ):
  85. logger.warning(
  86. 'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
  87. )
  88. async def close(self):
  89. if self.agent_task is not None:
  90. self.agent_task.cancel()
  91. self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
  92. self.runtime.sandbox.close()
  93. self.runtime.browser.close()
  94. await self.set_agent_state_to(AgentState.STOPPED)
  95. def update_state_for_step(self, i):
  96. self.state.iteration = i
  97. self.state.background_commands_obs = self.runtime.get_background_obs()
  98. def update_state_after_step(self):
  99. self.state.updated_info = []
  100. async def add_error_to_history(self, message: str):
  101. await self.add_history(NullAction(), ErrorObservation(message))
  102. async def add_history(
  103. self, action: Action, observation: Observation, add_to_stream=True
  104. ):
  105. if not isinstance(action, Action):
  106. raise TypeError(
  107. f'action must be an instance of Action, got {type(action).__name__} instead'
  108. )
  109. if not isinstance(observation, Observation):
  110. raise TypeError(
  111. f'observation must be an instance of Observation, got {type(observation).__name__} instead'
  112. )
  113. self.state.history.append((action, observation))
  114. self.state.updated_info.append((action, observation))
  115. if add_to_stream:
  116. await self.event_stream.add_event(action, EventSource.AGENT)
  117. await self.event_stream.add_event(observation, EventSource.AGENT)
  118. async def _run(self):
  119. if self._agent_state != AgentState.RUNNING:
  120. raise ValueError('Task is not in running state')
  121. for i in range(self._cur_step, self.max_iterations):
  122. self._cur_step = i
  123. try:
  124. finished = await self.step(i)
  125. if finished:
  126. await self.set_agent_state_to(AgentState.FINISHED)
  127. break
  128. except Exception:
  129. logger.error('Error in loop', exc_info=True)
  130. await self.set_agent_state_to(AgentState.ERROR)
  131. await self.add_error_to_history(
  132. 'Oops! Something went wrong while completing your task. You can check the logs for more info.'
  133. )
  134. break
  135. if self._is_stuck():
  136. logger.info('Loop detected, stopping task')
  137. await self.set_agent_state_to(AgentState.ERROR)
  138. await self.add_error_to_history(
  139. 'I got stuck into a loop, the task has stopped.'
  140. )
  141. break
  142. await asyncio.sleep(
  143. 0.001
  144. ) # Give back control for a tick, so other async stuff can run
  145. final_state = self.get_agent_state()
  146. if final_state == AgentState.RUNNING:
  147. await self.set_agent_state_to(AgentState.PAUSED)
  148. async def on_event(self, event: Event):
  149. if isinstance(event, ChangeAgentStateAction):
  150. await self.set_agent_state_to(event.agent_state) # type: ignore
  151. elif isinstance(event, MessageAction) and event.source == EventSource.USER:
  152. await self.add_history(event, NullObservation(''), add_to_stream=False)
  153. if self.get_agent_state() != AgentState.RUNNING:
  154. await self.set_agent_state_to(AgentState.RUNNING)
  155. async def reset_task(self):
  156. if self.agent_task is not None:
  157. self.agent_task.cancel()
  158. self.state = State()
  159. self._cur_step = 0
  160. self.agent.reset()
  161. async def set_agent_state_to(self, new_state: AgentState):
  162. logger.info(
  163. f'Setting agent({type(self.agent).__name__}) state from {self._agent_state} to {new_state}'
  164. )
  165. if new_state == self._agent_state:
  166. return
  167. self._agent_state = new_state
  168. if new_state == AgentState.RUNNING:
  169. self.agent_task = asyncio.create_task(self._run())
  170. elif (
  171. new_state == AgentState.PAUSED
  172. or new_state == AgentState.AWAITING_USER_INPUT
  173. ):
  174. self._cur_step += 1
  175. if self.agent_task is not None:
  176. self.agent_task.cancel()
  177. elif new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
  178. await self.reset_task()
  179. await self.event_stream.add_event(
  180. AgentStateChangedObservation('', self._agent_state), EventSource.AGENT
  181. )
  182. def get_agent_state(self):
  183. """Returns the current state of the agent task."""
  184. return self._agent_state
  185. async def start_delegate(self, action: AgentDelegateAction):
  186. AgentCls: Type[Agent] = Agent.get_cls(action.agent)
  187. agent = AgentCls(llm=self.agent.llm)
  188. self.delegate = AgentController(
  189. sid=self.id + '-delegate',
  190. agent=agent,
  191. event_stream=self.event_stream,
  192. max_iterations=self.max_iterations,
  193. max_chars=self.max_chars,
  194. inputs=action.inputs,
  195. )
  196. async def step(self, i: int) -> bool:
  197. if self.delegate is not None:
  198. delegate_done = await self.delegate.step(i)
  199. if delegate_done:
  200. outputs = self.delegate.state.outputs if self.delegate.state else {}
  201. obs: Observation = AgentDelegateObservation(content='', outputs=outputs)
  202. await self.add_history(NullAction(), obs)
  203. self.delegate = None
  204. self.delegateAction = None
  205. return False
  206. logger.info(f'STEP {i}', extra={'msg_type': 'STEP'})
  207. if self.state.num_of_chars > self.max_chars:
  208. raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
  209. log_obs = self.runtime.get_background_obs()
  210. for obs in log_obs:
  211. await self.add_history(NullAction(), obs)
  212. logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
  213. self.update_state_for_step(i)
  214. action: Action = NullAction()
  215. observation: Observation = NullObservation('')
  216. try:
  217. action = self.agent.step(self.state)
  218. if action is None:
  219. raise AgentNoActionError('No action was returned')
  220. except (AgentMalformedActionError, AgentNoActionError, LLMOutputError) as e:
  221. observation = ErrorObservation(str(e))
  222. logger.info(action, extra={'msg_type': 'ACTION'})
  223. self.update_state_after_step()
  224. if isinstance(action, AgentFinishAction):
  225. self.state.outputs = action.outputs # type: ignore[attr-defined]
  226. logger.info(action, extra={'msg_type': 'INFO'})
  227. await self.add_history(action, NullObservation(''))
  228. return True
  229. elif isinstance(action, MessageAction) and action.wait_for_response:
  230. # FIXME: remove this once history is managed outside the agent controller
  231. await self.add_history(action, NullObservation(''))
  232. await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
  233. return False
  234. elif isinstance(action, AgentDelegateAction):
  235. await self.start_delegate(action)
  236. elif isinstance(action, AddTaskAction):
  237. self.state.root_task.add_subtask(
  238. action.parent, action.goal, action.subtasks
  239. )
  240. elif isinstance(action, ModifyTaskAction):
  241. self.state.root_task.set_subtask_state(action.id, action.state)
  242. elif not isinstance(observation, ErrorObservation):
  243. observation = await self.runtime.run_action(action)
  244. if not isinstance(observation, NullObservation):
  245. logger.info(observation, extra={'msg_type': 'OBSERVATION'})
  246. await self.add_history(action, observation)
  247. return False
  248. def get_state(self):
  249. return self.state
  250. def _is_stuck(self):
  251. # check if delegate stuck
  252. if self.delegate and self.delegate._is_stuck():
  253. return True
  254. if len(self.state.history) < 3:
  255. return False
  256. # if the last three (Action, Observation) tuples are too repetitive
  257. # the agent got stuck in a loop
  258. if all(
  259. [
  260. self.state.history[-i][0] == self.state.history[-3][0]
  261. for i in range(1, 3)
  262. ]
  263. ):
  264. # it repeats same action, give it a chance, but not if:
  265. if all(
  266. isinstance(self.state.history[-i][1], NullObservation)
  267. for i in range(1, 4)
  268. ):
  269. # same (Action, NullObservation): like 'think' the same thought over and over
  270. logger.debug('Action, NullObservation loop detected')
  271. return True
  272. elif all(
  273. isinstance(self.state.history[-i][1], ErrorObservation)
  274. for i in range(1, 4)
  275. ):
  276. # (NullAction, ErrorObservation): errors coming from an exception
  277. # (Action, ErrorObservation): the same action getting an error, even if not necessarily the same error
  278. logger.debug('Action, ErrorObservation loop detected')
  279. return True
  280. return False