stream.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. import asyncio
  2. import threading
  3. from concurrent.futures import ThreadPoolExecutor
  4. from datetime import datetime
  5. from enum import Enum
  6. from queue import Queue
  7. from typing import Callable, Iterable
  8. from openhands.core.logger import openhands_logger as logger
  9. from openhands.core.utils import json
  10. from openhands.events.event import Event, EventSource
  11. from openhands.events.serialization.event import event_from_dict, event_to_dict
  12. from openhands.storage import FileStore
  13. from openhands.storage.locations import (
  14. get_conversation_dir,
  15. get_conversation_event_filename,
  16. get_conversation_events_dir,
  17. )
  18. from openhands.utils.async_utils import call_sync_from_async
  19. from openhands.utils.shutdown_listener import should_continue
  20. class EventStreamSubscriber(str, Enum):
  21. AGENT_CONTROLLER = 'agent_controller'
  22. SECURITY_ANALYZER = 'security_analyzer'
  23. RESOLVER = 'openhands_resolver'
  24. SERVER = 'server'
  25. RUNTIME = 'runtime'
  26. MAIN = 'main'
  27. TEST = 'test'
  28. async def session_exists(sid: str, file_store: FileStore) -> bool:
  29. try:
  30. await call_sync_from_async(file_store.list, get_conversation_dir(sid))
  31. return True
  32. except FileNotFoundError:
  33. return False
  34. class AsyncEventStreamWrapper:
  35. def __init__(self, event_stream, *args, **kwargs):
  36. self.event_stream = event_stream
  37. self.args = args
  38. self.kwargs = kwargs
  39. async def __aiter__(self):
  40. loop = asyncio.get_running_loop()
  41. # Create an async generator that yields events
  42. for event in self.event_stream.get_events(*self.args, **self.kwargs):
  43. # Run the blocking get_events() in a thread pool
  44. yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
  45. class EventStream:
  46. sid: str
  47. file_store: FileStore
  48. # For each subscriber ID, there is a map of callback functions - useful
  49. # when there are multiple listeners
  50. _subscribers: dict[str, dict[str, Callable]]
  51. _cur_id: int = 0
  52. _lock: threading.Lock
  53. def __init__(self, sid: str, file_store: FileStore, num_workers: int = 1):
  54. self.sid = sid
  55. self.file_store = file_store
  56. self._queue: Queue[Event] = Queue()
  57. self._thread_pools: dict[str, dict[str, ThreadPoolExecutor]] = {}
  58. self._queue_thread = threading.Thread(target=self._run_queue_loop)
  59. self._queue_thread.daemon = True
  60. self._queue_thread.start()
  61. self._subscribers = {}
  62. self._lock = threading.Lock()
  63. self._cur_id = 0
  64. # load the stream
  65. self.__post_init__()
  66. def __post_init__(self) -> None:
  67. try:
  68. events = self.file_store.list(get_conversation_events_dir(self.sid))
  69. except FileNotFoundError:
  70. logger.debug(f'No events found for session {self.sid}')
  71. self._cur_id = 0
  72. return
  73. # if we have events, we need to find the highest id to prepare for new events
  74. for event_str in events:
  75. id = self._get_id_from_filename(event_str)
  76. if id >= self._cur_id:
  77. self._cur_id = id + 1
  78. def _init_thread_loop(self):
  79. loop = asyncio.new_event_loop()
  80. asyncio.set_event_loop(loop)
  81. def _get_filename_for_id(self, id: int) -> str:
  82. return get_conversation_event_filename(self.sid, id)
  83. @staticmethod
  84. def _get_id_from_filename(filename: str) -> int:
  85. try:
  86. return int(filename.split('/')[-1].split('.')[0])
  87. except ValueError:
  88. logger.warning(f'get id from filename ({filename}) failed.')
  89. return -1
  90. def get_events(
  91. self,
  92. start_id: int = 0,
  93. end_id: int | None = None,
  94. reverse: bool = False,
  95. filter_out_type: tuple[type[Event], ...] | None = None,
  96. filter_hidden=False,
  97. ) -> Iterable[Event]:
  98. """
  99. Retrieve events from the event stream, optionally filtering out events of a given type
  100. and events marked as hidden.
  101. Args:
  102. start_id: The ID of the first event to retrieve. Defaults to 0.
  103. end_id: The ID of the last event to retrieve. Defaults to the last event in the stream.
  104. reverse: Whether to retrieve events in reverse order. Defaults to False.
  105. filter_out_type: A tuple of event types to filter out. Typically used to filter out backend events from the agent.
  106. filter_hidden: If True, filters out events with the 'hidden' attribute set to True.
  107. Yields:
  108. Events from the stream that match the criteria.
  109. """
  110. def should_filter(event: Event):
  111. if filter_hidden and hasattr(event, 'hidden') and event.hidden:
  112. return True
  113. if filter_out_type is not None and isinstance(event, filter_out_type):
  114. return True
  115. return False
  116. if reverse:
  117. if end_id is None:
  118. end_id = self._cur_id - 1
  119. event_id = end_id
  120. while event_id >= start_id:
  121. try:
  122. event = self.get_event(event_id)
  123. if not should_filter(event):
  124. yield event
  125. except FileNotFoundError:
  126. logger.debug(f'No event found for ID {event_id}')
  127. event_id -= 1
  128. else:
  129. event_id = start_id
  130. while should_continue():
  131. if end_id is not None and event_id > end_id:
  132. break
  133. try:
  134. event = self.get_event(event_id)
  135. if not should_filter(event):
  136. yield event
  137. except FileNotFoundError:
  138. break
  139. event_id += 1
  140. def get_event(self, id: int) -> Event:
  141. filename = self._get_filename_for_id(id)
  142. content = self.file_store.read(filename)
  143. data = json.loads(content)
  144. return event_from_dict(data)
  145. def get_latest_event(self) -> Event:
  146. return self.get_event(self._cur_id - 1)
  147. def get_latest_event_id(self) -> int:
  148. return self._cur_id - 1
  149. def subscribe(
  150. self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
  151. ):
  152. pool = ThreadPoolExecutor(max_workers=1, initializer=self._init_thread_loop)
  153. if subscriber_id not in self._subscribers:
  154. self._subscribers[subscriber_id] = {}
  155. self._thread_pools[subscriber_id] = {}
  156. if callback_id in self._subscribers[subscriber_id]:
  157. raise ValueError(
  158. f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
  159. )
  160. self._subscribers[subscriber_id][callback_id] = callback
  161. self._thread_pools[subscriber_id][callback_id] = pool
  162. def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
  163. if subscriber_id not in self._subscribers:
  164. logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
  165. return
  166. if callback_id not in self._subscribers[subscriber_id]:
  167. logger.warning(f'Callback not found during unsubscribe: {callback_id}')
  168. return
  169. del self._subscribers[subscriber_id][callback_id]
  170. def add_event(self, event: Event, source: EventSource):
  171. if hasattr(event, '_id') and event.id is not None:
  172. raise ValueError(
  173. 'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.'
  174. )
  175. with self._lock:
  176. event._id = self._cur_id # type: ignore [attr-defined]
  177. self._cur_id += 1
  178. logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
  179. event._timestamp = datetime.now().isoformat()
  180. event._source = source # type: ignore [attr-defined]
  181. data = event_to_dict(event)
  182. if event.id is not None:
  183. self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
  184. self._queue.put(event)
  185. def _run_queue_loop(self):
  186. loop = asyncio.new_event_loop()
  187. asyncio.set_event_loop(loop)
  188. loop.run_until_complete(self._process_queue())
  189. async def _process_queue(self):
  190. while should_continue():
  191. event = self._queue.get()
  192. for key in sorted(self._subscribers.keys()):
  193. callbacks = self._subscribers[key]
  194. for callback_id in callbacks:
  195. callback = callbacks[callback_id]
  196. pool = self._thread_pools[key][callback_id]
  197. future = pool.submit(callback, event)
  198. future.add_done_callback(self._make_error_handler(callback_id, key))
  199. def _make_error_handler(self, callback_id: str, subscriber_id: str):
  200. def _handle_callback_error(fut):
  201. try:
  202. # This will raise any exception that occurred during callback execution
  203. fut.result()
  204. except Exception as e:
  205. logger.error(
  206. f'Error in event callback {callback_id} for subscriber {subscriber_id}: {str(e)}',
  207. exc_info=True,
  208. stack_info=True,
  209. )
  210. # Re-raise in the main thread so the error is not swallowed
  211. raise e
  212. return _handle_callback_error
  213. def filtered_events_by_source(self, source: EventSource):
  214. for event in self.get_events():
  215. if event.source == source:
  216. yield event
  217. def _should_filter_event(
  218. self,
  219. event,
  220. query: str | None = None,
  221. event_type: str | None = None,
  222. source: str | None = None,
  223. start_date: str | None = None,
  224. end_date: str | None = None,
  225. ) -> bool:
  226. """Check if an event should be filtered out based on the given criteria.
  227. Args:
  228. event: The event to check
  229. query (str, optional): Text to search for in event content
  230. event_type (str, optional): Filter by event type (e.g., "FileReadAction")
  231. source (str, optional): Filter by event source
  232. start_date (str, optional): Filter events after this date (ISO format)
  233. end_date (str, optional): Filter events before this date (ISO format)
  234. Returns:
  235. bool: True if the event should be filtered out, False if it matches all criteria
  236. """
  237. if event_type and not event.__class__.__name__ == event_type:
  238. return True
  239. if source and not event.source.value == source:
  240. return True
  241. if start_date and event.timestamp < start_date:
  242. return True
  243. if end_date and event.timestamp > end_date:
  244. return True
  245. # Text search in event content if query provided
  246. if query:
  247. event_dict = event_to_dict(event)
  248. event_str = str(event_dict).lower()
  249. if query.lower() not in event_str:
  250. return True
  251. return False
  252. def get_matching_events(
  253. self,
  254. query: str | None = None,
  255. event_type: str | None = None,
  256. source: str | None = None,
  257. start_date: str | None = None,
  258. end_date: str | None = None,
  259. start_id: int = 0,
  260. limit: int = 100,
  261. ) -> list:
  262. """Get matching events from the event stream based on filters.
  263. Args:
  264. query (str, optional): Text to search for in event content
  265. event_type (str, optional): Filter by event type (e.g., "FileReadAction")
  266. source (str, optional): Filter by event source
  267. start_date (str, optional): Filter events after this date (ISO format)
  268. end_date (str, optional): Filter events before this date (ISO format)
  269. start_id (int): Starting ID in the event stream. Defaults to 0
  270. limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 100
  271. Returns:
  272. list: List of matching events (as dicts)
  273. Raises:
  274. ValueError: If limit is less than 1 or greater than 100
  275. """
  276. if limit < 1 or limit > 100:
  277. raise ValueError('Limit must be between 1 and 100')
  278. matching_events: list = []
  279. for event in self.get_events(start_id=start_id):
  280. if self._should_filter_event(
  281. event, query, event_type, source, start_date, end_date
  282. ):
  283. continue
  284. matching_events.append(event_to_dict(event))
  285. # Stop if we have enough events
  286. if len(matching_events) >= limit:
  287. break
  288. return matching_events