session.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import asyncio
  2. import time
  3. from fastapi import WebSocket, WebSocketDisconnect
  4. from opendevin.controller.agent import Agent
  5. from opendevin.core.config import AppConfig
  6. from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
  7. from opendevin.core.logger import opendevin_logger as logger
  8. from opendevin.core.schema import AgentState
  9. from opendevin.core.schema.action import ActionType
  10. from opendevin.core.schema.config import ConfigType
  11. from opendevin.events.action import ChangeAgentStateAction, MessageAction, NullAction
  12. from opendevin.events.event import Event, EventSource
  13. from opendevin.events.observation import (
  14. AgentStateChangedObservation,
  15. CmdOutputObservation,
  16. NullObservation,
  17. )
  18. from opendevin.events.serialization import event_from_dict, event_to_dict
  19. from opendevin.events.stream import EventStreamSubscriber
  20. from opendevin.llm.llm import LLM
  21. from opendevin.storage.files import FileStore
  22. from .agent import AgentSession
  23. DEL_DELT_SEC = 60 * 60 * 5
  24. class Session:
  25. sid: str
  26. websocket: WebSocket | None
  27. last_active_ts: int = 0
  28. is_alive: bool = True
  29. agent_session: AgentSession
  30. def __init__(
  31. self, sid: str, ws: WebSocket | None, config: AppConfig, file_store: FileStore
  32. ):
  33. self.sid = sid
  34. self.websocket = ws
  35. self.last_active_ts = int(time.time())
  36. self.agent_session = AgentSession(sid, file_store)
  37. self.agent_session.event_stream.subscribe(
  38. EventStreamSubscriber.SERVER, self.on_event
  39. )
  40. self.config = config
  41. async def close(self):
  42. self.is_alive = False
  43. await self.agent_session.close()
  44. async def loop_recv(self):
  45. try:
  46. if self.websocket is None:
  47. return
  48. while True:
  49. try:
  50. data = await self.websocket.receive_json()
  51. except ValueError:
  52. await self.send_error('Invalid JSON')
  53. continue
  54. await self.dispatch(data)
  55. except WebSocketDisconnect:
  56. await self.close()
  57. logger.info('WebSocket disconnected, sid: %s', self.sid)
  58. except RuntimeError as e:
  59. await self.close()
  60. logger.exception('Error in loop_recv: %s', e)
  61. async def _initialize_agent(self, data: dict):
  62. self.agent_session.event_stream.add_event(
  63. ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
  64. )
  65. self.agent_session.event_stream.add_event(
  66. AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
  67. )
  68. # Extract the agent-relevant arguments from the request
  69. args = {
  70. key: value for key, value in data.get('args', {}).items() if value != ''
  71. }
  72. agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
  73. self.config.security.confirmation_mode = args.get(
  74. ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode
  75. )
  76. self.config.security.security_analyzer = data.get('args', {}).get(
  77. ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer
  78. )
  79. max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
  80. # override default LLM config
  81. default_llm_config = self.config.get_llm_config()
  82. default_llm_config.model = args.get(
  83. ConfigType.LLM_MODEL, default_llm_config.model
  84. )
  85. default_llm_config.api_key = args.get(
  86. ConfigType.LLM_API_KEY, default_llm_config.api_key
  87. )
  88. default_llm_config.base_url = args.get(
  89. ConfigType.LLM_BASE_URL, default_llm_config.base_url
  90. )
  91. # TODO: override other LLM config & agent config groups (#2075)
  92. llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
  93. agent_config = self.config.get_agent_config(agent_cls)
  94. agent = Agent.get_cls(agent_cls)(llm, agent_config)
  95. # Create the agent session
  96. try:
  97. await self.agent_session.start(
  98. runtime_name=self.config.runtime,
  99. config=self.config,
  100. agent=agent,
  101. max_iterations=max_iterations,
  102. max_budget_per_task=self.config.max_budget_per_task,
  103. agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
  104. agent_configs=self.config.get_agent_configs(),
  105. )
  106. except Exception as e:
  107. logger.exception(f'Error creating controller: {e}')
  108. await self.send_error(
  109. f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
  110. )
  111. return
  112. self.agent_session.event_stream.add_event(
  113. ChangeAgentStateAction(AgentState.INIT), EventSource.USER
  114. )
  115. async def on_event(self, event: Event):
  116. """Callback function for agent events.
  117. Args:
  118. event: The agent event (Observation or Action).
  119. """
  120. if isinstance(event, NullAction):
  121. return
  122. if isinstance(event, NullObservation):
  123. return
  124. if event.source == EventSource.AGENT:
  125. logger.info('Server event')
  126. await self.send(event_to_dict(event))
  127. elif event.source == EventSource.USER and isinstance(
  128. event, CmdOutputObservation
  129. ):
  130. await self.send(event_to_dict(event))
  131. async def dispatch(self, data: dict):
  132. action = data.get('action', '')
  133. if action == ActionType.INIT:
  134. await self._initialize_agent(data)
  135. return
  136. event = event_from_dict(data.copy())
  137. # This checks if the model supports images
  138. if isinstance(event, MessageAction) and event.images_urls:
  139. controller = self.agent_session.controller
  140. if controller and not controller.agent.llm.supports_vision():
  141. await self.send_error(
  142. 'Model does not support image upload, change to a different model or try without an image.'
  143. )
  144. return
  145. self.agent_session.event_stream.add_event(event, EventSource.USER)
  146. async def send(self, data: dict[str, object]) -> bool:
  147. try:
  148. if self.websocket is None or not self.is_alive:
  149. return False
  150. await self.websocket.send_json(data)
  151. await asyncio.sleep(0.001) # This flushes the data to the client
  152. self.last_active_ts = int(time.time())
  153. return True
  154. except WebSocketDisconnect:
  155. self.is_alive = False
  156. return False
  157. async def send_error(self, message: str) -> bool:
  158. """Sends an error message to the client."""
  159. return await self.send({'error': True, 'message': message})
  160. async def send_message(self, message: str) -> bool:
  161. """Sends a message to the client."""
  162. return await self.send({'message': message})
  163. def update_connection(self, ws: WebSocket):
  164. self.websocket = ws
  165. self.is_alive = True
  166. self.last_active_ts = int(time.time())
  167. def load_from_data(self, data: dict) -> bool:
  168. self.last_active_ts = data.get('last_active_ts', 0)
  169. if self.last_active_ts < int(time.time()) - DEL_DELT_SEC:
  170. return False
  171. self.is_alive = data.get('is_alive', False)
  172. return True