llm.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. import asyncio
  2. import copy
  3. import warnings
  4. from functools import partial
  5. from typing import Union
  6. from openhands.core.config import LLMConfig
  7. with warnings.catch_warnings():
  8. warnings.simplefilter('ignore')
  9. import litellm
  10. from litellm import completion as litellm_completion
  11. from litellm import completion_cost as litellm_completion_cost
  12. from litellm.exceptions import (
  13. APIConnectionError,
  14. ContentPolicyViolationError,
  15. InternalServerError,
  16. NotFoundError,
  17. OpenAIError,
  18. RateLimitError,
  19. ServiceUnavailableError,
  20. )
  21. from litellm.types.utils import CostPerToken
  22. from tenacity import (
  23. retry,
  24. retry_if_exception_type,
  25. stop_after_attempt,
  26. wait_exponential,
  27. )
  28. from openhands.core.exceptions import LLMResponseError, UserCancelledError
  29. from openhands.core.logger import llm_prompt_logger, llm_response_logger
  30. from openhands.core.logger import openhands_logger as logger
  31. from openhands.core.message import Message, format_messages
  32. from openhands.core.metrics import Metrics
  33. __all__ = ['LLM']
  34. message_separator = '\n\n----------\n\n'
  35. cache_prompting_supported_models = [
  36. 'claude-3-5-sonnet-20240620',
  37. 'claude-3-haiku-20240307',
  38. ]
  39. class LLM:
  40. """The LLM class represents a Language Model instance.
  41. Attributes:
  42. config: an LLMConfig object specifying the configuration of the LLM.
  43. """
  44. def __init__(
  45. self,
  46. config: LLMConfig,
  47. metrics: Metrics | None = None,
  48. ):
  49. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  50. Passing simple parameters always overrides config.
  51. Args:
  52. config: The LLM configuration
  53. """
  54. self.metrics = metrics if metrics is not None else Metrics()
  55. self.cost_metric_supported = True
  56. self.config = copy.deepcopy(config)
  57. # Set up config attributes with default values to prevent AttributeError
  58. LLMConfig.set_missing_attributes(self.config)
  59. # litellm actually uses base Exception here for unknown model
  60. self.model_info = None
  61. try:
  62. if self.config.model.startswith('openrouter'):
  63. self.model_info = litellm.get_model_info(self.config.model)
  64. else:
  65. self.model_info = litellm.get_model_info(
  66. self.config.model.split(':')[0]
  67. )
  68. # noinspection PyBroadException
  69. except Exception as e:
  70. logger.warning(f'Could not get model info for {config.model}:\n{e}')
  71. # Tuple of exceptions to retry on
  72. self.retry_exceptions = (
  73. APIConnectionError,
  74. ContentPolicyViolationError,
  75. InternalServerError,
  76. OpenAIError,
  77. RateLimitError,
  78. )
  79. # Set the max tokens in an LM-specific way if not set
  80. if self.config.max_input_tokens is None:
  81. if (
  82. self.model_info is not None
  83. and 'max_input_tokens' in self.model_info
  84. and isinstance(self.model_info['max_input_tokens'], int)
  85. ):
  86. self.config.max_input_tokens = self.model_info['max_input_tokens']
  87. else:
  88. # Safe fallback for any potentially viable model
  89. self.config.max_input_tokens = 4096
  90. if self.config.max_output_tokens is None:
  91. # Safe default for any potentially viable model
  92. self.config.max_output_tokens = 4096
  93. if self.model_info is not None:
  94. # max_output_tokens has precedence over max_tokens, if either exists.
  95. # litellm has models with both, one or none of these 2 parameters!
  96. if 'max_output_tokens' in self.model_info and isinstance(
  97. self.model_info['max_output_tokens'], int
  98. ):
  99. self.config.max_output_tokens = self.model_info['max_output_tokens']
  100. elif 'max_tokens' in self.model_info and isinstance(
  101. self.model_info['max_tokens'], int
  102. ):
  103. self.config.max_output_tokens = self.model_info['max_tokens']
  104. if self.config.drop_params:
  105. litellm.drop_params = self.config.drop_params
  106. # This only seems to work with Google as the provider, not with OpenRouter!
  107. gemini_safety_settings = (
  108. [
  109. {
  110. 'category': 'HARM_CATEGORY_HARASSMENT',
  111. 'threshold': 'BLOCK_NONE',
  112. },
  113. {
  114. 'category': 'HARM_CATEGORY_HATE_SPEECH',
  115. 'threshold': 'BLOCK_NONE',
  116. },
  117. {
  118. 'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
  119. 'threshold': 'BLOCK_NONE',
  120. },
  121. {
  122. 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
  123. 'threshold': 'BLOCK_NONE',
  124. },
  125. ]
  126. if self.config.model.lower().startswith('gemini')
  127. else None
  128. )
  129. self._completion = partial(
  130. litellm_completion,
  131. model=self.config.model,
  132. api_key=self.config.api_key,
  133. base_url=self.config.base_url,
  134. api_version=self.config.api_version,
  135. custom_llm_provider=self.config.custom_llm_provider,
  136. max_tokens=self.config.max_output_tokens,
  137. timeout=self.config.timeout,
  138. temperature=self.config.temperature,
  139. top_p=self.config.top_p,
  140. **(
  141. {'safety_settings': gemini_safety_settings}
  142. if gemini_safety_settings is not None
  143. else {}
  144. ),
  145. )
  146. if self.vision_is_active():
  147. logger.debug('LLM: model has vision enabled')
  148. completion_unwrapped = self._completion
  149. def attempt_on_error(retry_state):
  150. """Custom attempt function for litellm completion."""
  151. logger.error(
  152. f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
  153. exc_info=False,
  154. )
  155. return None
  156. def custom_completion_wait(retry_state):
  157. """Custom wait function for litellm completion."""
  158. if not retry_state:
  159. return 0
  160. exception = retry_state.outcome.exception() if retry_state.outcome else None
  161. if exception is None:
  162. return 0
  163. min_wait_time = self.config.retry_min_wait
  164. max_wait_time = self.config.retry_max_wait
  165. # for rate limit errors, wait 1 minute by default, max 4 minutes between retries
  166. exception_type = type(exception).__name__
  167. logger.error(f'\nexception_type: {exception_type}\n')
  168. if exception_type == 'RateLimitError':
  169. min_wait_time = 60
  170. max_wait_time = 240
  171. elif exception_type == 'BadRequestError' and exception.response:
  172. # this should give us the burried, actual error message from
  173. # the LLM model.
  174. logger.error(f'\n\nBadRequestError: {exception.response}\n\n')
  175. # Return the wait time using exponential backoff
  176. exponential_wait = wait_exponential(
  177. multiplier=self.config.retry_multiplier,
  178. min=min_wait_time,
  179. max=max_wait_time,
  180. )
  181. # Call the exponential wait function with retry_state to get the actual wait time
  182. return exponential_wait(retry_state)
  183. @retry(
  184. after=attempt_on_error,
  185. stop=stop_after_attempt(self.config.num_retries),
  186. reraise=True,
  187. retry=retry_if_exception_type(self.retry_exceptions),
  188. wait=custom_completion_wait,
  189. )
  190. def wrapper(*args, **kwargs):
  191. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  192. # some callers might just send the messages directly
  193. if 'messages' in kwargs:
  194. messages = kwargs['messages']
  195. else:
  196. messages = args[1] if len(args) > 1 else []
  197. # this serves to prevent empty messages and logging the messages
  198. debug_message = self._get_debug_message(messages)
  199. if self.is_caching_prompt_active():
  200. # Anthropic-specific prompt caching
  201. if 'claude-3' in self.config.model:
  202. kwargs['extra_headers'] = {
  203. 'anthropic-beta': 'prompt-caching-2024-07-31',
  204. }
  205. # skip if messages is empty (thus debug_message is empty)
  206. if debug_message:
  207. llm_prompt_logger.debug(debug_message)
  208. resp = completion_unwrapped(*args, **kwargs)
  209. else:
  210. logger.debug('No completion messages!')
  211. resp = {'choices': [{'message': {'content': ''}}]}
  212. # log the response
  213. message_back = resp['choices'][0]['message']['content']
  214. if message_back:
  215. llm_response_logger.debug(message_back)
  216. # post-process to log costs
  217. self._post_completion(resp)
  218. return resp
  219. self._completion = wrapper # type: ignore
  220. # Async version
  221. self._async_completion = partial(
  222. self._call_acompletion,
  223. model=self.config.model,
  224. api_key=self.config.api_key,
  225. base_url=self.config.base_url,
  226. api_version=self.config.api_version,
  227. custom_llm_provider=self.config.custom_llm_provider,
  228. max_tokens=self.config.max_output_tokens,
  229. timeout=self.config.timeout,
  230. temperature=self.config.temperature,
  231. top_p=self.config.top_p,
  232. drop_params=True,
  233. **(
  234. {'safety_settings': gemini_safety_settings}
  235. if gemini_safety_settings is not None
  236. else {}
  237. ),
  238. )
  239. async_completion_unwrapped = self._async_completion
  240. @retry(
  241. after=attempt_on_error,
  242. stop=stop_after_attempt(self.config.num_retries),
  243. reraise=True,
  244. retry=retry_if_exception_type(self.retry_exceptions),
  245. wait=custom_completion_wait,
  246. )
  247. async def async_completion_wrapper(*args, **kwargs):
  248. """Async wrapper for the litellm acompletion function."""
  249. # some callers might just send the messages directly
  250. if 'messages' in kwargs:
  251. messages = kwargs['messages']
  252. else:
  253. messages = args[1] if len(args) > 1 else []
  254. # this serves to prevent empty messages and logging the messages
  255. debug_message = self._get_debug_message(messages)
  256. async def check_stopped():
  257. while True:
  258. if (
  259. hasattr(self.config, 'on_cancel_requested_fn')
  260. and self.config.on_cancel_requested_fn is not None
  261. and await self.config.on_cancel_requested_fn()
  262. ):
  263. raise UserCancelledError('LLM request cancelled by user')
  264. await asyncio.sleep(0.1)
  265. stop_check_task = asyncio.create_task(check_stopped())
  266. try:
  267. # Directly call and await litellm_acompletion
  268. if debug_message:
  269. llm_prompt_logger.debug(debug_message)
  270. resp = await async_completion_unwrapped(*args, **kwargs)
  271. else:
  272. logger.debug('No completion messages!')
  273. resp = {'choices': [{'message': {'content': ''}}]}
  274. # skip if messages is empty (thus debug_message is empty)
  275. if debug_message:
  276. message_back = resp['choices'][0]['message']['content']
  277. llm_response_logger.debug(message_back)
  278. else:
  279. resp = {'choices': [{'message': {'content': ''}}]}
  280. self._post_completion(resp)
  281. # We do not support streaming in this method, thus return resp
  282. return resp
  283. except UserCancelledError:
  284. logger.info('LLM request cancelled by user.')
  285. raise
  286. except (
  287. APIConnectionError,
  288. ContentPolicyViolationError,
  289. InternalServerError,
  290. NotFoundError,
  291. OpenAIError,
  292. RateLimitError,
  293. ServiceUnavailableError,
  294. ) as e:
  295. logger.error(f'Completion Error occurred:\n{e}')
  296. raise
  297. finally:
  298. await asyncio.sleep(0.1)
  299. stop_check_task.cancel()
  300. try:
  301. await stop_check_task
  302. except asyncio.CancelledError:
  303. pass
  304. @retry(
  305. after=attempt_on_error,
  306. stop=stop_after_attempt(self.config.num_retries),
  307. reraise=True,
  308. retry=retry_if_exception_type(self.retry_exceptions),
  309. wait=custom_completion_wait,
  310. )
  311. async def async_acompletion_stream_wrapper(*args, **kwargs):
  312. """Async wrapper for the litellm acompletion with streaming function."""
  313. # some callers might just send the messages directly
  314. if 'messages' in kwargs:
  315. messages = kwargs['messages']
  316. else:
  317. messages = args[1] if len(args) > 1 else []
  318. # log the prompt
  319. debug_message = ''
  320. for message in messages:
  321. debug_message += message_separator + message['content']
  322. llm_prompt_logger.debug(debug_message)
  323. try:
  324. # Directly call and await litellm_acompletion
  325. resp = await async_completion_unwrapped(*args, **kwargs)
  326. # For streaming we iterate over the chunks
  327. async for chunk in resp:
  328. # Check for cancellation before yielding the chunk
  329. if (
  330. hasattr(self.config, 'on_cancel_requested_fn')
  331. and self.config.on_cancel_requested_fn is not None
  332. and await self.config.on_cancel_requested_fn()
  333. ):
  334. raise UserCancelledError(
  335. 'LLM request cancelled due to CANCELLED state'
  336. )
  337. # with streaming, it is "delta", not "message"!
  338. message_back = chunk['choices'][0]['delta']['content']
  339. llm_response_logger.debug(message_back)
  340. self._post_completion(chunk)
  341. yield chunk
  342. except UserCancelledError:
  343. logger.info('LLM request cancelled by user.')
  344. raise
  345. except (
  346. APIConnectionError,
  347. ContentPolicyViolationError,
  348. InternalServerError,
  349. NotFoundError,
  350. OpenAIError,
  351. RateLimitError,
  352. ServiceUnavailableError,
  353. ) as e:
  354. logger.error(f'Completion Error occurred:\n{e}')
  355. raise
  356. finally:
  357. if kwargs.get('stream', False):
  358. await asyncio.sleep(0.1)
  359. self._async_completion = async_completion_wrapper # type: ignore
  360. self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
  361. def _get_debug_message(self, messages):
  362. if not messages:
  363. return ''
  364. messages = messages if isinstance(messages, list) else [messages]
  365. return message_separator.join(
  366. self._format_message_content(msg) for msg in messages if msg['content']
  367. )
  368. def _format_message_content(self, message):
  369. content = message['content']
  370. if isinstance(content, list):
  371. return self._format_list_content(content)
  372. return str(content)
  373. def _format_list_content(self, content_list):
  374. return '\n'.join(
  375. self._format_content_element(element) for element in content_list
  376. )
  377. def _format_content_element(self, element):
  378. if isinstance(element, dict):
  379. if 'text' in element:
  380. return element['text']
  381. if (
  382. self.vision_is_active()
  383. and 'image_url' in element
  384. and 'url' in element['image_url']
  385. ):
  386. return element['image_url']['url']
  387. return str(element)
  388. async def _call_acompletion(self, *args, **kwargs):
  389. return await litellm.acompletion(*args, **kwargs)
  390. @property
  391. def completion(self):
  392. """Decorator for the litellm completion function.
  393. Check the complete documentation at https://litellm.vercel.app/docs/completion
  394. """
  395. try:
  396. return self._completion
  397. except Exception as e:
  398. raise LLMResponseError(e)
  399. @property
  400. def async_completion(self):
  401. """Decorator for the async litellm acompletion function.
  402. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  403. """
  404. try:
  405. return self._async_completion
  406. except Exception as e:
  407. raise LLMResponseError(e)
  408. @property
  409. def async_streaming_completion(self):
  410. """Decorator for the async litellm acompletion function with streaming.
  411. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  412. """
  413. try:
  414. return self._async_streaming_completion
  415. except Exception as e:
  416. raise LLMResponseError(e)
  417. def vision_is_active(self):
  418. return not self.config.disable_vision and self._supports_vision()
  419. def _supports_vision(self):
  420. """Acquire from litellm if model is vision capable.
  421. Returns:
  422. bool: True if model is vision capable. If model is not supported by litellm, it will return False.
  423. """
  424. try:
  425. return litellm.supports_vision(self.config.model)
  426. except Exception:
  427. return False
  428. def is_caching_prompt_active(self) -> bool:
  429. """Check if prompt caching is enabled and supported for current model.
  430. Returns:
  431. boolean: True if prompt caching is active for the given model.
  432. """
  433. return self.config.caching_prompt is True and any(
  434. model in self.config.model for model in cache_prompting_supported_models
  435. )
  436. def _post_completion(self, response) -> None:
  437. """Post-process the completion response."""
  438. try:
  439. cur_cost = self.completion_cost(response)
  440. except Exception:
  441. cur_cost = 0
  442. stats = ''
  443. if self.cost_metric_supported:
  444. stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
  445. cur_cost,
  446. self.metrics.accumulated_cost,
  447. )
  448. usage = response.get('usage')
  449. if usage:
  450. input_tokens = usage.get('prompt_tokens')
  451. output_tokens = usage.get('completion_tokens')
  452. if input_tokens:
  453. stats += 'Input tokens: ' + str(input_tokens) + '\n'
  454. if output_tokens:
  455. stats += 'Output tokens: ' + str(output_tokens) + '\n'
  456. model_extra = usage.get('model_extra', {})
  457. cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
  458. if cache_creation_input_tokens:
  459. stats += (
  460. 'Input tokens (cache write): '
  461. + str(cache_creation_input_tokens)
  462. + '\n'
  463. )
  464. cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
  465. if cache_read_input_tokens:
  466. stats += (
  467. 'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
  468. )
  469. if stats:
  470. logger.info(stats)
  471. def get_token_count(self, messages):
  472. """Get the number of tokens in a list of messages.
  473. Args:
  474. messages (list): A list of messages.
  475. Returns:
  476. int: The number of tokens.
  477. """
  478. try:
  479. return litellm.token_counter(model=self.config.model, messages=messages)
  480. except Exception:
  481. # TODO: this is to limit logspam in case token count is not supported
  482. return 0
  483. def is_local(self):
  484. """Determines if the system is using a locally running LLM.
  485. Returns:
  486. boolean: True if executing a local model.
  487. """
  488. if self.config.base_url is not None:
  489. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  490. if substring in self.config.base_url:
  491. return True
  492. elif self.config.model is not None:
  493. if self.config.model.startswith('ollama'):
  494. return True
  495. return False
  496. def completion_cost(self, response):
  497. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  498. Add the current cost into total cost in metrics.
  499. Args:
  500. response: A response from a model invocation.
  501. Returns:
  502. number: The cost of the response.
  503. """
  504. if not self.cost_metric_supported:
  505. return 0.0
  506. extra_kwargs = {}
  507. if (
  508. self.config.input_cost_per_token is not None
  509. and self.config.output_cost_per_token is not None
  510. ):
  511. cost_per_token = CostPerToken(
  512. input_cost_per_token=self.config.input_cost_per_token,
  513. output_cost_per_token=self.config.output_cost_per_token,
  514. )
  515. logger.info(f'Using custom cost per token: {cost_per_token}')
  516. extra_kwargs['custom_cost_per_token'] = cost_per_token
  517. if not self.is_local():
  518. try:
  519. cost = litellm_completion_cost(
  520. completion_response=response, **extra_kwargs
  521. )
  522. self.metrics.add_cost(cost)
  523. return cost
  524. except Exception:
  525. self.cost_metric_supported = False
  526. logger.warning('Cost calculation not supported for this model.')
  527. return 0.0
  528. def __str__(self):
  529. if self.config.api_version:
  530. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  531. elif self.config.base_url:
  532. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  533. return f'LLM(model={self.config.model})'
  534. def __repr__(self):
  535. return str(self)
  536. def reset(self):
  537. self.metrics = Metrics()
  538. def format_messages_for_llm(
  539. self, messages: Union[Message, list[Message]]
  540. ) -> list[dict]:
  541. return format_messages(
  542. messages, self.vision_is_active(), self.is_caching_prompt_active()
  543. )