Просмотр исходного кода

Refactor ConversationStore to follow SettingsStore pattern (#5881)

Co-authored-by: openhands <openhands@all-hands.dev>
tofarr 1 год назад
Родитель
Сommit
4dd40049ab

+ 4 - 1
openhands/server/config/openhands_config.py

@@ -16,7 +16,10 @@ class OpenhandsConfig(OpenhandsConfigInterface):
         'openhands.server.middleware.AttachConversationMiddleware'
     )
     settings_store_class: str = (
-        'openhands.storage.file_settings_store.FileSettingsStore'
+        'openhands.storage.settings.file_settings_store.FileSettingsStore'
+    )
+    conversation_store_class: str = (
+        'openhands.storage.conversation.file_conversation_store.FileConversationStore'
     )
 
     def verify_config(self):

+ 8 - 0
openhands/server/data_models/conversation_metadata.py

@@ -0,0 +1,8 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class ConversationMetadata:
+    conversation_id: str
+    github_user_id: str
+    selected_repository: str | None

+ 4 - 5
openhands/server/listen_socket.py

@@ -13,13 +13,10 @@ from openhands.events.observation import (
 from openhands.events.observation.agent import AgentStateChangedObservation
 from openhands.events.serialization import event_to_dict
 from openhands.events.stream import AsyncEventStreamWrapper
-from openhands.server.routes.settings import SettingsStoreImpl
+from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
 from openhands.server.session.manager import ConversationDoesNotExistError
 from openhands.server.shared import config, openhands_config, session_manager, sio
 from openhands.server.types import AppMode
-from openhands.storage.conversation.conversation_store import (
-    ConversationStore,
-)
 from openhands.utils.async_utils import call_sync_from_async
 
 
@@ -44,7 +41,9 @@ async def connect(connection_id: str, environ, auth):
 
         logger.info(f'User {user_id} is connecting to conversation {conversation_id}')
 
-        conversation_store = await ConversationStore.get_instance(config)
+        conversation_store = await ConversationStoreImpl.get_instance(
+            config, github_token
+        )
         metadata = await conversation_store.get_metadata(conversation_id)
         if metadata.github_user_id != user_id:
             logger.error(

+ 3 - 6
openhands/server/routes/new_conversation.py

@@ -6,13 +6,10 @@ from github import Github
 from pydantic import BaseModel
 
 from openhands.core.logger import openhands_logger as logger
-from openhands.server.routes.settings import SettingsStoreImpl
+from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
 from openhands.server.session.conversation_init_data import ConversationInitData
 from openhands.server.shared import config, session_manager
-from openhands.storage.conversation.conversation_store import (
-    ConversationMetadata,
-    ConversationStore,
-)
+from openhands.server.data_models.conversation_metadata import ConversationMetadata
 from openhands.utils.async_utils import call_sync_from_async
 
 app = APIRouter(prefix='/api')
@@ -49,7 +46,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
     session_init_args['selected_repository'] = data.selected_repository
     conversation_init_data = ConversationInitData(**session_init_args)
 
-    conversation_store = await ConversationStore.get_instance(config)
+    conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
 
     conversation_id = uuid.uuid4().hex
     while await conversation_store.exists(conversation_id):

+ 5 - 1
openhands/server/routes/settings.py

@@ -4,12 +4,16 @@ from fastapi.responses import JSONResponse
 from openhands.core.logger import openhands_logger as logger
 from openhands.server.settings import Settings
 from openhands.server.shared import config, openhands_config
-from openhands.storage.settings_store import SettingsStore
+from openhands.storage.conversation.conversation_store import ConversationStore
+from openhands.storage.settings.settings_store import SettingsStore
 from openhands.utils.import_utils import get_impl
 
 app = APIRouter(prefix='/api')
 
 SettingsStoreImpl = get_impl(SettingsStore, openhands_config.settings_store_class)  # type: ignore
+ConversationStoreImpl = get_impl(
+    ConversationStore, openhands_config.conversation_store_class  # type: ignore
+)
 
 
 @app.get('/settings')

+ 19 - 31
openhands/storage/conversation/conversation_store.py

@@ -1,43 +1,31 @@
-import json
-from dataclasses import dataclass
-
-from openhands.core.config.app_config import AppConfig
-from openhands.storage import get_file_store
-from openhands.storage.files import FileStore
-from openhands.storage.locations import get_conversation_metadata_filename
-from openhands.utils.async_utils import call_sync_from_async
+from __future__ import annotations
 
+from abc import ABC, abstractmethod
 
-@dataclass
-class ConversationMetadata:
-    conversation_id: str
-    github_user_id: str
-    selected_repository: str | None
+from openhands.core.config.app_config import AppConfig
+from openhands.server.data_models.conversation_metadata import ConversationMetadata
 
 
-@dataclass
-class ConversationStore:
-    file_store: FileStore
+class ConversationStore(ABC):
+    """
+    Storage for conversation metadata. May or may not support multiple users depending on the environment
+    """
 
+    @abstractmethod
     async def save_metadata(self, metadata: ConversationMetadata):
-        json_str = json.dumps(metadata.__dict__)
-        path = get_conversation_metadata_filename(metadata.conversation_id)
-        await call_sync_from_async(self.file_store.write, path, json_str)
+        """Store conversation metadata"""
 
+    @abstractmethod
     async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
-        path = get_conversation_metadata_filename(conversation_id)
-        json_str = await call_sync_from_async(self.file_store.read, path)
-        return ConversationMetadata(**json.loads(json_str))
+        """Load conversation metadata"""
 
+    @abstractmethod
     async def exists(self, conversation_id: str) -> bool:
-        path = get_conversation_metadata_filename(conversation_id)
-        try:
-            await call_sync_from_async(self.file_store.read, path)
-            return True
-        except FileNotFoundError:
-            return False
+        """Check if conversation exists"""
 
     @classmethod
-    async def get_instance(cls, config: AppConfig):
-        file_store = get_file_store(config.file_store, config.file_store_path)
-        return ConversationStore(file_store)
+    @abstractmethod
+    async def get_instance(
+        cls, config: AppConfig, token: str | None
+    ) -> ConversationStore:
+        """Get a store for the user represented by the token given"""

+ 40 - 0
openhands/storage/conversation/file_conversation_store.py

@@ -0,0 +1,40 @@
+from __future__ import annotations
+
+import json
+from dataclasses import dataclass
+
+from openhands.core.config.app_config import AppConfig
+from openhands.storage import get_file_store
+from openhands.storage.conversation.conversation_store import ConversationStore
+from openhands.server.data_models.conversation_metadata import ConversationMetadata
+from openhands.storage.files import FileStore
+from openhands.storage.locations import get_conversation_metadata_filename
+from openhands.utils.async_utils import call_sync_from_async
+
+
+@dataclass
+class FileConversationStore(ConversationStore):
+    file_store: FileStore
+
+    async def save_metadata(self, metadata: ConversationMetadata):
+        json_str = json.dumps(metadata.__dict__)
+        path = get_conversation_metadata_filename(metadata.conversation_id)
+        await call_sync_from_async(self.file_store.write, path, json_str)
+
+    async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
+        path = get_conversation_metadata_filename(conversation_id)
+        json_str = await call_sync_from_async(self.file_store.read, path)
+        return ConversationMetadata(**json.loads(json_str))
+
+    async def exists(self, conversation_id: str) -> bool:
+        path = get_conversation_metadata_filename(conversation_id)
+        try:
+            await call_sync_from_async(self.file_store.read, path)
+            return True
+        except FileNotFoundError:
+            return False
+
+    @classmethod
+    async def get_instance(cls, config: AppConfig, token: str | None):
+        file_store = get_file_store(config.file_store, config.file_store_path)
+        return FileConversationStore(file_store)

+ 1 - 1
openhands/storage/file_settings_store.py → openhands/storage/settings/file_settings_store.py

@@ -7,7 +7,7 @@ from openhands.core.config.app_config import AppConfig
 from openhands.server.settings import Settings
 from openhands.storage import get_file_store
 from openhands.storage.files import FileStore
-from openhands.storage.settings_store import SettingsStore
+from openhands.storage.settings.settings_store import SettingsStore
 from openhands.utils.async_utils import call_sync_from_async
 
 

+ 0 - 0
openhands/storage/settings_store.py → openhands/storage/settings/settings_store.py


+ 2 - 0
pyproject.toml

@@ -100,6 +100,7 @@ reportlab = "*"
 [tool.coverage.run]
 concurrency = ["gevent"]
 
+
 [tool.poetry.group.runtime.dependencies]
 jupyterlab = "*"
 notebook = "*"
@@ -129,6 +130,7 @@ ignore = ["D1"]
 [tool.ruff.lint.pydocstyle]
 convention = "google"
 
+
 [tool.poetry.group.evaluation.dependencies]
 streamlit = "*"
 whatthepatch = "*"

+ 11 - 11
tests/unit/test_file_settings_store.py

@@ -5,8 +5,8 @@ import pytest
 
 from openhands.core.config.app_config import AppConfig
 from openhands.server.settings import Settings
-from openhands.storage.file_settings_store import FileSettingsStore
 from openhands.storage.files import FileStore
+from openhands.storage.settings.file_settings_store import FileSettingsStore
 
 
 @pytest.fixture
@@ -15,18 +15,18 @@ def mock_file_store():
 
 
 @pytest.fixture
-def session_init_store(mock_file_store):
+def file_settings_store(mock_file_store):
     return FileSettingsStore(mock_file_store)
 
 
 @pytest.mark.asyncio
-async def test_load_nonexistent_data(session_init_store):
-    session_init_store.file_store.read.side_effect = FileNotFoundError()
-    assert await session_init_store.load() is None
+async def test_load_nonexistent_data(file_settings_store):
+    file_settings_store.file_store.read.side_effect = FileNotFoundError()
+    assert await file_settings_store.load() is None
 
 
 @pytest.mark.asyncio
-async def test_store_and_load_data(session_init_store):
+async def test_store_and_load_data(file_settings_store):
     # Test data
     init_data = Settings(
         language='python',
@@ -40,19 +40,19 @@ async def test_store_and_load_data(session_init_store):
     )
 
     # Store data
-    await session_init_store.store(init_data)
+    await file_settings_store.store(init_data)
 
     # Verify store called with correct JSON
     expected_json = json.dumps(init_data.__dict__)
-    session_init_store.file_store.write.assert_called_once_with(
+    file_settings_store.file_store.write.assert_called_once_with(
         'settings.json', expected_json
     )
 
     # Setup mock for load
-    session_init_store.file_store.read.return_value = expected_json
+    file_settings_store.file_store.read.return_value = expected_json
 
     # Load and verify data
-    loaded_data = await session_init_store.load()
+    loaded_data = await file_settings_store.load()
     assert loaded_data is not None
     assert loaded_data.language == init_data.language
     assert loaded_data.agent == init_data.agent
@@ -69,7 +69,7 @@ async def test_get_instance():
     config = AppConfig(file_store='local', file_store_path='/test/path')
 
     with patch(
-        'openhands.storage.file_settings_store.get_file_store'
+        'openhands.storage.settings.file_settings_store.get_file_store'
     ) as mock_get_store:
         mock_store = MagicMock(spec=FileStore)
         mock_get_store.return_value = mock_store