agent_controller.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import asyncio
  2. import time
  3. from typing import List, Callable
  4. from opendevin.plan import Plan
  5. from opendevin.state import State
  6. from opendevin.agent import Agent
  7. from opendevin.observation import Observation, AgentErrorObservation, NullObservation
  8. from litellm.exceptions import APIConnectionError
  9. from openai import AuthenticationError
  10. from opendevin import config
  11. from opendevin.logger import opendevin_logger as logger
  12. from opendevin.exceptions import MaxCharsExceedError
  13. from .action_manager import ActionManager
  14. from opendevin.action import (
  15. Action,
  16. NullAction,
  17. AgentFinishAction,
  18. )
  19. from opendevin.exceptions import AgentNoActionError
  20. from ..action.tasks import TaskStateChangedAction
  21. from ..schema import TaskState
  22. MAX_ITERATIONS = config.get('MAX_ITERATIONS')
  23. MAX_CHARS = config.get('MAX_CHARS')
  24. class AgentController:
  25. id: str
  26. agent: Agent
  27. max_iterations: int
  28. action_manager: ActionManager
  29. callbacks: List[Callable]
  30. state: State | None = None
  31. _task_state: TaskState = TaskState.INIT
  32. _cur_step: int = 0
  33. def __init__(
  34. self,
  35. agent: Agent,
  36. sid: str = '',
  37. max_iterations: int = MAX_ITERATIONS,
  38. max_chars: int = MAX_CHARS,
  39. container_image: str | None = None,
  40. callbacks: List[Callable] = [],
  41. ):
  42. self.id = sid
  43. self.agent = agent
  44. self.max_iterations = max_iterations
  45. self.action_manager = ActionManager(self.id, container_image)
  46. self.max_chars = max_chars
  47. self.callbacks = callbacks
  48. def update_state_for_step(self, i):
  49. if self.state is None:
  50. return
  51. self.state.iteration = i
  52. self.state.background_commands_obs = self.action_manager.get_background_obs()
  53. def update_state_after_step(self):
  54. if self.state is None:
  55. return
  56. self.state.updated_info = []
  57. def add_history(self, action: Action, observation: Observation):
  58. if self.state is None:
  59. return
  60. if not isinstance(action, Action):
  61. raise TypeError(
  62. f'action must be an instance of Action, got {type(action).__name__} instead'
  63. )
  64. if not isinstance(observation, Observation):
  65. raise TypeError(
  66. f'observation must be an instance of Observation, got {type(observation).__name__} instead'
  67. )
  68. self.state.history.append((action, observation))
  69. self.state.updated_info.append((action, observation))
  70. async def _run(self):
  71. if self.state is None:
  72. return
  73. if self._task_state != TaskState.RUNNING:
  74. raise ValueError('Task is not in running state')
  75. for i in range(self._cur_step, self.max_iterations):
  76. self._cur_step = i
  77. try:
  78. finished = await self.step(i)
  79. if finished:
  80. self._task_state = TaskState.FINISHED
  81. except Exception as e:
  82. logger.error('Error in loop', exc_info=True)
  83. raise e
  84. if self._task_state == TaskState.FINISHED:
  85. logger.info('Task finished by agent')
  86. await self.reset_task()
  87. break
  88. elif self._task_state == TaskState.STOPPED:
  89. logger.info('Task stopped by user')
  90. await self.reset_task()
  91. break
  92. elif self._task_state == TaskState.PAUSED:
  93. logger.info('Task paused')
  94. self._cur_step = i + 1
  95. await self.notify_task_state_changed()
  96. break
  97. async def start(self, task: str):
  98. """Starts the agent controller with a task.
  99. If task already run before, it will continue from the last step.
  100. """
  101. self._task_state = TaskState.RUNNING
  102. await self.notify_task_state_changed()
  103. self.state = State(Plan(task))
  104. await self._run()
  105. async def resume(self):
  106. if self.state is None:
  107. raise ValueError('No task to resume')
  108. self._task_state = TaskState.RUNNING
  109. await self.notify_task_state_changed()
  110. await self._run()
  111. async def reset_task(self):
  112. self.state = None
  113. self._cur_step = 0
  114. self._task_state = TaskState.INIT
  115. self.agent.reset()
  116. await self.notify_task_state_changed()
  117. async def set_task_state_to(self, state: TaskState):
  118. self._task_state = state
  119. if state == TaskState.STOPPED:
  120. await self.reset_task()
  121. logger.info(f'Task state set to {state}')
  122. def get_task_state(self):
  123. """Returns the current state of the agent task."""
  124. return self._task_state
  125. async def notify_task_state_changed(self):
  126. await self._run_callbacks(TaskStateChangedAction(self._task_state))
  127. async def step(self, i: int):
  128. if self.state is None:
  129. return
  130. logger.info(f'STEP {i}', extra={'msg_type': 'STEP'})
  131. logger.info(self.state.plan.main_goal, extra={'msg_type': 'PLAN'})
  132. if self.state.num_of_chars > self.max_chars:
  133. raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
  134. log_obs = self.action_manager.get_background_obs()
  135. for obs in log_obs:
  136. self.add_history(NullAction(), obs)
  137. await self._run_callbacks(obs)
  138. logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
  139. self.update_state_for_step(i)
  140. action: Action = NullAction()
  141. observation: Observation = NullObservation('')
  142. try:
  143. action = self.agent.step(self.state)
  144. if action is None:
  145. raise AgentNoActionError()
  146. logger.info(action, extra={'msg_type': 'ACTION'})
  147. except Exception as e:
  148. observation = AgentErrorObservation(str(e))
  149. logger.error(e)
  150. if isinstance(e, APIConnectionError):
  151. time.sleep(3)
  152. # raise specific exceptions that need to be handled outside
  153. # note: we are using AuthenticationError class from openai rather than
  154. # litellm because:
  155. # 1) litellm.exceptions.AuthenticationError is a subclass of openai.AuthenticationError
  156. # 2) embeddings call, initiated by llama-index, has no wrapper for authentication
  157. # errors. This means we have to catch individual authentication errors
  158. # from different providers, and OpenAI is one of these.
  159. if isinstance(e, (AuthenticationError, AgentNoActionError)):
  160. raise
  161. self.update_state_after_step()
  162. await self._run_callbacks(action)
  163. finished = isinstance(action, AgentFinishAction)
  164. if finished:
  165. logger.info(action, extra={'msg_type': 'INFO'})
  166. return True
  167. if isinstance(observation, NullObservation):
  168. observation = await self.action_manager.run_action(action, self)
  169. if not isinstance(observation, NullObservation):
  170. logger.info(observation, extra={'msg_type': 'OBSERVATION'})
  171. self.add_history(action, observation)
  172. await self._run_callbacks(observation)
  173. async def _run_callbacks(self, event):
  174. if event is None:
  175. return
  176. for callback in self.callbacks:
  177. idx = self.callbacks.index(callback)
  178. try:
  179. await callback(event)
  180. except Exception as e:
  181. logger.exception(f'Callback error: {e}, idx: {idx}')
  182. await asyncio.sleep(
  183. 0.001
  184. ) # Give back control for a tick, so we can await in callbacks