llm.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686
  1. import asyncio
  2. import copy
  3. import os
  4. import time
  5. import warnings
  6. from functools import partial
  7. from typing import Any
  8. from openhands.core.config import LLMConfig
  9. from openhands.runtime.utils.shutdown_listener import should_continue
  10. with warnings.catch_warnings():
  11. warnings.simplefilter('ignore')
  12. import litellm
  13. from litellm import completion as litellm_completion
  14. from litellm import completion_cost as litellm_completion_cost
  15. from litellm.exceptions import (
  16. APIConnectionError,
  17. ContentPolicyViolationError,
  18. InternalServerError,
  19. NotFoundError,
  20. OpenAIError,
  21. RateLimitError,
  22. ServiceUnavailableError,
  23. )
  24. from litellm.types.utils import CostPerToken
  25. from tenacity import (
  26. retry,
  27. retry_if_exception_type,
  28. retry_if_not_exception_type,
  29. stop_after_attempt,
  30. wait_exponential,
  31. )
  32. from openhands.core.exceptions import (
  33. LLMResponseError,
  34. OperationCancelled,
  35. UserCancelledError,
  36. )
  37. from openhands.core.logger import llm_prompt_logger, llm_response_logger
  38. from openhands.core.logger import openhands_logger as logger
  39. from openhands.core.message import Message
  40. from openhands.core.metrics import Metrics
  41. from openhands.runtime.utils.shutdown_listener import should_exit
  42. __all__ = ['LLM']
  43. message_separator = '\n\n----------\n\n'
  44. cache_prompting_supported_models = [
  45. 'claude-3-5-sonnet-20240620',
  46. 'claude-3-haiku-20240307',
  47. ]
  48. class LLM:
  49. """The LLM class represents a Language Model instance.
  50. Attributes:
  51. config: an LLMConfig object specifying the configuration of the LLM.
  52. """
  53. def __init__(
  54. self,
  55. config: LLMConfig,
  56. metrics: Metrics | None = None,
  57. ):
  58. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  59. Passing simple parameters always overrides config.
  60. Args:
  61. config: The LLM configuration
  62. """
  63. self.metrics = metrics if metrics is not None else Metrics()
  64. self.cost_metric_supported = True
  65. self.config = copy.deepcopy(config)
  66. os.environ['OR_SITE_URL'] = self.config.openrouter_site_url
  67. os.environ['OR_APP_NAME'] = self.config.openrouter_app_name
  68. # list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
  69. # - 'messages': list of messages
  70. # - 'response': response from the LLM
  71. self.llm_completions: list[dict[str, Any]] = []
  72. # Set up config attributes with default values to prevent AttributeError
  73. LLMConfig.set_missing_attributes(self.config)
  74. # litellm actually uses base Exception here for unknown model
  75. self.model_info = None
  76. try:
  77. if self.config.model.startswith('openrouter'):
  78. self.model_info = litellm.get_model_info(self.config.model)
  79. else:
  80. self.model_info = litellm.get_model_info(
  81. self.config.model.split(':')[0]
  82. )
  83. # noinspection PyBroadException
  84. except Exception as e:
  85. logger.warning(f'Could not get model info for {config.model}:\n{e}')
  86. # Tuple of exceptions to retry on
  87. self.retry_exceptions = (
  88. APIConnectionError,
  89. ContentPolicyViolationError,
  90. InternalServerError,
  91. OpenAIError,
  92. RateLimitError,
  93. )
  94. # Set the max tokens in an LM-specific way if not set
  95. if self.config.max_input_tokens is None:
  96. if (
  97. self.model_info is not None
  98. and 'max_input_tokens' in self.model_info
  99. and isinstance(self.model_info['max_input_tokens'], int)
  100. ):
  101. self.config.max_input_tokens = self.model_info['max_input_tokens']
  102. else:
  103. # Safe fallback for any potentially viable model
  104. self.config.max_input_tokens = 4096
  105. if self.config.max_output_tokens is None:
  106. # Safe default for any potentially viable model
  107. self.config.max_output_tokens = 4096
  108. if self.model_info is not None:
  109. # max_output_tokens has precedence over max_tokens, if either exists.
  110. # litellm has models with both, one or none of these 2 parameters!
  111. if 'max_output_tokens' in self.model_info and isinstance(
  112. self.model_info['max_output_tokens'], int
  113. ):
  114. self.config.max_output_tokens = self.model_info['max_output_tokens']
  115. elif 'max_tokens' in self.model_info and isinstance(
  116. self.model_info['max_tokens'], int
  117. ):
  118. self.config.max_output_tokens = self.model_info['max_tokens']
  119. # This only seems to work with Google as the provider, not with OpenRouter!
  120. gemini_safety_settings = (
  121. [
  122. {
  123. 'category': 'HARM_CATEGORY_HARASSMENT',
  124. 'threshold': 'BLOCK_NONE',
  125. },
  126. {
  127. 'category': 'HARM_CATEGORY_HATE_SPEECH',
  128. 'threshold': 'BLOCK_NONE',
  129. },
  130. {
  131. 'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
  132. 'threshold': 'BLOCK_NONE',
  133. },
  134. {
  135. 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
  136. 'threshold': 'BLOCK_NONE',
  137. },
  138. ]
  139. if self.config.model.lower().startswith('gemini')
  140. else None
  141. )
  142. self._completion = partial(
  143. litellm_completion,
  144. model=self.config.model,
  145. api_key=self.config.api_key,
  146. base_url=self.config.base_url,
  147. api_version=self.config.api_version,
  148. custom_llm_provider=self.config.custom_llm_provider,
  149. max_tokens=self.config.max_output_tokens,
  150. timeout=self.config.timeout,
  151. temperature=self.config.temperature,
  152. top_p=self.config.top_p,
  153. drop_params=self.config.drop_params,
  154. **(
  155. {'safety_settings': gemini_safety_settings}
  156. if gemini_safety_settings is not None
  157. else {}
  158. ),
  159. )
  160. if self.vision_is_active():
  161. logger.debug('LLM: model has vision enabled')
  162. completion_unwrapped = self._completion
  163. def log_retry_attempt(retry_state):
  164. """With before_sleep, this is called before `custom_completion_wait` and
  165. ONLY if the retry is triggered by an exception."""
  166. if should_exit():
  167. raise OperationCancelled(
  168. 'Operation cancelled.'
  169. ) # exits the @retry loop
  170. exception = retry_state.outcome.exception()
  171. logger.error(
  172. f'{exception}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
  173. exc_info=False,
  174. )
  175. def custom_completion_wait(retry_state):
  176. """Custom wait function for litellm completion."""
  177. if not retry_state:
  178. return 0
  179. exception = retry_state.outcome.exception() if retry_state.outcome else None
  180. if exception is None:
  181. return 0
  182. min_wait_time = self.config.retry_min_wait
  183. max_wait_time = self.config.retry_max_wait
  184. # for rate limit errors, wait 1 minute by default, max 4 minutes between retries
  185. exception_type = type(exception).__name__
  186. logger.error(f'\nexception_type: {exception_type}\n')
  187. if exception_type == 'RateLimitError':
  188. min_wait_time = 60
  189. max_wait_time = 240
  190. elif exception_type == 'BadRequestError' and exception.response:
  191. # this should give us the burried, actual error message from
  192. # the LLM model.
  193. logger.error(f'\n\nBadRequestError: {exception.response}\n\n')
  194. # Return the wait time using exponential backoff
  195. exponential_wait = wait_exponential(
  196. multiplier=self.config.retry_multiplier,
  197. min=min_wait_time,
  198. max=max_wait_time,
  199. )
  200. # Call the exponential wait function with retry_state to get the actual wait time
  201. return exponential_wait(retry_state)
  202. @retry(
  203. before_sleep=log_retry_attempt,
  204. stop=stop_after_attempt(self.config.num_retries),
  205. reraise=True,
  206. retry=(
  207. retry_if_exception_type(self.retry_exceptions)
  208. & retry_if_not_exception_type(OperationCancelled)
  209. ),
  210. wait=custom_completion_wait,
  211. )
  212. def wrapper(*args, **kwargs):
  213. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  214. # some callers might just send the messages directly
  215. if 'messages' in kwargs:
  216. messages = kwargs['messages']
  217. else:
  218. messages = args[1] if len(args) > 1 else []
  219. # this serves to prevent empty messages and logging the messages
  220. debug_message = self._get_debug_message(messages)
  221. if self.is_caching_prompt_active():
  222. # Anthropic-specific prompt caching
  223. if 'claude-3' in self.config.model:
  224. kwargs['extra_headers'] = {
  225. 'anthropic-beta': 'prompt-caching-2024-07-31',
  226. }
  227. # skip if messages is empty (thus debug_message is empty)
  228. if debug_message:
  229. llm_prompt_logger.debug(debug_message)
  230. resp = completion_unwrapped(*args, **kwargs)
  231. else:
  232. logger.debug('No completion messages!')
  233. resp = {'choices': [{'message': {'content': ''}}]}
  234. if self.config.log_completions:
  235. self.llm_completions.append(
  236. {
  237. 'messages': messages,
  238. 'response': resp,
  239. 'timestamp': time.time(),
  240. 'cost': self.completion_cost(resp),
  241. }
  242. )
  243. # log the response
  244. message_back = resp['choices'][0]['message']['content']
  245. if message_back:
  246. llm_response_logger.debug(message_back)
  247. # post-process to log costs
  248. self._post_completion(resp)
  249. return resp
  250. self._completion = wrapper # type: ignore
  251. # Async version
  252. self._async_completion = partial(
  253. self._call_acompletion,
  254. model=self.config.model,
  255. api_key=self.config.api_key,
  256. base_url=self.config.base_url,
  257. api_version=self.config.api_version,
  258. custom_llm_provider=self.config.custom_llm_provider,
  259. max_tokens=self.config.max_output_tokens,
  260. timeout=self.config.timeout,
  261. temperature=self.config.temperature,
  262. top_p=self.config.top_p,
  263. drop_params=self.config.drop_params,
  264. **(
  265. {'safety_settings': gemini_safety_settings}
  266. if gemini_safety_settings is not None
  267. else {}
  268. ),
  269. )
  270. async_completion_unwrapped = self._async_completion
  271. @retry(
  272. before_sleep=log_retry_attempt,
  273. stop=stop_after_attempt(self.config.num_retries),
  274. reraise=True,
  275. retry=(
  276. retry_if_exception_type(self.retry_exceptions)
  277. & retry_if_not_exception_type(OperationCancelled)
  278. ),
  279. wait=custom_completion_wait,
  280. )
  281. async def async_completion_wrapper(*args, **kwargs):
  282. """Async wrapper for the litellm acompletion function."""
  283. # some callers might just send the messages directly
  284. if 'messages' in kwargs:
  285. messages = kwargs['messages']
  286. else:
  287. messages = args[1] if len(args) > 1 else []
  288. # this serves to prevent empty messages and logging the messages
  289. debug_message = self._get_debug_message(messages)
  290. async def check_stopped():
  291. while should_continue():
  292. if (
  293. hasattr(self.config, 'on_cancel_requested_fn')
  294. and self.config.on_cancel_requested_fn is not None
  295. and await self.config.on_cancel_requested_fn()
  296. ):
  297. raise UserCancelledError('LLM request cancelled by user')
  298. await asyncio.sleep(0.1)
  299. stop_check_task = asyncio.create_task(check_stopped())
  300. try:
  301. # Directly call and await litellm_acompletion
  302. if debug_message:
  303. llm_prompt_logger.debug(debug_message)
  304. resp = await async_completion_unwrapped(*args, **kwargs)
  305. else:
  306. logger.debug('No completion messages!')
  307. resp = {'choices': [{'message': {'content': ''}}]}
  308. # skip if messages is empty (thus debug_message is empty)
  309. if debug_message:
  310. message_back = resp['choices'][0]['message']['content']
  311. llm_response_logger.debug(message_back)
  312. else:
  313. resp = {'choices': [{'message': {'content': ''}}]}
  314. self._post_completion(resp)
  315. # We do not support streaming in this method, thus return resp
  316. return resp
  317. except UserCancelledError:
  318. logger.info('LLM request cancelled by user.')
  319. raise
  320. except (
  321. APIConnectionError,
  322. ContentPolicyViolationError,
  323. InternalServerError,
  324. NotFoundError,
  325. OpenAIError,
  326. RateLimitError,
  327. ServiceUnavailableError,
  328. ) as e:
  329. logger.error(f'Completion Error occurred:\n{e}')
  330. raise
  331. finally:
  332. await asyncio.sleep(0.1)
  333. stop_check_task.cancel()
  334. try:
  335. await stop_check_task
  336. except asyncio.CancelledError:
  337. pass
  338. @retry(
  339. before_sleep=log_retry_attempt,
  340. stop=stop_after_attempt(self.config.num_retries),
  341. reraise=True,
  342. retry=(
  343. retry_if_exception_type(self.retry_exceptions)
  344. & retry_if_not_exception_type(OperationCancelled)
  345. ),
  346. wait=custom_completion_wait,
  347. )
  348. async def async_acompletion_stream_wrapper(*args, **kwargs):
  349. """Async wrapper for the litellm acompletion with streaming function."""
  350. # some callers might just send the messages directly
  351. if 'messages' in kwargs:
  352. messages = kwargs['messages']
  353. else:
  354. messages = args[1] if len(args) > 1 else []
  355. # log the prompt
  356. debug_message = ''
  357. for message in messages:
  358. debug_message += message_separator + message['content']
  359. llm_prompt_logger.debug(debug_message)
  360. try:
  361. # Directly call and await litellm_acompletion
  362. resp = await async_completion_unwrapped(*args, **kwargs)
  363. # For streaming we iterate over the chunks
  364. async for chunk in resp:
  365. # Check for cancellation before yielding the chunk
  366. if (
  367. hasattr(self.config, 'on_cancel_requested_fn')
  368. and self.config.on_cancel_requested_fn is not None
  369. and await self.config.on_cancel_requested_fn()
  370. ):
  371. raise UserCancelledError(
  372. 'LLM request cancelled due to CANCELLED state'
  373. )
  374. # with streaming, it is "delta", not "message"!
  375. message_back = chunk['choices'][0]['delta']['content']
  376. llm_response_logger.debug(message_back)
  377. self._post_completion(chunk)
  378. yield chunk
  379. except UserCancelledError:
  380. logger.info('LLM request cancelled by user.')
  381. raise
  382. except (
  383. APIConnectionError,
  384. ContentPolicyViolationError,
  385. InternalServerError,
  386. NotFoundError,
  387. OpenAIError,
  388. RateLimitError,
  389. ServiceUnavailableError,
  390. ) as e:
  391. logger.error(f'Completion Error occurred:\n{e}')
  392. raise
  393. finally:
  394. if kwargs.get('stream', False):
  395. await asyncio.sleep(0.1)
  396. self._async_completion = async_completion_wrapper # type: ignore
  397. self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
  398. def _get_debug_message(self, messages):
  399. if not messages:
  400. return ''
  401. messages = messages if isinstance(messages, list) else [messages]
  402. return message_separator.join(
  403. self._format_message_content(msg) for msg in messages if msg['content']
  404. )
  405. def _format_message_content(self, message):
  406. content = message['content']
  407. if isinstance(content, list):
  408. return self._format_list_content(content)
  409. return str(content)
  410. def _format_list_content(self, content_list):
  411. return '\n'.join(
  412. self._format_content_element(element) for element in content_list
  413. )
  414. def _format_content_element(self, element):
  415. if isinstance(element, dict):
  416. if 'text' in element:
  417. return element['text']
  418. if (
  419. self.vision_is_active()
  420. and 'image_url' in element
  421. and 'url' in element['image_url']
  422. ):
  423. return element['image_url']['url']
  424. return str(element)
  425. async def _call_acompletion(self, *args, **kwargs):
  426. """This is a wrapper for the litellm acompletion function which
  427. makes it mockable for testing.
  428. """
  429. return await litellm.acompletion(*args, **kwargs)
  430. @property
  431. def completion(self):
  432. """Decorator for the litellm completion function.
  433. Check the complete documentation at https://litellm.vercel.app/docs/completion
  434. """
  435. try:
  436. return self._completion
  437. except Exception as e:
  438. raise LLMResponseError(e)
  439. @property
  440. def async_completion(self):
  441. """Decorator for the async litellm acompletion function.
  442. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  443. """
  444. try:
  445. return self._async_completion
  446. except Exception as e:
  447. raise LLMResponseError(e)
  448. @property
  449. def async_streaming_completion(self):
  450. """Decorator for the async litellm acompletion function with streaming.
  451. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  452. """
  453. try:
  454. return self._async_streaming_completion
  455. except Exception as e:
  456. raise LLMResponseError(e)
  457. def vision_is_active(self):
  458. return not self.config.disable_vision and self._supports_vision()
  459. def _supports_vision(self):
  460. """Acquire from litellm if model is vision capable.
  461. Returns:
  462. bool: True if model is vision capable. If model is not supported by litellm, it will return False.
  463. """
  464. try:
  465. return litellm.supports_vision(self.config.model)
  466. except Exception:
  467. return False
  468. def is_caching_prompt_active(self) -> bool:
  469. """Check if prompt caching is enabled and supported for current model.
  470. Returns:
  471. boolean: True if prompt caching is active for the given model.
  472. """
  473. return self.config.caching_prompt is True and any(
  474. model in self.config.model for model in cache_prompting_supported_models
  475. )
  476. def _post_completion(self, response) -> None:
  477. """Post-process the completion response."""
  478. try:
  479. cur_cost = self.completion_cost(response)
  480. except Exception:
  481. cur_cost = 0
  482. stats = ''
  483. if self.cost_metric_supported:
  484. stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
  485. cur_cost,
  486. self.metrics.accumulated_cost,
  487. )
  488. usage = response.get('usage')
  489. if usage:
  490. input_tokens = usage.get('prompt_tokens')
  491. output_tokens = usage.get('completion_tokens')
  492. if input_tokens:
  493. stats += 'Input tokens: ' + str(input_tokens)
  494. if output_tokens:
  495. stats += (
  496. (' | ' if input_tokens else '')
  497. + 'Output tokens: '
  498. + str(output_tokens)
  499. + '\n'
  500. )
  501. model_extra = usage.get('model_extra', {})
  502. cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
  503. if cache_creation_input_tokens:
  504. stats += (
  505. 'Input tokens (cache write): '
  506. + str(cache_creation_input_tokens)
  507. + '\n'
  508. )
  509. cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
  510. if cache_read_input_tokens:
  511. stats += (
  512. 'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
  513. )
  514. if stats:
  515. logger.info(stats)
  516. def get_token_count(self, messages):
  517. """Get the number of tokens in a list of messages.
  518. Args:
  519. messages (list): A list of messages.
  520. Returns:
  521. int: The number of tokens.
  522. """
  523. try:
  524. return litellm.token_counter(model=self.config.model, messages=messages)
  525. except Exception:
  526. # TODO: this is to limit logspam in case token count is not supported
  527. return 0
  528. def is_local(self):
  529. """Determines if the system is using a locally running LLM.
  530. Returns:
  531. boolean: True if executing a local model.
  532. """
  533. if self.config.base_url is not None:
  534. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  535. if substring in self.config.base_url:
  536. return True
  537. elif self.config.model is not None:
  538. if self.config.model.startswith('ollama'):
  539. return True
  540. return False
  541. def completion_cost(self, response):
  542. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  543. Add the current cost into total cost in metrics.
  544. Args:
  545. response: A response from a model invocation.
  546. Returns:
  547. number: The cost of the response.
  548. """
  549. if not self.cost_metric_supported:
  550. return 0.0
  551. extra_kwargs = {}
  552. if (
  553. self.config.input_cost_per_token is not None
  554. and self.config.output_cost_per_token is not None
  555. ):
  556. cost_per_token = CostPerToken(
  557. input_cost_per_token=self.config.input_cost_per_token,
  558. output_cost_per_token=self.config.output_cost_per_token,
  559. )
  560. logger.info(f'Using custom cost per token: {cost_per_token}')
  561. extra_kwargs['custom_cost_per_token'] = cost_per_token
  562. if not self.is_local():
  563. try:
  564. cost = litellm_completion_cost(
  565. completion_response=response, **extra_kwargs
  566. )
  567. self.metrics.add_cost(cost)
  568. return cost
  569. except Exception:
  570. self.cost_metric_supported = False
  571. logger.warning('Cost calculation not supported for this model.')
  572. return 0.0
  573. def __str__(self):
  574. if self.config.api_version:
  575. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  576. elif self.config.base_url:
  577. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  578. return f'LLM(model={self.config.model})'
  579. def __repr__(self):
  580. return str(self)
  581. def reset(self):
  582. self.metrics = Metrics()
  583. self.llm_completions = []
  584. def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
  585. if isinstance(messages, Message):
  586. return [messages.model_dump()]
  587. return [message.model_dump() for message in messages]