manager.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. import asyncio
  2. import json
  3. import time
  4. from dataclasses import dataclass, field
  5. import socketio
  6. from openhands.core.config import AppConfig
  7. from openhands.core.exceptions import AgentRuntimeUnavailableError
  8. from openhands.core.logger import openhands_logger as logger
  9. from openhands.events.stream import EventStream, session_exists
  10. from openhands.server.session.conversation import Conversation
  11. from openhands.server.session.conversation_init_data import ConversationInitData
  12. from openhands.server.session.session import ROOM_KEY, Session
  13. from openhands.storage.files import FileStore
  14. from openhands.utils.async_utils import call_sync_from_async
  15. from openhands.utils.shutdown_listener import should_continue
  16. _REDIS_POLL_TIMEOUT = 1.5
  17. _CHECK_ALIVE_INTERVAL = 15
  18. _CLEANUP_INTERVAL = 15
  19. _CLEANUP_EXCEPTION_WAIT_TIME = 15
  20. class ConversationDoesNotExistError(Exception):
  21. pass
  22. @dataclass
  23. class SessionManager:
  24. sio: socketio.AsyncServer
  25. config: AppConfig
  26. file_store: FileStore
  27. _local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict)
  28. local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
  29. _last_alive_timestamps: dict[str, float] = field(default_factory=dict)
  30. _redis_listen_task: asyncio.Task | None = None
  31. _session_is_running_flags: dict[str, asyncio.Event] = field(default_factory=dict)
  32. _active_conversations: dict[str, tuple[Conversation, int]] = field(
  33. default_factory=dict
  34. )
  35. _detached_conversations: dict[str, tuple[Conversation, float]] = field(
  36. default_factory=dict
  37. )
  38. _conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  39. _cleanup_task: asyncio.Task | None = None
  40. _has_remote_connections_flags: dict[str, asyncio.Event] = field(
  41. default_factory=dict
  42. )
  43. async def __aenter__(self):
  44. redis_client = self._get_redis_client()
  45. if redis_client:
  46. self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
  47. self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations())
  48. return self
  49. async def __aexit__(self, exc_type, exc_value, traceback):
  50. if self._redis_listen_task:
  51. self._redis_listen_task.cancel()
  52. self._redis_listen_task = None
  53. if self._cleanup_task:
  54. self._cleanup_task.cancel()
  55. self._cleanup_task = None
  56. def _get_redis_client(self):
  57. redis_client = getattr(self.sio.manager, 'redis', None)
  58. return redis_client
  59. async def _redis_subscribe(self):
  60. """
  61. We use a redis backchannel to send actions between server nodes
  62. """
  63. logger.debug('_redis_subscribe')
  64. redis_client = self._get_redis_client()
  65. pubsub = redis_client.pubsub()
  66. await pubsub.subscribe('oh_event')
  67. while should_continue():
  68. try:
  69. message = await pubsub.get_message(
  70. ignore_subscribe_messages=True, timeout=5
  71. )
  72. if message:
  73. await self._process_message(message)
  74. except asyncio.CancelledError:
  75. return
  76. except Exception:
  77. try:
  78. asyncio.get_running_loop()
  79. logger.warning(
  80. 'error_reading_from_redis', exc_info=True, stack_info=True
  81. )
  82. except RuntimeError:
  83. return # Loop has been shut down
  84. async def _process_message(self, message: dict):
  85. data = json.loads(message['data'])
  86. logger.debug(f'got_published_message:{message}')
  87. sid = data['sid']
  88. message_type = data['message_type']
  89. if message_type == 'event':
  90. session = self._local_agent_loops_by_sid.get(sid)
  91. if session:
  92. await session.dispatch(data['data'])
  93. elif message_type == 'is_session_running':
  94. # Another node in the cluster is asking if the current node is running the session given.
  95. session = self._local_agent_loops_by_sid.get(sid)
  96. if session:
  97. await self._get_redis_client().publish(
  98. 'oh_event',
  99. json.dumps({'sid': sid, 'message_type': 'session_is_running'}),
  100. )
  101. elif message_type == 'session_is_running':
  102. self._last_alive_timestamps[sid] = time.time()
  103. flag = self._session_is_running_flags.get(sid)
  104. if flag:
  105. flag.set()
  106. elif message_type == 'has_remote_connections_query':
  107. # Another node in the cluster is asking if the current node is connected to a session
  108. required = sid in self.local_connection_id_to_session_id.values()
  109. if required:
  110. await self._get_redis_client().publish(
  111. 'oh_event',
  112. json.dumps(
  113. {'sid': sid, 'message_type': 'has_remote_connections_response'}
  114. ),
  115. )
  116. elif message_type == 'has_remote_connections_response':
  117. flag = self._has_remote_connections_flags.get(sid)
  118. if flag:
  119. flag.set()
  120. elif message_type == 'session_closing':
  121. # Session closing event - We only get this in the event of graceful shutdown,
  122. # which can't be guaranteed - nodes can simply vanish unexpectedly!
  123. logger.debug(f'session_closing:{sid}')
  124. for (
  125. connection_id,
  126. local_sid,
  127. ) in self.local_connection_id_to_session_id.items():
  128. if sid == local_sid:
  129. logger.warning(
  130. 'local_connection_to_closing_session:{connection_id}:{sid}'
  131. )
  132. await self.sio.disconnect(connection_id)
  133. async def attach_to_conversation(self, sid: str) -> Conversation | None:
  134. start_time = time.time()
  135. if not await session_exists(sid, self.file_store):
  136. return None
  137. async with self._conversations_lock:
  138. # Check if we have an active conversation we can reuse
  139. if sid in self._active_conversations:
  140. conversation, count = self._active_conversations[sid]
  141. self._active_conversations[sid] = (conversation, count + 1)
  142. logger.info(f'Reusing active conversation {sid}')
  143. return conversation
  144. # Check if we have a detached conversation we can reuse
  145. if sid in self._detached_conversations:
  146. conversation, _ = self._detached_conversations.pop(sid)
  147. self._active_conversations[sid] = (conversation, 1)
  148. logger.info(f'Reusing detached conversation {sid}')
  149. return conversation
  150. # Create new conversation if none exists
  151. c = Conversation(sid, file_store=self.file_store, config=self.config)
  152. try:
  153. await c.connect()
  154. except AgentRuntimeUnavailableError as e:
  155. logger.error(f'Error connecting to conversation {c.sid}: {e}')
  156. return None
  157. end_time = time.time()
  158. logger.info(
  159. f'Conversation {c.sid} connected in {end_time - start_time} seconds'
  160. )
  161. self._active_conversations[sid] = (c, 1)
  162. return c
  163. async def join_conversation(self, sid: str, connection_id: str) -> EventStream:
  164. await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
  165. self.local_connection_id_to_session_id[connection_id] = sid
  166. # If we have a local session running, use that
  167. session = self._local_agent_loops_by_sid.get(sid)
  168. if session:
  169. logger.info(f'found_local_session:{sid}')
  170. return session.agent_session.event_stream
  171. if await self._is_agent_loop_running_in_cluster(sid):
  172. return EventStream(sid, self.file_store)
  173. return await self.maybe_start_agent_loop(sid)
  174. async def detach_from_conversation(self, conversation: Conversation):
  175. sid = conversation.sid
  176. async with self._conversations_lock:
  177. if sid in self._active_conversations:
  178. conv, count = self._active_conversations[sid]
  179. if count > 1:
  180. self._active_conversations[sid] = (conv, count - 1)
  181. return
  182. else:
  183. self._active_conversations.pop(sid)
  184. self._detached_conversations[sid] = (conversation, time.time())
  185. async def _cleanup_detached_conversations(self):
  186. while should_continue():
  187. logger.info(f'Attached conversations: {len(self._active_conversations)}')
  188. logger.info(f'Detached conversations: {len(self._detached_conversations)}')
  189. try:
  190. async with self._conversations_lock:
  191. # Create a list of items to process to avoid modifying dict during iteration
  192. items = list(self._detached_conversations.items())
  193. for sid, (conversation, detach_time) in items:
  194. await conversation.disconnect()
  195. self._detached_conversations.pop(sid, None)
  196. await asyncio.sleep(_CLEANUP_INTERVAL)
  197. except asyncio.CancelledError:
  198. async with self._conversations_lock:
  199. for conversation, _ in self._detached_conversations.values():
  200. await conversation.disconnect()
  201. self._detached_conversations.clear()
  202. return
  203. except Exception:
  204. logger.warning('error_cleaning_detached_conversations', exc_info=True)
  205. await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
  206. async def _is_agent_loop_running(self, sid: str) -> bool:
  207. if await self._is_agent_loop_running_locally(sid):
  208. return True
  209. if await self._is_agent_loop_running_in_cluster(sid):
  210. return True
  211. return False
  212. async def _is_agent_loop_running_locally(self, sid: str) -> bool:
  213. if self._local_agent_loops_by_sid.get(sid, None):
  214. return True
  215. return False
  216. async def _is_agent_loop_running_in_cluster(self, sid: str) -> bool:
  217. """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
  218. redis_client = self._get_redis_client()
  219. if not redis_client:
  220. return False
  221. flag = asyncio.Event()
  222. self._session_is_running_flags[sid] = flag
  223. try:
  224. logger.debug(f'publish:is_session_running:{sid}')
  225. await redis_client.publish(
  226. 'oh_event',
  227. json.dumps(
  228. {
  229. 'sid': sid,
  230. 'message_type': 'is_session_running',
  231. }
  232. ),
  233. )
  234. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  235. await flag.wait()
  236. result = flag.is_set()
  237. return result
  238. except TimeoutError:
  239. # Nobody replied in time
  240. return False
  241. finally:
  242. self._session_is_running_flags.pop(sid, None)
  243. async def _has_remote_connections(self, sid: str) -> bool:
  244. """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
  245. # Create a flag for the callback
  246. flag = asyncio.Event()
  247. self._has_remote_connections_flags[sid] = flag
  248. try:
  249. await self._get_redis_client().publish(
  250. 'oh_event',
  251. json.dumps(
  252. {
  253. 'sid': sid,
  254. 'message_type': 'has_remote_connections_query',
  255. }
  256. ),
  257. )
  258. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  259. await flag.wait()
  260. result = flag.is_set()
  261. return result
  262. except TimeoutError:
  263. # Nobody replied in time
  264. return False
  265. finally:
  266. self._has_remote_connections_flags.pop(sid, None)
  267. async def maybe_start_agent_loop(
  268. self, sid: str, conversation_init_data: ConversationInitData | None = None
  269. ) -> EventStream:
  270. logger.info(f'maybe_start_agent_loop:{sid}')
  271. session: Session | None = None
  272. if not await self._is_agent_loop_running_locally(sid):
  273. session = Session(
  274. sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
  275. )
  276. self._local_agent_loops_by_sid[sid] = session
  277. if not await self._is_agent_loop_running_in_cluster(sid):
  278. logger.info(f'start_agent_loop:{sid}')
  279. await session.initialize_agent(conversation_init_data)
  280. session = self._local_agent_loops_by_sid.get(sid)
  281. if session is not None:
  282. return session.agent_session.event_stream
  283. raise RuntimeError(f'no_session:{sid}')
  284. async def send_to_event_stream(self, connection_id: str, data: dict):
  285. # If there is a local session running, send to that
  286. sid = self.local_connection_id_to_session_id.get(connection_id)
  287. if not sid:
  288. raise RuntimeError(f'no_connected_session:{connection_id}')
  289. session = self._local_agent_loops_by_sid.get(sid)
  290. if session:
  291. await session.dispatch(data)
  292. return
  293. redis_client = self._get_redis_client()
  294. if redis_client:
  295. # If we have a recent report that the session is alive in another pod
  296. last_alive_at = self._last_alive_timestamps.get(sid) or 0
  297. next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
  298. if next_alive_check > time.time() or self._is_agent_loop_running_in_cluster(
  299. sid
  300. ):
  301. # Send the event to the other pod
  302. await redis_client.publish(
  303. 'oh_event',
  304. json.dumps(
  305. {
  306. 'sid': sid,
  307. 'message_type': 'event',
  308. 'data': data,
  309. }
  310. ),
  311. )
  312. return
  313. raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
  314. async def disconnect_from_session(self, connection_id: str):
  315. sid = self.local_connection_id_to_session_id.pop(connection_id, None)
  316. logger.info(f'disconnect_from_session:{connection_id}:{sid}')
  317. if not sid:
  318. # This can occur if the init action was never run.
  319. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
  320. return
  321. if should_continue():
  322. asyncio.create_task(self._cleanup_session_later(sid))
  323. else:
  324. await self._close_session(sid)
  325. async def _cleanup_session_later(self, sid: str):
  326. # Once there have been no connections to a session for a reasonable period, we close it
  327. try:
  328. await asyncio.sleep(self.config.sandbox.close_delay)
  329. finally:
  330. # If the sleep was cancelled, we still want to close these
  331. await self._cleanup_session(sid)
  332. async def _cleanup_session(self, sid: str) -> bool:
  333. # Get local connections
  334. logger.info(f'_cleanup_session:{sid}')
  335. has_local_connections = next(
  336. (True for v in self.local_connection_id_to_session_id.values() if v == sid),
  337. False,
  338. )
  339. if has_local_connections:
  340. return False
  341. # If no local connections, get connections through redis
  342. redis_client = self._get_redis_client()
  343. if redis_client and await self._has_remote_connections(sid):
  344. return False
  345. # We alert the cluster in case they are interested
  346. if redis_client:
  347. await redis_client.publish(
  348. 'oh_event',
  349. json.dumps({'sid': sid, 'message_type': 'session_closing'}),
  350. )
  351. await self._close_session(sid)
  352. return True
  353. async def _close_session(self, sid: str):
  354. logger.info(f'_close_session:{sid}')
  355. # Clear up local variables
  356. connection_ids_to_remove = list(
  357. connection_id
  358. for connection_id, sid in self.local_connection_id_to_session_id.items()
  359. if sid == sid
  360. )
  361. logger.info(f'removing connections: {connection_ids_to_remove}')
  362. for connnnection_id in connection_ids_to_remove:
  363. self.local_connection_id_to_session_id.pop(connnnection_id, None)
  364. session = self._local_agent_loops_by_sid.pop(sid, None)
  365. if not session:
  366. logger.warning(f'no_session_to_close:{sid}')
  367. return
  368. logger.info(f'closing_session:{session.sid}')
  369. # We alert the cluster in case they are interested
  370. redis_client = self._get_redis_client()
  371. if redis_client:
  372. await redis_client.publish(
  373. 'oh_event',
  374. json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
  375. )
  376. await call_sync_from_async(session.close)
  377. logger.info(f'closed_session:{session.sid}')