stream.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import asyncio
  2. import json
  3. import threading
  4. from datetime import datetime
  5. from enum import Enum
  6. from typing import Callable, Iterable
  7. from opendevin.core.logger import opendevin_logger as logger
  8. from opendevin.events.serialization.event import event_from_dict, event_to_dict
  9. from opendevin.storage import FileStore, get_file_store
  10. from .event import Event, EventSource
  11. class EventStreamSubscriber(str, Enum):
  12. AGENT_CONTROLLER = 'agent_controller'
  13. SERVER = 'server'
  14. RUNTIME = 'runtime'
  15. MAIN = 'main'
  16. TEST = 'test'
  17. class EventStream:
  18. sid: str
  19. # For each subscriber ID, there is a stack of callback functions - useful
  20. # when there are agent delegates
  21. _subscribers: dict[str, list[Callable]]
  22. _cur_id: int
  23. _lock: threading.Lock
  24. _file_store: FileStore
  25. def __init__(self, sid: str):
  26. self.sid = sid
  27. self._file_store = get_file_store()
  28. self._subscribers = {}
  29. self._cur_id = 0
  30. self._lock = threading.Lock()
  31. self._reinitialize_from_file_store()
  32. def _reinitialize_from_file_store(self):
  33. try:
  34. events = self._file_store.list(f'sessions/{self.sid}/events')
  35. except FileNotFoundError:
  36. logger.warning(f'No events found for session {self.sid}')
  37. return
  38. for event_str in events:
  39. id = self._get_id_from_filename(event_str)
  40. if id >= self._cur_id:
  41. self._cur_id = id + 1
  42. def _get_filename_for_id(self, id: int) -> str:
  43. return f'sessions/{self.sid}/events/{id}.json'
  44. @staticmethod
  45. def _get_id_from_filename(filename: str) -> int:
  46. try:
  47. return int(filename.split('/')[-1].split('.')[0])
  48. except ValueError:
  49. logger.warning(f'get id from filename ({filename}) failed.')
  50. return -1
  51. def get_events(self, start_id=0, end_id=None) -> Iterable[Event]:
  52. event_id = start_id
  53. while True:
  54. if end_id is not None and event_id > end_id:
  55. break
  56. try:
  57. event = self.get_event(event_id)
  58. except FileNotFoundError:
  59. break
  60. yield event
  61. event_id += 1
  62. def get_event(self, id: int) -> Event:
  63. filename = self._get_filename_for_id(id)
  64. content = self._file_store.read(filename)
  65. data = json.loads(content)
  66. return event_from_dict(data)
  67. def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
  68. if id in self._subscribers:
  69. if append:
  70. self._subscribers[id].append(callback)
  71. else:
  72. raise ValueError('Subscriber already exists: ' + id)
  73. else:
  74. self._subscribers[id] = [callback]
  75. def unsubscribe(self, id: EventStreamSubscriber):
  76. if id not in self._subscribers:
  77. logger.warning('Subscriber not found during unsubscribe: ' + id)
  78. else:
  79. self._subscribers[id].pop()
  80. if len(self._subscribers[id]) == 0:
  81. del self._subscribers[id]
  82. def add_event(self, event: Event, source: EventSource):
  83. with self._lock:
  84. event._id = self._cur_id # type: ignore [attr-defined]
  85. self._cur_id += 1
  86. logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
  87. event._timestamp = datetime.now() # type: ignore[attr-defined]
  88. event._source = source # type: ignore[attr-defined]
  89. data = event_to_dict(event)
  90. if event.id is not None:
  91. self._file_store.write(
  92. self._get_filename_for_id(event.id), json.dumps(data)
  93. )
  94. for stack in self._subscribers.values():
  95. callback = stack[-1]
  96. asyncio.create_task(callback(event))