session.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import asyncio
  2. import time
  3. from fastapi import WebSocket, WebSocketDisconnect
  4. from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
  5. from opendevin.core.logger import opendevin_logger as logger
  6. from opendevin.core.schema import AgentState
  7. from opendevin.core.schema.action import ActionType
  8. from opendevin.events.action import ChangeAgentStateAction, NullAction
  9. from opendevin.events.event import Event, EventSource
  10. from opendevin.events.observation import AgentStateChangedObservation, NullObservation
  11. from opendevin.events.serialization import event_from_dict, event_to_dict
  12. from opendevin.events.stream import EventStreamSubscriber
  13. from .agent import AgentSession
  14. DEL_DELT_SEC = 60 * 60 * 5
  15. class Session:
  16. sid: str
  17. websocket: WebSocket | None
  18. last_active_ts: int = 0
  19. is_alive: bool = True
  20. agent_session: AgentSession
  21. def __init__(self, sid: str, ws: WebSocket | None):
  22. self.sid = sid
  23. self.websocket = ws
  24. self.last_active_ts = int(time.time())
  25. self.agent_session = AgentSession(sid)
  26. self.agent_session.event_stream.subscribe(
  27. EventStreamSubscriber.SERVER, self.on_event
  28. )
  29. async def close(self):
  30. self.is_alive = False
  31. await self.agent_session.close()
  32. async def loop_recv(self):
  33. try:
  34. if self.websocket is None:
  35. return
  36. while True:
  37. try:
  38. data = await self.websocket.receive_json()
  39. except ValueError:
  40. await self.send_error('Invalid JSON')
  41. continue
  42. await self.dispatch(data)
  43. except WebSocketDisconnect:
  44. await self.close()
  45. logger.info('WebSocket disconnected, sid: %s', self.sid)
  46. except RuntimeError as e:
  47. await self.close()
  48. logger.exception('Error in loop_recv: %s', e)
  49. async def _initialize_agent(self, data: dict):
  50. await self.agent_session.event_stream.add_event(
  51. ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
  52. )
  53. await self.agent_session.event_stream.add_event(
  54. AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
  55. )
  56. try:
  57. await self.agent_session.start(data)
  58. except Exception as e:
  59. logger.exception(f'Error creating controller: {e}')
  60. await self.send_error(
  61. f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
  62. )
  63. return
  64. await self.agent_session.event_stream.add_event(
  65. ChangeAgentStateAction(AgentState.INIT), EventSource.USER
  66. )
  67. async def on_event(self, event: Event):
  68. """Callback function for agent events.
  69. Args:
  70. event: The agent event (Observation or Action).
  71. """
  72. if isinstance(event, NullAction):
  73. return
  74. if isinstance(event, NullObservation):
  75. return
  76. if event.source == EventSource.AGENT and not isinstance(
  77. event, (NullAction, NullObservation)
  78. ):
  79. await self.send(event_to_dict(event))
  80. async def dispatch(self, data: dict):
  81. action = data.get('action', '')
  82. if action == ActionType.INIT:
  83. await self._initialize_agent(data)
  84. return
  85. event = event_from_dict(data.copy())
  86. await self.agent_session.event_stream.add_event(event, EventSource.USER)
  87. async def send(self, data: dict[str, object]) -> bool:
  88. try:
  89. if self.websocket is None or not self.is_alive:
  90. return False
  91. await self.websocket.send_json(data)
  92. await asyncio.sleep(0.001) # This flushes the data to the client
  93. self.last_active_ts = int(time.time())
  94. return True
  95. except WebSocketDisconnect:
  96. self.is_alive = False
  97. return False
  98. async def send_error(self, message: str) -> bool:
  99. """Sends an error message to the client."""
  100. return await self.send({'error': True, 'message': message})
  101. async def send_message(self, message: str) -> bool:
  102. """Sends a message to the client."""
  103. return await self.send({'message': message})
  104. def update_connection(self, ws: WebSocket):
  105. self.websocket = ws
  106. self.is_alive = True
  107. self.last_active_ts = int(time.time())
  108. def load_from_data(self, data: dict) -> bool:
  109. self.last_active_ts = data.get('last_active_ts', 0)
  110. if self.last_active_ts < int(time.time()) - DEL_DELT_SEC:
  111. return False
  112. self.is_alive = data.get('is_alive', False)
  113. return True