| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- import asyncio
- import threading
- from datetime import datetime
- from enum import Enum
- from typing import Callable, Iterable
- from openhands.core.logger import openhands_logger as logger
- from openhands.core.utils import json
- from openhands.events.event import Event, EventSource
- from openhands.events.serialization.event import event_from_dict, event_to_dict
- from openhands.storage import FileStore
- class EventStreamSubscriber(str, Enum):
- AGENT_CONTROLLER = 'agent_controller'
- SECURITY_ANALYZER = 'security_analyzer'
- SERVER = 'server'
- RUNTIME = 'runtime'
- MAIN = 'main'
- TEST = 'test'
- class EventStream:
- sid: str
- file_store: FileStore
- # For each subscriber ID, there is a stack of callback functions - useful
- # when there are agent delegates
- _subscribers: dict[str, list[Callable]]
- _cur_id: int
- _lock: threading.Lock
- def __init__(self, sid: str, file_store: FileStore):
- self.sid = sid
- self.file_store = file_store
- self._subscribers = {}
- self._cur_id = 0
- self._lock = threading.Lock()
- self._reinitialize_from_file_store()
- def _reinitialize_from_file_store(self) -> None:
- try:
- events = self.file_store.list(f'sessions/{self.sid}/events')
- except FileNotFoundError:
- logger.debug(f'No events found for session {self.sid}')
- self._cur_id = 0
- return
- # if we have events, we need to find the highest id to prepare for new events
- for event_str in events:
- id = self._get_id_from_filename(event_str)
- if id >= self._cur_id:
- self._cur_id = id + 1
- def _get_filename_for_id(self, id: int) -> str:
- return f'sessions/{self.sid}/events/{id}.json'
- @staticmethod
- def _get_id_from_filename(filename: str) -> int:
- try:
- return int(filename.split('/')[-1].split('.')[0])
- except ValueError:
- logger.warning(f'get id from filename ({filename}) failed.')
- return -1
- def get_events(
- self,
- start_id=0,
- end_id=None,
- reverse=False,
- filter_out_type: tuple[type[Event], ...] | None = None,
- ) -> Iterable[Event]:
- if reverse:
- if end_id is None:
- end_id = self._cur_id - 1
- event_id = end_id
- while event_id >= start_id:
- try:
- event = self.get_event(event_id)
- if filter_out_type is None or not isinstance(
- event, filter_out_type
- ):
- yield event
- except FileNotFoundError:
- logger.debug(f'No event found for ID {event_id}')
- event_id -= 1
- else:
- event_id = start_id
- while True:
- if end_id is not None and event_id > end_id:
- break
- try:
- event = self.get_event(event_id)
- if filter_out_type is None or not isinstance(
- event, filter_out_type
- ):
- yield event
- except FileNotFoundError:
- break
- event_id += 1
- def get_event(self, id: int) -> Event:
- filename = self._get_filename_for_id(id)
- content = self.file_store.read(filename)
- data = json.loads(content)
- return event_from_dict(data)
- def get_latest_event(self) -> Event:
- return self.get_event(self._cur_id - 1)
- def get_latest_event_id(self) -> int:
- return self._cur_id - 1
- def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
- if id in self._subscribers:
- if append:
- self._subscribers[id].append(callback)
- else:
- raise ValueError('Subscriber already exists: ' + id)
- else:
- self._subscribers[id] = [callback]
- def unsubscribe(self, id: EventStreamSubscriber):
- if id not in self._subscribers:
- logger.warning('Subscriber not found during unsubscribe: ' + id)
- else:
- self._subscribers[id].pop()
- if len(self._subscribers[id]) == 0:
- del self._subscribers[id]
- def add_event(self, event: Event, source: EventSource):
- with self._lock:
- event._id = self._cur_id # type: ignore [attr-defined]
- self._cur_id += 1
- logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
- event._timestamp = datetime.now() # type: ignore [attr-defined]
- event._source = source # type: ignore [attr-defined]
- data = event_to_dict(event)
- if event.id is not None:
- self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
- for key in sorted(self._subscribers.keys()):
- stack = self._subscribers[key]
- callback = stack[-1]
- asyncio.create_task(callback(event))
- def filtered_events_by_source(self, source: EventSource):
- for event in self.get_events():
- if event.source == source:
- yield event
- def clear(self):
- self.file_store.delete(f'sessions/{self.sid}')
- self._cur_id = 0
- # self._subscribers = {}
- self._reinitialize_from_file_store()
|