llm.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. import copy
  2. import time
  3. import warnings
  4. from functools import partial
  5. from typing import Any
  6. from openhands.core.config import LLMConfig
  7. with warnings.catch_warnings():
  8. warnings.simplefilter('ignore')
  9. import litellm
  10. from litellm import ModelInfo, PromptTokensDetails
  11. from litellm import completion as litellm_completion
  12. from litellm import completion_cost as litellm_completion_cost
  13. from litellm.exceptions import (
  14. APIConnectionError,
  15. APIError,
  16. InternalServerError,
  17. RateLimitError,
  18. ServiceUnavailableError,
  19. )
  20. from litellm.types.utils import CostPerToken, ModelResponse, Usage
  21. from openhands.core.exceptions import CloudFlareBlockageError
  22. from openhands.core.logger import openhands_logger as logger
  23. from openhands.core.message import Message
  24. from openhands.llm.debug_mixin import DebugMixin
  25. from openhands.llm.metrics import Metrics
  26. from openhands.llm.retry_mixin import RetryMixin
  27. __all__ = ['LLM']
  28. # tuple of exceptions to retry on
  29. LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
  30. APIConnectionError,
  31. # FIXME: APIError is useful on 502 from a proxy for example,
  32. # but it also retries on other errors that are permanent
  33. APIError,
  34. InternalServerError,
  35. RateLimitError,
  36. ServiceUnavailableError,
  37. )
  38. # cache prompt supporting models
  39. # remove this when we gemini and deepseek are supported
  40. CACHE_PROMPT_SUPPORTED_MODELS = [
  41. 'claude-3-5-sonnet-20240620',
  42. 'claude-3-haiku-20240307',
  43. 'claude-3-opus-20240229',
  44. 'anthropic/claude-3-opus-20240229',
  45. 'anthropic/claude-3-haiku-20240307',
  46. 'anthropic/claude-3-5-sonnet-20240620',
  47. ]
  48. class LLM(RetryMixin, DebugMixin):
  49. """The LLM class represents a Language Model instance.
  50. Attributes:
  51. config: an LLMConfig object specifying the configuration of the LLM.
  52. """
  53. def __init__(
  54. self,
  55. config: LLMConfig,
  56. metrics: Metrics | None = None,
  57. ):
  58. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  59. Passing simple parameters always overrides config.
  60. Args:
  61. config: The LLM configuration.
  62. metrics: The metrics to use.
  63. """
  64. self.metrics: Metrics = (
  65. metrics if metrics is not None else Metrics(model_name=config.model)
  66. )
  67. self.cost_metric_supported: bool = True
  68. self.config: LLMConfig = copy.deepcopy(config)
  69. # list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
  70. # - 'messages': list of messages
  71. # - 'response': response from the LLM
  72. self.llm_completions: list[dict[str, Any]] = []
  73. # litellm actually uses base Exception here for unknown model
  74. self.model_info: ModelInfo | None = None
  75. try:
  76. if self.config.model.startswith('openrouter'):
  77. self.model_info = litellm.get_model_info(self.config.model)
  78. else:
  79. self.model_info = litellm.get_model_info(
  80. self.config.model.split(':')[0]
  81. )
  82. # noinspection PyBroadException
  83. except Exception as e:
  84. logger.warning(f'Could not get model info for {config.model}:\n{e}')
  85. # Set the max tokens in an LM-specific way if not set
  86. if self.config.max_input_tokens is None:
  87. if (
  88. self.model_info is not None
  89. and 'max_input_tokens' in self.model_info
  90. and isinstance(self.model_info['max_input_tokens'], int)
  91. ):
  92. self.config.max_input_tokens = self.model_info['max_input_tokens']
  93. else:
  94. # Safe fallback for any potentially viable model
  95. self.config.max_input_tokens = 4096
  96. if self.config.max_output_tokens is None:
  97. # Safe default for any potentially viable model
  98. self.config.max_output_tokens = 4096
  99. if self.model_info is not None:
  100. # max_output_tokens has precedence over max_tokens, if either exists.
  101. # litellm has models with both, one or none of these 2 parameters!
  102. if 'max_output_tokens' in self.model_info and isinstance(
  103. self.model_info['max_output_tokens'], int
  104. ):
  105. self.config.max_output_tokens = self.model_info['max_output_tokens']
  106. elif 'max_tokens' in self.model_info and isinstance(
  107. self.model_info['max_tokens'], int
  108. ):
  109. self.config.max_output_tokens = self.model_info['max_tokens']
  110. self._completion = partial(
  111. litellm_completion,
  112. model=self.config.model,
  113. api_key=self.config.api_key,
  114. base_url=self.config.base_url,
  115. api_version=self.config.api_version,
  116. custom_llm_provider=self.config.custom_llm_provider,
  117. max_tokens=self.config.max_output_tokens,
  118. timeout=self.config.timeout,
  119. temperature=self.config.temperature,
  120. top_p=self.config.top_p,
  121. drop_params=self.config.drop_params,
  122. )
  123. if self.vision_is_active():
  124. logger.debug('LLM: model has vision enabled')
  125. if self.is_caching_prompt_active():
  126. logger.debug('LLM: caching prompt enabled')
  127. completion_unwrapped = self._completion
  128. @self.retry_decorator(
  129. num_retries=self.config.num_retries,
  130. retry_exceptions=LLM_RETRY_EXCEPTIONS,
  131. retry_min_wait=self.config.retry_min_wait,
  132. retry_max_wait=self.config.retry_max_wait,
  133. retry_multiplier=self.config.retry_multiplier,
  134. )
  135. def wrapper(*args, **kwargs):
  136. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  137. messages: list[dict[str, Any]] | dict[str, Any] = []
  138. # some callers might send the model and messages directly
  139. # litellm allows positional args, like completion(model, messages, **kwargs)
  140. if len(args) > 1:
  141. # ignore the first argument if it's provided (it would be the model)
  142. # design wise: we don't allow overriding the configured values
  143. # implementation wise: the partial function set the model as a kwarg already
  144. # as well as other kwargs
  145. messages = args[1] if len(args) > 1 else args[0]
  146. kwargs['messages'] = messages
  147. # remove the first args, they're sent in kwargs
  148. args = args[2:]
  149. elif 'messages' in kwargs:
  150. messages = kwargs['messages']
  151. # ensure we work with a list of messages
  152. messages = messages if isinstance(messages, list) else [messages]
  153. # if we have no messages, something went very wrong
  154. if not messages:
  155. raise ValueError(
  156. 'The messages list is empty. At least one message is required.'
  157. )
  158. # log the entire LLM prompt
  159. self.log_prompt(messages)
  160. if self.is_caching_prompt_active():
  161. # Anthropic-specific prompt caching
  162. if 'claude-3' in self.config.model:
  163. kwargs['extra_headers'] = {
  164. 'anthropic-beta': 'prompt-caching-2024-07-31',
  165. }
  166. try:
  167. # we don't support streaming here, thus we get a ModelResponse
  168. resp: ModelResponse = completion_unwrapped(*args, **kwargs)
  169. # log for evals or other scripts that need the raw completion
  170. if self.config.log_completions:
  171. self.llm_completions.append(
  172. {
  173. 'messages': messages,
  174. 'response': resp,
  175. 'timestamp': time.time(),
  176. 'cost': self._completion_cost(resp),
  177. }
  178. )
  179. message_back: str = resp['choices'][0]['message']['content']
  180. # log the LLM response
  181. self.log_response(message_back)
  182. # post-process the response
  183. self._post_completion(resp)
  184. return resp
  185. except APIError as e:
  186. if 'Attention Required! | Cloudflare' in str(e):
  187. raise CloudFlareBlockageError(
  188. 'Request blocked by CloudFlare'
  189. ) from e
  190. raise
  191. self._completion = wrapper
  192. @property
  193. def completion(self):
  194. """Decorator for the litellm completion function.
  195. Check the complete documentation at https://litellm.vercel.app/docs/completion
  196. """
  197. return self._completion
  198. def vision_is_active(self):
  199. return not self.config.disable_vision and self._supports_vision()
  200. def _supports_vision(self):
  201. """Acquire from litellm if model is vision capable.
  202. Returns:
  203. bool: True if model is vision capable. If model is not supported by litellm, it will return False.
  204. """
  205. # litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes)
  206. # but model_info will have the correct value for some reason.
  207. # we can go with it, but we will need to keep an eye if model_info is correct for Vertex or other providers
  208. # remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
  209. return litellm.supports_vision(self.config.model) or (
  210. self.model_info is not None
  211. and self.model_info.get('supports_vision', False)
  212. )
  213. def is_caching_prompt_active(self) -> bool:
  214. """Check if prompt caching is supported and enabled for current model.
  215. Returns:
  216. boolean: True if prompt caching is supported and enabled for the given model.
  217. """
  218. return (
  219. self.config.caching_prompt is True
  220. and self.model_info is not None
  221. and self.model_info.get('supports_prompt_caching', False)
  222. and self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
  223. )
  224. def _post_completion(self, response: ModelResponse) -> None:
  225. """Post-process the completion response.
  226. Logs the cost and usage stats of the completion call.
  227. """
  228. try:
  229. cur_cost = self._completion_cost(response)
  230. except Exception:
  231. cur_cost = 0
  232. stats = ''
  233. if self.cost_metric_supported:
  234. # keep track of the cost
  235. stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
  236. cur_cost,
  237. self.metrics.accumulated_cost,
  238. )
  239. usage: Usage | None = response.get('usage')
  240. if usage:
  241. # keep track of the input and output tokens
  242. input_tokens = usage.get('prompt_tokens')
  243. output_tokens = usage.get('completion_tokens')
  244. if input_tokens:
  245. stats += 'Input tokens: ' + str(input_tokens)
  246. if output_tokens:
  247. stats += (
  248. (' | ' if input_tokens else '')
  249. + 'Output tokens: '
  250. + str(output_tokens)
  251. + '\n'
  252. )
  253. # read the prompt cache hit, if any
  254. prompt_tokens_details: PromptTokensDetails = usage.get(
  255. 'prompt_tokens_details'
  256. )
  257. cache_hit_tokens = (
  258. prompt_tokens_details.cached_tokens if prompt_tokens_details else None
  259. )
  260. if cache_hit_tokens:
  261. stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n'
  262. # For Anthropic, the cache writes have a different cost than regular input tokens
  263. # but litellm doesn't separate them in the usage stats
  264. # so we can read it from the provider-specific extra field
  265. model_extra = usage.get('model_extra', {})
  266. cache_write_tokens = model_extra.get('cache_creation_input_tokens')
  267. if cache_write_tokens:
  268. stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
  269. # log the stats
  270. if stats:
  271. logger.info(stats)
  272. def get_token_count(self, messages):
  273. """Get the number of tokens in a list of messages.
  274. Args:
  275. messages (list): A list of messages.
  276. Returns:
  277. int: The number of tokens.
  278. """
  279. try:
  280. return litellm.token_counter(model=self.config.model, messages=messages)
  281. except Exception:
  282. # TODO: this is to limit logspam in case token count is not supported
  283. return 0
  284. def _is_local(self):
  285. """Determines if the system is using a locally running LLM.
  286. Returns:
  287. boolean: True if executing a local model.
  288. """
  289. if self.config.base_url is not None:
  290. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  291. if substring in self.config.base_url:
  292. return True
  293. elif self.config.model is not None:
  294. if self.config.model.startswith('ollama'):
  295. return True
  296. return False
  297. def _completion_cost(self, response):
  298. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  299. Add the current cost into total cost in metrics.
  300. Args:
  301. response: A response from a model invocation.
  302. Returns:
  303. number: The cost of the response.
  304. """
  305. if not self.cost_metric_supported:
  306. return 0.0
  307. extra_kwargs = {}
  308. if (
  309. self.config.input_cost_per_token is not None
  310. and self.config.output_cost_per_token is not None
  311. ):
  312. cost_per_token = CostPerToken(
  313. input_cost_per_token=self.config.input_cost_per_token,
  314. output_cost_per_token=self.config.output_cost_per_token,
  315. )
  316. logger.info(f'Using custom cost per token: {cost_per_token}')
  317. extra_kwargs['custom_cost_per_token'] = cost_per_token
  318. if not self._is_local():
  319. try:
  320. cost = litellm_completion_cost(
  321. completion_response=response, **extra_kwargs
  322. )
  323. self.metrics.add_cost(cost)
  324. return cost
  325. except Exception:
  326. self.cost_metric_supported = False
  327. logger.warning('Cost calculation not supported for this model.')
  328. return 0.0
  329. def __str__(self):
  330. if self.config.api_version:
  331. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  332. elif self.config.base_url:
  333. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  334. return f'LLM(model={self.config.model})'
  335. def __repr__(self):
  336. return str(self)
  337. def reset(self):
  338. self.metrics.reset()
  339. self.llm_completions = []
  340. def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
  341. if isinstance(messages, Message):
  342. messages = [messages]
  343. # set flags to know how to serialize the messages
  344. for message in messages:
  345. message.cache_enabled = self.is_caching_prompt_active()
  346. message.vision_enabled = self.vision_is_active()
  347. # let pydantic handle the serialization
  348. return [message.model_dump() for message in messages]