|
|
@@ -1,6 +1,11 @@
|
|
|
+import asyncio
|
|
|
+from collections import defaultdict
|
|
|
+from datetime import datetime, timedelta
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
+from fastapi import Request
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
+from fastapi.responses import JSONResponse
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
from starlette.types import ASGIApp
|
|
|
|
|
|
@@ -41,3 +46,55 @@ class NoCacheMiddleware(BaseHTTPMiddleware):
|
|
|
response.headers['Pragma'] = 'no-cache'
|
|
|
response.headers['Expires'] = '0'
|
|
|
return response
|
|
|
+
|
|
|
+
|
|
|
+class InMemoryRateLimiter:
|
|
|
+ history: dict
|
|
|
+ requests: int
|
|
|
+ seconds: int
|
|
|
+ sleep_seconds: int
|
|
|
+
|
|
|
+ def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1):
|
|
|
+ self.requests = requests
|
|
|
+ self.seconds = seconds
|
|
|
+ self.history = defaultdict(list)
|
|
|
+
|
|
|
+ def _clean_old_requests(self, key: str) -> None:
|
|
|
+ now = datetime.now()
|
|
|
+ cutoff = now - timedelta(seconds=self.seconds)
|
|
|
+ self.history[key] = [ts for ts in self.history[key] if ts > cutoff]
|
|
|
+
|
|
|
+ async def __call__(self, request: Request) -> bool:
|
|
|
+ key = request.client.host
|
|
|
+ now = datetime.now()
|
|
|
+
|
|
|
+ self._clean_old_requests(key)
|
|
|
+
|
|
|
+ self.history[key].append(now)
|
|
|
+
|
|
|
+ if len(self.history[key]) > self.requests * 2:
|
|
|
+ return False
|
|
|
+ elif len(self.history[key]) > self.requests:
|
|
|
+ if self.sleep_seconds > 0:
|
|
|
+ await asyncio.sleep(self.sleep_seconds)
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ return False
|
|
|
+
|
|
|
+ return True
|
|
|
+
|
|
|
+
|
|
|
+class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
|
+ def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter):
|
|
|
+ super().__init__(app)
|
|
|
+ self.rate_limiter = rate_limiter
|
|
|
+
|
|
|
+ async def dispatch(self, request, call_next):
|
|
|
+ ok = await self.rate_limiter(request)
|
|
|
+ if not ok:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=429,
|
|
|
+ content={'message': 'Too many requests'},
|
|
|
+ headers={'Retry-After': '1'},
|
|
|
+ )
|
|
|
+ return await call_next(request)
|