test_llm.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. from unittest.mock import MagicMock, patch
  2. import pytest
  3. from litellm.exceptions import (
  4. APIConnectionError,
  5. ContentPolicyViolationError,
  6. InternalServerError,
  7. OpenAIError,
  8. RateLimitError,
  9. )
  10. from openhands.core.config import LLMConfig
  11. from openhands.core.exceptions import OperationCancelled
  12. from openhands.core.metrics import Metrics
  13. from openhands.llm.llm import LLM
  14. @pytest.fixture(autouse=True)
  15. def mock_logger(monkeypatch):
  16. # suppress logging of completion data to file
  17. mock_logger = MagicMock()
  18. monkeypatch.setattr('openhands.llm.debug_mixin.llm_prompt_logger', mock_logger)
  19. monkeypatch.setattr('openhands.llm.debug_mixin.llm_response_logger', mock_logger)
  20. return mock_logger
  21. @pytest.fixture
  22. def default_config():
  23. return LLMConfig(
  24. model='gpt-4o',
  25. api_key='test_key',
  26. num_retries=2,
  27. retry_min_wait=1,
  28. retry_max_wait=2,
  29. )
  30. def test_llm_init_with_default_config(default_config):
  31. llm = LLM(default_config)
  32. assert llm.config.model == 'gpt-4o'
  33. assert llm.config.api_key == 'test_key'
  34. assert isinstance(llm.metrics, Metrics)
  35. @patch('openhands.llm.llm.litellm.get_model_info')
  36. def test_llm_init_with_model_info(mock_get_model_info, default_config):
  37. mock_get_model_info.return_value = {
  38. 'max_input_tokens': 8000,
  39. 'max_output_tokens': 2000,
  40. }
  41. llm = LLM(default_config)
  42. assert llm.config.max_input_tokens == 8000
  43. assert llm.config.max_output_tokens == 2000
  44. @patch('openhands.llm.llm.litellm.get_model_info')
  45. def test_llm_init_without_model_info(mock_get_model_info, default_config):
  46. mock_get_model_info.side_effect = Exception('Model info not available')
  47. llm = LLM(default_config)
  48. assert llm.config.max_input_tokens == 4096
  49. assert llm.config.max_output_tokens == 4096
  50. def test_llm_init_with_custom_config():
  51. custom_config = LLMConfig(
  52. model='custom-model',
  53. api_key='custom_key',
  54. max_input_tokens=5000,
  55. max_output_tokens=1500,
  56. temperature=0.8,
  57. top_p=0.9,
  58. )
  59. llm = LLM(custom_config)
  60. assert llm.config.model == 'custom-model'
  61. assert llm.config.api_key == 'custom_key'
  62. assert llm.config.max_input_tokens == 5000
  63. assert llm.config.max_output_tokens == 1500
  64. assert llm.config.temperature == 0.8
  65. assert llm.config.top_p == 0.9
  66. def test_llm_init_with_metrics():
  67. config = LLMConfig(model='gpt-4o', api_key='test_key')
  68. metrics = Metrics()
  69. llm = LLM(config, metrics=metrics)
  70. assert llm.metrics is metrics
  71. def test_llm_reset():
  72. llm = LLM(LLMConfig(model='gpt-4o-mini', api_key='test_key'))
  73. initial_metrics = llm.metrics
  74. llm.reset()
  75. assert llm.metrics is not initial_metrics
  76. assert isinstance(llm.metrics, Metrics)
  77. @patch('openhands.llm.llm.litellm.get_model_info')
  78. def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
  79. default_config.model = 'openrouter:gpt-4o-mini'
  80. mock_get_model_info.return_value = {
  81. 'max_input_tokens': 7000,
  82. 'max_output_tokens': 1500,
  83. }
  84. llm = LLM(default_config)
  85. assert llm.config.max_input_tokens == 7000
  86. assert llm.config.max_output_tokens == 1500
  87. mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')
  88. # Tests involving completion and retries
  89. @patch('openhands.llm.llm.litellm_completion')
  90. def test_completion_with_mocked_logger(
  91. mock_litellm_completion, default_config, mock_logger
  92. ):
  93. mock_litellm_completion.return_value = {
  94. 'choices': [{'message': {'content': 'Test response'}}]
  95. }
  96. llm = LLM(config=default_config)
  97. response = llm.completion(
  98. messages=[{'role': 'user', 'content': 'Hello!'}],
  99. stream=False,
  100. )
  101. assert response['choices'][0]['message']['content'] == 'Test response'
  102. assert mock_litellm_completion.call_count == 1
  103. mock_logger.debug.assert_called()
  104. @pytest.mark.parametrize(
  105. 'exception_class,extra_args,expected_retries',
  106. [
  107. (
  108. APIConnectionError,
  109. {'llm_provider': 'test_provider', 'model': 'test_model'},
  110. 2,
  111. ),
  112. (
  113. ContentPolicyViolationError,
  114. {'model': 'test_model', 'llm_provider': 'test_provider'},
  115. 2,
  116. ),
  117. (
  118. InternalServerError,
  119. {'llm_provider': 'test_provider', 'model': 'test_model'},
  120. 2,
  121. ),
  122. (OpenAIError, {}, 2),
  123. (RateLimitError, {'llm_provider': 'test_provider', 'model': 'test_model'}, 2),
  124. ],
  125. )
  126. @patch('openhands.llm.llm.litellm_completion')
  127. def test_completion_retries(
  128. mock_litellm_completion,
  129. default_config,
  130. exception_class,
  131. extra_args,
  132. expected_retries,
  133. ):
  134. mock_litellm_completion.side_effect = [
  135. exception_class('Test error message', **extra_args),
  136. {'choices': [{'message': {'content': 'Retry successful'}}]},
  137. ]
  138. llm = LLM(config=default_config)
  139. response = llm.completion(
  140. messages=[{'role': 'user', 'content': 'Hello!'}],
  141. stream=False,
  142. )
  143. assert response['choices'][0]['message']['content'] == 'Retry successful'
  144. assert mock_litellm_completion.call_count == expected_retries
  145. @patch('openhands.llm.llm.litellm_completion')
  146. def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config):
  147. with patch('time.sleep') as mock_sleep:
  148. mock_litellm_completion.side_effect = [
  149. RateLimitError(
  150. 'Rate limit exceeded', llm_provider='test_provider', model='test_model'
  151. ),
  152. {'choices': [{'message': {'content': 'Retry successful'}}]},
  153. ]
  154. llm = LLM(config=default_config)
  155. response = llm.completion(
  156. messages=[{'role': 'user', 'content': 'Hello!'}],
  157. stream=False,
  158. )
  159. assert response['choices'][0]['message']['content'] == 'Retry successful'
  160. assert mock_litellm_completion.call_count == 2
  161. mock_sleep.assert_called_once()
  162. wait_time = mock_sleep.call_args[0][0]
  163. assert (
  164. default_config.retry_min_wait <= wait_time <= default_config.retry_max_wait
  165. ), f'Expected wait time between {default_config.retry_min_wait} and {default_config.retry_max_wait} seconds, but got {wait_time}'
  166. @patch('openhands.llm.llm.litellm_completion')
  167. def test_completion_exhausts_retries(mock_litellm_completion, default_config):
  168. mock_litellm_completion.side_effect = APIConnectionError(
  169. 'Persistent error', llm_provider='test_provider', model='test_model'
  170. )
  171. llm = LLM(config=default_config)
  172. with pytest.raises(APIConnectionError):
  173. llm.completion(
  174. messages=[{'role': 'user', 'content': 'Hello!'}],
  175. stream=False,
  176. )
  177. assert mock_litellm_completion.call_count == llm.config.num_retries
  178. @patch('openhands.llm.llm.litellm_completion')
  179. def test_completion_operation_cancelled(mock_litellm_completion, default_config):
  180. mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
  181. llm = LLM(config=default_config)
  182. with pytest.raises(OperationCancelled):
  183. llm.completion(
  184. messages=[{'role': 'user', 'content': 'Hello!'}],
  185. stream=False,
  186. )
  187. assert mock_litellm_completion.call_count == 1
  188. @patch('openhands.llm.llm.litellm_completion')
  189. def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
  190. def side_effect(*args, **kwargs):
  191. raise KeyboardInterrupt('Simulated KeyboardInterrupt')
  192. mock_litellm_completion.side_effect = side_effect
  193. llm = LLM(config=default_config)
  194. with pytest.raises(OperationCancelled):
  195. try:
  196. llm.completion(
  197. messages=[{'role': 'user', 'content': 'Hello!'}],
  198. stream=False,
  199. )
  200. except KeyboardInterrupt:
  201. raise OperationCancelled('Operation cancelled due to KeyboardInterrupt')
  202. assert mock_litellm_completion.call_count == 1
  203. @patch('openhands.llm.llm.litellm_completion')
  204. def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_config):
  205. global _should_exit
  206. def side_effect(*args, **kwargs):
  207. global _should_exit
  208. _should_exit = True
  209. return {'choices': [{'message': {'content': 'Simulated interrupt response'}}]}
  210. mock_litellm_completion.side_effect = side_effect
  211. llm = LLM(config=default_config)
  212. result = llm.completion(
  213. messages=[{'role': 'user', 'content': 'Hello!'}],
  214. stream=False,
  215. )
  216. assert mock_litellm_completion.call_count == 1
  217. assert result['choices'][0]['message']['content'] == 'Simulated interrupt response'
  218. assert _should_exit
  219. _should_exit = False
  220. @patch('openhands.llm.llm.litellm_completion')
  221. def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
  222. mock_response = {
  223. 'choices': [{'message': {'content': 'This is a mocked response.'}}]
  224. }
  225. mock_litellm_completion.return_value = mock_response
  226. test_llm = LLM(config=default_config)
  227. response = test_llm.completion(
  228. messages=[{'role': 'user', 'content': 'Hello!'}],
  229. stream=False,
  230. drop_params=True,
  231. )
  232. # Assertions
  233. assert response['choices'][0]['message']['content'] == 'This is a mocked response.'
  234. mock_litellm_completion.assert_called_once()
  235. # Check if the correct arguments were passed to litellm_completion
  236. call_args = mock_litellm_completion.call_args[1] # Get keyword arguments
  237. assert call_args['model'] == default_config.model
  238. assert call_args['messages'] == [{'role': 'user', 'content': 'Hello!'}]
  239. assert not call_args['stream']