test_manager.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import asyncio
  2. import json
  3. from dataclasses import dataclass
  4. from unittest.mock import AsyncMock, MagicMock, patch
  5. import pytest
  6. from openhands.core.config.app_config import AppConfig
  7. from openhands.server.session.manager import SessionManager
  8. from openhands.server.session.session_init_data import SessionInitData
  9. from openhands.storage.memory import InMemoryFileStore
  10. @dataclass
  11. class GetMessageMock:
  12. message: dict | None
  13. sleep_time: int = 0.01
  14. async def get_message(self, **kwargs):
  15. await asyncio.sleep(self.sleep_time)
  16. return {'data': json.dumps(self.message)}
  17. def get_mock_sio(get_message: GetMessageMock | None = None):
  18. sio = MagicMock()
  19. sio.enter_room = AsyncMock()
  20. sio.manager.redis = MagicMock()
  21. sio.manager.redis.publish = AsyncMock()
  22. pubsub = AsyncMock()
  23. pubsub.get_message = (get_message or GetMessageMock(None)).get_message
  24. sio.manager.redis.pubsub.return_value = pubsub
  25. return sio
  26. @pytest.mark.asyncio
  27. async def test_session_not_running_in_cluster():
  28. sio = get_mock_sio()
  29. with (
  30. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  31. ):
  32. async with SessionManager(
  33. sio, AppConfig(), InMemoryFileStore()
  34. ) as session_manager:
  35. result = await session_manager._is_session_running_in_cluster(
  36. 'non-existant-session'
  37. )
  38. assert result is False
  39. assert sio.manager.redis.publish.await_count == 1
  40. sio.manager.redis.publish.assert_called_once_with(
  41. 'oh_event',
  42. '{"sid": "non-existant-session", "message_type": "is_session_running"}',
  43. )
  44. @pytest.mark.asyncio
  45. async def test_session_is_running_in_cluster():
  46. sio = get_mock_sio(
  47. GetMessageMock(
  48. {'sid': 'existing-session', 'message_type': 'session_is_running'}
  49. )
  50. )
  51. with (
  52. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.05),
  53. ):
  54. async with SessionManager(
  55. sio, AppConfig(), InMemoryFileStore()
  56. ) as session_manager:
  57. result = await session_manager._is_session_running_in_cluster(
  58. 'existing-session'
  59. )
  60. assert result is True
  61. assert sio.manager.redis.publish.await_count == 1
  62. sio.manager.redis.publish.assert_called_once_with(
  63. 'oh_event',
  64. '{"sid": "existing-session", "message_type": "is_session_running"}',
  65. )
  66. @pytest.mark.asyncio
  67. async def test_init_new_local_session():
  68. session_instance = AsyncMock()
  69. session_instance.agent_session = MagicMock()
  70. mock_session = MagicMock()
  71. mock_session.return_value = session_instance
  72. sio = get_mock_sio()
  73. is_session_running_in_cluster_mock = AsyncMock()
  74. is_session_running_in_cluster_mock.return_value = False
  75. with (
  76. patch('openhands.server.session.manager.Session', mock_session),
  77. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  78. patch(
  79. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  80. AsyncMock(),
  81. ),
  82. patch(
  83. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  84. is_session_running_in_cluster_mock,
  85. ),
  86. ):
  87. async with SessionManager(
  88. sio, AppConfig(), InMemoryFileStore()
  89. ) as session_manager:
  90. await session_manager.init_or_join_session(
  91. 'new-session-id', 'new-session-id', SessionInitData()
  92. )
  93. assert session_instance.initialize_agent.call_count == 1
  94. assert sio.enter_room.await_count == 1
  95. @pytest.mark.asyncio
  96. async def test_join_local_session():
  97. session_instance = AsyncMock()
  98. session_instance.agent_session = MagicMock()
  99. mock_session = MagicMock()
  100. mock_session.return_value = session_instance
  101. sio = get_mock_sio()
  102. is_session_running_in_cluster_mock = AsyncMock()
  103. is_session_running_in_cluster_mock.return_value = False
  104. with (
  105. patch('openhands.server.session.manager.Session', mock_session),
  106. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  107. patch(
  108. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  109. AsyncMock(),
  110. ),
  111. patch(
  112. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  113. is_session_running_in_cluster_mock,
  114. ),
  115. ):
  116. async with SessionManager(
  117. sio, AppConfig(), InMemoryFileStore()
  118. ) as session_manager:
  119. # First call initializes
  120. await session_manager.init_or_join_session(
  121. 'new-session-id', 'new-session-id', SessionInitData()
  122. )
  123. # Second call joins
  124. await session_manager.init_or_join_session(
  125. 'new-session-id', 'extra-connection-id', SessionInitData()
  126. )
  127. assert session_instance.initialize_agent.call_count == 1
  128. assert sio.enter_room.await_count == 2
  129. @pytest.mark.asyncio
  130. async def test_join_cluster_session():
  131. session_instance = AsyncMock()
  132. session_instance.agent_session = MagicMock()
  133. mock_session = MagicMock()
  134. mock_session.return_value = session_instance
  135. sio = get_mock_sio()
  136. is_session_running_in_cluster_mock = AsyncMock()
  137. is_session_running_in_cluster_mock.return_value = True
  138. with (
  139. patch('openhands.server.session.manager.Session', mock_session),
  140. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  141. patch(
  142. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  143. AsyncMock(),
  144. ),
  145. patch(
  146. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  147. is_session_running_in_cluster_mock,
  148. ),
  149. ):
  150. async with SessionManager(
  151. sio, AppConfig(), InMemoryFileStore()
  152. ) as session_manager:
  153. # First call initializes
  154. await session_manager.init_or_join_session(
  155. 'new-session-id', 'new-session-id', SessionInitData()
  156. )
  157. assert session_instance.initialize_agent.call_count == 0
  158. assert sio.enter_room.await_count == 1
  159. @pytest.mark.asyncio
  160. async def test_add_to_local_event_stream():
  161. session_instance = AsyncMock()
  162. session_instance.agent_session = MagicMock()
  163. mock_session = MagicMock()
  164. mock_session.return_value = session_instance
  165. sio = get_mock_sio()
  166. is_session_running_in_cluster_mock = AsyncMock()
  167. is_session_running_in_cluster_mock.return_value = False
  168. with (
  169. patch('openhands.server.session.manager.Session', mock_session),
  170. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  171. patch(
  172. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  173. AsyncMock(),
  174. ),
  175. patch(
  176. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  177. is_session_running_in_cluster_mock,
  178. ),
  179. ):
  180. async with SessionManager(
  181. sio, AppConfig(), InMemoryFileStore()
  182. ) as session_manager:
  183. await session_manager.init_or_join_session(
  184. 'new-session-id', 'connection-id', SessionInitData()
  185. )
  186. await session_manager.send_to_event_stream(
  187. 'connection-id', {'event_type': 'some_event'}
  188. )
  189. session_instance.dispatch.assert_called_once_with({'event_type': 'some_event'})
  190. @pytest.mark.asyncio
  191. async def test_add_to_cluster_event_stream():
  192. session_instance = AsyncMock()
  193. session_instance.agent_session = MagicMock()
  194. mock_session = MagicMock()
  195. mock_session.return_value = session_instance
  196. sio = get_mock_sio()
  197. is_session_running_in_cluster_mock = AsyncMock()
  198. is_session_running_in_cluster_mock.return_value = True
  199. with (
  200. patch('openhands.server.session.manager.Session', mock_session),
  201. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  202. patch(
  203. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  204. AsyncMock(),
  205. ),
  206. patch(
  207. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  208. is_session_running_in_cluster_mock,
  209. ),
  210. ):
  211. async with SessionManager(
  212. sio, AppConfig(), InMemoryFileStore()
  213. ) as session_manager:
  214. await session_manager.init_or_join_session(
  215. 'new-session-id', 'connection-id', SessionInitData()
  216. )
  217. await session_manager.send_to_event_stream(
  218. 'connection-id', {'event_type': 'some_event'}
  219. )
  220. assert sio.manager.redis.publish.await_count == 1
  221. sio.manager.redis.publish.assert_called_once_with(
  222. 'oh_event',
  223. '{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}',
  224. )