Przeglądaj źródła

Fix: Mocking LLM proxy in unit tests (#5639)

tofarr 1 rok temu
rodzic
commit
d76e83b55e
2 zmienionych plików z 27 dodań i 9 usunięć
  1. 25 7
      tests/unit/test_acompletion.py
  2. 2 2
      tests/unit/test_manager.py

+ 25 - 7
tests/unit/test_acompletion.py

@@ -1,5 +1,7 @@
 import asyncio
-from unittest.mock import AsyncMock, patch
+from contextlib import contextmanager
+from typing import Type
+from unittest.mock import AsyncMock, MagicMock, patch
 
 import pytest
 
@@ -14,8 +16,12 @@ config = load_app_config()
 
 @pytest.fixture
 def test_llm():
-    # Create a mock config for testing
-    return LLM(config=config.get_llm_config())
+    return _get_llm(LLM)
+
+
+def _get_llm(type_: Type[LLM]):
+    with _patch_http():
+        return type_(config=config.get_llm_config())
 
 
 @pytest.fixture
@@ -39,6 +45,18 @@ def mock_response():
     ]
 
 
+@contextmanager
+def _patch_http():
+    with patch('openhands.llm.llm.requests.get', MagicMock()) as mock_http:
+        mock_http.json.return_value = {
+            'data': [
+                {'model_name': 'some_model'},
+                {'model_name': 'another_model'},
+            ]
+        }
+        yield
+
+
 @pytest.mark.asyncio
 async def test_acompletion_non_streaming():
     with patch.object(AsyncLLM, '_call_acompletion') as mock_call_acompletion:
@@ -46,7 +64,7 @@ async def test_acompletion_non_streaming():
             'choices': [{'message': {'content': 'This is a test message.'}}]
         }
         mock_call_acompletion.return_value = mock_response
-        test_llm = AsyncLLM(config=config.get_llm_config())
+        test_llm = _get_llm(AsyncLLM)
         response = await test_llm.async_completion(
             messages=[{'role': 'user', 'content': 'Hello!'}],
             stream=False,
@@ -60,7 +78,7 @@ async def test_acompletion_non_streaming():
 async def test_acompletion_streaming(mock_response):
     with patch.object(StreamingLLM, '_call_acompletion') as mock_call_acompletion:
         mock_call_acompletion.return_value.__aiter__.return_value = iter(mock_response)
-        test_llm = StreamingLLM(config=config.get_llm_config())
+        test_llm = _get_llm(StreamingLLM)
         async for chunk in test_llm.async_streaming_completion(
             messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
         ):
@@ -109,7 +127,7 @@ async def test_async_completion_with_user_cancellation(cancel_delay):
         AsyncLLM, '_call_acompletion', new_callable=AsyncMock
     ) as mock_call_acompletion:
         mock_call_acompletion.side_effect = mock_acompletion
-        test_llm = AsyncLLM(config=config.get_llm_config())
+        test_llm = _get_llm(AsyncLLM)
 
         async def cancel_after_delay():
             print(f'Starting cancel_after_delay with delay {cancel_delay}')
@@ -171,7 +189,7 @@ async def test_async_streaming_completion_with_user_cancellation(cancel_after_ch
         AsyncLLM, '_call_acompletion', new_callable=AsyncMock
     ) as mock_call_acompletion:
         mock_call_acompletion.return_value = mock_acompletion()
-        test_llm = StreamingLLM(config=config.get_llm_config())
+        test_llm = _get_llm(StreamingLLM)
 
         received_chunks = []
         with pytest.raises(UserCancelledError):

+ 2 - 2
tests/unit/test_manager.py

@@ -60,7 +60,7 @@ async def test_session_is_running_in_cluster():
         )
     )
     with (
-        patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.05),
+        patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
     ):
         async with SessionManager(
             sio, AppConfig(), InMemoryFileStore()
@@ -87,7 +87,7 @@ async def test_init_new_local_session():
     is_session_running_in_cluster_mock.return_value = False
     with (
         patch('openhands.server.session.manager.Session', mock_session),
-        patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
+        patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
         patch(
             'openhands.server.session.manager.SessionManager._redis_subscribe',
             AsyncMock(),