Browse Source

Refactor CORS middleware and enhance localhost handling (#4624)

Co-authored-by: openhands <openhands@all-hands.dev>
tofarr 1 year ago
parent
commit
05645d1bbd
2 changed files with 45 additions and 20 deletions
  1. 2 20
      openhands/server/listen.py
  2. 43 0
      openhands/server/middleware.py

+ 2 - 20
openhands/server/listen.py

@@ -29,12 +29,10 @@ from fastapi import (
     WebSocket,
     WebSocket,
     status,
     status,
 )
 )
-from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import JSONResponse, StreamingResponse
 from fastapi.responses import JSONResponse, StreamingResponse
 from fastapi.security import HTTPBearer
 from fastapi.security import HTTPBearer
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
 from pydantic import BaseModel
 from pydantic import BaseModel
-from starlette.middleware.base import BaseHTTPMiddleware
 
 
 import openhands.agenthub  # noqa F401 (we import this to get the agents registered)
 import openhands.agenthub  # noqa F401 (we import this to get the agents registered)
 from openhands.controller.agent import Agent
 from openhands.controller.agent import Agent
@@ -57,6 +55,7 @@ from openhands.events.serialization import event_to_dict
 from openhands.llm import bedrock
 from openhands.llm import bedrock
 from openhands.runtime.base import Runtime
 from openhands.runtime.base import Runtime
 from openhands.server.auth import get_sid_from_token, sign_token
 from openhands.server.auth import get_sid_from_token, sign_token
+from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware
 from openhands.server.session import SessionManager
 from openhands.server.session import SessionManager
 
 
 load_dotenv()
 load_dotenv()
@@ -93,30 +92,13 @@ async def lifespan(app: FastAPI):
 
 
 app = FastAPI(lifespan=lifespan)
 app = FastAPI(lifespan=lifespan)
 app.add_middleware(
 app.add_middleware(
-    CORSMiddleware,
-    allow_origins=['http://localhost:3001', 'http://127.0.0.1:3001'],
+    LocalhostCORSMiddleware,
     allow_credentials=True,
     allow_credentials=True,
     allow_methods=['*'],
     allow_methods=['*'],
     allow_headers=['*'],
     allow_headers=['*'],
 )
 )
 
 
 
 
-class NoCacheMiddleware(BaseHTTPMiddleware):
-    """
-    Middleware to disable caching for all routes by adding appropriate headers
-    """
-
-    async def dispatch(self, request, call_next):
-        response = await call_next(request)
-        if not request.url.path.startswith('/assets'):
-            response.headers['Cache-Control'] = (
-                'no-cache, no-store, must-revalidate, max-age=0'
-            )
-            response.headers['Pragma'] = 'no-cache'
-            response.headers['Expires'] = '0'
-        return response
-
-
 app.add_middleware(NoCacheMiddleware)
 app.add_middleware(NoCacheMiddleware)
 
 
 security_scheme = HTTPBearer()
 security_scheme = HTTPBearer()

+ 43 - 0
openhands/server/middleware.py

@@ -0,0 +1,43 @@
+from urllib.parse import urlparse
+
+from fastapi.middleware.cors import CORSMiddleware
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.types import ASGIApp
+
+
+class LocalhostCORSMiddleware(CORSMiddleware):
+    """
+    Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
+    while using standard CORS rules for other origins.
+    """
+
+    def __init__(self, app: ASGIApp, **kwargs) -> None:
+        super().__init__(app, **kwargs)
+
+    async def is_allowed_origin(self, origin: str) -> bool:
+        if origin:
+            parsed = urlparse(origin)
+            hostname = parsed.hostname or ''
+
+            # Allow any localhost/127.0.0.1 origin regardless of port
+            if hostname in ['localhost', '127.0.0.1']:
+                return True
+
+        # For missing origin or other origins, use the parent class's logic
+        return await super().is_allowed_origin(origin)
+
+
+class NoCacheMiddleware(BaseHTTPMiddleware):
+    """
+    Middleware to disable caching for all routes by adding appropriate headers
+    """
+
+    async def dispatch(self, request, call_next):
+        response = await call_next(request)
+        if not request.url.path.startswith('/assets'):
+            response.headers['Cache-Control'] = (
+                'no-cache, no-store, must-revalidate, max-age=0'
+            )
+            response.headers['Pragma'] = 'no-cache'
+            response.headers['Expires'] = '0'
+        return response