manager.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import asyncio
  2. import json
  3. import time
  4. from dataclasses import dataclass, field
  5. import socketio
  6. from openhands.core.config import AppConfig
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.events.stream import EventStream, session_exists
  9. from openhands.runtime.base import RuntimeUnavailableError
  10. from openhands.server.session.conversation import Conversation
  11. from openhands.server.session.session import ROOM_KEY, Session
  12. from openhands.storage.files import FileStore
  13. from openhands.utils.shutdown_listener import should_continue
  14. _REDIS_POLL_TIMEOUT = 1.5
  15. _CHECK_ALIVE_INTERVAL = 15
  16. @dataclass
  17. class SessionManager:
  18. sio: socketio.AsyncServer
  19. config: AppConfig
  20. file_store: FileStore
  21. local_sessions_by_sid: dict[str, Session] = field(default_factory=dict)
  22. local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
  23. _last_alive_timestamps: dict[str, float] = field(default_factory=dict)
  24. _redis_listen_task: asyncio.Task | None = None
  25. _session_is_running_flags: dict[str, asyncio.Event] = field(default_factory=dict)
  26. _has_remote_connections_flags: dict[str, asyncio.Event] = field(
  27. default_factory=dict
  28. )
  29. async def __aenter__(self):
  30. redis_client = self._get_redis_client()
  31. if redis_client:
  32. self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
  33. return self
  34. async def __aexit__(self, exc_type, exc_value, traceback):
  35. if self._redis_listen_task:
  36. self._redis_listen_task.cancel()
  37. self._redis_listen_task = None
  38. def _get_redis_client(self):
  39. redis_client = getattr(self.sio.manager, 'redis', None)
  40. return redis_client
  41. async def _redis_subscribe(self):
  42. """
  43. We use a redis backchannel to send actions between server nodes
  44. """
  45. redis_client = self._get_redis_client()
  46. pubsub = redis_client.pubsub()
  47. await pubsub.subscribe('oh_event')
  48. while should_continue():
  49. try:
  50. message = await pubsub.get_message(
  51. ignore_subscribe_messages=True, timeout=5
  52. )
  53. if message:
  54. await self._process_message(message)
  55. except asyncio.CancelledError:
  56. return
  57. except Exception:
  58. try:
  59. asyncio.get_running_loop()
  60. logger.warning(
  61. 'error_reading_from_redis', exc_info=True, stack_info=True
  62. )
  63. except RuntimeError:
  64. return # Loop has been shut down
  65. async def _process_message(self, message: dict):
  66. data = json.loads(message['data'])
  67. logger.info(f'got_published_message:{message}')
  68. sid = data['sid']
  69. message_type = data['message_type']
  70. if message_type == 'event':
  71. session = self.local_sessions_by_sid.get(sid)
  72. if session:
  73. await session.dispatch(data['data'])
  74. elif message_type == 'is_session_running':
  75. # Another node in the cluster is asking if the current node is running the session given.
  76. session = self.local_sessions_by_sid.get(sid)
  77. if session:
  78. await self._get_redis_client().publish(
  79. 'oh_event',
  80. json.dumps({'sid': sid, 'message_type': 'session_is_running'}),
  81. )
  82. elif message_type == 'session_is_running':
  83. self._last_alive_timestamps[sid] = time.time()
  84. flag = self._session_is_running_flags.get(sid)
  85. if flag:
  86. flag.set()
  87. elif message_type == 'has_remote_connections_query':
  88. # Another node in the cluster is asking if the current node is connected to a session
  89. required = sid in self.local_connection_id_to_session_id.values()
  90. if required:
  91. await self._get_redis_client().publish(
  92. 'oh_event',
  93. json.dumps(
  94. {'sid': sid, 'message_type': 'has_remote_connections_response'}
  95. ),
  96. )
  97. elif message_type == 'has_remote_connections_response':
  98. flag = self._has_remote_connections_flags.get(sid)
  99. if flag:
  100. flag.set()
  101. elif message_type == 'session_closing':
  102. # Session closing event - We only get this in the event of graceful shutdown,
  103. # which can't be guaranteed - nodes can simply vanish unexpectedly!
  104. logger.info(f'session_closing:{sid}')
  105. for (
  106. connection_id,
  107. local_sid,
  108. ) in self.local_connection_id_to_session_id.items():
  109. if sid == local_sid:
  110. logger.warning(
  111. 'local_connection_to_closing_session:{connection_id}:{sid}'
  112. )
  113. await self.sio.disconnect(connection_id)
  114. async def attach_to_conversation(self, sid: str) -> Conversation | None:
  115. start_time = time.time()
  116. if not await session_exists(sid, self.file_store):
  117. return None
  118. c = Conversation(sid, file_store=self.file_store, config=self.config)
  119. try:
  120. await c.connect()
  121. except RuntimeUnavailableError as e:
  122. logger.error(f'Error connecting to conversation {c.sid}: {e}')
  123. return None
  124. end_time = time.time()
  125. logger.info(
  126. f'Conversation {c.sid} connected in {end_time - start_time} seconds'
  127. )
  128. return c
  129. async def detach_from_conversation(self, conversation: Conversation):
  130. await conversation.disconnect()
  131. async def init_or_join_session(self, sid: str, connection_id: str, data: dict):
  132. await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
  133. self.local_connection_id_to_session_id[connection_id] = sid
  134. # If we have a local session running, use that
  135. session = self.local_sessions_by_sid.get(sid)
  136. if session:
  137. logger.info(f'found_local_session:{sid}')
  138. return session.agent_session.event_stream
  139. # If there is a remote session running, retrieve existing events for that
  140. redis_client = self._get_redis_client()
  141. if redis_client and await self._is_session_running_in_cluster(sid):
  142. return EventStream(sid, self.file_store)
  143. return await self.start_local_session(sid, data)
  144. async def _is_session_running_in_cluster(self, sid: str) -> bool:
  145. """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
  146. # Create a flag for the callback
  147. flag = asyncio.Event()
  148. self._session_is_running_flags[sid] = flag
  149. try:
  150. await self._get_redis_client().publish(
  151. 'oh_event',
  152. json.dumps(
  153. {
  154. 'sid': sid,
  155. 'message_type': 'is_session_running',
  156. }
  157. ),
  158. )
  159. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  160. await flag.wait()
  161. result = flag.is_set()
  162. return result
  163. except TimeoutError:
  164. # Nobody replied in time
  165. return False
  166. finally:
  167. self._session_is_running_flags.pop(sid)
  168. async def _has_remote_connections(self, sid: str) -> bool:
  169. """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
  170. # Create a flag for the callback
  171. flag = asyncio.Event()
  172. self._has_remote_connections_flags[sid] = flag
  173. try:
  174. await self._get_redis_client().publish(
  175. 'oh_event',
  176. json.dumps(
  177. {
  178. 'sid': sid,
  179. 'message_type': 'has_remote_connections_query',
  180. }
  181. ),
  182. )
  183. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  184. await flag.wait()
  185. result = flag.is_set()
  186. return result
  187. except TimeoutError:
  188. # Nobody replied in time
  189. return False
  190. finally:
  191. self._has_remote_connections_flags.pop(sid)
  192. async def start_local_session(self, sid: str, data: dict):
  193. # Start a new local session
  194. logger.info(f'start_new_local_session:{sid}')
  195. session = Session(
  196. sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
  197. )
  198. self.local_sessions_by_sid[sid] = session
  199. await session.initialize_agent(data)
  200. return session.agent_session.event_stream
  201. async def send_to_event_stream(self, connection_id: str, data: dict):
  202. # If there is a local session running, send to that
  203. sid = self.local_connection_id_to_session_id.get(connection_id)
  204. if not sid:
  205. raise RuntimeError(f'no_connected_session:{connection_id}')
  206. session = self.local_sessions_by_sid.get(sid)
  207. if session:
  208. await session.dispatch(data)
  209. return
  210. redis_client = self._get_redis_client()
  211. if redis_client:
  212. # If we have a recent report that the session is alive in another pod
  213. last_alive_at = self._last_alive_timestamps.get(sid) or 0
  214. next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
  215. if next_alive_check > time.time() or self._is_session_running_in_cluster(
  216. sid
  217. ):
  218. # Send the event to the other pod
  219. await redis_client.publish(
  220. 'oh_event',
  221. json.dumps(
  222. {
  223. 'sid': sid,
  224. 'message_type': 'event',
  225. 'data': data,
  226. }
  227. ),
  228. )
  229. return
  230. raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
  231. async def disconnect_from_session(self, connection_id: str):
  232. sid = self.local_connection_id_to_session_id.pop(connection_id, None)
  233. if not sid:
  234. # This can occur if the init action was never run.
  235. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
  236. return
  237. session = self.local_sessions_by_sid.get(sid)
  238. if session:
  239. logger.info(f'close_session:{connection_id}:{sid}')
  240. if should_continue():
  241. asyncio.create_task(self._cleanup_session_later(session))
  242. else:
  243. await self._close_session(session)
  244. async def _cleanup_session_later(self, session: Session):
  245. # Once there have been no connections to a session for a reasonable period, we close it
  246. try:
  247. await asyncio.sleep(self.config.sandbox.close_delay)
  248. finally:
  249. # If the sleep was cancelled, we still want to close these
  250. await self._cleanup_session(session)
  251. async def _cleanup_session(self, session: Session):
  252. # Get local connections
  253. has_local_connections = next(
  254. (
  255. True
  256. for v in self.local_connection_id_to_session_id.values()
  257. if v == session.sid
  258. ),
  259. False,
  260. )
  261. if has_local_connections:
  262. return False
  263. # If no local connections, get connections through redis
  264. redis_client = self._get_redis_client()
  265. if redis_client and await self._has_remote_connections(session.sid):
  266. return False
  267. # We alert the cluster in case they are interested
  268. if redis_client:
  269. await redis_client.publish(
  270. 'oh_event',
  271. json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
  272. )
  273. await self._close_session(session)
  274. async def _close_session(self, session: Session):
  275. logger.info(f'_close_session:{session.sid}')
  276. # Clear up local variables
  277. connection_ids_to_remove = list(
  278. connection_id
  279. for connection_id, sid in self.local_connection_id_to_session_id.items()
  280. if sid == session.sid
  281. )
  282. for connnnection_id in connection_ids_to_remove:
  283. self.local_connection_id_to_session_id.pop(connnnection_id, None)
  284. self.local_sessions_by_sid.pop(session.sid, None)
  285. # We alert the cluster in case they are interested
  286. redis_client = self._get_redis_client()
  287. if redis_client:
  288. await redis_client.publish(
  289. 'oh_event',
  290. json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
  291. )
  292. session.close()