stream.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import asyncio
  2. import threading
  3. from dataclasses import dataclass, field
  4. from datetime import datetime
  5. from enum import Enum
  6. from typing import Callable, Iterable
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.core.utils import json
  9. from openhands.events.event import Event, EventSource
  10. from openhands.events.serialization.event import event_from_dict, event_to_dict
  11. from openhands.runtime.utils.shutdown_listener import should_continue
  12. from openhands.storage import FileStore
  13. class EventStreamSubscriber(str, Enum):
  14. AGENT_CONTROLLER = 'agent_controller'
  15. SECURITY_ANALYZER = 'security_analyzer'
  16. SERVER = 'server'
  17. RUNTIME = 'runtime'
  18. MAIN = 'main'
  19. TEST = 'test'
  20. def session_exists(sid: str, file_store: FileStore) -> bool:
  21. try:
  22. file_store.list(f'sessions/{sid}')
  23. return True
  24. except FileNotFoundError:
  25. return False
  26. @dataclass
  27. class EventStream:
  28. sid: str
  29. file_store: FileStore
  30. # For each subscriber ID, there is a stack of callback functions - useful
  31. # when there are agent delegates
  32. _subscribers: dict[str, list[Callable]] = field(default_factory=dict)
  33. _cur_id: int = 0
  34. _lock: threading.Lock = field(default_factory=threading.Lock)
  35. def __post_init__(self) -> None:
  36. try:
  37. events = self.file_store.list(f'sessions/{self.sid}/events')
  38. except FileNotFoundError:
  39. logger.debug(f'No events found for session {self.sid}')
  40. self._cur_id = 0
  41. return
  42. # if we have events, we need to find the highest id to prepare for new events
  43. for event_str in events:
  44. id = self._get_id_from_filename(event_str)
  45. if id >= self._cur_id:
  46. self._cur_id = id + 1
  47. def _get_filename_for_id(self, id: int) -> str:
  48. return f'sessions/{self.sid}/events/{id}.json'
  49. @staticmethod
  50. def _get_id_from_filename(filename: str) -> int:
  51. try:
  52. return int(filename.split('/')[-1].split('.')[0])
  53. except ValueError:
  54. logger.warning(f'get id from filename ({filename}) failed.')
  55. return -1
  56. def get_events(
  57. self,
  58. start_id=0,
  59. end_id=None,
  60. reverse=False,
  61. filter_out_type: tuple[type[Event], ...] | None = None,
  62. filter_hidden=False,
  63. ) -> Iterable[Event]:
  64. def should_filter(event: Event):
  65. if filter_hidden and hasattr(event, 'hidden') and event.hidden:
  66. return True
  67. if filter_out_type is not None and isinstance(event, filter_out_type):
  68. return True
  69. return False
  70. if reverse:
  71. if end_id is None:
  72. end_id = self._cur_id - 1
  73. event_id = end_id
  74. while event_id >= start_id:
  75. try:
  76. event = self.get_event(event_id)
  77. if not should_filter(event):
  78. yield event
  79. except FileNotFoundError:
  80. logger.debug(f'No event found for ID {event_id}')
  81. event_id -= 1
  82. else:
  83. event_id = start_id
  84. while should_continue():
  85. if end_id is not None and event_id > end_id:
  86. break
  87. try:
  88. event = self.get_event(event_id)
  89. if not should_filter(event):
  90. yield event
  91. except FileNotFoundError:
  92. break
  93. event_id += 1
  94. def get_event(self, id: int) -> Event:
  95. filename = self._get_filename_for_id(id)
  96. content = self.file_store.read(filename)
  97. data = json.loads(content)
  98. return event_from_dict(data)
  99. def get_latest_event(self) -> Event:
  100. return self.get_event(self._cur_id - 1)
  101. def get_latest_event_id(self) -> int:
  102. return self._cur_id - 1
  103. def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
  104. if id in self._subscribers:
  105. if append:
  106. self._subscribers[id].append(callback)
  107. else:
  108. raise ValueError('Subscriber already exists: ' + id)
  109. else:
  110. self._subscribers[id] = [callback]
  111. def unsubscribe(self, id: EventStreamSubscriber):
  112. if id not in self._subscribers:
  113. logger.warning('Subscriber not found during unsubscribe: ' + id)
  114. else:
  115. self._subscribers[id].pop()
  116. if len(self._subscribers[id]) == 0:
  117. del self._subscribers[id]
  118. def add_event(self, event: Event, source: EventSource):
  119. try:
  120. asyncio.get_running_loop().create_task(self.async_add_event(event, source))
  121. except RuntimeError:
  122. # No event loop running...
  123. asyncio.run(self.async_add_event(event, source))
  124. async def async_add_event(self, event: Event, source: EventSource):
  125. with self._lock:
  126. event._id = self._cur_id # type: ignore [attr-defined]
  127. self._cur_id += 1
  128. logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
  129. event._timestamp = datetime.now().isoformat()
  130. event._source = source # type: ignore [attr-defined]
  131. data = event_to_dict(event)
  132. if event.id is not None:
  133. self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
  134. tasks = []
  135. for key in sorted(self._subscribers.keys()):
  136. stack = self._subscribers[key]
  137. callback = stack[-1]
  138. tasks.append(asyncio.create_task(callback(event)))
  139. if tasks:
  140. await asyncio.wait(tasks)
  141. def _callback(self, callback: Callable, event: Event):
  142. asyncio.run(callback(event))
  143. def filtered_events_by_source(self, source: EventSource):
  144. for event in self.get_events():
  145. if event.source == source:
  146. yield event
  147. def clear(self):
  148. self.file_store.delete(f'sessions/{self.sid}')
  149. self._cur_id = 0
  150. # self._subscribers = {}
  151. self.__post_init__()