Prechádzať zdrojové kódy

Lockup Resiliency and Asyncio Improvements (#4221)

tofarr 1 rok pred
rodič
commit
cdd05a98db

+ 1 - 1
Makefile

@@ -195,7 +195,7 @@ start-backend:
 # Start frontend
 start-frontend:
 	@echo "$(YELLOW)Starting frontend...$(RESET)"
-	@cd frontend && VITE_BACKEND_HOST=$(BACKEND_HOST_PORT) VITE_FRONTEND_PORT=$(FRONTEND_PORT) npm run start
+	@cd frontend && VITE_BACKEND_HOST=$(BACKEND_HOST_PORT) VITE_FRONTEND_PORT=$(FRONTEND_PORT) npm run start -- --port $(FRONTEND_PORT)
 
 # Common setup for running the app (non-callable)
 _run_setup:

+ 14 - 1
openhands/events/stream.py

@@ -129,6 +129,13 @@ class EventStream:
                 del self._subscribers[id]
 
     def add_event(self, event: Event, source: EventSource):
+        try:
+            asyncio.get_running_loop().create_task(self.async_add_event(event, source))
+        except RuntimeError:
+            # No event loop running...
+            asyncio.run(self.async_add_event(event, source))
+
+    async def async_add_event(self, event: Event, source: EventSource):
         with self._lock:
             event._id = self._cur_id  # type: ignore [attr-defined]
             self._cur_id += 1
@@ -138,10 +145,16 @@ class EventStream:
         data = event_to_dict(event)
         if event.id is not None:
             self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
+        tasks = []
         for key in sorted(self._subscribers.keys()):
             stack = self._subscribers[key]
             callback = stack[-1]
-            asyncio.create_task(callback(event))
+            tasks.append(asyncio.create_task(callback(event)))
+        if tasks:
+            await asyncio.wait(tasks)
+
+    def _callback(self, callback: Callable, event: Event):
+        asyncio.run(callback(event))
 
     def filtered_events_by_source(self, source: EventSource):
         for event in self.get_events():

+ 1 - 1
openhands/llm/async_llm.py

@@ -73,7 +73,7 @@ class AsyncLLM(LLM):
                         and self.config.on_cancel_requested_fn is not None
                         and await self.config.on_cancel_requested_fn()
                     ):
-                        raise UserCancelledError('LLM request cancelled by user')
+                        return
                     await asyncio.sleep(0.1)
 
             stop_check_task = asyncio.create_task(check_stopped())

+ 5 - 1
openhands/runtime/remote/runtime.py

@@ -200,6 +200,9 @@ class RemoteRuntime(Runtime):
         assert (
             self.runtime_url is not None
         ), 'Runtime URL is not set. This should never happen.'
+
+        self._wait_until_alive()
+
         self.send_status_message(' ')
 
         self._wait_until_alive()
@@ -229,7 +232,7 @@ class RemoteRuntime(Runtime):
             logger.warning(msg)
             raise RuntimeError(msg)
 
-    def close(self):
+    def close(self, timeout: int = 10):
         if self.runtime_id:
             try:
                 response = send_request(
@@ -237,6 +240,7 @@ class RemoteRuntime(Runtime):
                     'POST',
                     f'{self.config.sandbox.remote_runtime_api_url}/stop',
                     json={'runtime_id': self.runtime_id},
+                    timeout=timeout,
                 )
                 if response.status_code != 200:
                     logger.error(f'Failed to stop sandbox: {response.text}')

+ 9 - 2
openhands/runtime/runtime.py

@@ -1,3 +1,4 @@
+import asyncio
 import atexit
 import copy
 import json
@@ -117,10 +118,10 @@ class Runtime:
             if event.timeout is None:
                 event.timeout = self.config.sandbox.timeout
             assert event.timeout is not None
-            observation = self.run_action(event)
+            observation = await self.async_run_action(event)
             observation._cause = event.id  # type: ignore[attr-defined]
             source = event.source if event.source else EventSource.AGENT
