session.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import asyncio
  2. import time
  3. from fastapi import WebSocket, WebSocketDisconnect
  4. from opendevin.controller.agent import Agent
  5. from opendevin.core.config import AppConfig
  6. from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
  7. from opendevin.core.logger import opendevin_logger as logger
  8. from opendevin.core.schema import AgentState
  9. from opendevin.core.schema.action import ActionType
  10. from opendevin.core.schema.config import ConfigType
  11. from opendevin.events.action import Action, ChangeAgentStateAction, NullAction
  12. from opendevin.events.event import Event, EventSource
  13. from opendevin.events.observation import (
  14. AgentStateChangedObservation,
  15. CmdOutputObservation,
  16. NullObservation,
  17. )
  18. from opendevin.events.serialization import event_from_dict, event_to_dict
  19. from opendevin.events.stream import EventStreamSubscriber
  20. from opendevin.llm.llm import LLM
  21. from .agent import AgentSession
  22. DEL_DELT_SEC = 60 * 60 * 5
  23. class Session:
  24. sid: str
  25. websocket: WebSocket | None
  26. last_active_ts: int = 0
  27. is_alive: bool = True
  28. agent_session: AgentSession
  29. def __init__(self, sid: str, ws: WebSocket | None, config: AppConfig):
  30. self.sid = sid
  31. self.websocket = ws
  32. self.last_active_ts = int(time.time())
  33. self.agent_session = AgentSession(sid)
  34. self.agent_session.event_stream.subscribe(
  35. EventStreamSubscriber.SERVER, self.on_event
  36. )
  37. self.config = config
  38. async def close(self):
  39. self.is_alive = False
  40. await self.agent_session.close()
  41. async def loop_recv(self):
  42. try:
  43. if self.websocket is None:
  44. return
  45. while True:
  46. try:
  47. data = await self.websocket.receive_json()
  48. except ValueError:
  49. await self.send_error('Invalid JSON')
  50. continue
  51. await self.dispatch(data)
  52. except WebSocketDisconnect:
  53. await self.close()
  54. logger.info('WebSocket disconnected, sid: %s', self.sid)
  55. except RuntimeError as e:
  56. await self.close()
  57. logger.exception('Error in loop_recv: %s', e)
  58. async def _initialize_agent(self, data: dict):
  59. self.agent_session.event_stream.add_event(
  60. ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
  61. )
  62. self.agent_session.event_stream.add_event(
  63. AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
  64. )
  65. # Extract the agent-relevant arguments from the request
  66. args = {
  67. key: value for key, value in data.get('args', {}).items() if value != ''
  68. }
  69. agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
  70. confirmation_mode = args.get(
  71. ConfigType.CONFIRMATION_MODE, self.config.confirmation_mode
  72. )
  73. max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
  74. # override default LLM config
  75. default_llm_config = self.config.get_llm_config()
  76. default_llm_config.model = args.get(
  77. ConfigType.LLM_MODEL, default_llm_config.model
  78. )
  79. default_llm_config.api_key = args.get(
  80. ConfigType.LLM_API_KEY, default_llm_config.api_key
  81. )
  82. default_llm_config.base_url = args.get(
  83. ConfigType.LLM_BASE_URL, default_llm_config.base_url
  84. )
  85. # TODO: override other LLM config & agent config groups (#2075)
  86. llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
  87. agent = Agent.get_cls(agent_cls)(llm)
  88. # Create the agent session
  89. try:
  90. await self.agent_session.start(
  91. runtime_name=self.config.runtime,
  92. sandbox_config=self.config.sandbox,
  93. agent=agent,
  94. confirmation_mode=confirmation_mode,
  95. max_iterations=max_iterations,
  96. max_budget_per_task=self.config.max_budget_per_task,
  97. agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
  98. )
  99. except Exception as e:
  100. logger.exception(f'Error creating controller: {e}')
  101. await self.send_error(
  102. f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
  103. )
  104. return
  105. self.agent_session.event_stream.add_event(
  106. ChangeAgentStateAction(AgentState.INIT), EventSource.USER
  107. )
  108. async def on_event(self, event: Event):
  109. """Callback function for agent events.
  110. Args:
  111. event: The agent event (Observation or Action).
  112. """
  113. if isinstance(event, NullAction):
  114. return
  115. if isinstance(event, NullObservation):
  116. return
  117. if event.source == EventSource.AGENT:
  118. await self.send(event_to_dict(event))
  119. elif event.source == EventSource.USER and isinstance(
  120. event, CmdOutputObservation
  121. ):
  122. await self.send(event_to_dict(event))
  123. async def dispatch(self, data: dict):
  124. action = data.get('action', '')
  125. if action == ActionType.INIT:
  126. await self._initialize_agent(data)
  127. return
  128. event = event_from_dict(data.copy())
  129. self.agent_session.event_stream.add_event(event, EventSource.USER)
  130. if isinstance(event, Action):
  131. logger.info(
  132. event, extra={'msg_type': 'ACTION', 'event_source': EventSource.USER}
  133. )
  134. async def send(self, data: dict[str, object]) -> bool:
  135. try:
  136. if self.websocket is None or not self.is_alive:
  137. return False
  138. await self.websocket.send_json(data)
  139. await asyncio.sleep(0.001) # This flushes the data to the client
  140. self.last_active_ts = int(time.time())
  141. return True
  142. except WebSocketDisconnect:
  143. self.is_alive = False
  144. return False
  145. async def send_error(self, message: str) -> bool:
  146. """Sends an error message to the client."""
  147. return await self.send({'error': True, 'message': message})
  148. async def send_message(self, message: str) -> bool:
  149. """Sends a message to the client."""
  150. return await self.send({'message': message})
  151. def update_connection(self, ws: WebSocket):
  152. self.websocket = ws
  153. self.is_alive = True
  154. self.last_active_ts = int(time.time())
  155. def load_from_data(self, data: dict) -> bool:
  156. self.last_active_ts = data.get('last_active_ts', 0)
  157. if self.last_active_ts < int(time.time()) - DEL_DELT_SEC:
  158. return False
  159. self.is_alive = data.get('is_alive', False)
  160. return True