| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- import asyncio
- from collections import defaultdict
- from datetime import datetime, timedelta
- from typing import Callable
- from urllib.parse import urlparse
- from fastapi import Request, status
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import JSONResponse
- from starlette.middleware.base import BaseHTTPMiddleware
- from starlette.types import ASGIApp
- from openhands.server.shared import session_manager
- from openhands.server.types import SessionMiddlewareInterface
- class LocalhostCORSMiddleware(CORSMiddleware):
- """
- Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
- while using standard CORS rules for other origins.
- """
- def __init__(self, app: ASGIApp, **kwargs) -> None:
- super().__init__(app, **kwargs)
- def is_allowed_origin(self, origin: str) -> bool:
- if origin:
- parsed = urlparse(origin)
- hostname = parsed.hostname or ''
- # Allow any localhost/127.0.0.1 origin regardless of port
- if hostname in ['localhost', '127.0.0.1']:
- return True
- # For missing origin or other origins, use the parent class's logic
- return super().is_allowed_origin(origin)
- class NoCacheMiddleware(BaseHTTPMiddleware):
- """
- Middleware to disable caching for all routes by adding appropriate headers
- """
- async def dispatch(self, request, call_next):
- response = await call_next(request)
- if not request.url.path.startswith('/assets'):
- response.headers['Cache-Control'] = (
- 'no-cache, no-store, must-revalidate, max-age=0'
- )
- response.headers['Pragma'] = 'no-cache'
- response.headers['Expires'] = '0'
- return response
- class InMemoryRateLimiter:
- history: dict
- requests: int
- seconds: int
- sleep_seconds: int
- def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1):
- self.requests = requests
- self.seconds = seconds
- self.sleep_seconds = sleep_seconds
- self.history = defaultdict(list)
- self.sleep_seconds = sleep_seconds
- def _clean_old_requests(self, key: str) -> None:
- now = datetime.now()
- cutoff = now - timedelta(seconds=self.seconds)
- self.history[key] = [ts for ts in self.history[key] if ts > cutoff]
- async def __call__(self, request: Request) -> bool:
- key = request.client.host
- now = datetime.now()
- self._clean_old_requests(key)
- self.history[key].append(now)
- if len(self.history[key]) > self.requests * 2:
- return False
- elif len(self.history[key]) > self.requests:
- if self.sleep_seconds > 0:
- await asyncio.sleep(self.sleep_seconds)
- return True
- else:
- return False
- return True
- class RateLimitMiddleware(BaseHTTPMiddleware):
- def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter):
- super().__init__(app)
- self.rate_limiter = rate_limiter
- async def dispatch(self, request, call_next):
- ok = await self.rate_limiter(request)
- if not ok:
- return JSONResponse(
- status_code=429,
- content={'message': 'Too many requests'},
- headers={'Retry-After': '1'},
- )
- return await call_next(request)
- class AttachConversationMiddleware(SessionMiddlewareInterface):
- def __init__(self, app):
- self.app = app
- def _should_attach(self, request) -> bool:
- """
- Determine if the middleware should attach a session for the given request.
- """
- if request.method == 'OPTIONS':
- return False
- conversation_id = ''
- if request.url.path.startswith('/api/conversation'):
- # FIXME: we should be able to use path_params
- path_parts = request.url.path.split('/')
- if len(path_parts) > 3:
- conversation_id = request.url.path.split('/')[3]
- if not conversation_id:
- return False
- request.state.sid = conversation_id
- return True
- async def _attach_conversation(self, request: Request) -> JSONResponse | None:
- """
- Attach the user's session based on the provided authentication token.
- """
- request.state.conversation = await session_manager.attach_to_conversation(
- request.state.sid
- )
- if not request.state.conversation:
- return JSONResponse(
- status_code=status.HTTP_404_NOT_FOUND,
- content={'error': 'Session not found'},
- )
- return None
- async def _detach_session(self, request: Request) -> None:
- """
- Detach the user's session.
- """
- await session_manager.detach_from_conversation(request.state.conversation)
- async def __call__(self, request: Request, call_next: Callable):
- if not self._should_attach(request):
- return await call_next(request)
- response = await self._attach_conversation(request)
- if response:
- return response
- try:
- # Continue processing the request
- response = await call_next(request)
- finally:
- # Ensure the session is detached
- await self._detach_session(request)
- return response
|