Procházet zdrojové kódy

Add cookie-based authentication to all routes (#4642)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
Robert Brennan před 1 rokem
rodič
revize
a812e2b5f1

+ 6 - 8
frontend/src/api/open-hands.ts

@@ -213,20 +213,18 @@ class OpenHands {
   }
 
   /**
-   * Check if the user is authenticated
-   * @param login The user's GitHub login handle
-   * @returns Whether the user is authenticated
+   * Authenticate with GitHub token
+   * @param token The GitHub access token
+   * @returns Response with authentication status and user info if successful
    */
-  static async isAuthenticated(login: string): Promise<boolean> {
-    const response = await fetch(`${OpenHands.BASE_URL}/api/authenticate`, {
+  static async authenticate(token: string): Promise<Response> {
+    return fetch(`${OpenHands.BASE_URL}/api/authenticate`, {
       method: "POST",
-      body: JSON.stringify({ login }),
       headers: {
         "Content-Type": "application/json",
+        "X-GitHub-Token": token,
       },
     });
-
-    return response.status === 200;
   }
 }
 

+ 5 - 0
frontend/src/api/open-hands.types.ts

@@ -27,6 +27,11 @@ export interface GitHubAccessTokenResponse {
   access_token: string;
 }
 
+export interface AuthenticationResponse {
+  message: string;
+  login?: string; // Only present when allow list is enabled
+}
+
 export interface Feedback {
   version: string;
   email: string;

+ 8 - 3
frontend/src/context/socket.tsx

@@ -49,9 +49,14 @@ function SocketProvider({ children }: SocketProviderProps) {
     const fallback = getValidFallbackHost();
     const baseUrl = import.meta.env.VITE_BACKEND_BASE_URL || fallback;
     const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
-    const ws = new WebSocket(
-      `${protocol}//${baseUrl}/ws${options?.token ? `?token=${options.token}` : ""}`,
-    );
+    const sessionToken = options?.token || "NO_JWT"; // not allowed to be empty or duplicated
+    const ghToken = localStorage.getItem("ghToken") || "NO_GITHUB";
+
+    const ws = new WebSocket(`${protocol}//${baseUrl}/ws`, [
+      "openhands",
+      sessionToken,
+      ghToken,
+    ]);
 
     ws.addEventListener("open", (event) => {
       setIsConnected(true);

+ 9 - 2
frontend/src/routes/oauth.github.callback.tsx

@@ -11,10 +11,17 @@ export const clientLoader = async ({ request }: ClientLoaderFunctionArgs) => {
   const code = url.searchParams.get("code");
 
   if (code) {
-    // request to the server to exchange the code for a token
     const { access_token: accessToken } =
       await OpenHands.getGitHubAccessToken(code);
-    // set the token in local storage
+
+    const authResponse = await OpenHands.authenticate(accessToken);
+    if (!authResponse.ok) {
+      return json(
+        { error: "Failed to authenticate with GitHub" },
+        { status: authResponse.status },
+      );
+    }
+
     localStorage.setItem("ghToken", accessToken);
     return redirect("/");
   }

+ 9 - 1
frontend/src/services/api.ts

@@ -1,4 +1,4 @@
-import { getToken } from "./auth";
+import { getToken, getGitHubToken } from "./auth";
 import toast from "#/utils/toast";
 
 const WAIT_FOR_AUTH_DELAY_MS = 500;
@@ -18,6 +18,7 @@ export async function request(
 
   const needsAuth = !url.startsWith("/api/options/");
   const token = getToken();
+  const githubToken = getGitHubToken();
   if (!token && needsAuth) {
     return new Promise((resolve) => {
       setTimeout(() => {
@@ -32,6 +33,13 @@ export async function request(
       Authorization: `Bearer ${token}`,
     };
   }
+  if (githubToken) {
+    // eslint-disable-next-line no-param-reassign
+    options.headers = {
+      ...(options.headers || {}),
+      "X-GitHub-Token": githubToken,
+    };
+  }
 
   let response = null;
   try {

+ 20 - 1
frontend/src/services/auth.ts

@@ -1,4 +1,5 @@
 const TOKEN_KEY = "token";
+const GITHUB_TOKEN_KEY = "ghToken";
 
 const getToken = (): string => localStorage.getItem(TOKEN_KEY) ?? "";
 
@@ -10,4 +11,22 @@ const setToken = (token: string): void => {
   localStorage.setItem(TOKEN_KEY, token);
 };
 
-export { getToken, setToken, clearToken };
+const getGitHubToken = (): string =>
+  localStorage.getItem(GITHUB_TOKEN_KEY) ?? "";
+
+const setGitHubToken = (token: string): void => {
+  localStorage.setItem(GITHUB_TOKEN_KEY, token);
+};
+
+const clearGitHubToken = (): void => {
+  localStorage.removeItem(GITHUB_TOKEN_KEY);
+};
+
+export {
+  getToken,
+  setToken,
+  clearToken,
+  getGitHubToken,
+  setGitHubToken,
+  clearGitHubToken,
+};

+ 3 - 10
frontend/src/utils/user-is-authenticated.ts

@@ -1,16 +1,9 @@
-import { retrieveGitHubUser, isGitHubErrorReponse } from "#/api/github";
 import OpenHands from "#/api/open-hands";
 
 export const userIsAuthenticated = async (ghToken: string | null) => {
   if (window.__APP_MODE__ !== "saas") return true;
+  if (!ghToken) return false;
 
-  let user: GitHubUser | GitHubErrorReponse | null = null;
-  if (ghToken) user = await retrieveGitHubUser(ghToken);
-
-  if (user && !isGitHubErrorReponse(user)) {
-    const isAuthed = await OpenHands.isAuthenticated(user.login);
-    return isAuthed;
-  }
-
-  return false;
+  const authResponse = await OpenHands.authenticate(ghToken);
+  return authResponse.ok;
 };

+ 1 - 1
openhands/agenthub/codeact_agent/codeact_agent.py

@@ -109,7 +109,7 @@ class CodeActAgent(Agent):
                 codeact_enable_jupyter=self.config.codeact_enable_jupyter,
                 codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
             )
-            logger.info(
+            logger.debug(
                 f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2)}'
             )
             self.system_prompt = codeact_function_calling.SYSTEM_PROMPT

+ 72 - 0
openhands/server/github.py

@@ -0,0 +1,72 @@
+import os
+
+import httpx
+
+from openhands.core.logger import openhands_logger as logger
+
+GITHUB_CLIENT_ID = os.getenv('GITHUB_CLIENT_ID', '').strip()
+GITHUB_CLIENT_SECRET = os.getenv('GITHUB_CLIENT_SECRET', '').strip()
+GITHUB_USER_LIST = None
+
+
+def load_github_user_list():
+    global GITHUB_USER_LIST
+    waitlist = os.getenv('GITHUB_USER_LIST_FILE')
+    if waitlist:
+        with open(waitlist, 'r') as f:
+            GITHUB_USER_LIST = [line.strip() for line in f if line.strip()]
+
+
+load_github_user_list()
+
+
+async def authenticate_github_user(auth_token) -> bool:
+    logger.info('Checking GitHub token')
+    if not GITHUB_USER_LIST:
+        return True
+
+    if not auth_token:
+        logger.warning('No GitHub token provided')
+        return False
+
+    login, error = await get_github_user(auth_token)
+    if error:
+        logger.warning(f'Invalid GitHub token: {error}')
+        return False
+    if login not in GITHUB_USER_LIST:
+        logger.warning(f'GitHub user {login} not in allow list')
+        return False
+
+    logger.info(f'GitHub user {login} authenticated')
+    return True
+
+
+async def get_github_user(token: str) -> tuple[str | None, str | None]:
+    """Get GitHub user info from token.
+
+    Args:
+        token: GitHub access token
+
+    Returns:
+        Tuple of (login, error_message)
+        If successful, error_message is None
+        If failed, login is None and error_message contains the error
+    """
+    headers = {
+        'Accept': 'application/vnd.github+json',
+        'Authorization': f'Bearer {token}',
+        'X-GitHub-Api-Version': '2022-11-28',
+    }
+    try:
+        async with httpx.AsyncClient() as client:
+            response = await client.get('https://api.github.com/user', headers=headers)
+            if response.status_code == 200:
+                user_data = response.json()
+                return user_data.get('login'), None
+            else:
+                return (
+                    None,
+                    f'GitHub API error: {response.status_code} - {response.text}',
+                )
+    except Exception as e:
+        return None, f'Error connecting to GitHub: {str(e)}'

+ 45 - 40
openhands/server/listen.py

@@ -13,6 +13,11 @@ from pathspec.patterns import GitWildMatchPattern
 
 from openhands.security.options import SecurityAnalyzers
 from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
+from openhands.server.github import (
+    GITHUB_CLIENT_ID,
+    GITHUB_CLIENT_SECRET,
+    authenticate_github_user,
+)
 from openhands.storage import get_file_store
 from openhands.utils.async_utils import call_sync_from_async
 
@@ -64,24 +69,6 @@ config = load_app_config()
 file_store = get_file_store(config.file_store, config.file_store_path)
 session_manager = SessionManager(config, file_store)
 
-GITHUB_CLIENT_ID = os.getenv('GITHUB_CLIENT_ID', '').strip()
-GITHUB_CLIENT_SECRET = os.getenv('GITHUB_CLIENT_SECRET', '').strip()
-
-# New global variable to store the user list
-GITHUB_USER_LIST = None
-
-
-# New function to load the user list
-def load_github_user_list():
-    global GITHUB_USER_LIST
-    waitlist = os.getenv('GITHUB_USER_LIST_FILE')
-    if waitlist:
-        with open(waitlist, 'r') as f:
-            GITHUB_USER_LIST = [line.strip() for line in f if line.strip()]
-
-
-load_github_user_list()
-
 
 @asynccontextmanager
 async def lifespan(app: FastAPI):
@@ -216,7 +203,13 @@ async def attach_session(request: Request, call_next):
         response = await call_next(request)
         return response
 
-    # For all other methods, validate the Authorization header
+    github_token = request.headers.get('X-GitHub-Token')
+    if not await authenticate_github_user(github_token):
+        return JSONResponse(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            content={'error': 'Not authenticated'},
+        )
+
     if not request.headers.get('Authorization'):
         logger.warning('Missing Authorization header')
         return JSONResponse(
@@ -308,11 +301,28 @@ async def websocket_endpoint(websocket: WebSocket):
         {"action": "finish", "args": {}}
         ```
     """
-    await asyncio.wait_for(websocket.accept(), 10)
+    # Get protocols from Sec-WebSocket-Protocol header
+    protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
+
+    # The first protocol should be our real protocol (e.g. 'openhands')
+    # The second protocol should contain our auth token
+    if len(protocols) < 3:
+        logger.error('Expected 3 websocket protocols, got %d', len(protocols))
+        await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
+        return
+
+    real_protocol = protocols[0]
+    jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
+    github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
 
-    if websocket.query_params.get('token'):
-        token = websocket.query_params.get('token')
-        sid = get_sid_from_token(token, config.jwt_secret)
+    if not await authenticate_github_user(github_token):
+        await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
+        return
+
+    await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
+
+    if jwt_token:
+        sid = get_sid_from_token(jwt_token, config.jwt_secret)
 
         if sid == '':
             await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
@@ -320,11 +330,11 @@ async def websocket_endpoint(websocket: WebSocket):
             return
     else:
         sid = str(uuid.uuid4())
-        token = sign_token({'sid': sid}, config.jwt_secret)
+        jwt_token = sign_token({'sid': sid}, config.jwt_secret)
 
     logger.info(f'New session: {sid}')
     session = session_manager.add_or_restart_session(sid, websocket)
-    await websocket.send_json({'token': token, 'status': 'ok'})
+    await websocket.send_json({'token': jwt_token, 'status': 'ok'})
 
     latest_event_id = -1
     if websocket.query_params.get('latest_event_id'):
@@ -840,26 +850,21 @@ def github_callback(auth_code: AuthCode):
     )
 
 
-class User(BaseModel):
-    login: str  # GitHub login handle
-
-
 @app.post('/api/authenticate')
-def authenticate(user: User | None = None):
-    global GITHUB_USER_LIST
-
-    # Only check if waitlist is provided
-    if GITHUB_USER_LIST:
-        if user is None or user.login not in GITHUB_USER_LIST:
-            return JSONResponse(
-                status_code=status.HTTP_403_FORBIDDEN,
-                content={'error': 'User not on waitlist'},
-            )
+async def authenticate(request: Request):
+    token = request.headers.get('X-GitHub-Token')
+    if not await authenticate_github_user(token):
+        return JSONResponse(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            content={'error': 'Not authorized via GitHub waitlist'},
+        )
 
-    return JSONResponse(
+    response = JSONResponse(
         status_code=status.HTTP_200_OK, content={'message': 'User authenticated'}
     )
 
+    return response
+
 
 class SPAStaticFiles(StaticFiles):
     async def get_response(self, path: str, scope):