middleware.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import asyncio
  2. from collections import defaultdict
  3. from datetime import datetime, timedelta
  4. from typing import Callable
  5. from urllib.parse import urlparse
  6. import jwt
  7. from fastapi import APIRouter, Request, status
  8. from fastapi.middleware.cors import CORSMiddleware
  9. from fastapi.responses import JSONResponse
  10. from starlette.middleware.base import BaseHTTPMiddleware
  11. from starlette.types import ASGIApp
  12. from openhands.core.logger import openhands_logger as logger
  13. from openhands.server.auth import get_sid_from_token
  14. from openhands.server.github_utils import UserVerifier
  15. from openhands.server.shared import config, session_manager
  16. class LocalhostCORSMiddleware(CORSMiddleware):
  17. """
  18. Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
  19. while using standard CORS rules for other origins.
  20. """
  21. def __init__(self, app: ASGIApp, **kwargs) -> None:
  22. super().__init__(app, **kwargs)
  23. def is_allowed_origin(self, origin: str) -> bool:
  24. if origin:
  25. parsed = urlparse(origin)
  26. hostname = parsed.hostname or ''
  27. # Allow any localhost/127.0.0.1 origin regardless of port
  28. if hostname in ['localhost', '127.0.0.1']:
  29. return True
  30. # For missing origin or other origins, use the parent class's logic
  31. return super().is_allowed_origin(origin)
  32. class NoCacheMiddleware(BaseHTTPMiddleware):
  33. """
  34. Middleware to disable caching for all routes by adding appropriate headers
  35. """
  36. async def dispatch(self, request, call_next):
  37. response = await call_next(request)
  38. if not request.url.path.startswith('/assets'):
  39. response.headers['Cache-Control'] = (
  40. 'no-cache, no-store, must-revalidate, max-age=0'
  41. )
  42. response.headers['Pragma'] = 'no-cache'
  43. response.headers['Expires'] = '0'
  44. return response
  45. class InMemoryRateLimiter:
  46. history: dict
  47. requests: int
  48. seconds: int
  49. sleep_seconds: int
  50. def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1):
  51. self.requests = requests
  52. self.seconds = seconds
  53. self.sleep_seconds = sleep_seconds
  54. self.history = defaultdict(list)
  55. self.sleep_seconds = sleep_seconds
  56. def _clean_old_requests(self, key: str) -> None:
  57. now = datetime.now()
  58. cutoff = now - timedelta(seconds=self.seconds)
  59. self.history[key] = [ts for ts in self.history[key] if ts > cutoff]
  60. async def __call__(self, request: Request) -> bool:
  61. key = request.client.host
  62. now = datetime.now()
  63. self._clean_old_requests(key)
  64. self.history[key].append(now)
  65. if len(self.history[key]) > self.requests * 2:
  66. return False
  67. elif len(self.history[key]) > self.requests:
  68. if self.sleep_seconds > 0:
  69. await asyncio.sleep(self.sleep_seconds)
  70. return True
  71. else:
  72. return False
  73. return True
  74. class RateLimitMiddleware(BaseHTTPMiddleware):
  75. def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter):
  76. super().__init__(app)
  77. self.rate_limiter = rate_limiter
  78. async def dispatch(self, request, call_next):
  79. ok = await self.rate_limiter(request)
  80. if not ok:
  81. return JSONResponse(
  82. status_code=429,
  83. content={'message': 'Too many requests'},
  84. headers={'Retry-After': '1'},
  85. )
  86. return await call_next(request)
  87. class AttachSessionMiddleware:
  88. def __init__(self, app, target_router: APIRouter):
  89. self.app = app
  90. self.target_router = target_router
  91. self.target_paths = {route.path for route in target_router.routes}
  92. async def __call__(self, request: Request, call_next: Callable):
  93. do_attach = False
  94. if request.url.path in self.target_paths:
  95. do_attach = True
  96. if request.method == 'OPTIONS':
  97. do_attach = False
  98. if not do_attach:
  99. return await call_next(request)
  100. user_verifier = UserVerifier()
  101. if user_verifier.is_active():
  102. signed_token = request.cookies.get('github_auth')
  103. if not signed_token:
  104. return JSONResponse(
  105. status_code=status.HTTP_401_UNAUTHORIZED,
  106. content={'error': 'Not authenticated'},
  107. )
  108. try:
  109. jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256'])
  110. except Exception as e:
  111. logger.warning(f'Invalid token: {e}')
  112. return JSONResponse(
  113. status_code=status.HTTP_401_UNAUTHORIZED,
  114. content={'error': 'Invalid token'},
  115. )
  116. if not request.headers.get('Authorization'):
  117. logger.warning('Missing Authorization header')
  118. return JSONResponse(
  119. status_code=status.HTTP_401_UNAUTHORIZED,
  120. content={'error': 'Missing Authorization header'},
  121. )
  122. auth_token = request.headers.get('Authorization')
  123. if 'Bearer' in auth_token:
  124. auth_token = auth_token.split('Bearer')[1].strip()
  125. request.state.sid = get_sid_from_token(auth_token, config.jwt_secret)
  126. if request.state.sid == '':
  127. logger.warning('Invalid token')
  128. return JSONResponse(
  129. status_code=status.HTTP_401_UNAUTHORIZED,
  130. content={'error': 'Invalid token'},
  131. )
  132. request.state.conversation = await session_manager.attach_to_conversation(
  133. request.state.sid
  134. )
  135. if request.state.conversation is None:
  136. return JSONResponse(
  137. status_code=status.HTTP_404_NOT_FOUND,
  138. content={'error': 'Session not found'},
  139. )
  140. try:
  141. response = await call_next(request)
  142. finally:
  143. await session_manager.detach_from_conversation(request.state.conversation)
  144. return response