test_manager.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import asyncio
  2. import json
  3. from dataclasses import dataclass
  4. from unittest.mock import AsyncMock, MagicMock, patch
  5. import pytest
  6. from openhands.core.config.app_config import AppConfig
  7. from openhands.server.session.manager import SessionManager
  8. from openhands.storage.memory import InMemoryFileStore
  9. @dataclass
  10. class GetMessageMock:
  11. message: dict | None
  12. sleep_time: int = 0.01
  13. async def get_message(self, **kwargs):
  14. await asyncio.sleep(self.sleep_time)
  15. return {'data': json.dumps(self.message)}
  16. def get_mock_sio(get_message: GetMessageMock | None = None):
  17. sio = MagicMock()
  18. sio.enter_room = AsyncMock()
  19. sio.manager.redis = MagicMock()
  20. sio.manager.redis.publish = AsyncMock()
  21. pubsub = AsyncMock()
  22. pubsub.get_message = (get_message or GetMessageMock(None)).get_message
  23. sio.manager.redis.pubsub.return_value = pubsub
  24. return sio
  25. @pytest.mark.asyncio
  26. async def test_session_not_running_in_cluster():
  27. sio = get_mock_sio()
  28. with (
  29. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  30. ):
  31. async with SessionManager(
  32. sio, AppConfig(), InMemoryFileStore()
  33. ) as session_manager:
  34. result = await session_manager._is_session_running_in_cluster(
  35. 'non-existant-session'
  36. )
  37. assert result is False
  38. assert sio.manager.redis.publish.await_count == 1
  39. sio.manager.redis.publish.assert_called_once_with(
  40. 'oh_event',
  41. '{"sid": "non-existant-session", "message_type": "is_session_running"}',
  42. )
  43. @pytest.mark.asyncio
  44. async def test_session_is_running_in_cluster():
  45. sio = get_mock_sio(
  46. GetMessageMock(
  47. {'sid': 'existing-session', 'message_type': 'session_is_running'}
  48. )
  49. )
  50. with (
  51. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.05),
  52. ):
  53. async with SessionManager(
  54. sio, AppConfig(), InMemoryFileStore()
  55. ) as session_manager:
  56. result = await session_manager._is_session_running_in_cluster(
  57. 'existing-session'
  58. )
  59. assert result is True
  60. assert sio.manager.redis.publish.await_count == 1
  61. sio.manager.redis.publish.assert_called_once_with(
  62. 'oh_event',
  63. '{"sid": "existing-session", "message_type": "is_session_running"}',
  64. )
  65. @pytest.mark.asyncio
  66. async def test_init_new_local_session():
  67. session_instance = AsyncMock()
  68. session_instance.agent_session = MagicMock()
  69. mock_session = MagicMock()
  70. mock_session.return_value = session_instance
  71. sio = get_mock_sio()
  72. is_session_running_in_cluster_mock = AsyncMock()
  73. is_session_running_in_cluster_mock.return_value = False
  74. with (
  75. patch('openhands.server.session.manager.Session', mock_session),
  76. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  77. patch(
  78. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  79. AsyncMock(),
  80. ),
  81. patch(
  82. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  83. is_session_running_in_cluster_mock,
  84. ),
  85. ):
  86. async with SessionManager(
  87. sio, AppConfig(), InMemoryFileStore()
  88. ) as session_manager:
  89. await session_manager.init_or_join_session(
  90. 'new-session-id', 'new-session-id', {'type': 'mock-settings'}
  91. )
  92. assert session_instance.initialize_agent.call_count == 1
  93. assert sio.enter_room.await_count == 1
  94. @pytest.mark.asyncio
  95. async def test_join_local_session():
  96. session_instance = AsyncMock()
  97. session_instance.agent_session = MagicMock()
  98. mock_session = MagicMock()
  99. mock_session.return_value = session_instance
  100. sio = get_mock_sio()
  101. is_session_running_in_cluster_mock = AsyncMock()
  102. is_session_running_in_cluster_mock.return_value = False
  103. with (
  104. patch('openhands.server.session.manager.Session', mock_session),
  105. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  106. patch(
  107. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  108. AsyncMock(),
  109. ),
  110. patch(
  111. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  112. is_session_running_in_cluster_mock,
  113. ),
  114. ):
  115. async with SessionManager(
  116. sio, AppConfig(), InMemoryFileStore()
  117. ) as session_manager:
  118. # First call initializes
  119. await session_manager.init_or_join_session(
  120. 'new-session-id', 'new-session-id', {'type': 'mock-settings'}
  121. )
  122. # Second call joins
  123. await session_manager.init_or_join_session(
  124. 'new-session-id', 'extra-connection-id', {'type': 'mock-settings'}
  125. )
  126. assert session_instance.initialize_agent.call_count == 1
  127. assert sio.enter_room.await_count == 2
  128. @pytest.mark.asyncio
  129. async def test_join_cluster_session():
  130. session_instance = AsyncMock()
  131. session_instance.agent_session = MagicMock()
  132. mock_session = MagicMock()
  133. mock_session.return_value = session_instance
  134. sio = get_mock_sio()
  135. is_session_running_in_cluster_mock = AsyncMock()
  136. is_session_running_in_cluster_mock.return_value = True
  137. with (
  138. patch('openhands.server.session.manager.Session', mock_session),
  139. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  140. patch(
  141. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  142. AsyncMock(),
  143. ),
  144. patch(
  145. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  146. is_session_running_in_cluster_mock,
  147. ),
  148. ):
  149. async with SessionManager(
  150. sio, AppConfig(), InMemoryFileStore()
  151. ) as session_manager:
  152. # First call initializes
  153. await session_manager.init_or_join_session(
  154. 'new-session-id', 'new-session-id', {'type': 'mock-settings'}
  155. )
  156. assert session_instance.initialize_agent.call_count == 0
  157. assert sio.enter_room.await_count == 1
  158. @pytest.mark.asyncio
  159. async def test_add_to_local_event_stream():
  160. session_instance = AsyncMock()
  161. session_instance.agent_session = MagicMock()
  162. mock_session = MagicMock()
  163. mock_session.return_value = session_instance
  164. sio = get_mock_sio()
  165. is_session_running_in_cluster_mock = AsyncMock()
  166. is_session_running_in_cluster_mock.return_value = False
  167. with (
  168. patch('openhands.server.session.manager.Session', mock_session),
  169. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  170. patch(
  171. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  172. AsyncMock(),
  173. ),
  174. patch(
  175. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  176. is_session_running_in_cluster_mock,
  177. ),
  178. ):
  179. async with SessionManager(
  180. sio, AppConfig(), InMemoryFileStore()
  181. ) as session_manager:
  182. await session_manager.init_or_join_session(
  183. 'new-session-id', 'connection-id', {'type': 'mock-settings'}
  184. )
  185. await session_manager.send_to_event_stream(
  186. 'connection-id', {'event_type': 'some_event'}
  187. )
  188. session_instance.dispatch.assert_called_once_with({'event_type': 'some_event'})
  189. @pytest.mark.asyncio
  190. async def test_add_to_cluster_event_stream():
  191. session_instance = AsyncMock()
  192. session_instance.agent_session = MagicMock()
  193. mock_session = MagicMock()
  194. mock_session.return_value = session_instance
  195. sio = get_mock_sio()
  196. is_session_running_in_cluster_mock = AsyncMock()
  197. is_session_running_in_cluster_mock.return_value = True
  198. with (
  199. patch('openhands.server.session.manager.Session', mock_session),
  200. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  201. patch(
  202. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  203. AsyncMock(),
  204. ),
  205. patch(
  206. 'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
  207. is_session_running_in_cluster_mock,
  208. ),
  209. ):
  210. async with SessionManager(
  211. sio, AppConfig(), InMemoryFileStore()
  212. ) as session_manager:
  213. await session_manager.init_or_join_session(
  214. 'new-session-id', 'connection-id', {'type': 'mock-settings'}
  215. )
  216. await session_manager.send_to_event_stream(
  217. 'connection-id', {'event_type': 'some_event'}
  218. )
  219. assert sio.manager.redis.publish.await_count == 1
  220. sio.manager.redis.publish.assert_called_once_with(
  221. 'oh_event',
  222. '{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}',
  223. )