Просмотр исходного кода

Feat: Allow checking multiple conversations running at the same time (#5843)

tofarr 1 год назад
Родитель
Сommit
500598666e

+ 68 - 28
openhands/server/session/manager.py

@@ -2,6 +2,7 @@ import asyncio
 import json
 import time
 from dataclasses import dataclass, field
+from uuid import uuid4
 
 import socketio
 
@@ -27,6 +28,14 @@ class ConversationDoesNotExistError(Exception):
     pass
 
 
+@dataclass
+class _SessionIsRunningCheck:
+    request_id: str
+    request_sids: list[str]
+    running_sids: set[str] = field(default_factory=set)
+    flag: asyncio.Event = field(default_factory=asyncio.Event)
+
+
 @dataclass
 class SessionManager:
     sio: socketio.AsyncServer
@@ -36,7 +45,9 @@ class SessionManager:
     local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
     _last_alive_timestamps: dict[str, float] = field(default_factory=dict)
     _redis_listen_task: asyncio.Task | None = None
-    _session_is_running_flags: dict[str, asyncio.Event] = field(default_factory=dict)
+    _session_is_running_checks: dict[str, _SessionIsRunningCheck] = field(
+        default_factory=dict
+    )
     _active_conversations: dict[str, tuple[Conversation, int]] = field(
         default_factory=dict
     )
@@ -97,27 +108,41 @@ class SessionManager:
     async def _process_message(self, message: dict):
         data = json.loads(message['data'])
         logger.debug(f'got_published_message:{message}')
-        sid = data['sid']
         message_type = data['message_type']
         if message_type == 'event':
+            sid = data['sid']
             session = self._local_agent_loops_by_sid.get(sid)
             if session:
                 await session.dispatch(data['data'])
         elif message_type == 'is_session_running':
             # Another node in the cluster is asking if the current node is running the session given.
-            session = self._local_agent_loops_by_sid.get(sid)
-            if session:
+            request_id = data['request_id']
+            sids = [
+                sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid
+            ]
+            if sids:
                 await self._get_redis_client().publish(
                     'oh_event',
-                    json.dumps({'sid': sid, 'message_type': 'session_is_running'}),
+                    json.dumps(
+                        {
+                            'request_id': request_id,
+                            'sids': sids,
+                            'message_type': 'session_is_running',
+                        }
+                    ),
                 )
         elif message_type == 'session_is_running':
-            self._last_alive_timestamps[sid] = time.time()
-            flag = self._session_is_running_flags.get(sid)
-            if flag:
-                flag.set()
+            request_id = data['request_id']
+            for sid in data['sids']:
+                self._last_alive_timestamps[sid] = time.time()
+            check = self._session_is_running_checks.get(request_id)
+            if check:
+                check.running_sids.update(data['sids'])
+                if len(check.request_sids) == len(check.running_sids):
+                    check.flag.set()
         elif message_type == 'has_remote_connections_query':
             # Another node in the cluster is asking if the current node is connected to a session
+            sid = data['sid']
             required = sid in self.local_connection_id_to_session_id.values()
             if required:
                 await self._get_redis_client().publish(
@@ -127,12 +152,14 @@ class SessionManager:
                     ),
                 )
         elif message_type == 'has_remote_connections_response':
+            sid = data['sid']
             flag = self._has_remote_connections_flags.get(sid)
             if flag:
                 flag.set()
         elif message_type == 'session_closing':
             # Session closing event - We only get this in the event of graceful shutdown,
             # which can't be guaranteed - nodes can simply vanish unexpectedly!
+            sid = data['sid']
             logger.debug(f'session_closing:{sid}')
             for (
                 connection_id,
@@ -234,33 +261,47 @@ class SessionManager:
                 logger.warning('error_cleaning_detached_conversations', exc_info=True)
                 await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
 
-    async def _is_agent_loop_running(self, sid: str) -> bool:
-        if await self._is_agent_loop_running_locally(sid):
+    async def get_agent_loop_running(self, sids: set[str]) -> set[str]:
+        running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid)
+        check_cluster_sids = [sid for sid in sids if sid not in running_sids]
+        running_cluster_sids = await self.get_agent_loop_running_in_cluster(
+            check_cluster_sids
+        )
+        running_sids.union(running_cluster_sids)
+        return running_sids
+
+    async def is_agent_loop_running(self, sid: str) -> bool:
+        if await self.is_agent_loop_running_locally(sid):
             return True
