llm.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. import copy
  2. import json
  3. import os
  4. import time
  5. import warnings
  6. from functools import partial
  7. from typing import Any
  8. from openhands.core.config import LLMConfig
  9. with warnings.catch_warnings():
  10. warnings.simplefilter('ignore')
  11. import litellm
  12. from litellm import ModelInfo, PromptTokensDetails
  13. from litellm import completion as litellm_completion
  14. from litellm import completion_cost as litellm_completion_cost
  15. from litellm.exceptions import (
  16. APIConnectionError,
  17. APIError,
  18. InternalServerError,
  19. RateLimitError,
  20. ServiceUnavailableError,
  21. )
  22. from litellm.types.utils import CostPerToken, ModelResponse, Usage
  23. from openhands.core.exceptions import CloudFlareBlockageError
  24. from openhands.core.logger import openhands_logger as logger
  25. from openhands.core.message import Message
  26. from openhands.llm.debug_mixin import DebugMixin
  27. from openhands.llm.metrics import Metrics
  28. from openhands.llm.retry_mixin import RetryMixin
  29. __all__ = ['LLM']
  30. # tuple of exceptions to retry on
  31. LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
  32. APIConnectionError,
  33. # FIXME: APIError is useful on 502 from a proxy for example,
  34. # but it also retries on other errors that are permanent
  35. APIError,
  36. InternalServerError,
  37. RateLimitError,
  38. ServiceUnavailableError,
  39. )
  40. # cache prompt supporting models
  41. # remove this when we gemini and deepseek are supported
  42. CACHE_PROMPT_SUPPORTED_MODELS = [
  43. 'claude-3-5-sonnet-20240620',
  44. 'claude-3-5-sonnet-20241022',
  45. 'claude-3-haiku-20240307',
  46. 'claude-3-opus-20240229',
  47. ]
  48. class LLM(RetryMixin, DebugMixin):
  49. """The LLM class represents a Language Model instance.
  50. Attributes:
  51. config: an LLMConfig object specifying the configuration of the LLM.
  52. """
  53. def __init__(
  54. self,
  55. config: LLMConfig,
  56. metrics: Metrics | None = None,
  57. ):
  58. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  59. Passing simple parameters always overrides config.
  60. Args:
  61. config: The LLM configuration.
  62. metrics: The metrics to use.
  63. """
  64. self.metrics: Metrics = (
  65. metrics if metrics is not None else Metrics(model_name=config.model)
  66. )
  67. self.cost_metric_supported: bool = True
  68. self.config: LLMConfig = copy.deepcopy(config)
  69. # litellm actually uses base Exception here for unknown model
  70. self.model_info: ModelInfo | None = None
  71. try:
  72. if self.config.model.startswith('openrouter'):
  73. self.model_info = litellm.get_model_info(self.config.model)
  74. else:
  75. self.model_info = litellm.get_model_info(
  76. self.config.model.split(':')[0]
  77. )
  78. # noinspection PyBroadException
  79. except Exception as e:
  80. logger.warning(f'Could not get model info for {config.model}:\n{e}')
  81. if self.config.log_completions:
  82. if self.config.log_completions_folder is None:
  83. raise RuntimeError(
  84. 'log_completions_folder is required when log_completions is enabled'
  85. )
  86. os.makedirs(self.config.log_completions_folder, exist_ok=True)
  87. # Set the max tokens in an LM-specific way if not set
  88. if self.config.max_input_tokens is None:
  89. if (
  90. self.model_info is not None
  91. and 'max_input_tokens' in self.model_info
  92. and isinstance(self.model_info['max_input_tokens'], int)
  93. ):
  94. self.config.max_input_tokens = self.model_info['max_input_tokens']
  95. else:
  96. # Safe fallback for any potentially viable model
  97. self.config.max_input_tokens = 4096
  98. if self.config.max_output_tokens is None:
  99. # Safe default for any potentially viable model
  100. self.config.max_output_tokens = 4096
  101. if self.model_info is not None:
  102. # max_output_tokens has precedence over max_tokens, if either exists.
  103. # litellm has models with both, one or none of these 2 parameters!
  104. if 'max_output_tokens' in self.model_info and isinstance(
  105. self.model_info['max_output_tokens'], int
  106. ):
  107. self.config.max_output_tokens = self.model_info['max_output_tokens']
  108. elif 'max_tokens' in self.model_info and isinstance(
  109. self.model_info['max_tokens'], int
  110. ):
  111. self.config.max_output_tokens = self.model_info['max_tokens']
  112. self._completion = partial(
  113. litellm_completion,
  114. model=self.config.model,
  115. api_key=self.config.api_key,
  116. base_url=self.config.base_url,
  117. api_version=self.config.api_version,
  118. custom_llm_provider=self.config.custom_llm_provider,
  119. max_tokens=self.config.max_output_tokens,
  120. timeout=self.config.timeout,
  121. temperature=self.config.temperature,
  122. top_p=self.config.top_p,
  123. drop_params=self.config.drop_params,
  124. )
  125. if self.vision_is_active():
  126. logger.debug('LLM: model has vision enabled')
  127. if self.is_caching_prompt_active():
  128. logger.debug('LLM: caching prompt enabled')
  129. completion_unwrapped = self._completion
  130. @self.retry_decorator(
  131. num_retries=self.config.num_retries,
  132. retry_exceptions=LLM_RETRY_EXCEPTIONS,
  133. retry_min_wait=self.config.retry_min_wait,
  134. retry_max_wait=self.config.retry_max_wait,
  135. retry_multiplier=self.config.retry_multiplier,
  136. )
  137. def wrapper(*args, **kwargs):
  138. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  139. messages: list[dict[str, Any]] | dict[str, Any] = []
  140. # some callers might send the model and messages directly
  141. # litellm allows positional args, like completion(model, messages, **kwargs)
  142. if len(args) > 1:
  143. # ignore the first argument if it's provided (it would be the model)
  144. # design wise: we don't allow overriding the configured values
  145. # implementation wise: the partial function set the model as a kwarg already
  146. # as well as other kwargs
  147. messages = args[1] if len(args) > 1 else args[0]
  148. kwargs['messages'] = messages
  149. # remove the first args, they're sent in kwargs
  150. args = args[2:]
  151. elif 'messages' in kwargs:
  152. messages = kwargs['messages']
  153. # ensure we work with a list of messages
  154. messages = messages if isinstance(messages, list) else [messages]
  155. # if we have no messages, something went very wrong
  156. if not messages:
  157. raise ValueError(
  158. 'The messages list is empty. At least one message is required.'
  159. )
  160. # log the entire LLM prompt
  161. self.log_prompt(messages)
  162. if self.is_caching_prompt_active():
  163. # Anthropic-specific prompt caching
  164. if 'claude-3' in self.config.model:
  165. kwargs['extra_headers'] = {
  166. 'anthropic-beta': 'prompt-caching-2024-07-31',
  167. }
  168. try:
  169. # we don't support streaming here, thus we get a ModelResponse
  170. resp: ModelResponse = completion_unwrapped(*args, **kwargs)
  171. # log for evals or other scripts that need the raw completion
  172. if self.config.log_completions:
  173. assert self.config.log_completions_folder is not None
  174. log_file = os.path.join(
  175. self.config.log_completions_folder,
  176. # use the metric model name (for draft editor)
  177. f'{self.metrics.model_name}-{time.time()}.json',
  178. )
  179. with open(log_file, 'w') as f:
  180. json.dump(
  181. {
  182. 'messages': messages,
  183. 'response': resp,
  184. 'args': args,
  185. 'kwargs': kwargs,
  186. 'timestamp': time.time(),
  187. 'cost': self._completion_cost(resp),
  188. },
  189. f,
  190. )
  191. message_back: str = resp['choices'][0]['message']['content']
  192. # log the LLM response
  193. self.log_response(message_back)
  194. # post-process the response
  195. self._post_completion(resp)
  196. return resp
  197. except APIError as e:
  198. if 'Attention Required! | Cloudflare' in str(e):
  199. raise CloudFlareBlockageError(
  200. 'Request blocked by CloudFlare'
  201. ) from e
  202. raise
  203. self._completion = wrapper
  204. @property
  205. def completion(self):
  206. """Decorator for the litellm completion function.
  207. Check the complete documentation at https://litellm.vercel.app/docs/completion
  208. """
  209. return self._completion
  210. def vision_is_active(self):
  211. return not self.config.disable_vision and self._supports_vision()
  212. def _supports_vision(self):
  213. """Acquire from litellm if model is vision capable.
  214. Returns:
  215. bool: True if model is vision capable. If model is not supported by litellm, it will return False.
  216. """
  217. # litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes)
  218. # but model_info will have the correct value for some reason.
  219. # we can go with it, but we will need to keep an eye if model_info is correct for Vertex or other providers
  220. # remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
  221. return litellm.supports_vision(self.config.model) or (
  222. self.model_info is not None
  223. and self.model_info.get('supports_vision', False)
  224. )
  225. def is_caching_prompt_active(self) -> bool:
  226. """Check if prompt caching is supported and enabled for current model.
  227. Returns:
  228. boolean: True if prompt caching is supported and enabled for the given model.
  229. """
  230. return (
  231. self.config.caching_prompt is True
  232. and self.model_info is not None
  233. and self.model_info.get('supports_prompt_caching', False)
  234. and (
  235. self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
  236. or self.config.model.split('/')[-1] in CACHE_PROMPT_SUPPORTED_MODELS
  237. )
  238. )
  239. def _post_completion(self, response: ModelResponse) -> None:
  240. """Post-process the completion response.
  241. Logs the cost and usage stats of the completion call.
  242. """
  243. try:
  244. cur_cost = self._completion_cost(response)
  245. except Exception:
  246. cur_cost = 0
  247. stats = ''
  248. if self.cost_metric_supported:
  249. # keep track of the cost
  250. stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
  251. cur_cost,
  252. self.metrics.accumulated_cost,
  253. )
  254. usage: Usage | None = response.get('usage')
  255. if usage:
  256. # keep track of the input and output tokens
  257. input_tokens = usage.get('prompt_tokens')
  258. output_tokens = usage.get('completion_tokens')
  259. if input_tokens:
  260. stats += 'Input tokens: ' + str(input_tokens)
  261. if output_tokens:
  262. stats += (
  263. (' | ' if input_tokens else '')
  264. + 'Output tokens: '
  265. + str(output_tokens)
  266. + '\n'
  267. )
  268. # read the prompt cache hit, if any
  269. prompt_tokens_details: PromptTokensDetails = usage.get(
  270. 'prompt_tokens_details'
  271. )
  272. cache_hit_tokens = (
  273. prompt_tokens_details.cached_tokens if prompt_tokens_details else None
  274. )
  275. if cache_hit_tokens:
  276. stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n'
  277. # For Anthropic, the cache writes have a different cost than regular input tokens
  278. # but litellm doesn't separate them in the usage stats
  279. # so we can read it from the provider-specific extra field
  280. model_extra = usage.get('model_extra', {})
  281. cache_write_tokens = model_extra.get('cache_creation_input_tokens')
  282. if cache_write_tokens:
  283. stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
  284. # log the stats
  285. if stats:
  286. logger.info(stats)
  287. def get_token_count(self, messages):
  288. """Get the number of tokens in a list of messages.
  289. Args:
  290. messages (list): A list of messages.
  291. Returns:
  292. int: The number of tokens.
  293. """
  294. try:
  295. return litellm.token_counter(model=self.config.model, messages=messages)
  296. except Exception:
  297. # TODO: this is to limit logspam in case token count is not supported
  298. return 0
  299. def _is_local(self):
  300. """Determines if the system is using a locally running LLM.
  301. Returns:
  302. boolean: True if executing a local model.
  303. """
  304. if self.config.base_url is not None:
  305. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  306. if substring in self.config.base_url:
  307. return True
  308. elif self.config.model is not None:
  309. if self.config.model.startswith('ollama'):
  310. return True
  311. return False
  312. def _completion_cost(self, response):
  313. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  314. Add the current cost into total cost in metrics.
  315. Args:
  316. response: A response from a model invocation.
  317. Returns:
  318. number: The cost of the response.
  319. """
  320. if not self.cost_metric_supported:
  321. return 0.0
  322. extra_kwargs = {}
  323. if (
  324. self.config.input_cost_per_token is not None
  325. and self.config.output_cost_per_token is not None
  326. ):
  327. cost_per_token = CostPerToken(
  328. input_cost_per_token=self.config.input_cost_per_token,
  329. output_cost_per_token=self.config.output_cost_per_token,
  330. )
  331. logger.info(f'Using custom cost per token: {cost_per_token}')
  332. extra_kwargs['custom_cost_per_token'] = cost_per_token
  333. if not self._is_local():
  334. try:
  335. cost = litellm_completion_cost(
  336. completion_response=response, **extra_kwargs
  337. )
  338. self.metrics.add_cost(cost)
  339. return cost
  340. except Exception:
  341. self.cost_metric_supported = False
  342. logger.warning('Cost calculation not supported for this model.')
  343. return 0.0
  344. def __str__(self):
  345. if self.config.api_version:
  346. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  347. elif self.config.base_url:
  348. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  349. return f'LLM(model={self.config.model})'
  350. def __repr__(self):
  351. return str(self)
  352. def reset(self):
  353. self.metrics.reset()
  354. def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
  355. if isinstance(messages, Message):
  356. messages = [messages]
  357. # set flags to know how to serialize the messages
  358. for message in messages:
  359. message.cache_enabled = self.is_caching_prompt_active()
  360. message.vision_enabled = self.vision_is_active()
  361. # let pydantic handle the serialization
  362. return [message.model_dump() for message in messages]