Jelajahi Sumber

Fix: Buffering zip downloads to files rather than holding in memory (#4802)

tofarr 1 tahun lalu
induk
melakukan
932de79154

+ 3 - 2
openhands/runtime/base.py

@@ -3,6 +3,7 @@ import copy
 import json
 import os
 from abc import abstractmethod
+from pathlib import Path
 from typing import Callable
 
 from requests.exceptions import ConnectionError
@@ -274,6 +275,6 @@ class Runtime(FileEditRuntimeMixin):
         raise NotImplementedError('This method is not implemented in the base class.')
 
     @abstractmethod
-    def copy_from(self, path: str) -> bytes:
-        """Zip all files in the sandbox and return as a stream of bytes."""
+    def copy_from(self, path: str) -> Path:
+        """Zip all files in the sandbox and return a path in the local filesystem."""
         raise NotImplementedError('This method is not implemented in the base class.')

+ 7 - 3
openhands/runtime/impl/eventstream/eventstream_runtime.py

@@ -1,4 +1,5 @@
 import os
+from pathlib import Path
 import tempfile
 import threading
 from functools import lru_cache
@@ -604,7 +605,7 @@ class EventStreamRuntime(Runtime):
         except requests.Timeout:
             raise TimeoutError('List files operation timed out')
 
-    def copy_from(self, path: str) -> bytes:
+    def copy_from(self, path: str) -> Path:
         """Zip all files in the sandbox and return as a stream of bytes."""
         self._refresh_logs()
         try:
@@ -617,8 +618,11 @@ class EventStreamRuntime(Runtime):
                 stream=True,
                 timeout=30,
             )
-            data = response.content
-            return data
+            temp_file = tempfile.NamedTemporaryFile(delete=False)
+            for chunk in response.iter_content(chunk_size=8192):
+                if chunk:  # filter out keep-alive new chunks
+                    temp_file.write(chunk)
+            return Path(temp_file.name)
         except requests.Timeout:
             raise TimeoutError('Copy operation timed out')
 

+ 8 - 2
openhands/runtime/impl/remote/remote_runtime.py

@@ -1,4 +1,5 @@
 import os
+from pathlib import Path
 import tempfile
 import threading
 from typing import Callable, Optional
@@ -467,13 +468,18 @@ class RemoteRuntime(Runtime):
         assert isinstance(response_json, list)
         return response_json
 
-    def copy_from(self, path: str) -> bytes:
+    def copy_from(self, path: str) -> Path:
         """Zip all files in the sandbox and return as a stream of bytes."""
         params = {'path': path}
         response = self._send_request(
             'GET',
             f'{self.runtime_url}/download_files',
             params=params,
+            stream=True,
             timeout=30,
         )
-        return response.content
+        temp_file = tempfile.NamedTemporaryFile(delete=False)
+        for chunk in response.iter_content(chunk_size=8192):
+            if chunk:  # filter out keep-alive new chunks
+                temp_file.write(chunk)
+        return Path(temp_file.name)

+ 10 - 9
openhands/server/listen.py

@@ -1,5 +1,4 @@
 import asyncio
-import io
 import os
 import re
 import tempfile
@@ -27,6 +26,7 @@ with warnings.catch_warnings():
 
 from dotenv import load_dotenv
 from fastapi import (
+    BackgroundTasks,
     FastAPI,
     HTTPException,
     Request,
@@ -34,7 +34,7 @@ from fastapi import (
     WebSocket,
     status,
 )
-from fastapi.responses import JSONResponse, StreamingResponse
+from fastapi.responses import FileResponse, JSONResponse
 from fastapi.security import HTTPBearer
 from fastapi.staticfiles import StaticFiles
 from pydantic import BaseModel
@@ -790,20 +790,21 @@ async def security_api(request: Request):
 
 
 @app.get('/api/zip-directory')
-async def zip_current_workspace(request: Request):
+async def zip_current_workspace(request: Request, background_tasks: BackgroundTasks):
     try:
         logger.debug('Zipping workspace')
         runtime: Runtime = request.state.conversation.runtime
-
         path = runtime.config.workspace_mount_path_in_sandbox
-        zip_file_bytes = await call_sync_from_async(runtime.copy_from, path)
-        zip_stream = io.BytesIO(zip_file_bytes)  # Wrap to behave like a file stream
-        response = StreamingResponse(
-            zip_stream,
+        zip_file = await call_sync_from_async(runtime.copy_from, path)
+        response = FileResponse(
+            path=zip_file,
+            filename='workspace.zip',
             media_type='application/x-zip-compressed',
-            headers={'Content-Disposition': 'attachment; filename=workspace.zip'},
         )
 
+        # This will execute after the response is sent (So the file is not deleted before being sent)
+        background_tasks.add_task(zip_file.unlink)
+
         return response
     except Exception as e:
         logger.error(f'Error zipping workspace: {e}', exc_info=True)

+ 5 - 2
tests/runtime/test_bash.py

@@ -1,6 +1,7 @@
 """Bash-related tests for the EventStreamRuntime, which connects to the ActionExecutor running in the sandbox."""
 
 import os
+from pathlib import Path
 
 import pytest
 from conftest import (
@@ -586,8 +587,10 @@ def test_copy_from_directory(temp_dir, runtime_cls):
         path_to_copy_from = f'{sandbox_dir}/test_dir'
         result = runtime.copy_from(path=path_to_copy_from)
 
-        # Result is returned in bytes
-        assert isinstance(result, bytes)
+        # Result is returned as a path
+        assert isinstance(result, Path)
+
+        result.unlink()
     finally:
         _close_test_runtime(runtime)