test_llm.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. import copy
  2. from unittest.mock import MagicMock, patch
  3. import pytest
  4. from litellm.exceptions import (
  5. APIConnectionError,
  6. InternalServerError,
  7. RateLimitError,
  8. ServiceUnavailableError,
  9. )
  10. from openhands.core.config import LLMConfig
  11. from openhands.core.exceptions import OperationCancelled
  12. from openhands.core.message import Message, TextContent
  13. from openhands.llm.llm import LLM
  14. from openhands.llm.metrics import Metrics
  15. @pytest.fixture(autouse=True)
  16. def mock_logger(monkeypatch):
  17. # suppress logging of completion data to file
  18. mock_logger = MagicMock()
  19. monkeypatch.setattr('openhands.llm.debug_mixin.llm_prompt_logger', mock_logger)
  20. monkeypatch.setattr('openhands.llm.debug_mixin.llm_response_logger', mock_logger)
  21. monkeypatch.setattr('openhands.llm.llm.logger', mock_logger)
  22. return mock_logger
  23. @pytest.fixture
  24. def default_config():
  25. return LLMConfig(
  26. model='gpt-4o',
  27. api_key='test_key',
  28. num_retries=2,
  29. retry_min_wait=1,
  30. retry_max_wait=2,
  31. )
  32. def test_llm_init_with_default_config(default_config):
  33. llm = LLM(default_config)
  34. assert llm.config.model == 'gpt-4o'
  35. assert llm.config.api_key == 'test_key'
  36. assert isinstance(llm.metrics, Metrics)
  37. assert llm.metrics.model_name == 'gpt-4o'
  38. @patch('openhands.llm.llm.litellm.get_model_info')
  39. def test_llm_init_with_model_info(mock_get_model_info, default_config):
  40. mock_get_model_info.return_value = {
  41. 'max_input_tokens': 8000,
  42. 'max_output_tokens': 2000,
  43. }
  44. llm = LLM(default_config)
  45. llm.init_model_info()
  46. assert llm.config.max_input_tokens == 8000
  47. assert llm.config.max_output_tokens == 2000
  48. @patch('openhands.llm.llm.litellm.get_model_info')
  49. def test_llm_init_without_model_info(mock_get_model_info, default_config):
  50. mock_get_model_info.side_effect = Exception('Model info not available')
  51. llm = LLM(default_config)
  52. llm.init_model_info()
  53. assert llm.config.max_input_tokens == 4096
  54. assert llm.config.max_output_tokens == 4096
  55. def test_llm_init_with_custom_config():
  56. custom_config = LLMConfig(
  57. model='custom-model',
  58. api_key='custom_key',
  59. max_input_tokens=5000,
  60. max_output_tokens=1500,
  61. temperature=0.8,
  62. top_p=0.9,
  63. )
  64. llm = LLM(custom_config)
  65. assert llm.config.model == 'custom-model'
  66. assert llm.config.api_key == 'custom_key'
  67. assert llm.config.max_input_tokens == 5000
  68. assert llm.config.max_output_tokens == 1500
  69. assert llm.config.temperature == 0.8
  70. assert llm.config.top_p == 0.9
  71. def test_llm_init_with_metrics():
  72. config = LLMConfig(model='gpt-4o', api_key='test_key')
  73. metrics = Metrics()
  74. llm = LLM(config, metrics=metrics)
  75. assert llm.metrics is metrics
  76. assert (
  77. llm.metrics.model_name == 'default'
  78. ) # because we didn't specify model_name in Metrics init
  79. @patch('openhands.llm.llm.litellm_completion')
  80. @patch('time.time')
  81. def test_response_latency_tracking(mock_time, mock_litellm_completion):
  82. # Mock time.time() to return controlled values
  83. mock_time.side_effect = [1000.0, 1002.5] # Start time, end time (2.5s difference)
  84. # Mock the completion response with a specific ID
  85. mock_response = {
  86. 'id': 'test-response-123',
  87. 'choices': [{'message': {'content': 'Test response'}}],
  88. }
  89. mock_litellm_completion.return_value = mock_response
  90. # Create LLM instance and make a completion call
  91. config = LLMConfig(model='gpt-4o', api_key='test_key')
  92. llm = LLM(config)
  93. response = llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
  94. # Verify the response latency was tracked correctly
  95. assert len(llm.metrics.response_latencies) == 1
  96. latency_record = llm.metrics.response_latencies[0]
  97. assert latency_record.model == 'gpt-4o'
  98. assert (
  99. latency_record.latency == 2.5
  100. ) # Should be the difference between our mocked times
  101. assert latency_record.response_id == 'test-response-123'
  102. # Verify the completion response was returned correctly
  103. assert response['id'] == 'test-response-123'
  104. assert response['choices'][0]['message']['content'] == 'Test response'
  105. # To make sure the metrics fail gracefully, set the start/end time to go backwards.
  106. mock_time.side_effect = [1000.0, 999.0]
  107. llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
  108. # There should now be 2 latencies, the last of which has the value clipped to 0
  109. assert len(llm.metrics.response_latencies) == 2
  110. latency_record = llm.metrics.response_latencies[-1]
  111. assert latency_record.latency == 0.0 # Should be lifted to 0 instead of being -1!
  112. def test_llm_reset():
  113. llm = LLM(LLMConfig(model='gpt-4o-mini', api_key='test_key'))
  114. initial_metrics = copy.deepcopy(llm.metrics)
  115. initial_metrics.add_cost(1.0)
  116. initial_metrics.add_response_latency(0.5, 'test-id')
  117. llm.reset()
  118. assert llm.metrics._accumulated_cost != initial_metrics._accumulated_cost
  119. assert llm.metrics._costs != initial_metrics._costs
  120. assert llm.metrics._response_latencies != initial_metrics._response_latencies
  121. assert isinstance(llm.metrics, Metrics)
  122. @patch('openhands.llm.llm.litellm.get_model_info')
  123. def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
  124. default_config.model = 'openrouter:gpt-4o-mini'
  125. mock_get_model_info.return_value = {
  126. 'max_input_tokens': 7000,
  127. 'max_output_tokens': 1500,
  128. }
  129. llm = LLM(default_config)
  130. llm.init_model_info()
  131. assert llm.config.max_input_tokens == 7000
  132. assert llm.config.max_output_tokens == 1500
  133. mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')
  134. # Tests involving completion and retries
  135. @patch('openhands.llm.llm.litellm_completion')
  136. def test_completion_with_mocked_logger(
  137. mock_litellm_completion, default_config, mock_logger
  138. ):
  139. mock_litellm_completion.return_value = {
  140. 'choices': [{'message': {'content': 'Test response'}}]
  141. }
  142. llm = LLM(config=default_config)
  143. response = llm.completion(
  144. messages=[{'role': 'user', 'content': 'Hello!'}],
  145. stream=False,
  146. )
  147. assert response['choices'][0]['message']['content'] == 'Test response'
  148. assert mock_litellm_completion.call_count == 1
  149. mock_logger.debug.assert_called()
  150. @pytest.mark.parametrize(
  151. 'exception_class,extra_args,expected_retries',
  152. [
  153. (
  154. APIConnectionError,
  155. {'llm_provider': 'test_provider', 'model': 'test_model'},
  156. 2,
  157. ),
  158. (
  159. InternalServerError,
  160. {'llm_provider': 'test_provider', 'model': 'test_model'},
  161. 2,
  162. ),
  163. (
  164. ServiceUnavailableError,
  165. {'llm_provider': 'test_provider', 'model': 'test_model'},
  166. 2,
  167. ),
  168. (RateLimitError, {'llm_provider': 'test_provider', 'model': 'test_model'}, 2),
  169. ],
  170. )
  171. @patch('openhands.llm.llm.litellm_completion')
  172. def test_completion_retries(
  173. mock_litellm_completion,
  174. default_config,
  175. exception_class,
  176. extra_args,
  177. expected_retries,
  178. ):
  179. mock_litellm_completion.side_effect = [
  180. exception_class('Test error message', **extra_args),
  181. {'choices': [{'message': {'content': 'Retry successful'}}]},
  182. ]
  183. llm = LLM(config=default_config)
  184. response = llm.completion(
  185. messages=[{'role': 'user', 'content': 'Hello!'}],
  186. stream=False,
  187. )
  188. assert response['choices'][0]['message']['content'] == 'Retry successful'
  189. assert mock_litellm_completion.call_count == expected_retries
  190. @patch('openhands.llm.llm.litellm_completion')
  191. def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config):
  192. with patch('time.sleep') as mock_sleep:
  193. mock_litellm_completion.side_effect = [
  194. RateLimitError(
  195. 'Rate limit exceeded', llm_provider='test_provider', model='test_model'
  196. ),
  197. {'choices': [{'message': {'content': 'Retry successful'}}]},
  198. ]
  199. llm = LLM(config=default_config)
  200. response = llm.completion(
  201. messages=[{'role': 'user', 'content': 'Hello!'}],
  202. stream=False,
  203. )
  204. assert response['choices'][0]['message']['content'] == 'Retry successful'
  205. assert mock_litellm_completion.call_count == 2
  206. mock_sleep.assert_called_once()
  207. wait_time = mock_sleep.call_args[0][0]
  208. assert (
  209. default_config.retry_min_wait <= wait_time <= default_config.retry_max_wait
  210. ), f'Expected wait time between {default_config.retry_min_wait} and {default_config.retry_max_wait} seconds, but got {wait_time}'
  211. @patch('openhands.llm.llm.litellm_completion')
  212. def test_completion_exhausts_retries(mock_litellm_completion, default_config):
  213. mock_litellm_completion.side_effect = APIConnectionError(
  214. 'Persistent error', llm_provider='test_provider', model='test_model'
  215. )
  216. llm = LLM(config=default_config)
  217. with pytest.raises(APIConnectionError):
  218. llm.completion(
  219. messages=[{'role': 'user', 'content': 'Hello!'}],
  220. stream=False,
  221. )
  222. assert mock_litellm_completion.call_count == llm.config.num_retries
  223. @patch('openhands.llm.llm.litellm_completion')
  224. def test_completion_operation_cancelled(mock_litellm_completion, default_config):
  225. mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
  226. llm = LLM(config=default_config)
  227. with pytest.raises(OperationCancelled):
  228. llm.completion(
  229. messages=[{'role': 'user', 'content': 'Hello!'}],
  230. stream=False,
  231. )
  232. assert mock_litellm_completion.call_count == 1
  233. @patch('openhands.llm.llm.litellm_completion')
  234. def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
  235. def side_effect(*args, **kwargs):
  236. raise KeyboardInterrupt('Simulated KeyboardInterrupt')
  237. mock_litellm_completion.side_effect = side_effect
  238. llm = LLM(config=default_config)
  239. with pytest.raises(OperationCancelled):
  240. try:
  241. llm.completion(
  242. messages=[{'role': 'user', 'content': 'Hello!'}],
  243. stream=False,
  244. )
  245. except KeyboardInterrupt:
  246. raise OperationCancelled('Operation cancelled due to KeyboardInterrupt')
  247. assert mock_litellm_completion.call_count == 1
  248. @patch('openhands.llm.llm.litellm_completion')
  249. def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_config):
  250. global _should_exit
  251. def side_effect(*args, **kwargs):
  252. global _should_exit
  253. _should_exit = True
  254. return {'choices': [{'message': {'content': 'Simulated interrupt response'}}]}
  255. mock_litellm_completion.side_effect = side_effect
  256. llm = LLM(config=default_config)
  257. result = llm.completion(
  258. messages=[{'role': 'user', 'content': 'Hello!'}],
  259. stream=False,
  260. )
  261. assert mock_litellm_completion.call_count == 1
  262. assert result['choices'][0]['message']['content'] == 'Simulated interrupt response'
  263. assert _should_exit
  264. _should_exit = False
  265. @patch('openhands.llm.llm.litellm_completion')
  266. def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
  267. mock_response = {
  268. 'choices': [{'message': {'content': 'This is a mocked response.'}}]
  269. }
  270. mock_litellm_completion.return_value = mock_response
  271. test_llm = LLM(config=default_config)
  272. response = test_llm.completion(
  273. messages=[{'role': 'user', 'content': 'Hello!'}],
  274. stream=False,
  275. drop_params=True,
  276. )
  277. # Assertions
  278. assert response['choices'][0]['message']['content'] == 'This is a mocked response.'
  279. mock_litellm_completion.assert_called_once()
  280. # Check if the correct arguments were passed to litellm_completion
  281. call_args = mock_litellm_completion.call_args[1] # Get keyword arguments
  282. assert call_args['model'] == default_config.model
  283. assert call_args['messages'] == [{'role': 'user', 'content': 'Hello!'}]
  284. assert not call_args['stream']
  285. @patch('openhands.llm.llm.litellm_completion')
  286. def test_completion_with_two_positional_args(mock_litellm_completion, default_config):
  287. mock_response = {
  288. 'choices': [{'message': {'content': 'Response to positional args.'}}]
  289. }
  290. mock_litellm_completion.return_value = mock_response
  291. test_llm = LLM(config=default_config)
  292. response = test_llm.completion(
  293. 'some-model-to-be-ignored',
  294. [{'role': 'user', 'content': 'Hello from positional args!'}],
  295. stream=False,
  296. )
  297. # Assertions
  298. assert (
  299. response['choices'][0]['message']['content'] == 'Response to positional args.'
  300. )
  301. mock_litellm_completion.assert_called_once()
  302. # Check if the correct arguments were passed to litellm_completion
  303. call_args, call_kwargs = mock_litellm_completion.call_args
  304. assert (
  305. call_kwargs['model'] == default_config.model
  306. ) # Should use the model from config, not the first arg
  307. assert call_kwargs['messages'] == [
  308. {'role': 'user', 'content': 'Hello from positional args!'}
  309. ]
  310. assert not call_kwargs['stream']
  311. # Ensure the first positional argument (model) was ignored
  312. assert (
  313. len(call_args) == 0
  314. ) # No positional args should be passed to litellm_completion here
  315. @patch('openhands.llm.llm.litellm_completion')
  316. def test_llm_cloudflare_blockage(mock_litellm_completion, default_config):
  317. from litellm.exceptions import APIError
  318. from openhands.core.exceptions import CloudFlareBlockageError
  319. llm = LLM(default_config)
  320. mock_litellm_completion.side_effect = APIError(
  321. message='Attention Required! | Cloudflare',
  322. llm_provider='test_provider',
  323. model='test_model',
  324. status_code=403,
  325. )
  326. with pytest.raises(CloudFlareBlockageError, match='Request blocked by CloudFlare'):
  327. llm.completion(messages=[{'role': 'user', 'content': 'Hello'}])
  328. # Ensure the completion was called
  329. mock_litellm_completion.assert_called_once()
  330. @patch('openhands.llm.llm.litellm.token_counter')
  331. def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
  332. mock_token_counter.return_value = 42
  333. llm = LLM(default_config)
  334. messages = [{'role': 'user', 'content': 'Hello!'}]
  335. token_count = llm.get_token_count(messages)
  336. assert token_count == 42
  337. mock_token_counter.assert_called_once_with(
  338. model=default_config.model, messages=messages, custom_tokenizer=None
  339. )
  340. @patch('openhands.llm.llm.litellm.token_counter')
  341. def test_get_token_count_with_message_objects(
  342. mock_token_counter, default_config, mock_logger
  343. ):
  344. llm = LLM(default_config)
  345. # Create a Message object and its equivalent dict
  346. message_obj = Message(role='user', content=[TextContent(text='Hello!')])
  347. message_dict = {'role': 'user', 'content': 'Hello!'}
  348. # Mock token counter to return different values for each call
  349. mock_token_counter.side_effect = [42, 42] # Same value for both cases
  350. # Get token counts for both formats
  351. token_count_obj = llm.get_token_count([message_obj])
  352. token_count_dict = llm.get_token_count([message_dict])
  353. # Verify both formats get the same token count
  354. assert token_count_obj == token_count_dict
  355. assert mock_token_counter.call_count == 2
  356. @patch('openhands.llm.llm.litellm.token_counter')
  357. @patch('openhands.llm.llm.create_pretrained_tokenizer')
  358. def test_get_token_count_with_custom_tokenizer(
  359. mock_create_tokenizer, mock_token_counter, default_config
  360. ):
  361. mock_tokenizer = MagicMock()
  362. mock_create_tokenizer.return_value = mock_tokenizer
  363. mock_token_counter.return_value = 42
  364. config = copy.deepcopy(default_config)
  365. config.custom_tokenizer = 'custom/tokenizer'
  366. llm = LLM(config)
  367. messages = [{'role': 'user', 'content': 'Hello!'}]
  368. token_count = llm.get_token_count(messages)
  369. assert token_count == 42
  370. mock_create_tokenizer.assert_called_once_with('custom/tokenizer')
  371. mock_token_counter.assert_called_once_with(
  372. model=config.model, messages=messages, custom_tokenizer=mock_tokenizer
  373. )
  374. @patch('openhands.llm.llm.litellm.token_counter')
  375. def test_get_token_count_error_handling(
  376. mock_token_counter, default_config, mock_logger
  377. ):
  378. mock_token_counter.side_effect = Exception('Token counting failed')
  379. llm = LLM(default_config)
  380. messages = [{'role': 'user', 'content': 'Hello!'}]
  381. token_count = llm.get_token_count(messages)
  382. assert token_count == 0
  383. mock_token_counter.assert_called_once()
  384. mock_logger.error.assert_called_once_with(
  385. 'Error getting token count for\n model gpt-4o\nToken counting failed'
  386. )