manager.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. import asyncio
  2. import json
  3. import time
  4. from dataclasses import dataclass, field
  5. import socketio
  6. from openhands.core.config import AppConfig
  7. from openhands.core.exceptions import AgentRuntimeUnavailableError
  8. from openhands.core.logger import openhands_logger as logger
  9. from openhands.events.stream import EventStream, session_exists
  10. from openhands.server.session.conversation import Conversation
  11. from openhands.server.session.conversation_init_data import ConversationInitData
  12. from openhands.server.session.session import ROOM_KEY, Session
  13. from openhands.storage.files import FileStore
  14. from openhands.utils.async_utils import call_sync_from_async
  15. from openhands.utils.shutdown_listener import should_continue
  16. _REDIS_POLL_TIMEOUT = 1.5
  17. _CHECK_ALIVE_INTERVAL = 15
  18. _CLEANUP_INTERVAL = 15
  19. _CLEANUP_EXCEPTION_WAIT_TIME = 15
  20. class ConversationDoesNotExistError(Exception):
  21. pass
  22. @dataclass
  23. class SessionManager:
  24. sio: socketio.AsyncServer
  25. config: AppConfig
  26. file_store: FileStore
  27. _local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict)
  28. local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
  29. _last_alive_timestamps: dict[str, float] = field(default_factory=dict)
  30. _redis_listen_task: asyncio.Task | None = None
  31. _session_is_running_flags: dict[str, asyncio.Event] = field(default_factory=dict)
  32. _active_conversations: dict[str, tuple[Conversation, int]] = field(
  33. default_factory=dict
  34. )
  35. _detached_conversations: dict[str, tuple[Conversation, float]] = field(
  36. default_factory=dict
  37. )
  38. _conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  39. _cleanup_task: asyncio.Task | None = None
  40. _has_remote_connections_flags: dict[str, asyncio.Event] = field(
  41. default_factory=dict
  42. )
  43. async def __aenter__(self):
  44. redis_client = self._get_redis_client()
  45. if redis_client:
  46. self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
  47. self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations())
  48. return self
  49. async def __aexit__(self, exc_type, exc_value, traceback):
  50. if self._redis_listen_task:
  51. self._redis_listen_task.cancel()
  52. self._redis_listen_task = None
  53. if self._cleanup_task:
  54. self._cleanup_task.cancel()
  55. self._cleanup_task = None
  56. def _get_redis_client(self):
  57. redis_client = getattr(self.sio.manager, 'redis', None)
  58. return redis_client
  59. async def _redis_subscribe(self):
  60. """
  61. We use a redis backchannel to send actions between server nodes
  62. """
  63. logger.debug('_redis_subscribe')
  64. redis_client = self._get_redis_client()
  65. pubsub = redis_client.pubsub()
  66. await pubsub.subscribe('oh_event')
  67. while should_continue():
  68. try:
  69. message = await pubsub.get_message(
  70. ignore_subscribe_messages=True, timeout=5
  71. )
  72. if message:
  73. await self._process_message(message)
  74. except asyncio.CancelledError:
  75. return
  76. except Exception:
  77. try:
  78. asyncio.get_running_loop()
  79. logger.warning(
  80. 'error_reading_from_redis', exc_info=True, stack_info=True
  81. )
  82. except RuntimeError:
  83. return # Loop has been shut down
  84. async def _process_message(self, message: dict):
  85. data = json.loads(message['data'])
  86. logger.debug(f'got_published_message:{message}')
  87. sid = data['sid']
  88. message_type = data['message_type']
  89. if message_type == 'event':
  90. session = self._local_agent_loops_by_sid.get(sid)
  91. if session:
  92. await session.dispatch(data['data'])
  93. elif message_type == 'is_session_running':
  94. # Another node in the cluster is asking if the current node is running the session given.
  95. session = self._local_agent_loops_by_sid.get(sid)
  96. if session:
  97. await self._get_redis_client().publish(
  98. 'oh_event',
  99. json.dumps({'sid': sid, 'message_type': 'session_is_running'}),
  100. )
  101. elif message_type == 'session_is_running':
  102. self._last_alive_timestamps[sid] = time.time()
  103. flag = self._session_is_running_flags.get(sid)
  104. if flag:
  105. flag.set()
  106. elif message_type == 'has_remote_connections_query':
  107. # Another node in the cluster is asking if the current node is connected to a session
  108. required = sid in self.local_connection_id_to_session_id.values()
  109. if required:
  110. await self._get_redis_client().publish(
  111. 'oh_event',
  112. json.dumps(
  113. {'sid': sid, 'message_type': 'has_remote_connections_response'}
  114. ),
  115. )
  116. elif message_type == 'has_remote_connections_response':
  117. flag = self._has_remote_connections_flags.get(sid)
  118. if flag:
  119. flag.set()
  120. elif message_type == 'session_closing':
  121. # Session closing event - We only get this in the event of graceful shutdown,
  122. # which can't be guaranteed - nodes can simply vanish unexpectedly!
  123. logger.debug(f'session_closing:{sid}')
  124. for (
  125. connection_id,
  126. local_sid,
  127. ) in self.local_connection_id_to_session_id.items():
  128. if sid == local_sid:
  129. logger.warning(
  130. 'local_connection_to_closing_session:{connection_id}:{sid}'
  131. )
  132. await self.sio.disconnect(connection_id)
  133. async def attach_to_conversation(self, sid: str) -> Conversation | None:
  134. start_time = time.time()
  135. if not await session_exists(sid, self.file_store):
  136. return None
  137. async with self._conversations_lock:
  138. # Check if we have an active conversation we can reuse
  139. if sid in self._active_conversations:
  140. conversation, count = self._active_conversations[sid]
  141. self._active_conversations[sid] = (conversation, count + 1)
  142. logger.info(f'Reusing active conversation {sid}')
  143. return conversation
  144. # Check if we have a detached conversation we can reuse
  145. if sid in self._detached_conversations:
  146. conversation, _ = self._detached_conversations.pop(sid)
  147. self._active_conversations[sid] = (conversation, 1)
  148. logger.info(f'Reusing detached conversation {sid}')
  149. return conversation
  150. # Create new conversation if none exists
  151. c = Conversation(sid, file_store=self.file_store, config=self.config)
  152. try:
  153. await c.connect()
  154. except AgentRuntimeUnavailableError as e:
  155. logger.error(f'Error connecting to conversation {c.sid}: {e}')
  156. return None
  157. end_time = time.time()
  158. logger.info(
  159. f'Conversation {c.sid} connected in {end_time - start_time} seconds'
  160. )
  161. self._active_conversations[sid] = (c, 1)
  162. return c
  163. async def join_conversation(self, sid: str, connection_id: str) -> EventStream:
  164. logger.info(f'join_conversation:{sid}:{connection_id}')
  165. await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
  166. self.local_connection_id_to_session_id[connection_id] = sid
  167. event_stream = await self._get_event_stream(sid)
  168. if not event_stream:
  169. return await self.maybe_start_agent_loop(sid)
  170. return event_stream
  171. async def detach_from_conversation(self, conversation: Conversation):
  172. sid = conversation.sid
  173. async with self._conversations_lock:
  174. if sid in self._active_conversations:
  175. conv, count = self._active_conversations[sid]
  176. if count > 1:
  177. self._active_conversations[sid] = (conv, count - 1)
  178. return
  179. else:
  180. self._active_conversations.pop(sid)
  181. self._detached_conversations[sid] = (conversation, time.time())
  182. async def _cleanup_detached_conversations(self):
  183. while should_continue():
  184. if self._get_redis_client():
  185. # Debug info for HA envs
  186. logger.info(
  187. f'Attached conversations: {len(self._active_conversations)}'
  188. )
  189. logger.info(
  190. f'Detached conversations: {len(self._detached_conversations)}'
  191. )
  192. logger.info(
  193. f'Running agent loops: {len(self._local_agent_loops_by_sid)}'
  194. )
  195. logger.info(
  196. f'Local connections: {len(self.local_connection_id_to_session_id)}'
  197. )
  198. try:
  199. async with self._conversations_lock:
  200. # Create a list of items to process to avoid modifying dict during iteration
  201. items = list(self._detached_conversations.items())
  202. for sid, (conversation, detach_time) in items:
  203. await conversation.disconnect()
  204. self._detached_conversations.pop(sid, None)
  205. await asyncio.sleep(_CLEANUP_INTERVAL)
  206. except asyncio.CancelledError:
  207. async with self._conversations_lock:
  208. for conversation, _ in self._detached_conversations.values():
  209. await conversation.disconnect()
  210. self._detached_conversations.clear()
  211. return
  212. except Exception:
  213. logger.warning('error_cleaning_detached_conversations', exc_info=True)
  214. await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
  215. async def _is_agent_loop_running(self, sid: str) -> bool:
  216. if await self._is_agent_loop_running_locally(sid):
  217. return True
  218. if await self._is_agent_loop_running_in_cluster(sid):
  219. return True
  220. return False
  221. async def _is_agent_loop_running_locally(self, sid: str) -> bool:
  222. if self._local_agent_loops_by_sid.get(sid, None):
  223. return True
  224. return False
  225. async def _is_agent_loop_running_in_cluster(self, sid: str) -> bool:
  226. """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
  227. redis_client = self._get_redis_client()
  228. if not redis_client:
  229. return False
  230. flag = asyncio.Event()
  231. self._session_is_running_flags[sid] = flag
  232. try:
  233. logger.debug(f'publish:is_session_running:{sid}')
  234. await redis_client.publish(
  235. 'oh_event',
  236. json.dumps(
  237. {
  238. 'sid': sid,
  239. 'message_type': 'is_session_running',
  240. }
  241. ),
  242. )
  243. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  244. await flag.wait()
  245. result = flag.is_set()
  246. return result
  247. except TimeoutError:
  248. # Nobody replied in time
  249. return False
  250. finally:
  251. self._session_is_running_flags.pop(sid, None)
  252. async def _has_remote_connections(self, sid: str) -> bool:
  253. """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
  254. # Create a flag for the callback
  255. flag = asyncio.Event()
  256. self._has_remote_connections_flags[sid] = flag
  257. try:
  258. await self._get_redis_client().publish(
  259. 'oh_event',
  260. json.dumps(
  261. {
  262. 'sid': sid,
  263. 'message_type': 'has_remote_connections_query',
  264. }
  265. ),
  266. )
  267. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  268. await flag.wait()
  269. result = flag.is_set()
  270. return result
  271. except TimeoutError:
  272. # Nobody replied in time
  273. return False
  274. finally:
  275. self._has_remote_connections_flags.pop(sid, None)
  276. async def maybe_start_agent_loop(
  277. self, sid: str, conversation_init_data: ConversationInitData | None = None
  278. ) -> EventStream:
  279. logger.info(f'maybe_start_agent_loop:{sid}')
  280. session: Session | None = None
  281. if not await self._is_agent_loop_running(sid):
  282. logger.info(f'start_agent_loop:{sid}')
  283. session = Session(
  284. sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
  285. )
  286. self._local_agent_loops_by_sid[sid] = session
  287. await session.initialize_agent(conversation_init_data)
  288. event_stream = await self._get_event_stream(sid)
  289. if not event_stream:
  290. logger.error(f'No event stream after starting agent loop: {sid}')
  291. raise RuntimeError(f'no_event_stream:{sid}')
  292. return event_stream
  293. async def _get_event_stream(self, sid: str) -> EventStream | None:
  294. logger.info(f'_get_event_stream:{sid}')
  295. session = self._local_agent_loops_by_sid.get(sid)
  296. if session:
  297. logger.info(f'found_local_agent_loop:{sid}')
  298. return session.agent_session.event_stream
  299. if await self._is_agent_loop_running_in_cluster(sid):
  300. logger.info(f'found_remote_agent_loop:{sid}')
  301. return EventStream(sid, self.file_store)
  302. return None
  303. async def send_to_event_stream(self, connection_id: str, data: dict):
  304. # If there is a local session running, send to that
  305. sid = self.local_connection_id_to_session_id.get(connection_id)
  306. if not sid:
  307. raise RuntimeError(f'no_connected_session:{connection_id}')
  308. session = self._local_agent_loops_by_sid.get(sid)
  309. if session:
  310. await session.dispatch(data)
  311. return
  312. redis_client = self._get_redis_client()
  313. if redis_client:
  314. # If we have a recent report that the session is alive in another pod
  315. last_alive_at = self._last_alive_timestamps.get(sid) or 0
  316. next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
  317. if (
  318. next_alive_check > time.time()
  319. or await self._is_agent_loop_running_in_cluster(sid)
  320. ):
  321. # Send the event to the other pod
  322. await redis_client.publish(
  323. 'oh_event',
  324. json.dumps(
  325. {
  326. 'sid': sid,
  327. 'message_type': 'event',
  328. 'data': data,
  329. }
  330. ),
  331. )
  332. return
  333. raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
  334. async def disconnect_from_session(self, connection_id: str):
  335. sid = self.local_connection_id_to_session_id.pop(connection_id, None)
  336. logger.info(f'disconnect_from_session:{connection_id}:{sid}')
  337. if not sid:
  338. # This can occur if the init action was never run.
  339. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
  340. return
  341. if should_continue():
  342. asyncio.create_task(self._cleanup_session_later(sid))
  343. else:
  344. await self._close_session(sid)
  345. async def _cleanup_session_later(self, sid: str):
  346. # Once there have been no connections to a session for a reasonable period, we close it
  347. try:
  348. await asyncio.sleep(self.config.sandbox.close_delay)
  349. finally:
  350. # If the sleep was cancelled, we still want to close these
  351. await self._cleanup_session(sid)
  352. async def _cleanup_session(self, sid: str) -> bool:
  353. # Get local connections
  354. logger.info(f'_cleanup_session:{sid}')
  355. has_local_connections = next(
  356. (True for v in self.local_connection_id_to_session_id.values() if v == sid),
  357. False,
  358. )
  359. if has_local_connections:
  360. return False
  361. # If no local connections, get connections through redis
  362. redis_client = self._get_redis_client()
  363. if redis_client and await self._has_remote_connections(sid):
  364. return False
  365. # We alert the cluster in case they are interested
  366. if redis_client:
  367. await redis_client.publish(
  368. 'oh_event',
  369. json.dumps({'sid': sid, 'message_type': 'session_closing'}),
  370. )
  371. await self._close_session(sid)
  372. return True
  373. async def _close_session(self, sid: str):
  374. logger.info(f'_close_session:{sid}')
  375. # Clear up local variables
  376. connection_ids_to_remove = list(
  377. connection_id
  378. for connection_id, conn_sid in self.local_connection_id_to_session_id.items()
  379. if sid == conn_sid
  380. )
  381. logger.info(f'removing connections: {connection_ids_to_remove}')
  382. for connnnection_id in connection_ids_to_remove:
  383. self.local_connection_id_to_session_id.pop(connnnection_id, None)
  384. session = self._local_agent_loops_by_sid.pop(sid, None)
  385. if not session:
  386. logger.warning(f'no_session_to_close:{sid}')
  387. return
  388. logger.info(f'closing_session:{session.sid}')
  389. # We alert the cluster in case they are interested
  390. redis_client = self._get_redis_client()
  391. if redis_client:
  392. await redis_client.publish(
  393. 'oh_event',
  394. json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
  395. )
  396. await call_sync_from_async(session.close)
  397. logger.info(f'closed_session:{session.sid}')