stream.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import asyncio
  2. import json
  3. from datetime import datetime
  4. from enum import Enum
  5. from typing import Callable, Iterable
  6. from opendevin.core.logger import opendevin_logger as logger
  7. from opendevin.events.serialization.event import event_from_dict, event_to_dict
  8. from opendevin.storage import FileStore, get_file_store
  9. from .event import Event
  10. from .serialization.event import EventSource
  11. class EventStreamSubscriber(str, Enum):
  12. AGENT_CONTROLLER = 'agent_controller'
  13. SERVER = 'server'
  14. RUNTIME = 'runtime'
  15. MAIN = 'main'
  16. TEST = 'test'
  17. class EventStream:
  18. sid: str
  19. _subscribers: dict[str, Callable]
  20. _cur_id: int
  21. _lock: asyncio.Lock
  22. _file_store: FileStore
  23. def __init__(self, sid: str):
  24. self.sid = sid
  25. self._file_store = get_file_store()
  26. self._subscribers = {}
  27. self._cur_id = 0
  28. self._lock = asyncio.Lock()
  29. self._reinitialize_from_file_store()
  30. def _reinitialize_from_file_store(self):
  31. events = self._file_store.list(f'sessions/{self.sid}/events')
  32. for event_str in events:
  33. id = self._get_id_from_filename(event_str)
  34. if id >= self._cur_id:
  35. self._cur_id = id + 1
  36. def _get_filename_for_id(self, id: int) -> str:
  37. return f'sessions/{self.sid}/events/{id}.json'
  38. def _get_id_from_filename(self, filename: str) -> int:
  39. return int(filename.split('/')[-1].split('.')[0])
  40. def get_events(self, start_id=0, end_id=None) -> Iterable[Event]:
  41. events = self._file_store.list(f'sessions/{self.sid}/events')
  42. for event_str in events:
  43. id = self._get_id_from_filename(event_str)
  44. if start_id <= id and (end_id is None or id <= end_id):
  45. event = self.get_event(id)
  46. yield event
  47. def get_event(self, id: int) -> Event:
  48. filename = self._get_filename_for_id(id)
  49. content = self._file_store.read(filename)
  50. data = json.loads(content)
  51. return event_from_dict(data)
  52. def subscribe(self, id: EventStreamSubscriber, callback: Callable):
  53. if id in self._subscribers:
  54. raise ValueError('Subscriber already exists: ' + id)
  55. else:
  56. self._subscribers[id] = callback
  57. def unsubscribe(self, id: EventStreamSubscriber):
  58. if id not in self._subscribers:
  59. logger.warning('Subscriber not found during unsubscribe: ' + id)
  60. else:
  61. del self._subscribers[id]
  62. # TODO: make this not async
  63. async def add_event(self, event: Event, source: EventSource):
  64. async with self._lock:
  65. event._id = self._cur_id # type: ignore [attr-defined]
  66. self._cur_id += 1
  67. event._timestamp = datetime.now() # type: ignore [attr-defined]
  68. event._source = source # type: ignore [attr-defined]
  69. data = event_to_dict(event)
  70. self._file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
  71. for key, fn in self._subscribers.items():
  72. await fn(event)