فهرست منبع

Logger fixes for openhands-resolver (#4710)

Co-authored-by: Graham Neubig <neubig@gmail.com>
Rohit Malhotra 1 سال پیش
والد
کامیت
436ecb80a3

+ 14 - 2
openhands/controller/agent_controller.py

@@ -108,7 +108,7 @@ class AgentController:
         # subscribe to the event stream
         self.event_stream = event_stream
         self.event_stream.subscribe(
-            EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
+            EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
         )
 
         # state from the previous session, state from a parent agent, or a fresh state
@@ -156,7 +156,7 @@ class AgentController:
         )
 
         # unsubscribe from the event stream
-        self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
+        self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
 
     def log(self, level: str, message: str, extra: dict | None = None):
         """Logs a message to the agent controller's logger.
@@ -403,6 +403,8 @@ class AgentController:
             'debug',
             f'start delegate, creating agent {delegate_agent.name} using LLM {llm}',
         )
+
+        self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
         self.delegate = AgentController(
             sid=self.id + '-delegate',
             agent=delegate_agent,
@@ -519,6 +521,11 @@ class AgentController:
 
             # close the delegate upon error
             await self.delegate.close()
+
+            # resubscribe parent when delegate is finished
+            self.event_stream.subscribe(
+                EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
+            )
             self.delegate = None
             self.delegateAction = None
 
@@ -533,6 +540,11 @@ class AgentController:
             # close delegate controller: we must close the delegate controller before adding new events
             await self.delegate.close()
 
+            # resubscribe parent when delegate is finished
+            self.event_stream.subscribe(
+                EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
+            )
+
             # update delegate result observation
             # TODO: replace this with AI-generated summary (#2395)
             formatted_output = ', '.join(

+ 2 - 1
openhands/core/cli.py

@@ -2,6 +2,7 @@ import asyncio
 import logging
 import sys
 from typing import Type
+from uuid import uuid4
 
 from termcolor import colored
 
@@ -150,7 +151,7 @@ async def main():
             ]:
                 await prompt_for_next_task()
 
-    event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
+    event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
 
     await runtime.connect()
 

+ 1 - 1
openhands/core/main.py

@@ -186,7 +186,7 @@ async def run_controller(
                 action = MessageAction(content=message)
                 event_stream.add_event(action, EventSource.USER)
 
-    event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
+    event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, sid)
 
     await runtime.connect()
 

+ 30 - 21
openhands/events/stream.py

@@ -17,6 +17,7 @@ from openhands.utils.async_utils import call_sync_from_async
 class EventStreamSubscriber(str, Enum):
     AGENT_CONTROLLER = 'agent_controller'
     SECURITY_ANALYZER = 'security_analyzer'
+    RESOLVER = 'openhands_resolver'
     SERVER = 'server'
     RUNTIME = 'runtime'
     MAIN = 'main'
@@ -50,9 +51,9 @@ class AsyncEventStreamWrapper:
 class EventStream:
     sid: str
     file_store: FileStore
-    # For each subscriber ID, there is a stack of callback functions - useful
-    # when there are agent delegates
-    _subscribers: dict[str, list[Callable]] = field(default_factory=dict)
+    # For each subscriber ID, there is a map of callback functions - useful
+    # when there are multiple listeners
+    _subscribers: dict[str, dict[str, Callable]] = field(default_factory=dict)
     _cur_id: int = 0
     _lock: threading.Lock = field(default_factory=threading.Lock)
 
@@ -148,22 +149,29 @@ class EventStream:
     def get_latest_event_id(self) -> int:
         return self._cur_id - 1
 
-    def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
-        if id in self._subscribers:
-            if append:
-                self._subscribers[id].append(callback)
-            else:
-                raise ValueError('Subscriber already exists: ' + id)
-        else:
-            self._subscribers[id] = [callback]
+    def subscribe(
+        self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
+    ):
+        if subscriber_id not in self._subscribers:
+            self._subscribers[subscriber_id] = {}
 
-    def unsubscribe(self, id: EventStreamSubscriber):
-        if id not in self._subscribers:
-            logger.warning('Subscriber not found during unsubscribe: ' + id)
-        else:
-            self._subscribers[id].pop()
-            if len(self._subscribers[id]) == 0:
-                del self._subscribers[id]
+        if callback_id in self._subscribers[subscriber_id]:
+            raise ValueError(
+                f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
+            )
+
+        self._subscribers[subscriber_id][callback_id] = callback
+
+    def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
+        if subscriber_id not in self._subscribers:
+            logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
+            return
+
+        if callback_id not in self._subscribers[subscriber_id]:
+            logger.warning(f'Callback not found during unsubscribe: {callback_id}')
+            return
+
+        del self._subscribers[subscriber_id][callback_id]
 
     def add_event(self, event: Event, source: EventSource):
         try:
@@ -188,9 +196,10 @@ class EventStream:
             self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
         tasks = []
         for key in sorted(self._subscribers.keys()):
-            stack = self._subscribers[key]
-            callback = stack[-1]
-            tasks.append(asyncio.create_task(callback(event)))
+            callbacks = self._subscribers[key]
+            for callback_id in callbacks:
+                callback = callbacks[callback_id]
+                tasks.append(asyncio.create_task(callback(event)))
         if tasks:
             await asyncio.wait(tasks)
 

+ 3 - 1
openhands/runtime/base.py

@@ -86,7 +86,9 @@ class Runtime(FileEditRuntimeMixin):
     ):
         self.sid = sid
         self.event_stream = event_stream
-        self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
+        self.event_stream.subscribe(
+            EventStreamSubscriber.RUNTIME, self.on_event, self.sid
+        )
         self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
         self.status_callback = status_callback
         self.attach_to_existing = attach_to_existing

+ 2 - 1
openhands/security/analyzer.py

@@ -1,4 +1,5 @@
 from typing import Any
+from uuid import uuid4
 
 from fastapi import Request
 
@@ -19,7 +20,7 @@ class SecurityAnalyzer:
         """
         self.event_stream = event_stream
         self.event_stream.subscribe(
-            EventStreamSubscriber.SECURITY_ANALYZER, self.on_event
+            EventStreamSubscriber.SECURITY_ANALYZER, self.on_event, str(uuid4())
         )
 
     async def on_event(self, event: Event) -> None:

+ 1 - 1
openhands/server/session/session.py

@@ -44,7 +44,7 @@ class Session:
             sid, file_store, status_callback=self.queue_status_message
         )
         self.agent_session.event_stream.subscribe(
-            EventStreamSubscriber.SERVER, self.on_event
+            EventStreamSubscriber.SERVER, self.on_event, self.sid
         )
         self.config = config
         self.loop = asyncio.get_event_loop()

+ 3 - 2
tests/unit/test_agent_controller.py

@@ -1,5 +1,6 @@
 import asyncio
 from unittest.mock import AsyncMock, MagicMock, Mock
+from uuid import uuid4
 
 import pytest
 
@@ -143,7 +144,7 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
             error_obs._cause = event.id
             event_stream.add_event(error_obs, EventSource.USER)
 
-    event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
+    event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
     runtime.event_stream = event_stream
 
     state = await run_controller(
@@ -188,7 +189,7 @@ async def test_run_controller_stop_with_stuck():
             non_fatal_error_obs._cause = event.id
             event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
 
-    event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
+    event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
     runtime.event_stream = event_stream
 
     state = await run_controller(