manager.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. import asyncio
  2. import json
  3. import time
  4. from dataclasses import dataclass, field
  5. import socketio
  6. from openhands.core.config import AppConfig
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.events.stream import EventStream, session_exists
  9. from openhands.runtime.base import RuntimeUnavailableError
  10. from openhands.server.session.conversation import Conversation
  11. from openhands.server.session.session import ROOM_KEY, Session
  12. from openhands.server.session.session_init_data import SessionInitData
  13. from openhands.storage.files import FileStore
  14. from openhands.utils.shutdown_listener import should_continue
  15. _REDIS_POLL_TIMEOUT = 1.5
  16. _CHECK_ALIVE_INTERVAL = 15
  17. @dataclass
  18. class SessionManager:
  19. sio: socketio.AsyncServer
  20. config: AppConfig
  21. file_store: FileStore
  22. local_sessions_by_sid: dict[str, Session] = field(default_factory=dict)
  23. local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
  24. _last_alive_timestamps: dict[str, float] = field(default_factory=dict)
  25. _redis_listen_task: asyncio.Task | None = None
  26. _session_is_running_flags: dict[str, asyncio.Event] = field(default_factory=dict)
  27. _active_conversations: dict[str, tuple[Conversation, int]] = field(
  28. default_factory=dict
  29. )
  30. _detached_conversations: dict[str, tuple[Conversation, float]] = field(
  31. default_factory=dict
  32. )
  33. _conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  34. _cleanup_task: asyncio.Task | None = None
  35. _has_remote_connections_flags: dict[str, asyncio.Event] = field(
  36. default_factory=dict
  37. )
  38. async def __aenter__(self):
  39. redis_client = self._get_redis_client()
  40. if redis_client:
  41. self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
  42. self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations())
  43. return self
  44. async def __aexit__(self, exc_type, exc_value, traceback):
  45. if self._redis_listen_task:
  46. self._redis_listen_task.cancel()
  47. self._redis_listen_task = None
  48. if self._cleanup_task:
  49. self._cleanup_task.cancel()
  50. self._cleanup_task = None
  51. def _get_redis_client(self):
  52. redis_client = getattr(self.sio.manager, 'redis', None)
  53. return redis_client
  54. async def _redis_subscribe(self):
  55. """
  56. We use a redis backchannel to send actions between server nodes
  57. """
  58. logger.debug('_redis_subscribe')
  59. redis_client = self._get_redis_client()
  60. pubsub = redis_client.pubsub()
  61. await pubsub.subscribe('oh_event')
  62. while should_continue():
  63. try:
  64. message = await pubsub.get_message(
  65. ignore_subscribe_messages=True, timeout=5
  66. )
  67. if message:
  68. await self._process_message(message)
  69. except asyncio.CancelledError:
  70. return
  71. except Exception:
  72. try:
  73. asyncio.get_running_loop()
  74. logger.warning(
  75. 'error_reading_from_redis', exc_info=True, stack_info=True
  76. )
  77. except RuntimeError:
  78. return # Loop has been shut down
  79. async def _process_message(self, message: dict):
  80. data = json.loads(message['data'])
  81. logger.debug(f'got_published_message:{message}')
  82. sid = data['sid']
  83. message_type = data['message_type']
  84. if message_type == 'event':
  85. session = self.local_sessions_by_sid.get(sid)
  86. if session:
  87. await session.dispatch(data['data'])
  88. elif message_type == 'is_session_running':
  89. # Another node in the cluster is asking if the current node is running the session given.
  90. session = self.local_sessions_by_sid.get(sid)
  91. if session:
  92. await self._get_redis_client().publish(
  93. 'oh_event',
  94. json.dumps({'sid': sid, 'message_type': 'session_is_running'}),
  95. )
  96. elif message_type == 'session_is_running':
  97. self._last_alive_timestamps[sid] = time.time()
  98. flag = self._session_is_running_flags.get(sid)
  99. if flag:
  100. flag.set()
  101. elif message_type == 'has_remote_connections_query':
  102. # Another node in the cluster is asking if the current node is connected to a session
  103. required = sid in self.local_connection_id_to_session_id.values()
  104. if required:
  105. await self._get_redis_client().publish(
  106. 'oh_event',
  107. json.dumps(
  108. {'sid': sid, 'message_type': 'has_remote_connections_response'}
  109. ),
  110. )
  111. elif message_type == 'has_remote_connections_response':
  112. flag = self._has_remote_connections_flags.get(sid)
  113. if flag:
  114. flag.set()
  115. elif message_type == 'session_closing':
  116. # Session closing event - We only get this in the event of graceful shutdown,
  117. # which can't be guaranteed - nodes can simply vanish unexpectedly!
  118. logger.debug(f'session_closing:{sid}')
  119. for (
  120. connection_id,
  121. local_sid,
  122. ) in self.local_connection_id_to_session_id.items():
  123. if sid == local_sid:
  124. logger.warning(
  125. 'local_connection_to_closing_session:{connection_id}:{sid}'
  126. )
  127. await self.sio.disconnect(connection_id)
  128. async def attach_to_conversation(self, sid: str) -> Conversation | None:
  129. start_time = time.time()
  130. if not await session_exists(sid, self.file_store):
  131. return None
  132. async with self._conversations_lock:
  133. # Check if we have an active conversation we can reuse
  134. if sid in self._active_conversations:
  135. conversation, count = self._active_conversations[sid]
  136. self._active_conversations[sid] = (conversation, count + 1)
  137. logger.info(f'Reusing active conversation {sid}')
  138. return conversation
  139. # Check if we have a detached conversation we can reuse
  140. if sid in self._detached_conversations:
  141. conversation, _ = self._detached_conversations.pop(sid)
  142. self._active_conversations[sid] = (conversation, 1)
  143. logger.info(f'Reusing detached conversation {sid}')
  144. return conversation
  145. # Create new conversation if none exists
  146. c = Conversation(sid, file_store=self.file_store, config=self.config)
  147. try:
  148. await c.connect()
  149. except RuntimeUnavailableError as e:
  150. logger.error(f'Error connecting to conversation {c.sid}: {e}')
  151. return None
  152. end_time = time.time()
  153. logger.info(
  154. f'Conversation {c.sid} connected in {end_time - start_time} seconds'
  155. )
  156. self._active_conversations[sid] = (c, 1)
  157. return c
  158. async def detach_from_conversation(self, conversation: Conversation):
  159. sid = conversation.sid
  160. async with self._conversations_lock:
  161. if sid in self._active_conversations:
  162. conv, count = self._active_conversations[sid]
  163. if count > 1:
  164. self._active_conversations[sid] = (conv, count - 1)
  165. return
  166. else:
  167. self._active_conversations.pop(sid)
  168. self._detached_conversations[sid] = (conversation, time.time())
  169. async def _cleanup_detached_conversations(self):
  170. while should_continue():
  171. try:
  172. async with self._conversations_lock:
  173. # Create a list of items to process to avoid modifying dict during iteration
  174. items = list(self._detached_conversations.items())
  175. for sid, (conversation, detach_time) in items:
  176. await conversation.disconnect()
  177. self._detached_conversations.pop(sid, None)
  178. await asyncio.sleep(60)
  179. except asyncio.CancelledError:
  180. async with self._conversations_lock:
  181. for conversation, _ in self._detached_conversations.values():
  182. await conversation.disconnect()
  183. self._detached_conversations.clear()
  184. return
  185. except Exception:
  186. logger.warning('error_cleaning_detached_conversations', exc_info=True)
  187. await asyncio.sleep(15)
  188. async def init_or_join_session(
  189. self, sid: str, connection_id: str, session_init_data: SessionInitData
  190. ):
  191. await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
  192. self.local_connection_id_to_session_id[connection_id] = sid
  193. # If we have a local session running, use that
  194. session = self.local_sessions_by_sid.get(sid)
  195. if session:
  196. logger.info(f'found_local_session:{sid}')
  197. return session.agent_session.event_stream
  198. # If there is a remote session running, retrieve existing events for that
  199. redis_client = self._get_redis_client()
  200. if redis_client and await self._is_session_running_in_cluster(sid):
  201. return EventStream(sid, self.file_store)
  202. return await self.start_local_session(sid, session_init_data)
  203. async def _is_session_running_in_cluster(self, sid: str) -> bool:
  204. """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
  205. # Create a flag for the callback
  206. flag = asyncio.Event()
  207. self._session_is_running_flags[sid] = flag
  208. try:
  209. logger.debug(f'publish:is_session_running:{sid}')
  210. await self._get_redis_client().publish(
  211. 'oh_event',
  212. json.dumps(
  213. {
  214. 'sid': sid,
  215. 'message_type': 'is_session_running',
  216. }
  217. ),
  218. )
  219. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  220. await flag.wait()
  221. result = flag.is_set()
  222. return result
  223. except TimeoutError:
  224. # Nobody replied in time
  225. return False
  226. finally:
  227. self._session_is_running_flags.pop(sid)
  228. async def _has_remote_connections(self, sid: str) -> bool:
  229. """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
  230. # Create a flag for the callback
  231. flag = asyncio.Event()
  232. self._has_remote_connections_flags[sid] = flag
  233. try:
  234. await self._get_redis_client().publish(
  235. 'oh_event',
  236. json.dumps(
  237. {
  238. 'sid': sid,
  239. 'message_type': 'has_remote_connections_query',
  240. }
  241. ),
  242. )
  243. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  244. await flag.wait()
  245. result = flag.is_set()
  246. return result
  247. except TimeoutError:
  248. # Nobody replied in time
  249. return False
  250. finally:
  251. self._has_remote_connections_flags.pop(sid)
  252. async def start_local_session(self, sid: str, session_init_data: SessionInitData):
  253. # Start a new local session
  254. logger.info(f'start_new_local_session:{sid}')
  255. session = Session(
  256. sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
  257. )
  258. self.local_sessions_by_sid[sid] = session
  259. await session.initialize_agent(session_init_data)
  260. return session.agent_session.event_stream
  261. async def send_to_event_stream(self, connection_id: str, data: dict):
  262. # If there is a local session running, send to that
  263. sid = self.local_connection_id_to_session_id.get(connection_id)
  264. if not sid:
  265. raise RuntimeError(f'no_connected_session:{connection_id}')
  266. session = self.local_sessions_by_sid.get(sid)
  267. if session:
  268. await session.dispatch(data)
  269. return
  270. redis_client = self._get_redis_client()
  271. if redis_client:
  272. # If we have a recent report that the session is alive in another pod
  273. last_alive_at = self._last_alive_timestamps.get(sid) or 0
  274. next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
  275. if next_alive_check > time.time() or self._is_session_running_in_cluster(
  276. sid
  277. ):
  278. # Send the event to the other pod
  279. await redis_client.publish(
  280. 'oh_event',
  281. json.dumps(
  282. {
  283. 'sid': sid,
  284. 'message_type': 'event',
  285. 'data': data,
  286. }
  287. ),
  288. )
  289. return
  290. raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
  291. async def disconnect_from_session(self, connection_id: str):
  292. sid = self.local_connection_id_to_session_id.pop(connection_id, None)
  293. if not sid:
  294. # This can occur if the init action was never run.
  295. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
  296. return
  297. session = self.local_sessions_by_sid.get(sid)
  298. if session:
  299. logger.info(f'close_session:{connection_id}:{sid}')
  300. if should_continue():
  301. asyncio.create_task(self._cleanup_session_later(session))
  302. else:
  303. await self._close_session(session)
  304. async def _cleanup_session_later(self, session: Session):
  305. # Once there have been no connections to a session for a reasonable period, we close it
  306. try:
  307. await asyncio.sleep(self.config.sandbox.close_delay)
  308. finally:
  309. # If the sleep was cancelled, we still want to close these
  310. await self._cleanup_session(session)
  311. async def _cleanup_session(self, session: Session):
  312. # Get local connections
  313. has_local_connections = next(
  314. (
  315. True
  316. for v in self.local_connection_id_to_session_id.values()
  317. if v == session.sid
  318. ),
  319. False,
  320. )
  321. if has_local_connections:
  322. return False
  323. # If no local connections, get connections through redis
  324. redis_client = self._get_redis_client()
  325. if redis_client and await self._has_remote_connections(session.sid):
  326. return False
  327. # We alert the cluster in case they are interested
  328. if redis_client:
  329. await redis_client.publish(
  330. 'oh_event',
  331. json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
  332. )
  333. await self._close_session(session)
  334. return True
  335. async def _close_session(self, session: Session):
  336. logger.info(f'_close_session:{session.sid}')
  337. # Clear up local variables
  338. connection_ids_to_remove = list(
  339. connection_id
  340. for connection_id, sid in self.local_connection_id_to_session_id.items()
  341. if sid == session.sid
  342. )
  343. for connnnection_id in connection_ids_to_remove:
  344. self.local_connection_id_to_session_id.pop(connnnection_id, None)
  345. self.local_sessions_by_sid.pop(session.sid, None)
  346. # We alert the cluster in case they are interested
  347. redis_client = self._get_redis_client()
  348. if redis_client:
  349. await redis_client.publish(
  350. 'oh_event',
  351. json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
  352. )
  353. session.close()