| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- import asyncio
- from collections import defaultdict
- from datetime import datetime, timedelta
- from typing import Callable
- from urllib.parse import urlparse
- import jwt
- from fastapi import APIRouter, Request, status
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import JSONResponse
- from starlette.middleware.base import BaseHTTPMiddleware
- from starlette.types import ASGIApp
- from openhands.core.logger import openhands_logger as logger
- from openhands.server.auth import get_sid_from_token
- from openhands.server.github_utils import UserVerifier
- from openhands.server.shared import config, session_manager
- class LocalhostCORSMiddleware(CORSMiddleware):
- """
- Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
- while using standard CORS rules for other origins.
- """
- def __init__(self, app: ASGIApp, **kwargs) -> None:
- super().__init__(app, **kwargs)
- def is_allowed_origin(self, origin: str) -> bool:
- if origin:
- parsed = urlparse(origin)
- hostname = parsed.hostname or ''
- # Allow any localhost/127.0.0.1 origin regardless of port
- if hostname in ['localhost', '127.0.0.1']:
- return True
- # For missing origin or other origins, use the parent class's logic
- return super().is_allowed_origin(origin)
- class NoCacheMiddleware(BaseHTTPMiddleware):
- """
- Middleware to disable caching for all routes by adding appropriate headers
- """
- async def dispatch(self, request, call_next):
- response = await call_next(request)
- if not request.url.path.startswith('/assets'):
- response.headers['Cache-Control'] = (
- 'no-cache, no-store, must-revalidate, max-age=0'
- )
- response.headers['Pragma'] = 'no-cache'
- response.headers['Expires'] = '0'
- return response
- class InMemoryRateLimiter:
- history: dict
- requests: int
- seconds: int
- sleep_seconds: int
- def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1):
- self.requests = requests
- self.seconds = seconds
- self.sleep_seconds = sleep_seconds
- self.history = defaultdict(list)
- self.sleep_seconds = sleep_seconds
- def _clean_old_requests(self, key: str) -> None:
- now = datetime.now()
- cutoff = now - timedelta(seconds=self.seconds)
- self.history[key] = [ts for ts in self.history[key] if ts > cutoff]
- async def __call__(self, request: Request) -> bool:
- key = request.client.host
- now = datetime.now()
- self._clean_old_requests(key)
- self.history[key].append(now)
- if len(self.history[key]) > self.requests * 2:
- return False
- elif len(self.history[key]) > self.requests:
- if self.sleep_seconds > 0:
- await asyncio.sleep(self.sleep_seconds)
- return True
- else:
- return False
- return True
- class RateLimitMiddleware(BaseHTTPMiddleware):
- def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter):
- super().__init__(app)
- self.rate_limiter = rate_limiter
- async def dispatch(self, request, call_next):
- ok = await self.rate_limiter(request)
- if not ok:
- return JSONResponse(
- status_code=429,
- content={'message': 'Too many requests'},
- headers={'Retry-After': '1'},
- )
- return await call_next(request)
- class AttachSessionMiddleware:
- def __init__(self, app, target_router: APIRouter):
- self.app = app
- self.target_router = target_router
- self.target_paths = {route.path for route in target_router.routes}
- async def __call__(self, request: Request, call_next: Callable):
- do_attach = False
- if request.url.path in self.target_paths:
- do_attach = True
- if request.method == 'OPTIONS':
- do_attach = False
- if not do_attach:
- return await call_next(request)
- user_verifier = UserVerifier()
- if user_verifier.is_active():
- signed_token = request.cookies.get('github_auth')
- if not signed_token:
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={'error': 'Not authenticated'},
- )
- try:
- jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256'])
- except Exception as e:
- logger.warning(f'Invalid token: {e}')
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={'error': 'Invalid token'},
- )
- if not request.headers.get('Authorization'):
- logger.warning('Missing Authorization header')
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={'error': 'Missing Authorization header'},
- )
- auth_token = request.headers.get('Authorization')
- if 'Bearer' in auth_token:
- auth_token = auth_token.split('Bearer')[1].strip()
- request.state.sid = get_sid_from_token(auth_token, config.jwt_secret)
- if request.state.sid == '':
- logger.warning('Invalid token')
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={'error': 'Invalid token'},
- )
- request.state.conversation = await session_manager.attach_to_conversation(
- request.state.sid
- )
- if request.state.conversation is None:
- return JSONResponse(
- status_code=status.HTTP_404_NOT_FOUND,
- content={'error': 'Session not found'},
- )
- try:
- response = await call_next(request)
- finally:
- await session_manager.detach_from_conversation(request.state.conversation)
- return response
|