| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- import asyncio
- import json
- from dataclasses import dataclass
- from unittest.mock import AsyncMock, MagicMock, patch
- import pytest
- from openhands.core.config.app_config import AppConfig
- from openhands.server.session.manager import SessionManager
- from openhands.server.session.session_init_data import SessionInitData
- from openhands.storage.memory import InMemoryFileStore
- @dataclass
- class GetMessageMock:
- message: dict | None
- sleep_time: int = 0.01
- async def get_message(self, **kwargs):
- await asyncio.sleep(self.sleep_time)
- return {'data': json.dumps(self.message)}
- def get_mock_sio(get_message: GetMessageMock | None = None):
- sio = MagicMock()
- sio.enter_room = AsyncMock()
- sio.manager.redis = MagicMock()
- sio.manager.redis.publish = AsyncMock()
- pubsub = AsyncMock()
- pubsub.get_message = (get_message or GetMessageMock(None)).get_message
- sio.manager.redis.pubsub.return_value = pubsub
- return sio
- @pytest.mark.asyncio
- async def test_session_not_running_in_cluster():
- sio = get_mock_sio()
- with (
- patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
- ):
- async with SessionManager(
- sio, AppConfig(), InMemoryFileStore()
- ) as session_manager:
- result = await session_manager._is_session_running_in_cluster(
- 'non-existant-session'
- )
- assert result is False
- assert sio.manager.redis.publish.await_count == 1
- sio.manager.redis.publish.assert_called_once_with(
- 'oh_event',
- '{"sid": "non-existant-session", "message_type": "is_session_running"}',
- )
- @pytest.mark.asyncio
- async def test_session_is_running_in_cluster():
- sio = get_mock_sio(
- GetMessageMock(
- {'sid': 'existing-session', 'message_type': 'session_is_running'}
- )
- )
- with (
- patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.05),
- ):
- async with SessionManager(
- sio, AppConfig(), InMemoryFileStore()
- ) as session_manager:
- result = await session_manager._is_session_running_in_cluster(
- 'existing-session'
- )
- assert result is True
- assert sio.manager.redis.publish.await_count == 1
- sio.manager.redis.publish.assert_called_once_with(
- 'oh_event',
- '{"sid": "existing-session", "message_type": "is_session_running"}',
- )
- @pytest.mark.asyncio
- async def test_init_new_local_session():
- session_instance = AsyncMock()
- session_instance.agent_session = MagicMock()
- mock_session = MagicMock()
- mock_session.return_value = session_instance
- sio = get_mock_sio()
- is_session_running_in_cluster_mock = AsyncMock()
- is_session_running_in_cluster_mock.return_value = False
- with (
- patch('openhands.server.session.manager.Session', mock_session),
- patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
- patch(
- 'openhands.server.session.manager.SessionManager._redis_subscribe',
- AsyncMock(),
- ),
- patch(
- 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
- is_session_running_in_cluster_mock,
- ),
- ):
- async with SessionManager(
- sio, AppConfig(), InMemoryFileStore()
- ) as session_manager:
- await session_manager.init_or_join_session(
- 'new-session-id', 'new-session-id', SessionInitData()
- )
- assert session_instance.initialize_agent.call_count == 1
- assert sio.enter_room.await_count == 1
- @pytest.mark.asyncio
- async def test_join_local_session():
- session_instance = AsyncMock()
- session_instance.agent_session = MagicMock()
- mock_session = MagicMock()
- mock_session.return_value = session_instance
- sio = get_mock_sio()
- is_session_running_in_cluster_mock = AsyncMock()
- is_session_running_in_cluster_mock.return_value = False
- with (
- patch('openhands.server.session.manager.Session', mock_session),
- patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
- patch(
- 'openhands.server.session.manager.SessionManager._redis_subscribe',
- AsyncMock(),
- ),
- patch(
- 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
- is_session_running_in_cluster_mock,
- ),
- ):
- async with SessionManager(
- sio, AppConfig(), InMemoryFileStore()
- ) as session_manager:
- # First call initializes
- await session_manager.init_or_join_session(
- 'new-session-id', 'new-session-id', SessionInitData()
- )
- # Second call joins
- await session_manager.init_or_join_session(
- 'new-session-id', 'extra-connection-id', SessionInitData()
- )
- assert session_instance.initialize_agent.call_count == 1
- assert sio.enter_room.await_count == 2
- @pytest.mark.asyncio
- async def test_join_cluster_session():
- session_instance = AsyncMock()
- session_instance.agent_session = MagicMock()
- mock_session = MagicMock()
- mock_session.return_value = session_instance
- sio = get_mock_sio()
- is_session_running_in_cluster_mock = AsyncMock()
- is_session_running_in_cluster_mock.return_value = True
- with (
- patch('openhands.server.session.manager.Session', mock_session),
- patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
- patch(
- 'openhands.server.session.manager.SessionManager._redis_subscribe',
- AsyncMock(),
- ),
- patch(
- 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
- is_session_running_in_cluster_mock,
- ),
- ):
- async with SessionManager(
- sio, AppConfig(), InMemoryFileStore()
- ) as session_manager:
- # First call initializes
- await session_manager.init_or_join_session(
- 'new-session-id', 'new-session-id', SessionInitData()
- )
- assert session_instance.initialize_agent.call_count == 0
- assert sio.enter_room.await_count == 1
- @pytest.mark.asyncio
- async def test_add_to_local_event_stream():
- session_instance = AsyncMock()
- session_instance.agent_session = MagicMock()
- mock_session = MagicMock()
- mock_session.return_value = session_instance
- sio = get_mock_sio()
- is_session_running_in_cluster_mock = AsyncMock()
- is_session_running_in_cluster_mock.return_value = False
- with (
- patch('openhands.server.session.manager.Session', mock_session),
- patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
- patch(
- 'openhands.server.session.manager.SessionManager._redis_subscribe',
- AsyncMock(),
- ),
- patch(
- 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
- is_session_running_in_cluster_mock,
- ),
- ):
- async with SessionManager(
- sio, AppConfig(), InMemoryFileStore()
- ) as session_manager:
- await session_manager.init_or_join_session(
- 'new-session-id', 'connection-id', SessionInitData()
- )
- await session_manager.send_to_event_stream(
- 'connection-id', {'event_type': 'some_event'}
- )
- session_instance.dispatch.assert_called_once_with({'event_type': 'some_event'})
- @pytest.mark.asyncio
- async def test_add_to_cluster_event_stream():
- session_instance = AsyncMock()
- session_instance.agent_session = MagicMock()
- mock_session = MagicMock()
- mock_session.return_value = session_instance
- sio = get_mock_sio()
- is_session_running_in_cluster_mock = AsyncMock()
- is_session_running_in_cluster_mock.return_value = True
- with (
- patch('openhands.server.session.manager.Session', mock_session),
- patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
- patch(
- 'openhands.server.session.manager.SessionManager._redis_subscribe',
- AsyncMock(),
- ),
- patch(
- 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
- is_session_running_in_cluster_mock,
- ),
- ):
- async with SessionManager(
- sio, AppConfig(), InMemoryFileStore()
- ) as session_manager:
- await session_manager.init_or_join_session(
- 'new-session-id', 'connection-id', SessionInitData()
- )
- await session_manager.send_to_event_stream(
- 'connection-id', {'event_type': 'some_event'}
- )
- assert sio.manager.redis.publish.await_count == 1
- sio.manager.redis.publish.assert_called_once_with(
- 'oh_event',
- '{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}',
- )
|