llm.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import copy
  2. import warnings
  3. from functools import partial
  4. from opendevin.core.config import LLMConfig
  5. with warnings.catch_warnings():
  6. warnings.simplefilter('ignore')
  7. import litellm
  8. from litellm import completion as litellm_completion
  9. from litellm import completion_cost as litellm_completion_cost
  10. from litellm.exceptions import (
  11. APIConnectionError,
  12. ContentPolicyViolationError,
  13. InternalServerError,
  14. RateLimitError,
  15. ServiceUnavailableError,
  16. )
  17. from litellm.types.utils import CostPerToken
  18. from tenacity import (
  19. retry,
  20. retry_if_exception_type,
  21. stop_after_attempt,
  22. wait_random_exponential,
  23. )
  24. from opendevin.core.logger import llm_prompt_logger, llm_response_logger
  25. from opendevin.core.logger import opendevin_logger as logger
  26. from opendevin.core.metrics import Metrics
  27. __all__ = ['LLM']
  28. message_separator = '\n\n----------\n\n'
  29. class LLM:
  30. """The LLM class represents a Language Model instance.
  31. Attributes:
  32. config: an LLMConfig object specifying the configuration of the LLM.
  33. """
  34. def __init__(
  35. self,
  36. config: LLMConfig,
  37. metrics: Metrics | None = None,
  38. ):
  39. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  40. Passing simple parameters always overrides config.
  41. Args:
  42. config: The LLM configuration
  43. """
  44. self.config = copy.deepcopy(config)
  45. self.metrics = metrics if metrics is not None else Metrics()
  46. self.cost_metric_supported = True
  47. # litellm actually uses base Exception here for unknown model
  48. self.model_info = None
  49. try:
  50. if self.config.model.startswith('openrouter'):
  51. self.model_info = litellm.get_model_info(self.config.model)
  52. else:
  53. self.model_info = litellm.get_model_info(
  54. self.config.model.split(':')[0]
  55. )
  56. # noinspection PyBroadException
  57. except Exception:
  58. logger.warning(f'Could not get model info for {config.model}')
  59. # Set the max tokens in an LM-specific way if not set
  60. if config.max_input_tokens is None:
  61. if (
  62. self.model_info is not None
  63. and 'max_input_tokens' in self.model_info
  64. and isinstance(self.model_info['max_input_tokens'], int)
  65. ):
  66. self.config.max_input_tokens = self.model_info['max_input_tokens']
  67. else:
  68. # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
  69. self.config.max_input_tokens = 4096
  70. if config.max_output_tokens is None:
  71. if (
  72. self.model_info is not None
  73. and 'max_output_tokens' in self.model_info
  74. and isinstance(self.model_info['max_output_tokens'], int)
  75. ):
  76. self.config.max_output_tokens = self.model_info['max_output_tokens']
  77. else:
  78. # Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
  79. self.config.max_output_tokens = 1024
  80. if self.config.drop_params:
  81. litellm.drop_params = self.config.drop_params
  82. self._completion = partial(
  83. litellm_completion,
  84. model=self.config.model,
  85. api_key=self.config.api_key,
  86. base_url=self.config.base_url,
  87. api_version=self.config.api_version,
  88. custom_llm_provider=self.config.custom_llm_provider,
  89. max_tokens=self.config.max_output_tokens,
  90. timeout=self.config.timeout,
  91. temperature=self.config.temperature,
  92. top_p=self.config.top_p,
  93. )
  94. completion_unwrapped = self._completion
  95. def attempt_on_error(retry_state):
  96. logger.error(
  97. f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize these settings in the configuration.',
  98. exc_info=False,
  99. )
  100. return None
  101. @retry(
  102. reraise=True,
  103. stop=stop_after_attempt(config.num_retries),
  104. wait=wait_random_exponential(
  105. multiplier=config.retry_multiplier,
  106. min=config.retry_min_wait,
  107. max=config.retry_max_wait,
  108. ),
  109. retry=retry_if_exception_type(
  110. (
  111. RateLimitError,
  112. APIConnectionError,
  113. ServiceUnavailableError,
  114. InternalServerError,
  115. ContentPolicyViolationError,
  116. )
  117. ),
  118. after=attempt_on_error,
  119. )
  120. def wrapper(*args, **kwargs):
  121. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  122. # some callers might just send the messages directly
  123. if 'messages' in kwargs:
  124. messages = kwargs['messages']
  125. else:
  126. messages = args[1]
  127. # log the prompt
  128. debug_message = ''
  129. for message in messages:
  130. debug_message += message_separator + message['content']
  131. llm_prompt_logger.debug(debug_message)
  132. # call the completion function
  133. resp = completion_unwrapped(*args, **kwargs)
  134. # log the response
  135. message_back = resp['choices'][0]['message']['content']
  136. llm_response_logger.debug(message_back)
  137. # post-process to log costs
  138. self._post_completion(resp)
  139. return resp
  140. self._completion = wrapper # type: ignore
  141. @property
  142. def completion(self):
  143. """Decorator for the litellm completion function.
  144. Check the complete documentation at https://litellm.vercel.app/docs/completion
  145. """
  146. return self._completion
  147. def _post_completion(self, response: str) -> None:
  148. """Post-process the completion response."""
  149. try:
  150. cur_cost = self.completion_cost(response)
  151. except Exception:
  152. cur_cost = 0
  153. if self.cost_metric_supported:
  154. logger.info(
  155. 'Cost: %.2f USD | Accumulated Cost: %.2f USD',
  156. cur_cost,
  157. self.metrics.accumulated_cost,
  158. )
  159. def get_token_count(self, messages):
  160. """Get the number of tokens in a list of messages.
  161. Args:
  162. messages (list): A list of messages.
  163. Returns:
  164. int: The number of tokens.
  165. """
  166. return litellm.token_counter(model=self.config.model, messages=messages)
  167. def is_local(self):
  168. """Determines if the system is using a locally running LLM.
  169. Returns:
  170. boolean: True if executing a local model.
  171. """
  172. if self.config.base_url is not None:
  173. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  174. if substring in self.config.base_url:
  175. return True
  176. elif self.config.model is not None:
  177. if self.config.model.startswith('ollama'):
  178. return True
  179. return False
  180. def completion_cost(self, response):
  181. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  182. Add the current cost into total cost in metrics.
  183. Args:
  184. response: A response from a model invocation.
  185. Returns:
  186. number: The cost of the response.
  187. """
  188. if not self.cost_metric_supported:
  189. return 0.0
  190. extra_kwargs = {}
  191. if (
  192. self.config.input_cost_per_token is not None
  193. and self.config.output_cost_per_token is not None
  194. ):
  195. cost_per_token = CostPerToken(
  196. input_cost_per_token=self.config.input_cost_per_token,
  197. output_cost_per_token=self.config.output_cost_per_token,
  198. )
  199. logger.info(f'Using custom cost per token: {cost_per_token}')
  200. extra_kwargs['custom_cost_per_token'] = cost_per_token
  201. if not self.is_local():
  202. try:
  203. cost = litellm_completion_cost(
  204. completion_response=response, **extra_kwargs
  205. )
  206. self.metrics.add_cost(cost)
  207. return cost
  208. except Exception:
  209. self.cost_metric_supported = False
  210. logger.warning('Cost calculation not supported for this model.')
  211. return 0.0
  212. def __str__(self):
  213. if self.config.api_version:
  214. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  215. elif self.config.base_url:
  216. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  217. return f'LLM(model={self.config.model})'
  218. def __repr__(self):
  219. return str(self)
  220. def reset(self):
  221. self.metrics = Metrics()