session.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import asyncio
  2. import time
  3. from fastapi import WebSocket, WebSocketDisconnect
  4. from openhands.controller.agent import Agent
  5. from openhands.core.config import AppConfig
  6. from openhands.core.const.guide_url import TROUBLESHOOTING_URL
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.core.schema import AgentState
  9. from openhands.core.schema.action import ActionType
  10. from openhands.core.schema.config import ConfigType
  11. from openhands.events.action import ChangeAgentStateAction, MessageAction, NullAction
  12. from openhands.events.event import Event, EventSource
  13. from openhands.events.observation import (
  14. AgentStateChangedObservation,
  15. CmdOutputObservation,
  16. NullObservation,
  17. )
  18. from openhands.events.observation.error import ErrorObservation
  19. from openhands.events.serialization import event_from_dict, event_to_dict
  20. from openhands.events.stream import EventStreamSubscriber
  21. from openhands.llm.llm import LLM
  22. from openhands.server.session.agent_session import AgentSession
  23. from openhands.storage.files import FileStore
  24. from openhands.utils.shutdown_listener import should_continue
  25. class Session:
  26. sid: str
  27. websocket: WebSocket | None
  28. last_active_ts: int = 0
  29. is_alive: bool = True
  30. agent_session: AgentSession
  31. loop: asyncio.AbstractEventLoop
  32. def __init__(
  33. self, sid: str, ws: WebSocket | None, config: AppConfig, file_store: FileStore
  34. ):
  35. self.sid = sid
  36. self.websocket = ws
  37. self.last_active_ts = int(time.time())
  38. self.agent_session = AgentSession(
  39. sid, file_store, status_callback=self.queue_status_message
  40. )
  41. self.agent_session.event_stream.subscribe(
  42. EventStreamSubscriber.SERVER, self.on_event, self.sid
  43. )
  44. self.config = config
  45. self.loop = asyncio.get_event_loop()
  46. def close(self):
  47. self.is_alive = False
  48. try:
  49. if self.websocket is not None:
  50. asyncio.run_coroutine_threadsafe(self.websocket.close(), self.loop)
  51. self.websocket = None
  52. finally:
  53. self.agent_session.close()
  54. del (
  55. self.agent_session
  56. ) # FIXME: this should not be necessary but it mitigates a memory leak
  57. async def loop_recv(self):
  58. try:
  59. if self.websocket is None:
  60. return
  61. while should_continue():
  62. try:
  63. data = await self.websocket.receive_json()
  64. except ValueError:
  65. await self.send_error('Invalid JSON')
  66. continue
  67. await self.dispatch(data)
  68. except WebSocketDisconnect:
  69. logger.info('WebSocket disconnected, sid: %s', self.sid)
  70. self.close()
  71. except RuntimeError as e:
  72. logger.exception('Error in loop_recv: %s', e)
  73. self.close()
  74. async def _initialize_agent(self, data: dict):
  75. self.agent_session.event_stream.add_event(
  76. ChangeAgentStateAction(AgentState.LOADING), EventSource.ENVIRONMENT
  77. )
  78. self.agent_session.event_stream.add_event(
  79. AgentStateChangedObservation('', AgentState.LOADING),
  80. EventSource.ENVIRONMENT,
  81. )
  82. # Extract the agent-relevant arguments from the request
  83. args = {key: value for key, value in data.get('args', {}).items()}
  84. agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
  85. self.config.security.confirmation_mode = args.get(
  86. ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode
  87. )
  88. self.config.security.security_analyzer = data.get('args', {}).get(
  89. ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer
  90. )
  91. max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
  92. # override default LLM config
  93. default_llm_config = self.config.get_llm_config()
  94. default_llm_config.model = args.get(
  95. ConfigType.LLM_MODEL, default_llm_config.model
  96. )
  97. default_llm_config.api_key = args.get(
  98. ConfigType.LLM_API_KEY, default_llm_config.api_key
  99. )
  100. default_llm_config.base_url = args.get(
  101. ConfigType.LLM_BASE_URL, default_llm_config.base_url
  102. )
  103. # TODO: override other LLM config & agent config groups (#2075)
  104. llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
  105. agent_config = self.config.get_agent_config(agent_cls)
  106. agent = Agent.get_cls(agent_cls)(llm, agent_config)
  107. try:
  108. await self.agent_session.start(
  109. runtime_name=self.config.runtime,
  110. config=self.config,
  111. agent=agent,
  112. max_iterations=max_iterations,
  113. max_budget_per_task=self.config.max_budget_per_task,
  114. agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
  115. agent_configs=self.config.get_agent_configs(),
  116. )
  117. except Exception as e:
  118. logger.exception(f'Error creating controller: {e}')
  119. await self.send_error(
  120. f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
  121. )
  122. return
  123. async def on_event(self, event: Event):
  124. """Callback function for events that mainly come from the agent.
  125. Event is the base class for any agent action and observation.
  126. Args:
  127. event: The agent event (Observation or Action).
  128. """
  129. if isinstance(event, NullAction):
  130. return
  131. if isinstance(event, NullObservation):
  132. return
  133. if event.source == EventSource.AGENT:
  134. await self.send(event_to_dict(event))
  135. elif event.source == EventSource.USER and isinstance(
  136. event, CmdOutputObservation
  137. ):
  138. await self.send(event_to_dict(event))
  139. # NOTE: ipython observations are not sent here currently
  140. elif event.source == EventSource.ENVIRONMENT and isinstance(
  141. event, (CmdOutputObservation, AgentStateChangedObservation)
  142. ):
  143. # feedback from the environment to agent actions is understood as agent events by the UI
  144. event_dict = event_to_dict(event)
  145. event_dict['source'] = EventSource.AGENT
  146. await self.send(event_dict)
  147. elif isinstance(event, ErrorObservation):
  148. # send error events as agent events to the UI
  149. event_dict = event_to_dict(event)
  150. event_dict['source'] = EventSource.AGENT
  151. await self.send(event_dict)
  152. async def dispatch(self, data: dict):
  153. action = data.get('action', '')
  154. if action == ActionType.INIT:
  155. await self._initialize_agent(data)
  156. return
  157. event = event_from_dict(data.copy())
  158. # This checks if the model supports images
  159. if isinstance(event, MessageAction) and event.image_urls:
  160. controller = self.agent_session.controller
  161. if controller:
  162. if controller.agent.llm.config.disable_vision:
  163. await self.send_error(
  164. 'Support for images is disabled for this model, try without an image.'
  165. )
  166. return
  167. if not controller.agent.llm.vision_is_active():
  168. await self.send_error(
  169. 'Model does not support image upload, change to a different model or try without an image.'
  170. )
  171. return
  172. self.agent_session.event_stream.add_event(event, EventSource.USER)
  173. async def send(self, data: dict[str, object]) -> bool:
  174. try:
  175. if self.websocket is None or not self.is_alive:
  176. return False
  177. await self.websocket.send_json(data)
  178. await asyncio.sleep(0.001) # This flushes the data to the client
  179. self.last_active_ts = int(time.time())
  180. return True
  181. except (RuntimeError, WebSocketDisconnect):
  182. self.is_alive = False
  183. return False
  184. async def send_error(self, message: str) -> bool:
  185. """Sends an error message to the client."""
  186. return await self.send({'error': True, 'message': message})
  187. async def _send_status_message(self, msg_type: str, id: str, message: str) -> bool:
  188. """Sends a status message to the client."""
  189. if msg_type == 'error':
  190. await self.agent_session.stop_agent_loop_for_error()
  191. return await self.send(
  192. {'status_update': True, 'type': msg_type, 'id': id, 'message': message}
  193. )
  194. def queue_status_message(self, msg_type: str, id: str, message: str):
  195. """Queues a status message to be sent asynchronously."""
  196. asyncio.run_coroutine_threadsafe(
  197. self._send_status_message(msg_type, id, message), self.loop
  198. )