-            self.event_stream.add_event(observation, source)  # type: ignore[arg-type]
+            await self.event_stream.async_add_event(observation, source)  # type: ignore[arg-type]
 
     def run_action(self, action: Action) -> Observation:
         """Run an action and return the resulting observation.
@@ -151,6 +152,12 @@ class Runtime:
         observation = getattr(self, action_type)(action)
         return observation
 
+    async def async_run_action(self, action: Action) -> Observation:
+        observation = await asyncio.get_event_loop().run_in_executor(
+            None, self.run_action, action
+        )
+        return observation
+
     # ====================================================================
     # Context manager
     # ====================================================================

+ 3 - 4
openhands/security/invariant/analyzer.py

@@ -1,3 +1,4 @@
+import asyncio
 import re
 import uuid
 from typing import Any
@@ -144,10 +145,8 @@ class InvariantAnalyzer(SecurityAnalyzer):
         new_event = action_from_dict(
             {'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
         )
-        if event.source:
-            self.event_stream.add_event(new_event, event.source)
-        else:
-            self.event_stream.add_event(new_event, EventSource.AGENT)
+        event_source = event.source if event.source else EventSource.AGENT
+        await asyncio.get_event_loop().run_in_executor(None, self.event_stream.add_event, new_event, event_source)
 
     async def security_risk(self, event: Action) -> ActionSecurityRisk:
         logger.info('Calling security_risk on InvariantAnalyzer')

+ 6 - 3
openhands/server/listen.py

@@ -430,7 +430,9 @@ async def list_files(request: Request, path: str | None = None):
             content={'error': 'Runtime not yet initialized'},
         )
     runtime: Runtime = request.state.session.agent_session.runtime
-    file_list = runtime.list_files(path)
+    file_list = await asyncio.get_event_loop().run_in_executor(
+        None, runtime.list_files, path
+    )
     if path:
         file_list = [os.path.join(path, f) for f in file_list]
 
@@ -451,6 +453,7 @@ async def list_files(request: Request, path: str | None = None):
         return file_list
 
     file_list = filter_for_gitignore(file_list, '')
+
     return file_list
 
 
@@ -478,7 +481,7 @@ async def select_file(file: str, request: Request):
 
     file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
     read_action = FileReadAction(file)
-    observation = runtime.run_action(read_action)
+    observation = await runtime.async_run_action(read_action)
 
     if isinstance(observation, FileReadObservation):
         content = observation.content
@@ -720,7 +723,7 @@ async def save_file(request: Request):
             runtime.config.workspace_mount_path_in_sandbox, file_path
         )
         write_action = FileWriteAction(file_path, content)
-        observation = runtime.run_action(write_action)
+        observation = await runtime.async_run_action(write_action)
 
         if isinstance(observation, FileWriteObservation):
             return JSONResponse(

+ 13 - 26
openhands/server/session/agent_session.py

@@ -1,6 +1,4 @@
 import asyncio
-import concurrent.futures
-from threading import Thread
 from typing import Callable, Optional
 
 from openhands.controller import AgentController
@@ -32,7 +30,7 @@ class AgentSession:
     runtime: Runtime | None = None
     security_analyzer: SecurityAnalyzer | None = None
     _closed: bool = False
-    loop: asyncio.AbstractEventLoop
+    loop: asyncio.AbstractEventLoop | None = None
 
     def __init__(self, sid: str, file_store: FileStore):
         """Initializes a new instance of the Session class
@@ -45,7 +43,6 @@ class AgentSession:
         self.sid = sid
         self.event_stream = EventStream(sid, file_store)
         self.file_store = file_store
-        self.loop = asyncio.new_event_loop()
 
     async def start(
         self,
@@ -73,17 +70,9 @@ class AgentSession:
                 'Session already started. You need to close this session and start a new one.'
             )
 
