|
|
@@ -1,5 +1,7 @@
|
|
|
import asyncio
|
|
|
-from unittest.mock import AsyncMock, patch
|
|
|
+from contextlib import contextmanager
|
|
|
+from typing import Type
|
|
|
+from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
@@ -14,8 +16,12 @@ config = load_app_config()
|
|
|
|
|
|
@pytest.fixture
|
|
|
def test_llm():
|
|
|
- # Create a mock config for testing
|
|
|
- return LLM(config=config.get_llm_config())
|
|
|
+ return _get_llm(LLM)
|
|
|
+
|
|
|
+
|
|
|
+def _get_llm(type_: Type[LLM]):
|
|
|
+ with _patch_http():
|
|
|
+ return type_(config=config.get_llm_config())
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
@@ -39,6 +45,18 @@ def mock_response():
|
|
|
]
|
|
|
|
|
|
|
|
|
+@contextmanager
|
|
|
+def _patch_http():
|
|
|
+ with patch('openhands.llm.llm.requests.get', MagicMock()) as mock_http:
|
|
|
+ mock_http.json.return_value = {
|
|
|
+ 'data': [
|
|
|
+ {'model_name': 'some_model'},
|
|
|
+ {'model_name': 'another_model'},
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ yield
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_acompletion_non_streaming():
|
|
|
with patch.object(AsyncLLM, '_call_acompletion') as mock_call_acompletion:
|
|
|
@@ -46,7 +64,7 @@ async def test_acompletion_non_streaming():
|
|
|
'choices': [{'message': {'content': 'This is a test message.'}}]
|
|
|
}
|
|
|
mock_call_acompletion.return_value = mock_response
|
|
|
- test_llm = AsyncLLM(config=config.get_llm_config())
|
|
|
+ test_llm = _get_llm(AsyncLLM)
|
|
|
response = await test_llm.async_completion(
|
|
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
|
|
stream=False,
|
|
|
@@ -60,7 +78,7 @@ async def test_acompletion_non_streaming():
|
|
|
async def test_acompletion_streaming(mock_response):
|
|
|
with patch.object(StreamingLLM, '_call_acompletion') as mock_call_acompletion:
|
|
|
mock_call_acompletion.return_value.__aiter__.return_value = iter(mock_response)
|
|
|
- test_llm = StreamingLLM(config=config.get_llm_config())
|
|
|
+ test_llm = _get_llm(StreamingLLM)
|
|
|
async for chunk in test_llm.async_streaming_completion(
|
|
|
messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
|
|
|
):
|
|
|
@@ -109,7 +127,7 @@ async def test_async_completion_with_user_cancellation(cancel_delay):
|
|
|
AsyncLLM, '_call_acompletion', new_callable=AsyncMock
|
|
|
) as mock_call_acompletion:
|
|
|
mock_call_acompletion.side_effect = mock_acompletion
|
|
|
- test_llm = AsyncLLM(config=config.get_llm_config())
|
|
|
+ test_llm = _get_llm(AsyncLLM)
|
|
|
|
|
|
async def cancel_after_delay():
|
|
|
print(f'Starting cancel_after_delay with delay {cancel_delay}')
|
|
|
@@ -171,7 +189,7 @@ async def test_async_streaming_completion_with_user_cancellation(cancel_after_ch
|
|
|
AsyncLLM, '_call_acompletion', new_callable=AsyncMock
|
|
|
) as mock_call_acompletion:
|
|
|
mock_call_acompletion.return_value = mock_acompletion()
|
|
|
- test_llm = StreamingLLM(config=config.get_llm_config())
|
|
|
+ test_llm = _get_llm(StreamingLLM)
|
|
|
|
|
|
received_chunks = []
|
|
|
with pytest.raises(UserCancelledError):
|