浏览代码

Feat: Introduce class for SessionInitData rather than using a dict (#5406)

tofarr 1 年之前
父节点
当前提交
de81020a8d

+ 5 - 5
config.template.toml

@@ -95,10 +95,10 @@ workspace_base = "./workspace"
 # AWS secret access key
 #aws_secret_access_key = ""
 
-# API key to use
+# API key to use (For Headless / CLI only -  In Web this is overridden by Session Init)
 api_key = "your-api-key"
 
-# API base URL
+# API base URL (For Headless / CLI only -  In Web this is overridden by Session Init)
 #base_url = ""
 
 # API version
@@ -131,7 +131,7 @@ embedding_model = "local"
 # Maximum number of output tokens
 #max_output_tokens = 0
 
-# Model to use
+# Model to use. (For Headless / CLI only -  In Web this is overridden by Session Init)
 model = "gpt-4o"
 
 # Number of retries to attempt when an operation fails with the LLM.
@@ -237,10 +237,10 @@ llm_config = 'gpt3'
 ##############################################################################
 [security]
 
-# Enable confirmation mode
+# Enable confirmation mode (For Headless / CLI only -  In Web this is overridden by Session Init)
 #confirmation_mode = false
 
-# The security analyzer to use
+# The security analyzer to use (For Headless / CLI only -  In Web this is overridden by Session Init)
 #security_analyzer = ""
 
 #################################### Eval ####################################

+ 5 - 4
openhands/server/session/manager.py

@@ -11,6 +11,7 @@ from openhands.events.stream import EventStream, session_exists
 from openhands.runtime.base import RuntimeUnavailableError
 from openhands.server.session.conversation import Conversation
 from openhands.server.session.session import ROOM_KEY, Session
+from openhands.server.session.session_init_data import SessionInitData
 from openhands.storage.files import FileStore
 from openhands.utils.shutdown_listener import should_continue
 
@@ -141,7 +142,7 @@ class SessionManager:
     async def detach_from_conversation(self, conversation: Conversation):
         await conversation.disconnect()
 
-    async def init_or_join_session(self, sid: str, connection_id: str, data: dict):
+    async def init_or_join_session(self, sid: str, connection_id: str, session_init_data: SessionInitData):
         await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
         self.local_connection_id_to_session_id[connection_id] = sid
 
@@ -156,7 +157,7 @@ class SessionManager:
         if redis_client and await self._is_session_running_in_cluster(sid):
             return EventStream(sid, self.file_store)
 
-        return await self.start_local_session(sid, data)
+        return await self.start_local_session(sid, session_init_data)
 
     async def _is_session_running_in_cluster(self, sid: str) -> bool:
         """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
@@ -210,14 +211,14 @@ class SessionManager:
         finally:
             self._has_remote_connections_flags.pop(sid)
 
-    async def start_local_session(self, sid: str, data: dict):
+    async def start_local_session(self, sid: str, session_init_data: SessionInitData):
         # Start a new local session
         logger.info(f'start_new_local_session:{sid}')
         session = Session(
             sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
         )
         self.local_sessions_by_sid[sid] = session
-        await session.initialize_agent(data)
+        await session.initialize_agent(session_init_data)
         return session.agent_session.event_stream
 
     async def send_to_event_stream(self, connection_id: str, data: dict):

+ 14 - 23
openhands/server/session/session.py

@@ -1,4 +1,5 @@
 import asyncio
+from copy import deepcopy
 import time
 
 import socketio
@@ -21,6 +22,7 @@ from openhands.events.serialization import event_from_dict, event_to_dict
 from openhands.events.stream import EventStreamSubscriber
 from openhands.llm.llm import LLM
 from openhands.server.session.agent_session import AgentSession
+from openhands.server.session.session_init_data import SessionInitData
 from openhands.storage.files import FileStore
 
 ROOM_KEY = 'room:{sid}'
@@ -34,7 +36,6 @@ class Session:
     agent_session: AgentSession
     loop: asyncio.AbstractEventLoop
     config: AppConfig
-    settings: dict | None
 
     def __init__(
         self,
@@ -52,41 +53,31 @@ class Session:
         self.agent_session.event_stream.subscribe(
             EventStreamSubscriber.SERVER, self.on_event, self.sid
         )
-        self.config = config
+        # Copying this means that when we update variables they are not applied to the shared global configuration!
+        self.config = deepcopy(config)
         self.loop = asyncio.get_event_loop()
-        self.settings = None
 
     def close(self):
         self.is_alive = False
         self.agent_session.close()
 
-    async def initialize_agent(self, data: dict):
-        self.settings = data
+    async def initialize_agent(self, session_init_data: SessionInitData):
         self.agent_session.event_stream.add_event(
             AgentStateChangedObservation('', AgentState.LOADING),
             EventSource.ENVIRONMENT,
         )
         # Extract the agent-relevant arguments from the request
-        args = {key: value for key, value in data.get('args', {}).items()}
-        agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
-        self.config.security.confirmation_mode = args.get(
-            ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode
-        )
-        self.config.security.security_analyzer = data.get('args', {}).get(
-            ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer
-        )
-        max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
+        agent_cls = session_init_data.agent or self.config.default_agent
+        self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode
+        self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
+        max_iterations = session_init_data.max_iterations or self.config.max_iterations
         # override default LLM config
+        
+
         default_llm_config = self.config.get_llm_config()
-        default_llm_config.model = args.get(
-            ConfigType.LLM_MODEL, default_llm_config.model
-        )
-        default_llm_config.api_key = args.get(
-            ConfigType.LLM_API_KEY, default_llm_config.api_key
-        )
-        default_llm_config.base_url = args.get(
-            ConfigType.LLM_BASE_URL, default_llm_config.base_url
-        )
+        default_llm_config.model = session_init_data.llm_model or default_llm_config.model
+        default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key
+        default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url
 
         # TODO: override other LLM config & agent config groups (#2075)
 

+ 18 - 0
openhands/server/session/session_init_data.py

@@ -0,0 +1,18 @@
+
+
+from dataclasses import dataclass
+
+
+@dataclass
+class SessionInitData:
+    """
+    Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
+    """
+    language: str | None = None
+    agent: str | None = None
+    max_iterations: int | None = None
+    security_analyzer: str | None = None
+    confirmation_mode: bool | None = None
+    llm_model: str | None = None
+    llm_api_key: str | None = None
+    llm_base_url: str | None = None

+ 19 - 7
openhands/server/socket.py

@@ -13,6 +13,7 @@ from openhands.events.serialization import event_to_dict
 from openhands.events.stream import AsyncEventStreamWrapper
 from openhands.server.auth import get_sid_from_token, sign_token
 from openhands.server.github_utils import authenticate_github_user
+from openhands.server.session.session_init_data import SessionInitData
 from openhands.server.shared import config, session_manager, sio
 
 
@@ -26,19 +27,30 @@ async def oh_action(connection_id: str, data: dict):
     # If it's an init, we do it here.
     action = data.get('action', '')
     if action == ActionType.INIT:
-        await init_connection(connection_id, data)
+        token = data.pop('token', None)
+        github_token = data.pop('github_token', None)
+        latest_event_id = int(data.pop('latest_event_id', -1))
+        kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
+        session_init_data = SessionInitData(**kwargs)
+        await init_connection(
+            connection_id, token, github_token, session_init_data, latest_event_id
+        )
         return
 
     logger.info(f'sio:oh_action:{connection_id}')
     await session_manager.send_to_event_stream(connection_id, data)
 
 
-async def init_connection(connection_id: str, data: dict):
-    gh_token = data.pop('github_token', None)
+async def init_connection(
+    connection_id: str,
+    token: str | None,
+    gh_token: str | None,
+    session_init_data: SessionInitData,
+    latest_event_id: int,
+):
     if not await authenticate_github_user(gh_token):
         raise RuntimeError(status.WS_1008_POLICY_VIOLATION)
 
-    token = data.pop('token', None)
     if token:
         sid = get_sid_from_token(token, config.jwt_secret)
         if sid == '':
@@ -52,10 +64,10 @@ async def init_connection(connection_id: str, data: dict):
     token = sign_token({'sid': sid}, config.jwt_secret)
     await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id)
 
-    latest_event_id = int(data.pop('latest_event_id', -1))
-
     # The session in question should exist, but may not actually be running locally...
-    event_stream = await session_manager.init_or_join_session(sid, connection_id, data)
+    event_stream = await session_manager.init_or_join_session(
+        sid, connection_id, session_init_data
+    )
 
     # Send events
     agent_state_changed = None

+ 7 - 6
tests/unit/test_manager.py

@@ -7,6 +7,7 @@ import pytest
 
 from openhands.core.config.app_config import AppConfig
 from openhands.server.session.manager import SessionManager
+from openhands.server.session.session_init_data import SessionInitData
 from openhands.storage.memory import InMemoryFileStore
 
 
@@ -100,7 +101,7 @@ async def test_init_new_local_session():
             sio, AppConfig(), InMemoryFileStore()
         ) as session_manager:
             await session_manager.init_or_join_session(
-                'new-session-id', 'new-session-id', {'type': 'mock-settings'}
+                'new-session-id', 'new-session-id', SessionInitData()
             )
     assert session_instance.initialize_agent.call_count == 1
     assert sio.enter_room.await_count == 1
@@ -132,11 +133,11 @@ async def test_join_local_session():
         ) as session_manager:
             # First call initializes
             await session_manager.init_or_join_session(
-                'new-session-id', 'new-session-id', {'type': 'mock-settings'}
+                'new-session-id', 'new-session-id', SessionInitData()
             )
             # Second call joins
             await session_manager.init_or_join_session(
-                'new-session-id', 'extra-connection-id', {'type': 'mock-settings'}
+                'new-session-id', 'extra-connection-id', SessionInitData()
             )
     assert session_instance.initialize_agent.call_count == 1
     assert sio.enter_room.await_count == 2
@@ -168,7 +169,7 @@ async def test_join_cluster_session():
         ) as session_manager:
             # First call initializes
             await session_manager.init_or_join_session(
-                'new-session-id', 'new-session-id', {'type': 'mock-settings'}
+                'new-session-id', 'new-session-id', SessionInitData()
             )
     assert session_instance.initialize_agent.call_count == 0
     assert sio.enter_room.await_count == 1
@@ -199,7 +200,7 @@ async def test_add_to_local_event_stream():
             sio, AppConfig(), InMemoryFileStore()
         ) as session_manager:
             await session_manager.init_or_join_session(
-                'new-session-id', 'connection-id', {'type': 'mock-settings'}
+                'new-session-id', 'connection-id', SessionInitData()
             )
             await session_manager.send_to_event_stream(
                 'connection-id', {'event_type': 'some_event'}
@@ -232,7 +233,7 @@ async def test_add_to_cluster_event_stream():
             sio, AppConfig(), InMemoryFileStore()
         ) as session_manager:
             await session_manager.init_or_join_session(
-                'new-session-id', 'connection-id', {'type': 'mock-settings'}
+                'new-session-id', 'connection-id', SessionInitData()
             )
             await session_manager.send_to_event_stream(
                 'connection-id', {'event_type': 'some_event'}