Explorar o código

fix zip downloads (#5009)

Robert Brennan hai 1 ano
pai
achega
f3b35663e9

+ 35 - 33
openhands/runtime/impl/eventstream/eventstream_runtime.py

@@ -432,12 +432,13 @@ class EventStreamRuntime(Runtime):
         if not self.log_buffer:
             raise RuntimeError('Runtime client is not ready.')
 
-        send_request(
+        with send_request(
             self.session,
             'GET',
             f'{self.api_url}/alive',
             timeout=5,
-        )
+        ):
+            pass
 
     def close(self, rm_all_containers: bool = True):
         """Closes the EventStreamRuntime and associated objects
@@ -496,17 +497,17 @@ class EventStreamRuntime(Runtime):
             assert action.timeout is not None
 
             try:
-                response = send_request(
+                with send_request(
                     self.session,
                     'POST',
                     f'{self.api_url}/execute_action',
                     json={'action': event_to_dict(action)},
                     # wait a few more seconds to get the timeout error from client side
                     timeout=action.timeout + 5,
-                )
-                output = response.json()
-                obs = observation_from_dict(output)
-                obs._cause = action.id  # type: ignore[attr-defined]
+                ) as response:
+                    output = response.json()
+                    obs = observation_from_dict(output)
+                    obs._cause = action.id  # type: ignore[attr-defined]
             except requests.Timeout:
                 raise RuntimeError(
                     f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
@@ -567,14 +568,15 @@ class EventStreamRuntime(Runtime):
 
             params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
 
-            send_request(
+            with send_request(
                 self.session,
                 'POST',
                 f'{self.api_url}/upload_file',
                 files=upload_data,
                 params=params,
                 timeout=300,
-            )
+            ):
+                pass
 
         except requests.Timeout:
             raise TimeoutError('Copy operation timed out')
@@ -599,16 +601,16 @@ class EventStreamRuntime(Runtime):
             if path is not None:
                 data['path'] = path
 
-            response = send_request(
+            with send_request(
                 self.session,
                 'POST',
                 f'{self.api_url}/list_files',
                 json=data,
                 timeout=10,
-            )
-            response_json = response.json()
-            assert isinstance(response_json, list)
-            return response_json
+            ) as response:
+                response_json = response.json()
+                assert isinstance(response_json, list)
+                return response_json
         except requests.Timeout:
             raise TimeoutError('List files operation timed out')
 
@@ -617,19 +619,19 @@ class EventStreamRuntime(Runtime):
         self._refresh_logs()
         try:
             params = {'path': path}
-            response = send_request(
+            with send_request(
                 self.session,
                 'GET',
                 f'{self.api_url}/download_files',
                 params=params,
                 stream=True,
                 timeout=30,
-            )
-            temp_file = tempfile.NamedTemporaryFile(delete=False)
-            for chunk in response.iter_content(chunk_size=8192):
-                if chunk:  # filter out keep-alive new chunks
-                    temp_file.write(chunk)
-            return Path(temp_file.name)
+            ) as response:
+                temp_file = tempfile.NamedTemporaryFile(delete=False)
+                for chunk in response.iter_content(chunk_size=8192):
+                    if chunk:  # filter out keep-alive new chunks
+                        temp_file.write(chunk)
+                return Path(temp_file.name)
         except requests.Timeout:
             raise TimeoutError('Copy operation timed out')
 
@@ -658,21 +660,21 @@ class EventStreamRuntime(Runtime):
             ):  # cached value
                 return self._vscode_url
 
-            response = send_request(
+            with send_request(
                 self.session,
                 'GET',
                 f'{self.api_url}/vscode/connection_token',
                 timeout=10,
-            )
-            response_json = response.json()
-            assert isinstance(response_json, dict)
-            if response_json['token'] is None:
-                return None
-            self._vscode_url = f'http://localhost:{self._host_port + 1}/?tkn={response_json["token"]}&folder={self.config.workspace_mount_path_in_sandbox}'
-            self.log(
-                'debug',
-                f'VSCode URL: {self._vscode_url}',
-            )
-            return self._vscode_url
+            ) as response:
+                response_json = response.json()
+                assert isinstance(response_json, dict)
+                if response_json['token'] is None:
+                    return None
+                self._vscode_url = f'http://localhost:{self._host_port + 1}/?tkn={response_json["token"]}&folder={self.config.workspace_mount_path_in_sandbox}'
+                self.log(
+                    'debug',
+                    f'VSCode URL: {self._vscode_url}',
+                )
+                return self._vscode_url
         else:
             return None

+ 51 - 53
openhands/runtime/impl/remote/remote_runtime.py

@@ -141,29 +141,29 @@ class RemoteRuntime(Runtime):
 
     def _check_existing_runtime(self) -> bool:
         try:
-            response = self._send_request(
+            with self._send_request(
                 'GET',
                 f'{self.config.sandbox.remote_runtime_api_url}/sessions/{self.sid}',
                 is_retry=False,
                 timeout=5,
-            )
+            ) as response:
+                data = response.json()
+                status = data.get('status')
+                if status == 'running' or status == 'paused':
+                    self._parse_runtime_response(response)
         except requests.HTTPError as e:
             if e.response.status_code == 404:
                 return False
             self.log('debug', f'Error while looking for remote runtime: {e}')
             raise
 
-        data = response.json()
-        status = data.get('status')
         if status == 'running':
-            self._parse_runtime_response(response)
             return True
         elif status == 'stopped':
             self.log('debug', 'Found existing remote runtime, but it is stopped')
             return False
         elif status == 'paused':
             self.log('debug', 'Found existing remote runtime, but it is paused')
-            self._parse_runtime_response(response)
             self._resume_runtime()
             return True
         else:
@@ -172,13 +172,13 @@ class RemoteRuntime(Runtime):
 
     def _build_runtime(self):
         self.log('debug', f'Building RemoteRuntime config:\n{self.config}')
-        response = self._send_request(
+        with self._send_request(
             'GET',
             f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix',
             is_retry=False,
             timeout=10,
-        )
-        response_json = response.json()
+        ) as response:
+            response_json = response.json()
         registry_prefix = response_json['registry_prefix']
         os.environ['OH_RUNTIME_RUNTIME_IMAGE_REPO'] = (
             registry_prefix.rstrip('/') + '/runtime'
@@ -203,15 +203,17 @@ class RemoteRuntime(Runtime):
             force_rebuild=self.config.sandbox.force_rebuild_runtime,
         )
 
-        response = self._send_request(
+        with self._send_request(
             'GET',
             f'{self.config.sandbox.remote_runtime_api_url}/image_exists',
             is_retry=False,
             params={'image': self.container_image},
             timeout=10,
-        )
-        if not response.json()['exists']:
-            raise RuntimeError(f'Container image {self.container_image} does not exist')
+        ) as response:
+            if not response.json()['exists']:
+                raise RuntimeError(
+                    f'Container image {self.container_image} does not exist'
+                )
 
     def _start_runtime(self):
         # Prepare the request body for the /start endpoint
@@ -240,26 +242,27 @@ class RemoteRuntime(Runtime):
         }
 
         # Start the sandbox using the /start endpoint
-        response = self._send_request(
+        with self._send_request(
             'POST',
             f'{self.config.sandbox.remote_runtime_api_url}/start',
             is_retry=False,
             json=start_request,
-        )
-        self._parse_runtime_response(response)
+        ) as response:
+            self._parse_runtime_response(response)
         self.log(
             'debug',
             f'Runtime started. URL: {self.runtime_url}',
         )
 
     def _resume_runtime(self):
-        self._send_request(
+        with self._send_request(
             'POST',
             f'{self.config.sandbox.remote_runtime_api_url}/resume',
             is_retry=False,
             json={'runtime_id': self.runtime_id},
             timeout=30,
-        )
+        ):
+            pass
         self.log('debug', 'Runtime resumed.')
 
     def _parse_runtime_response(self, response: requests.Response):
@@ -279,12 +282,12 @@ class RemoteRuntime(Runtime):
             ):  # cached value
                 return self._vscode_url
 
-            response = self._send_request(
+            with self._send_request(
                 'GET',
                 f'{self.runtime_url}/vscode/connection_token',
                 timeout=10,
-            )
-            response_json = response.json()
+            ) as response:
+                response_json = response.json()
             assert isinstance(response_json, dict)
             if response_json['token'] is None:
                 return None
@@ -316,11 +319,11 @@ class RemoteRuntime(Runtime):
 
     def _wait_until_alive_impl(self):
         self.log('debug', f'Waiting for runtime to be alive at url: {self.runtime_url}')
-        runtime_info_response = self._send_request(
+        with self._send_request(
             'GET',
             f'{self.config.sandbox.remote_runtime_api_url}/sessions/{self.sid}',
-        )
-        runtime_data = runtime_info_response.json()
+        ) as runtime_info_response:
+            runtime_data = runtime_info_response.json()
         assert 'runtime_id' in runtime_data
         assert runtime_data['runtime_id'] == self.runtime_id
         assert 'pod_status' in runtime_data
@@ -332,10 +335,11 @@ class RemoteRuntime(Runtime):
         # Retry a period of time to give the cluster time to start the pod
         if pod_status == 'Ready':
             try:
-                self._send_request(
+                with self._send_request(
                     'GET',
                     f'{self.runtime_url}/alive',
-                )  # will raise exception if we don't get 200 back.
+                ):  # will raise exception if we don't get 200 back.
+                    pass
             except requests.HTTPError as e:
                 self.log(
                     'warning', f"Runtime /alive failed, but pod says it's ready: {e}"
@@ -374,19 +378,13 @@ class RemoteRuntime(Runtime):
             return
         if self.runtime_id and self.session:
             try:
-                response = self._send_request(
+                with self._send_request(
                     'POST',
                     f'{self.config.sandbox.remote_runtime_api_url}/stop',
                     is_retry=False,
                     json={'runtime_id': self.runtime_id},
                     timeout=timeout,
-                )
-                if response.status_code != 200:
-                    self.log(
-                        'error',
-                        f'Failed to stop runtime: {response.text}',
-                    )
-                else:
+                ):
                     self.log('debug', 'Runtime stopped.')
             except Exception as e:
                 raise e
@@ -415,15 +413,15 @@ class RemoteRuntime(Runtime):
             try:
                 request_body = {'action': event_to_dict(action)}
                 self.log('debug', f'Request body: {request_body}')
-                response = self._send_request(
+                with self._send_request(
                     'POST',
                     f'{self.runtime_url}/execute_action',
                     is_retry=False,
                     json=request_body,
                     # wait a few more seconds to get the timeout error from client side
                     timeout=action.timeout + 5,
-                )
-                output = response.json()
+                ) as response:
+                    output = response.json()
                 obs = observation_from_dict(output)
                 obs._cause = action.id  # type: ignore[attr-defined]
             except requests.Timeout:
@@ -502,18 +500,18 @@ class RemoteRuntime(Runtime):
 
             params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
 
-            response = self._send_request(
+            with self._send_request(
                 'POST',
                 f'{self.runtime_url}/upload_file',
                 is_retry=False,
                 files=upload_data,
                 params=params,
                 timeout=300,
-            )
-            self.log(
-                'debug',
-                f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
-            )
+            ) as response:
+                self.log(
+                    'debug',
+                    f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
+                )
         finally:
             if recursive:
                 os.unlink(temp_zip_path)
@@ -526,30 +524,30 @@ class RemoteRuntime(Runtime):
         if path is not None:
             data['path'] = path
 
-        response = self._send_request(
+        with self._send_request(
             'POST',
             f'{self.runtime_url}/list_files',
             is_retry=False,
             json=data,
             timeout=30,
-        )
-        response_json = response.json()
+        ) as response:
+            response_json = response.json()
         assert isinstance(response_json, list)
         return response_json
 
     def copy_from(self, path: str) -> Path:
         """Zip all files in the sandbox and return as a stream of bytes."""
         params = {'path': path}
-        response = self._send_request(
+        with self._send_request(
             'GET',
             f'{self.runtime_url}/download_files',
             is_retry=False,
             params=params,
             stream=True,
             timeout=30,
-        )
-        temp_file = tempfile.NamedTemporaryFile(delete=False)
-        for chunk in response.iter_content(chunk_size=8192):
-            if chunk:  # filter out keep-alive new chunks
-                temp_file.write(chunk)
-        return Path(temp_file.name)
+        ) as response:
+            temp_file = tempfile.NamedTemporaryFile(delete=False)
+            for chunk in response.iter_content(chunk_size=8192):
+                if chunk:  # filter out keep-alive new chunks
+                    temp_file.write(chunk)
+            return Path(temp_file.name)

+ 1 - 5
openhands/runtime/utils/request.py

@@ -58,9 +58,5 @@ def send_request(
     **kwargs: Any,
 ) -> requests.Response:
     response = session.request(method, url, **kwargs)
-    try:
-        response.raise_for_status()
-    finally:
-        response.close()
-
+    response.raise_for_status()
     return response