llm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. import copy
  2. import os
  3. import time
  4. import warnings
  5. from functools import partial
  6. from typing import Any
  7. from openhands.core.config import LLMConfig
  8. with warnings.catch_warnings():
  9. warnings.simplefilter('ignore')
  10. import litellm
  11. from litellm import completion as litellm_completion
  12. from litellm import completion_cost as litellm_completion_cost
  13. from litellm.exceptions import (
  14. APIConnectionError,
  15. ContentPolicyViolationError,
  16. InternalServerError,
  17. OpenAIError,
  18. RateLimitError,
  19. )
  20. from litellm.types.utils import CostPerToken
  21. from openhands.core.logger import openhands_logger as logger
  22. from openhands.core.message import Message
  23. from openhands.core.metrics import Metrics
  24. from openhands.llm.debug_mixin import DebugMixin
  25. from openhands.llm.retry_mixin import RetryMixin
  26. __all__ = ['LLM']
  27. cache_prompting_supported_models = [
  28. 'claude-3-5-sonnet-20240620',
  29. 'claude-3-haiku-20240307',
  30. ]
  31. class LLM(RetryMixin, DebugMixin):
  32. """The LLM class represents a Language Model instance.
  33. Attributes:
  34. config: an LLMConfig object specifying the configuration of the LLM.
  35. """
  36. def __init__(
  37. self,
  38. config: LLMConfig,
  39. metrics: Metrics | None = None,
  40. ):
  41. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  42. Passing simple parameters always overrides config.
  43. Args:
  44. config: The LLM configuration.
  45. metrics: The metrics to use.
  46. """
  47. self.metrics = metrics if metrics is not None else Metrics()
  48. self.cost_metric_supported = True
  49. self.config = copy.deepcopy(config)
  50. os.environ['OR_SITE_URL'] = self.config.openrouter_site_url
  51. os.environ['OR_APP_NAME'] = self.config.openrouter_app_name
  52. # list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
  53. # - 'messages': list of messages
  54. # - 'response': response from the LLM
  55. self.llm_completions: list[dict[str, Any]] = []
  56. # Set up config attributes with default values to prevent AttributeError
  57. LLMConfig.set_missing_attributes(self.config)
  58. # litellm actually uses base Exception here for unknown model
  59. self.model_info = None
  60. try:
  61. if self.config.model.startswith('openrouter'):
  62. self.model_info = litellm.get_model_info(self.config.model)
  63. else:
  64. self.model_info = litellm.get_model_info(
  65. self.config.model.split(':')[0]
  66. )
  67. # noinspection PyBroadException
  68. except Exception as e:
  69. logger.warning(f'Could not get model info for {config.model}:\n{e}')
  70. # Tuple of exceptions to retry on
  71. self.retry_exceptions = (
  72. APIConnectionError,
  73. ContentPolicyViolationError,
  74. InternalServerError,
  75. OpenAIError,
  76. RateLimitError,
  77. )
  78. # Set the max tokens in an LM-specific way if not set
  79. if self.config.max_input_tokens is None:
  80. if (
  81. self.model_info is not None
  82. and 'max_input_tokens' in self.model_info
  83. and isinstance(self.model_info['max_input_tokens'], int)
  84. ):
  85. self.config.max_input_tokens = self.model_info['max_input_tokens']
  86. else:
  87. # Safe fallback for any potentially viable model
  88. self.config.max_input_tokens = 4096
  89. if self.config.max_output_tokens is None:
  90. # Safe default for any potentially viable model
  91. self.config.max_output_tokens = 4096
  92. if self.model_info is not None:
  93. # max_output_tokens has precedence over max_tokens, if either exists.
  94. # litellm has models with both, one or none of these 2 parameters!
  95. if 'max_output_tokens' in self.model_info and isinstance(
  96. self.model_info['max_output_tokens'], int
  97. ):
  98. self.config.max_output_tokens = self.model_info['max_output_tokens']
  99. elif 'max_tokens' in self.model_info and isinstance(
  100. self.model_info['max_tokens'], int
  101. ):
  102. self.config.max_output_tokens = self.model_info['max_tokens']
  103. self._completion = partial(
  104. litellm_completion,
  105. model=self.config.model,
  106. api_key=self.config.api_key,
  107. base_url=self.config.base_url,
  108. api_version=self.config.api_version,
  109. custom_llm_provider=self.config.custom_llm_provider,
  110. max_tokens=self.config.max_output_tokens,
  111. timeout=self.config.timeout,
  112. temperature=self.config.temperature,
  113. top_p=self.config.top_p,
  114. drop_params=self.config.drop_params,
  115. )
  116. if self.vision_is_active():
  117. logger.debug('LLM: model has vision enabled')
  118. completion_unwrapped = self._completion
  119. @self.retry_decorator(
  120. num_retries=self.config.num_retries,
  121. retry_exceptions=self.retry_exceptions,
  122. retry_min_wait=self.config.retry_min_wait,
  123. retry_max_wait=self.config.retry_max_wait,
  124. retry_multiplier=self.config.retry_multiplier,
  125. )
  126. def wrapper(*args, **kwargs):
  127. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  128. # some callers might just send the messages directly
  129. if 'messages' in kwargs:
  130. messages = kwargs['messages']
  131. else:
  132. messages = args[1] if len(args) > 1 else []
  133. # if we have no messages, something went very wrong
  134. if not messages:
  135. raise ValueError(
  136. 'The messages list is empty. At least one message is required.'
  137. )
  138. # log the entire LLM prompt
  139. self.log_prompt(messages)
  140. if self.is_caching_prompt_active():
  141. # Anthropic-specific prompt caching
  142. if 'claude-3' in self.config.model:
  143. kwargs['extra_headers'] = {
  144. 'anthropic-beta': 'prompt-caching-2024-07-31',
  145. }
  146. resp = completion_unwrapped(*args, **kwargs)
  147. # log for evals or other scripts that need the raw completion
  148. if self.config.log_completions:
  149. self.llm_completions.append(
  150. {
  151. 'messages': messages,
  152. 'response': resp,
  153. 'timestamp': time.time(),
  154. 'cost': self._completion_cost(resp),
  155. }
  156. )
  157. message_back = resp['choices'][0]['message']['content']
  158. # log the LLM response
  159. self.log_response(message_back)
  160. # post-process the response
  161. self._post_completion(resp)
  162. return resp
  163. self._completion = wrapper
  164. @property
  165. def completion(self):
  166. """Decorator for the litellm completion function.
  167. Check the complete documentation at https://litellm.vercel.app/docs/completion
  168. """
  169. return self._completion
  170. def vision_is_active(self):
  171. return not self.config.disable_vision and self._supports_vision()
  172. def _supports_vision(self):
  173. """Acquire from litellm if model is vision capable.
  174. Returns:
  175. bool: True if model is vision capable. If model is not supported by litellm, it will return False.
  176. """
  177. try:
  178. return litellm.supports_vision(self.config.model)
  179. except Exception:
  180. return False
  181. def is_caching_prompt_active(self) -> bool:
  182. """Check if prompt caching is enabled and supported for current model.
  183. Returns:
  184. boolean: True if prompt caching is active for the given model.
  185. """
  186. return self.config.caching_prompt is True and any(
  187. model in self.config.model for model in cache_prompting_supported_models
  188. )
  189. def _post_completion(self, response) -> None:
  190. """Post-process the completion response.
  191. Logs the cost and usage stats of the completion call.
  192. """
  193. try:
  194. cur_cost = self._completion_cost(response)
  195. except Exception:
  196. cur_cost = 0
  197. stats = ''
  198. if self.cost_metric_supported:
  199. # keep track of the cost
  200. stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
  201. cur_cost,
  202. self.metrics.accumulated_cost,
  203. )
  204. usage = response.get('usage')
  205. if usage:
  206. # keep track of the input and output tokens
  207. input_tokens = usage.get('prompt_tokens')
  208. output_tokens = usage.get('completion_tokens')
  209. if input_tokens:
  210. stats += 'Input tokens: ' + str(input_tokens)
  211. if output_tokens:
  212. stats += (
  213. (' | ' if input_tokens else '')
  214. + 'Output tokens: '
  215. + str(output_tokens)
  216. + '\n'
  217. )
  218. # read the prompt caching status as received from the provider
  219. model_extra = usage.get('model_extra', {})
  220. cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
  221. if cache_creation_input_tokens:
  222. stats += (
  223. 'Input tokens (cache write): '
  224. + str(cache_creation_input_tokens)
  225. + '\n'
  226. )
  227. cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
  228. if cache_read_input_tokens:
  229. stats += (
  230. 'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
  231. )
  232. # log the stats
  233. if stats:
  234. logger.info(stats)
  235. def get_token_count(self, messages):
  236. """Get the number of tokens in a list of messages.
  237. Args:
  238. messages (list): A list of messages.
  239. Returns:
  240. int: The number of tokens.
  241. """
  242. try:
  243. return litellm.token_counter(model=self.config.model, messages=messages)
  244. except Exception:
  245. # TODO: this is to limit logspam in case token count is not supported
  246. return 0
  247. def _is_local(self):
  248. """Determines if the system is using a locally running LLM.
  249. Returns:
  250. boolean: True if executing a local model.
  251. """
  252. if self.config.base_url is not None:
  253. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  254. if substring in self.config.base_url:
  255. return True
  256. elif self.config.model is not None:
  257. if self.config.model.startswith('ollama'):
  258. return True
  259. return False
  260. def _completion_cost(self, response):
  261. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  262. Add the current cost into total cost in metrics.
  263. Args:
  264. response: A response from a model invocation.
  265. Returns:
  266. number: The cost of the response.
  267. """
  268. if not self.cost_metric_supported:
  269. return 0.0
  270. extra_kwargs = {}
  271. if (
  272. self.config.input_cost_per_token is not None
  273. and self.config.output_cost_per_token is not None
  274. ):
  275. cost_per_token = CostPerToken(
  276. input_cost_per_token=self.config.input_cost_per_token,
  277. output_cost_per_token=self.config.output_cost_per_token,
  278. )
  279. logger.info(f'Using custom cost per token: {cost_per_token}')
  280. extra_kwargs['custom_cost_per_token'] = cost_per_token
  281. if not self._is_local():
  282. try:
  283. cost = litellm_completion_cost(
  284. completion_response=response, **extra_kwargs
  285. )
  286. self.metrics.add_cost(cost)
  287. return cost
  288. except Exception:
  289. self.cost_metric_supported = False
  290. logger.warning('Cost calculation not supported for this model.')
  291. return 0.0
  292. def __str__(self):
  293. if self.config.api_version:
  294. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  295. elif self.config.base_url:
  296. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  297. return f'LLM(model={self.config.model})'
  298. def __repr__(self):
  299. return str(self)
  300. def reset(self):
  301. self.metrics = Metrics()
  302. self.llm_completions = []
  303. def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
  304. if isinstance(messages, Message):
  305. return [messages.model_dump()]
  306. return [message.model_dump() for message in messages]