test_llm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. from unittest.mock import MagicMock, patch
  2. import pytest
  3. from litellm.exceptions import (
  4. APIConnectionError,
  5. InternalServerError,
  6. RateLimitError,
  7. ServiceUnavailableError,
  8. )
  9. from openhands.core.config import LLMConfig
  10. from openhands.core.exceptions import OperationCancelled
  11. from openhands.core.metrics import Metrics
  12. from openhands.llm.llm import LLM
  13. @pytest.fixture(autouse=True)
  14. def mock_logger(monkeypatch):
  15. # suppress logging of completion data to file
  16. mock_logger = MagicMock()
  17. monkeypatch.setattr('openhands.llm.debug_mixin.llm_prompt_logger', mock_logger)
  18. monkeypatch.setattr('openhands.llm.debug_mixin.llm_response_logger', mock_logger)
  19. return mock_logger
  20. @pytest.fixture
  21. def default_config():
  22. return LLMConfig(
  23. model='gpt-4o',
  24. api_key='test_key',
  25. num_retries=2,
  26. retry_min_wait=1,
  27. retry_max_wait=2,
  28. )
  29. def test_llm_init_with_default_config(default_config):
  30. llm = LLM(default_config)
  31. assert llm.config.model == 'gpt-4o'
  32. assert llm.config.api_key == 'test_key'
  33. assert isinstance(llm.metrics, Metrics)
  34. @patch('openhands.llm.llm.litellm.get_model_info')
  35. def test_llm_init_with_model_info(mock_get_model_info, default_config):
  36. mock_get_model_info.return_value = {
  37. 'max_input_tokens': 8000,
  38. 'max_output_tokens': 2000,
  39. }
  40. llm = LLM(default_config)
  41. assert llm.config.max_input_tokens == 8000
  42. assert llm.config.max_output_tokens == 2000
  43. @patch('openhands.llm.llm.litellm.get_model_info')
  44. def test_llm_init_without_model_info(mock_get_model_info, default_config):
  45. mock_get_model_info.side_effect = Exception('Model info not available')
  46. llm = LLM(default_config)
  47. assert llm.config.max_input_tokens == 4096
  48. assert llm.config.max_output_tokens == 4096
  49. def test_llm_init_with_custom_config():
  50. custom_config = LLMConfig(
  51. model='custom-model',
  52. api_key='custom_key',
  53. max_input_tokens=5000,
  54. max_output_tokens=1500,
  55. temperature=0.8,
  56. top_p=0.9,
  57. )
  58. llm = LLM(custom_config)
  59. assert llm.config.model == 'custom-model'
  60. assert llm.config.api_key == 'custom_key'
  61. assert llm.config.max_input_tokens == 5000
  62. assert llm.config.max_output_tokens == 1500
  63. assert llm.config.temperature == 0.8
  64. assert llm.config.top_p == 0.9
  65. def test_llm_init_with_metrics():
  66. config = LLMConfig(model='gpt-4o', api_key='test_key')
  67. metrics = Metrics()
  68. llm = LLM(config, metrics=metrics)
  69. assert llm.metrics is metrics
  70. def test_llm_reset():
  71. llm = LLM(LLMConfig(model='gpt-4o-mini', api_key='test_key'))
  72. initial_metrics = llm.metrics
  73. llm.reset()
  74. assert llm.metrics is not initial_metrics
  75. assert isinstance(llm.metrics, Metrics)
  76. @patch('openhands.llm.llm.litellm.get_model_info')
  77. def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
  78. default_config.model = 'openrouter:gpt-4o-mini'
  79. mock_get_model_info.return_value = {
  80. 'max_input_tokens': 7000,
  81. 'max_output_tokens': 1500,
  82. }
  83. llm = LLM(default_config)
  84. assert llm.config.max_input_tokens == 7000
  85. assert llm.config.max_output_tokens == 1500
  86. mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')
  87. # Tests involving completion and retries
  88. @patch('openhands.llm.llm.litellm_completion')
  89. def test_completion_with_mocked_logger(
  90. mock_litellm_completion, default_config, mock_logger
  91. ):
  92. mock_litellm_completion.return_value = {
  93. 'choices': [{'message': {'content': 'Test response'}}]
  94. }
  95. llm = LLM(config=default_config)
  96. response = llm.completion(
  97. messages=[{'role': 'user', 'content': 'Hello!'}],
  98. stream=False,
  99. )
  100. assert response['choices'][0]['message']['content'] == 'Test response'
  101. assert mock_litellm_completion.call_count == 1
  102. mock_logger.debug.assert_called()
  103. @pytest.mark.parametrize(
  104. 'exception_class,extra_args,expected_retries',
  105. [
  106. (
  107. APIConnectionError,
  108. {'llm_provider': 'test_provider', 'model': 'test_model'},
  109. 2,
  110. ),
  111. (
  112. InternalServerError,
  113. {'llm_provider': 'test_provider', 'model': 'test_model'},
  114. 2,
  115. ),
  116. (
  117. ServiceUnavailableError,
  118. {'llm_provider': 'test_provider', 'model': 'test_model'},
  119. 2,
  120. ),
  121. (RateLimitError, {'llm_provider': 'test_provider', 'model': 'test_model'}, 2),
  122. ],
  123. )
  124. @patch('openhands.llm.llm.litellm_completion')
  125. def test_completion_retries(
  126. mock_litellm_completion,
  127. default_config,
  128. exception_class,
  129. extra_args,
  130. expected_retries,
  131. ):
  132. mock_litellm_completion.side_effect = [
  133. exception_class('Test error message', **extra_args),
  134. {'choices': [{'message': {'content': 'Retry successful'}}]},
  135. ]
  136. llm = LLM(config=default_config)
  137. response = llm.completion(
  138. messages=[{'role': 'user', 'content': 'Hello!'}],
  139. stream=False,
  140. )
  141. assert response['choices'][0]['message']['content'] == 'Retry successful'
  142. assert mock_litellm_completion.call_count == expected_retries
  143. @patch('openhands.llm.llm.litellm_completion')
  144. def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config):
  145. with patch('time.sleep') as mock_sleep:
  146. mock_litellm_completion.side_effect = [
  147. RateLimitError(
  148. 'Rate limit exceeded', llm_provider='test_provider', model='test_model'
  149. ),
  150. {'choices': [{'message': {'content': 'Retry successful'}}]},
  151. ]
  152. llm = LLM(config=default_config)
  153. response = llm.completion(
  154. messages=[{'role': 'user', 'content': 'Hello!'}],
  155. stream=False,
  156. )
  157. assert response['choices'][0]['message']['content'] == 'Retry successful'
  158. assert mock_litellm_completion.call_count == 2
  159. mock_sleep.assert_called_once()
  160. wait_time = mock_sleep.call_args[0][0]
  161. assert (
  162. default_config.retry_min_wait <= wait_time <= default_config.retry_max_wait
  163. ), f'Expected wait time between {default_config.retry_min_wait} and {default_config.retry_max_wait} seconds, but got {wait_time}'
  164. @patch('openhands.llm.llm.litellm_completion')
  165. def test_completion_exhausts_retries(mock_litellm_completion, default_config):
  166. mock_litellm_completion.side_effect = APIConnectionError(
  167. 'Persistent error', llm_provider='test_provider', model='test_model'
  168. )
  169. llm = LLM(config=default_config)
  170. with pytest.raises(APIConnectionError):
  171. llm.completion(
  172. messages=[{'role': 'user', 'content': 'Hello!'}],
  173. stream=False,
  174. )
  175. assert mock_litellm_completion.call_count == llm.config.num_retries
  176. @patch('openhands.llm.llm.litellm_completion')
  177. def test_completion_operation_cancelled(mock_litellm_completion, default_config):
  178. mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
  179. llm = LLM(config=default_config)
  180. with pytest.raises(OperationCancelled):
  181. llm.completion(
  182. messages=[{'role': 'user', 'content': 'Hello!'}],
  183. stream=False,
  184. )
  185. assert mock_litellm_completion.call_count == 1
  186. @patch('openhands.llm.llm.litellm_completion')
  187. def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
  188. def side_effect(*args, **kwargs):
  189. raise KeyboardInterrupt('Simulated KeyboardInterrupt')
  190. mock_litellm_completion.side_effect = side_effect
  191. llm = LLM(config=default_config)
  192. with pytest.raises(OperationCancelled):
  193. try:
  194. llm.completion(
  195. messages=[{'role': 'user', 'content': 'Hello!'}],
  196. stream=False,
  197. )
  198. except KeyboardInterrupt:
  199. raise OperationCancelled('Operation cancelled due to KeyboardInterrupt')
  200. assert mock_litellm_completion.call_count == 1
  201. @patch('openhands.llm.llm.litellm_completion')
  202. def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_config):
  203. global _should_exit
  204. def side_effect(*args, **kwargs):
  205. global _should_exit
  206. _should_exit = True
  207. return {'choices': [{'message': {'content': 'Simulated interrupt response'}}]}
  208. mock_litellm_completion.side_effect = side_effect
  209. llm = LLM(config=default_config)
  210. result = llm.completion(
  211. messages=[{'role': 'user', 'content': 'Hello!'}],
  212. stream=False,
  213. )
  214. assert mock_litellm_completion.call_count == 1
  215. assert result['choices'][0]['message']['content'] == 'Simulated interrupt response'
  216. assert _should_exit
  217. _should_exit = False
  218. @patch('openhands.llm.llm.litellm_completion')
  219. def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
  220. mock_response = {
  221. 'choices': [{'message': {'content': 'This is a mocked response.'}}]
  222. }
  223. mock_litellm_completion.return_value = mock_response
  224. test_llm = LLM(config=default_config)
  225. response = test_llm.completion(
  226. messages=[{'role': 'user', 'content': 'Hello!'}],
  227. stream=False,
  228. drop_params=True,
  229. )
  230. # Assertions
  231. assert response['choices'][0]['message']['content'] == 'This is a mocked response.'
  232. mock_litellm_completion.assert_called_once()
  233. # Check if the correct arguments were passed to litellm_completion
  234. call_args = mock_litellm_completion.call_args[1] # Get keyword arguments
  235. assert call_args['model'] == default_config.model
  236. assert call_args['messages'] == [{'role': 'user', 'content': 'Hello!'}]
  237. assert not call_args['stream']
  238. @patch('openhands.llm.llm.litellm_completion')
  239. def test_completion_with_two_positional_args(mock_litellm_completion, default_config):
  240. mock_response = {
  241. 'choices': [{'message': {'content': 'Response to positional args.'}}]
  242. }
  243. mock_litellm_completion.return_value = mock_response
  244. test_llm = LLM(config=default_config)
  245. response = test_llm.completion(
  246. 'some-model-to-be-ignored',
  247. [{'role': 'user', 'content': 'Hello from positional args!'}],
  248. stream=False,
  249. )
  250. # Assertions
  251. assert (
  252. response['choices'][0]['message']['content'] == 'Response to positional args.'
  253. )
  254. mock_litellm_completion.assert_called_once()
  255. # Check if the correct arguments were passed to litellm_completion
  256. call_args, call_kwargs = mock_litellm_completion.call_args
  257. assert (
  258. call_kwargs['model'] == default_config.model
  259. ) # Should use the model from config, not the first arg
  260. assert call_kwargs['messages'] == [
  261. {'role': 'user', 'content': 'Hello from positional args!'}
  262. ]
  263. assert not call_kwargs['stream']
  264. # Ensure the first positional argument (model) was ignored
  265. assert (
  266. len(call_args) == 0
  267. ) # No positional args should be passed to litellm_completion here