middleware.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import asyncio
  2. from collections import defaultdict
  3. from datetime import datetime, timedelta
  4. from typing import Callable
  5. from urllib.parse import urlparse
  6. from fastapi import APIRouter, Request, status
  7. from fastapi.middleware.cors import CORSMiddleware
  8. from fastapi.responses import JSONResponse
  9. from starlette.middleware.base import BaseHTTPMiddleware
  10. from starlette.types import ASGIApp
  11. from openhands.core.logger import openhands_logger as logger
  12. from openhands.server.auth import get_sid_from_token
  13. from openhands.server.shared import config, session_manager
  14. from openhands.server.types import SessionMiddlewareInterface
  15. class LocalhostCORSMiddleware(CORSMiddleware):
  16. """
  17. Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
  18. while using standard CORS rules for other origins.
  19. """
  20. def __init__(self, app: ASGIApp, **kwargs) -> None:
  21. super().__init__(app, **kwargs)
  22. def is_allowed_origin(self, origin: str) -> bool:
  23. if origin:
  24. parsed = urlparse(origin)
  25. hostname = parsed.hostname or ''
  26. # Allow any localhost/127.0.0.1 origin regardless of port
  27. if hostname in ['localhost', '127.0.0.1']:
  28. return True
  29. # For missing origin or other origins, use the parent class's logic
  30. return super().is_allowed_origin(origin)
  31. class NoCacheMiddleware(BaseHTTPMiddleware):
  32. """
  33. Middleware to disable caching for all routes by adding appropriate headers
  34. """
  35. async def dispatch(self, request, call_next):
  36. response = await call_next(request)
  37. if not request.url.path.startswith('/assets'):
  38. response.headers['Cache-Control'] = (
  39. 'no-cache, no-store, must-revalidate, max-age=0'
  40. )
  41. response.headers['Pragma'] = 'no-cache'
  42. response.headers['Expires'] = '0'
  43. return response
  44. class InMemoryRateLimiter:
  45. history: dict
  46. requests: int
  47. seconds: int
  48. sleep_seconds: int
  49. def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1):
  50. self.requests = requests
  51. self.seconds = seconds
  52. self.sleep_seconds = sleep_seconds
  53. self.history = defaultdict(list)
  54. self.sleep_seconds = sleep_seconds
  55. def _clean_old_requests(self, key: str) -> None:
  56. now = datetime.now()
  57. cutoff = now - timedelta(seconds=self.seconds)
  58. self.history[key] = [ts for ts in self.history[key] if ts > cutoff]
  59. async def __call__(self, request: Request) -> bool:
  60. key = request.client.host
  61. now = datetime.now()
  62. self._clean_old_requests(key)
  63. self.history[key].append(now)
  64. if len(self.history[key]) > self.requests * 2:
  65. return False
  66. elif len(self.history[key]) > self.requests:
  67. if self.sleep_seconds > 0:
  68. await asyncio.sleep(self.sleep_seconds)
  69. return True
  70. else:
  71. return False
  72. return True
  73. class RateLimitMiddleware(BaseHTTPMiddleware):
  74. def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter):
  75. super().__init__(app)
  76. self.rate_limiter = rate_limiter
  77. async def dispatch(self, request, call_next):
  78. ok = await self.rate_limiter(request)
  79. if not ok:
  80. return JSONResponse(
  81. status_code=429,
  82. content={'message': 'Too many requests'},
  83. headers={'Retry-After': '1'},
  84. )
  85. return await call_next(request)
  86. class AttachSessionMiddleware(SessionMiddlewareInterface):
  87. def __init__(self, app, target_router: APIRouter):
  88. self.app = app
  89. self.target_router = target_router
  90. self.target_paths = {route.path for route in target_router.routes}
  91. def _should_attach(self, request) -> bool:
  92. """
  93. Determine if the middleware should attach a session for the given request.
  94. """
  95. if request.method == 'OPTIONS':
  96. return False
  97. if request.url.path not in self.target_paths:
  98. return False
  99. return True
  100. async def _attach_session(self, request: Request) -> JSONResponse | None:
  101. """
  102. Attach the user's session based on the provided authentication token.
  103. """
  104. auth_token = request.headers.get('Authorization', '')
  105. if 'Bearer' in auth_token:
  106. auth_token = auth_token.split('Bearer')[1].strip()
  107. request.state.sid = get_sid_from_token(auth_token, config.jwt_secret)
  108. if not request.state.sid:
  109. logger.warning('Invalid token')
  110. return JSONResponse(
  111. status_code=status.HTTP_401_UNAUTHORIZED,
  112. content={'error': 'Invalid token'},
  113. )
  114. request.state.conversation = await session_manager.attach_to_conversation(
  115. request.state.sid
  116. )
  117. if not request.state.conversation:
  118. return JSONResponse(
  119. status_code=status.HTTP_404_NOT_FOUND,
  120. content={'error': 'Session not found'},
  121. )
  122. return None
  123. async def _detach_session(self, request: Request) -> None:
  124. """
  125. Detach the user's session.
  126. """
  127. await session_manager.detach_from_conversation(request.state.conversation)
  128. async def __call__(self, request: Request, call_next: Callable):
  129. if not self._should_attach(request):
  130. return await call_next(request)
  131. response = await self._attach_session(request)
  132. if response:
  133. return response
  134. try:
  135. # Continue processing the request
  136. response = await call_next(request)
  137. finally:
  138. # Ensure the session is detached
  139. await self._detach_session(request)
  140. return response