session.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import asyncio
  2. import time
  3. from fastapi import WebSocket, WebSocketDisconnect
  4. from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
  5. from opendevin.core.logger import opendevin_logger as logger
  6. from opendevin.core.schema import AgentState
  7. from opendevin.core.schema.action import ActionType
  8. from opendevin.events.action import Action, ChangeAgentStateAction, NullAction
  9. from opendevin.events.event import Event, EventSource
  10. from opendevin.events.observation import (
  11. AgentStateChangedObservation,
  12. CmdOutputObservation,
  13. NullObservation,
  14. )
  15. from opendevin.events.serialization import event_from_dict, event_to_dict
  16. from opendevin.events.stream import EventStreamSubscriber
  17. from .agent import AgentSession
  18. DEL_DELT_SEC = 60 * 60 * 5
  19. class Session:
  20. sid: str
  21. websocket: WebSocket | None
  22. last_active_ts: int = 0
  23. is_alive: bool = True
  24. agent_session: AgentSession
  25. def __init__(self, sid: str, ws: WebSocket | None):
  26. self.sid = sid
  27. self.websocket = ws
  28. self.last_active_ts = int(time.time())
  29. self.agent_session = AgentSession(sid)
  30. self.agent_session.event_stream.subscribe(
  31. EventStreamSubscriber.SERVER, self.on_event
  32. )
  33. async def close(self):
  34. self.is_alive = False
  35. await self.agent_session.close()
  36. async def loop_recv(self):
  37. try:
  38. if self.websocket is None:
  39. return
  40. while True:
  41. try:
  42. data = await self.websocket.receive_json()
  43. except ValueError:
  44. await self.send_error('Invalid JSON')
  45. continue
  46. await self.dispatch(data)
  47. except WebSocketDisconnect:
  48. await self.close()
  49. logger.info('WebSocket disconnected, sid: %s', self.sid)
  50. except RuntimeError as e:
  51. await self.close()
  52. logger.exception('Error in loop_recv: %s', e)
  53. async def _initialize_agent(self, data: dict):
  54. self.agent_session.event_stream.add_event(
  55. ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
  56. )
  57. self.agent_session.event_stream.add_event(
  58. AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
  59. )
  60. try:
  61. await self.agent_session.start(data)
  62. except Exception as e:
  63. logger.exception(f'Error creating controller: {e}')
  64. await self.send_error(
  65. f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
  66. )
  67. return
  68. self.agent_session.event_stream.add_event(
  69. ChangeAgentStateAction(AgentState.INIT), EventSource.USER
  70. )
  71. async def on_event(self, event: Event):
  72. """Callback function for agent events.
  73. Args:
  74. event: The agent event (Observation or Action).
  75. """
  76. if isinstance(event, NullAction):
  77. return
  78. if isinstance(event, NullObservation):
  79. return
  80. if event.source == EventSource.AGENT:
  81. await self.send(event_to_dict(event))
  82. elif event.source == EventSource.USER and isinstance(
  83. event, CmdOutputObservation
  84. ):
  85. await self.send(event_to_dict(event))
  86. async def dispatch(self, data: dict):
  87. action = data.get('action', '')
  88. if action == ActionType.INIT:
  89. await self._initialize_agent(data)
  90. return
  91. event = event_from_dict(data.copy())
  92. self.agent_session.event_stream.add_event(event, EventSource.USER)
  93. if isinstance(event, Action):
  94. logger.info(
  95. event, extra={'msg_type': 'ACTION', 'event_source': EventSource.USER}
  96. )
  97. async def send(self, data: dict[str, object]) -> bool:
  98. try:
  99. if self.websocket is None or not self.is_alive:
  100. return False
  101. await self.websocket.send_json(data)
  102. await asyncio.sleep(0.001) # This flushes the data to the client
  103. self.last_active_ts = int(time.time())
  104. return True
  105. except WebSocketDisconnect:
  106. self.is_alive = False
  107. return False
  108. async def send_error(self, message: str) -> bool:
  109. """Sends an error message to the client."""
  110. return await self.send({'error': True, 'message': message})
  111. async def send_message(self, message: str) -> bool:
  112. """Sends a message to the client."""
  113. return await self.send({'message': message})
  114. def update_connection(self, ws: WebSocket):
  115. self.websocket = ws
  116. self.is_alive = True
  117. self.last_active_ts = int(time.time())
  118. def load_from_data(self, data: dict) -> bool:
  119. self.last_active_ts = data.get('last_active_ts', 0)
  120. if self.last_active_ts < int(time.time()) - DEL_DELT_SEC:
  121. return False
  122. self.is_alive = data.get('is_alive', False)
  123. return True