session.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import asyncio
  2. import time
  3. from fastapi import WebSocket, WebSocketDisconnect
  4. from opendevin.controller.agent import Agent
  5. from opendevin.core.config import AppConfig
  6. from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
  7. from opendevin.core.logger import opendevin_logger as logger
  8. from opendevin.core.schema import AgentState
  9. from opendevin.core.schema.action import ActionType
  10. from opendevin.core.schema.config import ConfigType
  11. from opendevin.events.action import ChangeAgentStateAction, MessageAction, NullAction
  12. from opendevin.events.event import Event, EventSource
  13. from opendevin.events.observation import (
  14. AgentStateChangedObservation,
  15. CmdOutputObservation,
  16. NullObservation,
  17. )
  18. from opendevin.events.serialization import event_from_dict, event_to_dict
  19. from opendevin.events.stream import EventStreamSubscriber
  20. from opendevin.llm.llm import LLM
  21. from opendevin.storage.files import FileStore
  22. from .agent import AgentSession
  23. DEL_DELT_SEC = 60 * 60 * 5
  24. class Session:
  25. sid: str
  26. websocket: WebSocket | None
  27. last_active_ts: int = 0
  28. is_alive: bool = True
  29. agent_session: AgentSession
  30. def __init__(
  31. self, sid: str, ws: WebSocket | None, config: AppConfig, file_store: FileStore
  32. ):
  33. self.sid = sid
  34. self.websocket = ws
  35. self.last_active_ts = int(time.time())
  36. self.agent_session = AgentSession(sid, file_store)
  37. self.agent_session.event_stream.subscribe(
  38. EventStreamSubscriber.SERVER, self.on_event
  39. )
  40. self.config = config
  41. async def close(self):
  42. self.is_alive = False
  43. await self.agent_session.close()
  44. async def loop_recv(self):
  45. try:
  46. if self.websocket is None:
  47. return
  48. while True:
  49. try:
  50. data = await self.websocket.receive_json()
  51. except ValueError:
  52. await self.send_error('Invalid JSON')
  53. continue
  54. await self.dispatch(data)
  55. except WebSocketDisconnect:
  56. await self.close()
  57. logger.info('WebSocket disconnected, sid: %s', self.sid)
  58. except RuntimeError as e:
  59. await self.close()
  60. logger.exception('Error in loop_recv: %s', e)
  61. async def _initialize_agent(self, data: dict):
  62. self.agent_session.event_stream.add_event(
  63. ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
  64. )
  65. self.agent_session.event_stream.add_event(
  66. AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
  67. )
  68. # Extract the agent-relevant arguments from the request
  69. args = {
  70. key: value for key, value in data.get('args', {}).items() if value != ''
  71. }
  72. agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
  73. confirmation_mode = args.get(
  74. ConfigType.CONFIRMATION_MODE, self.config.confirmation_mode
  75. )
  76. max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
  77. # override default LLM config
  78. default_llm_config = self.config.get_llm_config()
  79. default_llm_config.model = args.get(
  80. ConfigType.LLM_MODEL, default_llm_config.model
  81. )
  82. default_llm_config.api_key = args.get(
  83. ConfigType.LLM_API_KEY, default_llm_config.api_key
  84. )
  85. default_llm_config.base_url = args.get(
  86. ConfigType.LLM_BASE_URL, default_llm_config.base_url
  87. )
  88. # TODO: override other LLM config & agent config groups (#2075)
  89. llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
  90. agent = Agent.get_cls(agent_cls)(llm)
  91. # Create the agent session
  92. try:
  93. await self.agent_session.start(
  94. runtime_name=self.config.runtime,
  95. config=self.config,
  96. agent=agent,
  97. confirmation_mode=confirmation_mode,
  98. max_iterations=max_iterations,
  99. max_budget_per_task=self.config.max_budget_per_task,
  100. agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
  101. )
  102. except Exception as e:
  103. logger.exception(f'Error creating controller: {e}')
  104. await self.send_error(
  105. f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
  106. )
  107. return
  108. self.agent_session.event_stream.add_event(
  109. ChangeAgentStateAction(AgentState.INIT), EventSource.USER
  110. )
  111. async def on_event(self, event: Event):
  112. """Callback function for agent events.
  113. Args:
  114. event: The agent event (Observation or Action).
  115. """
  116. if isinstance(event, NullAction):
  117. return
  118. if isinstance(event, NullObservation):
  119. return
  120. if event.source == EventSource.AGENT:
  121. await self.send(event_to_dict(event))
  122. elif event.source == EventSource.USER and isinstance(
  123. event, CmdOutputObservation
  124. ):
  125. await self.send(event_to_dict(event))
  126. async def dispatch(self, data: dict):
  127. action = data.get('action', '')
  128. if action == ActionType.INIT:
  129. await self._initialize_agent(data)
  130. return
  131. event = event_from_dict(data.copy())
  132. # This checks if the model supports images
  133. if isinstance(event, MessageAction) and event.images_urls:
  134. controller = self.agent_session.controller
  135. if controller and not controller.agent.llm.supports_vision():
  136. await self.send_error(
  137. 'Model does not support image upload, change to a different model or try without an image.'
  138. )
  139. return
  140. self.agent_session.event_stream.add_event(event, EventSource.USER)
  141. async def send(self, data: dict[str, object]) -> bool:
  142. try:
  143. if self.websocket is None or not self.is_alive:
  144. return False
  145. await self.websocket.send_json(data)
  146. await asyncio.sleep(0.001) # This flushes the data to the client
  147. self.last_active_ts = int(time.time())
  148. return True
  149. except WebSocketDisconnect:
  150. self.is_alive = False
  151. return False
  152. async def send_error(self, message: str) -> bool:
  153. """Sends an error message to the client."""
  154. return await self.send({'error': True, 'message': message})
  155. async def send_message(self, message: str) -> bool:
  156. """Sends a message to the client."""
  157. return await self.send({'message': message})
  158. def update_connection(self, ws: WebSocket):
  159. self.websocket = ws
  160. self.is_alive = True
  161. self.last_active_ts = int(time.time())
  162. def load_from_data(self, data: dict) -> bool:
  163. self.last_active_ts = data.get('last_active_ts', 0)
  164. if self.last_active_ts < int(time.time()) - DEL_DELT_SEC:
  165. return False
  166. self.is_alive = data.get('is_alive', False)
  167. return True