middleware.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import asyncio
  2. from collections import defaultdict
  3. from datetime import datetime, timedelta
  4. from typing import Callable
  5. from urllib.parse import urlparse
  6. from fastapi import Request, status
  7. from fastapi.middleware.cors import CORSMiddleware
  8. from fastapi.responses import JSONResponse
  9. from starlette.middleware.base import BaseHTTPMiddleware
  10. from starlette.types import ASGIApp
  11. from openhands.server.shared import session_manager
  12. from openhands.server.types import SessionMiddlewareInterface
  13. class LocalhostCORSMiddleware(CORSMiddleware):
  14. """
  15. Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
  16. while using standard CORS rules for other origins.
  17. """
  18. def __init__(self, app: ASGIApp, **kwargs) -> None:
  19. super().__init__(app, **kwargs)
  20. def is_allowed_origin(self, origin: str) -> bool:
  21. if origin:
  22. parsed = urlparse(origin)
  23. hostname = parsed.hostname or ''
  24. # Allow any localhost/127.0.0.1 origin regardless of port
  25. if hostname in ['localhost', '127.0.0.1']:
  26. return True
  27. # For missing origin or other origins, use the parent class's logic
  28. return super().is_allowed_origin(origin)
  29. class NoCacheMiddleware(BaseHTTPMiddleware):
  30. """
  31. Middleware to disable caching for all routes by adding appropriate headers
  32. """
  33. async def dispatch(self, request, call_next):
  34. response = await call_next(request)
  35. if not request.url.path.startswith('/assets'):
  36. response.headers['Cache-Control'] = (
  37. 'no-cache, no-store, must-revalidate, max-age=0'
  38. )
  39. response.headers['Pragma'] = 'no-cache'
  40. response.headers['Expires'] = '0'
  41. return response
  42. class InMemoryRateLimiter:
  43. history: dict
  44. requests: int
  45. seconds: int
  46. sleep_seconds: int
  47. def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1):
  48. self.requests = requests
  49. self.seconds = seconds
  50. self.sleep_seconds = sleep_seconds
  51. self.history = defaultdict(list)
  52. self.sleep_seconds = sleep_seconds
  53. def _clean_old_requests(self, key: str) -> None:
  54. now = datetime.now()
  55. cutoff = now - timedelta(seconds=self.seconds)
  56. self.history[key] = [ts for ts in self.history[key] if ts > cutoff]
  57. async def __call__(self, request: Request) -> bool:
  58. key = request.client.host
  59. now = datetime.now()
  60. self._clean_old_requests(key)
  61. self.history[key].append(now)
  62. if len(self.history[key]) > self.requests * 2:
  63. return False
  64. elif len(self.history[key]) > self.requests:
  65. if self.sleep_seconds > 0:
  66. await asyncio.sleep(self.sleep_seconds)
  67. return True
  68. else:
  69. return False
  70. return True
  71. class RateLimitMiddleware(BaseHTTPMiddleware):
  72. def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter):
  73. super().__init__(app)
  74. self.rate_limiter = rate_limiter
  75. async def dispatch(self, request, call_next):
  76. ok = await self.rate_limiter(request)
  77. if not ok:
  78. return JSONResponse(
  79. status_code=429,
  80. content={'message': 'Too many requests'},
  81. headers={'Retry-After': '1'},
  82. )
  83. return await call_next(request)
  84. class AttachConversationMiddleware(SessionMiddlewareInterface):
  85. def __init__(self, app):
  86. self.app = app
  87. def _should_attach(self, request) -> bool:
  88. """
  89. Determine if the middleware should attach a session for the given request.
  90. """
  91. if request.method == 'OPTIONS':
  92. return False
  93. conversation_id = ''
  94. if request.url.path.startswith('/api/conversation'):
  95. # FIXME: we should be able to use path_params
  96. path_parts = request.url.path.split('/')
  97. if len(path_parts) > 3:
  98. conversation_id = request.url.path.split('/')[3]
  99. if not conversation_id:
  100. return False
  101. request.state.sid = conversation_id
  102. return True
  103. async def _attach_conversation(self, request: Request) -> JSONResponse | None:
  104. """
  105. Attach the user's session based on the provided authentication token.
  106. """
  107. request.state.conversation = await session_manager.attach_to_conversation(
  108. request.state.sid
  109. )
  110. if not request.state.conversation:
  111. return JSONResponse(
  112. status_code=status.HTTP_404_NOT_FOUND,
  113. content={'error': 'Session not found'},
  114. )
  115. return None
  116. async def _detach_session(self, request: Request) -> None:
  117. """
  118. Detach the user's session.
  119. """
  120. await session_manager.detach_from_conversation(request.state.conversation)
  121. async def __call__(self, request: Request, call_next: Callable):
  122. if not self._should_attach(request):
  123. return await call_next(request)
  124. response = await self._attach_conversation(request)
  125. if response:
  126. return response
  127. try:
  128. # Continue processing the request
  129. response = await call_next(request)
  130. finally:
  131. # Ensure the session is detached
  132. await self._detach_session(request)
  133. return response