agent_controller.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. import asyncio
  2. import traceback
  3. from typing import Optional, Type
  4. from opendevin.controller.agent import Agent
  5. from opendevin.controller.state.state import State
  6. from opendevin.core.config import config
  7. from opendevin.core.exceptions import (
  8. AgentMalformedActionError,
  9. AgentNoActionError,
  10. LLMOutputError,
  11. MaxCharsExceedError,
  12. )
  13. from opendevin.core.logger import opendevin_logger as logger
  14. from opendevin.core.schema import AgentState
  15. from opendevin.events import EventSource, EventStream, EventStreamSubscriber
  16. from opendevin.events.action import (
  17. Action,
  18. AddTaskAction,
  19. AgentDelegateAction,
  20. AgentFinishAction,
  21. ChangeAgentStateAction,
  22. MessageAction,
  23. ModifyTaskAction,
  24. NullAction,
  25. )
  26. from opendevin.events.action.commands import CmdKillAction
  27. from opendevin.events.event import Event
  28. from opendevin.events.observation import (
  29. AgentDelegateObservation,
  30. AgentStateChangedObservation,
  31. CmdOutputObservation,
  32. ErrorObservation,
  33. NullObservation,
  34. Observation,
  35. )
  36. MAX_ITERATIONS = config.max_iterations
  37. MAX_CHARS = config.llm.max_chars
  38. MAX_BUDGET_PER_TASK = config.max_budget_per_task
  39. class AgentController:
  40. id: str
  41. agent: Agent
  42. max_iterations: int
  43. event_stream: EventStream
  44. state: State
  45. agent_task: Optional[asyncio.Task] = None
  46. delegate: 'AgentController | None' = None
  47. _pending_action: Action | None = None
  48. def __init__(
  49. self,
  50. agent: Agent,
  51. event_stream: EventStream,
  52. sid: str = 'default',
  53. max_iterations: int = MAX_ITERATIONS,
  54. max_chars: int = MAX_CHARS,
  55. max_budget_per_task: float | None = MAX_BUDGET_PER_TASK,
  56. inputs: dict | None = None,
  57. ):
  58. """Initializes a new instance of the AgentController class.
  59. Args:
  60. agent: The agent instance to control.
  61. event_stream: The event stream to publish events to.
  62. sid: The session ID of the agent.
  63. max_iterations: The maximum number of iterations the agent can run.
  64. max_chars: The maximum number of characters the agent can output.
  65. max_budget_per_task: The maximum budget (in USD) allowed per task, beyond which the agent will stop.
  66. inputs: The initial inputs to the agent.
  67. """
  68. self.id = sid
  69. self.agent = agent
  70. self.state = State(inputs=inputs or {}, max_iterations=max_iterations)
  71. self.event_stream = event_stream
  72. self.event_stream.subscribe(
  73. EventStreamSubscriber.AGENT_CONTROLLER, self.on_event
  74. )
  75. self.max_iterations = max_iterations
  76. self.max_chars = max_chars
  77. self.max_budget_per_task = max_budget_per_task
  78. self.agent_task = asyncio.create_task(self._start_step_loop())
  79. async def close(self):
  80. if self.agent_task is not None:
  81. self.agent_task.cancel()
  82. self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
  83. await self.set_agent_state_to(AgentState.STOPPED)
  84. def update_state_before_step(self):
  85. self.state.iteration += 1
  86. async def update_state_after_step(self):
  87. self.state.updated_info = []
  88. # update metrics especially for cost
  89. self.state.metrics = self.agent.llm.metrics
  90. if self.max_budget_per_task is not None:
  91. current_cost = self.state.metrics.accumulated_cost
  92. if current_cost > self.max_budget_per_task:
  93. await self.report_error(
  94. f'Task budget exceeded. Current cost: {current_cost}, Max budget: {self.max_budget_per_task}'
  95. )
  96. await self.set_agent_state_to(AgentState.ERROR)
  97. async def report_error(self, message: str, exception: Exception | None = None):
  98. self.state.error = message
  99. if exception:
  100. self.state.error += f': {str(exception)}'
  101. await self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
  102. async def add_history(self, action: Action, observation: Observation):
  103. if isinstance(action, NullAction) and isinstance(observation, NullObservation):
  104. return
  105. self.state.history.append((action, observation))
  106. self.state.updated_info.append((action, observation))
  107. async def _start_step_loop(self):
  108. while True:
  109. try:
  110. await self._step()
  111. except asyncio.CancelledError:
  112. logger.info('AgentController task was cancelled')
  113. break
  114. except Exception as e:
  115. traceback.print_exc()
  116. logger.error(f'Error while running the agent: {e}')
  117. logger.error(traceback.format_exc())
  118. await self.report_error(
  119. 'There was an unexpected error while running the agent', exception=e
  120. )
  121. await self.set_agent_state_to(AgentState.ERROR)
  122. break
  123. await asyncio.sleep(0.1)
  124. async def on_event(self, event: Event):
  125. if isinstance(event, ChangeAgentStateAction):
  126. await self.set_agent_state_to(event.agent_state) # type: ignore
  127. elif isinstance(event, MessageAction):
  128. if event.source == EventSource.USER:
  129. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  130. await self.add_history(event, NullObservation(''))
  131. if self.get_agent_state() != AgentState.RUNNING:
  132. await self.set_agent_state_to(AgentState.RUNNING)
  133. elif event.source == EventSource.AGENT and event.wait_for_response:
  134. logger.info(event, extra={'msg_type': 'ACTION'})
  135. await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
  136. elif isinstance(event, AgentDelegateAction):
  137. await self.start_delegate(event)
  138. elif isinstance(event, AddTaskAction):
  139. self.state.root_task.add_subtask(event.parent, event.goal, event.subtasks)
  140. elif isinstance(event, ModifyTaskAction):
  141. self.state.root_task.set_subtask_state(event.task_id, event.state)
  142. elif isinstance(event, AgentFinishAction):
  143. self.state.outputs = event.outputs # type: ignore[attr-defined]
  144. await self.set_agent_state_to(AgentState.FINISHED)
  145. elif isinstance(event, Observation):
  146. if self._pending_action and self._pending_action.id == event.cause:
  147. await self.add_history(self._pending_action, event)
  148. self._pending_action = None
  149. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  150. elif isinstance(event, CmdOutputObservation):
  151. await self.add_history(NullAction(), event)
  152. logger.info(event, extra={'msg_type': 'OBSERVATION'})
  153. def reset_task(self):
  154. self.agent.reset()
  155. async def set_agent_state_to(self, new_state: AgentState):
  156. logger.info(
  157. f'Setting agent({type(self.agent).__name__}) state from {self.state.agent_state} to {new_state}'
  158. )
  159. if new_state == self.state.agent_state:
  160. return
  161. self.state.agent_state = new_state
  162. if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
  163. self.reset_task()
  164. await self.event_stream.add_event(
  165. AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
  166. )
  167. if new_state == AgentState.INIT and self.state.resume_state:
  168. await self.set_agent_state_to(self.state.resume_state)
  169. self.state.resume_state = None
  170. def get_agent_state(self):
  171. """Returns the current state of the agent task."""
  172. return self.state.agent_state
  173. async def start_delegate(self, action: AgentDelegateAction):
  174. AgentCls: Type[Agent] = Agent.get_cls(action.agent)
  175. agent = AgentCls(llm=self.agent.llm)
  176. self.delegate = AgentController(
  177. sid=self.id + '-delegate',
  178. agent=agent,
  179. event_stream=self.event_stream,
  180. max_iterations=self.max_iterations,
  181. max_chars=self.max_chars,
  182. inputs=action.inputs,
  183. )
  184. async def _step(self):
  185. if self.get_agent_state() != AgentState.RUNNING:
  186. logger.debug('waiting for agent to run...')
  187. await asyncio.sleep(1)
  188. return
  189. if self._pending_action:
  190. logger.debug('waiting for pending action: ' + str(self._pending_action))
  191. await asyncio.sleep(1)
  192. return
  193. logger.info(f'STEP {self.state.iteration}', extra={'msg_type': 'STEP'})
  194. if self.state.iteration >= self.max_iterations:
  195. await self.report_error('Agent reached maximum number of iterations')
  196. await self.set_agent_state_to(AgentState.ERROR)
  197. return
  198. if self.delegate is not None:
  199. delegate_done = await self.delegate._step()
  200. if delegate_done:
  201. outputs = self.delegate.state.outputs if self.delegate.state else {}
  202. obs: Observation = AgentDelegateObservation(content='', outputs=outputs)
  203. await self.event_stream.add_event(obs, EventSource.AGENT)
  204. self.delegate = None
  205. self.delegateAction = None
  206. return
  207. if self.state.num_of_chars > self.max_chars:
  208. raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
  209. self.update_state_before_step()
  210. action: Action = NullAction()
  211. try:
  212. action = self.agent.step(self.state)
  213. if action is None:
  214. raise AgentNoActionError('No action was returned')
  215. except (AgentMalformedActionError, AgentNoActionError, LLMOutputError) as e:
  216. await self.report_error(str(e))
  217. return
  218. logger.info(action, extra={'msg_type': 'ACTION'})
  219. await self.update_state_after_step()
  220. if action.runnable:
  221. self._pending_action = action
  222. else:
  223. await self.add_history(action, NullObservation(''))
  224. if not isinstance(action, NullAction):
  225. await self.event_stream.add_event(action, EventSource.AGENT)
  226. if self._is_stuck():
  227. await self.report_error('Agent got stuck in a loop')
  228. await self.set_agent_state_to(AgentState.ERROR)
  229. def get_state(self):
  230. return self.state
  231. def set_state(self, state: State):
  232. self.state = state
  233. def _is_stuck(self):
  234. # check if delegate stuck
  235. if self.delegate and self.delegate._is_stuck():
  236. return True
  237. # filter out MessageAction with source='user' from history
  238. filtered_history = [
  239. _tuple
  240. for _tuple in self.state.history
  241. if not (
  242. isinstance(_tuple[0], MessageAction)
  243. and _tuple[0].source == EventSource.USER
  244. )
  245. ]
  246. if len(filtered_history) < 4:
  247. return False
  248. # FIXME rewrite this to be more readable
  249. # Check if the last four (Action, Observation) tuples are too repetitive
  250. last_four_tuples = filtered_history[-4:]
  251. if all(
  252. # (Action, Observation) tuples
  253. # compare the last action to the last four actions
  254. self._eq_no_pid(last_four_tuples[-1][0], _tuple[0])
  255. for _tuple in last_four_tuples
  256. ) and all(
  257. # compare the last observation to the last four observations
  258. self._eq_no_pid(last_four_tuples[-1][1], _tuple[1])
  259. for _tuple in last_four_tuples
  260. ):
  261. logger.warning('Action, Observation loop detected')
  262. return True
  263. # (action, error) tuples
  264. if all(
  265. self._eq_no_pid(last_four_tuples[-1][0], _tuple[0])
  266. for _tuple in last_four_tuples
  267. ):
  268. # It repeats the same action, give it a chance, but not if:
  269. if all(
  270. isinstance(_tuple[1], ErrorObservation) for _tuple in last_four_tuples
  271. ):
  272. logger.warning('Action, ErrorObservation loop detected')
  273. return True
  274. # check if the agent repeats the same (Action, Observation)
  275. # every other step in the last six tuples
  276. if len(filtered_history) >= 6:
  277. last_six_tuples = filtered_history[-6:]
  278. if (
  279. # this pattern is every other step, like:
  280. # (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
  281. self._eq_no_pid(last_six_tuples[-1][0], last_six_tuples[-3][0])
  282. and self._eq_no_pid(last_six_tuples[-1][0], last_six_tuples[-5][0])
  283. and self._eq_no_pid(last_six_tuples[-2][0], last_six_tuples[-4][0])
  284. and self._eq_no_pid(last_six_tuples[-2][0], last_six_tuples[-6][0])
  285. and self._eq_no_pid(last_six_tuples[-1][1], last_six_tuples[-3][1])
  286. and self._eq_no_pid(last_six_tuples[-1][1], last_six_tuples[-5][1])
  287. and self._eq_no_pid(last_six_tuples[-2][1], last_six_tuples[-4][1])
  288. and self._eq_no_pid(last_six_tuples[-2][1], last_six_tuples[-6][1])
  289. ):
  290. logger.warning('Action, Observation pattern detected')
  291. return True
  292. return False
  293. def _eq_no_pid(self, obj1, obj2):
  294. if isinstance(obj1, CmdOutputObservation) and isinstance(
  295. obj2, CmdOutputObservation
  296. ):
  297. # for loop detection, ignore command_id, which is the pid
  298. return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
  299. elif isinstance(obj1, CmdKillAction) and isinstance(obj2, CmdKillAction):
  300. # for loop detection, ignore command_id, which is the pid
  301. return obj1.thought == obj2.thought
  302. else:
  303. # this is the default comparison
  304. return obj1 == obj2