agent_controller.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import asyncio
  2. import traceback
  3. from typing import Optional, Type
  4. from opendevin.controller.agent import Agent
  5. from opendevin.controller.state.state import State
  6. from opendevin.core.config import config
  7. from opendevin.core.exceptions import (
  8. AgentMalformedActionError,
  9. AgentNoActionError,
  10. LLMOutputError,
  11. MaxCharsExceedError,
  12. )
  13. from opendevin.core.logger import opendevin_logger as logger
  14. from opendevin.core.schema import AgentState
  15. from opendevin.events import EventSource, EventStream, EventStreamSubscriber
  16. from opendevin.events.action import (
  17. Action,
  18. AddTaskAction,
  19. AgentDelegateAction,
  20. AgentFinishAction,
  21. ChangeAgentStateAction,
  22. MessageAction,
  23. ModifyTaskAction,
  24. NullAction,
  25. )
  26. from opendevin.events.event import Event
  27. from opendevin.events.observation import (
  28. AgentDelegateObservation,
  29. AgentStateChangedObservation,
  30. CmdOutputObservation,
  31. ErrorObservation,
  32. NullObservation,
  33. Observation,
  34. )
  35. MAX_ITERATIONS = config.max_iterations
  36. MAX_CHARS = config.llm.max_chars
  37. class AgentController:
  38. id: str
  39. agent: Agent
  40. max_iterations: int
  41. event_stream: EventStream
  42. state: State
  43. agent_task: Optional[asyncio.Task] = None
  44. delegate: 'AgentController | None' = None
  45. _pending_action: Action | None = None
  46. def __init__(
  47. self,
  48. agent: Agent,
  49. event_stream: EventStream,
  50. sid: str = 'default',
  51. max_iterations: int = MAX_ITERATIONS,
  52. max_chars: int = MAX_CHARS,
  53. inputs: dict | None = None,
  54. ):
  55. """Initializes a new instance of the AgentController class.
  56. Args:
  57. agent: The agent instance to control.
  58. event_stream: The event stream to publish events to.
  59. sid: The session ID of the agent.
  60. max_iterations: The maximum number of iterations the agent can run.
  61. max_chars: The maximum number of characters the agent can output.
  62. inputs: The initial inputs to the agent.
  63. """
  64. self.id = sid
  65. self.agent = agent
  66. self.state = State(inputs=inputs or {}, max_iterations=max_iterations)
  67. self.event_stream = event_stream
  68. self.event_stream.subscribe(
  69. EventStreamSubscriber.AGENT_CONTROLLER, self.on_event
  70. )
  71. self.max_iterations = max_iterations
  72. self.max_chars = max_chars
  73. self.agent_task = asyncio.create_task(self._start_step_loop())
  74. async def close(self):
  75. if self.agent_task is not None:
  76. self.agent_task.cancel()
  77. self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
  78. await self.set_agent_state_to(AgentState.STOPPED)
  79. def update_state_before_step(self):
  80. self.state.iteration += 1
  81. def update_state_after_step(self):
  82. self.state.updated_info = []
  83. # update metrics especially for cost
  84. self.state.metrics = self.agent.llm.metrics
  85. async def report_error(self, message: str, exception: Exception | None = None):
  86. self.state.error = message
  87. if exception:
  88. self.state.error += f': {str(exception)}'
  89. await self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
  90. async def add_history(self, action: Action, observation: Observation):
  91. if isinstance(action, NullAction) and isinstance(observation, NullObservation):
  92. return
  93. self.state.history.append((action, observation))
  94. self.state.updated_info.append((action, observation))
  95. async def _start_step_loop(self):
  96. while True:
  97. try:
  98. await self._step()
  99. except asyncio.CancelledError:
  100. logger.info('AgentController task was cancelled')
  101. break
  102. except Exception as e:
  103. traceback.print_exc()
  104. logger.error(f'Error while running the agent: {e}')
  105. logger.error(traceback.format_exc())
  106. await self.report_error(
  107. 'There was an unexpected error while running the agent', exception=e
  108. )
  109. await self.set_agent_state_to(AgentState.ERROR)
  110. break
  111. await asyncio.sleep(0.1)
  112. async def on_event(self, event: Event):
  113. if isinstance(event, ChangeAgentStateAction):
  114. await self.set_agent_state_to(event.agent_state) # type: ignore
  115. elif isinstance(event, MessageAction):
  116. if event.source == EventSource.USER:
  117. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  118. await self.add_history(event, NullObservation(''))
  119. if self.get_agent_state() != AgentState.RUNNING:
  120. await self.set_agent_state_to(AgentState.RUNNING)
  121. elif event.source == EventSource.AGENT and event.wait_for_response:
  122. logger.info(event, extra={'msg_type': 'ACTION'})
  123. await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
  124. elif isinstance(event, AgentDelegateAction):
  125. await self.start_delegate(event)
  126. elif isinstance(event, AddTaskAction):
  127. self.state.root_task.add_subtask(event.parent, event.goal, event.subtasks)
  128. elif isinstance(event, ModifyTaskAction):
  129. self.state.root_task.set_subtask_state(event.task_id, event.state)
  130. elif isinstance(event, AgentFinishAction):
  131. self.state.outputs = event.outputs # type: ignore[attr-defined]
  132. await self.set_agent_state_to(AgentState.FINISHED)
  133. elif isinstance(event, Observation):
  134. if self._pending_action and self._pending_action.id == event.cause:
  135. await self.add_history(self._pending_action, event)
  136. self._pending_action = None
  137. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  138. elif isinstance(event, CmdOutputObservation):
  139. await self.add_history(NullAction(), event)
  140. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  141. def reset_task(self):
  142. self.agent.reset()
  143. async def set_agent_state_to(self, new_state: AgentState):
  144. logger.info(
  145. f'Setting agent({type(self.agent).__name__}) state from {self.state.agent_state} to {new_state}'
  146. )
  147. if new_state == self.state.agent_state:
  148. return
  149. self.state.agent_state = new_state
  150. if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
  151. self.reset_task()
  152. await self.event_stream.add_event(
  153. AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
  154. )
  155. if new_state == AgentState.INIT and self.state.resume_state:
  156. await self.set_agent_state_to(self.state.resume_state)
  157. self.state.resume_state = None
  158. def get_agent_state(self):
  159. """Returns the current state of the agent task."""
  160. return self.state.agent_state
  161. async def start_delegate(self, action: AgentDelegateAction):
  162. AgentCls: Type[Agent] = Agent.get_cls(action.agent)
  163. agent = AgentCls(llm=self.agent.llm)
  164. self.delegate = AgentController(
  165. sid=self.id + '-delegate',
  166. agent=agent,
  167. event_stream=self.event_stream,
  168. max_iterations=self.max_iterations,
  169. max_chars=self.max_chars,
  170. inputs=action.inputs,
  171. )
  172. async def _step(self):
  173. if self.get_agent_state() != AgentState.RUNNING:
  174. logger.debug('waiting for agent to run...')
  175. await asyncio.sleep(1)
  176. return
  177. if self._pending_action:
  178. logger.debug('waiting for pending action: ' + str(self._pending_action))
  179. await asyncio.sleep(1)
  180. return
  181. logger.info(f'STEP {self.state.iteration}', extra={'msg_type': 'STEP'})
  182. if self.state.iteration >= self.max_iterations:
  183. await self.report_error('Agent reached maximum number of iterations')
  184. await self.set_agent_state_to(AgentState.ERROR)
  185. return
  186. if self.delegate is not None:
  187. delegate_done = await self.delegate._step()
  188. if delegate_done:
  189. outputs = self.delegate.state.outputs if self.delegate.state else {}
  190. obs: Observation = AgentDelegateObservation(content='', outputs=outputs)
  191. await self.event_stream.add_event(obs, EventSource.AGENT)
  192. self.delegate = None
  193. self.delegateAction = None
  194. return
  195. if self.state.num_of_chars > self.max_chars:
  196. raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
  197. self.update_state_before_step()
  198. action: Action = NullAction()
  199. try:
  200. action = self.agent.step(self.state)
  201. if action is None:
  202. raise AgentNoActionError('No action was returned')
  203. except (AgentMalformedActionError, AgentNoActionError, LLMOutputError) as e:
  204. await self.report_error(str(e))
  205. return
  206. logger.info(action, extra={'msg_type': 'ACTION'})
  207. self.update_state_after_step()
  208. if action.runnable:
  209. self._pending_action = action
  210. else:
  211. await self.add_history(action, NullObservation(''))
  212. if not isinstance(action, NullAction):
  213. await self.event_stream.add_event(action, EventSource.AGENT)
  214. if self._is_stuck():
  215. await self.report_error('Agent got stuck in a loop')
  216. await self.set_agent_state_to(AgentState.ERROR)
  217. def get_state(self):
  218. return self.state
  219. def set_state(self, state: State):
  220. self.state = state
  221. def _is_stuck(self):
  222. # check if delegate stuck
  223. if self.delegate and self.delegate._is_stuck():
  224. return True
  225. # filter out MessageAction with source='user' from history
  226. filtered_history = [
  227. _tuple
  228. for _tuple in self.state.history
  229. if not (
  230. isinstance(_tuple[0], MessageAction)
  231. and _tuple[0].source == EventSource.USER
  232. )
  233. ]
  234. if len(filtered_history) < 4:
  235. return False
  236. # Check if the last four (Action, Observation) tuples are too repetitive
  237. last_four_tuples = filtered_history[-4:]
  238. if all(last_four_tuples[-1] == _tuple for _tuple in last_four_tuples):
  239. logger.warning('Action, Observation loop detected')
  240. return True
  241. if all(last_four_tuples[-1][0] == _tuple[0] for _tuple in last_four_tuples):
  242. # It repeats the same action, give it a chance, but not if:
  243. if all(
  244. isinstance(_tuple[1], ErrorObservation) for _tuple in last_four_tuples
  245. ):
  246. logger.warning('Action, ErrorObservation loop detected')
  247. return True
  248. # check if the agent repeats the same (Action, Observation)
  249. # every other step in the last six tuples
  250. if len(filtered_history) >= 6:
  251. last_six_tuples = filtered_history[-6:]
  252. if (
  253. last_six_tuples[-1] == last_six_tuples[-3] == last_six_tuples[-5]
  254. and last_six_tuples[-2] == last_six_tuples[-4] == last_six_tuples[-6]
  255. ):
  256. logger.warning('Action, Observation pattern detected')
  257. return True
  258. return False