test_llm.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. import copy
  2. from unittest.mock import MagicMock, patch
  3. import pytest
  4. from litellm.exceptions import (
  5. APIConnectionError,
  6. InternalServerError,
  7. RateLimitError,
  8. ServiceUnavailableError,
  9. )
  10. from openhands.core.config import LLMConfig
  11. from openhands.core.exceptions import OperationCancelled
  12. from openhands.core.message import Message, TextContent
  13. from openhands.llm.llm import LLM
  14. from openhands.llm.metrics import Metrics
  15. @pytest.fixture(autouse=True)
  16. def mock_logger(monkeypatch):
  17. # suppress logging of completion data to file
  18. mock_logger = MagicMock()
  19. monkeypatch.setattr('openhands.llm.debug_mixin.llm_prompt_logger', mock_logger)
  20. monkeypatch.setattr('openhands.llm.debug_mixin.llm_response_logger', mock_logger)
  21. monkeypatch.setattr('openhands.llm.llm.logger', mock_logger)
  22. return mock_logger
  23. @pytest.fixture
  24. def default_config():
  25. return LLMConfig(
  26. model='gpt-4o',
  27. api_key='test_key',
  28. num_retries=2,
  29. retry_min_wait=1,
  30. retry_max_wait=2,
  31. )
  32. def test_llm_init_with_default_config(default_config):
  33. llm = LLM(default_config)
  34. assert llm.config.model == 'gpt-4o'
  35. assert llm.config.api_key == 'test_key'
  36. assert isinstance(llm.metrics, Metrics)
  37. assert llm.metrics.model_name == 'gpt-4o'
  38. @patch('openhands.llm.llm.litellm.get_model_info')
  39. def test_llm_init_with_model_info(mock_get_model_info, default_config):
  40. mock_get_model_info.return_value = {
  41. 'max_input_tokens': 8000,
  42. 'max_output_tokens': 2000,
  43. }
  44. llm = LLM(default_config)
  45. llm.init_model_info()
  46. assert llm.config.max_input_tokens == 8000
  47. assert llm.config.max_output_tokens == 2000
  48. @patch('openhands.llm.llm.litellm.get_model_info')
  49. def test_llm_init_without_model_info(mock_get_model_info, default_config):
  50. mock_get_model_info.side_effect = Exception('Model info not available')
  51. llm = LLM(default_config)
  52. llm.init_model_info()
  53. assert llm.config.max_input_tokens == 4096
  54. assert llm.config.max_output_tokens == 4096
  55. def test_llm_init_with_custom_config():
  56. custom_config = LLMConfig(
  57. model='custom-model',
  58. api_key='custom_key',
  59. max_input_tokens=5000,
  60. max_output_tokens=1500,
  61. temperature=0.8,
  62. top_p=0.9,
  63. )
  64. llm = LLM(custom_config)
  65. assert llm.config.model == 'custom-model'
  66. assert llm.config.api_key == 'custom_key'
  67. assert llm.config.max_input_tokens == 5000
  68. assert llm.config.max_output_tokens == 1500
  69. assert llm.config.temperature == 0.8
  70. assert llm.config.top_p == 0.9
  71. def test_llm_init_with_metrics():
  72. config = LLMConfig(model='gpt-4o', api_key='test_key')
  73. metrics = Metrics()
  74. llm = LLM(config, metrics=metrics)
  75. assert llm.metrics is metrics
  76. assert (
  77. llm.metrics.model_name == 'default'
  78. ) # because we didn't specify model_name in Metrics init
  79. @patch('openhands.llm.llm.litellm_completion')
  80. @patch('time.time')
  81. def test_response_latency_tracking(mock_time, mock_litellm_completion):
  82. # Mock time.time() to return controlled values
  83. mock_time.side_effect = [1000.0, 1002.5] # Start time, end time (2.5s difference)
  84. # Mock the completion response with a specific ID
  85. mock_response = {
  86. 'id': 'test-response-123',
  87. 'choices': [{'message': {'content': 'Test response'}}],
  88. }
  89. mock_litellm_completion.return_value = mock_response
  90. # Create LLM instance and make a completion call
  91. config = LLMConfig(model='gpt-4o', api_key='test_key')
  92. llm = LLM(config)
  93. response = llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
  94. # Verify the response latency was tracked correctly
  95. assert len(llm.metrics.response_latencies) == 1
  96. latency_record = llm.metrics.response_latencies[0]
  97. assert latency_record.model == 'gpt-4o'
  98. assert (
  99. latency_record.latency == 2.5
  100. ) # Should be the difference between our mocked times
  101. assert latency_record.response_id == 'test-response-123'
  102. # Verify the completion response was returned correctly
  103. assert response['id'] == 'test-response-123'
  104. assert response['choices'][0]['message']['content'] == 'Test response'
  105. def test_llm_reset():
  106. llm = LLM(LLMConfig(model='gpt-4o-mini', api_key='test_key'))
  107. initial_metrics = copy.deepcopy(llm.metrics)
  108. initial_metrics.add_cost(1.0)
  109. initial_metrics.add_response_latency(0.5, 'test-id')
  110. llm.reset()
  111. assert llm.metrics._accumulated_cost != initial_metrics._accumulated_cost
  112. assert llm.metrics._costs != initial_metrics._costs
  113. assert llm.metrics._response_latencies != initial_metrics._response_latencies
  114. assert isinstance(llm.metrics, Metrics)
  115. @patch('openhands.llm.llm.litellm.get_model_info')
  116. def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
  117. default_config.model = 'openrouter:gpt-4o-mini'
  118. mock_get_model_info.return_value = {
  119. 'max_input_tokens': 7000,
  120. 'max_output_tokens': 1500,
  121. }
  122. llm = LLM(default_config)
  123. llm.init_model_info()
  124. assert llm.config.max_input_tokens == 7000
  125. assert llm.config.max_output_tokens == 1500
  126. mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')
  127. # Tests involving completion and retries
  128. @patch('openhands.llm.llm.litellm_completion')
  129. def test_completion_with_mocked_logger(
  130. mock_litellm_completion, default_config, mock_logger
  131. ):
  132. mock_litellm_completion.return_value = {
  133. 'choices': [{'message': {'content': 'Test response'}}]
  134. }
  135. llm = LLM(config=default_config)
  136. response = llm.completion(
  137. messages=[{'role': 'user', 'content': 'Hello!'}],
  138. stream=False,
  139. )
  140. assert response['choices'][0]['message']['content'] == 'Test response'
  141. assert mock_litellm_completion.call_count == 1
  142. mock_logger.debug.assert_called()
  143. @pytest.mark.parametrize(
  144. 'exception_class,extra_args,expected_retries',
  145. [
  146. (
  147. APIConnectionError,
  148. {'llm_provider': 'test_provider', 'model': 'test_model'},
  149. 2,
  150. ),
  151. (
  152. InternalServerError,
  153. {'llm_provider': 'test_provider', 'model': 'test_model'},
  154. 2,
  155. ),
  156. (
  157. ServiceUnavailableError,
  158. {'llm_provider': 'test_provider', 'model': 'test_model'},
  159. 2,
  160. ),
  161. (RateLimitError, {'llm_provider': 'test_provider', 'model': 'test_model'}, 2),
  162. ],
  163. )
  164. @patch('openhands.llm.llm.litellm_completion')
  165. def test_completion_retries(
  166. mock_litellm_completion,
  167. default_config,
  168. exception_class,
  169. extra_args,
  170. expected_retries,
  171. ):
  172. mock_litellm_completion.side_effect = [
  173. exception_class('Test error message', **extra_args),
  174. {'choices': [{'message': {'content': 'Retry successful'}}]},
  175. ]
  176. llm = LLM(config=default_config)
  177. response = llm.completion(
  178. messages=[{'role': 'user', 'content': 'Hello!'}],
  179. stream=False,
  180. )
  181. assert response['choices'][0]['message']['content'] == 'Retry successful'
  182. assert mock_litellm_completion.call_count == expected_retries
  183. @patch('openhands.llm.llm.litellm_completion')
  184. def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config):
  185. with patch('time.sleep') as mock_sleep:
  186. mock_litellm_completion.side_effect = [
  187. RateLimitError(
  188. 'Rate limit exceeded', llm_provider='test_provider', model='test_model'
  189. ),
  190. {'choices': [{'message': {'content': 'Retry successful'}}]},
  191. ]
  192. llm = LLM(config=default_config)
  193. response = llm.completion(
  194. messages=[{'role': 'user', 'content': 'Hello!'}],
  195. stream=False,
  196. )
  197. assert response['choices'][0]['message']['content'] == 'Retry successful'
  198. assert mock_litellm_completion.call_count == 2
  199. mock_sleep.assert_called_once()
  200. wait_time = mock_sleep.call_args[0][0]
  201. assert (
  202. default_config.retry_min_wait <= wait_time <= default_config.retry_max_wait
  203. ), f'Expected wait time between {default_config.retry_min_wait} and {default_config.retry_max_wait} seconds, but got {wait_time}'
  204. @patch('openhands.llm.llm.litellm_completion')
  205. def test_completion_exhausts_retries(mock_litellm_completion, default_config):
  206. mock_litellm_completion.side_effect = APIConnectionError(
  207. 'Persistent error', llm_provider='test_provider', model='test_model'
  208. )
  209. llm = LLM(config=default_config)
  210. with pytest.raises(APIConnectionError):
  211. llm.completion(
  212. messages=[{'role': 'user', 'content': 'Hello!'}],
  213. stream=False,
  214. )
  215. assert mock_litellm_completion.call_count == llm.config.num_retries
  216. @patch('openhands.llm.llm.litellm_completion')
  217. def test_completion_operation_cancelled(mock_litellm_completion, default_config):
  218. mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
  219. llm = LLM(config=default_config)
  220. with pytest.raises(OperationCancelled):
  221. llm.completion(
  222. messages=[{'role': 'user', 'content': 'Hello!'}],
  223. stream=False,
  224. )
  225. assert mock_litellm_completion.call_count == 1
  226. @patch('openhands.llm.llm.litellm_completion')
  227. def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
  228. def side_effect(*args, **kwargs):
  229. raise KeyboardInterrupt('Simulated KeyboardInterrupt')
  230. mock_litellm_completion.side_effect = side_effect
  231. llm = LLM(config=default_config)
  232. with pytest.raises(OperationCancelled):
  233. try:
  234. llm.completion(
  235. messages=[{'role': 'user', 'content': 'Hello!'}],
  236. stream=False,
  237. )
  238. except KeyboardInterrupt:
  239. raise OperationCancelled('Operation cancelled due to KeyboardInterrupt')
  240. assert mock_litellm_completion.call_count == 1
  241. @patch('openhands.llm.llm.litellm_completion')
  242. def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_config):
  243. global _should_exit
  244. def side_effect(*args, **kwargs):
  245. global _should_exit
  246. _should_exit = True
  247. return {'choices': [{'message': {'content': 'Simulated interrupt response'}}]}
  248. mock_litellm_completion.side_effect = side_effect
  249. llm = LLM(config=default_config)
  250. result = llm.completion(
  251. messages=[{'role': 'user', 'content': 'Hello!'}],
  252. stream=False,
  253. )
  254. assert mock_litellm_completion.call_count == 1
  255. assert result['choices'][0]['message']['content'] == 'Simulated interrupt response'
  256. assert _should_exit
  257. _should_exit = False
  258. @patch('openhands.llm.llm.litellm_completion')
  259. def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
  260. mock_response = {
  261. 'choices': [{'message': {'content': 'This is a mocked response.'}}]
  262. }
  263. mock_litellm_completion.return_value = mock_response
  264. test_llm = LLM(config=default_config)
  265. response = test_llm.completion(
  266. messages=[{'role': 'user', 'content': 'Hello!'}],
  267. stream=False,
  268. drop_params=True,
  269. )
  270. # Assertions
  271. assert response['choices'][0]['message']['content'] == 'This is a mocked response.'
  272. mock_litellm_completion.assert_called_once()
  273. # Check if the correct arguments were passed to litellm_completion
  274. call_args = mock_litellm_completion.call_args[1] # Get keyword arguments
  275. assert call_args['model'] == default_config.model
  276. assert call_args['messages'] == [{'role': 'user', 'content': 'Hello!'}]
  277. assert not call_args['stream']
  278. @patch('openhands.llm.llm.litellm_completion')
  279. def test_completion_with_two_positional_args(mock_litellm_completion, default_config):
  280. mock_response = {
  281. 'choices': [{'message': {'content': 'Response to positional args.'}}]
  282. }
  283. mock_litellm_completion.return_value = mock_response
  284. test_llm = LLM(config=default_config)
  285. response = test_llm.completion(
  286. 'some-model-to-be-ignored',
  287. [{'role': 'user', 'content': 'Hello from positional args!'}],
  288. stream=False,
  289. )
  290. # Assertions
  291. assert (
  292. response['choices'][0]['message']['content'] == 'Response to positional args.'
  293. )
  294. mock_litellm_completion.assert_called_once()
  295. # Check if the correct arguments were passed to litellm_completion
  296. call_args, call_kwargs = mock_litellm_completion.call_args
  297. assert (
  298. call_kwargs['model'] == default_config.model
  299. ) # Should use the model from config, not the first arg
  300. assert call_kwargs['messages'] == [
  301. {'role': 'user', 'content': 'Hello from positional args!'}
  302. ]
  303. assert not call_kwargs['stream']
  304. # Ensure the first positional argument (model) was ignored
  305. assert (
  306. len(call_args) == 0
  307. ) # No positional args should be passed to litellm_completion here
  308. @patch('openhands.llm.llm.litellm_completion')
  309. def test_llm_cloudflare_blockage(mock_litellm_completion, default_config):
  310. from litellm.exceptions import APIError
  311. from openhands.core.exceptions import CloudFlareBlockageError
  312. llm = LLM(default_config)
  313. mock_litellm_completion.side_effect = APIError(
  314. message='Attention Required! | Cloudflare',
  315. llm_provider='test_provider',
  316. model='test_model',
  317. status_code=403,
  318. )
  319. with pytest.raises(CloudFlareBlockageError, match='Request blocked by CloudFlare'):
  320. llm.completion(messages=[{'role': 'user', 'content': 'Hello'}])
  321. # Ensure the completion was called
  322. mock_litellm_completion.assert_called_once()
  323. @patch('openhands.llm.llm.litellm.token_counter')
  324. def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
  325. mock_token_counter.return_value = 42
  326. llm = LLM(default_config)
  327. messages = [{'role': 'user', 'content': 'Hello!'}]
  328. token_count = llm.get_token_count(messages)
  329. assert token_count == 42
  330. mock_token_counter.assert_called_once_with(
  331. model=default_config.model, messages=messages, custom_tokenizer=None
  332. )
  333. @patch('openhands.llm.llm.litellm.token_counter')
  334. def test_get_token_count_with_message_objects(
  335. mock_token_counter, default_config, mock_logger
  336. ):
  337. llm = LLM(default_config)
  338. # Create a Message object and its equivalent dict
  339. message_obj = Message(role='user', content=[TextContent(text='Hello!')])
  340. message_dict = {'role': 'user', 'content': 'Hello!'}
  341. # Mock token counter to return different values for each call
  342. mock_token_counter.side_effect = [42, 42] # Same value for both cases
  343. # Get token counts for both formats
  344. token_count_obj = llm.get_token_count([message_obj])
  345. token_count_dict = llm.get_token_count([message_dict])
  346. # Verify both formats get the same token count
  347. assert token_count_obj == token_count_dict
  348. assert mock_token_counter.call_count == 2
  349. @patch('openhands.llm.llm.litellm.token_counter')
  350. @patch('openhands.llm.llm.create_pretrained_tokenizer')
  351. def test_get_token_count_with_custom_tokenizer(
  352. mock_create_tokenizer, mock_token_counter, default_config
  353. ):
  354. mock_tokenizer = MagicMock()
  355. mock_create_tokenizer.return_value = mock_tokenizer
  356. mock_token_counter.return_value = 42
  357. config = copy.deepcopy(default_config)
  358. config.custom_tokenizer = 'custom/tokenizer'
  359. llm = LLM(config)
  360. messages = [{'role': 'user', 'content': 'Hello!'}]
  361. token_count = llm.get_token_count(messages)
  362. assert token_count == 42
  363. mock_create_tokenizer.assert_called_once_with('custom/tokenizer')
  364. mock_token_counter.assert_called_once_with(
  365. model=config.model, messages=messages, custom_tokenizer=mock_tokenizer
  366. )
  367. @patch('openhands.llm.llm.litellm.token_counter')
  368. def test_get_token_count_error_handling(
  369. mock_token_counter, default_config, mock_logger
  370. ):
  371. mock_token_counter.side_effect = Exception('Token counting failed')
  372. llm = LLM(default_config)
  373. messages = [{'role': 'user', 'content': 'Hello!'}]
  374. token_count = llm.get_token_count(messages)
  375. assert token_count == 0
  376. mock_token_counter.assert_called_once()
  377. mock_logger.error.assert_called_once_with(
  378. 'Error getting token count for\n model gpt-4o\nToken counting failed'
  379. )