瀏覽代碼

feat: clean-up retries RemoteRuntime & add FatalErrorObservation (#4485)

Xingyao Wang 1 年之前
父節點
當前提交
91308ba4dc

+ 8 - 0
evaluation/swe_bench/run_infer.py

@@ -11,6 +11,7 @@ from datasets import load_dataset
 import openhands.agenthub
 from evaluation.swe_bench.prompt import CODEACT_SWE_PROMPT
 from evaluation.utils.shared import (
+    EvalException,
     EvalMetadata,
     EvalOutput,
     assert_and_raise,
@@ -384,6 +385,13 @@ def process_instance(
             )
         )
 
+        # if fatal error, throw EvalError to trigger re-run
+        if (
+            state.last_error
+            and 'fatal error during agent execution' in state.last_error
+        ):
+            raise EvalException('Fatal error detected: ' + state.last_error)
+
         # ======= THIS IS SWE-Bench specific =======
         # Get git patch
         return_val = complete_runtime(runtime, instance)

+ 7 - 0
openhands/controller/agent_controller.py

@@ -35,6 +35,7 @@ from openhands.events.observation import (
     AgentStateChangedObservation,
     CmdOutputObservation,
     ErrorObservation,
+    FatalErrorObservation,
     Observation,
 )
 from openhands.events.serialization.event import truncate_content
@@ -249,6 +250,12 @@ class AgentController:
         elif isinstance(observation, ErrorObservation):
             if self.state.agent_state == AgentState.ERROR:
                 self.state.metrics.merge(self.state.local_metrics)
+        elif isinstance(observation, FatalErrorObservation):
+            await self.report_error(
+                'There was a fatal error during agent execution: ' + str(observation)
+            )
+            await self.set_agent_state_to(AgentState.ERROR)
+            self.state.metrics.merge(self.state.local_metrics)
 
     async def _handle_message_action(self, action: MessageAction):
         """Handles message actions from the event stream.

+ 6 - 2
openhands/events/observation/__init__.py

@@ -6,8 +6,11 @@ from openhands.events.observation.commands import (
 )
 from openhands.events.observation.delegate import AgentDelegateObservation
 from openhands.events.observation.empty import NullObservation
-from openhands.events.observation.error import ErrorObservation
-from openhands.events.observation.files import FileReadObservation, FileWriteObservation
+from openhands.events.observation.error import ErrorObservation, FatalErrorObservation
+from openhands.events.observation.files import (
+    FileReadObservation,
+    FileWriteObservation,
+)
 from openhands.events.observation.observation import Observation
 from openhands.events.observation.reject import UserRejectObservation
 from openhands.events.observation.success import SuccessObservation
@@ -21,6 +24,7 @@ __all__ = [
     'FileReadObservation',
     'FileWriteObservation',
     'ErrorObservation',
+    'FatalErrorObservation',
     'AgentStateChangedObservation',
     'AgentDelegateObservation',
     'SuccessObservation',

+ 22 - 1
openhands/events/observation/error.py

@@ -6,10 +6,31 @@ from openhands.events.observation.observation import Observation
 
 @dataclass
 class ErrorObservation(Observation):
-    """This data class represents an error encountered by the agent."""
+    """This data class represents an error encountered by the agent.
+
+    This is the type of error that LLM can recover from.
+    E.g., Linter error after editing a file.
+    """
 
     observation: str = ObservationType.ERROR
 
     @property
     def message(self) -> str:
         return self.content
+
+    def __str__(self) -> str:
+        return f'**ErrorObservation**\n{self.content}'
+
+
+@dataclass
+class FatalErrorObservation(Observation):
+    """This data class represents a fatal error encountered by the agent.
+
+    This is the type of error that LLM CANNOT recover from, and the agent controller should stop the execution and report the error to the user.
+    E.g., Remote runtime action execution failure: 503 Server Error: Service Unavailable for url OR 404 Not Found.
+    """
+
+    observation: str = ObservationType.ERROR
+
+    def __str__(self) -> str:
+        return f'**FatalErrorObservation**\n{self.content}'

+ 22 - 8
openhands/runtime/client/runtime.py

@@ -23,7 +23,7 @@ from openhands.events.action import (
 )
 from openhands.events.action.action import Action
 from openhands.events.observation import (
-    ErrorObservation,
+    FatalErrorObservation,
     NullObservation,
     Observation,
     UserRejectObservation,
@@ -126,7 +126,13 @@ class EventStreamRuntime(Runtime):
         attach_to_existing: bool = False,
     ):
         super().__init__(
-            config, event_stream, sid, plugins, env_vars, status_message_callback, attach_to_existing
+            config,
+            event_stream,
+            sid,
+            plugins,
+            env_vars,
+            status_message_callback,
+            attach_to_existing,
         )
 
     def __init__(
@@ -192,7 +198,13 @@ class EventStreamRuntime(Runtime):
 
         # Will initialize both the event stream and the env vars
         self.init_base_runtime(
-            config, event_stream, sid, plugins, env_vars, status_message_callback, attach_to_existing
+            config,
+            event_stream,
+            sid,
+            plugins,
+            env_vars,
+            status_message_callback,
+            attach_to_existing,
         )
 
         logger.info('Waiting for client to become ready...')
@@ -431,9 +443,9 @@ class EventStreamRuntime(Runtime):
                 return NullObservation('')
             action_type = action.action  # type: ignore[attr-defined]
             if action_type not in ACTION_TYPE_TO_CLASS:
-                return ErrorObservation(f'Action {action_type} does not exist.')
+                return FatalErrorObservation(f'Action {action_type} does not exist.')
             if not hasattr(self, action_type):
-                return ErrorObservation(
+                return FatalErrorObservation(
                     f'Action {action_type} is not supported in the current runtime.'
                 )
             if (
@@ -465,15 +477,17 @@ class EventStreamRuntime(Runtime):
                     logger.debug(f'response: {response}')
                     error_message = response.text
                     logger.error(f'Error from server: {error_message}')
-                    obs = ErrorObservation(f'Action execution failed: {error_message}')
+                    obs = FatalErrorObservation(
+                        f'Action execution failed: {error_message}'
+                    )
             except requests.Timeout:
                 logger.error('No response received within the timeout period.')
-                obs = ErrorObservation(
+                obs = FatalErrorObservation(
                     f'Action execution timed out after {action.timeout} seconds.'
                 )
             except Exception as e:
                 logger.error(f'Error during action execution: {e}')
-                obs = ErrorObservation(f'Action execution failed: {str(e)}')
+                obs = FatalErrorObservation(f'Action execution failed: {str(e)}')
             self._refresh_logs()
             return obs
 

+ 59 - 46
openhands/runtime/remote/runtime.py

@@ -21,7 +21,7 @@ from openhands.events.action import (
 )
 from openhands.events.action.action import Action
 from openhands.events.observation import (
-    ErrorObservation,
+    FatalErrorObservation,
     NullObservation,
     Observation,
 )
@@ -31,8 +31,8 @@ from openhands.runtime.builder.remote import RemoteRuntimeBuilder
 from openhands.runtime.plugins import PluginRequirement
 from openhands.runtime.runtime import Runtime
 from openhands.runtime.utils.request import (
-    DEFAULT_RETRY_EXCEPTIONS,
     is_404_error,
+    is_503_error,
     send_request_with_retry,
 )
 from openhands.runtime.utils.runtime_build import build_runtime_image
@@ -90,7 +90,6 @@ class RemoteRuntime(Runtime):
             status_message_callback,
             attach_to_existing,
         )
-        self._wait_until_alive()
         self.setup_initial_env()
 
     def _start_or_attach_to_runtime(
@@ -232,10 +231,12 @@ class RemoteRuntime(Runtime):
             timeout=300,
         )
         if response.status_code != 201:
-            raise RuntimeError(f'Failed to start sandbox: {response.text}')
+            raise RuntimeError(
+                f'[Runtime (ID={self.runtime_id})] Failed to start runtime: {response.text}'
+            )
         self._parse_runtime_response(response)
         logger.info(
-            f'Sandbox started. Runtime ID: {self.runtime_id}, URL: {self.runtime_url}'
+            f'[Runtime (ID={self.runtime_id})] Runtime started. URL: {self.runtime_url}'
         )
 
     def _resume_runtime(self):
@@ -247,8 +248,10 @@ class RemoteRuntime(Runtime):
             timeout=30,
         )
         if response.status_code != 200:
-            raise RuntimeError(f'Failed to resume sandbox: {response.text}')
-        logger.info(f'Sandbox resumed. Runtime ID: {self.runtime_id}')
+            raise RuntimeError(
+                f'[Runtime (ID={self.runtime_id})] Failed to resume runtime: {response.text}'
+            )
+        logger.info(f'[Runtime (ID={self.runtime_id})] Runtime resumed.')
 
     def _parse_runtime_response(self, response: requests.Response):
         start_response = response.json()
@@ -298,7 +301,7 @@ class RemoteRuntime(Runtime):
                 # clean up the runtime
                 self.close()
                 raise RuntimeError(
-                    f'Runtime pod failed to start. Current status: {pod_status}'
+                    f'Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}'
                 )
             # Pending otherwise - add proper sleep
             time.sleep(10)
@@ -307,15 +310,15 @@ class RemoteRuntime(Runtime):
             self.session,
             'GET',
             f'{self.runtime_url}/alive',
-            # Retry 404 errors for the /alive endpoint
+            # Retry 404 & 503 errors for the /alive endpoint
             # because the runtime might just be starting up
             # and have not registered the endpoint yet
-            retry_fns=[is_404_error],
+            retry_fns=[is_404_error, is_503_error],
             # leave enough time for the runtime to start up
             timeout=600,
         )
         if response.status_code != 200:
-            msg = f'Runtime is not alive yet (id={self.runtime_id}). Status: {response.status_code}.'
+            msg = f'Runtime (ID={self.runtime_id}) is not alive yet. Status: {response.status_code}.'
             logger.warning(msg)
             raise RuntimeError(msg)
 
@@ -333,9 +336,11 @@ class RemoteRuntime(Runtime):
                     timeout=timeout,
                 )
                 if response.status_code != 200:
-                    logger.error(f'Failed to stop sandbox: {response.text}')
+                    logger.error(
+                        f'[Runtime (ID={self.runtime_id})] Failed to stop runtime: {response.text}'
+                    )
                 else:
-                    logger.info(f'Sandbox stopped. Runtime ID: {self.runtime_id}')
+                    logger.info(f'[Runtime (ID={self.runtime_id})] Runtime stopped.')
             except Exception as e:
                 raise e
             finally:
@@ -349,16 +354,17 @@ class RemoteRuntime(Runtime):
                 return NullObservation('')
             action_type = action.action  # type: ignore[attr-defined]
             if action_type not in ACTION_TYPE_TO_CLASS:
-                return ErrorObservation(f'Action {action_type} does not exist.')
+                return FatalErrorObservation(
+                    f'[Runtime (ID={self.runtime_id})] Action {action_type} does not exist.'
+                )
             if not hasattr(self, action_type):
-                return ErrorObservation(
-                    f'Action {action_type} is not supported in the current runtime.'
+                return FatalErrorObservation(
+                    f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.'
                 )
 
             assert action.timeout is not None
 
             try:
-                logger.info('Executing action')
                 request_body = {'action': event_to_dict(action)}
                 logger.debug(f'Request body: {request_body}')
                 response = send_request_with_retry(
@@ -367,13 +373,6 @@ class RemoteRuntime(Runtime):
                     f'{self.runtime_url}/execute_action',
                     json=request_body,
                     timeout=action.timeout,
-                    retry_exceptions=list(
-                        filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
-                    ),
-                    # Retry 404 errors for the /execute_action endpoint
-                    # because the runtime might just be starting up
-                    # and have not registered the endpoint yet
-                    retry_fns=[is_404_error],
                 )
                 if response.status_code == 200:
                     output = response.json()
@@ -383,13 +382,19 @@ class RemoteRuntime(Runtime):
                 else:
                     error_message = response.text
                     logger.error(f'Error from server: {error_message}')
-                    obs = ErrorObservation(f'Action execution failed: {error_message}')
+                    obs = FatalErrorObservation(
+                        f'Action execution failed: {error_message}'
+                    )
             except Timeout:
                 logger.error('No response received within the timeout period.')
-                obs = ErrorObservation('Action execution timed out')
+                obs = FatalErrorObservation(
+                    f'[Runtime (ID={self.runtime_id})] Action execution timed out'
+                )
             except Exception as e:
                 logger.error(f'Error during action execution: {e}')
-                obs = ErrorObservation(f'Action execution failed: {str(e)}')
+                obs = FatalErrorObservation(
+                    f'[Runtime (ID={self.runtime_id})] Action execution failed: {str(e)}'
+                )
             return obs
 
     def run(self, action: CmdRunAction) -> Observation:
@@ -444,9 +449,6 @@ class RemoteRuntime(Runtime):
                 f'{self.runtime_url}/upload_file',
                 files=upload_data,
                 params=params,
-                retry_exceptions=list(
-                    filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
-                ),
                 timeout=300,
             )
             if response.status_code == 200:
@@ -456,11 +458,17 @@ class RemoteRuntime(Runtime):
                 return
             else:
                 error_message = response.text
-                raise Exception(f'Copy operation failed: {error_message}')
+                raise Exception(
+                    f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}'
+                )
         except TimeoutError:
-            raise TimeoutError('Copy operation timed out')
+            raise TimeoutError(
+                f'[Runtime (ID={self.runtime_id})] Copy operation timed out'
+            )
         except Exception as e:
-            raise RuntimeError(f'Copy operation failed: {str(e)}')
+            raise RuntimeError(
+                f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}'
+            )
         finally:
             if recursive:
                 os.unlink(temp_zip_path)
@@ -477,9 +485,6 @@ class RemoteRuntime(Runtime):
                 'POST',
                 f'{self.runtime_url}/list_files',
                 json=data,
-                retry_exceptions=list(
-                    filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
-                ),
                 timeout=30,
             )
             if response.status_code == 200:
@@ -488,15 +493,20 @@ class RemoteRuntime(Runtime):
                 return response_json
             else:
                 error_message = response.text
-                raise Exception(f'List files operation failed: {error_message}')
+                raise Exception(
+                    f'[Runtime (ID={self.runtime_id})] List files operation failed: {error_message}'
+                )
         except TimeoutError:
-            raise TimeoutError('List files operation timed out')
+            raise TimeoutError(
+                f'[Runtime (ID={self.runtime_id})] List files operation timed out'
+            )
         except Exception as e:
-            raise RuntimeError(f'List files operation failed: {str(e)}')
+            raise RuntimeError(
+                f'[Runtime (ID={self.runtime_id})] List files operation failed: {str(e)}'
+            )
 
     def copy_from(self, path: str) -> bytes:
         """Zip all files in the sandbox and return as a stream of bytes."""
-        self._wait_until_alive()
         try:
             params = {'path': path}
             response = send_request_with_retry(
@@ -505,19 +515,22 @@ class RemoteRuntime(Runtime):
                 f'{self.runtime_url}/download_files',
                 params=params,
                 timeout=30,
-                retry_exceptions=list(
-                    filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
-                ),
             )
             if response.status_code == 200:
                 return response.content
             else:
                 error_message = response.text
-                raise Exception(f'Copy operation failed: {error_message}')
+                raise Exception(
+                    f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}'
+                )
         except requests.Timeout:
-            raise TimeoutError('Copy operation timed out')
+            raise TimeoutError(
+                f'[Runtime (ID={self.runtime_id})] Copy operation timed out'
+            )
         except Exception as e:
-            raise RuntimeError(f'Copy operation failed: {str(e)}')
+            raise RuntimeError(
+                f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}'
+            )
 
     def send_status_message(self, message: str):
         """Sends a status message if the callback function was provided."""

+ 22 - 3
openhands/runtime/utils/request.py

@@ -1,7 +1,10 @@
 from typing import Any, Callable, Type
 
 import requests
-from requests.exceptions import ConnectionError, Timeout
+from requests.exceptions import (
+    ChunkedEncodingError,
+    ConnectionError,
+)
 from tenacity import (
     retry,
     retry_if_exception,
@@ -9,6 +12,7 @@ from tenacity import (
     stop_after_delay,
     wait_exponential,
 )
+from urllib3.exceptions import IncompleteRead
 
 from openhands.utils.tenacity_stop import stop_if_should_exit
 
@@ -27,9 +31,24 @@ def is_404_error(exception):
     )
 
 
+def is_503_error(exception):
+    return (
+        isinstance(exception, requests.HTTPError)
+        and exception.response.status_code == 503
+    )
+
+
+def is_502_error(exception):
+    return (
+        isinstance(exception, requests.HTTPError)
+        and exception.response.status_code == 502
+    )
+
+
 DEFAULT_RETRY_EXCEPTIONS = [
     ConnectionError,
-    Timeout,
+    IncompleteRead,
+    ChunkedEncodingError,
 ]
 
 
@@ -45,7 +64,7 @@ def send_request_with_retry(
     exceptions_to_catch = retry_exceptions or DEFAULT_RETRY_EXCEPTIONS
     retry_condition = retry_if_exception_type(
         tuple(exceptions_to_catch)
-    ) | retry_if_exception(is_server_error)
+    ) | retry_if_exception(is_502_error)
     if retry_fns is not None:
         for fn in retry_fns:
             retry_condition |= retry_if_exception(fn)