test_manager.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import asyncio
  2. import json
  3. from dataclasses import dataclass
  4. from unittest.mock import AsyncMock, MagicMock, patch
  5. from uuid import uuid4
  6. import pytest
  7. from openhands.core.config.app_config import AppConfig
  8. from openhands.server.session.conversation_init_data import ConversationInitData
  9. from openhands.server.session.manager import SessionManager
  10. from openhands.storage.memory import InMemoryFileStore
  11. @dataclass
  12. class GetMessageMock:
  13. message: dict | None
  14. sleep_time: int = 0.01
  15. async def get_message(self, **kwargs):
  16. await asyncio.sleep(self.sleep_time)
  17. return {'data': json.dumps(self.message)}
  18. def get_mock_sio(get_message: GetMessageMock | None = None):
  19. sio = MagicMock()
  20. sio.enter_room = AsyncMock()
  21. sio.manager.redis = MagicMock()
  22. sio.manager.redis.publish = AsyncMock()
  23. pubsub = AsyncMock()
  24. pubsub.get_message = (get_message or GetMessageMock(None)).get_message
  25. sio.manager.redis.pubsub.return_value = pubsub
  26. return sio
  27. @pytest.mark.asyncio
  28. async def test_session_not_running_in_cluster():
  29. sio = get_mock_sio()
  30. id = uuid4()
  31. with (
  32. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  33. patch('openhands.server.session.manager.uuid4', MagicMock(return_value=id)),
  34. ):
  35. async with SessionManager(
  36. sio, AppConfig(), InMemoryFileStore()
  37. ) as session_manager:
  38. result = await session_manager.is_agent_loop_running_in_cluster(
  39. 'non-existant-session'
  40. )
  41. assert result is False
  42. assert sio.manager.redis.publish.await_count == 1
  43. sio.manager.redis.publish.assert_called_once_with(
  44. 'oh_event',
  45. '{"request_id": "'
  46. + str(id)
  47. + '", "sids": ["non-existant-session"], "message_type": "is_session_running"}',
  48. )
  49. @pytest.mark.asyncio
  50. async def test_session_is_running_in_cluster():
  51. id = uuid4()
  52. sio = get_mock_sio(
  53. GetMessageMock(
  54. {
  55. 'request_id': str(id),
  56. 'sids': ['existing-session'],
  57. 'message_type': 'session_is_running',
  58. }
  59. )
  60. )
  61. with (
  62. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
  63. patch('openhands.server.session.manager.uuid4', MagicMock(return_value=id)),
  64. ):
  65. async with SessionManager(
  66. sio, AppConfig(), InMemoryFileStore()
  67. ) as session_manager:
  68. result = await session_manager.is_agent_loop_running_in_cluster(
  69. 'existing-session'
  70. )
  71. assert result is True
  72. assert sio.manager.redis.publish.await_count == 1
  73. sio.manager.redis.publish.assert_called_once_with(
  74. 'oh_event',
  75. '{"request_id": "'
  76. + str(id)
  77. + '", "sids": ["existing-session"], "message_type": "is_session_running"}',
  78. )
  79. @pytest.mark.asyncio
  80. async def test_init_new_local_session():
  81. session_instance = AsyncMock()
  82. session_instance.agent_session = MagicMock()
  83. mock_session = MagicMock()
  84. mock_session.return_value = session_instance
  85. sio = get_mock_sio()
  86. is_agent_loop_running_in_cluster_mock = AsyncMock()
  87. is_agent_loop_running_in_cluster_mock.return_value = False
  88. with (
  89. patch('openhands.server.session.manager.Session', mock_session),
  90. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
  91. patch(
  92. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  93. AsyncMock(),
  94. ),
  95. patch(
  96. 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
  97. is_agent_loop_running_in_cluster_mock,
  98. ),
  99. ):
  100. async with SessionManager(
  101. sio, AppConfig(), InMemoryFileStore()
  102. ) as session_manager:
  103. await session_manager.maybe_start_agent_loop(
  104. 'new-session-id', ConversationInitData()
  105. )
  106. await session_manager.join_conversation('new-session-id', 'new-session-id')
  107. assert session_instance.initialize_agent.call_count == 1
  108. assert sio.enter_room.await_count == 1
  109. @pytest.mark.asyncio
  110. async def test_join_local_session():
  111. session_instance = AsyncMock()
  112. session_instance.agent_session = MagicMock()
  113. mock_session = MagicMock()
  114. mock_session.return_value = session_instance
  115. sio = get_mock_sio()
  116. is_agent_loop_running_in_cluster_mock = AsyncMock()
  117. is_agent_loop_running_in_cluster_mock.return_value = False
  118. with (
  119. patch('openhands.server.session.manager.Session', mock_session),
  120. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  121. patch(
  122. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  123. AsyncMock(),
  124. ),
  125. patch(
  126. 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
  127. is_agent_loop_running_in_cluster_mock,
  128. ),
  129. ):
  130. async with SessionManager(
  131. sio, AppConfig(), InMemoryFileStore()
  132. ) as session_manager:
  133. await session_manager.maybe_start_agent_loop(
  134. 'new-session-id', ConversationInitData()
  135. )
  136. await session_manager.join_conversation('new-session-id', 'new-session-id')
  137. await session_manager.join_conversation('new-session-id', 'new-session-id')
  138. assert session_instance.initialize_agent.call_count == 1
  139. assert sio.enter_room.await_count == 2
  140. @pytest.mark.asyncio
  141. async def test_join_cluster_session():
  142. session_instance = AsyncMock()
  143. session_instance.agent_session = MagicMock()
  144. mock_session = MagicMock()
  145. mock_session.return_value = session_instance
  146. sio = get_mock_sio()
  147. is_agent_loop_running_in_cluster_mock = AsyncMock()
  148. is_agent_loop_running_in_cluster_mock.return_value = True
  149. with (
  150. patch('openhands.server.session.manager.Session', mock_session),
  151. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  152. patch(
  153. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  154. AsyncMock(),
  155. ),
  156. patch(
  157. 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
  158. is_agent_loop_running_in_cluster_mock,
  159. ),
  160. ):
  161. async with SessionManager(
  162. sio, AppConfig(), InMemoryFileStore()
  163. ) as session_manager:
  164. await session_manager.join_conversation('new-session-id', 'new-session-id')
  165. assert session_instance.initialize_agent.call_count == 0
  166. assert sio.enter_room.await_count == 1
  167. @pytest.mark.asyncio
  168. async def test_add_to_local_event_stream():
  169. session_instance = AsyncMock()
  170. session_instance.agent_session = MagicMock()
  171. mock_session = MagicMock()
  172. mock_session.return_value = session_instance
  173. sio = get_mock_sio()
  174. is_agent_loop_running_in_cluster_mock = AsyncMock()
  175. is_agent_loop_running_in_cluster_mock.return_value = False
  176. with (
  177. patch('openhands.server.session.manager.Session', mock_session),
  178. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  179. patch(
  180. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  181. AsyncMock(),
  182. ),
  183. patch(
  184. 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
  185. is_agent_loop_running_in_cluster_mock,
  186. ),
  187. ):
  188. async with SessionManager(
  189. sio, AppConfig(), InMemoryFileStore()
  190. ) as session_manager:
  191. await session_manager.maybe_start_agent_loop(
  192. 'new-session-id', ConversationInitData()
  193. )
  194. await session_manager.join_conversation('new-session-id', 'connection-id')
  195. await session_manager.send_to_event_stream(
  196. 'connection-id', {'event_type': 'some_event'}
  197. )
  198. session_instance.dispatch.assert_called_once_with({'event_type': 'some_event'})
  199. @pytest.mark.asyncio
  200. async def test_add_to_cluster_event_stream():
  201. session_instance = AsyncMock()
  202. session_instance.agent_session = MagicMock()
  203. mock_session = MagicMock()
  204. mock_session.return_value = session_instance
  205. sio = get_mock_sio()
  206. is_agent_loop_running_in_cluster_mock = AsyncMock()
  207. is_agent_loop_running_in_cluster_mock.return_value = True
  208. with (
  209. patch('openhands.server.session.manager.Session', mock_session),
  210. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  211. patch(
  212. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  213. AsyncMock(),
  214. ),
  215. patch(
  216. 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
  217. is_agent_loop_running_in_cluster_mock,
  218. ),
  219. ):
  220. async with SessionManager(
  221. sio, AppConfig(), InMemoryFileStore()
  222. ) as session_manager:
  223. await session_manager.join_conversation('new-session-id', 'connection-id')
  224. await session_manager.send_to_event_stream(
  225. 'connection-id', {'event_type': 'some_event'}
  226. )
  227. assert sio.manager.redis.publish.await_count == 1
  228. sio.manager.redis.publish.assert_called_once_with(
  229. 'oh_event',
  230. '{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}',
  231. )
  232. @pytest.mark.asyncio
  233. async def test_cleanup_session_connections():
  234. sio = get_mock_sio()
  235. with (
  236. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  237. patch(
  238. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  239. AsyncMock(),
  240. ),
  241. ):
  242. async with SessionManager(
  243. sio, AppConfig(), InMemoryFileStore()
  244. ) as session_manager:
  245. session_manager.local_connection_id_to_session_id.update(
  246. {
  247. 'conn1': 'session1',
  248. 'conn2': 'session1',
  249. 'conn3': 'session2',
  250. 'conn4': 'session2',
  251. }
  252. )
  253. await session_manager._close_session('session1')
  254. remaining_connections = session_manager.local_connection_id_to_session_id
  255. assert 'conn1' not in remaining_connections
  256. assert 'conn2' not in remaining_connections
  257. assert 'conn3' in remaining_connections
  258. assert 'conn4' in remaining_connections
  259. assert remaining_connections['conn3'] == 'session2'
  260. assert remaining_connections['conn4'] == 'session2'