| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- import asyncio
- from collections import defaultdict
- from datetime import datetime, timedelta
- from urllib.parse import urlparse
- from fastapi import Request
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import JSONResponse
- from starlette.middleware.base import BaseHTTPMiddleware
- from starlette.types import ASGIApp
- class LocalhostCORSMiddleware(CORSMiddleware):
- """
- Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
- while using standard CORS rules for other origins.
- """
- def __init__(self, app: ASGIApp, **kwargs) -> None:
- super().__init__(app, **kwargs)
- def is_allowed_origin(self, origin: str) -> bool:
- if origin:
- parsed = urlparse(origin)
- hostname = parsed.hostname or ''
- # Allow any localhost/127.0.0.1 origin regardless of port
- if hostname in ['localhost', '127.0.0.1']:
- return True
- # For missing origin or other origins, use the parent class's logic
- return super().is_allowed_origin(origin)
- class NoCacheMiddleware(BaseHTTPMiddleware):
- """
- Middleware to disable caching for all routes by adding appropriate headers
- """
- async def dispatch(self, request, call_next):
- response = await call_next(request)
- if not request.url.path.startswith('/assets'):
- response.headers['Cache-Control'] = (
- 'no-cache, no-store, must-revalidate, max-age=0'
- )
- response.headers['Pragma'] = 'no-cache'
- response.headers['Expires'] = '0'
- return response
- class InMemoryRateLimiter:
- history: dict
- requests: int
- seconds: int
- sleep_seconds: int
- def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1):
- self.requests = requests
- self.seconds = seconds
- self.sleep_seconds = sleep_seconds
- self.history = defaultdict(list)
- def _clean_old_requests(self, key: str) -> None:
- now = datetime.now()
- cutoff = now - timedelta(seconds=self.seconds)
- self.history[key] = [ts for ts in self.history[key] if ts > cutoff]
- async def __call__(self, request: Request) -> bool:
- key = request.client.host
- now = datetime.now()
- self._clean_old_requests(key)
- self.history[key].append(now)
- if len(self.history[key]) > self.requests * 2:
- return False
- elif len(self.history[key]) > self.requests:
- if self.sleep_seconds > 0:
- await asyncio.sleep(self.sleep_seconds)
- return True
- else:
- return False
- return True
- class RateLimitMiddleware(BaseHTTPMiddleware):
- def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter):
- super().__init__(app)
- self.rate_limiter = rate_limiter
- async def dispatch(self, request, call_next):
- ok = await self.rate_limiter(request)
- if not ok:
- return JSONResponse(
- status_code=429,
- content={'message': 'Too many requests'},
- headers={'Retry-After': '1'},
- )
- return await call_next(request)
|