stream.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import asyncio
  2. import threading
  3. from dataclasses import dataclass, field
  4. from datetime import datetime
  5. from enum import Enum
  6. from typing import Callable, Iterable
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.core.utils import json
  9. from openhands.events.event import Event, EventSource
  10. from openhands.events.serialization.event import event_from_dict, event_to_dict
  11. from openhands.storage import FileStore
  12. from openhands.utils.async_utils import call_sync_from_async
  13. from openhands.utils.shutdown_listener import should_continue
  14. class EventStreamSubscriber(str, Enum):
  15. AGENT_CONTROLLER = 'agent_controller'
  16. SECURITY_ANALYZER = 'security_analyzer'
  17. RESOLVER = 'openhands_resolver'
  18. SERVER = 'server'
  19. RUNTIME = 'runtime'
  20. MAIN = 'main'
  21. TEST = 'test'
  22. async def session_exists(sid: str, file_store: FileStore) -> bool:
  23. try:
  24. await call_sync_from_async(file_store.list, f'sessions/{sid}')
  25. return True
  26. except FileNotFoundError:
  27. return False
  28. class AsyncEventStreamWrapper:
  29. def __init__(self, event_stream, *args, **kwargs):
  30. self.event_stream = event_stream
  31. self.args = args
  32. self.kwargs = kwargs
  33. async def __aiter__(self):
  34. loop = asyncio.get_running_loop()
  35. # Create an async generator that yields events
  36. for event in self.event_stream.get_events(*self.args, **self.kwargs):
  37. # Run the blocking get_events() in a thread pool
  38. yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
  39. @dataclass
  40. class EventStream:
  41. sid: str
  42. file_store: FileStore
  43. # For each subscriber ID, there is a map of callback functions - useful
  44. # when there are multiple listeners
  45. _subscribers: dict[str, dict[str, Callable]] = field(default_factory=dict)
  46. _cur_id: int = 0
  47. _lock: threading.Lock = field(default_factory=threading.Lock)
  48. def __post_init__(self) -> None:
  49. try:
  50. events = self.file_store.list(f'sessions/{self.sid}/events')
  51. except FileNotFoundError:
  52. logger.debug(f'No events found for session {self.sid}')
  53. self._cur_id = 0
  54. return
  55. # if we have events, we need to find the highest id to prepare for new events
  56. for event_str in events:
  57. id = self._get_id_from_filename(event_str)
  58. if id >= self._cur_id:
  59. self._cur_id = id + 1
  60. def _get_filename_for_id(self, id: int) -> str:
  61. return f'sessions/{self.sid}/events/{id}.json'
  62. @staticmethod
  63. def _get_id_from_filename(filename: str) -> int:
  64. try:
  65. return int(filename.split('/')[-1].split('.')[0])
  66. except ValueError:
  67. logger.warning(f'get id from filename ({filename}) failed.')
  68. return -1
  69. def get_events(
  70. self,
  71. start_id: int = 0,
  72. end_id: int | None = None,
  73. reverse: bool = False,
  74. filter_out_type: tuple[type[Event], ...] | None = None,
  75. filter_hidden=False,
  76. ) -> Iterable[Event]:
  77. """
  78. Retrieve events from the event stream, optionally filtering out events of a given type
  79. and events marked as hidden.
  80. Args:
  81. start_id: The ID of the first event to retrieve. Defaults to 0.
  82. end_id: The ID of the last event to retrieve. Defaults to the last event in the stream.
  83. reverse: Whether to retrieve events in reverse order. Defaults to False.
  84. filter_out_type: A tuple of event types to filter out. Typically used to filter out backend events from the agent.
  85. filter_hidden: If True, filters out events with the 'hidden' attribute set to True.
  86. Yields:
  87. Events from the stream that match the criteria.
  88. """
  89. def should_filter(event: Event):
  90. if filter_hidden and hasattr(event, 'hidden') and event.hidden:
  91. return True
  92. if filter_out_type is not None and isinstance(event, filter_out_type):
  93. return True
  94. return False
  95. if reverse:
  96. if end_id is None:
  97. end_id = self._cur_id - 1
  98. event_id = end_id
  99. while event_id >= start_id:
  100. try:
  101. event = self.get_event(event_id)
  102. if not should_filter(event):
  103. yield event
  104. except FileNotFoundError:
  105. logger.debug(f'No event found for ID {event_id}')
  106. event_id -= 1
  107. else:
  108. event_id = start_id
  109. while should_continue():
  110. if end_id is not None and event_id > end_id:
  111. break
  112. try:
  113. event = self.get_event(event_id)
  114. if not should_filter(event):
  115. yield event
  116. except FileNotFoundError:
  117. break
  118. event_id += 1
  119. def get_event(self, id: int) -> Event:
  120. filename = self._get_filename_for_id(id)
  121. content = self.file_store.read(filename)
  122. data = json.loads(content)
  123. return event_from_dict(data)
  124. def get_latest_event(self) -> Event:
  125. return self.get_event(self._cur_id - 1)
  126. def get_latest_event_id(self) -> int:
  127. return self._cur_id - 1
  128. def subscribe(
  129. self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
  130. ):
  131. if subscriber_id not in self._subscribers:
  132. self._subscribers[subscriber_id] = {}
  133. if callback_id in self._subscribers[subscriber_id]:
  134. raise ValueError(
  135. f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
  136. )
  137. self._subscribers[subscriber_id][callback_id] = callback
  138. def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
  139. if subscriber_id not in self._subscribers:
  140. logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
  141. return
  142. if callback_id not in self._subscribers[subscriber_id]:
  143. logger.warning(f'Callback not found during unsubscribe: {callback_id}')
  144. return
  145. del self._subscribers[subscriber_id][callback_id]
  146. def add_event(self, event: Event, source: EventSource):
  147. try:
  148. asyncio.get_running_loop().create_task(self._async_add_event(event, source))
  149. except RuntimeError:
  150. # No event loop running...
  151. asyncio.run(self._async_add_event(event, source))
  152. async def _async_add_event(self, event: Event, source: EventSource):
  153. if hasattr(event, '_id') and event.id is not None:
  154. raise ValueError(
  155. 'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.'
  156. )
  157. with self._lock:
  158. event._id = self._cur_id # type: ignore [attr-defined]
  159. self._cur_id += 1
  160. logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
  161. event._timestamp = datetime.now().isoformat()
  162. event._source = source # type: ignore [attr-defined]
  163. data = event_to_dict(event)
  164. if event.id is not None:
  165. self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
  166. tasks = []
  167. for key in sorted(self._subscribers.keys()):
  168. callbacks = self._subscribers[key]
  169. for callback_id in callbacks:
  170. callback = callbacks[callback_id]
  171. tasks.append(asyncio.create_task(callback(event)))
  172. if tasks:
  173. await asyncio.wait(tasks)
  174. def _callback(self, callback: Callable, event: Event):
  175. asyncio.run(callback(event))
  176. def filtered_events_by_source(self, source: EventSource):
  177. for event in self.get_events():
  178. if event.source == source:
  179. yield event
  180. def clear(self):
  181. self.file_store.delete(f'sessions/{self.sid}')
  182. self._cur_id = 0
  183. # self._subscribers = {}
  184. self.__post_init__()