Browse Source

Remove global config from auth (#2962)

Graham Neubig 1 năm trước cách đây
mục cha
commit
01ce1e35b5
2 tập tin đã thay đổi với 7 bổ sung8 xóa
  1. 4 5
      opendevin/server/auth/auth.py
  2. 3 3
      opendevin/server/listen.py

+ 4 - 5
opendevin/server/auth/auth.py

@@ -1,11 +1,10 @@
 import jwt
 from jwt.exceptions import InvalidTokenError
 
-from opendevin.core.config import config
 from opendevin.core.logger import opendevin_logger as logger
 
 
-def get_sid_from_token(token: str) -> str:
+def get_sid_from_token(token: str, jwt_secret: str) -> str:
     """Retrieves the session id from a JWT token.
 
     Parameters:
@@ -16,7 +15,7 @@ def get_sid_from_token(token: str) -> str:
     """
     try:
         # Decode the JWT using the specified secret and algorithm
-        payload = jwt.decode(token, config.jwt_secret, algorithms=['HS256'])
+        payload = jwt.decode(token, jwt_secret, algorithms=['HS256'])
 
         # Ensure the payload contains 'sid'
         if 'sid' in payload:
@@ -31,10 +30,10 @@ def get_sid_from_token(token: str) -> str:
     return ''
 
 
-def sign_token(payload: dict[str, object]) -> str:
+def sign_token(payload: dict[str, object], jwt_secret: str) -> str:
     """Signs a JWT token."""
     # payload = {
     #     "sid": sid,
     #     # "exp": datetime.now(timezone.utc) + timedelta(minutes=15),
     # }
-    return jwt.encode(payload, config.jwt_secret, algorithm='HS256')
+    return jwt.encode(payload, jwt_secret, algorithm='HS256')

+ 3 - 3
opendevin/server/listen.py

@@ -166,7 +166,7 @@ async def attach_session(request: Request, call_next):
     if 'Bearer' in auth_token:
         auth_token = auth_token.split('Bearer')[1].strip()
 
-    request.state.sid = get_sid_from_token(auth_token)
+    request.state.sid = get_sid_from_token(auth_token, config.jwt_secret)
     if request.state.sid == '':
         return JSONResponse(
             status_code=status.HTTP_401_UNAUTHORIZED,
@@ -245,7 +245,7 @@ async def websocket_endpoint(websocket: WebSocket):
 
     if websocket.query_params.get('token'):
         token = websocket.query_params.get('token')
-        sid = get_sid_from_token(token)
+        sid = get_sid_from_token(token, config.jwt_secret)
 
         if sid == '':
             await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
@@ -253,7 +253,7 @@ async def websocket_endpoint(websocket: WebSocket):
             return
     else:
         sid = str(uuid.uuid4())
-        token = sign_token({'sid': sid})
+        token = sign_token({'sid': sid}, config.jwt_secret)
 
     session = session_manager.add_or_restart_session(sid, websocket)
     await websocket.send_json({'token': token, 'status': 'ok'})