|
|
@@ -17,6 +17,7 @@ from openhands.utils.async_utils import call_sync_from_async
|
|
|
class EventStreamSubscriber(str, Enum):
|
|
|
AGENT_CONTROLLER = 'agent_controller'
|
|
|
SECURITY_ANALYZER = 'security_analyzer'
|
|
|
+ RESOLVER = 'openhands_resolver'
|
|
|
SERVER = 'server'
|
|
|
RUNTIME = 'runtime'
|
|
|
MAIN = 'main'
|
|
|
@@ -50,9 +51,9 @@ class AsyncEventStreamWrapper:
|
|
|
class EventStream:
|
|
|
sid: str
|
|
|
file_store: FileStore
|
|
|
- # For each subscriber ID, there is a stack of callback functions - useful
|
|
|
- # when there are agent delegates
|
|
|
- _subscribers: dict[str, list[Callable]] = field(default_factory=dict)
|
|
|
+ # For each subscriber ID, there is a map of callback functions - useful
|
|
|
+ # when there are multiple listeners
|
|
|
+ _subscribers: dict[str, dict[str, Callable]] = field(default_factory=dict)
|
|
|
_cur_id: int = 0
|
|
|
_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
|
|
|
|
@@ -148,22 +149,29 @@ class EventStream:
|
|
|
def get_latest_event_id(self) -> int:
|
|
|
return self._cur_id - 1
|
|
|
|
|
|
- def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
|
|
|
- if id in self._subscribers:
|
|
|
- if append:
|
|
|
- self._subscribers[id].append(callback)
|
|
|
- else:
|
|
|
- raise ValueError('Subscriber already exists: ' + id)
|
|
|
- else:
|
|
|
- self._subscribers[id] = [callback]
|
|
|
+ def subscribe(
|
|
|
+ self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
|
|
|
+ ):
|
|
|
+ if subscriber_id not in self._subscribers:
|
|
|
+ self._subscribers[subscriber_id] = {}
|
|
|
|
|
|
- def unsubscribe(self, id: EventStreamSubscriber):
|
|
|
- if id not in self._subscribers:
|
|
|
- logger.warning('Subscriber not found during unsubscribe: ' + id)
|
|
|
- else:
|
|
|
- self._subscribers[id].pop()
|
|
|
- if len(self._subscribers[id]) == 0:
|
|
|
- del self._subscribers[id]
|
|
|
+ if callback_id in self._subscribers[subscriber_id]:
|
|
|
+ raise ValueError(
|
|
|
+ f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
|
|
|
+ )
|
|
|
+
|
|
|
+ self._subscribers[subscriber_id][callback_id] = callback
|
|
|
+
|
|
|
+ def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
|
|
|
+ if subscriber_id not in self._subscribers:
|
|
|
+ logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
|
|
|
+ return
|
|
|
+
|
|
|
+ if callback_id not in self._subscribers[subscriber_id]:
|
|
|
+ logger.warning(f'Callback not found during unsubscribe: {callback_id}')
|
|
|
+ return
|
|
|
+
|
|
|
+ del self._subscribers[subscriber_id][callback_id]
|
|
|
|
|
|
def add_event(self, event: Event, source: EventSource):
|
|
|
try:
|
|
|
@@ -188,9 +196,10 @@ class EventStream:
|
|
|
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
|
|
|
tasks = []
|
|
|
for key in sorted(self._subscribers.keys()):
|
|
|
- stack = self._subscribers[key]
|
|
|
- callback = stack[-1]
|
|
|
- tasks.append(asyncio.create_task(callback(event)))
|
|
|
+ callbacks = self._subscribers[key]
|
|
|
+ for callback_id in callbacks:
|
|
|
+ callback = callbacks[callback_id]
|
|
|
+ tasks.append(asyncio.create_task(callback(event)))
|
|
|
if tasks:
|
|
|
await asyncio.wait(tasks)
|
|
|
|