|
|
@@ -13,6 +13,11 @@ from pathspec.patterns import GitWildMatchPattern
|
|
|
|
|
|
from openhands.security.options import SecurityAnalyzers
|
|
|
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
|
|
|
+from openhands.server.github import (
|
|
|
+ GITHUB_CLIENT_ID,
|
|
|
+ GITHUB_CLIENT_SECRET,
|
|
|
+ authenticate_github_user,
|
|
|
+)
|
|
|
from openhands.storage import get_file_store
|
|
|
from openhands.utils.async_utils import call_sync_from_async
|
|
|
|
|
|
@@ -64,24 +69,6 @@ config = load_app_config()
|
|
|
file_store = get_file_store(config.file_store, config.file_store_path)
|
|
|
session_manager = SessionManager(config, file_store)
|
|
|
|
|
|
-GITHUB_CLIENT_ID = os.getenv('GITHUB_CLIENT_ID', '').strip()
|
|
|
-GITHUB_CLIENT_SECRET = os.getenv('GITHUB_CLIENT_SECRET', '').strip()
|
|
|
-
|
|
|
-# New global variable to store the user list
|
|
|
-GITHUB_USER_LIST = None
|
|
|
-
|
|
|
-
|
|
|
-# New function to load the user list
|
|
|
-def load_github_user_list():
|
|
|
- global GITHUB_USER_LIST
|
|
|
- waitlist = os.getenv('GITHUB_USER_LIST_FILE')
|
|
|
- if waitlist:
|
|
|
- with open(waitlist, 'r') as f:
|
|
|
- GITHUB_USER_LIST = [line.strip() for line in f if line.strip()]
|
|
|
-
|
|
|
-
|
|
|
-load_github_user_list()
|
|
|
-
|
|
|
|
|
|
@asynccontextmanager
|
|
|
async def lifespan(app: FastAPI):
|
|
|
@@ -216,7 +203,13 @@ async def attach_session(request: Request, call_next):
|
|
|
response = await call_next(request)
|
|
|
return response
|
|
|
|
|
|
- # For all other methods, validate the Authorization header
|
|
|
+ github_token = request.headers.get('X-GitHub-Token')
|
|
|
+ if not await authenticate_github_user(github_token):
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
+ content={'error': 'Not authenticated'},
|
|
|
+ )
|
|
|
+
|
|
|
if not request.headers.get('Authorization'):
|
|
|
logger.warning('Missing Authorization header')
|
|
|
return JSONResponse(
|
|
|
@@ -308,11 +301,28 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
|
{"action": "finish", "args": {}}
|
|
|
```
|
|
|
"""
|
|
|
- await asyncio.wait_for(websocket.accept(), 10)
|
|
|
+ # Get protocols from Sec-WebSocket-Protocol header
|
|
|
+ protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
|
|
|
+
|
|
|
+ # The first protocol should be our real protocol (e.g. 'openhands')
|
|
|
+ # The second protocol should contain our auth token
|
|
|
+ if len(protocols) < 3:
|
|
|
+ logger.error('Expected 3 websocket protocols, got %d', len(protocols))
|
|
|
+ await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
|
+ return
|
|
|
+
|
|
|
+ real_protocol = protocols[0]
|
|
|
+ jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
|
|
|
+ github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
|
|
|
|
|
|
- if websocket.query_params.get('token'):
|
|
|
- token = websocket.query_params.get('token')
|
|
|
- sid = get_sid_from_token(token, config.jwt_secret)
|
|
|
+ if not await authenticate_github_user(github_token):
|
|
|
+ await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
|
+ return
|
|
|
+
|
|
|
+ await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
|
|
|
+
|
|
|
+ if jwt_token:
|
|
|
+ sid = get_sid_from_token(jwt_token, config.jwt_secret)
|
|
|
|
|
|
if sid == '':
|
|
|
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
|
|
|
@@ -320,11 +330,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
|
return
|
|
|
else:
|
|
|
sid = str(uuid.uuid4())
|
|
|
- token = sign_token({'sid': sid}, config.jwt_secret)
|
|
|
+ jwt_token = sign_token({'sid': sid}, config.jwt_secret)
|
|
|
|
|
|
logger.info(f'New session: {sid}')
|
|
|
session = session_manager.add_or_restart_session(sid, websocket)
|
|
|
- await websocket.send_json({'token': token, 'status': 'ok'})
|
|
|
+ await websocket.send_json({'token': jwt_token, 'status': 'ok'})
|
|
|
|
|
|
latest_event_id = -1
|
|
|
if websocket.query_params.get('latest_event_id'):
|
|
|
@@ -840,26 +850,21 @@ def github_callback(auth_code: AuthCode):
|
|
|
)
|
|
|
|
|
|
|
|
|
-class User(BaseModel):
|
|
|
- login: str # GitHub login handle
|
|
|
-
|
|
|
-
|
|
|
@app.post('/api/authenticate')
|
|
|
-def authenticate(user: User | None = None):
|
|
|
- global GITHUB_USER_LIST
|
|
|
-
|
|
|
- # Only check if waitlist is provided
|
|
|
- if GITHUB_USER_LIST:
|
|
|
- if user is None or user.login not in GITHUB_USER_LIST:
|
|
|
- return JSONResponse(
|
|
|
- status_code=status.HTTP_403_FORBIDDEN,
|
|
|
- content={'error': 'User not on waitlist'},
|
|
|
- )
|
|
|
+async def authenticate(request: Request):
|
|
|
+ token = request.headers.get('X-GitHub-Token')
|
|
|
+ if not await authenticate_github_user(token):
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
+ content={'error': 'Not authorized via GitHub waitlist'},
|
|
|
+ )
|
|
|
|
|
|
- return JSONResponse(
|
|
|
+ response = JSONResponse(
|
|
|
status_code=status.HTTP_200_OK, content={'message': 'User authenticated'}
|
|
|
)
|
|
|
|
|
|
+ return response
|
|
|
+
|
|
|
|
|
|
class SPAStaticFiles(StaticFiles):
|
|
|
async def get_response(self, path: str, scope):
|