Просмотр исходного кода

Fix for lockup - create the runtime in a background thread (#4412)

Co-authored-by: Robert Brennan <contact@rbren.io>
tofarr 1 год назад
Родитель
Сommit
8a93da51be

+ 2 - 2
openhands/runtime/runtime.py

@@ -28,7 +28,7 @@ from openhands.events.observation import (
 )
 from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
 from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
-from openhands.utils.async_utils import sync_from_async
+from openhands.utils.async_utils import call_sync_from_async
 
 
 def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
@@ -123,7 +123,7 @@ class Runtime:
             if event.timeout is None:
                 event.timeout = self.config.sandbox.timeout
             assert event.timeout is not None
-            observation = await sync_from_async(self.run_action, event)
+            observation = await call_sync_from_async(self.run_action, event)
             observation._cause = event.id  # type: ignore[attr-defined]
             source = event.source if event.source else EventSource.AGENT
             await self.event_stream.async_add_event(observation, source)  # type: ignore[arg-type]

+ 2 - 2
openhands/security/invariant/analyzer.py

@@ -19,7 +19,7 @@ from openhands.runtime.utils import find_available_tcp_port
 from openhands.security.analyzer import SecurityAnalyzer
 from openhands.security.invariant.client import InvariantClient
 from openhands.security.invariant.parser import TraceElement, parse_element
-from openhands.utils.async_utils import sync_from_async
+from openhands.utils.async_utils import call_sync_from_async
 
 
 class InvariantAnalyzer(SecurityAnalyzer):
@@ -146,7 +146,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
             {'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
         )
         event_source = event.source if event.source else EventSource.AGENT
-        await sync_from_async(self.event_stream.add_event, new_event, event_source)
+        await call_sync_from_async(self.event_stream.add_event, new_event, event_source)
 
     async def security_risk(self, event: Action) -> ActionSecurityRisk:
         logger.info('Calling security_risk on InvariantAnalyzer')

+ 8 - 6
openhands/server/listen.py

@@ -14,7 +14,7 @@ from pathspec.patterns import GitWildMatchPattern
 from openhands.security.options import SecurityAnalyzers
 from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
 from openhands.storage import get_file_store
-from openhands.utils.async_utils import sync_from_async
+from openhands.utils.async_utils import call_sync_from_async
 
 with warnings.catch_warnings():
     warnings.simplefilter('ignore')
@@ -211,8 +211,8 @@ async def attach_session(request: Request, call_next):
             content={'error': 'Invalid token'},
         )
 
-    request.state.conversation = session_manager.attach_to_conversation(
-        request.state.sid
+    request.state.conversation = await call_sync_from_async(
+        session_manager.attach_to_conversation, request.state.sid
     )
     if request.state.conversation is None:
         return JSONResponse(
@@ -441,7 +441,9 @@ async def list_files(request: Request, path: str | None = None):
         )
 
     runtime: Runtime = request.state.conversation.runtime
-    file_list = await sync_from_async(runtime.list_files, path)
+    file_list = await asyncio.create_task(
+        call_sync_from_async(runtime.list_files, path)
+    )
     if path:
         file_list = [os.path.join(path, f) for f in file_list]
 
@@ -490,7 +492,7 @@ async def select_file(file: str, request: Request):
 
     file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
     read_action = FileReadAction(file)
-    observation = await sync_from_async(runtime.run_action, read_action)
+    observation = await call_sync_from_async(runtime.run_action, read_action)
 
     if isinstance(observation, FileReadObservation):
         content = observation.content
@@ -687,7 +689,7 @@ async def save_file(request: Request):
             runtime.config.workspace_mount_path_in_sandbox, file_path
         )
         write_action = FileWriteAction(file_path, content)
-        observation = await sync_from_async(runtime.run_action, write_action)
+        observation = await call_sync_from_async(runtime.run_action, write_action)
 
         if isinstance(observation, FileWriteObservation):
             return JSONResponse(

+ 8 - 1
openhands/server/session/agent_session.py

@@ -14,6 +14,7 @@ from openhands.runtime import get_runtime_cls
 from openhands.runtime.runtime import Runtime
 from openhands.security import SecurityAnalyzer, options
 from openhands.storage.files import FileStore
+from openhands.utils.async_utils import call_sync_from_async
 
 
 class AgentSession:
@@ -102,7 +103,13 @@ class AgentSession:
     ):
         self.loop = asyncio.get_running_loop()
         self._create_security_analyzer(config.security.security_analyzer)
