test_llm.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import copy
  2. from unittest.mock import MagicMock, patch
  3. import pytest
  4. from litellm.exceptions import (
  5. APIConnectionError,
  6. InternalServerError,
  7. RateLimitError,
  8. ServiceUnavailableError,
  9. )
  10. from openhands.core.config import LLMConfig
  11. from openhands.core.exceptions import OperationCancelled
  12. from openhands.llm.llm import LLM
  13. from openhands.llm.metrics import Metrics
  14. @pytest.fixture(autouse=True)
  15. def mock_logger(monkeypatch):
  16. # suppress logging of completion data to file
  17. mock_logger = MagicMock()
  18. monkeypatch.setattr('openhands.llm.debug_mixin.llm_prompt_logger', mock_logger)
  19. monkeypatch.setattr('openhands.llm.debug_mixin.llm_response_logger', mock_logger)
  20. return mock_logger
  21. @pytest.fixture
  22. def default_config():
  23. return LLMConfig(
  24. model='gpt-4o',
  25. api_key='test_key',
  26. num_retries=2,
  27. retry_min_wait=1,
  28. retry_max_wait=2,
  29. )
  30. def test_llm_init_with_default_config(default_config):
  31. llm = LLM(default_config)
  32. assert llm.config.model == 'gpt-4o'
  33. assert llm.config.api_key == 'test_key'
  34. assert isinstance(llm.metrics, Metrics)
  35. assert llm.metrics.model_name == 'gpt-4o'
  36. @patch('openhands.llm.llm.litellm.get_model_info')
  37. def test_llm_init_with_model_info(mock_get_model_info, default_config):
  38. mock_get_model_info.return_value = {
  39. 'max_input_tokens': 8000,
  40. 'max_output_tokens': 2000,
  41. }
  42. llm = LLM(default_config)
  43. llm.init_model_info()
  44. assert llm.config.max_input_tokens == 8000
  45. assert llm.config.max_output_tokens == 2000
  46. @patch('openhands.llm.llm.litellm.get_model_info')
  47. def test_llm_init_without_model_info(mock_get_model_info, default_config):
  48. mock_get_model_info.side_effect = Exception('Model info not available')
  49. llm = LLM(default_config)
  50. llm.init_model_info()
  51. assert llm.config.max_input_tokens == 4096
  52. assert llm.config.max_output_tokens == 4096
  53. def test_llm_init_with_custom_config():
  54. custom_config = LLMConfig(
  55. model='custom-model',
  56. api_key='custom_key',
  57. max_input_tokens=5000,
  58. max_output_tokens=1500,
  59. temperature=0.8,
  60. top_p=0.9,
  61. )
  62. llm = LLM(custom_config)
  63. assert llm.config.model == 'custom-model'
  64. assert llm.config.api_key == 'custom_key'
  65. assert llm.config.max_input_tokens == 5000
  66. assert llm.config.max_output_tokens == 1500
  67. assert llm.config.temperature == 0.8
  68. assert llm.config.top_p == 0.9
  69. def test_llm_init_with_metrics():
  70. config = LLMConfig(model='gpt-4o', api_key='test_key')
  71. metrics = Metrics()
  72. llm = LLM(config, metrics=metrics)
  73. assert llm.metrics is metrics
  74. assert (
  75. llm.metrics.model_name == 'default'
  76. ) # because we didn't specify model_name in Metrics init
  77. def test_llm_reset():
  78. llm = LLM(LLMConfig(model='gpt-4o-mini', api_key='test_key'))
  79. initial_metrics = copy.deepcopy(llm.metrics)
  80. initial_metrics.add_cost(1.0)
  81. llm.reset()
  82. assert llm.metrics._accumulated_cost != initial_metrics._accumulated_cost
  83. assert llm.metrics._costs != initial_metrics._costs
  84. assert isinstance(llm.metrics, Metrics)
  85. @patch('openhands.llm.llm.litellm.get_model_info')
  86. def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
  87. default_config.model = 'openrouter:gpt-4o-mini'
  88. mock_get_model_info.return_value = {
  89. 'max_input_tokens': 7000,
  90. 'max_output_tokens': 1500,
  91. }
  92. llm = LLM(default_config)
  93. llm.init_model_info()
  94. assert llm.config.max_input_tokens == 7000
  95. assert llm.config.max_output_tokens == 1500
  96. mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')
  97. # Tests involving completion and retries
  98. @patch('openhands.llm.llm.litellm_completion')
  99. def test_completion_with_mocked_logger(
  100. mock_litellm_completion, default_config, mock_logger
  101. ):
  102. mock_litellm_completion.return_value = {
  103. 'choices': [{'message': {'content': 'Test response'}}]
  104. }
  105. llm = LLM(config=default_config)
  106. response = llm.completion(
  107. messages=[{'role': 'user', 'content': 'Hello!'}],
  108. stream=False,
  109. )
  110. assert response['choices'][0]['message']['content'] == 'Test response'
  111. assert mock_litellm_completion.call_count == 1
  112. mock_logger.debug.assert_called()
  113. @pytest.mark.parametrize(
  114. 'exception_class,extra_args,expected_retries',
  115. [
  116. (
  117. APIConnectionError,
  118. {'llm_provider': 'test_provider', 'model': 'test_model'},
  119. 2,
  120. ),
  121. (
  122. InternalServerError,
  123. {'llm_provider': 'test_provider', 'model': 'test_model'},
  124. 2,
  125. ),
  126. (
  127. ServiceUnavailableError,
  128. {'llm_provider': 'test_provider', 'model': 'test_model'},
  129. 2,
  130. ),
  131. (RateLimitError, {'llm_provider': 'test_provider', 'model': 'test_model'}, 2),
  132. ],
  133. )
  134. @patch('openhands.llm.llm.litellm_completion')
  135. def test_completion_retries(
  136. mock_litellm_completion,
  137. default_config,
  138. exception_class,
  139. extra_args,
  140. expected_retries,
  141. ):
  142. mock_litellm_completion.side_effect = [
  143. exception_class('Test error message', **extra_args),
  144. {'choices': [{'message': {'content': 'Retry successful'}}]},
  145. ]
  146. llm = LLM(config=default_config)
  147. response = llm.completion(
  148. messages=[{'role': 'user', 'content': 'Hello!'}],
  149. stream=False,
  150. )
  151. assert response['choices'][0]['message']['content'] == 'Retry successful'
  152. assert mock_litellm_completion.call_count == expected_retries
  153. @patch('openhands.llm.llm.litellm_completion')
  154. def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config):
  155. with patch('time.sleep') as mock_sleep:
  156. mock_litellm_completion.side_effect = [
  157. RateLimitError(
  158. 'Rate limit exceeded', llm_provider='test_provider', model='test_model'
  159. ),
  160. {'choices': [{'message': {'content': 'Retry successful'}}]},
  161. ]
  162. llm = LLM(config=default_config)
  163. response = llm.completion(
  164. messages=[{'role': 'user', 'content': 'Hello!'}],
  165. stream=False,
  166. )
  167. assert response['choices'][0]['message']['content'] == 'Retry successful'
  168. assert mock_litellm_completion.call_count == 2
  169. mock_sleep.assert_called_once()
  170. wait_time = mock_sleep.call_args[0][0]
  171. assert (
  172. default_config.retry_min_wait <= wait_time <= default_config.retry_max_wait
  173. ), f'Expected wait time between {default_config.retry_min_wait} and {default_config.retry_max_wait} seconds, but got {wait_time}'
  174. @patch('openhands.llm.llm.litellm_completion')
  175. def test_completion_exhausts_retries(mock_litellm_completion, default_config):
  176. mock_litellm_completion.side_effect = APIConnectionError(
  177. 'Persistent error', llm_provider='test_provider', model='test_model'
  178. )
  179. llm = LLM(config=default_config)
  180. with pytest.raises(APIConnectionError):
  181. llm.completion(
  182. messages=[{'role': 'user', 'content': 'Hello!'}],
  183. stream=False,
  184. )
  185. assert mock_litellm_completion.call_count == llm.config.num_retries
  186. @patch('openhands.llm.llm.litellm_completion')
  187. def test_completion_operation_cancelled(mock_litellm_completion, default_config):
  188. mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
  189. llm = LLM(config=default_config)
  190. with pytest.raises(OperationCancelled):
  191. llm.completion(
  192. messages=[{'role': 'user', 'content': 'Hello!'}],
  193. stream=False,
  194. )
  195. assert mock_litellm_completion.call_count == 1
  196. @patch('openhands.llm.llm.litellm_completion')
  197. def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
  198. def side_effect(*args, **kwargs):
  199. raise KeyboardInterrupt('Simulated KeyboardInterrupt')
  200. mock_litellm_completion.side_effect = side_effect
  201. llm = LLM(config=default_config)
  202. with pytest.raises(OperationCancelled):
  203. try:
  204. llm.completion(
  205. messages=[{'role': 'user', 'content': 'Hello!'}],
  206. stream=False,
  207. )
  208. except KeyboardInterrupt:
  209. raise OperationCancelled('Operation cancelled due to KeyboardInterrupt')
  210. assert mock_litellm_completion.call_count == 1
  211. @patch('openhands.llm.llm.litellm_completion')
  212. def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_config):
  213. global _should_exit
  214. def side_effect(*args, **kwargs):
  215. global _should_exit
  216. _should_exit = True
  217. return {'choices': [{'message': {'content': 'Simulated interrupt response'}}]}
  218. mock_litellm_completion.side_effect = side_effect
  219. llm = LLM(config=default_config)
  220. result = llm.completion(
  221. messages=[{'role': 'user', 'content': 'Hello!'}],
  222. stream=False,
  223. )
  224. assert mock_litellm_completion.call_count == 1
  225. assert result['choices'][0]['message']['content'] == 'Simulated interrupt response'
  226. assert _should_exit
  227. _should_exit = False
  228. @patch('openhands.llm.llm.litellm_completion')
  229. def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
  230. mock_response = {
  231. 'choices': [{'message': {'content': 'This is a mocked response.'}}]
  232. }
  233. mock_litellm_completion.return_value = mock_response
  234. test_llm = LLM(config=default_config)
  235. response = test_llm.completion(
  236. messages=[{'role': 'user', 'content': 'Hello!'}],
  237. stream=False,
  238. drop_params=True,
  239. )
  240. # Assertions
  241. assert response['choices'][0]['message']['content'] == 'This is a mocked response.'
  242. mock_litellm_completion.assert_called_once()
  243. # Check if the correct arguments were passed to litellm_completion
  244. call_args = mock_litellm_completion.call_args[1] # Get keyword arguments
  245. assert call_args['model'] == default_config.model
  246. assert call_args['messages'] == [{'role': 'user', 'content': 'Hello!'}]
  247. assert not call_args['stream']
  248. @patch('openhands.llm.llm.litellm_completion')
  249. def test_completion_with_two_positional_args(mock_litellm_completion, default_config):
  250. mock_response = {
  251. 'choices': [{'message': {'content': 'Response to positional args.'}}]
  252. }
  253. mock_litellm_completion.return_value = mock_response
  254. test_llm = LLM(config=default_config)
  255. response = test_llm.completion(
  256. 'some-model-to-be-ignored',
  257. [{'role': 'user', 'content': 'Hello from positional args!'}],
  258. stream=False,
  259. )
  260. # Assertions
  261. assert (
  262. response['choices'][0]['message']['content'] == 'Response to positional args.'
  263. )
  264. mock_litellm_completion.assert_called_once()
  265. # Check if the correct arguments were passed to litellm_completion
  266. call_args, call_kwargs = mock_litellm_completion.call_args
  267. assert (
  268. call_kwargs['model'] == default_config.model
  269. ) # Should use the model from config, not the first arg
  270. assert call_kwargs['messages'] == [
  271. {'role': 'user', 'content': 'Hello from positional args!'}
  272. ]
  273. assert not call_kwargs['stream']
  274. # Ensure the first positional argument (model) was ignored
  275. assert (
  276. len(call_args) == 0
  277. ) # No positional args should be passed to litellm_completion here
  278. @patch('openhands.llm.llm.litellm_completion')
  279. def test_llm_cloudflare_blockage(mock_litellm_completion, default_config):
  280. from litellm.exceptions import APIError
  281. from openhands.core.exceptions import CloudFlareBlockageError
  282. llm = LLM(default_config)
  283. mock_litellm_completion.side_effect = APIError(
  284. message='Attention Required! | Cloudflare',
  285. llm_provider='test_provider',
  286. model='test_model',
  287. status_code=403,
  288. )
  289. with pytest.raises(CloudFlareBlockageError, match='Request blocked by CloudFlare'):
  290. llm.completion(messages=[{'role': 'user', 'content': 'Hello'}])
  291. # Ensure the completion was called
  292. mock_litellm_completion.assert_called_once()