test_manager.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import asyncio
  2. import json
  3. from dataclasses import dataclass
  4. from unittest.mock import AsyncMock, MagicMock, patch
  5. import pytest
  6. from openhands.core.config.app_config import AppConfig
  7. from openhands.server.session.conversation_init_data import ConversationInitData
  8. from openhands.server.session.manager import SessionManager
  9. from openhands.storage.memory import InMemoryFileStore
  10. @dataclass
  11. class GetMessageMock:
  12. message: dict | None
  13. sleep_time: int = 0.01
  14. async def get_message(self, **kwargs):
  15. await asyncio.sleep(self.sleep_time)
  16. return {'data': json.dumps(self.message)}
  17. def get_mock_sio(get_message: GetMessageMock | None = None):
  18. sio = MagicMock()
  19. sio.enter_room = AsyncMock()
  20. sio.manager.redis = MagicMock()
  21. sio.manager.redis.publish = AsyncMock()
  22. pubsub = AsyncMock()
  23. pubsub.get_message = (get_message or GetMessageMock(None)).get_message
  24. sio.manager.redis.pubsub.return_value = pubsub
  25. return sio
  26. @pytest.mark.asyncio
  27. async def test_session_not_running_in_cluster():
  28. sio = get_mock_sio()
  29. with (
  30. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  31. ):
  32. async with SessionManager(
  33. sio, AppConfig(), InMemoryFileStore()
  34. ) as session_manager:
  35. result = await session_manager._is_agent_loop_running_in_cluster(
  36. 'non-existant-session'
  37. )
  38. assert result is False
  39. assert sio.manager.redis.publish.await_count == 1
  40. sio.manager.redis.publish.assert_called_once_with(
  41. 'oh_event',
  42. '{"sid": "non-existant-session", "message_type": "is_session_running"}',
  43. )
  44. @pytest.mark.asyncio
  45. async def test_session_is_running_in_cluster():
  46. sio = get_mock_sio(
  47. GetMessageMock(
  48. {'sid': 'existing-session', 'message_type': 'session_is_running'}
  49. )
  50. )
  51. with (
  52. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
  53. ):
  54. async with SessionManager(
  55. sio, AppConfig(), InMemoryFileStore()
  56. ) as session_manager:
  57. result = await session_manager._is_agent_loop_running_in_cluster(
  58. 'existing-session'
  59. )
  60. assert result is True
  61. assert sio.manager.redis.publish.await_count == 1
  62. sio.manager.redis.publish.assert_called_once_with(
  63. 'oh_event',
  64. '{"sid": "existing-session", "message_type": "is_session_running"}',
  65. )
  66. @pytest.mark.asyncio
  67. async def test_init_new_local_session():
  68. session_instance = AsyncMock()
  69. session_instance.agent_session = MagicMock()
  70. mock_session = MagicMock()
  71. mock_session.return_value = session_instance
  72. sio = get_mock_sio()
  73. is_agent_loop_running_in_cluster_mock = AsyncMock()
  74. is_agent_loop_running_in_cluster_mock.return_value = False
  75. with (
  76. patch('openhands.server.session.manager.Session', mock_session),
  77. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
  78. patch(
  79. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  80. AsyncMock(),
  81. ),
  82. patch(
  83. 'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
  84. is_agent_loop_running_in_cluster_mock,
  85. ),
  86. ):
  87. async with SessionManager(
  88. sio, AppConfig(), InMemoryFileStore()
  89. ) as session_manager:
  90. await session_manager.maybe_start_agent_loop(
  91. 'new-session-id', ConversationInitData()
  92. )
  93. await session_manager.join_conversation('new-session-id', 'new-session-id')
  94. assert session_instance.initialize_agent.call_count == 1
  95. assert sio.enter_room.await_count == 1
  96. @pytest.mark.asyncio
  97. async def test_join_local_session():
  98. session_instance = AsyncMock()
  99. session_instance.agent_session = MagicMock()
  100. mock_session = MagicMock()
  101. mock_session.return_value = session_instance
  102. sio = get_mock_sio()
  103. is_agent_loop_running_in_cluster_mock = AsyncMock()
  104. is_agent_loop_running_in_cluster_mock.return_value = False
  105. with (
  106. patch('openhands.server.session.manager.Session', mock_session),
  107. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  108. patch(
  109. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  110. AsyncMock(),
  111. ),
  112. patch(
  113. 'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
  114. is_agent_loop_running_in_cluster_mock,
  115. ),
  116. ):
  117. async with SessionManager(
  118. sio, AppConfig(), InMemoryFileStore()
  119. ) as session_manager:
  120. await session_manager.maybe_start_agent_loop(
  121. 'new-session-id', ConversationInitData()
  122. )
  123. await session_manager.join_conversation('new-session-id', 'new-session-id')
  124. await session_manager.join_conversation('new-session-id', 'new-session-id')
  125. assert session_instance.initialize_agent.call_count == 1
  126. assert sio.enter_room.await_count == 2
  127. @pytest.mark.asyncio
  128. async def test_join_cluster_session():
  129. session_instance = AsyncMock()
  130. session_instance.agent_session = MagicMock()
  131. mock_session = MagicMock()
  132. mock_session.return_value = session_instance
  133. sio = get_mock_sio()
  134. is_agent_loop_running_in_cluster_mock = AsyncMock()
  135. is_agent_loop_running_in_cluster_mock.return_value = True
  136. with (
  137. patch('openhands.server.session.manager.Session', mock_session),
  138. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  139. patch(
  140. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  141. AsyncMock(),
  142. ),
  143. patch(
  144. 'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
  145. is_agent_loop_running_in_cluster_mock,
  146. ),
  147. ):
  148. async with SessionManager(
  149. sio, AppConfig(), InMemoryFileStore()
  150. ) as session_manager:
  151. await session_manager.join_conversation('new-session-id', 'new-session-id')
  152. assert session_instance.initialize_agent.call_count == 0
  153. assert sio.enter_room.await_count == 1
  154. @pytest.mark.asyncio
  155. async def test_add_to_local_event_stream():
  156. session_instance = AsyncMock()
  157. session_instance.agent_session = MagicMock()
  158. mock_session = MagicMock()
  159. mock_session.return_value = session_instance
  160. sio = get_mock_sio()
  161. is_agent_loop_running_in_cluster_mock = AsyncMock()
  162. is_agent_loop_running_in_cluster_mock.return_value = False
  163. with (
  164. patch('openhands.server.session.manager.Session', mock_session),
  165. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  166. patch(
  167. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  168. AsyncMock(),
  169. ),
  170. patch(
  171. 'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
  172. is_agent_loop_running_in_cluster_mock,
  173. ),
  174. ):
  175. async with SessionManager(
  176. sio, AppConfig(), InMemoryFileStore()
  177. ) as session_manager:
  178. await session_manager.maybe_start_agent_loop(
  179. 'new-session-id', ConversationInitData()
  180. )
  181. await session_manager.join_conversation('new-session-id', 'connection-id')
  182. await session_manager.send_to_event_stream(
  183. 'connection-id', {'event_type': 'some_event'}
  184. )
  185. session_instance.dispatch.assert_called_once_with({'event_type': 'some_event'})
  186. @pytest.mark.asyncio
  187. async def test_add_to_cluster_event_stream():
  188. session_instance = AsyncMock()
  189. session_instance.agent_session = MagicMock()
  190. mock_session = MagicMock()
  191. mock_session.return_value = session_instance
  192. sio = get_mock_sio()
  193. is_agent_loop_running_in_cluster_mock = AsyncMock()
  194. is_agent_loop_running_in_cluster_mock.return_value = True
  195. with (
  196. patch('openhands.server.session.manager.Session', mock_session),
  197. patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
  198. patch(
  199. 'openhands.server.session.manager.SessionManager._redis_subscribe',
  200. AsyncMock(),
  201. ),
  202. patch(
  203. 'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
  204. is_agent_loop_running_in_cluster_mock,
  205. ),
  206. ):
  207. async with SessionManager(
  208. sio, AppConfig(), InMemoryFileStore()
  209. ) as session_manager:
  210. await session_manager.join_conversation('new-session-id', 'connection-id')
  211. await session_manager.send_to_event_stream(
  212. 'connection-id', {'event_type': 'some_event'}
  213. )
  214. assert sio.manager.redis.publish.await_count == 1
  215. sio.manager.redis.publish.assert_called_once_with(
  216. 'oh_event',
  217. '{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}',
  218. )