-        if await self._is_agent_loop_running_in_cluster(sid):
+        if await self.is_agent_loop_running_in_cluster(sid):
             return True
         return False
 
-    async def _is_agent_loop_running_locally(self, sid: str) -> bool:
-        if self._local_agent_loops_by_sid.get(sid, None):
-            return True
-        return False
+    async def is_agent_loop_running_locally(self, sid: str) -> bool:
+        return sid in self._local_agent_loops_by_sid
+
+    async def is_agent_loop_running_in_cluster(self, sid: str) -> bool:
+        running_sids = await self.get_agent_loop_running_in_cluster([sid])
+        return bool(running_sids)
 
-    async def _is_agent_loop_running_in_cluster(self, sid: str) -> bool:
+    async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]:
         """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
         redis_client = self._get_redis_client()
         if not redis_client:
-            return False
+            return set()
 
         flag = asyncio.Event()
-        self._session_is_running_flags[sid] = flag
+        request_id = str(uuid4())
+        check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids)
+        self._session_is_running_checks[request_id] = check
         try:
-            logger.debug(f'publish:is_session_running:{sid}')
+            logger.debug(f'publish:is_session_running:{sids}')
             await redis_client.publish(
                 'oh_event',
                 json.dumps(
                     {
-                        'sid': sid,
+                        'request_id': request_id,
+                        'sids': sids,
                         'message_type': 'is_session_running',
                     }
                 ),
@@ -268,13 +309,12 @@ class SessionManager:
             async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
                 await flag.wait()
 
-            result = flag.is_set()
-            return result
+            return check.running_sids
         except TimeoutError:
             # Nobody replied in time
-            return False
+            return check.running_sids
         finally:
-            self._session_is_running_flags.pop(sid, None)
+            self._session_is_running_checks.pop(request_id, None)
 
     async def _has_remote_connections(self, sid: str) -> bool:
         """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
@@ -307,7 +347,7 @@ class SessionManager:
     ) -> EventStream:
         logger.info(f'maybe_start_agent_loop:{sid}')
         session: Session | None = None
