session.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import asyncio
  2. from copy import deepcopy
  3. import time
  4. import socketio
  5. from openhands.controller.agent import Agent
  6. from openhands.core.config import AppConfig
  7. from openhands.core.const.guide_url import TROUBLESHOOTING_URL
  8. from openhands.core.logger import openhands_logger as logger
  9. from openhands.core.schema import AgentState
  10. from openhands.core.schema.config import ConfigType
  11. from openhands.events.action import MessageAction, NullAction
  12. from openhands.events.event import Event, EventSource
  13. from openhands.events.observation import (
  14. AgentStateChangedObservation,
  15. CmdOutputObservation,
  16. NullObservation,
  17. )
  18. from openhands.events.observation.error import ErrorObservation
  19. from openhands.events.serialization import event_from_dict, event_to_dict
  20. from openhands.events.stream import EventStreamSubscriber
  21. from openhands.llm.llm import LLM
  22. from openhands.server.session.agent_session import AgentSession
  23. from openhands.server.session.session_init_data import SessionInitData
  24. from openhands.storage.files import FileStore
  25. ROOM_KEY = 'room:{sid}'
  26. class Session:
  27. sid: str
  28. sio: socketio.AsyncServer | None
  29. last_active_ts: int = 0
  30. is_alive: bool = True
  31. agent_session: AgentSession
  32. loop: asyncio.AbstractEventLoop
  33. config: AppConfig
  34. def __init__(
  35. self,
  36. sid: str,
  37. config: AppConfig,
  38. file_store: FileStore,
  39. sio: socketio.AsyncServer | None,
  40. ):
  41. self.sid = sid
  42. self.sio = sio
  43. self.last_active_ts = int(time.time())
  44. self.agent_session = AgentSession(
  45. sid, file_store, status_callback=self.queue_status_message
  46. )
  47. self.agent_session.event_stream.subscribe(
  48. EventStreamSubscriber.SERVER, self.on_event, self.sid
  49. )
  50. # Copying this means that when we update variables they are not applied to the shared global configuration!
  51. self.config = deepcopy(config)
  52. self.loop = asyncio.get_event_loop()
  53. def close(self):
  54. self.is_alive = False
  55. self.agent_session.close()
  56. async def initialize_agent(self, session_init_data: SessionInitData):
  57. self.agent_session.event_stream.add_event(
  58. AgentStateChangedObservation('', AgentState.LOADING),
  59. EventSource.ENVIRONMENT,
  60. )
  61. # Extract the agent-relevant arguments from the request
  62. agent_cls = session_init_data.agent or self.config.default_agent
  63. self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode
  64. self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
  65. max_iterations = session_init_data.max_iterations or self.config.max_iterations
  66. # override default LLM config
  67. default_llm_config = self.config.get_llm_config()
  68. default_llm_config.model = session_init_data.llm_model or default_llm_config.model
  69. default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key
  70. default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url
  71. # TODO: override other LLM config & agent config groups (#2075)
  72. llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
  73. agent_config = self.config.get_agent_config(agent_cls)
  74. agent = Agent.get_cls(agent_cls)(llm, agent_config)
  75. try:
  76. await self.agent_session.start(
  77. runtime_name=self.config.runtime,
  78. config=self.config,
  79. agent=agent,
  80. max_iterations=max_iterations,
  81. max_budget_per_task=self.config.max_budget_per_task,
  82. agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
  83. agent_configs=self.config.get_agent_configs(),
  84. github_token=session_init_data.github_token,
  85. selected_repository=session_init_data.selected_repository,
  86. )
  87. except Exception as e:
  88. logger.exception(f'Error creating controller: {e}')
  89. await self.send_error(
  90. f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
  91. )
  92. return
  93. async def on_event(self, event: Event):
  94. """Callback function for events that mainly come from the agent.
  95. Event is the base class for any agent action and observation.
  96. Args:
  97. event: The agent event (Observation or Action).
  98. """
  99. if isinstance(event, NullAction):
  100. return
  101. if isinstance(event, NullObservation):
  102. return
  103. if event.source == EventSource.AGENT:
  104. await self.send(event_to_dict(event))
  105. elif event.source == EventSource.USER:
  106. await self.send(event_to_dict(event))
  107. # NOTE: ipython observations are not sent here currently
  108. elif event.source == EventSource.ENVIRONMENT and isinstance(
  109. event, (CmdOutputObservation, AgentStateChangedObservation)
  110. ):
  111. # feedback from the environment to agent actions is understood as agent events by the UI
  112. event_dict = event_to_dict(event)
  113. event_dict['source'] = EventSource.AGENT
  114. await self.send(event_dict)
  115. elif isinstance(event, ErrorObservation):
  116. # send error events as agent events to the UI
  117. event_dict = event_to_dict(event)
  118. event_dict['source'] = EventSource.AGENT
  119. await self.send(event_dict)
  120. async def dispatch(self, data: dict):
  121. event = event_from_dict(data.copy())
  122. # This checks if the model supports images
  123. if isinstance(event, MessageAction) and event.image_urls:
  124. controller = self.agent_session.controller
  125. if controller:
  126. if controller.agent.llm.config.disable_vision:
  127. await self.send_error(
  128. 'Support for images is disabled for this model, try without an image.'
  129. )
  130. return
  131. if not controller.agent.llm.vision_is_active():
  132. await self.send_error(
  133. 'Model does not support image upload, change to a different model or try without an image.'
  134. )
  135. return
  136. self.agent_session.event_stream.add_event(event, EventSource.USER)
  137. async def send(self, data: dict[str, object]):
  138. if asyncio.get_running_loop() != self.loop:
  139. self.loop.create_task(self._send(data))
  140. return
  141. await self._send(data)
  142. async def _send(self, data: dict[str, object]) -> bool:
  143. try:
  144. if not self.is_alive:
  145. return False
  146. if self.sio:
  147. await self.sio.emit('oh_event', data, to=ROOM_KEY.format(sid=self.sid))
  148. await asyncio.sleep(0.001) # This flushes the data to the client
  149. self.last_active_ts = int(time.time())
  150. return True
  151. except RuntimeError:
  152. logger.error('Error sending', stack_info=True, exc_info=True)
  153. self.is_alive = False
  154. return False
  155. async def send_error(self, message: str):
  156. """Sends an error message to the client."""
  157. await self.send({'error': True, 'message': message})
  158. async def _send_status_message(self, msg_type: str, id: str, message: str):
  159. """Sends a status message to the client."""
  160. if msg_type == 'error':
  161. await self.agent_session.stop_agent_loop_for_error()
  162. await self.send(
  163. {'status_update': True, 'type': msg_type, 'id': id, 'message': message}
  164. )
  165. def queue_status_message(self, msg_type: str, id: str, message: str):
  166. """Queues a status message to be sent asynchronously."""
  167. asyncio.run_coroutine_threadsafe(
  168. self._send_status_message(msg_type, id, message), self.loop
  169. )