Преглед изворни кода

Feat add agent manager (#904)

* feat: add agent manager to manage all agents;

* extract the host ssh port to prevent conflict.

* clean all containers with prefix is sandbox-

* merge from upstream/main

* merge from upstream/main

* Update frontend/src/state/settingsSlice.ts

* Update opendevin/sandbox/ssh_box.py

* Update opendevin/sandbox/exec_box.py

---------

Co-authored-by: Robert Brennan <accounts@rbren.io>
Leo пре 1 година
родитељ
комит
494a1b6872

+ 9 - 9
opendevin/agent.py

@@ -6,6 +6,7 @@ if TYPE_CHECKING:
     from opendevin.state import State
 from opendevin.llm.llm import LLM
 
+
 class Agent(ABC):
     """
     This abstract base class is an general interface for an agent dedicated to
@@ -14,11 +15,11 @@ class Agent(ABC):
     It tracks the execution status and maintains a history of interactions.
     """
 
-    _registry: Dict[str, Type["Agent"]] = {}
+    _registry: Dict[str, Type['Agent']] = {}
 
     def __init__(
-        self,
-        llm: LLM,
+            self,
+            llm: LLM,
     ):
         self.llm = llm
         self._complete = False
@@ -34,7 +35,7 @@ class Agent(ABC):
         return self._complete
 
     @abstractmethod
-    def step(self, state: "State") -> "Action":
+    def step(self, state: 'State') -> 'Action':
         """
         Starts the execution of the assigned instruction. This method should
         be implemented by subclasses to define the specific execution logic.
@@ -63,7 +64,7 @@ class Agent(ABC):
         self._complete = False
 
     @classmethod
-    def register(cls, name: str, agent_cls: Type["Agent"]):
+    def register(cls, name: str, agent_cls: Type['Agent']):
         """
         Registers an agent class in the registry.
 
@@ -76,7 +77,7 @@ class Agent(ABC):
         cls._registry[name] = agent_cls
 
     @classmethod
-    def get_cls(cls, name: str) -> Type["Agent"]:
+    def get_cls(cls, name: str) -> Type['Agent']:
         """
         Retrieves an agent class from the registry.
 
@@ -91,11 +92,10 @@ class Agent(ABC):
         return cls._registry[name]
 
     @classmethod
-    def listAgents(cls) -> list[str]:
+    def list_agents(cls) -> list[str]:
         """
         Retrieves the list of all agent names from the registry.
         """
         if not bool(cls._registry):
-            raise ValueError("No agent class registered.")
+            raise ValueError('No agent class registered.')
         return list(cls._registry.keys())
-        

+ 63 - 61
opendevin/controller/agent_controller.py

@@ -5,9 +5,7 @@ from typing import List, Callable, Literal, Mapping, Awaitable, Any, cast
 
 from termcolor import colored
 
-from opendevin.plan import Plan
-from opendevin.state import State
-from opendevin.agent import Agent
+from opendevin import config
 from opendevin.action import (
     Action,
     NullAction,
@@ -15,53 +13,52 @@ from opendevin.action import (
     AddTaskAction,
     ModifyTaskAction,
 )
-from opendevin.observation import Observation, AgentErrorObservation, NullObservation
-from opendevin import config
+from opendevin.agent import Agent
 from opendevin.logger import opendevin_logger as logger
-
+from opendevin.observation import Observation, AgentErrorObservation, NullObservation
+from opendevin.plan import Plan
+from opendevin.state import State
 from .command_manager import CommandManager
 
-
 ColorType = Literal[
-    'red',
-    'green',
-    'yellow',
-    'blue',
-    'magenta',
-    'cyan',
-    'light_grey',
-    'dark_grey',
-    'light_red',
-    'light_green',
-    'light_yellow',
-    'light_blue',
-    'light_magenta',
-    'light_cyan',
-    'white',
+    "red",
+    "green",
+    "yellow",
+    "blue",
+    "magenta",
+    "cyan",
+    "light_grey",
+    "dark_grey",
+    "light_red",
+    "light_green",
+    "light_yellow",
+    "light_blue",
+    "light_magenta",
+    "light_cyan",
+    "white",
 ]
 
-
 DISABLE_COLOR_PRINTING = (
-    config.get_or_default('DISABLE_COLOR', 'false').lower() == 'true'
+    config.get_or_default("DISABLE_COLOR", "false").lower() == "true"
 )
-MAX_ITERATIONS = config.get('MAX_ITERATIONS')
+MAX_ITERATIONS = config.get("MAX_ITERATIONS")
 
 
-def print_with_color(text: Any, print_type: str = 'INFO'):
+def print_with_color(text: Any, print_type: str = "INFO"):
     TYPE_TO_COLOR: Mapping[str, ColorType] = {
-        'BACKGROUND LOG': 'blue',
-        'ACTION': 'green',
-        'OBSERVATION': 'yellow',
-        'INFO': 'cyan',
-        'ERROR': 'red',
-        'PLAN': 'light_magenta',
+        "BACKGROUND LOG": "blue",
+        "ACTION": "green",
+        "OBSERVATION": "yellow",
+        "INFO": "cyan",
+        "ERROR": "red",
+        "PLAN": "light_magenta",
     }
-    color = TYPE_TO_COLOR.get(print_type.upper(), TYPE_TO_COLOR['INFO'])
+    color = TYPE_TO_COLOR.get(print_type.upper(), TYPE_TO_COLOR["INFO"])
     if DISABLE_COLOR_PRINTING:
         print(f"\n{print_type.upper()}:\n{str(text)}", flush=True)
     else:
         print(
-            colored(f"\n{print_type.upper()}:\n", color, attrs=['bold'])
+            colored(f"\n{print_type.upper()}:\n", color, attrs=["bold"])
             + colored(str(text), color),
             flush=True,
         )
@@ -69,22 +66,26 @@ def print_with_color(text: Any, print_type: str = 'INFO'):
 
 class AgentController:
     id: str
+    agent: Agent
+    max_iterations: int
+    workdir: str
+    command_manager: CommandManager
+    callbacks: List[Callable]
 
     def __init__(
         self,
         agent: Agent,
         workdir: str,
-        id: str = '',
+        sid: str = "",
         max_iterations: int = MAX_ITERATIONS,
         container_image: str | None = None,
         callbacks: List[Callable] = [],
     ):
-        self.id = id
+        self.id = sid
         self.agent = agent
         self.max_iterations = max_iterations
         self.workdir = workdir
-        self.command_manager = CommandManager(
-            self.id, workdir, container_image)
+        self.command_manager = CommandManager(self.id, workdir, container_image)
         self.callbacks = callbacks
 
     def update_state_for_step(self, i):
@@ -96,9 +97,9 @@ class AgentController:
 
     def add_history(self, action: Action, observation: Observation):
         if not isinstance(action, Action):
-            raise ValueError('action must be an instance of Action')
+            raise ValueError("action must be an instance of Action")
         if not isinstance(observation, Observation):
-            raise ValueError('observation must be an instance of Observation')
+            raise ValueError("observation must be an instance of Observation")
         self.state.history.append((action, observation))
         self.state.updated_info.append((action, observation))
 
@@ -110,38 +111,41 @@ class AgentController:
             try:
                 finished = await self.step(i)
             except Exception as e:
-                logger.error('Error in loop', exc_info=True)
+                logger.error("Error in loop", exc_info=True)
                 raise e
             if finished:
                 break
         if not finished:
-            logger.info('Exited before finishing the task.')
+            logger.info("Exited before finishing the task.")
 
     async def step(self, i: int):
-        print('\n\n==============', flush=True)
-        print('STEP', i, flush=True)
-        print_with_color(self.state.plan.main_goal, 'PLAN')
+        print("\n\n==============", flush=True)
+        print("STEP", i, flush=True)
+        print_with_color(self.state.plan.main_goal, "PLAN")
 
         log_obs = self.command_manager.get_background_obs()
         for obs in log_obs:
             self.add_history(NullAction(), obs)
             await self._run_callbacks(obs)
-            print_with_color(obs, 'BACKGROUND LOG')
+            print_with_color(obs, "BACKGROUND LOG")
 
         self.update_state_for_step(i)
         action: Action = NullAction()
-        observation: Observation = NullObservation('')
+        observation: Observation = NullObservation("")
         try:
             action = self.agent.step(self.state)
             if action is None:
-                raise ValueError('Agent must return an action')
-            print_with_color(action, 'ACTION')
+                raise ValueError("Agent must return an action")
+            print_with_color(action, "ACTION")
         except Exception as e:
             observation = AgentErrorObservation(str(e))
-            print_with_color(observation, 'ERROR')
+            print_with_color(observation, "ERROR")
             traceback.print_exc()
             # TODO Change to more robust error handling
-            if 'The api_key client option must be set' in observation.content or 'Incorrect API key provided:' in observation.content:
+            if (
+                "The api_key client option must be set" in observation.content
+                or "Incorrect API key provided:" in observation.content
+            ):
                 raise
         self.update_state_after_step()
 
@@ -149,23 +153,22 @@ class AgentController:
 
         finished = isinstance(action, AgentFinishAction)
         if finished:
-            print_with_color(action, 'INFO')
+            print_with_color(action, "INFO")
             return True
 
         if isinstance(action, AddTaskAction):
             try:
-                self.state.plan.add_subtask(
-                    action.parent, action.goal, action.subtasks)
+                self.state.plan.add_subtask(action.parent, action.goal, action.subtasks)
             except Exception as e:
                 observation = AgentErrorObservation(str(e))
-                print_with_color(observation, 'ERROR')
+                print_with_color(observation, "ERROR")
                 traceback.print_exc()
         elif isinstance(action, ModifyTaskAction):
             try:
                 self.state.plan.set_subtask_state(action.id, action.state)
             except Exception as e:
                 observation = AgentErrorObservation(str(e))
-                print_with_color(observation, 'ERROR')
+                print_with_color(observation, "ERROR")
                 traceback.print_exc()
 
         if action.executable:
@@ -176,11 +179,11 @@ class AgentController:
                     observation = action.run(self)
             except Exception as e:
                 observation = AgentErrorObservation(str(e))
-                print_with_color(observation, 'ERROR')
+                print_with_color(observation, "ERROR")
                 traceback.print_exc()
 
         if not isinstance(observation, NullObservation):
-            print_with_color(observation, 'OBSERVATION')
+            print_with_color(observation, "OBSERVATION")
 
         self.add_history(action, observation)
         await self._run_callbacks(observation)
@@ -192,9 +195,8 @@ class AgentController:
             idx = self.callbacks.index(callback)
             try:
                 callback(event)
-            except Exception:
-                logger.exception('Callback error: %s', idx)
-                pass
+            except Exception as e:
+                logger.exception(f"Callback error: {e}, idx: {idx}")
         await asyncio.sleep(
             0.001
         )  # Give back control for a tick, so we can await in callbacks

+ 13 - 9
opendevin/controller/command_manager.py

@@ -1,26 +1,30 @@
 from typing import List
+
+from opendevin import config
 from opendevin.observation import CmdOutputObservation
 from opendevin.sandbox import DockerExecBox, DockerSSHBox, Sandbox
-from opendevin import config
+from opendevin.schema import ConfigType
 
 
 class CommandManager:
+    id: str
+    directory: str
     shell: Sandbox
 
     def __init__(
-        self,
-        id: str,
-        dir: str,
-        container_image: str | None = None,
+            self,
+            sid: str,
+            directory: str,
+            container_image: str | None = None,
     ):
-        self.directory = dir
-        if config.get('SANDBOX_TYPE').lower() == 'exec':
+        self.directory = directory
+        if config.get(ConfigType.SANDBOX_TYPE).lower() == 'exec':
             self.shell = DockerExecBox(
-                id=(id or 'default'), workspace_dir=dir, container_image=container_image
+                sid=(sid or 'default'), workspace_dir=directory, container_image=container_image
             )
         else:
             self.shell = DockerSSHBox(
-                id=(id or 'default'), workspace_dir=dir, container_image=container_image
+                sid=(sid or 'default'), workspace_dir=directory, container_image=container_image
             )
 
     def run_command(self, command: str, background=False) -> CmdOutputObservation:

+ 48 - 56
opendevin/sandbox/exec_box.py

@@ -1,4 +1,5 @@
 import atexit
+import concurrent.futures
 import os
 import sys
 import time
@@ -7,19 +8,18 @@ from collections import namedtuple
 from typing import Dict, List, Tuple
 
 import docker
-import concurrent.futures
 
 from opendevin import config
 from opendevin.logger import opendevin_logger as logger
 from opendevin.sandbox.sandbox import Sandbox, BackgroundCommand
+from opendevin.schema import ConfigType
 
 InputType = namedtuple('InputType', ['content'])
 OutputType = namedtuple('OutputType', ['content'])
 
-DIRECTORY_REWRITE = config.get(
-    'DIRECTORY_REWRITE'
-)  # helpful for docker-in-docker scenarios
-CONTAINER_IMAGE = config.get('SANDBOX_CONTAINER_IMAGE')
+# helpful for docker-in-docker scenarios
+DIRECTORY_REWRITE = config.get(ConfigType.DIRECTORY_REWRITE)
+CONTAINER_IMAGE = config.get(ConfigType.SANDBOX_CONTAINER_IMAGE)
 
 # FIXME: On some containers, the devin user doesn't have enough permission, e.g. to install packages
 # How do we make this more flexible?
@@ -32,21 +32,32 @@ elif hasattr(os, 'getuid'):
 
 
 class DockerExecBox(Sandbox):
-    closed = False
+    instance_id: str
+    container_image: str
+    container_name_prefix = 'opendevin-sandbox-'
+    container_name: str
+    container: docker.models.containers.Container
+    docker_client: docker.DockerClient
+
     cur_background_id = 0
     background_commands: Dict[int, BackgroundCommand] = {}
 
     def __init__(
-        self,
-        workspace_dir: str | None = None,
-        container_image: str | None = None,
-        timeout: int = 120,
-        id: str | None = None,
+            self,
+            workspace_dir: str | None = None,
+            container_image: str | None = None,
+            timeout: int = 120,
+            sid: str | None = None,
     ):
-        if id is not None:
-            self.instance_id = id
-        else:
-            self.instance_id = str(uuid.uuid4())
+        # Initialize docker client. Throws an exception if Docker is not reachable.
+        try:
+            self.docker_client = docker.from_env()
+        except Exception as ex:
+            logger.exception(
+                'Please check Docker is running using `docker ps`.', exc_info=False)
+            raise ex
+
+        self.instance_id = sid if sid is not None else str(uuid.uuid4())
         if workspace_dir is not None:
             os.makedirs(workspace_dir, exist_ok=True)
             # expand to absolute path
@@ -67,20 +78,16 @@ class DockerExecBox(Sandbox):
         # if it is too short, the container may still waiting for previous
         # command to finish (e.g. apt-get update)
         # if it is too long, the user may have to wait for a unnecessary long time
-        self.timeout: int = timeout
-
-        if container_image is None:
-            self.container_image = CONTAINER_IMAGE
-        else:
-            self.container_image = container_image
+        self.timeout = timeout
+        self.container_image = CONTAINER_IMAGE if container_image is None else container_image
+        self.container_name = self.container_name_prefix + self.instance_id
 
-        self.container_name = f'sandbox-{self.instance_id}'
+        # always restart the container, cuz the initial be regarded as a new session
+        self.restart_docker_container()
 
-        if not self.is_container_running():
-            self.restart_docker_container()
         if RUN_AS_DEVIN:
             self.setup_devin_user()
-        atexit.register(self.cleanup)
+        atexit.register(self.close)
 
     def setup_devin_user(self):
         exit_code, logs = self.container.exec_run(
@@ -159,22 +166,9 @@ class DockerExecBox(Sandbox):
         self.background_commands.pop(id)
         return bg_cmd
 
-    def close(self):
-        self.stop_docker_container()
-        self.closed = True
-
     def stop_docker_container(self):
-
-        # Initialize docker client. Throws an exception if Docker is not reachable.
-        try:
-            docker_client = docker.from_env()
-        except docker.errors.DockerException as e:
-            logger.exception(
-                'Please check Docker is running using `docker ps`.', exc_info=False)
-            raise e
-
         try:
-            container = docker_client.containers.get(self.container_name)
+            container = self.docker_client.containers.get(self.container_name)
             container.stop()
             container.remove()
             elapsed = 0
@@ -183,14 +177,14 @@ class DockerExecBox(Sandbox):
                 elapsed += 1
                 if elapsed > self.timeout:
                     break
-                container = docker_client.containers.get(self.container_name)
+                container = self.docker_client.containers.get(
+                    self.container_name)
         except docker.errors.NotFound:
             pass
 
     def is_container_running(self):
         try:
-            docker_client = docker.from_env()
-            container = docker_client.containers.get(self.container_name)
+            container = self.docker_client.containers.get(self.container_name)
             if container.status == 'running':
                 self.container = container
                 return True
@@ -207,11 +201,8 @@ class DockerExecBox(Sandbox):
             raise e
 
         try:
-            # Initialize docker client. Throws an exception if Docker is not reachable.
-            docker_client = docker.from_env()
-
             # start the container
-            self.container = docker_client.containers.run(
+            self.container = self.docker_client.containers.run(
                 self.container_image,
                 command='tail -f /dev/null',
                 network_mode='host',
@@ -222,9 +213,9 @@ class DockerExecBox(Sandbox):
                     'bind': '/workspace', 'mode': 'rw'}},
             )
             logger.info('Container started')
-        except Exception as e:
+        except Exception as ex:
             logger.exception('Failed to start container', exc_info=False)
-            raise e
+            raise ex
 
         # wait for container to be ready
         elapsed = 0
@@ -236,20 +227,21 @@ class DockerExecBox(Sandbox):
                 break
             time.sleep(1)
             elapsed += 1
-            self.container = docker_client.containers.get(self.container_name)
+            self.container = self.docker_client.containers.get(self.container_name)
             if elapsed > self.timeout:
                 break
         if self.container.status != 'running':
             raise Exception('Failed to start container')
 
     # clean up the container, cannot do it in __del__ because the python interpreter is already shutting down
-    def cleanup(self):
-        if self.closed:
-            return
-        try:
-            self.container.remove(force=True)
-        except docker.errors.NotFound:
-            pass
+    def close(self):
+        containers = self.docker_client.containers.list(all=True)
+        for container in containers:
+            try:
+                if container.name.startswith(self.container_name_prefix):
+                    container.remove(force=True)
+            except docker.errors.NotFound:
+                pass
 
 
 if __name__ == '__main__':

+ 3 - 3
opendevin/sandbox/sandbox.py

@@ -1,8 +1,8 @@
 import select
 import sys
-from typing import Tuple
 from abc import ABC, abstractmethod
 from typing import Dict
+from typing import Tuple
 
 
 class BackgroundCommand:
@@ -28,8 +28,8 @@ class BackgroundCommand:
             msg_type = prefix[0:1]
             padding = prefix[1:4]
             if (
-                msg_type in [b'\x00', b'\x01', b'\x02', b'\x03']
-                and padding == b'\x00\x00\x00'
+                    msg_type in [b'\x00', b'\x01', b'\x02', b'\x03']
+                    and padding == b'\x00\x00\x00'
             ):
                 msg_length = int.from_bytes(prefix[4:8], byteorder=byte_order)
                 res += logs[i + 8: i + 8 + msg_length]

+ 65 - 66
opendevin/sandbox/ssh_box.py

@@ -1,26 +1,27 @@
 import atexit
 import os
+import platform
 import sys
 import time
 import uuid
-import platform
-from pexpect import pxssh
 from collections import namedtuple
 from typing import Dict, List, Tuple, Union
 
 import docker
+from pexpect import pxssh
 
 from opendevin import config
 from opendevin.logger import opendevin_logger as logger
 from opendevin.sandbox.sandbox import Sandbox, BackgroundCommand
+from opendevin.schema import ConfigType
+from opendevin.utils import find_available_tcp_port
 
 InputType = namedtuple('InputType', ['content'])
 OutputType = namedtuple('OutputType', ['content'])
 
-DIRECTORY_REWRITE = config.get(
-    'DIRECTORY_REWRITE'
-)  # helpful for docker-in-docker scenarios
-CONTAINER_IMAGE = config.get('SANDBOX_CONTAINER_IMAGE')
+# helpful for docker-in-docker scenarios
+DIRECTORY_REWRITE = config.get(ConfigType.DIRECTORY_REWRITE)
+CONTAINER_IMAGE = config.get(ConfigType.SANDBOX_CONTAINER_IMAGE)
 
 # FIXME: On some containers, the devin user doesn't have enough permission, e.g. to install packages
 # How do we make this more flexible?
@@ -33,21 +34,35 @@ elif hasattr(os, 'getuid'):
 
 
 class DockerSSHBox(Sandbox):
-    closed = False
+    instance_id: str
+    container_image: str
+    container_name_prefix = 'opendevin-sandbox-'
+    container_name: str
+    container: docker.models.containers.Container
+    docker_client: docker.DockerClient
+
+    _ssh_password: str
+    _ssh_port: int
+
     cur_background_id = 0
     background_commands: Dict[int, BackgroundCommand] = {}
 
     def __init__(
-        self,
-        workspace_dir: str | None = None,
-        container_image: str | None = None,
-        timeout: int = 120,
-        id: str | None = None,
+            self,
+            workspace_dir: str | None = None,
+            container_image: str | None = None,
+            timeout: int = 120,
+            sid: str | None = None,
     ):
-        if id is not None:
-            self.instance_id = id
-        else:
-            self.instance_id = str(uuid.uuid4())
+        # Initialize docker client. Throws an exception if Docker is not reachable.
+        try:
+            self.docker_client = docker.from_env()
+        except Exception as ex:
+            logger.exception(
+                'Please check Docker is running using `docker ps`.', exc_info=False)
+            raise ex
+
+        self.instance_id = sid if sid is not None else str(uuid.uuid4())
         if workspace_dir is not None:
             os.makedirs(workspace_dir, exist_ok=True)
             # expand to absolute path
@@ -68,22 +83,20 @@ class DockerSSHBox(Sandbox):
         # if it is too short, the container may still waiting for previous
         # command to finish (e.g. apt-get update)
         # if it is too long, the user may have to wait for a unnecessary long time
-        self.timeout: int = timeout
+        self.timeout = timeout
+        self.container_image = CONTAINER_IMAGE if container_image is None else container_image
+        self.container_name = self.container_name_prefix + self.instance_id
 
-        if container_image is None:
-            self.container_image = CONTAINER_IMAGE
-        else:
-            self.container_image = container_image
-
-        self.container_name = f'sandbox-{self.instance_id}'
-
-        if not self.is_container_running():
-            self.restart_docker_container()
         # set up random user password
         self._ssh_password = str(uuid.uuid4())
+        self._ssh_port = find_available_tcp_port()
+
+        # always restart the container, cuz the initial be regarded as a new session
+        self.restart_docker_container()
+
         self.setup_user()
         self.start_ssh_session()
-        atexit.register(self.cleanup)
+        atexit.register(self.close)
 
     def setup_user(self):
 
@@ -91,7 +104,7 @@ class DockerSSHBox(Sandbox):
         # TODO(sandbox): add this line in the Dockerfile for next minor version of docker image
         exit_code, logs = self.container.exec_run(
             ['/bin/bash', '-c',
-                r"echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers"],
+             r"echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers"],
             workdir='/workspace',
         )
         if exit_code != 0:
@@ -116,7 +129,7 @@ class DockerSSHBox(Sandbox):
         # Create the opendevin user
         exit_code, logs = self.container.exec_run(
             ['/bin/bash', '-c',
-                f'useradd -rm -d /home/opendevin -s /bin/bash -g root -G sudo -u {USER_ID} opendevin'],
+             f'useradd -rm -d /home/opendevin -s /bin/bash -g root -G sudo -u {USER_ID} opendevin'],
             workdir='/workspace',
         )
         if exit_code != 0:
@@ -124,7 +137,7 @@ class DockerSSHBox(Sandbox):
                 f'Failed to create opendevin user in sandbox: {logs}')
         exit_code, logs = self.container.exec_run(
             ['/bin/bash', '-c',
-                f"echo 'opendevin:{self._ssh_password}' | chpasswd"],
+             f"echo 'opendevin:{self._ssh_password}' | chpasswd"],
             workdir='/workspace',
         )
         if exit_code != 0:
@@ -134,7 +147,7 @@ class DockerSSHBox(Sandbox):
             exit_code, logs = self.container.exec_run(
                 # change password for root
                 ['/bin/bash', '-c',
-                    f"echo 'root:{self._ssh_password}' | chpasswd"],
+                 f"echo 'root:{self._ssh_password}' | chpasswd"],
                 workdir='/workspace',
             )
             if exit_code != 0:
@@ -158,7 +171,7 @@ class DockerSSHBox(Sandbox):
             # autopep8: off
             f"Connecting to {username}@{hostname} via ssh. If you encounter any issues, you can try `ssh -v -p 2222 {username}@{hostname}` with the password '{self._ssh_password}' and report the issue on GitHub."
         )
-        self.ssh.login(hostname, username, self._ssh_password, port=2222)
+        self.ssh.login(hostname, username, self._ssh_password, port=self._ssh_port)
 
         # Fix: https://github.com/pexpect/pexpect/issues/669
         self.ssh.sendline("bind 'set enable-bracketed-paste off'")
@@ -235,22 +248,9 @@ class DockerSSHBox(Sandbox):
         self.background_commands.pop(id)
         return bg_cmd
 
-    def close(self):
-        self.stop_docker_container()
-        self.closed = True
-
     def stop_docker_container(self):
-
-        # Initialize docker client. Throws an exception if Docker is not reachable.
         try:
-            docker_client = docker.from_env()
-        except docker.errors.DockerException as e:
-            logger.exception(
-                'Please check Docker is running using `docker ps`.', exc_info=False)
-            raise e
-
-        try:
-            container = docker_client.containers.get(self.container_name)
+            container = self.docker_client.containers.get(self.container_name)
             container.stop()
             container.remove()
             elapsed = 0
@@ -259,14 +259,14 @@ class DockerSSHBox(Sandbox):
                 elapsed += 1
                 if elapsed > self.timeout:
                     break
-                container = docker_client.containers.get(self.container_name)
+                container = self.docker_client.containers.get(
+                    self.container_name)
         except docker.errors.NotFound:
             pass
 
     def is_container_running(self):
         try:
-            docker_client = docker.from_env()
-            container = docker_client.containers.get(self.container_name)
+            container = self.docker_client.containers.get(self.container_name)
             if container.status == 'running':
                 self.container = container
                 return True
@@ -278,20 +278,17 @@ class DockerSSHBox(Sandbox):
         try:
             self.stop_docker_container()
             logger.info('Container stopped')
-        except docker.errors.DockerException as e:
+        except docker.errors.DockerException as ex:
             logger.exception('Failed to stop container', exc_info=False)
-            raise e
+            raise ex
 
         try:
-            # Initialize docker client. Throws an exception if Docker is not reachable.
-            docker_client = docker.from_env()
-
             network_kwargs: Dict[str, Union[str, Dict[str, int]]] = {}
             if platform.system() == 'Linux':
                 network_kwargs['network_mode'] = 'host'
             elif platform.system() == 'Darwin':
                 # FIXME: This is a temporary workaround for Mac OS
-                network_kwargs['ports'] = {'2222/tcp': 2222}
+                network_kwargs['ports'] = {'2222/tcp': self._ssh_port}
                 logger.warning(
                     ('Using port forwarding for Mac OS. '
                      'Server started by OpenDevin will not be accessible from the host machine at the moment. '
@@ -300,7 +297,7 @@ class DockerSSHBox(Sandbox):
                 )
 
             # start the container
-            self.container = docker_client.containers.run(
+            self.container = self.docker_client.containers.run(
                 self.container_image,
                 # allow root login
                 command="/usr/sbin/sshd -D -p 2222 -o 'PermitRootLogin=yes'",
@@ -313,9 +310,9 @@ class DockerSSHBox(Sandbox):
                     'bind': '/workspace', 'mode': 'rw'}},
             )
             logger.info('Container started')
-        except Exception as e:
+        except Exception as ex:
             logger.exception('Failed to start container', exc_info=False)
-            raise e
+            raise ex
 
         # wait for container to be ready
         elapsed = 0
@@ -327,7 +324,8 @@ class DockerSSHBox(Sandbox):
                 break
             time.sleep(1)
             elapsed += 1
-            self.container = docker_client.containers.get(self.container_name)
+            self.container = self.docker_client.containers.get(
+                self.container_name)
             logger.info(
                 f'waiting for container to start: {elapsed}, container status: {self.container.status}')
             if elapsed > self.timeout:
@@ -336,13 +334,14 @@ class DockerSSHBox(Sandbox):
             raise Exception('Failed to start container')
 
     # clean up the container, cannot do it in __del__ because the python interpreter is already shutting down
-    def cleanup(self):
-        if self.closed:
-            return
-        try:
-            self.container.remove(force=True)
-        except docker.errors.NotFound:
-            pass
+    def close(self):
+        containers = self.docker_client.containers.list(all=True)
+        for container in containers:
+            try:
+                if container.name.startswith(self.container_name_prefix):
+                    container.remove(force=True)
+            except docker.errors.NotFound:
+                pass
 
 
 if __name__ == '__main__':

+ 3 - 1
opendevin/server/agent/__init__.py

@@ -1,3 +1,5 @@
 from .manager import AgentManager
 
-__all__ = ["AgentManager"]
+agent_manager = AgentManager()
+
+__all__ = ['AgentManager', 'agent_manager']

+ 183 - 0
opendevin/server/agent/agent.py

@@ -0,0 +1,183 @@
+import asyncio
+import os
+from typing import Optional
+
+from opendevin import config
+from opendevin.action import (
+    Action,
+    NullAction,
+)
+from opendevin.agent import Agent
+from opendevin.controller import AgentController
+from opendevin.llm.llm import LLM
+from opendevin.logger import opendevin_logger as logger
+from opendevin.observation import NullObservation, Observation, UserMessageObservation
+from opendevin.schema import ActionType, ConfigType
+from opendevin.server.session import session_manager
+
+
+class AgentUnit:
+    """Represents a session with an agent.
+
+    Attributes:
+        controller: The AgentController instance for controlling the agent.
+        agent_task: The task representing the agent's execution.
+    """
+
+    sid: str
+    controller: Optional[AgentController] = None
+    agent_task: Optional[asyncio.Task] = None
+
+    def __init__(self, sid):
+        """Initializes a new instance of the Session class."""
+        self.sid = sid
+
+    async def send_error(self, message):
+        """Sends an error message to the client.
+
+        Args:
+            message: The error message to send.
+        """
+        await session_manager.send_error(self.sid, message)
+
+    async def send_message(self, message):
+        """Sends a message to the client.
+
+        Args:
+            message: The message to send.
+        """
+        await session_manager.send_message(self.sid, message)
+
+    async def send(self, data):
+        """Sends data to the client.
+
+        Args:
+            data: The data to send.
+        """
+        await session_manager.send(self.sid, data)
+
+    async def dispatch(self, action: str | None, data: dict):
+        """Dispatches actions to the agent from the client."""
+        if action is None:
+            await self.send_error('Invalid action')
+            return
+
+        match action:
+            case ActionType.INIT:
+                if self.controller is not None:
+                    # Agent already started, no need to create a new one
+                    await self.init_done()
+                    return
+                await self.create_controller(data)
+            case ActionType.START:
+                await self.start_task(data)
+            case ActionType.CHAT:
+                if self.controller is None:
+                    await self.send_error('No agent started. Please wait a second...')
+                    return
+                self.controller.add_history(
+                    NullAction(), UserMessageObservation(data['message'])
+                )
+            case _:
+                await self.send_error("I didn't recognize this action:" + action)
+
+    def get_arg_or_default(self, _args: dict, key: ConfigType) -> str:
+        """Gets an argument from the args dictionary or the default value.
+
+        Args:
+            _args: The args dictionary.
+            key: The key to get.
+
+        Returns:
+            The value of the key or the default value.
+        """
+
+        return _args.get(key, config.get(key))
+
+    async def create_controller(self, start_event: dict):
+        """Creates an AgentController instance.
+
+        Args:
+            start_event: The start event data (optional).
+        """
+        args = {
+            key: value
+            for key, value in start_event.get('args', {}).items()
+            if value != ''
+        }  # remove empty values, prevent FE from sending empty strings
+        directory = self.get_arg_or_default(args, ConfigType.WORKSPACE_DIR)
+        agent_cls = self.get_arg_or_default(args, ConfigType.AGENT)
+        model = self.get_arg_or_default(args, ConfigType.LLM_MODEL)
+        api_key = config.get(ConfigType.LLM_API_KEY)
+        api_base = config.get(ConfigType.LLM_BASE_URL)
+        container_image = config.get(ConfigType.SANDBOX_CONTAINER_IMAGE)
+        max_iterations = self.get_arg_or_default(
+            args, ConfigType.MAX_ITERATIONS)
+
+        if not os.path.exists(directory):
+            logger.info(
+                'Workspace directory %s does not exist. Creating it...', directory
+            )
+            os.makedirs(directory)
+        directory = os.path.relpath(directory, os.getcwd())
+        llm = LLM(model=model, api_key=api_key, base_url=api_base)
+        try:
+            self.controller = AgentController(
+                sid=self.sid,
+                agent=Agent.get_cls(agent_cls)(llm),
+                workdir=directory,
+                max_iterations=int(max_iterations),
+                container_image=container_image,
+                callbacks=[self.on_agent_event],
+            )
+        except Exception as e:
+            logger.exception(f'Error creating controller: {e}')
+            await self.send_error(
+                'Error creating controller. Please check Docker is running using `docker ps`.'
+            )
+            return
+        await self.init_done()
+
+    async def init_done(self):
+        await self.send({'action': ActionType.INIT, 'message': 'Control loop started.'})
+
+    async def start_task(self, start_event):
+        """Starts a task for the agent.
+
+        Args:
+            start_event: The start event data.
+        """
+        if 'task' not in start_event['args']:
+            await self.send_error('No task specified')
+            return
+        await self.send_message('Starting new task...')
+        task = start_event['args']['task']
+        if self.controller is None:
+            await self.send_error('No agent started. Please wait a second...')
+            return
+        try:
+            self.agent_task = await asyncio.create_task(
+                self.controller.start_loop(task), name='agent loop'
+            )
+        except Exception as e:
+            await self.send_error(f'Error during task loop: {e}')
+
+    def on_agent_event(self, event: Observation | Action):
+        """Callback function for agent events.
+
+        Args:
+            event: The agent event (Observation or Action).
+        """
+        if isinstance(event, NullAction):
+            return
+        if isinstance(event, NullObservation):
+            return
+        event_dict = event.to_dict()
+        asyncio.create_task(self.send(event_dict),
+                            name='send event in callback')
+
+    def close(self):
+        if self.agent_task:
+            self.agent_task.cancel()
+        if self.controller is not None:
+            self.controller.command_manager.shell.close()

+ 27 - 163
opendevin/server/agent/manager.py

@@ -1,180 +1,44 @@
-import asyncio
-import os
-from typing import Optional
+import atexit
+import signal
 
-from opendevin import config
-from opendevin.action import (
-    Action,
-    NullAction,
-)
-from opendevin.agent import Agent
-from opendevin.controller import AgentController
-from opendevin.llm.llm import LLM
-from opendevin.logger import opendevin_logger as logger
-from opendevin.observation import NullObservation, Observation, UserMessageObservation
-from opendevin.schema import ActionType, ConfigType
 from opendevin.server.session import session_manager
+from .agent import AgentUnit
 
 
 class AgentManager:
-    """Represents a session with an agent.
+    sid_to_agent: dict[str, 'AgentUnit'] = {}
 
-    Attributes:
-        controller: The AgentController instance for controlling the agent.
-        agent: The Agent instance representing the agent.
-        agent_task: The task representing the agent's execution.
-    """
+    def __init__(self):
+        atexit.register(self.close)
+        signal.signal(signal.SIGINT, self.handle_signal)
+        signal.signal(signal.SIGTERM, self.handle_signal)
 
-    sid: str
-
-    def __init__(self, sid):
-        """Initializes a new instance of the Session class."""
-        self.sid = sid
-        self.controller: Optional[AgentController] = None
-        self.agent: Optional[Agent] = None
-        self.agent_task = None
-
-    async def send_error(self, message):
-        """Sends an error message to the client.
+    def register_agent(self, sid: str):
+        """Registers a new agent.
 
         Args:
-            message: The error message to send.
+            sid: The session ID of the agent.
         """
-        await session_manager.send_error(self.sid, message)
-
-    async def send_message(self, message):
-        """Sends a message to the client.
-
-        Args:
-            message: The message to send.
-        """
-        await session_manager.send_message(self.sid, message)
-
-    async def send(self, data):
-        """Sends data to the client.
-
-        Args:
-            data: The data to send.
-        """
-        await session_manager.send(self.sid, data)
-
-    async def dispatch(self, action: str | None, data: dict):
-        """Dispatches actions to the agent from the client."""
-        if action is None:
-            await self.send_error('Invalid action')
+        if sid not in self.sid_to_agent:
+            self.sid_to_agent[sid] = AgentUnit(sid)
             return
 
-        if action == ActionType.INIT:
-            await self.create_controller(data)
-        elif action == ActionType.START:
-            await self.start_task(data)
-        else:
-            if self.controller is None:
-                await self.send_error('No agent started. Please wait a second...')
-            elif action == ActionType.CHAT:
-                self.controller.add_history(
-                    NullAction(), UserMessageObservation(data['message'])
-                )
-            else:
-                await self.send_error("I didn't recognize this action:" + action)
-
-    def get_arg_or_default(self, _args: dict, key: ConfigType) -> str:
-        """Gets an argument from the args dictionary or the default value.
-
-        Args:
-            _args: The args dictionary.
-            key: The key to get.
-
-        Returns:
-            The value of the key or the default value.
-        """
+        # TODO: confirm whether the agent is alive
 
-        return _args.get(key, config.get(key))
-
-    async def create_controller(self, start_event: dict):
-        """Creates an AgentController instance.
-
-        Args:
-            start_event: The start event data (optional).
-        """
-        args = {
-            key: value
-            for key, value in start_event.get('args', {}).items()
-            if value != ''
-        }  # remove empty values, prevent FE from sending empty strings
-        directory = self.get_arg_or_default(args, ConfigType.WORKSPACE_DIR)
-        agent_cls = self.get_arg_or_default(args, ConfigType.AGENT)
-        model = self.get_arg_or_default(args, ConfigType.LLM_MODEL)
-        api_key = config.get(ConfigType.LLM_API_KEY)
-        api_base = config.get(ConfigType.LLM_BASE_URL)
-        container_image = config.get(ConfigType.SANDBOX_CONTAINER_IMAGE)
-        max_iterations = self.get_arg_or_default(
-            args, ConfigType.MAX_ITERATIONS)
-
-        if not os.path.exists(directory):
-            logger.info(
-                'Workspace directory %s does not exist. Creating it...', directory
-            )
-            os.makedirs(directory)
-        directory = os.path.relpath(directory, os.getcwd())
-        llm = LLM(model=model, api_key=api_key, base_url=api_base)
-        AgentCls = Agent.get_cls(agent_cls)
-        self.agent = AgentCls(llm)
-        try:
-            self.controller = AgentController(
-                id=self.sid,
-                agent=self.agent,
-                workdir=directory,
-                max_iterations=int(max_iterations),
-                container_image=container_image,
-                callbacks=[self.on_agent_event],
-            )
-        except Exception:
-            logger.exception('Error creating controller.')
-            await self.send_error(
-                'Error creating controller. Please check Docker is running using `docker ps`.'
-            )
-            return
-        await self.send({'action': ActionType.INIT, 'message': 'Control loop started.'})
-
-    async def start_task(self, start_event):
-        """Starts a task for the agent.
-
-        Args:
-            start_event: The start event data.
-        """
-        if 'task' not in start_event['args']:
-            await self.send_error('No task specified')
-            return
-        await self.send_message('Starting new task...')
-        task = start_event['args']['task']
-        if self.controller is None:
-            await self.send_error('No agent started. Please wait a second...')
+    async def dispatch(self, sid: str, action: str | None, data: dict):
+        """Dispatches actions to the agent from the client."""
+        if sid not in self.sid_to_agent:
+            # self.register_agent(sid)  # auto-register agent, may be opened later
+            await session_manager.send_error(sid, 'Agent not registered')
             return
-        try:
-            self.agent_task = await asyncio.create_task(
-                self.controller.start_loop(task), name='agent loop'
-            )
-        except Exception:
-            await self.send_error('Error during task loop.')
 
-    def on_agent_event(self, event: Observation | Action):
-        """Callback function for agent events.
+        await self.sid_to_agent[sid].dispatch(action, data)
 
-        Args:
-            event: The agent event (Observation or Action).
-        """
-        if isinstance(event, NullAction):
-            return
-        if isinstance(event, NullObservation):
-            return
-        event_dict = event.to_dict()
-        asyncio.create_task(self.send(event_dict),
-                            name='send event in callback')
+    def handle_signal(self, signum, _):
+        print(f"Received signal {signum}, exiting...")
+        self.close()
+        exit(0)
 
-    def disconnect(self):
-        self.websocket = None
-        if self.agent_task:
-            self.agent_task.cancel()
-        if self.controller is not None:
-            self.controller.command_manager.shell.close()
+    def close(self):
+        for sid, agent in self.sid_to_agent.items():
+            agent.close()

+ 7 - 7
opendevin/server/listen.py

@@ -11,7 +11,7 @@ from starlette.responses import JSONResponse
 import agenthub  # noqa F401 (we import this to get the agents registered)
 from opendevin import config, files
 from opendevin.agent import Agent
-from opendevin.server.agent import AgentManager
+from opendevin.server.agent import agent_manager
 from opendevin.server.auth import get_sid_from_token, sign_token
 from opendevin.server.session import message_stack, session_manager
 
@@ -37,7 +37,7 @@ async def websocket_endpoint(websocket: WebSocket):
     session_manager.add_session(sid, websocket)
     # TODO: actually the agent_manager is created for each websocket connection, even if the session id is the same,
     # we need to manage the agent in memory for reconnecting the same session id to the same agent
-    agent_manager = AgentManager(sid)
+    agent_manager.register_agent(sid)
     await session_manager.loop_recv(sid, agent_manager.dispatch)
 
 
@@ -54,12 +54,12 @@ async def get_litellm_agents():
     """
     Get all agents supported by LiteLLM.
     """
-    return Agent.listAgents()
+    return Agent.list_agents()
 
 
 @app.get('/auth')
 async def get_token(
-    credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
+        credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
 ):
     """
     Get token for authentication when starts a websocket connection.
@@ -74,7 +74,7 @@ async def get_token(
 
 @app.get('/messages')
 async def get_messages(
-    credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
+        credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
 ):
     data = []
     sid = get_sid_from_token(credentials.credentials)
@@ -89,7 +89,7 @@ async def get_messages(
 
 @app.get('/messages/total')
 async def get_message_total(
-    credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
+        credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
 ):
     sid = get_sid_from_token(credentials.credentials)
     return JSONResponse(
@@ -100,7 +100,7 @@ async def get_message_total(
 
 @app.delete('/messages')
 async def del_messages(
-    credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
+        credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
 ):
     sid = get_sid_from_token(credentials.credentials)
     message_stack.del_messages(sid)

+ 4 - 3
opendevin/server/session/__init__.py

@@ -1,6 +1,7 @@
-from .session import Session
 from .manager import SessionManager
-from .manager import session_manager
 from .msg_stack import message_stack
+from .session import Session
+
+session_manager = SessionManager()
 
-__all__ = ["Session", "SessionManager", "session_manager", "message_stack"]
+__all__ = ['Session', 'SessionManager', 'session_manager', 'message_stack']

+ 14 - 18
opendevin/server/session/manager.py

@@ -1,17 +1,16 @@
-import os
-import json
 import atexit
+import json
+import os
 import signal
 from typing import Dict, Callable
 
 from fastapi import WebSocket
 
-from .session import Session
 from .msg_stack import message_stack
+from .session import Session
 
-
-CACHE_DIR = os.getenv("CACHE_DIR", "cache")
-SESSION_CACHE_FILE = os.path.join(CACHE_DIR, "sessions.json")
+CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
+SESSION_CACHE_FILE = os.path.join(CACHE_DIR, 'sessions.json')
 
 
 class SessionManager:
@@ -30,7 +29,7 @@ class SessionManager:
         self._sessions[sid].update_connection(ws_conn)
 
     async def loop_recv(self, sid: str, dispatch: Callable):
-        print(f"Starting loop_recv for sid: {sid}, {sid not in self._sessions}")
+        print(f"Starting loop_recv for sid: {sid}")
         """Starts listening for messages from the client."""
         if sid not in self._sessions:
             return
@@ -46,35 +45,35 @@ class SessionManager:
 
     async def send(self, sid: str, data: Dict[str, object]) -> bool:
         """Sends data to the client."""
-        message_stack.add_message(sid, "assistant", data)
+        message_stack.add_message(sid, 'assistant', data)
         if sid not in self._sessions:
             return False
         return await self._sessions[sid].send(data)
 
     async def send_error(self, sid: str, message: str) -> bool:
         """Sends an error message to the client."""
-        return await self.send(sid, {"error": True, "message": message})
+        return await self.send(sid, {'error': True, 'message': message})
 
     async def send_message(self, sid: str, message: str) -> bool:
         """Sends a message to the client."""
-        return await self.send(sid, {"message": message})
+        return await self.send(sid, {'message': message})
 
     def _save_sessions(self):
         data = {}
         for sid, conn in self._sessions.items():
             data[sid] = {
-                "sid": conn.sid,
-                "last_active_ts": conn.last_active_ts,
-                "is_alive": conn.is_alive,
+                'sid': conn.sid,
+                'last_active_ts': conn.last_active_ts,
+                'is_alive': conn.is_alive,
             }
         if not os.path.exists(CACHE_DIR):
             os.makedirs(CACHE_DIR)
-        with open(SESSION_CACHE_FILE, "w+") as file:
+        with open(SESSION_CACHE_FILE, 'w+') as file:
             json.dump(data, file)
 
     def _load_sessions(self):
         try:
-            with open(SESSION_CACHE_FILE, "r") as file:
+            with open(SESSION_CACHE_FILE, 'r') as file:
                 data = json.load(file)
                 for sid, sdata in data.items():
                     conn = Session(sid, None)
@@ -85,6 +84,3 @@ class SessionManager:
             pass
         except json.decoder.JSONDecodeError:
             pass
-
-
-session_manager = SessionManager()

+ 3 - 2
opendevin/server/session/session.py

@@ -1,9 +1,10 @@
 import time
 from typing import Dict, Callable
+
 from fastapi import WebSocket, WebSocketDisconnect
-from .msg_stack import message_stack
 
 from opendevin.logger import opendevin_logger as logger
+from .msg_stack import message_stack
 
 DEL_DELT_SEC = 60 * 60 * 5
 
@@ -32,7 +33,7 @@ class Session:
 
                 message_stack.add_message(self.sid, 'user', data)
                 action = data.get('action', None)
-                await dispatch(action, data)
+                await dispatch(self.sid, action, data)
         except WebSocketDisconnect:
             self.is_alive = False
             logger.info('WebSocket disconnected, sid: %s', self.sid)

+ 3 - 0
opendevin/utils/__init__.py

@@ -0,0 +1,3 @@
+from .system import find_available_tcp_port
+
+__all__ = ['find_available_tcp_port']

+ 15 - 0
opendevin/utils/system.py

@@ -0,0 +1,15 @@
+import socket
+
+
+def find_available_tcp_port() -> int:
+    """Find an available TCP port, return -1 if none available.
+    """
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    try:
+        sock.bind(('localhost', 0))
+        port = sock.getsockname()[1]
+        return port
+    except Exception:
+        return -1
+    finally:
+        sock.close()