manager.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. import asyncio
  2. import json
  3. import time
  4. from dataclasses import dataclass, field
  5. import socketio
  6. from openhands.core.config import AppConfig
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.events.stream import EventStream, session_exists
  9. from openhands.runtime.base import RuntimeUnavailableError
  10. from openhands.server.session.conversation import Conversation
  11. from openhands.server.session.session import ROOM_KEY, Session
  12. from openhands.server.session.session_init_data import SessionInitData
  13. from openhands.storage.files import FileStore
  14. from openhands.utils.shutdown_listener import should_continue
  15. _REDIS_POLL_TIMEOUT = 1.5
  16. _CHECK_ALIVE_INTERVAL = 15
  17. @dataclass
  18. class SessionManager:
  19. sio: socketio.AsyncServer
  20. config: AppConfig
  21. file_store: FileStore
  22. local_sessions_by_sid: dict[str, Session] = field(default_factory=dict)
  23. local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
  24. _last_alive_timestamps: dict[str, float] = field(default_factory=dict)
  25. _redis_listen_task: asyncio.Task | None = None
  26. _session_is_running_flags: dict[str, asyncio.Event] = field(default_factory=dict)
  27. _has_remote_connections_flags: dict[str, asyncio.Event] = field(
  28. default_factory=dict
  29. )
  30. async def __aenter__(self):
  31. redis_client = self._get_redis_client()
  32. if redis_client:
  33. self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
  34. return self
  35. async def __aexit__(self, exc_type, exc_value, traceback):
  36. if self._redis_listen_task:
  37. self._redis_listen_task.cancel()
  38. self._redis_listen_task = None
  39. def _get_redis_client(self):
  40. redis_client = getattr(self.sio.manager, 'redis', None)
  41. return redis_client
  42. async def _redis_subscribe(self):
  43. """
  44. We use a redis backchannel to send actions between server nodes
  45. """
  46. logger.debug('_redis_subscribe')
  47. redis_client = self._get_redis_client()
  48. pubsub = redis_client.pubsub()
  49. await pubsub.subscribe('oh_event')
  50. while should_continue():
  51. try:
  52. message = await pubsub.get_message(
  53. ignore_subscribe_messages=True, timeout=5
  54. )
  55. if message:
  56. await self._process_message(message)
  57. except asyncio.CancelledError:
  58. return
  59. except Exception:
  60. try:
  61. asyncio.get_running_loop()
  62. logger.warning(
  63. 'error_reading_from_redis', exc_info=True, stack_info=True
  64. )
  65. except RuntimeError:
  66. return # Loop has been shut down
  67. async def _process_message(self, message: dict):
  68. data = json.loads(message['data'])
  69. logger.debug(f'got_published_message:{message}')
  70. sid = data['sid']
  71. message_type = data['message_type']
  72. if message_type == 'event':
  73. session = self.local_sessions_by_sid.get(sid)
  74. if session:
  75. await session.dispatch(data['data'])
  76. elif message_type == 'is_session_running':
  77. # Another node in the cluster is asking if the current node is running the session given.
  78. session = self.local_sessions_by_sid.get(sid)
  79. if session:
  80. await self._get_redis_client().publish(
  81. 'oh_event',
  82. json.dumps({'sid': sid, 'message_type': 'session_is_running'}),
  83. )
  84. elif message_type == 'session_is_running':
  85. self._last_alive_timestamps[sid] = time.time()
  86. flag = self._session_is_running_flags.get(sid)
  87. if flag:
  88. flag.set()
  89. elif message_type == 'has_remote_connections_query':
  90. # Another node in the cluster is asking if the current node is connected to a session
  91. required = sid in self.local_connection_id_to_session_id.values()
  92. if required:
  93. await self._get_redis_client().publish(
  94. 'oh_event',
  95. json.dumps(
  96. {'sid': sid, 'message_type': 'has_remote_connections_response'}
  97. ),
  98. )
  99. elif message_type == 'has_remote_connections_response':
  100. flag = self._has_remote_connections_flags.get(sid)
  101. if flag:
  102. flag.set()
  103. elif message_type == 'session_closing':
  104. # Session closing event - We only get this in the event of graceful shutdown,
  105. # which can't be guaranteed - nodes can simply vanish unexpectedly!
  106. logger.debug(f'session_closing:{sid}')
  107. for (
  108. connection_id,
  109. local_sid,
  110. ) in self.local_connection_id_to_session_id.items():
  111. if sid == local_sid:
  112. logger.warning(
  113. 'local_connection_to_closing_session:{connection_id}:{sid}'
  114. )
  115. await self.sio.disconnect(connection_id)
  116. async def attach_to_conversation(self, sid: str) -> Conversation | None:
  117. start_time = time.time()
  118. if not await session_exists(sid, self.file_store):
  119. return None
  120. c = Conversation(sid, file_store=self.file_store, config=self.config)
  121. try:
  122. await c.connect()
  123. except RuntimeUnavailableError as e:
  124. logger.error(f'Error connecting to conversation {c.sid}: {e}')
  125. return None
  126. end_time = time.time()
  127. logger.info(
  128. f'Conversation {c.sid} connected in {end_time - start_time} seconds'
  129. )
  130. return c
  131. async def detach_from_conversation(self, conversation: Conversation):
  132. await conversation.disconnect()
  133. async def init_or_join_session(
  134. self, sid: str, connection_id: str, session_init_data: SessionInitData
  135. ):
  136. await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
  137. self.local_connection_id_to_session_id[connection_id] = sid
  138. # If we have a local session running, use that
  139. session = self.local_sessions_by_sid.get(sid)
  140. if session:
  141. logger.info(f'found_local_session:{sid}')
  142. return session.agent_session.event_stream
  143. # If there is a remote session running, retrieve existing events for that
  144. redis_client = self._get_redis_client()
  145. if redis_client and await self._is_session_running_in_cluster(sid):
  146. return EventStream(sid, self.file_store)
  147. return await self.start_local_session(sid, session_init_data)
  148. async def _is_session_running_in_cluster(self, sid: str) -> bool:
  149. """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
  150. # Create a flag for the callback
  151. flag = asyncio.Event()
  152. self._session_is_running_flags[sid] = flag
  153. try:
  154. logger.debug(f'publish:is_session_running:{sid}')
  155. await self._get_redis_client().publish(
  156. 'oh_event',
  157. json.dumps(
  158. {
  159. 'sid': sid,
  160. 'message_type': 'is_session_running',
  161. }
  162. ),
  163. )
  164. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  165. await flag.wait()
  166. result = flag.is_set()
  167. return result
  168. except TimeoutError:
  169. # Nobody replied in time
  170. return False
  171. finally:
  172. self._session_is_running_flags.pop(sid)
  173. async def _has_remote_connections(self, sid: str) -> bool:
  174. """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
  175. # Create a flag for the callback
  176. flag = asyncio.Event()
  177. self._has_remote_connections_flags[sid] = flag
  178. try:
  179. await self._get_redis_client().publish(
  180. 'oh_event',
  181. json.dumps(
  182. {
  183. 'sid': sid,
  184. 'message_type': 'has_remote_connections_query',
  185. }
  186. ),
  187. )
  188. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  189. await flag.wait()
  190. result = flag.is_set()
  191. return result
  192. except TimeoutError:
  193. # Nobody replied in time
  194. return False
  195. finally:
  196. self._has_remote_connections_flags.pop(sid)
  197. async def start_local_session(self, sid: str, session_init_data: SessionInitData):
  198. # Start a new local session
  199. logger.info(f'start_new_local_session:{sid}')
  200. session = Session(
  201. sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
  202. )
  203. self.local_sessions_by_sid[sid] = session
  204. await session.initialize_agent(session_init_data)
  205. return session.agent_session.event_stream
  206. async def send_to_event_stream(self, connection_id: str, data: dict):
  207. # If there is a local session running, send to that
  208. sid = self.local_connection_id_to_session_id.get(connection_id)
  209. if not sid:
  210. raise RuntimeError(f'no_connected_session:{connection_id}')
  211. session = self.local_sessions_by_sid.get(sid)
  212. if session:
  213. await session.dispatch(data)
  214. return
  215. redis_client = self._get_redis_client()
  216. if redis_client:
  217. # If we have a recent report that the session is alive in another pod
  218. last_alive_at = self._last_alive_timestamps.get(sid) or 0
  219. next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
  220. if next_alive_check > time.time() or self._is_session_running_in_cluster(
  221. sid
  222. ):
  223. # Send the event to the other pod
  224. await redis_client.publish(
  225. 'oh_event',
  226. json.dumps(
  227. {
  228. 'sid': sid,
  229. 'message_type': 'event',
  230. 'data': data,
  231. }
  232. ),
  233. )
  234. return
  235. raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
  236. async def disconnect_from_session(self, connection_id: str):
  237. sid = self.local_connection_id_to_session_id.pop(connection_id, None)
  238. if not sid:
  239. # This can occur if the init action was never run.
  240. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
  241. return
  242. session = self.local_sessions_by_sid.get(sid)
  243. if session:
  244. logger.info(f'close_session:{connection_id}:{sid}')
  245. if should_continue():
  246. asyncio.create_task(self._cleanup_session_later(session))
  247. else:
  248. await self._close_session(session)
  249. async def _cleanup_session_later(self, session: Session):
  250. # Once there have been no connections to a session for a reasonable period, we close it
  251. try:
  252. await asyncio.sleep(self.config.sandbox.close_delay)
  253. finally:
  254. # If the sleep was cancelled, we still want to close these
  255. await self._cleanup_session(session)
  256. async def _cleanup_session(self, session: Session):
  257. # Get local connections
  258. has_local_connections = next(
  259. (
  260. True
  261. for v in self.local_connection_id_to_session_id.values()
  262. if v == session.sid
  263. ),
  264. False,
  265. )
  266. if has_local_connections:
  267. return False
  268. # If no local connections, get connections through redis
  269. redis_client = self._get_redis_client()
  270. if redis_client and await self._has_remote_connections(session.sid):
  271. return False
  272. # We alert the cluster in case they are interested
  273. if redis_client:
  274. await redis_client.publish(
  275. 'oh_event',
  276. json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
  277. )
  278. await self._close_session(session)
  279. async def _close_session(self, session: Session):
  280. logger.info(f'_close_session:{session.sid}')
  281. # Clear up local variables
  282. connection_ids_to_remove = list(
  283. connection_id
  284. for connection_id, sid in self.local_connection_id_to_session_id.items()
  285. if sid == session.sid
  286. )
  287. for connnnection_id in connection_ids_to_remove:
  288. self.local_connection_id_to_session_id.pop(connnnection_id, None)
  289. self.local_sessions_by_sid.pop(session.sid, None)
  290. # We alert the cluster in case they are interested
  291. redis_client = self._get_redis_client()
  292. if redis_client:
  293. await redis_client.publish(
  294. 'oh_event',
  295. json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
  296. )
  297. session.close()