Explorar el Código

Feat Tightening up Timeouts and interrupt conditions. (#3926)

tofarr hace 1 año
padre
commit
ad0b549d8b

+ 3 - 3
evaluation/swe_bench/run_infer.py

@@ -2,7 +2,6 @@ import asyncio
 import json
 import os
 import tempfile
-import time
 from typing import Any
 
 import pandas as pd
@@ -32,6 +31,7 @@ from openhands.core.main import create_runtime, run_controller
 from openhands.events.action import CmdRunAction
 from openhands.events.observation import CmdOutputObservation, ErrorObservation
 from openhands.runtime.runtime import Runtime
+from openhands.runtime.utils.shutdown_listener import sleep_if_should_continue
 
 USE_HINT_TEXT = os.environ.get('USE_HINT_TEXT', 'false').lower() == 'true'
 USE_INSTANCE_IMAGE = os.environ.get('USE_INSTANCE_IMAGE', 'false').lower() == 'true'
@@ -316,10 +316,10 @@ def complete_runtime(
                 break
             else:
                 logger.info('Failed to get git diff, retrying...')
-                time.sleep(10)
+                sleep_if_should_continue(10)
         elif isinstance(obs, ErrorObservation):
             logger.error(f'Error occurred: {obs.content}. Retrying...')
-            time.sleep(10)
+            sleep_if_should_continue(10)
         else:
             raise ValueError(f'Unexpected observation type: {type(obs)}')
 

+ 2 - 1
openhands/controller/agent_controller.py

@@ -37,6 +37,7 @@ from openhands.events.observation import (
     Observation,
 )
 from openhands.llm.llm import LLM
+from openhands.runtime.utils.shutdown_listener import should_continue
 
 # note: RESUME is only available on web GUI
 TRAFFIC_CONTROL_REMINDER = (
@@ -148,7 +149,7 @@ class AgentController:
         """The main loop for the agent's step-by-step execution."""
 
         logger.info(f'[Agent Controller {self.id}] Starting step loop...')
-        while True:
+        while should_continue():
             try:
                 await self._step()
             except asyncio.CancelledError:

+ 2 - 1
openhands/events/stream.py

@@ -8,6 +8,7 @@ from openhands.core.logger import openhands_logger as logger
 from openhands.core.utils import json
 from openhands.events.event import Event, EventSource
 from openhands.events.serialization.event import event_from_dict, event_to_dict
+from openhands.runtime.utils.shutdown_listener import should_continue
 from openhands.storage import FileStore
 
 
@@ -85,7 +86,7 @@ class EventStream:
                 event_id -= 1
         else:
             event_id = start_id
-            while True:
+            while should_continue():
                 if end_id is not None and event_id > end_id:
                     break
                 try:

+ 2 - 1
openhands/llm/llm.py

@@ -5,6 +5,7 @@ from functools import partial
 from typing import Union
 
 from openhands.core.config import LLMConfig
+from openhands.runtime.utils.shutdown_listener import should_continue
 
 with warnings.catch_warnings():
     warnings.simplefilter('ignore')
@@ -296,7 +297,7 @@ class LLM:
             debug_message = self._get_debug_message(messages)
 
             async def check_stopped():
-                while True:
+                while should_continue():
                     if (
                         hasattr(self.config, 'on_cancel_requested_fn')
                         and self.config.on_cancel_requested_fn is not None

+ 3 - 2
openhands/runtime/browser/browser_env.py

@@ -16,6 +16,7 @@ from PIL import Image
 
 from openhands.core.exceptions import BrowserInitException
 from openhands.core.logger import openhands_logger as logger
+from openhands.runtime.utils.shutdown_listener import should_continue, should_exit
 
 BROWSER_EVAL_GET_GOAL_ACTION = 'GET_EVAL_GOAL'
 BROWSER_EVAL_GET_REWARDS_ACTION = 'GET_EVAL_REWARDS'
@@ -99,7 +100,7 @@ class BrowserEnv:
             self.eval_goal = obs['goal']
 
         logger.info('Browser env started.')
-        while True:
+        while should_continue():
             try:
                 if self.browser_side.poll(timeout=0.01):
                     unique_request_id, action_data = self.browser_side.recv()
@@ -157,7 +158,7 @@ class BrowserEnv:
         self.agent_side.send((unique_request_id, {'action': action_str}))
         start_time = time.time()
         while True:
-            if time.time() - start_time > timeout:
+            if should_exit() or time.time() - start_time > timeout:
                 raise TimeoutError('Browser environment took too long to respond.')
             if self.agent_side.poll(timeout=0.01):
                 response_id, obs = self.agent_side.recv()

+ 3 - 2
openhands/runtime/builder/remote.py

@@ -8,6 +8,7 @@ import requests
 from openhands.core.logger import openhands_logger as logger
 from openhands.runtime.builder import RuntimeBuilder
 from openhands.runtime.utils.request import send_request
+from openhands.runtime.utils.shutdown_listener import should_exit, sleep_if_should_continue
 
 
 class RemoteRuntimeBuilder(RuntimeBuilder):
@@ -57,7 +58,7 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
         start_time = time.time()
         timeout = 30 * 60  # 20 minutes in seconds
         while True:
-            if time.time() - start_time > timeout:
+            if should_exit() or time.time() - start_time > timeout:
                 logger.error('Build timed out after 30 minutes')
                 raise RuntimeError('Build timed out after 30 minutes')
 
@@ -95,7 +96,7 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
                 raise RuntimeError(error_message)
 
             # Wait before polling again
-            time.sleep(30)
+            sleep_if_should_continue(30)
 
     def image_exists(self, image_name: str) -> bool:
         """Checks if an image exists in the remote registry using the /image_exists endpoint."""

+ 2 - 1
openhands/runtime/plugins/jupyter/__init__.py

@@ -8,6 +8,7 @@ from openhands.events.observation import IPythonRunCellObservation
 from openhands.runtime.plugins.jupyter.execute_server import JupyterKernel
 from openhands.runtime.plugins.requirement import Plugin, PluginRequirement
 from openhands.runtime.utils import find_available_tcp_port
+from openhands.runtime.utils.shutdown_listener import should_continue
 
 
 @dataclass
@@ -38,7 +39,7 @@ class JupyterPlugin(Plugin):
         )
         # read stdout until the kernel gateway is ready
         output = ''
-        while True and self.gateway_process.stdout is not None:
+        while should_continue() and self.gateway_process.stdout is not None:
             line = self.gateway_process.stdout.readline().decode('utf-8')
             output += line
             if 'at' in line:

+ 1 - 0
openhands/runtime/utils/request.py

@@ -47,6 +47,7 @@ def send_request(
     if retry_fns is not None:
         for fn in retry_fns:
             retry_condition |= retry_if_exception(fn)
+    kwargs["timeout"] = timeout
 
     @retry(
         stop=stop_after_delay(timeout),

+ 60 - 0
openhands/runtime/utils/shutdown_listener.py

@@ -0,0 +1,60 @@
+"""
+This module monitors the app for shutdown signals
+"""
+import asyncio
+import signal
+import time
+from types import FrameType
+
+from uvicorn.server import HANDLED_SIGNALS
+
+_should_exit = None
+
+
+def _register_signal_handler(sig: signal.Signals):
+    original_handler = None
+
+    def handler(sig_: int, frame: FrameType | None):
+        global _should_exit
+        _should_exit = True    
+        if original_handler:
+            original_handler(sig_, frame)  # type: ignore[unreachable]
+
+    original_handler = signal.signal(sig, handler)
+
+
+def _register_signal_handlers():
+    global _should_exit
+    if _should_exit is not None:
+        return
+    _should_exit = False
+    for sig in HANDLED_SIGNALS:
+        _register_signal_handler(sig)
+
+
+def should_exit() -> bool:
+    _register_signal_handlers()
+    return bool(_should_exit)
+
+
+def should_continue() -> bool:
+    _register_signal_handlers()
+    return not _should_exit
+
+
+def sleep_if_should_continue(timeout: float):
+    if(timeout <= 1):
+        time.sleep(timeout)
+        return
+    start_time = time.time()
+    while (time.time() - start_time) < timeout and should_continue():
+        time.sleep(1)
+
+
+async def async_sleep_if_should_continue(timeout: float):
+    if(timeout <= 1):
+        await asyncio.sleep(timeout)
+        return
+    start_time = time.time()
+    while time.time() - start_time < timeout and should_continue():
+        await asyncio.sleep(1)

+ 2 - 1
openhands/server/mock/listen.py

@@ -2,6 +2,7 @@ import uvicorn
 from fastapi import FastAPI, WebSocket
 
 from openhands.core.schema import ActionType
+from openhands.runtime.utils.shutdown_listener import should_continue
 
 app = FastAPI()
 
@@ -15,7 +16,7 @@ async def websocket_endpoint(websocket: WebSocket):
     )
 
     try:
-        while True:
+        while should_continue():
             # receive message
             data = await websocket.receive_json()
             print(f'Received message: {data}')

+ 2 - 1
openhands/server/session/manager.py

@@ -5,6 +5,7 @@ from fastapi import WebSocket
 
 from openhands.core.config import AppConfig
 from openhands.core.logger import openhands_logger as logger
+from openhands.runtime.utils.shutdown_listener import should_continue
 from openhands.server.session.session import Session
 from openhands.storage.files import FileStore
 
@@ -47,7 +48,7 @@ class SessionManager:
         return await self.send(sid, {'message': message})
 
     async def _cleanup_sessions(self):
-        while True:
+        while should_continue():
             current_time = time.time()
             session_ids_to_remove = []
             for sid, session in list(self._sessions.items()):

+ 2 - 1
openhands/server/session/session.py

@@ -20,6 +20,7 @@ from openhands.events.observation import (
 from openhands.events.serialization import event_from_dict, event_to_dict
 from openhands.events.stream import EventStreamSubscriber
 from openhands.llm.llm import LLM
+from openhands.runtime.utils.shutdown_listener import should_continue
 from openhands.server.session.agent import AgentSession
 from openhands.storage.files import FileStore
 
@@ -53,7 +54,7 @@ class Session:
         try:
             if self.websocket is None:
                 return
-            while True:
+            while should_continue():
                 try:
                     data = await self.websocket.receive_json()
                 except ValueError: