manager.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import asyncio
  2. import json
  3. import time
  4. from dataclasses import dataclass, field
  5. import socketio
  6. from openhands.core.config import AppConfig
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.events.stream import EventStream, session_exists
  9. from openhands.runtime.base import RuntimeUnavailableError
  10. from openhands.server.session.conversation import Conversation
  11. from openhands.server.session.session import ROOM_KEY, Session
  12. from openhands.server.session.session_init_data import SessionInitData
  13. from openhands.storage.files import FileStore
  14. from openhands.utils.shutdown_listener import should_continue
  15. _REDIS_POLL_TIMEOUT = 1.5
  16. _CHECK_ALIVE_INTERVAL = 15
  17. @dataclass
  18. class SessionManager:
  19. sio: socketio.AsyncServer
  20. config: AppConfig
  21. file_store: FileStore
  22. local_sessions_by_sid: dict[str, Session] = field(default_factory=dict)
  23. local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
  24. _last_alive_timestamps: dict[str, float] = field(default_factory=dict)
  25. _redis_listen_task: asyncio.Task | None = None
  26. _session_is_running_flags: dict[str, asyncio.Event] = field(default_factory=dict)
  27. _has_remote_connections_flags: dict[str, asyncio.Event] = field(
  28. default_factory=dict
  29. )
  30. async def __aenter__(self):
  31. redis_client = self._get_redis_client()
  32. if redis_client:
  33. self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
  34. return self
  35. async def __aexit__(self, exc_type, exc_value, traceback):
  36. if self._redis_listen_task:
  37. self._redis_listen_task.cancel()
  38. self._redis_listen_task = None
  39. def _get_redis_client(self):
  40. redis_client = getattr(self.sio.manager, 'redis', None)
  41. return redis_client
  42. async def _redis_subscribe(self):
  43. """
  44. We use a redis backchannel to send actions between server nodes
  45. """
  46. redis_client = self._get_redis_client()
  47. pubsub = redis_client.pubsub()
  48. await pubsub.subscribe('oh_event')
  49. while should_continue():
  50. try:
  51. message = await pubsub.get_message(
  52. ignore_subscribe_messages=True, timeout=5
  53. )
  54. if message:
  55. await self._process_message(message)
  56. except asyncio.CancelledError:
  57. return
  58. except Exception:
  59. try:
  60. asyncio.get_running_loop()
  61. logger.warning(
  62. 'error_reading_from_redis', exc_info=True, stack_info=True
  63. )
  64. except RuntimeError:
  65. return # Loop has been shut down
  66. async def _process_message(self, message: dict):
  67. data = json.loads(message['data'])
  68. logger.info(f'got_published_message:{message}')
  69. sid = data['sid']
  70. message_type = data['message_type']
  71. if message_type == 'event':
  72. session = self.local_sessions_by_sid.get(sid)
  73. if session:
  74. await session.dispatch(data['data'])
  75. elif message_type == 'is_session_running':
  76. # Another node in the cluster is asking if the current node is running the session given.
  77. session = self.local_sessions_by_sid.get(sid)
  78. if session:
  79. await self._get_redis_client().publish(
  80. 'oh_event',
  81. json.dumps({'sid': sid, 'message_type': 'session_is_running'}),
  82. )
  83. elif message_type == 'session_is_running':
  84. self._last_alive_timestamps[sid] = time.time()
  85. flag = self._session_is_running_flags.get(sid)
  86. if flag:
  87. flag.set()
  88. elif message_type == 'has_remote_connections_query':
  89. # Another node in the cluster is asking if the current node is connected to a session
  90. required = sid in self.local_connection_id_to_session_id.values()
  91. if required:
  92. await self._get_redis_client().publish(
  93. 'oh_event',
  94. json.dumps(
  95. {'sid': sid, 'message_type': 'has_remote_connections_response'}
  96. ),
  97. )
  98. elif message_type == 'has_remote_connections_response':
  99. flag = self._has_remote_connections_flags.get(sid)
  100. if flag:
  101. flag.set()
  102. elif message_type == 'session_closing':
  103. # Session closing event - We only get this in the event of graceful shutdown,
  104. # which can't be guaranteed - nodes can simply vanish unexpectedly!
  105. logger.info(f'session_closing:{sid}')
  106. for (
  107. connection_id,
  108. local_sid,
  109. ) in self.local_connection_id_to_session_id.items():
  110. if sid == local_sid:
  111. logger.warning(
  112. 'local_connection_to_closing_session:{connection_id}:{sid}'
  113. )
  114. await self.sio.disconnect(connection_id)
  115. async def attach_to_conversation(self, sid: str) -> Conversation | None:
  116. start_time = time.time()
  117. if not await session_exists(sid, self.file_store):
  118. return None
  119. c = Conversation(sid, file_store=self.file_store, config=self.config)
  120. try:
  121. await c.connect()
  122. except RuntimeUnavailableError as e:
  123. logger.error(f'Error connecting to conversation {c.sid}: {e}')
  124. return None
  125. end_time = time.time()
  126. logger.info(
  127. f'Conversation {c.sid} connected in {end_time - start_time} seconds'
  128. )
  129. return c
  130. async def detach_from_conversation(self, conversation: Conversation):
  131. await conversation.disconnect()
  132. async def init_or_join_session(self, sid: str, connection_id: str, session_init_data: SessionInitData):
  133. await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
  134. self.local_connection_id_to_session_id[connection_id] = sid
  135. # If we have a local session running, use that
  136. session = self.local_sessions_by_sid.get(sid)
  137. if session:
  138. logger.info(f'found_local_session:{sid}')
  139. return session.agent_session.event_stream
  140. # If there is a remote session running, retrieve existing events for that
  141. redis_client = self._get_redis_client()
  142. if redis_client and await self._is_session_running_in_cluster(sid):
  143. return EventStream(sid, self.file_store)
  144. return await self.start_local_session(sid, session_init_data)
  145. async def _is_session_running_in_cluster(self, sid: str) -> bool:
  146. """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
  147. # Create a flag for the callback
  148. flag = asyncio.Event()
  149. self._session_is_running_flags[sid] = flag
  150. try:
  151. await self._get_redis_client().publish(
  152. 'oh_event',
  153. json.dumps(
  154. {
  155. 'sid': sid,
  156. 'message_type': 'is_session_running',
  157. }
  158. ),
  159. )
  160. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  161. await flag.wait()
  162. result = flag.is_set()
  163. return result
  164. except TimeoutError:
  165. # Nobody replied in time
  166. return False
  167. finally:
  168. self._session_is_running_flags.pop(sid)
  169. async def _has_remote_connections(self, sid: str) -> bool:
  170. """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
  171. # Create a flag for the callback
  172. flag = asyncio.Event()
  173. self._has_remote_connections_flags[sid] = flag
  174. try:
  175. await self._get_redis_client().publish(
  176. 'oh_event',
  177. json.dumps(
  178. {
  179. 'sid': sid,
  180. 'message_type': 'has_remote_connections_query',
  181. }
  182. ),
  183. )
  184. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  185. await flag.wait()
  186. result = flag.is_set()
  187. return result
  188. except TimeoutError:
  189. # Nobody replied in time
  190. return False
  191. finally:
  192. self._has_remote_connections_flags.pop(sid)
  193. async def start_local_session(self, sid: str, session_init_data: SessionInitData):
  194. # Start a new local session
  195. logger.info(f'start_new_local_session:{sid}')
  196. session = Session(
  197. sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
  198. )
  199. self.local_sessions_by_sid[sid] = session
  200. await session.initialize_agent(session_init_data)
  201. return session.agent_session.event_stream
  202. async def send_to_event_stream(self, connection_id: str, data: dict):
  203. # If there is a local session running, send to that
  204. sid = self.local_connection_id_to_session_id.get(connection_id)
  205. if not sid:
  206. raise RuntimeError(f'no_connected_session:{connection_id}')
  207. session = self.local_sessions_by_sid.get(sid)
  208. if session:
  209. await session.dispatch(data)
  210. return
  211. redis_client = self._get_redis_client()
  212. if redis_client:
  213. # If we have a recent report that the session is alive in another pod
  214. last_alive_at = self._last_alive_timestamps.get(sid) or 0
  215. next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
  216. if next_alive_check > time.time() or self._is_session_running_in_cluster(
  217. sid
  218. ):
  219. # Send the event to the other pod
  220. await redis_client.publish(
  221. 'oh_event',
  222. json.dumps(
  223. {
  224. 'sid': sid,
  225. 'message_type': 'event',
  226. 'data': data,
  227. }
  228. ),
  229. )
  230. return
  231. raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
  232. async def disconnect_from_session(self, connection_id: str):
  233. sid = self.local_connection_id_to_session_id.pop(connection_id, None)
  234. if not sid:
  235. # This can occur if the init action was never run.
  236. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
  237. return
  238. session = self.local_sessions_by_sid.get(sid)
  239. if session:
  240. logger.info(f'close_session:{connection_id}:{sid}')
  241. if should_continue():
  242. asyncio.create_task(self._cleanup_session_later(session))
  243. else:
  244. await self._close_session(session)
  245. async def _cleanup_session_later(self, session: Session):
  246. # Once there have been no connections to a session for a reasonable period, we close it
  247. try:
  248. await asyncio.sleep(self.config.sandbox.close_delay)
  249. finally:
  250. # If the sleep was cancelled, we still want to close these
  251. await self._cleanup_session(session)
  252. async def _cleanup_session(self, session: Session):
  253. # Get local connections
  254. has_local_connections = next(
  255. (
  256. True
  257. for v in self.local_connection_id_to_session_id.values()
  258. if v == session.sid
  259. ),
  260. False,
  261. )
  262. if has_local_connections:
  263. return False
  264. # If no local connections, get connections through redis
  265. redis_client = self._get_redis_client()
  266. if redis_client and await self._has_remote_connections(session.sid):
  267. return False
  268. # We alert the cluster in case they are interested
  269. if redis_client:
  270. await redis_client.publish(
  271. 'oh_event',
  272. json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
  273. )
  274. await self._close_session(session)
  275. async def _close_session(self, session: Session):
  276. logger.info(f'_close_session:{session.sid}')
  277. # Clear up local variables
  278. connection_ids_to_remove = list(
  279. connection_id
  280. for connection_id, sid in self.local_connection_id_to_session_id.items()
  281. if sid == session.sid
  282. )
  283. for connnnection_id in connection_ids_to_remove:
  284. self.local_connection_id_to_session_id.pop(connnnection_id, None)
  285. self.local_sessions_by_sid.pop(session.sid, None)
  286. # We alert the cluster in case they are interested
  287. redis_client = self._get_redis_client()
  288. if redis_client:
  289. await redis_client.publish(
  290. 'oh_event',
  291. json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
  292. )
  293. session.close()