-        self.thread = Thread(target=self._run, daemon=True)
-        self.thread.start()
-
-        def coro_callback(task):
-            fut: concurrent.futures.Future = concurrent.futures.Future()
-            try:
-                fut.set_result(task.result())
-            except Exception as e:
-                logger.error(f'Error starting session: {e}')
-
-        coro = self._start(
+        asyncio.get_event_loop().run_in_executor(
+            None,
+            self._start_thread,
             runtime_name,
             config,
             agent,
@@ -93,9 +82,12 @@ class AgentSession:
             agent_configs,
             status_message_callback,
         )
-        asyncio.run_coroutine_threadsafe(coro, self.loop).add_done_callback(
-            coro_callback
-        )  # type: ignore
+
+    def _start_thread(self, *args):
+        try:
+            asyncio.run(self._start(*args), debug=True)
+        except RuntimeError:
+            logger.info('Session Finished')
 
     async def _start(
         self,
@@ -108,6 +100,7 @@ class AgentSession:
         agent_configs: dict[str, AgentConfig] | None = None,
         status_message_callback: Optional[Callable] = None,
     ):
+        self.loop = asyncio.get_running_loop()
         self._create_security_analyzer(config.security.security_analyzer)
         self._create_runtime(runtime_name, config, agent, status_message_callback)
         self._create_controller(
@@ -125,10 +118,6 @@ class AgentSession:
             self.controller.agent_task = self.controller.start_step_loop()
             await self.controller.agent_task  # type: ignore
 
-    def _run(self):
-        asyncio.set_event_loop(self.loop)
-        self.loop.run_forever()
-
     async def close(self):
         """Closes the Agent session"""
 
@@ -143,10 +132,8 @@ class AgentSession:
         if self.security_analyzer is not None:
             await self.security_analyzer.close()
 
-        self.loop.call_soon_threadsafe(self.loop.stop)
-        if self.thread:
-            # We may be closing an agent_session that was never actually started
-            self.thread.join()
+        if self.loop:
+            self.loop.call_soon_threadsafe(self.loop.stop)
 
         self._closed = True
 

+ 4 - 3
openhands/server/session/session.py

@@ -162,9 +162,10 @@ class Session:
                         'Model does not support image upload, change to a different model or try without an image.'
                     )
                     return
-        asyncio.run_coroutine_threadsafe(
-            self._add_event(event, EventSource.USER), self.agent_session.loop
-        )  # type: ignore
+        if self.agent_session.loop:
+            asyncio.run_coroutine_threadsafe(
+                self._add_event(event, EventSource.USER), self.agent_session.loop
+            )  # type: ignore
 
     async def _add_event(self, event, event_source):
         self.agent_session.event_stream.add_event(event, EventSource.USER)

+ 6 - 7
tests/unit/test_security.py

@@ -1,4 +1,3 @@
-import asyncio
 import pathlib
 import tempfile
 
@@ -42,7 +41,7 @@ def temp_dir(monkeypatch):
         yield temp_dir
 
 
-async def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]]):
+def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]]):
     for event, source in data:
         event_stream.add_event(event, source)
 
@@ -62,7 +61,7 @@ def test_msg(temp_dir: str):
         (MessageAction('Hello world!'), EventSource.USER),
         (MessageAction('ABC!'), EventSource.AGENT),
     ]
-    asyncio.run(add_events(event_stream, data))
+    add_events(event_stream, data)
     for i in range(3):
         assert data[i][0].security_risk == ActionSecurityRisk.LOW
     assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
@@ -86,7 +85,7 @@ def test_cmd(cmd, expected_risk, temp_dir: str):
         (MessageAction('Hello world!'), EventSource.USER),
         (CmdRunAction(cmd), EventSource.USER),
     ]
-    asyncio.run(add_events(event_stream, data))
+    add_events(event_stream, data)
     assert data[0][0].security_risk == ActionSecurityRisk.LOW
     assert data[1][0].security_risk == expected_risk
 
@@ -115,7 +114,7 @@ def test_leak_secrets(code, expected_risk, temp_dir: str):
         (IPythonRunCellAction(code), EventSource.AGENT),
         (IPythonRunCellAction('hello'), EventSource.AGENT),
     ]
-    asyncio.run(add_events(event_stream, data))
+    add_events(event_stream, data)
     assert data[0][0].security_risk == ActionSecurityRisk.LOW
     assert data[1][0].security_risk == expected_risk
     assert data[2][0].security_risk == ActionSecurityRisk.LOW
@@ -133,7 +132,7 @@ def test_unsafe_python_code(temp_dir: str):
         (MessageAction('Hello world!'), EventSource.USER),
         (IPythonRunCellAction(code), EventSource.AGENT),
     ]
-    asyncio.run(add_events(event_stream, data))
+    add_events(event_stream, data)
     assert data[0][0].security_risk == ActionSecurityRisk.LOW
     # TODO: this failed but idk why and seems not deterministic to me
     # assert data[1][0].security_risk == ActionSecurityRisk.MEDIUM
@@ -148,7 +147,7 @@ def test_unsafe_bash_command(temp_dir: str):
         (MessageAction('Hello world!'), EventSource.USER),
         (CmdRunAction(code), EventSource.AGENT),
     ]
-    asyncio.run(add_events(event_stream, data))
+    add_events(event_stream, data)
     assert data[0][0].security_risk == ActionSecurityRisk.LOW
     assert data[1][0].security_risk == ActionSecurityRisk.MEDIUM