stream.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import asyncio
  2. import threading
  3. from datetime import datetime
  4. from enum import Enum
  5. from typing import Callable, Iterable
  6. from opendevin.core.logger import opendevin_logger as logger
  7. from opendevin.core.utils import json
  8. from opendevin.events.serialization.event import event_from_dict, event_to_dict
  9. from opendevin.storage import FileStore, get_file_store
  10. from .event import Event, EventSource
  11. class EventStreamSubscriber(str, Enum):
  12. AGENT_CONTROLLER = 'agent_controller'
  13. SERVER = 'server'
  14. RUNTIME = 'runtime'
  15. MAIN = 'main'
  16. TEST = 'test'
  17. class EventStream:
  18. sid: str
  19. # For each subscriber ID, there is a stack of callback functions - useful
  20. # when there are agent delegates
  21. _subscribers: dict[str, list[Callable]]
  22. _cur_id: int
  23. _lock: threading.Lock
  24. _file_store: FileStore
  25. def __init__(self, sid: str):
  26. self.sid = sid
  27. self._file_store = get_file_store()
  28. self._subscribers = {}
  29. self._cur_id = 0
  30. self._lock = threading.Lock()
  31. self._reinitialize_from_file_store()
  32. def _reinitialize_from_file_store(self):
  33. try:
  34. events = self._file_store.list(f'sessions/{self.sid}/events')
  35. except FileNotFoundError:
  36. logger.debug(f'No events found for session {self.sid}')
  37. self._cur_id = 0
  38. return
  39. # if we have events, we need to find the highest id to prepare for new events
  40. for event_str in events:
  41. id = self._get_id_from_filename(event_str)
  42. if id >= self._cur_id:
  43. self._cur_id = id + 1
  44. def _get_filename_for_id(self, id: int) -> str:
  45. return f'sessions/{self.sid}/events/{id}.json'
  46. @staticmethod
  47. def _get_id_from_filename(filename: str) -> int:
  48. try:
  49. return int(filename.split('/')[-1].split('.')[0])
  50. except ValueError:
  51. logger.warning(f'get id from filename ({filename}) failed.')
  52. return -1
  53. def get_events(
  54. self,
  55. start_id=0,
  56. end_id=None,
  57. reverse=False,
  58. filter_out_type: tuple[type[Event], ...] | None = None,
  59. ) -> Iterable[Event]:
  60. if reverse:
  61. if end_id is None:
  62. end_id = self._cur_id - 1
  63. event_id = end_id
  64. while event_id >= start_id:
  65. try:
  66. event = self.get_event(event_id)
  67. if filter_out_type is None or not isinstance(
  68. event, filter_out_type
  69. ):
  70. yield event
  71. except FileNotFoundError:
  72. logger.debug(f'No event found for ID {event_id}')
  73. event_id -= 1
  74. else:
  75. event_id = start_id
  76. while True:
  77. if end_id is not None and event_id > end_id:
  78. break
  79. try:
  80. event = self.get_event(event_id)
  81. if filter_out_type is None or not isinstance(
  82. event, filter_out_type
  83. ):
  84. yield event
  85. except FileNotFoundError:
  86. break
  87. event_id += 1
  88. def get_event(self, id: int) -> Event:
  89. filename = self._get_filename_for_id(id)
  90. content = self._file_store.read(filename)
  91. data = json.loads(content)
  92. return event_from_dict(data)
  93. def get_latest_event(self) -> Event:
  94. return self.get_event(self._cur_id - 1)
  95. def get_latest_event_id(self) -> int:
  96. return self._cur_id - 1
  97. def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
  98. if id in self._subscribers:
  99. if append:
  100. self._subscribers[id].append(callback)
  101. else:
  102. raise ValueError('Subscriber already exists: ' + id)
  103. else:
  104. self._subscribers[id] = [callback]
  105. def unsubscribe(self, id: EventStreamSubscriber):
  106. if id not in self._subscribers:
  107. logger.warning('Subscriber not found during unsubscribe: ' + id)
  108. else:
  109. self._subscribers[id].pop()
  110. if len(self._subscribers[id]) == 0:
  111. del self._subscribers[id]
  112. def add_event(self, event: Event, source: EventSource):
  113. with self._lock:
  114. event._id = self._cur_id # type: ignore [attr-defined]
  115. self._cur_id += 1
  116. logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
  117. event._timestamp = datetime.now() # type: ignore [attr-defined]
  118. event._source = source # type: ignore [attr-defined]
  119. data = event_to_dict(event)
  120. if event.id is not None:
  121. self._file_store.write(
  122. self._get_filename_for_id(event.id), json.dumps(data)
  123. )
  124. for stack in self._subscribers.values():
  125. callback = stack[-1]
  126. asyncio.create_task(callback(event))
  127. def filtered_events_by_source(self, source: EventSource):
  128. for event in self.get_events():
  129. if event.source == source:
  130. yield event
  131. def clear(self):
  132. self._file_store.delete(f'sessions/{self.sid}')
  133. self._cur_id = 0
  134. # self._subscribers = {}
  135. self._reinitialize_from_file_store()