Преглед изворни кода

Feat config server side store (#5594)

Co-authored-by: openhands <openhands@all-hands.dev>
tofarr пре 1 година
родитељ
комит
fe1bb1c233

+ 3 - 0
openhands/core/config/app_config.py

@@ -66,6 +66,9 @@ class AppConfig:
     modal_api_token_secret: str = ''
     disable_color: bool = False
     jwt_secret: str = ''
+    settings_store_class: str = (
+        'openhands.storage.file_settings_store.FileSettingsStore'
+    )
     debug: bool = False
     file_uploads_max_file_size_mb: int = 0
     file_uploads_restrict_file_types: bool = False

+ 2 - 0
openhands/server/app.py

@@ -22,6 +22,7 @@ from openhands.server.routes.files import app as files_api_router
 from openhands.server.routes.github import app as github_api_router
 from openhands.server.routes.public import app as public_api_router
 from openhands.server.routes.security import app as security_api_router
+from openhands.server.routes.settings import app as settings_router
 from openhands.server.shared import openhands_config, session_manager
 from openhands.utils.import_utils import get_impl
 
@@ -56,6 +57,7 @@ app.include_router(files_api_router)
 app.include_router(conversation_api_router)
 app.include_router(security_api_router)
 app.include_router(feedback_api_router)
+app.include_router(settings_router)
 app.include_router(github_api_router)
 
 

+ 23 - 11
openhands/server/listen_socket.py

@@ -10,8 +10,9 @@ from openhands.events.observation.agent import AgentStateChangedObservation
 from openhands.events.serialization import event_to_dict
 from openhands.events.stream import AsyncEventStreamWrapper
 from openhands.server.auth import get_sid_from_token, sign_token
+from openhands.server.routes.settings import SettingsStoreImpl
 from openhands.server.session.session_init_data import SessionInitData
-from openhands.server.shared import config, openhands_config, session_manager, sio
+from openhands.server.shared import config, session_manager, sio
 
 
 @sio.event
@@ -24,15 +25,16 @@ async def oh_action(connection_id: str, data: dict):
     # If it's an init, we do it here.
     action = data.get('action', '')
     if action == ActionType.INIT:
-        await openhands_config.github_auth(data)
-        github_token = data.pop('github_token', None)
-        token = data.pop('token', None)
-        latest_event_id = int(data.pop('latest_event_id', -1))
-        kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
-        session_init_data = SessionInitData(**kwargs)
-        session_init_data.github_token = github_token
-        session_init_data.selected_repository = data.get('selected_repository', None)
-        await init_connection(connection_id, token, session_init_data, latest_event_id)
+        await init_connection(
+            connection_id=connection_id,
+            token=data.get('token', None),
+            github_token=data.get('github_token', None),
+            session_init_args={
+                k.lower(): v for k, v in (data.get('args') or {}).items()
+            },
+            latest_event_id=int(data.get('latest_event_id', -1)),
+            selected_repository=data.get('selected_repository'),
+        )
         return
 
     logger.info(f'sio:oh_action:{connection_id}')
@@ -42,9 +44,19 @@ async def oh_action(connection_id: str, data: dict):
 async def init_connection(
     connection_id: str,
     token: str | None,
-    session_init_data: SessionInitData,
+    github_token: str | None,
+    session_init_args: dict,
     latest_event_id: int,
+    selected_repository: str | None,
 ):
+    settings_store = await SettingsStoreImpl.get_instance(config, github_token)
+    settings = await settings_store.load()
+    if settings:
+        session_init_args = {**settings.__dict__, **session_init_args}
+    session_init_args['github_token'] = github_token
+    session_init_args['selected_repository'] = selected_repository
+    session_init_data = SessionInitData(**session_init_args)
+
     if token:
         sid = get_sid_from_token(token, config.jwt_secret)
         if sid == '':

+ 47 - 0
openhands/server/routes/settings.py

@@ -0,0 +1,47 @@
+from typing import Annotated
+
+from fastapi import APIRouter, Header, status
+from fastapi.responses import JSONResponse
+
+from openhands.core.logger import openhands_logger as logger
+from openhands.server.settings import Settings
+from openhands.server.shared import config
+from openhands.storage.settings_store import SettingsStore
+from openhands.utils.import_utils import get_impl
+
+app = APIRouter(prefix='/api')
+
+SettingsStoreImpl = get_impl(SettingsStore, config.settings_store_class)  # type: ignore
+
+
+@app.get('/settings')
+async def load_settings(
+    github_auth: Annotated[str | None, Header()] = None,
+) -> Settings | None:
+    try:
+        settings_store = await SettingsStoreImpl.get_instance(config, github_auth)
+        settings = await settings_store.load()
+        return settings
+    except Exception as e:
+        logger.warning(f'Invalid token: {e}')
+        return JSONResponse(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            content={'error': 'Invalid token'},
+        )
+
+
+@app.post('/settings')
+async def store_settings(
+    settings: Settings,
+    github_auth: Annotated[str | None, Header()] = None,
+) -> bool:
+    try:
+        settings_store = await SettingsStoreImpl.get_instance(config, github_auth)
+        settings = await settings_store.store(settings)
+        return True
+    except Exception as e:
+        logger.warning(f'Invalid token: {e}')
+        return JSONResponse(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            content={'error': 'Invalid token'},
+        )

+ 3 - 9
openhands/server/session/session_init_data.py

@@ -1,19 +1,13 @@
 from dataclasses import dataclass
 
+from openhands.server.settings import Settings
+
 
 @dataclass
-class SessionInitData:
+class SessionInitData(Settings):
     """
     Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
     """
 
-    language: str | None = None
-    agent: str | None = None
-    max_iterations: int | None = None
-    security_analyzer: str | None = None
-    confirmation_mode: bool | None = None
-    llm_model: str | None = None
-    llm_api_key: str | None = None
-    llm_base_url: str | None = None
     github_token: str | None = None
     selected_repository: str | None = None

+ 17 - 0
openhands/server/settings.py

@@ -0,0 +1,17 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class Settings:
+    """
+    Persisted settings for OpenHands sessions
+    """
+
+    language: str | None = None
+    agent: str | None = None
+    max_iterations: int | None = None
+    security_analyzer: str | None = None
+    confirmation_mode: bool | None = None
+    llm_model: str | None = None
+    llm_api_key: str | None = None
+    llm_base_url: str | None = None

+ 34 - 0
openhands/storage/file_settings_store.py

@@ -0,0 +1,34 @@
+from __future__ import annotations
+
+import json
+from dataclasses import dataclass
+
+from openhands.core.config.app_config import AppConfig
+from openhands.server.settings import Settings
+from openhands.storage import get_file_store
+from openhands.storage.files import FileStore
+from openhands.storage.settings_store import SettingsStore
+
+
+@dataclass
+class FileSettingsStore(SettingsStore):
+    file_store: FileStore
+    path: str = 'settings.json'
+
+    async def load(self) -> Settings | None:
+        try:
+            json_str = self.file_store.read(self.path)
+            kwargs = json.loads(json_str)
+            settings = Settings(**kwargs)
+            return settings
+        except FileNotFoundError:
+            return None
+
+    async def store(self, settings: Settings):
+        json_str = json.dumps(settings.__dict__)
+        self.file_store.write(self.path, json_str)
+
+    @classmethod
+    async def get_instance(cls, config: AppConfig, token: str | None):
+        file_store = get_file_store(config.file_store, config.file_store_path)
+        return FileSettingsStore(file_store)

+ 25 - 0
openhands/storage/settings_store.py

@@ -0,0 +1,25 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+
+from openhands.core.config.app_config import AppConfig
+from openhands.server.settings import Settings
+
+
+class SettingsStore(ABC):
+    """
+    Storage for SessionInitData. May or may not support multiple users depending on the environment
+    """
+
+    @abstractmethod
+    async def load(self) -> Settings | None:
+        """Load session init data"""
+
+    @abstractmethod
+    async def store(self, settings: Settings):
+        """Store session init data"""
+
+    @classmethod
+    @abstractmethod
+    async def get_instance(cls, config: AppConfig, token: str | None) -> SettingsStore:
+        """Get a store for the user represented by the token given"""

+ 81 - 0
tests/unit/test_file_settings_store.py

@@ -0,0 +1,81 @@
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from openhands.core.config.app_config import AppConfig
+from openhands.server.settings import Settings
+from openhands.storage.file_settings_store import FileSettingsStore
+from openhands.storage.files import FileStore
+
+
+@pytest.fixture
+def mock_file_store():
+    return MagicMock(spec=FileStore)
+
+
+@pytest.fixture
+def session_init_store(mock_file_store):
+    return FileSettingsStore(mock_file_store)
+
+
+@pytest.mark.asyncio
+async def test_load_nonexistent_data(session_init_store):
+    session_init_store.file_store.read.side_effect = FileNotFoundError()
+    assert await session_init_store.load() is None
+
+
+@pytest.mark.asyncio
+async def test_store_and_load_data(session_init_store):
+    # Test data
+    init_data = Settings(
+        language='python',
+        agent='test-agent',
+        max_iterations=100,
+        security_analyzer='default',
+        confirmation_mode=True,
+        llm_model='test-model',
+        llm_api_key='test-key',
+        llm_base_url='https://test.com',
+    )
+
+    # Store data
+    await session_init_store.store(init_data)
+
+    # Verify store called with correct JSON
+    expected_json = json.dumps(init_data.__dict__)
+    session_init_store.file_store.write.assert_called_once_with(
+        'settings.json', expected_json
+    )
+
+    # Setup mock for load
+    session_init_store.file_store.read.return_value = expected_json
+
+    # Load and verify data
+    loaded_data = await session_init_store.load()
+    assert loaded_data is not None
+    assert loaded_data.language == init_data.language
+    assert loaded_data.agent == init_data.agent
+    assert loaded_data.max_iterations == init_data.max_iterations
+    assert loaded_data.security_analyzer == init_data.security_analyzer
+    assert loaded_data.confirmation_mode == init_data.confirmation_mode
+    assert loaded_data.llm_model == init_data.llm_model
+    assert loaded_data.llm_api_key == init_data.llm_api_key
+    assert loaded_data.llm_base_url == init_data.llm_base_url
+
+
+@pytest.mark.asyncio
+async def test_get_instance():
+    config = AppConfig(file_store='local', file_store_path='/test/path')
+
+    with patch(
+        'openhands.storage.file_settings_store.get_file_store'
+    ) as mock_get_store:
+        mock_store = MagicMock(spec=FileStore)
+        mock_get_store.return_value = mock_store
+
+        store = await FileSettingsStore.get_instance(config, None)
+
+        assert isinstance(store, FileSettingsStore)
+        assert store.file_store == mock_store
+        mock_get_store.assert_called_once_with('local', '/test/path')