Bläddra i källkod

Feat: Async Goodies for OpenHands (#4347)

tofarr 1 år sedan
förälder
incheckning
4c5e2a339f

+ 12 - 9
openhands/runtime/client/client.py

@@ -55,6 +55,7 @@ from openhands.runtime.plugins import (
 )
 from openhands.runtime.utils import split_bash_commands
 from openhands.runtime.utils.files import insert_lines, read_lines
+from openhands.utils.async_utils import wait_all
 
 
 class ActionRequest(BaseModel):
@@ -108,15 +109,7 @@ class RuntimeClient:
         return self._initial_pwd
 
     async def ainit(self):
-        for plugin in self.plugins_to_load:
-            await plugin.initialize(self.username)
-            self.plugins[plugin.name] = plugin
-            logger.info(f'Initializing plugin: {plugin.name}')
-
-            if isinstance(plugin, JupyterPlugin):
-                await self.run_ipython(
-                    IPythonRunCellAction(code=f'import os; os.chdir("{self.pwd}")')
-                )
+        await wait_all(self._init_plugin(plugin) for plugin in self.plugins_to_load)
 
         # This is a temporary workaround
         # TODO: refactor AgentSkills to be part of JupyterPlugin
@@ -132,6 +125,16 @@ class RuntimeClient:
         await self._init_bash_commands()
         logger.info('Runtime client initialized.')
 
+    async def _init_plugin(self, plugin: Plugin):
+        await plugin.initialize(self.username)
+        self.plugins[plugin.name] = plugin
+        logger.info(f'Initializing plugin: {plugin.name}')
+
+        if isinstance(plugin, JupyterPlugin):
+            await self.run_ipython(
+                IPythonRunCellAction(code=f'import os; os.chdir("{self.pwd}")')
+            )
+
     def _init_user(self, username: str, user_id: int) -> None:
         """Create working directory and user if not exists.
         It performs the following steps effectively:

+ 2 - 8
openhands/runtime/runtime.py

@@ -1,4 +1,3 @@
-import asyncio
 import atexit
 import copy
 import json
@@ -29,6 +28,7 @@ from openhands.events.observation import (
 )
 from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
 from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
+from openhands.utils.async_utils import sync_from_async
 
 
 def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
@@ -118,7 +118,7 @@ class Runtime:
             if event.timeout is None:
                 event.timeout = self.config.sandbox.timeout
             assert event.timeout is not None
-            observation = await self.async_run_action(event)
+            observation = await sync_from_async(self.run_action, event)
             observation._cause = event.id  # type: ignore[attr-defined]
             source = event.source if event.source else EventSource.AGENT
             await self.event_stream.async_add_event(observation, source)  # type: ignore[arg-type]
@@ -152,12 +152,6 @@ class Runtime:
         observation = getattr(self, action_type)(action)
         return observation
 
-    async def async_run_action(self, action: Action) -> Observation:
-        observation = await asyncio.get_event_loop().run_in_executor(
-            None, self.run_action, action
-        )
-        return observation
-
     # ====================================================================
     # Context manager
     # ====================================================================

+ 2 - 2
openhands/security/invariant/analyzer.py

@@ -1,4 +1,3 @@
-import asyncio
 import re
 import uuid
 from typing import Any
@@ -20,6 +19,7 @@ from openhands.runtime.utils import find_available_tcp_port
 from openhands.security.analyzer import SecurityAnalyzer
 from openhands.security.invariant.client import InvariantClient
 from openhands.security.invariant.parser import TraceElement, parse_element
+from openhands.utils.async_utils import sync_from_async
 
 
 class InvariantAnalyzer(SecurityAnalyzer):
@@ -146,7 +146,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
             {'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
         )
         event_source = event.source if event.source else EventSource.AGENT
-        await asyncio.get_event_loop().run_in_executor(None, self.event_stream.add_event, new_event, event_source)
+        await sync_from_async(self.event_stream.add_event, new_event, event_source)
 
     async def security_risk(self, event: Action) -> ActionSecurityRisk:
         logger.info('Calling security_risk on InvariantAnalyzer')

+ 4 - 5
openhands/server/listen.py

@@ -14,6 +14,7 @@ from pathspec.patterns import GitWildMatchPattern
 from openhands.security.options import SecurityAnalyzers
 from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
 from openhands.storage import get_file_store
+from openhands.utils.async_utils import sync_from_async
 
 with warnings.catch_warnings():
     warnings.simplefilter('ignore')
@@ -439,9 +440,7 @@ async def list_files(request: Request, path: str | None = None):
             content={'error': 'Runtime not yet initialized'},
         )
     runtime: Runtime = request.state.session.agent_session.runtime
-    file_list = await asyncio.get_event_loop().run_in_executor(
-        None, runtime.list_files, path
-    )
+    file_list = await sync_from_async(runtime.list_files, path)
     if path:
         file_list = [os.path.join(path, f) for f in file_list]
 
@@ -490,7 +489,7 @@ async def select_file(file: str, request: Request):
 
     file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
     read_action = FileReadAction(file)
-    observation = await runtime.async_run_action(read_action)
+    observation = await sync_from_async(runtime.run_action, read_action)
 
     if isinstance(observation, FileReadObservation):
         content = observation.content
@@ -732,7 +731,7 @@ async def save_file(request: Request):
             runtime.config.workspace_mount_path_in_sandbox, file_path
         )
         write_action = FileWriteAction(file_path, content)
-        observation = await runtime.async_run_action(write_action)
+        observation = await sync_from_async(runtime.run_action, write_action)
 
         if isinstance(observation, FileWriteObservation):
             return JSONResponse(

+ 85 - 0
openhands/utils/async_utils.py

@@ -0,0 +1,85 @@
+import asyncio
+from concurrent import futures
+from concurrent.futures import ThreadPoolExecutor
+from typing import Callable, Coroutine, Iterable, List
+
+GENERAL_TIMEOUT: int = 15
+EXECUTOR = ThreadPoolExecutor()
+
+
+async def sync_from_async(fn: Callable, *args, **kwargs):
+    """
+    Shorthand for running a function in the default background thread pool executor
+    and awaiting the result. The nature of synchronous code is that the future
+    returned by this function is not cancellable
+    """
+    loop = asyncio.get_event_loop()
+    coro = loop.run_in_executor(None, lambda: fn(*args, **kwargs))
+    result = await coro
+    return result
+
+
+def async_from_sync(
+    corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
+):
+    """
+    Shorthand for running a coroutine in the default background thread pool executor
+    and awaiting the result
+    """
+
+    async def arun():
+        coro = corofn(*args, **kwargs)
+        result = await coro
+        return result
+
+    def run():
+        loop_for_thread = asyncio.new_event_loop()
+        try:
+            asyncio.set_event_loop(loop_for_thread)
+            return asyncio.run(arun())
+        finally:
+            loop_for_thread.close()
+
+    future = EXECUTOR.submit(run)
+    futures.wait([future], timeout=timeout or None)
+    result = future.result()
+    return result
+
+
+async def wait_all(
+    iterable: Iterable[Coroutine], timeout: int = GENERAL_TIMEOUT
+) -> List:
+    """
+    Shorthand for waiting for all the coroutines in the iterable given in parallel. Creates
+    a task for each coroutine.
+    Returns a list of results in the original order. If any single task raised an exception, this is raised.
+    If multiple tasks raised exceptions, an AsyncException is raised containing all exceptions.
+    """
+    tasks = [asyncio.create_task(c) for c in iterable]
+    if not tasks:
+        return []
+    _, pending = await asyncio.wait(tasks, timeout=timeout)
+    if pending:
+        for task in pending:
+            task.cancel()
+        raise asyncio.TimeoutError()
+    results = []
+    errors = []
+    for task in tasks:
+        try:
+            results.append(task.result())
+        except Exception as e:
+            errors.append(e)
+    if errors:
+        if len(errors) == 1:
+            raise errors[0]
+        raise AsyncException(errors)
+    return [task.result() for task in tasks]
+
+
+class AsyncException(Exception):
+    def __init__(self, exceptions):
+        self.exceptions = exceptions
+
+    def __str__(self):
+        return '\n'.join(str(e) for e in self.exceptions)

+ 140 - 0
tests/unit/test_async_utils.py

@@ -0,0 +1,140 @@
+import asyncio
+
+import pytest
+
+from openhands.utils.async_utils import (
+    AsyncException,
+    async_from_sync,
+    sync_from_async,
+    wait_all,
+)
+
+
+@pytest.mark.asyncio
+async def test_await_all():
+    # Mock function demonstrating some calculation - always takes a minimum of 0.1 seconds
+    async def dummy(value: int):
+        await asyncio.sleep(0.1)
+        return value * 2
+
+    # wait for 10 calculations - serially this would take 1 second
+    coro = wait_all(dummy(i) for i in range(10))
+
+    # give the task only 0.3 seconds to complete (This verifies they occur in parallel)
+    task = asyncio.create_task(coro)
+    await asyncio.wait([task], timeout=0.3)
+
+    # validate the results (We need to sort because they can return in any order)
+    results = list(await task)
+    expected = [i * 2 for i in range(10)]
+    assert expected == results
+
+
+@pytest.mark.asyncio
+async def test_await_all_single_exception():
+    # Mock function demonstrating some calculation - always takes a minimum of 0.1 seconds
+    async def dummy(value: int):
+        await asyncio.sleep(0.1)
+        if value == 1:
+            raise ValueError('Invalid value 1')  # Throw an exception on every odd value
+        return value * 2
+
+    # expect an exception to be raised.
+    with pytest.raises(ValueError, match='Invalid value 1'):
+        await wait_all(dummy(i) for i in range(10))
+
+
+@pytest.mark.asyncio
+async def test_await_all_multi_exception():
+    # Mock function demonstrating some calculation - always takes a minimum of 0.1 seconds
+    async def dummy(value: int):
+        await asyncio.sleep(0.1)
+        if value & 1:
+            raise ValueError(
+                f'Invalid value {value}'
+            )  # Throw an exception on every odd value
+        return value * 2
+
+    # expect an exception to be raised.
+    with pytest.raises(AsyncException):
+        await wait_all(dummy(i) for i in range(10))
+
+
+@pytest.mark.asyncio
+async def test_await_all_timeout():
+    result = 0
+
+    # Mock function updates a nonlocal variable after a delay
+    async def dummy(value: int):
+        nonlocal result
+        await asyncio.sleep(0.2)
+        result += value
+
+    # expect an exception to be raised.
+    with pytest.raises(asyncio.TimeoutError):
+        await wait_all((dummy(i) for i in range(10)), 0.1)
+
+    # Wait and then check the shared result - this makes sure that pending tasks were cancelled.
+    asyncio.sleep(0.2)
+    assert result == 0
+
+
+@pytest.mark.asyncio
+async def test_sync_from_async():
+    def dummy(value: int = 2):
+        return value * 2
+
+    result = await sync_from_async(dummy)
+    assert result == 4
+    result = await sync_from_async(dummy, 3)
+    assert result == 6
+    result = await sync_from_async(dummy, value=5)
+    assert result == 10
+
+
+@pytest.mark.asyncio
+async def test_sync_from_async_error():
+    def dummy():
+        raise ValueError()
+
+    with pytest.raises(ValueError):
+        await sync_from_async(dummy)
+
+
+def test_async_from_sync():
+    async def dummy(value: int):
+        return value * 2
+
+    result = async_from_sync(dummy, 0, 3)
+    assert result == 6
+
+
+def test_async_from_sync_error():
+    async def dummy(value: int):
+        raise ValueError()
+
+    with pytest.raises(ValueError):
+        async_from_sync(dummy, 0, 3)
+
+
+def test_async_from_sync_background_tasks():
+    events = []
+
+    async def bg_task():
+        # This background task should finish after the dummy task
+        events.append('bg_started')
+        asyncio.sleep(0.2)
+        events.append('bg_finished')
+
+    async def dummy(value: int):
+        events.append('dummy_started')
+        # This coroutine kicks off a background task
+        asyncio.create_task(bg_task())
+        events.append('dummy_started')
+
+    async_from_sync(dummy, 0, 3)
+
+    # We check that the function did not return until all coroutines completed
+    # (Even though some of these were started as background tasks)
+    expected = ['dummy_started', 'dummy_started', 'bg_started', 'bg_finished']
+    assert expected == events