session.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import asyncio
  2. import time
  3. from fastapi import WebSocket, WebSocketDisconnect
  4. from openhands.controller.agent import Agent
  5. from openhands.core.config import AppConfig
  6. from openhands.core.const.guide_url import TROUBLESHOOTING_URL
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.core.schema import AgentState
  9. from openhands.core.schema.action import ActionType
  10. from openhands.core.schema.config import ConfigType
  11. from openhands.events.action import ChangeAgentStateAction, MessageAction, NullAction
  12. from openhands.events.event import Event, EventSource
  13. from openhands.events.observation import (
  14. AgentStateChangedObservation,
  15. CmdOutputObservation,
  16. NullObservation,
  17. )
  18. from openhands.events.serialization import event_from_dict, event_to_dict
  19. from openhands.events.stream import EventStreamSubscriber
  20. from openhands.llm.llm import LLM
  21. from openhands.runtime.utils.shutdown_listener import should_continue
  22. from openhands.server.session.agent import AgentSession
  23. from openhands.storage.files import FileStore
  24. DEL_DELT_SEC = 60 * 60 * 5
  25. class Session:
  26. sid: str
  27. websocket: WebSocket | None
  28. last_active_ts: int = 0
  29. is_alive: bool = True
  30. agent_session: AgentSession
  31. def __init__(
  32. self, sid: str, ws: WebSocket | None, config: AppConfig, file_store: FileStore
  33. ):
  34. self.sid = sid
  35. self.websocket = ws
  36. self.last_active_ts = int(time.time())
  37. self.agent_session = AgentSession(sid, file_store)
  38. self.agent_session.event_stream.subscribe(
  39. EventStreamSubscriber.SERVER, self.on_event
  40. )
  41. self.config = config
  42. async def close(self):
  43. self.is_alive = False
  44. await self.agent_session.close()
  45. async def loop_recv(self):
  46. try:
  47. if self.websocket is None:
  48. return
  49. while should_continue():
  50. try:
  51. data = await self.websocket.receive_json()
  52. except ValueError:
  53. await self.send_error('Invalid JSON')
  54. continue
  55. await self.dispatch(data)
  56. except WebSocketDisconnect:
  57. await self.close()
  58. logger.info('WebSocket disconnected, sid: %s', self.sid)
  59. except RuntimeError as e:
  60. await self.close()
  61. logger.exception('Error in loop_recv: %s', e)
  62. async def _initialize_agent(self, data: dict):
  63. self.agent_session.event_stream.add_event(
  64. ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
  65. )
  66. self.agent_session.event_stream.add_event(
  67. AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
  68. )
  69. # Extract the agent-relevant arguments from the request
  70. args = {
  71. key: value for key, value in data.get('args', {}).items() if value != ''
  72. }
  73. agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
  74. self.config.security.confirmation_mode = args.get(
  75. ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode
  76. )
  77. self.config.security.security_analyzer = data.get('args', {}).get(
  78. ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer
  79. )
  80. max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
  81. # override default LLM config
  82. default_llm_config = self.config.get_llm_config()
  83. default_llm_config.model = args.get(
  84. ConfigType.LLM_MODEL, default_llm_config.model
  85. )
  86. default_llm_config.api_key = args.get(
  87. ConfigType.LLM_API_KEY, default_llm_config.api_key
  88. )
  89. default_llm_config.base_url = args.get(
  90. ConfigType.LLM_BASE_URL, default_llm_config.base_url
  91. )
  92. # TODO: override other LLM config & agent config groups (#2075)
  93. llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
  94. agent_config = self.config.get_agent_config(agent_cls)
  95. agent = Agent.get_cls(agent_cls)(llm, agent_config)
  96. # Create the agent session
  97. try:
  98. await self.agent_session.start(
  99. runtime_name=self.config.runtime,
  100. config=self.config,
  101. agent=agent,
  102. max_iterations=max_iterations,
  103. max_budget_per_task=self.config.max_budget_per_task,
  104. agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
  105. agent_configs=self.config.get_agent_configs(),
  106. )
  107. except Exception as e:
  108. logger.exception(f'Error creating controller: {e}')
  109. await self.send_error(
  110. f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
  111. )
  112. return
  113. self.agent_session.event_stream.add_event(
  114. ChangeAgentStateAction(AgentState.INIT), EventSource.USER
  115. )
  116. async def on_event(self, event: Event):
  117. """Callback function for agent events.
  118. Args:
  119. event: The agent event (Observation or Action).
  120. """
  121. if isinstance(event, NullAction):
  122. return
  123. if isinstance(event, NullObservation):
  124. return
  125. if event.source == EventSource.AGENT:
  126. logger.info('Server event')
  127. await self.send(event_to_dict(event))
  128. elif event.source == EventSource.USER and isinstance(
  129. event, CmdOutputObservation
  130. ):
  131. await self.send(event_to_dict(event))
  132. async def dispatch(self, data: dict):
  133. action = data.get('action', '')
  134. if action == ActionType.INIT:
  135. await self._initialize_agent(data)
  136. return
  137. event = event_from_dict(data.copy())
  138. # This checks if the model supports images
  139. if isinstance(event, MessageAction) and event.images_urls:
  140. controller = self.agent_session.controller
  141. if controller:
  142. if controller.agent.llm.config.disable_vision:
  143. await self.send_error(
  144. 'Support for images is disabled for this model, try without an image.'
  145. )
  146. return
  147. if not controller.agent.llm.vision_is_active():
  148. await self.send_error(
  149. 'Model does not support image upload, change to a different model or try without an image.'
  150. )
  151. return
  152. self.agent_session.event_stream.add_event(event, EventSource.USER)
  153. async def send(self, data: dict[str, object]) -> bool:
  154. try:
  155. if self.websocket is None or not self.is_alive:
  156. return False
  157. await self.websocket.send_json(data)
  158. await asyncio.sleep(0.001) # This flushes the data to the client
  159. self.last_active_ts = int(time.time())
  160. return True
  161. except WebSocketDisconnect:
  162. self.is_alive = False
  163. return False
  164. async def send_error(self, message: str) -> bool:
  165. """Sends an error message to the client."""
  166. return await self.send({'error': True, 'message': message})
  167. async def send_message(self, message: str) -> bool:
  168. """Sends a message to the client."""
  169. return await self.send({'message': message})
  170. def update_connection(self, ws: WebSocket):
  171. self.websocket = ws
  172. self.is_alive = True
  173. self.last_active_ts = int(time.time())
  174. def load_from_data(self, data: dict) -> bool:
  175. self.last_active_ts = data.get('last_active_ts', 0)
  176. if self.last_active_ts < int(time.time()) - DEL_DELT_SEC:
  177. return False
  178. self.is_alive = data.get('is_alive', False)
  179. return True