session.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import asyncio
  2. import time
  3. from copy import deepcopy
  4. import socketio
  5. from openhands.controller.agent import Agent
  6. from openhands.core.config import AppConfig
  7. from openhands.core.const.guide_url import TROUBLESHOOTING_URL
  8. from openhands.core.logger import openhands_logger as logger
  9. from openhands.core.schema import AgentState
  10. from openhands.events.action import MessageAction, NullAction
  11. from openhands.events.event import Event, EventSource
  12. from openhands.events.observation import (
  13. AgentStateChangedObservation,
  14. CmdOutputObservation,
  15. NullObservation,
  16. )
  17. from openhands.events.observation.error import ErrorObservation
  18. from openhands.events.serialization import event_from_dict, event_to_dict
  19. from openhands.events.stream import EventStreamSubscriber
  20. from openhands.llm.llm import LLM
  21. from openhands.server.session.agent_session import AgentSession
  22. from openhands.server.session.session_init_data import SessionInitData
  23. from openhands.storage.files import FileStore
  24. ROOM_KEY = 'room:{sid}'
  25. class Session:
  26. sid: str
  27. sio: socketio.AsyncServer | None
  28. last_active_ts: int = 0
  29. is_alive: bool = True
  30. agent_session: AgentSession
  31. loop: asyncio.AbstractEventLoop
  32. config: AppConfig
  33. def __init__(
  34. self,
  35. sid: str,
  36. config: AppConfig,
  37. file_store: FileStore,
  38. sio: socketio.AsyncServer | None,
  39. ):
  40. self.sid = sid
  41. self.sio = sio
  42. self.last_active_ts = int(time.time())
  43. self.agent_session = AgentSession(
  44. sid, file_store, status_callback=self.queue_status_message
  45. )
  46. self.agent_session.event_stream.subscribe(
  47. EventStreamSubscriber.SERVER, self.on_event, self.sid
  48. )
  49. # Copying this means that when we update variables they are not applied to the shared global configuration!
  50. self.config = deepcopy(config)
  51. self.loop = asyncio.get_event_loop()
  52. def close(self):
  53. self.is_alive = False
  54. self.agent_session.close()
  55. async def initialize_agent(self, session_init_data: SessionInitData):
  56. self.agent_session.event_stream.add_event(
  57. AgentStateChangedObservation('', AgentState.LOADING),
  58. EventSource.ENVIRONMENT,
  59. )
  60. # Extract the agent-relevant arguments from the request
  61. agent_cls = session_init_data.agent or self.config.default_agent
  62. self.config.security.confirmation_mode = (
  63. self.config.security.confirmation_mode
  64. if session_init_data.confirmation_mode is None
  65. else session_init_data.confirmation_mode
  66. )
  67. self.config.security.security_analyzer = (
  68. session_init_data.security_analyzer
  69. or self.config.security.security_analyzer
  70. )
  71. max_iterations = session_init_data.max_iterations or self.config.max_iterations
  72. # override default LLM config
  73. default_llm_config = self.config.get_llm_config()
  74. default_llm_config.model = (
  75. session_init_data.llm_model or default_llm_config.model
  76. )
  77. default_llm_config.api_key = (
  78. session_init_data.llm_api_key or default_llm_config.api_key
  79. )
  80. default_llm_config.base_url = (
  81. session_init_data.llm_base_url or default_llm_config.base_url
  82. )
  83. # TODO: override other LLM config & agent config groups (#2075)
  84. llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
  85. agent_config = self.config.get_agent_config(agent_cls)
  86. agent = Agent.get_cls(agent_cls)(llm, agent_config)
  87. try:
  88. await self.agent_session.start(
  89. runtime_name=self.config.runtime,
  90. config=self.config,
  91. agent=agent,
  92. max_iterations=max_iterations,
  93. max_budget_per_task=self.config.max_budget_per_task,
  94. agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
  95. agent_configs=self.config.get_agent_configs(),
  96. github_token=session_init_data.github_token,
  97. selected_repository=session_init_data.selected_repository,
  98. )
  99. except Exception as e:
  100. logger.exception(f'Error creating controller: {e}')
  101. await self.send_error(
  102. f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
  103. )
  104. return
  105. async def on_event(self, event: Event):
  106. """Callback function for events that mainly come from the agent.
  107. Event is the base class for any agent action and observation.
  108. Args:
  109. event: The agent event (Observation or Action).
  110. """
  111. if isinstance(event, NullAction):
  112. return
  113. if isinstance(event, NullObservation):
  114. return
  115. if event.source == EventSource.AGENT:
  116. await self.send(event_to_dict(event))
  117. elif event.source == EventSource.USER:
  118. await self.send(event_to_dict(event))
  119. # NOTE: ipython observations are not sent here currently
  120. elif event.source == EventSource.ENVIRONMENT and isinstance(
  121. event, (CmdOutputObservation, AgentStateChangedObservation)
  122. ):
  123. # feedback from the environment to agent actions is understood as agent events by the UI
  124. event_dict = event_to_dict(event)
  125. event_dict['source'] = EventSource.AGENT
  126. await self.send(event_dict)
  127. elif isinstance(event, ErrorObservation):
  128. # send error events as agent events to the UI
  129. event_dict = event_to_dict(event)
  130. event_dict['source'] = EventSource.AGENT
  131. await self.send(event_dict)
  132. async def dispatch(self, data: dict):
  133. event = event_from_dict(data.copy())
  134. # This checks if the model supports images
  135. if isinstance(event, MessageAction) and event.image_urls:
  136. controller = self.agent_session.controller
  137. if controller:
  138. if controller.agent.llm.config.disable_vision:
  139. await self.send_error(
  140. 'Support for images is disabled for this model, try without an image.'
  141. )
  142. return
  143. if not controller.agent.llm.vision_is_active():
  144. await self.send_error(
  145. 'Model does not support image upload, change to a different model or try without an image.'
  146. )
  147. return
  148. self.agent_session.event_stream.add_event(event, EventSource.USER)
  149. async def send(self, data: dict[str, object]):
  150. if asyncio.get_running_loop() != self.loop:
  151. self.loop.create_task(self._send(data))
  152. return
  153. await self._send(data)
  154. async def _send(self, data: dict[str, object]) -> bool:
  155. try:
  156. if not self.is_alive:
  157. return False
  158. if self.sio:
  159. await self.sio.emit('oh_event', data, to=ROOM_KEY.format(sid=self.sid))
  160. await asyncio.sleep(0.001) # This flushes the data to the client
  161. self.last_active_ts = int(time.time())
  162. return True
  163. except RuntimeError:
  164. logger.error('Error sending', stack_info=True, exc_info=True)
  165. self.is_alive = False
  166. return False
  167. async def send_error(self, message: str):
  168. """Sends an error message to the client."""
  169. await self.send({'error': True, 'message': message})
  170. async def _send_status_message(self, msg_type: str, id: str, message: str):
  171. """Sends a status message to the client."""
  172. if msg_type == 'error':
  173. await self.agent_session.stop_agent_loop_for_error()
  174. await self.send(
  175. {'status_update': True, 'type': msg_type, 'id': id, 'message': message}
  176. )
  177. def queue_status_message(self, msg_type: str, id: str, message: str):
  178. """Queues a status message to be sent asynchronously."""
  179. asyncio.run_coroutine_threadsafe(
  180. self._send_status_message(msg_type, id, message), self.loop
  181. )