stream.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import asyncio
  2. import json
  3. from datetime import datetime
  4. from enum import Enum
  5. from typing import Callable, Iterable
  6. from opendevin.core.logger import opendevin_logger as logger
  7. from opendevin.events.serialization.event import event_from_dict, event_to_dict
  8. from opendevin.storage import FileStore, get_file_store
  9. from .event import Event, EventSource
  10. class EventStreamSubscriber(str, Enum):
  11. AGENT_CONTROLLER = 'agent_controller'
  12. SERVER = 'server'
  13. RUNTIME = 'runtime'
  14. MAIN = 'main'
  15. TEST = 'test'
  16. class EventStream:
  17. sid: str
  18. # For each subscriber ID, there is a stack of callback functions - useful
  19. # when there are agent delegates
  20. _subscribers: dict[str, list[Callable]]
  21. _cur_id: int
  22. _lock: asyncio.Lock
  23. _file_store: FileStore
  24. def __init__(self, sid: str):
  25. self.sid = sid
  26. self._file_store = get_file_store()
  27. self._subscribers = {}
  28. self._cur_id = 0
  29. self._lock = asyncio.Lock()
  30. self._reinitialize_from_file_store()
  31. def _reinitialize_from_file_store(self):
  32. try:
  33. events = self._file_store.list(f'sessions/{self.sid}/events')
  34. except FileNotFoundError:
  35. logger.warning(f'No events found for session {self.sid}')
  36. return
  37. for event_str in events:
  38. id = self._get_id_from_filename(event_str)
  39. if id >= self._cur_id:
  40. self._cur_id = id + 1
  41. def _get_filename_for_id(self, id: int) -> str:
  42. return f'sessions/{self.sid}/events/{id}.json'
  43. @staticmethod
  44. def _get_id_from_filename(filename: str) -> int:
  45. try:
  46. return int(filename.split('/')[-1].split('.')[0])
  47. except ValueError:
  48. logger.warning(f'get id from filename ({filename}) failed.')
  49. return -1
  50. def get_events(self, start_id=0, end_id=None) -> Iterable[Event]:
  51. event_id = start_id
  52. while True:
  53. if end_id is not None and event_id > end_id:
  54. break
  55. try:
  56. event = self.get_event(event_id)
  57. except FileNotFoundError:
  58. break
  59. yield event
  60. event_id += 1
  61. def get_event(self, id: int) -> Event:
  62. filename = self._get_filename_for_id(id)
  63. content = self._file_store.read(filename)
  64. data = json.loads(content)
  65. return event_from_dict(data)
  66. def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
  67. if id in self._subscribers:
  68. if append:
  69. self._subscribers[id].append(callback)
  70. else:
  71. raise ValueError('Subscriber already exists: ' + id)
  72. else:
  73. self._subscribers[id] = [callback]
  74. def unsubscribe(self, id: EventStreamSubscriber):
  75. if id not in self._subscribers:
  76. logger.warning('Subscriber not found during unsubscribe: ' + id)
  77. else:
  78. self._subscribers[id].pop()
  79. if len(self._subscribers[id]) == 0:
  80. del self._subscribers[id]
  81. # TODO: make this not async
  82. async def add_event(self, event: Event, source: EventSource):
  83. async with self._lock:
  84. event._id = self._cur_id # type: ignore [attr-defined]
  85. self._cur_id += 1
  86. event._timestamp = datetime.now() # type: ignore [attr-defined]
  87. event._source = source # type: ignore [attr-defined]
  88. data = event_to_dict(event)
  89. if event.id is not None:
  90. self._file_store.write(
  91. self._get_filename_for_id(event.id), json.dumps(data)
  92. )
  93. for key, stack in self._subscribers.items():
  94. callback = stack[-1]
  95. await callback(event)