Просмотр исходного кода

arch: refractor eventstream into async (#2907)

* deprecating recall action

* fix integration tests

* fix integration tests

* refractor runtime to use async

* remove search memory

* rename .initialize to .ainit
Xingyao Wang 1 год назад
Родитель
Сommit
7e68de746d

+ 2 - 1
opendevin/core/main.py

@@ -81,6 +81,7 @@ async def run_agent_controller(
     # runtime and tools
     runtime_cls = get_runtime_cls(config.runtime)
     runtime = runtime_cls(event_stream=event_stream, sandbox=sandbox)
+    await runtime.ainit()
     runtime.init_sandbox_plugins(controller.agent.sandbox_plugins)
     runtime.init_runtime_tools(
         controller.agent.runtime_tools,
@@ -140,7 +141,7 @@ async def run_agent_controller(
 
     # close when done
     await controller.close()
-    runtime.close()
+    await runtime.close()
     return controller.get_state()
 
 

+ 69 - 81
opendevin/runtime/client/runtime.py

@@ -1,5 +1,4 @@
 import asyncio
-import atexit
 import uuid
 from typing import Optional
 
@@ -9,7 +8,7 @@ import tenacity
 
 from opendevin.core.config import config
 from opendevin.core.logger import opendevin_logger as logger
-from opendevin.events import EventSource, EventStream, EventStreamSubscriber
+from opendevin.events import EventSource, EventStream
 from opendevin.events.action import (
     BrowseInteractiveAction,
     BrowseURLAction,
@@ -38,8 +37,9 @@ from opendevin.runtime.utils.image_agnostic import get_od_sandbox_image
 
 
 class EventStreamRuntime(Runtime):
-    # This runtime will subscribe the event stream
-    # When receive an event, it will send the event to od-runtime-client which run inside the docker environment
+    """This runtime will subscribe the event stream.
+    When receive an event, it will send the event to od-runtime-client which run inside the docker environment.
+    """
 
     container_name_prefix = 'opendevin-sandbox-'
 
@@ -50,6 +50,7 @@ class EventStreamRuntime(Runtime):
         container_image: str | None = None,
         plugins: list[PluginRequirement] | None = None,
     ):
+        super().__init__(event_stream, sid)  # will initialize the event stream
         self._port = find_available_tcp_port()
         self.api_url = f'http://localhost:{self._port}'
         self.session: Optional[aiohttp.ClientSession] = None
@@ -57,25 +58,27 @@ class EventStreamRuntime(Runtime):
         self.instance_id = (
             sid + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
         )
+        # TODO: We can switch to aiodocker when `get_od_sandbox_image` is updated to use aiodocker
         self.docker_client: docker.DockerClient = self._init_docker_client()
         self.container_image = (
             config.sandbox.container_image
             if container_image is None
             else container_image
         )
-        self.container_image = get_od_sandbox_image(
-            self.container_image, self.docker_client, is_eventstream_runtime=True
-        )
         self.container_name = self.container_name_prefix + self.instance_id
-        atexit.register(self.close)
 
-        # We don't need sandbox in this runtime, because it's equal to a websocket sandbox
-        self._init_event_stream(event_stream)
         self.plugins = plugins if plugins is not None else []
-        self.container = self._init_container(
+        self.container = None
+        self.action_semaphore = asyncio.Semaphore(1)  # Ensure one action at a time
+
+    async def ainit(self):
+        self.container_image = get_od_sandbox_image(
+            self.container_image, self.docker_client, is_eventstream_runtime=True
+        )
+        self.container = await self._init_container(
             self.sandbox_workspace_dir,
             mount_dir=config.workspace_mount_path,
-            plugins=plugins,
+            plugins=self.plugins,
         )
 
     @staticmethod
@@ -92,21 +95,13 @@ class EventStreamRuntime(Runtime):
         stop=tenacity.stop_after_attempt(5),
         wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
     )
-    def _init_container(
+    async def _init_container(
         self,
         sandbox_workspace_dir: str,
         mount_dir: str = config.workspace_mount_path,
         plugins: list[PluginRequirement] | None = None,
     ):
-        """Start a container and return the container object.
-
-        Args:
-            mount_dir: str: The directory (on host machine) to mount to the container
-            sandbox_workspace_dir: str: working directory in the container, also the target directory for the mount
-        """
-
         try:
-            # start the container
             logger.info(
                 f'Starting container with image: {self.container_image} and name: {self.container_name}'
             )
@@ -120,9 +115,8 @@ class EventStreamRuntime(Runtime):
                     'PYTHONUNBUFFERED=1 poetry run '
                     f'python -u -m opendevin.runtime.client.client {self._port} '
                     f'--working-dir {sandbox_workspace_dir} '
-                    f'--plugins {plugin_names} '
+                    f'--plugins {plugin_names}'
                 ),
-                # TODO: test it in mac and linux
                 network_mode='host',
                 working_dir='/opendevin/code/',
                 name=self.container_name,
@@ -134,13 +128,9 @@ class EventStreamRuntime(Runtime):
         except Exception as e:
             logger.error('Failed to start container')
             logger.exception(e)
-            self.close(close_client=False)
+            await self.close(close_client=False)
             raise e
 
-    def _init_event_stream(self, event_stream: EventStream):
-        self.event_stream = event_stream
-        self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
-
     async def _ensure_session(self):
         if self.session is None or self.session.closed:
             self.session = aiohttp.ClientSession()
@@ -167,14 +157,16 @@ class EventStreamRuntime(Runtime):
     def sandbox_workspace_dir(self):
         return config.workspace_mount_path_in_sandbox
 
-    def close(self, close_client: bool = True):
+    async def close(self, close_client: bool = True):
+        if self.session is not None and not self.session.closed:
+            await self.session.close()
+
         containers = self.docker_client.containers.list(all=True)
         for container in containers:
             try:
                 if container.name.startswith(self.container_name_prefix):
-                    # tail the logs before removing the container
                     logs = container.logs(tail=1000).decode('utf-8')
-                    logger.info(
+                    logger.debug(
                         f'==== Container logs ====\n{logs}\n==== End of container logs ===='
                     )
                     container.remove(force=True)
@@ -188,55 +180,50 @@ class EventStreamRuntime(Runtime):
         if isinstance(event, Action):
             logger.info(event, extra={'msg_type': 'ACTION'})
             observation = await self.run_action(event)
-            logger.info(observation, extra={'msg_type': 'OBSERVATION'})
             # observation._cause = event.id  # type: ignore[attr-defined]
+            logger.info(observation, extra={'msg_type': 'OBSERVATION'})
             source = event.source if event.source else EventSource.AGENT
             await self.event_stream.add_event(observation, source)
 
     async def run_action(self, action: Action, timeout: int = 600) -> Observation:
-        """
-        Run an action and return the resulting observation.
-        If the action is not runnable in any runtime, a NullObservation is returned.
-        If the action is not supported by the current runtime, an ErrorObservation is returned.
-        We will filter some action and execute in runtime. Pass others into od-runtime-client
-        """
-        if not action.runnable:
-            return NullObservation('')
-        action_type = action.action  # type: ignore[attr-defined]
-        if action_type not in ACTION_TYPE_TO_CLASS:
-            return ErrorObservation(f'Action {action_type} does not exist.')
-        if not hasattr(self, action_type):
-            return ErrorObservation(
-                f'Action {action_type} is not supported in the current runtime.'
-            )
-
-        # Run action in od-runtime-client
-        session = await self._ensure_session()
-        await self._wait_until_alive()
-        try:
-            async with session.post(
-                f'{self.api_url}/execute_action',
-                json={'action': event_to_dict(action)},
-                timeout=timeout,
-            ) as response:
-                if response.status == 200:
-                    output = await response.json()
-                    obs = observation_from_dict(output)
-                    obs._cause = action.id  # type: ignore[attr-defined]
-                    return obs
-                else:
-                    error_message = await response.text()
-                    logger.error(f'Error from server: {error_message}')
-                    obs = ErrorObservation(f'Command execution failed: {error_message}')
-        except asyncio.TimeoutError:
-            logger.error('No response received within the timeout period.')
-            obs = ErrorObservation('Command execution timed out')
-        except Exception as e:
-            logger.error(f'Error during command execution: {e}')
-            obs = ErrorObservation(f'Command execution failed: {str(e)}')
-        # TODO: fix ID problem, see comments https://github.com/OpenDevin/OpenDevin/pull/2603#discussion_r1668994137
-        obs._parent = action.id  # type: ignore[attr-defined]
-        return obs
+        async with self.action_semaphore:
+            if not action.runnable:
+                return NullObservation('')
+            action_type = action.action  # type: ignore[attr-defined]
+            if action_type not in ACTION_TYPE_TO_CLASS:
+                return ErrorObservation(f'Action {action_type} does not exist.')
+            if not hasattr(self, action_type):
+                return ErrorObservation(
+                    f'Action {action_type} is not supported in the current runtime.'
+                )
+
+            session = await self._ensure_session()
+            await self._wait_until_alive()
+            try:
+                async with session.post(
+                    f'{self.api_url}/execute_action',
+                    json={'action': event_to_dict(action)},
+                    timeout=timeout,
+                ) as response:
+                    if response.status == 200:
+                        output = await response.json()
+                        obs = observation_from_dict(output)
+                        obs._cause = action.id  # type: ignore[attr-defined]
+                        return obs
+                    else:
+                        error_message = await response.text()
+                        logger.error(f'Error from server: {error_message}')
+                        obs = ErrorObservation(
+                            f'Command execution failed: {error_message}'
+                        )
+            except asyncio.TimeoutError:
+                logger.error('No response received within the timeout period.')
+                obs = ErrorObservation('Command execution timed out')
+            except Exception as e:
+                logger.error(f'Error during command execution: {e}')
+                obs = ErrorObservation(f'Command execution failed: {str(e)}')
+            obs._parent = action.id  # type: ignore[attr-defined]
+            return obs
 
     async def run(self, action: CmdRunAction) -> Observation:
         return await self.run_action(action)
@@ -261,10 +248,6 @@ class EventStreamRuntime(Runtime):
     ############################################################################
 
     def get_working_directory(self):
-        # FIXME: this is not needed for the agent - we keep this
-        # method to be consistent with the other runtimes
-        # but eventually we will remove this method across all runtimes
-        # when we use EventStreamRuntime to replace the other sandbox-based runtime
         raise NotImplementedError(
             'This method is not implemented in the runtime client.'
         )
@@ -281,12 +264,13 @@ class EventStreamRuntime(Runtime):
         pass
 
 
-def test_run_command():
+async def test_run_command():
     sid = 'test'
     cli_session = 'main' + ('_' + sid if sid else '')
     event_stream = EventStream(cli_session)
     runtime = EventStreamRuntime(event_stream)
-    asyncio.run(runtime.run_action(CmdRunAction('ls -l')))
+    await runtime.ainit()
+    await runtime.run_action(CmdRunAction('ls -l'))
 
 
 async def test_event_stream():
@@ -299,6 +283,8 @@ async def test_event_stream():
         'ubuntu:22.04',
         plugins=[JupyterRequirement(), AgentSkillsRequirement()],
     )
+    await runtime.ainit()
+
     # Test run command
     action_cmd = CmdRunAction(command='ls -l')
     logger.info(action_cmd, extra={'msg_type': 'ACTION'})
@@ -340,6 +326,8 @@ async def test_event_stream():
         await runtime.run_action(action_browse), extra={'msg_type': 'OBSERVATION'}
     )
 
+    await runtime.close()
+
 
 if __name__ == '__main__':
     asyncio.run(test_event_stream())

+ 41 - 53
opendevin/runtime/runtime.py

@@ -1,9 +1,8 @@
+import asyncio
+import atexit
 from abc import abstractmethod
 from typing import Any, Optional
 
-from opendevin.core.config import config
-from opendevin.core.exceptions import BrowserInitException
-from opendevin.core.logger import opendevin_logger as logger
 from opendevin.events import EventStream, EventStreamSubscriber
 from opendevin.events.action import (
     Action,
@@ -23,27 +22,9 @@ from opendevin.events.observation import (
     RejectObservation,
 )
 from opendevin.events.serialization.action import ACTION_TYPE_TO_CLASS
-from opendevin.runtime import (
-    DockerSSHBox,
-    E2BBox,
-    LocalBox,
-    Sandbox,
-)
-from opendevin.runtime.browser.browser_env import BrowserEnv
 from opendevin.runtime.plugins import PluginRequirement
 from opendevin.runtime.tools import RuntimeTool
-from opendevin.storage import FileStore, InMemoryFileStore
-
-
-def create_sandbox(sid: str = 'default', box_type: str = 'ssh') -> Sandbox:
-    if box_type == 'local':
-        return LocalBox()
-    elif box_type == 'ssh':
-        return DockerSSHBox(sid=sid)
-    elif box_type == 'e2b':
-        return E2BBox()
-    else:
-        raise ValueError(f'Invalid sandbox type: {box_type}')
+from opendevin.storage import FileStore
 
 
 class Runtime:
@@ -57,32 +38,42 @@ class Runtime:
     sid: str
     file_store: FileStore
 
-    def __init__(
-        self,
-        event_stream: EventStream,
-        sid: str = 'default',
-        sandbox: Sandbox | None = None,
-    ):
+    def __init__(self, event_stream: EventStream, sid: str = 'default'):
         self.sid = sid
-        if sandbox is None:
-            self.sandbox = create_sandbox(sid, config.sandbox.box_type)
-            self._is_external_sandbox = False
-        else:
-            self.sandbox = sandbox
-            self._is_external_sandbox = True
-        self.browser: BrowserEnv | None = None
-        self.file_store = InMemoryFileStore()
         self.event_stream = event_stream
         self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
+        atexit.register(self.close_sync)
 
-    def close(self):
-        if not self._is_external_sandbox:
-            self.sandbox.close()
-        if self.browser is not None:
-            self.browser.close()
+    async def ainit(self) -> None:
+        """
+        Initialize the runtime (asynchronously).
+        This method should be called after the runtime's constructor.
+        """
+        pass
+
+    async def close(self) -> None:
+        pass
+
+    def close_sync(self) -> None:
+        try:
+            loop = asyncio.get_running_loop()
+        except RuntimeError:
+            # No running event loop, use asyncio.run()
+            asyncio.run(self.close())
+        else:
+            # There is a running event loop, create a task
+            if loop.is_running():
+                loop.create_task(self.close())
+            else:
+                loop.run_until_complete(self.close())
+
+    # ====================================================================
+    # Methods we plan to deprecate when we move to new EventStreamRuntime
+    # ====================================================================
 
     def init_sandbox_plugins(self, plugins: list[PluginRequirement]) -> None:
-        self.sandbox.init_plugins(plugins)
+        # TODO: deprecate this method when we move to the new EventStreamRuntime
+        raise NotImplementedError('This method is not implemented in the base class.')
 
     def init_runtime_tools(
         self,
@@ -90,17 +81,10 @@ class Runtime:
         runtime_tools_config: Optional[dict[RuntimeTool, Any]] = None,
         is_async: bool = True,
     ) -> None:
-        # if browser in runtime_tools, init it
-        if RuntimeTool.BROWSER in runtime_tools:
-            if runtime_tools_config is None:
-                runtime_tools_config = {}
-            browser_env_config = runtime_tools_config.get(RuntimeTool.BROWSER, {})
-            try:
-                self.browser = BrowserEnv(is_async=is_async, **browser_env_config)
-            except BrowserInitException:
-                logger.warn(
-                    'Failed to start browser environment, web browsing functionality will not work'
-                )
+        # TODO: deprecate this method when we move to the new EventStreamRuntime
+        raise NotImplementedError('This method is not implemented in the base class.')
+
+    # ====================================================================
 
     async def on_event(self, event: Event) -> None:
         if isinstance(event, Action):
@@ -139,6 +123,10 @@ class Runtime:
         observation._parent = action.id  # type: ignore[attr-defined]
         return observation
 
+    # ====================================================================
+    # Implement these methods in the subclass
+    # ====================================================================
+
     @abstractmethod
     async def run(self, action: CmdRunAction) -> Observation:
         pass

+ 62 - 2
opendevin/runtime/server/runtime.py

@@ -1,4 +1,8 @@
+from typing import Any, Optional
+
 from opendevin.core.config import config
+from opendevin.core.exceptions import BrowserInitException
+from opendevin.core.logger import opendevin_logger as logger
 from opendevin.events.action import (
     BrowseInteractiveAction,
     BrowseURLAction,
@@ -14,14 +18,33 @@ from opendevin.events.observation import (
     Observation,
 )
 from opendevin.events.stream import EventStream
-from opendevin.runtime import Sandbox
+from opendevin.runtime import (
+    DockerSSHBox,
+    E2BBox,
+    LocalBox,
+    Sandbox,
+)
+from opendevin.runtime.browser.browser_env import BrowserEnv
+from opendevin.runtime.plugins import PluginRequirement
 from opendevin.runtime.runtime import Runtime
+from opendevin.runtime.tools import RuntimeTool
 from opendevin.storage.local import LocalFileStore
 
 from ..browser import browse
 from .files import read_file, write_file
 
 
+def create_sandbox(sid: str = 'default', box_type: str = 'ssh') -> Sandbox:
+    if box_type == 'local':
+        return LocalBox()
+    elif box_type == 'ssh':
+        return DockerSSHBox(sid=sid)
+    elif box_type == 'e2b':
+        return E2BBox()
+    else:
+        raise ValueError(f'Invalid sandbox type: {box_type}')
+
+
 class ServerRuntime(Runtime):
     def __init__(
         self,
@@ -29,8 +52,45 @@ class ServerRuntime(Runtime):
         sid: str = 'default',
         sandbox: Sandbox | None = None,
     ):
-        super().__init__(event_stream, sid, sandbox)
+        super().__init__(event_stream, sid)
         self.file_store = LocalFileStore(config.workspace_base)
+        if sandbox is None:
+            self.sandbox = create_sandbox(sid, config.sandbox.box_type)
+            self._is_external_sandbox = False
+        else:
+            self.sandbox = sandbox
+            self._is_external_sandbox = True
+        self.browser: BrowserEnv | None = None
+
+    async def ainit(self) -> None:
+        pass
+
+    async def close(self):
+        if not self._is_external_sandbox:
+            self.sandbox.close()
+        if self.browser is not None:
+            self.browser.close()
+
+    def init_sandbox_plugins(self, plugins: list[PluginRequirement]) -> None:
+        self.sandbox.init_plugins(plugins)
+
+    def init_runtime_tools(
+        self,
+        runtime_tools: list[RuntimeTool],
+        runtime_tools_config: Optional[dict[RuntimeTool, Any]] = None,
+        is_async: bool = True,
+    ) -> None:
+        # if browser in runtime_tools, init it
+        if RuntimeTool.BROWSER in runtime_tools:
+            if runtime_tools_config is None:
+                runtime_tools_config = {}
+            browser_env_config = runtime_tools_config.get(RuntimeTool.BROWSER, {})
+            try:
+                self.browser = BrowserEnv(is_async=is_async, **browser_env_config)
+            except BrowserInitException:
+                logger.warn(
+                    'Failed to start browser environment, web browsing functionality will not work'
+                )
 
     async def run(self, action: CmdRunAction) -> Observation:
         return self._run_command(action.command)

+ 7 - 2
opendevin/server/session/agent.py

@@ -11,6 +11,7 @@ from opendevin.events.stream import EventStream
 from opendevin.llm.llm import LLM
 from opendevin.runtime import DockerSSHBox, get_runtime_cls
 from opendevin.runtime.runtime import Runtime
+from opendevin.runtime.server.runtime import ServerRuntime
 
 
 class AgentSession:
@@ -52,7 +53,7 @@ class AgentSession:
             end_state.save_to_session(self.sid)
             await self.controller.close()
         if self.runtime is not None:
-            self.runtime.close()
+            await self.runtime.close()
         self._closed = True
 
     async def _create_runtime(self):
@@ -62,6 +63,7 @@ class AgentSession:
         logger.info(f'Using runtime: {config.runtime}')
         runtime_cls = get_runtime_cls(config.runtime)
         self.runtime = runtime_cls(self.event_stream, self.sid)
+        await self.runtime.ainit()
 
     async def _create_controller(self, start_event: dict):
         """Creates an AgentController instance.
@@ -92,7 +94,10 @@ class AgentSession:
         llm = LLM(model=model, api_key=api_key, base_url=api_base)
         agent = Agent.get_cls(agent_cls)(llm)
         if isinstance(agent, CodeActAgent):
-            if not self.runtime or not isinstance(self.runtime.sandbox, DockerSSHBox):
+            if not self.runtime or not (
+                isinstance(self.runtime, ServerRuntime)
+                and isinstance(self.runtime.sandbox, DockerSSHBox)
+            ):
                 logger.warning(
                     'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful'
                     ' LocalBox will not work properly.'