manager.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. import asyncio
  2. import json
  3. import time
  4. from dataclasses import dataclass, field
  5. from uuid import uuid4
  6. import socketio
  7. from openhands.core.config import AppConfig
  8. from openhands.core.exceptions import AgentRuntimeUnavailableError
  9. from openhands.core.logger import openhands_logger as logger
  10. from openhands.events.stream import EventStream, session_exists
  11. from openhands.server.session.conversation import Conversation
  12. from openhands.server.session.session import ROOM_KEY, Session
  13. from openhands.server.settings import Settings
  14. from openhands.storage.files import FileStore
  15. from openhands.utils.async_utils import call_sync_from_async
  16. from openhands.utils.shutdown_listener import should_continue
  17. _REDIS_POLL_TIMEOUT = 1.5
  18. _CHECK_ALIVE_INTERVAL = 15
  19. _CLEANUP_INTERVAL = 15
  20. _CLEANUP_EXCEPTION_WAIT_TIME = 15
  21. class ConversationDoesNotExistError(Exception):
  22. pass
  23. @dataclass
  24. class _SessionIsRunningCheck:
  25. request_id: str
  26. request_sids: list[str]
  27. running_sids: set[str] = field(default_factory=set)
  28. flag: asyncio.Event = field(default_factory=asyncio.Event)
  29. @dataclass
  30. class SessionManager:
  31. sio: socketio.AsyncServer
  32. config: AppConfig
  33. file_store: FileStore
  34. _local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict)
  35. local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
  36. _last_alive_timestamps: dict[str, float] = field(default_factory=dict)
  37. _redis_listen_task: asyncio.Task | None = None
  38. _session_is_running_checks: dict[str, _SessionIsRunningCheck] = field(
  39. default_factory=dict
  40. )
  41. _active_conversations: dict[str, tuple[Conversation, int]] = field(
  42. default_factory=dict
  43. )
  44. _detached_conversations: dict[str, tuple[Conversation, float]] = field(
  45. default_factory=dict
  46. )
  47. _conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
  48. _cleanup_task: asyncio.Task | None = None
  49. _has_remote_connections_flags: dict[str, asyncio.Event] = field(
  50. default_factory=dict
  51. )
  52. async def __aenter__(self):
  53. redis_client = self._get_redis_client()
  54. if redis_client:
  55. self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
  56. self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations())
  57. return self
  58. async def __aexit__(self, exc_type, exc_value, traceback):
  59. if self._redis_listen_task:
  60. self._redis_listen_task.cancel()
  61. self._redis_listen_task = None
  62. if self._cleanup_task:
  63. self._cleanup_task.cancel()
  64. self._cleanup_task = None
  65. def _get_redis_client(self):
  66. redis_client = getattr(self.sio.manager, 'redis', None)
  67. return redis_client
  68. async def _redis_subscribe(self):
  69. """
  70. We use a redis backchannel to send actions between server nodes
  71. """
  72. logger.debug('_redis_subscribe')
  73. redis_client = self._get_redis_client()
  74. pubsub = redis_client.pubsub()
  75. await pubsub.subscribe('oh_event')
  76. while should_continue():
  77. try:
  78. message = await pubsub.get_message(
  79. ignore_subscribe_messages=True, timeout=5
  80. )
  81. if message:
  82. await self._process_message(message)
  83. except asyncio.CancelledError:
  84. return
  85. except Exception:
  86. try:
  87. asyncio.get_running_loop()
  88. logger.warning(
  89. 'error_reading_from_redis', exc_info=True, stack_info=True
  90. )
  91. except RuntimeError:
  92. return # Loop has been shut down
  93. async def _process_message(self, message: dict):
  94. data = json.loads(message['data'])
  95. logger.debug(f'got_published_message:{message}')
  96. message_type = data['message_type']
  97. if message_type == 'event':
  98. sid = data['sid']
  99. session = self._local_agent_loops_by_sid.get(sid)
  100. if session:
  101. await session.dispatch(data['data'])
  102. elif message_type == 'is_session_running':
  103. # Another node in the cluster is asking if the current node is running the session given.
  104. request_id = data['request_id']
  105. sids = [
  106. sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid
  107. ]
  108. if sids:
  109. await self._get_redis_client().publish(
  110. 'oh_event',
  111. json.dumps(
  112. {
  113. 'request_id': request_id,
  114. 'sids': sids,
  115. 'message_type': 'session_is_running',
  116. }
  117. ),
  118. )
  119. elif message_type == 'session_is_running':
  120. request_id = data['request_id']
  121. for sid in data['sids']:
  122. self._last_alive_timestamps[sid] = time.time()
  123. check = self._session_is_running_checks.get(request_id)
  124. if check:
  125. check.running_sids.update(data['sids'])
  126. if len(check.request_sids) == len(check.running_sids):
  127. check.flag.set()
  128. elif message_type == 'has_remote_connections_query':
  129. # Another node in the cluster is asking if the current node is connected to a session
  130. sid = data['sid']
  131. required = sid in self.local_connection_id_to_session_id.values()
  132. if required:
  133. await self._get_redis_client().publish(
  134. 'oh_event',
  135. json.dumps(
  136. {'sid': sid, 'message_type': 'has_remote_connections_response'}
  137. ),
  138. )
  139. elif message_type == 'has_remote_connections_response':
  140. sid = data['sid']
  141. flag = self._has_remote_connections_flags.get(sid)
  142. if flag:
  143. flag.set()
  144. elif message_type == 'session_closing':
  145. # Session closing event - We only get this in the event of graceful shutdown,
  146. # which can't be guaranteed - nodes can simply vanish unexpectedly!
  147. sid = data['sid']
  148. logger.debug(f'session_closing:{sid}')
  149. for (
  150. connection_id,
  151. local_sid,
  152. ) in self.local_connection_id_to_session_id.items():
  153. if sid == local_sid:
  154. logger.warning(
  155. 'local_connection_to_closing_session:{connection_id}:{sid}'
  156. )
  157. await self.sio.disconnect(connection_id)
  158. async def attach_to_conversation(self, sid: str) -> Conversation | None:
  159. start_time = time.time()
  160. if not await session_exists(sid, self.file_store):
  161. return None
  162. async with self._conversations_lock:
  163. # Check if we have an active conversation we can reuse
  164. if sid in self._active_conversations:
  165. conversation, count = self._active_conversations[sid]
  166. self._active_conversations[sid] = (conversation, count + 1)
  167. logger.info(f'Reusing active conversation {sid}')
  168. return conversation
  169. # Check if we have a detached conversation we can reuse
  170. if sid in self._detached_conversations:
  171. conversation, _ = self._detached_conversations.pop(sid)
  172. self._active_conversations[sid] = (conversation, 1)
  173. logger.info(f'Reusing detached conversation {sid}')
  174. return conversation
  175. # Create new conversation if none exists
  176. c = Conversation(sid, file_store=self.file_store, config=self.config)
  177. try:
  178. await c.connect()
  179. except AgentRuntimeUnavailableError as e:
  180. logger.error(f'Error connecting to conversation {c.sid}: {e}')
  181. return None
  182. end_time = time.time()
  183. logger.info(
  184. f'Conversation {c.sid} connected in {end_time - start_time} seconds'
  185. )
  186. self._active_conversations[sid] = (c, 1)
  187. return c
  188. async def join_conversation(self, sid: str, connection_id: str, settings: Settings):
  189. logger.info(f'join_conversation:{sid}:{connection_id}')
  190. await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
  191. self.local_connection_id_to_session_id[connection_id] = sid
  192. event_stream = await self._get_event_stream(sid)
  193. if not event_stream:
  194. return await self.maybe_start_agent_loop(sid, settings)
  195. return event_stream
  196. async def detach_from_conversation(self, conversation: Conversation):
  197. sid = conversation.sid
  198. async with self._conversations_lock:
  199. if sid in self._active_conversations:
  200. conv, count = self._active_conversations[sid]
  201. if count > 1:
  202. self._active_conversations[sid] = (conv, count - 1)
  203. return
  204. else:
  205. self._active_conversations.pop(sid)
  206. self._detached_conversations[sid] = (conversation, time.time())
  207. async def _cleanup_detached_conversations(self):
  208. while should_continue():
  209. if self._get_redis_client():
  210. # Debug info for HA envs
  211. logger.info(
  212. f'Attached conversations: {len(self._active_conversations)}'
  213. )
  214. logger.info(
  215. f'Detached conversations: {len(self._detached_conversations)}'
  216. )
  217. logger.info(
  218. f'Running agent loops: {len(self._local_agent_loops_by_sid)}'
  219. )
  220. logger.info(
  221. f'Local connections: {len(self.local_connection_id_to_session_id)}'
  222. )
  223. try:
  224. async with self._conversations_lock:
  225. # Create a list of items to process to avoid modifying dict during iteration
  226. items = list(self._detached_conversations.items())
  227. for sid, (conversation, detach_time) in items:
  228. await conversation.disconnect()
  229. self._detached_conversations.pop(sid, None)
  230. await asyncio.sleep(_CLEANUP_INTERVAL)
  231. except asyncio.CancelledError:
  232. async with self._conversations_lock:
  233. for conversation, _ in self._detached_conversations.values():
  234. await conversation.disconnect()
  235. self._detached_conversations.clear()
  236. return
  237. except Exception:
  238. logger.warning('error_cleaning_detached_conversations', exc_info=True)
  239. await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
  240. async def get_agent_loop_running(self, sids: set[str]) -> set[str]:
  241. running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid)
  242. check_cluster_sids = [sid for sid in sids if sid not in running_sids]
  243. running_cluster_sids = await self.get_agent_loop_running_in_cluster(
  244. check_cluster_sids
  245. )
  246. running_sids.union(running_cluster_sids)
  247. return running_sids
  248. async def is_agent_loop_running(self, sid: str) -> bool:
  249. if await self.is_agent_loop_running_locally(sid):
  250. return True
  251. if await self.is_agent_loop_running_in_cluster(sid):
  252. return True
  253. return False
  254. async def is_agent_loop_running_locally(self, sid: str) -> bool:
  255. return sid in self._local_agent_loops_by_sid
  256. async def is_agent_loop_running_in_cluster(self, sid: str) -> bool:
  257. running_sids = await self.get_agent_loop_running_in_cluster([sid])
  258. return bool(running_sids)
  259. async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]:
  260. """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
  261. redis_client = self._get_redis_client()
  262. if not redis_client:
  263. return set()
  264. flag = asyncio.Event()
  265. request_id = str(uuid4())
  266. check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids)
  267. self._session_is_running_checks[request_id] = check
  268. try:
  269. logger.debug(f'publish:is_session_running:{sids}')
  270. await redis_client.publish(
  271. 'oh_event',
  272. json.dumps(
  273. {
  274. 'request_id': request_id,
  275. 'sids': sids,
  276. 'message_type': 'is_session_running',
  277. }
  278. ),
  279. )
  280. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  281. await flag.wait()
  282. return check.running_sids
  283. except TimeoutError:
  284. # Nobody replied in time
  285. return check.running_sids
  286. finally:
  287. self._session_is_running_checks.pop(request_id, None)
  288. async def _has_remote_connections(self, sid: str) -> bool:
  289. """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
  290. # Create a flag for the callback
  291. flag = asyncio.Event()
  292. self._has_remote_connections_flags[sid] = flag
  293. try:
  294. await self._get_redis_client().publish(
  295. 'oh_event',
  296. json.dumps(
  297. {
  298. 'sid': sid,
  299. 'message_type': 'has_remote_connections_query',
  300. }
  301. ),
  302. )
  303. async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
  304. await flag.wait()
  305. result = flag.is_set()
  306. return result
  307. except TimeoutError:
  308. # Nobody replied in time
  309. return False
  310. finally:
  311. self._has_remote_connections_flags.pop(sid, None)
  312. async def maybe_start_agent_loop(self, sid: str, settings: Settings) -> EventStream:
  313. logger.info(f'maybe_start_agent_loop:{sid}')
  314. session: Session | None = None
  315. if not await self.is_agent_loop_running(sid):
  316. logger.info(f'start_agent_loop:{sid}')
  317. session = Session(
  318. sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
  319. )
  320. self._local_agent_loops_by_sid[sid] = session
  321. await session.initialize_agent(settings)
  322. event_stream = await self._get_event_stream(sid)
  323. if not event_stream:
  324. logger.error(f'No event stream after starting agent loop: {sid}')
  325. raise RuntimeError(f'no_event_stream:{sid}')
  326. return event_stream
  327. async def _get_event_stream(self, sid: str) -> EventStream | None:
  328. logger.info(f'_get_event_stream:{sid}')
  329. session = self._local_agent_loops_by_sid.get(sid)
  330. if session:
  331. logger.info(f'found_local_agent_loop:{sid}')
  332. return session.agent_session.event_stream
  333. if await self.is_agent_loop_running_in_cluster(sid):
  334. logger.info(f'found_remote_agent_loop:{sid}')
  335. return EventStream(sid, self.file_store)
  336. return None
  337. async def send_to_event_stream(self, connection_id: str, data: dict):
  338. # If there is a local session running, send to that
  339. sid = self.local_connection_id_to_session_id.get(connection_id)
  340. if not sid:
  341. raise RuntimeError(f'no_connected_session:{connection_id}')
  342. session = self._local_agent_loops_by_sid.get(sid)
  343. if session:
  344. await session.dispatch(data)
  345. return
  346. redis_client = self._get_redis_client()
  347. if redis_client:
  348. # If we have a recent report that the session is alive in another pod
  349. last_alive_at = self._last_alive_timestamps.get(sid) or 0
  350. next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
  351. if (
  352. next_alive_check > time.time()
  353. or await self.is_agent_loop_running_in_cluster(sid)
  354. ):
  355. # Send the event to the other pod
  356. await redis_client.publish(
  357. 'oh_event',
  358. json.dumps(
  359. {
  360. 'sid': sid,
  361. 'message_type': 'event',
  362. 'data': data,
  363. }
  364. ),
  365. )
  366. return
  367. raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
  368. async def disconnect_from_session(self, connection_id: str):
  369. sid = self.local_connection_id_to_session_id.pop(connection_id, None)
  370. logger.info(f'disconnect_from_session:{connection_id}:{sid}')
  371. if not sid:
  372. # This can occur if the init action was never run.
  373. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
  374. return
  375. if should_continue():
  376. asyncio.create_task(self._cleanup_session_later(sid))
  377. else:
  378. await self._close_session(sid)
  379. async def _cleanup_session_later(self, sid: str):
  380. # Once there have been no connections to a session for a reasonable period, we close it
  381. try:
  382. await asyncio.sleep(self.config.sandbox.close_delay)
  383. finally:
  384. # If the sleep was cancelled, we still want to close these
  385. await self._cleanup_session(sid)
  386. async def _cleanup_session(self, sid: str) -> bool:
  387. # Get local connections
  388. logger.info(f'_cleanup_session:{sid}')
  389. has_local_connections = next(
  390. (True for v in self.local_connection_id_to_session_id.values() if v == sid),
  391. False,
  392. )
  393. if has_local_connections:
  394. return False
  395. # If no local connections, get connections through redis
  396. redis_client = self._get_redis_client()
  397. if redis_client and await self._has_remote_connections(sid):
  398. return False
  399. # We alert the cluster in case they are interested
  400. if redis_client:
  401. await redis_client.publish(
  402. 'oh_event',
  403. json.dumps({'sid': sid, 'message_type': 'session_closing'}),
  404. )
  405. await self._close_session(sid)
  406. return True
  407. async def _close_session(self, sid: str):
  408. logger.info(f'_close_session:{sid}')
  409. # Clear up local variables
  410. connection_ids_to_remove = list(
  411. connection_id
  412. for connection_id, conn_sid in self.local_connection_id_to_session_id.items()
  413. if sid == conn_sid
  414. )
  415. logger.info(f'removing connections: {connection_ids_to_remove}')
  416. for connnnection_id in connection_ids_to_remove:
  417. self.local_connection_id_to_session_id.pop(connnnection_id, None)
  418. session = self._local_agent_loops_by_sid.pop(sid, None)
  419. if not session:
  420. logger.warning(f'no_session_to_close:{sid}')
  421. return
  422. logger.info(f'closing_session:{session.sid}')
  423. # We alert the cluster in case they are interested
  424. redis_client = self._get_redis_client()
  425. if redis_client:
  426. await redis_client.publish(
  427. 'oh_event',
  428. json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
  429. )
  430. await call_sync_from_async(session.close)
  431. logger.info(f'closed_session:{session.sid}')