manager.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import atexit
  2. import json
  3. import os
  4. from typing import Dict, Callable
  5. from fastapi import WebSocket
  6. from opendevin.logger import opendevin_logger as logger
  7. from .msg_stack import message_stack
  8. from .session import Session
  9. CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
  10. SESSION_CACHE_FILE = os.path.join(CACHE_DIR, 'sessions.json')
  11. class SessionManager:
  12. _sessions: Dict[str, Session] = {}
  13. def __init__(self):
  14. self._load_sessions()
  15. atexit.register(self.close)
  16. def add_session(self, sid: str, ws_conn: WebSocket):
  17. if sid not in self._sessions:
  18. self._sessions[sid] = Session(sid=sid, ws=ws_conn)
  19. return
  20. self._sessions[sid].update_connection(ws_conn)
  21. async def loop_recv(self, sid: str, dispatch: Callable):
  22. print(f'Starting loop_recv for sid: {sid}')
  23. """Starts listening for messages from the client."""
  24. if sid not in self._sessions:
  25. return
  26. await self._sessions[sid].loop_recv(dispatch)
  27. def close(self):
  28. logger.info('Saving sessions...')
  29. self._save_sessions()
  30. async def send(self, sid: str, data: Dict[str, object]) -> bool:
  31. """Sends data to the client."""
  32. message_stack.add_message(sid, 'assistant', data)
  33. if sid not in self._sessions:
  34. return False
  35. return await self._sessions[sid].send(data)
  36. async def send_error(self, sid: str, message: str) -> bool:
  37. """Sends an error message to the client."""
  38. return await self.send(sid, {'error': True, 'message': message})
  39. async def send_message(self, sid: str, message: str) -> bool:
  40. """Sends a message to the client."""
  41. return await self.send(sid, {'message': message})
  42. def _save_sessions(self):
  43. data = {}
  44. for sid, conn in self._sessions.items():
  45. data[sid] = {
  46. 'sid': conn.sid,
  47. 'last_active_ts': conn.last_active_ts,
  48. 'is_alive': conn.is_alive,
  49. }
  50. if not os.path.exists(CACHE_DIR):
  51. os.makedirs(CACHE_DIR)
  52. with open(SESSION_CACHE_FILE, 'w+') as file:
  53. json.dump(data, file)
  54. def _load_sessions(self):
  55. try:
  56. with open(SESSION_CACHE_FILE, 'r') as file:
  57. data = json.load(file)
  58. for sid, sdata in data.items():
  59. conn = Session(sid, None)
  60. ok = conn.load_from_data(sdata)
  61. if ok:
  62. self._sessions[sid] = conn
  63. except FileNotFoundError:
  64. pass
  65. except json.decoder.JSONDecodeError:
  66. pass