stream.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import asyncio
  2. import threading
  3. from dataclasses import dataclass, field
  4. from datetime import datetime
  5. from enum import Enum
  6. from typing import Callable, Iterable
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.core.utils import json
  9. from openhands.events.event import Event, EventSource
  10. from openhands.events.serialization.event import event_from_dict, event_to_dict
  11. from openhands.runtime.utils.shutdown_listener import should_continue
  12. from openhands.storage import FileStore
  13. from openhands.utils.async_utils import call_sync_from_async
  14. class EventStreamSubscriber(str, Enum):
  15. AGENT_CONTROLLER = 'agent_controller'
  16. SECURITY_ANALYZER = 'security_analyzer'
  17. SERVER = 'server'
  18. RUNTIME = 'runtime'
  19. MAIN = 'main'
  20. TEST = 'test'
  21. async def session_exists(sid: str, file_store: FileStore) -> bool:
  22. try:
  23. await call_sync_from_async(file_store.list, f'sessions/{sid}')
  24. return True
  25. except FileNotFoundError:
  26. return False
  27. class AsyncEventStreamWrapper:
  28. def __init__(self, event_stream, *args, **kwargs):
  29. self.event_stream = event_stream
  30. self.args = args
  31. self.kwargs = kwargs
  32. async def __aiter__(self):
  33. loop = asyncio.get_running_loop()
  34. # Create an async generator that yields events
  35. for event in self.event_stream.get_events(*self.args, **self.kwargs):
  36. # Run the blocking get_events() in a thread pool
  37. yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
  38. @dataclass
  39. class EventStream:
  40. sid: str
  41. file_store: FileStore
  42. # For each subscriber ID, there is a stack of callback functions - useful
  43. # when there are agent delegates
  44. _subscribers: dict[str, list[Callable]] = field(default_factory=dict)
  45. _cur_id: int = 0
  46. _lock: threading.Lock = field(default_factory=threading.Lock)
  47. def __post_init__(self) -> None:
  48. try:
  49. events = self.file_store.list(f'sessions/{self.sid}/events')
  50. except FileNotFoundError:
  51. logger.debug(f'No events found for session {self.sid}')
  52. self._cur_id = 0
  53. return
  54. # if we have events, we need to find the highest id to prepare for new events
  55. for event_str in events:
  56. id = self._get_id_from_filename(event_str)
  57. if id >= self._cur_id:
  58. self._cur_id = id + 1
  59. def _get_filename_for_id(self, id: int) -> str:
  60. return f'sessions/{self.sid}/events/{id}.json'
  61. @staticmethod
  62. def _get_id_from_filename(filename: str) -> int:
  63. try:
  64. return int(filename.split('/')[-1].split('.')[0])
  65. except ValueError:
  66. logger.warning(f'get id from filename ({filename}) failed.')
  67. return -1
  68. def get_events(
  69. self,
  70. start_id=0,
  71. end_id=None,
  72. reverse=False,
  73. filter_out_type: tuple[type[Event], ...] | None = None,
  74. filter_hidden=False,
  75. ) -> Iterable[Event]:
  76. def should_filter(event: Event):
  77. if filter_hidden and hasattr(event, 'hidden') and event.hidden:
  78. return True
  79. if filter_out_type is not None and isinstance(event, filter_out_type):
  80. return True
  81. return False
  82. if reverse:
  83. if end_id is None:
  84. end_id = self._cur_id - 1
  85. event_id = end_id
  86. while event_id >= start_id:
  87. try:
  88. event = self.get_event(event_id)
  89. if not should_filter(event):
  90. yield event
  91. except FileNotFoundError:
  92. logger.debug(f'No event found for ID {event_id}')
  93. event_id -= 1
  94. else:
  95. event_id = start_id
  96. while should_continue():
  97. if end_id is not None and event_id > end_id:
  98. break
  99. try:
  100. event = self.get_event(event_id)
  101. if not should_filter(event):
  102. yield event
  103. except FileNotFoundError:
  104. break
  105. event_id += 1
  106. def get_event(self, id: int) -> Event:
  107. filename = self._get_filename_for_id(id)
  108. content = self.file_store.read(filename)
  109. data = json.loads(content)
  110. return event_from_dict(data)
  111. def get_latest_event(self) -> Event:
  112. return self.get_event(self._cur_id - 1)
  113. def get_latest_event_id(self) -> int:
  114. return self._cur_id - 1
  115. def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
  116. if id in self._subscribers:
  117. if append:
  118. self._subscribers[id].append(callback)
  119. else:
  120. raise ValueError('Subscriber already exists: ' + id)
  121. else:
  122. self._subscribers[id] = [callback]
  123. def unsubscribe(self, id: EventStreamSubscriber):
  124. if id not in self._subscribers:
  125. logger.warning('Subscriber not found during unsubscribe: ' + id)
  126. else:
  127. self._subscribers[id].pop()
  128. if len(self._subscribers[id]) == 0:
  129. del self._subscribers[id]
  130. def add_event(self, event: Event, source: EventSource):
  131. try:
  132. asyncio.get_running_loop().create_task(self.async_add_event(event, source))
  133. except RuntimeError:
  134. # No event loop running...
  135. asyncio.run(self.async_add_event(event, source))
  136. async def async_add_event(self, event: Event, source: EventSource):
  137. with self._lock:
  138. event._id = self._cur_id # type: ignore [attr-defined]
  139. self._cur_id += 1
  140. logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
  141. event._timestamp = datetime.now().isoformat()
  142. event._source = source # type: ignore [attr-defined]
  143. data = event_to_dict(event)
  144. if event.id is not None:
  145. self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
  146. tasks = []
  147. for key in sorted(self._subscribers.keys()):
  148. stack = self._subscribers[key]
  149. callback = stack[-1]
  150. tasks.append(asyncio.create_task(callback(event)))
  151. if tasks:
  152. await asyncio.wait(tasks)
  153. def _callback(self, callback: Callable, event: Event):
  154. asyncio.run(callback(event))
  155. def filtered_events_by_source(self, source: EventSource):
  156. for event in self.get_events():
  157. if event.source == source:
  158. yield event
  159. def clear(self):
  160. self.file_store.delete(f'sessions/{self.sid}')
  161. self._cur_id = 0
  162. # self._subscribers = {}
  163. self.__post_init__()