manager.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import asyncio
  2. import time
  3. from dataclasses import dataclass, field
  4. from typing import Optional
  5. from fastapi import WebSocket
  6. from openhands.core.config import AppConfig
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.events.stream import session_exists
  9. from openhands.runtime.utils.shutdown_listener import should_continue
  10. from openhands.server.session.conversation import Conversation
  11. from openhands.server.session.session import Session
  12. from openhands.storage.files import FileStore
  13. @dataclass
  14. class SessionManager:
  15. config: AppConfig
  16. file_store: FileStore
  17. cleanup_interval: int = 300
  18. session_timeout: int = 600
  19. _sessions: dict[str, Session] = field(default_factory=dict)
  20. _session_cleanup_task: Optional[asyncio.Task] = None
  21. async def __aenter__(self):
  22. if not self._session_cleanup_task:
  23. self._session_cleanup_task = asyncio.create_task(self._cleanup_sessions())
  24. return self
  25. async def __aexit__(self, exc_type, exc_value, traceback):
  26. if self._session_cleanup_task:
  27. self._session_cleanup_task.cancel()
  28. self._session_cleanup_task = None
  29. def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
  30. if sid in self._sessions:
  31. asyncio.create_task(self._sessions[sid].close())
  32. self._sessions[sid] = Session(
  33. sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
  34. )
  35. return self._sessions[sid]
  36. def get_session(self, sid: str) -> Session | None:
  37. if sid not in self._sessions:
  38. return None
  39. return self._sessions.get(sid)
  40. async def attach_to_conversation(self, sid: str) -> Conversation | None:
  41. if not session_exists(sid, self.file_store):
  42. return None
  43. c = Conversation(sid, file_store=self.file_store, config=self.config)
  44. await c.connect()
  45. return c
  46. async def send(self, sid: str, data: dict[str, object]) -> bool:
  47. """Sends data to the client."""
  48. session = self.get_session(sid)
  49. if session is None:
  50. logger.error(f'*** No session found for {sid}, skipping message ***')
  51. return False
  52. return await session.send(data)
  53. async def send_error(self, sid: str, message: str) -> bool:
  54. """Sends an error message to the client."""
  55. return await self.send(sid, {'error': True, 'message': message})
  56. async def send_message(self, sid: str, message: str) -> bool:
  57. """Sends a message to the client."""
  58. return await self.send(sid, {'message': message})
  59. async def _cleanup_sessions(self):
  60. while should_continue():
  61. current_time = time.time()
  62. session_ids_to_remove = []
  63. for sid, session in list(self._sessions.items()):
  64. # if session inactive for a long time, remove it
  65. if (
  66. not session.is_alive
  67. and current_time - session.last_active_ts > self.session_timeout
  68. ):
  69. session_ids_to_remove.append(sid)
  70. for sid in session_ids_to_remove:
  71. to_del_session: Session | None = self._sessions.pop(sid, None)
  72. if to_del_session is not None:
  73. await to_del_session.close()
  74. logger.info(
  75. f'Session {sid} and related resource have been removed due to inactivity.'
  76. )
  77. await asyncio.sleep(self.cleanup_interval)