manager.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import os
  2. import json
  3. import atexit
  4. import signal
  5. from typing import Dict, Callable
  6. from fastapi import WebSocket
  7. from .session import Session
  8. from .msg_stack import message_stack
  9. CACHE_DIR = os.getenv("CACHE_DIR", "cache")
  10. SESSION_CACHE_FILE = os.path.join(CACHE_DIR, "sessions.json")
  11. class SessionManager:
  12. _sessions: Dict[str, Session] = {}
  13. def __init__(self):
  14. self._load_sessions()
  15. atexit.register(self.close)
  16. signal.signal(signal.SIGINT, self.handle_signal)
  17. signal.signal(signal.SIGTERM, self.handle_signal)
  18. def add_session(self, sid: str, ws_conn: WebSocket):
  19. if sid not in self._sessions:
  20. self._sessions[sid] = Session(sid=sid, ws=ws_conn)
  21. return
  22. self._sessions[sid].update_connection(ws_conn)
  23. async def loop_recv(self, sid: str, dispatch: Callable):
  24. print(f"Starting loop_recv for sid: {sid}, {sid not in self._sessions}")
  25. """Starts listening for messages from the client."""
  26. if sid not in self._sessions:
  27. return
  28. await self._sessions[sid].loop_recv(dispatch)
  29. def close(self):
  30. self._save_sessions()
  31. def handle_signal(self, signum, _):
  32. print(f"Received signal {signum}, exiting...")
  33. self.close()
  34. exit(0)
  35. async def send(self, sid: str, data: Dict[str, object]) -> bool:
  36. """Sends data to the client."""
  37. message_stack.add_message(sid, "assistant", data)
  38. if sid not in self._sessions:
  39. return False
  40. return await self._sessions[sid].send(data)
  41. async def send_error(self, sid: str, message: str) -> bool:
  42. """Sends an error message to the client."""
  43. return await self.send(sid, {"error": True, "message": message})
  44. async def send_message(self, sid: str, message: str) -> bool:
  45. """Sends a message to the client."""
  46. return await self.send(sid, {"message": message})
  47. def _save_sessions(self):
  48. data = {}
  49. for sid, conn in self._sessions.items():
  50. data[sid] = {
  51. "sid": conn.sid,
  52. "last_active_ts": conn.last_active_ts,
  53. "is_alive": conn.is_alive,
  54. }
  55. if not os.path.exists(CACHE_DIR):
  56. os.makedirs(CACHE_DIR)
  57. with open(SESSION_CACHE_FILE, "w+") as file:
  58. json.dump(data, file)
  59. def _load_sessions(self):
  60. try:
  61. with open(SESSION_CACHE_FILE, "r") as file:
  62. data = json.load(file)
  63. for sid, sdata in data.items():
  64. conn = Session(sid, None)
  65. ok = conn.load_from_data(sdata)
  66. if ok:
  67. self._sessions[sid] = conn
  68. except FileNotFoundError:
  69. pass
  70. except json.decoder.JSONDecodeError:
  71. pass
  72. session_manager = SessionManager()