Răsfoiți Sursa

Simple initial rate limiting implementation (#4976)

Robert Brennan 1 an în urmă
părinte
comite
3c61a9521b
2 a modificat fișierele cu 67 adăugiri și 1 ștergeri
  1. 10 1
      openhands/server/listen.py
  2. 57 0
      openhands/server/middleware.py

+ 10 - 1
openhands/server/listen.py

@@ -64,7 +64,12 @@ from openhands.events.stream import AsyncEventStreamWrapper
 from openhands.llm import bedrock
 from openhands.runtime.base import Runtime
 from openhands.server.auth.auth import get_sid_from_token, sign_token
-from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware
+from openhands.server.middleware import (
+    InMemoryRateLimiter,
+    LocalhostCORSMiddleware,
+    NoCacheMiddleware,
+    RateLimitMiddleware,
+)
 from openhands.server.session import SessionManager
 
 load_dotenv()
@@ -84,6 +89,10 @@ app.add_middleware(
 
 
 app.add_middleware(NoCacheMiddleware)
+app.add_middleware(
+    RateLimitMiddleware, rate_limiter=InMemoryRateLimiter(requests=2, seconds=1)
+)
+
 
 security_scheme = HTTPBearer()
 

+ 57 - 0
openhands/server/middleware.py

@@ -1,6 +1,11 @@
+import asyncio
+from collections import defaultdict
+from datetime import datetime, timedelta
 from urllib.parse import urlparse
 
+from fastapi import Request
 from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.types import ASGIApp
 
@@ -41,3 +46,55 @@ class NoCacheMiddleware(BaseHTTPMiddleware):
             response.headers['Pragma'] = 'no-cache'
             response.headers['Expires'] = '0'
         return response
+
+
+class InMemoryRateLimiter:
+    history: dict
+    requests: int
+    seconds: int
+    sleep_seconds: int
+
+    def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1):
+        self.requests = requests
+        self.seconds = seconds
+        self.history = defaultdict(list)
+
+    def _clean_old_requests(self, key: str) -> None:
+        now = datetime.now()
+        cutoff = now - timedelta(seconds=self.seconds)
+        self.history[key] = [ts for ts in self.history[key] if ts > cutoff]
+
+    async def __call__(self, request: Request) -> bool:
+        key = request.client.host
+        now = datetime.now()
+
+        self._clean_old_requests(key)
+
+        self.history[key].append(now)
+
+        if len(self.history[key]) > self.requests * 2:
+            return False
+        elif len(self.history[key]) > self.requests:
+            if self.sleep_seconds > 0:
+                await asyncio.sleep(self.sleep_seconds)
+                return True
+            else:
+                return False
+
+        return True
+
+
+class RateLimitMiddleware(BaseHTTPMiddleware):
+    def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter):
+        super().__init__(app)
+        self.rate_limiter = rate_limiter
+
+    async def dispatch(self, request, call_next):
+        ok = await self.rate_limiter(request)
+        if not ok:
+            return JSONResponse(
+                status_code=429,
+                content={'message': 'Too many requests'},
+                headers={'Retry-After': '1'},
+            )
+        return await call_next(request)