agent_controller.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import asyncio
  2. import traceback
  3. from typing import Optional, Type
  4. from opendevin.controller.agent import Agent
  5. from opendevin.controller.state.state import State
  6. from opendevin.core.config import config
  7. from opendevin.core.exceptions import (
  8. AgentMalformedActionError,
  9. AgentNoActionError,
  10. LLMOutputError,
  11. MaxCharsExceedError,
  12. )
  13. from opendevin.core.logger import opendevin_logger as logger
  14. from opendevin.core.schema import AgentState
  15. from opendevin.events import EventSource, EventStream, EventStreamSubscriber
  16. from opendevin.events.action import (
  17. Action,
  18. AddTaskAction,
  19. AgentDelegateAction,
  20. AgentFinishAction,
  21. AgentRejectAction,
  22. ChangeAgentStateAction,
  23. MessageAction,
  24. ModifyTaskAction,
  25. NullAction,
  26. )
  27. from opendevin.events.action.commands import CmdKillAction
  28. from opendevin.events.event import Event
  29. from opendevin.events.observation import (
  30. AgentDelegateObservation,
  31. AgentStateChangedObservation,
  32. CmdOutputObservation,
  33. ErrorObservation,
  34. NullObservation,
  35. Observation,
  36. )
  37. MAX_ITERATIONS = config.max_iterations
  38. MAX_CHARS = config.llm.max_chars
  39. MAX_BUDGET_PER_TASK = config.max_budget_per_task
  40. class AgentController:
  41. id: str
  42. agent: Agent
  43. max_iterations: int
  44. event_stream: EventStream
  45. state: State
  46. agent_task: Optional[asyncio.Task] = None
  47. parent: 'AgentController | None' = None
  48. delegate: 'AgentController | None' = None
  49. _pending_action: Action | None = None
  50. def __init__(
  51. self,
  52. agent: Agent,
  53. event_stream: EventStream,
  54. sid: str = 'default',
  55. max_iterations: int = MAX_ITERATIONS,
  56. max_chars: int = MAX_CHARS,
  57. max_budget_per_task: float | None = MAX_BUDGET_PER_TASK,
  58. initial_state: State | None = None,
  59. is_delegate: bool = False,
  60. ):
  61. """Initializes a new instance of the AgentController class.
  62. Args:
  63. agent: The agent instance to control.
  64. event_stream: The event stream to publish events to.
  65. sid: The session ID of the agent.
  66. max_iterations: The maximum number of iterations the agent can run.
  67. max_chars: The maximum number of characters the agent can output.
  68. max_budget_per_task: The maximum budget (in USD) allowed per task, beyond which the agent will stop.
  69. initial_state: The initial state of the controller.
  70. is_delegate: Whether this controller is a delegate.
  71. """
  72. self._step_lock = asyncio.Lock()
  73. self.id = sid
  74. self.agent = agent
  75. self.max_chars = max_chars
  76. if initial_state is None:
  77. self.state = State(inputs={}, max_iterations=max_iterations)
  78. else:
  79. self.state = initial_state
  80. self.event_stream = event_stream
  81. self.event_stream.subscribe(
  82. EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
  83. )
  84. self.max_budget_per_task = max_budget_per_task
  85. if not is_delegate:
  86. self.agent_task = asyncio.create_task(self._start_step_loop())
  87. async def close(self):
  88. if self.agent_task is not None:
  89. self.agent_task.cancel()
  90. await self.set_agent_state_to(AgentState.STOPPED)
  91. self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
  92. def update_state_before_step(self):
  93. self.state.iteration += 1
  94. async def update_state_after_step(self):
  95. self.state.updated_info = []
  96. # update metrics especially for cost
  97. self.state.metrics = self.agent.llm.metrics
  98. if self.max_budget_per_task is not None:
  99. current_cost = self.state.metrics.accumulated_cost
  100. if current_cost > self.max_budget_per_task:
  101. await self.report_error(
  102. f'Task budget exceeded. Current cost: {current_cost}, Max budget: {self.max_budget_per_task}'
  103. )
  104. await self.set_agent_state_to(AgentState.ERROR)
  105. async def report_error(self, message: str, exception: Exception | None = None):
  106. self.state.error = message
  107. if exception:
  108. self.state.error += f': {str(exception)}'
  109. await self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
  110. async def add_history(self, action: Action, observation: Observation):
  111. if isinstance(action, NullAction) and isinstance(observation, NullObservation):
  112. return
  113. self.state.history.append((action, observation))
  114. self.state.updated_info.append((action, observation))
  115. async def _start_step_loop(self):
  116. logger.info(f'[Agent Controller {self.id}] Starting step loop...')
  117. while True:
  118. try:
  119. await self._step()
  120. except asyncio.CancelledError:
  121. logger.info('AgentController task was cancelled')
  122. break
  123. except Exception as e:
  124. traceback.print_exc()
  125. logger.error(f'Error while running the agent: {e}')
  126. logger.error(traceback.format_exc())
  127. await self.report_error(
  128. 'There was an unexpected error while running the agent', exception=e
  129. )
  130. await self.set_agent_state_to(AgentState.ERROR)
  131. break
  132. await asyncio.sleep(0.1)
  133. async def on_event(self, event: Event):
  134. if isinstance(event, ChangeAgentStateAction):
  135. await self.set_agent_state_to(event.agent_state) # type: ignore
  136. elif isinstance(event, MessageAction):
  137. if event.source == EventSource.USER:
  138. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  139. await self.add_history(event, NullObservation(''))
  140. if self.get_agent_state() != AgentState.RUNNING:
  141. await self.set_agent_state_to(AgentState.RUNNING)
  142. elif event.source == EventSource.AGENT and event.wait_for_response:
  143. logger.info(event, extra={'msg_type': 'ACTION'})
  144. await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
  145. elif isinstance(event, AgentDelegateAction):
  146. await self.start_delegate(event)
  147. elif isinstance(event, AddTaskAction):
  148. self.state.root_task.add_subtask(event.parent, event.goal, event.subtasks)
  149. elif isinstance(event, ModifyTaskAction):
  150. self.state.root_task.set_subtask_state(event.task_id, event.state)
  151. elif isinstance(event, AgentFinishAction):
  152. self.state.outputs = event.outputs # type: ignore[attr-defined]
  153. await self.set_agent_state_to(AgentState.FINISHED)
  154. elif isinstance(event, AgentRejectAction):
  155. self.state.outputs = event.outputs # type: ignore[attr-defined]
  156. await self.set_agent_state_to(AgentState.REJECTED)
  157. elif isinstance(event, Observation):
  158. if self._pending_action and self._pending_action.id == event.cause:
  159. await self.add_history(self._pending_action, event)
  160. self._pending_action = None
  161. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  162. elif isinstance(event, CmdOutputObservation):
  163. await self.add_history(NullAction(), event)
  164. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  165. elif isinstance(event, AgentDelegateObservation):
  166. await self.add_history(NullAction(), event)
  167. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  168. def reset_task(self):
  169. self.agent.reset()
  170. async def set_agent_state_to(self, new_state: AgentState):
  171. logger.info(
  172. f'[Agent Controller {self.id}] Setting agent({type(self.agent).__name__}) state from {self.state.agent_state} to {new_state}'
  173. )
  174. if new_state == self.state.agent_state:
  175. return
  176. self.state.agent_state = new_state
  177. if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
  178. self.reset_task()
  179. await self.event_stream.add_event(
  180. AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
  181. )
  182. if new_state == AgentState.INIT and self.state.resume_state:
  183. await self.set_agent_state_to(self.state.resume_state)
  184. self.state.resume_state = None
  185. def get_agent_state(self):
  186. """Returns the current state of the agent task."""
  187. return self.state.agent_state
  188. async def start_delegate(self, action: AgentDelegateAction):
  189. AgentCls: Type[Agent] = Agent.get_cls(action.agent)
  190. agent = AgentCls(llm=self.agent.llm)
  191. state = State(
  192. inputs=action.inputs or {},
  193. iteration=0,
  194. max_iterations=self.state.max_iterations,
  195. num_of_chars=self.state.num_of_chars,
  196. delegate_level=self.state.delegate_level + 1,
  197. )
  198. logger.info(f'[Agent Controller {self.id}]: start delegate')
  199. self.delegate = AgentController(
  200. sid=self.id + '-delegate',
  201. agent=agent,
  202. event_stream=self.event_stream,
  203. max_iterations=self.state.max_iterations,
  204. max_chars=self.max_chars,
  205. initial_state=state,
  206. is_delegate=True,
  207. )
  208. await self.delegate.set_agent_state_to(AgentState.RUNNING)
  209. async def _step(self):
  210. logger.debug(f'[Agent Controller {self.id}] Entering step method')
  211. if self.get_agent_state() != AgentState.RUNNING:
  212. await asyncio.sleep(1)
  213. return
  214. if self._pending_action:
  215. logger.info(
  216. f'[Agent Controller {self.id}] waiting for pending action: {self._pending_action}'
  217. )
  218. await asyncio.sleep(1)
  219. return
  220. if self.delegate is not None:
  221. logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
  222. assert self.delegate != self
  223. await self.delegate._step()
  224. logger.debug(f'[Agent Controller {self.id}] Delegate step done')
  225. assert self.delegate is not None
  226. delegate_state = self.delegate.get_agent_state()
  227. if delegate_state == AgentState.ERROR:
  228. # close the delegate upon error
  229. await self.delegate.close()
  230. await self.report_error('Delegator agent encounters an error')
  231. # propagate error state until an agent or user can handle it
  232. await self.set_agent_state_to(AgentState.ERROR)
  233. return
  234. delegate_done = delegate_state in (AgentState.FINISHED, AgentState.REJECTED)
  235. if delegate_done:
  236. logger.info(
  237. f'[Agent Controller {self.id}] Delegate agent has finished execution'
  238. )
  239. # retrieve delegate result
  240. outputs = self.delegate.state.outputs if self.delegate.state else {}
  241. # close delegate controller: we must close the delegate controller before adding new events
  242. await self.delegate.close()
  243. # clean up delegate status
  244. self.delegate = None
  245. self.delegateAction = None
  246. # update delegate result observation
  247. obs: Observation = AgentDelegateObservation(outputs=outputs, content='')
  248. await self.event_stream.add_event(obs, EventSource.AGENT)
  249. return
  250. if self.state.num_of_chars > self.max_chars:
  251. raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
  252. logger.info(
  253. f'{type(self.agent).__name__} LEVEL {self.state.delegate_level} STEP {self.state.iteration}',
  254. extra={'msg_type': 'STEP'},
  255. )
  256. if self.state.iteration >= self.state.max_iterations:
  257. await self.report_error('Agent reached maximum number of iterations')
  258. await self.set_agent_state_to(AgentState.ERROR)
  259. return
  260. self.update_state_before_step()
  261. action: Action = NullAction()
  262. try:
  263. action = self.agent.step(self.state)
  264. if action is None:
  265. raise AgentNoActionError('No action was returned')
  266. except (AgentMalformedActionError, AgentNoActionError, LLMOutputError) as e:
  267. await self.report_error(str(e))
  268. return
  269. logger.info(action, extra={'msg_type': 'ACTION'})
  270. await self.update_state_after_step()
  271. if action.runnable:
  272. self._pending_action = action
  273. else:
  274. await self.add_history(action, NullObservation(''))
  275. if not isinstance(action, NullAction):
  276. await self.event_stream.add_event(action, EventSource.AGENT)
  277. if self._is_stuck():
  278. await self.report_error('Agent got stuck in a loop')
  279. await self.set_agent_state_to(AgentState.ERROR)
  280. def get_state(self):
  281. return self.state
  282. def set_state(self, state: State):
  283. self.state = state
  284. def _is_stuck(self):
  285. # check if delegate stuck
  286. if self.delegate and self.delegate._is_stuck():
  287. return True
  288. # filter out MessageAction with source='user' from history
  289. filtered_history = [
  290. _tuple
  291. for _tuple in self.state.history
  292. if not (
  293. isinstance(_tuple[0], MessageAction)
  294. and _tuple[0].source == EventSource.USER
  295. )
  296. ]
  297. if len(filtered_history) < 3:
  298. return False
  299. # FIXME rewrite this to be more readable
  300. # Scenario 1: the same (Action, Observation) loop
  301. # 3 pairs of (action, observation) to stop the agent
  302. last_three_tuples = filtered_history[-3:]
  303. if all(
  304. # (Action, Observation) tuples
  305. # compare the last action to the last three actions
  306. self._eq_no_pid(last_three_tuples[-1][0], _tuple[0])
  307. for _tuple in last_three_tuples
  308. ) and all(
  309. # compare the last observation to the last three observations
  310. self._eq_no_pid(last_three_tuples[-1][1], _tuple[1])
  311. for _tuple in last_three_tuples
  312. ):
  313. logger.warning('Action, Observation loop detected')
  314. return True
  315. if len(filtered_history) < 4:
  316. return False
  317. last_four_tuples = filtered_history[-4:]
  318. # Scenario 2: (action, error) pattern, not necessary identical error
  319. # 4 pairs of (action, error) to stop the agent
  320. if all(
  321. self._eq_no_pid(last_four_tuples[-1][0], _tuple[0])
  322. for _tuple in last_four_tuples
  323. ):
  324. # It repeats the same action, give it a chance, but not if:
  325. if all(
  326. isinstance(_tuple[1], ErrorObservation) for _tuple in last_four_tuples
  327. ):
  328. logger.warning('Action, ErrorObservation loop detected')
  329. return True
  330. # check if the agent repeats the same (Action, Observation)
  331. # every other step in the last six tuples
  332. # step1 = step3 = step5
  333. # step2 = step4 = step6
  334. if len(filtered_history) >= 6:
  335. last_six_tuples = filtered_history[-6:]
  336. if (
  337. # this pattern is every other step, like:
  338. # (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
  339. self._eq_no_pid(last_six_tuples[-1][0], last_six_tuples[-3][0])
  340. and self._eq_no_pid(last_six_tuples[-1][0], last_six_tuples[-5][0])
  341. and self._eq_no_pid(last_six_tuples[-2][0], last_six_tuples[-4][0])
  342. and self._eq_no_pid(last_six_tuples[-2][0], last_six_tuples[-6][0])
  343. and self._eq_no_pid(last_six_tuples[-1][1], last_six_tuples[-3][1])
  344. and self._eq_no_pid(last_six_tuples[-1][1], last_six_tuples[-5][1])
  345. and self._eq_no_pid(last_six_tuples[-2][1], last_six_tuples[-4][1])
  346. and self._eq_no_pid(last_six_tuples[-2][1], last_six_tuples[-6][1])
  347. ):
  348. logger.warning('Action, Observation pattern detected')
  349. return True
  350. return False
  351. def __repr__(self):
  352. return (
  353. f'AgentController(id={self.id}, agent={self.agent!r}, '
  354. f'event_stream={self.event_stream!r}, '
  355. f'state={self.state!r}, agent_task={self.agent_task!r}, '
  356. f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})'
  357. )
  358. def _eq_no_pid(self, obj1, obj2):
  359. if isinstance(obj1, CmdOutputObservation) and isinstance(
  360. obj2, CmdOutputObservation
  361. ):
  362. # for loop detection, ignore command_id, which is the pid
  363. return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
  364. elif isinstance(obj1, CmdKillAction) and isinstance(obj2, CmdKillAction):
  365. # for loop detection, ignore command_id, which is the pid
  366. return obj1.thought == obj2.thought
  367. else:
  368. # this is the default comparison
  369. return obj1 == obj2