test_manager.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import asyncio
  2. import json
  3. from dataclasses import dataclass
  4. from unittest.mock import AsyncMock, MagicMock, patch
  5. import pytest
  6. from openhands.core.config.app_config import AppConfig
  7. from openhands.server.session.manager import SessionManager
  8. from openhands.server.session.session_init_data import SessionInitData
  9. from openhands.storage.memory import InMemoryFileStore
  10. @dataclass
  11. class GetMessageMock:
  12. message: dict | None
  13. sleep_time: int = 0.01
  14. async def get_message(self, **kwargs):
  15. await asyncio.sleep(self.sleep_time)
  16. return {'data': json.dumps(self.message)}
  17. def get_mock_sio(get_message: GetMessageMock | None = None):
  18. sio = MagicMock()
  19. sio.enter_room = AsyncMock()
  20. sio.manager.redis = MagicMock()
  21. sio.manager.redis.publish = AsyncMock()
  22. pubsub = AsyncMock()
  23. pubsub.get_message = (get_message or GetMessageMock(None)).get_message
  24. sio.manager.redis.pubsub.return_value = pubsub
  25. return sio
  26. @pytest.mark.asyncio
  27. async def test_session_not_running_in_cluster():
  28. sio = get_mock_sio()
  29. with (
  30. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  31. ):
  32. async with SessionManager(
  33. sio, AppConfig(), InMemoryFileStore()
  34. ) as session_manager:
  35. result = await session_manager._is_session_running_in_cluster(
  36. 'non-existant-session'
  37. )
  38. assert result is False
  39. assert sio.manager.redis.publish.await_count == 1
  40. sio.manager.redis.publish.assert_called_once_with(
  41. 'oh_event',
  42. '{"sid": "non-existant-session", "message_type": "is_session_running"}',
  43. )
  44. @pytest.mark.asyncio
  45. async def test_session_is_running_in_cluster():
  46. sio = get_mock_sio(
  47. GetMessageMock(
  48. {'sid': 'existing-session', 'message_type': 'session_is_running'}
  49. )
  50. )
  51. with (
  52. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
  53. ):
  54. async with SessionManager(
  55. sio, AppConfig(), InMemoryFileStore()
  56. ) as session_manager:
  57. result = await session_manager._is_session_running_in_cluster(
  58. 'existing-session'
  59. )
  60. assert result is True
  61. assert sio.manager.redis.publish.await_count == 1
  62. sio.manager.redis.publish.assert_called_once_with(
  63. 'oh_event',
  64. '{"sid": "existing-session", "message_type": "is_session_running"}',
  65. )
  66. @pytest.mark.asyncio
  67. async def test_init_new_local_session():
  68. session_instance = AsyncMock()
  69. session_instance.agent_session = MagicMock()
  70. mock_session = MagicMock()
  71. mock_session.return_value = session_instance
  72. sio = get_mock_sio()
  73. is_session_running_in_cluster_mock = AsyncMock()
  74. is_session_running_in_cluster_mock.return_value = False
  75. with (
  76. patch('openhands.server.session.manager.Session', mock_session),
  77. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
  78. patch(
  79. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  80. AsyncMock(),
  81. ),
  82. patch(
  83. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  84. is_session_running_in_cluster_mock,
  85. ),
  86. ):
  87. async with SessionManager(
  88. sio, AppConfig(), InMemoryFileStore()
  89. ) as session_manager:
  90. await session_manager.start_agent_loop('new-session-id', SessionInitData())
  91. await session_manager.join_conversation('new-session-id', 'new-session-id')
  92. assert session_instance.initialize_agent.call_count == 1
  93. assert sio.enter_room.await_count == 1
  94. @pytest.mark.asyncio
  95. async def test_join_local_session():
  96. session_instance = AsyncMock()
  97. session_instance.agent_session = MagicMock()
  98. mock_session = MagicMock()
  99. mock_session.return_value = session_instance
  100. sio = get_mock_sio()
  101. is_session_running_in_cluster_mock = AsyncMock()
  102. is_session_running_in_cluster_mock.return_value = False
  103. with (
  104. patch('openhands.server.session.manager.Session', mock_session),
  105. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  106. patch(
  107. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  108. AsyncMock(),
  109. ),
  110. patch(
  111. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  112. is_session_running_in_cluster_mock,
  113. ),
  114. ):
  115. async with SessionManager(
  116. sio, AppConfig(), InMemoryFileStore()
  117. ) as session_manager:
  118. await session_manager.start_agent_loop('new-session-id', SessionInitData())
  119. await session_manager.join_conversation('new-session-id', 'new-session-id')
  120. await session_manager.join_conversation('new-session-id', 'new-session-id')
  121. assert session_instance.initialize_agent.call_count == 1
  122. assert sio.enter_room.await_count == 2
  123. @pytest.mark.asyncio
  124. async def test_join_cluster_session():
  125. session_instance = AsyncMock()
  126. session_instance.agent_session = MagicMock()
  127. mock_session = MagicMock()
  128. mock_session.return_value = session_instance
  129. sio = get_mock_sio()
  130. is_session_running_in_cluster_mock = AsyncMock()
  131. is_session_running_in_cluster_mock.return_value = True
  132. with (
  133. patch('openhands.server.session.manager.Session', mock_session),
  134. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  135. patch(
  136. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  137. AsyncMock(),
  138. ),
  139. patch(
  140. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  141. is_session_running_in_cluster_mock,
  142. ),
  143. ):
  144. async with SessionManager(
  145. sio, AppConfig(), InMemoryFileStore()
  146. ) as session_manager:
  147. await session_manager.join_conversation('new-session-id', 'new-session-id')
  148. assert session_instance.initialize_agent.call_count == 0
  149. assert sio.enter_room.await_count == 1
  150. @pytest.mark.asyncio
  151. async def test_add_to_local_event_stream():
  152. session_instance = AsyncMock()
  153. session_instance.agent_session = MagicMock()
  154. mock_session = MagicMock()
  155. mock_session.return_value = session_instance
  156. sio = get_mock_sio()
  157. is_session_running_in_cluster_mock = AsyncMock()
  158. is_session_running_in_cluster_mock.return_value = False
  159. with (
  160. patch('openhands.server.session.manager.Session', mock_session),
  161. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  162. patch(
  163. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  164. AsyncMock(),
  165. ),
  166. patch(
  167. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  168. is_session_running_in_cluster_mock,
  169. ),
  170. ):
  171. async with SessionManager(
  172. sio, AppConfig(), InMemoryFileStore()
  173. ) as session_manager:
  174. await session_manager.start_agent_loop('new-session-id', SessionInitData())
  175. await session_manager.join_conversation('new-session-id', 'connection-id')
  176. await session_manager.send_to_event_stream(
  177. 'connection-id', {'event_type': 'some_event'}
  178. )
  179. session_instance.dispatch.assert_called_once_with({'event_type': 'some_event'})
  180. @pytest.mark.asyncio
  181. async def test_add_to_cluster_event_stream():
  182. session_instance = AsyncMock()
  183. session_instance.agent_session = MagicMock()
  184. mock_session = MagicMock()
  185. mock_session.return_value = session_instance
  186. sio = get_mock_sio()
  187. is_session_running_in_cluster_mock = AsyncMock()
  188. is_session_running_in_cluster_mock.return_value = True
  189. with (
  190. patch('openhands.server.session.manager.Session', mock_session),
  191. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  192. patch(
  193. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  194. AsyncMock(),
  195. ),
  196. patch(
  197. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  198. is_session_running_in_cluster_mock,
  199. ),
  200. ):
  201. async with SessionManager(
  202. sio, AppConfig(), InMemoryFileStore()
  203. ) as session_manager:
  204. await session_manager.join_conversation('new-session-id', 'connection-id')
  205. await session_manager.send_to_event_stream(
  206. 'connection-id', {'event_type': 'some_event'}
  207. )
  208. assert sio.manager.redis.publish.await_count == 1
  209. sio.manager.redis.publish.assert_called_once_with(
  210. 'oh_event',
  211. '{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}',
  212. )