Просмотр исходного кода

[arch] Implement `copy_to` for Runtime (#3211)

* add copy to

* implement for ServerRuntime

* implement copyto for runtime (required by eval);
add tests for copy to

* fix exist file check

* unify copy_to_behavior and fix stuff
Xingyao Wang 1 год назад
Родитель
Сommit
286f10053e

+ 1 - 1
evaluation/mint/run_infer.py

@@ -112,7 +112,7 @@ def process_instance(
     )
     )
 
 
     requirements_host_src = 'evaluation/mint/requirements.txt'
     requirements_host_src = 'evaluation/mint/requirements.txt'
-    requirements_sandbox_dest = '/opendevin/plugins/mint/requirements.txt'
+    requirements_sandbox_dest = '/opendevin/plugins/mint/'
     sandbox.copy_to(
     sandbox.copy_to(
         host_src=requirements_host_src,
         host_src=requirements_host_src,
         sandbox_dest=requirements_sandbox_dest,
         sandbox_dest=requirements_sandbox_dest,

+ 57 - 1
opendevin/runtime/client/client.py

@@ -13,12 +13,14 @@ import argparse
 import asyncio
 import asyncio
 import os
 import os
 import re
 import re
+import shutil
 import subprocess
 import subprocess
 from contextlib import asynccontextmanager
 from contextlib import asynccontextmanager
 from pathlib import Path
 from pathlib import Path
 
 
 import pexpect
 import pexpect
-from fastapi import FastAPI, HTTPException, Request
+from fastapi import FastAPI, HTTPException, Request, UploadFile
+from fastapi.responses import JSONResponse
 from pydantic import BaseModel
 from pydantic import BaseModel
 from uvicorn import run
 from uvicorn import run
 
 
@@ -407,6 +409,60 @@ if __name__ == '__main__':
             logger.error(f'Error processing command: {str(e)}')
             logger.error(f'Error processing command: {str(e)}')
             raise HTTPException(status_code=500, detail=str(e))
             raise HTTPException(status_code=500, detail=str(e))
 
 
+    @app.post('/upload_file')
+    async def upload_file(
+        file: UploadFile, destination: str = '/', recursive: bool = False
+    ):
+        assert client is not None
+
+        try:
+            # Ensure the destination directory exists
+            if not os.path.isabs(destination):
+                raise HTTPException(
+                    status_code=400, detail='Destination must be an absolute path'
+                )
+
+            full_dest_path = destination
+            if not os.path.exists(full_dest_path):
+                os.makedirs(full_dest_path, exist_ok=True)
+
+            if recursive:
+                # For recursive uploads, we expect a zip file
+                if not file.filename.endswith('.zip'):
+                    raise HTTPException(
+                        status_code=400, detail='Recursive uploads must be zip files'
+                    )
+
+                zip_path = os.path.join(full_dest_path, file.filename)
+                with open(zip_path, 'wb') as buffer:
+                    shutil.copyfileobj(file.file, buffer)
+
+                # Extract the zip file
+                shutil.unpack_archive(zip_path, full_dest_path)
+                os.remove(zip_path)  # Remove the zip file after extraction
+
+                logger.info(
+                    f'Uploaded file {file.filename} and extracted to {destination}'
+                )
+            else:
+                # For single file uploads
+                file_path = os.path.join(full_dest_path, file.filename)
+                with open(file_path, 'wb') as buffer:
+                    shutil.copyfileobj(file.file, buffer)
+                logger.info(f'Uploaded file {file.filename} to {destination}')
+
+            return JSONResponse(
+                content={
+                    'filename': file.filename,
+                    'destination': destination,
+                    'recursive': recursive,
+                },
+                status_code=200,
+            )
+
+        except Exception as e:
+            raise HTTPException(status_code=500, detail=str(e))
+
     @app.get('/alive')
     @app.get('/alive')
     async def alive():
     async def alive():
         return {'status': 'ok'}
         return {'status': 'ok'}

+ 53 - 1
opendevin/runtime/client/runtime.py

@@ -1,6 +1,9 @@
 import asyncio
 import asyncio
+import os
+import tempfile
 import uuid
 import uuid
-from typing import Optional
+from typing import Any, Optional
+from zipfile import ZipFile
 
 
 import aiohttp
 import aiohttp
 import docker
 import docker
@@ -205,6 +208,55 @@ class EventStreamRuntime(Runtime):
         if close_client:
         if close_client:
             self.docker_client.close()
             self.docker_client.close()
 
 
+    async def copy_to(
+        self, host_src: str, sandbox_dest: str, recursive: bool = False
+    ) -> dict[str, Any]:
+        if not os.path.exists(host_src):
+            raise FileNotFoundError(f'Source file {host_src} does not exist')
+
+        session = await self._ensure_session()
+        await self._wait_until_alive()
+        try:
+            if recursive:
+                # For recursive copy, create a zip file
+                with tempfile.NamedTemporaryFile(
+                    suffix='.zip', delete=False
+                ) as temp_zip:
+                    temp_zip_path = temp_zip.name
+
+                with ZipFile(temp_zip_path, 'w') as zipf:
+                    for root, _, files in os.walk(host_src):
+                        for file in files:
+                            file_path = os.path.join(root, file)
+                            arcname = os.path.relpath(
+                                file_path, os.path.dirname(host_src)
+                            )
+                            zipf.write(file_path, arcname)
+
+                upload_data = {'file': open(temp_zip_path, 'rb')}
+            else:
+                # For single file copy
+                upload_data = {'file': open(host_src, 'rb')}
+
+            params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
+
+            async with session.post(
+                f'{self.api_url}/upload_file', data=upload_data, params=params
+            ) as response:
+                if response.status == 200:
+                    return await response.json()
+                else:
+                    error_message = await response.text()
+                    raise Exception(f'Copy operation failed: {error_message}')
+
+        except asyncio.TimeoutError:
+            raise TimeoutError('Copy operation timed out')
+        except Exception as e:
+            raise RuntimeError(f'Copy operation failed: {str(e)}')
+        finally:
+            if recursive:
+                os.unlink(temp_zip_path)
+
     async def on_event(self, event: Event) -> None:
     async def on_event(self, event: Event) -> None:
         logger.info(f'EventStreamRuntime: on_event triggered: {event}')
         logger.info(f'EventStreamRuntime: on_event triggered: {event}')
         if isinstance(event, Action):
         if isinstance(event, Action):

+ 5 - 1
opendevin/runtime/docker/ssh_box.py

@@ -456,6 +456,9 @@ class DockerSSHBox(Sandbox):
         return exit_code, command_output
         return exit_code, command_output
 
 
     def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
     def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
+        if not os.path.exists(host_src):
+            raise FileNotFoundError(f'Source file {host_src} does not exist')
+
         # mkdir -p sandbox_dest if it doesn't exist
         # mkdir -p sandbox_dest if it doesn't exist
         exit_code, logs = self.container.exec_run(
         exit_code, logs = self.container.exec_run(
             ['/bin/bash', '-c', f'mkdir -p {sandbox_dest}'],
             ['/bin/bash', '-c', f'mkdir -p {sandbox_dest}'],
@@ -494,7 +497,8 @@ class DockerSSHBox(Sandbox):
 
 
             with open(tar_filename, 'rb') as f:
             with open(tar_filename, 'rb') as f:
                 data = f.read()
                 data = f.read()
-            self.container.put_archive(os.path.dirname(sandbox_dest), data)
+
+            self.container.put_archive(sandbox_dest, data)
 
 
     def start_docker_container(self):
     def start_docker_container(self):
         try:
         try:

+ 1 - 1
opendevin/runtime/plugins/agent_skills/__init__.py

@@ -11,7 +11,7 @@ class AgentSkillsRequirement(PluginRequirement):
     host_src: str = os.path.dirname(
     host_src: str = os.path.dirname(
         os.path.abspath(__file__)
         os.path.abspath(__file__)
     )  # The directory of this file (opendevin/runtime/plugins/agent_skills)
     )  # The directory of this file (opendevin/runtime/plugins/agent_skills)
-    sandbox_dest: str = '/opendevin/plugins/agent_skills'
+    sandbox_dest: str = '/opendevin/plugins/'
     bash_script_path: str = 'setup.sh'
     bash_script_path: str = 'setup.sh'
     documentation: str = DOCUMENTATION
     documentation: str = DOCUMENTATION
 
 

+ 1 - 1
opendevin/runtime/plugins/jupyter/__init__.py

@@ -18,7 +18,7 @@ class JupyterRequirement(PluginRequirement):
     host_src: str = os.path.dirname(
     host_src: str = os.path.dirname(
         os.path.abspath(__file__)
         os.path.abspath(__file__)
     )  # The directory of this file (opendevin/runtime/plugins/jupyter)
     )  # The directory of this file (opendevin/runtime/plugins/jupyter)
-    sandbox_dest: str = '/opendevin/plugins/jupyter'
+    sandbox_dest: str = '/opendevin/plugins/'
     bash_script_path: str = 'setup.sh'
     bash_script_path: str = 'setup.sh'
 
 
     # ================================================================
     # ================================================================

+ 3 - 1
opendevin/runtime/plugins/mixin.py

@@ -62,7 +62,9 @@ class PluginMixin:
 
 
                 # Execute the bash script
                 # Execute the bash script
                 abs_path_to_bash_script = os.path.join(
                 abs_path_to_bash_script = os.path.join(
-                    requirement.sandbox_dest, requirement.bash_script_path
+                    requirement.sandbox_dest,
+                    requirement.name,
+                    requirement.bash_script_path,
                 )
                 )
                 logger.info(
                 logger.info(
                     f'Initializing plugin [{requirement.name}] by executing [{abs_path_to_bash_script}] in the sandbox.'
                     f'Initializing plugin [{requirement.name}] by executing [{abs_path_to_bash_script}] in the sandbox.'

+ 1 - 1
opendevin/runtime/plugins/swe_agent_commands/__init__.py

@@ -35,7 +35,7 @@ DEFAULT_DOCUMENTATION = ''.join(
 class SWEAgentCommandsRequirement(PluginRequirement):
 class SWEAgentCommandsRequirement(PluginRequirement):
     name: str = 'swe_agent_commands'
     name: str = 'swe_agent_commands'
     host_src: str = os.path.dirname(os.path.abspath(__file__))
     host_src: str = os.path.dirname(os.path.abspath(__file__))
-    sandbox_dest: str = '/opendevin/plugins/swe_agent_commands'
+    sandbox_dest: str = '/opendevin/plugins/'
     bash_script_path: str = 'setup_default.sh'
     bash_script_path: str = 'setup_default.sh'
 
 
     scripts_filepaths: list[str | None] = field(
     scripts_filepaths: list[str | None] = field(

+ 4 - 0
opendevin/runtime/runtime.py

@@ -176,6 +176,10 @@ class Runtime:
         observation = await getattr(self, action_type)(action)
         observation = await getattr(self, action_type)(action)
         return observation
         return observation
 
 
+    @abstractmethod
+    async def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
+        raise NotImplementedError('This method is not implemented in the base class.')
+
     # ====================================================================
     # ====================================================================
     # Implement these methods in the subclass
     # Implement these methods in the subclass
     # ====================================================================
     # ====================================================================

+ 3 - 0
opendevin/runtime/server/runtime.py

@@ -121,6 +121,9 @@ class ServerRuntime(Runtime):
                     'Failed to start browser environment, web browsing functionality will not work'
                     'Failed to start browser environment, web browsing functionality will not work'
                 )
                 )
 
 
+    async def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
+        self.sandbox.copy_to(host_src, sandbox_dest, recursive)
+
     async def run(self, action: CmdRunAction) -> Observation:
     async def run(self, action: CmdRunAction) -> Observation:
         return self._run_command(action.command)
         return self._run_command(action.command)
 
 

+ 153 - 0
tests/unit/test_runtime.py

@@ -2,6 +2,7 @@
 
 
 import asyncio
 import asyncio
 import os
 import os
+import tempfile
 import time
 import time
 from unittest.mock import patch
 from unittest.mock import patch
 
 
@@ -937,3 +938,155 @@ async def test_ipython_agentskills_fileop_pwd_agnostic_sandbox(
     await _test_ipython_agentskills_fileop_pwd_impl(runtime, enable_auto_lint)
     await _test_ipython_agentskills_fileop_pwd_impl(runtime, enable_auto_lint)
     await runtime.close()
     await runtime.close()
     await asyncio.sleep(1)
     await asyncio.sleep(1)
+
+
+def _create_test_file(host_temp_dir):
+    # Single file
+    with open(os.path.join(host_temp_dir, 'test_file.txt'), 'w') as f:
+        f.write('Hello, World!')
+
+
+@pytest.mark.asyncio
+async def test_copy_single_file(temp_dir, box_class):
+    runtime = await _load_runtime(temp_dir, box_class)
+
+    with tempfile.TemporaryDirectory() as host_temp_dir:
+        _create_test_file(host_temp_dir)
+        await runtime.copy_to(
+            os.path.join(host_temp_dir, 'test_file.txt'), '/workspace'
+        )
+
+    action = CmdRunAction(command='ls -alh /workspace')
+    logger.info(action, extra={'msg_type': 'ACTION'})
+    obs = await runtime.run_action(action)
+    logger.info(obs, extra={'msg_type': 'OBSERVATION'})
+    assert isinstance(obs, CmdOutputObservation)
+    assert obs.exit_code == 0
+    assert 'test_file.txt' in obs.content
+
+    action = CmdRunAction(command='cat /workspace/test_file.txt')
+    logger.info(action, extra={'msg_type': 'ACTION'})
+    obs = await runtime.run_action(action)
+    logger.info(obs, extra={'msg_type': 'OBSERVATION'})
+    assert isinstance(obs, CmdOutputObservation)
+    assert obs.exit_code == 0
+    assert 'Hello, World!' in obs.content
+
+
+def _create_test_dir_with_files(host_temp_dir):
+    os.mkdir(os.path.join(host_temp_dir, 'test_dir'))
+    with open(os.path.join(host_temp_dir, 'test_dir', 'file1.txt'), 'w') as f:
+        f.write('File 1 content')
+    with open(os.path.join(host_temp_dir, 'test_dir', 'file2.txt'), 'w') as f:
+        f.write('File 2 content')
+
+
+@pytest.mark.asyncio
+async def test_copy_directory_recursively(temp_dir, box_class):
+    runtime = await _load_runtime(temp_dir, box_class)
+
+    with tempfile.TemporaryDirectory() as host_temp_dir:
+        # We need a separate directory, since temp_dir is mounted to /workspace
+        _create_test_dir_with_files(host_temp_dir)
+        await runtime.copy_to(
+            os.path.join(host_temp_dir, 'test_dir'), '/workspace', recursive=True
+        )
+
+    action = CmdRunAction(command='ls -alh /workspace')
+    logger.info(action, extra={'msg_type': 'ACTION'})
+    obs = await runtime.run_action(action)
+    logger.info(obs, extra={'msg_type': 'OBSERVATION'})
+    assert isinstance(obs, CmdOutputObservation)
+    assert obs.exit_code == 0
+    assert 'test_dir' in obs.content
+    assert 'file1.txt' not in obs.content
+    assert 'file2.txt' not in obs.content
+
+    action = CmdRunAction(command='ls -alh /workspace/test_dir')
+    logger.info(action, extra={'msg_type': 'ACTION'})
+    obs = await runtime.run_action(action)
+    logger.info(obs, extra={'msg_type': 'OBSERVATION'})
+    assert isinstance(obs, CmdOutputObservation)
+    assert obs.exit_code == 0
+    assert 'file1.txt' in obs.content
+    assert 'file2.txt' in obs.content
+
+    action = CmdRunAction(command='cat /workspace/test_dir/file1.txt')
+    logger.info(action, extra={'msg_type': 'ACTION'})
+    obs = await runtime.run_action(action)
+    logger.info(obs, extra={'msg_type': 'OBSERVATION'})
+    assert isinstance(obs, CmdOutputObservation)
+    assert obs.exit_code == 0
+    assert 'File 1 content' in obs.content
+
+
+@pytest.mark.asyncio
+async def test_copy_to_non_existent_directory(temp_dir, box_class):
+    runtime = await _load_runtime(temp_dir, box_class)
+
+    with tempfile.TemporaryDirectory() as host_temp_dir:
+        _create_test_file(host_temp_dir)
+        await runtime.copy_to(
+            os.path.join(host_temp_dir, 'test_file.txt'), '/workspace/new_dir'
+        )
+
+    action = CmdRunAction(command='cat /workspace/new_dir/test_file.txt')
+    logger.info(action, extra={'msg_type': 'ACTION'})
+    obs = await runtime.run_action(action)
+    logger.info(obs, extra={'msg_type': 'OBSERVATION'})
+    assert isinstance(obs, CmdOutputObservation)
+    assert obs.exit_code == 0
+    assert 'Hello, World!' in obs.content
+
+
+@pytest.mark.asyncio
+async def test_overwrite_existing_file(temp_dir, box_class):
+    runtime = await _load_runtime(temp_dir, box_class)
+
+    # touch a file in /workspace
+    action = CmdRunAction(command='touch /workspace/test_file.txt')
+    logger.info(action, extra={'msg_type': 'ACTION'})
+    obs = await runtime.run_action(action)
+    logger.info(obs, extra={'msg_type': 'OBSERVATION'})
+    assert isinstance(obs, CmdOutputObservation)
+    assert obs.exit_code == 0
+
+    action = CmdRunAction(command='cat /workspace/test_file.txt')
+    logger.info(action, extra={'msg_type': 'ACTION'})
+    obs = await runtime.run_action(action)
+    logger.info(obs, extra={'msg_type': 'OBSERVATION'})
+    assert isinstance(obs, CmdOutputObservation)
+    assert obs.exit_code == 0
+    assert 'Hello, World!' not in obs.content
+
+    with tempfile.TemporaryDirectory() as host_temp_dir:
+        _create_test_file(host_temp_dir)
+        await runtime.copy_to(
+            os.path.join(host_temp_dir, 'test_file.txt'), '/workspace'
+        )
+
+    action = CmdRunAction(command='cat /workspace/test_file.txt')
+    logger.info(action, extra={'msg_type': 'ACTION'})
+    obs = await runtime.run_action(action)
+    logger.info(obs, extra={'msg_type': 'OBSERVATION'})
+    assert isinstance(obs, CmdOutputObservation)
+    assert obs.exit_code == 0
+    assert 'Hello, World!' in obs.content
+
+
+@pytest.mark.asyncio
+async def test_copy_non_existent_file(temp_dir, box_class):
+    runtime = await _load_runtime(temp_dir, box_class)
+
+    with pytest.raises(FileNotFoundError):
+        await runtime.copy_to(
+            os.path.join(temp_dir, 'non_existent_file.txt'),
+            '/workspace/should_not_exist.txt',
+        )
+
+    action = CmdRunAction(command='ls /workspace/should_not_exist.txt')
+    logger.info(action, extra={'msg_type': 'ACTION'})
+    obs = await runtime.run_action(action)
+    logger.info(obs, extra={'msg_type': 'OBSERVATION'})
+    assert isinstance(obs, CmdOutputObservation)
+    assert obs.exit_code != 0  # File should not exist