session.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import asyncio
  2. import json
  3. import time
  4. from copy import deepcopy
  5. import socketio
  6. from openhands.controller.agent import Agent
  7. from openhands.core.config import AppConfig
  8. from openhands.core.const.guide_url import TROUBLESHOOTING_URL
  9. from openhands.core.logger import openhands_logger as logger
  10. from openhands.core.schema import AgentState
  11. from openhands.events.action import MessageAction, NullAction
  12. from openhands.events.event import Event, EventSource
  13. from openhands.events.observation import (
  14. AgentStateChangedObservation,
  15. CmdOutputObservation,
  16. NullObservation,
  17. )
  18. from openhands.events.observation.error import ErrorObservation
  19. from openhands.events.serialization import event_from_dict, event_to_dict
  20. from openhands.events.stream import EventStreamSubscriber
  21. from openhands.llm.llm import LLM
  22. from openhands.server.session.agent_session import AgentSession
  23. from openhands.server.session.conversation_init_data import ConversationInitData
  24. from openhands.storage.files import FileStore
  25. from openhands.storage.locations import get_conversation_init_data_filename
  26. from openhands.utils.async_utils import call_sync_from_async
  27. ROOM_KEY = 'room:{sid}'
  28. class Session:
  29. sid: str
  30. sio: socketio.AsyncServer | None
  31. last_active_ts: int = 0
  32. is_alive: bool = True
  33. agent_session: AgentSession
  34. loop: asyncio.AbstractEventLoop
  35. config: AppConfig
  36. file_store: FileStore
  37. def __init__(
  38. self,
  39. sid: str,
  40. config: AppConfig,
  41. file_store: FileStore,
  42. sio: socketio.AsyncServer | None,
  43. ):
  44. self.sid = sid
  45. self.sio = sio
  46. self.last_active_ts = int(time.time())
  47. self.file_store = file_store
  48. self.agent_session = AgentSession(
  49. sid, file_store, status_callback=self.queue_status_message
  50. )
  51. self.agent_session.event_stream.subscribe(
  52. EventStreamSubscriber.SERVER, self.on_event, self.sid
  53. )
  54. # Copying this means that when we update variables they are not applied to the shared global configuration!
  55. self.config = deepcopy(config)
  56. self.loop = asyncio.get_event_loop()
  57. def close(self):
  58. self.is_alive = False
  59. self.agent_session.close()
  60. async def _restore_init_data(self, sid: str) -> ConversationInitData:
  61. # FIXME: we should not store/restore this data once we have server-side
  62. # LLM configs. Should be done by 1/1/2025
  63. json_str = await call_sync_from_async(
  64. self.file_store.read, get_conversation_init_data_filename(sid)
  65. )
  66. data = json.loads(json_str)
  67. return ConversationInitData(**data)
  68. async def _save_init_data(self, sid: str, init_data: ConversationInitData):
  69. # FIXME: we should not store/restore this data once we have server-side
  70. # LLM configs. Should be done by 1/1/2025
  71. json_str = json.dumps(init_data.__dict__)
  72. await call_sync_from_async(
  73. self.file_store.write, get_conversation_init_data_filename(sid), json_str
  74. )
  75. async def initialize_agent(
  76. self, conversation_init_data: ConversationInitData | None = None
  77. ):
  78. self.agent_session.event_stream.add_event(
  79. AgentStateChangedObservation('', AgentState.LOADING),
  80. EventSource.ENVIRONMENT,
  81. )
  82. if conversation_init_data is None:
  83. try:
  84. conversation_init_data = await self._restore_init_data(self.sid)
  85. except FileNotFoundError:
  86. logger.error(f'User settings not found for session {self.sid}')
  87. raise RuntimeError('User settings not found')
  88. agent_cls = conversation_init_data.agent or self.config.default_agent
  89. self.config.security.confirmation_mode = (
  90. self.config.security.confirmation_mode
  91. if conversation_init_data.confirmation_mode is None
  92. else conversation_init_data.confirmation_mode
  93. )
  94. self.config.security.security_analyzer = (
  95. conversation_init_data.security_analyzer
  96. or self.config.security.security_analyzer
  97. )
  98. max_iterations = (
  99. conversation_init_data.max_iterations or self.config.max_iterations
  100. )
  101. # override default LLM config
  102. default_llm_config = self.config.get_llm_config()
  103. default_llm_config.model = (
  104. conversation_init_data.llm_model or default_llm_config.model
  105. )
  106. default_llm_config.api_key = (
  107. conversation_init_data.llm_api_key or default_llm_config.api_key
  108. )
  109. default_llm_config.base_url = (
  110. conversation_init_data.llm_base_url or default_llm_config.base_url
  111. )
  112. await self._save_init_data(self.sid, conversation_init_data)
  113. # TODO: override other LLM config & agent config groups (#2075)
  114. llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
  115. agent_config = self.config.get_agent_config(agent_cls)
  116. agent = Agent.get_cls(agent_cls)(llm, agent_config)
  117. try:
  118. await self.agent_session.start(
  119. runtime_name=self.config.runtime,
  120. config=self.config,
  121. agent=agent,
  122. max_iterations=max_iterations,
  123. max_budget_per_task=self.config.max_budget_per_task,
  124. agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
  125. agent_configs=self.config.get_agent_configs(),
  126. github_token=conversation_init_data.github_token,
  127. selected_repository=conversation_init_data.selected_repository,
  128. )
  129. except Exception as e:
  130. logger.exception(f'Error creating controller: {e}')
  131. await self.send_error(
  132. f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
  133. )
  134. return
  135. async def on_event(self, event: Event):
  136. """Callback function for events that mainly come from the agent.
  137. Event is the base class for any agent action and observation.
  138. Args:
  139. event: The agent event (Observation or Action).
  140. """
  141. if isinstance(event, NullAction):
  142. return
  143. if isinstance(event, NullObservation):
  144. return
  145. if event.source == EventSource.AGENT:
  146. await self.send(event_to_dict(event))
  147. elif event.source == EventSource.USER:
  148. await self.send(event_to_dict(event))
  149. # NOTE: ipython observations are not sent here currently
  150. elif event.source == EventSource.ENVIRONMENT and isinstance(
  151. event, (CmdOutputObservation, AgentStateChangedObservation)
  152. ):
  153. # feedback from the environment to agent actions is understood as agent events by the UI
  154. event_dict = event_to_dict(event)
  155. event_dict['source'] = EventSource.AGENT
  156. await self.send(event_dict)
  157. elif isinstance(event, ErrorObservation):
  158. # send error events as agent events to the UI
  159. event_dict = event_to_dict(event)
  160. event_dict['source'] = EventSource.AGENT
  161. await self.send(event_dict)
  162. async def dispatch(self, data: dict):
  163. event = event_from_dict(data.copy())
  164. # This checks if the model supports images
  165. if isinstance(event, MessageAction) and event.image_urls:
  166. controller = self.agent_session.controller
  167. if controller:
  168. if controller.agent.llm.config.disable_vision:
  169. await self.send_error(
  170. 'Support for images is disabled for this model, try without an image.'
  171. )
  172. return
  173. if not controller.agent.llm.vision_is_active():
  174. await self.send_error(
  175. 'Model does not support image upload, change to a different model or try without an image.'
  176. )
  177. return
  178. self.agent_session.event_stream.add_event(event, EventSource.USER)
  179. async def send(self, data: dict[str, object]):
  180. if asyncio.get_running_loop() != self.loop:
  181. self.loop.create_task(self._send(data))
  182. return
  183. await self._send(data)
  184. async def _send(self, data: dict[str, object]) -> bool:
  185. try:
  186. if not self.is_alive:
  187. return False
  188. if self.sio:
  189. await self.sio.emit('oh_event', data, to=ROOM_KEY.format(sid=self.sid))
  190. await asyncio.sleep(0.001) # This flushes the data to the client
  191. self.last_active_ts = int(time.time())
  192. return True
  193. except RuntimeError:
  194. logger.error('Error sending', stack_info=True, exc_info=True)
  195. self.is_alive = False
  196. return False
  197. async def send_error(self, message: str):
  198. """Sends an error message to the client."""
  199. await self.send({'error': True, 'message': message})
  200. async def _send_status_message(self, msg_type: str, id: str, message: str):
  201. """Sends a status message to the client."""
  202. if msg_type == 'error':
  203. await self.agent_session.stop_agent_loop_for_error()
  204. await self.send(
  205. {'status_update': True, 'type': msg_type, 'id': id, 'message': message}
  206. )
  207. def queue_status_message(self, msg_type: str, id: str, message: str):
  208. """Queues a status message to be sent asynchronously."""
  209. asyncio.run_coroutine_threadsafe(
  210. self._send_status_message(msg_type, id, message), self.loop
  211. )