Просмотр исходного кода

Add rate limiting to server endpoints (#4867)

Co-authored-by: openhands <openhands@all-hands.dev>
Robert Brennan 1 год назад
Родитель
Сommit
79492b6551
3 измененных файлов с 95 добавлено и 1 удалено
  1. 40 0
      openhands/server/listen.py
  2. 52 1
      poetry.lock
  3. 3 0
      pyproject.toml

+ 40 - 0
openhands/server/listen.py

@@ -11,6 +11,9 @@ import jwt
 import requests
 from pathspec import PathSpec
 from pathspec.patterns import GitWildMatchPattern
+from slowapi import Limiter, _rate_limit_exceeded_handler
+from slowapi.errors import RateLimitExceeded
+from slowapi.util import get_remote_address
 
 from openhands.security.options import SecurityAnalyzers
 from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
@@ -94,6 +97,36 @@ app.add_middleware(NoCacheMiddleware)
 
 security_scheme = HTTPBearer()
 
+# Initialize rate limiter
+limiter = Limiter(
+    key_func=get_remote_address,
+    default_limits=['5 per second'],
+    strategy='moving-window',  # Use a sliding window for more accurate rate limiting
+)
+app.state.limiter = limiter
+app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
+
+
+# Apply stricter limits to auth endpoints
+def get_path_limits(request: Request):
+    path = request.url.path
+    if path == '/ws' or path in ['/api/github/callback', '/api/authenticate']:
+        return ['1 per second']
+    return ['5 per second']
+
+
+@app.middleware('http')
+async def rate_limit_middleware(request: Request, call_next):
+    limits = get_path_limits(request)
+    try:
+        await limiter.check_request_limit(request, limits=limits)
+    except RateLimitExceeded:
+        return JSONResponse(
+            status_code=status.HTTP_429_TOO_MANY_REQUESTS,
+            content={'error': 'Too many requests'},
+        )
+    return await call_next(request)
+
 
 def load_file_upload_config() -> tuple[int, bool, list[str]]:
     """Load file upload configuration from the config object.
@@ -260,6 +293,13 @@ async def attach_session(request: Request, call_next):
 
 @app.websocket('/ws')
 async def websocket_endpoint(websocket: WebSocket):
+    try:
+        # Create a mock request object for rate limiting
+        mock_request = Request(scope={'type': 'http', 'client': websocket.client})
+        await limiter.check_request_limit(mock_request, limits=['1 per second'])
+    except RateLimitExceeded:
+        await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
+        return
     """WebSocket endpoint for receiving events from the client (i.e., the browser).
     Once connected, the client can send various actions:
     - Initialize the agent:

+ 52 - 1
poetry.lock

@@ -3932,6 +3932,35 @@ tiktoken = "*"
 transformers = "*"
 types-tqdm = "*"
 
+[[package]]
+name = "limits"
+version = "3.13.0"
+description = "Rate limiting utilities"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "limits-3.13.0-py3-none-any.whl", hash = "sha256:9767f7233da4255e9904b79908a728e8ec0984c0b086058b4cbbd309aea553f6"},
+    {file = "limits-3.13.0.tar.gz", hash = "sha256:6571b0c567bfa175a35fed9f8a954c0c92f1c3200804282f1b8f1de4ad98a953"},
+]
+
+[package.dependencies]
+deprecated = ">=1.2"
+importlib-resources = ">=1.3"
+packaging = ">=21,<25"
+typing-extensions = "*"
+
+[package.extras]
+all = ["aetcd", "coredis (>=3.4.0,<5)", "emcache (>=0.6.1)", "emcache (>=1)", "etcd3", "motor (>=3,<4)", "pymemcache (>3,<5.0.0)", "pymongo (>4.1,<5)", "redis (>3,!=4.5.2,!=4.5.3,<6.0.0)", "redis (>=4.2.0,!=4.5.2,!=4.5.3)"]
+async-etcd = ["aetcd"]
+async-memcached = ["emcache (>=0.6.1)", "emcache (>=1)"]
+async-mongodb = ["motor (>=3,<4)"]
+async-redis = ["coredis (>=3.4.0,<5)"]
+etcd = ["etcd3"]
+memcached = ["pymemcache (>3,<5.0.0)"]
+mongodb = ["pymongo (>4.1,<5)"]
+redis = ["redis (>3,!=4.5.2,!=4.5.3,<6.0.0)"]
+rediscluster = ["redis (>=4.2.0,!=4.5.2,!=4.5.3)"]
+
 [[package]]
 name = "litellm"
 version = "1.52.3"
@@ -7995,6 +8024,11 @@ files = [
     {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"},
     {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"},
     {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"},
+    {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"},
+    {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"},
+    {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"},
+    {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"},
+    {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"},
     {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"},
     {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"},
     {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"},
@@ -8250,6 +8284,23 @@ files = [
     {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
 ]
 
+[[package]]
+name = "slowapi"
+version = "0.1.9"
+description = "A rate limiting extension for Starlette and Fastapi"
+optional = false
+python-versions = ">=3.7,<4.0"
+files = [
+    {file = "slowapi-0.1.9-py3-none-any.whl", hash = "sha256:cfad116cfb84ad9d763ee155c1e5c5cbf00b0d47399a769b227865f5df576e36"},
+    {file = "slowapi-0.1.9.tar.gz", hash = "sha256:639192d0f1ca01b1c6d95bf6c71d794c3a9ee189855337b4821f7f457dddad77"},
+]
+
+[package.dependencies]
+limits = ">=2.3"
+
+[package.extras]
+redis = ["redis (>=3.4.1,<4.0.0)"]
+
 [[package]]
 name = "smmap"
 version = "5.0.1"
@@ -10128,4 +10179,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.12"
-content-hash = "8a34ef6158ca2a9fe3615fc362db3fd71bc43eabb57ffc2e2e14dfb658cf52c3"
+content-hash = "70158478f99a3e3e3356b9cfcd3bc5a156257ce7adebf6d2f5ffe8c9eff5d4a7"

+ 3 - 0
pyproject.toml

@@ -62,6 +62,7 @@ opentelemetry-api = "1.25.0"
 opentelemetry-exporter-otlp-proto-grpc = "1.25.0"
 modal = "^0.64.145"
 runloop-api-client = "0.7.0"
+slowapi = "^0.1.9"
 
 [tool.poetry.group.llama-index.dependencies]
 llama-index = "*"
@@ -93,6 +94,7 @@ reportlab = "*"
 [tool.coverage.run]
 concurrency = ["gevent"]
 
+
 [tool.poetry.group.runtime.dependencies]
 jupyterlab = "*"
 notebook = "*"
@@ -123,6 +125,7 @@ ignore = ["D1"]
 [tool.ruff.lint.pydocstyle]
 convention = "google"
 
+
 [tool.poetry.group.evaluation.dependencies]
 streamlit = "*"
 whatthepatch = "*"