浏览代码

Refactor CORS middleware and enhance localhost handling (#4624)

Co-authored-by: openhands <openhands@all-hands.dev>
tofarr 1 年之前
父节点
当前提交
05645d1bbd
共有 2 个文件被更改,包括 45 次插入20 次删除
  1. 2 20
      openhands/server/listen.py
  2. 43 0
      openhands/server/middleware.py

+ 2 - 20
openhands/server/listen.py

@@ -29,12 +29,10 @@ from fastapi import (
     WebSocket,
     status,
 )
-from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import JSONResponse, StreamingResponse
 from fastapi.security import HTTPBearer
 from fastapi.staticfiles import StaticFiles
 from pydantic import BaseModel
-from starlette.middleware.base import BaseHTTPMiddleware
 
 import openhands.agenthub  # noqa F401 (we import this to get the agents registered)
 from openhands.controller.agent import Agent
@@ -57,6 +55,7 @@ from openhands.events.serialization import event_to_dict
 from openhands.llm import bedrock
 from openhands.runtime.base import Runtime
 from openhands.server.auth import get_sid_from_token, sign_token
+from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware
 from openhands.server.session import SessionManager
 
 load_dotenv()
@@ -93,30 +92,13 @@ async def lifespan(app: FastAPI):
 
 app = FastAPI(lifespan=lifespan)
 app.add_middleware(
-    CORSMiddleware,
-    allow_origins=['http://localhost:3001', 'http://127.0.0.1:3001'],
+    LocalhostCORSMiddleware,
     allow_credentials=True,
     allow_methods=['*'],
     allow_headers=['*'],
 )
 
 
-class NoCacheMiddleware(BaseHTTPMiddleware):
-    """
-    Middleware to disable caching for all routes by adding appropriate headers
-    """
-
-    async def dispatch(self, request, call_next):
-        response = await call_next(request)
-        if not request.url.path.startswith('/assets'):
-            response.headers['Cache-Control'] = (
-                'no-cache, no-store, must-revalidate, max-age=0'
-            )
-            response.headers['Pragma'] = 'no-cache'
-            response.headers['Expires'] = '0'
-        return response
-
-
 app.add_middleware(NoCacheMiddleware)
 
 security_scheme = HTTPBearer()

+ 43 - 0
openhands/server/middleware.py

@@ -0,0 +1,43 @@
+from urllib.parse import urlparse
+
+from fastapi.middleware.cors import CORSMiddleware
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.types import ASGIApp
+
+
+class LocalhostCORSMiddleware(CORSMiddleware):
+    """
+    Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
+    while using standard CORS rules for other origins.
+    """
+
+    def __init__(self, app: ASGIApp, **kwargs) -> None:
+        super().__init__(app, **kwargs)
+
+    async def is_allowed_origin(self, origin: str) -> bool:
+        if origin:
+            parsed = urlparse(origin)
+            hostname = parsed.hostname or ''
+
+            # Allow any localhost/127.0.0.1 origin regardless of port
+            if hostname in ['localhost', '127.0.0.1']:
+                return True
+
+        # For missing origin or other origins, use the parent class's logic
+        return await super().is_allowed_origin(origin)
+
+
+class NoCacheMiddleware(BaseHTTPMiddleware):
+    """
+    Middleware to disable caching for all routes by adding appropriate headers
+    """
+
+    async def dispatch(self, request, call_next):
+        response = await call_next(request)
+        if not request.url.path.startswith('/assets'):
+            response.headers['Cache-Control'] = (
+                'no-cache, no-store, must-revalidate, max-age=0'
+            )
+            response.headers['Pragma'] = 'no-cache'
+            response.headers['Expires'] = '0'
+        return response