|
|
@@ -4,7 +4,6 @@ from datetime import datetime, timedelta
|
|
|
from typing import Callable
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
-import jwt
|
|
|
from fastapi import APIRouter, Request, status
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from fastapi.responses import JSONResponse
|
|
|
@@ -13,8 +12,8 @@ from starlette.types import ASGIApp
|
|
|
|
|
|
from openhands.core.logger import openhands_logger as logger
|
|
|
from openhands.server.auth import get_sid_from_token
|
|
|
-from openhands.server.github_utils import UserVerifier
|
|
|
from openhands.server.shared import config, session_manager
|
|
|
+from openhands.server.types import SessionMiddlewareInterface
|
|
|
|
|
|
|
|
|
class LocalhostCORSMiddleware(CORSMiddleware):
|
|
|
@@ -109,53 +108,32 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
|
return await call_next(request)
|
|
|
|
|
|
|
|
|
-class AttachSessionMiddleware:
|
|
|
+class AttachSessionMiddleware(SessionMiddlewareInterface):
|
|
|
def __init__(self, app, target_router: APIRouter):
|
|
|
self.app = app
|
|
|
self.target_router = target_router
|
|
|
self.target_paths = {route.path for route in target_router.routes}
|
|
|
|
|
|
- async def __call__(self, request: Request, call_next: Callable):
|
|
|
- do_attach = False
|
|
|
- if request.url.path in self.target_paths:
|
|
|
- do_attach = True
|
|
|
-
|
|
|
+ def _should_attach(self, request) -> bool:
|
|
|
+ """
|
|
|
+ Determine if the middleware should attach a session for the given request.
|
|
|
+ """
|
|
|
if request.method == 'OPTIONS':
|
|
|
- do_attach = False
|
|
|
-
|
|
|
- if not do_attach:
|
|
|
- return await call_next(request)
|
|
|
-
|
|
|
- user_verifier = UserVerifier()
|
|
|
- if user_verifier.is_active():
|
|
|
- signed_token = request.cookies.get('github_auth')
|
|
|
- if not signed_token:
|
|
|
- return JSONResponse(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- content={'error': 'Not authenticated'},
|
|
|
- )
|
|
|
- try:
|
|
|
- jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256'])
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f'Invalid token: {e}')
|
|
|
- return JSONResponse(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- content={'error': 'Invalid token'},
|
|
|
- )
|
|
|
-
|
|
|
- if not request.headers.get('Authorization'):
|
|
|
- logger.warning('Missing Authorization header')
|
|
|
- return JSONResponse(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- content={'error': 'Missing Authorization header'},
|
|
|
- )
|
|
|
+ return False
|
|
|
+ if request.url.path not in self.target_paths:
|
|
|
+ return False
|
|
|
+ return True
|
|
|
|
|
|
- auth_token = request.headers.get('Authorization')
|
|
|
+ async def _attach_session(self, request: Request) -> JSONResponse | None:
|
|
|
+ """
|
|
|
+ Attach the user's session based on the provided authentication token.
|
|
|
+ """
|
|
|
+ auth_token = request.headers.get('Authorization', '')
|
|
|
if 'Bearer' in auth_token:
|
|
|
auth_token = auth_token.split('Bearer')[1].strip()
|
|
|
|
|
|
request.state.sid = get_sid_from_token(auth_token, config.jwt_secret)
|
|
|
- if request.state.sid == '':
|
|
|
+ if not request.state.sid:
|
|
|
logger.warning('Invalid token')
|
|
|
return JSONResponse(
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
@@ -165,13 +143,32 @@ class AttachSessionMiddleware:
|
|
|
request.state.conversation = await session_manager.attach_to_conversation(
|
|
|
request.state.sid
|
|
|
)
|
|
|
- if request.state.conversation is None:
|
|
|
+ if not request.state.conversation:
|
|
|
return JSONResponse(
|
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
|
content={'error': 'Session not found'},
|
|
|
)
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def _detach_session(self, request: Request) -> None:
|
|
|
+ """
|
|
|
+ Detach the user's session.
|
|
|
+ """
|
|
|
+ await session_manager.detach_from_conversation(request.state.conversation)
|
|
|
+
|
|
|
+ async def __call__(self, request: Request, call_next: Callable):
|
|
|
+ if not self._should_attach(request):
|
|
|
+ return await call_next(request)
|
|
|
+
|
|
|
+ response = await self._attach_session(request)
|
|
|
+ if response:
|
|
|
+ return response
|
|
|
+
|
|
|
try:
|
|
|
+ # Continue processing the request
|
|
|
response = await call_next(request)
|
|
|
finally:
|
|
|
- await session_manager.detach_from_conversation(request.state.conversation)
|
|
|
+ # Ensure the session is detached
|
|
|
+ await self._detach_session(request)
|
|
|
+
|
|
|
return response
|