llm.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. import asyncio
  2. import copy
  3. import os
  4. import time
  5. import warnings
  6. from functools import partial
  7. from typing import Any
  8. from openhands.core.config import LLMConfig
  9. from openhands.runtime.utils.shutdown_listener import should_continue
  10. with warnings.catch_warnings():
  11. warnings.simplefilter('ignore')
  12. import litellm
  13. from litellm import completion as litellm_completion
  14. from litellm import completion_cost as litellm_completion_cost
  15. from litellm.exceptions import (
  16. APIConnectionError,
  17. ContentPolicyViolationError,
  18. InternalServerError,
  19. NotFoundError,
  20. OpenAIError,
  21. RateLimitError,
  22. ServiceUnavailableError,
  23. )
  24. from litellm.types.utils import CostPerToken
  25. from tenacity import (
  26. retry,
  27. retry_if_exception_type,
  28. retry_if_not_exception_type,
  29. stop_after_attempt,
  30. wait_exponential,
  31. )
  32. from openhands.core.exceptions import (
  33. OperationCancelled,
  34. UserCancelledError,
  35. )
  36. from openhands.core.logger import llm_prompt_logger, llm_response_logger
  37. from openhands.core.logger import openhands_logger as logger
  38. from openhands.core.message import Message
  39. from openhands.core.metrics import Metrics
  40. from openhands.runtime.utils.shutdown_listener import should_exit
  41. __all__ = ['LLM']
  42. message_separator = '\n\n----------\n\n'
  43. cache_prompting_supported_models = [
  44. 'claude-3-5-sonnet-20240620',
  45. 'claude-3-haiku-20240307',
  46. ]
  47. class LLM:
  48. """The LLM class represents a Language Model instance.
  49. Attributes:
  50. config: an LLMConfig object specifying the configuration of the LLM.
  51. """
  52. def __init__(
  53. self,
  54. config: LLMConfig,
  55. metrics: Metrics | None = None,
  56. ):
  57. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  58. Passing simple parameters always overrides config.
  59. Args:
  60. config: The LLM configuration
  61. """
  62. self.metrics = metrics if metrics is not None else Metrics()
  63. self.cost_metric_supported = True
  64. self.config = copy.deepcopy(config)
  65. os.environ['OR_SITE_URL'] = self.config.openrouter_site_url
  66. os.environ['OR_APP_NAME'] = self.config.openrouter_app_name
  67. # list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
  68. # - 'messages': list of messages
  69. # - 'response': response from the LLM
  70. self.llm_completions: list[dict[str, Any]] = []
  71. # Set up config attributes with default values to prevent AttributeError
  72. LLMConfig.set_missing_attributes(self.config)
  73. # litellm actually uses base Exception here for unknown model
  74. self.model_info = None
  75. try:
  76. if self.config.model.startswith('openrouter'):
  77. self.model_info = litellm.get_model_info(self.config.model)
  78. else:
  79. self.model_info = litellm.get_model_info(
  80. self.config.model.split(':')[0]
  81. )
  82. # noinspection PyBroadException
  83. except Exception as e:
  84. logger.warning(f'Could not get model info for {config.model}:\n{e}')
  85. # Tuple of exceptions to retry on
  86. self.retry_exceptions = (
  87. APIConnectionError,
  88. ContentPolicyViolationError,
  89. InternalServerError,
  90. OpenAIError,
  91. RateLimitError,
  92. )
  93. # Set the max tokens in an LM-specific way if not set
  94. if self.config.max_input_tokens is None:
  95. if (
  96. self.model_info is not None
  97. and 'max_input_tokens' in self.model_info
  98. and isinstance(self.model_info['max_input_tokens'], int)
  99. ):
  100. self.config.max_input_tokens = self.model_info['max_input_tokens']
  101. else:
  102. # Safe fallback for any potentially viable model
  103. self.config.max_input_tokens = 4096
  104. if self.config.max_output_tokens is None:
  105. # Safe default for any potentially viable model
  106. self.config.max_output_tokens = 4096
  107. if self.model_info is not None:
  108. # max_output_tokens has precedence over max_tokens, if either exists.
  109. # litellm has models with both, one or none of these 2 parameters!
  110. if 'max_output_tokens' in self.model_info and isinstance(
  111. self.model_info['max_output_tokens'], int
  112. ):
  113. self.config.max_output_tokens = self.model_info['max_output_tokens']
  114. elif 'max_tokens' in self.model_info and isinstance(
  115. self.model_info['max_tokens'], int
  116. ):
  117. self.config.max_output_tokens = self.model_info['max_tokens']
  118. # This only seems to work with Google as the provider, not with OpenRouter!
  119. gemini_safety_settings = (
  120. [
  121. {
  122. 'category': 'HARM_CATEGORY_HARASSMENT',
  123. 'threshold': 'BLOCK_NONE',
  124. },
  125. {
  126. 'category': 'HARM_CATEGORY_HATE_SPEECH',
  127. 'threshold': 'BLOCK_NONE',
  128. },
  129. {
  130. 'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
  131. 'threshold': 'BLOCK_NONE',
  132. },
  133. {
  134. 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
  135. 'threshold': 'BLOCK_NONE',
  136. },
  137. ]
  138. if self.config.model.lower().startswith('gemini')
  139. else None
  140. )
  141. self._completion = partial(
  142. litellm_completion,
  143. model=self.config.model,
  144. api_key=self.config.api_key,
  145. base_url=self.config.base_url,
  146. api_version=self.config.api_version,
  147. custom_llm_provider=self.config.custom_llm_provider,
  148. max_tokens=self.config.max_output_tokens,
  149. timeout=self.config.timeout,
  150. temperature=self.config.temperature,
  151. top_p=self.config.top_p,
  152. drop_params=self.config.drop_params,
  153. **(
  154. {'safety_settings': gemini_safety_settings}
  155. if gemini_safety_settings is not None
  156. else {}
  157. ),
  158. )
  159. if self.vision_is_active():
  160. logger.debug('LLM: model has vision enabled')
  161. completion_unwrapped = self._completion
  162. def log_retry_attempt(retry_state):
  163. """With before_sleep, this is called before `custom_completion_wait` and
  164. ONLY if the retry is triggered by an exception."""
  165. if should_exit():
  166. raise OperationCancelled(
  167. 'Operation cancelled.'
  168. ) # exits the @retry loop
  169. exception = retry_state.outcome.exception()
  170. logger.error(
  171. f'{exception}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
  172. exc_info=False,
  173. )
  174. def custom_completion_wait(retry_state):
  175. """Custom wait function for litellm completion."""
  176. if not retry_state:
  177. return 0
  178. exception = retry_state.outcome.exception() if retry_state.outcome else None
  179. if exception is None:
  180. return 0
  181. min_wait_time = self.config.retry_min_wait
  182. max_wait_time = self.config.retry_max_wait
  183. # for rate limit errors, wait 1 minute by default, max 4 minutes between retries
  184. exception_type = type(exception).__name__
  185. logger.error(f'\nexception_type: {exception_type}\n')
  186. if exception_type == 'RateLimitError':
  187. min_wait_time = 60
  188. max_wait_time = 240
  189. elif exception_type == 'BadRequestError' and exception.response:
  190. # this should give us the burried, actual error message from
  191. # the LLM model.
  192. logger.error(f'\n\nBadRequestError: {exception.response}\n\n')
  193. # Return the wait time using exponential backoff
  194. exponential_wait = wait_exponential(
  195. multiplier=self.config.retry_multiplier,
  196. min=min_wait_time,
  197. max=max_wait_time,
  198. )
  199. # Call the exponential wait function with retry_state to get the actual wait time
  200. return exponential_wait(retry_state)
  201. @retry(
  202. before_sleep=log_retry_attempt,
  203. stop=stop_after_attempt(self.config.num_retries),
  204. reraise=True,
  205. retry=(
  206. retry_if_exception_type(self.retry_exceptions)
  207. & retry_if_not_exception_type(OperationCancelled)
  208. ),
  209. wait=custom_completion_wait,
  210. )
  211. def wrapper(*args, **kwargs):
  212. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  213. # some callers might just send the messages directly
  214. if 'messages' in kwargs:
  215. messages = kwargs['messages']
  216. else:
  217. messages = args[1] if len(args) > 1 else []
  218. # this serves to prevent empty messages and logging the messages
  219. debug_message = self._get_debug_message(messages)
  220. if self.is_caching_prompt_active():
  221. # Anthropic-specific prompt caching
  222. if 'claude-3' in self.config.model:
  223. kwargs['extra_headers'] = {
  224. 'anthropic-beta': 'prompt-caching-2024-07-31',
  225. }
  226. # skip if messages is empty (thus debug_message is empty)
  227. if debug_message:
  228. llm_prompt_logger.debug(debug_message)
  229. resp = completion_unwrapped(*args, **kwargs)
  230. else:
  231. logger.debug('No completion messages!')
  232. resp = {'choices': [{'message': {'content': ''}}]}
  233. if self.config.log_completions:
  234. self.llm_completions.append(
  235. {
  236. 'messages': messages,
  237. 'response': resp,
  238. 'timestamp': time.time(),
  239. 'cost': self.completion_cost(resp),
  240. }
  241. )
  242. # log the response
  243. message_back = resp['choices'][0]['message']['content']
  244. if message_back:
  245. llm_response_logger.debug(message_back)
  246. # post-process to log costs
  247. self._post_completion(resp)
  248. return resp
  249. self._completion = wrapper # type: ignore
  250. # Async version
  251. self._async_completion = partial(
  252. self._call_acompletion,
  253. model=self.config.model,
  254. api_key=self.config.api_key,
  255. base_url=self.config.base_url,
  256. api_version=self.config.api_version,
  257. custom_llm_provider=self.config.custom_llm_provider,
  258. max_tokens=self.config.max_output_tokens,
  259. timeout=self.config.timeout,
  260. temperature=self.config.temperature,
  261. top_p=self.config.top_p,
  262. drop_params=self.config.drop_params,
  263. **(
  264. {'safety_settings': gemini_safety_settings}
  265. if gemini_safety_settings is not None
  266. else {}
  267. ),
  268. )
  269. async_completion_unwrapped = self._async_completion
  270. @retry(
  271. before_sleep=log_retry_attempt,
  272. stop=stop_after_attempt(self.config.num_retries),
  273. reraise=True,
  274. retry=(
  275. retry_if_exception_type(self.retry_exceptions)
  276. & retry_if_not_exception_type(OperationCancelled)
  277. ),
  278. wait=custom_completion_wait,
  279. )
  280. async def async_completion_wrapper(*args, **kwargs):
  281. """Async wrapper for the litellm acompletion function."""
  282. # some callers might just send the messages directly
  283. if 'messages' in kwargs:
  284. messages = kwargs['messages']
  285. else:
  286. messages = args[1] if len(args) > 1 else []
  287. # this serves to prevent empty messages and logging the messages
  288. debug_message = self._get_debug_message(messages)
  289. async def check_stopped():
  290. while should_continue():
  291. if (
  292. hasattr(self.config, 'on_cancel_requested_fn')
  293. and self.config.on_cancel_requested_fn is not None
  294. and await self.config.on_cancel_requested_fn()
  295. ):
  296. raise UserCancelledError('LLM request cancelled by user')
  297. await asyncio.sleep(0.1)
  298. stop_check_task = asyncio.create_task(check_stopped())
  299. try:
  300. # Directly call and await litellm_acompletion
  301. if debug_message:
  302. llm_prompt_logger.debug(debug_message)
  303. resp = await async_completion_unwrapped(*args, **kwargs)
  304. else:
  305. logger.debug('No completion messages!')
  306. resp = {'choices': [{'message': {'content': ''}}]}
  307. # skip if messages is empty (thus debug_message is empty)
  308. if debug_message:
  309. message_back = resp['choices'][0]['message']['content']
  310. llm_response_logger.debug(message_back)
  311. else:
  312. resp = {'choices': [{'message': {'content': ''}}]}
  313. self._post_completion(resp)
  314. # We do not support streaming in this method, thus return resp
  315. return resp
  316. except UserCancelledError:
  317. logger.info('LLM request cancelled by user.')
  318. raise
  319. except (
  320. APIConnectionError,
  321. ContentPolicyViolationError,
  322. InternalServerError,
  323. NotFoundError,
  324. OpenAIError,
  325. RateLimitError,
  326. ServiceUnavailableError,
  327. ) as e:
  328. logger.error(f'Completion Error occurred:\n{e}')
  329. raise
  330. finally:
  331. await asyncio.sleep(0.1)
  332. stop_check_task.cancel()
  333. try:
  334. await stop_check_task
  335. except asyncio.CancelledError:
  336. pass
  337. @retry(
  338. before_sleep=log_retry_attempt,
  339. stop=stop_after_attempt(self.config.num_retries),
  340. reraise=True,
  341. retry=(
  342. retry_if_exception_type(self.retry_exceptions)
  343. & retry_if_not_exception_type(OperationCancelled)
  344. ),
  345. wait=custom_completion_wait,
  346. )
  347. async def async_acompletion_stream_wrapper(*args, **kwargs):
  348. """Async wrapper for the litellm acompletion with streaming function."""
  349. # some callers might just send the messages directly
  350. if 'messages' in kwargs:
  351. messages = kwargs['messages']
  352. else:
  353. messages = args[1] if len(args) > 1 else []
  354. # log the prompt
  355. debug_message = ''
  356. for message in messages:
  357. debug_message += message_separator + message['content']
  358. llm_prompt_logger.debug(debug_message)
  359. try:
  360. # Directly call and await litellm_acompletion
  361. resp = await async_completion_unwrapped(*args, **kwargs)
  362. # For streaming we iterate over the chunks
  363. async for chunk in resp:
  364. # Check for cancellation before yielding the chunk
  365. if (
  366. hasattr(self.config, 'on_cancel_requested_fn')
  367. and self.config.on_cancel_requested_fn is not None
  368. and await self.config.on_cancel_requested_fn()
  369. ):
  370. raise UserCancelledError(
  371. 'LLM request cancelled due to CANCELLED state'
  372. )
  373. # with streaming, it is "delta", not "message"!
  374. message_back = chunk['choices'][0]['delta']['content']
  375. llm_response_logger.debug(message_back)
  376. self._post_completion(chunk)
  377. yield chunk
  378. except UserCancelledError:
  379. logger.info('LLM request cancelled by user.')
  380. raise
  381. except (
  382. APIConnectionError,
  383. ContentPolicyViolationError,
  384. InternalServerError,
  385. NotFoundError,
  386. OpenAIError,
  387. RateLimitError,
  388. ServiceUnavailableError,
  389. ) as e:
  390. logger.error(f'Completion Error occurred:\n{e}')
  391. raise
  392. finally:
  393. if kwargs.get('stream', False):
  394. await asyncio.sleep(0.1)
  395. self._async_completion = async_completion_wrapper # type: ignore
  396. self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
  397. def _get_debug_message(self, messages):
  398. if not messages:
  399. return ''
  400. messages = messages if isinstance(messages, list) else [messages]
  401. return message_separator.join(
  402. self._format_message_content(msg) for msg in messages if msg['content']
  403. )
  404. def _format_message_content(self, message):
  405. content = message['content']
  406. if isinstance(content, list):
  407. return self._format_list_content(content)
  408. return str(content)
  409. def _format_list_content(self, content_list):
  410. return '\n'.join(
  411. self._format_content_element(element) for element in content_list
  412. )
  413. def _format_content_element(self, element):
  414. if isinstance(element, dict):
  415. if 'text' in element:
  416. return element['text']
  417. if (
  418. self.vision_is_active()
  419. and 'image_url' in element
  420. and 'url' in element['image_url']
  421. ):
  422. return element['image_url']['url']
  423. return str(element)
  424. async def _call_acompletion(self, *args, **kwargs):
  425. """This is a wrapper for the litellm acompletion function which
  426. makes it mockable for testing.
  427. """
  428. return await litellm.acompletion(*args, **kwargs)
  429. @property
  430. def completion(self):
  431. """Decorator for the litellm completion function.
  432. Check the complete documentation at https://litellm.vercel.app/docs/completion
  433. """
  434. return self._completion
  435. @property
  436. def async_completion(self):
  437. """Decorator for the async litellm acompletion function.
  438. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  439. """
  440. return self._async_completion
  441. @property
  442. def async_streaming_completion(self):
  443. """Decorator for the async litellm acompletion function with streaming.
  444. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  445. """
  446. return self._async_streaming_completion
  447. def vision_is_active(self):
  448. return not self.config.disable_vision and self._supports_vision()
  449. def _supports_vision(self):
  450. """Acquire from litellm if model is vision capable.
  451. Returns:
  452. bool: True if model is vision capable. If model is not supported by litellm, it will return False.
  453. """
  454. try:
  455. return litellm.supports_vision(self.config.model)
  456. except Exception:
  457. return False
  458. def is_caching_prompt_active(self) -> bool:
  459. """Check if prompt caching is enabled and supported for current model.
  460. Returns:
  461. boolean: True if prompt caching is active for the given model.
  462. """
  463. return self.config.caching_prompt is True and any(
  464. model in self.config.model for model in cache_prompting_supported_models
  465. )
  466. def _post_completion(self, response) -> None:
  467. """Post-process the completion response."""
  468. try:
  469. cur_cost = self.completion_cost(response)
  470. except Exception:
  471. cur_cost = 0
  472. stats = ''
  473. if self.cost_metric_supported:
  474. stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
  475. cur_cost,
  476. self.metrics.accumulated_cost,
  477. )
  478. usage = response.get('usage')
  479. if usage:
  480. input_tokens = usage.get('prompt_tokens')
  481. output_tokens = usage.get('completion_tokens')
  482. if input_tokens:
  483. stats += 'Input tokens: ' + str(input_tokens)
  484. if output_tokens:
  485. stats += (
  486. (' | ' if input_tokens else '')
  487. + 'Output tokens: '
  488. + str(output_tokens)
  489. + '\n'
  490. )
  491. model_extra = usage.get('model_extra', {})
  492. cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
  493. if cache_creation_input_tokens:
  494. stats += (
  495. 'Input tokens (cache write): '
  496. + str(cache_creation_input_tokens)
  497. + '\n'
  498. )
  499. cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
  500. if cache_read_input_tokens:
  501. stats += (
  502. 'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
  503. )
  504. if stats:
  505. logger.info(stats)
  506. def get_token_count(self, messages):
  507. """Get the number of tokens in a list of messages.
  508. Args:
  509. messages (list): A list of messages.
  510. Returns:
  511. int: The number of tokens.
  512. """
  513. try:
  514. return litellm.token_counter(model=self.config.model, messages=messages)
  515. except Exception:
  516. # TODO: this is to limit logspam in case token count is not supported
  517. return 0
  518. def is_local(self):
  519. """Determines if the system is using a locally running LLM.
  520. Returns:
  521. boolean: True if executing a local model.
  522. """
  523. if self.config.base_url is not None:
  524. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  525. if substring in self.config.base_url:
  526. return True
  527. elif self.config.model is not None:
  528. if self.config.model.startswith('ollama'):
  529. return True
  530. return False
  531. def completion_cost(self, response):
  532. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  533. Add the current cost into total cost in metrics.
  534. Args:
  535. response: A response from a model invocation.
  536. Returns:
  537. number: The cost of the response.
  538. """
  539. if not self.cost_metric_supported:
  540. return 0.0
  541. extra_kwargs = {}
  542. if (
  543. self.config.input_cost_per_token is not None
  544. and self.config.output_cost_per_token is not None
  545. ):
  546. cost_per_token = CostPerToken(
  547. input_cost_per_token=self.config.input_cost_per_token,
  548. output_cost_per_token=self.config.output_cost_per_token,
  549. )
  550. logger.info(f'Using custom cost per token: {cost_per_token}')
  551. extra_kwargs['custom_cost_per_token'] = cost_per_token
  552. if not self.is_local():
  553. try:
  554. cost = litellm_completion_cost(
  555. completion_response=response, **extra_kwargs
  556. )
  557. self.metrics.add_cost(cost)
  558. return cost
  559. except Exception:
  560. self.cost_metric_supported = False
  561. logger.warning('Cost calculation not supported for this model.')
  562. return 0.0
  563. def __str__(self):
  564. if self.config.api_version:
  565. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  566. elif self.config.base_url:
  567. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  568. return f'LLM(model={self.config.model})'
  569. def __repr__(self):
  570. return str(self)
  571. def reset(self):
  572. self.metrics = Metrics()
  573. self.llm_completions = []
  574. def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
  575. if isinstance(messages, Message):
  576. return [messages.model_dump()]
  577. return [message.model_dump() for message in messages]