agent_controller.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import asyncio
  2. import time
  3. from typing import List, Callable
  4. from opendevin.plan import Plan
  5. from opendevin.state import State
  6. from opendevin.agent import Agent
  7. from opendevin.observation import Observation, AgentErrorObservation, NullObservation
  8. from litellm.exceptions import APIConnectionError
  9. from openai import AuthenticationError
  10. from opendevin import config
  11. from opendevin.logger import opendevin_logger as logger
  12. from opendevin.exceptions import MaxCharsExceedError
  13. from .action_manager import ActionManager
  14. from opendevin.action import (
  15. Action,
  16. NullAction,
  17. AgentFinishAction,
  18. )
  19. from opendevin.exceptions import AgentNoActionError
  20. MAX_ITERATIONS = config.get('MAX_ITERATIONS')
  21. MAX_CHARS = config.get('MAX_CHARS')
  22. class AgentController:
  23. id: str
  24. agent: Agent
  25. max_iterations: int
  26. action_manager: ActionManager
  27. callbacks: List[Callable]
  28. def __init__(
  29. self,
  30. agent: Agent,
  31. sid: str = '',
  32. max_iterations: int = MAX_ITERATIONS,
  33. max_chars: int = MAX_CHARS,
  34. container_image: str | None = None,
  35. callbacks: List[Callable] = [],
  36. ):
  37. self.id = sid
  38. self.agent = agent
  39. self.max_iterations = max_iterations
  40. self.action_manager = ActionManager(self.id, container_image)
  41. self.max_chars = max_chars
  42. self.callbacks = callbacks
  43. def update_state_for_step(self, i):
  44. self.state.iteration = i
  45. self.state.background_commands_obs = self.action_manager.get_background_obs()
  46. def update_state_after_step(self):
  47. self.state.updated_info = []
  48. def add_history(self, action: Action, observation: Observation):
  49. if not isinstance(action, Action):
  50. raise TypeError(
  51. f'action must be an instance of Action, got {type(action).__name__} instead')
  52. if not isinstance(observation, Observation):
  53. raise TypeError(
  54. f'observation must be an instance of Observation, got {type(observation).__name__} instead')
  55. self.state.history.append((action, observation))
  56. self.state.updated_info.append((action, observation))
  57. async def start_loop(self, task: str):
  58. finished = False
  59. plan = Plan(task)
  60. self.state = State(plan)
  61. for i in range(self.max_iterations):
  62. try:
  63. finished = await self.step(i)
  64. except Exception as e:
  65. logger.error('Error in loop', exc_info=True)
  66. raise e
  67. if finished:
  68. break
  69. if not finished:
  70. logger.info('Exited before finishing the task.')
  71. self.agent.reset()
  72. async def step(self, i: int):
  73. logger.info(f'STEP {i}', extra={'msg_type': 'STEP'})
  74. logger.info(self.state.plan.main_goal, extra={'msg_type': 'PLAN'})
  75. if self.state.num_of_chars > self.max_chars:
  76. raise MaxCharsExceedError(
  77. self.state.num_of_chars, self.max_chars)
  78. log_obs = self.action_manager.get_background_obs()
  79. for obs in log_obs:
  80. self.add_history(NullAction(), obs)
  81. await self._run_callbacks(obs)
  82. logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
  83. self.update_state_for_step(i)
  84. action: Action = NullAction()
  85. observation: Observation = NullObservation('')
  86. try:
  87. action = self.agent.step(self.state)
  88. if action is None:
  89. raise AgentNoActionError()
  90. logger.info(action, extra={'msg_type': 'ACTION'})
  91. except Exception as e:
  92. observation = AgentErrorObservation(str(e))
  93. logger.error(e)
  94. if isinstance(e, APIConnectionError):
  95. time.sleep(3)
  96. # raise specific exceptions that need to be handled outside
  97. # note: we are using AuthenticationError class from openai rather than
  98. # litellm because:
  99. # 1) litellm.exceptions.AuthenticationError is a subclass of openai.AuthenticationError
  100. # 2) embeddings call, initiated by llama-index, has no wrapper for authentication
  101. # errors. This means we have to catch individual authentication errors
  102. # from different providers, and OpenAI is one of these.
  103. if isinstance(e, (AuthenticationError, AgentNoActionError)):
  104. raise
  105. self.update_state_after_step()
  106. await self._run_callbacks(action)
  107. finished = isinstance(action, AgentFinishAction)
  108. if finished:
  109. logger.info(action, extra={'msg_type': 'INFO'})
  110. return True
  111. if isinstance(observation, NullObservation):
  112. observation = await self.action_manager.run_action(action, self)
  113. if not isinstance(observation, NullObservation):
  114. logger.info(observation, extra={'msg_type': 'OBSERVATION'})
  115. self.add_history(action, observation)
  116. await self._run_callbacks(observation)
  117. async def _run_callbacks(self, event):
  118. if event is None:
  119. return
  120. for callback in self.callbacks:
  121. idx = self.callbacks.index(callback)
  122. try:
  123. callback(event)
  124. except Exception as e:
  125. logger.exception(f'Callback error: {e}, idx: {idx}')
  126. await asyncio.sleep(
  127. 0.001
  128. ) # Give back control for a tick, so we can await in callbacks