middleware.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import asyncio
  2. from collections import defaultdict
  3. from datetime import datetime, timedelta
  4. from urllib.parse import urlparse
  5. from fastapi import Request
  6. from fastapi.middleware.cors import CORSMiddleware
  7. from fastapi.responses import JSONResponse
  8. from starlette.middleware.base import BaseHTTPMiddleware
  9. from starlette.types import ASGIApp
  10. class LocalhostCORSMiddleware(CORSMiddleware):
  11. """
  12. Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
  13. while using standard CORS rules for other origins.
  14. """
  15. def __init__(self, app: ASGIApp, **kwargs) -> None:
  16. super().__init__(app, **kwargs)
  17. def is_allowed_origin(self, origin: str) -> bool:
  18. if origin:
  19. parsed = urlparse(origin)
  20. hostname = parsed.hostname or ''
  21. # Allow any localhost/127.0.0.1 origin regardless of port
  22. if hostname in ['localhost', '127.0.0.1']:
  23. return True
  24. # For missing origin or other origins, use the parent class's logic
  25. return super().is_allowed_origin(origin)
  26. class NoCacheMiddleware(BaseHTTPMiddleware):
  27. """
  28. Middleware to disable caching for all routes by adding appropriate headers
  29. """
  30. async def dispatch(self, request, call_next):
  31. response = await call_next(request)
  32. if not request.url.path.startswith('/assets'):
  33. response.headers['Cache-Control'] = (
  34. 'no-cache, no-store, must-revalidate, max-age=0'
  35. )
  36. response.headers['Pragma'] = 'no-cache'
  37. response.headers['Expires'] = '0'
  38. return response
  39. class InMemoryRateLimiter:
  40. history: dict
  41. requests: int
  42. seconds: int
  43. sleep_seconds: int
  44. def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1):
  45. self.requests = requests
  46. self.seconds = seconds
  47. self.sleep_seconds = sleep_seconds
  48. self.history = defaultdict(list)
  49. def _clean_old_requests(self, key: str) -> None:
  50. now = datetime.now()
  51. cutoff = now - timedelta(seconds=self.seconds)
  52. self.history[key] = [ts for ts in self.history[key] if ts > cutoff]
  53. async def __call__(self, request: Request) -> bool:
  54. key = request.client.host
  55. now = datetime.now()
  56. self._clean_old_requests(key)
  57. self.history[key].append(now)
  58. if len(self.history[key]) > self.requests * 2:
  59. return False
  60. elif len(self.history[key]) > self.requests:
  61. if self.sleep_seconds > 0:
  62. await asyncio.sleep(self.sleep_seconds)
  63. return True
  64. else:
  65. return False
  66. return True
  67. class RateLimitMiddleware(BaseHTTPMiddleware):
  68. def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter):
  69. super().__init__(app)
  70. self.rate_limiter = rate_limiter
  71. async def dispatch(self, request, call_next):
  72. ok = await self.rate_limiter(request)
  73. if not ok:
  74. return JSONResponse(
  75. status_code=429,
  76. content={'message': 'Too many requests'},
  77. headers={'Retry-After': '1'},
  78. )
  79. return await call_next(request)