manager.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import asyncio
  2. import atexit
  3. import json
  4. import os
  5. import time
  6. from typing import Callable
  7. from fastapi import WebSocket
  8. from opendevin.core.logger import opendevin_logger as logger
  9. from .msg_stack import message_stack
  10. from .session import Session
  11. CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
  12. SESSION_CACHE_FILE = os.path.join(CACHE_DIR, 'sessions.json')
  13. class SessionManager:
  14. _sessions: dict[str, Session] = {}
  15. cleanup_interval: int = 300
  16. session_timeout: int = 600
  17. def __init__(self):
  18. self._load_sessions()
  19. atexit.register(self.close)
  20. asyncio.create_task(self._cleanup_sessions())
  21. def add_session(self, sid: str, ws_conn: WebSocket):
  22. if sid not in self._sessions:
  23. self._sessions[sid] = Session(sid=sid, ws=ws_conn)
  24. return
  25. self._sessions[sid].update_connection(ws_conn)
  26. async def loop_recv(self, sid: str, dispatch: Callable):
  27. print(f'Starting loop_recv for sid: {sid}')
  28. """Starts listening for messages from the client."""
  29. if sid not in self._sessions:
  30. return
  31. await self._sessions[sid].loop_recv(dispatch)
  32. def close(self):
  33. logger.info('Saving sessions...')
  34. self._save_sessions()
  35. async def send(self, sid: str, data: dict[str, object]) -> bool:
  36. """Sends data to the client."""
  37. message_stack.add_message(sid, 'assistant', data)
  38. if sid not in self._sessions:
  39. return False
  40. return await self._sessions[sid].send(data)
  41. async def send_error(self, sid: str, message: str) -> bool:
  42. """Sends an error message to the client."""
  43. return await self.send(sid, {'error': True, 'message': message})
  44. async def send_message(self, sid: str, message: str) -> bool:
  45. """Sends a message to the client."""
  46. return await self.send(sid, {'message': message})
  47. def _save_sessions(self):
  48. data = {}
  49. for sid, conn in self._sessions.items():
  50. data[sid] = {
  51. 'sid': conn.sid,
  52. 'last_active_ts': conn.last_active_ts,
  53. 'is_alive': conn.is_alive,
  54. }
  55. if not os.path.exists(CACHE_DIR):
  56. os.makedirs(CACHE_DIR)
  57. with open(SESSION_CACHE_FILE, 'w+') as file:
  58. json.dump(data, file)
  59. def _load_sessions(self):
  60. try:
  61. with open(SESSION_CACHE_FILE, 'r') as file:
  62. data = json.load(file)
  63. for sid, sdata in data.items():
  64. conn = Session(sid, None)
  65. ok = conn.load_from_data(sdata)
  66. if ok:
  67. self._sessions[sid] = conn
  68. except FileNotFoundError:
  69. pass
  70. except json.decoder.JSONDecodeError:
  71. pass
  72. async def _cleanup_sessions(self):
  73. while True:
  74. current_time = time.time()
  75. session_ids_to_remove = []
  76. for sid, session in list(self._sessions.items()):
  77. # if session inactive for a long time, remove it
  78. if (
  79. not session.is_alive
  80. and current_time - session.last_active_ts > self.session_timeout
  81. ):
  82. session_ids_to_remove.append(sid)
  83. for sid in session_ids_to_remove:
  84. del self._sessions[sid]
  85. logger.info(f'Session {sid} has been removed due to inactivity.')
  86. await asyncio.sleep(self.cleanup_interval)