session.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import asyncio
  2. import time
  3. import socketio
  4. from openhands.controller.agent import Agent
  5. from openhands.core.config import AppConfig
  6. from openhands.core.const.guide_url import TROUBLESHOOTING_URL
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.core.schema import AgentState
  9. from openhands.core.schema.config import ConfigType
  10. from openhands.events.action import MessageAction, NullAction
  11. from openhands.events.event import Event, EventSource
  12. from openhands.events.observation import (
  13. AgentStateChangedObservation,
  14. CmdOutputObservation,
  15. NullObservation,
  16. )
  17. from openhands.events.observation.error import ErrorObservation
  18. from openhands.events.serialization import event_from_dict, event_to_dict
  19. from openhands.events.stream import EventStreamSubscriber
  20. from openhands.llm.llm import LLM
  21. from openhands.server.session.agent_session import AgentSession
  22. from openhands.storage.files import FileStore
  23. ROOM_KEY = 'room:{sid}'
  24. class Session:
  25. sid: str
  26. sio: socketio.AsyncServer | None
  27. last_active_ts: int = 0
  28. is_alive: bool = True
  29. agent_session: AgentSession
  30. loop: asyncio.AbstractEventLoop
  31. config: AppConfig
  32. settings: dict | None
  33. def __init__(
  34. self,
  35. sid: str,
  36. config: AppConfig,
  37. file_store: FileStore,
  38. sio: socketio.AsyncServer | None,
  39. ):
  40. self.sid = sid
  41. self.sio = sio
  42. self.last_active_ts = int(time.time())
  43. self.agent_session = AgentSession(
  44. sid, file_store, status_callback=self.queue_status_message
  45. )
  46. self.agent_session.event_stream.subscribe(
  47. EventStreamSubscriber.SERVER, self.on_event, self.sid
  48. )
  49. self.config = config
  50. self.loop = asyncio.get_event_loop()
  51. self.settings = None
  52. def close(self):
  53. self.is_alive = False
  54. self.agent_session.close()
  55. async def initialize_agent(self, data: dict):
  56. self.settings = data
  57. self.agent_session.event_stream.add_event(
  58. AgentStateChangedObservation('', AgentState.LOADING),
  59. EventSource.ENVIRONMENT,
  60. )
  61. # Extract the agent-relevant arguments from the request
  62. args = {key: value for key, value in data.get('args', {}).items()}
  63. agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
  64. self.config.security.confirmation_mode = args.get(
  65. ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode
  66. )
  67. self.config.security.security_analyzer = data.get('args', {}).get(
  68. ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer
  69. )
  70. max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
  71. # override default LLM config
  72. default_llm_config = self.config.get_llm_config()
  73. default_llm_config.model = args.get(
  74. ConfigType.LLM_MODEL, default_llm_config.model
  75. )
  76. default_llm_config.api_key = args.get(
  77. ConfigType.LLM_API_KEY, default_llm_config.api_key
  78. )
  79. default_llm_config.base_url = args.get(
  80. ConfigType.LLM_BASE_URL, default_llm_config.base_url
  81. )
  82. # TODO: override other LLM config & agent config groups (#2075)
  83. llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
  84. agent_config = self.config.get_agent_config(agent_cls)
  85. agent = Agent.get_cls(agent_cls)(llm, agent_config)
  86. try:
  87. await self.agent_session.start(
  88. runtime_name=self.config.runtime,
  89. config=self.config,
  90. agent=agent,
  91. max_iterations=max_iterations,
  92. max_budget_per_task=self.config.max_budget_per_task,
  93. agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
  94. agent_configs=self.config.get_agent_configs(),
  95. )
  96. except Exception as e:
  97. logger.exception(f'Error creating controller: {e}')
  98. await self.send_error(
  99. f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
  100. )
  101. return
  102. async def on_event(self, event: Event):
  103. """Callback function for events that mainly come from the agent.
  104. Event is the base class for any agent action and observation.
  105. Args:
  106. event: The agent event (Observation or Action).
  107. """
  108. if isinstance(event, NullAction):
  109. return
  110. if isinstance(event, NullObservation):
  111. return
  112. if event.source == EventSource.AGENT:
  113. await self.send(event_to_dict(event))
  114. elif event.source == EventSource.USER:
  115. await self.send(event_to_dict(event))
  116. # NOTE: ipython observations are not sent here currently
  117. elif event.source == EventSource.ENVIRONMENT and isinstance(
  118. event, (CmdOutputObservation, AgentStateChangedObservation)
  119. ):
  120. # feedback from the environment to agent actions is understood as agent events by the UI
  121. event_dict = event_to_dict(event)
  122. event_dict['source'] = EventSource.AGENT
  123. await self.send(event_dict)
  124. elif isinstance(event, ErrorObservation):
  125. # send error events as agent events to the UI
  126. event_dict = event_to_dict(event)
  127. event_dict['source'] = EventSource.AGENT
  128. await self.send(event_dict)
  129. async def dispatch(self, data: dict):
  130. event = event_from_dict(data.copy())
  131. # This checks if the model supports images
  132. if isinstance(event, MessageAction) and event.image_urls:
  133. controller = self.agent_session.controller
  134. if controller:
  135. if controller.agent.llm.config.disable_vision:
  136. await self.send_error(
  137. 'Support for images is disabled for this model, try without an image.'
  138. )
  139. return
  140. if not controller.agent.llm.vision_is_active():
  141. await self.send_error(
  142. 'Model does not support image upload, change to a different model or try without an image.'
  143. )
  144. return
  145. self.agent_session.event_stream.add_event(event, EventSource.USER)
  146. async def send(self, data: dict[str, object]):
  147. if asyncio.get_running_loop() != self.loop:
  148. self.loop.create_task(self._send(data))
  149. return
  150. await self._send(data)
  151. async def _send(self, data: dict[str, object]) -> bool:
  152. try:
  153. if not self.is_alive:
  154. return False
  155. if self.sio:
  156. await self.sio.emit('oh_event', data, to=ROOM_KEY.format(sid=self.sid))
  157. await asyncio.sleep(0.001) # This flushes the data to the client
  158. self.last_active_ts = int(time.time())
  159. return True
  160. except RuntimeError:
  161. logger.error('Error sending', stack_info=True, exc_info=True)
  162. self.is_alive = False
  163. return False
  164. async def send_error(self, message: str):
  165. """Sends an error message to the client."""
  166. await self.send({'error': True, 'message': message})
  167. async def _send_status_message(self, msg_type: str, id: str, message: str):
  168. """Sends a status message to the client."""
  169. if msg_type == 'error':
  170. await self.agent_session.stop_agent_loop_for_error()
  171. await self.send(
  172. {'status_update': True, 'type': msg_type, 'id': id, 'message': message}
  173. )
  174. def queue_status_message(self, msg_type: str, id: str, message: str):
  175. """Queues a status message to be sent asynchronously."""
  176. asyncio.run_coroutine_threadsafe(
  177. self._send_status_message(msg_type, id, message), self.loop
  178. )