llm.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. import asyncio
  2. import copy
  3. import warnings
  4. from functools import partial
  5. from typing import Union
  6. from openhands.core.config import LLMConfig
  7. with warnings.catch_warnings():
  8. warnings.simplefilter('ignore')
  9. import litellm
  10. from litellm import completion as litellm_completion
  11. from litellm import completion_cost as litellm_completion_cost
  12. from litellm.exceptions import (
  13. APIConnectionError,
  14. ContentPolicyViolationError,
  15. InternalServerError,
  16. NotFoundError,
  17. OpenAIError,
  18. RateLimitError,
  19. ServiceUnavailableError,
  20. )
  21. from litellm.types.utils import CostPerToken
  22. from tenacity import (
  23. retry,
  24. retry_if_exception_type,
  25. stop_after_attempt,
  26. wait_random_exponential,
  27. )
  28. from openhands.core.exceptions import LLMResponseError, UserCancelledError
  29. from openhands.core.logger import llm_prompt_logger, llm_response_logger
  30. from openhands.core.logger import openhands_logger as logger
  31. from openhands.core.message import Message, format_messages
  32. from openhands.core.metrics import Metrics
  33. __all__ = ['LLM']
  34. message_separator = '\n\n----------\n\n'
  35. cache_prompting_supported_models = [
  36. 'claude-3-5-sonnet-20240620',
  37. 'claude-3-haiku-20240307',
  38. ]
  39. class LLM:
  40. """The LLM class represents a Language Model instance.
  41. Attributes:
  42. config: an LLMConfig object specifying the configuration of the LLM.
  43. """
  44. def __init__(
  45. self,
  46. config: LLMConfig,
  47. metrics: Metrics | None = None,
  48. ):
  49. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  50. Passing simple parameters always overrides config.
  51. Args:
  52. config: The LLM configuration
  53. """
  54. self.metrics = metrics if metrics is not None else Metrics()
  55. self.cost_metric_supported = True
  56. self.config = copy.deepcopy(config)
  57. # Set up config attributes with default values to prevent AttributeError
  58. LLMConfig.set_missing_attributes(self.config)
  59. # litellm actually uses base Exception here for unknown model
  60. self.model_info = None
  61. try:
  62. if self.config.model.startswith('openrouter'):
  63. self.model_info = litellm.get_model_info(self.config.model)
  64. else:
  65. self.model_info = litellm.get_model_info(
  66. self.config.model.split(':')[0]
  67. )
  68. # noinspection PyBroadException
  69. except Exception as e:
  70. logger.warning(f'Could not get model info for {config.model}:\n{e}')
  71. # Set the max tokens in an LM-specific way if not set
  72. if self.config.max_input_tokens is None:
  73. if (
  74. self.model_info is not None
  75. and 'max_input_tokens' in self.model_info
  76. and isinstance(self.model_info['max_input_tokens'], int)
  77. ):
  78. self.config.max_input_tokens = self.model_info['max_input_tokens']
  79. else:
  80. # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
  81. self.config.max_input_tokens = 4096
  82. if self.config.max_output_tokens is None:
  83. if (
  84. self.model_info is not None
  85. and 'max_output_tokens' in self.model_info
  86. and isinstance(self.model_info['max_output_tokens'], int)
  87. ):
  88. self.config.max_output_tokens = self.model_info['max_output_tokens']
  89. else:
  90. # Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
  91. self.config.max_output_tokens = 1024
  92. if self.config.drop_params:
  93. litellm.drop_params = self.config.drop_params
  94. self._completion = partial(
  95. litellm_completion,
  96. model=self.config.model,
  97. api_key=self.config.api_key,
  98. base_url=self.config.base_url,
  99. api_version=self.config.api_version,
  100. custom_llm_provider=self.config.custom_llm_provider,
  101. max_tokens=self.config.max_output_tokens,
  102. timeout=self.config.timeout,
  103. temperature=self.config.temperature,
  104. top_p=self.config.top_p,
  105. )
  106. completion_unwrapped = self._completion
  107. def attempt_on_error(retry_state):
  108. logger.error(
  109. f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize these settings in the configuration.',
  110. exc_info=False,
  111. )
  112. return None
  113. @retry(
  114. reraise=True,
  115. stop=stop_after_attempt(self.config.num_retries),
  116. wait=wait_random_exponential(
  117. multiplier=self.config.retry_multiplier,
  118. min=self.config.retry_min_wait,
  119. max=self.config.retry_max_wait,
  120. ),
  121. retry=retry_if_exception_type(
  122. (
  123. APIConnectionError,
  124. ContentPolicyViolationError,
  125. InternalServerError,
  126. OpenAIError,
  127. RateLimitError,
  128. )
  129. ),
  130. after=attempt_on_error,
  131. )
  132. def wrapper(*args, **kwargs):
  133. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  134. # some callers might just send the messages directly
  135. if 'messages' in kwargs:
  136. messages = kwargs['messages']
  137. else:
  138. messages = args[1] if len(args) > 1 else []
  139. # log the prompt
  140. debug_message = ''
  141. for message in messages:
  142. debug_str = '' # helper to prevent empty messages
  143. content = message['content']
  144. if isinstance(content, list):
  145. for element in content:
  146. if isinstance(element, dict):
  147. if 'text' in element:
  148. debug_str = element['text'].strip()
  149. elif (
  150. self.vision_is_active()
  151. and 'image_url' in element
  152. and 'url' in element['image_url']
  153. ):
  154. debug_str = element['image_url']['url']
  155. else:
  156. debug_str = str(element)
  157. else:
  158. debug_str = str(element)
  159. else:
  160. debug_str = str(content)
  161. if debug_str:
  162. debug_message += message_separator + debug_str
  163. if self.is_caching_prompt_active():
  164. # Anthropic-specific prompt caching
  165. if 'claude-3' in self.config.model:
  166. kwargs['extra_headers'] = {
  167. 'anthropic-beta': 'prompt-caching-2024-07-31',
  168. }
  169. # skip if messages is empty (thus debug_message is empty)
  170. if debug_message:
  171. llm_prompt_logger.debug(debug_message)
  172. resp = completion_unwrapped(*args, **kwargs)
  173. else:
  174. logger.debug('No completion messages!')
  175. resp = {'choices': [{'message': {'content': ''}}]}
  176. # log the response
  177. message_back = resp['choices'][0]['message']['content']
  178. llm_response_logger.debug(message_back)
  179. # post-process to log costs
  180. self._post_completion(resp)
  181. return resp
  182. self._completion = wrapper # type: ignore
  183. # Async version
  184. self._async_completion = partial(
  185. self._call_acompletion,
  186. model=self.config.model,
  187. api_key=self.config.api_key,
  188. base_url=self.config.base_url,
  189. api_version=self.config.api_version,
  190. custom_llm_provider=self.config.custom_llm_provider,
  191. max_tokens=self.config.max_output_tokens,
  192. timeout=self.config.timeout,
  193. temperature=self.config.temperature,
  194. top_p=self.config.top_p,
  195. drop_params=True,
  196. )
  197. async_completion_unwrapped = self._async_completion
  198. @retry(
  199. reraise=True,
  200. stop=stop_after_attempt(self.config.num_retries),
  201. wait=wait_random_exponential(
  202. multiplier=self.config.retry_multiplier,
  203. min=self.config.retry_min_wait,
  204. max=self.config.retry_max_wait,
  205. ),
  206. retry=retry_if_exception_type(
  207. (
  208. APIConnectionError,
  209. ContentPolicyViolationError,
  210. InternalServerError,
  211. OpenAIError,
  212. RateLimitError,
  213. )
  214. ),
  215. after=attempt_on_error,
  216. )
  217. async def async_completion_wrapper(*args, **kwargs):
  218. """Async wrapper for the litellm acompletion function."""
  219. # some callers might just send the messages directly
  220. if 'messages' in kwargs:
  221. messages = kwargs['messages']
  222. else:
  223. messages = args[1]
  224. # log the prompt
  225. debug_message = ''
  226. for message in messages:
  227. content = message['content']
  228. if isinstance(content, list):
  229. for element in content:
  230. if isinstance(element, dict):
  231. if 'text' in element:
  232. debug_str = element['text']
  233. elif (
  234. self.vision_is_active()
  235. and 'image_url' in element
  236. and 'url' in element['image_url']
  237. ):
  238. debug_str = element['image_url']['url']
  239. else:
  240. debug_str = str(element)
  241. else:
  242. debug_str = str(element)
  243. debug_message += message_separator + debug_str
  244. else:
  245. debug_str = str(content)
  246. debug_message += message_separator + debug_str
  247. llm_prompt_logger.debug(debug_message)
  248. async def check_stopped():
  249. while True:
  250. if (
  251. hasattr(self.config, 'on_cancel_requested_fn')
  252. and self.config.on_cancel_requested_fn is not None
  253. and await self.config.on_cancel_requested_fn()
  254. ):
  255. raise UserCancelledError('LLM request cancelled by user')
  256. await asyncio.sleep(0.1)
  257. stop_check_task = asyncio.create_task(check_stopped())
  258. try:
  259. # Directly call and await litellm_acompletion
  260. resp = await async_completion_unwrapped(*args, **kwargs)
  261. # skip if messages is empty (thus debug_message is empty)
  262. if debug_message:
  263. message_back = resp['choices'][0]['message']['content']
  264. llm_response_logger.debug(message_back)
  265. else:
  266. resp = {'choices': [{'message': {'content': ''}}]}
  267. self._post_completion(resp)
  268. # We do not support streaming in this method, thus return resp
  269. return resp
  270. except UserCancelledError:
  271. logger.info('LLM request cancelled by user.')
  272. raise
  273. except (
  274. APIConnectionError,
  275. ContentPolicyViolationError,
  276. InternalServerError,
  277. NotFoundError,
  278. OpenAIError,
  279. RateLimitError,
  280. ServiceUnavailableError,
  281. ) as e:
  282. logger.error(f'Completion Error occurred:\n{e}')
  283. raise
  284. finally:
  285. await asyncio.sleep(0.1)
  286. stop_check_task.cancel()
  287. try:
  288. await stop_check_task
  289. except asyncio.CancelledError:
  290. pass
  291. @retry(
  292. reraise=True,
  293. stop=stop_after_attempt(self.config.num_retries),
  294. wait=wait_random_exponential(
  295. multiplier=self.config.retry_multiplier,
  296. min=self.config.retry_min_wait,
  297. max=self.config.retry_max_wait,
  298. ),
  299. retry=retry_if_exception_type(
  300. (
  301. APIConnectionError,
  302. ContentPolicyViolationError,
  303. InternalServerError,
  304. OpenAIError,
  305. RateLimitError,
  306. )
  307. ),
  308. after=attempt_on_error,
  309. )
  310. async def async_acompletion_stream_wrapper(*args, **kwargs):
  311. """Async wrapper for the litellm acompletion with streaming function."""
  312. # some callers might just send the messages directly
  313. if 'messages' in kwargs:
  314. messages = kwargs['messages']
  315. else:
  316. messages = args[1]
  317. # log the prompt
  318. debug_message = ''
  319. for message in messages:
  320. debug_message += message_separator + message['content']
  321. llm_prompt_logger.debug(debug_message)
  322. try:
  323. # Directly call and await litellm_acompletion
  324. resp = await async_completion_unwrapped(*args, **kwargs)
  325. # For streaming we iterate over the chunks
  326. async for chunk in resp:
  327. # Check for cancellation before yielding the chunk
  328. if (
  329. hasattr(self.config, 'on_cancel_requested_fn')
  330. and self.config.on_cancel_requested_fn is not None
  331. and await self.config.on_cancel_requested_fn()
  332. ):
  333. raise UserCancelledError(
  334. 'LLM request cancelled due to CANCELLED state'
  335. )
  336. # with streaming, it is "delta", not "message"!
  337. message_back = chunk['choices'][0]['delta']['content']
  338. llm_response_logger.debug(message_back)
  339. self._post_completion(chunk)
  340. yield chunk
  341. except UserCancelledError:
  342. logger.info('LLM request cancelled by user.')
  343. raise
  344. except (
  345. APIConnectionError,
  346. ContentPolicyViolationError,
  347. InternalServerError,
  348. NotFoundError,
  349. OpenAIError,
  350. RateLimitError,
  351. ServiceUnavailableError,
  352. ) as e:
  353. logger.error(f'Completion Error occurred:\n{e}')
  354. raise
  355. finally:
  356. if kwargs.get('stream', False):
  357. await asyncio.sleep(0.1)
  358. self._async_completion = async_completion_wrapper # type: ignore
  359. self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
  360. async def _call_acompletion(self, *args, **kwargs):
  361. return await litellm.acompletion(*args, **kwargs)
  362. @property
  363. def completion(self):
  364. """Decorator for the litellm completion function.
  365. Check the complete documentation at https://litellm.vercel.app/docs/completion
  366. """
  367. try:
  368. return self._completion
  369. except Exception as e:
  370. raise LLMResponseError(e)
  371. @property
  372. def async_completion(self):
  373. """Decorator for the async litellm acompletion function.
  374. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  375. """
  376. try:
  377. return self._async_completion
  378. except Exception as e:
  379. raise LLMResponseError(e)
  380. @property
  381. def async_streaming_completion(self):
  382. """Decorator for the async litellm acompletion function with streaming.
  383. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  384. """
  385. try:
  386. return self._async_streaming_completion
  387. except Exception as e:
  388. raise LLMResponseError(e)
  389. def vision_is_active(self):
  390. return not self.config.disable_vision and self._supports_vision()
  391. def _supports_vision(self):
  392. """Acquire from litellm if model is vision capable.
  393. Returns:
  394. bool: True if model is vision capable. If model is not supported by litellm, it will return False.
  395. """
  396. try:
  397. return litellm.supports_vision(self.config.model)
  398. except Exception:
  399. return False
  400. def is_caching_prompt_active(self) -> bool:
  401. """Check if prompt caching is enabled and supported for current model.
  402. Returns:
  403. boolean: True if prompt caching is active for the given model.
  404. """
  405. return (
  406. self.config.caching_prompt is True
  407. and self.config.model in cache_prompting_supported_models
  408. )
  409. def _post_completion(self, response) -> None:
  410. """Post-process the completion response."""
  411. try:
  412. cur_cost = self.completion_cost(response)
  413. except Exception:
  414. cur_cost = 0
  415. stats = ''
  416. if self.cost_metric_supported:
  417. stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
  418. cur_cost,
  419. self.metrics.accumulated_cost,
  420. )
  421. usage = response.get('usage')
  422. if usage:
  423. input_tokens = usage.get('prompt_tokens')
  424. output_tokens = usage.get('completion_tokens')
  425. if input_tokens:
  426. stats += 'Input tokens: ' + str(input_tokens) + '\n'
  427. if output_tokens:
  428. stats += 'Output tokens: ' + str(output_tokens) + '\n'
  429. model_extra = usage.get('model_extra', {})
  430. cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
  431. if cache_creation_input_tokens:
  432. stats += (
  433. 'Input tokens (cache write): '
  434. + str(cache_creation_input_tokens)
  435. + '\n'
  436. )
  437. cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
  438. if cache_read_input_tokens:
  439. stats += (
  440. 'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
  441. )
  442. if stats:
  443. logger.info(stats)
  444. def get_token_count(self, messages):
  445. """Get the number of tokens in a list of messages.
  446. Args:
  447. messages (list): A list of messages.
  448. Returns:
  449. int: The number of tokens.
  450. """
  451. try:
  452. return litellm.token_counter(model=self.config.model, messages=messages)
  453. except Exception:
  454. # TODO: this is to limit logspam in case token count is not supported
  455. return 0
  456. def is_local(self):
  457. """Determines if the system is using a locally running LLM.
  458. Returns:
  459. boolean: True if executing a local model.
  460. """
  461. if self.config.base_url is not None:
  462. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  463. if substring in self.config.base_url:
  464. return True
  465. elif self.config.model is not None:
  466. if self.config.model.startswith('ollama'):
  467. return True
  468. return False
  469. def completion_cost(self, response):
  470. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  471. Add the current cost into total cost in metrics.
  472. Args:
  473. response: A response from a model invocation.
  474. Returns:
  475. number: The cost of the response.
  476. """
  477. if not self.cost_metric_supported:
  478. return 0.0
  479. extra_kwargs = {}
  480. if (
  481. self.config.input_cost_per_token is not None
  482. and self.config.output_cost_per_token is not None
  483. ):
  484. cost_per_token = CostPerToken(
  485. input_cost_per_token=self.config.input_cost_per_token,
  486. output_cost_per_token=self.config.output_cost_per_token,
  487. )
  488. logger.info(f'Using custom cost per token: {cost_per_token}')
  489. extra_kwargs['custom_cost_per_token'] = cost_per_token
  490. if not self.is_local():
  491. try:
  492. cost = litellm_completion_cost(
  493. completion_response=response, **extra_kwargs
  494. )
  495. self.metrics.add_cost(cost)
  496. return cost
  497. except Exception:
  498. self.cost_metric_supported = False
  499. logger.warning('Cost calculation not supported for this model.')
  500. return 0.0
  501. def __str__(self):
  502. if self.config.api_version:
  503. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  504. elif self.config.base_url:
  505. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  506. return f'LLM(model={self.config.model})'
  507. def __repr__(self):
  508. return str(self)
  509. def reset(self):
  510. self.metrics = Metrics()
  511. def format_messages_for_llm(
  512. self, messages: Union[Message, list[Message]]
  513. ) -> list[dict]:
  514. return format_messages(messages, self.vision_is_active())