-        if not await self._is_agent_loop_running(sid):
+        if not await self.is_agent_loop_running(sid):
             logger.info(f'start_agent_loop:{sid}')
             session = Session(
                 sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
@@ -328,7 +368,7 @@ class SessionManager:
             logger.info(f'found_local_agent_loop:{sid}')
             return session.agent_session.event_stream
 
-        if await self._is_agent_loop_running_in_cluster(sid):
+        if await self.is_agent_loop_running_in_cluster(sid):
             logger.info(f'found_remote_agent_loop:{sid}')
             return EventStream(sid, self.file_store)
 
@@ -352,7 +392,7 @@ class SessionManager:
             next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
             if (
                 next_alive_check > time.time()
-                or await self._is_agent_loop_running_in_cluster(sid)
+                or await self.is_agent_loop_running_in_cluster(sid)
             ):
                 # Send the event to the other pod
                 await redis_client.publish(

+ 2 - 2
openhands/storage/memory.py

@@ -9,8 +9,8 @@ IN_MEMORY_FILES: dict = {}
 class InMemoryFileStore(FileStore):
     files: dict[str, str]
 
-    def __init__(self):
-        self.files = IN_MEMORY_FILES
+    def __init__(self, files: dict[str, str] = IN_MEMORY_FILES):
+        self.files = files
 
     def write(self, path: str, contents: str) -> None:
         self.files[path] = contents

+ 2 - 1
tests/unit/test_agent_controller.py

@@ -20,6 +20,7 @@ from openhands.llm import LLM
 from openhands.llm.metrics import Metrics
 from openhands.runtime.base import Runtime
 from openhands.storage import get_file_store
+from openhands.storage.memory import InMemoryFileStore
 
 
 @pytest.fixture
@@ -168,7 +169,7 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
 @pytest.mark.asyncio
 async def test_run_controller_stop_with_stuck():
     config = AppConfig()
-    file_store = get_file_store(config.file_store, config.file_store_path)
+    file_store = InMemoryFileStore({})
     event_stream = EventStream(sid='test', file_store=file_store)
 
     agent = MagicMock(spec=Agent)

+ 23 - 10
tests/unit/test_manager.py

@@ -2,6 +2,7 @@ import asyncio
 import json
 from dataclasses import dataclass
 from unittest.mock import AsyncMock, MagicMock, patch
+from uuid import uuid4
 
 import pytest
 
@@ -35,44 +36,56 @@ def get_mock_sio(get_message: GetMessageMock | None = None):
 @pytest.mark.asyncio
 async def test_session_not_running_in_cluster():
     sio = get_mock_sio()
+    id = uuid4()
     with (
         patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
+        patch('openhands.server.session.manager.uuid4', MagicMock(return_value=id)),
     ):
         async with SessionManager(
             sio, AppConfig(), InMemoryFileStore()
         ) as session_manager:
-            result = await session_manager._is_agent_loop_running_in_cluster(
+            result = await session_manager.is_agent_loop_running_in_cluster(
                 'non-existant-session'
             )
             assert result is False
             assert sio.manager.redis.publish.await_count == 1
             sio.manager.redis.publish.assert_called_once_with(
                 'oh_event',
-                '{"sid": "non-existant-session", "message_type": "is_session_running"}',
+                '{"request_id": "'
+                + str(id)
+                + '", "sids": ["non-existant-session"], "message_type": "is_session_running"}',
             )
 
 
 @pytest.mark.asyncio
 async def test_session_is_running_in_cluster():
+    id = uuid4()
     sio = get_mock_sio(
         GetMessageMock(
-            {'sid': 'existing-session', 'message_type': 'session_is_running'}
+            {
+                'request_id': str(id),
+                'sids': ['existing-session'],
+                'message_type': 'session_is_running',
+            }
         )
     )
     with (
         patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
+        patch('openhands.server.session.manager.uuid4', MagicMock(return_value=id)),
     ):
         async with SessionManager(
             sio, AppConfig(), InMemoryFileStore()
         ) as session_manager:
-            result = await session_manager._is_agent_loop_running_in_cluster(
+            result = await session_manager.is_agent_loop_running_in_cluster(
                 'existing-session'
             )
             assert result is True
             assert sio.manager.redis.publish.await_count == 1
             sio.manager.redis.publish.assert_called_once_with(
                 'oh_event',
-                '{"sid": "existing-session", "message_type": "is_session_running"}',
+                '{"request_id": "'
+                + str(id)
+                + '", "sids": ["existing-session"], "message_type": "is_session_running"}',
             )
 
 
@@ -93,7 +106,7 @@ async def test_init_new_local_session():
             AsyncMock(),
         ),
         patch(
-            'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
+            'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
             is_agent_loop_running_in_cluster_mock,
         ),
     ):
@@ -125,7 +138,7 @@ async def test_join_local_session():
             AsyncMock(),
         ),
         patch(
-            'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
+            'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
             is_agent_loop_running_in_cluster_mock,
         ),
     ):
@@ -158,7 +171,7 @@ async def test_join_cluster_session():
             AsyncMock(),
         ),
         patch(
-            'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
+            'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
             is_agent_loop_running_in_cluster_mock,
         ),
     ):
@@ -187,7 +200,7 @@ async def test_add_to_local_event_stream():
             AsyncMock(),
         ),
         patch(
-            'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
+            'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
             is_agent_loop_running_in_cluster_mock,
         ),
     ):
@@ -221,7 +234,7 @@ async def test_add_to_cluster_event_stream():
             AsyncMock(),
         ),
         patch(
-            'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
+            'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
             is_agent_loop_running_in_cluster_mock,
         ),
     ):