agent_controller.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. import asyncio
  2. import traceback
  3. from typing import Optional, Type
  4. from opendevin.controller.agent import Agent
  5. from opendevin.controller.state.state import State, TrafficControlState
  6. from opendevin.core.config import config
  7. from opendevin.core.exceptions import (
  8. LLMMalformedActionError,
  9. LLMNoActionError,
  10. LLMResponseError,
  11. )
  12. from opendevin.core.logger import opendevin_logger as logger
  13. from opendevin.core.schema import AgentState
  14. from opendevin.events import EventSource, EventStream, EventStreamSubscriber
  15. from opendevin.events.action import (
  16. Action,
  17. AddTaskAction,
  18. AgentDelegateAction,
  19. AgentFinishAction,
  20. AgentRejectAction,
  21. ChangeAgentStateAction,
  22. MessageAction,
  23. ModifyTaskAction,
  24. NullAction,
  25. )
  26. from opendevin.events.action.commands import CmdKillAction
  27. from opendevin.events.event import Event
  28. from opendevin.events.observation import (
  29. AgentDelegateObservation,
  30. AgentStateChangedObservation,
  31. CmdOutputObservation,
  32. ErrorObservation,
  33. NullObservation,
  34. Observation,
  35. )
  36. MAX_ITERATIONS = config.max_iterations
  37. MAX_BUDGET_PER_TASK = config.max_budget_per_task
  38. # note: RESUME is only available on web GUI
  39. TRAFFIC_CONTROL_REMINDER = (
  40. "Please click on resume button if you'd like to continue, or start a new task."
  41. )
  42. class AgentController:
  43. id: str
  44. agent: Agent
  45. max_iterations: int
  46. event_stream: EventStream
  47. state: State
  48. agent_task: Optional[asyncio.Task] = None
  49. parent: 'AgentController | None' = None
  50. delegate: 'AgentController | None' = None
  51. _pending_action: Action | None = None
  52. def __init__(
  53. self,
  54. agent: Agent,
  55. event_stream: EventStream,
  56. sid: str = 'default',
  57. max_iterations: int | None = MAX_ITERATIONS,
  58. max_budget_per_task: float | None = MAX_BUDGET_PER_TASK,
  59. initial_state: State | None = None,
  60. is_delegate: bool = False,
  61. ):
  62. """Initializes a new instance of the AgentController class.
  63. Args:
  64. agent: The agent instance to control.
  65. event_stream: The event stream to publish events to.
  66. sid: The session ID of the agent.
  67. max_iterations: The maximum number of iterations the agent can run.
  68. max_budget_per_task: The maximum budget (in USD) allowed per task, beyond which the agent will stop.
  69. initial_state: The initial state of the controller.
  70. is_delegate: Whether this controller is a delegate.
  71. """
  72. self._step_lock = asyncio.Lock()
  73. self.id = sid
  74. self.agent = agent
  75. # subscribe to the event stream
  76. self.event_stream = event_stream
  77. self.event_stream.subscribe(
  78. EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
  79. )
  80. # state from the previous session, state from a parent agent, or a fresh state
  81. max_iterations = (
  82. max_iterations if max_iterations is not None else MAX_ITERATIONS
  83. )
  84. self.set_initial_state(
  85. state=initial_state,
  86. max_iterations=max_iterations,
  87. )
  88. self.max_budget_per_task = max_budget_per_task
  89. if not is_delegate:
  90. self.agent_task = asyncio.create_task(self._start_step_loop())
  91. async def close(self):
  92. if self.agent_task is not None:
  93. self.agent_task.cancel()
  94. await self.set_agent_state_to(AgentState.STOPPED)
  95. self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
  96. def update_state_before_step(self):
  97. self.state.iteration += 1
  98. async def update_state_after_step(self):
  99. # update metrics especially for cost
  100. self.state.metrics = self.agent.llm.metrics
  101. async def report_error(self, message: str, exception: Exception | None = None):
  102. """
  103. This error will be reported to the user and sent to the LLM next step, in the hope it can self-correct.
  104. This method should be called for a particular type of errors:
  105. - the string message should be user-friendly, it will be shown in the UI
  106. - an ErrorObservation can be sent to the LLM by the agent, with the exception message, so it can self-correct next time
  107. """
  108. self.state.last_error = message
  109. if exception:
  110. self.state.last_error += f': {exception}'
  111. self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
  112. async def add_history(self, action: Action, observation: Observation):
  113. if isinstance(action, NullAction) and isinstance(observation, NullObservation):
  114. return
  115. self.state.history.append((action, observation))
  116. async def _start_step_loop(self):
  117. logger.info(f'[Agent Controller {self.id}] Starting step loop...')
  118. while True:
  119. try:
  120. await self._step()
  121. except asyncio.CancelledError:
  122. logger.info('AgentController task was cancelled')
  123. break
  124. except Exception as e:
  125. traceback.print_exc()
  126. logger.error(f'Error while running the agent: {e}')
  127. logger.error(traceback.format_exc())
  128. await self.report_error(
  129. 'There was an unexpected error while running the agent', exception=e
  130. )
  131. await self.set_agent_state_to(AgentState.ERROR)
  132. break
  133. await asyncio.sleep(0.1)
  134. async def on_event(self, event: Event):
  135. if isinstance(event, ChangeAgentStateAction):
  136. await self.set_agent_state_to(event.agent_state) # type: ignore
  137. elif isinstance(event, MessageAction):
  138. if event.source == EventSource.USER:
  139. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  140. await self.add_history(event, NullObservation(''))
  141. if self.get_agent_state() != AgentState.RUNNING:
  142. await self.set_agent_state_to(AgentState.RUNNING)
  143. elif event.source == EventSource.AGENT and event.wait_for_response:
  144. logger.info(event, extra={'msg_type': 'ACTION'})
  145. await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
  146. elif isinstance(event, AgentDelegateAction):
  147. await self.start_delegate(event)
  148. elif isinstance(event, AddTaskAction):
  149. self.state.root_task.add_subtask(event.parent, event.goal, event.subtasks)
  150. elif isinstance(event, ModifyTaskAction):
  151. self.state.root_task.set_subtask_state(event.task_id, event.state)
  152. elif isinstance(event, AgentFinishAction):
  153. self.state.outputs = event.outputs # type: ignore[attr-defined]
  154. await self.set_agent_state_to(AgentState.FINISHED)
  155. elif isinstance(event, AgentRejectAction):
  156. self.state.outputs = event.outputs # type: ignore[attr-defined]
  157. await self.set_agent_state_to(AgentState.REJECTED)
  158. elif isinstance(event, Observation):
  159. if self._pending_action and self._pending_action.id == event.cause:
  160. await self.add_history(self._pending_action, event)
  161. self._pending_action = None
  162. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  163. elif isinstance(event, CmdOutputObservation):
  164. await self.add_history(NullAction(), event)
  165. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  166. elif isinstance(event, AgentDelegateObservation):
  167. await self.add_history(NullAction(), event)
  168. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  169. elif isinstance(event, ErrorObservation):
  170. await self.add_history(NullAction(), event)
  171. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  172. def reset_task(self):
  173. self.agent.reset()
  174. async def set_agent_state_to(self, new_state: AgentState):
  175. logger.debug(
  176. f'[Agent Controller {self.id}] Setting agent({self.agent.name}) state from {self.state.agent_state} to {new_state}'
  177. )
  178. if new_state == self.state.agent_state:
  179. return
  180. if (
  181. self.state.agent_state == AgentState.PAUSED
  182. and new_state == AgentState.RUNNING
  183. and self.state.traffic_control_state == TrafficControlState.THROTTLING
  184. ):
  185. # user intends to interrupt traffic control and let the task resume temporarily
  186. self.state.traffic_control_state = TrafficControlState.PAUSED
  187. self.state.agent_state = new_state
  188. if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
  189. self.reset_task()
  190. self.event_stream.add_event(
  191. AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
  192. )
  193. if new_state == AgentState.INIT and self.state.resume_state:
  194. await self.set_agent_state_to(self.state.resume_state)
  195. self.state.resume_state = None
  196. def get_agent_state(self):
  197. """Returns the current state of the agent task."""
  198. return self.state.agent_state
  199. async def start_delegate(self, action: AgentDelegateAction):
  200. agent_cls: Type[Agent] = Agent.get_cls(action.agent)
  201. agent = agent_cls(llm=self.agent.llm)
  202. state = State(
  203. inputs=action.inputs or {},
  204. iteration=0,
  205. max_iterations=self.state.max_iterations,
  206. delegate_level=self.state.delegate_level + 1,
  207. # metrics should be shared between parent and child
  208. metrics=self.state.metrics,
  209. )
  210. logger.info(f'[Agent Controller {self.id}]: start delegate')
  211. self.delegate = AgentController(
  212. sid=self.id + '-delegate',
  213. agent=agent,
  214. event_stream=self.event_stream,
  215. max_iterations=self.state.max_iterations,
  216. max_budget_per_task=self.max_budget_per_task,
  217. initial_state=state,
  218. is_delegate=True,
  219. )
  220. await self.delegate.set_agent_state_to(AgentState.RUNNING)
  221. async def _step(self):
  222. logger.debug(f'[Agent Controller {self.id}] Entering step method')
  223. if self.get_agent_state() != AgentState.RUNNING:
  224. await asyncio.sleep(1)
  225. return
  226. if self._pending_action:
  227. logger.debug(
  228. f'[Agent Controller {self.id}] waiting for pending action: {self._pending_action}'
  229. )
  230. await asyncio.sleep(1)
  231. return
  232. if self.delegate is not None:
  233. logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
  234. assert self.delegate != self
  235. await self.delegate._step()
  236. logger.debug(f'[Agent Controller {self.id}] Delegate step done')
  237. assert self.delegate is not None
  238. delegate_state = self.delegate.get_agent_state()
  239. if delegate_state == AgentState.ERROR:
  240. # close the delegate upon error
  241. await self.delegate.close()
  242. self.delegate = None
  243. self.delegateAction = None
  244. await self.report_error('Delegator agent encounters an error')
  245. return
  246. delegate_done = delegate_state in (AgentState.FINISHED, AgentState.REJECTED)
  247. if delegate_done:
  248. logger.info(
  249. f'[Agent Controller {self.id}] Delegate agent has finished execution'
  250. )
  251. # retrieve delegate result
  252. outputs = self.delegate.state.outputs if self.delegate.state else {}
  253. # close delegate controller: we must close the delegate controller before adding new events
  254. await self.delegate.close()
  255. # update delegate result observation
  256. # TODO: replace this with AI-generated summary (#2395)
  257. formatted_output = ', '.join(
  258. f'{key}: {value}' for key, value in outputs.items()
  259. )
  260. content = (
  261. f'{self.delegate.agent.name} finishes task with {formatted_output}'
  262. )
  263. obs: Observation = AgentDelegateObservation(
  264. outputs=outputs, content=content
  265. )
  266. # clean up delegate status
  267. self.delegate = None
  268. self.delegateAction = None
  269. self.event_stream.add_event(obs, EventSource.AGENT)
  270. return
  271. logger.info(
  272. f'{self.agent.name} LEVEL {self.state.delegate_level} STEP {self.state.iteration}',
  273. extra={'msg_type': 'STEP'},
  274. )
  275. if self.state.iteration >= self.state.max_iterations:
  276. if self.state.traffic_control_state == TrafficControlState.PAUSED:
  277. logger.info(
  278. 'Hitting traffic control, temporarily resume upon user request'
  279. )
  280. self.state.traffic_control_state = TrafficControlState.NORMAL
  281. else:
  282. self.state.traffic_control_state = TrafficControlState.THROTTLING
  283. await self.report_error(
  284. f'Agent reached maximum number of iterations, task paused. {TRAFFIC_CONTROL_REMINDER}'
  285. )
  286. await self.set_agent_state_to(AgentState.PAUSED)
  287. return
  288. elif self.max_budget_per_task is not None:
  289. current_cost = self.state.metrics.accumulated_cost
  290. if current_cost > self.max_budget_per_task:
  291. if self.state.traffic_control_state == TrafficControlState.PAUSED:
  292. logger.info(
  293. 'Hitting traffic control, temporarily resume upon user request'
  294. )
  295. self.state.traffic_control_state = TrafficControlState.NORMAL
  296. else:
  297. self.state.traffic_control_state = TrafficControlState.THROTTLING
  298. await self.report_error(
  299. f'Task budget exceeded. Current cost: {current_cost:.2f}, Max budget: {self.max_budget_per_task:.2f}, task paused. {TRAFFIC_CONTROL_REMINDER}'
  300. )
  301. await self.set_agent_state_to(AgentState.PAUSED)
  302. return
  303. self.update_state_before_step()
  304. action: Action = NullAction()
  305. try:
  306. action = self.agent.step(self.state)
  307. if action is None:
  308. raise LLMNoActionError('No action was returned')
  309. except (LLMMalformedActionError, LLMNoActionError, LLMResponseError) as e:
  310. # report to the user
  311. # and send the underlying exception to the LLM for self-correction
  312. await self.report_error(str(e))
  313. return
  314. logger.info(action, extra={'msg_type': 'ACTION'})
  315. if action.runnable:
  316. self._pending_action = action
  317. else:
  318. await self.add_history(action, NullObservation(''))
  319. if not isinstance(action, NullAction):
  320. self.event_stream.add_event(action, EventSource.AGENT)
  321. await self.update_state_after_step()
  322. if self._is_stuck():
  323. await self.report_error('Agent got stuck in a loop')
  324. await self.set_agent_state_to(AgentState.ERROR)
  325. def get_state(self):
  326. return self.state
  327. def set_initial_state(
  328. self, state: State | None, max_iterations: int = MAX_ITERATIONS
  329. ):
  330. # state from the previous session, state from a parent agent, or a new state
  331. # note that this is called twice when restoring a previous session, first with state=None
  332. if state is None:
  333. self.state = State(inputs={}, max_iterations=max_iterations)
  334. else:
  335. self.state = state
  336. def _is_stuck(self):
  337. # check if delegate stuck
  338. if self.delegate and self.delegate._is_stuck():
  339. return True
  340. # filter out MessageAction with source='user' from history
  341. filtered_history = [
  342. _tuple
  343. for _tuple in self.state.history
  344. if not (
  345. isinstance(_tuple[0], MessageAction)
  346. and _tuple[0].source == EventSource.USER
  347. )
  348. ]
  349. if len(filtered_history) < 3:
  350. return False
  351. # FIXME rewrite this to be more readable
  352. # Scenario 1: the same (Action, Observation) loop
  353. # 3 pairs of (action, observation) to stop the agent
  354. last_three_tuples = filtered_history[-3:]
  355. if all(
  356. # (Action, Observation) tuples
  357. # compare the last action to the last three actions
  358. self._eq_no_pid(last_three_tuples[-1][0], _tuple[0])
  359. for _tuple in last_three_tuples
  360. ) and all(
  361. # compare the last observation to the last three observations
  362. self._eq_no_pid(last_three_tuples[-1][1], _tuple[1])
  363. for _tuple in last_three_tuples
  364. ):
  365. logger.warning('Action, Observation loop detected')
  366. return True
  367. if len(filtered_history) < 4:
  368. return False
  369. last_four_tuples = filtered_history[-4:]
  370. # Scenario 2: (action, error) pattern, not necessary identical error
  371. # 4 pairs of (action, error) to stop the agent
  372. if all(
  373. self._eq_no_pid(last_four_tuples[-1][0], _tuple[0])
  374. for _tuple in last_four_tuples
  375. ):
  376. # It repeats the same action, give it a chance, but not if:
  377. if all(
  378. isinstance(_tuple[1], ErrorObservation) for _tuple in last_four_tuples
  379. ):
  380. logger.warning('Action, ErrorObservation loop detected')
  381. return True
  382. # check if the agent repeats the same (Action, Observation)
  383. # every other step in the last six tuples
  384. # step1 = step3 = step5
  385. # step2 = step4 = step6
  386. if len(filtered_history) >= 6:
  387. last_six_tuples = filtered_history[-6:]
  388. if (
  389. # this pattern is every other step, like:
  390. # (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
  391. self._eq_no_pid(last_six_tuples[-1][0], last_six_tuples[-3][0])
  392. and self._eq_no_pid(last_six_tuples[-1][0], last_six_tuples[-5][0])
  393. and self._eq_no_pid(last_six_tuples[-2][0], last_six_tuples[-4][0])
  394. and self._eq_no_pid(last_six_tuples[-2][0], last_six_tuples[-6][0])
  395. and self._eq_no_pid(last_six_tuples[-1][1], last_six_tuples[-3][1])
  396. and self._eq_no_pid(last_six_tuples[-1][1], last_six_tuples[-5][1])
  397. and self._eq_no_pid(last_six_tuples[-2][1], last_six_tuples[-4][1])
  398. and self._eq_no_pid(last_six_tuples[-2][1], last_six_tuples[-6][1])
  399. ):
  400. logger.warning('Action, Observation pattern detected')
  401. return True
  402. return False
  403. def __repr__(self):
  404. return (
  405. f'AgentController(id={self.id}, agent={self.agent!r}, '
  406. f'event_stream={self.event_stream!r}, '
  407. f'state={self.state!r}, agent_task={self.agent_task!r}, '
  408. f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})'
  409. )
  410. def _eq_no_pid(self, obj1, obj2):
  411. if isinstance(obj1, CmdOutputObservation) and isinstance(
  412. obj2, CmdOutputObservation
  413. ):
  414. # for loop detection, ignore command_id, which is the pid
  415. return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
  416. elif isinstance(obj1, CmdKillAction) and isinstance(obj2, CmdKillAction):
  417. # for loop detection, ignore command_id, which is the pid
  418. return obj1.thought == obj2.thought
  419. else:
  420. # this is the default comparison
  421. return obj1 == obj2