Browse Source

Fix CLI and headless after changes to eventstream (#5949)

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
Robert Brennan 1 năm trước cách đây
mục cha
commit
f3885cadc1

+ 7 - 6
openhands/core/cli.py

@@ -91,7 +91,7 @@ def display_event(event: Event, config: AppConfig):
         display_confirmation(event.confirmation_state)
 
 
-async def main():
+async def main(loop):
     """Runs the agent in CLI mode"""
 
     parser = get_parser()
@@ -112,7 +112,7 @@ async def main():
 
     logger.setLevel(logging.WARNING)
     config = load_app_config(config_file=args.config_file)
-    sid = 'cli'
+    sid = str(uuid4())
 
     agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
     agent_config = config.get_agent_config(config.default_agent)
@@ -150,7 +150,6 @@ async def main():
 
     async def prompt_for_next_task():
         # Run input() in a thread pool to avoid blocking the event loop
-        loop = asyncio.get_event_loop()
         next_message = await loop.run_in_executor(
             None, lambda: input('How can I help? >> ')
         )
@@ -165,13 +164,12 @@ async def main():
         event_stream.add_event(action, EventSource.USER)
 
     async def prompt_for_user_confirmation():
-        loop = asyncio.get_event_loop()
         user_confirmation = await loop.run_in_executor(
             None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
         )
         return user_confirmation.lower() == 'y'
 
-    async def on_event(event: Event):
+    async def on_event_async(event: Event):
         display_event(event, config)
         if isinstance(event, AgentStateChangedObservation):
             if event.agent_state in [
@@ -193,6 +191,9 @@ async def main():
                     ChangeAgentStateAction(AgentState.USER_REJECTED), EventSource.USER
                 )
 
+    def on_event(event: Event) -> None:
+        loop.create_task(on_event_async(event))
+
     event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
 
     await runtime.connect()
@@ -208,7 +209,7 @@ if __name__ == '__main__':
     loop = asyncio.new_event_loop()
     asyncio.set_event_loop(loop)
     try:
-        loop.run_until_complete(main())
+        loop.run_until_complete(main(loop))
     except KeyboardInterrupt:
         print('Received keyboard interrupt, shutting down...')
     except ConnectionRefusedError as e:

+ 1 - 1
openhands/core/config/utils.py

@@ -385,7 +385,7 @@ def get_parser() -> argparse.ArgumentParser:
     parser.add_argument(
         '-n',
         '--name',
-        default='default',
+        default='',
         type=str,
         help='Name for the session',
     )

+ 1 - 1
openhands/core/main.py

@@ -182,7 +182,7 @@ async def run_controller(
         # init with the provided actions
         event_stream.add_event(initial_user_action, EventSource.USER)
 
-    async def on_event(event: Event):
+    def on_event(event: Event):
         if isinstance(event, AgentStateChangedObservation):
             if event.agent_state == AgentState.AWAITING_USER_INPUT:
                 if exit_on_message:

+ 18 - 4
openhands/events/stream.py

@@ -229,10 +229,24 @@ class EventStream:
                 for callback_id in callbacks:
                     callback = callbacks[callback_id]
                     pool = self._thread_pools[key][callback_id]
-                    pool.submit(callback, event)
-
-    def _callback(self, callback: Callable, event: Event):
-        asyncio.run(callback(event))
+                    future = pool.submit(callback, event)
+                    future.add_done_callback(self._make_error_handler(callback_id, key))
+
+    def _make_error_handler(self, callback_id: str, subscriber_id: str):
+        def _handle_callback_error(fut):
+            try:
+                # This will raise any exception that occurred during callback execution
+                fut.result()
+            except Exception as e:
+                logger.error(
+                    f'Error in event callback {callback_id} for subscriber {subscriber_id}: {str(e)}',
+                    exc_info=True,
+                    stack_info=True,
+                )
+                # Re-raise in the main thread so the error is not swallowed
+                raise e
+
+        return _handle_callback_error
 
     def filtered_events_by_source(self, source: EventSource):
         for event in self.get_events():

+ 1 - 1
openhands/resolver/resolve_issue.py

@@ -202,7 +202,7 @@ async def process_issue(
     runtime = create_runtime(config)
     await runtime.connect()
 
-    async def on_event(evt):
+    def on_event(evt):
         logger.info(evt)
 
     runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))

+ 1 - 1
tests/unit/test_arg_parser.py

@@ -18,7 +18,7 @@ def test_parser_default_values():
     assert args.eval_num_workers == 4
     assert args.eval_note is None
     assert args.llm_config is None
-    assert args.name == 'default'
+    assert args.name == ''
     assert not args.no_auto_continue