-        self._create_runtime(runtime_name, config, agent, status_message_callback)
+        await call_sync_from_async(
+            self._create_runtime,
+            runtime_name=runtime_name,
+            config=config,
+            agent=agent,
+            status_message_callback=status_message_callback,
+        )
         self._create_controller(
             agent,
             config.security.confirmation_mode,

+ 14 - 2
openhands/utils/async_utils.py

@@ -7,7 +7,7 @@ GENERAL_TIMEOUT: int = 15
 EXECUTOR = ThreadPoolExecutor()
 
 
-async def sync_from_async(fn: Callable, *args, **kwargs):
+async def call_sync_from_async(fn: Callable, *args, **kwargs):
     """
     Shorthand for running a function in the default background thread pool executor
     and awaiting the result. The nature of synchronous code is that the future
@@ -19,7 +19,7 @@ async def sync_from_async(fn: Callable, *args, **kwargs):
     return result
 
 
-def async_from_sync(
+def call_async_from_sync(
     corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
 ):
     """
@@ -27,6 +27,11 @@ def async_from_sync(
     and awaiting the result
     """
 
+    if corofn is None:
+        raise ValueError('corofn is None')
+    if not asyncio.iscoroutinefunction(corofn):
+        raise ValueError('corofn is not a coroutine function')
+
     async def arun():
         coro = corofn(*args, **kwargs)
         result = await coro
@@ -46,6 +51,13 @@ def async_from_sync(
     return result
 
 
+async def call_coro_in_bg_thread(
+    corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
+):
+    """Function for running a coroutine in a background thread."""
+    await call_sync_from_async(call_async_from_sync, corofn, timeout, *args, **kwargs)
+
+
 async def wait_all(
     iterable: Iterable[Coroutine], timeout: int = GENERAL_TIMEOUT
 ) -> List:

+ 40 - 14
tests/unit/test_async_utils.py

@@ -1,11 +1,13 @@
 import asyncio
+import time
 
 import pytest
 
 from openhands.utils.async_utils import (
     AsyncException,
-    async_from_sync,
-    sync_from_async,
+    call_async_from_sync,
+    call_coro_in_bg_thread,
+    call_sync_from_async,
     wait_all,
 )
 
@@ -80,44 +82,44 @@ async def test_await_all_timeout():
 
 
 @pytest.mark.asyncio
-async def test_sync_from_async():
+async def test_call_sync_from_async():
     def dummy(value: int = 2):
         return value * 2
 
-    result = await sync_from_async(dummy)
+    result = await call_sync_from_async(dummy)
     assert result == 4
-    result = await sync_from_async(dummy, 3)
+    result = await call_sync_from_async(dummy, 3)
     assert result == 6
-    result = await sync_from_async(dummy, value=5)
+    result = await call_sync_from_async(dummy, value=5)
     assert result == 10
 
 
 @pytest.mark.asyncio
-async def test_sync_from_async_error():
+async def test_call_sync_from_async_error():
     def dummy():
         raise ValueError()
 
     with pytest.raises(ValueError):
-        await sync_from_async(dummy)
+        await call_sync_from_async(dummy)
 
 
-def test_async_from_sync():
+def test_call_async_from_sync():
     async def dummy(value: int):
         return value * 2
 
-    result = async_from_sync(dummy, 0, 3)
+    result = call_async_from_sync(dummy, 0, 3)
     assert result == 6
 
 
-def test_async_from_sync_error():
+def test_call_async_from_sync_error():
     async def dummy(value: int):
         raise ValueError()
 
     with pytest.raises(ValueError):
-        async_from_sync(dummy, 0, 3)
+        call_async_from_sync(dummy, 0, 3)
 
 
-def test_async_from_sync_background_tasks():
+def test_call_async_from_sync_background_tasks():
     events = []
 
     async def bg_task():
@@ -132,9 +134,33 @@ def test_async_from_sync_background_tasks():
         asyncio.create_task(bg_task())
         events.append('dummy_started')
 
-    async_from_sync(dummy, 0, 3)
+    call_async_from_sync(dummy, 0, 3)
 
     # We check that the function did not return until all coroutines completed
     # (Even though some of these were started as background tasks)
     expected = ['dummy_started', 'dummy_started', 'bg_started', 'bg_finished']
     assert expected == events
+
+
+@pytest.mark.asyncio
+async def test_call_coro_in_bg_thread():
+    times = {}
+
+    async def bad_async(id_):
+        # Dummy demonstrating some bad async function that does not cede control
+        time.sleep(0.1)
+        times[id_] = time.time()
+
+    async def curve_ball():
+        # A curve ball - an async function that wants to run while the bad async functions are in progress
+        await asyncio.sleep(0.05)
+        times['curve_ball'] = time.time()
+
+    start = time.time()
+    asyncio.create_task(curve_ball())
+    await wait_all(
+        call_coro_in_bg_thread(bad_async, id_=f'bad_async_{id_}') for id_ in range(5)
+    )
+    assert (times['curve_ball'] - start) == pytest.approx(0.05, abs=0.1)
+    for id_ in range(5):
+        assert (times[f'bad_async_{id_}'] - start) == pytest.approx(0.1, abs=0.1)