stream.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import asyncio
  2. import threading
  3. from datetime import datetime
  4. from enum import Enum
  5. from typing import Callable, Iterable
  6. from openhands.core.logger import openhands_logger as logger
  7. from openhands.core.utils import json
  8. from openhands.events.event import Event, EventSource
  9. from openhands.events.serialization.event import event_from_dict, event_to_dict
  10. from openhands.runtime.utils.shutdown_listener import should_continue
  11. from openhands.storage import FileStore
  12. class EventStreamSubscriber(str, Enum):
  13. AGENT_CONTROLLER = 'agent_controller'
  14. SECURITY_ANALYZER = 'security_analyzer'
  15. SERVER = 'server'
  16. RUNTIME = 'runtime'
  17. MAIN = 'main'
  18. TEST = 'test'
  19. def session_exists(sid: str, file_store: FileStore) -> bool:
  20. try:
  21. file_store.list(f'sessions/{sid}')
  22. return True
  23. except FileNotFoundError:
  24. return False
  25. class EventStream:
  26. sid: str
  27. file_store: FileStore
  28. # For each subscriber ID, there is a stack of callback functions - useful
  29. # when there are agent delegates
  30. _subscribers: dict[str, list[Callable]]
  31. _cur_id: int
  32. _lock: threading.Lock
  33. def __init__(self, sid: str, file_store: FileStore):
  34. self.sid = sid
  35. self.file_store = file_store
  36. self._subscribers = {}
  37. self._cur_id = 0
  38. self._lock = threading.Lock()
  39. self._reinitialize_from_file_store()
  40. def _reinitialize_from_file_store(self) -> None:
  41. try:
  42. events = self.file_store.list(f'sessions/{self.sid}/events')
  43. except FileNotFoundError:
  44. logger.debug(f'No events found for session {self.sid}')
  45. self._cur_id = 0
  46. return
  47. # if we have events, we need to find the highest id to prepare for new events
  48. for event_str in events:
  49. id = self._get_id_from_filename(event_str)
  50. if id >= self._cur_id:
  51. self._cur_id = id + 1
  52. def _get_filename_for_id(self, id: int) -> str:
  53. return f'sessions/{self.sid}/events/{id}.json'
  54. @staticmethod
  55. def _get_id_from_filename(filename: str) -> int:
  56. try:
  57. return int(filename.split('/')[-1].split('.')[0])
  58. except ValueError:
  59. logger.warning(f'get id from filename ({filename}) failed.')
  60. return -1
  61. def get_events(
  62. self,
  63. start_id=0,
  64. end_id=None,
  65. reverse=False,
  66. filter_out_type: tuple[type[Event], ...] | None = None,
  67. ) -> Iterable[Event]:
  68. if reverse:
  69. if end_id is None:
  70. end_id = self._cur_id - 1
  71. event_id = end_id
  72. while event_id >= start_id:
  73. try:
  74. event = self.get_event(event_id)
  75. if filter_out_type is None or not isinstance(
  76. event, filter_out_type
  77. ):
  78. yield event
  79. except FileNotFoundError:
  80. logger.debug(f'No event found for ID {event_id}')
  81. event_id -= 1
  82. else:
  83. event_id = start_id
  84. while should_continue():
  85. if end_id is not None and event_id > end_id:
  86. break
  87. try:
  88. event = self.get_event(event_id)
  89. if filter_out_type is None or not isinstance(
  90. event, filter_out_type
  91. ):
  92. yield event
  93. except FileNotFoundError:
  94. break
  95. event_id += 1
  96. def get_event(self, id: int) -> Event:
  97. filename = self._get_filename_for_id(id)
  98. content = self.file_store.read(filename)
  99. data = json.loads(content)
  100. return event_from_dict(data)
  101. def get_latest_event(self) -> Event:
  102. return self.get_event(self._cur_id - 1)
  103. def get_latest_event_id(self) -> int:
  104. return self._cur_id - 1
  105. def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
  106. if id in self._subscribers:
  107. if append:
  108. self._subscribers[id].append(callback)
  109. else:
  110. raise ValueError('Subscriber already exists: ' + id)
  111. else:
  112. self._subscribers[id] = [callback]
  113. def unsubscribe(self, id: EventStreamSubscriber):
  114. if id not in self._subscribers:
  115. logger.warning('Subscriber not found during unsubscribe: ' + id)
  116. else:
  117. self._subscribers[id].pop()
  118. if len(self._subscribers[id]) == 0:
  119. del self._subscribers[id]
  120. def add_event(self, event: Event, source: EventSource):
  121. try:
  122. asyncio.get_running_loop().create_task(self.async_add_event(event, source))
  123. except RuntimeError:
  124. # No event loop running...
  125. asyncio.run(self.async_add_event(event, source))
  126. async def async_add_event(self, event: Event, source: EventSource):
  127. with self._lock:
  128. event._id = self._cur_id # type: ignore [attr-defined]
  129. self._cur_id += 1
  130. logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
  131. event._timestamp = datetime.now().isoformat()
  132. event._source = source # type: ignore [attr-defined]
  133. data = event_to_dict(event)
  134. if event.id is not None:
  135. self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
  136. tasks = []
  137. for key in sorted(self._subscribers.keys()):
  138. stack = self._subscribers[key]
  139. callback = stack[-1]
  140. tasks.append(asyncio.create_task(callback(event)))
  141. if tasks:
  142. await asyncio.wait(tasks)
  143. def _callback(self, callback: Callable, event: Event):
  144. asyncio.run(callback(event))
  145. def filtered_events_by_source(self, source: EventSource):
  146. for event in self.get_events():
  147. if event.source == source:
  148. yield event
  149. def clear(self):
  150. self.file_store.delete(f'sessions/{self.sid}')
  151. self._cur_id = 0
  152. # self._subscribers = {}
  153. self._reinitialize_from_file_store()