agent.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import asyncio
  2. from typing import Optional, Dict, List
  3. from opendevin import config
  4. from opendevin.action import (
  5. Action,
  6. NullAction,
  7. )
  8. from opendevin.agent import Agent
  9. from opendevin.controller import AgentController
  10. from opendevin.llm.llm import LLM
  11. from opendevin.logger import opendevin_logger as logger
  12. from opendevin.observation import NullObservation, Observation, UserMessageObservation
  13. from opendevin.schema import ActionType, ConfigType, TaskState, TaskStateAction
  14. from opendevin.server.session import session_manager
  15. # new task state to valid old task states
  16. VALID_TASK_STATE_MAP: Dict[TaskStateAction, List[TaskState]] = {
  17. TaskStateAction.PAUSE: [TaskState.RUNNING],
  18. TaskStateAction.RESUME: [TaskState.PAUSED],
  19. TaskStateAction.STOP: [TaskState.RUNNING, TaskState.PAUSED],
  20. }
  21. IGNORED_TASK_STATE_MAP: Dict[TaskStateAction, List[TaskState]] = {
  22. TaskStateAction.PAUSE: [
  23. TaskState.INIT,
  24. TaskState.PAUSED,
  25. TaskState.STOPPED,
  26. TaskState.FINISHED,
  27. ],
  28. TaskStateAction.RESUME: [
  29. TaskState.INIT,
  30. TaskState.RUNNING,
  31. TaskState.STOPPED,
  32. TaskState.FINISHED,
  33. ],
  34. TaskStateAction.STOP: [TaskState.INIT, TaskState.STOPPED, TaskState.FINISHED],
  35. }
  36. TASK_STATE_ACTION_MAP: Dict[TaskStateAction, TaskState] = {
  37. TaskStateAction.START: TaskState.RUNNING,
  38. TaskStateAction.PAUSE: TaskState.PAUSED,
  39. TaskStateAction.RESUME: TaskState.RUNNING,
  40. TaskStateAction.STOP: TaskState.STOPPED,
  41. }
  42. class AgentUnit:
  43. """Represents a session with an agent.
  44. Attributes:
  45. controller: The AgentController instance for controlling the agent.
  46. agent_task: The task representing the agent's execution.
  47. """
  48. sid: str
  49. controller: Optional[AgentController] = None
  50. agent_task: Optional[asyncio.Task] = None
  51. def __init__(self, sid):
  52. """Initializes a new instance of the Session class."""
  53. self.sid = sid
  54. async def send_error(self, message):
  55. """Sends an error message to the client.
  56. Args:
  57. message: The error message to send.
  58. """
  59. await session_manager.send_error(self.sid, message)
  60. async def send_message(self, message):
  61. """Sends a message to the client.
  62. Args:
  63. message: The message to send.
  64. """
  65. await session_manager.send_message(self.sid, message)
  66. async def send(self, data):
  67. """Sends data to the client.
  68. Args:
  69. data: The data to send.
  70. """
  71. await session_manager.send(self.sid, data)
  72. async def dispatch(self, action: str | None, data: dict):
  73. """Dispatches actions to the agent from the client."""
  74. if action is None:
  75. await self.send_error('Invalid action')
  76. return
  77. match action:
  78. case ActionType.INIT:
  79. await self.create_controller(data)
  80. case ActionType.START:
  81. await self.start_task(data)
  82. case ActionType.CHANGE_TASK_STATE:
  83. task_state_action = data.get('args', {}).get('task_state_action', None)
  84. if task_state_action is None:
  85. await self.send_error('No task state action specified.')
  86. return
  87. await self.set_task_state(TaskStateAction(task_state_action))
  88. case ActionType.CHAT:
  89. if self.controller is None:
  90. await self.send_error('No agent started. Please wait a second...')
  91. return
  92. self.controller.add_history(
  93. NullAction(), UserMessageObservation(data['message'])
  94. )
  95. case _:
  96. await self.send_error("I didn't recognize this action:" + action)
  97. def get_arg_or_default(self, _args: dict, key: ConfigType) -> str:
  98. """Gets an argument from the args dictionary or the default value.
  99. Args:
  100. _args: The args dictionary.
  101. key: The key to get.
  102. Returns:
  103. The value of the key or the default value.
  104. """
  105. return _args.get(key, config.get(key))
  106. async def create_controller(self, start_event: dict):
  107. """Creates an AgentController instance.
  108. Args:
  109. start_event: The start event data (optional).
  110. """
  111. args = {
  112. key: value
  113. for key, value in start_event.get('args', {}).items()
  114. if value != ''
  115. } # remove empty values, prevent FE from sending empty strings
  116. agent_cls = self.get_arg_or_default(args, ConfigType.AGENT)
  117. model = self.get_arg_or_default(args, ConfigType.LLM_MODEL)
  118. api_key = config.get(ConfigType.LLM_API_KEY)
  119. api_base = config.get(ConfigType.LLM_BASE_URL)
  120. max_iterations = self.get_arg_or_default(args, ConfigType.MAX_ITERATIONS)
  121. max_chars = self.get_arg_or_default(args, ConfigType.MAX_CHARS)
  122. logger.info(f'Creating agent {agent_cls} using LLM {model}')
  123. llm = LLM(model=model, api_key=api_key, base_url=api_base)
  124. try:
  125. self.controller = AgentController(
  126. sid=self.sid,
  127. agent=Agent.get_cls(agent_cls)(llm),
  128. max_iterations=int(max_iterations),
  129. max_chars=int(max_chars),
  130. callbacks=[self.on_agent_event],
  131. )
  132. except Exception as e:
  133. logger.exception(f'Error creating controller: {e}')
  134. await self.send_error(
  135. 'Error creating controller. Please check Docker is running using `docker ps`.'
  136. )
  137. return
  138. await self.init_done()
  139. async def init_done(self):
  140. if self.controller is None:
  141. await self.send_error('No agent started.')
  142. return
  143. await self.send(
  144. {
  145. 'action': ActionType.INIT,
  146. 'message': 'Control loop started.',
  147. }
  148. )
  149. await self.controller.notify_task_state_changed()
  150. async def start_task(self, start_event):
  151. """Starts a task for the agent.
  152. Args:
  153. start_event: The start event data.
  154. """
  155. if 'task' not in start_event['args']:
  156. await self.send_error('No task specified')
  157. return
  158. await self.send_message('Starting new task...')
  159. task = start_event['args']['task']
  160. if self.controller is None:
  161. await self.send_error('No agent started. Please wait a second...')
  162. return
  163. try:
  164. if self.agent_task:
  165. self.agent_task.cancel()
  166. self.agent_task = asyncio.create_task(
  167. self.controller.start(task), name='agent start task loop'
  168. )
  169. except Exception as e:
  170. await self.send_error(f'Error during task loop: {e}')
  171. async def set_task_state(self, new_state_action: TaskStateAction):
  172. """Sets the state of the agent task."""
  173. if self.controller is None:
  174. await self.send_error('No agent started.')
  175. return
  176. cur_state = self.controller.get_task_state()
  177. new_state = TASK_STATE_ACTION_MAP.get(new_state_action)
  178. if new_state is None:
  179. await self.send_error('Invalid task state action.')
  180. return
  181. if cur_state in VALID_TASK_STATE_MAP.get(new_state_action, []):
  182. await self.controller.set_task_state_to(new_state)
  183. elif cur_state in IGNORED_TASK_STATE_MAP.get(new_state_action, []):
  184. # notify once again.
  185. await self.controller.notify_task_state_changed()
  186. return
  187. else:
  188. await self.send_error('Current task state not recognized.')
  189. return
  190. if new_state_action == TaskStateAction.RESUME:
  191. if self.agent_task:
  192. self.agent_task.cancel()
  193. self.agent_task = asyncio.create_task(
  194. self.controller.resume(), name='agent resume task loop'
  195. )
  196. async def on_agent_event(self, event: Observation | Action):
  197. """Callback function for agent events.
  198. Args:
  199. event: The agent event (Observation or Action).
  200. """
  201. if isinstance(event, NullAction):
  202. return
  203. if isinstance(event, NullObservation):
  204. return
  205. await self.send(event.to_dict())
  206. def close(self):
  207. if self.agent_task:
  208. self.agent_task.cancel()
  209. if self.controller is not None:
  210. self.controller.action_manager.sandbox.close()